Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ed4959b2
"vscode:/vscode.git/clone" did not exist on "fce0a57dec5635022d4170ba4a15e32a432b9cc7"
Commit
ed4959b2
authored
Jan 20, 2024
by
Tri Dao
Browse files
Change inline to __forceinline__, use __grid_constant__ param
parent
6f706eff
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
56 additions
and
56 deletions
+56
-56
csrc/flash_attn/src/alibi.h
csrc/flash_attn/src/alibi.h
+2
-2
csrc/flash_attn/src/block_info.h
csrc/flash_attn/src/block_info.h
+2
-2
csrc/flash_attn/src/dropout.h
csrc/flash_attn/src/dropout.h
+2
-2
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+1
-1
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+6
-6
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+3
-3
csrc/flash_attn/src/mask.h
csrc/flash_attn/src/mask.h
+4
-4
csrc/flash_attn/src/philox.cuh
csrc/flash_attn/src/philox.cuh
+3
-3
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+10
-10
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+23
-23
No files found.
csrc/flash_attn/src/alibi.h
View file @
ed4959b2
...
@@ -19,7 +19,7 @@ struct Alibi {
...
@@ -19,7 +19,7 @@ struct Alibi {
const
float
alibi_slope
;
const
float
alibi_slope
;
const
int
max_seqlen_k
,
max_seqlen_q
;
const
int
max_seqlen_k
,
max_seqlen_q
;
inline
__device__
Alibi
(
const
float
alibi_slope
,
const
int
max_seqlen_k
,
const
int
max_seqlen_q
)
__force
inline
__
__device__
Alibi
(
const
float
alibi_slope
,
const
int
max_seqlen_k
,
const
int
max_seqlen_q
)
:
alibi_slope
(
alibi_slope
)
:
alibi_slope
(
alibi_slope
)
,
max_seqlen_k
(
max_seqlen_k
)
,
max_seqlen_k
(
max_seqlen_k
)
,
max_seqlen_q
(
max_seqlen_q
)
{
,
max_seqlen_q
(
max_seqlen_q
)
{
...
@@ -27,7 +27,7 @@ struct Alibi {
...
@@ -27,7 +27,7 @@ struct Alibi {
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
__force
inline
__
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
col_idx_offset_
,
const
int
row_idx_offset
,
const
int
row_idx_offset
,
const
int
warp_row_stride
)
{
const
int
warp_row_stride
)
{
...
...
csrc/flash_attn/src/block_info.h
View file @
ed4959b2
...
@@ -24,12 +24,12 @@ struct BlockInfo {
...
@@ -24,12 +24,12 @@ struct BlockInfo {
}
}
template
<
typename
index_t
>
template
<
typename
index_t
>
inline
__device__
index_t
q_offset
(
const
index_t
batch_stride
,
const
index_t
row_stride
,
const
int
bidb
)
const
{
__force
inline
__
__device__
index_t
q_offset
(
const
index_t
batch_stride
,
const
index_t
row_stride
,
const
int
bidb
)
const
{
return
sum_s_q
==
-
1
?
bidb
*
batch_stride
:
uint32_t
(
sum_s_q
)
*
row_stride
;
return
sum_s_q
==
-
1
?
bidb
*
batch_stride
:
uint32_t
(
sum_s_q
)
*
row_stride
;
}
}
template
<
typename
index_t
>
template
<
typename
index_t
>
inline
__device__
index_t
k_offset
(
const
index_t
batch_stride
,
const
index_t
row_stride
,
const
int
bidb
)
const
{
__force
inline
__
__device__
index_t
k_offset
(
const
index_t
batch_stride
,
const
index_t
row_stride
,
const
int
bidb
)
const
{
return
sum_s_k
==
-
1
?
bidb
*
batch_stride
:
uint32_t
(
sum_s_k
)
*
row_stride
;
return
sum_s_k
==
-
1
?
bidb
*
batch_stride
:
uint32_t
(
sum_s_k
)
*
row_stride
;
}
}
...
...
csrc/flash_attn/src/dropout.h
View file @
ed4959b2
...
@@ -14,7 +14,7 @@ struct Dropout {
...
@@ -14,7 +14,7 @@ struct Dropout {
const
unsigned
long
long
seed
,
offset
;
const
unsigned
long
long
seed
,
offset
;
const
uint8_t
p_dropout_in_uint8_t
;
const
uint8_t
p_dropout_in_uint8_t
;
inline
__device__
Dropout
(
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
__force
inline
__
__device__
Dropout
(
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
const
uint8_t
p_dropout_in_uint8_t
,
const
uint8_t
p_dropout_in_uint8_t
,
const
int
bid
,
const
int
hid
,
const
int
tid
,
const
int
nheads
)
const
int
bid
,
const
int
hid
,
const
int
tid
,
const
int
nheads
)
:
seed
(
seed
)
:
seed
(
seed
)
...
@@ -23,7 +23,7 @@ struct Dropout {
...
@@ -23,7 +23,7 @@ struct Dropout {
}
}
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor_
,
__force
inline
__
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor_
,
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
tensor
=
make_tensor
(
tensor_
.
data
(),
flash
::
convert_layout_rowcol_dropout
(
tensor_
.
layout
()));
Tensor
tensor
=
make_tensor
(
tensor_
.
data
(),
flash
::
convert_layout_rowcol_dropout
(
tensor_
.
layout
()));
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
ed4959b2
...
@@ -448,7 +448,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
...
@@ -448,7 +448,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
clear
(
acc_dv
);
clear
(
acc_dv
);
clear
(
acc_dk
);
clear
(
acc_dk
);
const
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
const
float
alibi_slope
=
!
Has_alibi
||
params
.
alibi_slopes_ptr
==
nullptr
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
ed4959b2
...
@@ -12,33 +12,33 @@
...
@@ -12,33 +12,33 @@
#include "flash_bwd_kernel.h"
#include "flash_bwd_kernel.h"
template
<
bool
Clear_dQaccum
=
true
,
typename
Kernel_traits
>
template
<
bool
Clear_dQaccum
=
true
,
typename
Kernel_traits
>
__global__
void
flash_bwd_dot_do_o_kernel
(
Flash_bwd_params
params
)
{
__global__
void
flash_bwd_dot_do_o_kernel
(
const
Flash_bwd_params
params
)
{
flash
::
compute_dot_do_o
<
Clear_dQaccum
,
Kernel_traits
>
(
params
);
flash
::
compute_dot_do_o
<
Clear_dQaccum
,
Kernel_traits
>
(
params
);
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
__global__
void
flash_bwd_clear_dkvaccum_kernel
(
Flash_bwd_params
params
)
{
__global__
void
flash_bwd_clear_dkvaccum_kernel
(
const
Flash_bwd_params
params
)
{
flash
::
clear_dKVaccum
<
Kernel_traits
>
(
params
);
flash
::
clear_dKVaccum
<
Kernel_traits
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_M
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_M
,
bool
Is_even_K
>
__global__
void
flash_bwd_dq_dk_dv_loop_kernel
(
Flash_bwd_params
params
)
{
__global__
void
flash_bwd_dq_dk_dv_loop_kernel
(
__grid_constant__
const
Flash_bwd_params
params
)
{
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
Is_even_M
,
Is_even_K
>
(
params
);
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
Is_even_M
,
Is_even_K
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
>
__global__
void
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
(
Flash_bwd_params
params
)
{
__global__
void
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
(
__grid_constant__
const
Flash_bwd_params
params
)
{
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
>
(
params
);
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
>
(
params
);
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
__global__
void
flash_bwd_convert_dq_kernel
(
Flash_bwd_params
params
,
const
int
nsplits
)
{
__global__
void
flash_bwd_convert_dq_kernel
(
const
Flash_bwd_params
params
,
const
int
nsplits
)
{
flash
::
convert_dQ
<
Kernel_traits
>
(
params
,
nsplits
);
flash
::
convert_dQ
<
Kernel_traits
>
(
params
,
nsplits
);
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
__global__
void
flash_bwd_convert_dkv_kernel
(
Flash_bwd_params
params
)
{
__global__
void
flash_bwd_convert_dkv_kernel
(
const
Flash_bwd_params
params
)
{
flash
::
convert_dKV
<
Kernel_traits
>
(
params
);
flash
::
convert_dKV
<
Kernel_traits
>
(
params
);
}
}
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
ed4959b2
...
@@ -11,18 +11,18 @@
...
@@ -11,18 +11,18 @@
#include "flash_fwd_kernel.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
__global__
void
flash_fwd_kernel
(
__grid_constant__
const
Flash_fwd_params
params
)
{
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
__global__
void
flash_fwd_splitkv_kernel
(
__grid_constant__
const
Flash_fwd_params
params
)
{
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_combine_kernel
(
Flash_fwd_params
params
)
{
__global__
void
flash_fwd_splitkv_combine_kernel
(
__grid_constant__
const
Flash_fwd_params
params
)
{
static_assert
(
Log_max_splits
>=
1
);
static_assert
(
Log_max_splits
>=
1
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
kBlockM
,
Log_max_splits
,
Is_even_K
>
(
params
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
kBlockM
,
Log_max_splits
,
Is_even_K
>
(
params
);
}
}
...
...
csrc/flash_attn/src/mask.h
View file @
ed4959b2
...
@@ -11,7 +11,7 @@ namespace flash {
...
@@ -11,7 +11,7 @@ namespace flash {
using
namespace
cute
;
using
namespace
cute
;
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
__force
inline
__
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
...
@@ -35,7 +35,7 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
...
@@ -35,7 +35,7 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
}
}
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
__force
inline
__
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
window_size_left
,
const
int
window_size_right
)
{
...
@@ -72,7 +72,7 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
...
@@ -72,7 +72,7 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
}
}
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
__force
inline
__
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
...
@@ -81,7 +81,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
...
@@ -81,7 +81,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
__force
inline
__
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
{
...
...
csrc/flash_attn/src/philox.cuh
View file @
ed4959b2
...
@@ -9,7 +9,7 @@ struct ull2 {
...
@@ -9,7 +9,7 @@ struct ull2 {
unsigned
long
long
y
;
unsigned
long
long
y
;
};
};
inline
__device__
uint2
mulhilo32
(
const
unsigned
int
a
,
const
unsigned
int
b
)
{
__force
inline
__
__device__
uint2
mulhilo32
(
const
unsigned
int
a
,
const
unsigned
int
b
)
{
uint2
*
res
;
uint2
*
res
;
unsigned
long
long
tmp
;
unsigned
long
long
tmp
;
asm
(
"mul.wide.u32 %0, %1, %2;
\n\t
"
asm
(
"mul.wide.u32 %0, %1, %2;
\n\t
"
...
@@ -19,7 +19,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
...
@@ -19,7 +19,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
return
*
res
;
return
*
res
;
}
}
inline
__device__
uint4
philox_single_round
(
const
uint4
ctr
,
const
uint2
key
)
{
__force
inline
__
__device__
uint4
philox_single_round
(
const
uint4
ctr
,
const
uint2
key
)
{
constexpr
unsigned
long
kPhiloxSA
=
0xD2511F53
;
constexpr
unsigned
long
kPhiloxSA
=
0xD2511F53
;
constexpr
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
constexpr
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
uint2
res0
=
mulhilo32
(
kPhiloxSA
,
ctr
.
x
);
uint2
res0
=
mulhilo32
(
kPhiloxSA
,
ctr
.
x
);
...
@@ -28,7 +28,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
...
@@ -28,7 +28,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
return
ret
;
return
ret
;
}
}
inline
__device__
uint4
philox
(
unsigned
long
long
seed
,
__force
inline
__
__device__
uint4
philox
(
unsigned
long
long
seed
,
unsigned
long
long
subsequence
,
unsigned
long
long
subsequence
,
unsigned
long
long
offset
)
{
unsigned
long
long
offset
)
{
constexpr
unsigned
long
kPhilox10A
=
0x9E3779B9
;
constexpr
unsigned
long
kPhilox10A
=
0x9E3779B9
;
...
...
csrc/flash_attn/src/softmax.h
View file @
ed4959b2
...
@@ -20,7 +20,7 @@ using namespace cute;
...
@@ -20,7 +20,7 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
inline
void
thread_reduce_
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
summary
,
Operator
&
op
)
{
__device__
__force
inline
__
void
thread_reduce_
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
summary
,
Operator
&
op
)
{
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
summary
)
==
size
<
0
>
(
tensor
));
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
summary
)
==
size
<
0
>
(
tensor
));
...
@@ -35,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te
...
@@ -35,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
inline
void
quad_allreduce_
(
Tensor
<
Engine0
,
Layout0
>
&
dst
,
Tensor
<
Engine1
,
Layout1
>
&
src
,
Operator
&
op
)
{
__device__
__force
inline
__
void
quad_allreduce_
(
Tensor
<
Engine0
,
Layout0
>
&
dst
,
Tensor
<
Engine1
,
Layout1
>
&
src
,
Operator
&
op
)
{
CUTE_STATIC_ASSERT_V
(
size
(
dst
)
==
size
(
src
));
CUTE_STATIC_ASSERT_V
(
size
(
dst
)
==
size
(
src
));
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
dst
);
i
++
){
for
(
int
i
=
0
;
i
<
size
(
dst
);
i
++
){
...
@@ -44,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng
...
@@ -44,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng
}
}
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
inline
void
reduce_
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
summary
,
Operator
&
op
)
{
__device__
__force
inline
__
void
reduce_
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
summary
,
Operator
&
op
)
{
thread_reduce_
<
zero_init
>
(
tensor
,
summary
,
op
);
thread_reduce_
<
zero_init
>
(
tensor
,
summary
,
op
);
quad_allreduce_
(
summary
,
summary
,
op
);
quad_allreduce_
(
summary
,
summary
,
op
);
}
}
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__device__
inline
void
reduce_max
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
max
){
__device__
__force
inline
__
void
reduce_max
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
max
){
MaxOp
<
float
>
max_op
;
MaxOp
<
float
>
max_op
;
reduce_
<
zero_init
>
(
tensor
,
max
,
max_op
);
reduce_
<
zero_init
>
(
tensor
,
max
,
max_op
);
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__device__
inline
void
reduce_sum
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
sum
){
__device__
__force
inline
__
void
reduce_sum
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
sum
){
SumOp
<
float
>
sum_op
;
SumOp
<
float
>
sum_op
;
reduce_
(
tensor
,
sum
,
sum_op
);
reduce_
(
tensor
,
sum
,
sum_op
);
}
}
// Apply the exp to all the elements.
// Apply the exp to all the elements.
template
<
bool
Scale_max
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
bool
Scale_max
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
scale_apply_exp2
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
max
,
const
float
scale
)
{
__force
inline
__
__device__
void
scale_apply_exp2
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
max
,
const
float
scale
)
{
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
max
)
==
size
<
0
>
(
tensor
));
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
max
)
==
size
<
0
>
(
tensor
));
...
@@ -85,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor
...
@@ -85,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor
// Apply the exp to all the elements.
// Apply the exp to all the elements.
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
max_scale_exp2_sum
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
max
,
Tensor
<
Engine1
,
Layout1
>
&
sum
,
const
float
scale
)
{
__force
inline
__
__device__
void
max_scale_exp2_sum
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
max
,
Tensor
<
Engine1
,
Layout1
>
&
sum
,
const
float
scale
)
{
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
max
)
==
size
<
0
>
(
tensor
));
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
max
)
==
size
<
0
>
(
tensor
));
...
@@ -123,10 +123,10 @@ struct Softmax {
...
@@ -123,10 +123,10 @@ struct Softmax {
using
TensorT
=
decltype
(
make_tensor
<
float
>
(
Shape
<
Int
<
kNRows
>>
{}));
using
TensorT
=
decltype
(
make_tensor
<
float
>
(
Shape
<
Int
<
kNRows
>>
{}));
TensorT
row_max
,
row_sum
;
TensorT
row_max
,
row_sum
;
inline
__device__
Softmax
()
{};
__force
inline
__
__device__
Softmax
()
{};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
>
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
>
inline
__device__
void
softmax_rescale_o
(
Tensor0
&
acc_s
,
Tensor1
&
acc_o
,
float
softmax_scale_log2
)
{
__force
inline
__
__device__
void
softmax_rescale_o
(
Tensor0
&
acc_s
,
Tensor1
&
acc_o
,
float
softmax_scale_log2
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
...
@@ -160,7 +160,7 @@ struct Softmax {
...
@@ -160,7 +160,7 @@ struct Softmax {
};
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
typename
Tensor0
>
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
typename
Tensor0
>
inline
__device__
TensorT
normalize_softmax_lse
(
Tensor0
&
acc_o
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
__force
inline
__
__device__
TensorT
normalize_softmax_lse
(
Tensor0
&
acc_o
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
TensorT
lse
=
make_fragment_like
(
row_sum
);
TensorT
lse
=
make_fragment_like
(
row_sum
);
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
...
...
csrc/flash_attn/src/utils.h
View file @
ed4959b2
...
@@ -29,10 +29,10 @@ namespace flash {
...
@@ -29,10 +29,10 @@ namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
template
<
typename
T
>
inline
__device__
uint32_t
relu2
(
const
uint32_t
x
);
__force
inline
__
__device__
uint32_t
relu2
(
const
uint32_t
x
);
template
<
>
template
<
>
inline
__device__
uint32_t
relu2
<
cutlass
::
half_t
>
(
const
uint32_t
x
)
{
__force
inline
__
__device__
uint32_t
relu2
<
cutlass
::
half_t
>
(
const
uint32_t
x
)
{
uint32_t
res
;
uint32_t
res
;
const
uint32_t
zero
=
0u
;
const
uint32_t
zero
=
0u
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
...
@@ -50,7 +50,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
...
@@ -50,7 +50,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
template
<
>
inline
__device__
uint32_t
relu2
<
cutlass
::
bfloat16_t
>
(
const
uint32_t
x
)
{
__force
inline
__
__device__
uint32_t
relu2
<
cutlass
::
bfloat16_t
>
(
const
uint32_t
x
)
{
uint32_t
res
;
uint32_t
res
;
const
uint32_t
zero
=
0u
;
const
uint32_t
zero
=
0u
;
asm
volatile
(
"max.bf16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
asm
volatile
(
"max.bf16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
...
@@ -63,10 +63,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
...
@@ -63,10 +63,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
typename
T
>
template
<
typename
T
>
inline
__device__
uint32_t
convert_relu2
(
const
float2
x
);
__force
inline
__
__device__
uint32_t
convert_relu2
(
const
float2
x
);
template
<
>
template
<
>
inline
__device__
uint32_t
convert_relu2
<
cutlass
::
half_t
>
(
const
float2
x
)
{
__force
inline
__
__device__
uint32_t
convert_relu2
<
cutlass
::
half_t
>
(
const
float2
x
)
{
uint32_t
res
;
uint32_t
res
;
const
uint32_t
a
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
x
);
const
uint32_t
a
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
x
);
const
uint32_t
b
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
y
);
const
uint32_t
b
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
y
);
...
@@ -75,7 +75,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
...
@@ -75,7 +75,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
}
}
template
<
>
template
<
>
inline
__device__
uint32_t
convert_relu2
<
cutlass
::
bfloat16_t
>
(
const
float2
x
)
{
__force
inline
__
__device__
uint32_t
convert_relu2
<
cutlass
::
bfloat16_t
>
(
const
float2
x
)
{
uint32_t
res
;
uint32_t
res
;
const
uint32_t
a
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
x
);
const
uint32_t
a
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
x
);
const
uint32_t
b
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
y
);
const
uint32_t
b
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
y
);
...
@@ -89,20 +89,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
...
@@ -89,20 +89,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
template
<
typename
T
>
template
<
typename
T
>
struct
MaxOp
{
struct
MaxOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
>
y
?
x
:
y
;
}
__device__
__force
inline
__
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
>
y
?
x
:
y
;
}
};
};
template
<
>
template
<
>
struct
MaxOp
<
float
>
{
struct
MaxOp
<
float
>
{
// This is slightly faster
// This is slightly faster
__device__
inline
float
operator
()(
float
const
&
x
,
float
const
&
y
)
{
return
max
(
x
,
y
);
}
__device__
__force
inline
__
float
operator
()(
float
const
&
x
,
float
const
&
y
)
{
return
max
(
x
,
y
);
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
template
<
typename
T
>
struct
SumOp
{
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
__device__
__force
inline
__
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -111,7 +111,7 @@ template<int THREADS>
...
@@ -111,7 +111,7 @@ template<int THREADS>
struct
Allreduce
{
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
static
__device__
__force
inline
__
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
...
@@ -123,7 +123,7 @@ struct Allreduce {
...
@@ -123,7 +123,7 @@ struct Allreduce {
template
<
>
template
<
>
struct
Allreduce
<
2
>
{
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
static
__device__
__force
inline
__
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
return
x
;
}
}
...
@@ -135,7 +135,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename
...
@@ -135,7 +135,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename
typename
Tensor2
,
typename
Tensor3
,
typename
Tensor4
,
typename
Tensor2
,
typename
Tensor3
,
typename
Tensor4
,
typename
TiledMma
,
typename
TiledCopyA
,
typename
TiledCopyB
,
typename
TiledMma
,
typename
TiledCopyA
,
typename
TiledCopyB
,
typename
ThrCopyA
,
typename
ThrCopyB
>
typename
ThrCopyA
,
typename
ThrCopyB
>
inline
__device__
void
gemm
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsA
,
__force
inline
__
__device__
void
gemm
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsA
,
Tensor4
const
&
tCsB
,
TiledMma
tiled_mma
,
Tensor4
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopyA
smem_tiled_copy_A
,
TiledCopyB
smem_tiled_copy_B
,
TiledCopyA
smem_tiled_copy_A
,
TiledCopyB
smem_tiled_copy_B
,
ThrCopyA
smem_thr_copy_A
,
ThrCopyB
smem_thr_copy_B
)
{
ThrCopyA
smem_thr_copy_A
,
ThrCopyB
smem_thr_copy_B
)
{
...
@@ -162,7 +162,7 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
...
@@ -162,7 +162,7 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
template
<
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
template
<
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
TiledMma
,
typename
TiledCopy
,
typename
ThrCopy
>
typename
TiledMma
,
typename
TiledCopy
,
typename
ThrCopy
>
inline
__device__
void
gemm_rs
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
__force
inline
__
__device__
void
gemm_rs
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopy
smem_tiled_copy_B
,
TiledMma
tiled_mma
,
TiledCopy
smem_tiled_copy_B
,
ThrCopy
smem_thr_copy_B
)
{
ThrCopy
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
...
@@ -184,7 +184,7 @@ inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tenso
...
@@ -184,7 +184,7 @@ inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tenso
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template
<
typename
Layout
>
template
<
typename
Layout
>
inline
__device__
auto
convert_layout_acc_rowcol
(
Layout
acc_layout
)
{
__force
inline
__
__device__
auto
convert_layout_acc_rowcol
(
Layout
acc_layout
)
{
static_assert
(
decltype
(
size
<
0
>
(
acc_layout
))
::
value
==
4
);
static_assert
(
decltype
(
size
<
0
>
(
acc_layout
))
::
value
==
4
);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
_2
>
{});
// ((2, 2), MMA_M, MMA_N)
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
_2
>
{});
// ((2, 2), MMA_M, MMA_N)
...
@@ -196,7 +196,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
...
@@ -196,7 +196,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template
<
typename
MMA_traits
,
typename
Layout
>
template
<
typename
MMA_traits
,
typename
Layout
>
inline
__device__
auto
convert_layout_rowcol_Aregs
(
Layout
rowcol_layout
)
{
__force
inline
__
__device__
auto
convert_layout_rowcol_Aregs
(
Layout
rowcol_layout
)
{
using
X
=
Underscore
;
using
X
=
Underscore
;
static_assert
(
decltype
(
size
<
0
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
0
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
1
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
1
,
0
>
(
rowcol_layout
))
::
value
==
2
);
...
@@ -213,7 +213,7 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
...
@@ -213,7 +213,7 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
template
<
typename
Layout
>
template
<
typename
Layout
>
inline
__device__
auto
convert_layout_rowcol_dropout
(
Layout
rowcol_layout
)
{
__force
inline
__
__device__
auto
convert_layout_rowcol_dropout
(
Layout
rowcol_layout
)
{
using
X
=
Underscore
;
using
X
=
Underscore
;
static_assert
(
decltype
(
size
<
0
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
0
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
1
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
1
,
0
>
(
rowcol_layout
))
::
value
==
2
);
...
@@ -226,7 +226,7 @@ inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
...
@@ -226,7 +226,7 @@ inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
inline
__device__
auto
convert_type
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
__force
inline
__
__device__
auto
convert_type
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
using
From_type
=
typename
Engine
::
value_type
;
using
From_type
=
typename
Engine
::
value_type
;
constexpr
int
numel
=
decltype
(
size
(
tensor
))
::
value
;
constexpr
int
numel
=
decltype
(
size
(
tensor
))
::
value
;
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
...
@@ -238,7 +238,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
...
@@ -238,7 +238,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Engine
,
typename
Layout
>
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
relu_
(
Tensor
<
Engine
,
Layout
>
&
tensor
)
{
__force
inline
__
__device__
void
relu_
(
Tensor
<
Engine
,
Layout
>
&
tensor
)
{
constexpr
int
numel
=
decltype
(
size
(
tensor
))
::
value
;
constexpr
int
numel
=
decltype
(
size
(
tensor
))
::
value
;
static_assert
(
numel
%
2
==
0
);
static_assert
(
numel
%
2
==
0
);
using
value_t
=
typename
Engine
::
value_type
;
using
value_t
=
typename
Engine
::
value_type
;
...
@@ -254,7 +254,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
...
@@ -254,7 +254,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
inline
__device__
auto
convert_type_relu
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
__force
inline
__
__device__
auto
convert_type_relu
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
using
From_type
=
typename
Engine
::
value_type
;
using
From_type
=
typename
Engine
::
value_type
;
static_assert
(
std
::
is_same_v
<
To_type
,
cutlass
::
half_t
>
||
std
::
is_same_v
<
To_type
,
cutlass
::
bfloat16_t
>
);
static_assert
(
std
::
is_same_v
<
To_type
,
cutlass
::
half_t
>
||
std
::
is_same_v
<
To_type
,
cutlass
::
bfloat16_t
>
);
static_assert
(
std
::
is_same_v
<
float
,
From_type
>
);
static_assert
(
std
::
is_same_v
<
float
,
From_type
>
);
...
@@ -296,7 +296,7 @@ void cp_async_wait() {
...
@@ -296,7 +296,7 @@ void cp_async_wait() {
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
__force
inline
__
__device__
void
copy
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
)
{
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
...
@@ -365,7 +365,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
...
@@ -365,7 +365,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
template
<
bool
Is_even_K
=
true
,
template
<
bool
Is_even_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_w_min_idx
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
__force
inline
__
__device__
void
copy_w_min_idx
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
min_MN
=
0
)
{
const
int
max_MN
=
0
,
const
int
min_MN
=
0
)
{
...
@@ -395,7 +395,7 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
...
@@ -395,7 +395,7 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
template
<
bool
Is_even_K
=
true
,
bool
Clear_OOB_K
=
true
,
template
<
bool
Is_even_K
=
true
,
bool
Clear_OOB_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_rotary_interleaved
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
__force
inline
__
__device__
void
copy_rotary_interleaved
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
Cos
,
Tensor
<
Engine2
,
Layout2
>
const
&
Cos
,
Tensor
<
Engine2
,
Layout2
>
const
&
Sin
,
Tensor
<
Engine2
,
Layout2
>
const
&
Sin
,
...
@@ -458,7 +458,7 @@ inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S
...
@@ -458,7 +458,7 @@ inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S
template
<
bool
Is_even_K
=
true
,
bool
Clear_OOB_K
=
true
,
template
<
bool
Is_even_K
=
true
,
bool
Clear_OOB_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_rotary_contiguous
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
__force
inline
__
__device__
void
copy_rotary_contiguous
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
Cos
,
Tensor
<
Engine2
,
Layout2
>
const
&
Cos
,
Tensor
<
Engine2
,
Layout2
>
const
&
Sin
,
Tensor
<
Engine2
,
Layout2
>
const
&
Sin
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment