Unverified Commit 8295d7a8 authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

Fixing gelu_checkpointing memory issue (#812)

* fixing buffers in transformer kernel when gelu-checkpoint is enabled

* fixing the test issue for other memory optimization flags

* fixing a bug for when attn_dropout_checkpoint is enabled
parent 937c5cee
......@@ -31,7 +31,8 @@ size_t get_workspace_size(int maxBatchSize,
if (training) {
workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
if (gelu_checkpoint)
workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size);
}
return workSpacesize; // * sizeof(T);
}
......@@ -178,9 +179,17 @@ void BertTransformerLayer<T>::Forward(int bsz,
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
T* buf_2 = buf_1;
if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size;
if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size;
if (_normalize_invertible) {
add_res_ptr = buf_1 + 3 * small_buf_size;
buf_2 = add_res_ptr;
}
if (_gelu_checkpoint) buf_2 += small_buf_size;
if (_attn_dropout_checkpoint)
ctx_bufB_ptr =
(_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size)
: (buf_1 + 4 * small_buf_size));
int bsz_seq = bsz * _seq_length;
......@@ -257,14 +266,11 @@ void BertTransformerLayer<T>::Forward(int bsz,
_gelu.ForwardWithBiasAdd(bsz_seq,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr,
(_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
_stream);
_ff2.Forward(bsz_seq,
(_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr),
output_w_ptr,
out_ptr,
_cublasHandle);
_ff2.Forward(
bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle);
// layer output dropout.
if (_pre_or_postLayerNorm)
......@@ -336,7 +342,7 @@ void BertTransformerLayer<T>::Backward(int bsz,
T* buf_2 = buf_1 + small_buf_size;
T* buf_3 = buf_2 + small_buf_size;
T* ff2_buf = (_gelu_checkpoint ? buf_2 + (bsz * _seq_length * _intermediate_size)
T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size)
: buf_3 + small_buf_size);
T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment