Unverified Commit d76118d9 authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] Linear: fix computation for wgrad if sequence_parallel=True (#531)



* [PyTorch] Linear: fix computation for wgrad if sequence_parallel=True
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* Remove buggy gather_along_last_dim
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [PyTorch] Linear: fix line length
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* Simplify logic for saving input tensor for Linear backward
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 54e46e21
...@@ -473,29 +473,6 @@ def gather_along_first_dim( ...@@ -473,29 +473,6 @@ def gather_along_first_dim(
return output, handle return output, handle
def gather_along_last_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Gather tensors and concatinate along the last dimension."""
world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_, None
dim_size = list(input_.size())
dim_size[-1] = dim_size[-1] * world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=tp_group, async_op=async_op
)
return output, handle
def allreduce( def allreduce(
input_: torch.Tensor, input_: torch.Tensor,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
......
...@@ -35,7 +35,6 @@ from ..distributed import ( ...@@ -35,7 +35,6 @@ from ..distributed import (
initialize_affine_weight_gpu, initialize_affine_weight_gpu,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
gather_along_last_dim,
) )
from ..cpp_extensions import ( from ..cpp_extensions import (
fp8_gemm, fp8_gemm,
...@@ -102,35 +101,34 @@ class _Linear(torch.autograd.Function): ...@@ -102,35 +101,34 @@ class _Linear(torch.autograd.Function):
ub_atomic_gemm_rs = False ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag: if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
# Cast for native AMP
# Cast input to expected dtype
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_t = None
inputmat_no_fp8 = inputmat inputmat_no_fp8 = inputmat
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if (
if not fp8_meta["recipe"].override_linear_precision.wgrad: not fp8_meta["recipe"].override_linear_precision.wgrad
if is_grad_enabled: and is_grad_enabled
inputmat, inputmat_t = fp8_cast_transpose_fused( and weight.requires_grad
inputmat, and not sequence_parallel
fp8_meta["scaling_fwd"], ):
tex.FP8FwdTensors.GEMM1_INPUT, # FP8 input for forward, FP8 input transpose for backward wgrad
fp8_dtype_forward, inputmat, inputmat_t = fp8_cast_transpose_fused(
) inputmat,
else: fp8_meta["scaling_fwd"],
inputmat = cast_to_fp8( tex.FP8FwdTensors.GEMM1_INPUT,
inputmat, fp8_dtype_forward,
fp8_meta["scaling_fwd"], )
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else: else:
inputmat, inputmat_t = cast_to_fp8( # FP8 input for forward
inputmat = cast_to_fp8(
inputmat, inputmat,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
), None )
# Column Parallel Linear # Column Parallel Linear
if parallel_mode == "column" and sequence_parallel: if parallel_mode == "column" and sequence_parallel:
...@@ -260,10 +258,19 @@ class _Linear(torch.autograd.Function): ...@@ -260,10 +258,19 @@ class _Linear(torch.autograd.Function):
) )
if is_grad_enabled: if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad saved_inputmat = None
saved_inputmat_t = None
if weight.requires_grad:
if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad:
if inputmat_t is None:
saved_inputmat = inputmat
else:
saved_inputmat_t = inputmat_t
else:
saved_inputmat = inputmat_no_fp8
ctx.save_for_backward( ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None, saved_inputmat,
inputmat_t if weight.requires_grad and fp8_wgrad else None, saved_inputmat_t,
weight, weight,
weight_t_fp8 if fp8 else None, weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
...@@ -336,19 +343,16 @@ class _Linear(torch.autograd.Function): ...@@ -336,19 +343,16 @@ class _Linear(torch.autograd.Function):
# Column Parallel Linear # Column Parallel Linear
# Overlap input AG with dgrad # Overlap input AG with dgrad
inputmat_total = None
inputmat_t_total = None
handle = None
if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel: if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: inputmat_total, handle = gather_along_first_dim(
inputmat_t_total, handle = gather_along_last_dim( inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad )
)
else:
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
)
else: else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat inputmat_total = inputmat
handle = None inputmat_t_total = inputmat_t
if ctx.is_first_microbatch is not None: if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = ( accumulate_wgrad_into_param_main_grad = (
...@@ -412,6 +416,8 @@ class _Linear(torch.autograd.Function): ...@@ -412,6 +416,8 @@ class _Linear(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
if inputmat_t_total is None:
inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward)
wgrad, _ = fp8_gemm( wgrad, _ = fp8_gemm(
inputmat_t_total, inputmat_t_total,
fwd_scale_inverses, fwd_scale_inverses,
...@@ -427,7 +433,6 @@ class _Linear(torch.autograd.Function): ...@@ -427,7 +433,6 @@ class _Linear(torch.autograd.Function):
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
clear_tensor_data(inputmat_t_total)
else: else:
wgrad, _, _ = gemm( wgrad, _, _ = gemm(
inputmat_total, inputmat_total,
...@@ -439,7 +444,6 @@ class _Linear(torch.autograd.Function): ...@@ -439,7 +444,6 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
clear_tensor_data(inputmat_total)
else: else:
# WGRAD # WGRAD
wgrad, grad_bias, _ = gemm( wgrad, grad_bias, _ = gemm(
...@@ -453,7 +457,10 @@ class _Linear(torch.autograd.Function): ...@@ -453,7 +457,10 @@ class _Linear(torch.autograd.Function):
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
clear_tensor_data(inputmat_total)
# Deallocate input tensor
clear_tensor_data(inputmat_total)
clear_tensor_data(inputmat_t_total)
# Column Parallel Linear # Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
......
...@@ -8,7 +8,7 @@ from typing import Any, Callable, Optional, Tuple ...@@ -8,7 +8,7 @@ from typing import Any, Callable, Optional, Tuple
import torch import torch
def clear_tensor_data(*tensors: Tuple[torch.Tensor, ...]) -> None: def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
""" """
Trick to deallocate tensor memory when delete operation does not Trick to deallocate tensor memory when delete operation does not
release the tensor due to PyTorch override. release the tensor due to PyTorch override.
...@@ -16,8 +16,9 @@ def clear_tensor_data(*tensors: Tuple[torch.Tensor, ...]) -> None: ...@@ -16,8 +16,9 @@ def clear_tensor_data(*tensors: Tuple[torch.Tensor, ...]) -> None:
Must be used carefully. Must be used carefully.
""" """
for t in tensors: for t in tensors:
t.data = torch.Tensor() if t is not None:
del t t.data = torch.Tensor()
del t
def get_device_compute_capability() -> Tuple[int, int]: def get_device_compute_capability() -> Tuple[int, int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment