Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5d07483b
Commit
5d07483b
authored
Jun 12, 2022
by
Tri Dao
Browse files
Refactor Gmem code to store q, k, v pointers separately
parent
d3e64409
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
93 additions
and
193 deletions
+93
-193
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+11
-5
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+12
-9
csrc/flash_attn/src/fmha/gmem_tile.h
csrc/flash_attn/src/fmha/gmem_tile.h
+37
-126
csrc/flash_attn/src/fmha/kernel_traits.h
csrc/flash_attn/src/fmha/kernel_traits.h
+1
-3
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
+10
-19
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
+7
-7
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+10
-19
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+5
-5
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
5d07483b
...
@@ -56,12 +56,18 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
...
@@ -56,12 +56,18 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
memset
(
&
params
,
0
,
sizeof
(
params
));
memset
(
&
params
,
0
,
sizeof
(
params
));
// Set the pointers and strides.
// Set the pointers and strides.
params
.
qkv_ptr
=
qkv_packed_d
;
params
.
q_ptr
=
qkv_packed_d
;
params
.
qkv_stride_in_elts
=
h
*
3
*
d
;
params
.
k_ptr
=
qkv_packed_d
+
get_size_in_bytes
(
h
*
d
,
data_type
);
params
.
qkv_stride_in_bytes
=
get_size_in_bytes
(
h
*
3
*
d
,
data_type
);
params
.
v_ptr
=
qkv_packed_d
+
2
*
get_size_in_bytes
(
h
*
d
,
data_type
);
params
.
q_row_stride_in_elts
=
3
*
h
*
d
;
params
.
k_row_stride_in_elts
=
3
*
h
*
d
;
params
.
v_row_stride_in_elts
=
3
*
h
*
d
;
params
.
q_head_stride_in_elts
=
d
;
params
.
k_head_stride_in_elts
=
d
;
params
.
v_head_stride_in_elts
=
d
;
params
.
o_ptr
=
o_packed_d
;
params
.
o_ptr
=
o_packed_d
;
params
.
o_stride_in_elts
=
h
*
d
;
params
.
o_
row_
stride_in_elts
=
h
*
d
;
params
.
o_stride_in_
bytes
=
get_size_in_bytes
(
h
*
d
,
data_type
)
;
params
.
o_
head_
stride_in_
elts
=
d
;
params
.
do_ptr
=
do_packed_d
;
params
.
do_ptr
=
do_packed_d
;
params
.
o_tmp_ptr
=
o_tmp_d
;
params
.
o_tmp_ptr
=
o_tmp_d
;
...
...
csrc/flash_attn/src/fmha.h
View file @
5d07483b
...
@@ -50,15 +50,21 @@ constexpr int D_DIM = 3;
...
@@ -50,15 +50,21 @@ constexpr int D_DIM = 3;
struct
Qkv_params
{
struct
Qkv_params
{
// The QKV matrices.
// The QKV matrices.
void
*
__restrict__
qkv_ptr
;
void
*
__restrict__
q_ptr
;
void
*
__restrict__
k_ptr
;
void
*
__restrict__
v_ptr
;
// The stride between rows of the Q, K and V matrices.
// The stride between rows of the Q, K and V matrices.
// size_t qkv_stride_in_elts;
// size_t qkv_stride_in_elts;
// size_t qkv_stride_in_bytes;
// size_t qkv_stride_in_bytes;
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
// The code probably won't work for arrays larger than 2GB.
uint32_t
qkv_stride_in_elts
;
uint32_t
q_row_stride_in_elts
;
uint32_t
qkv_stride_in_bytes
;
uint32_t
k_row_stride_in_elts
;
uint32_t
v_row_stride_in_elts
;
uint32_t
q_head_stride_in_elts
;
uint32_t
k_head_stride_in_elts
;
uint32_t
v_head_stride_in_elts
;
// The number of heads.
// The number of heads.
int
h
;
int
h
;
...
@@ -71,17 +77,14 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -71,17 +77,14 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
// The dQKV matrices.
void
*
__restrict__
dqkv_ptr
;
void
*
__restrict__
dqkv_ptr
;
// Temporary for dKV.
void
*
__restrict__
dkv_ptr
;
// The O matrix (output).
// The O matrix (output).
void
*
__restrict__
o_ptr
;
void
*
__restrict__
o_ptr
;
// The stride between rows of O.
// The stride between rows of O.
// size_t o_stride_in_elts;
// size_t o_stride_in_elts;
// size_t o_stride_in_bytes;
// size_t o_stride_in_bytes;
uint32_t
o_stride_in_elts
;
uint32_t
o_
row_
stride_in_elts
;
uint32_t
o_stride_in_
byte
s
;
uint32_t
o_
head_
stride_in_
elt
s
;
// The pointer to the O_tmp matrix, which holds O intermediate value during
// The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop;
// the loop;
...
@@ -171,4 +174,4 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
...
@@ -171,4 +174,4 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
void
run_fmha_block_fp16_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_block_fp16_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_block_dgrad_fp16_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_block_dgrad_fp16_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
\ No newline at end of file
csrc/flash_attn/src/fmha/gmem_tile.h
View file @
5d07483b
...
@@ -39,14 +39,13 @@ template<
...
@@ -39,14 +39,13 @@ template<
// The number of rows of Q, K or V loaded by this tile.
// The number of rows of Q, K or V loaded by this tile.
int
ROWS_
,
int
ROWS_
,
// The number of columns.
// The number of columns.
int
COLS
,
int
COLS
// The number of matrics.
int
NUM_MATS
=
3
>
>
struct
Gmem_tile_qkv
{
struct
Gmem_tile_qkv
{
using
Cta_tile
=
Cta_tile_
;
using
Cta_tile
=
Cta_tile_
;
static
constexpr
int
BYTES_PER_ELEMENT
=
BITS_PER_ELEMENT
/
8
;
// The size of each LDG.
// The size of each LDG.
static
constexpr
int
BYTES_PER_LDG
=
16
;
static
constexpr
int
BYTES_PER_LDG
=
16
;
// The size of a row in bytes.
// The size of a row in bytes.
...
@@ -62,11 +61,12 @@ struct Gmem_tile_qkv {
...
@@ -62,11 +61,12 @@ struct Gmem_tile_qkv {
static
constexpr
int
LDGS
=
DivUpConstexpr
(
ROWS
,
ROWS_PER_LDG
);
static
constexpr
int
LDGS
=
DivUpConstexpr
(
ROWS
,
ROWS_PER_LDG
);
// Ctor.
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
template
<
typename
BInfo
>
inline
__device__
Gmem_tile_qkv
(
const
Params
&
params
,
const
int
qkv_offset
,
const
BInfo
&
binfo
,
const
int
tidx
)
inline
__device__
Gmem_tile_qkv
(
void
*
ptr_
,
const
uint32_t
row_stride_in_elts
,
:
params_qkv_stride_in_bytes_
(
params
.
qkv_stride_in_bytes
)
const
uint32_t
head_stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
qkv_
ptr
_
(
reinterpret_cast
<
char
*>
(
params
.
qkv_
ptr
))
,
ptr
(
reinterpret_cast
<
char
*>
(
ptr
_
))
,
tidx_
(
tidx
)
{
,
tidx_
(
tidx
)
{
// Compute the position in the sequence (within the CTA for the moment).
// Compute the position in the sequence (within the CTA for the moment).
...
@@ -80,13 +80,13 @@ struct Gmem_tile_qkv {
...
@@ -80,13 +80,13 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t
row_offset
=
(
uint32_t
)
row
*
params
.
qkv
_stride_in_bytes
;
uint32_t
row_offset
=
(
uint32_t
)
((
binfo
.
sum_s
+
row
)
*
row
_stride_in_bytes
)
;
// Add the block index.
// Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset
+=
(
uint32_t
)(
(
binfo
.
sum_s
*
NUM_MATS
+
qkv_offset
)
*
binfo
.
h
+
binfo
.
bidh
)
*
BYTES_PER_
ROW
;
row_offset
+=
(
uint32_t
)(
binfo
.
bidh
*
head_stride_in_elts
*
BYTES_PER_
ELEMENT
)
;
// Assemble the final pointer.
// Assemble the final pointer.
qkv_
ptr
_
+=
row_offset
+
col
*
BYTES_PER_LDG
;
ptr
+=
row_offset
+
col
*
BYTES_PER_LDG
;
}
}
// Store data to shared memory.
// Store data to shared memory.
...
@@ -101,8 +101,8 @@ struct Gmem_tile_qkv {
...
@@ -101,8 +101,8 @@ struct Gmem_tile_qkv {
uint32_t
preds
[
LDGS
];
uint32_t
preds
[
LDGS
];
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
// ptrs[ii] =
qkv_
ptr
_
+ (int64_t)ii * ROWS_PER_LDG *
params_qkv
_stride_in_bytes
_
;
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG *
row
_stride_in_bytes;
ptrs
[
ii
]
=
qkv_
ptr
_
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
params_qkv
_stride_in_bytes
_
;
ptrs
[
ii
]
=
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row
_stride_in_bytes
;
preds
[
ii
]
=
((
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
));
preds
[
ii
]
=
((
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
));
fetch_
[
ii
]
=
make_uint4
(
0
,
0
,
0
,
0
);
fetch_
[
ii
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
}
...
@@ -120,32 +120,25 @@ struct Gmem_tile_qkv {
...
@@ -120,32 +120,25 @@ struct Gmem_tile_qkv {
int
row_
=
tidx_
/
THREADS_PER_ROW
;
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
// char *ptr =
qkv_
ptr
_
+ (int64_t)ii * ROWS_PER_LDG *
params_qkv
_stride_in_bytes
_
;
// char *ptr
_
= ptr + (int64_t)ii * ROWS_PER_LDG *
row
_stride_in_bytes;
char
*
ptr
=
qkv_
ptr
_
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
params_qkv
_stride_in_bytes
_
;
char
*
ptr
_
=
ptr
+
(
uint32_t
)
ii
*
ROWS_PER_LDG
*
row
_stride_in_bytes
;
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
fmha
::
stg
(
ptr
,
data
[
ii
]);
fmha
::
stg
(
ptr
_
,
data
[
ii
]);
}
}
}
}
}
}
// Move the pointer to the next location.
inline
__device__
void
move
(
const
int
steps
=
1
)
{
inline
__device__
void
move
()
{
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
// qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;
ptr
+=
(
uint32_t
)
ROWS
*
row_stride_in_bytes
*
steps
;
qkv_ptr_
+=
(
uint32_t
)
ROWS
*
params_qkv_stride_in_bytes_
;
actual_seqlen
-=
ROWS
;
}
inline
__device__
void
move
(
int
steps
)
{
// qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
qkv_ptr_
+=
(
uint32_t
)
ROWS
*
params_qkv_stride_in_bytes_
*
steps
;
actual_seqlen
-=
ROWS
*
steps
;
actual_seqlen
-=
ROWS
*
steps
;
}
}
// The stride between rows for the QKV matrice.
// The stride between rows for the QKV matrice.
// int64_t
params_qkv
_stride_in_bytes
_
;
// int64_t
row
_stride_in_bytes;
uint32_t
params_qkv
_stride_in_bytes
_
;
const
uint32_t
row
_stride_in_bytes
;
// The pointer.
// The pointer.
char
*
qkv_
ptr
_
;
char
*
ptr
;
// The fetch registers.
// The fetch registers.
uint4
fetch_
[
LDGS
];
uint4
fetch_
[
LDGS
];
// Keep track of the row the thread is processing as we move the tile.
// Keep track of the row the thread is processing as we move the tile.
...
@@ -196,10 +189,10 @@ struct Gmem_tile_o {
...
@@ -196,10 +189,10 @@ struct Gmem_tile_o {
// Ctor.
// Ctor.
template
<
typename
BInfo
>
template
<
typename
BInfo
>
// inline __device__ Gmem_tile_o(void *ptr, const size_t stride_in_elts, const BInfo &binfo, const int tidx)
// inline __device__ Gmem_tile_o(void *ptr, const size_t
row_
stride_in_elts, const BInfo &binfo, const int tidx)
inline
__device__
Gmem_tile_o
(
void
*
ptr
,
const
uint32_t
stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
)
inline
__device__
Gmem_tile_o
(
void
*
ptr
,
const
uint32_t
row_
stride_in_elts
,
:
stride_in_bytes_
(
stride_in_elts
*
BYTES_PER_ELEMENT
)
const
uint32_t
head_stride_in_elts
,
const
BInfo
&
binfo
,
const
int
tidx
)
,
actual_seqlen_
(
binfo
.
actual_seqlen
)
:
row_stride_in_bytes
(
row_stride_in_elts
*
BYTES_PER_ELEMENT
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
))
,
ptr_
(
reinterpret_cast
<
char
*>
(
ptr
))
,
tidx_
(
tidx
)
{
,
tidx_
(
tidx
)
{
...
@@ -213,8 +206,9 @@ struct Gmem_tile_o {
...
@@ -213,8 +206,9 @@ struct Gmem_tile_o {
// row_ = row;
// row_ = row;
// The row offset in the batched GEMM.
// The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
uint32_t
row_offset
=
(
uint32_t
)
row
*
stride_in_bytes_
+
binfo
.
bidx
*
BYTES_PER_ROW
;
uint32_t
row_offset
=
(
uint32_t
)((
binfo
.
sum_s
+
row
)
*
row_stride_in_bytes
);
row_offset
+=
(
uint32_t
)(
binfo
.
bidh
*
head_stride_in_elts
*
BYTES_PER_ELEMENT
);
// Assemble the final pointer.
// Assemble the final pointer.
ptr_
+=
row_offset
+
col
*
BYTES_PER_STG
;
ptr_
+=
row_offset
+
col
*
BYTES_PER_STG
;
...
@@ -224,25 +218,19 @@ struct Gmem_tile_o {
...
@@ -224,25 +218,19 @@ struct Gmem_tile_o {
}
}
}
}
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_o
(
const
Params
&
params
,
const
BInfo
&
binfo
,
const
int
tidx
)
:
Gmem_tile_o
(
params
.
o_ptr
,
params
.
o_stride_in_elts
,
binfo
,
tidx
)
{}
// Store data to global memory.
// Store data to global memory.
inline
__device__
void
store
(
const
uint4
(
&
src
)[
STGS_PER_LOOP
],
int
mi
)
{
inline
__device__
void
store
(
const
uint4
(
&
src
)[
STGS_PER_LOOP
],
int
mi
)
{
int
row_
=
tidx_
/
THREADS_PER_ROW
;
int
row_
=
tidx_
/
THREADS_PER_ROW
;
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
// if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {
// break;
if
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen
)
{
if
(
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen
)
{
break
;
break
;
}
}
if
(
BYTES_PER_ELEMENT
==
4
)
{
if
(
BYTES_PER_ELEMENT
==
4
)
{
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
stg
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
stride_in_bytes
_
,
src
[
ii
]);
fmha
::
stg
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_
stride_in_bytes
,
src
[
ii
]);
}
}
}
else
if
(
BYTES_PER_ELEMENT
==
2
)
{
}
else
if
(
BYTES_PER_ELEMENT
==
2
)
{
float
x
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
x
);
float
x
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
x
);
...
@@ -251,7 +239,7 @@ struct Gmem_tile_o {
...
@@ -251,7 +239,7 @@ struct Gmem_tile_o {
float
w
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
w
);
float
w
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
w
);
uint2
out
=
float4_to_half4
(
x
,
y
,
z
,
w
);
uint2
out
=
float4_to_half4
(
x
,
y
,
z
,
w
);
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
stg
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
stride_in_bytes
_
,
out
);
fmha
::
stg
(
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_
stride_in_bytes
,
out
);
}
}
}
}
}
}
...
@@ -269,37 +257,26 @@ struct Gmem_tile_o {
...
@@ -269,37 +257,26 @@ struct Gmem_tile_o {
}
}
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
ldg
(
dst
[
ii
],
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
stride_in_bytes
_
);
fmha
::
ldg
(
dst
[
ii
],
this
->
ptr_
+
jj
*
ROWS_PER_STG
*
this
->
row_
stride_in_bytes
);
}
}
}
}
}
}
// Move the pointer to the next location.
inline
__device__
void
move
(
const
int
steps
=
1
)
{
inline
__device__
void
move
()
{
// row_ += ROWS;
// ptr_ += (int64_t)ROWS * stride_in_bytes_;
ptr_
+=
(
uint32_t
)
ROWS
*
stride_in_bytes_
;
actual_seqlen
-=
ROWS
;
}
inline
__device__
void
move
(
const
int
steps
)
{
// row_ += ROWS * steps;
// row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * stride_in_bytes
_
* steps;
// ptr_ += (int64_t)ROWS *
row_
stride_in_bytes * steps;
ptr_
+=
(
uint32_t
)
ROWS
*
stride_in_bytes
_
*
steps
;
ptr_
+=
(
uint32_t
)
ROWS
*
row_
stride_in_bytes
*
steps
;
actual_seqlen
-=
ROWS
*
steps
;
actual_seqlen
-=
ROWS
*
steps
;
}
}
// The stride between rows for the QKV matrice.
// The stride between rows for the QKV matrice.
// int64_t stride_in_bytes
_
;
// int64_t
row_
stride_in_bytes;
uint32_t
stride_in_bytes
_
;
const
uint32_t
row_
stride_in_bytes
;
// The pointer.
// The pointer.
char
*
ptr_
;
char
*
ptr_
;
// Is the thread active for the last STG?
// Is the thread active for the last STG?
int
is_active_for_last_stg_
;
int
is_active_for_last_stg_
;
// Keep track of the row to disable loads.
// int row_;
// The length of the sequence loaded by that memory tile.
// The length of the sequence loaded by that memory tile.
const
int
actual_seqlen_
;
int
actual_seqlen
;
int
actual_seqlen
;
const
int
tidx_
;
const
int
tidx_
;
};
};
...
@@ -363,10 +340,7 @@ struct Gmem_tile_mma_sd {
...
@@ -363,10 +340,7 @@ struct Gmem_tile_mma_sd {
}
}
// Move to the next tile.
// Move to the next tile.
inline
__device__
void
move
()
{
inline
__device__
void
move
(
const
int
steps
=
1
)
{
ptr_
+=
LOOP_STRIDE_BYTES
;
}
inline
__device__
void
move
(
const
int
steps
)
{
ptr_
+=
LOOP_STRIDE_BYTES
*
steps
;
ptr_
+=
LOOP_STRIDE_BYTES
*
steps
;
}
}
...
@@ -459,69 +433,6 @@ struct Gmem_tile_mma_s : public Base {
...
@@ -459,69 +433,6 @@ struct Gmem_tile_mma_s : public Base {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The base class.
typename
Base
=
fmha
::
Gmem_tile_qkv
<
Cta_tile
,
fmha
::
BITS_PER_ELEMENT_A
,
Cta_tile
::
M
,
Cta_tile
::
K
>
>
struct
Gmem_tile_dout
:
public
Base
{
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_dout
(
void
*
ptr
,
const
Params
&
params
,
const
BInfo
&
binfo
,
int
tidx
)
:
Base
(
params
,
0
,
binfo
,
tidx
)
{
// this->qkv_ptr_ = reinterpret_cast<char *>(params.do_ptr);
this
->
qkv_ptr_
=
static_cast
<
char
*>
(
ptr
);
this
->
params_qkv_stride_in_bytes_
=
params
.
o_stride_in_bytes
;
// needed for move
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
Base
::
THREADS_PER_ROW
;
// Compute the position of the thread in the row.
int
col
=
tidx
%
Base
::
THREADS_PER_ROW
;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
// int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
uint32_t
row_offset
=
(
uint32_t
)
row
*
params
.
o_stride_in_bytes
+
binfo
.
bidx
*
Base
::
BYTES_PER_ROW
;
// Assemble the final pointer.
this
->
qkv_ptr_
+=
row_offset
+
col
*
Base
::
BYTES_PER_LDG
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Base
=
fmha
::
Gmem_tile_o
<
Cta_tile
>
>
struct
Gmem_tile_dq
:
public
Base
{
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_dq
(
const
Params
&
params
,
const
int
qkv_offset
,
const
BInfo
&
binfo
,
int
tidx
)
:
Base
(
params
.
dqkv_ptr
,
params
.
qkv_stride_in_elts
,
binfo
,
tidx
)
{
this
->
ptr_
=
reinterpret_cast
<
char
*>
(
params
.
dqkv_ptr
);
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
Base
::
THREADS_PER_ROW
;
// Compute the position of the thread in the row.
int
col
=
tidx
%
Base
::
THREADS_PER_ROW
;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
// int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +
// ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * this->stride_in_bytes_ +
// ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
uint32_t
row_offset
=
(
uint32_t
)
row
*
this
->
stride_in_bytes_
+
((
binfo
.
sum_s
*
3
+
qkv_offset
)
*
binfo
.
h
+
binfo
.
bidh
)
*
Base
::
BYTES_PER_ROW
;
// Assemble the final pointer.
this
->
ptr_
+=
row_offset
+
col
*
Base
::
BYTES_PER_STG
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
template
<
// The dimensions of the tile computed by the CTA.
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
typename
Cta_tile
...
...
csrc/flash_attn/src/fmha/kernel_traits.h
View file @
5d07483b
...
@@ -72,9 +72,7 @@ struct FMHA_kernel_traits {
...
@@ -72,9 +72,7 @@ struct FMHA_kernel_traits {
// The shared memory tile to transpose S.
// The shared memory tile to transpose S.
using
Smem_tile_st
=
fmha
::
Smem_tile_mma_transposed
<
Cta_tile_p
>
;
using
Smem_tile_st
=
fmha
::
Smem_tile_mma_transposed
<
Cta_tile_p
>
;
using
Gmem_tile_do
=
fmha
::
Gmem_tile_dout
<
Cta_tile_p
>
;
using
Gmem_tile_do
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_A
,
STEP
,
D
>
;
using
Gmem_tile_dot
=
fmha
::
Gmem_tile_dout
<
Cta_tile_p
,
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_B
,
S
,
D
>
>
;
// The global memory tile to store the softmax sum.
// The global memory tile to store the softmax sum.
using
Gmem_softmax_sum
=
fmha
::
Gmem_summary_stats
<
Cta_tile_p
>
;
using
Gmem_softmax_sum
=
fmha
::
Gmem_summary_stats
<
Cta_tile_p
>
;
...
...
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
View file @
5d07483b
...
@@ -77,8 +77,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -77,8 +77,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
using
Gmem_tile_o
=
Gmem_tile_do
;
using
Gmem_tile_o
=
Gmem_tile_do
;
// The global memory tile to store dQ.
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using
Gmem_tile_dq
=
typename
Kernel_traits
::
Gmem_tile_o
;
using
Gmem_tile_dq
=
fmha
::
Gmem_tile_dq
<
Cta_tile_dq
>
;
using
Gmem_tile_dq_tmp
=
fmha
::
Gmem_tile_o
<
Cta_tile_dq
,
4
>
;
using
Gmem_tile_dq_tmp
=
fmha
::
Gmem_tile_o
<
Cta_tile_dq
,
4
>
;
// The shared memory tile to swizzle dQ.
// The shared memory tile to swizzle dQ.
using
Smem_tile_dq
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Smem_tile_dq
=
typename
Kernel_traits
::
Smem_tile_o
;
...
@@ -139,19 +138,19 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -139,19 +138,19 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dQ.
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
,
0
,
binfo
,
tidx
);
Gmem_tile_dq
gmem_dq
(
params
.
dqkv_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_
row_stride_in_elts
,
params
.
o_head_
stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
...
@@ -161,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -161,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
,
binfo
,
tidx
);
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dO.
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
...
@@ -173,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -173,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
...
@@ -703,11 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -703,11 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
smem_dv
.
load
(
dv_out
);
Qkv_params
dv_params
;
Gmem_tile_dv
gmem_dv
(
params
.
dqkv_ptr
+
2
*
params
.
h
*
params
.
d
*
2
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
dv_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dv_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dv_params
.
h
=
params
.
h
;
Gmem_tile_dv
gmem_dv
(
dv_params
,
2
,
binfo
,
tidx
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dv
.
move
(
loop_step_idx
);
gmem_dv
.
move
(
loop_step_idx
);
}
}
...
@@ -718,11 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -718,11 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
// }
Qkv_params
dk_params
;
Gmem_tile_dk
gmem_dk
(
params
.
dqkv_ptr
+
params
.
h
*
params
.
d
*
2
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
dk_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dk_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dk_params
.
h
=
params
.
h
;
Gmem_tile_dk
gmem_dk
(
dk_params
,
1
,
binfo
,
tidx
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
gmem_dk
.
move
(
loop_step_idx
);
}
}
...
...
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
View file @
5d07483b
...
@@ -97,10 +97,10 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -97,10 +97,10 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
Gemm1
gemm_q_k
(
smem_
,
tidx
);
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_
row_stride_in_elts
,
params
.
o_head_
stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
...
@@ -122,12 +122,12 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -122,12 +122,12 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
...
@@ -193,7 +193,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -193,7 +193,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
__syncthreads
();
__syncthreads
();
}
}
// Load the fragments for K.
// Load the fragments for K.
gemm_q_k
.
load_k
();
gemm_q_k
.
load_k
();
// Create the object to do the softmax.
// Create the object to do the softmax.
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
5d07483b
...
@@ -80,8 +80,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -80,8 +80,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
using
Gmem_tile_o
=
Gmem_tile_do
;
using
Gmem_tile_o
=
Gmem_tile_do
;
// The global memory tile to store dQ.
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using
Gmem_tile_dq
=
typename
Kernel_traits
::
Gmem_tile_o
;
using
Gmem_tile_dq
=
fmha
::
Gmem_tile_dq
<
Cta_tile_dq
>
;
using
Gmem_tile_dq_tmp
=
fmha
::
Gmem_tile_o
<
Cta_tile_dq
,
4
>
;
using
Gmem_tile_dq_tmp
=
fmha
::
Gmem_tile_o
<
Cta_tile_dq
,
4
>
;
// The shared memory tile to swizzle dQ.
// The shared memory tile to swizzle dQ.
using
Smem_tile_dq
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Smem_tile_dq
=
typename
Kernel_traits
::
Smem_tile_o
;
...
@@ -132,19 +131,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -132,19 +131,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
Gemm1
gemm_q_k
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for dQ.
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq
gmem_dq
(
params
,
0
,
binfo
,
tidx
);
Gmem_tile_dq
gmem_dq
(
params
.
dqkv_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_dq_tmp
gmem_dq_tmp
(
params
.
o_tmp_ptr
,
params
.
o_
row_stride_in_elts
,
params
.
o_head_
stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_V
];
...
@@ -154,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -154,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_kt
smem_kt
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for dO.
// Allocate the global memory tile loader for dO.
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
,
binfo
,
tidx
);
Gmem_tile_do
gmem_do
(
params
.
do_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dO.
// Allocate the shared memory tile loader for dO.
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_do
smem_do
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
Smem_tile_dot
smem_dot
(
&
smem_
[
0
],
tidx
);
...
@@ -166,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -166,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_dp
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
+
Smem_tile_dq
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
Smem_tile_dq
smem_dq
(
&
smem_
[
Smem_tile_do
::
BYTES_PER_TILE
+
Gemm1
::
SMEM_OFFSET_O
],
tidx
);
...
@@ -654,11 +653,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -654,11 +653,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
smem_dv
.
load
(
dv_out
);
Qkv_params
dv_params
;
Gmem_tile_dv
gmem_dv
(
params
.
dqkv_ptr
+
2
*
params
.
h
*
params
.
d
*
2
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
dv_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dv_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dv_params
.
h
=
params
.
h
;
Gmem_tile_dv
gmem_dv
(
dv_params
,
2
,
binfo
,
tidx
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dv
.
move
(
loop_step_idx
);
gmem_dv
.
move
(
loop_step_idx
);
}
}
...
@@ -669,11 +664,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -669,11 +664,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
// }
Qkv_params
dk_params
;
Gmem_tile_dk
gmem_dk
(
params
.
dqkv_ptr
+
params
.
h
*
params
.
d
*
2
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
dk_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dk_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dk_params
.
h
=
params
.
h
;
Gmem_tile_dk
gmem_dk
(
dk_params
,
1
,
binfo
,
tidx
);
if
(
!
Is_first
)
{
if
(
!
Is_first
)
{
gmem_dk
.
move
(
loop_step_idx
);
gmem_dk
.
move
(
loop_step_idx
);
}
}
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
5d07483b
...
@@ -247,10 +247,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -247,10 +247,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
Gemm1
gemm_q_k
(
smem_
,
tidx
);
Gemm1
gemm_q_k
(
smem_
,
tidx
);
// Allocate the global memory tile loader for Q.
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
Gmem_tile_q
gmem_q
(
params
.
q_ptr
,
params
.
q_row_stride_in_elts
,
params
.
q_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for O.
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
Gmem_tile_o
gmem_o
(
params
.
o_ptr
,
params
.
o_row_stride_in_elts
,
params
.
o_head_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_stride_in_elts
,
binfo
,
tidx
);
Gmem_tile_o_tmp
gmem_o_tmp
(
params
.
o_tmp_ptr
,
params
.
o_
row_stride_in_elts
,
params
.
o_head_
stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for S.
// Allocate the global memory tile loader for S.
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
Gmem_softmax_sum
gmem_softmax_lse
(
params
.
softmax_lse_ptr
,
params
,
tidx
);
...
@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
fmha
::
Mask
<
Cta_tile_p
,
Is_causal
>
mask
(
binfo
,
tidx
,
loop_step_idx
);
// Allocate the global memory tile loader for K.
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
Gmem_tile_k
gmem_k
(
params
.
k_ptr
,
params
.
k_row_stride_in_elts
,
params
.
k_head_stride_in_elts
,
binfo
,
tidx
);
// Allocate the global memory tile loader for V.
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
Gmem_tile_v
gmem_v
(
params
.
v_ptr
,
params
.
v_row_stride_in_elts
,
params
.
v_head_stride_in_elts
,
binfo
,
tidx
);
// The base pointer of smem_v;
// The base pointer of smem_v;
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
char
*
smem_v_
=
&
smem_
[
Gemm1
::
SMEM_OFFSET_V
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment