Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8823cc48
Unverified
Commit
8823cc48
authored
Jan 29, 2024
by
Frank Lee
Committed by
GitHub
Jan 29, 2024
Browse files
Merge pull request #5310 from hpcaitech/feature/npu
Feature/npu
parents
bce9499e
73f4dc57
Changes
266
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
127 additions
and
3294 deletions
+127
-3294
colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
...nel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
+0
-100
colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
...alai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
+0
-1172
colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
...ssalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
+0
-365
colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu
...alai/kernel/cuda_native/csrc/kernels/transform_kernels.cu
+0
-314
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
...ssalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
+0
-406
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
+0
-167
colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp
colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp
+0
-8
colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu
colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu
+0
-162
colossalai/kernel/cuda_native/csrc/smoothquant/linear.h
colossalai/kernel/cuda_native/csrc/smoothquant/linear.h
+0
-12
colossalai/kernel/cuda_native/mha/__init__.py
colossalai/kernel/cuda_native/mha/__init__.py
+0
-3
colossalai/kernel/cuda_native/mha/flash_attn_2.py
colossalai/kernel/cuda_native/mha/flash_attn_2.py
+0
-80
colossalai/kernel/cuda_native/mha/mem_eff_attn.py
colossalai/kernel/cuda_native/mha/mem_eff_attn.py
+0
-70
colossalai/kernel/cuda_native/mha/utils.py
colossalai/kernel/cuda_native/mha/utils.py
+0
-82
colossalai/kernel/cuda_native/multihead_attention.py
colossalai/kernel/cuda_native/multihead_attention.py
+0
-338
colossalai/kernel/extensions
colossalai/kernel/extensions
+1
-0
colossalai/kernel/jit/option.py
colossalai/kernel/jit/option.py
+12
-10
colossalai/kernel/kernel_loader.py
colossalai/kernel/kernel_loader.py
+109
-0
colossalai/kernel/op_builder
colossalai/kernel/op_builder
+0
-1
colossalai/legacy/amp/naive_amp/_fp16_optimizer.py
colossalai/legacy/amp/naive_amp/_fp16_optimizer.py
+2
-2
colossalai/legacy/amp/torch_amp/torch_amp.py
colossalai/legacy/amp/torch_amp/torch_amp.py
+3
-2
No files found.
colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
deleted
100644 → 0
View file @
bce9499e
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
Licensed under the MIT License.
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <array>
#include "cublas_wrappers.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
cublasOperation_t
op_A
;
cublasOperation_t
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
float
param_alpha
,
float
param_beta
,
cublasOperation_t
opA
,
cublasOperation_t
opB
)
:
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}))
{}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
cublasOperation_t
op_b
=
(
_config
.
op_B
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
CUBLAS_OP_T
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
CUBLAS_OP_T
?
d_output
:
_buffer_b
),
inpGradA
,
CUBLAS_OP_N
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
// A need to transpose.
cublasOperation_t
op_a
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
CUBLAS_OP_N
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
};
colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu
deleted
100644 → 0
View file @
bce9499e
#include <cooperative_groups.h>
#include "block_reduce.h"
#include "kernels.h"
namespace
cg
=
cooperative_groups
;
const
float
LN_EPSILON
=
1e-8
f
;
#define TILE_DIM 32
template
<
typename
T
>
__forceinline__
__device__
T
add_eps
(
T
x
)
{
return
fabsf
(
x
)
>
LN_EPSILON
?
x
:
(
x
<
0
?
-
LN_EPSILON
:
LN_EPSILON
);
}
/**
@brief: ker_layer_norm
Standard layer normalization.
It will not only output the layer norm result,
but also outputs variance.
may also output means, depends on whether
the means argument is nullptr
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
ln_res: [batch_size* seq_len, hidden_size], ln result.
vars: [batch_size* seq_len], variance per token
means: [batch_size* seq_len], means per token, can be nullput
inp: [batch_size * seq_len, hidden_size], ln input.
scale: [hidden_size], ln scale
bias: [hidden_size], ln bias
*/
template
<
typename
T
>
__global__
void
ker_layer_norm
(
T
*
ln_res
,
T
*
vars
,
T
*
means
,
const
T
*
inp
,
const
T
*
scale
,
const
T
*
bias
,
int
hidden_size
)
{
// step 0. compute local sum
float
l_sum
=
0
;
float
l_square_sum
=
0
;
const
float4
*
inp_f4
=
(
const
float4
*
)
inp
+
blockIdx
.
x
*
hidden_size
;
for
(
uint
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float4
val
=
inp_f4
[
idx
];
l_sum
+=
val
.
x
+
val
.
y
+
val
.
z
+
val
.
w
;
l_square_sum
+=
val
.
x
*
val
.
x
+
val
.
y
*
val
.
y
+
val
.
z
*
val
.
z
+
val
.
w
*
val
.
w
;
}
// step 1. compute reduce sum
float
mean_dim
=
float
(
hidden_size
)
*
4.
f
;
float
reduce_val
[
2
]
=
{
l_sum
,
l_square_sum
};
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_mean
,
s_var
;
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
reduce_val
[
0
]
/
mean_dim
;
if
(
means
!=
nullptr
)
{
means
[
blockIdx
.
x
]
=
s_mean
;
}
s_var
=
reduce_val
[
1
]
/
mean_dim
-
s_mean
*
s_mean
+
LN_EPSILON
;
vars
[
blockIdx
.
x
]
=
s_var
;
s_var
=
rsqrtf
(
s_var
);
}
__syncthreads
();
// step 2. layer norm result
float4
*
output_f4
=
(
float4
*
)
ln_res
+
blockIdx
.
x
*
hidden_size
;
for
(
uint
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float4
vscale
=
__ldg
((
const
float4
*
)
scale
+
idx
);
float4
vbias
=
__ldg
((
const
float4
*
)
bias
+
idx
);
float4
val
=
inp_f4
[
idx
];
val
.
x
=
(
val
.
x
-
s_mean
)
*
s_var
*
vscale
.
x
+
vbias
.
x
;
val
.
y
=
(
val
.
y
-
s_mean
)
*
s_var
*
vscale
.
y
+
vbias
.
y
;
val
.
z
=
(
val
.
z
-
s_mean
)
*
s_var
*
vscale
.
z
+
vbias
.
z
;
val
.
w
=
(
val
.
w
-
s_mean
)
*
s_var
*
vscale
.
w
+
vbias
.
w
;
output_f4
[
idx
]
=
val
;
}
}
template
<
>
__global__
void
ker_layer_norm
<
__half
>
(
__half
*
ln_res
,
__half
*
vars
,
__half
*
means
,
const
__half
*
inp
,
const
__half
*
scale
,
const
__half
*
bias
,
int
hidden_size
)
{
// step 0. compute local sum
float
l_sum
=
0
;
float
l_square_sum
=
0
;
const
float4
*
inp_f4
=
(
const
float4
*
)
inp
+
blockIdx
.
x
*
hidden_size
;
for
(
uint
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float4
val_f4
=
inp_f4
[
idx
];
__half2
*
val_h2
=
(
__half2
*
)(
&
val_f4
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
val_f2
=
__half22float2
(
val_h2
[
i
]);
l_sum
+=
val_f2
.
x
+
val_f2
.
y
;
l_square_sum
+=
val_f2
.
x
*
val_f2
.
x
+
val_f2
.
y
*
val_f2
.
y
;
}
}
// step 1. compute reduce sum
float
mean_dim
=
float
(
hidden_size
)
*
8.
f
;
float
reduce_val
[
2
]
=
{
l_sum
,
l_square_sum
};
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_mean
,
s_var
;
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
reduce_val
[
0
]
/
mean_dim
;
if
(
means
!=
nullptr
)
{
means
[
blockIdx
.
x
]
=
s_mean
;
}
s_var
=
reduce_val
[
1
]
/
mean_dim
-
s_mean
*
s_mean
+
LN_EPSILON
;
vars
[
blockIdx
.
x
]
=
s_var
;
s_var
=
rsqrtf
(
s_var
);
}
__syncthreads
();
// step 2. layer norm result
float4
*
output_f4
=
(
float4
*
)
ln_res
+
blockIdx
.
x
*
hidden_size
;
for
(
uint
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
// load scale, bias, input
float4
scale_f4
=
__ldg
((
const
float4
*
)
scale
+
idx
);
__half2
*
scale_h2
=
(
__half2
*
)(
&
scale_f4
);
float4
bias_f4
=
__ldg
((
const
float4
*
)
bias
+
idx
);
__half2
*
bias_h2
=
(
__half2
*
)(
&
bias_f4
);
float4
val_f4
=
inp_f4
[
idx
];
__half2
*
val_h2
=
(
__half2
*
)(
&
val_f4
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
scale_f2
=
__half22float2
(
scale_h2
[
i
]);
float2
bias_f2
=
__half22float2
(
bias_h2
[
i
]);
float2
val_f2
=
__half22float2
(
val_h2
[
i
]);
val_f2
.
x
=
(
val_f2
.
x
-
s_mean
)
*
s_var
*
scale_f2
.
x
+
bias_f2
.
x
;
val_f2
.
y
=
(
val_f2
.
y
-
s_mean
)
*
s_var
*
scale_f2
.
y
+
bias_f2
.
y
;
val_h2
[
i
]
=
__float22half2_rn
(
val_f2
);
}
output_f4
[
idx
]
=
val_f4
;
}
}
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half
// *bias, int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x
// * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// }
// }
// // step 1. compute reduce sum
// float mean_dim = float(hidden_size) * 8.f * 2;
// float reduce_val[2] = {l_sum, l_square_sum};
// blockReduce<ReduceType::kSum, 2>(reduce_val);
// __shared__ float s_mean, s_var;
// if (threadIdx.x == 0) {
// s_mean = reduce_val[0] / mean_dim;
// if (means != nullptr) {
// means[blockIdx.x] = s_mean;
// }
// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
// vars[blockIdx.x] = s_var;
// s_var = rsqrtf(s_var);
// }
// __syncthreads();
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1);
// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1);
// float4 bias_f4 = __ldg((const float4 *)bias + idx);
// __half2 *bias_h2 = (__half2 *)(&bias_f4);
// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1);
// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1);
// float4 val_f4 = inp_f4[idx];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 scale_f2 = __half22float2(scale_h2[i]);
// float2 scale_f2_1 = __half22float2(scale_h2_1[i]);
// float2 bias_f2 = __half22float2(bias_h2[i]);
// float2 bias_f2_1 = __half22float2(bias_h2_1[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// }
// }
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half
// *bias, int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_2 = inp_f4[idx+2];
// float4 val_f4_3 = inp_f4[idx+3];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// __half2 *val_h2_2 = (__half2 *)(&val_f4_2);
// __half2 *val_h2_3 = (__half2 *)(&val_f4_3);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x +
// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x *
// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x
// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x +
// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x +
// val_f2_3.y * val_f2_3.y;
// }
// }
// // step 1. compute reduce sum
// float mean_dim = float(hidden_size) * 8.f * 4;
// float reduce_val[2] = {l_sum, l_square_sum};
// blockReduce<ReduceType::kSum, 2>(reduce_val);
// __shared__ float s_mean, s_var;
// if (threadIdx.x == 0) {
// s_mean = reduce_val[0] / mean_dim;
// if (means != nullptr) {
// means[blockIdx.x] = s_mean;
// }
// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
// vars[blockIdx.x] = s_var;
// s_var = rsqrtf(s_var);
// }
// __syncthreads();
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1);
// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1);
// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2);
// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2);
// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3);
// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3);
// float4 bias_f4 = __ldg((const float4 *)bias + idx);
// __half2 *bias_h2 = (__half2 *)(&bias_f4);
// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1);
// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1);
// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2);
// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2);
// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3);
// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3);
// float4 val_f4 = inp_f4[idx];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// float4 val_f4_2 = inp_f4[idx+2];
// __half2 *val_h2_2 = (__half2 *)(&val_f4_2);
// float4 val_f4_3 = inp_f4[idx+3];
// __half2 *val_h2_3 = (__half2 *)(&val_f4_3);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 scale_f2 = __half22float2(scale_h2[i]);
// float2 scale_f2_1 = __half22float2(scale_h2_1[i]);
// float2 scale_f2_2 = __half22float2(scale_h2_2[i]);
// float2 scale_f2_3 = __half22float2(scale_h2_3[i]);
// float2 bias_f2 = __half22float2(bias_h2[i]);
// float2 bias_f2_1 = __half22float2(bias_h2_1[i]);
// float2 bias_f2_2 = __half22float2(bias_h2_2[i]);
// float2 bias_f2_3 = __half22float2(bias_h2_3[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var *
// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var
// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) *
// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean)
// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] =
// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_h2_2[i] = __float22half2_rn(val_f2_2);
// val_h2_3[i] = __float22half2_rn(val_f2_3);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// output_f4[idx+2] = val_f4_2;
// output_f4[idx+3] = val_f4_3;
// }
// }
template
<
>
void
launch_layer_norm
<
float
>
(
float
*
ln_res
,
float
*
vars
,
float
*
means
,
const
float
*
inp
,
const
float
*
scale
,
const
float
*
bias
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
)
{
if
(
hidden_dim
%
4
!=
0
)
{
throw
std
::
runtime_error
(
"violate hidden_dim % 4 = 0"
);
}
hidden_dim
>>=
2
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
nthread
);
ker_layer_norm
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
ln_res
,
vars
,
means
,
inp
,
scale
,
bias
,
hidden_dim
);
}
template
<
>
void
launch_layer_norm
<
__half
>
(
__half
*
ln_res
,
__half
*
vars
,
__half
*
means
,
const
__half
*
inp
,
const
__half
*
scale
,
const
__half
*
bias
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
)
{
if
(
hidden_dim
%
8
!=
0
)
{
throw
std
::
runtime_error
(
"violate hidden_dim % 8 = 0"
);
}
hidden_dim
>>=
3
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
nthread
);
ker_layer_norm
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
ln_res
,
vars
,
means
,
inp
,
scale
,
bias
,
hidden_dim
);
// if (hidden_dim % 8 != 0) {
// throw std::runtime_error("violate hidden_dim % 8 = 0");
// }
// hidden_dim >>= 3;
// if (hidden_dim * 8 < 8192) {
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm<__half><<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) {
// hidden_dim >>= 1;
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm_x2<<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) {
// hidden_dim >>= 2;
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm_x4<<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else {
// throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
// }
}
/**
@brief: ker_ln_bw_dgamma_dbetta
Layer norm backword kernel, compute the gradient of gamma and betta.
dbetta = sum(dout, dim=0)
dgamma = sum(xhat * dout, dim=0)
xhat = (input - mean) * rsqrt(var) or
(output - betta) / gamma
@thread
gridDim.x = hidden_size / 32
blockDim.x = 32
blockDim.y = 32
@param
gamma_grad: [hidden_size], gradient of gamma
betta_grad: [hidden_size], gradient of betta
out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr
ln input if means is not nullptr
gamma: [hidden_size], gamma of ln,
used to compute xhat, maybe nullptr
betta: [hidden_size], betta of ln,
used to compute xhat, maybe nullptr
vars: [batch_size * seq_len], variance of ln forward,
used to compute xhat, maybe nullptr
means: [batch_size * seq_len], mean of ln forward,
used to compute xhat, maybe nullptr
(gamma && betta) ^ (vars && means) should be true
*/
template
<
typename
T
>
__global__
void
ker_ln_bw_dgamma_dbetta
(
T
*
gamma_grad
,
T
*
betta_grad
,
const
T
*
out_grad
,
const
T
*
inp_or_out
,
const
T
*
gamma
,
const
T
*
betta
,
const
T
*
vars
,
const
T
*
means
,
int
rows
,
int
width
)
{
__shared__
float
betta_buffer
[
TILE_DIM
][
TILE_DIM
];
__shared__
float
gamma_buffer
[
TILE_DIM
][
TILE_DIM
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
offset
=
threadIdx
.
y
*
width
+
idx
;
int
y_stride
=
width
*
TILE_DIM
;
// Loop across inp height
float
dbetta
=
0
;
float
dgamma
=
0
;
float
dout
,
val
;
if
(
idx
<
width
)
{
if
(
means
==
nullptr
)
{
float
vbetta
=
(
float
)
betta
[
idx
];
float
vgamma
=
(
float
)
gamma
[
idx
];
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
dout
=
(
float
)
out_grad
[
offset
];
// inp_or_out is output
val
=
(
float
)
inp_or_out
[
offset
];
dbetta
+=
dout
;
dgamma
+=
((
val
-
vbetta
)
/
add_eps
(
vgamma
)
*
dout
);
offset
+=
y_stride
;
}
}
else
{
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
dout
=
(
float
)
out_grad
[
offset
];
// inp_or_out is input
val
=
(
float
)
inp_or_out
[
offset
];
dbetta
+=
dout
;
dgamma
+=
((
val
-
(
float
)
means
[
r
])
*
rsqrtf
((
float
)
vars
[
r
]
+
LN_EPSILON
)
*
dout
);
offset
+=
y_stride
;
}
}
}
// Sum the shared buffer.
betta_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
dbetta
;
gamma_buffer
[
threadIdx
.
x
][
threadIdx
.
y
]
=
dgamma
;
__syncthreads
();
float
s1
=
betta_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
float
s2
=
gamma_buffer
[
threadIdx
.
y
][
threadIdx
.
x
];
__syncthreads
();
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
{
s1
+=
g
.
shfl_down
(
s1
,
i
);
s2
+=
g
.
shfl_down
(
s2
,
i
);
}
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
if
(
threadIdx
.
x
==
0
&&
idx
<
width
)
{
betta_grad
[
pos
]
=
s1
;
gamma_grad
[
pos
]
=
s2
;
}
}
/**
@brief: ker_ln_bw_dinp
Layer norm backword kernel, compute the gradient of input.
dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim)
* rsqrt(var)
xhat = (input - mean) * rsqrt(var) if mean is not nullptr
(output - betta) / gamma if mean is nullptr
dxhat = dout * gamma
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input,
usually appear in pre-layer-norm for transformer layer, maybe nullptr
inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr
ln input if means is not nullptr
gamma: [hidden_size], gamma of ln,
used to compute xhat and dxhat
betta: [hidden_size], betta of ln,
used to compute xhat, maybe nullptr
vars: [batch_size * seq_len], variance of ln forward,
used to compute xhat and dinp
means: [batch_size * seq_len], mean of ln forward,
used to compute xhat, maybe nullptr
*/
template
<
typename
T
>
__global__
void
ker_ln_bw_dinp
(
T
*
inp_grad
,
const
T
*
out_grad
,
const
T
*
residual_grad
,
const
T
*
inp_or_out
,
const
T
*
gamma
,
const
T
*
betta
,
const
T
*
vars
,
const
T
*
means
,
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
+
threadIdx
.
x
;
float4
dxhat
,
xhat
;
float
var_rsqrt
;
if
(
threadIdx
.
x
<
hidden_dim
)
{
// step 0. dxhat = dout * gamma
dxhat
=
((
const
float4
*
)
out_grad
)[
offset
];
float4
vgamma
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
];
dxhat
.
x
*=
vgamma
.
x
;
dxhat
.
y
*=
vgamma
.
y
;
dxhat
.
z
*=
vgamma
.
z
;
dxhat
.
w
*=
vgamma
.
w
;
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
xhat
=
((
const
float4
*
)
inp_or_out
)[
offset
];
var_rsqrt
=
rsqrtf
((
float
)
vars
[
blockIdx
.
x
]
+
LN_EPSILON
);
if
(
means
==
nullptr
)
{
// inp_or_out is output, xhat = (output - betta) / gamma
float4
vbetta
=
((
const
float4
*
)
betta
)[
threadIdx
.
x
];
xhat
.
x
=
(
xhat
.
x
-
vbetta
.
x
)
/
add_eps
(
vgamma
.
x
);
xhat
.
y
=
(
xhat
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
xhat
.
z
=
(
xhat
.
z
-
vbetta
.
z
)
/
add_eps
(
vgamma
.
z
);
xhat
.
w
=
(
xhat
.
w
-
vbetta
.
w
)
/
add_eps
(
vgamma
.
w
);
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float
fmean
=
(
float
)
means
[
blockIdx
.
x
];
xhat
.
x
=
(
xhat
.
x
-
fmean
)
*
var_rsqrt
;
xhat
.
y
=
(
xhat
.
y
-
fmean
)
*
var_rsqrt
;
xhat
.
z
=
(
xhat
.
z
-
fmean
)
*
var_rsqrt
;
xhat
.
w
=
(
xhat
.
w
-
fmean
)
*
var_rsqrt
;
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
float
reduce_val
[
2
]
=
{
0.
f
,
0.
f
};
if
(
threadIdx
.
x
<
hidden_dim
)
{
reduce_val
[
0
]
=
dxhat
.
x
+
dxhat
.
y
+
dxhat
.
z
+
dxhat
.
w
;
reduce_val
[
1
]
=
dxhat
.
x
*
xhat
.
x
+
dxhat
.
y
*
xhat
.
y
+
dxhat
.
z
*
xhat
.
z
+
dxhat
.
w
*
xhat
.
w
;
}
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_sum_dxhat
,
s_sum_dxhat_xhat
;
if
(
threadIdx
.
x
==
0
)
{
float
mean_dim
=
hidden_dim
*
4
;
s_sum_dxhat
=
reduce_val
[
0
]
/
mean_dim
;
s_sum_dxhat_xhat
=
reduce_val
[
1
]
/
mean_dim
;
}
__syncthreads
();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if
(
threadIdx
.
x
>=
hidden_dim
)
{
return
;
}
dxhat
.
x
=
(
dxhat
.
x
-
s_sum_dxhat
-
xhat
.
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
;
dxhat
.
y
=
(
dxhat
.
y
-
s_sum_dxhat
-
xhat
.
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
;
dxhat
.
z
=
(
dxhat
.
z
-
s_sum_dxhat
-
xhat
.
z
*
s_sum_dxhat_xhat
)
*
var_rsqrt
;
dxhat
.
w
=
(
dxhat
.
w
-
s_sum_dxhat
-
xhat
.
w
*
s_sum_dxhat_xhat
)
*
var_rsqrt
;
if
(
residual_grad
)
{
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
dxhat
.
x
+=
dresidual
.
x
;
dxhat
.
y
+=
dresidual
.
y
;
dxhat
.
z
+=
dresidual
.
z
;
dxhat
.
w
+=
dresidual
.
w
;
}
((
float4
*
)
inp_grad
)[
offset
]
=
dxhat
;
}
template
<
>
__global__
void
ker_ln_bw_dinp
<
__half
>
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
means
,
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
+
threadIdx
.
x
;
float2
dxhat
[
4
],
xhat
[
4
];
float
var_rsqrt
;
float4
vtmp
;
__half2
*
tmp_h2
;
float
reduce_val
[
2
]
=
{
0.
f
,
0.
f
};
if
(
threadIdx
.
x
<
hidden_dim
)
{
// step 0. dxhat = dout * gamma
vtmp
=
((
const
float4
*
)
out_grad
)[
offset
];
tmp_h2
=
reinterpret_cast
<
__half2
*>
(
&
vtmp
);
float4
gamma_f4
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
];
__half2
*
gamma_h2
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vdout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
dxhat
[
i
].
x
=
vdout
.
x
*
vgamma
.
x
;
dxhat
[
i
].
y
=
vdout
.
y
*
vgamma
.
y
;
reduce_val
[
0
]
+=
dxhat
[
i
].
x
+
dxhat
[
i
].
y
;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp
=
((
const
float4
*
)
inp_or_out
)[
offset
];
var_rsqrt
=
rsqrtf
((
float
)
vars
[
blockIdx
.
x
]
+
LN_EPSILON
);
if
(
means
==
nullptr
)
{
// inp_or_out is output, xhat = (output - betta) / gamma
float4
vbetta
=
((
const
float4
*
)
betta
)[
threadIdx
.
x
];
__half2
*
betta_h2
=
reinterpret_cast
<
__half2
*>
(
&
vbetta
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vbetta
=
__half22float2
(
betta_h2
[
i
]);
xhat
[
i
].
x
=
(
vout
.
x
-
vbetta
.
x
)
/
add_eps
(
vgamma
.
x
);
xhat
[
i
].
y
=
(
vout
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
}
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float
fmean
=
(
float
)
means
[
blockIdx
.
x
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vinp
=
__half22float2
(
tmp_h2
[
i
]);
xhat
[
i
].
x
=
(
vinp
.
x
-
fmean
)
*
var_rsqrt
;
xhat
[
i
].
y
=
(
vinp
.
y
-
fmean
)
*
var_rsqrt
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_sum_dxhat
,
s_sum_dxhat_xhat
;
if
(
threadIdx
.
x
==
0
)
{
float
mean_dim
=
hidden_dim
*
8
;
s_sum_dxhat
=
reduce_val
[
0
]
/
mean_dim
;
s_sum_dxhat_xhat
=
reduce_val
[
1
]
/
mean_dim
;
}
__syncthreads
();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if
(
threadIdx
.
x
>=
hidden_dim
)
{
return
;
}
if
(
residual_grad
)
{
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
]));
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
+
1
]));
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
}
}
((
float4
*
)
inp_grad
)[
offset
]
=
vtmp
;
}
__global__
void
ker_ln_bw_dinp_x2
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
means
,
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
*
2
+
threadIdx
.
x
*
2
;
float2
dxhat
[
4
],
xhat
[
4
];
float2
dxhat_1
[
4
],
xhat_1
[
4
];
float
var_rsqrt
;
float4
vtmp
,
vtmp_1
;
__half2
*
tmp_h2
;
__half2
*
tmp_h2_1
;
float
reduce_val
[
2
]
=
{
0.
f
,
0.
f
};
if
(
threadIdx
.
x
<
hidden_dim
)
{
// step 0. dxhat = dout * gamma
vtmp
=
((
const
float4
*
)
out_grad
)[
offset
];
vtmp_1
=
((
const
float4
*
)
out_grad
)[
offset
+
1
];
tmp_h2
=
reinterpret_cast
<
__half2
*>
(
&
vtmp
);
tmp_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
vtmp_1
);
float4
gamma_f4
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
2
];
float4
gamma_f4_1
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
2
+
1
];
__half2
*
gamma_h2
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4
);
__half2
*
gamma_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vdout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vdout_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vgamma_1
=
__half22float2
(
gamma_h2_1
[
i
]);
dxhat
[
i
].
x
=
vdout
.
x
*
vgamma
.
x
;
dxhat
[
i
].
y
=
vdout
.
y
*
vgamma
.
y
;
dxhat_1
[
i
].
x
=
vdout_1
.
x
*
vgamma_1
.
x
;
dxhat_1
[
i
].
y
=
vdout_1
.
y
*
vgamma_1
.
y
;
reduce_val
[
0
]
+=
dxhat
[
i
].
x
+
dxhat
[
i
].
y
+
dxhat_1
[
i
].
x
+
dxhat_1
[
i
].
y
;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp
=
((
const
float4
*
)
inp_or_out
)[
offset
];
vtmp_1
=
((
const
float4
*
)
inp_or_out
)[
offset
+
1
];
var_rsqrt
=
rsqrtf
((
float
)
vars
[
blockIdx
.
x
]
+
LN_EPSILON
);
if
(
means
==
nullptr
)
{
// inp_or_out is output, xhat = (output - betta) / gamma
float4
vbetta
=
((
const
float4
*
)
betta
)[
2
*
threadIdx
.
x
];
float4
vbetta_1
=
((
const
float4
*
)
betta
)[
2
*
threadIdx
.
x
+
1
];
__half2
*
betta_h2
=
reinterpret_cast
<
__half2
*>
(
&
vbetta
);
__half2
*
betta_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
vbetta_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vout_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vgamma_1
=
__half22float2
(
gamma_h2_1
[
i
]);
float2
vbetta
=
__half22float2
(
betta_h2
[
i
]);
float2
vbetta_1
=
__half22float2
(
betta_h2_1
[
i
]);
xhat
[
i
].
x
=
(
vout
.
x
-
vbetta
.
x
)
/
add_eps
(
vgamma
.
x
);
xhat_1
[
i
].
x
=
(
vout_1
.
x
-
vbetta_1
.
x
)
/
add_eps
(
vgamma_1
.
x
);
xhat
[
i
].
y
=
(
vout
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
xhat_1
[
i
].
y
=
(
vout_1
.
y
-
vbetta_1
.
y
)
/
add_eps
(
vgamma_1
.
y
);
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
}
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float
fmean
=
(
float
)
means
[
blockIdx
.
x
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vinp
=
__half22float2
(
tmp_h2
[
i
]);
float2
vinp_1
=
__half22float2
(
tmp_h2_1
[
i
]);
xhat
[
i
].
x
=
(
vinp
.
x
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
x
=
(
vinp_1
.
x
-
fmean
)
*
var_rsqrt
;
xhat
[
i
].
y
=
(
vinp
.
y
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
y
=
(
vinp_1
.
y
-
fmean
)
*
var_rsqrt
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_sum_dxhat
,
s_sum_dxhat_xhat
;
if
(
threadIdx
.
x
==
0
)
{
float
mean_dim
=
hidden_dim
*
8
*
2
;
s_sum_dxhat
=
reduce_val
[
0
]
/
mean_dim
;
s_sum_dxhat_xhat
=
reduce_val
[
1
]
/
mean_dim
;
}
__syncthreads
();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if
(
threadIdx
.
x
>=
hidden_dim
)
{
return
;
}
if
(
residual_grad
)
{
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
float4
dresidual_1
=
((
const
float4
*
)
residual_grad
)[
offset
+
1
];
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
__half
*
hdres_1
=
reinterpret_cast
<
__half
*>
(
&
dresidual_1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
]));
tmp_h2_1
[
i
].
x
=
__float2half
(
(
dxhat_1
[
i
].
x
-
s_sum_dxhat
-
xhat_1
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
]));
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
+
1
]));
tmp_h2_1
[
i
].
y
=
__float2half
(
(
dxhat_1
[
i
].
y
-
s_sum_dxhat
-
xhat_1
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
+
1
]));
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_1
[
i
].
x
=
__float2half
(
(
dxhat_1
[
i
].
x
-
s_sum_dxhat
-
xhat_1
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_1
[
i
].
y
=
__float2half
(
(
dxhat_1
[
i
].
y
-
s_sum_dxhat
-
xhat_1
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
}
}
((
float4
*
)
inp_grad
)[
offset
]
=
vtmp
;
((
float4
*
)
inp_grad
)[
offset
+
1
]
=
vtmp_1
;
}
__global__
void
ker_ln_bw_dinp_x4
(
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
means
,
int
hidden_dim
)
{
int
offset
=
blockIdx
.
x
*
hidden_dim
*
4
+
threadIdx
.
x
*
4
;
float2
dxhat
[
4
],
xhat
[
4
];
float2
dxhat_1
[
4
],
xhat_1
[
4
];
float2
dxhat_2
[
4
],
xhat_2
[
4
];
float2
dxhat_3
[
4
],
xhat_3
[
4
];
float
var_rsqrt
;
float4
vtmp
,
vtmp_1
,
vtmp_2
,
vtmp_3
;
__half2
*
tmp_h2
;
__half2
*
tmp_h2_1
;
__half2
*
tmp_h2_2
;
__half2
*
tmp_h2_3
;
float
reduce_val
[
2
]
=
{
0.
f
,
0.
f
};
if
(
threadIdx
.
x
<
hidden_dim
)
{
// step 0. dxhat = dout * gamma
vtmp
=
((
const
float4
*
)
out_grad
)[
offset
];
vtmp_1
=
((
const
float4
*
)
out_grad
)[
offset
+
1
];
vtmp_2
=
((
const
float4
*
)
out_grad
)[
offset
+
2
];
vtmp_3
=
((
const
float4
*
)
out_grad
)[
offset
+
3
];
tmp_h2
=
reinterpret_cast
<
__half2
*>
(
&
vtmp
);
tmp_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
vtmp_1
);
tmp_h2_2
=
reinterpret_cast
<
__half2
*>
(
&
vtmp_2
);
tmp_h2_3
=
reinterpret_cast
<
__half2
*>
(
&
vtmp_3
);
float4
gamma_f4
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
4
];
float4
gamma_f4_1
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
4
+
1
];
float4
gamma_f4_2
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
4
+
2
];
float4
gamma_f4_3
=
((
const
float4
*
)
gamma
)[
threadIdx
.
x
*
4
+
3
];
__half2
*
gamma_h2
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4
);
__half2
*
gamma_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4_1
);
__half2
*
gamma_h2_2
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4_2
);
__half2
*
gamma_h2_3
=
reinterpret_cast
<
__half2
*>
(
&
gamma_f4_3
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vdout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vdout_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vdout_2
=
__half22float2
(
tmp_h2_2
[
i
]);
float2
vdout_3
=
__half22float2
(
tmp_h2_3
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vgamma_1
=
__half22float2
(
gamma_h2_1
[
i
]);
float2
vgamma_2
=
__half22float2
(
gamma_h2_2
[
i
]);
float2
vgamma_3
=
__half22float2
(
gamma_h2_3
[
i
]);
dxhat
[
i
].
x
=
vdout
.
x
*
vgamma
.
x
;
dxhat
[
i
].
y
=
vdout
.
y
*
vgamma
.
y
;
dxhat_1
[
i
].
x
=
vdout_1
.
x
*
vgamma_1
.
x
;
dxhat_1
[
i
].
y
=
vdout_1
.
y
*
vgamma_1
.
y
;
dxhat_2
[
i
].
x
=
vdout_2
.
x
*
vgamma_2
.
x
;
dxhat_2
[
i
].
y
=
vdout_2
.
y
*
vgamma_2
.
y
;
dxhat_3
[
i
].
x
=
vdout_3
.
x
*
vgamma_3
.
x
;
dxhat_3
[
i
].
y
=
vdout_3
.
y
*
vgamma_3
.
y
;
reduce_val
[
0
]
+=
dxhat
[
i
].
x
+
dxhat
[
i
].
y
+
dxhat_1
[
i
].
x
+
dxhat_1
[
i
].
y
+
dxhat_2
[
i
].
x
+
dxhat_2
[
i
].
y
+
dxhat_3
[
i
].
x
+
dxhat_3
[
i
].
y
;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp
=
((
const
float4
*
)
inp_or_out
)[
offset
];
vtmp_1
=
((
const
float4
*
)
inp_or_out
)[
offset
+
1
];
vtmp_2
=
((
const
float4
*
)
inp_or_out
)[
offset
+
2
];
vtmp_3
=
((
const
float4
*
)
inp_or_out
)[
offset
+
3
];
var_rsqrt
=
rsqrtf
((
float
)
vars
[
blockIdx
.
x
]
+
LN_EPSILON
);
if
(
means
==
nullptr
)
{
// inp_or_out is output, xhat = (output - betta) / gamma
float4
vbetta
=
((
const
float4
*
)
betta
)[
4
*
threadIdx
.
x
];
float4
vbetta_1
=
((
const
float4
*
)
betta
)[
4
*
threadIdx
.
x
+
1
];
float4
vbetta_2
=
((
const
float4
*
)
betta
)[
4
*
threadIdx
.
x
+
2
];
float4
vbetta_3
=
((
const
float4
*
)
betta
)[
4
*
threadIdx
.
x
+
3
];
__half2
*
betta_h2
=
reinterpret_cast
<
__half2
*>
(
&
vbetta
);
__half2
*
betta_h2_1
=
reinterpret_cast
<
__half2
*>
(
&
vbetta_1
);
__half2
*
betta_h2_2
=
reinterpret_cast
<
__half2
*>
(
&
vbetta_2
);
__half2
*
betta_h2_3
=
reinterpret_cast
<
__half2
*>
(
&
vbetta_3
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vout
=
__half22float2
(
tmp_h2
[
i
]);
float2
vout_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vout_2
=
__half22float2
(
tmp_h2_2
[
i
]);
float2
vout_3
=
__half22float2
(
tmp_h2_3
[
i
]);
float2
vgamma
=
__half22float2
(
gamma_h2
[
i
]);
float2
vgamma_1
=
__half22float2
(
gamma_h2_1
[
i
]);
float2
vgamma_2
=
__half22float2
(
gamma_h2_2
[
i
]);
float2
vgamma_3
=
__half22float2
(
gamma_h2_3
[
i
]);
float2
vbetta
=
__half22float2
(
betta_h2
[
i
]);
float2
vbetta_1
=
__half22float2
(
betta_h2_1
[
i
]);
float2
vbetta_2
=
__half22float2
(
betta_h2_2
[
i
]);
float2
vbetta_3
=
__half22float2
(
betta_h2_3
[
i
]);
xhat
[
i
].
x
=
(
vout
.
x
-
vbetta
.
x
)
/
add_eps
(
vgamma
.
x
);
xhat_1
[
i
].
x
=
(
vout_1
.
x
-
vbetta_1
.
x
)
/
add_eps
(
vgamma_1
.
x
);
xhat_2
[
i
].
x
=
(
vout_2
.
x
-
vbetta_2
.
x
)
/
add_eps
(
vgamma_2
.
x
);
xhat_3
[
i
].
x
=
(
vout_3
.
x
-
vbetta_3
.
x
)
/
add_eps
(
vgamma_3
.
x
);
xhat
[
i
].
y
=
(
vout
.
y
-
vbetta
.
y
)
/
add_eps
(
vgamma
.
y
);
xhat_1
[
i
].
y
=
(
vout_1
.
y
-
vbetta_1
.
y
)
/
add_eps
(
vgamma_1
.
y
);
xhat_2
[
i
].
y
=
(
vout_2
.
y
-
vbetta_2
.
y
)
/
add_eps
(
vgamma_2
.
y
);
xhat_3
[
i
].
y
=
(
vout_3
.
y
-
vbetta_3
.
y
)
/
add_eps
(
vgamma_3
.
y
);
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_2
[
i
].
x
*
dxhat_2
[
i
].
x
+
xhat_2
[
i
].
y
*
dxhat_2
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_3
[
i
].
x
*
dxhat_3
[
i
].
x
+
xhat_3
[
i
].
y
*
dxhat_3
[
i
].
y
;
}
}
else
{
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float
fmean
=
(
float
)
means
[
blockIdx
.
x
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
float2
vinp
=
__half22float2
(
tmp_h2
[
i
]);
float2
vinp_1
=
__half22float2
(
tmp_h2_1
[
i
]);
float2
vinp_2
=
__half22float2
(
tmp_h2_2
[
i
]);
float2
vinp_3
=
__half22float2
(
tmp_h2_3
[
i
]);
xhat
[
i
].
x
=
(
vinp
.
x
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
x
=
(
vinp_1
.
x
-
fmean
)
*
var_rsqrt
;
xhat_2
[
i
].
x
=
(
vinp_2
.
x
-
fmean
)
*
var_rsqrt
;
xhat_3
[
i
].
x
=
(
vinp_3
.
x
-
fmean
)
*
var_rsqrt
;
xhat
[
i
].
y
=
(
vinp
.
y
-
fmean
)
*
var_rsqrt
;
xhat_1
[
i
].
y
=
(
vinp_1
.
y
-
fmean
)
*
var_rsqrt
;
xhat_2
[
i
].
y
=
(
vinp_2
.
y
-
fmean
)
*
var_rsqrt
;
xhat_3
[
i
].
y
=
(
vinp_3
.
y
-
fmean
)
*
var_rsqrt
;
reduce_val
[
1
]
+=
xhat
[
i
].
x
*
dxhat
[
i
].
x
+
xhat
[
i
].
y
*
dxhat
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_1
[
i
].
x
*
dxhat_1
[
i
].
x
+
xhat_1
[
i
].
y
*
dxhat_1
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_2
[
i
].
x
*
dxhat_2
[
i
].
x
+
xhat_2
[
i
].
y
*
dxhat_2
[
i
].
y
;
reduce_val
[
1
]
+=
xhat_3
[
i
].
x
*
dxhat_3
[
i
].
x
+
xhat_3
[
i
].
y
*
dxhat_3
[
i
].
y
;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce
<
ReduceType
::
kSum
,
2
>
(
reduce_val
);
__shared__
float
s_sum_dxhat
,
s_sum_dxhat_xhat
;
if
(
threadIdx
.
x
==
0
)
{
float
mean_dim
=
hidden_dim
*
8
*
4
;
s_sum_dxhat
=
reduce_val
[
0
]
/
mean_dim
;
s_sum_dxhat_xhat
=
reduce_val
[
1
]
/
mean_dim
;
}
__syncthreads
();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if
(
threadIdx
.
x
>=
hidden_dim
)
{
return
;
}
if
(
residual_grad
)
{
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4
dresidual
=
((
const
float4
*
)
residual_grad
)[
offset
];
float4
dresidual_1
=
((
const
float4
*
)
residual_grad
)[
offset
+
1
];
float4
dresidual_2
=
((
const
float4
*
)
residual_grad
)[
offset
+
2
];
float4
dresidual_3
=
((
const
float4
*
)
residual_grad
)[
offset
+
3
];
__half
*
hdres
=
reinterpret_cast
<
__half
*>
(
&
dresidual
);
__half
*
hdres_1
=
reinterpret_cast
<
__half
*>
(
&
dresidual_1
);
__half
*
hdres_2
=
reinterpret_cast
<
__half
*>
(
&
dresidual_2
);
__half
*
hdres_3
=
reinterpret_cast
<
__half
*>
(
&
dresidual_3
);
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
]));
tmp_h2_1
[
i
].
x
=
__float2half
(
(
dxhat_1
[
i
].
x
-
s_sum_dxhat
-
xhat_1
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
]));
tmp_h2_2
[
i
].
x
=
__float2half
(
(
dxhat_2
[
i
].
x
-
s_sum_dxhat
-
xhat_2
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_2
[
2
*
i
]));
tmp_h2_3
[
i
].
x
=
__float2half
(
(
dxhat_3
[
i
].
x
-
s_sum_dxhat
-
xhat_3
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_3
[
2
*
i
]));
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres
[
2
*
i
+
1
]));
tmp_h2_1
[
i
].
y
=
__float2half
(
(
dxhat_1
[
i
].
y
-
s_sum_dxhat
-
xhat_1
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
+
1
]));
tmp_h2_2
[
i
].
y
=
__float2half
(
(
dxhat_2
[
i
].
y
-
s_sum_dxhat
-
xhat_2
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
+
1
]));
tmp_h2_3
[
i
].
y
=
__float2half
(
(
dxhat_3
[
i
].
y
-
s_sum_dxhat
-
xhat_3
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
+
__half2float
(
hdres_1
[
2
*
i
+
1
]));
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp_h2
[
i
].
x
=
__float2half
(
(
dxhat
[
i
].
x
-
s_sum_dxhat
-
xhat
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_1
[
i
].
x
=
__float2half
(
(
dxhat_1
[
i
].
x
-
s_sum_dxhat
-
xhat_1
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_2
[
i
].
x
=
__float2half
(
(
dxhat_2
[
i
].
x
-
s_sum_dxhat
-
xhat_2
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_3
[
i
].
x
=
__float2half
(
(
dxhat_3
[
i
].
x
-
s_sum_dxhat
-
xhat_3
[
i
].
x
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2
[
i
].
y
=
__float2half
(
(
dxhat
[
i
].
y
-
s_sum_dxhat
-
xhat
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_1
[
i
].
y
=
__float2half
(
(
dxhat_1
[
i
].
y
-
s_sum_dxhat
-
xhat_1
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_2
[
i
].
y
=
__float2half
(
(
dxhat_2
[
i
].
y
-
s_sum_dxhat
-
xhat_2
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
tmp_h2_3
[
i
].
y
=
__float2half
(
(
dxhat_3
[
i
].
y
-
s_sum_dxhat
-
xhat_3
[
i
].
y
*
s_sum_dxhat_xhat
)
*
var_rsqrt
);
}
}
((
float4
*
)
inp_grad
)[
offset
]
=
vtmp
;
((
float4
*
)
inp_grad
)[
offset
+
1
]
=
vtmp_1
;
((
float4
*
)
inp_grad
)[
offset
+
2
]
=
vtmp_2
;
((
float4
*
)
inp_grad
)[
offset
+
3
]
=
vtmp_3
;
}
/**
Layer norm backword,
compute the gradient of gamma, betta and input.
dbetta = sum(dout, dim=0)
xhat = (input - mean) * rsqrt(var) if mean is not nullptr
(output - betta) / gamma if mean is nullptr
dgamma = sum(xhat * dout, dim=0)
dxhat = dout * gamma
dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim)
* rsqrt(var)
residual_grad, means, betta can be nullptr.
residual_grad will be added to dinp if it is not nullptr
which is useful in transformer layer when pre-ln
means and betta are only used to compute xhat,
(means == nullptr) ^ (betta == nullptr) should be true
*/
template
<
>
void
launch_ln_bw
<
float
>
(
float
*
gamma_grad
,
float
*
betta_grad
,
float
*
inp_grad
,
const
float
*
out_grad
,
const
float
*
residual_grad
,
const
float
*
inp_or_out
,
const
float
*
gamma
,
const
float
*
betta
,
const
float
*
vars
,
const
float
*
means
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
// compute grad of gamma and betta
dim3
grid_dim
(((
hidden_dim
+
TILE_DIM
-
1
)
/
TILE_DIM
)
*
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
ker_ln_bw_dgamma_dbetta
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
gamma_grad
,
betta_grad
,
out_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
batch
,
hidden_dim
);
// compute grad of input
if
(
hidden_dim
%
4
!=
0
||
hidden_dim
>
4096
)
{
throw
std
::
runtime_error
(
"hidden_dim % 4 != 0 || hidden_dim > 4096"
);
}
hidden_dim
>>=
2
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
hidden_dim
);
}
template
<
>
void
launch_ln_bw
<
__half
>
(
__half
*
gamma_grad
,
__half
*
betta_grad
,
__half
*
inp_grad
,
const
__half
*
out_grad
,
const
__half
*
residual_grad
,
const
__half
*
inp_or_out
,
const
__half
*
gamma
,
const
__half
*
betta
,
const
__half
*
vars
,
const
__half
*
means
,
int
batch
,
int
hidden_dim
,
cudaStream_t
stream
[
2
])
{
// compute grad of gamma and betta
dim3
grid_dim
(((
hidden_dim
+
TILE_DIM
-
1
)
/
TILE_DIM
)
*
TILE_DIM
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
ker_ln_bw_dgamma_dbetta
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
[
0
]
>>>
(
gamma_grad
,
betta_grad
,
out_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
batch
,
hidden_dim
);
// compute grad of input
if
(
hidden_dim
%
8
!=
0
)
{
throw
std
::
runtime_error
(
"hidden_dim % 8 != 0"
);
}
hidden_dim
>>=
3
;
if
(
hidden_dim
*
8
<=
8192
)
{
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
hidden_dim
);
}
else
if
(
hidden_dim
*
8
>
8192
&&
hidden_dim
*
8
<=
8192
*
2
)
{
hidden_dim
>>=
1
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp_x2
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
hidden_dim
);
}
else
if
(
hidden_dim
*
8
>
2
*
8192
&&
hidden_dim
*
8
<=
8192
*
4
)
{
hidden_dim
>>=
2
;
int
nthread
=
min
(((
hidden_dim
+
31
)
/
32
)
*
32
,
MAX_THREADS
);
ker_ln_bw_dinp_x4
<<<
batch
,
nthread
,
0
,
stream
[
1
]
>>>
(
inp_grad
,
out_grad
,
residual_grad
,
inp_or_out
,
gamma
,
betta
,
vars
,
means
,
hidden_dim
);
}
else
{
throw
std
::
runtime_error
(
"hidden_dim % 4 != 0 || hidden_dim > 32768"
);
}
}
colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu
deleted
100644 → 0
View file @
bce9499e
#include <cooperative_groups.h>
#include <math.h>
#include <cub/block/block_load.cuh>
#include <cub/cub.cuh>
#include "block_reduce.h"
#include "kernels.h"
namespace
cg
=
cooperative_groups
;
const
float
EPSILON
=
1e-8
f
;
/**
@brief: softmax_kernel
Softmax forward kernel for
enc-self-attn, dec-self-attn, encdec-attn
@thread
gridDim.x = dynamic
gridDim.y = batch_size
gridDim.z = nhead
blockDim.x = from_len
@param
inp: [batch_size, nhead, from_len, to_len], softmax input.
attn_mask: [batch_size, to_len], padding tokens are -inf,
non padding tokens are 0.
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template
<
typename
T
,
int
block_dim
,
int
ele_per_thread
>
__global__
void
ker_attn_softmax
(
T
*
inp
,
const
T
*
attn_mask
,
int
from_len
,
int
to_len
,
bool
mask_future
)
{
int
batch_id
=
blockIdx
.
y
;
int
head_id
=
blockIdx
.
z
;
const
int
nhead
=
gridDim
.
z
;
const
int
token_per_reduce
=
1
;
typedef
cub
::
BlockLoad
<
T
,
block_dim
,
ele_per_thread
,
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_dim
,
ele_per_thread
,
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
T
mval
[
ele_per_thread
];
if
(
attn_mask
)
{
attn_mask
+=
batch_id
*
to_len
;
BlockLoad
(
ts_load
).
Load
(
attn_mask
,
mval
,
to_len
,
REDUCE_FLOAT_INF_NEG
);
}
inp
+=
flat_3dim
(
batch_id
,
head_id
,
0
,
nhead
,
from_len
*
to_len
);
for
(
int
token_id
=
blockIdx
.
x
*
token_per_reduce
;
token_id
<
from_len
;
token_id
+=
gridDim
.
x
*
token_per_reduce
)
{
T
inp_val
[
token_per_reduce
][
ele_per_thread
];
for
(
int
i
=
0
;
i
<
token_per_reduce
&&
(
token_id
+
i
)
<
from_len
;
i
++
)
{
BlockLoad
(
ts_load
).
Load
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
,
REDUCE_FLOAT_INF_NEG
);
}
/* step 1. compute max */
// thread local max
float
val
[
token_per_reduce
][
ele_per_thread
];
float
l_max
[
token_per_reduce
];
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
l_max
[
i
]
=
REDUCE_FLOAT_INF_NEG
;
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
if
(
attn_mask
)
{
val
[
i
][
j
]
=
(
float
)
inp_val
[
i
][
j
]
+
(
float
)
mval
[
j
];
}
else
{
if
(
mask_future
&&
ele_per_thread
*
threadIdx
.
x
+
j
>
token_id
+
i
)
{
val
[
i
][
j
]
=
REDUCE_FLOAT_INF_NEG
;
}
else
{
val
[
i
][
j
]
=
(
float
)
inp_val
[
i
][
j
];
}
}
l_max
[
i
]
=
fmaxf
(
l_max
[
i
],
val
[
i
][
j
]);
}
}
// block reduce max
blockReduce
<
ReduceType
::
kMax
,
token_per_reduce
>
(
l_max
);
// write shared
__shared__
float
s_max
[
token_per_reduce
];
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
s_max
[
i
]
=
l_max
[
i
];
}
}
__syncthreads
();
/* step 2. compute sum */
// thread local sum
float
l_sum
[
token_per_reduce
];
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
l_sum
[
i
]
=
0.
f
;
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
val
[
i
][
j
]
=
__expf
(
val
[
i
][
j
]
-
s_max
[
i
]);
l_sum
[
i
]
+=
val
[
i
][
j
];
}
}
// block reduce sum
blockReduce
<
ReduceType
::
kSum
,
token_per_reduce
>
(
l_sum
);
// write shared
__shared__
float
s_sum
[
token_per_reduce
];
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
s_sum
[
i
]
=
__fdividef
(
1.0
f
,
l_sum
[
i
]
+
EPSILON
);
}
}
__syncthreads
();
/* step 3. compute final result */
for
(
int
i
=
0
;
i
<
token_per_reduce
&&
(
token_id
+
i
)
<
from_len
;
i
++
)
{
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
inp_val
[
i
][
j
]
=
(
T
)(
val
[
i
][
j
]
*
s_sum
[
i
]);
}
BlockStore
(
ts_store
).
Store
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
);
}
}
// blockIdx.x
}
template
<
typename
T
,
int
block_dim
,
int
ele_per_thread
>
__global__
void
ker_attn_softmax_lt32
(
T
*
inp
,
const
T
*
attn_mask
,
int
from_len
,
int
to_len
,
bool
mask_future
)
{
int
batch_id
=
blockIdx
.
y
;
int
head_id
=
blockIdx
.
z
;
const
int
nhead
=
gridDim
.
z
;
const
int
token_per_reduce
=
1
;
typedef
cub
::
BlockLoad
<
T
,
block_dim
,
ele_per_thread
,
cub
::
BLOCK_LOAD_VECTORIZE
>
BlockLoad
;
__shared__
typename
BlockLoad
::
TempStorage
ts_load
;
typedef
cub
::
BlockStore
<
T
,
block_dim
,
ele_per_thread
,
cub
::
BLOCK_STORE_VECTORIZE
>
BlockStore
;
__shared__
typename
BlockStore
::
TempStorage
ts_store
;
T
mval
[
ele_per_thread
];
if
(
attn_mask
)
{
attn_mask
+=
batch_id
*
to_len
;
BlockLoad
(
ts_load
).
Load
(
attn_mask
,
mval
,
to_len
,
REDUCE_FLOAT_INF_NEG
);
}
inp
+=
flat_3dim
(
batch_id
,
head_id
,
0
,
nhead
,
from_len
*
to_len
);
for
(
int
token_id
=
blockIdx
.
x
*
token_per_reduce
;
token_id
<
from_len
;
token_id
+=
gridDim
.
x
*
token_per_reduce
)
{
T
inp_val
[
token_per_reduce
][
ele_per_thread
];
for
(
int
i
=
0
;
i
<
token_per_reduce
&&
(
token_id
+
i
)
<
from_len
;
i
++
)
{
BlockLoad
(
ts_load
).
Load
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
,
REDUCE_FLOAT_INF_NEG
);
}
/* step 1. compute max */
// thread local max
float
val
[
token_per_reduce
][
ele_per_thread
];
float
l_max
[
token_per_reduce
];
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
l_max
[
i
]
=
REDUCE_FLOAT_INF_NEG
;
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
if
(
attn_mask
)
{
val
[
i
][
j
]
=
(
float
)
inp_val
[
i
][
j
]
+
(
float
)
mval
[
j
];
}
else
{
if
(
mask_future
&&
ele_per_thread
*
threadIdx
.
x
+
j
>
token_id
+
i
)
{
val
[
i
][
j
]
=
REDUCE_FLOAT_INF_NEG
;
}
else
{
val
[
i
][
j
]
=
(
float
)
inp_val
[
i
][
j
];
}
}
l_max
[
i
]
=
fmaxf
(
l_max
[
i
],
val
[
i
][
j
]);
}
}
// warp reduce max
warpReduce
<
ReduceType
::
kMax
,
token_per_reduce
>
(
l_max
);
/* step 2. compute sum */
// thread local sum
float
l_sum
[
token_per_reduce
];
for
(
int
i
=
0
;
i
<
token_per_reduce
;
i
++
)
{
l_sum
[
i
]
=
0.
f
;
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
val
[
i
][
j
]
=
__expf
(
val
[
i
][
j
]
-
l_max
[
i
]);
l_sum
[
i
]
+=
val
[
i
][
j
];
}
}
// warp reduce sum
warpReduce
<
ReduceType
::
kSum
,
token_per_reduce
>
(
l_sum
);
/* step 3. compute final result */
for
(
int
i
=
0
;
i
<
token_per_reduce
&&
(
token_id
+
i
)
<
from_len
;
i
++
)
{
l_sum
[
i
]
=
__fdividef
(
1.0
f
,
l_sum
[
i
]
+
EPSILON
);
for
(
int
j
=
0
;
j
<
ele_per_thread
;
j
++
)
{
inp_val
[
i
][
j
]
=
(
T
)(
val
[
i
][
j
]
*
l_sum
[
i
]);
}
BlockStore
(
ts_store
).
Store
(
inp
+
(
token_id
+
i
)
*
to_len
,
inp_val
[
i
],
to_len
);
}
}
// blockIdx.x
}
/*
attn_mask!=nullptr for enc-self-attn and enc-dec-attn
attn_mask=nullptr and mask_future=ture for dec-self-attn training
attn_mask=nullptr and mask_future=false for dec-self-attn infer
*/
template
<
>
void
launch_attn_softmax
<
float
>
(
float
*
inp
,
const
float
*
attn_mask
,
int
batch_size
,
int
nhead
,
int
from_len
,
int
to_len
,
bool
mask_future
,
cudaStream_t
stream
)
{
dim3
grid_dim
(
1
,
batch_size
,
nhead
);
if
(
to_len
<=
32
)
{
ker_attn_softmax_lt32
<
float
,
32
,
1
><<<
grid_dim
,
32
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
64
)
{
ker_attn_softmax_lt32
<
float
,
32
,
2
><<<
grid_dim
,
32
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
128
)
{
grid_dim
.
x
=
16
;
ker_attn_softmax
<
float
,
64
,
2
><<<
grid_dim
,
64
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
256
)
{
grid_dim
.
x
=
32
;
ker_attn_softmax
<
float
,
128
,
2
><<<
grid_dim
,
128
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
512
)
{
grid_dim
.
x
=
64
;
ker_attn_softmax
<
float
,
256
,
2
><<<
grid_dim
,
256
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
{
throw
std
::
runtime_error
(
"Sequence length greater than 512 is currently not supported"
);
}
}
template
<
>
void
launch_attn_softmax
<
__half
>
(
__half
*
inp
,
const
__half
*
attn_mask
,
int
batch_size
,
int
nhead
,
int
from_len
,
int
to_len
,
bool
mask_future
,
cudaStream_t
stream
)
{
dim3
grid_dim
(
1
,
batch_size
,
nhead
);
if
(
to_len
<=
32
)
{
ker_attn_softmax_lt32
<
__half
,
32
,
1
><<<
grid_dim
,
32
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
64
)
{
ker_attn_softmax_lt32
<
__half
,
32
,
2
><<<
grid_dim
,
32
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
128
)
{
grid_dim
.
x
=
8
;
ker_attn_softmax
<
__half
,
64
,
2
><<<
grid_dim
,
64
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
256
)
{
grid_dim
.
x
=
16
;
ker_attn_softmax
<
__half
,
128
,
2
><<<
grid_dim
,
128
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
if
(
to_len
<=
512
)
{
grid_dim
.
x
=
32
;
ker_attn_softmax
<
__half
,
256
,
2
><<<
grid_dim
,
256
,
0
,
stream
>>>
(
inp
,
attn_mask
,
from_len
,
to_len
,
mask_future
);
}
else
{
throw
std
::
runtime_error
(
"Sequence length greater than 512 is currently not supported"
);
}
}
/**
@brief: ker_attn_softmax_bw
Softmax backward in self attention.
@thread
gridDim.x = batch_size * nhead * seq_len / warps_per_block
blockDim.x = WARP_SIZE
blockDim.y = warps_per_block
@param
grad: [batch_size, nhead, seq_len, seq_len], output grad.
output: [batch_size, nhead, seq_len, seq_len], output of softmax forward.
*/
template
<
typename
T
,
int
ITERATIONS
>
__global__
void
ker_attn_softmax_bw
(
T
*
grad
,
const
T
*
inp
,
int
softmax_length
)
{
int
batch_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
offset
=
batch_idx
*
softmax_length
+
threadIdx
.
x
;
grad
+=
offset
;
inp
+=
offset
;
T
grad_reg
[
ITERATIONS
];
T
inp_reg
[
ITERATIONS
];
float
sum
=
0.0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
int
curr_idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
curr_idx
<
softmax_length
)
{
grad_reg
[
i
]
=
grad
[
i
*
WARP_SIZE
];
inp_reg
[
i
]
=
inp
[
i
*
WARP_SIZE
];
sum
+=
(
float
)
grad_reg
[
i
]
*
(
float
)
inp_reg
[
i
];
}
}
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
<<=
1
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ITERATIONS
;
++
i
)
{
int
curr_idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
curr_idx
<
softmax_length
)
grad
[
i
*
WARP_SIZE
]
=
(
T
)((
float
)
inp_reg
[
i
]
*
((
float
)
grad_reg
[
i
]
-
sum
));
}
}
template
<
typename
T
>
void
launch_attn_softmax_bw
(
T
*
out_grad
,
const
T
*
soft_inp
,
int
rows
,
int
softmax_len
,
cudaStream_t
stream
)
{
const
int
warps_per_block
=
4
;
// rows = batch_size * nhead * from_len
dim3
grid_dim
(
rows
/
warps_per_block
);
dim3
block_dim
(
WARP_SIZE
,
warps_per_block
);
if
(
softmax_len
<=
32
)
ker_attn_softmax_bw
<
T
,
1
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
64
)
ker_attn_softmax_bw
<
T
,
2
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
128
)
ker_attn_softmax_bw
<
T
,
4
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
256
)
ker_attn_softmax_bw
<
T
,
8
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
384
)
ker_attn_softmax_bw
<
T
,
12
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
512
)
ker_attn_softmax_bw
<
T
,
16
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
768
)
ker_attn_softmax_bw
<
T
,
24
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
1024
)
ker_attn_softmax_bw
<
T
,
32
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
if
(
softmax_len
<=
2048
)
ker_attn_softmax_bw
<
T
,
64
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out_grad
,
soft_inp
,
softmax_len
);
else
throw
std
::
runtime_error
(
std
::
string
(
"Special sequence length found in softmax backward, seq_len: "
)
+
std
::
to_string
(
softmax_len
));
}
template
void
launch_attn_softmax_bw
<
__half
>(
__half
*
out_grad
,
const
__half
*
soft_inp
,
int
rows
,
int
softmax_len
,
cudaStream_t
stream
);
template
void
launch_attn_softmax_bw
<
float
>(
float
*
out_grad
,
const
float
*
soft_inp
,
int
rows
,
int
softmax_len
,
cudaStream_t
stream
);
colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu
deleted
100644 → 0
View file @
bce9499e
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include "kernels.h"
using
namespace
cub
;
/**
@brief: transform_0213
Split the attention heads and reshape input
during backward progress of encoder self-attention
@thread
gridDim.x = batch_size
gridDim.y = seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
input: [batch_size, seq_len, hidden_dim]
output: [batch_size, nhead, seq_len, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
*/
template
<
typename
T
>
__global__
void
transform_0213
(
T
*
output
,
const
T
*
input
,
int
hidden_dim
,
int
head_dim
);
template
<
>
__global__
void
transform_0213
<
float
>
(
float
*
output
,
const
float
*
input
,
int
hidden_dim
,
int
head_dim
)
{
int
batch_id
=
blockIdx
.
x
;
int
token_id
=
blockIdx
.
y
;
int
seq_len
=
gridDim
.
y
;
int
nhead
=
hidden_dim
/
head_dim
;
// [b, s, h]
int
src_offset
=
flat_3dim
(
batch_id
,
token_id
,
0
,
seq_len
,
hidden_dim
);
// [b, nh, s, ad]
int
trg_offset
=
flat_4dim
(
batch_id
,
0
,
token_id
,
0
,
nhead
,
seq_len
,
head_dim
);
const
float4
*
input4
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
float4
vinput4
;
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
hidden_dim
;
i
+=
blockDim
.
x
)
{
vinput4
=
input4
[
src_offset
+
i
];
int
head_id
=
i
/
head_dim
;
int
dim_id
=
i
%
head_dim
;
int
cur_trg_offset
=
flat_3dim
(
head_id
,
0
,
dim_id
,
seq_len
,
head_dim
);
res4
[
trg_offset
+
cur_trg_offset
]
=
vinput4
;
}
}
template
<
>
__global__
void
transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
int
hidden_dim
,
int
head_dim
)
{
int
batch_id
=
blockIdx
.
x
;
int
token_id
=
blockIdx
.
y
;
int
seq_len
=
gridDim
.
y
;
int
nhead
=
hidden_dim
/
head_dim
;
// [b, s, h]
int
src_offset
=
flat_3dim
(
batch_id
,
token_id
,
0
,
seq_len
,
hidden_dim
);
// [b, nh, s, ad]
int
trg_offset
=
flat_4dim
(
batch_id
,
0
,
token_id
,
0
,
nhead
,
seq_len
,
head_dim
);
const
float4
*
input4
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
float4
vinput4
;
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
hidden_dim
;
i
+=
blockDim
.
x
)
{
vinput4
=
input4
[
src_offset
+
i
];
int
head_id
=
i
/
head_dim
;
int
dim_id
=
i
%
head_dim
;
int
cur_trg_offset
=
flat_3dim
(
head_id
,
0
,
dim_id
,
seq_len
,
head_dim
);
res4
[
trg_offset
+
cur_trg_offset
]
=
vinput4
;
}
}
// [b, s, h] -> [b, nh, s, ad]
template
<
>
void
launch_transform_0213
<
float
>
(
float
*
output
,
const
float
*
input
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
int
nhead
,
cudaStream_t
stream
)
{
hidden_dim
>>=
2
;
int
head_dim
=
hidden_dim
/
nhead
;
dim3
grid_dim
(
batch_size
,
seq_len
);
dim3
block_dim
(
min
(
hidden_dim
,
MAX_THREADS
));
transform_0213
<
float
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
hidden_dim
,
head_dim
);
}
template
<
>
void
launch_transform_0213
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
int
nhead
,
cudaStream_t
stream
)
{
hidden_dim
>>=
3
;
int
head_dim
=
hidden_dim
/
nhead
;
dim3
grid_dim
(
batch_size
,
seq_len
);
dim3
block_dim
(
min
(
hidden_dim
,
MAX_THREADS
));
transform_0213
<
__half
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
hidden_dim
,
head_dim
);
}
/**
@brief: bias_add_transform_20314
Add bias to input, transform from
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
@thread
gridDim.x = dim_0
gridDim.y = dim_1
gridDim.z = dim_2
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
@param
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
bias: [dim_2, dim_3, dim_4]
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
*/
template
<
typename
T
>
__global__
void
bias_add_transform_20314
(
T
*
output
,
const
T
*
input
,
const
T
*
bias
,
int
dim_3
,
int
dim_4
);
template
<
>
__global__
void
bias_add_transform_20314
<
float
>
(
float
*
output
,
const
float
*
input
,
const
float
*
bias
,
int
dim_3
,
int
dim_4
)
{
int
id0
=
blockIdx
.
x
;
int
id1
=
blockIdx
.
y
;
int
id2
=
blockIdx
.
z
;
int
dim_0
=
gridDim
.
x
;
int
dim_1
=
gridDim
.
y
;
int
dim_2
=
gridDim
.
z
;
int
dim_34
=
dim_3
*
dim_4
;
int
src_offset
=
flat_4dim
(
id0
,
id1
,
id2
,
0
,
dim_1
,
dim_2
,
dim_34
);
int
trg_offset
=
flat_5dim
(
id2
,
id0
,
0
,
id1
,
0
,
dim_0
,
dim_3
,
dim_1
,
dim_4
);
int
bias_offset
=
flat_2dim
(
id2
,
0
,
dim_34
);
const
float4
*
qkv4
=
reinterpret_cast
<
const
float4
*>
(
input
);
const
float4
*
bias4
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
float4
vqkv4
;
float4
vbias4
;
float4
vres4
;
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
dim_34
;
i
+=
blockDim
.
x
)
{
vqkv4
=
qkv4
[
src_offset
+
i
];
vbias4
=
bias4
[
bias_offset
+
i
];
vres4
.
x
=
vqkv4
.
x
+
vbias4
.
x
;
vres4
.
y
=
vqkv4
.
y
+
vbias4
.
y
;
vres4
.
z
=
vqkv4
.
z
+
vbias4
.
z
;
vres4
.
w
=
vqkv4
.
w
+
vbias4
.
w
;
int
id3
=
i
/
dim_4
;
int
id4
=
i
%
dim_4
;
int
cur_trg_offset
=
flat_3dim
(
id3
,
0
,
id4
,
dim_1
,
dim_4
);
res4
[
trg_offset
+
cur_trg_offset
]
=
vres4
;
}
}
template
<
>
__global__
void
bias_add_transform_20314
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
const
__half
*
bias
,
int
dim_3
,
int
dim_4
)
{
int
id0
=
blockIdx
.
x
;
int
id1
=
blockIdx
.
y
;
int
id2
=
blockIdx
.
z
;
int
dim_0
=
gridDim
.
x
;
int
dim_1
=
gridDim
.
y
;
int
dim_2
=
gridDim
.
z
;
int
dim_34
=
dim_3
*
dim_4
;
int
src_offset
=
flat_4dim
(
id0
,
id1
,
id2
,
0
,
dim_1
,
dim_2
,
dim_34
);
int
trg_offset
=
flat_5dim
(
id2
,
id0
,
0
,
id1
,
0
,
dim_0
,
dim_3
,
dim_1
,
dim_4
);
int
bias_offset
=
flat_2dim
(
id2
,
0
,
dim_34
);
const
float4
*
qkv4
=
reinterpret_cast
<
const
float4
*>
(
input
);
const
float4
*
bias4
=
reinterpret_cast
<
const
float4
*>
(
bias
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
float4
vqkv4
;
float4
vbias4
;
float4
vres4
;
__half2
*
h2_qkv
=
reinterpret_cast
<
__half2
*>
(
&
vqkv4
);
__half2
*
h2_bias
=
reinterpret_cast
<
__half2
*>
(
&
vbias4
);
__half2
*
h2_res
=
reinterpret_cast
<
__half2
*>
(
&
vres4
);
for
(
std
::
size_t
i
=
threadIdx
.
x
;
i
<
dim_34
;
i
+=
blockDim
.
x
)
{
vqkv4
=
qkv4
[
src_offset
+
i
];
vbias4
=
bias4
[
bias_offset
+
i
];
h2_res
[
0
]
=
__hadd2
(
h2_qkv
[
0
],
h2_bias
[
0
]);
h2_res
[
1
]
=
__hadd2
(
h2_qkv
[
1
],
h2_bias
[
1
]);
h2_res
[
2
]
=
__hadd2
(
h2_qkv
[
2
],
h2_bias
[
2
]);
h2_res
[
3
]
=
__hadd2
(
h2_qkv
[
3
],
h2_bias
[
3
]);
int
id3
=
i
/
dim_4
;
int
id4
=
i
%
dim_4
;
int
cur_trg_offset
=
flat_3dim
(
id3
,
0
,
id4
,
dim_1
,
dim_4
);
res4
[
trg_offset
+
cur_trg_offset
]
=
vres4
;
}
}
// [b, s, 3, h] -> [3, b, nh, s, ad]
template
<
>
void
launch_bias_add_transform_20314
<
float
>
(
float
*
output
,
const
float
*
input
,
const
float
*
bias
,
int
dim_0
,
int
dim_1
,
int
dim_2
,
int
dim_3
,
int
dim_4
,
cudaStream_t
stream
)
{
dim_4
>>=
2
;
dim3
grid_dim
(
dim_0
,
dim_1
,
dim_2
);
dim3
block_dim
(
min
(
dim_3
*
dim_4
,
MAX_THREADS
));
bias_add_transform_20314
<
float
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
bias
,
dim_3
,
dim_4
);
}
template
<
>
void
launch_bias_add_transform_20314
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
const
__half
*
bias
,
int
dim_0
,
int
dim_1
,
int
dim_2
,
int
dim_3
,
int
dim_4
,
cudaStream_t
stream
)
{
dim_4
>>=
3
;
dim3
grid_dim
(
dim_0
,
dim_1
,
dim_2
);
dim3
block_dim
(
min
(
dim_3
*
dim_4
,
MAX_THREADS
));
bias_add_transform_20314
<
__half
>
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
output
,
input
,
bias
,
dim_3
,
dim_4
);
}
/**
@brief: transform4d_0213
Reshape the input matrix to merge the heads
@thread
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
blockDim.x = max_block_thread
@param
input: [trans_count, batch_size, nhead, seq_len, head_dim]
output: [batch_size, seq_len, trans_count, nhead, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
trans_count: 1 or 3, the count of matrice need to be transformed
*/
template
<
typename
T
>
__global__
void
transform4d_0213
(
T
*
output
,
const
T
*
input
,
int
batch_size
,
int
seq_len
,
int
trans_count
,
int
nhead
,
int
head_dim
,
int
num_all
)
{
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
>=
num_all
)
{
return
;
}
int
trans_id
,
batch_id
,
head_id
,
token_id
,
dim_id
;
decompose_5dim
(
offset
,
batch_size
,
nhead
,
seq_len
,
head_dim
,
&
trans_id
,
&
batch_id
,
&
head_id
,
&
token_id
,
&
dim_id
);
// [b, s, tc, nh, ad]
int
trg_offset
=
flat_5dim
(
batch_id
,
token_id
,
trans_id
,
head_id
,
dim_id
,
seq_len
,
trans_count
,
nhead
,
head_dim
);
const
float4
*
input4
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
res4
=
reinterpret_cast
<
float4
*>
(
output
);
res4
[
trg_offset
]
=
input4
[
offset
];
}
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template
<
>
void
launch_transform4d_0213
<
float
>
(
float
*
output
,
const
float
*
input
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
int
nhead
,
int
trans_count
,
cudaStream_t
stream
)
{
hidden_dim
>>=
2
;
int
head_dim
=
hidden_dim
/
nhead
;
int
num_all
=
batch_size
*
seq_len
*
trans_count
*
hidden_dim
;
int
nblock
=
(
num_all
+
MAX_THREADS
-
1
)
/
MAX_THREADS
;
transform4d_0213
<
float
><<<
nblock
,
MAX_THREADS
,
0
,
stream
>>>
(
output
,
input
,
batch_size
,
seq_len
,
trans_count
,
nhead
,
head_dim
,
num_all
);
}
template
<
>
void
launch_transform4d_0213
<
__half
>
(
__half
*
output
,
const
__half
*
input
,
int
batch_size
,
int
seq_len
,
int
hidden_dim
,
int
nhead
,
int
trans_count
,
cudaStream_t
stream
)
{
hidden_dim
>>=
3
;
int
head_dim
=
hidden_dim
/
nhead
;
int
num_all
=
batch_size
*
seq_len
*
trans_count
*
hidden_dim
;
int
nblock
=
(
num_all
+
MAX_THREADS
-
1
)
/
MAX_THREADS
;
transform4d_0213
<
__half
><<<
nblock
,
MAX_THREADS
,
0
,
stream
>>>
(
output
,
input
,
batch_size
,
seq_len
,
trans_count
,
nhead
,
head_dim
,
num_all
);
}
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
deleted
100644 → 0
View file @
bce9499e
#include "multihead_attention_1d.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/Types.hpp>
#else
#include <c10d/Types.hpp>
#endif
#include <iostream>
#include "context.h"
#include "kernels.h"
template
<
typename
T
>
MultiHeadAttention
<
T
>::
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
max_seq_len
,
int
hidden_size
,
int
num_heads
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
bool
pre_or_postLayerNorm
)
:
_layer_id
(
layer_id
),
_max_batch_tokens
(
max_batch_tokens
),
_max_seq_len
(
max_seq_len
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
3
*
hidden_size
,
hidden_size
)),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
hidden_size
,
hidden_size
)),
_attn_ln
(
typename
Normalize_Layer
<
T
>::
Config
(
hidden_size
,
false
),
_max_batch_tokens
),
_softmax
(
typename
Softmax
<
T
>::
Config
(
num_heads
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
),
_max_batch_tokens
*
_heads
*
_max_seq_len
),
_attn_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
),
_max_batch_tokens
*
_hidden_size
),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
(
T
(
1.0
)
/
T
(
sqrt
(
_hidden_size
/
_heads
))),
T
(
0.0
),
CUBLAS_OP_T
,
CUBLAS_OP_N
)),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
T
(
1.0
),
T
(
0.0
),
CUBLAS_OP_N
,
CUBLAS_OP_N
))
{
assert
(
_hidden_size
%
_heads
==
0
);
}
template
<
typename
T
>
MultiHeadAttention
<
T
>::~
MultiHeadAttention
()
{
free_mem_buffer
();
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
T
*
buffer
)
{
T
*
q_tf_ptr
=
_qkv_ptr
;
T
*
k_tf_ptr
=
q_tf_ptr
+
_batch_dim
/
pg_size
;
T
*
v_tf_ptr
=
k_tf_ptr
+
_batch_dim
/
pg_size
;
if
(
_pre_or_postLayerNorm
)
{
_attn_ln
.
Forward
(
_gemmQKV_inp_ptr
,
input_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
_stream
);
}
const
T
*
gemmQKV_inp_ptr
=
_pre_or_postLayerNorm
?
_gemmQKV_inp_ptr
:
input_ptr
;
_qkv_linear
.
reset_size
(
3
*
_hidden_size
/
pg_size
,
_hidden_size
);
_qkv_linear
.
Forward
(
_batch_tokens
,
gemmQKV_inp_ptr
,
_attn_qkvw_ptr
,
buffer
,
_cublasHandle
);
launch_bias_add_transform_20314
<
T
>
(
q_tf_ptr
,
buffer
,
_attn_qkvb_ptr
,
_batch_size
,
_seq_len
,
3
,
_heads
/
pg_size
,
_hidden_size
/
_heads
,
_stream
);
// attention scores, q*k
_attn_scores
.
Forward
(
_batch_heads
,
_soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
reset_size
(
_heads
/
pg_size
);
_softmax
.
Forward
(
_soft_out_ptr
,
input_mask_ptr
,
_batch_size
,
_seq_len
,
_seq_len
,
_stream
,
true
);
// attn prob dropout.
_attn_prob_dropout
.
dropout
(
_ctx_bufB_ptr
,
_soft_out_ptr
,
_batch_heads
*
_seq_len
*
_seq_len
,
_stream
);
// attention context, score * v
_attn_context
.
Forward
(
_batch_heads
,
buffer
,
v_tf_ptr
,
_ctx_bufB_ptr
,
_cublasHandle
);
// [b, nh, s, ad] -> [b, s, nh, ad]
launch_transform4d_0213
<
T
>
(
_attn_o_inp_ptr
,
buffer
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
1
,
_stream
);
_attn_out_linear
.
reset_size
(
_hidden_size
,
_hidden_size
/
pg_size
);
_attn_out_linear
.
Forward
(
_batch_tokens
,
_attn_o_inp_ptr
,
_attn_ow_ptr
,
output_ptr
,
_cublasHandle
);
// allreduce
if
(
pg
==
c10
::
detail
::
UniqueVoidPtr
()
||
pg
->
getSize
()
==
1
)
{
}
else
{
auto
data_type
=
torch
::
kFloat
;
if
(
typeid
(
T
)
!=
typeid
(
float
))
{
data_type
=
torch
::
kHalf
;
}
auto
output_tensor
=
torch
::
from_blob
(
output_ptr
,
{
int
(
_batch_size
),
int
(
_seq_len
),
int
(
_hidden_size
)},
torch
::
TensorOptions
(
torch
::
kCUDA
).
dtype
(
data_type
));
std
::
vector
<
torch
::
Tensor
>
allreduce_tensors
=
{
output_tensor
};
auto
work
=
pg
->
allreduce
(
allreduce_tensors
,
c10d
::
AllreduceOptions
());
work
->
wait
();
}
_attn_dropout
.
bias_dropout_residual
(
output_ptr
,
output_ptr
,
input_ptr
,
_attn_ob_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
// in-place ln since ln-input will not be used in post-ln mode
_attn_ln
.
Forward
(
output_ptr
,
output_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
_stream
);
}
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
)
{
_stream
=
Context
::
Instance
().
get_stream
();
_cublasHandle
=
Context
::
Instance
().
get_cublashandle
();
T
*
attn_buffer
=
_shared_mem_ptr
;
// 3 * _batch_dim
attn_layer_fw
(
input_ptr
,
input_mask_ptr
,
out_ptr
,
attn_buffer
);
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_ptr
,
T
*
buffer
)
{
cudaStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
const
T
*
q_tf_ptr
=
_qkv_ptr
;
const
T
*
k_tf_ptr
=
q_tf_ptr
+
_batch_dim
/
pg_size
;
const
T
*
v_tf_ptr
=
k_tf_ptr
+
_batch_dim
/
pg_size
;
// batch_dim = batch_size * seq_len * hidden_size
// buffer size: batch_dim * 3 + max(batch_dim * 3,
// batch_size * head_num * seq_len * seq_len)
T
*
grad_residual_ptr
=
buffer
;
buffer
+=
_batch_dim
;
T
*
grad_input_buf_ptr
=
buffer
;
// batch_dim
T
*
grad_qkv_5d_ptr
=
buffer
;
// batch_dim * 3
buffer
+=
3
*
_batch_dim
/
pg_size
;
T
*
grad_qkv_4d_ptr
=
buffer
;
// batch_dim * 3
T
*
grad_softmax_ptr
=
buffer
;
// batch_size * head_num * seq_len * seq_len
// buffer += max(3 * _batch_dim,
// batch_size * head_num * seq_len * seq_len);
if
(
_pre_or_postLayerNorm
)
{
_attn_dropout
.
d_bias_dropout_residual
(
grad_input_ptr
,
_grad_attn_ob_ptr
,
grad_output_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
}
else
{
_attn_ln
.
Backward
(
_grad_attn_nw_ptr
,
_grad_attn_nb_ptr
,
grad_residual_ptr
,
grad_output_ptr
,
nullptr
,
output_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
streams
);
_attn_dropout
.
d_bias_dropout_residual
(
grad_input_ptr
,
_grad_attn_ob_ptr
,
grad_residual_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
}
// bw of output project
_attn_out_linear
.
reset_size
(
_hidden_size
,
_hidden_size
/
pg_size
);
_attn_out_linear
.
Backward
(
_batch_tokens
,
grad_input_ptr
,
_attn_o_inp_ptr
,
_attn_ow_ptr
,
_grad_attn_ow_ptr
,
_grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
grad_input_buf_ptr
,
nullptr
,
false
);
launch_transform_0213
<
T
>
(
grad_input_ptr
,
grad_input_buf_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
_stream
);
// bw of score * v
_attn_context
.
Backward
(
_batch_heads
,
grad_input_ptr
,
v_tf_ptr
,
_ctx_bufB_ptr
,
_cublasHandle
,
grad_qkv_5d_ptr
+
2
*
_batch_dim
/
pg_size
,
grad_softmax_ptr
);
_attn_prob_dropout
.
d_dropout
(
grad_softmax_ptr
,
_batch_heads
*
_seq_len
*
_seq_len
,
_stream
);
_softmax
.
reset_size
(
_heads
/
pg_size
);
_softmax
.
Backward
(
grad_softmax_ptr
,
_soft_out_ptr
,
_batch_size
,
_seq_len
,
_seq_len
,
_stream
);
// bw of q * k
_attn_scores
.
Backward
(
_batch_heads
,
grad_softmax_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
grad_qkv_5d_ptr
+
_batch_dim
/
pg_size
,
grad_qkv_5d_ptr
);
// [3, b, nh, s, ad] -> [b, s, 3, h]
launch_transform4d_0213
<
T
>
(
grad_qkv_4d_ptr
,
grad_qkv_5d_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
3
,
_stream
);
const
T
*
gemmQKV_inp_ptr
=
_pre_or_postLayerNorm
?
_gemmQKV_inp_ptr
:
input_ptr
;
_qkv_linear
.
reset_size
(
3
*
_hidden_size
/
pg_size
,
_hidden_size
);
_qkv_linear
.
Backward
(
_batch_tokens
,
grad_qkv_4d_ptr
,
gemmQKV_inp_ptr
,
_attn_qkvw_ptr
,
_grad_attn_qkvw_ptr
,
_grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
grad_input_buf_ptr
,
nullptr
,
true
);
// allreduce
if
(
pg
==
c10
::
detail
::
UniqueVoidPtr
()
||
pg
->
getSize
()
==
1
)
{
}
else
{
auto
data_type
=
torch
::
kFloat
;
if
(
typeid
(
T
)
!=
typeid
(
float
))
{
data_type
=
torch
::
kHalf
;
}
auto
grad_input_tensor
=
torch
::
from_blob
(
grad_input_buf_ptr
,
{
int
(
_batch_size
),
int
(
_seq_len
),
int
(
_hidden_size
)},
torch
::
TensorOptions
(
torch
::
kCUDA
).
dtype
(
data_type
));
std
::
vector
<
torch
::
Tensor
>
allreduce_tensors
=
{
grad_input_tensor
};
auto
work
=
pg
->
allreduce
(
allreduce_tensors
,
c10d
::
AllreduceOptions
());
work
->
wait
();
}
if
(
_pre_or_postLayerNorm
)
{
_attn_ln
.
Backward
(
_grad_attn_nw_ptr
,
_grad_attn_nb_ptr
,
grad_input_ptr
,
grad_input_buf_ptr
,
grad_output_ptr
,
gemmQKV_inp_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
streams
);
}
else
{
// FIXME later
launch_fused_add2
<
T
>
(
grad_input_ptr
,
grad_input_buf_ptr
,
grad_residual_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
,
_stream
);
}
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
)
{
_stream
=
Context
::
Instance
().
get_stream
();
_cublasHandle
=
Context
::
Instance
().
get_cublashandle
();
T
*
buffer
=
_shared_mem_ptr
;
/*
buffer size needed by attn bw:
4 * _batch_dim + max(3 * _batch_dim,
_batch_size * _head_num * _seq_len * _seq_len);
*/
attn_layer_bw
(
input_ptr
,
input_mask_ptr
,
output_ptr
,
grad_output_ptr
,
grad_input_ptr
,
buffer
);
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
T
*
MultiHeadAttention
<
T
>::
_shared_mem_ptr
=
nullptr
;
template
class
MultiHeadAttention
<
float
>;
template
class
MultiHeadAttention
<
__half
>;
// x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_multihead_attention
;
template
<
typename
T
>
int
create_multihead_attention
(
int
layer_id
,
int
max_batch_tokens
,
int
max_seq_len
,
int
hidden_dim
,
int
num_heads
,
float
attn_prob_dropout_ratio
,
float
hidden_dropout_ratio
,
bool
pre_or_postLayerNorm
,
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
pg_
)
{
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
Context
::
Instance
().
set_stream
(
stream
);
auto
layer
=
std
::
make_shared
<
MultiHeadAttention
<
T
>>
(
layer_id
,
max_batch_tokens
,
max_seq_len
,
hidden_dim
,
num_heads
,
attn_prob_dropout_ratio
,
hidden_dropout_ratio
,
pre_or_postLayerNorm
);
layer
->
SetPG
(
pg_
);
s_multihead_attention
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
multihead_attention_fw
(
int
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
in_proj_weight
,
const
torch
::
Tensor
&
in_proj_bias
,
const
torch
::
Tensor
&
out_proj_weight
,
const
torch
::
Tensor
&
out_proj_bias
,
const
torch
::
Tensor
&
norm_weight
,
const
torch
::
Tensor
&
norm_bias
,
bool
training_mode
,
bool
prelayernorm
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
std
::
shared_ptr
<
MultiHeadAttention
<
T
>>
layer
=
std
::
static_pointer_cast
<
MultiHeadAttention
<
T
>>
(
s_multihead_attention
[
layer_id
]);
layer
->
set_cur_batch_shape
(
input
.
size
(
0
),
input
.
size
(
1
));
layer
->
SetTrainingMode
(
training_mode
);
layer
->
_attn_qkvw_ptr
=
(
const
T
*
)
in_proj_weight
.
data_ptr
();
layer
->
_attn_qkvb_ptr
=
(
const
T
*
)
in_proj_bias
.
data_ptr
();
layer
->
_attn_ow_ptr
=
(
const
T
*
)
out_proj_weight
.
data_ptr
();
layer
->
_attn_ob_ptr
=
(
const
T
*
)
out_proj_bias
.
data_ptr
();
layer
->
_attn_nw_ptr
=
(
const
T
*
)
norm_weight
.
data_ptr
();
layer
->
_attn_nb_ptr
=
(
const
T
*
)
norm_bias
.
data_ptr
();
layer
->
Forward
(
input_ptr
,
input_mask_ptr
,
out_ptr
);
return
{
output
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
multihead_attention_bw
(
int
layer_id
,
const
torch
::
Tensor
&
grad_dec_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
in_proj_weight
,
const
torch
::
Tensor
&
in_proj_bias
,
const
torch
::
Tensor
&
out_proj_weight
,
const
torch
::
Tensor
&
out_proj_bias
,
const
torch
::
Tensor
&
norm_weight
,
const
torch
::
Tensor
&
norm_bias
)
{
auto
g_output
=
grad_dec_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_in_proj_weight
=
torch
::
empty_like
(
in_proj_weight
);
auto
grad_in_proj_bias
=
torch
::
empty_like
(
in_proj_bias
);
auto
grad_out_proj_weight
=
torch
::
empty_like
(
out_proj_weight
);
auto
grad_out_proj_bias
=
torch
::
empty_like
(
out_proj_bias
);
auto
grad_norm_weight
=
torch
::
empty_like
(
norm_weight
);
auto
grad_norm_bias
=
torch
::
empty_like
(
norm_bias
);
// inputs.
const
T
*
grad_dec_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
std
::
shared_ptr
<
MultiHeadAttention
<
T
>>
layer
=
std
::
static_pointer_cast
<
MultiHeadAttention
<
T
>>
(
s_multihead_attention
[
layer_id
]);
layer
->
set_cur_batch_shape
(
g_output
.
size
(
0
),
g_output
.
size
(
1
));
layer
->
_grad_attn_qkvw_ptr
=
(
T
*
)
grad_in_proj_weight
.
data_ptr
();
layer
->
_grad_attn_qkvb_ptr
=
(
T
*
)
grad_in_proj_bias
.
data_ptr
();
layer
->
_grad_attn_ow_ptr
=
(
T
*
)
grad_out_proj_weight
.
data_ptr
();
layer
->
_grad_attn_ob_ptr
=
(
T
*
)
grad_out_proj_bias
.
data_ptr
();
layer
->
_grad_attn_nw_ptr
=
(
T
*
)
grad_norm_weight
.
data_ptr
();
layer
->
_grad_attn_nb_ptr
=
(
T
*
)
grad_norm_bias
.
data_ptr
();
layer
->
Backward
(
grad_dec_output_ptr
,
input_ptr
,
output_ptr
,
input_mask_ptr
,
grad_input_ptr
);
return
{
grad_input
,
grad_in_proj_weight
,
grad_in_proj_bias
,
grad_out_proj_weight
,
grad_out_proj_bias
,
grad_norm_weight
,
grad_norm_bias
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multihead_attention_fw_fp32"
,
&
multihead_attention_fw
<
float
>
,
"Multi-head Attention forward with fp32 (CUDA)"
);
m
.
def
(
"multihead_attention_fw_fp16"
,
&
multihead_attention_fw
<
__half
>
,
"Multi-head Attention forward with fp16 (CUDA)"
);
m
.
def
(
"multihead_attention_bw_fp32"
,
&
multihead_attention_bw
<
float
>
,
"Multi-head Attention backward with fp32 (CUDA)"
);
m
.
def
(
"multihead_attention_bw_fp16"
,
&
multihead_attention_bw
<
__half
>
,
"Multi-head Attention backward with fp16 (CUDA)"
);
m
.
def
(
"create_multihead_attention_fp32"
,
&
create_multihead_attention
<
float
>
,
"Create Multi-head Attention with fp32 (CUDA)"
);
m
.
def
(
"create_multihead_attention_fp16"
,
&
create_multihead_attention
<
__half
>
,
"Create Multi-head Attention with fp16 (CUDA)"
);
}
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
deleted
100644 → 0
View file @
bce9499e
#pragma once
#include <c10/util/intrusive_ptr.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <torch/torch.h>
#if TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#else
#include <c10d/ProcessGroup.hpp>
#endif
#include <string>
#include <type_traits>
#include "cuda_util.h"
#include "dropout.h"
#include "feed_forward.h"
#include "normalize_layer.h"
#include "softmax.h"
#include "strided_batch_gemm.h"
template
<
typename
T
>
class
MultiHeadAttention
{
public:
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
_max_seq_len
,
int
hidden_size
,
int
num_heads
,
float
attn_dropout_ratio
,
float
hidden_output_dropout_ratio
,
bool
pre_or_postLayerNorm
);
virtual
~
MultiHeadAttention
();
void
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
);
void
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
);
void
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
T
*
buffer
);
void
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_attn_layer_bwptr
,
T
*
buffer
);
void
set_cur_batch_shape
(
int
batch_size
,
int
seq_len
)
{
_batch_size
=
batch_size
;
_seq_len
=
seq_len
;
_batch_tokens
=
batch_size
*
seq_len
;
_batch_heads
=
batch_size
*
_heads
/
pg_size
;
_batch_dim
=
_batch_tokens
*
_hidden_size
;
_attn_scores
.
SetConfig
(
_seq_len
,
_seq_len
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_len
,
_seq_len
);
}
void
SetTrainingMode
(
bool
training
);
inline
bool
IsTrainingMode
()
const
{
return
_training
;
}
void
SetPG
(
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
pg_
)
{
pg
=
pg_
;
pg_size
=
1
;
if
(
pg
!=
c10
::
detail
::
UniqueVoidPtr
())
{
pg_size
=
pg
->
getSize
();
}
allocate_mem_buffer
();
}
// weights ptr
const
T
*
_attn_qkvw_ptr
;
const
T
*
_attn_qkvb_ptr
;
const
T
*
_attn_ow_ptr
;
const
T
*
_attn_ob_ptr
;
const
T
*
_attn_nw_ptr
;
const
T
*
_attn_nb_ptr
;
// grads ptr
T
*
_grad_attn_qkvw_ptr
;
T
*
_grad_attn_qkvb_ptr
;
T
*
_grad_attn_ow_ptr
;
T
*
_grad_attn_ob_ptr
;
T
*
_grad_attn_nw_ptr
;
T
*
_grad_attn_nb_ptr
;
private:
void
allocate_mem_buffer
()
{
// allocate local gpu memory
if
(
_pre_or_postLayerNorm
)
{
_gemmQKV_inp_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
);
}
else
{
_gemmQKV_inp_ptr
=
nullptr
;
}
_qkv_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
*
3
);
_soft_out_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_ctx_bufB_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
_attn_o_inp_ptr
=
cuda_malloc
<
T
>
(
_max_batch_tokens
*
_hidden_size
);
// buffer size needed by attn bw
size_t
smem_size
=
4
*
_max_batch_tokens
*
_hidden_size
/
pg_size
+
std
::
max
(
3
*
_max_batch_tokens
*
_hidden_size
/
pg_size
,
_max_batch_tokens
*
_heads
/
pg_size
*
_max_seq_len
);
if
(
!
_shared_mem_ptr
)
{
cuda_free
(
_shared_mem_ptr
);
_shared_mem_ptr
=
cuda_malloc
<
T
>
(
smem_size
);
}
}
void
free_mem_buffer
()
{
// free local gpu memory
cuda_free
(
_gemmQKV_inp_ptr
);
cuda_free
(
_qkv_ptr
);
cuda_free
(
_soft_out_ptr
);
cuda_free
(
_ctx_bufB_ptr
);
cuda_free
(
_attn_o_inp_ptr
);
// free shared gpu memory between layers
cuda_free
(
_shared_mem_ptr
);
_shared_mem_ptr
=
nullptr
;
}
// const parameter between batch
const
size_t
_layer_id
;
const
size_t
_hidden_size
;
const
size_t
_heads
;
const
size_t
_max_batch_tokens
;
const
size_t
_max_seq_len
;
const
bool
_pre_or_postLayerNorm
;
// dynamic parameter between batch
size_t
_batch_size
;
size_t
_seq_len
;
size_t
_batch_tokens
;
size_t
_batch_heads
;
size_t
_batch_dim
;
bool
_training
;
cublasHandle_t
_cublasHandle
;
cudaStream_t
_stream
;
// layers
FeedForward
<
T
>
_qkv_linear
;
FeedForward
<
T
>
_attn_out_linear
;
Normalize_Layer
<
T
>
_attn_ln
;
Softmax
<
T
>
_softmax
;
Dropout
<
T
>
_attn_prob_dropout
;
Dropout
<
T
>
_attn_dropout
;
StridedBatchGemm
<
T
>
_attn_scores
;
StridedBatchGemm
<
T
>
_attn_context
;
// local GPU memory
T
*
_gemmQKV_inp_ptr
;
T
*
_qkv_ptr
;
T
*
_soft_out_ptr
;
T
*
_ctx_bufB_ptr
;
T
*
_attn_o_inp_ptr
;
// shared GPU memory between layer
static
T
*
_shared_mem_ptr
;
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
pg
;
int
pg_size
;
};
colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp
deleted
100644 → 0
View file @
bce9499e
#include <torch/extension.h>
#include "linear.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"linear_silu_a8_w8_bfp32_ofp32"
,
&
linear_silu_a8_w8_bfp32_ofp32
,
"Linear SiLU (INT8)"
);
}
colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu
deleted
100644 → 0
View file @
bce9499e
// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu
#include "linear.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cstdint>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <torch/torch.h>
torch
::
Tensor
linear_silu_a8_w8_bfp32_ofp32
(
torch
::
Tensor
input
,
// INT8
torch
::
Tensor
weight
,
// INT8
torch
::
Tensor
bias
,
// FP32
float
alpha
,
// FP32
float
beta
// FP32
)
{
auto
M
=
input
.
size
(
0
);
auto
N
=
weight
.
size
(
0
);
auto
K
=
input
.
size
(
1
);
using
ElementOutput
=
float
;
using
ElementAccumulator
=
int32_t
;
using
ElementComputeEpilogue
=
float
;
using
ElementInputA
=
int8_t
;
// <- data type of elements in input matrix A
using
ElementInputB
=
int8_t
;
// <- data type of elements in input matrix B
// The code section below describes matrix layout of input and output
// matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major
// for Matrix C
using
LayoutInputA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutInputB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
#if CUDA_ARCH >= 800
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
LinearCombinationSilu
<
ElementOutput
,
// <- data type of output matrix
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
,
// <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator
,
// <- data type of accumulator
ElementComputeEpilogue
// <- data type for alpha in linear combination
// function
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
Gemm
<
int8_t
,
cutlass
::
layout
::
RowMajor
,
int8_t
,
cutlass
::
layout
::
ColumnMajor
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
,
EpilogueOp
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<>
,
3
>
;
#elif CUDA_ARCH >= 750
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
LinearCombinationSilu
<
ElementOutput
,
// <- data type of output matrix
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
,
// <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This
// becomes the vector width of math
// instructions in epilogue too
ElementAccumulator
,
// <- data type of accumulator
ElementComputeEpilogue
// <- data type for alpha in linear combination
// function
>
;
using
DefaultGemmCfg
=
cutlass
::
gemm
::
device
::
DefaultGemmConfiguration
<
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
ElementInputA
,
ElementInputB
,
ElementOutput
,
ElementAccumulator
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
Gemm
<
int8_t
,
cutlass
::
layout
::
RowMajor
,
int8_t
,
cutlass
::
layout
::
ColumnMajor
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm75
,
DefaultGemmCfg
::
ThreadblockShape
,
DefaultGemmCfg
::
WarpShape
,
DefaultGemmCfg
::
InstructionShape
,
EpilogueOp
>
;
#elif CUDA_ARCH >= 700
#define USE_TORCH_SILU
using
DefaultGemmCfg
=
cutlass
::
gemm
::
device
::
DefaultGemmConfiguration
<
cutlass
::
arch
::
OpClassSimt
,
cutlass
::
arch
::
Sm70
,
ElementInputA
,
ElementInputB
,
ElementOutput
,
ElementAccumulator
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
Gemm
<
int8_t
,
cutlass
::
layout
::
RowMajor
,
int8_t
,
cutlass
::
layout
::
ColumnMajor
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
cutlass
::
arch
::
OpClassSimt
,
cutlass
::
arch
::
Sm70
,
DefaultGemmCfg
::
ThreadblockShape
,
DefaultGemmCfg
::
WarpShape
,
DefaultGemmCfg
::
InstructionShape
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementOutput
,
1
,
ElementAccumulator
,
ElementComputeEpilogue
>>
;
#else
#error "Unsupported cuda arch"
#endif
auto
input_size
=
cutlass
::
MatrixCoord
(
M
,
K
);
auto
weight_size
=
cutlass
::
MatrixCoord
(
K
,
N
);
auto
output_size
=
cutlass
::
MatrixCoord
(
M
,
N
);
auto
device
=
input
.
device
();
// use the broadcasted bias as the output
auto
out
=
bias
.
to
(
device
).
view
({
1
,
-
1
}).
repeat
({
M
,
1
});
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass
::
gemm
::
GemmCoord
problem_size
(
M
,
N
,
K
);
cutlass
::
TensorRef
<
ElementInputA
,
LayoutInputA
>
input_ref
(
input
.
data_ptr
<
ElementInputA
>
(),
LayoutInputA
::
packed
(
input_size
));
cutlass
::
TensorRef
<
ElementInputB
,
LayoutInputB
>
weight_ref
(
weight
.
data_ptr
<
ElementInputB
>
(),
LayoutInputB
::
packed
(
weight_size
));
cutlass
::
TensorRef
<
ElementOutput
,
LayoutOutput
>
out_ref
(
out
.
data_ptr
<
ElementOutput
>
(),
LayoutOutput
::
packed
(
output_size
));
typename
Gemm
::
Arguments
arguments
{
problem_size
,
// <- problem size of matrix multiplication
input_ref
,
// <- reference to matrix A on device
weight_ref
,
// <- reference to matrix B on device
out_ref
,
// <- reference to matrix C on device
out_ref
,
// <- reference to matrix D on device
{
alpha
,
beta
},
1
};
Gemm
gemm_op
;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
arguments
);
// Allocate workspace memory
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
// Check the problem size is supported or not
cutlass
::
Status
status
=
gemm_op
.
can_implement
(
arguments
);
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot implement"
);
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status
=
gemm_op
.
initialize
(
arguments
,
workspace
.
get
());
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot initialize"
);
}
status
=
gemm_op
();
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot run"
);
}
#ifdef USE_TORCH_SILU
#undef USE_TORCH_SILU
out
=
torch
::
silu
(
out
);
#endif
return
out
;
}
colossalai/kernel/cuda_native/csrc/smoothquant/linear.h
deleted
100644 → 0
View file @
bce9499e
#include <torch/torch.h>
#include <torch/types.h>
#include <cstdint>
#include <iostream>
torch
::
Tensor
linear_silu_a8_w8_bfp32_ofp32
(
torch
::
Tensor
input
,
// INT8
torch
::
Tensor
weight
,
// INT8
torch
::
Tensor
bias
,
// FP32
float
alpha
,
// FP32
float
beta
// FP32
);
colossalai/kernel/cuda_native/mha/__init__.py
deleted
100644 → 0
View file @
bce9499e
from
.mha
import
ColoAttention
__all__
=
[
"ColoAttention"
]
colossalai/kernel/cuda_native/mha/flash_attn_2.py
deleted
100644 → 0
View file @
bce9499e
import
warnings
from
typing
import
Optional
import
torch
def
is_ampere_or_better_gpu
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
properties
=
torch
.
cuda
.
get_device_properties
(
device
)
if
properties
.
major
>=
8
:
# Ampere GPUs or newer
return
True
return
False
# "Check Ampere GPUs or newer"
HAS_FLASH_ATTN
=
False
if
is_ampere_or_better_gpu
():
HAS_FLASH_ATTN
=
True
else
:
warnings
.
warn
(
"FlashAttention only supports Ampere GPUs or newer."
)
HAS_FLASH_ATTN
=
False
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_func
,
flash_attn_varlen_func
HAS_FLASH_ATTN
=
True
except
ImportError
:
warnings
.
warn
(
"please install flash_attn from https://github.com/HazyResearch/flash-attention"
)
HAS_FLASH_ATTN
=
False
if
HAS_FLASH_ATTN
:
pass
from
.utils
import
SeqLenInfo
def
flash_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
seq_len_info_q
:
SeqLenInfo
,
seq_len_info_kv
:
SeqLenInfo
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
dropout_p
:
float
=
0.0
,
scale
:
float
=
None
,
causal
:
bool
=
False
,
padded
:
bool
=
False
,
):
"""
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
sm_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
if
padded
:
if
seq_len_info_kv
==
None
:
seq_len_info_kv
=
seq_len_info_q
attn_out
=
flash_attn_varlen_func
(
q
,
k
,
v
,
seq_len_info_q
.
cu_seqlens
,
seq_len_info_kv
.
cu_seqlens
,
seq_len_info_q
.
max_seqlen
,
seq_len_info_kv
.
max_seqlen
,
dropout_p
,
scale
,
causal
,
)
else
:
attn_out
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
dropout_p
,
softmax_scale
=
scale
,
causal
=
causal
)
return
attn_out
colossalai/kernel/cuda_native/mha/mem_eff_attn.py
deleted
100644 → 0
View file @
bce9499e
import
warnings
HAS_MEM_EFF_ATTN
=
False
try
:
from
xformers.ops.fmha
import
MemoryEfficientAttentionCutlassOp
,
memory_efficient_attention
from
xformers.ops.fmha.attn_bias
import
(
BlockDiagonalCausalMask
,
BlockDiagonalMask
,
LowerTriangularMask
,
LowerTriangularMaskWithTensorBias
,
)
HAS_MEM_EFF_ATTN
=
True
except
ImportError
:
warnings
.
warn
(
"please install xformers from https://github.com/facebookresearch/xformers"
)
HAS_MEM_EFF_ATTN
=
False
if
HAS_MEM_EFF_ATTN
:
"""
A general attention module using the flash attention kernels from xformers:
https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
"""
from
typing
import
Optional
import
torch
from
.utils
import
SeqLenInfo
allow_alibi
=
True
for
op
in
MemoryEfficientAttentionCutlassOp
:
allow_alibi
=
allow_alibi
&
(
LowerTriangularMaskWithTensorBias
in
op
.
SUPPORTED_ATTN_BIAS_TYPES
)
def
mem_eff_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
seq_len_info_q
:
SeqLenInfo
,
seq_len_info_kv
:
SeqLenInfo
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
dropout_p
:
float
=
0.0
,
scale
:
float
=
None
,
causal
:
bool
=
False
,
padded
:
bool
=
False
,
):
attn_bias
=
None
if
padded
:
# bert style
if
not
causal
:
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
seq_len_info_q
.
seqlens
,
seq_len_info_kv
.
seqlens
)
else
:
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
seq_len_info_q
.
seqlens
,
seq_len_info_kv
.
seqlens
)
elif
causal
:
# gpt style
attn_bias
=
LowerTriangularMask
()
if
bias
is
not
None
:
# alibi / relative position embedding
assert
allow_alibi
,
"flash attention with bias is not supported in this system."
assert
causal
,
"attention with bias is only supported for causal attention so far."
attn_bias
=
attn_bias
.
add_bias
(
bias
)
if
padded
:
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
out
=
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
attn_bias
,
p
=
dropout_p
,
scale
=
scale
)
# shape: (b*s, n, d)
if
padded
:
out
=
out
.
squeeze
(
0
)
return
out
colossalai/kernel/cuda_native/mha/utils.py
deleted
100644 → 0
View file @
bce9499e
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
colossalai.utils.device
import
get_current_device
class
Unpad
(
torch
.
autograd
.
Function
):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
):
ctx
.
save_for_backward
(
indices
)
# [b, s, ...]
assert
tensor
.
ndim
>=
3
ctx
.
bsz
=
tensor
.
shape
[
0
]
out
=
rearrange
(
tensor
,
"b s ... -> (b s) ..."
)
ctx
.
shape
=
out
.
shape
# [ntokens, ...]
return
out
[
indices
]
@
staticmethod
def
backward
(
ctx
,
grad_output
):
(
indices
,)
=
ctx
.
saved_tensors
# [ntokens, ...]
grad
=
torch
.
zeros
(
ctx
.
shape
,
dtype
=
grad_output
.
dtype
,
device
=
grad_output
.
device
)
grad
[
indices
]
=
grad_output
grad
=
rearrange
(
grad
,
"(b s) ... -> b s ..."
,
b
=
ctx
.
bsz
)
# [b, s, ...]
return
grad
,
None
class
Repad
(
torch
.
autograd
.
Function
):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
,
seq_len
:
int
):
ctx
.
save_for_backward
(
indices
)
# [ntokens, ...]
tensor
=
tensor
out
=
torch
.
zeros
((
batch_size
*
seq_len
,
*
tensor
.
shape
[
1
:]),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
# [b*s, ...]
out
[
indices
]
=
tensor
return
out
@
staticmethod
def
backward
(
ctx
,
grad_output
):
(
indices
,)
=
ctx
.
saved_tensors
# [b*s, ...]
grad
=
grad_output
[
indices
]
# [ntokens, ...]
return
grad
,
None
,
None
,
None
@
dataclass
class
SeqLenInfo
:
seqlens
:
Iterable
[
int
]
=
None
indices
:
torch
.
Tensor
=
None
max_seqlen
:
int
=
None
cu_seqlens
:
torch
.
Tensor
=
None
@
staticmethod
def
materialize
(
attn_mask
:
torch
.
Tensor
=
None
,
size
:
Tuple
[
int
]
=
None
,
device
=
get_current_device
()):
if
attn_mask
is
not
None
:
indices
=
torch
.
nonzero
(
attn_mask
.
flatten
(),
as_tuple
=
False
).
flatten
().
to
(
device
)
seqlens
=
attn_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
).
flatten
()
else
:
batch_size
,
tgt_len
=
size
[
0
],
size
[
1
]
indices
=
torch
.
arange
(
batch_size
*
tgt_len
,
dtype
=
torch
.
long
,
device
=
device
)
seqlens
=
torch
.
LongTensor
([
tgt_len
]
*
batch_size
,
device
=
device
)
max_seqlen
=
max
(
seqlens
)
cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)).
to
(
device
)
return
SeqLenInfo
(
seqlens
.
tolist
(),
indices
,
max_seqlen
,
cu_seqlens
)
colossalai/kernel/cuda_native/multihead_attention.py
deleted
100644 → 0
View file @
bce9499e
import
math
from
dataclasses
import
dataclass
import
torch
from
torch
import
nn
from
torch.autograd
import
Function
def
check_config
(
config
):
if
config
.
hidden_size
%
config
.
nhead
!=
0
:
raise
Exception
(
"hidden_size % nhead != 0"
)
factor
=
8
if
config
.
fp16
else
4
upbound
=
factor
*
1024
*
4
if
config
.
hidden_size
>
upbound
:
# as required by ln backward kernel currently
raise
Exception
(
f
"hidden_size >
{
upbound
}
"
)
head_dim
=
config
.
hidden_size
//
config
.
nhead
if
head_dim
%
factor
!=
0
:
# as required by reshape kernel
raise
Exception
(
f
"head_dim(
{
head_dim
}
) %
{
factor
}
!= 0"
)
def
calc_offset
(
sizes
):
offsets
=
[
0
]
tmp
=
0
for
x
in
sizes
:
tmp
+=
x
offsets
.
append
(
tmp
)
return
offsets
colossal_multihead_attention
=
None
@
dataclass
class
Config
:
max_batch_tokens
:
int
# max batch token numbers
max_seq_len
:
int
# max sequence length
hidden_size
:
int
# size of transformer hidden layers
nhead
:
int
# number of heads in attention
attn_prob_dropout_ratio
:
float
# attention score dropout ratio
hidden_dropout_ratio
:
float
# dropout ration before residual
norm_first
:
bool
# norm_first
fp16
:
bool
# fp16 precision
class
MultiHeadAttention1DFunc
(
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
,
config
,
):
cuda_module
=
colossal_multihead_attention
forward_func
=
(
cuda_module
.
multihead_attention_fw_fp16
if
config
.
fp16
else
cuda_module
.
multihead_attention_fw_fp32
)
if
config
.
fp16
:
input
=
input
.
to
(
torch
.
half
)
input_mask
=
input_mask
.
to
(
torch
.
half
)
(
output
,)
=
forward_func
(
config
.
layer_id
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
,
config
.
training
,
config
.
norm_first
,
)
if
config
.
is_grad_enabled
and
config
.
training
:
ctx
.
save_for_backward
(
output
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
,
)
ctx
.
config
=
config
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
assert
ctx
.
config
.
training
cuda_module
=
colossal_multihead_attention
backward_func
=
(
cuda_module
.
multihead_attention_bw_fp16
if
ctx
.
config
.
fp16
else
cuda_module
.
multihead_attention_bw_fp32
)
(
output
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
,
)
=
ctx
.
saved_tensors
grad_input
=
None
grad_in_proj_weight
=
None
grad_in_proj_bias
=
None
grad_out_proj_weight
=
None
grad_out_proj_bias
=
None
grad_norm_weight
=
None
grad_norm_bias
=
None
if
ctx
.
config
.
fp16
:
grad_output
=
grad_output
.
to
(
torch
.
half
)
output
=
output
.
to
(
torch
.
half
)
input
=
input
.
to
(
torch
.
half
)
input_mask
=
input_mask
.
to
(
torch
.
half
)
(
grad_input
,
grad_in_proj_weight
,
grad_in_proj_bias
,
grad_out_proj_weight
,
grad_out_proj_bias
,
grad_norm_weight
,
grad_norm_bias
,
)
=
backward_func
(
ctx
.
config
.
layer_id
,
grad_output
,
output
,
input
,
input_mask
,
in_proj_weight
,
in_proj_bias
,
out_proj_weight
,
out_proj_bias
,
norm_weight
,
norm_bias
,
)
return
(
grad_input
,
None
,
grad_in_proj_weight
,
grad_in_proj_bias
,
grad_out_proj_weight
,
grad_out_proj_bias
,
grad_norm_weight
,
grad_norm_bias
,
None
,
)
class
MultiHeadAttention
(
nn
.
Module
):
"""Initialize the MultiHeadAttention.
Static variable:
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
Arguments:
hidden_size: Total dimension of hidden_size.
nhead: Number of parallel attention heads.
batch_size: Batch Size for one forward
max_seq_len: Max length of input sequence
dropout: Dropout probability
norm_first: perform LayerNorms before attention
"""
layer_id
=
0
def
__init__
(
self
,
hidden_size
,
nhead
,
batch_size
,
max_seq_len
,
dropout
=
0.0
,
norm_first
=
False
,
fp16
=
True
,
pg
=
None
):
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
config
=
Config
(
batch_size
*
max_seq_len
,
max_seq_len
,
hidden_size
,
nhead
,
dropout
,
dropout
,
norm_first
,
fp16
)
check_config
(
self
.
config
)
self
.
pg
=
pg
self
.
pg_size
=
1
if
self
.
pg
:
self
.
pg_size
=
pg
.
size
()
self
.
config
.
layer_id
=
MultiHeadAttention
.
layer_id
MultiHeadAttention
.
layer_id
=
MultiHeadAttention
.
layer_id
+
1
# Load cuda modules if needed
global
colossal_multihead_attention
if
colossal_multihead_attention
is
None
:
from
colossalai.kernel.op_builder
import
MultiHeadAttnBuilder
multihead_attention
=
MultiHeadAttnBuilder
().
load
()
colossal_multihead_attention
=
multihead_attention
# create the layer in cuda kernels.
cuda_module
=
colossal_multihead_attention
create_layer_func
=
(
cuda_module
.
create_multihead_attention_fp16
if
self
.
config
.
fp16
else
cuda_module
.
create_multihead_attention_fp32
)
create_layer_func
(
self
.
config
.
layer_id
,
self
.
config
.
max_batch_tokens
,
self
.
config
.
max_seq_len
,
self
.
config
.
hidden_size
,
self
.
config
.
nhead
,
self
.
config
.
attn_prob_dropout_ratio
,
self
.
config
.
hidden_dropout_ratio
,
self
.
config
.
norm_first
,
self
.
pg
,
)
hs
=
self
.
config
.
hidden_size
self
.
precision
=
torch
.
float32
if
self
.
config
.
fp16
:
self
.
precision
=
torch
.
half
self
.
hs_per_rank
=
int
(
hs
/
self
.
pg_size
)
self
.
in_proj_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
3
,
self
.
hs_per_rank
,
hs
))
self
.
in_proj_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
3
,
self
.
hs_per_rank
))
self
.
out_proj_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
hs
,
self
.
hs_per_rank
))
self
.
out_proj_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
hs
))
self
.
norm_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
hs
))
self
.
norm_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
hs
))
self
.
reset_parameters
()
torch
.
cuda
.
empty_cache
()
def
calc_bound
(
self
,
w
):
fan_in
,
_
=
nn
.
init
.
_calculate_fan_in_and_fan_out
(
w
)
bound
=
1.0
/
math
.
sqrt
(
fan_in
)
return
bound
def
reset_parameters
(
self
):
hs
=
self
.
config
.
hidden_size
nn
.
init
.
zeros_
(
self
.
out_proj_bias
)
nn
.
init
.
ones_
(
self
.
norm_weight
)
nn
.
init
.
zeros_
(
self
.
norm_bias
)
if
self
.
pg_size
>
1
:
rank_in_pg
=
torch
.
distributed
.
get_rank
(
self
.
pg
)
attn_qkvw_global
=
torch
.
empty
(
hs
*
3
,
hs
)
attn_qkvb_global
=
torch
.
empty
(
hs
*
3
)
nn
.
init
.
xavier_uniform_
(
attn_qkvw_global
,
1.0
/
math
.
sqrt
(
2.0
))
bound
=
self
.
calc_bound
(
attn_qkvw_global
)
nn
.
init
.
uniform_
(
attn_qkvb_global
,
-
bound
,
bound
)
attn_qkvw_global
=
attn_qkvw_global
.
cuda
()
attn_qkvb_global
=
attn_qkvb_global
.
cuda
()
torch
.
distributed
.
broadcast
(
attn_qkvw_global
,
src
=
0
,
group
=
self
.
pg
)
torch
.
distributed
.
broadcast
(
attn_qkvb_global
,
src
=
0
,
group
=
self
.
pg
)
attn_qkvw_global
=
attn_qkvw_global
.
cpu
()
attn_qkvb_global
=
attn_qkvb_global
.
cpu
()
with
torch
.
no_grad
():
self
.
in_proj_weight
.
copy_
(
attn_qkvw_global
.
view
(
3
,
hs
,
hs
)[
:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
)
:
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
),
:
]
)
self
.
in_proj_bias
.
copy_
(
attn_qkvb_global
.
view
(
3
,
hs
)[
:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
)
:
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
)
]
)
attn_ow_global
=
torch
.
empty
(
hs
,
hs
)
nn
.
init
.
xavier_uniform_
(
attn_ow_global
,
1.0
)
attn_ow_global
=
attn_ow_global
.
cuda
()
torch
.
distributed
.
broadcast
(
attn_ow_global
,
src
=
0
,
group
=
self
.
pg
)
attn_ow_global
=
attn_ow_global
.
cpu
()
with
torch
.
no_grad
():
self
.
out_proj_weight
.
copy_
(
attn_ow_global
[:,
int
(
hs
*
rank_in_pg
/
self
.
pg_size
)
:
int
(
hs
*
(
rank_in_pg
+
1
)
/
self
.
pg_size
)]
)
else
:
attn_qkvw
=
self
.
in_proj_weight
.
view
(
-
1
,
hs
)
nn
.
init
.
xavier_uniform_
(
attn_qkvw
,
1.0
/
math
.
sqrt
(
2.0
))
bound
=
self
.
calc_bound
(
attn_qkvw
)
nn
.
init
.
uniform_
(
self
.
in_proj_bias
,
-
bound
,
bound
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj_weight
,
1.0
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
destination
=
torch
.
nn
.
Module
.
state_dict
(
self
,
destination
=
destination
,
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
destination
def
forward
(
self
,
hidden_states
,
encoder_padding_mask
):
self
.
config
.
training
=
self
.
training
self
.
config
.
is_grad_enabled
=
torch
.
is_grad_enabled
()
hidden_states
=
hidden_states
.
contiguous
()
encoder_padding_mask
=
(
encoder_padding_mask
*
-
1e8
).
type_as
(
hidden_states
).
contiguous
()
bs
,
sl
,
dim
=
hidden_states
.
size
()
if
bs
*
sl
>
self
.
config
.
max_batch_tokens
:
raise
ValueError
(
f
"Batch token numbers
{
bs
*
sl
}
exceeds the limit
{
self
.
config
.
max_batch_tokens
}
."
)
if
sl
>
self
.
config
.
max_seq_len
:
raise
ValueError
(
f
"Sequence length
{
sl
}
exceeds the limit
{
self
.
config
.
max_seq_len
}
."
)
if
len
(
encoder_padding_mask
.
size
())
==
1
:
assert
bs
==
1
and
sl
==
encoder_padding_mask
.
size
(
0
)
else
:
assert
bs
==
encoder_padding_mask
.
size
(
0
)
and
sl
==
encoder_padding_mask
.
size
(
1
)
output
=
MultiHeadAttention1DFunc
.
apply
(
hidden_states
,
encoder_padding_mask
,
self
.
in_proj_weight
,
self
.
in_proj_bias
,
self
.
out_proj_weight
,
self
.
out_proj_bias
,
self
.
norm_weight
,
self
.
norm_bias
,
self
.
config
,
)
return
output
.
to
(
self
.
precision
)
colossalai/kernel/extensions
0 → 120000
View file @
8823cc48
../../extensions
\ No newline at end of file
colossalai/kernel/jit/option.py
View file @
8823cc48
import
torch
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.nn.layer.colossalai_layer
import
Embedding
,
Linear
from
colossalai.utils
import
get_current_device
from
.bias_dropout_add
import
bias_dropout_add_fused_train
from
.bias_gelu
import
bias_gelu_impl
...
...
@@ -46,11 +46,13 @@ def warmup_jit_fusion(
):
"""Compile JIT functions before the main training steps"""
embed
=
Embedding
(
vocab_size
,
hidden_size
).
to
(
get_current_device
())
linear_1
=
Linear
(
hidden_size
,
hidden_size
*
4
,
skip_bias_add
=
True
).
to
(
get_current_device
())
linear_2
=
Linear
(
hidden_size
*
4
,
hidden_size
,
skip_bias_add
=
True
).
to
(
get_current_device
())
embed
=
Embedding
(
vocab_size
,
hidden_size
).
to
(
get_
accelerator
().
get_
current_device
())
linear_1
=
Linear
(
hidden_size
,
hidden_size
*
4
,
skip_bias_add
=
True
).
to
(
get_
accelerator
().
get_
current_device
())
linear_2
=
Linear
(
hidden_size
*
4
,
hidden_size
,
skip_bias_add
=
True
).
to
(
get_
accelerator
().
get_
current_device
())
x
=
torch
.
randint
(
vocab_size
,
(
batch_size
,
seq_length
),
dtype
=
torch
.
long
,
device
=
get_current_device
())
x
=
torch
.
randint
(
vocab_size
,
(
batch_size
,
seq_length
),
dtype
=
torch
.
long
,
device
=
get_accelerator
().
get_current_device
()
)
x
=
embed
(
x
)
y
,
y_bias
=
linear_1
(
x
)
z
,
z_bias
=
linear_2
(
y
)
...
...
@@ -58,8 +60,8 @@ def warmup_jit_fusion(
# prop and recomputation
for
bias_grad
,
input_grad
in
zip
([
True
,
True
],
[
False
,
True
]):
for
_
in
range
(
10
):
bias
=
torch
.
rand_like
(
y_bias
,
dtype
=
dtype
,
device
=
get_current_device
())
input_
=
torch
.
rand_like
(
y
,
dtype
=
dtype
,
device
=
get_current_device
())
bias
=
torch
.
rand_like
(
y_bias
,
dtype
=
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
input_
=
torch
.
rand_like
(
y
,
dtype
=
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
bias
.
requires_grad
,
input_
.
requires_grad
=
bias_grad
,
input_grad
bias_gelu_impl
(
input_
,
bias
)
...
...
@@ -69,9 +71,9 @@ def warmup_jit_fusion(
# prop and recomputation
for
input_grad
,
bias_grad
,
residual_grad
in
zip
([
False
,
True
],
[
True
,
True
],
[
True
,
True
]):
for
_
in
range
(
10
):
input_
=
torch
.
rand_like
(
z
,
dtype
=
dtype
,
device
=
get_current_device
())
residual
=
torch
.
rand_like
(
x
,
dtype
=
dtype
,
device
=
get_current_device
())
bias
=
torch
.
rand_like
(
z_bias
,
dtype
=
dtype
,
device
=
get_current_device
())
input_
=
torch
.
rand_like
(
z
,
dtype
=
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
residual
=
torch
.
rand_like
(
x
,
dtype
=
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
bias
=
torch
.
rand_like
(
z_bias
,
dtype
=
dtype
,
device
=
get_
accelerator
().
get_
current_device
())
input_
.
requires_grad
=
input_grad
bias
.
requires_grad
=
bias_grad
residual
.
requires_grad
=
residual_grad
...
...
colossalai/kernel/kernel_loader.py
0 → 100644
View file @
8823cc48
import
warnings
from
typing
import
List
from
.extensions
import
(
CpuAdamArmExtension
,
CpuAdamX86Extension
,
FlashAttentionDaoCudaExtension
,
FlashAttentionNpuExtension
,
FlashAttentionXformersCudaExtension
,
FusedOptimizerCudaExtension
,
LayerNormCudaExtension
,
MoeCudaExtension
,
ScaledMaskedSoftmaxCudaExtension
,
ScaledUpperTriangleMaskedSoftmaxCudaExtension
,
)
from
.extensions.base_extension
import
_Extension
__all__
=
[
"KernelLoader"
,
"CPUAdamLoader"
,
"LayerNormLoader"
,
"MoeLoader"
,
"FusedOptimizerLoader"
,
"ScaledMaskedSoftmaxLoader"
,
"ScaledUpperTriangleMaskedSoftmaxLoader"
,
]
class
KernelLoader
:
"""
An abstract class which offers encapsulation to the kernel loading process.
Usage:
kernel_loader = KernelLoader()
kernel = kernel_loader.load()
"""
REGISTRY
:
List
[
_Extension
]
=
[]
@
classmethod
def
register_extension
(
cls
,
extension
:
_Extension
):
"""
This classmethod is an extension point which allows users to register their customized
kernel implementations to the loader.
Args:
extension (_Extension): the extension to be registered.
"""
cls
.
REGISTRY
.
append
(
extension
)
def
load
(
self
,
ext_name
:
str
=
None
):
"""
Load the kernel according to the current machine.
Args:
ext_name (str): the name of the extension to be loaded. If not specified, the loader
will try to look for an kernel available on the current machine.
"""
exts
=
[
ext_cls
()
for
ext_cls
in
self
.
__class__
.
REGISTRY
]
# look for exts which can be built/loaded on the current machine
if
ext_name
:
usable_exts
=
list
(
filter
(
lambda
ext
:
ext
.
name
==
ext_name
,
exts
))
else
:
usable_exts
=
[]
for
ext
in
exts
:
if
ext
.
is_hardware_available
():
# make sure the machine is compatible during kernel loading
ext
.
assert_hardware_compatible
()
usable_exts
.
append
(
ext
)
assert
len
(
usable_exts
)
!=
0
,
f
"No usable kernel found for
{
self
.
__class__
.
__name__
}
on the current machine."
if
len
(
usable_exts
)
>
1
:
# if more than one usable kernel is found, we will try to load the kernel with the highest priority
usable_exts
=
sorted
(
usable_exts
,
key
=
lambda
ext
:
ext
.
priority
,
reverse
=
True
)
warnings
.
warn
(
f
"More than one kernel is available, loading the kernel with the highest priority -
{
usable_exts
[
0
].
__class__
.
__name__
}
"
)
return
usable_exts
[
0
].
load
()
class
CPUAdamLoader
(
KernelLoader
):
REGISTRY
=
[
CpuAdamX86Extension
,
CpuAdamArmExtension
]
class
LayerNormLoader
(
KernelLoader
):
REGISTRY
=
[
LayerNormCudaExtension
]
class
MoeLoader
(
KernelLoader
):
REGISTRY
=
[
MoeCudaExtension
]
class
FusedOptimizerLoader
(
KernelLoader
):
REGISTRY
=
[
FusedOptimizerCudaExtension
]
class
ScaledMaskedSoftmaxLoader
(
KernelLoader
):
REGISTRY
=
[
ScaledMaskedSoftmaxCudaExtension
]
class
ScaledUpperTriangleMaskedSoftmaxLoader
(
KernelLoader
):
REGISTRY
=
[
ScaledUpperTriangleMaskedSoftmaxCudaExtension
]
class
FlashAttentionLoader
(
KernelLoader
):
REGISTRY
=
[
FlashAttentionNpuExtension
,
FlashAttentionDaoCudaExtension
,
FlashAttentionXformersCudaExtension
]
colossalai/kernel/op_builder
deleted
120000 → 0
View file @
bce9499e
../../op_builder
\ No newline at end of file
colossalai/legacy/amp/naive_amp/_fp16_optimizer.py
View file @
8823cc48
...
...
@@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from
torch.optim
import
Optimizer
from
colossalai.amp.naive_amp.grad_scaler
import
BaseGradScaler
from
colossalai.kernel.
op_buil
der
import
FusedOptim
Buil
der
from
colossalai.kernel.
kernel_loa
der
import
FusedOptim
izerLoa
der
from
colossalai.legacy.context
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.utils
import
clip_grad_norm_fp32
,
copy_tensor_parallel_attributes
...
...
@@ -28,7 +28,7 @@ def load_fused_optim():
global
fused_optim
if
fused_optim
is
None
:
fused_optim
=
FusedOptim
Buil
der
().
load
()
fused_optim
=
FusedOptim
izerLoa
der
().
load
()
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
...
...
colossalai/legacy/amp/torch_amp/torch_amp.py
View file @
8823cc48
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
colossalai.utils.device
import
autocast
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.nn.modules.loss
import
_Loss
from
torch.optim
import
Optimizer
from
colossalai.accelerator
import
get_accelerator
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.legacy.utils
import
clip_grad_norm_fp32
from
._grad_scaler
import
GradScaler
autocast
=
get_accelerator
().
autocast
class
TorchAMPOptimizer
(
OptimizerWrapper
):
"""A wrapper class which integrate Pytorch AMP with an optimizer
...
...
Prev
1
2
3
4
5
6
7
8
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment