Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ed719967
Unverified
Commit
ed719967
authored
Jul 17, 2021
by
yjk21
Committed by
GitHub
Jul 16, 2021
Browse files
Adds small-batch kernels (#1126)
parent
c1378e6f
Changes
15
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1584 additions
and
200 deletions
+1584
-200
apex/contrib/csrc/fmha/fmha_api.cpp
apex/contrib/csrc/fmha/fmha_api.cpp
+172
-45
apex/contrib/csrc/fmha/src/fmha.h
apex/contrib/csrc/fmha/src/fmha.h
+37
-4
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
+4
-2
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
+36
-1
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
+45
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
+27
-68
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
+571
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
+42
-0
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
+7
-9
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
+343
-0
apex/contrib/csrc/fmha/src/fmha_kernel.h
apex/contrib/csrc/fmha/src/fmha_kernel.h
+93
-57
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
+177
-0
apex/contrib/fmha/fmha.py
apex/contrib/fmha/fmha.py
+13
-6
apex/contrib/test/fmha/test_fmha.py
apex/contrib/test/fmha/test_fmha.py
+16
-8
setup.py
setup.py
+1
-0
No files found.
apex/contrib/csrc/fmha/fmha_api.cpp
View file @
ed719967
...
@@ -30,28 +30,6 @@
...
@@ -30,28 +30,6 @@
#include "fmha.h"
#include "fmha.h"
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
set_params
(
Fused_multihead_attention_fprop_params
&
params
,
void
set_params
(
Fused_multihead_attention_fprop_params
&
params
,
// sizes
// sizes
const
size_t
b
,
const
size_t
b
,
...
@@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
...
@@ -61,7 +39,6 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
// device pointers
// device pointers
void
*
qkv_packed_d
,
void
*
qkv_packed_d
,
void
*
cu_seqlens_d
,
void
*
cu_seqlens_d
,
void
*
seqlens_d
,
void
*
o_packed_d
,
void
*
o_packed_d
,
void
*
s_d
,
void
*
s_d
,
float
p_dropout
)
{
float
p_dropout
)
{
...
@@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
...
@@ -79,7 +56,6 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
params
.
o_stride_in_bytes
=
get_size_in_bytes
(
h
*
d
,
data_type
);
params
.
o_stride_in_bytes
=
get_size_in_bytes
(
h
*
d
,
data_type
);
params
.
cu_seqlens
=
static_cast
<
int
*>
(
cu_seqlens_d
);
params
.
cu_seqlens
=
static_cast
<
int
*>
(
cu_seqlens_d
);
params
.
seqlens
=
static_cast
<
int
*>
(
seqlens_d
);
// S = softmax(P)
// S = softmax(P)
params
.
s_ptr
=
s_d
;
params
.
s_ptr
=
s_d
;
...
@@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
...
@@ -107,13 +83,9 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms,
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
}
}
constexpr
uint32_t
NUM_HEADS_DIM
=
2
;
constexpr
uint32_t
THREE_DIM
=
1
;
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
mha_fwd
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
at
::
Tensor
&
seqlens
,
// b
const
float
p_dropout
,
const
float
p_dropout
,
const
int
max_seq_len
,
const
int
max_seq_len
,
const
bool
is_training
,
const
bool
is_training
,
...
@@ -149,17 +121,14 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
...
@@ -149,17 +121,14 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK
(
qkv
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
qkv
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
const
auto
sizes
=
qkv
.
sizes
();
...
@@ -167,10 +136,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
...
@@ -167,10 +136,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
TORCH_CHECK
(
seqlens
.
numel
()
==
batch_size
);
const
int
total
=
sizes
[
TOTAL_DIM
];
const
int
total
=
sizes
[
0
];
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
num_heads
=
sizes
[
NUM_HEADS_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
const
int
head_size
=
sizes
[
3
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
TORCH_CHECK
(
head_size
==
64
);
auto
opts
=
qkv
.
options
();
auto
opts
=
qkv
.
options
();
...
@@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
...
@@ -191,7 +159,6 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
head_size
,
head_size
,
qkv
.
data_ptr
(),
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
seqlens
.
data_ptr
(),
ctx
.
data_ptr
(),
ctx
.
data_ptr
(),
s
.
data_ptr
(),
s
.
data_ptr
(),
p_dropout
);
p_dropout
);
...
@@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -217,7 +184,6 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
at
::
Tensor
&
seqlens
,
// b
const
float
p_dropout
,
// probability to drop
const
float
p_dropout
,
// probability to drop
const
int
max_seq_len
// max sequence length to choose the kernel
const
int
max_seq_len
// max sequence length to choose the kernel
)
{
)
{
...
@@ -247,17 +213,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -247,17 +213,14 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK
(
dout
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
dout
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
softmax
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
softmax
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
qkv
.
is_cuda
());
TORCH_CHECK
(
qkv
.
is_cuda
());
TORCH_CHECK
(
cu_seqlens
.
is_cuda
());
TORCH_CHECK
(
cu_seqlens
.
is_cuda
());
TORCH_CHECK
(
qkv
.
is_contiguous
());
TORCH_CHECK
(
qkv
.
is_contiguous
());
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
());
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
());
TORCH_CHECK
(
seqlens
.
is_contiguous
());
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
const
auto
sizes
=
qkv
.
sizes
();
...
@@ -265,9 +228,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -265,9 +228,8 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
TORCH_CHECK
(
seqlens
.
numel
()
==
batch_size
);
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
num_heads
=
sizes
[
NUM_HEADS_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
const
int
head_size
=
sizes
[
3
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
TORCH_CHECK
(
head_size
==
64
);
...
@@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -282,12 +244,11 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
head_size
,
head_size
,
qkv
.
data_ptr
(),
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
seqlens
.
data_ptr
(),
dout
.
data_ptr
(),
// we set o_ptr to dout
dout
.
data_ptr
(),
// we set o_ptr to dout
softmax
.
data_ptr
(),
// softmax gets overwritten by dP!
softmax
.
data_ptr
(),
// softmax gets overwritten by dP!
p_dropout
);
p_dropout
);
// we're re-using these scales
scales
// we're re-using these scales
Data_type
acc_type
=
DATA_TYPE_FP32
;
Data_type
acc_type
=
DATA_TYPE_FP32
;
set_alpha
(
params
.
scale_bmm1
,
1.
f
,
acc_type
);
set_alpha
(
params
.
scale_bmm1
,
1.
f
,
acc_type
);
set_alpha
(
params
.
scale_softmax
,
1.
f
/
sqrtf
(
head_size
),
acc_type
);
set_alpha
(
params
.
scale_softmax
,
1.
f
/
sqrtf
(
head_size
),
acc_type
);
...
@@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
...
@@ -298,8 +259,174 @@ mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
return
{
dqkv
,
softmax
};
return
{
dqkv
,
softmax
};
}
}
std
::
vector
<
at
::
Tensor
>
mha_fwd_nl
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
const
int
max_seq_len
,
const
bool
is_training
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_fp16_512_64_sm80_nl
;
TORCH_CHECK
(
max_seq_len
==
seq_len
);
constexpr
int
warps_m
=
1
;
constexpr
int
warps_n
=
4
;
// this leads to an upper bound
const
int
mmas_m
=
seq_len
/
16
/
warps_m
;
const
int
mmas_n
=
seq_len
/
16
/
warps_n
;
// static_assert( mmas_m == 32 );
// static_assert( mmas_n == 4 );
const
int
elts_per_thread
=
8
*
mmas_m
*
mmas_n
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
total
=
sizes
[
TOTAL_DIM
];
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
auto
opts
=
qkv
.
options
();
auto
ctx
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
);
auto
s
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
,
seq_len
},
opts
);
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
ctx
.
data_ptr
(),
s
.
data_ptr
(),
p_dropout
);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t
counter_offset
=
elts_per_thread
;
at
::
PhiloxCudaState
rng_engine_inputs
;
if
(
is_training
)
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
int
num_chunks
=
3
;
if
(
batch_size
==
3
)
{
num_chunks
=
2
;
}
launch
(
params
,
is_training
,
num_chunks
,
stream
);
return
{
ctx
,
s
};
}
std
::
vector
<
at
::
Tensor
>
mha_bwd_nl
(
const
at
::
Tensor
&
dout
,
// total x num_heads, x head_size
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
float
p_dropout
,
// probability to drop
const
int
max_seq_len
// max sequence length to choose the kernel
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
const
int
total
=
sizes
[
TOTAL_DIM
];
const
int
num_heads
=
sizes
[
H_DIM
];
const
int
head_size
=
sizes
[
D_DIM
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_dgrad_fp16_512_64_sm80_nl
;
auto
opts
=
qkv
.
options
();
auto
dqkv
=
torch
::
empty_like
(
qkv
);
int
num_chunks
=
2
;
if
(
batch_size
==
1
)
{
num_chunks
=
4
;
}
else
if
(
batch_size
==
2
)
{
num_chunks
=
3
;
}
auto
dkv
=
torch
::
empty
({
total
,
num_chunks
,
2
,
num_heads
,
head_size
},
opts
);
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
dout
.
data_ptr
(),
// o_ptr = dout
softmax
.
data_ptr
(),
// softmax gets overwritten by dP!
p_dropout
);
params
.
dkv_ptr
=
dkv
.
data_ptr
();
Data_type
acc_type
=
DATA_TYPE_FP32
;
set_alpha
(
params
.
scale_bmm1
,
1.
f
,
acc_type
);
set_alpha
(
params
.
scale_softmax
,
1.
f
/
sqrtf
(
head_size
),
acc_type
);
set_alpha
(
params
.
scale_bmm2
,
1.
f
,
DATA_TYPE_FP16
);
params
.
dqkv_ptr
=
dqkv
.
data_ptr
();
launch
(
params
,
num_chunks
,
stream
);
//SPLIT-K reduction of num_chunks dK, dV parts
// The equivalent of the following Pytorch code:
// using namespace torch::indexing;
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
// torch::sum_out(view_out, dkv, 1);
const
int
hidden_size
=
num_heads
*
head_size
;
fmha_run_noloop_reduce
(
dqkv
.
data_ptr
(),
dkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
<
int
>
(),
hidden_size
,
batch_size
,
total
,
num_chunks
,
stream
);
return
{
dqkv
,
softmax
,
dkv
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"Fused Multi-head Self-attention for BERT"
;
m
.
doc
()
=
"Fused Multi-head Self-attention for BERT"
;
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
m
.
def
(
"fwd_nl"
,
&
mha_fwd_nl
,
"Forward pass (small-batch)"
);
m
.
def
(
"bwd_nl"
,
&
mha_bwd_nl
,
"Backward pass (small-batch)"
);
}
}
apex/contrib/csrc/fmha/src/fmha.h
View file @
ed719967
...
@@ -35,6 +35,12 @@
...
@@ -35,6 +35,12 @@
#include <fmha_utils.h>
#include <fmha_utils.h>
constexpr
int
TOTAL_DIM
=
0
;
constexpr
int
THREE_DIM
=
1
;
constexpr
int
H_DIM
=
2
;
constexpr
int
D_DIM
=
3
;
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Qkv_params
{
struct
Qkv_params
{
...
@@ -43,6 +49,9 @@ struct Qkv_params {
...
@@ -43,6 +49,9 @@ struct Qkv_params {
// The stride between rows of the Q, K and V matrices.
// The stride between rows of the Q, K and V matrices.
size_t
qkv_stride_in_bytes
;
size_t
qkv_stride_in_bytes
;
// The number of heads.
int
h
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -52,6 +61,9 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -52,6 +61,9 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
// The dQKV matrices.
void
*
dqkv_ptr
;
void
*
dqkv_ptr
;
// Temporary for dKV.
void
*
dkv_ptr
;
// The O matrix (output).
// The O matrix (output).
void
*
o_ptr
;
void
*
o_ptr
;
...
@@ -64,7 +76,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -64,7 +76,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
int64_t
s_stride_in_bytes
;
int64_t
s_stride_in_bytes
;
// The dimensions.
// The dimensions.
int
b
,
h
,
s
,
d
;
int
b
,
s
,
d
;
// The scaling factors for the kernel.
// The scaling factors for the kernel.
uint32_t
scale_bmm1
,
scale_softmax
,
scale_bmm2
;
uint32_t
scale_bmm1
,
scale_softmax
,
scale_bmm2
;
...
@@ -72,9 +84,6 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -72,9 +84,6 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
// array of length b+1 holding starting offset of each sequence.
int
*
cu_seqlens
;
int
*
cu_seqlens
;
// array of length b holding the actual sequence lenghts.
int
*
seqlens
;
// The dropout probability (probability of keeping an activation).
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
float
p_dropout
;
...
@@ -90,3 +99,27 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
...
@@ -90,3 +99,27 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
bool
is_training
,
const
int
num_chunks
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
int
num_chunks
,
cudaStream_t
stream
);
void
fmha_run_noloop_reduce
(
void
*
out
,
const
void
*
in
,
const
int
*
cu_seqlens
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
total
,
const
int
num_chunks
,
cudaStream_t
stream
);
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
View file @
ed719967
...
@@ -39,7 +39,9 @@ template<
...
@@ -39,7 +39,9 @@ template<
// The number of rows of Q, K or V loaded by this tile.
// The number of rows of Q, K or V loaded by this tile.
int
ROWS
,
int
ROWS
,
// The number of columns.
// The number of columns.
int
COLS
int
COLS
,
// The number of matrics.
int
NUM_MATS
=
3
>
>
struct
Gmem_tile_qkv
{
struct
Gmem_tile_qkv
{
...
@@ -74,7 +76,7 @@ struct Gmem_tile_qkv {
...
@@ -74,7 +76,7 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
int64_t
row_offset
=
(
int64_t
)
row
*
params
.
qkv_stride_in_bytes
;
int64_t
row_offset
=
(
int64_t
)
row
*
params
.
qkv_stride_in_bytes
;
// Add the block index.
// Add the block index.
row_offset
+=
(
int64_t
)((
binfo
.
sum_s
*
3
+
qkv_offset
)
*
binfo
.
h
+
binfo
.
bidh
)
*
BYTES_PER_ROW
;
row_offset
+=
(
int64_t
)((
binfo
.
sum_s
*
NUM_MATS
+
qkv_offset
)
*
binfo
.
h
+
binfo
.
bidh
)
*
BYTES_PER_ROW
;
// Assemble the final pointer.
// Assemble the final pointer.
qkv_ptr_
+=
row_offset
+
col
*
BYTES_PER_LDG
;
qkv_ptr_
+=
row_offset
+
col
*
BYTES_PER_LDG
;
...
...
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
View file @
ed719967
...
@@ -1217,7 +1217,8 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1217,7 +1217,8 @@ struct Smem_tile_mma_epilogue : public Base {
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
((
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
static_assert
((
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
using
Fragment
=
typename
Base
::
Fragment
;
using
Acc
=
fmha
::
Fragment_accumulator
;
inline
__device__
Smem_tile_mma_epilogue
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
inline
__device__
Smem_tile_mma_epilogue
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
const
int
read_row
=
tidx
/
THREADS_PER_ROW
;
const
int
read_row
=
tidx
/
THREADS_PER_ROW
;
...
@@ -1233,6 +1234,40 @@ struct Smem_tile_mma_epilogue : public Base {
...
@@ -1233,6 +1234,40 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
}
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
Acc
(
&
acc
)[
M
][
N
]){
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
// 1st row - 4 elements per row.
float
tmp00
=
acc
[
mi
][
ni
].
elt
(
0
);
float
tmp01
=
acc
[
mi
][
ni
].
elt
(
1
);
float
tmp02
=
acc
[
mi
][
ni
].
elt
(
4
);
float
tmp03
=
acc
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
float
tmp10
=
acc
[
mi
][
ni
].
elt
(
2
);
float
tmp11
=
acc
[
mi
][
ni
].
elt
(
3
);
float
tmp12
=
acc
[
mi
][
ni
].
elt
(
6
);
float
tmp13
=
acc
[
mi
][
ni
].
elt
(
7
);
uint32_t
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
uint32_t
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
uint32_t
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
uint32_t
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
size_t
offset
=
(
this
->
write_offset_
^
(
ni
*
32
))
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
x
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
z
);
offset
^=
4
*
Base
::
BYTES_PER_STS
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
y
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
w
);
}
}
}
template
<
int
M
,
int
N
>
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
View file @
ed719967
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
...
@@ -35,6 +36,13 @@ extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_at
...
@@ -35,6 +36,13 @@ extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_at
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
}
template
<
int
CHUNKS
>
__global__
void
fmha_dgrad_fp16_512_64_sm80_nl_kernel
(
Fused_multihead_attention_fprop_params
params
){
fmha
::
compute_dv_1xN_nl
<
CHUNKS
,
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN_nl
<
CHUNKS
,
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
...
@@ -58,3 +66,40 @@ void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_param
...
@@ -58,3 +66,40 @@ void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_param
dim3
grid
(
params
.
h
,
params
.
b
);
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_512_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
fmha_dgrad_fp16_512_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
}
void
run_fmha_dgrad_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
512
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
auto
kernel
=
fmha_dgrad_fp16_512_64_sm80_nl_kernel
<
2
>
;
if
(
num_chunks
==
2
)
{
kernel
=
fmha_dgrad_fp16_512_64_sm80_nl_kernel
<
2
>
;
}
else
if
(
num_chunks
==
3
)
{
kernel
=
fmha_dgrad_fp16_512_64_sm80_nl_kernel
<
3
>
;
}
else
{
assert
(
false
&&
"Unsupperted number of chunks"
);
}
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
,
num_chunks
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
View file @
ed719967
...
@@ -156,8 +156,10 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -156,8 +156,10 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
fmha
::
Fragment_accumulator
acc_dv
[
Mma_tile_dv
::
MMAS_M
][
Mma_tile_dv
::
MMAS_N
];
fmha
::
Fragment_accumulator
acc_dv
[
Mma_tile_dv
::
MMAS_M
][
Mma_tile_dv
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dv
::
WARPS_K
>::
apply
(
acc_dv
);
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dv
::
WARPS_K
>::
apply
(
acc_dv
);
enum
{
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
};
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
break
;
...
@@ -185,6 +187,13 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -185,6 +187,13 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
int
ki
=
Mma_tile_p
::
MMAS_K
;
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if
(
l
<
STEPS
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Convert from the accumulator type to FP32 for Softmax.
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
softmax
.
unpack
(
acc_p
);
...
@@ -203,8 +212,6 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -203,8 +212,6 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
}
}
}
}
float
d_s
[
2
*
M
][
4
*
N
];
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
#pragma unroll
...
@@ -213,10 +220,11 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -213,10 +220,11 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
const
float
s_dmask
=
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
];
float
&
s_dmask
=
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
];
const
bool
drop
=
reinterpret_cast
<
const
uint32_t
&>
(
s_dmask
)
&
0x80000000
;
const
bool
drop
=
reinterpret_cast
<
const
uint32_t
&>
(
s_dmask
)
&
0x80000000
;
d_s
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
drop
?
0.
f
:
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*
params
.
rp_dropout
;
const
float
d_s
=
drop
?
0.
f
:
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*
params
.
rp_dropout
;
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
d_s
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*
fabsf
(
s_dmask
);
s_dmask
=
fabsf
(
s_dmask
);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
d_s
*
fabsf
(
s_dmask
);
}
}
}
}
}
}
...
@@ -225,6 +233,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -225,6 +233,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
float
p_sum
[
2
*
M
];
float
p_sum
[
2
*
M
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
#pragma unroll
...
@@ -233,20 +242,12 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -233,20 +242,12 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
-=
p_sum
[
2
*
mi
+
ii
]
*
(
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
])
;
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
(
d_s
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
-
p_sum
[
2
*
mi
+
ii
])
*
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*=
scalef
;
fabsf
(
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
])
*
scalef
;
}
}
}
}
}
}
}
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dv
::
MMAS_K
][
Mma_tile_dv
::
MMAS_M
];
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dv
::
MMAS_K
][
Mma_tile_dv
::
MMAS_M
];
smem_s
.
load
(
frag_s
);
smem_s
.
load
(
frag_s
);
for
(
int
ki
=
0
;
ki
<
Mma_tile_dv
::
MMAS_K
;
ki
++
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_dv
::
MMAS_K
;
ki
++
)
{
...
@@ -275,7 +276,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -275,7 +276,7 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
}
// Commit the values for Q into shared memory.
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
if
(
l
<
STEPS
-
1
)
{
gmem_q
.
commit
(
smem_q
);
gmem_q
.
commit
(
smem_q
);
}
}
...
@@ -295,36 +296,15 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
...
@@ -295,36 +296,15 @@ inline __device__ void compute_dv_1xN(const Params ¶ms) {
// Epilogue swizzle for dV
// Epilogue swizzle for dV
Smem_tile_dv
smem_dv
(
&
smem_
[
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Smem_tile_dv
smem_dv
(
&
smem_
[
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
uint4
dv
[
Mma_tile_dv
::
MMAS_M
][
Mma_tile_dv
::
MMAS_N
];
smem_dv
.
store
(
acc_dv
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_dv
::
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_dv
::
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
float
tmp00
=
acc_dv
[
mi
][
ni
].
elt
(
0
);
float
tmp01
=
acc_dv
[
mi
][
ni
].
elt
(
1
);
float
tmp02
=
acc_dv
[
mi
][
ni
].
elt
(
4
);
float
tmp03
=
acc_dv
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
float
tmp10
=
acc_dv
[
mi
][
ni
].
elt
(
2
);
float
tmp11
=
acc_dv
[
mi
][
ni
].
elt
(
3
);
float
tmp12
=
acc_dv
[
mi
][
ni
].
elt
(
6
);
float
tmp13
=
acc_dv
[
mi
][
ni
].
elt
(
7
);
dv
[
mi
][
ni
].
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
dv
[
mi
][
ni
].
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
dv
[
mi
][
ni
].
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
dv
[
mi
][
ni
].
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
}
}
smem_dv
.
store
(
dv
);
__syncthreads
();
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
smem_dv
.
load
(
dv_out
);
Qkv_params
dv_params
;
Qkv_params
dv_params
;
dv_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dv_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dv_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dv_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dv_params
.
h
=
params
.
h
;
Gmem_tile_dv
gmem_dv
(
dv_params
,
2
,
binfo
,
tidx
);
Gmem_tile_dv
gmem_dv
(
dv_params
,
2
,
binfo
,
tidx
);
gmem_dv
.
store
(
dv_out
);
gmem_dv
.
store
(
dv_out
);
}
}
...
@@ -447,13 +427,15 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
...
@@ -447,13 +427,15 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
enum
{
THREADS_PER_ROW
=
32
};
enum
{
THREADS_PER_ROW
=
32
};
enum
{
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
};
// Declare the accumulators for the 2nd gemm.
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Fragment_accumulator
acc_dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dk
::
WARPS_K
>::
apply
(
acc_dk
);
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dk
::
WARPS_K
>::
apply
(
acc_dk
);
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
break
;
...
@@ -492,7 +474,7 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
...
@@ -492,7 +474,7 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
// Store dP to smem for transpose
// Store dP to smem for transpose
smem_s
.
store
(
s_regs
);
smem_s
.
store
(
s_regs
);
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
if
(
l
<
STEPS
-
1
)
{
// Load next part of S
// Load next part of S
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
move
();
gmem_s
.
move
();
...
@@ -544,7 +526,7 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
...
@@ -544,7 +526,7 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
}
}
// Commit the values for Q into shared memory.
// Commit the values for Q into shared memory.
if
(
l
oop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
if
(
l
<
STEPS
-
1
)
{
gmem_q
.
commit
(
smem_q
);
gmem_q
.
commit
(
smem_q
);
}
}
...
@@ -559,37 +541,14 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
...
@@ -559,37 +541,14 @@ inline __device__ void compute_dq_dk_1xN(const Params ¶ms) {
// Epilogue swizzle for dK
// Epilogue swizzle for dK
Smem_tile_dk
smem_dk
(
&
smem_
[
0
],
tidx
);
Smem_tile_dk
smem_dk
(
&
smem_
[
0
],
tidx
);
uint4
dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
smem_dk
.
store
(
acc_dk
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_dk
::
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_dk
::
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
float
tmp00
=
acc_dk
[
mi
][
ni
].
elt
(
0
);
float
tmp01
=
acc_dk
[
mi
][
ni
].
elt
(
1
);
float
tmp02
=
acc_dk
[
mi
][
ni
].
elt
(
4
);
float
tmp03
=
acc_dk
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
float
tmp10
=
acc_dk
[
mi
][
ni
].
elt
(
2
);
float
tmp11
=
acc_dk
[
mi
][
ni
].
elt
(
3
);
float
tmp12
=
acc_dk
[
mi
][
ni
].
elt
(
6
);
float
tmp13
=
acc_dk
[
mi
][
ni
].
elt
(
7
);
dk
[
mi
][
ni
].
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
dk
[
mi
][
ni
].
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
dk
[
mi
][
ni
].
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
dk
[
mi
][
ni
].
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
}
}
smem_dk
.
store
(
dk
);
__syncthreads
();
__syncthreads
();
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
smem_dk
.
load
(
dk_out
);
smem_dk
.
load
(
dk_out
);
Qkv_params
dk_params
;
Qkv_params
dk_params
;
dk_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dk_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dk_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dk_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
dk_params
.
h
=
params
.
h
;
Gmem_tile_dk
gmem_dk
(
dk_params
,
1
,
binfo
,
tidx
);
Gmem_tile_dk
gmem_dk
(
dk_params
,
1
,
binfo
,
tidx
);
gmem_dk
.
store
(
dk_out
);
gmem_dk
.
store
(
dk_out
);
}
}
...
...
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h
0 → 100644
View file @
ed719967
This diff is collapsed.
Click to expand it.
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
View file @
ed719967
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN_nl.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
...
@@ -38,6 +39,17 @@ extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_mult
...
@@ -38,6 +39,17 @@ extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_mult
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
}
template
<
int
CHUNKS
>
__global__
void
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN_nl
<
CHUNKS
,
Kernel_traits
,
true
>
(
params
);
}
template
<
int
CHUNKS
>
__global__
void
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN_nl
<
CHUNKS
,
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_kernel
:
&
fmha_fprop_fp16_512_64_sm80_predict_kernel
;
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_kernel
:
&
fmha_fprop_fp16_512_64_sm80_predict_kernel
;
...
@@ -54,3 +66,33 @@ void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &par
...
@@ -54,3 +66,33 @@ void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &par
dim3
grid
(
params
.
h
,
params
.
b
);
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
}
void
run_fmha_fp16_512_64_sm80_nl
(
const
Fused_multihead_attention_fprop_params
&
params
,
const
bool
is_training
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
2
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
2
>
;
if
(
num_chunks
==
2
)
{
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
2
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
2
>
;
}
else
if
(
num_chunks
==
3
)
{
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
3
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
3
>
;
}
else
if
(
num_chunks
==
4
)
{
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_nl_kernel
<
4
>
:
&
fmha_fprop_fp16_512_64_sm80_predict_nl_kernel
<
4
>
;
}
else
{
assert
(
false
&&
"Unsupported num_chunks"
);
}
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
,
num_chunks
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
View file @
ed719967
...
@@ -174,9 +174,11 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -174,9 +174,11 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
enum
{
THREADS_PER_ROW
=
32
};
enum
{
THREADS_PER_ROW
=
32
};
enum
{
STEPS
=
Cta_tile_p
::
N
/
Cta_tile_p
::
M
};
// Load over the entire sequence length.
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
for
(
int
l
=
0
;
l
<
STEPS
;
l
++
)
{
const
int
loop
=
l
*
Cta_tile_p
::
M
;
if
(
loop
>=
binfo
.
actual_seqlen
)
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
break
;
...
@@ -200,12 +202,8 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -200,12 +202,8 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
// Store the P matrix.
#if defined(STORE_P)
gmem_p
.
store
(
acc_p
);
#endif
// Load the mask for that iteration.
// Load the mask for that iteration.
mask
.
load
(
outer
);
mask
.
load
(
l
);
// Convert from the accumulator type to FP32 for Softmax.
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
softmax
.
unpack
(
acc_p
);
...
@@ -213,7 +211,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -213,7 +211,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
// Apply the mask.
// Apply the mask.
softmax
.
apply_mask
(
mask
);
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
oop
==
0
)
{
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
__syncthreads
();
}
}
...
@@ -261,7 +259,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -261,7 +259,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
}
}
// Trigger the load for the next Q values.
// Trigger the load for the next Q values.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
if
(
l
<
STEPS
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
gmem_q
.
load
(
smem_q
);
...
@@ -320,7 +318,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
...
@@ -320,7 +318,7 @@ template <typename Kernel_traits, bool Is_training, typename Params> inline __de
gmem_o
.
move
();
gmem_o
.
move
();
// Commit the values for Q into shared memory.
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
if
(
l
<
STEPS
-
1
)
{
gmem_q
.
commit
(
smem_q
);
gmem_q
.
commit
(
smem_q
);
}
}
...
...
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_nl.h
0 → 100644
View file @
ed719967
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
CHUNKS
,
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN_nl
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store S/D.
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Noloop
=
Noloop_traits
<
CHUNKS
,
Cta_tile_p
>
;
// Shared memory.
extern
__shared__
char
smem_
[];
const
int
bidc
=
blockIdx
.
z
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
Noloop
nl_traits
(
bidc
);
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
fmha
::
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
nl_traits
.
move_all
(
gmem_q
,
gmem_o
,
gmem_s
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
32
};
// Load over the entire sequence length.
for
(
int
l
=
0
;
l
<
nl_traits
.
num_steps_
;
l
++
)
{
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Trigger the load for the next Q values.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Load the mask for that iteration.
mask
.
load
(
nl_traits
.
loop_offset_
+
l
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
l
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
l
<
nl_traits
.
num_steps_
-
1
)
{
gmem_q
.
commit
(
smem_q
);
__syncthreads
();
smem_q
.
load
(
frag_q
[
0
],
0
);
}
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_kernel.h
View file @
ed719967
...
@@ -40,94 +40,130 @@ namespace fmha {
...
@@ -40,94 +40,130 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
FMHA_VERSION
>
struct
BlockInfo
{};
template
<
int
THREADS_PER_CTA
>
struct
BlockInfoPadded
{
template
<
>
struct
BlockInfo
<
1
>
{
int
actual_seqlen
;
int
bidx
;
int
sum_s
;
int
bidh
;
int
bidb
;
template
<
typename
Params
>
template
<
typename
Params
>
__device__
BlockInfo
(
const
Params
&
params
,
__device__
BlockInfo
Padded
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidb
,
const
int
bidh
,
const
int
bidh
,
const
int
tidx
)
const
int
tidx
)
:
bidb
(
bidb
),
bidh
(
bidh
)
{
:
bidb
(
bidb
),
bidh
(
bidh
),
h
(
params
.
h
)
{
// The block index.
// The block index.
sum_s
=
params
.
b
*
params
.
s
;
sum_s
=
params
.
cu_seqlens
[
bidb
];
actual_seqlen
=
params
.
s
;
actual_seqlen
=
params
.
cu_seqlens
[
bidb
+
1
]
-
sum_s
;
bidx
=
bidb
*
params
.
h
+
bidh
;
bidx
=
sum_s
*
params
.
h
+
bidh
;
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
THREADS_PER_CTA
+
tidx
;
}
}
__device__
bool
stop_early
()
const
{
__device__
bool
stop_early
()
const
{
return
f
alse
;
return
actu
al
_
se
qlen
==
0
;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
BlockInfo
<
2
>
{
int
actual_seqlen
;
int
actual_seqlen
;
int
bidx
;
int
bidx
;
int
sum_s
;
int
sum_s
;
int
bidh
;
int
bidh
;
int
bidb
;
int
bidb
;
int
tidx_global
;
int
h
;
};
template
<
typename
Params
>
////////////////////////////////////////////////////////////////////////////////////////////////////
__device__
BlockInfo
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
tidx
)
:
bidb
(
bidb
),
bidh
(
bidh
)
{
// The block index.
template
<
int
CHUNKS
,
typename
Cta_tile
>
sum_s
=
params
.
cu_seqlens
[
bidb
];
struct
Noloop_traits
{
actual_seqlen
=
params
.
cu_seqlens
[
bidb
+
1
]
-
sum_s
;
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
bidx
=
sum_s
*
params
.
h
+
bidh
;
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
// The size of the subsequence this CTA is processing
enum
{
SUBSEQ
=
SEQLEN
/
CHUNKS
};
static_assert
(
SUBSEQ
*
CHUNKS
==
SEQLEN
);
// The number of steps to process the subsequence
enum
{
NUM_STEPS
=
SUBSEQ
/
STEP
};
static_assert
(
NUM_STEPS
*
Cta_tile
::
M
==
SUBSEQ
);
inline
__device__
Noloop_traits
(
const
int
bidc
)
:
loop_offset_
(
NUM_STEPS
*
bidc
)
,
bidc_
(
bidc
)
{
}
}
__device__
bool
stop_early
()
const
{
template
<
typename
...
Tiles
>
return
actual_seqlen
==
0
;
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
using
expand_type
=
int
[];
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
expand_type
{
(
tiles
.
move
(),
0
)...
};
}
}
}
inline
__device__
int
get_idx_dk
()
const
{
//return bidc_;
return
bidc_
*
2
+
0
;
}
inline
__device__
int
get_idx_dv
()
const
{
//return CHUNKS + bidc_;
return
bidc_
*
2
+
1
;
}
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
// convert loop counter to position in the outer sequence
return
(
loop_offset_
+
l
)
*
STEP
;
}
const
int
loop_offset_
;
const
uint32_t
bidc_
;
const
int
num_steps_
=
NUM_STEPS
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_CTA
>
template
<
typename
Cta_tile
>
struct
BlockInfoPadded
{
struct
Noloop_traits
<
3
,
Cta_tile
>
{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum
{
STEP
=
Cta_tile
::
M
};
enum
{
SEQLEN
=
Cta_tile
::
N
};
template
<
typename
Params
>
static_assert
(
STEP
==
16
&&
SEQLEN
==
512
);
__device__
BlockInfoPadded
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
tidx
)
:
bidb
(
bidb
),
bidh
(
bidh
),
h
(
params
.
h
)
{
// The block index.
inline
__device__
Noloop_traits
(
const
int
bidc
)
sum_s
=
params
.
cu_seqlens
[
bidb
];
:
bidc_
(
bidc
)
actual_seqlen
=
params
.
seqlens
[
bidb
];
,
num_steps_
(
bidc
<
2
?
11
:
10
)
bidx
=
sum_s
*
params
.
h
+
bidh
;
,
loop_offset_
(
bidc
*
11
)
{
}
tidx_global
=
(
bidb
*
params
.
h
+
bidh
)
*
THREADS_PER_CTA
+
tidx
;
template
<
typename
...
Tiles
>
inline
__device__
void
move_all
(
Tiles
&
...
tiles
)
const
{
using
expand_type
=
int
[];
for
(
int
s
=
0
;
s
<
loop_offset_
;
s
++
)
{
expand_type
{
(
tiles
.
move
(),
0
)...
};
}
}
}
__device__
bool
stop_early
()
const
{
inline
__device__
int
get_idx_dk
()
const
{
return
actual_seqlen
==
0
;
//return bidc_;
return
bidc_
*
2
+
0
;
}
}
int
actual_seqlen
;
inline
__device__
int
get_idx_dv
()
const
{
int
bidx
;
//return CHUNKS + bidc_;
int
sum_s
;
return
bidc_
*
2
+
1
;
int
bidh
;
}
int
bidb
;
int
tidx_global
;
inline
__device__
int
offset_loop_count
(
const
int
l
)
{
int
h
;
// convert loop counter to position in the outer sequence
return
(
loop_offset_
+
l
)
*
STEP
;
}
const
int
loop_offset_
;
const
uint32_t
bidc_
;
const
int
num_steps_
;
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu
0 → 100644
View file @
ed719967
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
inline
__device__
float4
ldg128
(
const
void
*
ptr
)
{
return
*
static_cast
<
const
float4
*>
(
ptr
);
}
inline
__device__
void
stg128
(
void
*
ptr
,
const
float4
&
data
)
{
*
static_cast
<
float4
*>
(
ptr
)
=
data
;
}
template
<
typename
T
,
int
THREADS
,
int
HIDDEN_SIZE
,
int
CHUNKS
>
__global__
__launch_bounds__
(
THREADS
)
void
fmha_noloop_reduce_kernel
(
void
*
__restrict__
out
,
const
void
*
__restrict__
in
,
const
int
*
__restrict__
cu_seqlens
,
const
int
batch_size
)
{
enum
{
BYTES_PER_LDG
=
16
};
enum
{
NUM_ELTS
=
BYTES_PER_LDG
/
sizeof
(
T
)
};
// One CTA hidden vector for K and V
enum
{
BYTES_PER_ROW
=
HIDDEN_SIZE
*
sizeof
(
T
)
*
2
};
// The stride in bytes in dQKV
enum
{
OUT_STRIDE_BYTES
=
3
*
HIDDEN_SIZE
*
sizeof
(
T
)
};
// The offset in bytes in dQKV to the dKV part for non-interleaved heads
enum
{
OUT_OFFSET_KV_BYTES
=
HIDDEN_SIZE
*
sizeof
(
T
)
};
static_assert
(
BYTES_PER_ROW
==
HIDDEN_SIZE
*
2
*
sizeof
(
T
));
// Size in bytes of the input tile
enum
{
BYTES_PER_TILE
=
CHUNKS
*
BYTES_PER_ROW
};
enum
{
BYTES_PER_CTA
=
THREADS
*
BYTES_PER_LDG
};
enum
{
LDGS
=
BYTES_PER_ROW
/
BYTES_PER_CTA
};
static_assert
(
BYTES_PER_CTA
*
LDGS
==
BYTES_PER_ROW
);
union
Vec_t
{
float4
raw
;
T
elt
[
NUM_ELTS
];
};
// ZERO-OUT invalid positions in dQKV
const
int
total
=
cu_seqlens
[
batch_size
];
if
(
blockIdx
.
x
>=
total
){
enum
{
BYTES_PER_QKV_ROW
=
3
*
HIDDEN_SIZE
*
sizeof
(
T
)
};
enum
{
STGS
=
BYTES_PER_QKV_ROW
/
BYTES_PER_LDG
};
const
float4
zeros
=
make_float4
(
0.
f
,
0.
f
,
0.
f
,
0.
f
);
char
*
base_ptr
=
static_cast
<
char
*>
(
out
)
+
blockIdx
.
x
*
OUT_STRIDE_BYTES
;
for
(
int
tidx
=
threadIdx
.
x
;
tidx
<
STGS
;
tidx
+=
THREADS
){
stg128
(
base_ptr
+
tidx
*
BYTES_PER_LDG
,
zeros
);
}
return
;
}
// SETUP
const
int
offset_in
=
blockIdx
.
x
*
BYTES_PER_TILE
+
threadIdx
.
x
*
BYTES_PER_LDG
;
const
char
*
ptr_in
=
static_cast
<
const
char
*>
(
in
)
+
offset_in
;
const
int
offset_out
=
blockIdx
.
x
*
OUT_STRIDE_BYTES
+
threadIdx
.
x
*
BYTES_PER_LDG
;
char
*
ptr_out
=
static_cast
<
char
*>
(
out
)
+
OUT_OFFSET_KV_BYTES
+
offset_out
;
// LOAD
Vec_t
local_in
[
CHUNKS
][
LDGS
];
#pragma unroll
for
(
int
c
=
0
;
c
<
CHUNKS
;
c
++
)
{
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
int
offset
=
c
*
BYTES_PER_ROW
+
l
*
BYTES_PER_CTA
;
local_in
[
c
][
l
].
raw
=
ldg128
(
ptr_in
+
offset
);
}
}
// UNPACK
float
acc
[
LDGS
][
NUM_ELTS
];
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
acc
[
l
][
e
]
=
float
(
local_in
[
0
][
l
].
elt
[
e
]);
}
}
// COMPUTE
#pragma unroll
for
(
int
c
=
1
;
c
<
CHUNKS
;
c
++
)
{
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
acc
[
l
][
e
]
+=
float
(
local_in
[
c
][
l
].
elt
[
e
]);
}
}
}
// PACK
Vec_t
local_out
[
LDGS
];
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
#pragma unroll
for
(
int
e
=
0
;
e
<
NUM_ELTS
;
e
++
)
{
local_out
[
l
].
elt
[
e
]
=
T
(
acc
[
l
][
e
]);
}
}
// STORE
#pragma unroll
for
(
int
l
=
0
;
l
<
LDGS
;
l
++
)
{
const
int
offset
=
l
*
BYTES_PER_CTA
;
stg128
(
ptr_out
+
offset
,
local_out
[
l
].
raw
);
}
}
void
fmha_run_noloop_reduce
(
void
*
out
,
const
void
*
in
,
const
int
*
cu_seqlens
,
const
int
hidden_size
,
const
int
batch_size
,
const
int
total
,
const
int
num_chunks
,
cudaStream_t
stream
)
{
const
int
blocks
=
total
;
if
(
hidden_size
==
1024
){
constexpr
int
HIDDEN_SIZE
=
1024
;
constexpr
int
THREADS
=
256
;
if
(
num_chunks
==
2
)
{
fmha_noloop_reduce_kernel
<
half
,
THREADS
,
HIDDEN_SIZE
,
2
><<<
blocks
,
THREADS
,
0
,
stream
>>>
(
out
,
in
,
cu_seqlens
,
batch_size
);
}
else
if
(
num_chunks
==
3
)
{
fmha_noloop_reduce_kernel
<
half
,
THREADS
,
HIDDEN_SIZE
,
3
><<<
blocks
,
THREADS
,
0
,
stream
>>>
(
out
,
in
,
cu_seqlens
,
batch_size
);
}
else
{
assert
(
false
&&
"Unsupported num_chunks"
);
}
}
else
{
assert
(
false
&&
"Unsupported hidden_size"
);
}
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
}
apex/contrib/fmha/fmha.py
View file @
ed719967
...
@@ -32,11 +32,14 @@ import fmhalib as mha
...
@@ -32,11 +32,14 @@ import fmhalib as mha
class
FMHAFun
(
torch
.
autograd
.
Function
):
class
FMHAFun
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
seqlens
,
p_dropout
,
max_s
,
is_training
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
):
context
,
S_dmask
=
mha
.
fwd
(
qkv
,
cu_seqlens
,
seqlens
,
p_dropout
,
max_s
,
is_training
,
None
)
batch_size
=
cu_seqlens
.
numel
()
-
1
if
batch_size
<
4
:
context
,
S_dmask
=
mha
.
fwd_nl
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
None
)
else
:
context
,
S_dmask
=
mha
.
fwd
(
qkv
,
cu_seqlens
,
p_dropout
,
max_s
,
is_training
,
None
)
ctx
.
save_for_backward
(
qkv
,
S_dmask
)
ctx
.
save_for_backward
(
qkv
,
S_dmask
)
ctx
.
cu_seqlens
=
cu_seqlens
ctx
.
cu_seqlens
=
cu_seqlens
ctx
.
seqlens
=
seqlens
ctx
.
p_dropout
=
p_dropout
ctx
.
p_dropout
=
p_dropout
ctx
.
max_s
=
max_s
ctx
.
max_s
=
max_s
return
context
return
context
...
@@ -44,7 +47,11 @@ class FMHAFun(torch.autograd.Function):
...
@@ -44,7 +47,11 @@ class FMHAFun(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dout
):
def
backward
(
ctx
,
dout
):
qkv
,
S_dmask
=
ctx
.
saved_tensors
qkv
,
S_dmask
=
ctx
.
saved_tensors
dqkv
,
dp
=
mha
.
bwd
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
)
batch_size
=
ctx
.
cu_seqlens
.
numel
()
-
1
if
batch_size
<
4
:
dqkv
,
dp
,
_
=
mha
.
bwd_nl
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
)
else
:
dqkv
,
dp
=
mha
.
bwd
(
dout
,
qkv
,
S_dmask
,
ctx
.
cu_seqlens
,
ctx
.
p_dropout
,
ctx
.
max_s
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
...
@@ -60,8 +67,8 @@ class FMHA(torch.nn.Module):
...
@@ -60,8 +67,8 @@ class FMHA(torch.nn.Module):
self
.
d
=
self
.
hidden_size
//
self
.
h
self
.
d
=
self
.
hidden_size
//
self
.
h
assert
self
.
d
*
self
.
h
==
self
.
hidden_size
,
"Invalid hidden size/num_heads"
assert
self
.
d
*
self
.
h
==
self
.
hidden_size
,
"Invalid hidden size/num_heads"
def
forward
(
self
,
qkv
,
cu_seqlens
,
seqlens
,
max_s
,
is_training
=
True
):
def
forward
(
self
,
qkv
,
cu_seqlens
,
max_s
,
is_training
=
True
):
ctx
=
FMHAFun
.
apply
(
qkv
.
view
(
-
1
,
3
,
self
.
h
,
self
.
d
),
cu_seqlens
,
seqlens
,
self
.
p_dropout
,
max_s
,
is_training
)
ctx
=
FMHAFun
.
apply
(
qkv
.
view
(
-
1
,
3
,
self
.
h
,
self
.
d
),
cu_seqlens
,
self
.
p_dropout
,
max_s
,
is_training
)
return
ctx
.
view
(
-
1
,
self
.
hidden_size
)
return
ctx
.
view
(
-
1
,
self
.
hidden_size
)
apex/contrib/test/fmha/test_fmha.py
View file @
ed719967
...
@@ -51,7 +51,8 @@ def py_mha(qkv, amask, b, s, h, d):
...
@@ -51,7 +51,8 @@ def py_mha(qkv, amask, b, s, h, d):
class
TestFMHA
(
unittest
.
TestCase
):
class
TestFMHA
(
unittest
.
TestCase
):
def
run_test
(
self
,
s
):
def
run_test
(
self
,
s
,
b
):
print
(
f
'Test s=
{
s
}
b=
{
b
}
'
)
torch
.
manual_seed
(
1234
)
torch
.
manual_seed
(
1234
)
torch
.
cuda
.
manual_seed
(
1234
)
torch
.
cuda
.
manual_seed
(
1234
)
...
@@ -59,7 +60,6 @@ class TestFMHA(unittest.TestCase):
...
@@ -59,7 +60,6 @@ class TestFMHA(unittest.TestCase):
dtype
=
torch
.
float16
dtype
=
torch
.
float16
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
b
=
32
h
=
16
h
=
16
d
=
64
d
=
64
...
@@ -76,7 +76,10 @@ class TestFMHA(unittest.TestCase):
...
@@ -76,7 +76,10 @@ class TestFMHA(unittest.TestCase):
qkv
.
requires_grad
=
True
qkv
.
requires_grad
=
True
ctx
,
S_
=
mha
.
fwd
(
qkv_vs
,
cu_seqlens
,
seqlens
,
0.0
,
s
,
True
,
None
)
if
b
<
4
:
ctx
,
S_
=
mha
.
fwd_nl
(
qkv_vs
,
cu_seqlens
,
0.0
,
s
,
True
,
None
)
else
:
ctx
,
S_
=
mha
.
fwd
(
qkv_vs
,
cu_seqlens
,
0.0
,
s
,
True
,
None
)
ctx
=
ctx
.
view
(
b
,
s
,
h
,
d
)
ctx
=
ctx
.
view
(
b
,
s
,
h
,
d
)
ctx_ref
=
py_mha
(
qkv
,
amask
,
b
,
s
,
h
,
d
)
ctx_ref
=
py_mha
(
qkv
,
amask
,
b
,
s
,
h
,
d
)
...
@@ -91,23 +94,28 @@ class TestFMHA(unittest.TestCase):
...
@@ -91,23 +94,28 @@ class TestFMHA(unittest.TestCase):
dw2
=
dw
.
permute
(
0
,
2
,
1
,
3
).
clone
().
detach
().
contiguous
()
dw2
=
dw
.
permute
(
0
,
2
,
1
,
3
).
clone
().
detach
().
contiguous
()
dqkv2
,
_
=
mha
.
bwd
(
dw2
,
qkv_vs
,
S_
,
cu_seqlens
,
seqlens
,
0.0
,
s
)
if
b
<
4
:
dqkv2
,
_
,
_
=
mha
.
bwd_nl
(
dw2
,
qkv_vs
,
S_
,
cu_seqlens
,
0.0
,
s
)
else
:
dqkv2
,
_
=
mha
.
bwd
(
dw2
,
qkv_vs
,
S_
,
cu_seqlens
,
0.0
,
s
)
dqkv2
=
dqkv2
.
permute
(
0
,
2
,
1
,
3
).
view
(
b
,
s
,
h
,
3
,
d
)
dqkv2
=
dqkv2
.
permute
(
0
,
2
,
1
,
3
).
view
(
b
,
s
,
h
,
3
,
d
)
self
.
assertTrue
(
torch
.
allclose
(
qkv
.
grad
.
float
(),
dqkv2
.
float
(),
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
qkv
.
grad
.
float
(),
dqkv2
.
float
(),
atol
=
1e-3
))
def
test_128
(
self
):
def
test_128
(
self
):
self
.
run_test
(
128
)
self
.
run_test
(
128
,
32
)
def
test_256
(
self
):
def
test_256
(
self
):
self
.
run_test
(
256
)
self
.
run_test
(
256
,
32
)
def
test_384
(
self
):
def
test_384
(
self
):
self
.
run_test
(
384
)
self
.
run_test
(
384
,
32
)
def
test_512
(
self
):
def
test_512
(
self
):
self
.
run_test
(
512
)
self
.
run_test
(
512
,
32
)
self
.
run_test
(
512
,
2
)
self
.
run_test
(
512
,
3
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
setup.py
View file @
ed719967
...
@@ -349,6 +349,7 @@ if "--fmha" in sys.argv:
...
@@ -349,6 +349,7 @@ if "--fmha" in sys.argv:
CUDAExtension
(
name
=
'fmhalib'
,
CUDAExtension
(
name
=
'fmhalib'
,
sources
=
[
sources
=
[
'apex/contrib/csrc/fmha/fmha_api.cpp'
,
'apex/contrib/csrc/fmha/fmha_api.cpp'
,
'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu'
,
'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment