Unverified Commit 93a67af8 authored by yuzhongw-nvidia's avatar yuzhongw-nvidia Committed by GitHub
Browse files

Fix memory overhead of linear layer when all gather from sequence parallel (#2125)



* fix memory overhead of all gather from sequence parallel
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* quick fix the errors when for UB buffers
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/module/linear.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Avoid deallocating FP8 scale-invs since they are reused
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 7042d7ae
......@@ -353,8 +353,11 @@ class _LayerNormLinear(torch.autograd.Function):
# Deallocate GEMM input tensor if no longer needed
if not weight.requires_grad and not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
ln_out = ln_out_total = None
elif with_input_all_gather and not return_layernorm_output_gathered:
clear_tensor_data(ln_out_total)
ln_out_total = None
# ------------------------------------------------------
# Prepare output tensor
......@@ -891,9 +894,19 @@ class _LayerNormLinear(torch.autograd.Function):
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor if permitted
if not ctx.return_layernorm_output:
# Deallocate input tensors if permitted
if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
# Input tensors have not been exposed externally
clear_tensor_data(ln_out)
elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered:
# Non-gathered input has not been exposed externally
clear_tensor_data(ln_out)
if ctx.ln_out_needs_gather:
# Gathered input is internal
clear_tensor_data(ln_out_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
clear_tensor_data(grad_output)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
......@@ -1169,7 +1182,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.return_layernorm_output_gathered = (
return_layernorm_output_gathered if return_layernorm_output else False
)
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
......
......@@ -317,6 +317,13 @@ class _Linear(torch.autograd.Function):
# Finished forward GEMM...
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
# TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
if with_input_all_gather_nccl:
clear_tensor_data(inputmat_total)
inputmat_total = None
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
......@@ -878,9 +885,16 @@ class _Linear(torch.autograd.Function):
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor if permitted
# Deallocate tensors if permitted
if ctx.owns_input:
# Input tensor is internal
clear_tensor_data(inputmat_total)
elif ctx.backward_input_needs_gather:
# Gathered input tensor is internal
clear_tensor_data(inputmat_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
clear_tensor_data(grad_output)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
......
......@@ -349,9 +349,14 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def _transpose_columnwise_data(self):
"""Plainly transpose the columnwise data and scale inv."""
if self._columnwise_data is not None:
# TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
_old_data = self._columnwise_data
self._columnwise_data = tex.fp8_transpose(
self._columnwise_data, self._fp8_dtype, out=None
)
_old_data.data = _empty_tensor()
del _old_data
def __repr__(self):
if self._rowwise_data is not None:
......
......@@ -95,8 +95,13 @@ class Float8TensorBase(QuantizedTensorBase):
return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (self._data, self._transpose, self._scale_inv):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully.
Scale-inv tensor is not deallocated because it's often shared
between multiple FP8 tensors.
"""
for t in (self._data, self._transpose):
if t is not None:
t.data = _empty_tensor()
self._transpose_invalid = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment