Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3c88451a
Unverified
Commit
3c88451a
authored
Mar 25, 2022
by
yjk21
Committed by
GitHub
Mar 25, 2022
Browse files
update fmha (#1344)
parent
a0ed4151
Changes
26
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
860 additions
and
577 deletions
+860
-577
apex/contrib/csrc/fmha/fmha_api.cpp
apex/contrib/csrc/fmha/fmha_api.cpp
+15
-108
apex/contrib/csrc/fmha/src/fmha.h
apex/contrib/csrc/fmha/src/fmha.h
+43
-10
apex/contrib/csrc/fmha/src/fmha/gemm.h
apex/contrib/csrc/fmha/src/fmha/gemm.h
+0
-3
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
+38
-10
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
+4
-2
apex/contrib/csrc/fmha/src/fmha/mask.h
apex/contrib/csrc/fmha/src/fmha/mask.h
+5
-0
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
+0
-2
apex/contrib/csrc/fmha/src/fmha/softmax.h
apex/contrib/csrc/fmha/src/fmha/softmax.h
+144
-227
apex/contrib/csrc/fmha/src/fmha/utils.h
apex/contrib/csrc/fmha/src/fmha/utils.h
+85
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
+1
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
+1
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
+1
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
+1
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
+3
-3
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
+5
-7
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
+42
-16
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
+42
-16
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
+43
-16
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
+86
-47
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
+301
-106
No files found.
apex/contrib/csrc/fmha/fmha_api.cpp
View file @
3c88451a
...
@@ -72,7 +72,7 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
...
@@ -72,7 +72,7 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
constexpr
float
scale_softmax
=
1.
f
;
constexpr
float
scale_softmax
=
1.
f
;
constexpr
float
scale_bmm2
=
1.
f
;
constexpr
float
scale_bmm2
=
1.
f
;
set_alpha
(
params
.
scale_bmm1
,
scale_bmm1
,
acc
_type
);
set_alpha
(
params
.
scale_bmm1
,
scale_bmm1
,
data
_type
);
set_alpha
(
params
.
scale_softmax
,
scale_softmax
,
acc_type
);
set_alpha
(
params
.
scale_softmax
,
scale_softmax
,
acc_type
);
set_alpha
(
params
.
scale_bmm2
,
scale_bmm2
,
data_type
);
set_alpha
(
params
.
scale_bmm2
,
scale_bmm2
,
data_type
);
...
@@ -83,16 +83,21 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
...
@@ -83,16 +83,21 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
}
}
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
mha_fwd
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
const
float
p_dropout
,
const
int
max_seq_len
,
const
int
max_seq_len
,
const
bool
is_training
,
const
bool
is_training
,
const
bool
is_nl
,
const
bool
zero_tensors
,
const
bool
zero_tensors
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
c10
::
optional
<
at
::
Generator
>
gen_
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
);
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
Launch_params
<
Fused_multihead_attention_fprop_params
>
launch_params
(
dprops
,
stream
,
is_training
,
is_nl
);
int
seq_len
=
512
;
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_fp16_512_64_sm80
;
auto
launch
=
&
run_fmha_fp16_512_64_sm80
;
if
(
max_seq_len
<=
128
)
{
if
(
max_seq_len
<=
128
)
{
...
@@ -111,18 +116,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
...
@@ -111,18 +116,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK
(
false
);
TORCH_CHECK
(
false
);
}
}
constexpr
int
warps_m
=
1
;
constexpr
int
warps_n
=
4
;
// this leads to an upper bound
const
int
mmas_m
=
seq_len
/
16
/
warps_m
;
const
int
mmas_n
=
seq_len
/
16
/
warps_n
;
const
int
elts_per_thread
=
8
*
mmas_m
*
mmas_n
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
...
@@ -156,9 +149,8 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
...
@@ -156,9 +149,8 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
set_params
(
launch_params
.
params
,
batch_size
,
batch_size
,
seq_len
,
seq_len
,
num_heads
,
num_heads
,
...
@@ -169,22 +161,24 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
...
@@ -169,22 +161,24 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s
.
data_ptr
(),
s
.
data_ptr
(),
p_dropout
);
p_dropout
);
// number of times random will be generated per thread, to offset philox counter in the random
launch
(
launch_params
,
/*configure=*/
true
);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
int64_t
counter_offset
=
elts_per_thread
;
int64_t
counter_offset
=
launch_params
.
elts_per_thread
;
at
::
PhiloxCudaState
rng_engine_inputs
;
at
::
PhiloxCudaState
rng_engine_inputs
;
if
(
is_training
)
{
if
(
is_training
)
{
// See Note [Acquire lock when using random generators]
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
launch_params
.
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
}
launch
(
params
,
is_training
,
stream
);
launch
(
launch_
params
,
/*configure=*/
false
);
return
{
ctx
,
s
};
return
{
ctx
,
s
};
}
}
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
mha_bwd
(
const
at
::
Tensor
&
dout
,
// total x num_heads, x head_size
mha_bwd
(
const
at
::
Tensor
&
dout
,
// total x num_heads, x head_size
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
...
@@ -270,92 +264,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -270,92 +264,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return
{
dqkv
,
softmax
};
return
{
dqkv
,
softmax
};
}
}
std
::
vector
<
at
::
Tensor
>
mha_fwd_nl
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
const
int
max_seq_len
,
const
bool
is_training
,
const
bool
zero_tensors
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_fp16_512_64_sm80_nl
;
TORCH_CHECK
(
max_seq_len
==
seq_len
);
constexpr
int
warps_m
=
1
;
constexpr
int
warps_n
=
4
;
// this leads to an upper bound
const
int
mmas_m
=
seq_len
/
16
/
warps_m
;
const
int
mmas_n
=
seq_len
/
16
/
warps_n
;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const
int
elts_per_thread
=
8
*
mmas_m
*
mmas_n
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
total
=
sizes
[
TOTAL_DIM
];
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
auto
opts
=
qkv
.
options
();
auto
ctx
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
);
auto
s
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
,
seq_len
},
opts
);
if
(
zero_tensors
)
{
ctx
.
zero_
();
s
.
zero_
();
}
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
ctx
.
data_ptr
(),
s
.
data_ptr
(),
p_dropout
);
// number of times random will be generated per thread, to offset philox counter in the random
// state
int64_t
counter_offset
=
elts_per_thread
;
at
::
PhiloxCudaState
rng_engine_inputs
;
if
(
is_training
)
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
int
num_chunks
=
3
;
if
(
batch_size
==
3
)
{
num_chunks
=
2
;
}
launch
(
params
,
is_training
,
num_chunks
,
stream
);
return
{
ctx
,
s
};
}
std
::
vector
<
at
::
Tensor
>
mha_bwd_nl
(
const
at
::
Tensor
&
dout
,
// total x num_heads, x head_size
std
::
vector
<
at
::
Tensor
>
mha_bwd_nl
(
const
at
::
Tensor
&
dout
,
// total x num_heads, x head_size
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
...
@@ -449,6 +357,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -449,6 +357,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
doc
()
=
"Fused Multi-head Self-attention for BERT"
;
m
.
doc
()
=
"Fused Multi-head Self-attention for BERT"
;
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
m
.
def
(
"fwd_nl"
,
&
mha_fwd_nl
,
"Forward pass (small-batch)"
);
m
.
def
(
"bwd_nl"
,
&
mha_bwd_nl
,
"Backward pass (small-batch)"
);
m
.
def
(
"bwd_nl"
,
&
mha_bwd_nl
,
"Backward pass (small-batch)"
);
}
}
apex/contrib/csrc/fmha/src/fmha.h
View file @
3c88451a
...
@@ -50,7 +50,7 @@ constexpr int D_DIM = 3;
...
@@ -50,7 +50,7 @@ constexpr int D_DIM = 3;
struct
Qkv_params
{
struct
Qkv_params
{
// The QKV matrices.
// The QKV matrices.
void
*
qkv_ptr
;
void
*
__restrict__
qkv_ptr
;
// The stride between rows of the Q, K and V matrices.
// The stride between rows of the Q, K and V matrices.
size_t
qkv_stride_in_bytes
;
size_t
qkv_stride_in_bytes
;
...
@@ -64,19 +64,19 @@ struct Qkv_params {
...
@@ -64,19 +64,19 @@ struct Qkv_params {
struct
Fused_multihead_attention_fprop_params
:
public
Qkv_params
{
struct
Fused_multihead_attention_fprop_params
:
public
Qkv_params
{
// The dQKV matrices.
// The dQKV matrices.
void
*
dqkv_ptr
;
void
*
__restrict__
dqkv_ptr
;
// Temporary for dKV.
// Temporary for dKV.
void
*
dkv_ptr
;
void
*
__restrict__
dkv_ptr
;
// The O matrix (output).
// The O matrix (output).
void
*
o_ptr
;
void
*
__restrict__
o_ptr
;
// The stride between rows of O.
// The stride between rows of O.
int64_t
o_stride_in_bytes
;
int64_t
o_stride_in_bytes
;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void
*
s_ptr
;
void
*
__restrict__
s_ptr
;
// The stride between rows of the S matrix.
// The stride between rows of the S matrix.
int64_t
s_stride_in_bytes
;
int64_t
s_stride_in_bytes
;
...
@@ -87,7 +87,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -87,7 +87,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
uint32_t
scale_bmm1
,
scale_softmax
,
scale_bmm2
;
uint32_t
scale_bmm1
,
scale_softmax
,
scale_bmm2
;
// array of length b+1 holding starting offset of each sequence.
// array of length b+1 holding starting offset of each sequence.
int
*
cu_seqlens
;
int
*
__restrict__
cu_seqlens
;
// The dropout probability (probability of keeping an activation).
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
float
p_dropout
;
...
@@ -104,10 +104,43 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -104,10 +104,43 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
template
<
typename
Kernel_params
>
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
struct
Launch_params
{
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
Launch_params
(
cudaDeviceProp
*
props_
,
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
cudaStream_t
stream_
,
bool
is_training_
,
bool
is_nl_
)
:
elts_per_thread
(
0
)
,
props
(
props_
)
,
stream
(
stream_
)
,
is_training
(
is_training_
)
,
is_nl
(
is_nl_
)
{
}
size_t
elts_per_thread
;
cudaDeviceProp
*
props
;
cudaStream_t
stream
;
bool
is_training
;
Kernel_params
params
;
int
num_full_heads
;
int
num_main_groups
;
int
heads_last_wave
;
int
main_steps
;
int
rest_steps
;
bool
is_nl
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void
run_fmha_fp16_128_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_fp16_256_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_fp16_384_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_fp16_512_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
);
void
run_fmha_dgrad_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
...
...
apex/contrib/csrc/fmha/src/fmha/gemm.h
View file @
3c88451a
...
@@ -210,9 +210,6 @@ struct Clear_accumulator<float, WARPS_K> {
...
@@ -210,9 +210,6 @@ struct Clear_accumulator<float, WARPS_K> {
}
}
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
...
...
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
View file @
3c88451a
...
@@ -60,7 +60,7 @@ struct Gmem_tile_qkv {
...
@@ -60,7 +60,7 @@ struct Gmem_tile_qkv {
// Ctor.
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_qkv
(
const
Params
&
params
,
int
qkv_offset
,
const
BInfo
&
binfo
,
int
tidx
)
inline
__device__
Gmem_tile_qkv
(
const
Params
&
params
,
const
int
qkv_offset
,
const
BInfo
&
binfo
,
const
int
tidx
)
:
params_qkv_stride_in_bytes_
(
params
.
qkv_stride_in_bytes
)
:
params_qkv_stride_in_bytes_
(
params
.
qkv_stride_in_bytes
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
qkv_ptr_
(
reinterpret_cast
<
char
*>
(
params
.
qkv_ptr
))
{
,
qkv_ptr_
(
reinterpret_cast
<
char
*>
(
params
.
qkv_ptr
))
{
...
@@ -125,6 +125,11 @@ struct Gmem_tile_qkv {
...
@@ -125,6 +125,11 @@ struct Gmem_tile_qkv {
actual_seqlen
-=
ROWS
;
actual_seqlen
-=
ROWS
;
}
}
inline
__device__
void
move
(
int
steps
)
{
qkv_ptr_
+=
(
int64_t
)
ROWS
*
params_qkv_stride_in_bytes_
*
steps
;
actual_seqlen
-=
ROWS
*
steps
;
}
// The stride between rows for the QKV matrice.
// The stride between rows for the QKV matrice.
int64_t
params_qkv_stride_in_bytes_
;
int64_t
params_qkv_stride_in_bytes_
;
// The pointer.
// The pointer.
...
@@ -224,6 +229,11 @@ struct Gmem_tile_o {
...
@@ -224,6 +229,11 @@ struct Gmem_tile_o {
o_ptr_
+=
(
int64_t
)
ROWS
*
params_o_stride_in_bytes_
;
o_ptr_
+=
(
int64_t
)
ROWS
*
params_o_stride_in_bytes_
;
}
}
inline
__device__
void
move
(
const
int
steps
)
{
row_
+=
ROWS
*
steps
;
o_ptr_
+=
(
int64_t
)
ROWS
*
params_o_stride_in_bytes_
*
steps
;
}
// The stride between rows for the QKV matrice.
// The stride between rows for the QKV matrice.
int64_t
params_o_stride_in_bytes_
;
int64_t
params_o_stride_in_bytes_
;
// The pointer.
// The pointer.
...
@@ -270,13 +280,9 @@ struct Gmem_tile_mma_sd {
...
@@ -270,13 +280,9 @@ struct Gmem_tile_mma_sd {
// Ctor.
// Ctor.
template
<
typename
Params
>
template
<
typename
Params
>
inline
__device__
Gmem_tile_mma_sd
(
void
*
ptr
,
const
Params
&
params
,
const
int
tidx
)
inline
__device__
Gmem_tile_mma_sd
(
void
*
ptr
,
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
tidx
)
:
ptr_
(
static_cast
<
char
*>
(
ptr
))
{
:
ptr_
(
static_cast
<
char
*>
(
ptr
))
{
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The block index.
// The block index.
size_t
bidx
=
bidb
*
params
.
h
+
bidh
;
size_t
bidx
=
bidb
*
params
.
h
+
bidh
;
...
@@ -300,6 +306,9 @@ struct Gmem_tile_mma_sd {
...
@@ -300,6 +306,9 @@ struct Gmem_tile_mma_sd {
inline
__device__
void
move
()
{
inline
__device__
void
move
()
{
ptr_
+=
LOOP_STRIDE_BYTES
;
ptr_
+=
LOOP_STRIDE_BYTES
;
}
}
inline
__device__
void
move
(
const
int
steps
)
{
ptr_
+=
LOOP_STRIDE_BYTES
*
steps
;
}
// The pointer in global memory.
// The pointer in global memory.
char
*
ptr_
;
char
*
ptr_
;
...
@@ -318,9 +327,9 @@ struct Gmem_tile_mma_s : public Base {
...
@@ -318,9 +327,9 @@ struct Gmem_tile_mma_s : public Base {
using
Type
=
typename
Base
::
Type
;
using
Type
=
typename
Base
::
Type
;
// Ctor.
// Ctor.
template
<
typename
Params
>
template
<
typename
Params
,
typename
Block_info
>
inline
__device__
Gmem_tile_mma_s
(
void
*
ptr
,
const
Params
&
params
,
const
int
tidx
)
inline
__device__
Gmem_tile_mma_s
(
const
Params
&
params
,
const
Block_info
&
binfo
,
const
int
tidx
)
:
Base
(
ptr
,
params
,
tidx
)
{
:
Base
(
params
.
s_
ptr
,
params
,
binfo
.
bidb
,
binfo
.
bidh
,
tidx
)
{
}
}
// Store to global memory.
// Store to global memory.
...
@@ -353,6 +362,25 @@ struct Gmem_tile_mma_s : public Base {
...
@@ -353,6 +362,25 @@ struct Gmem_tile_mma_s : public Base {
}
}
}
}
// Store to global memory.
template
<
typename
Mask
,
typename
Fragment
>
inline
__device__
void
store
(
const
Fragment
(
&
frag
)[
N
][
M
],
const
Mask
&
mask
){
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint4
dst
;
dst
.
x
=
frag
[
ni
][
mi
].
reg
(
0
);
dst
.
y
=
frag
[
ni
][
mi
].
reg
(
2
);
dst
.
z
=
frag
[
ni
][
mi
].
reg
(
1
);
dst
.
w
=
frag
[
ni
][
mi
].
reg
(
3
);
if
(
mask
.
any_valid
(
mi
,
ni
)
)
{
Base
::
store
(
dst
,
mi
,
ni
);
}
}
}
}
// Load from global memory.
// Load from global memory.
template
<
typename
Mask
>
template
<
typename
Mask
>
inline
__device__
void
load
(
uint4
(
&
regs
)[
M
][
N
],
const
Mask
&
mask
)
{
inline
__device__
void
load
(
uint4
(
&
regs
)[
M
][
N
],
const
Mask
&
mask
)
{
...
@@ -361,7 +389,7 @@ struct Gmem_tile_mma_s : public Base {
...
@@ -361,7 +389,7 @@ struct Gmem_tile_mma_s : public Base {
#pragma unroll
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
regs
[
mi
][
ni
]
=
make_uint4
(
0
,
0
,
0
,
0
);
regs
[
mi
][
ni
]
=
make_uint4
(
0
,
0
,
0
,
0
);
if
(
mask
.
is
_valid
(
mi
,
ni
,
0
,
0
)
)
{
if
(
mask
.
any
_valid
(
mi
,
ni
)
)
{
Base
::
load
(
regs
[
mi
][
ni
],
mi
,
ni
);
Base
::
load
(
regs
[
mi
][
ni
],
mi
,
ni
);
}
}
}
}
...
...
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
View file @
3c88451a
...
@@ -29,7 +29,7 @@
...
@@ -29,7 +29,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
S
,
int
D
,
int
STEP
,
int
WARPS_M
,
int
WARPS_N
,
uint32_t
FLAGS
=
0x8u
>
template
<
int
S
,
int
D
,
int
STEP
,
int
WARPS_M
,
int
WARPS_N
,
uint32_t
FLAGS
=
0x
0
8u
>
struct
FMHA_kernel_traits
{
struct
FMHA_kernel_traits
{
// The CTA description for the 1st GEMM.
// The CTA description for the 1st GEMM.
...
@@ -38,7 +38,9 @@ struct FMHA_kernel_traits {
...
@@ -38,7 +38,9 @@ struct FMHA_kernel_traits {
using
Cta_tile_o
=
fmha
::
Cta_tile_extd
<
STEP
,
D
,
S
,
WARPS_M
,
1
,
WARPS_N
>
;
using
Cta_tile_o
=
fmha
::
Cta_tile_extd
<
STEP
,
D
,
S
,
WARPS_M
,
1
,
WARPS_N
>
;
// Do we use one buffer for K and V.
// Do we use one buffer for K and V.
enum
{
SHARE_SMEM_FOR_K_AND_V
=
(
FLAGS
&
0x8u
)
!=
0u
};
enum
{
SHARE_SMEM_FOR_K_AND_V
=
(
FLAGS
&
0x08u
)
!=
0u
};
// Do we keep K in registers.
enum
{
K_IN_REGS
=
(
FLAGS
&
0x10u
)
==
0u
};
// The global memory tile to load Q.
// The global memory tile to load Q.
using
Gmem_tile_q
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_A
,
STEP
,
D
>
;
using
Gmem_tile_q
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_A
,
STEP
,
D
>
;
...
...
apex/contrib/csrc/fmha/src/fmha/mask.h
View file @
3c88451a
...
@@ -63,6 +63,11 @@ struct Mask {
...
@@ -63,6 +63,11 @@ struct Mask {
// return row_valid && col_valid;
// return row_valid && col_valid;
}
}
//BERT Mask: if upper left is invalid, none are valid
inline
__device__
bool
any_valid
(
int
mi
,
int
ni
)
const
{
return
is_valid
(
mi
,
ni
,
0
,
0
);
}
inline
__device__
void
load
(
int
it
)
{
inline
__device__
void
load
(
int
it
)
{
row_offset
=
it
*
Cta_tile
::
M
+
row
;
row_offset
=
it
*
Cta_tile
::
M
+
row
;
}
}
...
...
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
View file @
3c88451a
...
@@ -1266,8 +1266,6 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1266,8 +1266,6 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
}
}
template
<
int
M
,
int
N
>
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
...
...
apex/contrib/csrc/fmha/src/fmha/softmax.h
View file @
3c88451a
...
@@ -55,6 +55,88 @@ inline __device__ float apply_exp_(float x, float max) {
...
@@ -55,6 +55,88 @@ inline __device__ float apply_exp_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
COLS
>
struct
ReadType
{};
template
<
>
struct
ReadType
<
4
>
{
using
T
=
float
;};
template
<
>
struct
ReadType
<
8
>
{
using
T
=
float2
;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Smem_tile_reduce
{
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The number of MMAs in M/N dimensions.
enum
{
MMAS_M
=
Mma_tile
::
MMAS_M
};
enum
{
MMAS_N
=
Mma_tile
::
MMAS_N
};
enum
{
WARPS_M
=
Cta_tile
::
WARPS_M
};
enum
{
WARPS_N
=
Cta_tile
::
WARPS_N
};
static
constexpr
int
ROWS
=
WARPS_M
*
MMAS_M
*
16
;
static
constexpr
int
COLS
=
WARPS_N
;
static_assert
(
COLS
==
4
||
COLS
==
8
);
static
constexpr
int
ROWS_PER_XOR_PATTERN
=
(
COLS
==
8
)
?
4
:
8
;
static
constexpr
int
BYTES_PER_TILE
=
ROWS
*
COLS
*
sizeof
(
float
);
static
constexpr
int
ELTS_PER_TILE
=
ROWS
*
COLS
;
static
constexpr
int
THREADS_PER_GROUP
=
Kernel_traits
::
Gmem_tile_o
::
THREADS_PER_ROW
;
static_assert
(
THREADS_PER_GROUP
==
16
);
// DEBUG
static
constexpr
int
ROWS_PER_WARP
=
32
/
THREADS_PER_GROUP
;
static
constexpr
int
LOOPS
=
Kernel_traits
::
Gmem_tile_o
::
LOOPS
;
static_assert
(
LOOPS
==
1
);
using
read_t
=
typename
ReadType
<
COLS
>::
T
;
__device__
inline
Smem_tile_reduce
(
float
*
smem_
,
const
int
tidx
)
{
int
lane
=
tidx
%
32
;
int
warp
=
tidx
/
32
;
int
warp_m
=
warp
%
WARPS_M
;
int
warp_n
=
warp
/
WARPS_M
;
qid_
=
lane
%
4
;
int
qp
=
lane
/
4
;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const
int
col
=
warp_n
^
(
qp
/
ROWS_PER_XOR_PATTERN
);
smem_write_
=
&
smem_
[
warp_m
*
16
*
MMAS_M
*
WARPS_N
+
qp
*
WARPS_N
+
col
];
smem_read_
=
&
reinterpret_cast
<
read_t
*>
(
smem_
)[
warp_m
*
16
*
MMAS_M
*
4
+
qp
*
4
+
qid_
];
}
__device__
inline
void
store
(
float
(
&
frag
)[
2
*
MMAS_M
])
{
if
(
qid_
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
WARPS_N
;
smem_write_
[
offset
+
0
*
8
*
WARPS_N
]
=
frag
[
mi
*
2
+
0
];
smem_write_
[
offset
+
1
*
8
*
WARPS_N
]
=
frag
[
mi
*
2
+
1
];
}
}
}
__device__
inline
void
load
(
read_t
(
&
frag
)[
2
*
MMAS_M
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
4
;
frag
[
mi
*
2
+
0
]
=
smem_read_
[
offset
+
0
*
8
*
4
];
frag
[
mi
*
2
+
1
]
=
smem_read_
[
offset
+
1
*
8
*
4
];
}
}
int
qid_
;
float
*
smem_write_
;
read_t
*
smem_read_
;
};
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Softmax_base
{
struct
Softmax_base
{
...
@@ -136,201 +218,6 @@ struct Softmax_base {
...
@@ -136,201 +218,6 @@ struct Softmax_base {
}
}
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce_1x4
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if
(
Functor
::
IS_SUM
)
{
// Apply the summation inside the thread.
float
tmp
[
MMAS_M
*
2
][
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
tmp
[
mi
][
0
]
=
0.
f
;
tmp
[
mi
][
1
]
=
0.
f
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
0
];
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
1
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
2
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
3
];
}
dst
[
mi
]
=
tmp
[
mi
][
0
]
+
tmp
[
mi
][
1
];
}
}
else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
elt_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
elt_
[
mi
][
ni
]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
__syncwarp
();
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
__syncwarp
();
}
// Store the different values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
if
(
tidx_
%
4
==
0
)
{
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
0
];
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
1
];
}
}
// Make sure the values are in shared memory.
__syncthreads
();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4
tmp
[
1
];
if
(
tidx_
<
Cta_tile
::
M
)
{
tmp
[
0
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
0
*
ELEMENTS
/
2
])[
tidx_
];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
y
);
tmp
[
0
].
z
=
Functor
::
apply
(
tmp
[
0
].
z
,
tmp
[
0
].
w
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
z
);
// Make sure we can write to shared memory.
__syncthreads
();
// Store the value back to shared memory.
if
(
tidx_
<
Cta_tile
::
M
)
{
smem_
[
tidx_
]
=
tmp
[
0
].
x
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// Finally read the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
dst
[
2
*
mi
+
0
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
];
dst
[
2
*
mi
+
1
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
];
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce_1x8
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if
(
Functor
::
IS_SUM
)
{
// Apply the summation inside the thread.
float
tmp
[
MMAS_M
*
2
][
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
tmp
[
mi
][
0
]
=
0.
f
;
tmp
[
mi
][
1
]
=
0.
f
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
0
];
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
1
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
2
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
3
];
}
dst
[
mi
]
=
tmp
[
mi
][
0
]
+
tmp
[
mi
][
1
];
}
}
else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
elt_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
elt_
[
mi
][
ni
]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
__syncwarp
();
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
__syncwarp
();
}
// Store the different values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
if
(
tidx_
%
4
==
0
)
{
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
0
];
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
1
];
}
}
// Make sure the values are in shared memory.
__syncthreads
();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4
tmp
[
2
];
if
(
tidx_
<
Cta_tile
::
M
)
{
tmp
[
0
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
0
*
ELEMENTS
/
2
])[
tidx_
];
tmp
[
1
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
1
*
ELEMENTS
/
2
])[
tidx_
];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
y
);
tmp
[
0
].
z
=
Functor
::
apply
(
tmp
[
0
].
z
,
tmp
[
0
].
w
);
tmp
[
1
].
x
=
Functor
::
apply
(
tmp
[
1
].
x
,
tmp
[
1
].
y
);
tmp
[
1
].
z
=
Functor
::
apply
(
tmp
[
1
].
z
,
tmp
[
1
].
w
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
z
);
tmp
[
1
].
x
=
Functor
::
apply
(
tmp
[
1
].
x
,
tmp
[
1
].
z
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
1
].
x
);
// Make sure we can write to shared memory.
__syncthreads
();
// Store the value back to shared memory.
if
(
tidx_
<
Cta_tile
::
M
)
{
smem_
[
tidx_
]
=
tmp
[
0
].
x
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// Finally read the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
dst
[
2
*
mi
+
0
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
];
dst
[
2
*
mi
+
1
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
];
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
(
Cta_tile
::
WARPS_N
==
4
||
Cta_tile
::
WARPS_N
==
8
));
if
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
4
)
{
reduce_1x4
<
Functor
>
(
dst
);
}
else
if
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
8
)
{
reduce_1x8
<
Functor
>
(
dst
);
}
else
{
assert
(
false
);
}
// Make sure we are done reading from shared memory.
__syncthreads
();
}
// Scale all the elements.
// Scale all the elements.
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
...
@@ -372,6 +259,8 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -372,6 +259,8 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert
(
Fragment_a
::
NUM_REGS
==
4
);
static_assert
(
Fragment_a
::
NUM_REGS
==
4
);
enum
{
WARPS_M
=
Cta_tile
::
WARPS_M
};
enum
{
WARPS_N
=
Cta_tile
::
WARPS_N
};
// The MMAs.
// The MMAs.
enum
{
MMAS_M
=
Base
::
MMAS_M
};
enum
{
MMAS_M
=
Base
::
MMAS_M
};
enum
{
MMAS_N
=
Base
::
MMAS_N
};
enum
{
MMAS_N
=
Base
::
MMAS_N
};
...
@@ -383,41 +272,15 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -383,41 +272,15 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static_assert
(
std
::
is_same
<
Accumulator
::
Data_type
,
float
>::
value
);
static_assert
(
std
::
is_same
<
Accumulator
::
Data_type
,
float
>::
value
);
using
Smem_tile_red
=
Smem_tile_reduce
<
Cta_tile
,
Kernel_traits
>
;
static_assert
(
Smem_tile_red
::
ELTS_PER_TILE
==
Cta_tile
::
M
*
WARPS_N
);
// Ctor.
// Ctor.
template
<
typename
Params
>
template
<
typename
Params
>
inline
__device__
Softmax
(
const
Params
&
params
,
void
*
smem
,
int
bidb
,
int
tidx
)
inline
__device__
Softmax
(
const
Params
&
params
,
void
*
smem
,
int
bidb
,
int
tidx
)
:
Base
(
params
,
smem
,
bidb
,
tidx
),
params_scale_bmm1_
(
params
.
scale_bmm1
)
{
:
Base
(
params
,
smem
,
bidb
,
tidx
)
}
,
params_scale_bmm1_
(
params
.
scale_bmm1
)
,
smem_sum_
(
static_cast
<
float
*>
(
smem
),
tidx
)
// Store the tile after softmax.
,
smem_max_
(
static_cast
<
float
*>
(
smem
)
+
Smem_tile_red
::
ELTS_PER_TILE
,
tidx
)
{
template
<
typename
Gmem_tile
>
inline
__device__
void
store
(
Gmem_tile
&
gmem_tile
)
{
Accumulator_out
acc
[
MMAS_M
][
MMAS_N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// The elements.
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
];
float
tmp_02
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
];
float
tmp_03
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
];
float
tmp_10
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
];
float
tmp_11
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
];
float
tmp_12
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
];
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
];
// Transform to accumulators.
acc
[
mi
][
ni
].
reg
(
0
)
=
fmha
::
float2_to_half2
(
tmp_00
,
tmp_01
);
acc
[
mi
][
ni
].
reg
(
1
)
=
fmha
::
float2_to_half2
(
tmp_10
,
tmp_11
);
acc
[
mi
][
ni
].
reg
(
2
)
=
fmha
::
float2_to_half2
(
tmp_02
,
tmp_03
);
acc
[
mi
][
ni
].
reg
(
3
)
=
fmha
::
float2_to_half2
(
tmp_12
,
tmp_13
);
}
}
// Delegate to the gmem tile to store.
gmem_tile
.
store
(
acc
);
}
}
// Pack the data to a fragment for the next GEMM.
// Pack the data to a fragment for the next GEMM.
...
@@ -470,7 +333,61 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -470,7 +333,61 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
}
}
}
}
// Scale FP32 fragments
inline
__device__
void
unpack_noscale
(
const
Accumulator
(
&
acc
)[
MMAS_M
][
MMAS_N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
4
);
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
2
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
3
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
6
);
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
7
);
}
}
}
template
<
typename
Operator
>
__device__
inline
void
reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
for
(
int
mi
=
0
;
mi
<
2
*
MMAS_M
;
mi
++
)
{
frag
[
mi
]
=
this
->
elt_
[
mi
][
0
];
for
(
int
ni
=
1
;
ni
<
4
*
MMAS_N
;
ni
++
)
{
frag
[
mi
]
=
op
(
frag
[
mi
],
this
->
elt_
[
mi
][
ni
]);
}
}
quad_reduce
(
frag
,
frag
,
op
);
smem_red
.
store
(
frag
);
__syncthreads
();
typename
Smem_tile_red
::
read_t
tmp
[
2
*
MMAS_M
];
smem_red
.
load
(
tmp
);
quad_allreduce
(
frag
,
tmp
,
op
);
}
__device__
inline
void
reduce_max
(
float
(
&
frag
)[
2
*
MMAS_M
]){
MaxOp
<
float
>
max
;
reduce_
(
frag
,
max
,
smem_max_
);
}
__device__
inline
void
reduce_sum
(
float
(
&
frag
)[
2
*
MMAS_M
]){
SumOp
<
float
>
sum
;
reduce_
(
frag
,
sum
,
smem_sum_
);
}
const
uint32_t
params_scale_bmm1_
;
const
uint32_t
params_scale_bmm1_
;
Smem_tile_red
smem_max_
;
Smem_tile_red
smem_sum_
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
apex/contrib/csrc/fmha/src/fmha/utils.h
View file @
3c88451a
...
@@ -950,4 +950,89 @@ inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
...
@@ -950,4 +950,89 @@ inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
MaxOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
>
y
?
x
:
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_reduce
(
float
(
&
dst
)[
M
],
float
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
dst
[
mi
]
=
op
(
dst
[
mi
],
__shfl_down_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_reduce
(
float
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
float
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
src
[
mi
].
x
,
src
[
mi
].
y
);
}
quad_reduce
(
dst
,
tmp
,
op
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_allreduce
(
float
(
&
dst
)[
M
],
float
(
&
src
)[
M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
dst
[
mi
]
=
src
[
mi
];
dst
[
mi
]
=
Allreduce
<
4
>::
run
(
dst
[
mi
],
op
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Operator
,
int
M
>
__device__
inline
void
quad_allreduce
(
float
(
&
dst
)[
M
],
float2
(
&
src
)[
M
],
Operator
&
op
)
{
float
tmp
[
M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
){
tmp
[
mi
]
=
op
(
src
[
mi
].
x
,
src
[
mi
].
y
);
}
quad_allreduce
(
dst
,
tmp
,
op
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
View file @
3c88451a
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
View file @
3c88451a
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_256_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_256_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
View file @
3c88451a
...
@@ -28,7 +28,7 @@
...
@@ -28,7 +28,7 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
View file @
3c88451a
...
@@ -29,7 +29,7 @@
...
@@ -29,7 +29,7 @@
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
extern
"C"
__global__
void
fmha_dgrad_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
View file @
3c88451a
...
@@ -141,7 +141,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -141,7 +141,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
// Create the object to do the softmax.
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
...
@@ -231,7 +231,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -231,7 +231,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
}
}
float
p_sum
[
2
*
M
];
float
p_sum
[
2
*
M
];
softmax
.
template
reduce
<
fmha
::
Sum_
>
(
p_sum
);
softmax
.
reduce_sum
(
p_sum
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
#pragma unroll
#pragma unroll
...
@@ -406,7 +406,7 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
...
@@ -406,7 +406,7 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
// Trigger the loads for K.
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
gmem_k
.
load
(
smem_k
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
// Load dP
// Load dP
uint4
s_regs
[
M
][
N
];
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
load
(
s_regs
,
mask
);
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
View file @
3c88451a
...
@@ -114,11 +114,11 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
...
@@ -114,11 +114,11 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
// Allocate the shared memory tile loader for K.
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
Noloop
nl_traits
(
bidc
);
Noloop
nl_traits
(
bidc
,
binfo
);
nl_traits
.
move_all
(
gmem_q
,
gmem_s
);
nl_traits
.
move_all
(
gmem_q
,
gmem_s
);
// Trigger the loads for Q.
// Trigger the loads for Q.
...
@@ -163,8 +163,6 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
...
@@ -163,8 +163,6 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
const
int
loop
=
nl_traits
.
offset_loop_count
(
l
);
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
uint4
s_regs
[
M
][
N
];
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
load
(
s_regs
,
mask
);
...
@@ -230,7 +228,7 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
...
@@ -230,7 +228,7 @@ inline __device__ void compute_dv_1xN_nl(const Params ¶ms) {
}
}
float
p_sum
[
2
*
M
];
float
p_sum
[
2
*
M
];
softmax
.
template
reduce
<
fmha
::
Sum_
>
(
p_sum
);
softmax
.
reduce_sum
(
p_sum
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
#pragma unroll
#pragma unroll
...
@@ -400,7 +398,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) {
...
@@ -400,7 +398,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) {
// Allocate the shared memory tile loader for Q (as B).
// Allocate the shared memory tile loader for Q (as B).
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for dP.
// Allocate the global memory tile loader for dP.
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
Gmem_tile_s
gmem_s
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for dP.
// Allocate the shared memory tile loader for dP.
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
...
@@ -414,7 +412,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) {
...
@@ -414,7 +412,7 @@ inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) {
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
Noloop
nl_traits
(
bidc
);
Noloop
nl_traits
(
bidc
,
binfo
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
...
...
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
View file @
3c88451a
...
@@ -28,31 +28,57 @@
...
@@ -28,31 +28,57 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_128_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
template
<
bool
Is_training
>
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
__global__
}
void
fmha_fprop_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
extern
"C"
__global__
void
fmha_fprop_fp16_128_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
param
s
);
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_step
s
);
}
}
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
void
run_fmha_fp16_128_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_128_64_sm80_train_kernel
:
&
fmha_fprop_fp16_128_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_128_64_sm80_kernel
<
true
>
:
&
fmha_fprop_fp16_128_64_sm80_kernel
<
false
>
;
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
(
);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
);
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
View file @
3c88451a
...
@@ -28,31 +28,57 @@
...
@@ -28,31 +28,57 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_256_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
template
<
bool
Is_training
>
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
__global__
}
void
fmha_fprop_fp16_256_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
extern
"C"
__global__
void
fmha_fprop_fp16_256_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
param
s
);
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_step
s
);
}
}
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
void
run_fmha_fp16_256_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_256_64_sm80_train_kernel
:
&
fmha_fprop_fp16_256_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_256_64_sm80_kernel
<
true
>
:
&
fmha_fprop_fp16_256_64_sm80_kernel
<
false
>
;
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
(
);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
);
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
View file @
3c88451a
...
@@ -26,32 +26,59 @@
...
@@ -26,32 +26,59 @@
******************************************************************************/
******************************************************************************/
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN
_reload_v
.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x
0
8u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x
1
8u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_384_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
template
<
bool
Is_training
>
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
__global__
}
void
fmha_fprop_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
extern
"C"
__global__
void
fmha_fprop_fp16_384_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
param
s
);
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_step
s
);
}
}
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
void
run_fmha_fp16_384_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_384_64_sm80_train_kernel
:
&
fmha_fprop_fp16_384_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_384_64_sm80_kernel
<
true
>
:
&
fmha_fprop_fp16_384_64_sm80_kernel
<
false
>
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_v
+
smem_size_o
+
smem_size_softmax
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
;
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
);
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
View file @
3c88451a
...
@@ -27,72 +27,111 @@
...
@@ -27,72 +27,111 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x0
8
u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x0
0
u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_512_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
template
<
bool
Is_training
>
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
__global__
}
void
fmha_fprop_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
,
const
int
total_heads
)
{
extern
"C"
__global__
void
fmha_fprop_fp16_512_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
template
<
int
CHUNKS
>
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
params
,
total_heads
);
__global__
void
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN_nl
<
CHUNKS
,
Kernel_traits
,
true
>
(
params
);
}
}
template
<
int
CHUNKS
>
template
<
bool
Is_training
>
__global__
void
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
__global__
fmha
::
device_1xN_nl
<
CHUNKS
,
Kernel_traits
,
false
>
(
params
);
void
fmha_fprop_fp16_512_64_sm80_kernel_nl
(
Fused_multihead_attention_fprop_params
params
,
const
int
num_full_heads
,
const
int
num_main_groups
,
const
int
main_group_size
,
const
int
main_steps
,
const
int
rest_steps
)
{
fmha
::
device_1xN
<
Kernel_traits
,
Is_training
>
(
params
,
num_full_heads
,
num_main_groups
,
main_group_size
,
main_steps
,
rest_steps
);
}
}
void
run_fmha_fp16_512_64_sm80_
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
void
run_
fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
launch_params
.
is_training
?
&
fmha_
fprop_
fp16_512_64_sm80
_kernel
<
true
>
:
&
fmha_fprop_fp16_512_64_sm80_kernel
<
false
>
;
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_kernel
:
&
fmha_fprop_fp16_512_64_sm80_predict_kernel
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
void
run_fmha_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
bool
is_training
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
int
ctas_per_sm
;
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
2
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
2
>
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
if
(
num_chunks
==
2
)
{
int
total_ctas
=
sm_count
*
ctas_per_sm
;
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
2
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
2
>
;
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
}
else
if
(
num_chunks
==
3
)
{
if
(
configure
)
{
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
3
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
3
>
;
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
typename
Kernel_traits
::
Cta_tile_p
>
;
}
else
if
(
num_chunks
==
4
)
{
constexpr
size_t
STEPS
=
Kernel_traits
::
Cta_tile_p
::
N
/
Kernel_traits
::
Cta_tile_p
::
M
;
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
4
>
constexpr
size_t
MMAS_M
=
Mma_tile_p
::
MMAS_M
;
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
4
>
;
constexpr
size_t
MMAS_N
=
Mma_tile_p
::
MMAS_N
;
}
else
{
assert
(
false
&&
"Unsupported num_chunks"
);
size_t
heads_per_cta
=
((
heads_total
+
total_ctas
-
1
)
/
total_ctas
);
size_t
elts_per_head
=
STEPS
*
MMAS_M
*
MMAS_N
*
8
;
launch_params
.
elts_per_thread
=
heads_per_cta
*
elts_per_head
;
return
;
}
}
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
dim3
grid
(
total_ctas
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
launch_params
.
params
,
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
heads_total
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_fp16_512_64_sm80_nl_
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
auto
kernel
=
launch_params
.
is_training
?
&
fmha_fprop_fp16_512_64_sm80_kernel_nl
<
true
>
:
&
fmha_fprop_fp16_512_64_sm80_kernel_nl
<
false
>
;
constexpr
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
();
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
dim3
grid
(
params
.
h
,
params
.
b
,
num_chunks
);
const
int
sm_count
=
launch_params
.
props
->
multiProcessorCount
;
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
int
ctas_per_sm
;
FMHA_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS
,
smem_size
));
int
total_ctas
=
sm_count
*
ctas_per_sm
;
if
(
configure
)
{
const
int
heads_total
=
launch_params
.
params
.
b
*
launch_params
.
params
.
h
;
std
::
tie
(
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
,
launch_params
.
elts_per_thread
)
=
fmha
::
work_dist
<
Kernel_traits
>
(
total_ctas
,
heads_total
);
return
;
}
dim3
grid
(
total_ctas
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
,
launch_params
.
num_full_heads
,
launch_params
.
num_main_groups
,
launch_params
.
heads_last_wave
,
launch_params
.
main_steps
,
launch_params
.
rest_steps
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
void
run_fmha_fp16_512_64_sm80
(
Launch_params
<
Fused_multihead_attention_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
if
(
launch_params
.
is_nl
)
{
run_fmha_fp16_512_64_sm80_nl_
(
launch_params
,
configure
);
}
else
{
run_fmha_fp16_512_64_sm80_
(
launch_params
,
configure
);
}
}
}
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
View file @
3c88451a
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment