Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d3e64409
Commit
d3e64409
authored
Jun 11, 2022
by
Tri Dao
Browse files
Implement bwd for head dim 128
parent
0d854692
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
56 additions
and
36 deletions
+56
-36
README.md
README.md
+2
-2
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+12
-20
csrc/flash_attn/src/fmha/smem_tile.h
csrc/flash_attn/src/fmha/smem_tile.h
+20
-7
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+4
-0
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+16
-2
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+2
-2
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+0
-3
No files found.
README.md
View file @
d3e64409
...
@@ -24,7 +24,7 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
...
@@ -24,7 +24,7 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttention currently supports:
FlashAttention currently supports:
1.
Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
1.
Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
2.
fp16.
2.
fp16.
3.
Head dimensions 16, 32, 64.
3.
Head dimensions 16, 32, 64
, 128 (bwd requires A100)
.
Our tentative roadmap:
Our tentative roadmap:
1.
[Jun 2022] Make package pip-installable.
1.
[Jun 2022] Make package pip-installable.
...
@@ -32,7 +32,7 @@ Our tentative roadmap:
...
@@ -32,7 +32,7 @@ Our tentative roadmap:
3.
[Jun 2022] Refactor to use Cutlass.
3.
[Jun 2022] Refactor to use Cutlass.
4.
~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
4.
~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
5.
[Jun 2022] Support bf16.
5.
[Jun 2022] Support bf16.
6.
[Jul 2022] Support head dimension 128.
6.
~~
[Jul 2022] Support head dimension 128
~~[Done]
.
7.
[Jul 2022] Support SM70 GPUs (V100).
7.
[Jul 2022] Support SM70 GPUs (V100).
8.
[Aug 2022] Fuse rotary embedding.
8.
[Aug 2022] Fuse rotary embedding.
9.
[Aug 2022] Support Attention linear bias (e.g. ALiBi).
9.
[Aug 2022] Support Attention linear bias (e.g. ALiBi).
...
...
csrc/flash_attn/fmha_api.cpp
View file @
d3e64409
...
@@ -144,9 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -144,9 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int
base_N
=
((
head_size
==
128
&&
(
is_dropout
||
!
is_sm80
))
||
(
is_sm75
&&
head_size
==
64
&&
is_dropout
))
?
128
:
256
;
int
base_N
=
((
head_size
==
128
&&
(
is_dropout
||
!
is_sm80
))
||
(
is_sm75
&&
head_size
==
64
&&
is_dropout
))
?
128
:
256
;
// int base_N = 256;
int
seq_len
=
512
;
int
seq_len
=
512
;
if
(
max_seq_len
<=
128
)
{
if
(
max_seq_len
<=
128
)
{
seq_len
=
128
;
seq_len
=
128
;
...
@@ -162,18 +160,13 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -162,18 +160,13 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
auto
ctx
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
);
auto
ctx
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
);
at
::
Tensor
o_tmp
;
at
::
Tensor
o_tmp
;
if
(
loop
)
{
if
(
loop
)
{
o_tmp
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
.
dtype
(
at
::
kFloat
));
}
o_tmp
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
.
dtype
(
at
::
kFloat
));
}
auto
softmax_lse
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
},
opts
.
dtype
(
at
::
kFloat
));
auto
softmax_lse
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
},
opts
.
dtype
(
at
::
kFloat
));
// auto softmax_lse = torch::full({batch_size, num_heads, seq_len}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
// auto softmax_lse = torch::full({batch_size, num_heads, seq_len}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
at
::
Tensor
s
;
at
::
Tensor
s
;
if
(
return_softmax
)
{
if
(
return_softmax
)
{
s
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
,
seq_len
},
opts
);
}
s
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
,
seq_len
},
opts
);
// s = torch::ones({ batch_size, num_heads, seq_len, seq_len }, opts) * 10000.0;
}
if
(
zero_tensors
)
{
if
(
zero_tensors
)
{
ctx
.
zero_
();
ctx
.
zero_
();
...
@@ -228,7 +221,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -228,7 +221,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
out
,
// total x num_heads x head_size
const
at
::
Tensor
&
out
,
// total x num_heads x head_size
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
const
at
::
Tensor
&
softmax_lse
,
// b x h x s softmax logsumexp
const
at
::
Tensor
&
softmax_lse
_
,
// b x h x s softmax logsumexp
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
// probability to drop
const
float
p_dropout
,
// probability to drop
const
float
softmax_scale
,
const
float
softmax_scale
,
...
@@ -239,6 +232,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -239,6 +232,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
)
{
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
bool
is_sm80
=
dprops
->
major
==
8
&&
dprops
->
minor
==
0
;
TORCH_CHECK
((
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
||
is_sm75
);
TORCH_CHECK
((
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
||
is_sm75
);
auto
launch
=
&
run_fmha_dgrad_fp16_sm80
;
auto
launch
=
&
run_fmha_dgrad_fp16_sm80
;
...
@@ -269,8 +263,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -269,8 +263,10 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const
int
head_size
=
sizes
[
D_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
if
(
head_size
==
128
)
{
// TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK
(
is_sm80
);
}
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int
base_N
=
(
head_size
==
128
||
(
is_sm75
&&
head_size
==
64
))
?
128
:
256
;
int
base_N
=
(
head_size
==
128
||
(
is_sm75
&&
head_size
==
64
))
?
128
:
256
;
int
seq_len
=
512
;
int
seq_len
=
512
;
if
(
max_seq_len
<=
128
)
{
if
(
max_seq_len
<=
128
)
{
...
@@ -282,18 +278,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -282,18 +278,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
}
}
bool
loop
=
seq_len
>
base_N
;
bool
loop
=
seq_len
>
base_N
;
// It's possible the softmax_lse_ from the fwd has a different length since base_N could be different.
auto
softmax_lse
=
softmax_lse_
.
index
({
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(),
torch
::
indexing
::
Slice
(
torch
::
indexing
::
None
,
seq_len
)}).
contiguous
();
auto
dqkv
=
torch
::
empty_like
(
qkv
);
auto
dqkv
=
torch
::
empty_like
(
qkv
);
auto
opts
=
qkv
.
options
();
auto
opts
=
qkv
.
options
();
// auto softmax_lse =
// torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
auto
softmax_d
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
},
opts
.
dtype
(
at
::
kFloat
));
auto
softmax_d
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
},
opts
.
dtype
(
at
::
kFloat
));
// softmax.zero_();
// torch::nn::init::ones_(softmax);
// torch::nn::init::ones_(dqkv);
at
::
Tensor
dq_tmp
;
at
::
Tensor
dq_tmp
;
if
(
loop
)
{
if
(
loop
)
{
dq_tmp
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
.
dtype
(
at
::
kFloat
));
}
dq_tmp
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
.
dtype
(
at
::
kFloat
));
}
if
(
zero_tensors
)
{
if
(
zero_tensors
)
{
dqkv
.
zero_
();
dqkv
.
zero_
();
...
@@ -324,7 +316,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -324,7 +316,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
// We're gonna reset the rng state in Python after this kernel, so the counter offset
// We're gonna reset the rng state in Python after this kernel, so the counter offset
// here doesn't matter at all. We just choose an arbitrary number
;
// here doesn't matter at all. We just choose an arbitrary number
.
int64_t
counter_offset
=
4
;
int64_t
counter_offset
=
4
;
if
(
is_dropout
)
{
if
(
is_dropout
)
{
...
...
csrc/flash_attn/src/fmha/smem_tile.h
View file @
d3e64409
...
@@ -847,6 +847,7 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
...
@@ -847,6 +847,7 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
// The size in bytes of the data needed to compute an MMA per CTA.
// The size in bytes of the data needed to compute an MMA per CTA.
const
int
BYTES_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
*
BITS_PER_ELT
/
8
;
const
int
BYTES_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
*
BITS_PER_ELT
/
8
;
// uint32_t smem_read_og = this->smem_ + this->smem_read_offset_;
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Prepare the offset.
// Prepare the offset.
...
@@ -872,6 +873,9 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
...
@@ -872,6 +873,9 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
lds
(
tmp
.
w
,
(
ptr
^
32
)
+
4
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
);
lds
(
tmp
.
w
,
(
ptr
^
32
)
+
4
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og);
// }
// Store those values in the fragment.
// Store those values in the fragment.
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
...
@@ -885,6 +889,8 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
...
@@ -885,6 +889,8 @@ struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
// Nothing to do!
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
8
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
4
==
3
?
14
:
(
ni
%
2
==
1
?
6
:
2
));
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
4
)
{
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
2
)
{
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
2
)
{
...
@@ -1100,8 +1106,8 @@ struct Smem_tile_o {
...
@@ -1100,8 +1106,8 @@ struct Smem_tile_o {
for
(
int
jj
=
0
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
for
(
int
jj
=
0
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
int
imm
=
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
+
jj
*
Cta_tile
::
N
*
BYTES_PER_ELEMENT
;
int
imm
=
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
+
jj
*
Cta_tile
::
N
*
BYTES_PER_ELEMENT
;
uint32_t
smem_read
=
this
->
smem_read_
+
imm
;
uint32_t
smem_read
=
this
->
smem_read_
+
imm
;
// TD [2022-06-05] Ugly fix for d=128, maybe there's a better way.
// TD [2022-06-05] Ugly fix for d=128
in the forward pass
, maybe there's a better way.
if
((
Cta_tile
::
N
==
128
)
&&
(
ii
%
2
==
1
))
{
if
((
Cta_tile
::
N
==
128
)
&&
(
ROWS_PER_LDS
==
4
)
&&
(
ii
%
2
==
1
))
{
smem_read
^=
8
*
BYTES_PER_LDS
;
smem_read
^=
8
*
BYTES_PER_LDS
;
}
}
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...
@@ -1232,16 +1238,17 @@ struct Smem_tile_mma {
...
@@ -1232,16 +1238,17 @@ struct Smem_tile_mma {
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
int
write_col
,
write_row
;
int
write_col
,
write_row
;
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
||
(
WARPS_M
==
4
||
WARPS_
N
==
8
)
||
WARPS_N
==
1
);
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
||
(
WARPS_M
==
4
||
WARPS_
M
==
8
)
||
WARPS_N
==
1
);
if
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
)
{
if
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
)
{
write_row
=
(
tidx
&
0x1c
)
/
4
;
write_row
=
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0xe0
)
/
4
+
(
tidx
&
0x03
);
write_col
=
(
tidx
&
0xe0
)
/
4
+
(
tidx
&
0x03
);
write_col
^=
(
write_row
&
0x07
)
*
4
;
}
else
{
}
else
{
write_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x1c
)
/
4
;
write_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0x03
);
write_col
=
(
tidx
&
0x03
);
// write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4;
write_col
^=
(
write_row
&
(
BYTES_PER_ROW
==
32
?
0x01
:
(
BYTES_PER_ROW
==
64
?
0x03
:
(
BYTES_PER_ROW
==
128
?
0x07
:
0x07
))))
*
4
;
}
}
// TODO [TD] Only works for, D=16, D=32 or D=64
write_col
^=
(
write_row
&
(
BYTES_PER_ROW
==
32
?
0x01
:
(
BYTES_PER_ROW
==
64
?
0x03
:
0x07
)))
*
4
;
// write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
...
@@ -1309,7 +1316,8 @@ struct Smem_tile_mma_transposed : public Base {
...
@@ -1309,7 +1316,8 @@ struct Smem_tile_mma_transposed : public Base {
read_row
=
(
tidx
&
0x0f
);
read_row
=
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0xe0
)
/
16
+
(
tidx
&
0x1c
)
/
16
;
read_col
=
(
tidx
&
0xe0
)
/
16
+
(
tidx
&
0x1c
)
/
16
;
read_col
^=
(
read_row
&
(
Base
::
BYTES_PER_ROW
==
32
?
0x01
:
(
Base
::
BYTES_PER_ROW
==
64
?
0x03
:
0x07
)));
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x0f))));
read_col
^=
(
read_row
&
0x07
);
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
}
...
@@ -1357,7 +1365,9 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1357,7 +1365,9 @@ struct Smem_tile_mma_epilogue : public Base {
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
const
int
read_row
=
tidx
/
THREADS_PER_ROW
;
const
int
read_row
=
tidx
/
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
read_col
^=
(
read_row
&
(
Base
::
BYTES_PER_ROW
==
32
?
0x01
:
(
Base
::
BYTES_PER_ROW
==
64
?
0x03
:
0x07
)));
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
static_assert
(
Base
::
BYTES_PER_ROW
==
32
||
Base
::
BYTES_PER_ROW
==
64
||
Base
::
BYTES_PER_ROW
==
128
||
Base
::
BYTES_PER_ROW
==
256
);
read_col
^=
(
read_row
&
(
Base
::
BYTES_PER_ROW
==
32
?
0x01
:
(
Base
::
BYTES_PER_ROW
==
64
?
0x03
:
(
Base
::
BYTES_PER_ROW
==
128
?
0x07
:
0x07
))));
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
}
...
@@ -1402,6 +1412,9 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1402,6 +1412,9 @@ struct Smem_tile_mma_epilogue : public Base {
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t
offset
=
(
this
->
smem_write_
^
(
ni
*
32
))
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
;
uint32_t
offset
=
(
this
->
smem_write_
^
(
ni
*
32
))
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
// }
fmha
::
sts
(
offset
+
0
*
BYTES_PER_ROW
,
x
);
fmha
::
sts
(
offset
+
0
*
BYTES_PER_ROW
,
x
);
fmha
::
sts
(
offset
+
8
*
BYTES_PER_ROW
,
z
);
fmha
::
sts
(
offset
+
8
*
BYTES_PER_ROW
,
z
);
offset
^=
4
*
Base
::
BYTES_PER_STS
;
offset
^=
4
*
Base
::
BYTES_PER_STS
;
...
...
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
d3e64409
...
@@ -120,4 +120,8 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
...
@@ -120,4 +120,8 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
// }
// }
// }
// }
// }
// }
// if (params.d == 128) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
}
}
\ No newline at end of file
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
d3e64409
...
@@ -512,6 +512,17 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -512,6 +512,17 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
fmha
::
gemm_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// float2 tmp0 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][0]));
// printf("frag_dot[0][0]=%.6f, %.6f\n", tmp0.x, tmp0.y);
// float2 tmp1 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][1]));
// printf("frag_dot[0][1]=%.6f, %.6f\n", tmp1.x, tmp1.y);
// }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l = %d, acc_dv[0][0]=%.6f, %.6f\n", l, acc_dv[0][0].elt(2), acc_dv[0][0].elt(3));
// printf("l = %d, acc_dv[0][1]=%.6f, %.6f\n", l, acc_dv[0][1].elt(2), acc_dv[0][1].elt(3));
// }
// __syncthreads();
// __syncthreads();
// Commit the values for Q and dO into shared memory.
// Commit the values for Q and dO into shared memory.
if
(
l
<
steps
-
1
)
{
if
(
l
<
steps
-
1
)
{
...
@@ -577,7 +588,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -577,7 +588,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// if (Is_dropout) {
// if (Is_dropout) {
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// }
// }
dq_out
[
0
]
=
fmha
::
fmul4
(
dq_out
[
0
],
params
.
scale_bmm1f
);
for
(
int
jj
=
0
;
jj
<
Gmem_tile_dq
::
STGS_PER_LOOP
;
++
jj
)
{
dq_out
[
jj
]
=
fmha
::
fmul4
(
dq_out
[
jj
],
params
.
scale_bmm1f
);
}
// Output the values.
// Output the values.
gmem_dq
.
store
(
dq_out
,
0
);
gmem_dq
.
store
(
dq_out
,
0
);
// Move to the next part of the output.
// Move to the next part of the output.
...
@@ -614,7 +627,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -614,7 +627,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
}
}
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
// printf("l final, acc_dv[0][0]=%.6f, %.6f\n", acc_dv[0][0].elt(2), acc_dv[0][0].elt(3));
// printf("l final, acc_dv[0][1]=%.6f, %.6f\n", acc_dv[0][1].elt(2), acc_dv[0][1].elt(3));
// }
// }
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_dkv
::
MMAS_M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile_dkv
::
MMAS_N
;
ni
++
)
{
for
(
int
ni
=
0
;
ni
<
Mma_tile_dkv
::
MMAS_N
;
ni
++
)
{
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
d3e64409
...
@@ -126,7 +126,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -126,7 +126,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
{
}
else
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
&&
!
is_dropout
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
&&
!
launch_params
.
is_dropout
)
{
// TD [2022-06-05] Keep K in registers to reduce register spilling
// TD [2022-06-05] Keep K in registers to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
// Gives about 6% speedup compared to using block size 128.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
128
,
16
,
1
,
4
,
0x18u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
128
,
16
,
1
,
4
,
0x18u
>
;
...
@@ -170,7 +170,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -170,7 +170,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else {
// } else {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0 && !is_dropout) {
// if (dprops->major == 8 && dprops->minor >= 0 && !
launch_params.
is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
d3e64409
...
@@ -382,8 +382,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -382,8 +382,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Apply the mask.
// Apply the mask.
softmax
.
apply_mask
(
mask
);
softmax
.
apply_mask
(
mask
);
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
__syncthreads
();
...
@@ -408,7 +406,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -408,7 +406,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
}
}
}
}
// __half2 p_max[Mma_tile_p::MMAS_M];
softmax
.
template
reduce_max
<
/*zero_init=*/
Is_first
>(
p_max
);
softmax
.
template
reduce_max
<
/*zero_init=*/
Is_first
>(
p_max
);
// if ((threadIdx.x == 0) && (l == 38)) {
// if ((threadIdx.x == 0) && (l == 38)) {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment