Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
47ad0761
Commit
47ad0761
authored
Mar 28, 2024
by
Woosuk Kwon
Browse files
Remove backward pass
parent
06c23892
Changes
21
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2 additions
and
2256 deletions
+2
-2256
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+2
-563
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
+0
-10
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+0
-826
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+0
-330
csrc/flash_attn/src/flash_bwd_preprocess_kernel.h
csrc/flash_attn/src/flash_bwd_preprocess_kernel.h
+0
-377
No files found.
csrc/flash_attn/flash_api.cpp
View file @
47ad0761
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
128
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim128
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
128
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim128
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
160
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim160
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
160
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim160
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
192
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim192
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
192
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim192
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
224
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim224
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
224
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim224
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
256
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim256
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
256
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim256
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
32
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim32
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
32
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
64
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
deleted
100644 → 0
View file @
06c23892
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
96
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_kernel.h
deleted
100644 → 0
View file @
06c23892
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_bwd_launch_template.h
deleted
100644 → 0
View file @
06c23892
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "static_switch.h"
#include "flash.h"
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
DEFINE_FLASH_BACKWARD_KERNEL
(
flash_bwd_dq_dk_dv_loop_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_M
,
bool
Is_even_K
)
{
#if defined(ARCH_SUPPORTS_FLASH)
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
Is_even_M
,
Is_even_K
>
(
params
);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL
(
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
)
{
#if defined(ARCH_SUPPORTS_FLASH)
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
>
(
params
);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template
<
bool
Clear_dQaccum
=
true
,
typename
Kernel_traits
>
__global__
void
flash_bwd_dot_do_o_kernel
(
const
Flash_bwd_params
params
)
{
flash
::
compute_dot_do_o
<
Clear_dQaccum
,
Kernel_traits
>
(
params
);
}
template
<
typename
Kernel_traits
>
__global__
void
flash_bwd_clear_dkvaccum_kernel
(
const
Flash_bwd_params
params
)
{
flash
::
clear_dKVaccum
<
Kernel_traits
>
(
params
);
}
template
<
typename
Kernel_traits
>
__global__
void
flash_bwd_convert_dq_kernel
(
const
Flash_bwd_params
params
,
const
int
nsplits
)
{
flash
::
convert_dQ
<
Kernel_traits
>
(
params
,
nsplits
);
}
template
<
typename
Kernel_traits
>
__global__
void
flash_bwd_convert_dkv_kernel
(
const
Flash_bwd_params
params
)
{
flash
::
convert_dKV
<
Kernel_traits
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
void
run_flash_bwd_seqk_parallel
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid_m
(
num_m_block
,
params
.
b
,
params
.
h
);
const
int
num_n_block
=
(
params
.
seqlen_k
+
Kernel_traits
::
kBlockN
-
1
)
/
Kernel_traits
::
kBlockN
;
int
gridDimx
=
num_n_block
;
if
(
params
.
deterministic
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
gridDimx
=
(
dprops
->
multiProcessorCount
+
params
.
b
*
params
.
h
-
1
)
/
(
params
.
b
*
params
.
h
);
}
dim3
grid_n
(
gridDimx
,
params
.
b
,
params
.
h
);
if
(
!
params
.
deterministic
)
{
flash_bwd_dot_do_o_kernel
<
true
,
Kernel_traits
><<<
grid_m
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
{
flash_bwd_dot_do_o_kernel
<
false
,
Kernel_traits
><<<
grid_m
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
// a multiple of kBlockN, we'll need to apply mask in the loop.
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
constexpr
int
smem_size_dq_dk_dv
=
Kernel_traits
::
kSmemSize1colblock
;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
EVENK_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
kernel
<<<
grid_n
,
Kernel_traits
::
kNThreads
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
auto
kernel_dq
=
&
flash_bwd_convert_dq_kernel
<
Kernel_traits
>
;
if
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel_dq
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemdQSize
));
}
kernel_dq
<<<
grid_m
,
Kernel_traits
::
kNThreads
,
Kernel_traits
::
kSmemdQSize
,
stream
>>>
(
params
,
!
params
.
deterministic
?
1
:
gridDimx
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
#ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
>
(
params
,
stream
);
#endif
}
template
<
typename
T
>
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
32
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
((
3
*
128
+
2
*
128
)
*
Headdim
+
2
*
128
*
128
))
{
// 104 KB
if
constexpr
(
!
Is_dropout
)
{
// We can afford more registers to keep V in registers
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
}
else
{
// 96 KB
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim64
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
64
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if
(
max_smem_per_block
>=
144
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
// This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
}
else
{
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// } else {
// }
}
});
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
}
template
<
typename
T
>
void
run_mha_bwd_hdim96
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
96
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
if
constexpr
(
!
Is_dropout
)
{
// 92KB
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
// 116 KB
// This is faster for dropout since we don't have many registers to spare
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim128
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
if
(
max_smem_per_block
>=
144
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
}
else
{
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim160
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
160
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim192
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
192
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
136
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim224
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
224
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim256
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
256
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
176
*
1024
)
{
// H100
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
if
(
max_smem_per_block
>=
144
*
1024
)
{
// A100, we don't do double buffering to save smem
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
// sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if
constexpr
(
!
Is_dropout
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
32
,
8
,
4
,
1
,
2
,
true
,
true
,
T
>
,
false
>
(
params
,
stream
);
}
}
});
}
csrc/flash_attn/src/flash_bwd_preprocess_kernel.h
deleted
100644 → 0
View file @
06c23892
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment