Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
d95ee1a9
Commit
d95ee1a9
authored
Nov 25, 2022
by
Tri Dao
Browse files
Speed up compilation by splitting into separate .cu files
parent
b784ed73
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
251 additions
and
165 deletions
+251
-165
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+21
-2
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+6
-2
csrc/flash_attn/src/fmha_bwd_hdim128.cu
csrc/flash_attn/src/fmha_bwd_hdim128.cu
+13
-0
csrc/flash_attn/src/fmha_bwd_hdim32.cu
csrc/flash_attn/src/fmha_bwd_hdim32.cu
+18
-0
csrc/flash_attn/src/fmha_bwd_hdim64.cu
csrc/flash_attn/src/fmha_bwd_hdim64.cu
+31
-0
csrc/flash_attn/src/fmha_bwd_launch_template.h
csrc/flash_attn/src/fmha_bwd_launch_template.h
+17
-55
csrc/flash_attn/src/fmha_fwd_hdim128.cu
csrc/flash_attn/src/fmha_fwd_hdim128.cu
+12
-0
csrc/flash_attn/src/fmha_fwd_hdim32.cu
csrc/flash_attn/src/fmha_fwd_hdim32.cu
+17
-0
csrc/flash_attn/src/fmha_fwd_hdim64.cu
csrc/flash_attn/src/fmha_fwd_hdim64.cu
+17
-0
csrc/flash_attn/src/fmha_fwd_launch_template.h
csrc/flash_attn/src/fmha_fwd_launch_template.h
+92
-0
csrc/flash_attn/src/fmha_kernel.h
csrc/flash_attn/src/fmha_kernel.h
+0
-103
setup.py
setup.py
+7
-3
No files found.
csrc/flash_attn/fmha_api.cpp
View file @
d95ee1a9
...
...
@@ -176,6 +176,16 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
params
.
dsoftmax_sum
=
dsoftmax_sum_d
;
}
void
run_fmha_fwd
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
if
(
launch_params
.
params
.
d
<=
32
)
{
run_fmha_fwd_hdim32
(
launch_params
);
}
else
if
(
launch_params
.
params
.
d
<=
64
)
{
run_fmha_fwd_hdim64
(
launch_params
);
}
else
if
(
launch_params
.
params
.
d
<=
128
)
{
run_fmha_fwd_hdim128
(
launch_params
);
}
}
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
const
at
::
Tensor
&
q
,
// total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
k
,
// total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
...
...
@@ -307,13 +317,22 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
launch_params
.
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
run_fmha_f
p16_sm80
(
launch_params
);
run_fmha_f
wd
(
launch_params
);
std
::
vector
<
at
::
Tensor
>
result
=
{
softmax_lse
};
if
(
return_softmax
)
{
result
.
push_back
(
s
);}
return
result
;
}
void
run_fmha_bwd
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
if
(
params
.
d
<=
32
)
{
run_fmha_bwd_hdim32
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
64
)
{
run_fmha_bwd_hdim64
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
128
)
{
run_fmha_bwd_hdim128
(
params
,
stream
,
configure
);
}
}
std
::
vector
<
at
::
Tensor
>
mha_bwd
(
const
at
::
Tensor
&
dout
,
// total_q x num_heads, x head_size
...
...
@@ -341,7 +360,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool
is_sm80
=
dprops
->
major
==
8
&&
dprops
->
minor
==
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
TORCH_CHECK
(
is_sm8x
||
is_sm75
);
auto
launch
=
&
run_fmha_
dgrad_fp16_sm80
;
auto
launch
=
&
run_fmha_
bwd
;
bool
is_dropout
=
p_dropout
>
0.0
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
csrc/flash_attn/src/fmha.h
View file @
d95ee1a9
...
...
@@ -195,9 +195,13 @@ struct Launch_params{
////////////////////////////////////////////////////////////////////////////////////////////////////
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_fwd_hdim32
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_fwd_hdim64
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_fwd_hdim128
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
);
void
run_fmha_dgrad_fp16_sm80
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
void
run_fmha_bwd_hdim32
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
void
run_fmha_bwd_hdim64
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
void
run_fmha_bwd_hdim128
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
void
run_fmha_block_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
const
bool
configure
);
...
...
csrc/flash_attn/src/fmha_bwd_hdim128.cu
0 → 100644
View file @
d95ee1a9
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim128
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_bwd_hdim32.cu
0 → 100644
View file @
d95ee1a9
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim32
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_bwd_hdim64.cu
0 → 100644
View file @
d95ee1a9
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim64
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>
0
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
dprops
->
major
==
7
&&
dprops
->
minor
==
5
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_
dgrad_fp16_kernel_loop.sm80.cu
→
csrc/flash_attn/src/fmha_
bwd_launch_template.h
View file @
d95ee1a9
/* Copyright (c) 2022, Tri Dao.
*/
// Copyright (c) 2022, Tri Dao.
#pragma once
#include "static_switch.h"
#include "fp16_switch.h"
...
...
@@ -9,7 +10,7 @@
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
// dq_tmp and having to copy dq_tmp to dq.
int
num_splits_heuristic_bwd
(
int
batch_nheads
,
int
num_SMs
,
int
ctas_per_sm
,
int
seqlen
,
inline
int
num_splits_heuristic_bwd
(
int
batch_nheads
,
int
num_SMs
,
int
ctas_per_sm
,
int
seqlen
,
int
blocksize
,
bool
is_causal
)
{
float
n_waves_1
=
float
(
batch_nheads
)
/
(
num_SMs
*
ctas_per_sm
);
float
eff_1
=
n_waves_1
/
ceil
(
n_waves_1
);
...
...
@@ -29,22 +30,22 @@ int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int
}
template
<
typename
Kernel_traits
>
__global__
void
fmha_
dgra
d_dot_do_o_kernel
(
FMHA_dgrad_params
params
)
{
__global__
void
fmha_
bw
d_dot_do_o_kernel
(
FMHA_dgrad_params
params
)
{
fmha
::
compute_dot_do_o
<
Kernel_traits
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
int
loop_steps
=-
1
>
__global__
void
fmha_
dgrad_fp16_sm80
_dq_dk_dv_loop_kernel
(
FMHA_dgrad_params
params
)
{
__global__
void
fmha_
bwd
_dq_dk_dv_loop_kernel
(
FMHA_dgrad_params
params
)
{
fmha
::
compute_dq_dk_dv_1xN
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
loop_steps
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
__global__
void
fmha_
dgrad_fp16_sm80_d
q_dk_dv_loop_seqparallel_kernel
(
FMHA_dgrad_params
params
)
{
__global__
void
fmha_
bwd_
q_dk_dv_loop_seqparallel_kernel
(
FMHA_dgrad_params
params
)
{
fmha
::
compute_dq_dk_dv_seqparallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
>
(
params
);
}
template
<
typename
Kernel_traits
>
void
run_fmha_
dgrad_fp16_sm80
_loop
_
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_fmha_
bwd
_loop
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
...
...
@@ -63,20 +64,20 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params ¶ms, cudaStream_t stre
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
auto
kernel
=
params
.
is_causal
?
&
fmha_
dgrad_fp16_sm80
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
>
:
&
fmha_
dgrad_fp16_sm80
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
>
;
?
&
fmha_
bwd
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
>
:
&
fmha_
bwd
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
>
;
if
(
params
.
seqlen_k
==
blocksize_c
)
{
kernel
=
params
.
is_causal
?
&
fmha_
dgrad_fp16_sm80
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_
dgrad_fp16_sm80
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
/*loop_steps=*/
1
>
;
?
&
fmha_
bwd
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
/*loop_steps=*/
1
>
:
&
fmha_
bwd
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
/*loop_steps=*/
1
>
;
}
else
if
(
params
.
seqlen_k
==
blocksize_c
*
2
)
{
kernel
=
params
.
is_causal
?
&
fmha_
dgrad_fp16_sm80
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_
dgrad_fp16_sm80
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
/*loop_steps=*/
2
>
;
?
&
fmha_
bwd
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
/*loop_steps=*/
2
>
:
&
fmha_
bwd
_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
/*loop_steps=*/
2
>
;
}
auto
kernel_seqparallel
=
params
.
is_causal
?
&
fmha_
dgrad_fp16_sm80_d
q_dk_dv_loop_seqparallel_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
>
:
&
fmha_
dgrad_fp16_sm80_d
q_dk_dv_loop_seqparallel_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
>
;
?
&
fmha_
bwd_
q_dk_dv_loop_seqparallel_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
>
:
&
fmha_
bwd_
q_dk_dv_loop_seqparallel_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
>
;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
...
...
@@ -104,7 +105,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params ¶ms, cudaStream_t stre
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
}
else
{
dim3
grid_dot
(
params
.
b
,
params
.
h
,
(
params
.
seqlen_q
+
128
-
1
)
/
128
);
fmha_
dgra
d_dot_do_o_kernel
<
Kernel_traits
><<<
grid_dot
,
Kernel_traits
::
THREADS
,
0
,
stream
>>>
(
params
);
fmha_
bw
d_dot_do_o_kernel
<
Kernel_traits
><<<
grid_dot
,
Kernel_traits
::
THREADS
,
0
,
stream
>>>
(
params
);
int
num_splits
=
params
.
seqlen_k
/
blocksize_c
;
// seqlen_k is divisible by blocksize_c
dim3
grid
(
params
.
b
,
params
.
h
,
num_splits
);
kernel_seqparallel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
...
...
@@ -112,42 +113,3 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params ¶ms, cudaStream_t stre
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
}
void
run_fmha_dgrad_fp16_sm80
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
d
<=
32
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
else
if
(
params
.
d
<=
64
)
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
dprops
->
major
==
8
&&
dprops
->
minor
>
0
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
dprops
->
major
==
7
&&
dprops
->
minor
==
5
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
}
else
if
(
params
.
d
<=
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_dgrad_fp16_sm80_loop_
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_fwd_hdim128.cu
0 → 100644
View file @
d95ee1a9
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim128
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_fwd_hdim32.cu
0 → 100644
View file @
d95ee1a9
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim32
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
}
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_fwd_hdim64.cu
0 → 100644
View file @
d95ee1a9
// Copyright (c) 2022, Tri Dao.
// Splitting the different head dimentions to different files to speed up compilation.
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim64
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
}
});
}
csrc/flash_attn/src/fmha_f
prop_fp16_kernel.sm80.cu
→
csrc/flash_attn/src/fmha_f
wd_launch_template.h
View file @
d95ee1a9
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
// Copyright (c) 2022, Tri Dao.
#pragma once
#include <vector>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
...
...
@@ -39,7 +18,8 @@
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 95%
// of the best efficiency.
int
num_splits_heuristic_fwd
(
int
batch_nheads
,
int
num_SMs
,
int
ctas_per_sm
,
int
max_splits
)
{
// [2022-11-25] TD: Mark this as "inline" otherwise we get "multiple definition" error.
inline
int
num_splits_heuristic_fwd
(
int
batch_nheads
,
int
num_SMs
,
int
ctas_per_sm
,
int
max_splits
)
{
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
...
...
@@ -60,12 +40,12 @@ int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Return_softmax
>
__global__
void
fmha_f
prop_fp16_sm80
_loop_kernel
(
FMHA_fprop_params
params
)
{
__global__
void
fmha_f
wd
_loop_kernel
(
FMHA_fprop_params
params
)
{
fmha
::
device_1xN_loop
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
>
void
run_fmha_f
p16_sm80
_loop
_
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
void
run_fmha_f
wd
_loop
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
constexpr
int
blocksize_c
=
Kernel_traits
::
Cta_tile_p
::
N
;
const
int
loop_steps
=
(
launch_params
.
params
.
seqlen_k
+
blocksize_c
-
1
)
/
blocksize_c
;
...
...
@@ -80,11 +60,11 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) {
BOOL_SWITCH
(
launch_params
.
is_dropout
,
IsDropoutConst
,
[
&
]
{
auto
kernel
=
launch_params
.
params
.
is_causal
?
(
launch_params
.
return_softmax
?
&
fmha_f
prop_fp16_sm80
_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
true
>
:
&
fmha_f
prop_fp16_sm80
_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
false
>
)
?
&
fmha_f
wd
_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
true
>
:
&
fmha_f
wd
_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
false
>
)
:
(
launch_params
.
return_softmax
?
&
fmha_f
prop_fp16_sm80
_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
true
>
:
&
fmha_f
prop_fp16_sm80
_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
false
>
);
?
&
fmha_f
wd
_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
true
>
:
&
fmha_f
wd
_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
false
>
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
...
...
@@ -110,44 +90,3 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) {
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
}
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
launch_params
.
params
.
d
<=
32
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
else
if
(
launch_params
.
params
.
d
<=
64
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
else
if
(
launch_params
.
params
.
seqlen_k
>=
256
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
}
else
if
(
launch_params
.
params
.
d
<=
128
)
{
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
// to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
// reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
// some speedup (6-10%) for large batch size, but slows things down for smal batch size.
// Now that we have better parallelism (over seqlen_q), block size 128 is faster for small
// batch size and only slightly slower (~3%) on large batch size.
// For causal=True, block size 128 seems always faster (for small & large batch size).
// So we're just gonna use block size 128 for simplicity.
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fp16_sm80_loop_
<
Kernel_traits
>
(
launch_params
);
}
// if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
});
}
\ No newline at end of file
csrc/flash_attn/src/fmha_kernel.h
View file @
d95ee1a9
...
...
@@ -75,107 +75,4 @@ struct BlockInfoPadded {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Cta_tile
>
struct
Noloop_traits
{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
template
<
typename
Block_info
>
inline
__device__
Noloop_traits
(
const
int
bidc
,
const
Block_info
&
binfo
)
:
bidc_
(
bidc
)
{
const
int
seqlen
=
binfo
.
actual_seqlen
;
const
int
steps
=
(
seqlen
+
STEP
-
1
)
/
STEP
;
const
int
steps_per_chunk
=
(
steps
+
CHUNKS
-
1
)
/
CHUNKS
;
const
int
step_begin
=
bidc_
*
steps_per_chunk
;
const
int
step_end
=
min
(
steps
,
(
bidc_
+
1
)
*
steps_per_chunk
);
const
int
actual_steps
=
max
(
0
,
step_end
-
step_begin
);
loop_offset_
=
step_begin
;
num_steps_
=
actual_steps
;
}
template
<
typename
...
Tiles
>
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
using
expand_type
=
int
[];
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
expand_type
{
(
tiles
.
move
(),
0
)...
};
}
}
inline
__device__
int
get_idx_dk
()
const
{
//return bidc_;
return
bidc_
*
2
+
0
;
}
inline
__device__
int
get_idx_dv
()
const
{
//return CHUNKS + bidc_;
return
bidc_
*
2
+
1
;
}
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
// convert loop counter to position in the outer sequence
return
(
loop_offset_
+
l
)
*
STEP
;
}
const
uint32_t
bidc_
;
int
loop_offset_
;
int
num_steps_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
>
std
::
tuple
<
int
,
int
,
int
,
int
,
int
,
int
>
work_dist
(
const
int
total_ctas
,
const
int
heads_total
)
{
constexpr
int
STEPS_PER_HEAD
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
const
int
num_full_heads
=
heads_total
/
total_ctas
;
const
int
heads_last_wave
=
heads_total
%
total_ctas
;
int
num_main_groups
=
0
;
int
main_steps
=
0
;
int
rest_steps
=
0
;
if
(
heads_last_wave
>
0
)
{
// Number of CTA groups that process within heads.
num_main_groups
=
total_ctas
/
heads_last_wave
;
// Remaining CTAs that process between heads.
const
int
rest_ctas
=
total_ctas
-
(
heads_last_wave
*
num_main_groups
);
if
(
rest_ctas
==
0
)
{
// We have exactly "num_main_groups" CTAs to process each of the remaining heads.
main_steps
=
(
STEPS_PER_HEAD
+
num_main_groups
-
1
)
/
num_main_groups
;
num_main_groups
=
STEPS_PER_HEAD
/
main_steps
;
// Here: main_step > 0
rest_steps
=
STEPS_PER_HEAD
%
main_steps
;
}
else
{
// Ideal number of steps if we could load-balance as evenly as possible.
const
int
steps_ideal
=
(
heads_last_wave
*
STEPS_PER_HEAD
+
total_ctas
-
1
)
/
total_ctas
;
// Iterations that a "rest" CTA has to do at most.
const
int
max_rest_iters
=
(
heads_last_wave
+
rest_ctas
-
1
)
/
rest_ctas
;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps
=
steps_ideal
;
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
for
(
;
main_steps
*
num_main_groups
<
STEPS_PER_HEAD
;
main_steps
++
)
{
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
const
int
max_rest_total_steps
=
rest_steps
*
max_rest_iters
;
if
(
max_rest_total_steps
<
main_steps
)
break
;
}
rest_steps
=
STEPS_PER_HEAD
-
main_steps
*
num_main_groups
;
}
}
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
const
int
max_steps
=
STEPS_PER_HEAD
*
num_full_heads
+
std
::
max
(
main_steps
,
rest_steps
);
const
int
elts_per_thread_per_step
=
Mma_tile_p
::
MMAS_M
*
Mma_tile_p
::
MMAS_N
*
8
;
const
int
elts_per_thread
=
max_steps
*
elts_per_thread_per_step
;
return
{
num_full_heads
,
num_main_groups
,
heads_last_wave
,
main_steps
,
rest_steps
,
elts_per_thread
};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
setup.py
View file @
d95ee1a9
...
...
@@ -119,8 +119,12 @@ ext_modules.append(
name
=
"flash_attn_cuda"
,
sources
=
[
"csrc/flash_attn/fmha_api.cpp"
,
"csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu"
,
"csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu"
,
"csrc/flash_attn/src/fmha_fwd_hdim32.cu"
,
"csrc/flash_attn/src/fmha_fwd_hdim64.cu"
,
"csrc/flash_attn/src/fmha_fwd_hdim128.cu"
,
"csrc/flash_attn/src/fmha_bwd_hdim32.cu"
,
"csrc/flash_attn/src/fmha_bwd_hdim64.cu"
,
"csrc/flash_attn/src/fmha_bwd_hdim128.cu"
,
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu"
,
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu"
,
],
...
...
@@ -152,7 +156,7 @@ ext_modules.append(
setup
(
name
=
"flash_attn"
,
version
=
"0.2.
1
"
,
version
=
"0.2.
2
"
,
packages
=
find_packages
(
exclude
=
(
"build"
,
"csrc"
,
"include"
,
"tests"
,
"dist"
,
"docs"
,
"benchmarks"
,
"flash_attn.egg-info"
,)
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment