"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "beb848e2b6cc888bd5039e6f6cac7c932c6c3225"
Unverified Commit f0f2a702 authored by RezaYazdaniAminabadi's avatar RezaYazdaniAminabadi Committed by GitHub
Browse files

support dynamic sequence length in transformer kernels (#424)


Co-authored-by: default avatarConglong Li <conglong.li@gmail.com>
parent 71f7df39
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512 #define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 4096 #define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N) inline int DS_GET_BLOCKS(const int N)
{ {
......
...@@ -29,7 +29,6 @@ void launch_bias_gelu(const T* input, ...@@ -29,7 +29,6 @@ void launch_bias_gelu(const T* input,
T* output, T* output,
int intermediate_size, int intermediate_size,
int batch_size, int batch_size,
int sequence_length,
cudaStream_t stream); cudaStream_t stream);
template <typename T> template <typename T>
...@@ -37,7 +36,6 @@ void launch_gelu(const T* input, ...@@ -37,7 +36,6 @@ void launch_gelu(const T* input,
T* output, T* output,
int intermediate_size, int intermediate_size,
int batch_size, int batch_size,
int sequence_length,
cudaStream_t stream); cudaStream_t stream);
template <typename T> template <typename T>
...@@ -46,7 +44,6 @@ void launch_d_gelu(T* d_output, ...@@ -46,7 +44,6 @@ void launch_d_gelu(T* d_output,
const T* bias, const T* bias,
int intermediate_size, int intermediate_size,
int batch_size, int batch_size,
int sequence_length,
cudaStream_t stream); cudaStream_t stream);
// Custom fused bias add with layer normalization // Custom fused bias add with layer normalization
...@@ -57,14 +54,12 @@ void launch_bias_residual_layer_norm(T* vals, ...@@ -57,14 +54,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta, const T* beta,
float epsilon, float epsilon,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream, cudaStream_t stream,
bool preLayerNorm, bool preLayerNorm,
bool training = false, bool training,
T* vars = nullptr, T* vars,
T* means = nullptr, T* means);
T* vals_hat = nullptr);
template <typename T> template <typename T>
void launch_bias_residual_layer_norm(T* vals, void launch_bias_residual_layer_norm(T* vals,
...@@ -73,14 +68,11 @@ void launch_bias_residual_layer_norm(T* vals, ...@@ -73,14 +68,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta, const T* beta,
float epsilon, float epsilon,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream, cudaStream_t stream,
bool preLayerNorm, bool preLayerNorm,
bool training = false, bool training,
T* vars = nullptr, T* vars);
T* vals_hat = nullptr,
bool save_vals = false);
template <typename T> template <typename T>
void launch_layerNorm_backward_fused_add(const T* out_grad1, void launch_layerNorm_backward_fused_add(const T* out_grad1,
...@@ -93,7 +85,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1, ...@@ -93,7 +85,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad, T* betta_grad,
T* inp_grad, T* inp_grad,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]); cudaStream_t stream[2]);
template <typename T> template <typename T>
...@@ -106,7 +97,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1, ...@@ -106,7 +97,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad, T* betta_grad,
T* inp_grad, T* inp_grad,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2], cudaStream_t stream[2],
bool invertible = false, bool invertible = false,
...@@ -122,7 +112,6 @@ void launch_layerNorm_backward(const T* out_grad, ...@@ -122,7 +112,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad, T* betta_grad,
T* inp_grad, T* inp_grad,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]); cudaStream_t stream[2]);
...@@ -135,7 +124,6 @@ void launch_layerNorm_backward(const T* out_grad, ...@@ -135,7 +124,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad, T* betta_grad,
T* inp_grad, T* inp_grad,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2], cudaStream_t stream[2],
bool invertible = false, bool invertible = false,
...@@ -153,7 +141,6 @@ void launch_layerNorm_backward_nreversible(const T* out_grad, ...@@ -153,7 +141,6 @@ void launch_layerNorm_backward_nreversible(const T* out_grad,
T* betta_grad, T* betta_grad,
T* inp_grad, T* inp_grad,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]); cudaStream_t stream[2]);
......
...@@ -9,15 +9,13 @@ class Dropout { ...@@ -9,15 +9,13 @@ class Dropout {
public: public:
struct Config { struct Config {
float ratio; float ratio;
uint32_t batch, dim; uint32_t dim;
bool training; bool training;
Config(float r, uint32_t batch, uint32_t dim) Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {}
: ratio(r), batch(batch), dim(dim), training(true)
{
}
float RATIO() const { return training ? ratio : 0.0; } float RATIO() const { return training ? ratio : 0.0; }
inline void SetDim(uint32_t d) { dim = d; }
}; };
Dropout(const Config& config) : _config(config), _mask(nullptr) {} Dropout(const Config& config) : _config(config), _mask(nullptr) {}
...@@ -70,6 +68,8 @@ public: ...@@ -70,6 +68,8 @@ public:
Config GetConfig() const { return _config; } Config GetConfig() const { return _config; }
inline void SetDimension(uint32_t dim) { _config.SetDim(dim); }
private: private:
uint8_t* _mask; uint8_t* _mask;
Config _config; Config _config;
......
...@@ -121,11 +121,17 @@ public: ...@@ -121,11 +121,17 @@ public:
void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr, uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr); uint8_t* layer_output_dropout_mask_ptr,
T* layer_norm_var,
T* layer_norm_mean,
T* attn_layer_norm_var,
T* attn_layer_norm_mean);
inline int GetBatchSize() const { return _batch_size; } inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; } inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; } inline int GetSeqLength() const { return _seq_length; }
void SetSeqLength(int seq_len, int bsz);
inline int GetHiddenSize() const { return _hidden_size; } inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training); void SetTrainingMode(bool training);
...@@ -150,8 +156,8 @@ private: ...@@ -150,8 +156,8 @@ private:
// layers // layers
FeedForward<T> _qkv_linear; FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear; FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _norm_layer2; Normalize_Layer<T> _attn_layer_norm;
Normalize_Layer<T> _norm_layer3; Normalize_Layer<T> _layer_norm;
Normalize_Layer<T>* _last_normalize; Normalize_Layer<T>* _last_normalize;
FeedForward<T> _ff1, _ff2; FeedForward<T> _ff1, _ff2;
Softmax<T> _softmax; Softmax<T> _softmax;
......
...@@ -9,13 +9,8 @@ template <typename T> ...@@ -9,13 +9,8 @@ template <typename T>
class Gelu { class Gelu {
public: public:
struct Config { struct Config {
uint32_t batch_size;
uint32_t seq_length;
uint32_t intermediate_size; uint32_t intermediate_size;
Config(uint32_t batch, uint32_t seq, uint32_t inter_size) Config(uint32_t inter_size) : intermediate_size(inter_size) {}
: batch_size(batch), seq_length(seq), intermediate_size(inter_size)
{
}
}; };
Gelu(const Config& config) : _config(config) {} Gelu(const Config& config) : _config(config) {}
...@@ -28,14 +23,12 @@ public: ...@@ -28,14 +23,12 @@ public:
T* output, T* output,
cudaStream_t stream) cudaStream_t stream)
{ {
launch_bias_gelu<T>( launch_bias_gelu<T>(input_buf, bias, output, _config.intermediate_size, bsz, stream);
input_buf, bias, output, _config.intermediate_size, bsz, _config.seq_length, stream);
} }
void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream) void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream)
{ {
launch_d_gelu<T>( launch_d_gelu<T>(d_output, input_buf, bias, _config.intermediate_size, bsz, stream);
d_output, input_buf, bias, _config.intermediate_size, bsz, _config.seq_length, stream);
} }
private: private:
......
...@@ -16,57 +16,27 @@ public: ...@@ -16,57 +16,27 @@ public:
uint32_t seqLength; uint32_t seqLength;
uint32_t hiddenDim; uint32_t hiddenDim;
float epsilon; float epsilon;
bool training, save_vals; bool training;
bool allocateGrad;
bool useMean; bool useMean;
Config(uint32_t batch, Config(uint32_t batch, uint32_t seq, uint32_t h, bool training, bool useMean = true)
uint32_t seq,
uint32_t h,
bool training,
bool save_vals = true,
bool allocateGrad = true,
bool useMean = true)
: batchSize(batch), : batchSize(batch),
seqLength(seq), seqLength(seq),
hiddenDim(h), hiddenDim(h),
epsilon(1e-12), epsilon(1e-12),
training(training), training(training),
save_vals(save_vals),
allocateGrad(allocateGrad),
useMean(useMean) useMean(useMean)
{ {
} }
}; };
Normalize_Layer(Config config) : config_(config), vars(nullptr), vals_hat(nullptr) Normalize_Layer(Config config)
: config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
{ {
if (config_.training) {
cudaMalloc((void**)&vars, config_.batchSize * config_.seqLength * sizeof(T));
if (config_.useMean)
cudaMalloc((void**)&means, config_.batchSize * config_.seqLength * sizeof(T));
if (config_.save_vals)
cudaMalloc((void**)&vals_hat,
config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
if (config_.allocateGrad)
cudaMalloc((void**)&inp_grad,
config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
}
} }
~Normalize_Layer() ~Normalize_Layer() {}
{
if (config_.training) {
cudaFree(vars);
if (config_.useMean) cudaFree(means);
if (config_.save_vals) cudaFree(vals_hat);
if (config_.allocateGrad) cudaFree(inp_grad);
}
}
void ForwardCheckpoint(int bsz, void ForwardCheckpoint(int bsz, // batch * seq
T* vals, T* vals,
const T* residual, const T* residual,
const T* gamma, const T* gamma,
...@@ -80,14 +50,12 @@ public: ...@@ -80,14 +50,12 @@ public:
betta, betta,
config_.epsilon, config_.epsilon,
bsz, bsz,
config_.seqLength,
config_.hiddenDim, config_.hiddenDim,
stream, stream,
preLayerNorm, preLayerNorm,
config_.training, config_.training,
vars, vars,
means, means);
vals_hat);
} }
void Forward(int bsz, void Forward(int bsz,
...@@ -104,14 +72,11 @@ public: ...@@ -104,14 +72,11 @@ public:
betta, betta,
config_.epsilon, config_.epsilon,
bsz, bsz,
config_.seqLength,
config_.hiddenDim, config_.hiddenDim,
stream, stream,
preLayerNorm, preLayerNorm,
config_.training, config_.training,
vars, vars);
vals_hat,
config_.save_vals);
} }
void Backward(int bsz, void Backward(int bsz,
...@@ -120,7 +85,7 @@ public: ...@@ -120,7 +85,7 @@ public:
T* gamma_grad, T* gamma_grad,
T* betta_grad, T* betta_grad,
cudaStream_t stream[2], cudaStream_t stream[2],
T* inp_grad_out = nullptr, T* inp_grad_out,
const T* norm_in = nullptr) const T* norm_in = nullptr)
{ {
launch_layerNorm_backward(out_grad, launch_layerNorm_backward(out_grad,
...@@ -130,9 +95,8 @@ public: ...@@ -130,9 +95,8 @@ public:
gamma, gamma,
gamma_grad, gamma_grad,
betta_grad, betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out), inp_grad_out,
bsz, bsz,
config_.seqLength,
config_.hiddenDim, config_.hiddenDim,
stream); stream);
} }
...@@ -144,21 +108,20 @@ public: ...@@ -144,21 +108,20 @@ public:
T* gamma_grad, T* gamma_grad,
T* betta_grad, T* betta_grad,
cudaStream_t stream[2], cudaStream_t stream[2],
T* inp_grad_out = nullptr, T* inp_grad_out,
const T* norm_out = nullptr) const T* norm_out)
{ {
launch_layerNorm_backward(out_grad, launch_layerNorm_backward(out_grad,
(config_.save_vals ? vals_hat : norm_out), norm_out,
vars, vars,
gamma, gamma,
gamma_grad, gamma_grad,
betta_grad, betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out), inp_grad_out,
bsz, bsz,
config_.seqLength,
config_.hiddenDim, config_.hiddenDim,
stream, stream,
config_.save_vals, !config_.useMean,
betta); betta);
} }
...@@ -169,7 +132,7 @@ public: ...@@ -169,7 +132,7 @@ public:
T* gamma_grad, T* gamma_grad,
T* betta_grad, T* betta_grad,
cudaStream_t stream[2], cudaStream_t stream[2],
T* inp_grad_out = nullptr, T* inp_grad_out,
const T* norm_in = nullptr) const T* norm_in = nullptr)
{ {
launch_layerNorm_backward_fused_add(out_grad1, launch_layerNorm_backward_fused_add(out_grad1,
...@@ -180,9 +143,8 @@ public: ...@@ -180,9 +143,8 @@ public:
gamma, gamma,
gamma_grad, gamma_grad,
betta_grad, betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out), inp_grad_out,
bsz, bsz,
config_.seqLength,
config_.hiddenDim, config_.hiddenDim,
stream); stream);
} }
...@@ -195,33 +157,41 @@ public: ...@@ -195,33 +157,41 @@ public:
T* gamma_grad, T* gamma_grad,
T* betta_grad, T* betta_grad,
cudaStream_t stream[2], cudaStream_t stream[2],
T* inp_grad_out = nullptr, T* inp_grad_out,
const T* norm_out = nullptr) const T* norm_out)
{ {
launch_layerNorm_backward_fused_add(out_grad1, launch_layerNorm_backward_fused_add(out_grad1,
out_grad2, out_grad2,
(config_.save_vals ? vals_hat : norm_out), norm_out,
vars, vars,
gamma, gamma,
gamma_grad, gamma_grad,
betta_grad, betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out), inp_grad_out,
bsz, bsz,
config_.seqLength,
config_.hiddenDim, config_.hiddenDim,
stream, stream,
config_.save_vals, !config_.useMean,
betta); betta);
} }
inline T* GetInputGrad() const { return inp_grad; }
inline bool UseMean() const { return config_.useMean; } inline bool UseMean() const { return config_.useMean; }
inline void SetVar(T* variance)
{
if (!variance) { throw std::runtime_error("Normalize variance is null."); }
vars = variance;
}
inline void SetMean(T* mean)
{
if (!mean) { throw std::runtime_error("Normalize mean is null."); }
means = mean;
}
private: private:
Config config_; Config config_;
T* vars; T* vars;
T* means; T* means;
T* vals_hat; T* vals_hat;
T* inp_grad;
}; };
...@@ -45,13 +45,15 @@ public: ...@@ -45,13 +45,15 @@ public:
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream); out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
} }
inline int GetProbDepth() const { return config_.prob_depth; } inline size_t GetProbDepth() const { return config_.prob_depth; }
inline int GetBatchSize() const { return config_.batchSize; } inline size_t GetBatchSize() const { return config_.batchSize; }
inline int GetNumHeads() const { return config_.heads; } inline size_t GetNumHeads() const { return config_.heads; }
inline int GetSeqLength() const { return config_.seq_length; } inline size_t GetSeqLength() const { return config_.seq_length; }
inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private: private:
Config config_; Config config_;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include "context.h"
template <typename T> template <typename T>
class StridedBatchGemm { class StridedBatchGemm {
...@@ -38,6 +39,12 @@ public: ...@@ -38,6 +39,12 @@ public:
gemm_algos(algos) gemm_algos(algos)
{ {
} }
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
}; };
StridedBatchGemm(const Config& config) : _config(config) {} StridedBatchGemm(const Config& config) : _config(config) {}
...@@ -163,6 +170,8 @@ public: ...@@ -163,6 +170,8 @@ public:
inline const T* GetBufferB() const { return q_buf; } inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private: private:
Config _config; Config _config;
const T* q_buf; const T* q_buf;
......
...@@ -34,7 +34,12 @@ int cublas_gemm_ex(cublasHandle_t handle, ...@@ -34,7 +34,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo); algo);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n"); fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE; return EXIT_FAILURE;
} }
return 0; return 0;
...@@ -74,7 +79,12 @@ int cublas_gemm_ex(cublasHandle_t handle, ...@@ -74,7 +79,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo); algo);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n"); fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE; return EXIT_FAILURE;
} }
return 0; return 0;
...@@ -122,7 +132,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, ...@@ -122,7 +132,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo); algo);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n"); fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE; return EXIT_FAILURE;
} }
return 0; return 0;
...@@ -170,7 +185,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, ...@@ -170,7 +185,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo); algo);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n"); fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE; return EXIT_FAILURE;
} }
......
...@@ -78,20 +78,16 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id, ...@@ -78,20 +78,16 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
hidden_size, hidden_size,
hidden_size, hidden_size,
gemm_algos[0])), gemm_algos[0])),
_norm_layer2(typename Normalize_Layer<T>::Config(batch_size, _attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length, seq_length,
hidden_size, hidden_size,
true, true,
false, !normalize_invertible)),
false, _layer_norm(typename Normalize_Layer<T>::Config(batch_size,
!normalize_invertible)), seq_length,
_norm_layer3(typename Normalize_Layer<T>::Config(batch_size, hidden_size,
seq_length, true,
hidden_size, !normalize_invertible)),
true,
false,
false,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length, _ff1(typename FeedForward<T>::Config(batch_size * seq_length,
_intermediate_size, _intermediate_size,
hidden_size, hidden_size,
...@@ -101,16 +97,10 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id, ...@@ -101,16 +97,10 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
_intermediate_size, _intermediate_size,
gemm_algos[2])), gemm_algos[2])),
_softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)), _softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
_gelu(typename Gelu<T>::Config(_batch_size, _seq_length, _intermediate_size)), _gelu(typename Gelu<T>::Config(_intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
_batch_size * _heads * _seq_length, _attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_seq_length)), _layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio,
_batch_size * _seq_length,
_hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio,
_batch_size * _seq_length,
_hidden_size)),
_attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads, _attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_seq_length, _seq_length,
_seq_length, _seq_length,
...@@ -196,18 +186,18 @@ void BertTransformerLayer<T>::Forward(int bsz, ...@@ -196,18 +186,18 @@ void BertTransformerLayer<T>::Forward(int bsz,
if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size; if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size;
if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size; if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size;
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm) { if (_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean()) if (_layer_norm.UseMean())
_norm_layer3.ForwardCheckpoint( _layer_norm.ForwardCheckpoint(
bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else else
_norm_layer3.Forward( _layer_norm.Forward(
bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
} }
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm) if (_pre_or_postLayerNorm)
_qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle); _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
else else
...@@ -247,19 +237,19 @@ void BertTransformerLayer<T>::Forward(int bsz, ...@@ -247,19 +237,19 @@ void BertTransformerLayer<T>::Forward(int bsz,
bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream); bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
if (_pre_or_postLayerNorm) { if (_pre_or_postLayerNorm) {
if (_norm_layer2.UseMean()) if (_attn_layer_norm.UseMean())
_norm_layer2.ForwardCheckpoint( _attn_layer_norm.ForwardCheckpoint(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else else
_norm_layer2.Forward( _attn_layer_norm.Forward(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} else { } else {
if (_norm_layer2.UseMean()) if (_attn_layer_norm.UseMean())
_norm_layer2.ForwardCheckpoint( _attn_layer_norm.ForwardCheckpoint(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else else
_norm_layer2.Forward( _attn_layer_norm.Forward(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} }
_ff1.Forward(bsz_seq, _ff1.Forward(bsz_seq,
...@@ -268,7 +258,7 @@ void BertTransformerLayer<T>::Forward(int bsz, ...@@ -268,7 +258,7 @@ void BertTransformerLayer<T>::Forward(int bsz,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
_cublasHandle); _cublasHandle);
_gelu.ForwardWithBiasAdd(bsz, _gelu.ForwardWithBiasAdd(bsz_seq,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr, inter_b_ptr,
(_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr), (_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
...@@ -289,11 +279,12 @@ void BertTransformerLayer<T>::Forward(int bsz, ...@@ -289,11 +279,12 @@ void BertTransformerLayer<T>::Forward(int bsz,
bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream); bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
if (!_pre_or_postLayerNorm) { if (!_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean()) if (_layer_norm.UseMean())
_norm_layer3.ForwardCheckpoint( _layer_norm.ForwardCheckpoint(
bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else else
_norm_layer3.Forward(bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); _layer_norm.Forward(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
} }
} }
...@@ -359,26 +350,26 @@ void BertTransformerLayer<T>::Backward(int bsz, ...@@ -359,26 +350,26 @@ void BertTransformerLayer<T>::Backward(int bsz,
int bsz_heads = bsz * _heads; int bsz_heads = bsz * _heads;
if (!_pre_or_postLayerNorm) { if (!_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean()) if (_layer_norm.UseMean())
_norm_layer3.Backward(bsz, _layer_norm.Backward(bsz_seq,
grad_output_ptr, grad_output_ptr,
norm_w_ptr, norm_w_ptr,
grad_norm_w_ptr, grad_norm_w_ptr,
grad_norm_b_ptr, grad_norm_b_ptr,
streams, streams,
buf_1, buf_1,
inp_norm_ptr); inp_norm_ptr);
else else
_norm_layer3.Backward(bsz, _layer_norm.Backward(bsz_seq,
grad_output_ptr, grad_output_ptr,
norm_w_ptr, norm_w_ptr,
norm_b_ptr, norm_b_ptr,
grad_norm_w_ptr, grad_norm_w_ptr,
grad_norm_b_ptr, grad_norm_b_ptr,
streams, streams,
buf_1, buf_1,
output_ptr); output_ptr);
} }
if (_pre_or_postLayerNorm) if (_pre_or_postLayerNorm)
...@@ -390,7 +381,8 @@ void BertTransformerLayer<T>::Backward(int bsz, ...@@ -390,7 +381,8 @@ void BertTransformerLayer<T>::Backward(int bsz,
? buf_0 ? buf_0
: (_pre_or_postLayerNorm ? grad_output_ptr : buf_1); : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
if (_gelu_checkpoint) _gelu.ForwardWithBiasAdd(bsz, ff2_inp_ptr, inter_b_ptr, buf_2, _stream); if (_gelu_checkpoint)
_gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
_ff2.Backward(bsz_seq, _ff2.Backward(bsz_seq,
layer_dropout_buf, layer_dropout_buf,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr), (_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
...@@ -402,7 +394,7 @@ void BertTransformerLayer<T>::Backward(int bsz, ...@@ -402,7 +394,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
ff2_buf); ff2_buf);
_gelu.Backward( _gelu.Backward(
bsz, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream); bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
_ff1.Backward(bsz_seq, _ff1.Backward(bsz_seq,
ff2_buf, ff2_buf,
...@@ -418,49 +410,49 @@ void BertTransformerLayer<T>::Backward(int bsz, ...@@ -418,49 +410,49 @@ void BertTransformerLayer<T>::Backward(int bsz,
launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream); launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
if (_pre_or_postLayerNorm) { if (_pre_or_postLayerNorm) {
if (_norm_layer2.UseMean()) if (_attn_layer_norm.UseMean())
_norm_layer2.BackwardFusedAdd(bsz, _attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3, buf_3,
grad_output_ptr, grad_output_ptr,
attn_nw_ptr, attn_nw_ptr,
grad_attn_nw_ptr, grad_attn_nw_ptr,
grad_attn_nb_ptr, grad_attn_nb_ptr,
streams, streams,
buf_0, buf_0,
add_res_ptr); add_res_ptr);
else else
_norm_layer2.BackwardFusedAdd(bsz, _attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3, buf_3,
grad_output_ptr, grad_output_ptr,
attn_nw_ptr, attn_nw_ptr,
attn_nb_ptr, attn_nb_ptr,
grad_attn_nw_ptr, grad_attn_nw_ptr,
grad_attn_nb_ptr, grad_attn_nb_ptr,
streams, streams,
buf_0, buf_0,
ff1_inp_ptr); ff1_inp_ptr);
} else { } else {
if (_norm_layer2.UseMean()) if (_attn_layer_norm.UseMean())
_norm_layer2.Backward(bsz, _attn_layer_norm.Backward(bsz_seq,
buf_2, buf_2,
attn_nw_ptr, attn_nw_ptr,
grad_attn_nw_ptr, grad_attn_nw_ptr,
grad_attn_nb_ptr, grad_attn_nb_ptr,
streams, streams,
buf_0, buf_0,
add_res_ptr); add_res_ptr);
else else
_norm_layer2.Backward(bsz, _attn_layer_norm.Backward(bsz_seq,
buf_2, buf_2,
attn_nw_ptr, attn_nw_ptr,
attn_nb_ptr, attn_nb_ptr,
grad_attn_nw_ptr, grad_attn_nw_ptr,
grad_attn_nb_ptr, grad_attn_nb_ptr,
streams, streams,
buf_0, buf_0,
ff1_inp_ptr); ff1_inp_ptr);
} }
_attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream); _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
...@@ -525,28 +517,28 @@ void BertTransformerLayer<T>::Backward(int bsz, ...@@ -525,28 +517,28 @@ void BertTransformerLayer<T>::Backward(int bsz,
buf_2); buf_2);
if (_pre_or_postLayerNorm) { if (_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean()) if (_layer_norm.UseMean())
_norm_layer3.BackwardFusedAdd(bsz, _layer_norm.BackwardFusedAdd(bsz_seq,
buf_2, buf_2,
buf_0, buf_0,
norm_w_ptr, norm_w_ptr,
grad_norm_w_ptr, grad_norm_w_ptr,
grad_norm_b_ptr, grad_norm_b_ptr,
streams, streams,
grad_input_ptr, grad_input_ptr,
input_ptr); input_ptr);
else else
_norm_layer3.BackwardFusedAdd(bsz, _layer_norm.BackwardFusedAdd(bsz_seq,
buf_2, buf_2,
buf_0, buf_0,
norm_w_ptr, norm_w_ptr,
norm_b_ptr, norm_b_ptr,
grad_norm_w_ptr, grad_norm_w_ptr,
grad_norm_b_ptr, grad_norm_b_ptr,
streams, streams,
grad_input_ptr, grad_input_ptr,
inp_norm_ptr); inp_norm_ptr);
} else } else
launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream); launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
} }
...@@ -563,11 +555,34 @@ void BertTransformerLayer<T>::SetTrainingMode(bool training) ...@@ -563,11 +555,34 @@ void BertTransformerLayer<T>::SetTrainingMode(bool training)
template <typename T> template <typename T>
void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr, uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr) uint8_t* layer_output_dropout_mask_ptr,
T* attn_layer_norm_var,
T* attn_layer_norm_mean,
T* layer_norm_var,
T* layer_norm_mean)
{ {
_attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr); _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
_attn_output_dropout.SetMask(attn_output_dropout_mask_ptr); _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
_layer_output_dropout.SetMask(layer_output_dropout_mask_ptr); _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
_attn_layer_norm.SetVar(attn_layer_norm_var);
_attn_layer_norm.SetMean(attn_layer_norm_mean);
_layer_norm.SetVar(layer_norm_var);
_layer_norm.SetMean(layer_norm_mean);
}
template <typename T>
void BertTransformerLayer<T>::SetSeqLength(int seq_len, int bsz)
{
_seq_length = seq_len;
_softmax.SetSeqLength(_seq_length);
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
Context::Instance().GenWorkSpace(get_workspace_size<T>(
bsz, _seq_length, _hidden_size, _intermediate_size, _heads, _training, _gelu_checkpoint));
} }
template <typename T> template <typename T>
...@@ -688,54 +703,61 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id, ...@@ -688,54 +703,61 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
std::shared_ptr<BertTransformerLayer<T>> layer = std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]); std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
int seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len, bsz);
}
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output); auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input)); auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input); auto attn_o_inp = torch::empty_like(input);
auto qkv_tf = torch::empty({(bsz * layer->GetSeqLength()), output_w.size(0) * 3}, options); auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
auto attn_prob_dropout_mask = auto attn_prob_dropout_mask =
torch::empty({(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
uint8_options);
auto attn_output_dropout_mask = auto attn_output_dropout_mask =
torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options); torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto layer_output_dropout_mask = auto layer_output_dropout_mask =
torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options); torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
T* inp_norm_ptr = (T*)inp_norm.data_ptr(); T* inp_norm_ptr = (T*)inp_norm.data_ptr();
T* add_res_ptr = (T*)add_res.data_ptr(); T* add_res_ptr = (T*)add_res.data_ptr();
T* q_tf_ptr = (T*)qkv_tf.data_ptr(); T* q_tf_ptr = (T*)qkv_tf.data_ptr();
T* k_tf_ptr = T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)k_tf.data_ptr(); T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
T* v_tf_ptr =
k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)v_tf.data_ptr();
T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr(); T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
torch::Tensor ff2_inp = torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options);
torch::Tensor gelu_inp = torch::Tensor gelu_inp =
(gelu_checkpoint (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
? ff2_inp
: torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options));
auto ff1_inp = torch::empty_like(input); auto ff1_inp = torch::empty_like(input);
T* ff2_inp_ptr = (T*)ff2_inp.data_ptr(); T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
T* gelu_inp_ptr = (T*)gelu_inp.data_ptr(); T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
T* ff1_inp_ptr = (T*)ff1_inp.data_ptr(); T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
torch::Tensor soft_out = torch::empty( torch::Tensor soft_out =
{(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, options); torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
torch::Tensor ctx_bufB = torch::Tensor ctx_bufB =
(attn_dropout_checkpoint (attn_dropout_checkpoint
? soft_out ? soft_out
: torch::empty( : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
{(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
options));
T* soft_out_ptr = (T*)soft_out.data_ptr(); T* soft_out_ptr = (T*)soft_out.data_ptr();
T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr(); T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
layer->SetTrainingMode(training_mode); layer->SetTrainingMode(training_mode);
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(), (uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr()); (uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Forward(bsz, layer->Forward(bsz,
input_ptr, input_ptr,
...@@ -777,7 +799,11 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id, ...@@ -777,7 +799,11 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
ff2_inp, ff2_inp,
attn_prob_dropout_mask, attn_prob_dropout_mask,
attn_output_dropout_mask, attn_output_dropout_mask,
layer_output_dropout_mask}; layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean};
} }
template <typename T> template <typename T>
...@@ -796,6 +822,10 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id, ...@@ -796,6 +822,10 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
const torch::Tensor& attn_prob_dropout_mask, const torch::Tensor& attn_prob_dropout_mask,
const torch::Tensor& attn_output_dropout_mask, const torch::Tensor& attn_output_dropout_mask,
const torch::Tensor& layer_output_dropout_mask, const torch::Tensor& layer_output_dropout_mask,
const torch::Tensor& attn_layer_norm_var,
const torch::Tensor& attn_layer_norm_mean,
const torch::Tensor& layer_norm_var,
const torch::Tensor& layer_norm_mean,
const torch::Tensor& input, const torch::Tensor& input,
const torch::Tensor& input_mask, const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw, const torch::Tensor& attn_qkvw,
...@@ -839,6 +869,7 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id, ...@@ -839,6 +869,7 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
CHECK_INPUT(norm_b); CHECK_INPUT(norm_b);
int bsz = g_output.size(0); int bsz = g_output.size(0);
std::shared_ptr<BertTransformerLayer<T>> layer = std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]); std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
...@@ -901,7 +932,11 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id, ...@@ -901,7 +932,11 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(), (uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr()); (uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Backward(bsz, layer->Backward(bsz,
grad_output_ptr, grad_output_ptr,
......
...@@ -279,13 +279,12 @@ void launch_bias_gelu(const T* input, ...@@ -279,13 +279,12 @@ void launch_bias_gelu(const T* input,
T* output, T* output,
int intermediate_size, int intermediate_size,
int batch_size, int batch_size,
int sequence_length,
cudaStream_t stream) cudaStream_t stream)
{ {
int iterations = (intermediate_size + 1023) / 1024; int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4; int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads); dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size); dim3 grid_dims(batch_size);
fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(input, bias, output, intermediate_size); fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(input, bias, output, intermediate_size);
} }
...@@ -295,24 +294,26 @@ void launch_gelu(const T* input, ...@@ -295,24 +294,26 @@ void launch_gelu(const T* input,
T* output, T* output,
int intermediate_size, int intermediate_size,
int batch_size, int batch_size,
int sequence_length,
cudaStream_t stream) cudaStream_t stream)
{ {
int iterations = (intermediate_size + 1023) / 1024; int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4; int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads); dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size); dim3 grid_dims(batch_size);
gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(input, output, intermediate_size); gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(input, output, intermediate_size);
} }
template void template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, cudaStream_t);
launch_bias_gelu<float>(const float*, const float*, float*, int, int, int, cudaStream_t); template void launch_bias_gelu<__half>(const __half*,
template void const __half*,
launch_bias_gelu<__half>(const __half*, const __half*, __half*, int, int, int, cudaStream_t); __half*,
int,
int,
cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, int, cudaStream_t); template void launch_gelu<float>(const float*, float*, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, int, cudaStream_t); template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t);
template <typename T> template <typename T>
void launch_d_gelu(T* d_output, void launch_d_gelu(T* d_output,
...@@ -320,17 +321,15 @@ void launch_d_gelu(T* d_output, ...@@ -320,17 +321,15 @@ void launch_d_gelu(T* d_output,
const T* bias, const T* bias,
int intermediate_size, int intermediate_size,
int batch_size, int batch_size,
int sequence_length,
cudaStream_t stream) cudaStream_t stream)
{ {
int iterations = (intermediate_size + 1023) / 1024; int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4; int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads); dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size); dim3 grid_dims(batch_size);
d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(d_output, input, bias, intermediate_size); d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(d_output, input, bias, intermediate_size);
} }
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, int, cudaStream_t); template void launch_d_gelu<float>(float*, const float*, const float*, int, int, cudaStream_t);
template void template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t);
launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, int, cudaStream_t);
...@@ -14,15 +14,18 @@ __global__ void column_sum_reduce(const T* __restrict__ inp, ...@@ -14,15 +14,18 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b); cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM; int y_stride = width * TILE_DIM;
float localSum = 0; float localSum = 0;
// Loop across matrix height // Loop across matrix height
for (int r = threadIdx.y; r < rows; r += TILE_DIM) { if (idx < width) {
localSum += (float)inp[offset]; int offset = threadIdx.y * width + idx;
offset += y_stride; for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
} }
tile[threadIdx.x][threadIdx.y] = localSum; tile[threadIdx.x][threadIdx.y] = localSum;
...@@ -40,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp, ...@@ -40,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y; int pos = blockIdx.x * TILE_DIM + threadIdx.y;
out[pos] = sum; if (pos < (rows * width)) out[pos] = sum;
} }
} }
...@@ -58,10 +61,10 @@ void launch_fuse_transpose_bias_kernel<float>(const float* inp, ...@@ -58,10 +61,10 @@ void launch_fuse_transpose_bias_kernel<float>(const float* inp,
int cols, int cols,
cudaStream_t stream) cudaStream_t stream)
{ {
assert(rows % TILE_DIM == 0); // assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0); // assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM); dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols); column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
...@@ -74,10 +77,10 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, ...@@ -74,10 +77,10 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
int cols, int cols,
cudaStream_t stream) cudaStream_t stream)
{ {
assert(rows % TILE_DIM == 0); // assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0); // assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM); dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols); column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
......
...@@ -27,10 +27,9 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -27,10 +27,9 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
const float* beta, const float* beta,
float epsilon, float epsilon,
bool preLayerNorm, bool preLayerNorm,
bool training = false, bool training,
float* vars = nullptr, float* vars,
float* means = nullptr, float* means)
float* vals_hat = nullptr)
{ {
constexpr int iteration_stride = row_stride / iterations; constexpr int iteration_stride = row_stride / iterations;
...@@ -108,10 +107,9 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -108,10 +107,9 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* beta, const __half* beta,
float epsilon, float epsilon,
bool preLayerNorm, bool preLayerNorm,
bool training = false, bool training,
__half* vars = nullptr, __half* vars,
__half* means = nullptr, __half* means)
__half* vals_hat = nullptr)
{ {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
constexpr int iteration_stride = row_stride / iterations; constexpr int iteration_stride = row_stride / iterations;
...@@ -204,14 +202,12 @@ void launch_bias_residual_layer_norm(T* vals, ...@@ -204,14 +202,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta, const T* beta,
float epsilon, float epsilon,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream, cudaStream_t stream,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
T* vars, T* vars,
T* means, T* means);
T* vals_hat);
template <> template <>
void launch_bias_residual_layer_norm<float>(float* vals, void launch_bias_residual_layer_norm<float>(float* vals,
...@@ -220,40 +216,38 @@ void launch_bias_residual_layer_norm<float>(float* vals, ...@@ -220,40 +216,38 @@ void launch_bias_residual_layer_norm<float>(float* vals,
const float* beta, const float* beta,
float epsilon, float epsilon,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream, cudaStream_t stream,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
float* vars, float* vars,
float* means, float* means)
float* vals_hat)
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
dim3 grid_dim(batch_size * sequence_length); dim3 grid_dim(batch_size);
dim3 block_dim(threads); dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations. // There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768) if (hidden_dim == 768)
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 512) else if (hidden_dim == 512)
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1024) else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1536) else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2048) else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2560) else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
} }
...@@ -265,39 +259,37 @@ void launch_bias_residual_layer_norm<__half>(__half* vals, ...@@ -265,39 +259,37 @@ void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* beta, const __half* beta,
float epsilon, float epsilon,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream, cudaStream_t stream,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
__half* vars, __half* vars,
__half* means, __half* means)
__half* vals_hat)
{ {
constexpr int threads = 128; constexpr int threads = 128;
dim3 grid_dim(batch_size * sequence_length); dim3 grid_dim(batch_size);
dim3 block_dim(threads); dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations. // There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768) if (hidden_dim == 768)
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 512) else if (hidden_dim == 512)
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1024) else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 1536) else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2048) else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else if (hidden_dim == 2560) else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>( fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, vals_hat); vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means);
else else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
} }
...@@ -309,10 +301,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals, ...@@ -309,10 +301,8 @@ __global__ void fused_bias_residual_layer_norm(float* vals,
const float* beta, const float* beta,
float epsilon, float epsilon,
bool preLayerNorm, bool preLayerNorm,
bool training = false, bool training,
float* vars = nullptr, float* vars)
float* vals_hat = nullptr,
bool save_vals = false)
{ {
constexpr int iteration_stride = row_stride / iterations; constexpr int iteration_stride = row_stride / iterations;
...@@ -388,10 +378,8 @@ __global__ void fused_bias_residual_layer_norm(__half* vals, ...@@ -388,10 +378,8 @@ __global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* beta, const __half* beta,
float epsilon, float epsilon,
bool preLayerNorm, bool preLayerNorm,
bool training = false, bool training,
__half* vars = nullptr, __half* vars)
__half* vals_hat = nullptr,
bool save_vals = false)
{ {
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
constexpr int iteration_stride = row_stride / iterations; constexpr int iteration_stride = row_stride / iterations;
...@@ -481,14 +469,11 @@ void launch_bias_residual_layer_norm(T* vals, ...@@ -481,14 +469,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta, const T* beta,
float epsilon, float epsilon,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream, cudaStream_t stream,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
T* vars, T* vars);
T* vals_hat,
bool save_vals);
/* /*
To tune this launch the following restrictions must be met: To tune this launch the following restrictions must be met:
...@@ -512,88 +497,37 @@ void launch_bias_residual_layer_norm<float>(float* vals, ...@@ -512,88 +497,37 @@ void launch_bias_residual_layer_norm<float>(float* vals,
const float* beta, const float* beta,
float epsilon, float epsilon,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream, cudaStream_t stream,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
float* vars, float* vars)
float* vals_hat,
bool save_vals)
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
dim3 grid_dim(batch_size * sequence_length); dim3 grid_dim(batch_size);
dim3 block_dim(threads); dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations. // There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768) if (hidden_dim == 768)
fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<768, 3><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 512) else if (hidden_dim == 512)
fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<512, 2><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 1024) else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<1024, 4><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 1536) else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<1536, 6><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 2048) else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<2048, 8><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 2560) else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<2560, 10><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
} }
...@@ -605,87 +539,36 @@ void launch_bias_residual_layer_norm<__half>(__half* vals, ...@@ -605,87 +539,36 @@ void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* beta, const __half* beta,
float epsilon, float epsilon,
int batch_size, int batch_size,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream, cudaStream_t stream,
bool preLayerNorm, bool preLayerNorm,
bool training, bool training,
__half* vars, __half* vars)
__half* vals_hat,
bool save_vals)
{ {
constexpr int threads = 128; constexpr int threads = 128;
dim3 grid_dim(batch_size * sequence_length); dim3 grid_dim(batch_size);
dim3 block_dim(threads); dim3 block_dim(threads);
// There are some limitations to call below functions, now just enumerate the situations. // There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim == 768) if (hidden_dim == 768)
fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<384, 3><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 512) else if (hidden_dim == 512)
fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<256, 2><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 1024) else if (hidden_dim == 1024)
fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<512, 4><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 1536) else if (hidden_dim == 1536)
fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<768, 6><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 2048) else if (hidden_dim == 2048)
fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<1024, 8><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else if (hidden_dim == 2560) else if (hidden_dim == 2560)
fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(vals, fused_bias_residual_layer_norm<1280, 10><<<grid_dim, block_dim, 0, stream>>>(
residual, vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars);
gamma,
beta,
epsilon,
preLayerNorm,
training,
vars,
vals_hat,
save_vals);
else else
throw std::runtime_error("Unsupport hidden_dim."); throw std::runtime_error("Unsupport hidden_dim.");
} }
...@@ -1037,15 +920,13 @@ void launch_layerNorm_backward<float>(const float* out_grad, ...@@ -1037,15 +920,13 @@ void launch_layerNorm_backward<float>(const float* out_grad,
float* gamma_grad, float* gamma_grad,
float* betta_grad, float* betta_grad,
float* inp_grad, float* inp_grad,
int batch_size, int batch,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2], cudaStream_t stream[2],
bool invertible, bool invertible,
const float* betta) const float* betta)
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1086,15 +967,13 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, ...@@ -1086,15 +967,13 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
__half* gamma_grad, __half* gamma_grad,
__half* betta_grad, __half* betta_grad,
__half* inp_grad, __half* inp_grad,
int batch_size, int batch,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2], cudaStream_t stream[2],
bool invertible, bool invertible,
const __half* betta) const __half* betta)
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1336,13 +1215,11 @@ void launch_layerNorm_backward<float>(const float* out_grad, ...@@ -1336,13 +1215,11 @@ void launch_layerNorm_backward<float>(const float* out_grad,
float* gamma_grad, float* gamma_grad,
float* betta_grad, float* betta_grad,
float* inp_grad, float* inp_grad,
int batch_size, int batch,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]) cudaStream_t stream[2])
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1384,13 +1261,11 @@ void launch_layerNorm_backward<__half>(const __half* out_grad, ...@@ -1384,13 +1261,11 @@ void launch_layerNorm_backward<__half>(const __half* out_grad,
__half* gamma_grad, __half* gamma_grad,
__half* betta_grad, __half* betta_grad,
__half* inp_grad, __half* inp_grad,
int batch_size, int batch,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]) cudaStream_t stream[2])
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1759,15 +1634,13 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1, ...@@ -1759,15 +1634,13 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
float* gamma_grad, float* gamma_grad,
float* betta_grad, float* betta_grad,
float* inp_grad, float* inp_grad,
int batch_size, int batch,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2], cudaStream_t stream[2],
bool invertible, bool invertible,
const float* betta) const float* betta)
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -1808,15 +1681,13 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, ...@@ -1808,15 +1681,13 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
__half* gamma_grad, __half* gamma_grad,
__half* betta_grad, __half* betta_grad,
__half* inp_grad, __half* inp_grad,
int batch_size, int batch,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2], cudaStream_t stream[2],
bool invertible, bool invertible,
const __half* betta) const __half* betta)
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -2070,13 +1941,11 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1, ...@@ -2070,13 +1941,11 @@ void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
float* gamma_grad, float* gamma_grad,
float* betta_grad, float* betta_grad,
float* inp_grad, float* inp_grad,
int batch_size, int batch,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]) cudaStream_t stream[2])
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
...@@ -2119,13 +1988,11 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, ...@@ -2119,13 +1988,11 @@ void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
__half* gamma_grad, __half* gamma_grad,
__half* betta_grad, __half* betta_grad,
__half* inp_grad, __half* inp_grad,
int batch_size, int batch,
int sequence_length,
int hidden_dim, int hidden_dim,
cudaStream_t stream[2]) cudaStream_t stream[2])
{ {
constexpr int threads = THREADS; constexpr int threads = THREADS;
int batch = batch_size * sequence_length;
dim3 grid_dim(hidden_dim / TILE_DIM); dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM); dim3 block_dim(TILE_DIM, TILE_DIM);
......
#include <math.h>
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
#include "general_kernels.h" #include "general_kernels.h"
...@@ -282,7 +283,7 @@ __global__ void attn_softmax(__half* vals, ...@@ -282,7 +283,7 @@ __global__ void attn_softmax(__half* vals,
} }
template <typename T> template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t, bool); void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t);
template <> template <>
void launch_attn_softmax<float>(float* vals, void launch_attn_softmax<float>(float* vals,
...@@ -294,11 +295,10 @@ void launch_attn_softmax<float>(float* vals, ...@@ -294,11 +295,10 @@ void launch_attn_softmax<float>(float* vals,
{ {
const int threads = 128; const int threads = 128;
int seq_length4 = sequence_length / 4; int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size = int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4); (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size); dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
...@@ -330,8 +330,9 @@ void launch_attn_softmax<float>(float* vals, ...@@ -330,8 +330,9 @@ void launch_attn_softmax<float>(float* vals,
else { else {
const int threads = 256; const int threads = 256;
block_compute_size = block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4); (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size); : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
...@@ -362,11 +363,10 @@ void launch_attn_softmax<__half>(__half* vals, ...@@ -362,11 +363,10 @@ void launch_attn_softmax<__half>(__half* vals,
{ {
const int threads = 128; const int threads = 128;
int seq_length4 = sequence_length / 4; int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size = int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4); (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size); dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
...@@ -399,8 +399,9 @@ void launch_attn_softmax<__half>(__half* vals, ...@@ -399,8 +399,9 @@ void launch_attn_softmax<__half>(__half* vals,
else { else {
const int threads = 256; const int threads = 256;
block_compute_size = block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4); (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size); : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
...@@ -531,55 +532,41 @@ void launch_attn_softmax_backward_v2(T* out_grad, ...@@ -531,55 +532,41 @@ void launch_attn_softmax_backward_v2(T* out_grad,
int seq_length, int seq_length,
cudaStream_t stream) cudaStream_t stream)
{ {
if ((seq_length % WARP_SIZE) != 0 || seq_length > 2048)
throw std::runtime_error("Invalid sequence length found in softmax backward.");
const int warps_per_block = 4; const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block); dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block); dim3 block_dim(WARP_SIZE, warps_per_block);
switch (seq_length) { if (seq_length <= 32)
case 32: softmax_backward_kernel_v2<T, 1>
softmax_backward_kernel_v2<T, 1> <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length); else if (seq_length <= 64)
break; softmax_backward_kernel_v2<T, 2>
case 64: <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
softmax_backward_kernel_v2<T, 2> else if (seq_length <= 128)
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length); softmax_backward_kernel_v2<T, 4>
break; <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
case 128: else if (seq_length <= 256)
softmax_backward_kernel_v2<T, 4> softmax_backward_kernel_v2<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length); <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break; else if (seq_length <= 384)
case 256: softmax_backward_kernel_v2<T, 12>
softmax_backward_kernel_v2<T, 8> <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length); else if (seq_length <= 512)
break; softmax_backward_kernel_v2<T, 16>
case 384: <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
softmax_backward_kernel_v2<T, 12> else if (seq_length <= 768)
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length); softmax_backward_kernel_v2<T, 24>
break; <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
case 512: else if (seq_length <= 1024)
softmax_backward_kernel_v2<T, 16> softmax_backward_kernel_v2<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length); <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break; else if (seq_length <= 2048)
case 768: softmax_backward_kernel_v2<T, 64>
softmax_backward_kernel_v2<T, 24> <<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length); else
break; throw std::runtime_error(
case 1024: std::string("Special sequence length found in softmax backward, seq_length: ") +
softmax_backward_kernel_v2<T, 32> std::to_string(seq_length));
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 2048:
softmax_backward_kernel_v2<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
default:
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
} }
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad, template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
......
...@@ -187,26 +187,30 @@ class DeepSpeedTransformerFunction(Function): ...@@ -187,26 +187,30 @@ class DeepSpeedTransformerFunction(Function):
ff2_inp, ff2_inp,
attn_prob_dropout_mask, attn_prob_dropout_mask,
attn_output_dropout_mask, attn_output_dropout_mask,
layer_output_dropout_mask) = forward_func(config.layer_id, layer_output_dropout_mask,
input, attn_layer_norm_var,
input_mask, attn_layer_norm_mean,
attn_qkvw, layer_norm_var,
attn_qkvb, layer_norm_mean) = forward_func(config.layer_id,
attn_ow, input,
attn_ob, input_mask,
attn_nw, attn_qkvw,
attn_nb, attn_qkvb,
inter_w, attn_ow,
inter_b, attn_ob,
output_w, attn_nw,
output_b, attn_nb,
norm_w, inter_w,
norm_b, inter_b,
config.training, output_w,
config.pre_layer_norm, output_b,
config.attn_dropout_checkpoint, norm_w,
config.normalize_invertible, norm_b,
config.gelu_checkpoint) config.training,
config.pre_layer_norm,
config.attn_dropout_checkpoint,
config.normalize_invertible,
config.gelu_checkpoint)
# For testing only. # For testing only.
if grads is not None: if grads is not None:
...@@ -283,6 +287,9 @@ class DeepSpeedTransformerFunction(Function): ...@@ -283,6 +287,9 @@ class DeepSpeedTransformerFunction(Function):
if not config.normalize_invertible: if not config.normalize_invertible:
ctx.add_res = add_res ctx.add_res = add_res
ctx.attn_layer_norm_mean = attn_layer_norm_mean
ctx.layer_norm_mean = layer_norm_mean
ctx.ff1_inp = ff1_inp ctx.ff1_inp = ff1_inp
if not config.gelu_checkpoint: if not config.gelu_checkpoint:
ctx.gelu_inp = gelu_inp ctx.gelu_inp = gelu_inp
...@@ -291,6 +298,8 @@ class DeepSpeedTransformerFunction(Function): ...@@ -291,6 +298,8 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_prob_dropout_mask = attn_prob_dropout_mask ctx.attn_prob_dropout_mask = attn_prob_dropout_mask
ctx.attn_output_dropout_mask = attn_output_dropout_mask ctx.attn_output_dropout_mask = attn_output_dropout_mask
ctx.layer_output_dropout_mask = layer_output_dropout_mask ctx.layer_output_dropout_mask = layer_output_dropout_mask
ctx.attn_layer_norm_var = attn_layer_norm_var
ctx.layer_norm_var = layer_norm_var
return output return output
...@@ -367,6 +376,10 @@ class DeepSpeedTransformerFunction(Function): ...@@ -367,6 +376,10 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_prob_dropout_mask, ctx.attn_prob_dropout_mask,
ctx.attn_output_dropout_mask, ctx.attn_output_dropout_mask,
ctx.layer_output_dropout_mask, ctx.layer_output_dropout_mask,
ctx.attn_layer_norm_var,
ctx.attn_layer_norm_mean,
ctx.layer_norm_var,
ctx.layer_norm_mean,
(ctx.inp_norm if (ctx.config.pre_layer_norm (ctx.inp_norm if (ctx.config.pre_layer_norm
and ctx.config.normalize_invertible) else input), and ctx.config.normalize_invertible) else input),
input_mask, input_mask,
......
...@@ -256,10 +256,10 @@ def run_backward(ds_config, atol=1e-2, verbose=False): ...@@ -256,10 +256,10 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
[ [
(3,1024,128,16,24,True,False, 0.05), (3,1024,120,16,24,True,False, 0.05),
(3,1024,128,16,24,True,True, 0.05), (3,1024,120,16,24,True,True, 0.05),
(3,1024,128,16,24,False,False, 0.1), (3,1024,56,16,24,False,False, 0.1),
(3,1024,128,16,24,False,True, 0.2), (3,1024,56,16,24,False,True, 0.2),
]) # yapf: disable ]) # yapf: disable
def test_backward(batch_size, def test_backward(batch_size,
hidden_size, hidden_size,
......
import argparse import argparse
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest import pytest
import json import json
import random import random
import time import time
import copy import copy
from torch import nn from torch import nn
from modelingpreln import BertEncoder as BertEncoderPreln from modelingpreln import BertEncoder as BertEncoderPreln
from modeling import BertEncoder as BertEncoderPostln from modeling import BertEncoder as BertEncoderPostln
from modeling import BertLayerNorm, BertConfig from modeling import BertLayerNorm, BertConfig
from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
import deepspeed import deepspeed
import sys import sys
if not deepspeed.ops.__installed_ops__['transformer']: if not deepspeed.ops.__installed_ops__['transformer']:
pytest.skip("transformer kernels are not installed", allow_module_level=True) pytest.skip("transformer kernels are not installed", allow_module_level=True)
def check_equal(first, second, atol=1e-2, verbose=False): def check_equal(first, second, atol=1e-2, verbose=False):
if verbose: if verbose:
print() print()
for i, (x, y) in enumerate(zip(first, second)): for i, (x, y) in enumerate(zip(first, second)):
x = x[0].cpu().detach().numpy() x = x[0].cpu().detach().numpy()
y = y[0].cpu().detach().numpy() y = y[0].cpu().detach().numpy()
if verbose: if verbose:
print("x = {}".format(x.flatten())) print("x = {}".format(x.flatten()))
print("y = {}".format(y.flatten())) print("y = {}".format(y.flatten()))
print('-' * 80) print('-' * 80)
np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i), atol=atol) np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i), atol=atol)
def zero_grad(variables): def zero_grad(variables):
for variable in variables: for variable in variables:
variable.grad.zero_() variable.grad.zero_()
device = torch.device("cuda") device = torch.device("cuda")
kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True} kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True}
kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True} kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True}
class DSEncoder(nn.Module): class DSEncoder(nn.Module):
def __init__(self, config, weights, biases): def __init__(self, config, weights, biases):
super(DSEncoder, self).__init__() super(DSEncoder, self).__init__()
self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.layer = nn.ModuleList([ self.layer = nn.ModuleList([
copy.deepcopy(DeepSpeedTransformerLayer(i, copy.deepcopy(DeepSpeedTransformerLayer(i,
config, config,
weights, weights,
biases)) biases))
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
]) ])
self.grads = [] self.grads = []
self.pre_or_post = config.pre_layer_norm self.pre_or_post = config.pre_layer_norm
def forward(self, def forward(self,
hidden_states, hidden_states,
attention_mask, attention_mask,
output_all_encoded_layers=True, output_all_encoded_layers=True,
checkpoint_activations=False): checkpoint_activations=False):
all_encoder_layers = [] all_encoder_layers = []
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
layers = self.layer[start:end] layers = self.layer[start:end]
x_ = inputs[0] x_ = inputs[0]
for layer in layers: for layer in layers:
x_ = layer(x_, inputs[1]) x_ = layer(x_, inputs[1])
return x_ return x_
return custom_forward return custom_forward
if checkpoint_activations: if checkpoint_activations:
l = 0 l = 0
num_layers = len(self.layer) num_layers = len(self.layer)
chunk_length = math.ceil(math.sqrt(num_layers)) chunk_length = math.ceil(math.sqrt(num_layers))
while l < num_layers: while l < num_layers:
hidden_states = checkpoint.checkpoint(custom(l, hidden_states = checkpoint.checkpoint(custom(l,
l + chunk_length), l + chunk_length),
hidden_states, hidden_states,
attention_mask * 1) attention_mask * 1)
l += chunk_length l += chunk_length
# decoder layers # decoder layers
else: else:
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask) hidden_states = layer_module(hidden_states, attention_mask)
hidden_states.register_hook( hidden_states.register_hook(
lambda x, lambda x,
i=i, i=i,
self=self: self.grads.append([x, self=self: self.grads.append([x,
"hidden_state"])) "hidden_state"]))
if output_all_encoded_layers: if output_all_encoded_layers:
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers or checkpoint_activations: if not output_all_encoded_layers or checkpoint_activations:
if (self.pre_or_post): if (self.pre_or_post):
hidden_states = self.FinalLayerNorm(hidden_states) hidden_states = self.FinalLayerNorm(hidden_states)
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
return all_encoder_layers return all_encoder_layers
def get_grads(self): def get_grads(self):
return self.grads return self.grads
def create_models(ds_config): def create_models(ds_config):
bert_config = BertConfig(vocab_size_or_config_json_file=119547, bert_config = BertConfig(vocab_size_or_config_json_file=119547,
hidden_size=ds_config.hidden_size, hidden_size=ds_config.hidden_size,
num_hidden_layers=ds_config.num_hidden_layers, num_hidden_layers=ds_config.num_hidden_layers,
num_attention_heads=ds_config.heads, num_attention_heads=ds_config.heads,
batch_size=ds_config.batch_size, batch_size=ds_config.batch_size,
intermediate_size=ds_config.intermediate_size, intermediate_size=ds_config.intermediate_size,
hidden_act="gelu", hidden_act="gelu",
hidden_dropout_prob=ds_config.hidden_dropout_ratio, hidden_dropout_prob=ds_config.hidden_dropout_ratio,
attention_probs_dropout_prob=ds_config.attn_dropout_ratio, attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
max_position_embeddings=ds_config.max_seq_length, max_position_embeddings=ds_config.max_seq_length,
type_vocab_size=2, type_vocab_size=2,
initializer_range=ds_config.initializer_range, initializer_range=ds_config.initializer_range,
fp16=ds_config.fp16) fp16=ds_config.fp16)
weights = [] weights = []
biases = [] biases = []
for i in range(4): for i in range(4):
weights.append( weights.append(
nn.Parameter(torch.Tensor(ds_config.hidden_size, nn.Parameter(torch.Tensor(ds_config.hidden_size,
ds_config.hidden_size))) ds_config.hidden_size)))
weights[i].data.normal_(mean=0.0, std=ds_config.initializer_range) weights[i].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[4].data.fill_(1.0) weights[4].data.fill_(1.0)
weights.append( weights.append(
nn.Parameter(torch.Tensor(ds_config.intermediate_size, nn.Parameter(torch.Tensor(ds_config.intermediate_size,
ds_config.hidden_size))) ds_config.hidden_size)))
weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range) weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append( weights.append(
nn.Parameter(torch.Tensor(ds_config.hidden_size, nn.Parameter(torch.Tensor(ds_config.hidden_size,
ds_config.intermediate_size))) ds_config.intermediate_size)))
weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range) weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
weights[7].data.fill_(1.0) weights[7].data.fill_(1.0)
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[0].data.zero_() biases[0].data.zero_()
for i in range(4): for i in range(4):
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[i + 1].data.zero_() biases[i + 1].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size))) biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size)))
biases[5].data.zero_() biases[5].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[6].data.zero_() biases[6].data.zero_()
biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size))) biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
biases[7].data.zero_() biases[7].data.zero_()
if (ds_config.pre_layer_norm): if (ds_config.pre_layer_norm):
bert_encoder = BertEncoderPreln(bert_config, weights, biases) bert_encoder = BertEncoderPreln(bert_config, weights, biases)
else: else:
bert_encoder = BertEncoderPostln(bert_config, weights, biases) bert_encoder = BertEncoderPostln(bert_config, weights, biases)
ds_encoder = DSEncoder(ds_config, weights, biases) ds_encoder = DSEncoder(ds_config, weights, biases)
if ds_config.fp16: if ds_config.fp16:
bert_encoder.half() bert_encoder.half()
ds_encoder.half() ds_encoder.half()
bert_encoder.cuda() bert_encoder.cuda()
ds_encoder.cuda() ds_encoder.cuda()
return bert_encoder, ds_encoder return bert_encoder, ds_encoder
def set_seed(seed): def set_seed(seed):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None): def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
set_seed(123) set_seed(123)
bert_encoder, ds_encoder = create_models(ds_config) bert_encoder, ds_encoder = create_models(ds_config)
bsz = ds_config.batch_size if test_bsz is None else test_bsz bsz = ds_config.batch_size if test_bsz is None else test_bsz
# prepare test data # prepare test data
kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32 kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
hidden_states = torch.randn(bsz, hidden_states = torch.randn(bsz,
ds_config.max_seq_length, seq_len, #ds_config.max_seq_length,
ds_config.hidden_size, ds_config.hidden_size,
**kwargs) **kwargs)
input_mask = torch.randn(bsz, 1, 1, ds_config.max_seq_length, **kwargs) input_mask = torch.randn(bsz, 1, 1,
seq_len, #ds_config.max_seq_length,
# run baseline **kwargs)
base_results = bert_encoder(hidden_states,
input_mask, # run baseline
output_all_encoded_layers=False, base_results = bert_encoder(hidden_states,
checkpoint_activations=False) input_mask,
output_all_encoded_layers=False,
# run ds checkpoint_activations=False)
ds_results = ds_encoder(hidden_states,
input_mask, # run ds
output_all_encoded_layers=False, ds_results = ds_encoder(hidden_states,
checkpoint_activations=False) input_mask,
output_all_encoded_layers=False,
# check grads checkpoint_activations=False)
check_equal(base_results, ds_results, atol=atol, verbose=verbose)
# check grads
check_equal(base_results, ds_results, atol=atol, verbose=verbose)
# FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[ # FP16 test cases can only run on the devices support FP16.
(64,1024,128,16,3,True,False), @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
(64,1024,128,16,3,True,True), [
(8,1024,384,16,3,True,False), (64,1024,128,16,3,True,False),
(8,1024,384,16,3,True,True), (64,1024,128,16,3,True,True),
(8,1024,512,16,3,True,False), (8,1024,384,16,3,True,False),
(8,1024,512,16,3,True,True), (8,1024,384,16,3,True,True),
(64,1024,128,16,3,False,False), (8,1024,384,16,3,True,True),
(64,1024,128,16,3,False,True), (8,1024,120,16,3,True,False),
(8,1024,384,16,3,False,False), (8,1024,120,16,3,True,True),
(8,1024,384,16,3,False,True), (8,1024,512,16,3,True,False),
(8,1024,512,16,3,False,False), (8,1024,512,16,3,True,True),
(8,1024,512,16,3,False,True), (64,1024,56,16,3,False,False),
(8,1536,128,24,3,False,False), (64,1024,56,16,3,False,True),
(8,1536,128,24,3,False,True), (64,1024,24,16,3,False,False),
(8,2048,128,32,3,False,False), (64,1024,24,16,3,False,True),
(8,2048,128,32,3,False,True), (8,1024,384,16,3,False,False),
(8,2560,128,40,3,False,False), (8,1024,384,16,3,False,True),
(8,2560,128,40,3,False,True), (8,1024,512,16,3,False,False),
]) # yapf: disable (8,1024,512,16,3,False,True),
def test_forward(batch_size, (8,1536,128,24,3,False,False),
hidden_size, (8,1536,128,24,3,False,True),
seq_len, (8,2048,128,32,3,False,False),
heads, (8,2048,128,32,3,False,True),
num_layers, (8,2560,128,40,3,False,False),
is_preln, (8,2560,128,40,3,False,True),
use_fp16): ]) # yapf: disable
# Only run fp16 test cases on devices with 7+ capability. def test_forward(batch_size,
major, _ = torch.cuda.get_device_capability() hidden_size,
if major < 7 and use_fp16 is True: seq_len,
return heads,
num_layers,
ds_config = DeepSpeedTransformerConfig() is_preln,
ds_config.layer_id = None use_fp16):
ds_config.batch_size = batch_size # Only run fp16 test cases on devices with 7+ capability.
ds_config.hidden_size = hidden_size major, _ = torch.cuda.get_device_capability()
ds_config.intermediate_size = 4 * hidden_size if major < 7 and use_fp16 is True:
ds_config.max_seq_length = seq_len return
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0 ds_config = DeepSpeedTransformerConfig()
ds_config.hidden_dropout_ratio = 0.0 ds_config.layer_id = None
ds_config.num_hidden_layers = num_layers ds_config.batch_size = batch_size
ds_config.pre_layer_norm = is_preln ds_config.hidden_size = hidden_size
ds_config.initializer_range = 0.02 ds_config.max_seq_length = 128 #seq_len
ds_config.fp16 = use_fp16 ds_config.intermediate_size = 4 * hidden_size
ds_config.heads = heads
run_forward(ds_config, atol=2e-2) ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
ds_config.num_hidden_layers = num_layers
@pytest.mark.parametrize('batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', ds_config.pre_layer_norm = is_preln
[ ds_config.initializer_range = 0.02
(8,3,1024,512,16,3,True,False), ds_config.fp16 = use_fp16
(8,7,1024,512,16,3,True,True),
(8,3,1024,512,16,3,False,False), run_forward(ds_config, seq_len, atol=2e-2)
(8,7,1024,512,16,3,False,True),
]) # yapf: disable
def test_forward_with_small_bsz(batch_size, @pytest.mark.parametrize('batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
small_bsz, [
hidden_size, (8,3,1024,512,16,3,True,False),
seq_len, (8,7,1024,512,16,3,True,True),
heads, (8,3,1024,512,16,3,False,False),
num_layers, (8,7,1024,512,16,3,False,True),
is_preln, ]) # yapf: disable
use_fp16): def test_forward_with_small_bsz(batch_size,
# Only run fp16 test cases on devices with 7+ capability. small_bsz,
major, _ = torch.cuda.get_device_capability() hidden_size,
if major < 7 and use_fp16 is True: seq_len,
return heads,
num_layers,
ds_config = DeepSpeedTransformerConfig() is_preln,
ds_config.layer_id = None use_fp16):
ds_config.batch_size = batch_size # Only run fp16 test cases on devices with 7+ capability.
ds_config.hidden_size = hidden_size major, _ = torch.cuda.get_device_capability()
ds_config.intermediate_size = 4 * hidden_size if major < 7 and use_fp16 is True:
ds_config.max_seq_length = seq_len return
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0 ds_config = DeepSpeedTransformerConfig()
ds_config.hidden_dropout_ratio = 0.0 ds_config.layer_id = None
ds_config.num_hidden_layers = num_layers ds_config.batch_size = batch_size
ds_config.pre_layer_norm = is_preln ds_config.hidden_size = hidden_size
ds_config.initializer_range = 0.02 ds_config.intermediate_size = 4 * hidden_size
ds_config.fp16 = use_fp16 ds_config.max_seq_length = seq_len
ds_config.heads = heads
run_forward(ds_config, atol=2e-2, test_bsz=small_bsz) ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', ds_config.num_hidden_layers = num_layers
[ ds_config.pre_layer_norm = is_preln
(64,1024,128,16,3,True,False), ds_config.initializer_range = 0.02
(64,1024,128,16,3,True,True), ds_config.fp16 = use_fp16
(64,1024,128,16,3,False,False),
(64,1024,128,16,3,False,True), run_forward(ds_config, seq_len, atol=2e-2, test_bsz=small_bsz)
]) # yapf: disable
def test_forward_stochastic(batch_size, @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
hidden_size, [
seq_len, (64,1024,128,16,3,True,False),
heads, (64,1024,128,16,3,True,True),
num_layers, (64,1024,128,16,3,False,False),
is_preln, (64,1024,128,16,3,False,True),
use_fp16): ]) # yapf: disable
# Only run fp16 test cases on devices with 7+ capability. def test_forward_stochastic(batch_size,
major, _ = torch.cuda.get_device_capability() hidden_size,
if major < 7 and use_fp16 is True: seq_len,
return heads,
num_layers,
ds_config = DeepSpeedTransformerConfig() is_preln,
ds_config.layer_id = None use_fp16):
ds_config.batch_size = batch_size # Only run fp16 test cases on devices with 7+ capability.
ds_config.hidden_size = hidden_size major, _ = torch.cuda.get_device_capability()
ds_config.intermediate_size = 4 * hidden_size if major < 7 and use_fp16 is True:
ds_config.max_seq_length = seq_len return
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0 ds_config = DeepSpeedTransformerConfig()
ds_config.hidden_dropout_ratio = 0.0 ds_config.layer_id = None
ds_config.num_hidden_layers = num_layers ds_config.batch_size = batch_size
ds_config.pre_layer_norm = is_preln ds_config.hidden_size = hidden_size
ds_config.initializer_range = 0.02 ds_config.intermediate_size = 4 * hidden_size
ds_config.fp16 = use_fp16 ds_config.max_seq_length = seq_len
ds_config.stochastic_mode = True ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
run_forward(ds_config, atol=7e-2) ds_config.hidden_dropout_ratio = 0.0
ds_config.num_hidden_layers = num_layers
ds_config.pre_layer_norm = is_preln
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
ds_config.stochastic_mode = True
run_forward(ds_config, seq_len, atol=7e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment