Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b1fbbd83
Commit
b1fbbd83
authored
Aug 29, 2023
by
Tri Dao
Browse files
Implement splitKV attention
parent
7a983df7
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
825 additions
and
11 deletions
+825
-11
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+66
-1
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+5
-0
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+6
-6
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+569
-0
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+74
-4
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
+7
-0
No files found.
csrc/flash_attn/flash_api.cpp
View file @
b1fbbd83
...
...
@@ -178,11 +178,57 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
FWD_HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
run_mha_fwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
if
(
params
.
num_splits
<=
1
)
{
// If we don't set it num_splits == 0
run_mha_fwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
}
else
{
run_mha_fwd_splitkv_dispatch
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
}
});
});
}
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 85%
// of the best efficiency.
inline
int
num_splits_heuristic
(
int
batch_nheads_mblocks
,
int
num_SMs
,
int
num_n_blocks
,
int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nheads_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
,
num_n_blocks
});
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
auto
ceildiv
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto
is_split_eligible
=
[
&
ceildiv
,
&
num_n_blocks
](
int
num_splits
)
{
return
num_splits
==
1
||
ceildiv
(
num_n_blocks
,
num_splits
)
!=
ceildiv
(
num_n_blocks
,
num_splits
-
1
);
};
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
efficiency
.
push_back
(
0.
f
);
}
else
{
float
n_waves
=
float
(
batch_nheads_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_waves
/
ceil
(
n_waves
);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
}
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
continue
;
}
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
// printf("num_splits chosen = %d\n", num_splits);
return
num_splits
;
}
}
return
1
;
}
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
const
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
k
,
// batch_size x seqlen_k x num_heads_k x head_size
...
...
@@ -294,6 +340,25 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
softmax_scale
,
is_causal
);
// This needs to match with run_mha_fwd_splitkv_dispatch
const
int
block_n
=
is_sm90
||
is_sm8x
?
(
head_size
<=
64
?
256
:
(
head_size
<=
160
?
128
:
64
))
:
(
head_size
<=
64
?
256
:
(
head_size
<=
128
?
128
:
64
));
const
int
num_n_blocks
=
(
seqlen_k
+
block_n
-
1
)
/
block_n
;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const
int
num_m_blocks
=
(
seqlen_q
+
64
-
1
)
/
64
;
params
.
num_splits
=
1
;
if
(
p_dropout
==
0.0
f
)
{
// SplitKV is not implemented for dropout
params
.
num_splits
=
num_splits_heuristic
(
batch_size
*
num_heads
*
num_m_blocks
,
dprops
->
multiProcessorCount
,
num_n_blocks
,
64
);
if
(
params
.
num_splits
>
1
)
{
at
::
Tensor
softmax_lse_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
num_splits
,
batch_size
,
num_heads
,
seqlen_q
,
head_size_rounded
},
opts
.
dtype
(
at
::
kFloat
));
params
.
softmax_lseaccum_ptr
=
softmax_lse_accum
.
data_ptr
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
();
}
}
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
...
...
csrc/flash_attn/src/flash.h
View file @
b1fbbd83
...
...
@@ -53,6 +53,7 @@ struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void
*
__restrict__
o_ptr
;
void
*
__restrict__
oaccum_ptr
;
// The stride between rows of O.
index_t
o_batch_stride
;
...
...
@@ -64,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params {
// The pointer to the softmax sum.
void
*
__restrict__
softmax_lse_ptr
;
void
*
__restrict__
softmax_lseaccum_ptr
;
// The dimensions.
int
b
,
seqlen_q
,
seqlen_k
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
;
...
...
@@ -96,6 +98,8 @@ struct Flash_fwd_params : public Qkv_params {
bool
is_bf16
;
bool
is_causal
;
int
num_splits
;
// For split-KV version
};
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -140,5 +144,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_bwd_
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
b1fbbd83
...
...
@@ -64,7 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
constexpr
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
...
...
@@ -75,7 +75,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
});
auto
kernel_dq
=
&
flash_bwd_convert_dq_kernel
<
Kernel_traits
>
;
if
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
if
constexpr
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel_dq
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemdQSize
));
}
...
...
@@ -103,7 +103,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
constexpr
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
...
...
@@ -114,7 +114,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
});
auto
kernel_dkv
=
&
flash_bwd_convert_dkv_kernel
<
Kernel_traits
>
;
if
(
Kernel_traits
::
kSmemKVSize
>=
48
*
1024
)
{
if
constexpr
(
Kernel_traits
::
kSmemKVSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel_dkv
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemKVSize
));
}
...
...
@@ -147,7 +147,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
// if (smem_size_dq_dk_dv >= 48 * 1024) {
// if
constexpr
(smem_size_dq_dk_dv >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
// }
...
...
@@ -159,7 +159,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
// });
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
// if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
// if
constexpr
(Kernel_traits::kSmemdQSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
// }
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
b1fbbd83
This diff is collapsed.
Click to expand it.
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
b1fbbd83
...
...
@@ -15,6 +15,17 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_even_MN
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_even_MN
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
int
Log_max_splits
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_combine_kernel
(
Flash_fwd_params
params
)
{
static_assert
(
Log_max_splits
>=
1
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
Log_max_splits
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
...
...
@@ -35,13 +46,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// Will only return softmax if dropout, to reduce compilation time.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
if
(
smem_size
>=
48
*
1024
)
{
if
constexpr
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
kNThreads
,
smem_size
);
//
int ctas_per_sm;
//
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
//
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
...
...
@@ -50,6 +61,65 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
}
template
<
typename
Kernel_traits
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
num_splits
,
params
.
b
*
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
// TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit.
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
constexpr
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
16
-
1
)
/
16
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
// } else if (params.num_splits <= 128) {
// flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
constexpr
int
kBlockM
=
64
;
// Fixed for all head dimensions
if
(
!
is_sm8x
)
{
// A100, H100
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
constexpr
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
160
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
}
else
{
// Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
constexpr
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
}
}
template
<
typename
T
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
Headdim
=
32
;
...
...
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
128
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
160
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
160
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
192
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
192
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
224
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
224
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
256
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
256
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
0 → 100644
View file @
b1fbbd83
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment