Unverified Commit f0f2a702 authored by RezaYazdaniAminabadi's avatar RezaYazdaniAminabadi Committed by GitHub
Browse files

support dynamic sequence length in transformer kernels (#424)


Co-authored-by: default avatarConglong Li <conglong.li@gmail.com>
parent 71f7df39
......@@ -29,7 +29,7 @@
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 4096
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
......
......@@ -29,7 +29,6 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
template <typename T>
......@@ -37,7 +36,6 @@ void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
template <typename T>
......@@ -46,7 +44,6 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream);
// Custom fused bias add with layer normalization
......@@ -57,14 +54,12 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training = false,
T* vars = nullptr,
T* means = nullptr,
T* vals_hat = nullptr);
bool training,
T* vars,
T* means);
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
......@@ -73,14 +68,11 @@ void launch_bias_residual_layer_norm(T* vals,
const T* beta,
float epsilon,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream,
bool preLayerNorm,
bool training = false,
T* vars = nullptr,
T* vals_hat = nullptr,
bool save_vals = false);
bool training,
T* vars);
template <typename T>
void launch_layerNorm_backward_fused_add(const T* out_grad1,
......@@ -93,7 +85,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
template <typename T>
......@@ -106,7 +97,6 @@ void launch_layerNorm_backward_fused_add(const T* out_grad1,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
......@@ -122,7 +112,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
......@@ -135,7 +124,6 @@ void launch_layerNorm_backward(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2],
bool invertible = false,
......@@ -153,7 +141,6 @@ void launch_layerNorm_backward_nreversible(const T* out_grad,
T* betta_grad,
T* inp_grad,
int batch_size,
int sequence_length,
int hidden_dim,
cudaStream_t stream[2]);
......
......@@ -9,15 +9,13 @@ class Dropout {
public:
struct Config {
float ratio;
uint32_t batch, dim;
uint32_t dim;
bool training;
Config(float r, uint32_t batch, uint32_t dim)
: ratio(r), batch(batch), dim(dim), training(true)
{
}
Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {}
float RATIO() const { return training ? ratio : 0.0; }
inline void SetDim(uint32_t d) { dim = d; }
};
Dropout(const Config& config) : _config(config), _mask(nullptr) {}
......@@ -70,6 +68,8 @@ public:
Config GetConfig() const { return _config; }
inline void SetDimension(uint32_t dim) { _config.SetDim(dim); }
private:
uint8_t* _mask;
Config _config;
......
......@@ -121,11 +121,17 @@ public:
void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr);
uint8_t* layer_output_dropout_mask_ptr,
T* layer_norm_var,
T* layer_norm_mean,
T* attn_layer_norm_var,
T* attn_layer_norm_mean);
inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
void SetSeqLength(int seq_len, int bsz);
inline int GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
......@@ -150,8 +156,8 @@ private:
// layers
FeedForward<T> _qkv_linear;
FeedForward<T> _attn_out_linear;
Normalize_Layer<T> _norm_layer2;
Normalize_Layer<T> _norm_layer3;
Normalize_Layer<T> _attn_layer_norm;
Normalize_Layer<T> _layer_norm;
Normalize_Layer<T>* _last_normalize;
FeedForward<T> _ff1, _ff2;
Softmax<T> _softmax;
......
......@@ -9,13 +9,8 @@ template <typename T>
class Gelu {
public:
struct Config {
uint32_t batch_size;
uint32_t seq_length;
uint32_t intermediate_size;
Config(uint32_t batch, uint32_t seq, uint32_t inter_size)
: batch_size(batch), seq_length(seq), intermediate_size(inter_size)
{
}
Config(uint32_t inter_size) : intermediate_size(inter_size) {}
};
Gelu(const Config& config) : _config(config) {}
......@@ -28,14 +23,12 @@ public:
T* output,
cudaStream_t stream)
{
launch_bias_gelu<T>(
input_buf, bias, output, _config.intermediate_size, bsz, _config.seq_length, stream);
launch_bias_gelu<T>(input_buf, bias, output, _config.intermediate_size, bsz, stream);
}
void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream)
{
launch_d_gelu<T>(
d_output, input_buf, bias, _config.intermediate_size, bsz, _config.seq_length, stream);
launch_d_gelu<T>(d_output, input_buf, bias, _config.intermediate_size, bsz, stream);
}
private:
......
......@@ -16,57 +16,27 @@ public:
uint32_t seqLength;
uint32_t hiddenDim;
float epsilon;
bool training, save_vals;
bool allocateGrad;
bool training;
bool useMean;
Config(uint32_t batch,
uint32_t seq,
uint32_t h,
bool training,
bool save_vals = true,
bool allocateGrad = true,
bool useMean = true)
Config(uint32_t batch, uint32_t seq, uint32_t h, bool training, bool useMean = true)
: batchSize(batch),
seqLength(seq),
hiddenDim(h),
epsilon(1e-12),
training(training),
save_vals(save_vals),
allocateGrad(allocateGrad),
useMean(useMean)
{
}
};
Normalize_Layer(Config config) : config_(config), vars(nullptr), vals_hat(nullptr)
Normalize_Layer(Config config)
: config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
{
if (config_.training) {
cudaMalloc((void**)&vars, config_.batchSize * config_.seqLength * sizeof(T));
if (config_.useMean)
cudaMalloc((void**)&means, config_.batchSize * config_.seqLength * sizeof(T));
if (config_.save_vals)
cudaMalloc((void**)&vals_hat,
config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
if (config_.allocateGrad)
cudaMalloc((void**)&inp_grad,
config_.batchSize * config_.seqLength * config_.hiddenDim * sizeof(T));
}
}
~Normalize_Layer()
{
if (config_.training) {
cudaFree(vars);
if (config_.useMean) cudaFree(means);
if (config_.save_vals) cudaFree(vals_hat);
if (config_.allocateGrad) cudaFree(inp_grad);
}
}
~Normalize_Layer() {}
void ForwardCheckpoint(int bsz,
void ForwardCheckpoint(int bsz, // batch * seq
T* vals,
const T* residual,
const T* gamma,
......@@ -80,14 +50,12 @@ public:
betta,
config_.epsilon,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
means,
vals_hat);
means);
}
void Forward(int bsz,
......@@ -104,14 +72,11 @@ public:
betta,
config_.epsilon,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
vals_hat,
config_.save_vals);
vars);
}
void Backward(int bsz,
......@@ -120,7 +85,7 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward(out_grad,
......@@ -130,9 +95,8 @@ public:
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream);
}
......@@ -144,21 +108,20 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
const T* norm_out = nullptr)
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward(out_grad,
(config_.save_vals ? vals_hat : norm_out),
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
config_.save_vals,
!config_.useMean,
betta);
}
......@@ -169,7 +132,7 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward_fused_add(out_grad1,
......@@ -180,9 +143,8 @@ public:
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream);
}
......@@ -195,33 +157,41 @@ public:
T* gamma_grad,
T* betta_grad,
cudaStream_t stream[2],
T* inp_grad_out = nullptr,
const T* norm_out = nullptr)
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward_fused_add(out_grad1,
out_grad2,
(config_.save_vals ? vals_hat : norm_out),
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
(config_.allocateGrad ? inp_grad : inp_grad_out),
inp_grad_out,
bsz,
config_.seqLength,
config_.hiddenDim,
stream,
config_.save_vals,
!config_.useMean,
betta);
}
inline T* GetInputGrad() const { return inp_grad; }
inline bool UseMean() const { return config_.useMean; }
inline void SetVar(T* variance)
{
if (!variance) { throw std::runtime_error("Normalize variance is null."); }
vars = variance;
}
inline void SetMean(T* mean)
{
if (!mean) { throw std::runtime_error("Normalize mean is null."); }
means = mean;
}
private:
Config config_;
T* vars;
T* means;
T* vals_hat;
T* inp_grad;
};
......@@ -45,13 +45,15 @@ public:
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
}
inline int GetProbDepth() const { return config_.prob_depth; }
inline size_t GetProbDepth() const { return config_.prob_depth; }
inline int GetBatchSize() const { return config_.batchSize; }
inline size_t GetBatchSize() const { return config_.batchSize; }
inline int GetNumHeads() const { return config_.heads; }
inline size_t GetNumHeads() const { return config_.heads; }
inline int GetSeqLength() const { return config_.seq_length; }
inline size_t GetSeqLength() const { return config_.seq_length; }
inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private:
Config config_;
......
......@@ -3,6 +3,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "context.h"
template <typename T>
class StridedBatchGemm {
......@@ -38,6 +39,12 @@ public:
gemm_algos(algos)
{
}
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config& config) : _config(config) {}
......@@ -163,6 +170,8 @@ public:
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
......
......@@ -34,7 +34,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -74,7 +79,12 @@ int cublas_gemm_ex(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -122,7 +132,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
......@@ -170,7 +185,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "!!!! kernel execution error.\n");
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
......
......@@ -78,19 +78,15 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
hidden_size,
hidden_size,
gemm_algos[0])),
_norm_layer2(typename Normalize_Layer<T>::Config(batch_size,
_attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
true,
false,
false,
!normalize_invertible)),
_norm_layer3(typename Normalize_Layer<T>::Config(batch_size,
_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
true,
false,
false,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length,
_intermediate_size,
......@@ -101,16 +97,10 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
_intermediate_size,
gemm_algos[2])),
_softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
_gelu(typename Gelu<T>::Config(_batch_size, _seq_length, _intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio,
_batch_size * _heads * _seq_length,
_seq_length)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio,
_batch_size * _seq_length,
_hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio,
_batch_size * _seq_length,
_hidden_size)),
_gelu(typename Gelu<T>::Config(_intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_seq_length,
_seq_length,
......@@ -196,18 +186,18 @@ void BertTransformerLayer<T>::Forward(int bsz,
if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size;
if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size;
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.ForwardCheckpoint(
bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_norm_layer3.Forward(
bsz, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
_layer_norm.Forward(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm)
_qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
else
......@@ -247,19 +237,19 @@ void BertTransformerLayer<T>::Forward(int bsz,
bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
if (_pre_or_postLayerNorm) {
if (_norm_layer2.UseMean())
_norm_layer2.ForwardCheckpoint(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_norm_layer2.Forward(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} else {
if (_norm_layer2.UseMean())
_norm_layer2.ForwardCheckpoint(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_norm_layer2.Forward(
bsz, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
}
_ff1.Forward(bsz_seq,
......@@ -268,7 +258,7 @@ void BertTransformerLayer<T>::Forward(int bsz,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
_cublasHandle);
_gelu.ForwardWithBiasAdd(bsz,
_gelu.ForwardWithBiasAdd(bsz_seq,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr,
(_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
......@@ -289,11 +279,12 @@ void BertTransformerLayer<T>::Forward(int bsz,
bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
if (!_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.ForwardCheckpoint(
bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_norm_layer3.Forward(bsz, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
_layer_norm.Forward(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
}
......@@ -359,8 +350,8 @@ void BertTransformerLayer<T>::Backward(int bsz,
int bsz_heads = bsz * _heads;
if (!_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.Backward(bsz,
if (_layer_norm.UseMean())
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
grad_norm_w_ptr,
......@@ -370,7 +361,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
inp_norm_ptr);
else
_norm_layer3.Backward(bsz,
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
norm_b_ptr,
......@@ -390,7 +381,8 @@ void BertTransformerLayer<T>::Backward(int bsz,
? buf_0
: (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
if (_gelu_checkpoint) _gelu.ForwardWithBiasAdd(bsz, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
if (_gelu_checkpoint)
_gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
_ff2.Backward(bsz_seq,
layer_dropout_buf,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
......@@ -402,7 +394,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
ff2_buf);
_gelu.Backward(
bsz, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
_ff1.Backward(bsz_seq,
ff2_buf,
......@@ -418,8 +410,8 @@ void BertTransformerLayer<T>::Backward(int bsz,
launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
if (_pre_or_postLayerNorm) {
if (_norm_layer2.UseMean())
_norm_layer2.BackwardFusedAdd(bsz,
if (_attn_layer_norm.UseMean())
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
......@@ -430,7 +422,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
add_res_ptr);
else
_norm_layer2.BackwardFusedAdd(bsz,
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
......@@ -441,8 +433,8 @@ void BertTransformerLayer<T>::Backward(int bsz,
buf_0,
ff1_inp_ptr);
} else {
if (_norm_layer2.UseMean())
_norm_layer2.Backward(bsz,
if (_attn_layer_norm.UseMean())
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
grad_attn_nw_ptr,
......@@ -452,7 +444,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
add_res_ptr);
else
_norm_layer2.Backward(bsz,
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
attn_nb_ptr,
......@@ -525,8 +517,8 @@ void BertTransformerLayer<T>::Backward(int bsz,
buf_2);
if (_pre_or_postLayerNorm) {
if (_norm_layer3.UseMean())
_norm_layer3.BackwardFusedAdd(bsz,
if (_layer_norm.UseMean())
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
......@@ -537,7 +529,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
input_ptr);
else
_norm_layer3.BackwardFusedAdd(bsz,
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
......@@ -563,11 +555,34 @@ void BertTransformerLayer<T>::SetTrainingMode(bool training)
template <typename T>
void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr)
uint8_t* layer_output_dropout_mask_ptr,
T* attn_layer_norm_var,
T* attn_layer_norm_mean,
T* layer_norm_var,
T* layer_norm_mean)
{
_attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
_attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
_layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
_attn_layer_norm.SetVar(attn_layer_norm_var);
_attn_layer_norm.SetMean(attn_layer_norm_mean);
_layer_norm.SetVar(layer_norm_var);
_layer_norm.SetMean(layer_norm_mean);
}
template <typename T>
void BertTransformerLayer<T>::SetSeqLength(int seq_len, int bsz)
{
_seq_length = seq_len;
_softmax.SetSeqLength(_seq_length);
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
Context::Instance().GenWorkSpace(get_workspace_size<T>(
bsz, _seq_length, _hidden_size, _intermediate_size, _heads, _training, _gelu_checkpoint));
}
template <typename T>
......@@ -688,54 +703,61 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
int seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len, bsz);
}
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
auto qkv_tf = torch::empty({(bsz * layer->GetSeqLength()), output_w.size(0) * 3}, options);
auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
auto attn_prob_dropout_mask =
torch::empty({(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
uint8_options);
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
auto attn_output_dropout_mask =
torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options);
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto layer_output_dropout_mask =
torch::empty({(bsz * layer->GetSeqLength()), layer->GetHiddenSize()}, uint8_options);
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
T* inp_norm_ptr = (T*)inp_norm.data_ptr();
T* add_res_ptr = (T*)add_res.data_ptr();
T* q_tf_ptr = (T*)qkv_tf.data_ptr();
T* k_tf_ptr =
q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)k_tf.data_ptr();
T* v_tf_ptr =
k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(T*)v_tf.data_ptr();
T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
torch::Tensor ff2_inp =
torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options);
torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
torch::Tensor gelu_inp =
(gelu_checkpoint
? ff2_inp
: torch::empty({(bsz * layer->GetSeqLength()), output_w.size(1)}, options));
(gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
auto ff1_inp = torch::empty_like(input);
T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
torch::Tensor soft_out = torch::empty(
{(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()}, options);
torch::Tensor soft_out =
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
torch::Tensor ctx_bufB =
(attn_dropout_checkpoint
? soft_out
: torch::empty(
{(bsz * layer->GetNumHeads() * layer->GetSeqLength()), layer->GetSeqLength()},
options));
: torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
T* soft_out_ptr = (T*)soft_out.data_ptr();
T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
layer->SetTrainingMode(training_mode);
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr());
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Forward(bsz,
input_ptr,
......@@ -777,7 +799,11 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask};
layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean};
}
template <typename T>
......@@ -796,6 +822,10 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
const torch::Tensor& attn_prob_dropout_mask,
const torch::Tensor& attn_output_dropout_mask,
const torch::Tensor& layer_output_dropout_mask,
const torch::Tensor& attn_layer_norm_var,
const torch::Tensor& attn_layer_norm_mean,
const torch::Tensor& layer_norm_var,
const torch::Tensor& layer_norm_mean,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
......@@ -839,6 +869,7 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
CHECK_INPUT(norm_b);
int bsz = g_output.size(0);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
......@@ -901,7 +932,11 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr());
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Backward(bsz,
grad_output_ptr,
......
......@@ -279,13 +279,12 @@ void launch_bias_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(input, bias, output, intermediate_size);
}
......@@ -295,24 +294,26 @@ void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(input, output, intermediate_size);
}
template void
launch_bias_gelu<float>(const float*, const float*, float*, int, int, int, cudaStream_t);
template void
launch_bias_gelu<__half>(const __half*, const __half*, __half*, int, int, int, cudaStream_t);
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, int, cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
......@@ -320,17 +321,15 @@ void launch_d_gelu(T* d_output,
const T* bias,
int intermediate_size,
int batch_size,
int sequence_length,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(sequence_length * batch_size);
dim3 grid_dims(batch_size);
d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(d_output, input, bias, intermediate_size);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, int, cudaStream_t);
template void
launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, int, cudaStream_t);
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, cudaStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t);
......@@ -14,16 +14,19 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
......@@ -40,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp,
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
out[pos] = sum;
if (pos < (rows * width)) out[pos] = sum;
}
}
......@@ -58,10 +61,10 @@ void launch_fuse_transpose_bias_kernel<float>(const float* inp,
int cols,
cudaStream_t stream)
{
assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0);
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
......@@ -74,10 +77,10 @@ void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
int cols,
cudaStream_t stream)
{
assert(rows % TILE_DIM == 0);
assert(cols % TILE_DIM == 0);
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim(cols / TILE_DIM);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
......
This diff is collapsed.
#include <math.h>
#include "custom_cuda_layers.h"
#include "general_kernels.h"
......@@ -282,7 +283,7 @@ __global__ void attn_softmax(__half* vals,
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t, bool);
void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
......@@ -294,11 +295,10 @@ void launch_attn_softmax<float>(float* vals,
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -330,8 +330,9 @@ void launch_attn_softmax<float>(float* vals,
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -362,11 +363,10 @@ void launch_attn_softmax<__half>(__half* vals,
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int seq2 = sequence_length * seq_length4;
int block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -399,8 +399,9 @@ void launch_attn_softmax<__half>(__half* vals,
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? ((threads / seq_length4) * seq_length4) : seq_length4);
dim3 grid_dim(batch_size, heads * seq2 / block_compute_size);
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
......@@ -531,55 +532,41 @@ void launch_attn_softmax_backward_v2(T* out_grad,
int seq_length,
cudaStream_t stream)
{
if ((seq_length % WARP_SIZE) != 0 || seq_length > 2048)
throw std::runtime_error("Invalid sequence length found in softmax backward.");
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
switch (seq_length) {
case 32:
if (seq_length <= 32)
softmax_backward_kernel_v2<T, 1>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 64:
else if (seq_length <= 64)
softmax_backward_kernel_v2<T, 2>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 128:
else if (seq_length <= 128)
softmax_backward_kernel_v2<T, 4>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 256:
else if (seq_length <= 256)
softmax_backward_kernel_v2<T, 8>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 384:
else if (seq_length <= 384)
softmax_backward_kernel_v2<T, 12>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 512:
else if (seq_length <= 512)
softmax_backward_kernel_v2<T, 16>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 768:
else if (seq_length <= 768)
softmax_backward_kernel_v2<T, 24>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 1024:
else if (seq_length <= 1024)
softmax_backward_kernel_v2<T, 32>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
case 2048:
else if (seq_length <= 2048)
softmax_backward_kernel_v2<T, 64>
<<<grid_dim, block_dim, 0, stream>>>(out_grad, soft_inp, seq_length);
break;
default:
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
......
......@@ -187,7 +187,11 @@ class DeepSpeedTransformerFunction(Function):
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask) = forward_func(config.layer_id,
layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean) = forward_func(config.layer_id,
input,
input_mask,
attn_qkvw,
......@@ -283,6 +287,9 @@ class DeepSpeedTransformerFunction(Function):
if not config.normalize_invertible:
ctx.add_res = add_res
ctx.attn_layer_norm_mean = attn_layer_norm_mean
ctx.layer_norm_mean = layer_norm_mean
ctx.ff1_inp = ff1_inp
if not config.gelu_checkpoint:
ctx.gelu_inp = gelu_inp
......@@ -291,6 +298,8 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_prob_dropout_mask = attn_prob_dropout_mask
ctx.attn_output_dropout_mask = attn_output_dropout_mask
ctx.layer_output_dropout_mask = layer_output_dropout_mask
ctx.attn_layer_norm_var = attn_layer_norm_var
ctx.layer_norm_var = layer_norm_var
return output
......@@ -367,6 +376,10 @@ class DeepSpeedTransformerFunction(Function):
ctx.attn_prob_dropout_mask,
ctx.attn_output_dropout_mask,
ctx.layer_output_dropout_mask,
ctx.attn_layer_norm_var,
ctx.attn_layer_norm_mean,
ctx.layer_norm_var,
ctx.layer_norm_mean,
(ctx.inp_norm if (ctx.config.pre_layer_norm
and ctx.config.normalize_invertible) else input),
input_mask,
......
......@@ -256,10 +256,10 @@ def run_backward(ds_config, atol=1e-2, verbose=False):
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
[
(3,1024,128,16,24,True,False, 0.05),
(3,1024,128,16,24,True,True, 0.05),
(3,1024,128,16,24,False,False, 0.1),
(3,1024,128,16,24,False,True, 0.2),
(3,1024,120,16,24,True,False, 0.05),
(3,1024,120,16,24,True,True, 0.05),
(3,1024,56,16,24,False,False, 0.1),
(3,1024,56,16,24,False,True, 0.2),
]) # yapf: disable
def test_backward(batch_size,
hidden_size,
......
......@@ -178,7 +178,7 @@ def set_seed(seed):
torch.manual_seed(seed)
def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None):
def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
set_seed(123)
bert_encoder, ds_encoder = create_models(ds_config)
......@@ -187,10 +187,12 @@ def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None):
# prepare test data
kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
hidden_states = torch.randn(bsz,
ds_config.max_seq_length,
seq_len, #ds_config.max_seq_length,
ds_config.hidden_size,
**kwargs)
input_mask = torch.randn(bsz, 1, 1, ds_config.max_seq_length, **kwargs)
input_mask = torch.randn(bsz, 1, 1,
seq_len, #ds_config.max_seq_length,
**kwargs)
# run baseline
base_results = bert_encoder(hidden_states,
......@@ -215,10 +217,15 @@ def run_forward(ds_config, atol=1e-2, verbose=False, test_bsz=None):
(64,1024,128,16,3,True,True),
(8,1024,384,16,3,True,False),
(8,1024,384,16,3,True,True),
(8,1024,384,16,3,True,True),
(8,1024,120,16,3,True,False),
(8,1024,120,16,3,True,True),
(8,1024,512,16,3,True,False),
(8,1024,512,16,3,True,True),
(64,1024,128,16,3,False,False),
(64,1024,128,16,3,False,True),
(64,1024,56,16,3,False,False),
(64,1024,56,16,3,False,True),
(64,1024,24,16,3,False,False),
(64,1024,24,16,3,False,True),
(8,1024,384,16,3,False,False),
(8,1024,384,16,3,False,True),
(8,1024,512,16,3,False,False),
......@@ -246,8 +253,8 @@ def test_forward(batch_size,
ds_config.layer_id = None
ds_config.batch_size = batch_size
ds_config.hidden_size = hidden_size
ds_config.max_seq_length = 128 #seq_len
ds_config.intermediate_size = 4 * hidden_size
ds_config.max_seq_length = seq_len
ds_config.heads = heads
ds_config.attn_dropout_ratio = 0.0
ds_config.hidden_dropout_ratio = 0.0
......@@ -256,7 +263,7 @@ def test_forward(batch_size,
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
run_forward(ds_config, atol=2e-2)
run_forward(ds_config, seq_len, atol=2e-2)
@pytest.mark.parametrize('batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
......@@ -293,7 +300,7 @@ def test_forward_with_small_bsz(batch_size,
ds_config.initializer_range = 0.02
ds_config.fp16 = use_fp16
run_forward(ds_config, atol=2e-2, test_bsz=small_bsz)
run_forward(ds_config, seq_len, atol=2e-2, test_bsz=small_bsz)
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
......@@ -329,4 +336,4 @@ def test_forward_stochastic(batch_size,
ds_config.fp16 = use_fp16
ds_config.stochastic_mode = True
run_forward(ds_config, atol=7e-2)
run_forward(ds_config, seq_len, atol=7e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment