Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9486635c
Unverified
Commit
9486635c
authored
Jul 01, 2024
by
66RING
Committed by
GitHub
Jun 30, 2024
Browse files
Fix typos of comments about shape. (#837)
parent
0d810cfb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
6 deletions
+6
-6
csrc/flash_attn/src/alibi.h
csrc/flash_attn/src/alibi.h
+1
-1
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+2
-2
csrc/flash_attn/src/mask.h
csrc/flash_attn/src/mask.h
+3
-3
No files found.
csrc/flash_attn/src/alibi.h
View file @
9486635c
...
@@ -31,7 +31,7 @@ struct Alibi {
...
@@ -31,7 +31,7 @@ struct Alibi {
const
int
col_idx_offset_
,
const
int
col_idx_offset_
,
const
int
row_idx_offset
,
const
int
row_idx_offset
,
const
int
warp_row_stride
)
{
const
int
warp_row_stride
)
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
9486635c
...
@@ -471,7 +471,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -471,7 +471,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma_sdp
,
flash
::
gemm
(
acc_s
,
tSrQ
,
tSrK
,
tSsQ
,
tSsK
,
tiled_mma_sdp
,
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (
col
=(2, MMA_N),
row
=(2, MMA_N))
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (
row
=(2, MMA_N),
col
=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
// if (cute::thread(32, 0)) { print(scores); }
// if (cute::thread(32, 0)) { print(scores); }
...
@@ -565,7 +565,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -565,7 +565,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
smem_tiled_copy_QdO
,
smem_tiled_copy_KV
,
smem_thr_copy_QdO
,
smem_thr_copy_KV
);
);
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (
col
=(2, MMA_N),
row
=(2, MMA_N))
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (
row
=(2, MMA_N),
col
=(2, MMA_N))
Tensor
dS
=
make_tensor
(
acc_dp
.
data
(),
scores
.
layout
());
Tensor
dS
=
make_tensor
(
acc_dp
.
data
(),
scores
.
layout
());
auto
pointwise_mult
=
[](
float
p
,
float
dp
,
float
d
)
{
auto
pointwise_mult
=
[](
float
p
,
float
dp
,
float
d
)
{
return
p
*
(
!
Is_dropout
||
p
>=
0
?
dp
-
d
:
d
);
return
p
*
(
!
Is_dropout
||
p
>=
0
?
dp
-
d
:
d
);
...
...
csrc/flash_attn/src/mask.h
View file @
9486635c
...
@@ -13,7 +13,7 @@ using namespace cute;
...
@@ -13,7 +13,7 @@ using namespace cute;
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
__forceinline__
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
__forceinline__
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
...
@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
...
@@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
...
@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
...
@@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
{
// tensor has shape (n
col
=(2, MMA_M), n
row
=(2, MMA_N))
// tensor has shape (n
row
=(2, MMA_M), n
col
=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment