Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhangdong1
Block-Sparse-Attention
Commits
4f83cf8f
Commit
4f83cf8f
authored
Oct 10, 2024
by
Junxian
Browse files
[release] v0.0.1
parents
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1245 additions
and
0 deletions
+1245
-0
csrc/block_sparse_attn/src/flash_fwd_launch_template.h
csrc/block_sparse_attn/src/flash_fwd_launch_template.h
+471
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
...lock_sparse_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
...block_sparse_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
...block_sparse_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
...block_sparse_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
...block_sparse_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
...block_sparse_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
...block_sparse_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
+7
-0
csrc/block_sparse_attn/src/generate_kernels.py
csrc/block_sparse_attn/src/generate_kernels.py
+106
-0
csrc/block_sparse_attn/src/kernel_traits.h
csrc/block_sparse_attn/src/kernel_traits.h
+397
-0
csrc/block_sparse_attn/src/kernel_traits_sm90.h
csrc/block_sparse_attn/src/kernel_traits_sm90.h
+159
-0
No files found.
csrc/block_sparse_attn/src/flash_fwd_launch_template.h
0 → 100644
View file @
4f83cf8f
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
/******************************************************************************
* Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_fwd_launch_template.h
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Is_exact_streaming
>
__global__
void
flash_fwd_block_kernel
(
Flash_fwd_params
params
)
{
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_block_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
,
Is_exact_streaming
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
}
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_combine_kernel
(
Flash_fwd_params
params
)
{
static_assert
(
Log_max_splits
>=
1
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
kBlockM
,
Log_max_splits
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
b
,
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
}
// blocksparse
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_fwd_block
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
b
,
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
return_softmax
=
params
.
p_ptr
!=
nullptr
;
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_exact_streaming
,
Is_exact_streaming
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_block_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&
!
Is_causal
,
false
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
,
Is_exact_streaming
&&
Is_causal
>
;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
}
template
<
typename
Kernel_traits
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
constexpr
size_t
smem_size
=
Kernel_traits
::
kSmemSize
;
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
});
});
if
(
params
.
num_splits
>
1
)
{
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr
static
int
kBlockM
=
Kernel_traits
::
kHeadDim
%
128
==
0
?
4
:
(
Kernel_traits
::
kHeadDim
%
64
==
0
?
8
:
16
);
dim3
grid_combine
((
params
.
b
*
params
.
h
*
params
.
seqlen_q
+
kBlockM
-
1
)
/
kBlockM
);
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
if
(
params
.
num_splits
<=
2
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
1
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
4
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
2
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
8
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
3
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
16
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
4
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
32
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
5
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
64
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
6
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
else
if
(
params
.
num_splits
<=
128
)
{
flash_fwd_splitkv_combine_kernel
<
Kernel_traits
,
kBlockM
,
7
,
IsEvenKConst
><<<
grid_combine
,
Kernel_traits
::
kNThreads
,
0
,
stream
>>>
(
params
);
}
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
}
}
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
}
template
<
typename
T
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
32
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
64
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
96
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// These two are always slower
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 1st ones are good for H100, A100
// 2nd one is good for A6000 bc we get slightly better occupancy
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
160
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal.
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
192
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
224
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
))
{
// 112 KB
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
// If we have N = 32, there are only 1024 elements to load at once, where each load
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
256
;
int
device
;
cudaGetDevice
(
&
device
);
int
max_smem_per_sm
,
max_smem_per_block
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_sm
,
cudaDevAttrMaxSharedMemoryPerMultiprocessor
,
device
);
status_
=
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
)
&&
max_smem_per_sm
<
4
*
Headdim
*
(
64
+
2
*
64
))
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_block_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
32
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd_block
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
});
});
}
template
<
typename
T
>
void
run_mha_fwd_block_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
64
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd_block
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
else
{
run_flash_fwd_block
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template
<
typename
T
>
void
run_mha_fwd_block_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
run_flash_fwd_block
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fwd_block
<
Flash_fwd_kernel_traits
<
Headdim
,
64
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fwd_block
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 1st ones are good for H100, A100
// 2nd one is good for A6000 bc we get slightly better occupancy
}
else
{
run_flash_fwd_block
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
32
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
csrc/block_sparse_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
128
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
160
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
160
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
192
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
192
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
224
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
224
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
256
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
256
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
32
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
64
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
0 → 100644
View file @
4f83cf8f
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
96
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/block_sparse_attn/src/generate_kernels.py
0 → 100644
View file @
4f83cf8f
# Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602
# This file is run to generate the kernel instantiations for the flash_attn kernels
# They are written to several files in order to speed up compilation
import
argparse
import
itertools
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
List
,
Optional
DTYPE_MAP
=
{
"fp16"
:
"cutlass::half_t"
,
"bf16"
:
"cutlass::bfloat16_t"
,
}
SM
=
[
80
]
# Sm80 kernels support up to
HEAD_DIMENSIONS
=
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
KERNEL_IMPL_TEMPLATE_FWD
=
"""#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
}}
"""
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
=
"""#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream);
"""
KERNEL_IMPL_TEMPLATE_BWD
=
"""#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure);
}}
"""
@
dataclass
class
Kernel
:
sm
:
int
dtype
:
str
head_dim
:
int
direction
:
str
@
property
def
template
(
self
)
->
str
:
if
self
.
direction
==
"fwd"
:
return
KERNEL_IMPL_TEMPLATE_FWD
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
)
elif
self
.
direction
==
"bwd"
:
return
KERNEL_IMPL_TEMPLATE_BWD
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
)
else
:
return
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
)
@
property
def
filename
(
self
)
->
str
:
return
f
"flash_
{
self
.
direction
}
_hdim
{
self
.
head_dim
}
_
{
self
.
dtype
}
_sm
{
self
.
sm
}
.cu"
def
get_all_kernels
()
->
List
[
Kernel
]:
for
dtype
,
head_dim
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
SM
):
for
direction
in
[
"fwd"
,
"bwd"
,
"fwd_split"
]:
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
direction
=
direction
)
def
write_kernel
(
kernel
:
Kernel
,
autogen_dir
:
Path
)
->
None
:
prelude
=
"""// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
\n
"""
(
autogen_dir
/
kernel
.
filename
).
write_text
(
prelude
+
kernel
.
template
)
def
main
(
output_dir
:
Optional
[
str
])
->
None
:
if
output_dir
is
None
:
output_dir
=
Path
(
__file__
).
parent
else
:
output_dir
=
Path
(
output_dir
)
for
kernel
in
get_all_kernels
():
write_kernel
(
kernel
,
output_dir
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
prog
=
"generate_kernels"
,
description
=
"Generate the flash_attention kernels template instantiations"
,
)
# Set an optional output directory
parser
.
add_argument
(
"-o"
,
"--output_dir"
,
required
=
False
,
help
=
"Where to generate the kernels "
" will default to the current directory "
,
)
args
=
parser
.
parse_args
()
main
(
args
.
output_dir
)
csrc/block_sparse_attn/src/kernel_traits.h
0 → 100644
View file @
4f83cf8f
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using
namespace
cute
;
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
typename
elem_type
=
cutlass
::
half_t
>
struct
Flash_kernel_traits
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
Element
=
elem_type
;
static
constexpr
bool
Has_cp_async
=
true
;
#else
using
Element
=
cutlass
::
half_t
;
static
constexpr
bool
Has_cp_async
=
false
;
#endif
using
ElementAccum
=
float
;
using
index_t
=
int64_t
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
SM80_16x8x16_F32F16F16F32_TN
>
,
MMA_Atom
<
SM80_16x8x16_F32BF16BF16F32_TN
>
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
#else
using
MMA_Atom_Arch
=
MMA_Atom
<
SM75_16x8x8_F32F16F16F32_TN
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_2
,
_2
>>
;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using
SmemCopyAtom
=
Copy_Atom
<
SM75_U32x4_LDSM_N
,
elem_type
>
;
using
SmemCopyAtomTransposed
=
Copy_Atom
<
SM75_U16x8_LDSM_T
,
elem_type
>
;
#else
using
SmemCopyAtom
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemCopyAtomTransposed
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
#endif
};
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
bool
Is_Q_in_regs_
=
false
,
bool
Share_Q_K_smem_
=
false
,
typename
elem_type
=
cutlass
::
half_t
,
typename
Base
=
Flash_kernel_traits
<
kHeadDim_
,
kBlockM_
,
kBlockN_
,
kNWarps_
,
elem_type
>
>
struct
Flash_fwd_kernel_traits
:
public
Base
{
using
Element
=
typename
Base
::
Element
;
using
ElementAccum
=
typename
Base
::
ElementAccum
;
using
index_t
=
typename
Base
::
index_t
;
static
constexpr
bool
Has_cp_async
=
Base
::
Has_cp_async
;
using
SmemCopyAtom
=
typename
Base
::
SmemCopyAtom
;
using
SmemCopyAtomTransposed
=
typename
Base
::
SmemCopyAtomTransposed
;
static
constexpr
bool
Share_Q_K_smem
=
Share_Q_K_smem_
;
static
constexpr
bool
Is_Q_in_regs
=
Is_Q_in_regs_
||
Share_Q_K_smem
;
// The number of threads.
static
constexpr
int
kNWarps
=
kNWarps_
;
//number of warps in a thread block
static
constexpr
int
kNThreads
=
kNWarps
*
32
;
static
constexpr
int
kBlockM
=
kBlockM_
;
static
constexpr
int
kBlockN
=
kBlockN_
;
static
constexpr
int
kHeadDim
=
kHeadDim_
;
static_assert
(
kHeadDim
%
32
==
0
);
static
constexpr
int
kBlockKSmem
=
kHeadDim
%
64
==
0
?
64
:
32
;
static
constexpr
int
kBlockKGmem
=
kHeadDim
%
128
==
0
?
128
:
(
kHeadDim
%
64
==
0
?
64
:
32
);
static
constexpr
int
kSwizzle
=
kBlockKSmem
==
32
?
2
:
3
;
using
TiledMma
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
// 4x1x1 or 8x1x1 thread group
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using
SmemLayoutAtomQ
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout
<
Shape
<
_8
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemLayoutKV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
using
SmemLayoutAtomVtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
using
SmemLayoutAtomVtransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomVtransposedNoSwizzle
{}));
using
SmemLayoutVtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVtransposed
{},
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVtransposedNoSwizzle
{},
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using
SmemLayoutAtomO
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutO
=
decltype
(
tile_to_shape
(
SmemLayoutAtomO
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemCopyAtomO
=
Copy_Atom
<
DefaultCopy
,
Element
>
;
using
SmemCopyAtomOaccum
=
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
;
static
constexpr
int
kSmemQCount
=
size
(
SmemLayoutQ
{});
static
constexpr
int
kSmemKVCount
=
size
(
SmemLayoutKV
{})
*
2
;
static
constexpr
int
kSmemQSize
=
kSmemQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemKVSize
=
kSmemKVCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemSize
=
Share_Q_K_smem
?
std
::
max
(
kSmemQSize
,
kSmemKVSize
)
:
kSmemQSize
+
kSmemKVSize
;
static
constexpr
int
kGmemElemsPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
Element
);
static_assert
(
kHeadDim
%
kGmemElemsPerLoad
==
0
,
"kHeadDim must be a multiple of kGmemElemsPerLoad"
);
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static
constexpr
int
kGmemThreadsPerRow
=
kBlockKSmem
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRow
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRow"
);
using
GmemLayoutAtom
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRow
>
,
Int
<
kGmemThreadsPerRow
>>
,
Stride
<
Int
<
kGmemThreadsPerRow
>
,
_1
>>
;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using
Gmem_copy_struct
=
std
::
conditional_t
<
Has_cp_async
,
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
DefaultCopy
>
;
using
GmemTiledCopyQKV
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
Element
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per read
using
GmemTiledCopyO
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
static
constexpr
int
kGmemThreadsPerRowP
=
kBlockN
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRowP
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRowP"
);
using
GmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRowP
>
,
Int
<
kGmemThreadsPerRowP
>>
,
Stride
<
Int
<
kGmemThreadsPerRowP
>
,
_1
>>
;
using
GmemTiledCopyP
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomP
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
using
GmemLayoutAtomOaccum
=
std
::
conditional_t
<
kBlockKSmem
==
32
,
Layout
<
Shape
<
_16
,
_8
>
,
// Thread layout, 8 threads per row
Stride
<
_8
,
_1
>>
,
Layout
<
Shape
<
_8
,
_16
>
,
// Thread layout, 16 threads per row
Stride
<
_16
,
_1
>>
>
;
using
GmemTiledCopyOaccum
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
{},
GmemLayoutAtomOaccum
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per store
using
GmemLayoutAtomRotcossin
=
GmemLayoutAtom
;
using
GmemTiledCopyRotcossin
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
UniversalCopy
<
uint64_t
>
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per load
using
GmemTiledCopyRotcossinCont
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per load
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
// No_double_buffer is another option to reduce smem usage, but will slow things down.
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
int
AtomLayoutMSdP_
=
1
,
int
AtomLayoutNdKV
=
2
,
int
AtomLayoutMdQ
=
2
,
bool
Is_V_in_regs_
=
false
,
bool
No_double_buffer_
=
false
,
typename
elem_type
=
cutlass
::
half_t
,
typename
Base
=
Flash_kernel_traits
<
kHeadDim_
,
kBlockM_
,
kBlockN_
,
kNWarps_
,
elem_type
>
>
struct
Flash_bwd_kernel_traits
:
public
Base
{
using
Element
=
typename
Base
::
Element
;
using
ElementAccum
=
typename
Base
::
ElementAccum
;
using
index_t
=
typename
Base
::
index_t
;
static
constexpr
bool
Has_cp_async
=
Base
::
Has_cp_async
;
using
SmemCopyAtom
=
typename
Base
::
SmemCopyAtom
;
using
SmemCopyAtomTransposed
=
typename
Base
::
SmemCopyAtomTransposed
;
static
constexpr
bool
Is_V_in_regs
=
Is_V_in_regs_
;
static
constexpr
bool
No_double_buffer
=
No_double_buffer_
;
// The number of threads.
static
constexpr
int
kNWarps
=
kNWarps_
;
static
constexpr
int
kNThreads
=
kNWarps
*
32
;
static
constexpr
int
kBlockM
=
kBlockM_
;
static
constexpr
int
kBlockN
=
kBlockN_
;
static
constexpr
int
kHeadDim
=
kHeadDim_
;
static_assert
(
kHeadDim
%
32
==
0
);
static
constexpr
int
kBlockKSmem
=
kHeadDim
%
64
==
0
?
64
:
32
;
static
constexpr
int
kBlockKGmem
=
kHeadDim
%
128
==
0
?
128
:
(
kHeadDim
%
64
==
0
?
64
:
32
);
static
constexpr
int
kSwizzle
=
kBlockKSmem
==
32
?
2
:
3
;
static
constexpr
int
AtomLayoutMSdP
=
AtomLayoutMSdP_
;
static_assert
(
kNWarps
%
AtomLayoutMSdP
==
0
);
static_assert
(
kNWarps
%
AtomLayoutNdKV
==
0
);
static_assert
(
kNWarps
%
AtomLayoutMdQ
==
0
);
using
TiledMmaSdP
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
AtomLayoutMSdP
>
,
Int
<
kNWarps
/
AtomLayoutMSdP
>
,
_1
>>
,
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using
TiledMmadKV
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
AtomLayoutNdKV
>
,
Int
<
kNWarps
/
AtomLayoutNdKV
>
,
_1
>>
,
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using
TiledMmadQ
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
AtomLayoutMdQ
>
,
Int
<
kNWarps
/
AtomLayoutMdQ
>
,
_1
>>
,
// 2x4x1 or 4x2x1 thread group
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using
SmemLayoutAtomQdO
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
_8
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutQdO
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQdO
{},
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kHeadDim
>
{})));
using
SmemLayoutAtomKV
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
kBlockM
/
kNWarps
>
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutKV
=
decltype
(
tile_to_shape
(
// SmemLayoutAtomQdO{},
SmemLayoutAtomKV
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kHeadDim
>
{})));
using
SmemLayoutAtomKtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
using
SmemLayoutAtomKtransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomKtransposedNoSwizzle
{}));
using
SmemLayoutKtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomKtransposed
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockN
>
{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using
SmemLayoutKtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomKtransposedNoSwizzle
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockN
>
{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN;
static_assert
(
kBlockN
>=
64
);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
static
constexpr
int
kPBlockN
=
64
;
static_assert
(
kPBlockN
==
16
||
kPBlockN
==
32
||
kPBlockN
==
64
);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static
constexpr
int
kSwizzlePdS
=
3
;
using
SmemLayoutAtomPdS
=
decltype
(
composition
(
Swizzle
<
kSwizzlePdS
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
kBlockM
>
,
Int
<
kPBlockN
>>
,
Stride
<
Int
<
kPBlockN
>
,
_1
>>
{}));
using
SmemLayoutPdS
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdS
{},
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kBlockN
>
{})));
using
SmemLayoutAtomPdStransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kPBlockN
>
,
Int
<
kBlockM
>>
,
Stride
<
_1
,
Int
<
kPBlockN
>>>
;
using
SmemLayoutAtomPdStransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzlePdS
,
3
,
3
>
{},
SmemLayoutAtomPdStransposedNoSwizzle
{}));
using
SmemLayoutPdStransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdStransposed
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kBlockM
>
{})));
using
SmemLayoutPdStransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomPdStransposedNoSwizzle
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kBlockM
>
{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using
SmemCopyAtomPdS
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemLayoutAtomQdOtransposedNoSwizzle
=
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockM
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
;
using
SmemLayoutAtomQdOtransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
SmemLayoutAtomQdOtransposedNoSwizzle
{}));
using
SmemLayoutQdOtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQdOtransposed
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockM
>
{})));
using
SmemLayoutQdOtransposedNoSwizzle
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQdOtransposedNoSwizzle
{},
make_shape
(
Int
<
kHeadDim
>
{},
Int
<
kBlockM
>
{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using
SmemLayoutAtomdKV
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
_8
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutdKV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomdKV
{},
make_shape
(
Int
<
kBlockN
>
{},
Int
<
kHeadDim
>
{})));
using
SmemCopyAtomdKV
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemLayoutAtomdQ
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
_8
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutdQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomdQ
{},
make_shape
(
Int
<
kBlockM
>
{},
Int
<
kHeadDim
>
{})));
using
SmemCopyAtomdQ
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
static
constexpr
int
kSmemQdOCount
=
size
(
SmemLayoutQdO
{})
*
(
No_double_buffer
?
2
:
3
);
// Double buffer for sQ
static
constexpr
int
kSmemKVCount
=
size
(
SmemLayoutKV
{})
*
2
;
static
constexpr
int
kSmemdSCount
=
size
(
SmemLayoutPdS
{});
static
constexpr
int
kSmemPCount
=
size
(
SmemLayoutPdS
{});
static
constexpr
int
kSmemdQCount
=
size
(
SmemLayoutdQ
{});
static
constexpr
int
kSmemQdOSize
=
kSmemQdOCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemKVSize
=
kSmemKVCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdSSize
=
kSmemdSCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemPSize
=
kSmemPCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdQSize
=
kSmemdQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemSize
=
kSmemQdOSize
+
(
!
Is_V_in_regs
?
kSmemKVSize
+
kSmemdSSize
+
std
::
max
(
kSmemPSize
,
kSmemdQSize
)
:
std
::
max
(
kSmemKVSize
,
kSmemKVSize
/
2
+
kSmemdSSize
+
std
::
max
(
kSmemPSize
,
kSmemdQSize
)));
static
constexpr
int
kSmemSize1colblock
=
kSmemQdOSize
+
(
!
Is_V_in_regs
?
kSmemKVSize
+
kSmemdSSize
+
kSmemPSize
:
std
::
max
(
kSmemKVSize
,
kSmemKVSize
/
2
+
kSmemdSSize
+
kSmemPSize
));
static
constexpr
int
kSmemSize1rowblock
=
kSmemQdOSize
/
3
*
2
+
kSmemKVSize
/
2
*
3
+
kSmemdSSize
+
kSmemPSize
;
static
constexpr
int
kGmemElemsPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
Element
);
static_assert
(
kHeadDim
%
kGmemElemsPerLoad
==
0
,
"kHeadDim must be a multiple of kGmemElemsPerLoad"
);
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
// to affect speed in practice.
static
constexpr
int
kGmemThreadsPerRow
=
kBlockKSmem
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRow
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRow"
);
using
GmemLayoutAtom
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRow
>
,
Int
<
kGmemThreadsPerRow
>>
,
Stride
<
Int
<
kGmemThreadsPerRow
>
,
_1
>>
;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using
Gmem_copy_struct
=
std
::
conditional_t
<
Has_cp_async
,
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
DefaultCopy
>
;
using
GmemTiledCopyQKV
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
elem_type
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per read
using
GmemTiledCopydO
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
elem_type
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
using
GmemTiledCopydKV
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
elem_type
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
using
GmemTiledCopydQ
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
elem_type
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
using
GmemLayoutAtomdQaccum
=
std
::
conditional_t
<
kBlockKSmem
==
32
,
Layout
<
Shape
<
_32
,
_8
>
,
// Thread layout, 8 threads per row
Stride
<
_8
,
_1
>>
,
Layout
<
Shape
<
_16
,
_16
>
,
// Thread layout, 16 threads per row
Stride
<
_16
,
_1
>>
>
;
using
GmemTiledCopydQaccum
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
{},
GmemLayoutAtomdQaccum
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per store
using
GmemTiledCopydQaccumAtomicAdd
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
{},
Layout
<
Shape
<
_8
,
_32
>
,
// Thread layout, 8 threads per row
Stride
<
_32
,
_1
>>
{},
Layout
<
Shape
<
_1
,
_1
>>
{}));
// Val layout, 1 val per store
};
////////////////////////////////////////////////////////////////////////////////////////////////////
csrc/block_sparse_attn/src/kernel_traits_sm90.h
0 → 100644
View file @
4f83cf8f
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using
namespace
cute
;
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
typename
elem_type
=
cutlass
::
half_t
>
struct
Flash_kernel_traits_sm90
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
Element
=
elem_type
;
static
constexpr
bool
Has_cp_async
=
true
;
#else
using
Element
=
cutlass
::
half_t
;
static
constexpr
bool
Has_cp_async
=
false
;
#endif
using
ElementAccum
=
float
;
using
index_t
=
int64_t
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
SM80_16x8x16_F32F16F16F32_TN
>
,
MMA_Atom
<
SM80_16x8x16_F32BF16BF16F32_TN
>
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
#else
using
MMA_Atom_Arch
=
MMA_Atom
<
SM75_16x8x8_F32F16F16F32_TN
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_2
,
_2
>>
;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using
SmemCopyAtom
=
Copy_Atom
<
SM75_U32x4_LDSM_N
,
elem_type
>
;
using
SmemCopyAtomTransposed
=
Copy_Atom
<
SM75_U16x8_LDSM_T
,
elem_type
>
;
#else
using
SmemCopyAtom
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemCopyAtomTransposed
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
#endif
};
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
bool
Is_Q_in_regs_
=
false
,
bool
Share_Q_K_smem_
=
false
,
typename
elem_type
=
cutlass
::
half_t
,
typename
Base
=
Flash_kernel_traits_sm90
<
kHeadDim_
,
kBlockM_
,
kBlockN_
,
kNWarps_
,
elem_type
>
>
struct
Flash_fwd_kernel_traits
:
public
Base
{
using
Element
=
typename
Base
::
Element
;
using
ElementAccum
=
typename
Base
::
ElementAccum
;
using
index_t
=
typename
Base
::
index_t
;
static
constexpr
bool
Has_cp_async
=
Base
::
Has_cp_async
;
using
SmemCopyAtom
=
typename
Base
::
SmemCopyAtom
;
using
SmemCopyAtomTransposed
=
typename
Base
::
SmemCopyAtomTransposed
;
static
constexpr
bool
Share_Q_K_smem
=
Share_Q_K_smem_
;
static
constexpr
bool
Is_Q_in_regs
=
Is_Q_in_regs_
||
Share_Q_K_smem
;
// The number of threads.
static
constexpr
int
kNWarps
=
kNWarps_
;
static
constexpr
int
kNThreads
=
kNWarps
*
32
;
static
constexpr
int
kBlockM
=
kBlockM_
;
static
constexpr
int
kBlockN
=
kBlockN_
;
static
constexpr
int
kHeadDim
=
kHeadDim_
;
static_assert
(
kHeadDim
%
32
==
0
);
static
constexpr
int
kBlockKSmem
=
kHeadDim
%
64
==
0
?
64
:
32
;
static
constexpr
int
kBlockKGmem
=
kHeadDim
%
128
==
0
?
128
:
(
kHeadDim
%
64
==
0
?
64
:
32
);
static
constexpr
int
kSwizzle
=
kBlockKSmem
==
32
?
2
:
3
;
using
TiledMma
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
// 4x1x1 or 8x1x1 thread group
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using
SmemLayoutAtomQ
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout
<
Shape
<
_8
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemLayoutKV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{}));
using
SmemLayoutAtomVtransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVtransposed
{},
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
SmemLayoutVtransposed
{}.
layout_fn
());
using
SmemLayoutAtomO
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutO
=
decltype
(
tile_to_shape
(
SmemLayoutAtomO
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemCopyAtomO
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
static
constexpr
int
kSmemQCount
=
size
(
SmemLayoutQ
{});
static
constexpr
int
kSmemKVCount
=
size
(
SmemLayoutKV
{})
*
2
;
static
constexpr
int
kSmemQSize
=
kSmemQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemKVSize
=
kSmemKVCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemSize
=
Share_Q_K_smem
?
std
::
max
(
kSmemQSize
,
kSmemKVSize
)
:
kSmemQSize
+
kSmemKVSize
;
static
constexpr
int
kGmemElemsPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
Element
);
static_assert
(
kHeadDim
%
kGmemElemsPerLoad
==
0
,
"kHeadDim must be a multiple of kGmemElemsPerLoad"
);
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static
constexpr
int
kGmemThreadsPerRow
=
kBlockKSmem
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRow
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRow"
);
using
GmemLayoutAtom
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRow
>
,
Int
<
kGmemThreadsPerRow
>>
,
Stride
<
Int
<
kGmemThreadsPerRow
>
,
_1
>>
;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using
Gmem_copy_struct
=
std
::
conditional_t
<
Has_cp_async
,
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
DefaultCopy
>
;
using
GmemTiledCopyQKV
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
elem_type
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per read
using
GmemTiledCopyO
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
elem_type
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
static
constexpr
int
kGmemThreadsPerRowP
=
kBlockN
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRowP
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRowP"
);
using
GmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRowP
>
,
Int
<
kGmemThreadsPerRowP
>>
,
Stride
<
Int
<
kGmemThreadsPerRowP
>
,
_1
>>
;
using
GmemTiledCopyP
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
elem_type
>
{},
GmemLayoutAtomP
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
};
////////////////////////////////////////////////////////////////////////////////////////////////////
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment