Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5c3843dc
Unverified
Commit
5c3843dc
authored
Dec 21, 2021
by
shenggan
Committed by
GitHub
Dec 21, 2021
Browse files
add colossalai kernel module (#55)
parent
648f8063
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5401 additions
and
0 deletions
+5401
-0
colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
...nel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
+99
-0
colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
...alai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
+1160
-0
colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
...ssalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
+366
-0
colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu
...alai/kernel/cuda_native/csrc/kernels/transform_kernels.cu
+314
-0
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
+185
-0
colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
+813
-0
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
...ssalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
+364
-0
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
+153
-0
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp
+84
-0
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h
+492
-0
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu
...lai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu
+104
-0
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp
...l/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp
+59
-0
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h
...nel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h
+500
-0
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu
...da_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu
+85
-0
colossalai/kernel/cuda_native/csrc/type_shim.h
colossalai/kernel/cuda_native/csrc/type_shim.h
+73
-0
colossalai/kernel/cuda_native/layer_norm.py
colossalai/kernel/cuda_native/layer_norm.py
+69
-0
colossalai/kernel/cuda_native/multihead_attention.py
colossalai/kernel/cuda_native/multihead_attention.py
+270
-0
colossalai/kernel/cuda_native/scaled_softmax.py
colossalai/kernel/cuda_native/scaled_softmax.py
+184
-0
colossalai/kernel/jit/__init__.py
colossalai/kernel/jit/__init__.py
+3
-0
colossalai/kernel/jit/bias_dropout_add.py
colossalai/kernel/jit/bias_dropout_add.py
+24
-0
No files found.
colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
0 → 100644
View file @
5c3843dc
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include "cublas_wrappers.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
cublasOperation_t
op_A
;
cublasOperation_t
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
float
param_alpha
,
float
param_beta
,
cublasOperation_t
opA
,
cublasOperation_t
opB
)
:
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}))
{}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
cublasOperation_t
op_b
=
(
_config
.
op_B
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
CUBLAS_OP_T
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
CUBLAS_OP_T
?
d_output
:
_buffer_b
),
inpGradA
,
CUBLAS_OP_N
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
// A need to transpose.
cublasOperation_t
op_a
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
CUBLAS_OP_N
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
};
colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
0 → 100644
View file @
5c3843dc
#include "block_reduce.h"
#include "kernels.h"
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
const
float
LN_EPSILON
=
1e-8
f
;
#define TILE_DIM 32
template
<
typename
T
>
__forceinline__
__device__
T
add_eps
(
T
x
)
{
return
fabsf
(
x
)
>
LN_EPSILON
?
x
:
(
x
<
0
?
-
LN_EPSILON
:
LN_EPSILON
);
}
/**
@brief: ker_layer_norm
Standard layer normalization.
It will not only output the layer norm result,
but also outputs variance.
may also output means, depends on whether
the means argument is nullptr
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
ln_res: [batch_size* seq_len, hidden_size], ln result.
vars: [batch_size* seq_len], variance per token
means: [batch_size* seq_len], means per token, can be nullput
inp: [batch_size * seq_len, hidden_size], ln input.
scale: [hidden_size], ln scale
bias: [hidden_size], ln bias
*/
template
<
typename
T
>
__global__
void
ker_layer_norm
(
T
*
ln_res
,
T
*
vars
,
T
*
means
,
const
T
*
inp
,
const
T
*
scale
,
const
T
*
bias
,
int
hidden_size
)
{
// step 0. compute local sum
float
l_sum
=
0
;
float
l_square_sum
=
0
;
const
float4
*
inp_f4
=
(
const
float4
*
)
inp
+
blockIdx
.
x
*
hidden_size
;
for
(
uint
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float4
val
=
inp_f4
[
idx
];
l_sum
+=
val
.
x
+
val
.
y
+
val
.
z
+
val
.
w
;
l_square_sum
+=
val
.
x
*
val
.
x
+
val
.
y
*
val
.
y
+
val
.
z
*
val
.
z
+
val
.
w
*
val
.
w
;
}
// step 1. compute reduce sum
float
mean_dim
=
float
(
hidden_size
)
*
4.
f
;
float
reduce_val
[
2
]
=
{
l_sum
,
l_square_sum
};
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_mean
,
s_var
;
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
reduce_val
[
0
]
/
mean_dim
;
if
(
means
!=
nullptr
)
{
means
[
blockIdx
.
x
]
=
s_mean
;
}
s_var
=
reduce_val
[
1
]
/
mean_dim
-
s_mean
*
s_mean
+
LN_EPSILON
;
vars
[
blockIdx
.
x
]
=
s_var
;
s_var
=
rsqrtf
(
s_var
);
}
__syncthreads
();
// step 2. layer norm result
float4
*
output_f4
=
(
float4
*
)
ln_res
+
blockIdx
.
x
*
hidden_size
;
for
(
uint
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float4
vscale
=
__ldg
((
const
float4
*
)
scale
+
idx
);
float4
vbias
=
__ldg
((
const
float4
*
)
bias
+
idx
);
float4
val
=
inp_f4
[
idx
];
val
.
x
=
(
val
.
x
-
s_mean
)
*
s_var
*
vscale
.
x
+
vbias
.
x
;
val
.
y
=
(
val
.
y
-
s_mean
)
*
s_var
*
vscale
.
y
+
vbias
.
y
;
val
.
z
=
(
val
.
z
-
s_mean
)
*
s_var
*
vscale
.
z
+
vbias
.
z
;
val
.
w
=
(
val
.
w
-
s_mean
)
*
s_var
*
vscale
.
w
+
vbias
.
w
;
output_f4
[
idx
]
=
val
;
}
}
template
<
>
__global__
void
ker_layer_norm
<
__half
>
(
__half
*
ln_res
,
__half
*
vars
,
__half
*
means
,
const
__half
*
inp
,
const
__half
*
scale
,
const
__half
*
bias
,
int
hidden_size
)
{
// step 0. compute local sum
float
l_sum
=
0
;
float
l_square_sum
=
0
;
const
float4
*
inp_f4
=
(
const
float4
*
)
inp
+
blockIdx
.
x
*
hidden_size
;
for
(
uint
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float4
val_f4
=
inp_f4
[
idx
];
__half2
*
val_h2
=
(
__half2
*
)(
&
val_f4
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
val_f2
=
__half22float2
(
val_h2
[
i
]);
l_sum
+=
val_f2
.
x
+
val_f2
.
y
;
l_square_sum
+=
val_f2
.
x
*
val_f2
.
x
+
val_f2
.
y
*
val_f2
.
y
;
}
}
// step 1. compute reduce sum
float
mean_dim
=
float
(
hidden_size
)
*
8.
f
;
float
reduce_val
[
2
]
=
{
l_sum
,
l_square_sum
};
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_mean
,
s_var
;
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
reduce_val
[
0
]
/
mean_dim
;
if
(
means
!=
nullptr
)
{
means
[
blockIdx
.
x
]
=
s_mean
;
}
s_var
=
reduce_val
[
1
]
/
mean_dim
-
s_mean
*
s_mean
+
LN_EPSILON
;
vars
[
blockIdx
.
x
]
=
s_var
;
s_var
=
rsqrtf
(
s_var
);
}
__syncthreads
();
// step 2. layer norm result
float4
*
output_f4
=
(
float4
*
)
ln_res
+
blockIdx
.
x
*
hidden_size
;
for
(
uint
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
// load scale, bias, input
float4
scale_f4
=
__ldg
((
const
float4
*
)
scale
+
idx
);
__half2
*
scale_h2
=
(
__half2
*
)(
&
scale_f4
);
float4
bias_f4
=
__ldg
((
const
float4
*
)
bias
+
idx
);
__half2
*
bias_h2
=
(
__half2
*
)(
&
bias_f4
);
float4
val_f4
=
inp_f4
[
idx
];
__half2
*
val_h2
=
(
__half2
*
)(
&
val_f4
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
scale_f2
=
__half22float2
(
scale_h2
[
i
]);
float2
bias_f2
=
__half22float2
(
bias_h2
[
i
]);
float2
val_f2
=
__half22float2
(
val_h2
[
i
]);
val_f2
.
x
=
(
val_f2
.
x
-
s_mean
)
*
s_var
*
scale_f2
.
x
+
bias_f2
.
x
;
val_f2
.
y
=
(
val_f2
.
y
-
s_mean
)
*
s_var
*
scale_f2
.
y
+
bias_f2
.
y
;
val_h2
[
i
]
=
__float22half2_rn
(
val_f2
);
}
output_f4
[
idx
]
=
val_f4
;
}
}
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half *bias,
// int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// }
// }
// // step 1. compute reduce sum
// float mean_dim = float(hidden_size) * 8.f * 2;
// float reduce_val[2] = {l_sum, l_square_sum};
// blockReduce<ReduceType::kSum, 2>(reduce_val);
// __shared__ float s_mean, s_var;
// if (threadIdx.x == 0) {
// s_mean = reduce_val[0] / mean_dim;
// if (means != nullptr) {
// means[blockIdx.x] = s_mean;
// }
// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
// vars[blockIdx.x] = s_var;
// s_var = rsqrtf(s_var);
// }
// __syncthreads();
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1);
// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1);
// float4 bias_f4 = __ldg((const float4 *)bias + idx);
// __half2 *bias_h2 = (__half2 *)(&bias_f4);
// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1);
// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1);
// float4 val_f4 = inp_f4[idx];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 scale_f2 = __half22float2(scale_h2[i]);
// float2 scale_f2_1 = __half22float2(scale_h2_1[i]);
// float2 bias_f2 = __half22float2(bias_h2[i]);
// float2 bias_f2_1 = __half22float2(bias_h2_1[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x;
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y;
// val_h2_1[i] = __float22half2_rn(val_f2_1);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// }
// }
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half *bias,
// int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_2 = inp_f4[idx+2];
// float4 val_f4_3 = inp_f4[idx+3];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// __half2 *val_h2_2 = (__half2 *)(&val_f4_2);
// __half2 *val_h2_3 = (__half2 *)(&val_f4_3);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + val_f2_2.y + val_f2_3.x + val_f2_3.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y;
// l_square_sum += val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// l_square_sum += val_f2_2.x * val_f2_2.x + val_f2_2.y * val_f2_2.y;
// l_square_sum += val_f2_3.x * val_f2_3.x + val_f2_3.y * val_f2_3.y;
// }
// }
// // step 1. compute reduce sum
// float mean_dim = float(hidden_size) * 8.f * 4;
// float reduce_val[2] = {l_sum, l_square_sum};
// blockReduce<ReduceType::kSum, 2>(reduce_val);
// __shared__ float s_mean, s_var;
// if (threadIdx.x == 0) {
// s_mean = reduce_val[0] / mean_dim;
// if (means != nullptr) {
// means[blockIdx.x] = s_mean;
// }
// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
// vars[blockIdx.x] = s_var;
// s_var = rsqrtf(s_var);
// }
// __syncthreads();
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1);
// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1);
// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2);
// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2);
// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3);
// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3);
// float4 bias_f4 = __ldg((const float4 *)bias + idx);
// __half2 *bias_h2 = (__half2 *)(&bias_f4);
// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1);
// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1);
// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2);
// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2);
// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3);
// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3);
// float4 val_f4 = inp_f4[idx];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// float4 val_f4_2 = inp_f4[idx+2];
// __half2 *val_h2_2 = (__half2 *)(&val_f4_2);
// float4 val_f4_3 = inp_f4[idx+3];
// __half2 *val_h2_3 = (__half2 *)(&val_f4_3);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 scale_f2 = __half22float2(scale_h2[i]);
// float2 scale_f2_1 = __half22float2(scale_h2_1[i]);
// float2 scale_f2_2 = __half22float2(scale_h2_2[i]);
// float2 scale_f2_3 = __half22float2(scale_h2_3[i]);
// float2 bias_f2 = __half22float2(bias_h2[i]);
// float2 bias_f2_1 = __half22float2(bias_h2_1[i]);
// float2 bias_f2_2 = __half22float2(bias_h2_2[i]);
// float2 bias_f2_3 = __half22float2(bias_h2_3[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x;
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y;
// val_f2_2.x = (val_f2_2.x - s_mean) * s_var * scale_f2_2.x + bias_f2_2.x;
// val_f2_2.y = (val_f2_2.y - s_mean) * s_var * scale_f2_2.y + bias_f2_2.y;
// val_f2_3.x = (val_f2_3.x - s_mean) * s_var * scale_f2_3.x + bias_f2_3.x;
// val_f2_3.y = (val_f2_3.y - s_mean) * s_var * scale_f2_3.y + bias_f2_3.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_h2_2[i] = __float22half2_rn(val_f2_2);
// val_h2_3[i] = __float22half2_rn(val_f2_3);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// output_f4[idx+2] = val_f4_2;
// output_f4[idx+3] = val_f4_3;
// }
// }
template
<
>
void
launch_layer_norm
<
float
>
(
float
*
ln_res
,
float
*
vars
,
float
*
means
,
const
float
*
inp
,
const
float
*
scale
,
const
float
*
bias
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
)
{
if
(
hidden_dim
%
4
!=
0
)
{
throw
std
::
runtime_error
(
"violate hidden_dim % 4 = 0"
);
}
hidden_dim
>>=
2
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
nthread
);
ker_layer_norm
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
ln_res
,
vars
,
means
,
inp
,
scale
,
bias
,
hidden_dim
);
}
template
<
>
void
launch_layer_norm
<
__half
>
(
__half
*
ln_res
,
__half
*
vars
,
__half
*
means
,
const
__half
*
inp
,
const
__half
*
scale
,
const
__half
*
bias
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
)
{
if
(
hidden_dim
%
8
!=
0
)
{
throw
std
::
runtime_error
(
"violate hidden_dim % 8 = 0"
);
}
hidden_dim
>>=
3
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
nthread
);
ker_layer_norm
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
ln_res
,
vars
,
means
,
inp
,
scale
,
bias
,
hidden_dim
);
// if (hidden_dim % 8 != 0) {
// throw std::runtime_error("violate hidden_dim % 8 = 0");
// }
// hidden_dim >>= 3;
// if (hidden_dim * 8 < 8192) {
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm<__half><<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) {
// hidden_dim >>= 1;
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm_x2<<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) {
// hidden_dim >>= 2;
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm_x4<<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else {
// throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
// }
}
/**
@brief: ker_ln_bw_dgamma_dbetta
Layer norm backword kernel, compute the gradient of gamma and betta.
dbetta = sum(dout, dim=0)
dgamma = sum(xhat * dout, dim=0)
xhat = (input - mean) * rsqrt(var) or
(output - betta) / gamma
@thread
gridDim.x = hidden_size / 32
blockDim.x = 32
blockDim.y = 32
@param
gamma_grad: [hidden_size], gradient of gamma
betta_grad: [hidden_size], gradient of betta
out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr
ln input if means is not nullptr
gamma: [hidden_size], gamma of ln,
used to compute xhat, maybe nullptr
betta: [hidden_size], betta of ln,
used to compute xhat, maybe nullptr
vars: [batch_size * seq_len], variance of ln forward,
used to compute xhat, maybe nullptr
means: [batch_size * seq_len], mean of ln forward,
used to compute xhat, maybe nullptr
(gamma && betta) ^ (vars && means) should be true
*/
template
<
typename
T
>
__global__
void
ker_ln_bw_dgamma_dbetta
(
T
*
gamma_grad
,
T
*
betta_grad
,
const
T
*
out_grad
,
const
T
*
inp_or_out
,
const
T
*
gamma
,
const
T
*
betta
,
const
T
*
vars
,
const
T
*
means
,
int
rows
,
int
width
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
// Loop across inp height
float
dbetta
=
0
;
float
dgamma
=
0
;
float
dout
,
val
;
if
(
idx
<
width
)
{
if
(
means
==
nullptr
)
{
float
vbetta
=
(
float
)
betta
[
idx
];
float
vgamma
=
(
float
)
gamma
[
idx
];
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
dout
=
(
float
)
out_grad
[
offset
];
// inp_or_out is output
val
=
(
float
)
inp_or_out
[
offset
];
dbetta
+=
dout
;
dgamma
+=
((
val
-
vbetta
)
/
add_eps
(
vgamma
)
*
dout
);
offset
+=
y_stride
;
}
}
else
{
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
dout
=
(
float
)
out_grad
[
offset
];
// inp_or_out is input
val
=
(
float
)
inp_or_out
[
offset
];
dbetta
+=
dout
;
dgamma
+=
((
val
-
(
float
)
means
[
r
])
*
rsqrtf
((
float
)
vars
[
r
]
+
LN_EPSILON
)
*
dout
);
offset
+=
y_stride
;
}
}
}
// Sum the shared buffer.
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
dbetta
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
dgamma
;
__syncthreads
();
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
__syncthreads
();
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
if
(
threadIdx
.
x
==
0
&&
idx
<
width
)
{
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
/**
@brief: ker_ln_bw_dinp
Layer norm backword kernel, compute the gradient of input.
dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim)
* rsqrt(var)
xhat = (input - mean) * rsqrt(var) if mean is not nullptr
(output - betta) / gamma if mean is nullptr
dxhat = dout * gamma
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input,
usually appear in pre-layer-norm for transformer layer, maybe nullptr
inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr
ln input if means is not nullptr
gamma: [hidden_size], gamma of ln,
used to compute xhat and dxhat
betta: [hidden_size], betta of ln,
used to compute xhat, maybe nullptr
vars: [batch_size * seq_len], variance of ln forward,
used to compute xhat and dinp
means: [batch_size * seq_len], mean of ln forward,
used to compute xhat, maybe nullptr
*/
template
<
typename
T
>
__global__
void
ker_ln_bw_dinp
(
T
*
inp_grad
,
const
T
*
out_grad
,
const
T
*
residual_grad
,
const
T
*
inp_or_out
,
const
T
*
gamma
,
const
T
*
betta
,
const
T
*
vars
,
const
T
*
means
,
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
+
threadIdx
.
x
;
float4
dxhat
,
xhat
;
float
var_rsqrt
;
if
(
threadIdx
.
x
<
hidden_dim
)
{
// step 0. dxhat = dout * gamma
dxhat
=
((
const
float4
*
)
out_grad
)[
offset
];
float4
vgamma
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
];
dxhat
.
x
*=
vgamma
.
x
;
dxhat
.
y
*=
vgamma
.
y
;
dxhat
.
z
*=
vgamma
.
z
;
dxhat
.
w
*=
vgamma
.
w
;
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
xhat
=
((
const
float4
*
)
inp_or_out
)[
offset
];
var_rsqrt
=
rsqrtf
((
float
)
vars
[
blockIdx
.
x
]
+
LN_EPSILON
);
if
(
means
==
nullptr
)
{
// inp_or_out is output, xhat = (output - betta) / gamma
float4
vbetta
=
((
const
float4
*
)
betta
)[
threadIdx
.
x
];
xhat
.
x
=
(
xhat
.
x
-
vbetta
.
x
)
/
add_eps
(
vgamma
.
x
);
xhat
.
y
=
(
xhat
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
xhat
.
z
=
(
xhat
.
z
-
vbetta
.
z
)
/
add_eps
(
vgamma
.
z
);
xhat
.
w
=
(
xhat
.
w
-
vbetta
.
w
)
/
add_eps
(
vgamma
.
w
);
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float
fmean
=
(
float
)
means
[
blockIdx
.
x
];
xhat
.
x
=
(
xhat
.
x
-
fmean
)
*
var_rsqrt
;
xhat
.
y
=
(
xhat
.
y
-
fmean
)
*
var_rsqrt
;
xhat
.
z
=
(
xhat
.
z
-
fmean
)
*
var_rsqrt
;
xhat
.
w
=
(
xhat
.
w
-
fmean
)
*
var_rsqrt
;
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
float
reduce_val
[
2
]
=
{
0.
f
,
0.
f
};
if
(
threadIdx
.
x
<
hidden_dim
)
{
reduce_val
[
0
]
=
dxhat
.
x
+
dxhat
.
y
+
dxhat
.
z
+
dxhat
.
w
;
reduce_val
[
1
]
=
dxhat
.
x
*
xhat
.
x
+
dxhat
.
y
*
xhat
.
y
+
dxhat
.
z
*
xhat
.
z
+
dxhat
.
w
*
xhat
.
w
;
}
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_sum_dxhat
,
s_sum_dxhat_xhat
;
if
(
threadIdx
.
x
==
0
)
{
float
mean_dim
=
hidden_dim
*
4
;
s_sum_dxhat
=
reduce_val
[
0
]
/
mean_dim
;
s_sum_dxhat_xhat
=
reduce_val
[
1
]
/
mean_dim
;
}
__syncthreads
();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if
(
threadIdx
.
x
>=
hidden_dim
)
{
return
;
}
dxhat
.
x
=
(
dxhat
.
x
-
s_sum_dxhat
-
xhat
.
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
;
dxhat
.
y
=
(
dxhat
.
y
-
s_sum_dxhat
-
xhat
.
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
;
dxhat
.
z
=
(
dxhat
.
z
-
s_sum_dxhat
-
xhat
.
z
*
s_sum_dxhat_xhat
)
*
var_rsqrt
;
dxhat
.
w
=
(
dxhat
.
w
-
s_sum_dxhat
-
xhat
.
w
*
s_sum_dxhat_xhat
)
*
var_rsqrt
;
if
(
residual_grad
)
{
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
dxhat
.
x
+=
dresidual
.
x
;
dxhat
.
y
+=
dresidual
.
y
;
dxhat
.
z
+=
dresidual
.
z
;
dxhat
.
w
+=
dresidual
.
w
;
}
((
float4
*
)
inp_grad
)[
offset
]
=
dxhat
;
}
template
<
>
__global__
void
ker_ln_bw_dinp
<
__half
>
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
means
,
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
+
threadIdx
.
x
;
float2
dxhat
[
4
],
xhat
[
4
];
float
var_rsqrt
;
float4
vtmp
;
__half2
*
tmp_h2
;
float
reduce_val
[
2
]
=
{
0.
f
,
0.
f
};
if
(
threadIdx
.
x
<
hidden_dim
)
{
// step 0. dxhat = dout * gamma
vtmp
=
((
const
float4
*
)
out_grad
)[
offset
];
tmp_h2
=
reinterpret_cast
<
__half2
*>
(
&
vtmp
);
float4
gamma_f4
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
];
__half2
*
gamma_h2
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vdout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
dxhat
[
i
].
x
=
vdout
.
x
*
vgamma
.
x
;
dxhat
[
i
].
y
=
vdout
.
y
*
vgamma
.
y
;
reduce_val
[
0
]
+=
dxhat
[
i
].
x
+
dxhat
[
i
].
y
;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp
=
((
const
float4
*
)
inp_or_out
)[
offset
];
var_rsqrt
=
rsqrtf
((
float
)
vars
[
blockIdx
.
x
]
+
LN_EPSILON
);
if
(
means
==
nullptr
)
{
// inp_or_out is output, xhat = (output - betta) / gamma
float4
vbetta
=
((
const
float4
*
)
betta
)[
threadIdx
.
x
];
__half2
*
betta_h2
=
reinterpret_cast
<
__half2
*>
(
&
vbetta
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vbetta
=
__half22float2
(
betta_h2
[
i
]);
xhat
[
i
].
x
=
(
vout
.
x
-
vbetta
.
x
)
/
add_eps
(
vgamma
.
x
);
xhat
[
i
].
y
=
(
vout
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
}
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float
fmean
=
(
float
)
means
[
blockIdx
.
x
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vinp
=
__half22float2
(
tmp_h2
[
i
]);
xhat
[
i
].
x
=
(
vinp
.
x
-
fmean
)
*
var_rsqrt
;
xhat
[
i
].
y
=
(
vinp
.
y
-
fmean
)
*
var_rsqrt
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_sum_dxhat
,
s_sum_dxhat_xhat
;
if
(
threadIdx
.
x
==
0
)
{
float
mean_dim
=
hidden_dim
*
8
;
s_sum_dxhat
=
reduce_val
[
0
]
/
mean_dim
;
s_sum_dxhat_xhat
=
reduce_val
[
1
]
/
mean_dim
;
}
__syncthreads
();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if
(
threadIdx
.
x
>=
hidden_dim
)
{
return
;
}
if
(
residual_grad
)
{
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
]));
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
+
1
]));
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
}
}
((
float4
*
)
inp_grad
)[
offset
]
=
vtmp
;
}
__global__
void
ker_ln_bw_dinp_x2
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
means
,
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
*
2
+
threadIdx
.
x
*
2
;
float2
dxhat
[
4
],
xhat
[
4
];
float2
dxhat_1
[
4
],
xhat_1
[
4
];
float
var_rsqrt
;
float4
vtmp
,
vtmp_1
;
__half2
*
tmp_h2
;
__half2
*
tmp_h2_1
;
float
reduce_val
[
2
]
=
{
0.
f
,
0.
f
};
if
(
threadIdx
.
x
<
hidden_dim
)
{
// step 0. dxhat = dout * gamma
vtmp
=
((
const
float4
*
)
out_grad
)[
offset
];
vtmp_1
=
((
const
float4
*
)
out_grad
)[
offset
+
1
];
tmp_h2
=
reinterpret_cast
<
__half2
*>
(
&
vtmp
);
tmp_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
vtmp_1
);
float4
gamma_f4
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
2
];
float4
gamma_f4_1
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
2
+
1
];
__half2
*
gamma_h2
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4
);
__half2
*
gamma_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vdout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vdout_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vgamma_1
=
__half22float2
(
gamma_h2_1
[
i
]);
dxhat
[
i
].
x
=
vdout
.
x
*
vgamma
.
x
;
dxhat
[
i
].
y
=
vdout
.
y
*
vgamma
.
y
;
dxhat_1
[
i
].
x
=
vdout_1
.
x
*
vgamma_1
.
x
;
dxhat_1
[
i
].
y
=
vdout_1
.
y
*
vgamma_1
.
y
;
reduce_val
[
0
]
+=
dxhat
[
i
].
x
+
dxhat
[
i
].
y
+
dxhat_1
[
i
].
x
+
dxhat_1
[
i
].
y
;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp
=
((
const
float4
*
)
inp_or_out
)[
offset
];
vtmp_1
=
((
const
float4
*
)
inp_or_out
)[
offset
+
1
];
var_rsqrt
=
rsqrtf
((
float
)
vars
[
blockIdx
.
x
]
+
LN_EPSILON
);
if
(
means
==
nullptr
)
{
// inp_or_out is output, xhat = (output - betta) / gamma
float4
vbetta
=
((
const
float4
*
)
betta
)[
2
*
threadIdx
.
x
];
float4
vbetta_1
=
((
const
float4
*
)
betta
)[
2
*
threadIdx
.
x
+
1
];
__half2
*
betta_h2
=
reinterpret_cast
<
__half2
*>
(
&
vbetta
);
__half2
*
betta_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
vbetta_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vout_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vgamma_1
=
__half22float2
(
gamma_h2_1
[
i
]);
float2
vbetta
=
__half22float2
(
betta_h2
[
i
]);
float2
vbetta_1
=
__half22float2
(
betta_h2_1
[
i
]);
xhat
[
i
].
x
=
(
vout
.
x
-
vbetta
.
x
)
/
add_eps
(
vgamma
.
x
);
xhat_1
[
i
].
x
=
(
vout_1
.
x
-
vbetta_1
.
x
)
/
add_eps
(
vgamma_1
.
x
);
xhat
[
i
].
y
=
(
vout
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
xhat_1
[
i
].
y
=
(
vout_1
.
y
-
vbetta_1
.
y
)
/
add_eps
(
vgamma_1
.
y
);
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
}
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float
fmean
=
(
float
)
means
[
blockIdx
.
x
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vinp
=
__half22float2
(
tmp_h2
[
i
]);
float2
vinp_1
=
__half22float2
(
tmp_h2_1
[
i
]);
xhat
[
i
].
x
=
(
vinp
.
x
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
x
=
(
vinp_1
.
x
-
fmean
)
*
var_rsqrt
;
xhat
[
i
].
y
=
(
vinp
.
y
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
y
=
(
vinp_1
.
y
-
fmean
)
*
var_rsqrt
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_sum_dxhat
,
s_sum_dxhat_xhat
;
if
(
threadIdx
.
x
==
0
)
{
float
mean_dim
=
hidden_dim
*
8
*
2
;
s_sum_dxhat
=
reduce_val
[
0
]
/
mean_dim
;
s_sum_dxhat_xhat
=
reduce_val
[
1
]
/
mean_dim
;
}
__syncthreads
();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if
(
threadIdx
.
x
>=
hidden_dim
)
{
return
;
}
if
(
residual_grad
)
{
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
float4
dresidual_1
=
((
const
float4
*
)
residual_grad
)[
offset
+
1
];
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
__half
*
hdres_1
=
reinterpret_cast
<
__half
*>
(
&
dresidual_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
]));
tmp_h2_1
[
i
].
x
=
__float2half
(
(
dxhat_1
[
i
].
x
-
s_sum_dxhat
-
xhat_1
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
]));
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
+
1
]));
tmp_h2_1
[
i
].
y
=
__float2half
(
(
dxhat_1
[
i
].
y
-
s_sum_dxhat
-
xhat_1
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
+
1
]));
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_1
[
i
].
x
=
__float2half
(
(
dxhat_1
[
i
].
x
-
s_sum_dxhat
-
xhat_1
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_1
[
i
].
y
=
__float2half
(
(
dxhat_1
[
i
].
y
-
s_sum_dxhat
-
xhat_1
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
}
}
((
float4
*
)
inp_grad
)[
offset
]
=
vtmp
;
((
float4
*
)
inp_grad
)[
offset
+
1
]
=
vtmp_1
;
}
__global__
void
ker_ln_bw_dinp_x4
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
means
,
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
*
4
+
threadIdx
.
x
*
4
;
float2
dxhat
[
4
],
xhat
[
4
];
float2
dxhat_1
[
4
],
xhat_1
[
4
];
float2
dxhat_2
[
4
],
xhat_2
[
4
];
float2
dxhat_3
[
4
],
xhat_3
[
4
];
float
var_rsqrt
;
float4
vtmp
,
vtmp_1
,
vtmp_2
,
vtmp_3
;
__half2
*
tmp_h2
;
__half2
*
tmp_h2_1
;
__half2
*
tmp_h2_2
;
__half2
*
tmp_h2_3
;
float
reduce_val
[
2
]
=
{
0.
f
,
0.
f
};
if
(
threadIdx
.
x
<
hidden_dim
)
{
// step 0. dxhat = dout * gamma
vtmp
=
((
const
float4
*
)
out_grad
)[
offset
];
vtmp_1
=
((
const
float4
*
)
out_grad
)[
offset
+
1
];
vtmp_2
=
((
const
float4
*
)
out_grad
)[
offset
+
2
];
vtmp_3
=
((
const
float4
*
)
out_grad
)[
offset
+
3
];
tmp_h2
=
reinterpret_cast
<
__half2
*>
(
&
vtmp
);
tmp_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
vtmp_1
);
tmp_h2_2
=
reinterpret_cast
<
__half2
*>
(
&
vtmp_2
);
tmp_h2_3
=
reinterpret_cast
<
__half2
*>
(
&
vtmp_3
);
float4
gamma_f4
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
4
];
float4
gamma_f4_1
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
4
+
1
];
float4
gamma_f4_2
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
4
+
2
];
float4
gamma_f4_3
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
4
+
3
];
__half2
*
gamma_h2
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4
);
__half2
*
gamma_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4_1
);
__half2
*
gamma_h2_2
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4_2
);
__half2
*
gamma_h2_3
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4_3
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vdout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vdout_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vdout_2
=
__half22float2
(
tmp_h2_2
[
i
]);
float2
vdout_3
=
__half22float2
(
tmp_h2_3
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vgamma_1
=
__half22float2
(
gamma_h2_1
[
i
]);
float2
vgamma_2
=
__half22float2
(
gamma_h2_2
[
i
]);
float2
vgamma_3
=
__half22float2
(
gamma_h2_3
[
i
]);
dxhat
[
i
].
x
=
vdout
.
x
*
vgamma
.
x
;
dxhat
[
i
].
y
=
vdout
.
y
*
vgamma
.
y
;
dxhat_1
[
i
].
x
=
vdout_1
.
x
*
vgamma_1
.
x
;
dxhat_1
[
i
].
y
=
vdout_1
.
y
*
vgamma_1
.
y
;
dxhat_2
[
i
].
x
=
vdout_2
.
x
*
vgamma_2
.
x
;
dxhat_2
[
i
].
y
=
vdout_2
.
y
*
vgamma_2
.
y
;
dxhat_3
[
i
].
x
=
vdout_3
.
x
*
vgamma_3
.
x
;
dxhat_3
[
i
].
y
=
vdout_3
.
y
*
vgamma_3
.
y
;
reduce_val
[
0
]
+=
dxhat
[
i
].
x
+
dxhat
[
i
].
y
+
dxhat_1
[
i
].
x
+
dxhat_1
[
i
].
y
+
dxhat_2
[
i
].
x
+
dxhat_2
[
i
].
y
+
dxhat_3
[
i
].
x
+
dxhat_3
[
i
].
y
;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp
=
((
const
float4
*
)
inp_or_out
)[
offset
];
vtmp_1
=
((
const
float4
*
)
inp_or_out
)[
offset
+
1
];
vtmp_2
=
((
const
float4
*
)
inp_or_out
)[
offset
+
2
];
vtmp_3
=
((
const
float4
*
)
inp_or_out
)[
offset
+
3
];
var_rsqrt
=
rsqrtf
((
float
)
vars
[
blockIdx
.
x
]
+
LN_EPSILON
);
if
(
means
==
nullptr
)
{
// inp_or_out is output, xhat = (output - betta) / gamma
float4
vbetta
=
((
const
float4
*
)
betta
)[
4
*
threadIdx
.
x
];
float4
vbetta_1
=
((
const
float4
*
)
betta
)[
4
*
threadIdx
.
x
+
1
];
float4
vbetta_2
=
((
const
float4
*
)
betta
)[
4
*
threadIdx
.
x
+
2
];
float4
vbetta_3
=
((
const
float4
*
)
betta
)[
4
*
threadIdx
.
x
+
3
];
__half2
*
betta_h2
=
reinterpret_cast
<
__half2
*>
(
&
vbetta
);
__half2
*
betta_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
vbetta_1
);
__half2
*
betta_h2_2
=
reinterpret_cast
<
__half2
*>
(
&
vbetta_2
);
__half2
*
betta_h2_3
=
reinterpret_cast
<
__half2
*>
(
&
vbetta_3
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vout_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vout_2
=
__half22float2
(
tmp_h2_2
[
i
]);
float2
vout_3
=
__half22float2
(
tmp_h2_3
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vgamma_1
=
__half22float2
(
gamma_h2_1
[
i
]);
float2
vgamma_2
=
__half22float2
(
gamma_h2_2
[
i
]);
float2
vgamma_3
=
__half22float2
(
gamma_h2_3
[
i
]);
float2
vbetta
=
__half22float2
(
betta_h2
[
i
]);
float2
vbetta_1
=
__half22float2
(
betta_h2_1
[
i
]);
float2
vbetta_2
=
__half22float2
(
betta_h2_2
[
i
]);
float2
vbetta_3
=
__half22float2
(
betta_h2_3
[
i
]);
xhat
[
i
].
x
=
(
vout
.
x
-
vbetta
.
x
)
/
add_eps
(
vgamma
.
x
);
xhat_1
[
i
].
x
=
(
vout_1
.
x
-
vbetta_1
.
x
)
/
add_eps
(
vgamma_1
.
x
);
xhat_2
[
i
].
x
=
(
vout_2
.
x
-
vbetta_2
.
x
)
/
add_eps
(
vgamma_2
.
x
);
xhat_3
[
i
].
x
=
(
vout_3
.
x
-
vbetta_3
.
x
)
/
add_eps
(
vgamma_3
.
x
);
xhat
[
i
].
y
=
(
vout
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
xhat_1
[
i
].
y
=
(
vout_1
.
y
-
vbetta_1
.
y
)
/
add_eps
(
vgamma_1
.
y
);
xhat_2
[
i
].
y
=
(
vout_2
.
y
-
vbetta_2
.
y
)
/
add_eps
(
vgamma_2
.
y
);
xhat_3
[
i
].
y
=
(
vout_3
.
y
-
vbetta_3
.
y
)
/
add_eps
(
vgamma_3
.
y
);
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_2
[
i
].
x
*
dxhat_2
[
i
].
x
+
xhat_2
[
i
].
y
*
dxhat_2
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_3
[
i
].
x
*
dxhat_3
[
i
].
x
+
xhat_3
[
i
].
y
*
dxhat_3
[
i
].
y
;
}
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float
fmean
=
(
float
)
means
[
blockIdx
.
x
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vinp
=
__half22float2
(
tmp_h2
[
i
]);
float2
vinp_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vinp_2
=
__half22float2
(
tmp_h2_2
[
i
]);
float2
vinp_3
=
__half22float2
(
tmp_h2_3
[
i
]);
xhat
[
i
].
x
=
(
vinp
.
x
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
x
=
(
vinp_1
.
x
-
fmean
)
*
var_rsqrt
;
xhat_2
[
i
].
x
=
(
vinp_2
.
x
-
fmean
)
*
var_rsqrt
;
xhat_3
[
i
].
x
=
(
vinp_3
.
x
-
fmean
)
*
var_rsqrt
;
xhat
[
i
].
y
=
(
vinp
.
y
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
y
=
(
vinp_1
.
y
-
fmean
)
*
var_rsqrt
;
xhat_2
[
i
].
y
=
(
vinp_2
.
y
-
fmean
)
*
var_rsqrt
;
xhat_3
[
i
].
y
=
(
vinp_3
.
y
-
fmean
)
*
var_rsqrt
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_2
[
i
].
x
*
dxhat_2
[
i
].
x
+
xhat_2
[
i
].
y
*
dxhat_2
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_3
[
i
].
x
*
dxhat_3
[
i
].
x
+
xhat_3
[
i
].
y
*
dxhat_3
[
i
].
y
;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_sum_dxhat
,
s_sum_dxhat_xhat
;
if
(
threadIdx
.
x
==
0
)
{
float
mean_dim
=
hidden_dim
*
8
*
4
;
s_sum_dxhat
=
reduce_val
[
0
]
/
mean_dim
;
s_sum_dxhat_xhat
=
reduce_val
[
1
]
/
mean_dim
;
}
__syncthreads
();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if
(
threadIdx
.
x
>=
hidden_dim
)
{
return
;
}
if
(
residual_grad
)
{
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
float4
dresidual_1
=
((
const
float4
*
)
residual_grad
)[
offset
+
1
];
float4
dresidual_2
=
((
const
float4
*
)
residual_grad
)[
offset
+
2
];
float4
dresidual_3
=
((
const
float4
*
)
residual_grad
)[
offset
+
3
];
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
__half
*
hdres_1
=
reinterpret_cast
<
__half
*>
(
&
dresidual_1
);
__half
*
hdres_2
=
reinterpret_cast
<
__half
*>
(
&
dresidual_2
);
__half
*
hdres_3
=
reinterpret_cast
<
__half
*>
(
&
dresidual_3
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
]));
tmp_h2_1
[
i
].
x
=
__float2half
(
(
dxhat_1
[
i
].
x
-
s_sum_dxhat
-
xhat_1
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
]));
tmp_h2_2
[
i
].
x
=
__float2half
(
(
dxhat_2
[
i
].
x
-
s_sum_dxhat
-
xhat_2
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_2
[
2
*
i
]));
tmp_h2_3
[
i
].
x
=
__float2half
(
(
dxhat_3
[
i
].
x
-
s_sum_dxhat
-
xhat_3
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_3
[
2
*
i
]));
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
+
1
]));
tmp_h2_1
[
i
].
y
=
__float2half
(
(
dxhat_1
[
i
].
y
-
s_sum_dxhat
-
xhat_1
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
+
1
]));
tmp_h2_2
[
i
].
y
=
__float2half
(
(
dxhat_2
[
i
].
y
-
s_sum_dxhat
-
xhat_2
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
+
1
]));
tmp_h2_3
[
i
].
y
=
__float2half
(
(
dxhat_3
[
i
].
y
-
s_sum_dxhat
-
xhat_3
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
+
1
]));
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_1
[
i
].
x
=
__float2half
(
(
dxhat_1
[
i
].
x
-
s_sum_dxhat
-
xhat_1
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_2
[
i
].
x
=
__float2half
(
(
dxhat_2
[
i
].
x
-
s_sum_dxhat
-
xhat_2
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_3
[
i
].
x
=
__float2half
(
(
dxhat_3
[
i
].
x
-
s_sum_dxhat
-
xhat_3
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_1
[
i
].
y
=
__float2half
(
(
dxhat_1
[
i
].
y
-
s_sum_dxhat
-
xhat_1
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_2
[
i
].
y
=
__float2half
(
(
dxhat_2
[
i
].
y
-
s_sum_dxhat
-
xhat_2
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_3
[
i
].
y
=
__float2half
(
(
dxhat_3
[
i
].
y
-
s_sum_dxhat
-
xhat_3
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
}
}
((
float4
*
)
inp_grad
)[
offset
]
=
vtmp
;
((
float4
*
)
inp_grad
)[
offset
+
1
]
=
vtmp_1
;
((
float4
*
)
inp_grad
)[
offset
+
2
]
=
vtmp_2
;
((
float4
*
)
inp_grad
)[
offset
+
3
]
=
vtmp_3
;
}
/**
Layer norm backword,
compute the gradient of gamma, betta and input.
dbetta = sum(dout, dim=0)
xhat = (input - mean) * rsqrt(var) if mean is not nullptr
(output - betta) / gamma if mean is nullptr
dgamma = sum(xhat * dout, dim=0)
dxhat = dout * gamma
dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim)
* rsqrt(var)
residual_grad, means, betta can be nullptr.
residual_grad will be added to dinp if it is not nullptr
which is useful in transformer layer when pre-ln
means and betta are only used to compute xhat,
(means == nullptr) ^ (betta == nullptr) should be true
*/
template
<
>
void
launch_ln_bw
<
float
>
(
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
const
float
*
out_grad
,
const
float
*
residual_grad
,
const
float
*
inp_or_out
,
const
float
*
gamma
,
const
float
*
betta
,
const
float
*
vars
,
const
float
*
means
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
// compute grad of gamma and betta
dim3
grid_dim
(((
hidden_dim
+
TILE_DIM
-
1
)
/
TILE_DIM
)
*
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
ker_ln_bw_dgamma_dbetta
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
gamma_grad
,
betta_grad
,
out_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
batch
,
hidden_dim
);
// compute grad of input
if
(
hidden_dim
%
4
!=
0
||
hidden_dim
>
4096
)
{
throw
std
::
runtime_error
(
"hidden_dim % 4 != 0 || hidden_dim > 4096"
);
}
hidden_dim
>>=
2
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
hidden_dim
);
}
template
<
>
void
launch_ln_bw
<
__half
>
(
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
means
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
// compute grad of gamma and betta
dim3
grid_dim
(((
hidden_dim
+
TILE_DIM
-
1
)
/
TILE_DIM
)
*
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
ker_ln_bw_dgamma_dbetta
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
gamma_grad
,
betta_grad
,
out_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
batch
,
hidden_dim
);
// compute grad of input
if
(
hidden_dim
%
8
!=
0
)
{
throw
std
::
runtime_error
(
"hidden_dim % 8 != 0"
);
}
hidden_dim
>>=
3
;
if
(
hidden_dim
*
8
<=
8192
)
{
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
hidden_dim
);
}
else
if
(
hidden_dim
*
8
>
8192
&&
hidden_dim
*
8
<=
8192
*
2
)
{
hidden_dim
>>=
1
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp_x2
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
hidden_dim
);
}
else
if
(
hidden_dim
*
8
>
2
*
8192
&&
hidden_dim
*
8
<=
8192
*
4
)
{
hidden_dim
>>=
2
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp_x4
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
hidden_dim
);
}
else
{
throw
std
::
runtime_error
(
"hidden_dim % 4 != 0 || hidden_dim > 32768"
);
}
}
colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
0 → 100644
View file @
5c3843dc
#include <math.h>
#include <cub/block/block_load.cuh>
#include <cub/cub.cuh>
#include "block_reduce.h"
#include "kernels.h"
#include <cooperative_groups.h>
namespace
cg
=
cooperative_groups
;
const
float
EPSILON
=
1e-8
f
;
/**
@brief: softmax_kernel
Softmax forward kernel for
enc-self-attn, dec-self-attn, encdec-attn
@thread
gridDim.x = dynamic
gridDim.y = batch_size
gridDim.z = nhead
blockDim.x = from_len
@param
inp: [batch_size, nhead, from_len, to_len], softmax input.
attn_mask: [batch_size, to_len], padding tokens are -inf,
non padding tokens are 0.
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template
<
typename
T
,
int
block_dim
,
int
ele_per_thread
>
__global__
void
ker_attn_softmax
(
T
*
inp
,
const
T
*
attn_mask
,
int
from_len
,
int
to_len
,
bool
mask_future
)
{
int
batch_id
=
blockIdx
.
y
;
int
head_id
=
blockIdx
.
z
;
const
int
nhead
=
gridDim
.
z
;
const
int
token_per_reduce
=
1
;
typedef
cub
::
BlockLoad
<
T
,
block_dim
,
ele_per_thread
,
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_dim
,
ele_per_thread
,
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
T
mval
[
ele_per_thread
];
if
(
attn_mask
)
{
attn_mask
+=
batch_id
*
to_len
;
BlockLoad
(
ts_load
).
Load
(
attn_mask
,
mval
,
to_len
,
REDUCE_FLOAT_INF_NEG
);
}
inp
+=
flat_3dim
(
batch_id
,
head_id
,
0
,
nhead
,
from_len
*
to_len
);
for
(
int
token_id
=
blockIdx
.
x
*
token_per_reduce
;
token_id
<
from_len
;
token_id
+=
gridDim
.
x
*
token_per_reduce
)
{
T
inp_val
[
token_per_reduce
][
ele_per_thread
];
for
(
int
i
=
0
;
i
<
token_per_reduce
&&
(
token_id
+
i
)
<
from_len
;
i
++
)
{
BlockLoad
(
ts_load
).
Load
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
,
REDUCE_FLOAT_INF_NEG
);
}
/* step 1. compute max */
// thread local max
float
val
[
token_per_reduce
][
ele_per_thread
];
float
l_max
[
token_per_reduce
];
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
l_max
[
i
]
=
REDUCE_FLOAT_INF_NEG
;
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
if
(
attn_mask
)
{
val
[
i
][
j
]
=
(
float
)
inp_val
[
i
][
j
]
+
(
float
)
mval
[
j
];
}
else
{
if
(
mask_future
&&
ele_per_thread
*
threadIdx
.
x
+
j
>
token_id
+
i
)
{
val
[
i
][
j
]
=
REDUCE_FLOAT_INF_NEG
;
}
else
{
val
[
i
][
j
]
=
(
float
)
inp_val
[
i
][
j
];
}
}
l_max
[
i
]
=
fmaxf
(
l_max
[
i
],
val
[
i
][
j
]);
}
}
// block reduce max
blockReduce
<
ReduceType
::
kMax
,
token_per_reduce
>
(
l_max
);
// write shared
__shared__
float
s_max
[
token_per_reduce
];
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
s_max
[
i
]
=
l_max
[
i
];
}
}
__syncthreads
();
/* step 2. compute sum */
// thread local sum
float
l_sum
[
token_per_reduce
];
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
l_sum
[
i
]
=
0.
f
;
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
val
[
i
][
j
]
=
__expf
(
val
[
i
][
j
]
-
s_max
[
i
]);
l_sum
[
i
]
+=
val
[
i
][
j
];
}
}
// block reduce sum
blockReduce
<
ReduceType
::
kSum
,
token_per_reduce
>
(
l_sum
);
// write shared
__shared__
float
s_sum
[
token_per_reduce
];
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
s_sum
[
i
]
=
__fdividef
(
1.0
f
,
l_sum
[
i
]
+
EPSILON
);
}
}
__syncthreads
();
/* step 3. compute final result */
for
(
int
i
=
0
;
i
<
token_per_reduce
&&
(
token_id
+
i
)
<
from_len
;
i
++
)
{
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
inp_val
[
i
][
j
]
=
(
T
)(
val
[
i
][
j
]
*
s_sum
[
i
]);
}
BlockStore
(
ts_store
).
Store
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
);
}
}
// blockIdx.x
}
template
<
typename
T
,
int
block_dim
,
int
ele_per_thread
>
__global__
void
ker_attn_softmax_lt32
(
T
*
inp
,
const
T
*
attn_mask
,
int
from_len
,
int
to_len
,
bool
mask_future
)
{
int
batch_id
=
blockIdx
.
y
;
int
head_id
=
blockIdx
.
z
;
const
int
nhead
=
gridDim
.
z
;
const
int
token_per_reduce
=
1
;
typedef
cub
::
BlockLoad
<
T
,
block_dim
,
ele_per_thread
,
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_dim
,
ele_per_thread
,
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
T
mval
[
ele_per_thread
];
if
(
attn_mask
)
{
attn_mask
+=
batch_id
*
to_len
;
BlockLoad
(
ts_load
).
Load
(
attn_mask
,
mval
,
to_len
,
REDUCE_FLOAT_INF_NEG
);
}
inp
+=
flat_3dim
(
batch_id
,
head_id
,
0
,
nhead
,
from_len
*
to_len
);
for
(
int
token_id
=
blockIdx
.
x
*
token_per_reduce
;
token_id
<
from_len
;
token_id
+=
gridDim
.
x
*
token_per_reduce
)
{
T
inp_val
[
token_per_reduce
][
ele_per_thread
];
for
(
int
i
=
0
;
i
<
token_per_reduce
&&
(
token_id
+
i
)
<
from_len
;
i
++
)
{
BlockLoad
(
ts_load
).
Load
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
,
REDUCE_FLOAT_INF_NEG
);
}
/* step 1. compute max */
// thread local max
float
val
[
token_per_reduce
][
ele_per_thread
];
float
l_max
[
token_per_reduce
];
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
l_max
[
i
]
=
REDUCE_FLOAT_INF_NEG
;
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
if
(
attn_mask
)
{
val
[
i
][
j
]
=
(
float
)
inp_val
[
i
][
j
]
+
(
float
)
mval
[
j
];
}
else
{
if
(
mask_future
&&
ele_per_thread
*
threadIdx
.
x
+
j
>
token_id
+
i
)
{
val
[
i
][
j
]
=
REDUCE_FLOAT_INF_NEG
;
}
else
{
val
[
i
][
j
]
=
(
float
)
inp_val
[
i
][
j
];
}
}
l_max
[
i
]
=
fmaxf
(
l_max
[
i
],
val
[
i
][
j
]);
}
}
// warp reduce max
warpReduce
<
ReduceType
::
kMax
,
token_per_reduce
>
(
l_max
);
/* step 2. compute sum */
// thread local sum
float
l_sum
[
token_per_reduce
];
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
l_sum
[
i
]
=
0.
f
;
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
val
[
i
][
j
]
=
__expf
(
val
[
i
][
j
]
-
l_max
[
i
]);
l_sum
[
i
]
+=
val
[
i
][
j
];
}
}
// warp reduce sum
warpReduce
<
ReduceType
::
kSum
,
token_per_reduce
>
(
l_sum
);
/* step 3. compute final result */
for
(
int
i
=
0
;
i
<
token_per_reduce
&&
(
token_id
+
i
)
<
from_len
;
i
++
)
{
l_sum
[
i
]
=
__fdividef
(
1.0
f
,
l_sum
[
i
]
+
EPSILON
);
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
inp_val
[
i
][
j
]
=
(
T
)(
val
[
i
][
j
]
*
l_sum
[
i
]);
}
BlockStore
(
ts_store
).
Store
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
);
}
}
// blockIdx.x
}
/*
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template
<
>
void
launch_attn_softmax
<
float
>
(
float
*
inp
,
const
float
*
attn_mask
,
int
batch_size
,
int
nhead
,
int
from_len
,
int
to_len
,
bool
mask_future
,
cudaStream_t
stream
)
{
dim3
grid_dim
(
1
,
batch_size
,
nhead
);
if
(
to_len
<=
32
)
{
ker_attn_softmax_lt32
<
float
,
32
,
1
><<<
grid_dim
,
32
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
64
)
{
ker_attn_softmax_lt32
<
float
,
32
,
2
><<<
grid_dim
,
32
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
128
)
{
grid_dim
.
x
=
16
;
ker_attn_softmax
<
float
,
64
,
2
><<<
grid_dim
,
64
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
256
)
{
grid_dim
.
x
=
32
;
ker_attn_softmax
<
float
,
128
,
2
><<<
grid_dim
,
128
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
512
)
{
grid_dim
.
x
=
64
;
ker_attn_softmax
<
float
,
256
,
2
><<<
grid_dim
,
256
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
{
throw
std
::
runtime_error
(
"Sequence length greater than 512 is currently not supported"
);
}
}
template
<
>
void
launch_attn_softmax
<
__half
>
(
__half
*
inp
,
const
__half
*
attn_mask
,
int
batch_size
,
int
nhead
,
int
from_len
,
int
to_len
,
bool
mask_future
,
cudaStream_t
stream
)
{
dim3
grid_dim
(
1
,
batch_size
,
nhead
);
if
(
to_len
<=
32
)
{
ker_attn_softmax_lt32
<
__half
,
32
,
1
><<<
grid_dim
,
32
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
64
)
{
ker_attn_softmax_lt32
<
__half
,
32
,
2
><<<
grid_dim
,
32
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
128
)
{
grid_dim
.
x
=
8
;
ker_attn_softmax
<
__half
,
64
,
2
><<<
grid_dim
,
64
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
256
)
{
grid_dim
.
x
=
16
;
ker_attn_softmax
<
__half
,
128
,
2
><<<
grid_dim
,
128
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
512
)
{
grid_dim
.
x
=
32
;
ker_attn_softmax
<
__half
,
256
,
2
><<<
grid_dim
,
256
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
{
throw
std
::
runtime_error
(
"Sequence length greater than 512 is currently not supported"
);
}
}
/**
@brief: ker_attn_softmax_bw
Softmax backward in self attention.
@thread
gridDim.x = batch_size * nhead * seq_len / warps_per_block
blockDim.x = WARP_SIZE
blockDim.y = warps_per_block
@param
grad: [batch_size, nhead, seq_len, seq_len], output grad.
output: [batch_size, nhead, seq_len, seq_len], output of softmax forward.
*/
template
<
typename
T
,
int
ITERATIONS
>
__global__
void
ker_attn_softmax_bw
(
T
*
grad
,
const
T
*
inp
,
int
softmax_length
)
{
int
batch_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
offset
=
batch_idx
*
softmax_length
+
threadIdx
.
x
;
grad
+=
offset
;
inp
+=
offset
;
T
grad_reg
[
ITERATIONS
];
T
inp_reg
[
ITERATIONS
];
float
sum
=
0.0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
int
curr_idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
curr_idx
<
softmax_length
)
{
grad_reg
[
i
]
=
grad
[
i
*
WARP_SIZE
];
inp_reg
[
i
]
=
inp
[
i
*
WARP_SIZE
];
sum
+=
(
float
)
grad_reg
[
i
]
*
(
float
)
inp_reg
[
i
];
}
}
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
int
curr_idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
curr_idx
<
softmax_length
)
grad
[
i
*
WARP_SIZE
]
=
(
T
)((
float
)
inp_reg
[
i
]
*
((
float
)
grad_reg
[
i
]
-
sum
));
}
}
template
<
typename
T
>
void
launch_attn_softmax_bw
(
T
*
out_grad
,
const
T
*
soft_inp
,
int
rows
,
int
softmax_len
,
cudaStream_t
stream
)
{
const
int
warps_per_block
=
4
;
// rows = batch_size * nhead * from_len
dim3
grid_dim
(
rows
/
warps_per_block
);
dim3
block_dim
(
WARP_SIZE
,
warps_per_block
);
if
(
softmax_len
<=
32
)
ker_attn_softmax_bw
<
T
,
1
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
64
)
ker_attn_softmax_bw
<
T
,
2
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
128
)
ker_attn_softmax_bw
<
T
,
4
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
256
)
ker_attn_softmax_bw
<
T
,
8
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
384
)
ker_attn_softmax_bw
<
T
,
12
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
512
)
ker_attn_softmax_bw
<
T
,
16
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
768
)
ker_attn_softmax_bw
<
T
,
24
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
1024
)
ker_attn_softmax_bw
<
T
,
32
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
2048
)
ker_attn_softmax_bw
<
T
,
64
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
throw
std
::
runtime_error
(
std
::
string
(
"Special sequence length found in softmax backward, seq_len: "
)
+
std
::
to_string
(
softmax_len
));
}
template
void
launch_attn_softmax_bw
<
__half
>(
__half
*
out_grad
,
const
__half
*
soft_inp
,
int
rows
,
int
softmax_len
,
cudaStream_t
stream
);
template
void
launch_attn_softmax_bw
<
float
>(
float
*
out_grad
,
const
float
*
soft_inp
,
int
rows
,
int
softmax_len
,
cudaStream_t
stream
);
colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu
0 → 100644
View file @
5c3843dc
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include "kernels.h"
using
namespace
cub
;
/**
@brief: transform_0213
Split the attention heads and reshape input
during backward progress of encoder self-attention
@thread
gridDim.x = batch_size
gridDim.y = seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
input: [batch_size, seq_len, hidden_dim]
output: [batch_size, nhead, seq_len, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
*/
template
<
typename
T
>
__global__
void
transform_0213
(
T
*
output
,
const
T
*
input
,
int
hidden_dim
,
int
head_dim
);
template
<
>
__global__
void
transform_0213
<
float
>
(
float
*
output
,
const
float
*
input
,
int
hidden_dim
,
int
head_dim
)
{
int
batch_id
=
blockIdx
.
x
;
int
token_id
=
blockIdx
.
y
;
int
seq_len
=
gridDim
.
y
;
int
nhead
=
hidden_dim
/
head_dim
;
// [b, s, h]
int
src_offset
=
flat_3dim
(
batch_id
,
token_id
,
0
,
seq_len
,
hidden_dim
);
// [b, nh, s, ad]
int
trg_offset
=
flat_4dim
(
batch_id
,
0
,
token_id
,
0
,
nhead
,
seq_len
,
head_dim
);
const
float4
*
input4
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
float4
vinput4
;
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
hidden_dim
;
i
+=
blockDim
.
x
)
{
vinput4
=
input4
[
src_offset
+
i
];
int
head_id
=
i
/
head_dim
;
int
dim_id
=
i
%
head_dim
;
int
cur_trg_offset
=
flat_3dim
(
head_id
,
0
,
dim_id
,
seq_len
,
head_dim
);
res4
[
trg_offset
+
cur_trg_offset
]
=
vinput4
;
}
}
template
<
>
__global__
void
transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
int
hidden_dim
,
int
head_dim
)
{
int
batch_id
=
blockIdx
.
x
;
int
token_id
=
blockIdx
.
y
;
int
seq_len
=
gridDim
.
y
;
int
nhead
=
hidden_dim
/
head_dim
;
// [b, s, h]
int
src_offset
=
flat_3dim
(
batch_id
,
token_id
,
0
,
seq_len
,
hidden_dim
);
// [b, nh, s, ad]
int
trg_offset
=
flat_4dim
(
batch_id
,
0
,
token_id
,
0
,
nhead
,
seq_len
,
head_dim
);
const
float4
*
input4
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
float4
vinput4
;
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
hidden_dim
;
i
+=
blockDim
.
x
)
{
vinput4
=
input4
[
src_offset
+
i
];
int
head_id
=
i
/
head_dim
;
int
dim_id
=
i
%
head_dim
;
int
cur_trg_offset
=
flat_3dim
(
head_id
,
0
,
dim_id
,
seq_len
,
head_dim
);
res4
[
trg_offset
+
cur_trg_offset
]
=
vinput4
;
}
}
// [b, s, h] -> [b, nh, s, ad]
template
<
>
void
launch_transform_0213
<
float
>
(
float
*
output
,
const
float
*
input
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
int
nhead
,
cudaStream_t
stream
)
{
hidden_dim
>>=
2
;
int
head_dim
=
hidden_dim
/
nhead
;
dim3
grid_dim
(
batch_size
,
seq_len
);
dim3
block_dim
(
min
(
hidden_dim
,
MAX_THREADS
));
transform_0213
<
float
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
hidden_dim
,
head_dim
);
}
template
<
>
void
launch_transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
int
nhead
,
cudaStream_t
stream
)
{
hidden_dim
>>=
3
;
int
head_dim
=
hidden_dim
/
nhead
;
dim3
grid_dim
(
batch_size
,
seq_len
);
dim3
block_dim
(
min
(
hidden_dim
,
MAX_THREADS
));
transform_0213
<
__half
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
hidden_dim
,
head_dim
);
}
/**
@brief: bias_add_transform_20314
Add bias to input, transform from
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
@thread
gridDim.x = dim_0
gridDim.y = dim_1
gridDim.z = dim_2
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
@param
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
bias: [dim_2, dim_3, dim_4]
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
*/
template
<
typename
T
>
__global__
void
bias_add_transform_20314
(
T
*
output
,
const
T
*
input
,
const
T
*
bias
,
int
dim_3
,
int
dim_4
);
template
<
>
__global__
void
bias_add_transform_20314
<
float
>
(
float
*
output
,
const
float
*
input
,
const
float
*
bias
,
int
dim_3
,
int
dim_4
)
{
int
id0
=
blockIdx
.
x
;
int
id1
=
blockIdx
.
y
;
int
id2
=
blockIdx
.
z
;
int
dim_0
=
gridDim
.
x
;
int
dim_1
=
gridDim
.
y
;
int
dim_2
=
gridDim
.
z
;
int
dim_34
=
dim_3
*
dim_4
;
int
src_offset
=
flat_4dim
(
id0
,
id1
,
id2
,
0
,
dim_1
,
dim_2
,
dim_34
);
int
trg_offset
=
flat_5dim
(
id2
,
id0
,
0
,
id1
,
0
,
dim_0
,
dim_3
,
dim_1
,
dim_4
);
int
bias_offset
=
flat_2dim
(
id2
,
0
,
dim_34
);
const
float4
*
qkv4
=
reinterpret_cast
<
const
float4
*>
(
input
);
const
float4
*
bias4
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
float4
vqkv4
;
float4
vbias4
;
float4
vres4
;
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
dim_34
;
i
+=
blockDim
.
x
)
{
vqkv4
=
qkv4
[
src_offset
+
i
];
vbias4
=
bias4
[
bias_offset
+
i
];
vres4
.
x
=
vqkv4
.
x
+
vbias4
.
x
;
vres4
.
y
=
vqkv4
.
y
+
vbias4
.
y
;
vres4
.
z
=
vqkv4
.
z
+
vbias4
.
z
;
vres4
.
w
=
vqkv4
.
w
+
vbias4
.
w
;
int
id3
=
i
/
dim_4
;
int
id4
=
i
%
dim_4
;
int
cur_trg_offset
=
flat_3dim
(
id3
,
0
,
id4
,
dim_1
,
dim_4
);
res4
[
trg_offset
+
cur_trg_offset
]
=
vres4
;
}
}
template
<
>
__global__
void
bias_add_transform_20314
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
const
__half
*
bias
,
int
dim_3
,
int
dim_4
)
{
int
id0
=
blockIdx
.
x
;
int
id1
=
blockIdx
.
y
;
int
id2
=
blockIdx
.
z
;
int
dim_0
=
gridDim
.
x
;
int
dim_1
=
gridDim
.
y
;
int
dim_2
=
gridDim
.
z
;
int
dim_34
=
dim_3
*
dim_4
;
int
src_offset
=
flat_4dim
(
id0
,
id1
,
id2
,
0
,
dim_1
,
dim_2
,
dim_34
);
int
trg_offset
=
flat_5dim
(
id2
,
id0
,
0
,
id1
,
0
,
dim_0
,
dim_3
,
dim_1
,
dim_4
);
int
bias_offset
=
flat_2dim
(
id2
,
0
,
dim_34
);
const
float4
*
qkv4
=
reinterpret_cast
<
const
float4
*>
(
input
);
const
float4
*
bias4
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
float4
vqkv4
;
float4
vbias4
;
float4
vres4
;
__half2
*
h2_qkv
=
reinterpret_cast
<
__half2
*>
(
&
vqkv4
);
__half2
*
h2_bias
=
reinterpret_cast
<
__half2
*>
(
&
vbias4
);
__half2
*
h2_res
=
reinterpret_cast
<
__half2
*>
(
&
vres4
);
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
dim_34
;
i
+=
blockDim
.
x
)
{
vqkv4
=
qkv4
[
src_offset
+
i
];
vbias4
=
bias4
[
bias_offset
+
i
];
h2_res
[
0
]
=
__hadd2
(
h2_qkv
[
0
],
h2_bias
[
0
]);
h2_res
[
1
]
=
__hadd2
(
h2_qkv
[
1
],
h2_bias
[
1
]);
h2_res
[
2
]
=
__hadd2
(
h2_qkv
[
2
],
h2_bias
[
2
]);
h2_res
[
3
]
=
__hadd2
(
h2_qkv
[
3
],
h2_bias
[
3
]);
int
id3
=
i
/
dim_4
;
int
id4
=
i
%
dim_4
;
int
cur_trg_offset
=
flat_3dim
(
id3
,
0
,
id4
,
dim_1
,
dim_4
);
res4
[
trg_offset
+
cur_trg_offset
]
=
vres4
;
}
}
// [b, s, 3, h] -> [3, b, nh, s, ad]
template
<
>
void
launch_bias_add_transform_20314
<
float
>
(
float
*
output
,
const
float
*
input
,
const
float
*
bias
,
int
dim_0
,
int
dim_1
,
int
dim_2
,
int
dim_3
,
int
dim_4
,
cudaStream_t
stream
)
{
dim_4
>>=
2
;
dim3
grid_dim
(
dim_0
,
dim_1
,
dim_2
);
dim3
block_dim
(
min
(
dim_3
*
dim_4
,
MAX_THREADS
));
bias_add_transform_20314
<
float
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
bias
,
dim_3
,
dim_4
);
}
template
<
>
void
launch_bias_add_transform_20314
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
const
__half
*
bias
,
int
dim_0
,
int
dim_1
,
int
dim_2
,
int
dim_3
,
int
dim_4
,
cudaStream_t
stream
)
{
dim_4
>>=
3
;
dim3
grid_dim
(
dim_0
,
dim_1
,
dim_2
);
dim3
block_dim
(
min
(
dim_3
*
dim_4
,
MAX_THREADS
));
bias_add_transform_20314
<
__half
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
bias
,
dim_3
,
dim_4
);
}
/**
@brief: transform4d_0213
Reshape the input matrix to merge the heads
@thread
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
blockDim.x = max_block_thread
@param
input: [trans_count, batch_size, nhead, seq_len, head_dim]
output: [batch_size, seq_len, trans_count, nhead, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
trans_count: 1 or 3, the count of matrice need to be transformed
*/
template
<
typename
T
>
__global__
void
transform4d_0213
(
T
*
output
,
const
T
*
input
,
int
batch_size
,
int
seq_len
,
int
trans_count
,
int
nhead
,
int
head_dim
,
int
num_all
)
{
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
>=
num_all
)
{
return
;
}
int
trans_id
,
batch_id
,
head_id
,
token_id
,
dim_id
;
decompose_5dim
(
offset
,
batch_size
,
nhead
,
seq_len
,
head_dim
,
&
trans_id
,
&
batch_id
,
&
head_id
,
&
token_id
,
&
dim_id
);
// [b, s, tc, nh, ad]
int
trg_offset
=
flat_5dim
(
batch_id
,
token_id
,
trans_id
,
head_id
,
dim_id
,
seq_len
,
trans_count
,
nhead
,
head_dim
);
const
float4
*
input4
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
res4
[
trg_offset
]
=
input4
[
offset
];
}
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template
<
>
void
launch_transform4d_0213
<
float
>
(
float
*
output
,
const
float
*
input
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
int
nhead
,
int
trans_count
,
cudaStream_t
stream
)
{
hidden_dim
>>=
2
;
int
head_dim
=
hidden_dim
/
nhead
;
int
num_all
=
batch_size
*
seq_len
*
trans_count
*
hidden_dim
;
int
nblock
=
(
num_all
+
MAX_THREADS
-
1
)
/
MAX_THREADS
;
transform4d_0213
<
float
><<<
nblock
,
MAX_THREADS
,
0
,
stream
>>>
(
output
,
input
,
batch_size
,
seq_len
,
trans_count
,
nhead
,
head_dim
,
num_all
);
}
template
<
>
void
launch_transform4d_0213
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
int
nhead
,
int
trans_count
,
cudaStream_t
stream
)
{
hidden_dim
>>=
3
;
int
head_dim
=
hidden_dim
/
nhead
;
int
num_all
=
batch_size
*
seq_len
*
trans_count
*
hidden_dim
;
int
nblock
=
(
num_all
+
MAX_THREADS
-
1
)
/
MAX_THREADS
;
transform4d_0213
<
__half
><<<
nblock
,
MAX_THREADS
,
0
,
stream
>>>
(
output
,
input
,
batch_size
,
seq_len
,
trans_count
,
nhead
,
head_dim
,
num_all
);
}
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
0 → 100644
View file @
5c3843dc
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <vector>
#include <cassert>
#include "compat.h"
namespace
{
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n2
)
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
assert
(
input
.
sizes
()[
i
+
idiff
]
==
normalized_shape
[
i
]
);
n2
*=
normalized_shape
[
i
];
}
n1
=
1
;
for
(
int
i
=
0
;
i
<
idiff
;
++
i
)
{
n1
*=
input
.
sizes
()[
i
];
}
}
void
check_args
(
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
)
{
TORCH_CHECK
(
!
gamma
.
defined
()
||
gamma
.
sizes
().
equals
(
normalized_shape
));
TORCH_CHECK
(
!
beta
.
defined
()
||
beta
.
sizes
().
equals
(
normalized_shape
));
}
void
check_args
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n2
)
{
int64_t
normalized_ndim
=
normalized_shape
.
size
();
if
(
normalized_ndim
<
1
)
{
std
::
stringstream
ss
;
ss
<<
"Expected normalized_shape to be at least 1-dimensional, i.e., "
<<
"containing at least one element, but got normalized_shape="
<<
normalized_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
auto
input_shape
=
input
.
sizes
();
auto
input_ndim
=
input
.
dim
();
if
(
input_ndim
<
normalized_ndim
||
!
input_shape
.
slice
(
input_ndim
-
normalized_ndim
).
equals
(
normalized_shape
))
{
std
::
stringstream
ss
;
ss
<<
"Given normalized_shape="
<<
normalized_shape
<<
", expected input with shape [*"
;
for
(
auto
size
:
normalized_shape
)
{
ss
<<
", "
<<
size
;
}
ss
<<
"], but got input of size"
<<
input_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
compute_n1_n2
(
input
,
normalized_shape
,
n1
,
n2
);
}
void
check_args
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
int
&
n1
,
int
&
n2
)
{
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
normalized_shape
,
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
(
{
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
}
\ No newline at end of file
colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
0 → 100644
View file @
5c3843dc
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
mu
=
lmean
;
U
delta2
=
curr
-
lmean
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
template
<
typename
U
>
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nB
=
countB
;
count
=
count
+
countB
;
U
nX
=
count
;
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
mu
=
U
(
0
);
sigma2
=
U
(
0
);
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U
count
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
}
}
}
template
<
>
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float
count
=
0.0
f
;
mu
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// first thread consumes first point
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
++
l
;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
}
}
}
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template
<
typename
T
>
struct
SharedMemory
;
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuApplyLayerNorm
(
V
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
V
*
__restrict__
beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
V
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
));
}
}
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
mean
[
i1
]
=
mu
;
invvar
[
i1
]
=
c_invvar
;
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputePartGradGammaBeta
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
__syncthreads
();
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputeGradInput
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
V
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
gamma
[
l
+
k
];
sum_loss2
+=
c_loss
*
gamma
[
l
+
k
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
gamma
[
l
];
sum_loss2
+=
c_loss
*
gamma
[
l
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
else
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
gamma
[
l
];
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
V
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
V
*
gamma
,
const
V
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
HostApplyLayerNorm
(
output
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
->
DATA_PTR
<
scalar_t_in
>
(),
n1
,
n2
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
const
V
*
dout
,
const
U
*
mean
,
const
U
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
const
V
*
gamma
,
const
V
*
beta
,
double
epsilon
,
T
*
grad_input
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
());
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
(),
part_size
,
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
gamma
->
scalar_type
(),
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
dout
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
\ No newline at end of file
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
0 → 100644
View file @
5c3843dc
#include "multihead_attention_1d.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10d/Types.hpp>
#include <iostream>
#include "context.h"
#include "kernels.h"
template
<
typename
T
>
MultiHeadAttention
<
T
>::
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
max_seq_len
,
int
hidden_size
,
int
num_heads
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
bool
pre_or_postLayerNorm
)
:
_layer_id
(
layer_id
),
_max_batch_tokens
(
max_batch_tokens
),
_max_seq_len
(
max_seq_len
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
3
*
hidden_size
,
hidden_size
)),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
hidden_size
,
hidden_size
)),
_attn_ln
(
typename
Normalize_Layer
<
T
>::
Config
(
hidden_size
,
false
),
_max_batch_tokens
),
_softmax
(
typename
Softmax
<
T
>::
Config
(
num_heads
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
),
_max_batch_tokens
*
_heads
*
_max_seq_len
),
_attn_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
),
_max_batch_tokens
*
_hidden_size
),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
((
T
(
1.0
)
/
T
(
sqrt
(
_hidden_size
/
_heads
))),
T
(
0.0
),
CUBLAS_OP_T
,
CUBLAS_OP_N
)),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
T
(
1.0
),
T
(
0.0
),
CUBLAS_OP_N
,
CUBLAS_OP_N
))
{
assert
(
_hidden_size
%
_heads
==
0
);
}
template
<
typename
T
>
MultiHeadAttention
<
T
>::~
MultiHeadAttention
()
{
free_mem_buffer
();
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
T
*
buffer
)
{
T
*
q_tf_ptr
=
_qkv_ptr
;
T
*
k_tf_ptr
=
q_tf_ptr
+
_batch_dim
/
pg_size
;
T
*
v_tf_ptr
=
k_tf_ptr
+
_batch_dim
/
pg_size
;
if
(
_pre_or_postLayerNorm
)
{
_attn_ln
.
Forward
(
_gemmQKV_inp_ptr
,
input_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
_stream
);
}
const
T
*
gemmQKV_inp_ptr
=
_pre_or_postLayerNorm
?
_gemmQKV_inp_ptr
:
input_ptr
;
_qkv_linear
.
reset_size
(
3
*
_hidden_size
/
pg_size
,
_hidden_size
);
_qkv_linear
.
Forward
(
_batch_tokens
,
gemmQKV_inp_ptr
,
_attn_qkvw_ptr
,
buffer
,
_cublasHandle
);
launch_bias_add_transform_20314
<
T
>
(
q_tf_ptr
,
buffer
,
_attn_qkvb_ptr
,
_batch_size
,
_seq_len
,
3
,
_heads
/
pg_size
,
_hidden_size
/
_heads
,
_stream
);
// attention scores, q*k
_attn_scores
.
Forward
(
_batch_heads
,
_soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
reset_size
(
_heads
/
pg_size
);
_softmax
.
Forward
(
_soft_out_ptr
,
input_mask_ptr
,
_batch_size
,
_seq_len
,
_seq_len
,
_stream
,
true
);
// attn prob dropout.
_attn_prob_dropout
.
dropout
(
_ctx_bufB_ptr
,
_soft_out_ptr
,
_batch_heads
*
_seq_len
*
_seq_len
,
_stream
);
// attention context, score * v
_attn_context
.
Forward
(
_batch_heads
,
buffer
,
v_tf_ptr
,
_ctx_bufB_ptr
,
_cublasHandle
);
// [b, nh, s, ad] -> [b, s, nh, ad]
launch_transform4d_0213
<
T
>
(
_attn_o_inp_ptr
,
buffer
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
1
,
_stream
);
_attn_out_linear
.
reset_size
(
_hidden_size
,
_hidden_size
/
pg_size
);
_attn_out_linear
.
Forward
(
_batch_tokens
,
_attn_o_inp_ptr
,
_attn_ow_ptr
,
output_ptr
,
_cublasHandle
);
// allreduce
if
(
pg
==
c10
::
detail
::
UniqueVoidPtr
()
||
pg
->
getSize
()
==
1
)
{
}
else
{
auto
data_type
=
torch
::
kFloat
;
if
(
typeid
(
T
)
!=
typeid
(
float
))
{
data_type
=
torch
::
kHalf
;
}
auto
output_tensor
=
torch
::
from_blob
(
output_ptr
,
{
int
(
_batch_size
),
int
(
_seq_len
),
int
(
_hidden_size
)},
torch
::
TensorOptions
(
torch
::
kCUDA
).
dtype
(
data_type
));
std
::
vector
<
torch
::
Tensor
>
allreduce_tensors
=
{
output_tensor
};
auto
work
=
pg
->
allreduce
(
allreduce_tensors
,
c10d
::
AllreduceOptions
());
work
->
wait
();
}
_attn_dropout
.
bias_dropout_residual
(
output_ptr
,
output_ptr
,
input_ptr
,
_attn_ob_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
// in-place ln since ln-input will not be used in post-ln mode
_attn_ln
.
Forward
(
output_ptr
,
output_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
_stream
);
}
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
)
{
_stream
=
Context
::
Instance
().
get_stream
();
_cublasHandle
=
Context
::
Instance
().
get_cublashandle
();
T
*
attn_buffer
=
_shared_mem_ptr
;
// 3 * _batch_dim
attn_layer_fw
(
input_ptr
,
input_mask_ptr
,
out_ptr
,
attn_buffer
);
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_ptr
,
T
*
buffer
)
{
cudaStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
const
T
*
q_tf_ptr
=
_qkv_ptr
;
const
T
*
k_tf_ptr
=
q_tf_ptr
+
_batch_dim
/
pg_size
;
const
T
*
v_tf_ptr
=
k_tf_ptr
+
_batch_dim
/
pg_size
;
// batch_dim = batch_size * seq_len * hidden_size
// buffer size: batch_dim * 3 + max(batch_dim * 3,
// batch_size * head_num * seq_len * seq_len)
T
*
grad_residual_ptr
=
buffer
;
buffer
+=
_batch_dim
;
T
*
grad_input_buf_ptr
=
buffer
;
// batch_dim
T
*
grad_qkv_5d_ptr
=
buffer
;
// batch_dim * 3
buffer
+=
3
*
_batch_dim
/
pg_size
;
T
*
grad_qkv_4d_ptr
=
buffer
;
// batch_dim * 3
T
*
grad_softmax_ptr
=
buffer
;
// batch_size * head_num * seq_len * seq_len
// buffer += max(3 * _batch_dim,
// batch_size * head_num * seq_len * seq_len);
if
(
_pre_or_postLayerNorm
)
{
_attn_dropout
.
d_bias_dropout_residual
(
grad_input_ptr
,
_grad_attn_ob_ptr
,
grad_output_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
}
else
{
_attn_ln
.
Backward
(
_grad_attn_nw_ptr
,
_grad_attn_nb_ptr
,
grad_residual_ptr
,
grad_output_ptr
,
nullptr
,
output_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
streams
);
_attn_dropout
.
d_bias_dropout_residual
(
grad_input_ptr
,
_grad_attn_ob_ptr
,
grad_residual_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
}
// bw of output project
_attn_out_linear
.
reset_size
(
_hidden_size
,
_hidden_size
/
pg_size
);
_attn_out_linear
.
Backward
(
_batch_tokens
,
grad_input_ptr
,
_attn_o_inp_ptr
,
_attn_ow_ptr
,
_grad_attn_ow_ptr
,
_grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
grad_input_buf_ptr
,
nullptr
,
false
);
launch_transform_0213
<
T
>
(
grad_input_ptr
,
grad_input_buf_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
_stream
);
// bw of score * v
_attn_context
.
Backward
(
_batch_heads
,
grad_input_ptr
,
v_tf_ptr
,
_ctx_bufB_ptr
,
_cublasHandle
,
grad_qkv_5d_ptr
+
2
*
_batch_dim
/
pg_size
,
grad_softmax_ptr
);
_attn_prob_dropout
.
d_dropout
(
grad_softmax_ptr
,
_batch_heads
*
_seq_len
*
_seq_len
,
_stream
);
_softmax
.
reset_size
(
_heads
/
pg_size
);
_softmax
.
Backward
(
grad_softmax_ptr
,
_soft_out_ptr
,
_batch_size
,
_seq_len
,
_seq_len
,
_stream
);
// bw of q * k
_attn_scores
.
Backward
(
_batch_heads
,
grad_softmax_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
grad_qkv_5d_ptr
+
_batch_dim
/
pg_size
,
grad_qkv_5d_ptr
);
// [3, b, nh, s, ad] -> [b, s, 3, h]
launch_transform4d_0213
<
T
>
(
grad_qkv_4d_ptr
,
grad_qkv_5d_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
3
,
_stream
);
const
T
*
gemmQKV_inp_ptr
=
_pre_or_postLayerNorm
?
_gemmQKV_inp_ptr
:
input_ptr
;
_qkv_linear
.
reset_size
(
3
*
_hidden_size
/
pg_size
,
_hidden_size
);
_qkv_linear
.
Backward
(
_batch_tokens
,
grad_qkv_4d_ptr
,
gemmQKV_inp_ptr
,
_attn_qkvw_ptr
,
_grad_attn_qkvw_ptr
,
_grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
grad_input_buf_ptr
,
nullptr
,
true
);
// allreduce
if
(
pg
==
c10
::
detail
::
UniqueVoidPtr
()
||
pg
->
getSize
()
==
1
)
{
}
else
{
auto
data_type
=
torch
::
kFloat
;
if
(
typeid
(
T
)
!=
typeid
(
float
))
{
data_type
=
torch
::
kHalf
;
}
auto
grad_input_tensor
=
torch
::
from_blob
(
grad_input_buf_ptr
,
{
int
(
_batch_size
),
int
(
_seq_len
),
int
(
_hidden_size
)},
torch
::
TensorOptions
(
torch
::
kCUDA
).
dtype
(
data_type
));
std
::
vector
<
torch
::
Tensor
>
allreduce_tensors
=
{
grad_input_tensor
};
auto
work
=
pg
->
allreduce
(
allreduce_tensors
,
c10d
::
AllreduceOptions
());
work
->
wait
();
}
if
(
_pre_or_postLayerNorm
)
{
_attn_ln
.
Backward
(
_grad_attn_nw_ptr
,
_grad_attn_nb_ptr
,
grad_input_ptr
,
grad_input_buf_ptr
,
grad_output_ptr
,
gemmQKV_inp_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
streams
);
}
else
{
// FIXME later
launch_fused_add2
<
T
>
(
grad_input_ptr
,
grad_input_buf_ptr
,
grad_residual_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
,
_stream
);
}
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
)
{
_stream
=
Context
::
Instance
().
get_stream
();
_cublasHandle
=
Context
::
Instance
().
get_cublashandle
();
T
*
buffer
=
_shared_mem_ptr
;
/*
buffer size needed by attn bw:
4 * _batch_dim + max(3 * _batch_dim,
_batch_size * _head_num * _seq_len * _seq_len);
*/
attn_layer_bw
(
input_ptr
,
input_mask_ptr
,
output_ptr
,
grad_output_ptr
,
grad_input_ptr
,
buffer
);
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
T
*
MultiHeadAttention
<
T
>::
_shared_mem_ptr
=
nullptr
;
template
class
MultiHeadAttention
<
float
>;
template
class
MultiHeadAttention
<
__half
>;
// x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_multihead_attention
;
template
<
typename
T
>
int
create_multihead_attention
(
int
layer_id
,
int
max_batch_tokens
,
int
max_seq_len
,
int
hidden_dim
,
int
num_heads
,
float
attn_prob_dropout_ratio
,
float
hidden_dropout_ratio
,
bool
pre_or_postLayerNorm
,
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
pg_
)
{
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
Context
::
Instance
().
set_stream
(
stream
);
auto
layer
=
std
::
make_shared
<
MultiHeadAttention
<
T
>>
(
layer_id
,
max_batch_tokens
,
max_seq_len
,
hidden_dim
,
num_heads
,
attn_prob_dropout_ratio
,
hidden_dropout_ratio
,
pre_or_postLayerNorm
);
layer
->
SetPG
(
pg_
);
s_multihead_attention
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
multihead_attention_fw
(
int
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
in_proj_weight
,
const
torch
::
Tensor
&
in_proj_bias
,
const
torch
::
Tensor
&
out_proj_weight
,
const
torch
::
Tensor
&
out_proj_bias
,
const
torch
::
Tensor
&
norm_weight
,
const
torch
::
Tensor
&
norm_bias
,
bool
training_mode
,
bool
prelayernorm
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
std
::
shared_ptr
<
MultiHeadAttention
<
T
>>
layer
=
std
::
static_pointer_cast
<
MultiHeadAttention
<
T
>>
(
s_multihead_attention
[
layer_id
]);
layer
->
set_cur_batch_shape
(
input
.
size
(
0
),
input
.
size
(
1
));
layer
->
SetTrainingMode
(
training_mode
);
layer
->
_attn_qkvw_ptr
=
(
const
T
*
)
in_proj_weight
.
data_ptr
();
layer
->
_attn_qkvb_ptr
=
(
const
T
*
)
in_proj_bias
.
data_ptr
();
layer
->
_attn_ow_ptr
=
(
const
T
*
)
out_proj_weight
.
data_ptr
();
layer
->
_attn_ob_ptr
=
(
const
T
*
)
out_proj_bias
.
data_ptr
();
layer
->
_attn_nw_ptr
=
(
const
T
*
)
norm_weight
.
data_ptr
();
layer
->
_attn_nb_ptr
=
(
const
T
*
)
norm_bias
.
data_ptr
();
layer
->
Forward
(
input_ptr
,
input_mask_ptr
,
out_ptr
);
return
{
output
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
multihead_attention_bw
(
int
layer_id
,
const
torch
::
Tensor
&
grad_dec_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
in_proj_weight
,
const
torch
::
Tensor
&
in_proj_bias
,
const
torch
::
Tensor
&
out_proj_weight
,
const
torch
::
Tensor
&
out_proj_bias
,
const
torch
::
Tensor
&
norm_weight
,
const
torch
::
Tensor
&
norm_bias
)
{
auto
g_output
=
grad_dec_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_in_proj_weight
=
torch
::
empty_like
(
in_proj_weight
);
auto
grad_in_proj_bias
=
torch
::
empty_like
(
in_proj_bias
);
auto
grad_out_proj_weight
=
torch
::
empty_like
(
out_proj_weight
);
auto
grad_out_proj_bias
=
torch
::
empty_like
(
out_proj_bias
);
auto
grad_norm_weight
=
torch
::
empty_like
(
norm_weight
);
auto
grad_norm_bias
=
torch
::
empty_like
(
norm_bias
);
// inputs.
const
T
*
grad_dec_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
std
::
shared_ptr
<
MultiHeadAttention
<
T
>>
layer
=
std
::
static_pointer_cast
<
MultiHeadAttention
<
T
>>
(
s_multihead_attention
[
layer_id
]);
layer
->
set_cur_batch_shape
(
g_output
.
size
(
0
),
g_output
.
size
(
1
));
layer
->
_grad_attn_qkvw_ptr
=
(
T
*
)
grad_in_proj_weight
.
data_ptr
();
layer
->
_grad_attn_qkvb_ptr
=
(
T
*
)
grad_in_proj_bias
.
data_ptr
();
layer
->
_grad_attn_ow_ptr
=
(
T
*
)
grad_out_proj_weight
.
data_ptr
();
layer
->
_grad_attn_ob_ptr
=
(
T
*
)
grad_out_proj_bias
.
data_ptr
();
layer
->
_grad_attn_nw_ptr
=
(
T
*
)
grad_norm_weight
.
data_ptr
();
layer
->
_grad_attn_nb_ptr
=
(
T
*
)
grad_norm_bias
.
data_ptr
();
layer
->
Backward
(
grad_dec_output_ptr
,
input_ptr
,
output_ptr
,
input_mask_ptr
,
grad_input_ptr
);
return
{
grad_input
,
grad_in_proj_weight
,
grad_in_proj_bias
,
grad_out_proj_weight
,
grad_out_proj_bias
,
grad_norm_weight
,
grad_norm_bias
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multihead_attention_fw_fp32"
,
&
multihead_attention_fw
<
float
>
,
"Multi-head Attention forward with fp32 (CUDA)"
);
m
.
def
(
"multihead_attention_fw_fp16"
,
&
multihead_attention_fw
<
__half
>
,
"Multi-head Attention forward with fp16 (CUDA)"
);
m
.
def
(
"multihead_attention_bw_fp32"
,
&
multihead_attention_bw
<
float
>
,
"Multi-head Attention backward with fp32 (CUDA)"
);
m
.
def
(
"multihead_attention_bw_fp16"
,
&
multihead_attention_bw
<
__half
>
,
"Multi-head Attention backward with fp16 (CUDA)"
);
m
.
def
(
"create_multihead_attention_fp32"
,
&
create_multihead_attention
<
float
>
,
"Create Multi-head Attention with fp32 (CUDA)"
);
m
.
def
(
"create_multihead_attention_fp16"
,
&
create_multihead_attention
<
__half
>
,
"Create Multi-head Attention with fp16 (CUDA)"
);
}
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
0 → 100644
View file @
5c3843dc
#pragma once
#include <c10/util/intrusive_ptr.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <c10d/ProcessGroup.hpp>
#include <string>
#include <type_traits>
#include "cuda_util.h"
#include "dropout.h"
#include "feed_forward.h"
#include "normalize_layer.h"
#include "softmax.h"
#include "strided_batch_gemm.h"
template
<
typename
T
>
class
MultiHeadAttention
{
public:
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
_max_seq_len
,
int
hidden_size
,
int
num_heads
,
float
attn_dropout_ratio
,
float
hidden_output_dropout_ratio
,
bool
pre_or_postLayerNorm
);
virtual
~
MultiHeadAttention
();
void
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
);
void
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
);
void
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
T
*
buffer
);
void
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_attn_layer_bwptr
,
T
*
buffer
);
void
set_cur_batch_shape
(
int
batch_size
,
int
seq_len
)
{
_batch_size
=
batch_size
;
_seq_len
=
seq_len
;
_batch_tokens
=
batch_size
*
seq_len
;
_batch_heads
=
batch_size
*
_heads
/
pg_size
;
_batch_dim
=
_batch_tokens
*
_hidden_size
;
_attn_scores
.
SetConfig
(
_seq_len
,
_seq_len
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_len
,
_seq_len
);
}
void
SetTrainingMode
(
bool
training
);
inline
bool
IsTrainingMode
()
const
{
return
_training
;
}
void
SetPG
(
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
pg_
)
{
pg
=
pg_
;
pg_size
=
1
;
if
(
pg
!=
c10
::
detail
::
UniqueVoidPtr
())
{
pg_size
=
pg
->
getSize
();
}
allocate_mem_buffer
();
}
// weights ptr
const
T
*
_attn_qkvw_ptr
;
const
T
*
_attn_qkvb_ptr
;
const
T
*
_attn_ow_ptr
;
const
T
*
_attn_ob_ptr
;
const
T
*
_attn_nw_ptr
;
const
T
*
_attn_nb_ptr
;
// grads ptr
T
*
_grad_attn_qkvw_ptr
;
T
*
_grad_attn_qkvb_ptr
;
T
*
_grad_attn_ow_ptr
;
T
*
_grad_attn_ob_ptr
;
T
*
_grad_attn_nw_ptr
;
T
*
_grad_attn_nb_ptr
;
private:
void
allocate_mem_buffer
()
{
// allocate local gpu memory
if
(
_pre_or_postLayerNorm
)
{
_gemmQKV_inp_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
);
}
else
{
_gemmQKV_inp_ptr
=
nullptr
;
}
_qkv_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
*
3
);
_soft_out_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_ctx_bufB_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_attn_o_inp_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
);
// buffer size needed by attn bw
size_t
smem_size
=
4
*
_max_batch_tokens
*
_hidden_size
/
pg_size
+
std
::
max
(
3
*
_max_batch_tokens
*
_hidden_size
/
pg_size
,
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
if
(
!
_shared_mem_ptr
)
{
cuda_free
(
_shared_mem_ptr
);
_shared_mem_ptr
=
cuda_malloc
<
T
>
(
smem_size
);
}
}
void
free_mem_buffer
()
{
// free local gpu memory
cuda_free
(
_gemmQKV_inp_ptr
);
cuda_free
(
_qkv_ptr
);
cuda_free
(
_soft_out_ptr
);
cuda_free
(
_ctx_bufB_ptr
);
cuda_free
(
_attn_o_inp_ptr
);
// free shared gpu memory between layers
cuda_free
(
_shared_mem_ptr
);
_shared_mem_ptr
=
nullptr
;
}
// const parameter between batch
const
size_t
_layer_id
;
const
size_t
_hidden_size
;
const
size_t
_heads
;
const
size_t
_max_batch_tokens
;
const
size_t
_max_seq_len
;
const
bool
_pre_or_postLayerNorm
;
// dynamic parameter between batch
size_t
_batch_size
;
size_t
_seq_len
;
size_t
_batch_tokens
;
size_t
_batch_heads
;
size_t
_batch_dim
;
bool
_training
;
cublasHandle_t
_cublasHandle
;
cudaStream_t
_stream
;
// layers
FeedForward
<
T
>
_qkv_linear
;
FeedForward
<
T
>
_attn_out_linear
;
Normalize_Layer
<
T
>
_attn_ln
;
Softmax
<
T
>
_softmax
;
Dropout
<
T
>
_attn_prob_dropout
;
Dropout
<
T
>
_attn_dropout
;
StridedBatchGemm
<
T
>
_attn_scores
;
StridedBatchGemm
<
T
>
_attn_context
;
// local GPU memory
T
*
_gemmQKV_inp_ptr
;
T
*
_qkv_ptr
;
T
*
_soft_out_ptr
;
T
*
_ctx_bufB_ptr
;
T
*
_attn_o_inp_ptr
;
// shared GPU memory between layer
static
T
*
_shared_mem_ptr
;
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
pg
;
int
pg_size
;
};
\ No newline at end of file
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp
0 → 100644
View file @
5c3843dc
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
}
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h
0 → 100644
View file @
5c3843dc
/*This code from NVIDIA Megatron:
* with minor changes. */
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
acc_t
scale
,
int
micro_batch_size
,
int
element_count
,
int
pad_batches
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
uint8_t
temp_mask
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_mask
,
mask
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
temp_mask
[
element
]
!=
1
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
10000.0
;
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
input_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu
0 → 100644
View file @
5c3843dc
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp
0 → 100644
View file @
5c3843dc
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h
0 → 100644
View file @
5c3843dc
/*This code from NVIDIA Megatron:
* with minor changes. */
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
((
element_index
+
element
)
<
batch_element_count
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
local_seq
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
else
{
out
[
element
]
=
0
;
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu
0 → 100644
View file @
5c3843dc
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
input
.
size
(
0
);
const
int
seq_len
=
input
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
seq_len
=
output_grads
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
colossalai/kernel/cuda_native/csrc/type_shim.h
0 → 100644
View file @
5c3843dc
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
colossalai/kernel/cuda_native/layer_norm.py
0 → 100644
View file @
5c3843dc
"""This code is from NVIDIA apex:
https://github.com/NVIDIA/apex
with some changes. """
import
numbers
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
import
importlib
global
colossal_layer_norm_cuda
colossal_layer_norm_cuda
=
None
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
weight_
=
weight
.
contiguous
()
bias_
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
colossal_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
\
=
colossal_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
colossal_layer_norm_cuda
colossal_layer_norm_cuda
=
importlib
.
import_module
(
"colossal_layer_norm_cuda"
)
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
colossalai/kernel/cuda_native/multihead_attention.py
0 → 100644
View file @
5c3843dc
import
math
import
importlib
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
torch.autograd
import
Function
def
check_config
(
config
):
if
config
.
hidden_size
%
config
.
nhead
!=
0
:
raise
Exception
(
f
"hidden_size % nhead != 0"
)
factor
=
8
if
config
.
fp16
else
4
upbound
=
factor
*
1024
*
4
if
config
.
hidden_size
>
upbound
:
# as required by ln backward kernel currently
raise
Exception
(
f
"hidden_size >
{
upbound
}
"
)
head_dim
=
config
.
hidden_size
//
config
.
nhead
if
head_dim
%
factor
!=
0
:
# as required by reshape kernel
raise
Exception
(
f
"head_dim(
{
head_dim
}
) %
{
factor
}
!= 0"
)
def
calc_offset
(
sizes
):
offsets
=
[
0
]
tmp
=
0
for
x
in
sizes
:
tmp
+=
x
offsets
.
append
(
tmp
)
return
offsets
colossal_multihead_attention
=
None
@
dataclass
class
Config
:
max_batch_tokens
:
int
# max batch token numbers
max_seq_len
:
int
# max sequence length
hidden_size
:
int
# size of transformer hidden layers
nhead
:
int
# number of heads in attention
attn_prob_dropout_ratio
:
float
# attention score dropout ratio
hidden_dropout_ratio
:
float
# dropout ration before residual
norm_first
:
bool
# norm_first
fp16
:
bool
# fp16 presion
class
MultiHeadAttention1DFunc
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
,
config
):
cuda_module
=
colossal_multihead_attention
forward_func
=
(
cuda_module
.
multihead_attention_fw_fp16
if
config
.
fp16
else
cuda_module
.
multihead_attention_fw_fp32
)
if
config
.
fp16
:
input
=
input
.
to
(
torch
.
half
)
input_mask
=
input_mask
.
to
(
torch
.
half
)
(
output
,)
=
forward_func
(
config
.
layer_id
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
,
config
.
training
,
config
.
norm_first
)
if
config
.
is_grad_enabled
and
config
.
training
:
ctx
.
save_for_backward
(
output
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
)
ctx
.
config
=
config
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
assert
ctx
.
config
.
training
cuda_module
=
colossal_multihead_attention
backward_func
=
(
cuda_module
.
multihead_attention_bw_fp16
if
ctx
.
config
.
fp16
else
cuda_module
.
multihead_attention_bw_fp32
)
output
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
\
out_proj_bias
,
norm_weight
,
norm_bias
=
ctx
.
saved_tensors
grad_input
=
None
grad_in_proj_weight
=
None
grad_in_proj_bias
=
None
grad_out_proj_weight
=
None
grad_out_proj_bias
=
None
grad_norm_weight
=
None
grad_norm_bias
=
None
if
ctx
.
config
.
fp16
:
grad_output
=
grad_output
.
to
(
torch
.
half
)
output
=
output
.
to
(
torch
.
half
)
input
=
input
.
to
(
torch
.
half
)
input_mask
=
input_mask
.
to
(
torch
.
half
)
grad_input
,
grad_in_proj_weight
,
grad_in_proj_bias
,
grad_out_proj_weight
,
\
grad_out_proj_bias
,
grad_norm_weight
,
grad_norm_bias
=
backward_func
(
ctx
.
config
.
layer_id
,
grad_output
,
output
,
input
,
input_mask
,
in_proj_weight
,
\
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
)
return
(
grad_input
,
None
,
grad_in_proj_weight
,
grad_in_proj_bias
,
grad_out_proj_weight
,
grad_out_proj_bias
,
grad_norm_weight
,
grad_norm_bias
,
None
)
class
MultiHeadAttention
(
nn
.
Module
):
"""Initialize the MultiHeadAttention.
Static variable:
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
Arguments:
hidden_size: Total dimension of hidden_size.
nhead: Number of parallel attention heads.
batch_size: Batch Size for one foward
max_seq_len: Max length of input sequence
dropout: Dropout probability
norm_first: perform LayerNorms before attention
"""
layer_id
=
0
def
__init__
(
self
,
hidden_size
,
nhead
,
batch_size
,
max_seq_len
,
dropout
=
0.0
,
norm_first
=
False
,
fp16
=
True
,
pg
=
None
):
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
config
=
Config
(
batch_size
*
max_seq_len
,
max_seq_len
,
hidden_size
,
nhead
,
dropout
,
dropout
,
norm_first
,
fp16
)
check_config
(
self
.
config
)
self
.
pg
=
pg
self
.
pg_size
=
1
if
self
.
pg
:
self
.
pg_size
=
pg
.
size
()
self
.
config
.
layer_id
=
MultiHeadAttention
.
layer_id
MultiHeadAttention
.
layer_id
=
MultiHeadAttention
.
layer_id
+
1
# Load cuda modules if needed
global
colossal_multihead_attention
if
colossal_multihead_attention
is
None
:
colossal_multihead_attention
=
importlib
.
import_module
(
"colossal_multihead_attention"
)
# create the layer in cuda kernels.
cuda_module
=
colossal_multihead_attention
create_layer_func
=
(
cuda_module
.
create_multihead_attention_fp16
if
self
.
config
.
fp16
else
cuda_module
.
create_multihead_attention_fp32
)
create_layer_func
(
self
.
config
.
layer_id
,
self
.
config
.
max_batch_tokens
,
self
.
config
.
max_seq_len
,
self
.
config
.
hidden_size
,
self
.
config
.
nhead
,
self
.
config
.
attn_prob_dropout_ratio
,
self
.
config
.
hidden_dropout_ratio
,
self
.
config
.
norm_first
,
self
.
pg
,
)
hs
=
self
.
config
.
hidden_size
self
.
precision
=
torch
.
float32
if
self
.
config
.
fp16
:
self
.
precision
=
torch
.
half
self
.
hs_per_rank
=
int
(
hs
/
self
.
pg_size
)
self
.
in_proj_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
3
,
self
.
hs_per_rank
,
hs
))
self
.
in_proj_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
3
,
self
.
hs_per_rank
))
self
.
out_proj_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
hs
,
self
.
hs_per_rank
))
self
.
out_proj_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
hs
))
self
.
norm_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
hs
))
self
.
norm_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
hs
))
self
.
reset_parameters
()
torch
.
cuda
.
empty_cache
()
def
calc_bound
(
self
,
w
):
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
w
)
bound
=
1.0
/
math
.
sqrt
(
fan_in
)
return
bound
def
reset_parameters
(
self
):
hs
=
self
.
config
.
hidden_size
nn
.
init
.
zeros_
(
self
.
out_proj_bias
)
nn
.
init
.
ones_
(
self
.
norm_weight
)
nn
.
init
.
zeros_
(
self
.
norm_bias
)
if
self
.
pg_size
>
1
:
rank_in_pg
=
torch
.
distributed
.
get_rank
(
self
.
pg
)
attn_qkvw_global
=
torch
.
empty
(
hs
*
3
,
hs
)
attn_qkvb_global
=
torch
.
empty
(
hs
*
3
)
nn
.
init
.
xavier_uniform_
(
attn_qkvw_global
,
1.0
/
math
.
sqrt
(
2.0
))
bound
=
self
.
calc_bound
(
attn_qkvw_global
)
nn
.
init
.
uniform_
(
attn_qkvb_global
,
-
bound
,
bound
)
attn_qkvw_global
=
attn_qkvw_global
.
cuda
()
attn_qkvb_global
=
attn_qkvb_global
.
cuda
()
torch
.
distributed
.
broadcast
(
attn_qkvw_global
,
src
=
0
,
group
=
self
.
pg
)
torch
.
distributed
.
broadcast
(
attn_qkvb_global
,
src
=
0
,
group
=
self
.
pg
)
attn_qkvw_global
=
attn_qkvw_global
.
cpu
()
attn_qkvb_global
=
attn_qkvb_global
.
cpu
()
with
torch
.
no_grad
():
self
.
in_proj_weight
.
copy_
(
attn_qkvw_global
.
view
(
3
,
hs
,
hs
)[:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
):
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
),
:])
self
.
in_proj_bias
.
copy_
(
attn_qkvb_global
.
view
(
3
,
hs
)[:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
):
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
)])
attn_ow_global
=
torch
.
empty
(
hs
,
hs
)
nn
.
init
.
xavier_uniform_
(
attn_ow_global
,
1.0
)
attn_ow_global
=
attn_ow_global
.
cuda
()
torch
.
distributed
.
broadcast
(
attn_ow_global
,
src
=
0
,
group
=
self
.
pg
)
attn_ow_global
=
attn_ow_global
.
cpu
()
with
torch
.
no_grad
():
self
.
out_proj_weight
.
copy_
(
attn_ow_global
[:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
):
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
)])
else
:
attn_qkvw
=
self
.
in_proj_weight
.
view
(
-
1
,
hs
)
nn
.
init
.
xavier_uniform_
(
attn_qkvw
,
1.0
/
math
.
sqrt
(
2.0
))
bound
=
self
.
calc_bound
(
attn_qkvw
)
nn
.
init
.
uniform_
(
self
.
in_proj_bias
,
-
bound
,
bound
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj_weight
,
1.0
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
destination
=
torch
.
nn
.
Module
.
state_dict
(
self
,
destination
=
destination
,
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
destination
def
forward
(
self
,
hidden_states
,
encoder_padding_mask
):
self
.
config
.
training
=
self
.
training
self
.
config
.
is_grad_enabled
=
torch
.
is_grad_enabled
()
hidden_states
=
hidden_states
.
contiguous
()
encoder_padding_mask
=
((
encoder_padding_mask
*
-
1e8
).
type_as
(
hidden_states
).
contiguous
())
bs
,
sl
,
dim
=
hidden_states
.
size
()
if
bs
*
sl
>
self
.
config
.
max_batch_tokens
:
raise
ValueError
(
f
"Batch token numbers
{
bs
*
sl
}
exceeds the limit
{
self
.
config
.
max_batch_tokens
}
."
)
if
sl
>
self
.
config
.
max_seq_len
:
raise
ValueError
(
f
"Sequence length
{
sl
}
exceeds the limit
{
self
.
config
.
max_seq_len
}
."
)
if
len
(
encoder_padding_mask
.
size
())
==
1
:
assert
bs
==
1
and
sl
==
encoder_padding_mask
.
size
(
0
)
else
:
assert
bs
==
encoder_padding_mask
.
size
(
0
)
and
sl
==
encoder_padding_mask
.
size
(
1
)
output
=
MultiHeadAttention1DFunc
.
apply
(
hidden_states
,
encoder_padding_mask
,
self
.
in_proj_weight
,
self
.
in_proj_bias
,
self
.
out_proj_weight
,
self
.
out_proj_bias
,
self
.
norm_weight
,
self
.
norm_bias
,
self
.
config
)
return
output
.
to
(
self
.
precision
)
colossalai/kernel/cuda_native/scaled_softmax.py
0 → 100644
View file @
5c3843dc
"""This code from NVIDIA Megatron
with some changes. """
import
torch
import
torch.nn
as
nn
import
enum
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
colossal_scaled_upper_triang_masked_softmax
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
colossal_scaled_upper_triang_masked_softmax
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
colossal_scaled_upper_triang_masked_softmax
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
colossal_scaled_upper_triang_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
colossal_scaled_masked_softmax
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
colossal_scaled_masked_softmax
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
colossal_scaled_masked_softmax
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
colossal_scaled_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
def
forward_fused_softmax
(
self
,
input
,
mask
):
b
,
np
,
sq
,
sk
=
input
.
size
()
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
colossal_scaled_masked_softmax
return
colossal_scaled_masked_softmax
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
colossalai/kernel/jit/__init__.py
0 → 100644
View file @
5c3843dc
from
.option
import
_set_jit_fusion_options
_set_jit_fusion_options
()
\ No newline at end of file
colossalai/kernel/jit/bias_dropout_add.py
0 → 100644
View file @
5c3843dc
import
torch
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out
=
torch
.
nn
.
functional
.
dropout
(
x
+
bias
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment