"vscode:/vscode.git/clone" did not exist on "0cd4e391ab98a8ab8a735d8b2178b43ce6004ddc"
Unverified Commit f0f2a702 authored by RezaYazdaniAminabadi's avatar RezaYazdaniAminabadi Committed by GitHub
Browse files

support dynamic sequence length in transformer kernels (#424)


Co-authored-by: default avatarConglong Li <conglong.li@gmail.com>
parent 71f7df39
......@@ -29,7 +29,7 @@
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 4096
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
......
......@@ -29,7 +29,6 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
template <typename T>
......@@ -37,7 +36,6 @@ void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
template <typename T>
......@@ -46,7 +44,6 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
// Custom fused bias add with layer normalization
......@@ -57,14 +54,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training = false,
T* vars = nullptr,
T* means = nullptr,
T* vals_hat = nullptr);
bool training,
T* vars,
T* means);
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
......@@ -73,14 +68,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training = false,
T* vars = nullptr,
T* vals_hat = nullptr,
bool save_vals = false);
bool training,
T* vars);
template <typename T>
void launch_layerNorm_backward_fused_add(const T* out_grad1,
......@@ -93,7 +85,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
template <typename T>
......@@ -106,7 +97,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
......@@ -122,7 +112,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
......@@ -135,7 +124,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
......@@ -153,7 +141,6 @@ void launch_layerNorm_backward_nreversible(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
......
......@@ -9,15 +9,13 @@ class Dropout {
public:
struct Config {
float ratio;
uint32_t batch, dim;
uint32_t dim;
bool training;
Config(float r, uint32_t batch, uint32_t dim)
: ratio(r), batch(batch), dim(dim), training(true)
{
}
Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {}
float RATIO() const { return training ? ratio : 0.0; }
inline void SetDim(uint32_t d) { dim = d; }
};
Dropout(const Config& config) : _config(config), _mask(nullptr) {}
......@@ -70,6 +68,8 @@ public:
Config GetConfig() const { return _config; }
inline void SetDimension(uint32_t dim) { _config.SetDim(dim); }
private:
uint8_t* _mask;
Config _config;
......
......@@ -121,11 +121,17 @@ public:
void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr);
uint8_t* layer_output_dropout_mask_ptr,
T* layer_norm_var,
T* layer_norm_mean,
T* attn_layer_norm_var,
T* attn_layer_norm_mean);
inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
void SetSeqLength(int seq_len, int bsz);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
......@@ -150,8 +156,8 @@ private:
// layers
FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _norm_layer2;
Normalize_Layer<T> _norm_layer3;
Normalize_Layer<T> _attn_layer_norm;
Normalize_Layer<T> _layer_norm;
Normalize_Layer<T>* _last_normalize;
FeedForward<T> _ff1, _ff2;
Softmax<T> _softmax;
......
......@@ -9,13 +9,8 @@ template <typename T>
class Gelu {
public:
struct Config {
uint32_t batch_size;
uint32_t seq_length;
uint32_t intermediate_size;
Config(uint32_t batch, uint32_t seq, uint32_t inter_size)
: batch_size(batch), seq_length(seq), intermediate_size(inter_size)
{
}
Config(uint32_t inter_size) : intermediate_size(inter_size) {}
};
Gelu(const Config& config) : _config(config) {}
......@@ -28,14 +23,12 @@ public:
T* output,
cudaStream_t stream)
{
launch_bias_gelu<T>(
input_buf, bias, output, _config.intermediate_size, bsz, _config.seq_length, stream);
launch_bias_gelu<T>(input_buf, bias, output, _config.intermediate_size, bsz, stream);
}
void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream)
{
launch_d_gelu<T>(
d_output, input_buf, bias, _config.intermediate_size, bsz, _config.seq_length, stream);
launch_d_gelu<T>(d_output, input_buf, bias, _config.intermediate_size, bsz, stream);
}
private:
......
......@@ -16,57 +16,27 @@ public:
uint32_t seqLength;
uint32_t hiddenDim;
float epsilon;
bool training, save_vals;
bool allocateGrad;
bool training;
bool useMean;
Config(uint32_t batch,
uint32_t seq,
uint32_t h,
bool training,
bool save_vals = true,
bool allocateGrad = true,
bool useMean = true)
Config(uint32_t batch, uint32_t seq, uint32_t h, bool training, bool useMean = true)
: batchSize(batch),
seqLength(seq),
hiddenDim(h),
epsilon(1e-12),
training(training),
save_vals(save_vals),
allocateGrad(allocateGrad),
useMean(useMean)
{
}
};
Normalize_Layer(Config config) : config_(config), vars(nullptr), vals_hat(nullptr)
Normalize_Layer(Config config)
: config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
{
if (config_.training) {
cudaMalloc((void**)&vars, config_.batchSize * config_.seqLength * sizeof(T));
if (config_.useMean)
cudaMalloc((void**)&means, config_.batchSize * config_.seqLength * sizeof(T));
if (config_.save_vals)
cudaMalloc((void**)&vals_hat,
config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
if (config_.allocateGrad)
cudaMalloc((void**)&inp_grad,
config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
}
}
~Normalize_Layer()
{
if (config_.training) {
cudaFree(vars);
if (config_.useMean) cudaFree(means);
if (config_.save_vals) cudaFree(vals_hat);
if (config_.allocateGrad) cudaFree(inp_grad);
}
}
~Normalize_Layer() {}
void ForwardCheckpoint(int bsz,
void ForwardCheckpoint(int bsz, // batch * seq
T* vals,
const T* residual,
const T* gamma,
......@@ -80,14 +50,12 @@ public:
betta,
config_.epsilon,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
means,
vals_hat);
means);
}
void Forward(int bsz,
......@@ -104,14 +72,11 @@ public:
betta,
config_.epsilon,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
vals_hat,
config_.save_vals);
vars);
}
void Backward(int bsz,
......@@ -120,7 +85,7 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward(out_grad,
......@@ -130,9 +95,8 @@ public:
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream);
}
......@@ -144,21 +108,20 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
const T* norm_out = nullptr)
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward(out_grad,
(config_.save_vals ? vals_hat : norm_out),
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
config_.save_vals,
!config_.useMean,
betta);
}
......@@ -169,7 +132,7 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward_fused_add(out_grad1,
......@@ -180,9 +143,8 @@ public:
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream);
}
......@@ -195,33 +157,41 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
const T* norm_out = nullptr)
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward_fused_add(out_grad1,
out_grad2,
(config_.save_vals ? vals_hat : norm_out),
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
config_.save_vals,
!config_.useMean,
betta);
}
inline T* GetInputGrad() const { return inp_grad; }
inline bool UseMean() const { return config_.useMean; }
inline void SetVar(T* variance)
{
if (!variance) { throw std::runtime_error("Normalize variance is null."); }
vars = variance;
}
inline void SetMean(T* mean)
{
if (!mean) { throw std::runtime_error("Normalize mean is null."); }
means = mean;
}
private:
Config config_;
T* vars;
T* means;
T* vals_hat;
T* inp_grad;
};
......@@ -45,13 +45,15 @@ public:
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
}
inline int GetProbDepth() const { return config_.prob_depth; }
inline size_t GetProbDepth() const { return config_.prob_depth; }
inline int GetBatchSize() const { return config_.batchSize; }
inline size_t GetBatchSize() const { return config_.batchSize; }
inline int GetNumHeads() const { return config_.heads; }
inline size_t GetNumHeads() const { return config_.heads; }
inline int GetSeqLength() const { return config_.seq_length; }
inline size_t GetSeqLength() const { return config_.seq_length; }
inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private:
Config config_;
......
......@@ -3,6 +3,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "context.h"
template <typename T>
class StridedBatchGemm {
......@@ -38,6 +39,12 @@ public:
gemm_algos(algos)
{
}
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config& config) : _config(config) {}
......@@ -163,6 +170,8 @@ public:
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
......
......@@ -34,7 +34,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -74,7 +79,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -122,7 +132,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -170,7 +185,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
......
......@@ -78,20 +78,16 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
hidden_size,
hidden_size,
gemm_algos[0])),
_norm_layer2(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
true,
false,
false,
!normalize_invertible)),
_norm_layer3(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
true,
false,
false,
!normalize_invertible)),
_attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
true,
!normalize_invertible)),
_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
true,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length,
_intermediate_size,
hidden_size,
......@@ -101,16 +97,10 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
_intermediate_size,
gemm_algos[2])),
_softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
_gelu(typename Gelu<T>::Config(_batch_size, _seq_length, _intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio,
_batch_size * _heads * _seq_length,
_seq_length)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio,
_batch_size * _seq_length,
_hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio,
_batch_size * _seq_length,
_hidden_size)),
_gelu(typename Gelu<T>::Config(_intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_seq_length,
_seq_length,
......@@ -196,18 +186,18 @@ void BertTransformerLayer<T>::Forward(int bsz,
if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size;
if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size;
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.ForwardCheckpoint(
bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_norm_layer3.Forward(
bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
_layer_norm.Forward(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm)
_qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
else
......@@ -247,19 +237,19 @@ void BertTransformerLayer<T>::Forward(int bsz,
bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
if (_pre_or_postLayerNorm) {
if (_norm_layer2.UseMean())
_norm_layer2.ForwardCheckpoint(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_norm_layer2.Forward(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} else {
if (_norm_layer2.UseMean())
_norm_layer2.ForwardCheckpoint(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_norm_layer2.Forward(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
}
_ff1.Forward(bsz_seq,
......@@ -268,7 +258,7 @@ void BertTransformerLayer<T>::Forward(int bsz,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
_cublasHandle);
_gelu.ForwardWithBiasAdd(bsz,
_gelu.ForwardWithBiasAdd(bsz_seq,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr,
(_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
......@@ -289,11 +279,12 @@ void BertTransformerLayer<T>::Forward(int bsz,
bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
if (!_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.ForwardCheckpoint(
bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_norm_layer3.Forward(bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
_layer_norm.Forward(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
}
......@@ -359,26 +350,26 @@ void BertTransformerLayer<T>::Backward(int bsz,
int bsz_heads = bsz * _heads;
if (!_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.Backward(bsz,
grad_output_ptr,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
inp_norm_ptr);
if (_layer_norm.UseMean())
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
inp_norm_ptr);
else
_norm_layer3.Backward(bsz,
grad_output_ptr,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
output_ptr);
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
output_ptr);
}
if (_pre_or_postLayerNorm)
......@@ -390,7 +381,8 @@ void BertTransformerLayer<T>::Backward(int bsz,
? buf_0
: (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
if (_gelu_checkpoint) _gelu.ForwardWithBiasAdd(bsz, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
if (_gelu_checkpoint)
_gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
_ff2.Backward(bsz_seq,
layer_dropout_buf,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
......@@ -402,7 +394,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
ff2_buf);
_gelu.Backward(
bsz, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
_ff1.Backward(bsz_seq,
ff2_buf,
......@@ -418,49 +410,49 @@ void BertTransformerLayer<T>::Backward(int bsz,
launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
if (_pre_or_postLayerNorm) {
if (_norm_layer2.UseMean())
_norm_layer2.BackwardFusedAdd(bsz,
buf_3,
grad_output_ptr,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
if (_attn_layer_norm.UseMean())
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_norm_layer2.BackwardFusedAdd(bsz,
buf_3,
grad_output_ptr,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
} else {
if (_norm_layer2.UseMean())
_norm_layer2.Backward(bsz,
buf_2,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
if (_attn_layer_norm.UseMean())
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_norm_layer2.Backward(bsz,
buf_2,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
}
_attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
......@@ -525,28 +517,28 @@ void BertTransformerLayer<T>::Backward(int bsz,
buf_2);
if (_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.BackwardFusedAdd(bsz,
buf_2,
buf_0,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
input_ptr);
if (_layer_norm.UseMean())
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
input_ptr);
else
_norm_layer3.BackwardFusedAdd(bsz,
buf_2,
buf_0,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
inp_norm_ptr);
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
inp_norm_ptr);
} else
launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
}
......@@ -563,11 +555,34 @@ void BertTransformerLayer<T>::SetTrainingMode(bool training)
template <typename T>
void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr)
uint8_t* layer_output_dropout_mask_ptr,
T* attn_layer_norm_var,
T* attn_layer_norm_mean,
T* layer_norm_var,
T* layer_norm_mean)
{
_attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
_attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
_layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
_attn_layer_norm.SetVar(attn_layer_norm_var);
_attn_layer_norm.SetMean(attn_layer_norm_mean);
_layer_norm.SetVar(layer_norm_var);
_layer_norm.SetMean(layer_norm_mean);
}
template <typename T>
void BertTransformerLayer<T>::SetSeqLength(int seq_len, int bsz)
{
_seq_length = seq_len;
_softmax.SetSeqLength(_seq_length);
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
Context::Instance().GenWorkSpace(get_workspace_size<T>(
bsz, _seq_length, _hidden_size, _intermediate_size, _heads, _training, _gelu_checkpoint));
}
template <typename T>
......@@ -688,54 +703,61 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
int seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len, bsz);
}
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
auto qkv_tf = torch::empty({(bsz * layer->GetSeqLength()), output_w.size(0) * 3}, options);
auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
auto attn_prob_dropout_mask =
torch::empty({(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
uint8_options);
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
auto attn_output_dropout_mask =
torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options);
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto layer_output_dropout_mask =
torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options);
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
T* inp_norm_ptr = (T*)inp_norm.data_ptr();
T* add_res_ptr = (T*)add_res.data_ptr();
T* q_tf_ptr = (T*)qkv_tf.data_ptr();
T* k_tf_ptr =
q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)k_tf.data_ptr();
T* v_tf_ptr =
k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)v_tf.data_ptr();
T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
torch::Tensor ff2_inp =
torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options);
torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
torch::Tensor gelu_inp =
(gelu_checkpoint
? ff2_inp
: torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options));
(gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
auto ff1_inp = torch::empty_like(input);
T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
torch::Tensor soft_out = torch::empty(
{(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, options);
torch::Tensor soft_out =
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
torch::Tensor ctx_bufB =
(attn_dropout_checkpoint
? soft_out
: torch::empty(
{(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
options));
: torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
T* soft_out_ptr = (T*)soft_out.data_ptr();
T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
layer->SetTrainingMode(training_mode);
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr());
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Forward(bsz,
input_ptr,
......@@ -777,7 +799,11 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask};
layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean};
}
template <typename T>
......@@ -796,6 +822,10 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
const torch::Tensor& attn_prob_dropout_mask,
const torch::Tensor& attn_output_dropout_mask,
const torch::Tensor& layer_output_dropout_mask,
const torch::Tensor& attn_layer_norm_var,
const torch::Tensor& attn_layer_norm_mean,
const torch::Tensor& layer_norm_var,
const torch::Tensor& layer_norm_mean,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
......@@ -839,6 +869,7 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
CHECK_INPUT(norm_b);
int bsz = g_output.size(0);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
......@@ -901,7 +932,11 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr());
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Backward(bsz,
grad_output_ptr,
......
......@@ -279,13 +279,12 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(input, bias, output, intermediate_size);
}
......@@ -295,24 +294,26 @@ void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(input, output, intermediate_size);
}
template void
launch_bias_gelu<float>(const float*, const float*, float*, int, int, int, cudaStream_t);
template void
launch_bias_gelu<__half>(const __half*, const __half*, __half*, int, int, int, cudaStream_t);
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, int, cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
......@@ -320,17 +321,15 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(d_output, input, bias, intermediate_size);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, int, cudaStream_t);
template void
launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, int, cudaStream_t);
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, cudaStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t);
......@@ -14,15 +14,18 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
......@@ -40,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
out[pos] = sum;
if (pos < (rows * width)) out[pos] = sum;
}
}
......@@ -58,10 +61,10 @@ void launch_fuse_transpose_bias_kernel<float>(const float* inp,
int cols,
cudaStream_t stream)
{
assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0);
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
......@@ -74,10 +77,10 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
int cols,
cudaStream_t stream)
{
assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0);
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
......
......@@ -27,10 +27,9 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training = false,
float* vars = nullptr,
float* means = nullptr,
float* vals_hat = nullptr)
bool training,
float* vars,
float* means)
{
constexpr int iteration_stride = row_stride / iterations;
......@@ -108,10 +107,9 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training = false,
__half* vars = nullptr,
__half* means = nullptr,
__half* vals_hat = nullptr)
bool training,
__half* vars,
__half* means)
{
#if __CUDA_ARCH__ >= 700
constexpr int iteration_stride = row_stride / iterations;
......@@ -204,14 +202,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means,
T* vals_hat);
T* means);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
......@@ -220,40 +216,38 @@ void launch_bias_residual_layer_norm<float>(float* vals,
const float* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means,
float* vals_hat)
float* means)
{
constexpr int threads = THREADS;
dim3 grid_dim(batch_size * sequence_length);
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
......@@ -265,39 +259,37 @@ void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
__half* vals_hat)
__half* means)
{
constexpr int threads = 128;
dim3 grid_dim(batch_size * sequence_length);
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat);
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
......@@ -309,10 +301,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training = false,
float* vars = nullptr,
float* vals_hat = nullptr,
bool save_vals = false)
bool training,
float* vars)
{
constexpr int iteration_stride = row_stride / iterations;
......@@ -388,10 +378,8 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training = false,
__half* vars = nullptr,
__half* vals_hat = nullptr,
bool save_vals = false)
bool training,
__half* vars)
{
#if __CUDA_ARCH__ >= 700
constexpr int iteration_stride = row_stride / iterations;
......@@ -481,14 +469,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* vals_hat,
bool save_vals);
T* vars);
/*
To tune this launch the following restrictions must be met:
......@@ -512,88 +497,37 @@ void launch_bias_residual_layer_norm<float>(float* vals,
const float* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* vals_hat,
bool save_vals)
float* vars)
{
constexpr int threads = THREADS;
dim3 grid_dim(batch_size * sequence_length);
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
......@@ -605,87 +539,36 @@ void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* vals_hat,
bool save_vals)
__half* vars)
{
constexpr int threads = 128;
dim3 grid_dim(batch_size * sequence_length);
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768)
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 512)
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(vals,
residual,
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
else
throw std::runtime_error("Unsupport hidden_dim.");
}
......@@ -1037,15 +920,13 @@ void launch_layerNorm_backward<float>(const float* out_grad,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch_size,
int sequence_length,
int batch,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const float* betta)
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
......@@ -1086,15 +967,13 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch_size,
int sequence_length,
int batch,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const __half* betta)
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
......@@ -1336,13 +1215,11 @@ void launch_layerNorm_backward<float>(const float* out_grad,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch_size,
int sequence_length,
int batch,
int hidden_dim,
cudaStream_t stream[2])
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
......@@ -1384,13 +1261,11 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch_size,
int sequence_length,
int batch,
int hidden_dim,
cudaStream_t stream[2])
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
......@@ -1759,15 +1634,13 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch_size,
int sequence_length,
int batch,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const float* betta)
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
......@@ -1808,15 +1681,13 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch_size,
int sequence_length,
int batch,
int hidden_dim,
cudaStream_t stream[2],
bool invertible,
const __half* betta)
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
......@@ -2070,13 +1941,11 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch_size,
int sequence_length,
int batch,
int hidden_dim,
cudaStream_t stream[2])
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
......@@ -2119,13 +1988,11 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch_size,
int sequence_length,
int batch,
int hidden_dim,
cudaStream_t stream[2])
{
constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
......
#include <math.h>
#include "custom_cuda_layers.h"
#include "general_kernels.h"
......@@ -282,7 +283,7 @@ __global__ void attn_softmax(__half* vals,
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t, bool);
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
......@@ -294,11 +295,10 @@ void launch_attn_softmax<float>(float* vals,
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -330,8 +330,9 @@ void launch_attn_softmax<float>(float* vals,
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -362,11 +363,10 @@ void launch_attn_softmax<__half>(__half* vals,
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -399,8 +399,9 @@ void launch_attn_softmax<__half>(__half* vals,
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -531,55 +532,41 @@ void launch_attn_softmax_backward_v2(T* out_grad,
int seq_length,
cudaStream_t stream)
{
if ((seq_length % WARP_SIZE) != 0 || seq_length > 2048)
throw std::runtime_error("Invalid sequence length found in softmax backward.");
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
switch (seq_length) {
case 32:
softmax_backward_kernel_v2<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 64:
softmax_backward_kernel_v2<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 128:
softmax_backward_kernel_v2<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 256:
softmax_backward_kernel_v2<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 384:
softmax_backward_kernel_v2<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 512:
softmax_backward_kernel_v2<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 768:
softmax_backward_kernel_v2<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 1024:
softmax_backward_kernel_v2<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 2048:
softmax_backward_kernel_v2<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
default:
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
if (seq_length <= 32)
softmax_backward_kernel_v2<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 64)
softmax_backward_kernel_v2<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 128)
softmax_backward_kernel_v2<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 256)
softmax_backward_kernel_v2<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 384)
softmax_backward_kernel_v2<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 512)
softmax_backward_kernel_v2<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 768)
softmax_backward_kernel_v2<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 1024)
softmax_backward_kernel_v2<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else if (seq_length <= 2048)
softmax_backward_kernel_v2<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
......
......@@ -187,26 +187,30 @@ class DeepSpeedTransformerFunction(Function):
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask) = forward_func(config.layer_id,
input,
input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b,
config.training,
config.pre_layer_norm,
config.attn_dropout_checkpoint,
config.normalize_invertible,
config.gelu_checkpoint)
layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean) = forward_func(config.layer_id,
input,
input_mask,
attn_qkvw,
attn_qkvb,
attn_ow,
attn_ob,
attn_nw,
attn_nb,
inter_w,
inter_b,
output_w,
output_b,
norm_w,
norm_b,
config.training,
config.pre_layer_norm,
config.attn_dropout_checkpoint,
config.normalize_invertible,
config.gelu_checkpoint)
# For testing only.
if grads is not None:
......@@ -283,6 +287,9 @@ class DeepSpeedTransformerFunction(Function):
if not config.normalize_invertible:
ctx.add_res = add_res
ctx.attn_layer_norm_mean = attn_layer_norm_mean
ctx.layer_norm_mean = layer_norm_mean
ctx.ff1_inp = ff1_inp
if not config.gelu_checkpoint:
ctx.gelu_inp = gelu_inp
......@@ -291,6 +298,8 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_prob_dropout_mask = attn_prob_dropout_mask
ctx.attn_output_dropout_mask = attn_output_dropout_mask
ctx.layer_output_dropout_mask = layer_output_dropout_mask
ctx.attn_layer_norm_var = attn_layer_norm_var
ctx.layer_norm_var = layer_norm_var
return output
......@@ -367,6 +376,10 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_prob_dropout_mask,
ctx.attn_output_dropout_mask,
ctx.layer_output_dropout_mask,
ctx.attn_layer_norm_var,
ctx.attn_layer_norm_mean,
ctx.layer_norm_var,
ctx.layer_norm_mean,
(ctx.inp_norm if (ctx.config.pre_layer_norm
and ctx.config.normalize_invertible) else input),
input_mask,
......
......@@ -256,10 +256,10 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
[
(3,1024,128,16,24,True,False, 0.05),
(3,1024,128,16,24,True,True, 0.05),
(3,1024,128,16,24,False,False, 0.1),
(3,1024,128,16,24,False,True, 0.2),
(3,1024,120,16,24,True,False, 0.05),
(3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2),
]) # yapf: disable
def test_backward(batch_size,
hidden_size,
......
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import pytest
import json
import random
import time
import copy
from torch import nn
from modelingpreln import BertEncoder as BertEncoderPreln
from modeling import BertEncoder as BertEncoderPostln
from modeling import BertLayerNorm, BertConfig
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
import deepspeed
import sys
if not deepspeed.ops.__installed_ops__['transformer']:
pytest.skip("transformer kernels are not installed", allow_module_level=True)
def check_equal(first, second, atol=1e-2, verbose=False):
if verbose:
print()
for i, (x, y) in enumerate(zip(first, second)):
x = x[0].cpu().detach().numpy()
y = y[0].cpu().detach().numpy()
if verbose:
print("x = {}".format(x.flatten()))
print("y = {}".format(y.flatten()))
print('-' * 80)
np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i), atol=atol)
def zero_grad(variables):
for variable in variables:
variable.grad.zero_()
device = torch.device("cuda")
kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True}
kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True}
class DSEncoder(nn.Module):
def __init__(self, config, weights, biases):
super(DSEncoder, self).__init__()
self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i,
config,
weights,
biases))
for i in range(config.num_hidden_layers)
])
self.grads = []
self.pre_or_post = config.pre_layer_norm
def forward(self,
hidden_states,
attention_mask,
output_all_encoded_layers=True,
checkpoint_activations=False):
all_encoder_layers = []
def custom(start, end):
def custom_forward(*inputs):
layers = self.layer[start:end]
x_ = inputs[0]
for layer in layers:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
if checkpoint_activations:
l = 0
num_layers = len(self.layer)
chunk_length = math.ceil(math.sqrt(num_layers))
while l < num_layers:
hidden_states = checkpoint.checkpoint(custom(l,
l + chunk_length),
hidden_states,
attention_mask * 1)
l += chunk_length
# decoder layers
else:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask)
hidden_states.register_hook(
lambda x,
i=i,
self=self: self.grads.append([x,
"hidden_state"]))
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers or checkpoint_activations:
if (self.pre_or_post):
hidden_states = self.FinalLayerNorm(hidden_states)
all_encoder_layers.append(hidden_states)
return all_encoder_layers
def get_grads(self):
return self.grads
def create_models(ds_config):
bert_config = BertConfig(vocab_size_or_config_json_file=119547,
hidden_size=ds_config.hidden_size,
num_hidden_layers=ds_config.num_hidden_layers,
num_attention_heads=ds_config.heads,
batch_size=ds_config.batch_size,
intermediate_size=ds_config.intermediate_size,
hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
max_position_embeddings=ds_config.max_seq_length,
type_vocab_size=2,
initializer_range=ds_config.initializer_range,
fp16=ds_config.fp16)
weights = []
biases = []
for i in range(4):
weights.append(
nn.Parameter(torch.Tensor(ds_config.hidden_size,
ds_config.hidden_size)))
weights[i].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[4].data.fill_(1.0)
weights.append(
nn.Parameter(torch.Tensor(ds_config.intermediate_size,
ds_config.hidden_size)))
weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(
nn.Parameter(torch.Tensor(ds_config.hidden_size,
ds_config.intermediate_size)))
weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[7].data.fill_(1.0)
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[0].data.zero_()
for i in range(4):
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[i + 1].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size)))
biases[5].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[6].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[7].data.zero_()
if (ds_config.pre_layer_norm):
bert_encoder = BertEncoderPreln(bert_config, weights, biases)
else:
bert_encoder = BertEncoderPostln(bert_config, weights, biases)
ds_encoder = DSEncoder(ds_config, weights, biases)
if ds_config.fp16:
bert_encoder.half()
ds_encoder.half()
bert_encoder.cuda()
ds_encoder.cuda()
return bert_encoder, ds_encoder
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None):
set_seed(123)
bert_encoder, ds_encoder = create_models(ds_config)
bsz = ds_config.batch_size if test_bsz is None else test_bsz
# prepare test data
kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
hidden_states = torch.randn(bsz,
ds_config.max_seq_length,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(bsz, 1, 1, ds_config.max_seq_length, **kwargs)
# run baseline
base_results = bert_encoder(hidden_states,
input_mask,
output_all_encoded_layers=False,
checkpoint_activations=False)
# run ds
ds_results = ds_encoder(hidden_states,
input_mask,
output_all_encoded_layers=False,
checkpoint_activations=False)
# check grads
check_equal(base_results, ds_results, atol=atol, verbose=verbose)
# FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False),
(8,1024,384,16,3,True,True),
(8,1024,512,16,3,True,False),
(8,1024,512,16,3,True,True),
(64,1024,128,16,3,False,False),
(64,1024,128,16,3,False,True),
(8,1024,384,16,3,False,False),
(8,1024,384,16,3,False,True),
(8,1024,512,16,3,False,False),
(8,1024,512,16,3,False,True),
(8,1536,128,24,3,False,False),
(8,1536,128,24,3,False,True),
(8,2048,128,32,3,False,False),
(8,2048,128,32,3,False,True),
(8,2560,128,40,3,False,False),
(8,2560,128,40,3,False,True),
]) # yapf: disable
def test_forward(batch_size,
hidden_size,
seq_len,
heads,
num_layers,
is_preln,
use_fp16):
# Only run fp16 test cases on devices with 7+ capability.
major, _ = torch.cuda.get_device_capability()
if major < 7 and use_fp16 is True:
return
ds_config = DeepSpeedTransformerConfig()
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
ds_config.num_hidden_layers = num_layers
ds_config.pre_layer_norm = is_preln
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
run_forward(ds_config, atol=2e-2)
@pytest.mark.parametrize('batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(8,3,1024,512,16,3,True,False),
(8,7,1024,512,16,3,True,True),
(8,3,1024,512,16,3,False,False),
(8,7,1024,512,16,3,False,True),
]) # yapf: disable
def test_forward_with_small_bsz(batch_size,
small_bsz,
hidden_size,
seq_len,
heads,
num_layers,
is_preln,
use_fp16):
# Only run fp16 test cases on devices with 7+ capability.
major, _ = torch.cuda.get_device_capability()
if major < 7 and use_fp16 is True:
return
ds_config = DeepSpeedTransformerConfig()
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
ds_config.num_hidden_layers = num_layers
ds_config.pre_layer_norm = is_preln
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
run_forward(ds_config, atol=2e-2, test_bsz=small_bsz)
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True),
(64,1024,128,16,3,False,False),
(64,1024,128,16,3,False,True),
]) # yapf: disable
def test_forward_stochastic(batch_size,
hidden_size,
seq_len,
heads,
num_layers,
is_preln,
use_fp16):
# Only run fp16 test cases on devices with 7+ capability.
major, _ = torch.cuda.get_device_capability()
if major < 7 and use_fp16 is True:
return
ds_config = DeepSpeedTransformerConfig()
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
ds_config.num_hidden_layers = num_layers
ds_config.pre_layer_norm = is_preln
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
ds_config.stochastic_mode = True
run_forward(ds_config, atol=7e-2)
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import pytest
import json
import random
import time
import copy
from torch import nn
from modelingpreln import BertEncoder as BertEncoderPreln
from modeling import BertEncoder as BertEncoderPostln
from modeling import BertLayerNorm, BertConfig
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
import deepspeed
import sys
if not deepspeed.ops.__installed_ops__['transformer']:
pytest.skip("transformer kernels are not installed", allow_module_level=True)
def check_equal(first, second, atol=1e-2, verbose=False):
if verbose:
print()
for i, (x, y) in enumerate(zip(first, second)):
x = x[0].cpu().detach().numpy()
y = y[0].cpu().detach().numpy()
if verbose:
print("x = {}".format(x.flatten()))
print("y = {}".format(y.flatten()))
print('-' * 80)
np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i), atol=atol)
def zero_grad(variables):
for variable in variables:
variable.grad.zero_()
device = torch.device("cuda")
kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True}
kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True}
class DSEncoder(nn.Module):
def __init__(self, config, weights, biases):
super(DSEncoder, self).__init__()
self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i,
config,
weights,
biases))
for i in range(config.num_hidden_layers)
])
self.grads = []
self.pre_or_post = config.pre_layer_norm
def forward(self,
hidden_states,
attention_mask,
output_all_encoded_layers=True,
checkpoint_activations=False):
all_encoder_layers = []
def custom(start, end):
def custom_forward(*inputs):
layers = self.layer[start:end]
x_ = inputs[0]
for layer in layers:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
if checkpoint_activations:
l = 0
num_layers = len(self.layer)
chunk_length = math.ceil(math.sqrt(num_layers))
while l < num_layers:
hidden_states = checkpoint.checkpoint(custom(l,
l + chunk_length),
hidden_states,
attention_mask * 1)
l += chunk_length
# decoder layers
else:
for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask)
hidden_states.register_hook(
lambda x,
i=i,
self=self: self.grads.append([x,
"hidden_state"]))
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers or checkpoint_activations:
if (self.pre_or_post):
hidden_states = self.FinalLayerNorm(hidden_states)
all_encoder_layers.append(hidden_states)
return all_encoder_layers
def get_grads(self):
return self.grads
def create_models(ds_config):
bert_config = BertConfig(vocab_size_or_config_json_file=119547,
hidden_size=ds_config.hidden_size,
num_hidden_layers=ds_config.num_hidden_layers,
num_attention_heads=ds_config.heads,
batch_size=ds_config.batch_size,
intermediate_size=ds_config.intermediate_size,
hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
max_position_embeddings=ds_config.max_seq_length,
type_vocab_size=2,
initializer_range=ds_config.initializer_range,
fp16=ds_config.fp16)
weights = []
biases = []
for i in range(4):
weights.append(
nn.Parameter(torch.Tensor(ds_config.hidden_size,
ds_config.hidden_size)))
weights[i].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[4].data.fill_(1.0)
weights.append(
nn.Parameter(torch.Tensor(ds_config.intermediate_size,
ds_config.hidden_size)))
weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(
nn.Parameter(torch.Tensor(ds_config.hidden_size,
ds_config.intermediate_size)))
weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[7].data.fill_(1.0)
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[0].data.zero_()
for i in range(4):
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[i + 1].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size)))
biases[5].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[6].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[7].data.zero_()
if (ds_config.pre_layer_norm):
bert_encoder = BertEncoderPreln(bert_config, weights, biases)
else:
bert_encoder = BertEncoderPostln(bert_config, weights, biases)
ds_encoder = DSEncoder(ds_config, weights, biases)
if ds_config.fp16:
bert_encoder.half()
ds_encoder.half()
bert_encoder.cuda()
ds_encoder.cuda()
return bert_encoder, ds_encoder
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
set_seed(123)
bert_encoder, ds_encoder = create_models(ds_config)
bsz = ds_config.batch_size if test_bsz is None else test_bsz
# prepare test data
kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
hidden_states = torch.randn(bsz,
seq_len, #ds_config.max_seq_length,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(bsz, 1, 1,
seq_len, #ds_config.max_seq_length,
**kwargs)
# run baseline
base_results = bert_encoder(hidden_states,
input_mask,
output_all_encoded_layers=False,
checkpoint_activations=False)
# run ds
ds_results = ds_encoder(hidden_states,
input_mask,
output_all_encoded_layers=False,
checkpoint_activations=False)
# check grads
check_equal(base_results, ds_results, atol=atol, verbose=verbose)
# FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False),
(8,1024,384,16,3,True,True),
(8,1024,384,16,3,True,True),
(8,1024,120,16,3,True,False),
(8,1024,120,16,3,True,True),
(8,1024,512,16,3,True,False),
(8,1024,512,16,3,True,True),
(64,1024,56,16,3,False,False),
(64,1024,56,16,3,False,True),
(64,1024,24,16,3,False,False),
(64,1024,24,16,3,False,True),
(8,1024,384,16,3,False,False),
(8,1024,384,16,3,False,True),
(8,1024,512,16,3,False,False),
(8,1024,512,16,3,False,True),
(8,1536,128,24,3,False,False),
(8,1536,128,24,3,False,True),
(8,2048,128,32,3,False,False),
(8,2048,128,32,3,False,True),
(8,2560,128,40,3,False,False),
(8,2560,128,40,3,False,True),
]) # yapf: disable
def test_forward(batch_size,
hidden_size,
seq_len,
heads,
num_layers,
is_preln,
use_fp16):
# Only run fp16 test cases on devices with 7+ capability.
major, _ = torch.cuda.get_device_capability()
if major < 7 and use_fp16 is True:
return
ds_config = DeepSpeedTransformerConfig()
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.max_seq_length = 128 #seq_len
ds_config.intermediate_size = 4 * hidden_size
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
ds_config.num_hidden_layers = num_layers
ds_config.pre_layer_norm = is_preln
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
run_forward(ds_config, seq_len, atol=2e-2)
@pytest.mark.parametrize('batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(8,3,1024,512,16,3,True,False),
(8,7,1024,512,16,3,True,True),
(8,3,1024,512,16,3,False,False),
(8,7,1024,512,16,3,False,True),
]) # yapf: disable
def test_forward_with_small_bsz(batch_size,
small_bsz,
hidden_size,
seq_len,
heads,
num_layers,
is_preln,
use_fp16):
# Only run fp16 test cases on devices with 7+ capability.
major, _ = torch.cuda.get_device_capability()
if major < 7 and use_fp16 is True:
return
ds_config = DeepSpeedTransformerConfig()
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
ds_config.num_hidden_layers = num_layers
ds_config.pre_layer_norm = is_preln
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
run_forward(ds_config, seq_len, atol=2e-2, test_bsz=small_bsz)
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
(64,1024,128,16,3,True,False),
(64,1024,128,16,3,True,True),
(64,1024,128,16,3,False,False),
(64,1024,128,16,3,False,True),
]) # yapf: disable
def test_forward_stochastic(batch_size,
hidden_size,
seq_len,
heads,
num_layers,
is_preln,
use_fp16):
# Only run fp16 test cases on devices with 7+ capability.
major, _ = torch.cuda.get_device_capability()
if major < 7 and use_fp16 is True:
return
ds_config = DeepSpeedTransformerConfig()
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
ds_config.num_hidden_layers = num_layers
ds_config.pre_layer_norm = is_preln
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
ds_config.stochastic_mode = True
run_forward(ds_config, seq_len, atol=7e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment