Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0d854692
Commit
0d854692
authored
Jun 05, 2022
by
Tri Dao
Browse files
Implement fwd for head dim 128
parent
0a398dfc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
10 deletions
+105
-10
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+2
-1
csrc/flash_attn/src/fmha/smem_tile.h
csrc/flash_attn/src/fmha/smem_tile.h
+60
-5
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+32
-2
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+11
-2
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
0d854692
...
@@ -118,6 +118,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -118,6 +118,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
bool
is_sm80
=
dprops
->
major
==
8
&&
dprops
->
minor
==
0
;
TORCH_CHECK
((
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
||
is_sm75
);
TORCH_CHECK
((
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
||
is_sm75
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
bool
is_dropout
=
p_dropout
>
0.0
;
bool
is_dropout
=
p_dropout
>
0.0
;
...
@@ -144,7 +145,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -144,7 +145,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int
base_N
=
(
head_size
==
128
||
(
is_sm75
&&
head_size
==
64
&&
is_dropout
))
?
128
:
256
;
int
base_N
=
(
(
head_size
==
128
&&
(
is_dropout
||
!
is_sm80
))
||
(
is_sm75
&&
head_size
==
64
&&
is_dropout
))
?
128
:
256
;
// int base_N = 256;
// int base_N = 256;
int
seq_len
=
512
;
int
seq_len
=
512
;
if
(
max_seq_len
<=
128
)
{
if
(
max_seq_len
<=
128
)
{
...
...
csrc/flash_attn/src/fmha/smem_tile.h
View file @
0d854692
...
@@ -1054,6 +1054,14 @@ struct Smem_tile_o {
...
@@ -1054,6 +1054,14 @@ struct Smem_tile_o {
constexpr
int
STS_PER_WARP
=
16
*
Mma_tile
::
MMAS_N
/
ELEMENTS_PER_STS
;
constexpr
int
STS_PER_WARP
=
16
*
Mma_tile
::
MMAS_N
/
ELEMENTS_PER_STS
;
int
write_col
=
warp
*
STS_PER_WARP
+
lane
%
STS_PER_WARP
;
int
write_col
=
warp
*
STS_PER_WARP
+
lane
%
STS_PER_WARP
;
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("write_row = %d, write_col = %d\n", write_row, write_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the write pointer.
// Assemble the write pointer.
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
...
@@ -1062,9 +1070,15 @@ struct Smem_tile_o {
...
@@ -1062,9 +1070,15 @@ struct Smem_tile_o {
int
read_col
=
tidx
%
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
// Take the XOR pattern into account for the column.
// Take the XOR pattern into account for the column.
//
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
read_col
^=
2
*
(
read_row
%
(
Cta_tile
::
N
==
16
?
2
:
(
Cta_tile
::
N
==
32
?
4
:
8
)));
read_col
^=
2
*
(
read_row
%
(
Cta_tile
::
N
==
16
?
2
:
(
Cta_tile
::
N
==
32
?
4
:
(
Cta_tile
::
N
==
128
?
16
:
8
))));
//
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("read_row = %d, read_col = %d\n", read_row, read_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the read pointer.
// Assemble the read pointer.
this
->
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
this
->
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
...
@@ -1085,16 +1099,31 @@ struct Smem_tile_o {
...
@@ -1085,16 +1099,31 @@ struct Smem_tile_o {
#pragma unroll
#pragma unroll
for
(
int
jj
=
0
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
for
(
int
jj
=
0
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
int
imm
=
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
+
jj
*
Cta_tile
::
N
*
BYTES_PER_ELEMENT
;
int
imm
=
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
+
jj
*
Cta_tile
::
N
*
BYTES_PER_ELEMENT
;
uint32_t
smem_read
=
this
->
smem_read_
+
imm
;
// TD [2022-06-05] Ugly fix for d=128, maybe there's a better way.
if
((
Cta_tile
::
N
==
128
)
&&
(
ii
%
2
==
1
))
{
smem_read
^=
8
*
BYTES_PER_LDS
;
}
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("imm diff = %d\n", smem_read - this->smem_read_);
// }
if
(
!
HAS_INCOMPLETE_LDS
||
(
ii
<
LDS_PER_LOOP
-
1
||
this
->
is_active_for_last_lds_
)
)
{
if
(
!
HAS_INCOMPLETE_LDS
||
(
ii
<
LDS_PER_LOOP
-
1
||
this
->
is_active_for_last_lds_
)
)
{
fmha
::
lds
(
tmp
[
jj
],
this
->
smem_read_
+
imm
);
// fmha::lds(tmp[jj], this->smem_read_ + imm);
fmha
::
lds
(
tmp
[
jj
],
smem_read
);
}
}
}
}
// Perform the reduction.
// Perform the reduction.
out
[
ii
]
=
zero_init
?
tmp
[
0
]
:
fmha
::
fadd4
(
out
[
ii
],
tmp
[
0
]);
out
[
ii
]
=
zero_init
?
tmp
[
0
]
:
fmha
::
fadd4
(
out
[
ii
],
tmp
[
0
]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
#pragma unroll
#pragma unroll
for
(
int
jj
=
1
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
for
(
int
jj
=
1
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
out
[
ii
]
=
fmha
::
fadd4
(
out
[
ii
],
tmp
[
jj
]);
out
[
ii
]
=
fmha
::
fadd4
(
out
[
ii
],
tmp
[
jj
]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
}
}
}
}
}
}
...
@@ -1102,6 +1131,7 @@ struct Smem_tile_o {
...
@@ -1102,6 +1131,7 @@ struct Smem_tile_o {
// Store the accumulators.
// Store the accumulators.
template
<
int
M
,
int
N
>
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
Accumulator
(
&
acc
)[
M
][
N
],
int
mi
)
{
inline
__device__
void
store
(
const
Accumulator
(
&
acc
)[
M
][
N
],
int
mi
)
{
// uint32_t smem_write_og = this->smem_write_;
static
constexpr
int
M_PER_MMA
=
Mma_tile
::
M_PER_MMA_PER_CTA
;
static
constexpr
int
M_PER_MMA
=
Mma_tile
::
M_PER_MMA_PER_CTA
;
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
...
@@ -1126,7 +1156,15 @@ struct Smem_tile_o {
...
@@ -1126,7 +1156,15 @@ struct Smem_tile_o {
fmha
::
sts
(
this
->
smem_write_
+
row_0
,
tmp0
);
fmha
::
sts
(
this
->
smem_write_
+
row_0
,
tmp0
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
}
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
// Swizzle the write pointer using a XOR of 16B.
// Swizzle the write pointer using a XOR of 16B.
this
->
smem_write_
^=
32
;
this
->
smem_write_
^=
32
;
...
@@ -1148,8 +1186,25 @@ struct Smem_tile_o {
...
@@ -1148,8 +1186,25 @@ struct Smem_tile_o {
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
}
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this
->
smem_write_
^=
(
ni
&
1
)
?
7
*
32
:
3
*
32
;
static_assert
(
Mma_tile
::
MMAS_N
<=
8
,
"Not implemented"
);
if
(
Mma_tile
::
MMAS_N
>=
8
&&
ni
%
4
==
3
)
{
this
->
smem_write_
^=
15
*
32
;
}
else
if
(
Mma_tile
::
MMAS_N
>=
4
&&
ni
%
2
==
1
)
{
this
->
smem_write_
^=
7
*
32
;
}
else
if
(
Mma_tile
::
MMAS_N
>=
2
)
{
this
->
smem_write_
^=
3
*
32
;
}
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
}
}
}
}
};
};
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
0d854692
...
@@ -121,8 +121,21 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -121,8 +121,21 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
}
}
}
}
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
if
(
launch_params
.
params
.
s
==
128
)
{
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
&&
!
is_dropout
)
{
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
128
,
16
,
1
,
4
,
0x18u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
// Need to use the same block size as backward
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
}
}
// if (launch_params.params.d == 64) {
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
...
@@ -151,4 +164,21 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -151,4 +164,21 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// }
// }
// }
// }
// }
// }
// if (launch_params.params.d == 128) {
// if( launch_params.params.s == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
}
}
\ No newline at end of file
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
0d854692
...
@@ -498,10 +498,19 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -498,10 +498,19 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
#pragma unroll
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm_cl
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
fmha
::
gemm_cl
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
// printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0));
// }
}
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// So we recalculate the max.
// printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0));
// }
// The mapping from tidx to rows changes between the softmax and the
// O-reduction. So we recalculate the max.
float
p_max_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
float
p_max_o
[
Gmem_tile_o
::
STGS_PER_LOOP
][
Mma_tile_o
::
MMAS_M
];
// TODO: not sure if this is right for seqlen 128 or 256
// TODO: not sure if this is right for seqlen 128 or 256
int
rows
[
Gmem_tile_o
::
STGS_PER_LOOP
];
int
rows
[
Gmem_tile_o
::
STGS_PER_LOOP
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment