Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9ee0ff1d
"deploy/hubserving/ocr_cls/module.py" did not exist on "f09d1f730b71a9b220cceec14a765c24cd9c1d6f"
Commit
9ee0ff1d
authored
Jul 20, 2023
by
Tri Dao
Browse files
Fix using dO stride for O, which can cause memory error in bwd
parent
2dd87d06
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+4
-4
No files found.
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
9ee0ff1d
...
...
@@ -141,7 +141,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
make_stride
(
params
.
do_row_stride
,
_1
{}));
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
d
o_row_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gdQaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dq_accum_ptr
)
+
row_offset_dq_accum
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
Tensor
dP_sum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dsoftmax_sum
)
+
row_offset_dpsum
),
...
...
@@ -474,7 +474,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
make_stride
(
params
.
do_row_stride
,
_1
{}));
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
d
o_row_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gdQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
dq_ptr
)
+
row_offset_dq
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
dq_row_stride
,
_1
{}));
...
...
@@ -1098,7 +1098,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
const
index_t
row_offset_do
=
binfo
.
q_offset
(
params
.
do_batch_stride
,
params
.
do_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
do_row_stride
+
bidh
*
params
.
do_head_stride
;
const
index_t
row_offset_o
=
binfo
.
q_offset
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
m_block
*
kBlockM
*
params
.
d
o_row_stride
+
bidh
*
params
.
o_head_stride
;
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
// We'll advance gdKaccum and gdVaccum before the first write.
const
index_t
row_offset_dkv_accum
=
((
bidb
*
params
.
h_k
+
(
bidh
/
params
.
h_h_k_ratio
))
*
params
.
seqlen_k_rounded
+
n_block_max
*
kBlockN
)
*
params
.
d_rounded
;
...
...
@@ -1119,7 +1119,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
make_stride
(
params
.
do_row_stride
,
_1
{}));
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
d
o_row_stride
,
_1
{}));
make_stride
(
params
.
o_row_stride
,
_1
{}));
Tensor
gdKaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
params
.
dk_accum_ptr
)
+
row_offset_dkv_accum
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Stride
<
Int
<
kHeadDim
>
,
_1
>
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment