Unverified Commit d76118d9 authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] Linear: fix computation for wgrad if sequence_parallel=True (#531)



* [PyTorch] Linear: fix computation for wgrad if sequence_parallel=True
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* Remove buggy gather_along_last_dim
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [PyTorch] Linear: fix line length
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* Simplify logic for saving input tensor for Linear backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 54e46e21
......@@ -473,29 +473,6 @@ def gather_along_first_dim(
return output, handle
def gather_along_last_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Gather tensors and concatinate along the last dimension."""
world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_, None
dim_size = list(input_.size())
dim_size[-1] = dim_size[-1] * world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op
)
return output, handle
def allreduce(
input_: torch.Tensor,
tp_group: Optional[dist_group_type] = None,
......
......@@ -35,7 +35,6 @@ from ..distributed import (
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim,
gather_along_first_dim,
gather_along_last_dim,
)
from ..cpp_extensions import (
fp8_gemm,
......@@ -102,35 +101,34 @@ class _Linear(torch.autograd.Function):
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
# Cast for native AMP
# Cast input to expected dtype
inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_t = None
inputmat_no_fp8 = inputmat
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad:
if is_grad_enabled:
inputmat, inputmat_t = fp8_cast_transpose_fused(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
and weight.requires_grad
and not sequence_parallel
):
# FP8 input for forward, FP8 input transpose for backward wgrad
inputmat, inputmat_t = fp8_cast_transpose_fused(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat, inputmat_t = cast_to_fp8(
# FP8 input for forward
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
), None
)
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
......@@ -260,10 +258,19 @@ class _Linear(torch.autograd.Function):
)
if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
saved_inputmat = None
saved_inputmat_t = None
if weight.requires_grad:
if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad:
if inputmat_t is None:
saved_inputmat = inputmat
else:
saved_inputmat_t = inputmat_t
else:
saved_inputmat = inputmat_no_fp8
ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None,
saved_inputmat,
saved_inputmat_t,
weight,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
......@@ -336,19 +343,16 @@ class _Linear(torch.autograd.Function):
# Column Parallel Linear
# Overlap input AG with dgrad
inputmat_total = None
inputmat_t_total = None
handle = None
if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad
)
else:
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
)
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
)
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
handle = None
inputmat_t_total = inputmat_t
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
......@@ -412,6 +416,8 @@ class _Linear(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
if inputmat_t_total is None:
inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward)
wgrad, _ = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses,
......@@ -427,7 +433,6 @@ class _Linear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
clear_tensor_data(inputmat_t_total)
else:
wgrad, _, _ = gemm(
inputmat_total,
......@@ -439,7 +444,6 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
clear_tensor_data(inputmat_total)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
......@@ -453,7 +457,10 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
clear_tensor_data(inputmat_total)
# Deallocate input tensor
clear_tensor_data(inputmat_total)
clear_tensor_data(inputmat_t_total)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
......
......@@ -8,7 +8,7 @@ from typing import Any, Callable, Optional, Tuple
import torch
def clear_tensor_data(*tensors: Tuple[torch.Tensor, ...]) -> None:
def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""
Trick to deallocate tensor memory when delete operation does not
release the tensor due to PyTorch override.
......@@ -16,8 +16,9 @@ def clear_tensor_data(*tensors: Tuple[torch.Tensor, ...]) -> None:
Must be used carefully.
"""
for t in tensors:
t.data = torch.Tensor()
del t
if t is not None:
t.data = torch.Tensor()
del t
def get_device_compute_capability() -> Tuple[int, int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment