Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
93f91cde
Commit
93f91cde
authored
Mar 17, 2020
by
Kexin Yu
Browse files
Merge remote-tracking branch 'upstream/master'
parents
33082d2b
80b90b9d
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5577 additions
and
1 deletion
+5577
-1
.gitmodules
.gitmodules
+4
-0
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+8
-0
apex/amp/lists/functional_overrides.py
apex/amp/lists/functional_overrides.py
+2
-1
apex/contrib/csrc/multihead_attn/cutlass
apex/contrib/csrc/multihead_attn/cutlass
+1
-0
apex/contrib/csrc/multihead_attn/dropout.h
apex/contrib/csrc/multihead_attn/dropout.h
+292
-0
apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp
apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp
+156
-0
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+556
-0
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp
...ib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp
+198
-0
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+639
-0
apex/contrib/csrc/multihead_attn/layer_norm.h
apex/contrib/csrc/multihead_attn/layer_norm.h
+740
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp
apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp
+132
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+471
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp
...trib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp
+173
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+556
-0
apex/contrib/csrc/multihead_attn/softmax.h
apex/contrib/csrc/multihead_attn/softmax.h
+1069
-0
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
+405
-0
apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py
...ntrib/examples/multihead_attn/perf_test_multihead_attn.py
+115
-0
apex/contrib/multihead_attn/MHA_bwd.png
apex/contrib/multihead_attn/MHA_bwd.png
+0
-0
apex/contrib/multihead_attn/MHA_fwd.png
apex/contrib/multihead_attn/MHA_fwd.png
+0
-0
apex/contrib/multihead_attn/README.md
apex/contrib/multihead_attn/README.md
+60
-0
No files found.
.gitmodules
0 → 100644
View file @
93f91cde
[submodule "apex/contrib/csrc/multihead_attn/cutlass"]
path = apex/contrib/csrc/multihead_attn/cutlass
url = https://github.com/NVIDIA/cutlass.git
branch = v1.2.0
apex/amp/_process_optimizer.py
View file @
93f91cde
...
@@ -92,6 +92,14 @@ def lazy_init_with_master_weights(self):
...
@@ -92,6 +92,14 @@ def lazy_init_with_master_weights(self):
def
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
,
scale_override
=
None
):
def
post_backward_models_are_masters
(
scaler
,
params
,
stashed_grads
,
scale_override
=
None
):
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scaler
.
loss_scale
(),
1.0
,
1.0
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scaler
.
loss_scale
(),
1.0
,
1.0
# not much to do if scale == 1.0 and static scaling
if
scaler
.
loss_scale
()
==
1.0
and
not
scaler
.
dynamic
:
# Clear the stash.
for
i
in
range
(
len
(
stashed_grads
)):
stashed_grads
[
i
]
=
None
return
if
scale_override
is
not
None
:
if
scale_override
is
not
None
:
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scale_override
grads_have_scale
,
stashed_have_scale
,
out_scale
=
scale_override
...
...
apex/amp/lists/functional_overrides.py
View file @
93f91cde
...
@@ -63,7 +63,8 @@ FP32_FUNCS = [
...
@@ -63,7 +63,8 @@ FP32_FUNCS = [
'binary_cross_entropy_with_logits'
,
'binary_cross_entropy_with_logits'
,
'smooth_l1_loss'
,
'smooth_l1_loss'
,
'soft_margin_loss'
,
'soft_margin_loss'
,
'triplet_margin_loss'
'triplet_margin_loss'
,
'ctc_loss'
]
]
BANNED_FUNCS
=
[
BANNED_FUNCS
=
[
...
...
cutlass
@
ed2ed4d6
Subproject commit ed2ed4d667ce95e1371bd62db32b6a114e774336
apex/contrib/csrc/multihead_attn/dropout.h
0 → 100644
View file @
93f91cde
#include <ATen/ATen.h>
#include <ATen/CUDAGenerator.h>
#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
#include <THC/THCGeneral.h>
const
int
UNROLL
=
4
;
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_fused_dropout_kernel
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
,
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
)
{
accscalar_t
pinv
=
accscalar_t
(
1
)
/
p
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
float4
rand
=
curand_uniform4
(
&
state
);
scalar_t
src
[
UNROLL
];
rand
.
x
=
rand
.
x
<
p
;
rand
.
y
=
rand
.
y
<
p
;
rand
.
z
=
rand
.
z
<
p
;
rand
.
w
=
rand
.
w
<
p
;
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
src
[
ii
]
=
inputs
[
li
];
}
}
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
src
[
ii
]
*
static_cast
<
scalar_t
>
((
&
rand
.
x
)[
ii
]
*
pinv
);
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
}
}
__syncthreads
();
}
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_dropout_add_kernel
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
,
std
::
pair
<
uint64_t
,
uint64_t
>
seeds
)
{
accscalar_t
pinv
=
accscalar_t
(
1
)
/
p
;
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seeds
.
first
,
idx
,
seeds
.
second
,
&
state
);
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
float4
rand
=
curand_uniform4
(
&
state
);
scalar_t
src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
rand
.
x
=
rand
.
x
<
p
;
rand
.
y
=
rand
.
y
<
p
;
rand
.
z
=
rand
.
z
<
p
;
rand
.
w
=
rand
.
w
<
p
;
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
src
[
ii
]
=
inputs
[
li
];
add_src
[
ii
]
=
add_inputs
[
li
];
}
}
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
accscalar_t
int1
=
static_cast
<
accscalar_t
>
((
&
rand
.
x
)[
ii
])
*
static_cast
<
accscalar_t
>
(
src
[
ii
]);
accscalar_t
int2
=
int1
*
static_cast
<
accscalar_t
>
(
pinv
);
outputs
[
li
]
=
static_cast
<
scalar_t
>
(
static_cast
<
accscalar_t
>
(
add_src
[
ii
])
+
int2
);
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
}
}
__syncthreads
();
}
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_add_kernel
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
IndexType
totalElements
)
{
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
scalar_t
src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
src
[
ii
]
=
inputs
[
li
];
add_src
[
ii
]
=
add_inputs
[
li
];
}
}
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
src
[
ii
]
+
add_src
[
ii
];
}
}
__syncthreads
();
}
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
__global__
void
apex_masked_scale_kernel
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
const
*
mask
,
IndexType
totalElements
,
accscalar_t
scale
)
{
IndexType
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
IndexType
rounded_size
=
((
totalElements
-
1
)
/
(
blockDim
.
x
*
gridDim
.
x
*
UNROLL
)
+
1
)
*
blockDim
.
x
*
gridDim
.
x
*
UNROLL
;
for
(
IndexType
linearIndex
=
idx
;
linearIndex
<
rounded_size
;
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
scalar_t
src
[
UNROLL
];
scalar_t
msk
[
UNROLL
];
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
src
[
ii
]
=
static_cast
<
scalar_t
>
(
inputs
[
li
]);
msk
[
ii
]
=
static_cast
<
scalar_t
>
(
mask
[
li
]);
}
}
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
static_cast
<
scalar_t
>
(
src
[
ii
]
*
static_cast
<
scalar_t
>
(
scale
))
*
msk
[
ii
];
}
}
}
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
void
apex_fused_dropout_cuda
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
)
{
auto
gen
=
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
();
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
//number of times random will be generated per thread, to offset philox counter in thc random state
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
rng_engine_inputs
=
gen
->
philox_engine_inputs
(
counter_offset
);
}
apex_fused_dropout_kernel
<
scalar_t
,
accscalar_t
,
IndexType
><<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
outputs
,
mask
,
totalElements
,
p
,
rng_engine_inputs
);
THCudaCheck
(
cudaGetLastError
());
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
void
apex_dropout_add_cuda
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
uint8_t
*
mask
,
IndexType
totalElements
,
accscalar_t
p
)
{
auto
gen
=
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
();
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
//number of times random will be generated per thread, to offset philox counter in thc random state
int64_t
counter_offset
=
((
totalElements
-
1
)
/
(
block_size
*
grid
.
x
*
UNROLL
)
+
1
)
*
UNROLL
;
std
::
pair
<
uint64_t
,
uint64_t
>
rng_engine_inputs
;
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
rng_engine_inputs
=
gen
->
philox_engine_inputs
(
counter_offset
);
}
apex_dropout_add_kernel
<
scalar_t
,
accscalar_t
,
IndexType
><<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
add_inputs
,
outputs
,
mask
,
totalElements
,
p
,
rng_engine_inputs
);
THCudaCheck
(
cudaGetLastError
());
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
void
apex_add_cuda
(
scalar_t
const
*
inputs
,
scalar_t
const
*
add_inputs
,
scalar_t
*
outputs
,
IndexType
totalElements
)
{
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
apex_add_kernel
<
scalar_t
,
accscalar_t
,
IndexType
><<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
add_inputs
,
outputs
,
totalElements
);
THCudaCheck
(
cudaGetLastError
());
}
template
<
typename
scalar_t
,
typename
accscalar_t
,
typename
IndexType
>
void
apex_masked_scale_cuda
(
scalar_t
const
*
inputs
,
scalar_t
*
outputs
,
uint8_t
const
*
mask
,
IndexType
totalElements
,
accscalar_t
scale
)
{
int
block_size
=
256
;
dim3
dim_block
(
block_size
);
dim3
grid
((
totalElements
+
block_size
-
1
)
/
block_size
);
unsigned
int
blocks_per_sm
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxThreadsPerMultiProcessor
/
block_size
;
grid
.
x
=
std
::
min
((
unsigned
int
)
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
*
blocks_per_sm
,
grid
.
x
);
apex_masked_scale_kernel
<
scalar_t
,
accscalar_t
,
IndexType
><<<
grid
,
dim_block
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
inputs
,
outputs
,
mask
,
totalElements
,
scale
);
THCudaCheck
(
cudaGetLastError
());
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp
0 → 100644
View file @
93f91cde
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
encdec
{
namespace
cublas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs_q
,
inputs_kv
,
input_weights_q
,
input_weights_kv
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_q_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_kv_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_q_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_kv_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_q_results
,
input_lin_kv_results
,
inputs_q
,
inputs_kv
,
input_weights_q
,
input_weights_kv
,
output_weights
,
dropout_mask
,
dropout_prob
);
}
}
// end namespace cublas_gemmex
}
// end namespace encdec
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
cublas_gemmex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
cublas_gemmex
::
bwd
,
"Encdec Multihead Attention Backward."
);
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
0 → 100644
View file @
93f91cde
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
namespace
multihead_attn
{
namespace
encdec
{
namespace
cublas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
k_seq_len
=
inputs_kv
.
size
(
0
);
const
int
batches_q
=
sequences
*
q_seq_len
;
const
int
batches_kv
=
sequences
*
k_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_q_dim
=
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto
act_options
=
inputs_q
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_q_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_q_dim
},
act_options
);
torch
::
Tensor
input_lin_kv_results
=
torch
::
empty
({
k_seq_len
,
sequences
,
output_lin_kv_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs_q
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_q_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_kv_results
.
data_ptr
());
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Q Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
CUDA_R_16F
,
output_lin_q_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
batches_kv
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
CUDA_R_16F
,
output_lin_kv_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim_q
,
batch_stride_q
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
}
// Matmul2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_lin_q_results
,
input_lin_kv_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
k_seq_len
=
inputs_kv
.
size
(
0
);
const
int
batches_q
=
sequences
*
q_seq_len
;
const
int
batches_kv
=
sequences
*
k_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_q_dim
=
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_q_grads
=
torch
::
empty_like
(
inputs_q
);
torch
::
Tensor
input_kv_grads
=
torch
::
empty_like
(
inputs_kv
);
torch
::
Tensor
input_weight_q_grads
=
torch
::
empty_like
(
input_weights_q
);
torch
::
Tensor
input_weight_kv_grads
=
torch
::
empty_like
(
input_weights_kv
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
at
::
Tensor
input_lin_kv_output_grads
=
torch
::
empty_like
(
input_lin_kv_results
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_q_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
());
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_q_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
());
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
())
+
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Output Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Softmax Grad
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim_q
,
batch_stride_q
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
// Input Linear Q Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Q Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_q_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_kv
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_kv_dim
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_q_grads
,
input_kv_grads
,
input_weight_q_grads
,
input_weight_kv_grads
,
output_weight_grads
};
}
}
// end namespace cublas_gemmex
}
// end namespace encdec
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp
0 → 100644
View file @
93f91cde
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
encdec_norm_add
{
namespace
cublas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs_q
,
inputs_kv
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights_q
,
input_weights_kv
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_q_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_kv_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_mean
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
inputs_q
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs_kv
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
input_weights_q
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights_kv
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_add_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_q_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_kv_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_mean
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
inputs_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_q
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights_kv
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
dropout_add_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_q_results
,
input_lin_kv_results
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
inputs_q
,
inputs_kv
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights_q
,
input_weights_kv
,
output_weights
,
dropout_mask
,
dropout_add_mask
,
dropout_prob
);
}
}
// end namespace cublas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec_norm_add
::
cublas_gemmex
::
fwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec_norm_add
::
cublas_gemmex
::
bwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."
);
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
0 → 100644
View file @
93f91cde
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
namespace
multihead_attn
{
namespace
encdec_norm_add
{
namespace
cublas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
k_seq_len
=
inputs_kv
.
size
(
0
);
const
int
batches_q
=
sequences
*
q_seq_len
;
const
int
batches_kv
=
sequences
*
k_seq_len
;
const
int
total_tokens_q
=
batches_q
*
embed_dim
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_q_dim
=
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto
act_options
=
inputs_q
.
options
().
requires_grad
(
false
);
auto
lyr_nrm_options
=
act_options
.
dtype
(
torch
::
kFloat32
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
lyr_nrm_mean
=
torch
::
empty
({
batches_q
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_invvar
=
torch
::
empty
({
batches_q
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_results
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
input_lin_q_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_q_dim
},
act_options
);
torch
::
Tensor
input_lin_kv_results
=
torch
::
empty
({
k_seq_len
,
sequences
,
output_lin_kv_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
output_lin_results
=
torch
::
empty_like
(
inputs_q
,
act_options
);
torch
::
Tensor
dropout_add_mask
=
torch
::
empty_like
(
inputs_q
,
mask_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs_q
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_q_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_kv_results
.
data_ptr
());
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
int
>
(
batches_q
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
1.0e-5
,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
CUDA_R_16F
,
output_lin_q_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
batches_kv
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
CUDA_R_16F
,
output_lin_kv_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim_q
,
batch_stride_q
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
}
// Matmul2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// End-of-block Dropout-Add
if
(
is_training
)
{
apex_dropout_add_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
half
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
half
*>
(
outputs
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens_q
,
(
1.0
f
-
dropout_prob
));
}
else
{
apex_add_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
half
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
half
*>
(
outputs
.
data_ptr
()),
total_tokens_q
);
}
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
input_lin_q_results
,
input_lin_kv_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
dropout_add_mask
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_q_results
,
torch
::
Tensor
const
&
input_lin_kv_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs_q
,
torch
::
Tensor
const
&
inputs_kv
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights_q
,
torch
::
Tensor
const
&
input_weights_kv
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs_q
.
size
(
2
);
const
int
sequences
=
inputs_q
.
size
(
1
);
const
int
q_seq_len
=
inputs_q
.
size
(
0
);
const
int
k_seq_len
=
inputs_kv
.
size
(
0
);
const
int
batches_q
=
sequences
*
q_seq_len
;
const
int
batches_kv
=
sequences
*
k_seq_len
;
const
int
total_tokens_q
=
batches_q
*
embed_dim
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_q_dim
=
embed_dim
;
const
int
output_lin_kv_dim
=
2
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim_q
=
attn_batches
*
head_dim
;
const
int
lead_dim_kv
=
attn_batches
*
2
*
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_q_grads
=
torch
::
empty_like
(
inputs_q
);
torch
::
Tensor
input_kv_grads
=
torch
::
empty_like
(
inputs_kv
);
torch
::
Tensor
lyr_nrm_gamma_grads
=
torch
::
empty_like
(
lyr_nrm_gamma_weights
);
torch
::
Tensor
lyr_nrm_beta_grads
=
torch
::
empty_like
(
lyr_nrm_beta_weights
);
torch
::
Tensor
input_weight_q_grads
=
torch
::
empty_like
(
input_weights_q
);
torch
::
Tensor
input_weight_kv_grads
=
torch
::
empty_like
(
input_weights_kv
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
at
::
Tensor
input_lin_kv_output_grads
=
torch
::
empty_like
(
input_lin_kv_results
);
at
::
Tensor
input_lin_q_grads
=
torch
::
empty_like
(
inputs_q
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_q_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
());
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_kv_results
.
data_ptr
())
+
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_q_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
());
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_kv_output_grads
.
data_ptr
())
+
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
apex_masked_scale_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens_q
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Output Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Softmax Grad
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim_q
,
batch_stride_q
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
// Input Linear Q Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_q
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Q Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_q_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches_kv
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear KV Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_kv_dim
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
static_cast
<
const
half
*>
(
input_lin_q_grads
.
data_ptr
()),
static_cast
<
half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
const
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
inputs_q
,
static_cast
<
int
>
(
batches_q
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
static_cast
<
const
half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()),
1.0e-5
,
static_cast
<
half
*>
(
input_q_grads
.
data_ptr
()),
static_cast
<
half
*>
(
lyr_nrm_gamma_grads
.
data_ptr
()),
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_q_grads
,
input_kv_grads
,
lyr_nrm_gamma_grads
,
lyr_nrm_beta_grads
,
input_weight_q_grads
,
input_weight_kv_grads
,
output_weight_grads
};
}
}
// end namespace cublas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/layer_norm.h
0 → 100644
View file @
93f91cde
#include "ATen/ATen.h"
#include <THC/THCDeviceUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
mu
=
lmean
;
U
delta2
=
curr
-
lmean
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
template
<
typename
U
>
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nB
=
countB
;
count
=
count
+
countB
;
U
nX
=
count
;
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
mu
=
U
(
0
);
sigma2
=
U
(
0
);
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U
count
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
}
}
}
template
<
>
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float
count
=
0.0
f
;
mu
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// first thread consumes first point
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
++
l
;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
}
}
}
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template
<
typename
T
>
struct
SharedMemory
;
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
getPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
template
<
typename
T
,
typename
U
>
__global__
void
cuApplyLayerNorm
(
T
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
T
*
__restrict__
gamma
,
const
T
*
__restrict__
beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
T
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
));
}
}
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
mean
[
i1
]
=
mu
;
invvar
[
i1
]
=
c_invvar
;
}
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
T
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
T
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
template
<
typename
T
,
typename
U
>
__global__
void
cuComputePartGradGammaBeta
(
const
T
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
__syncthreads
();
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
T
,
typename
U
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
T
*
grad_gamma
,
T
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
template
<
typename
T
,
typename
U
>
__global__
void
cuComputeGradInput
(
const
T
*
__restrict__
dout
,
const
T
*
__restrict__
dout_resid
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
T
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
T
*
k_dout
=
dout
+
i1
*
n2
;
const
T
*
k_dout_resid
=
dout_resid
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
+
k
]);
sum_loss2
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
+
k
])
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
]);
sum_loss2
+=
c_loss
*
static_cast
<
U
>
(
gamma
[
l
])
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
else
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
T
c_resid
=
static_cast
<
T
>
(
k_dout_resid
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
static_cast
<
U
>
(
gamma
[
l
]);
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
)
+
c_resid
;
}
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
T
c_resid
=
static_cast
<
T
>
(
k_dout_resid
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
)
+
c_resid
;
}
}
}
}
template
<
typename
T
,
typename
U
>
void
HostApplyLayerNorm
(
T
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
T
*
gamma
,
const
T
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
template
<
typename
T
,
typename
U
>
void
HostLayerNormGradient
(
const
T
*
dout
,
const
T
*
dout_resid
,
const
U
*
mean
,
const
U
*
invvar
,
const
at
::
Tensor
&
input
,
int
n1
,
int
n2
,
const
T
*
gamma
,
const
T
*
beta
,
double
epsilon
,
T
*
grad_input
,
T
*
grad_gamma
,
T
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
static_cast
<
T
*>
(
input
.
data_ptr
()),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
static_cast
<
U
*>
(
part_grad_gamma
.
data_ptr
()),
static_cast
<
U
*>
(
part_grad_beta
.
data_ptr
()));
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
static_cast
<
U
*>
(
part_grad_gamma
.
data_ptr
()),
static_cast
<
U
*>
(
part_grad_beta
.
data_ptr
()),
part_size
,
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
dout_resid
,
static_cast
<
T
*>
(
input
.
data_ptr
()),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp
0 → 100644
View file @
93f91cde
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
self
{
namespace
cublas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs
,
input_weights
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_results
,
inputs
,
input_weights
,
output_weights
,
dropout_mask
,
dropout_prob
);
}
}
// end namespace cublas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
cublas_gemmex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
cublas_gemmex
::
bwd
,
"Self Multihead Attention Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
0 → 100644
View file @
93f91cde
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
namespace
multihead_attn
{
namespace
self
{
namespace
cublas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto
act_options
=
inputs
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
CUDA_R_16F
,
output_lin_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
}
// Matmul2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_grads
=
torch
::
empty_like
(
inputs
);
torch
::
Tensor
input_weight_grads
=
torch
::
empty_like
(
input_weights
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_output_grads
=
torch
::
empty_like
(
input_lin_results
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
;
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
head_dim
;
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
2
*
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Output Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Softmax Grad
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Input Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_grads
,
input_weight_grads
,
output_weight_grads
};
}
}
// end namespace cublas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp
0 → 100644
View file @
93f91cde
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
self_norm_add
{
namespace
cublas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
);
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
torch
::
Tensor
>
fwd
(
bool
use_mask
,
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
pad_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
}
return
fwd_cuda
(
use_time_mask
,
is_training
,
heads
,
inputs
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights
,
output_weights
,
use_mask
?
static_cast
<
const
uint8_t
*>
(
pad_mask
.
data_ptr
())
:
nullptr
,
dropout_prob
);
}
std
::
vector
<
torch
::
Tensor
>
bwd
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_mean
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
dim
()
==
1
,
"expected 1D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
output_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
dropout_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_add_mask
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_mean
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
lyr_nrm_invvar
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"Only FLOAT is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_gamma_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
lyr_nrm_beta_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
AT_ASSERTM
(
dropout_add_mask
.
type
().
scalarType
()
==
at
::
ScalarType
::
Byte
,
"Only BYTE is supported"
);
return
bwd_cuda
(
heads
,
output_grads
,
matmul2_results
,
dropout_results
,
softmax_results
,
input_lin_results
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
inputs
,
lyr_nrm_gamma_weights
,
lyr_nrm_beta_weights
,
input_weights
,
output_weights
,
dropout_mask
,
dropout_add_mask
,
dropout_prob
);
}
}
// end namespace cublas_gemmex
}
// end namespace self_norm_add
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_norm_add
::
cublas_gemmex
::
fwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_norm_add
::
cublas_gemmex
::
bwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
0 → 100644
View file @
93f91cde
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "dropout.h"
#include "layer_norm.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
namespace
multihead_attn
{
namespace
self_norm_add
{
namespace
cublas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
is_training
,
int
heads
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
const
uint8_t
*
pad_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
total_tokens
=
batches
*
embed_dim
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto
act_options
=
inputs
.
options
().
requires_grad
(
false
);
auto
lyr_nrm_options
=
act_options
.
dtype
(
torch
::
kFloat32
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
lyr_nrm_mean
=
torch
::
empty
({
batches
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_invvar
=
torch
::
empty
({
batches
},
lyr_nrm_options
);
torch
::
Tensor
lyr_nrm_results
=
torch
::
empty_like
(
inputs
,
act_options
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
output_lin_results
=
torch
::
empty_like
(
inputs
,
act_options
);
torch
::
Tensor
dropout_add_mask
=
torch
::
empty_like
(
inputs
,
mask_options
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
,
act_options
);
// Input Linear Results Pointers to Q, K, and V of interviewed activations
void
*
q_lin_results_ptr
=
static_cast
<
void
*>
(
input_lin_results
.
data_ptr
());
void
*
k_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
inputs
.
data_ptr
()),
static_cast
<
int
>
(
batches
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
1.0e-5
,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
CUDA_R_16F
,
output_lin_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
scale
,
static_cast
<
const
half
*>
(
k_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
q_lin_results_ptr
),
lead_dim
,
batch_stride
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
}
else
{
if
(
use_time_mask
)
{
softmax_success
=
dispatch_time_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
q_seq_len
);
}
else
{
softmax_success
=
dispatch_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
pad_mask
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
);
}
}
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
half
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
}
// Matmul2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// End-of-block Dropout-Add
if
(
is_training
)
{
apex_dropout_add_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
half
const
*>
(
inputs
.
data_ptr
()),
static_cast
<
half
*>
(
outputs
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens
,
(
1.0
f
-
dropout_prob
));
}
else
{
apex_add_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
half
const
*>
(
inputs
.
data_ptr
()),
static_cast
<
half
*>
(
outputs
.
data_ptr
()),
total_tokens
);
}
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_invvar
,
input_lin_results
,
softmax_results
,
dropout_results
,
dropout_mask
,
matmul2_results
,
dropout_add_mask
,
outputs
};
}
std
::
vector
<
torch
::
Tensor
>
bwd_cuda
(
int
heads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
lyr_nrm_results
,
torch
::
Tensor
const
&
lyr_nrm_mean
,
torch
::
Tensor
const
&
lyr_nrm_invvar
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
lyr_nrm_gamma_weights
,
torch
::
Tensor
const
&
lyr_nrm_beta_weights
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
output_weights
,
torch
::
Tensor
const
&
dropout_mask
,
torch
::
Tensor
const
&
dropout_add_mask
,
float
dropout_prob
)
{
const
int
embed_dim
=
inputs
.
size
(
2
);
const
int
sequences
=
inputs
.
size
(
1
);
const
int
q_seq_len
=
inputs
.
size
(
0
);
const
int
k_seq_len
=
q_seq_len
;
const
int
batches
=
sequences
*
q_seq_len
;
const
int
total_tokens
=
batches
*
embed_dim
;
const
int
head_dim
=
embed_dim
/
heads
;
const
int
output_lin_dim
=
3
*
embed_dim
;
const
int
attn_batches
=
heads
*
sequences
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
// Output Tensor Allocations
torch
::
Tensor
input_grads
=
torch
::
empty_like
(
inputs
);
torch
::
Tensor
lyr_nrm_gamma_grads
=
torch
::
empty_like
(
lyr_nrm_gamma_weights
);
torch
::
Tensor
lyr_nrm_beta_grads
=
torch
::
empty_like
(
lyr_nrm_beta_weights
);
torch
::
Tensor
input_weight_grads
=
torch
::
empty_like
(
input_weights
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
torch
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
torch
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
torch
::
Tensor
input_lin_output_grads
=
torch
::
empty_like
(
input_lin_results
);
torch
::
Tensor
input_lin_grads
=
torch
::
empty_like
(
inputs
);
auto
q_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
());
auto
k_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
head_dim
;
auto
v_lin_results_ptr
=
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
;
auto
q_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
());
auto
k_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
head_dim
;
auto
v_lin_grads_ptr
=
static_cast
<
half
*>
(
input_lin_output_grads
.
data_ptr
())
+
2
*
head_dim
;
char
a_layout_n
{
'n'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
apex_masked_scale_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Output Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
batch_stride
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
half
,
float
,
uint32_t
>
(
static_cast
<
half
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
// Softmax Grad
bool
softmax_success
=
false
;
softmax_success
=
dispatch_softmax_backward
<
half
,
half
,
float
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
head_dim
,
q_seq_len
,
k_seq_len
,
scale
,
k_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
scale
,
q_lin_results_ptr
,
lead_dim
,
batch_stride
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Input Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
batches
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
static_cast
<
const
half
*>
(
input_lin_grads
.
data_ptr
()),
static_cast
<
half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
float
*>
(
lyr_nrm_mean
.
data_ptr
()),
static_cast
<
const
float
*>
(
lyr_nrm_invvar
.
data_ptr
()),
inputs
,
static_cast
<
int
>
(
batches
),
// n1
static_cast
<
int
>
(
embed_dim
),
// n2
static_cast
<
const
half
*>
(
lyr_nrm_gamma_weights
.
data_ptr
()),
static_cast
<
const
half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()),
1.0e-5
,
static_cast
<
half
*>
(
input_grads
.
data_ptr
()),
static_cast
<
half
*>
(
lyr_nrm_gamma_grads
.
data_ptr
()),
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_grads
,
lyr_nrm_gamma_grads
,
lyr_nrm_beta_grads
,
input_weight_grads
,
output_weight_grads
};
}
}
// end namespace cublas_gemmex
}
// end namespace self_norm_add
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/softmax.h
0 → 100644
View file @
93f91cde
#pragma once
#include <assert.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <cmath>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
float
,
1
>
(
float
*
dst
,
const
float
*
src
)
{
*
dst
=
*
src
;
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
apply_mask
(
Datatype
*
dst
,
Datatype
value
,
const
uint8_t
*
src
);
template
<
>
__device__
__inline__
void
apply_mask
<
__half
,
1
>
(
__half
*
dst
,
__half
value
,
const
uint8_t
*
src
)
{
if
(
*
src
==
1
)
{
*
dst
=
value
;
}
}
}
// namespace anonymous
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp Softmax forward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
softmax_warp_forward
(
input_t
*
dst
,
const
output_t
*
src
,
int
batch_size
,
int
stride
,
int
element_count
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
1
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements_input
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
src
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
//elements[i][it] = expf(elements[i][it] - max_value[i]);
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
//dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
softmax_forward_func
=
void
(
*
)(
input_t
*
dst
,
const
output_t
*
src
,
int
batch_size
,
int
stride
,
int
element_count
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_softmax_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
softmax_forward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_softmax
(
output_t
*
dst
,
const
input_t
*
src
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
>>>
(
dst
,
src
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
return
true
;
}
return
false
;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
masked_softmax_warp_forward
(
input_t
*
dst
,
const
output_t
*
src
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
1
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
// load data from global memory
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
pad_thread_offset
=
(
(
first_batch
+
i
)
/
pad_batch_stride
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
uint8_t
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements_input
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
if
(
element_index
<
batch_element_count
)
{
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
src
+
itr_idx
);
apply_mask
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
(
__half
)
-
std
::
numeric_limits
<
float
>::
infinity
(),
curr_mask
+
itr_jmp
);
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
//elements[i][it] = expf(elements[i][it] - max_value[i]);
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
//dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
masked_softmax_forward_func
=
void
(
*
)(
input_t
*
dst
,
const
output_t
*
src
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_masked_softmax_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
masked_softmax_forward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_masked_softmax
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
pad_batch_stride
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
masked_softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_masked_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
>>>
(
dst
,
src
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batch_stride
);
return
true
;
}
return
false
;
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
time_masked_softmax_warp_forward
(
input_t
*
dst
,
const
output_t
*
src
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
mod_seq_len
)
{
assert
(
ELEMENTS_PER_LDG_STG
==
1
);
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
src
+=
thread_offset
;
dst
+=
thread_offset
;
// load data from global memory
input_t
elements_input
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
pad_thread_offset
=
(
(
first_batch
+
i
)
%
mod_seq_len
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
uint8_t
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements_input
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
if
(
element_index
<
batch_element_count
)
{
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
src
+
itr_idx
);
apply_mask
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
elements_input
[
i
][
it
],
(
__half
)
-
std
::
numeric_limits
<
float
>::
infinity
(),
curr_mask
+
itr_jmp
);
}
}
}
// convert input_t to acc_t
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
elements_input
[
i
][
it
];
}
}
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
// compute local max_value
// take the max_value of the first element to avoid one max call
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
}
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
// reduction max_value
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
max_value
[
i
]
>
val
[
i
]
?
max_value
[
i
]
:
val
[
i
];
}
}
// compute local sum
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
//elements[i][it] = expf(elements[i][it] - max_value[i]);
elements
[
i
][
it
]
=
std
::
exp
(
elements
[
i
][
it
]
-
max_value
[
i
]);
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
// reduction sum
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
//dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i];
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
time_masked_softmax_forward_func
=
void
(
*
)(
input_t
*
dst
,
const
output_t
*
src
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
mod_seq_len
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_time_masked_softmax_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
time_masked_softmax_forward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
time_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_time_masked_softmax
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
mod_seq_len
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
time_masked_softmax_forward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_time_masked_softmax_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
>>>
(
dst
,
src
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
mod_seq_len
);
return
true
;
}
return
false
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp softmax backward
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
softmax_warp_backward
(
__half
*
gradInput
,
const
__half
*
grad
,
const
__half
*
output
,
int
batch_size
,
int
stride
,
int
element_count
)
{
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
input_t
grad_reg_input
[
WARP_BATCH
][
WARP_ITERATIONS
]
=
{
0.0
f
};
input_t
output_reg_input
[
WARP_BATCH
][
WARP_ITERATIONS
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
grad_reg_input
[
i
][
it
],
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
output_reg_input
[
i
][
it
],
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
}
}
// convert half to floating point
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
grad_reg
[
i
][
it
]
=
grad_reg_input
[
i
][
it
];
output_reg
[
i
][
it
]
=
output_reg_input
[
i
][
it
];
}
}
// compute thread local sum
acc_t
sum
[
WARP_BATCH
]
=
{
0
};
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
]
*
output_reg
[
i
][
it
];
}
}
// reduction sum
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_reg
[
i
][
it
+
element
]
*
(
grad_reg
[
i
][
it
+
element
]
-
sum
[
i
]));
}
// store them in global memory
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
softmax_backward_func
=
void
(
*
)(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
output
,
int
batch_size
,
int
stride
,
int
element_count
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_softmax_backward_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
softmax_backward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_softmax_backward
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
softmax_backward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_softmax_backward_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
>>>
(
grad_input
,
grad
,
output
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
return
true
;
}
return
false
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_ITERATIONS
,
int
WARP_SIZE
=
32
,
int
ELEMENTS_PER_LDG_STG
=
1
>
__global__
void
masked_softmax_warp_backward
(
__half
*
gradInput
,
const
__half
*
grad
,
const
__half
*
output
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
)
{
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
input_t
grad_reg_input
[
WARP_BATCH
][
WARP_ITERATIONS
]
=
{
0.0
f
};
input_t
output_reg_input
[
WARP_BATCH
][
WARP_ITERATIONS
]
=
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
grad_reg_input
[
i
][
it
],
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
output_reg_input
[
i
][
it
],
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
}
}
// convert half to floating point
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
grad_reg
[
i
][
it
]
=
grad_reg_input
[
i
][
it
];
output_reg
[
i
][
it
]
=
output_reg_input
[
i
][
it
];
}
}
// compute thread local sum
acc_t
sum
[
WARP_BATCH
]
=
{
0
};
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
]
*
output_reg
[
i
][
it
];
}
}
// reduction sum
constexpr
uint32_t
FULL_MASK
=
0xffffffff
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
int
pad_thread_offset
=
(
(
first_batch
+
i
)
/
pad_batch_stride
)
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
const
uint8_t
*
curr_mask
=
pad_mask
+
pad_thread_offset
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_reg
[
i
][
it
+
element
]
*
(
grad_reg
[
i
][
it
+
element
]
-
sum
[
i
]));
}
// store them in global memory
int
itr_jmp
=
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
itr_jmp
;
// It is kind of unfortunate this has to be here to zero something out that is close to
// zero in the first place
apply_mask
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
&
out
[
0
],
0.0
,
curr_mask
+
itr_jmp
);
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
itr_idx
,
out
);
}
}
}
}
// WARP_BATCH number of batches.
// WARP_ITERATOINS The number of iterations required for one warp to iterate over all data.
// WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1.
template
<
typename
input_t
,
typename
output_t
>
using
masked_softmax_backward_func
=
void
(
*
)(
output_t
*
gradInput
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
pad_mask
,
int
batch_size
,
int
stride
,
int
element_count
,
int
pad_batch_stride
);
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
warp_masked_softmax_backward_kernel
(
int
log2_elements
,
int
&
warp_size
,
int
&
batches_per_warp
,
masked_softmax_backward_func
<
input_t
,
output_t
>
&
kernel
)
{
// determine size of a warp
const
int
next_power_of_two
=
1
<<
log2_elements
;
warp_size
=
(
next_power_of_two
<
32
)
?
next_power_of_two
:
32
;
// determine how many batches a warp should process.
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
switch
(
log2_elements
)
{
case
0
:
// 1
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
1
,
1
>
;
break
;
case
1
:
// 2
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
2
,
1
>
;
break
;
case
2
:
// 4
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
4
,
1
>
;
break
;
case
3
:
// 8
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
8
,
1
>
;
break
;
case
4
:
// 16
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
16
,
1
>
;
break
;
case
5
:
// 32
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
1
,
32
,
1
>
;
break
;
case
6
:
// 64
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
2
,
32
,
1
>
;
break
;
case
7
:
// 128
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
,
4
,
32
,
1
>
;
break
;
case
8
:
// 256
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
8
,
32
,
1
>
;
break
;
case
9
:
// 512
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
16
,
32
,
1
>
;
break
;
case
10
:
// 1024
kernel
=
&
masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
,
32
,
32
,
1
>
;
break
;
default:
return
false
;
}
return
true
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
bool
dispatch_masked_softmax_backward
(
output_t
*
grad_input
,
const
input_t
*
grad
,
const
input_t
*
output
,
const
uint8_t
*
pad_mask
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
batch_count
,
int
pad_batch_stride
)
{
if
(
softmax_elements
==
0
)
{
return
true
;
}
else
if
(
softmax_elements
<=
1024
)
{
// compute function index. there's a function for each power of two size up to 1024.
int
log2_elements
=
0
;
while
((
1
<<
log2_elements
)
<
softmax_elements
)
++
log2_elements
;
masked_softmax_backward_func
<
input_t
,
output_t
>
kernel
;
int
warp_size
,
batches_per_warp
;
if
(
!
warp_masked_softmax_backward_kernel
<
input_t
,
output_t
,
acc_t
>
(
log2_elements
,
warp_size
,
batches_per_warp
,
kernel
))
{
return
false
;
}
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
// compute warps per block.
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
// compute launch size
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
(
batch_count
+
batches_per_block
-
1
)
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// launch
kernel
<<<
blocks
,
threads
>>>
(
grad_input
,
grad
,
output
,
pad_mask
,
batch_count
,
softmax_elements_stride
,
softmax_elements
,
pad_batch_stride
);
return
true
;
}
return
false
;
}
apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
0 → 100644
View file @
93f91cde
#include <vector>
#include <iostream>
//#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
cublasOperation_t
convertTransToCublasOperation
(
char
trans
)
{
if
(
trans
==
't'
)
return
CUBLAS_OP_T
;
else
if
(
trans
==
'n'
)
return
CUBLAS_OP_N
;
else
if
(
trans
==
'c'
)
return
CUBLAS_OP_C
;
else
{
THError
(
"trans must be one of: t, n, c"
);
return
CUBLAS_OP_T
;
}
}
void
CublasStridedBatchedGemm
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
,
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DEFAULT_TENSOR_OP
)
{
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cublasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck
(
cublasGemmStridedBatchedEx
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
CUDA_R_16F
,
(
int
)
lda
,
strideA
,
b
,
CUDA_R_16F
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
CUDA_R_16F
,
(
int
)
ldc
,
strideC
,
(
int
)
batchCount
,
CUDA_R_32F
,
algo
));
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
template
<
cutlass
::
MatrixLayout
::
Kind
A_LAYOUT
,
cutlass
::
MatrixLayout
::
Kind
B_LAYOUT
,
int
SRC_A
,
int
SRC_B
,
int
DST_C
>
void
CutlassGemm_FP32Accum
(
cudaStream_t
stream
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
typedef
cutlass
::
gemm
::
WmmaGemmTraits
<
A_LAYOUT
,
B_LAYOUT
,
cutlass
::
Shape
<
32
,
16
,
16
>
,
half
,
half
,
half
,
cutlass
::
gemm
::
LinearScaling
<
float
>
,
float
,
typename
cutlass
::
gemm
::
WmmaGemmAccumulatorsPerWarp
<
typename
cutlass
::
Shape
<
32
,
16
,
16
>
>::
Shape
,
typename
cutlass
::
Shape
<
16
,
16
,
16
>
,
SRC_A
,
//kScalarsPerLdgA_
SRC_B
,
//kScalarsPerLdgB_
SRC_A
,
//KScalarsPerLdsA_
SRC_B
,
//KScalarsPerLdsB_
DST_C
,
//kScalarsPerLdgCAndStgD_
DST_C
/
2
,
//kScalarsPerStsD_
DST_C
/
2
//kScalarsPerLdsD_
>
WmmaGemmTraits
;
typedef
cutlass
::
gemm
::
Gemm
<
WmmaGemmTraits
>
Gemm
;
typename
Gemm
::
Params
params
;
int
result
=
params
.
initialize
(
m
,
// M dimension for each batch
n
,
// N dimension for each batch
k
,
// K dimension for each batch
alpha
,
// scalar alpha
a
,
lda
,
strideA
,
// distance in memory between the first element of neighboring batch
b
,
ldb
,
strideB
,
// distance in memory between the first element of neighboring batch
beta
,
// scalar beta
c
,
// source matrix C
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
c
,
// destination matrix C (may be different memory than source C matrix)
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
batchCount
);
AT_ASSERTM
(
result
==
0
,
"Failed to initialize CUTLASS Gemm::Params object."
);
// batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is limited to 16 bits.
// To implement batched GEMM with larger batch size, we fragment it into
// smaller batched GEMMs of gridDim.z <= 64k
long
batchesLeft
=
batchCount
;
long
iterBatchCount
=
std
::
min
(
batchesLeft
,
static_cast
<
long
>
((
1
<<
16
)
-
1
));
do
{
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int
result
=
params
.
initialize
(
m
,
// M dimension for each batch
n
,
// N dimension for each batch
k
,
// K dimension for each batch
alpha
,
// scalar alpha
a
,
lda
,
strideA
,
// distance in memory between the first element of neighboring batch
b
,
ldb
,
strideB
,
// distance in memory between the first element of neighboring batch
beta
,
// scalar beta
c
,
// source matrix C
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
c
,
// destination matrix C (may be different memory than source C matrix)
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
iterBatchCount
);
AT_ASSERTM
(
result
==
0
,
"Failed to initialize CUTLASS Gemm::Params object."
);
// Launch the CUTLASS GEMM kernel.
THCudaCheck
(
Gemm
::
launch
(
params
));
// Update batched GEMM params based on completed work
batchesLeft
=
batchesLeft
-
iterBatchCount
;
a
+=
iterBatchCount
*
strideA
;
b
+=
iterBatchCount
*
strideB
;
c
+=
iterBatchCount
*
strideC
;;
iterBatchCount
=
std
::
min
(
batchesLeft
,
static_cast
<
long
>
((
1
<<
16
)
-
1
));
}
while
(
batchesLeft
>
0
);
}
void
gemm_switch_fp32accum
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
//printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
}
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
int m_rem = m % 64;
int n_rem = n % 64;
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kRowMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
}*/
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
}
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
int m_rem = m % 64;
int n_rem = n % 64;
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kColumnMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
}*/
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
,
CUBLAS_GEMM_ALGO0_TENSOR_OP
);
}
/*if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
int m_rem = m % 64;
int n_rem = n % 64;
if ( (m_rem > 48) && ( m <= 192) && (n_rem > 48) && (n <= 192 ) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else if ( (m_rem > 32) && ( m > 192) && (n_rem > 32) && (n > 192) ) {
CublasStridedBatchedGemm(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP);
} else {
CutlassGemm_FP32Accum<cutlass::MatrixLayout::kColumnMajor,cutlass::MatrixLayout::kRowMajor,8,8,8>(stream, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
}*/
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasStridedBatchedGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
}
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
{
int
transa_
=
((
transa
==
't'
)
||
(
transa
==
'T'
));
int
transb_
=
((
transb
==
't'
)
||
(
transb
==
'T'
));
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result
// requires (even if the value won't be used).
if
(
n
<=
1
)
*
ldc
=
std
::
max
<
int64_t
>
(
m
,
1
);
if
(
transa_
)
{
if
(
m
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
else
{
if
(
k
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
m
,
1
);
}
if
(
transb_
)
{
if
(
k
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
n
,
1
);
}
else
{
if
(
n
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
}
void
HgemmStridedBatched
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
if
(
(
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
)
)
{
THError
(
"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d"
,
INT_MAX
);
}
adjustLdLevel3
(
transa
,
transb
,
m
,
n
,
k
,
&
lda
,
&
ldb
,
&
ldc
);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
/******
at::Tensor strided_batched_gemm_cuda(
float beta,
at::Tensor in_result,
float alpha,
at::Tensor batch1,
at::Tensor batch2) {
bool transpose_result;
char transpose_batch1, transpose_batch2;
int64_t lda, ldb, ldc;
at::Tensor result, input1, input2;
if (in_result.stride(1) == 1)
{
transpose_result = false;
result = in_result;
ldc = result.stride(2);
}
else if (in_result.stride(2) == 1)
{
transpose_result = true;
at::Tensor swap = batch2;
batch2 = batch1;
batch1 = swap;
result = in_result;
ldc = result.stride(1);
} else {
AT_ASSERTM(false, "result should be contiguous");
}
if (batch1.stride(transpose_result ? 2 : 1) == 1 &&
batch1.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch1 = 'n';
input1 = batch1;
lda = input1.stride(transpose_result ? 1 : 2);
} else if (batch1.stride(transpose_result ? 1 : 2) == 1 &&
batch1.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch1 = 't';
input1 = batch1;
lda = input1.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input1 should be contiguous");
}
if (batch2.stride(transpose_result ? 2 : 1) == 1 &&
batch2.stride(transpose_result ? 1 : 2) != 0) {
transpose_batch2 = 'n';
input2 = batch2;
ldb = input2.stride(transpose_result ? 1 : 2);
} else if (batch2.stride(transpose_result ? 1 : 2) == 1 &&
batch2.stride(transpose_result ? 2 : 1) != 0) {
transpose_batch2 = 't';
input2 = batch2;
ldb = input2.stride(transpose_result ? 2 : 1);
} else {
AT_ASSERTM(false, "input2 should be contiguous");
}
int64_t num_batches = result.size(0);
HgemmStridedBatched(
state,
transpose_batch1,
transpose_batch2,
result.size(transpose_result ? 2 : 1),
result.size(transpose_result ? 1 : 2),
input1.size(transpose_result ? 1 : 2),
alpha,
static_cast<const half*>(input1.data_ptr()), lda, input1.stride(0),
static_cast<const half*>(input2.data_ptr()), ldb, input2.stride(0),
beta,
static_cast<half*>(result.data_ptr()), ldc, result.stride(0),
num_batches);
return in_result;
}
***/
apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py
0 → 100644
View file @
93f91cde
import
torch
import
torch.nn.functional
as
F
import
argparse
from
apex.contrib.multihead_attn
import
SelfMultiheadAttn
from
apex.contrib.multihead_attn
import
EncdecMultiheadAttn
parser
=
argparse
.
ArgumentParser
(
description
=
'Multihead Attention Standalone Test'
)
parser
.
add_argument
(
'--seq-length'
,
default
=
64
,
type
=
int
,
help
=
'Sequence Length of Input'
)
parser
.
add_argument
(
'--num-seqs-start'
,
default
=
10
,
type
=
int
,
help
=
'Start Range of Number of Sequences'
)
parser
.
add_argument
(
'--num-seqs-stop'
,
default
=
120
,
type
=
int
,
help
=
'Stop Range of Number of Sequences'
)
parser
.
add_argument
(
'--num-seqs-inc'
,
default
=
5
,
type
=
int
,
help
=
'Range Increment of Number of Sequences'
)
parser
.
add_argument
(
'--trials'
,
default
=
20
,
type
=
int
,
help
=
'Number of Trials to Execute'
)
parser
.
add_argument
(
'--warmup-trials'
,
default
=
5
,
type
=
int
,
help
=
'Warmup Trials to discard'
)
parser
.
add_argument
(
'--layers'
,
default
=
18
,
type
=
int
,
help
=
'Attention Layers to Execute to Gain CPU/GPU Time Overlap'
)
parser
.
add_argument
(
'--hidden-dim'
,
default
=
1024
,
type
=
int
,
help
=
'Multihead Attention hidden dimension'
)
parser
.
add_argument
(
'--heads'
,
default
=
16
,
type
=
int
,
help
=
'Number of Multihead Attention heads'
)
parser
.
add_argument
(
'--encdec-attn'
,
action
=
'store_true'
,
help
=
'Use Encoder-Decoder Attention instead of Self Attention.'
)
parser
.
add_argument
(
'--norm-add'
,
action
=
'store_true'
,
help
=
'Include Layer Norm and Dropout-Add in Multihead Attention block.'
)
parser
.
add_argument
(
'--ref'
,
action
=
'store_true'
,
help
=
'Reference implementation in python pytorch.'
)
parser
.
add_argument
(
'--native'
,
action
=
'store_true'
,
help
=
'torch.nn.MultitheadAttention Version.'
)
parser
.
add_argument
(
'--fwd'
,
action
=
'store_true'
,
help
=
'Only execute Fwd Pass.'
)
parser
.
add_argument
(
'--biases'
,
action
=
'store_true'
,
help
=
'Execute multihead attention with Linear Biases.'
)
args
=
parser
.
parse_args
()
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Running on CPU is not supported'
)
torch
.
cuda
.
set_device
(
0
)
torch
.
manual_seed
(
111
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
111
)
attn_layers
=
[]
for
idx
in
range
(
0
,
args
.
layers
)
:
if
args
.
encdec_attn
:
if
args
.
ref
:
attn_layers
.
append
(
EncdecMultiheadAttn
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
0.1
,
bias
=
args
.
biases
,
include_norm_add
=
False
,
impl
=
'default'
))
else
:
attn_layers
.
append
(
EncdecMultiheadAttn
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
0.1
,
bias
=
args
.
biases
,
include_norm_add
=
args
.
norm_add
,
impl
=
'fast'
))
else
:
if
args
.
native
:
attn_layers
.
append
(
torch
.
nn
.
MultiheadAttention
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
0.1
,
bias
=
args
.
biases
))
elif
args
.
ref
:
attn_layers
.
append
(
SelfMultiheadAttn
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
0.1
,
bias
=
args
.
biases
,
include_norm_add
=
args
.
norm_add
,
impl
=
'default'
))
else
:
attn_layers
.
append
(
SelfMultiheadAttn
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
0.1
,
bias
=
args
.
biases
,
include_norm_add
=
args
.
norm_add
,
impl
=
'fast'
))
attn_layers
[
idx
].
cuda
()
attn_layers
[
idx
].
half
()
if
not
args
.
native
:
attn_layers
[
idx
].
reset_parameters
()
start_evt_fwd
=
[]
start_evt_bwd
=
[]
stop_evt_bwd
=
[]
for
recorded_trial
in
range
(
0
,
args
.
trials
)
:
start_evt_fwd
.
append
(
torch
.
cuda
.
Event
(
enable_timing
=
True
))
start_evt_bwd
.
append
(
torch
.
cuda
.
Event
(
enable_timing
=
True
))
stop_evt_bwd
.
append
(
torch
.
cuda
.
Event
(
enable_timing
=
True
))
for
sequences
in
range
(
args
.
num_seqs_start
,
args
.
num_seqs_stop
+
args
.
num_seqs_inc
,
args
.
num_seqs_inc
)
:
inputs
=
torch
.
randn
(
args
.
seq_length
,
sequences
,
args
.
hidden_dim
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
)).
requires_grad_
(
True
)
grads
=
torch
.
randn_like
(
inputs
)
for
trial
in
range
(
0
,
args
.
trials
+
args
.
warmup_trials
)
:
layer_inputs
=
inputs
evt_idx
=
trial
-
args
.
warmup_trials
if
evt_idx
>=
0
:
start_evt_fwd
[
evt_idx
].
record
()
for
lyr_idx
in
range
(
0
,
args
.
layers
)
:
if
args
.
native
:
outputs
,
_
=
attn_layers
[
lyr_idx
].
forward
(
layer_inputs
,
layer_inputs
,
layer_inputs
,
key_padding_mask
=
None
,
need_weights
=
False
,
attn_mask
=
None
)
else
:
outputs
,
_
=
attn_layers
[
lyr_idx
].
forward
(
layer_inputs
,
layer_inputs
,
layer_inputs
,
key_padding_mask
=
None
,
need_weights
=
False
,
attn_mask
=
None
,
is_training
=
True
)
layer_inputs
=
outputs
if
evt_idx
>=
0
:
start_evt_bwd
[
evt_idx
].
record
()
if
not
args
.
fwd
:
layer_inputs
.
backward
(
grads
)
if
evt_idx
>=
0
:
stop_evt_bwd
[
evt_idx
].
record
()
torch
.
cuda
.
synchronize
()
elapsed_time_fwd
=
0.0
elapsed_time_bwd
=
0.0
for
evt_idx
in
range
(
0
,
args
.
trials
)
:
elapsed_time_fwd
+=
start_evt_fwd
[
evt_idx
].
elapsed_time
(
start_evt_bwd
[
evt_idx
])
elapsed_time_bwd
+=
start_evt_bwd
[
evt_idx
].
elapsed_time
(
stop_evt_bwd
[
evt_idx
])
print
(
"[ {} Attn {} ]Total Tokens: {:4d} Sequences: {:3d} Sequence Length: {:3d} Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms"
.
format
(
'Encdec'
if
args
.
encdec_attn
else
'Self'
,
\
'Norm&Add'
if
args
.
norm_add
else
''
,
\
sequences
*
args
.
seq_length
,
\
sequences
,
\
args
.
seq_length
,
\
elapsed_time_fwd
/
(
args
.
trials
*
args
.
layers
),
\
elapsed_time_bwd
/
(
args
.
trials
*
args
.
layers
)))
apex/contrib/multihead_attn/MHA_bwd.png
0 → 100644
View file @
93f91cde
84.6 KB
apex/contrib/multihead_attn/MHA_fwd.png
0 → 100644
View file @
93f91cde
82.4 KB
apex/contrib/multihead_attn/README.md
0 → 100644
View file @
93f91cde
# Fast Multihead Attention
This implementation has two main features :
*
A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes.
*
The removal of all copies and transposes found in standard implementations of Multihead Attention.
| | Python Version | C++ Version |
| :----------------------------------------- | :------------: | :---------: |
| Layer Norm and Residual Add Variant | X | X |
| Includes Linear Biases | X | |
| Reduces CPU Overheads | | X |
| Fuses masking with Softmax | | X |
| Removes Transposes and Copies | X | X |
| Includes Self and Encoder/Decoder Variants | X | X |
## How to Instantiate
`SelfMultiheadAttn(`
_hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_
`)`
`EncdecMultiheadAttn(`
_hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_
`)`
`impl`
has two options:
*
`fast`
uses C++ Version
*
`default`
uses Python Version
## Instructions to build on Linux
```
$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" ./
```
## Try Performance Tests Yourself!
Perf test script is found here!
```
cd contrib/examples/multihead_attn
```
#### Fast Multihead Attention
```
python perf_test_multihead_attn.py --ref
```
#### Fast Multihead Attention with C++ Implementation
```
python perf_test_multihead_attn.py
```
#### Compare with `torch.nn.MultiheadAttn`
```
python perf_test_multihead_attn.py --native
```
#### Test your own range!
```
python perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5
```
## Performance Comparisons
*
Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card.
*
Time is measured across multiple layers to simulate an in model scenario.


Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment