Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dc4b9ad6
Unverified
Commit
dc4b9ad6
authored
Nov 19, 2023
by
Driss Guessous
Committed by
GitHub
Nov 19, 2023
Browse files
add checks (#640)
parent
01771645
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
0 deletions
+27
-0
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+21
-0
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+6
-0
No files found.
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
dc4b9ad6
...
...
@@ -143,6 +143,9 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
((
3
*
128
+
2
*
128
)
*
Headdim
+
2
*
128
*
128
))
{
// 104 KB
if
constexpr
(
!
Is_dropout
)
{
// We can afford more registers to keep V in registers
...
...
@@ -164,6 +167,9 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
...
...
@@ -207,6 +213,9 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// if (params.h == params.h_k) {
...
...
@@ -234,6 +243,9 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// if (params.h == params.h_k) {
...
...
@@ -270,6 +282,9 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
...
...
@@ -287,6 +302,9 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
136
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
...
...
@@ -312,6 +330,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
176
*
1024
)
{
// H100
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
dc4b9ad6
...
...
@@ -289,6 +289,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
...
...
@@ -317,6 +320,9 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
&
max_smem_per_sm
,
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device
);
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment