"tools/git@developer.sourcefind.cn:OpenDAS/vllm-omni.git" did not exist on "c1cacde61d61a2d1d933bf253f44151c8ea135fd"
Unverified Commit a9cfbfd3 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Improve memory usage in backward of LayerNormLinear and LayerNormMLP (#509)



Improve PyTorch memory usage
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent bfaec644
......@@ -29,6 +29,7 @@ from ..utils import (
get_default_init_method,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -40,13 +41,13 @@ from ..distributed import (
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ._common import _apply_normalization
from ..float8_tensor import Float8Tensor
__all__ = ["LayerNormLinear"]
class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module
Calls custom cuda extensions.
......@@ -355,7 +356,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
......@@ -393,6 +394,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type,
)
clear_tensor_data(grad_output_c)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm(
......@@ -453,6 +455,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor
)
clear_tensor_data(ln_out_total_t, grad_output_t)
else:
ln_out_total_c = tex.cast_from_fp8(
ln_out_total,
......@@ -475,6 +478,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor
)
clear_tensor_data(ln_out_total_c)
else:
# WGRAD
wgrad, grad_bias, _ = tex.gemm(
......@@ -490,6 +494,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
clear_tensor_data(ln_out_total)
if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
......@@ -501,25 +506,24 @@ class _LayerNormLinear(torch.autograd.Function):
handle.wait()
# LayerNorm gradient
d_ln_out = dgrad.view(inputmat.shape)
dgrad = dgrad.view(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
if ctx.normalization == "LayerNorm":
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
elif ctx.normalization == "RMSNorm":
dxmat, dgamma = tex.rmsnorm_bwd(
d_ln_out, inputmat, rsigma, ln_weight,
dgrad, dgamma = tex.rmsnorm_bwd(
dgrad, inputmat, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
if not ctx.use_bias:
grad_bias = None
......@@ -538,7 +542,7 @@ class _LayerNormLinear(torch.autograd.Function):
wgrad = None
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
wgrad,
......
......@@ -32,6 +32,7 @@ from ..utils import (
get_default_init_method,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -276,6 +277,8 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
if not is_grad_enabled:
clear_tensor_data(ln_out_total)
gelu_out = activation_func(
fc1_out,
......@@ -283,6 +286,8 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
)
if not is_grad_enabled:
clear_tensor_data(fc1_out)
fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = (
None, None, None, activation_dtype)
......@@ -329,6 +334,8 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta_tensor = fc2_meta_tensor,
D_dtype = fc2_te_type,
)
if not is_grad_enabled:
clear_tensor_data(gelu_out)
else:
# Cast for native AMP
fc1_weight = cast_if_needed(fc1_weight, activation_dtype)
......@@ -360,6 +367,8 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
if not is_grad_enabled:
clear_tensor_data(ln_out_total)
if bias_gelu_nvfusion:
fc1_out, _, _ = fc1_outputs
......@@ -373,6 +382,8 @@ class _LayerNormMLP(torch.autograd.Function):
None,
tex.FP8FwdTensors.GEMM2_INPUT,
TE_DType[fc1_out.dtype])
if not is_grad_enabled:
clear_tensor_data(fc1_out)
if fp8_calibration:
# amax of fc2 input
......@@ -405,6 +416,8 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_fc2out if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
if not is_grad_enabled:
clear_tensor_data(gelu_out)
if is_grad_enabled:
ctx.save_for_backward(
......@@ -519,6 +532,7 @@ class _LayerNormMLP(torch.autograd.Function):
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True
)
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
......@@ -571,10 +585,13 @@ class _LayerNormMLP(torch.autograd.Function):
)
if ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
clear_tensor_data(grad_output_c)
# FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
clear_tensor_data(gelu_out)
fc2_wgrad, _ = tex.fp8_gemm(
gelu_out_t,
fwd_scale_inverses,
......@@ -592,6 +609,7 @@ class _LayerNormMLP(torch.autograd.Function):
else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
clear_tensor_data(gelu_out_t, grad_output_t)
if ctx.activation == 'gelu':
fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused(
......@@ -610,6 +628,7 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
clear_tensor_data(fc1_out)
else:
if fc2_weight.requires_grad:
gelu_out_c = tex.cast_from_fp8(
......@@ -619,6 +638,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
clear_tensor_data(gelu_out)
fc2_wgrad, _, _ = tex.gemm(
gelu_out_c,
grad_output,
......@@ -632,6 +652,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation
else None,
)
clear_tensor_data(gelu_out_c)
if ctx.activation == 'gelu':
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
......@@ -642,6 +663,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out,
TE_DType[fc2_dgrad.dtype])
fc1_bias_grad = dgelu_no_fp8.sum(dim=0)
clear_tensor_data(fc1_out)
dgelu = tex.cast_to_fp8(
dgelu_no_fp8,
......@@ -716,21 +738,24 @@ class _LayerNormMLP(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
clear_tensor_data(gelu_out)
if ctx.bias_gelu_nvfusion and ctx.activation == 'gelu':
fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
fc1_bias_grad, fc2_dgrad = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
else:
if ctx.activation == 'gelu':
dgelu = fc2_dgrad
else:
dgelu = activation_func(fc2_dgrad,
fc1_out,
TE_DType[fc2_dgrad.dtype])
if ctx.activation != 'gelu':
fc2_dgrad = activation_func(fc2_dgrad,
fc1_out,
TE_DType[fc2_dgrad.dtype])
# For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM
# and will not be calculated in case wgrad is not required.
if not fc1_weight.requires_grad:
fc1_bias_grad = dgelu.sum(dim=0)
fc1_bias_grad = fc2_dgrad.sum(dim=0)
# Overwrite data. Deleting the tensor does not release underlying memory.
clear_tensor_data(fc1_out)
dgelu = fc2_dgrad
fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1)
......@@ -741,6 +766,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
)
# FC1 DGRAD: Unconditional
_ = tex.gemm(
fc1_weight,
......@@ -802,6 +828,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
)
clear_tensor_data(ln_out_total_t, dgelu_t)
else:
ln_out_total_c = tex.cast_from_fp8(
ln_out_total,
......@@ -826,6 +853,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor,
)
clear_tensor_data(ln_out_total_c, dgelu_no_fp8)
else:
# FC1 WGRAD
fc1_wgrad_outputs = tex.gemm(
......@@ -841,6 +869,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
clear_tensor_data(ln_out_total, dgelu)
if ctx.bias_gelu_nvfusion:
fc1_wgrad, _, _ = fc1_wgrad_outputs
......@@ -857,20 +886,20 @@ class _LayerNormMLP(torch.autograd.Function):
handle.wait()
# LayerNorm gradient
d_ln_out = fc1_dgrad.view(inputmat.shape)
dgrad = fc1_dgrad.view(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
if ctx.normalization == "LayerNorm":
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
elif ctx.normalization == "RMSNorm":
dxmat, dgamma = tex.rmsnorm_bwd(
d_ln_out, inputmat, rsigma, ln_weight,
dgrad, dgamma = tex.rmsnorm_bwd(
dgrad, inputmat, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
......@@ -904,7 +933,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_wgrad = None
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
fc1_wgrad,
......
......@@ -26,6 +26,7 @@ from ..utils import (
get_default_init_method,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -431,6 +432,7 @@ class _Linear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
clear_tensor_data(inputmat_t_total)
else:
wgrad, _, _ = gemm(
inputmat_total,
......@@ -442,6 +444,7 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
clear_tensor_data(inputmat_total)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
......@@ -455,6 +458,7 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
clear_tensor_data(inputmat_total)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
......
......@@ -8,6 +8,18 @@ from typing import Any, Callable, Optional, Tuple
import torch
def clear_tensor_data(*tensors: Tuple[torch.Tensor, ...]) -> None:
"""
Trick to deallocate tensor memory when delete operation does not
release the tensor due to PyTorch override.
Must be used carefully.
"""
for t in tensors:
t.data = torch.Tensor()
del t
def get_device_compute_capability() -> Tuple[int, int]:
"""CUDA compute capability of current GPU"""
props = torch.cuda.get_device_properties(torch.cuda.current_device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment