"backend/apps/vscode:/vscode.git/clone" did not exist on "984dbf13abb15a933fc044e6d11a31cf46860823"
Unverified Commit a9cfbfd3 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Improve memory usage in backward of LayerNormLinear and LayerNormMLP (#509)



Improve PyTorch memory usage
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent bfaec644
...@@ -29,6 +29,7 @@ from ..utils import ( ...@@ -29,6 +29,7 @@ from ..utils import (
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -40,13 +41,13 @@ from ..distributed import ( ...@@ -40,13 +41,13 @@ from ..distributed import (
) )
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ._common import _apply_normalization from ._common import _apply_normalization
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
class _LayerNormLinear(torch.autograd.Function): class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module """LayerNormLinear semi-top level module
Calls custom cuda extensions. Calls custom cuda extensions.
...@@ -355,7 +356,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -355,7 +356,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad") ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else: else:
dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device) dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
if ctx.fp8: if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype( fp8_dtype_forward = get_fp8_te_dtype(
...@@ -393,6 +394,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -393,6 +394,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta_tensor = meta_tensor, fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type, D_dtype = out_te_type,
) )
clear_tensor_data(grad_output_c)
else: else:
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm( _, _, _ = tex.gemm(
...@@ -453,6 +455,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -453,6 +455,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor extra_output_tensor=extra_output_tensor
) )
clear_tensor_data(ln_out_total_t, grad_output_t)
else: else:
ln_out_total_c = tex.cast_from_fp8( ln_out_total_c = tex.cast_from_fp8(
ln_out_total, ln_out_total,
...@@ -475,6 +478,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -475,6 +478,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor extra_output_tensor=extra_output_tensor
) )
clear_tensor_data(ln_out_total_c)
else: else:
# WGRAD # WGRAD
wgrad, grad_bias, _ = tex.gemm( wgrad, grad_bias, _ = tex.gemm(
...@@ -490,6 +494,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -490,6 +494,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
clear_tensor_data(ln_out_total)
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
...@@ -501,25 +506,24 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -501,25 +506,24 @@ class _LayerNormLinear(torch.autograd.Function):
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
d_ln_out = dgrad.view(inputmat.shape) dgrad = dgrad.view(inputmat.shape)
# Residual gradient # Residual gradient
if ctx.return_layernorm_output: if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) dgrad = dgrad + grad_outputs[1].view_as(dgrad)
if ctx.normalization == "LayerNorm": if ctx.normalization == "LayerNorm":
dxmat, dgamma, dbeta = tex.layernorm_bwd( dgrad, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, dgrad, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
dxmat, dgamma = tex.rmsnorm_bwd( dgrad, dgamma = tex.rmsnorm_bwd(
d_ln_out, inputmat, rsigma, ln_weight, dgrad, inputmat, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
dbeta = None dbeta = None
if not ctx.use_bias: if not ctx.use_bias:
grad_bias = None grad_bias = None
...@@ -538,7 +542,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -538,7 +542,7 @@ class _LayerNormLinear(torch.autograd.Function):
wgrad = None wgrad = None
return ( return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
dbeta, dbeta,
wgrad, wgrad,
......
...@@ -32,6 +32,7 @@ from ..utils import ( ...@@ -32,6 +32,7 @@ from ..utils import (
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -276,6 +277,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -276,6 +277,8 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_lnout if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
if not is_grad_enabled:
clear_tensor_data(ln_out_total)
gelu_out = activation_func( gelu_out = activation_func(
fc1_out, fc1_out,
...@@ -283,6 +286,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -283,6 +286,8 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8FwdTensors.GEMM2_INPUT, tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) )
if not is_grad_enabled:
clear_tensor_data(fc1_out)
fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = (
None, None, None, activation_dtype) None, None, None, activation_dtype)
...@@ -329,6 +334,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -329,6 +334,8 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta_tensor = fc2_meta_tensor, fp8_meta_tensor = fc2_meta_tensor,
D_dtype = fc2_te_type, D_dtype = fc2_te_type,
) )
if not is_grad_enabled:
clear_tensor_data(gelu_out)
else: else:
# Cast for native AMP # Cast for native AMP
fc1_weight = cast_if_needed(fc1_weight, activation_dtype) fc1_weight = cast_if_needed(fc1_weight, activation_dtype)
...@@ -360,6 +367,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -360,6 +367,8 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_lnout if ub_split_ag else None, ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None, extra_output_tensor=ln_out if ub_split_ag else None,
) )
if not is_grad_enabled:
clear_tensor_data(ln_out_total)
if bias_gelu_nvfusion: if bias_gelu_nvfusion:
fc1_out, _, _ = fc1_outputs fc1_out, _, _ = fc1_outputs
...@@ -373,6 +382,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -373,6 +382,8 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
tex.FP8FwdTensors.GEMM2_INPUT, tex.FP8FwdTensors.GEMM2_INPUT,
TE_DType[fc1_out.dtype]) TE_DType[fc1_out.dtype])
if not is_grad_enabled:
clear_tensor_data(fc1_out)
if fp8_calibration: if fp8_calibration:
# amax of fc2 input # amax of fc2 input
...@@ -405,6 +416,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -405,6 +416,8 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_fc2out if ub_split_rs else None, ub=ub_obj_fc2out if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None, extra_output_tensor=rs_out if ub_split_rs else None,
) )
if not is_grad_enabled:
clear_tensor_data(gelu_out)
if is_grad_enabled: if is_grad_enabled:
ctx.save_for_backward( ctx.save_for_backward(
...@@ -519,6 +532,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -519,6 +532,7 @@ class _LayerNormMLP(torch.autograd.Function):
) = TransformerEngineBaseModule.grad_output_preprocess( ) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True ctx, grad_outputs[0], True
) )
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1:
...@@ -571,10 +585,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -571,10 +585,13 @@ class _LayerNormMLP(torch.autograd.Function):
) )
if ub_overlap_ag: if ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
clear_tensor_data(grad_output_c)
# FC2 WGRAD # FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad: if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
clear_tensor_data(gelu_out)
fc2_wgrad, _ = tex.fp8_gemm( fc2_wgrad, _ = tex.fp8_gemm(
gelu_out_t, gelu_out_t,
fwd_scale_inverses, fwd_scale_inverses,
...@@ -592,6 +609,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -592,6 +609,7 @@ class _LayerNormMLP(torch.autograd.Function):
else None, else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
clear_tensor_data(gelu_out_t, grad_output_t)
if ctx.activation == 'gelu': if ctx.activation == 'gelu':
fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused( fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused(
...@@ -610,6 +628,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -610,6 +628,7 @@ class _LayerNormMLP(torch.autograd.Function):
tex.FP8BwdTensors.GRAD_OUTPUT2, tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward, fp8_dtype_backward,
) )
clear_tensor_data(fc1_out)
else: else:
if fc2_weight.requires_grad: if fc2_weight.requires_grad:
gelu_out_c = tex.cast_from_fp8( gelu_out_c = tex.cast_from_fp8(
...@@ -619,6 +638,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -619,6 +638,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
TE_DType[ctx.activation_dtype], TE_DType[ctx.activation_dtype],
) )
clear_tensor_data(gelu_out)
fc2_wgrad, _, _ = tex.gemm( fc2_wgrad, _, _ = tex.gemm(
gelu_out_c, gelu_out_c,
grad_output, grad_output,
...@@ -632,6 +652,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -632,6 +652,7 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else None, else None,
) )
clear_tensor_data(gelu_out_c)
if ctx.activation == 'gelu': if ctx.activation == 'gelu':
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
...@@ -642,6 +663,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -642,6 +663,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_out, fc1_out,
TE_DType[fc2_dgrad.dtype]) TE_DType[fc2_dgrad.dtype])
fc1_bias_grad = dgelu_no_fp8.sum(dim=0) fc1_bias_grad = dgelu_no_fp8.sum(dim=0)
clear_tensor_data(fc1_out)
dgelu = tex.cast_to_fp8( dgelu = tex.cast_to_fp8(
dgelu_no_fp8, dgelu_no_fp8,
...@@ -716,21 +738,24 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -716,21 +738,24 @@ class _LayerNormMLP(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
clear_tensor_data(gelu_out)
if ctx.bias_gelu_nvfusion and ctx.activation == 'gelu': if ctx.bias_gelu_nvfusion and ctx.activation == 'gelu':
fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) fc1_bias_grad, fc2_dgrad = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
else: else:
if ctx.activation == 'gelu': if ctx.activation != 'gelu':
dgelu = fc2_dgrad fc2_dgrad = activation_func(fc2_dgrad,
else: fc1_out,
dgelu = activation_func(fc2_dgrad, TE_DType[fc2_dgrad.dtype])
fc1_out,
TE_DType[fc2_dgrad.dtype])
# For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM # For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM
# and will not be calculated in case wgrad is not required. # and will not be calculated in case wgrad is not required.
if not fc1_weight.requires_grad: if not fc1_weight.requires_grad:
fc1_bias_grad = dgelu.sum(dim=0) fc1_bias_grad = fc2_dgrad.sum(dim=0)
# Overwrite data. Deleting the tensor does not release underlying memory.
clear_tensor_data(fc1_out)
dgelu = fc2_dgrad
fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1) fc1_dgrad_size[1] = fc1_weight.size(1)
...@@ -741,6 +766,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -741,6 +766,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad = torch.empty( fc1_dgrad = torch.empty(
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
) )
# FC1 DGRAD: Unconditional # FC1 DGRAD: Unconditional
_ = tex.gemm( _ = tex.gemm(
fc1_weight, fc1_weight,
...@@ -802,6 +828,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -802,6 +828,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor, extra_output_tensor=extra_output_tensor,
) )
clear_tensor_data(ln_out_total_t, dgelu_t)
else: else:
ln_out_total_c = tex.cast_from_fp8( ln_out_total_c = tex.cast_from_fp8(
ln_out_total, ln_out_total,
...@@ -826,6 +853,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -826,6 +853,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
extra_output_tensor=extra_output_tensor, extra_output_tensor=extra_output_tensor,
) )
clear_tensor_data(ln_out_total_c, dgelu_no_fp8)
else: else:
# FC1 WGRAD # FC1 WGRAD
fc1_wgrad_outputs = tex.gemm( fc1_wgrad_outputs = tex.gemm(
...@@ -841,6 +869,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -841,6 +869,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
clear_tensor_data(ln_out_total, dgelu)
if ctx.bias_gelu_nvfusion: if ctx.bias_gelu_nvfusion:
fc1_wgrad, _, _ = fc1_wgrad_outputs fc1_wgrad, _, _ = fc1_wgrad_outputs
...@@ -857,20 +886,20 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -857,20 +886,20 @@ class _LayerNormMLP(torch.autograd.Function):
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
d_ln_out = fc1_dgrad.view(inputmat.shape) dgrad = fc1_dgrad.view(inputmat.shape)
# Residual gradient # Residual gradient
if ctx.return_layernorm_output: if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) dgrad = dgrad + grad_outputs[1].view_as(dgrad)
if ctx.normalization == "LayerNorm": if ctx.normalization == "LayerNorm":
dxmat, dgamma, dbeta = tex.layernorm_bwd( dgrad, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, dgrad, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
dxmat, dgamma = tex.rmsnorm_bwd( dgrad, dgamma = tex.rmsnorm_bwd(
d_ln_out, inputmat, rsigma, ln_weight, dgrad, inputmat, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
dbeta = None dbeta = None
...@@ -904,7 +933,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -904,7 +933,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_wgrad = None fc2_wgrad = None
return ( return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
dbeta, dbeta,
fc1_wgrad, fc1_wgrad,
......
...@@ -26,6 +26,7 @@ from ..utils import ( ...@@ -26,6 +26,7 @@ from ..utils import (
get_default_init_method, get_default_init_method,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -431,6 +432,7 @@ class _Linear(torch.autograd.Function): ...@@ -431,6 +432,7 @@ class _Linear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
clear_tensor_data(inputmat_t_total)
else: else:
wgrad, _, _ = gemm( wgrad, _, _ = gemm(
inputmat_total, inputmat_total,
...@@ -442,6 +444,7 @@ class _Linear(torch.autograd.Function): ...@@ -442,6 +444,7 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
clear_tensor_data(inputmat_total)
else: else:
# WGRAD # WGRAD
wgrad, grad_bias, _ = gemm( wgrad, grad_bias, _ = gemm(
...@@ -455,6 +458,7 @@ class _Linear(torch.autograd.Function): ...@@ -455,6 +458,7 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
clear_tensor_data(inputmat_total)
# Column Parallel Linear # Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
......
...@@ -8,6 +8,18 @@ from typing import Any, Callable, Optional, Tuple ...@@ -8,6 +8,18 @@ from typing import Any, Callable, Optional, Tuple
import torch import torch
def clear_tensor_data(*tensors: Tuple[torch.Tensor, ...]) -> None:
"""
Trick to deallocate tensor memory when delete operation does not
release the tensor due to PyTorch override.
Must be used carefully.
"""
for t in tensors:
t.data = torch.Tensor()
del t
def get_device_compute_capability() -> Tuple[int, int]: def get_device_compute_capability() -> Tuple[int, int]:
"""CUDA compute capability of current GPU""" """CUDA compute capability of current GPU"""
props = torch.cuda.get_device_properties(torch.cuda.current_device()) props = torch.cuda.get_device_properties(torch.cuda.current_device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment