Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
2712aa4c
Commit
2712aa4c
authored
Jun 02, 2022
by
Tri Dao
Browse files
Support Turing mma instructions
parent
05087332
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
82 additions
and
18 deletions
+82
-18
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+6
-4
csrc/flash_attn/src/fmha/gemm.h
csrc/flash_attn/src/fmha/gemm.h
+61
-7
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+3
-0
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+9
-6
setup.py
setup.py
+3
-1
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
2712aa4c
...
@@ -117,7 +117,8 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -117,7 +117,8 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
);
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
TORCH_CHECK
((
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
||
is_sm75
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
bool
is_dropout
=
p_dropout
>
0.0
;
bool
is_dropout
=
p_dropout
>
0.0
;
Launch_params
<
Fused_multihead_attention_fprop_params
>
launch_params
(
dprops
,
stream
,
is_dropout
,
return_softmax
);
Launch_params
<
Fused_multihead_attention_fprop_params
>
launch_params
(
dprops
,
stream
,
is_dropout
,
return_softmax
);
...
@@ -143,7 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
...
@@ -143,7 +144,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int
base_N
=
head_size
==
128
?
128
:
256
;
int
base_N
=
(
head_size
==
128
||
(
is_sm75
&&
head_size
==
64
))
?
128
:
256
;
// int base_N = 256;
// int base_N = 256;
int
seq_len
=
512
;
int
seq_len
=
512
;
if
(
max_seq_len
<=
128
)
{
if
(
max_seq_len
<=
128
)
{
...
@@ -236,7 +237,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -236,7 +237,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
c10
::
optional
<
at
::
Generator
>
gen_
c10
::
optional
<
at
::
Generator
>
gen_
)
{
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
);
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
TORCH_CHECK
((
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
||
is_sm75
);
auto
launch
=
&
run_fmha_dgrad_fp16_sm80
;
auto
launch
=
&
run_fmha_dgrad_fp16_sm80
;
bool
is_dropout
=
p_dropout
>
0.0
;
bool
is_dropout
=
p_dropout
>
0.0
;
...
@@ -268,7 +270,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -268,7 +270,7 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int
base_N
=
head_size
==
128
?
128
:
256
;
int
base_N
=
(
head_size
==
128
||
(
is_sm75
&&
head_size
==
64
))
?
128
:
256
;
int
seq_len
=
512
;
int
seq_len
=
512
;
if
(
max_seq_len
<=
128
)
{
if
(
max_seq_len
<=
128
)
{
seq_len
=
128
;
seq_len
=
128
;
...
...
csrc/flash_attn/src/fmha/gemm.h
View file @
2712aa4c
...
@@ -257,7 +257,15 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
...
@@ -257,7 +257,15 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
inline
__device__
void
gemm_cl
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
inline
__device__
void
gemm_cl
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
using
Shape
=
cutlass
::
gemm
::
GemmShape
<
16
*
M
,
16
*
N
,
16
>
;
using
Shape
=
cutlass
::
gemm
::
GemmShape
<
16
*
M
,
16
*
N
,
16
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
#else
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
4
>
;
// TD [2022-06-02] We don't support Volta (SM70) yet.
assert
(
0
);
#endif
using
Element
=
cutlass
::
half_t
;
using
Element
=
cutlass
::
half_t
;
using
ElementC
=
float
;
using
ElementC
=
float
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
...
@@ -267,19 +275,65 @@ inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N
...
@@ -267,19 +275,65 @@ inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N
Shape
,
InstructionShape
,
Element
,
LayoutA
,
Element
,
LayoutB
,
ElementC
,
Shape
,
InstructionShape
,
Element
,
LayoutA
,
Element
,
LayoutB
,
ElementC
,
cutlass
::
layout
::
RowMajor
,
cutlass
::
arch
::
OpMultiplyAdd
,
1
,
true
>::
Type
;
cutlass
::
layout
::
RowMajor
,
cutlass
::
arch
::
OpMultiplyAdd
,
1
,
true
>::
Type
;
using
FragmentA
=
typename
WarpMma
::
FragmentA
;
constexpr
int
kIters
=
Shape
::
kK
/
InstructionShape
::
kK
;
using
FragmentB
=
typename
WarpMma
::
FragmentB
;
// using FragmentA = typename WarpMma::FragmentA;
// using FragmentB = typename WarpMma::FragmentB;
using
FragmentA
=
typename
WarpMma
::
ArchMmaOperator
::
FragmentA
;
using
FragmentB
=
typename
WarpMma
::
ArchMmaOperator
::
FragmentB
;
using
FragmentC
=
typename
WarpMma
::
FragmentC
;
using
FragmentC
=
typename
WarpMma
::
FragmentC
;
static_assert
(
FragmentA
::
kStorageElements
==
M
*
a
[
0
].
NUM_REGS
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
static_assert
(
FragmentB
::
kStorageElements
==
N
*
b
[
0
].
NUM_REGS
);
// printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
// printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
// printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
// printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
// printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
// printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
// }
// static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
// static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
static_assert
(
FragmentA
::
kStorageElements
*
kIters
==
a
[
0
].
NUM_REGS
);
static_assert
(
FragmentB
::
kStorageElements
*
kIters
*
16
/
InstructionShape
::
kN
==
b
[
0
].
NUM_REGS
);
static_assert
(
FragmentC
::
kStorageElements
==
M
*
N
*
acc
[
0
][
0
].
NUM_REGS
);
static_assert
(
FragmentC
::
kStorageElements
==
M
*
N
*
acc
[
0
][
0
].
NUM_REGS
);
const
FragmentA
a_cl
=
reinterpret_cast
<
const
FragmentA
(
&
)
>
(
a
);
//
const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
const
FragmentB
b_cl
=
reinterpret_cast
<
const
FragmentB
(
&
)
>
(
b
);
//
const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
FragmentC
c_cl
=
reinterpret_cast
<
FragmentC
(
&
)
>
(
acc
);
FragmentC
c_cl
=
reinterpret_cast
<
FragmentC
(
&
)
>
(
acc
);
FragmentA
a_cl
[
kIters
][
M
];
FragmentA
b_cl
[
kIters
][
N
];
constexpr
int
kRegs
=
InstructionShape
::
kK
==
16
?
4
:
2
;
#pragma unroll
for
(
int
iter
=
0
;
iter
<
kIters
;
iter
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
uint32_t
*
a_ptr
=
a_cl
[
iter
][
mi
].
raw_data
();
#pragma unroll
for
(
int
ki
=
0
;
ki
<
kRegs
;
ki
++
)
{
a_ptr
[
ki
]
=
a
[
mi
].
regs_
[
iter
*
kRegs
+
ki
];
}
}
}
#pragma unroll
for
(
int
iter
=
0
;
iter
<
kIters
;
iter
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint32_t
*
b_ptr
=
b_cl
[
iter
][
ni
].
raw_data
();
#pragma unroll
for
(
int
ki
=
0
;
ki
<
kRegs
;
ki
++
)
{
// b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
// TD [2022-06-02] For some reason the order for frag_b is different.
b_ptr
[
ki
]
=
b
[
ni
].
regs_
[
InstructionShape
::
kK
==
16
?
iter
*
kRegs
+
ki
:
ki
*
kRegs
+
iter
];
}
}
}
WarpMma
mma_op
;
WarpMma
mma_op
;
mma_op
(
c_cl
,
a_cl
,
b_cl
,
c_cl
);
// mma_op(c_cl, a_cl, b_cl, c_cl);
#pragma unroll
for
(
int
iter
=
0
;
iter
<
kIters
;
iter
++
)
{
mma_op
(
c_cl
,
reinterpret_cast
<
const
typename
WarpMma
::
FragmentA
(
&
)
>
(
a_cl
[
iter
]),
reinterpret_cast
<
const
typename
WarpMma
::
FragmentB
(
&
)
>
(
b_cl
[
iter
]),
c_cl
);
}
// The modified c_cl is not copied back into acc, idk why
// The modified c_cl is not copied back into acc, idk why
#pragma unroll
#pragma unroll
...
...
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
2712aa4c
...
@@ -88,6 +88,9 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
...
@@ -88,6 +88,9 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
}
else
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>
0
)
{
}
else
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>
0
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
else
if
(
dprops
->
major
==
7
&&
dprops
->
minor
==
5
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
);
}
}
}
}
}
else
if
(
params
.
d
==
128
)
{
}
else
if
(
params
.
d
==
128
)
{
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
2712aa4c
...
@@ -105,12 +105,15 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
...
@@ -105,12 +105,15 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
if
(
launch_params
.
params
.
s
==
128
)
{
if
(
launch_params
.
params
.
s
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
launch_params
.
params
.
s
==
256
)
{
}
else
if
(
launch_params
.
params
.
s
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
)
{
}
else
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
else
if
(
dprops
->
major
==
7
&&
dprops
->
minor
==
5
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
,
configure
);
}
}
}
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
}
else
if
(
launch_params
.
params
.
d
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
>
;
...
...
setup.py
View file @
2712aa4c
...
@@ -107,7 +107,9 @@ raise_if_cuda_home_none("flash_attn")
...
@@ -107,7 +107,9 @@ raise_if_cuda_home_none("flash_attn")
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
<
11
:
if
int
(
bare_metal_major
)
<
11
:
raise
RuntimeError
(
"--flashattn only supported on SM80+"
)
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_75,code=sm_75"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment