"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "814e8a877398f77ab8b27d776049fa172ea471bc"
Commit 9ee0ff1d authored by Tri Dao's avatar Tri Dao
Browse files

Fix using dO stride for O, which can cause memory error in bwd

parent 2dd87d06
...@@ -141,7 +141,7 @@ inline __device__ void compute_dot_do_o(const Params &params) { ...@@ -141,7 +141,7 @@ inline __device__ void compute_dot_do_o(const Params &params) {
make_stride(params.do_row_stride, _1{})); make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o), Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{})); make_stride(params.o_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum), Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{}); Shape<Int<kBlockM>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum), Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
...@@ -474,7 +474,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -474,7 +474,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
make_stride(params.do_row_stride, _1{})); make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o), Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{})); make_stride(params.o_row_stride, _1{}));
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq), Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{})); make_stride(params.dq_row_stride, _1{}));
...@@ -1098,7 +1098,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1098,7 +1098,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.do_row_stride + bidh * params.o_head_stride; + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
// We'll advance gdKaccum and gdVaccum before the first write. // We'll advance gdKaccum and gdVaccum before the first write.
const index_t row_offset_dkv_accum = ((bidb * params.h_k + (bidh / params.h_h_k_ratio)) * params.seqlen_k_rounded const index_t row_offset_dkv_accum = ((bidb * params.h_k + (bidh / params.h_h_k_ratio)) * params.seqlen_k_rounded
+ n_block_max * kBlockN) * params.d_rounded; + n_block_max * kBlockN) * params.d_rounded;
...@@ -1119,7 +1119,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1119,7 +1119,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
make_stride(params.do_row_stride, _1{})); make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o), Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{})); make_stride(params.o_row_stride, _1{}));
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum), Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{}); Stride<Int<kHeadDim>, _1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment