Unverified Commit 9f8aaddf authored by Hongbin Liu's avatar Hongbin Liu Committed by GitHub
Browse files

Split wgrad&dgrad from backward() to support a2a overlap (#1653)



* split wgrad for GroupedLinear
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* support wgrad split for linear and ln_linear
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* add comments and fix WeightGradStore
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* support bias and fix unit tests
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* minor fix
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* support fuse_grad_accumulation=false
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add wgrad split for layernorm_mlp
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* minor fix
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix unittest
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add unittest for distributed interface apply Dener's suggestion
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* replace split_bw with delay_wgrad_compute
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/module/layernorm_mlp.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove comments
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1a6a6d7b
......@@ -300,6 +300,11 @@ def _loss_backward(output_single_node, output_distributed):
LOSS_FN(output_distributed, target).backward()
def _loss_backward_dw(model_single_node, model_distributed):
model_single_node.backward_dw()
model_distributed.backward_dw()
def _alloc_main_grad(model_single_node, model_distributed):
for model in [model_single_node, model_distributed]:
for param in model.parameters():
......@@ -473,6 +478,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
# Compute delayed weight gradient
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
......@@ -494,6 +503,7 @@ def test_linear():
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"params_dtype": torch.float16},
{"delay_wgrad_compute": True},
]
for kwargs in kwargs_list:
for parallel_mode in ["column", "row"]:
......@@ -645,6 +655,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
# Compute delayed weight gradient
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
......@@ -667,6 +681,7 @@ def test_layernorm_linear():
{"params_dtype": torch.float16},
{"zero_centered_gamma": False},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
......@@ -746,6 +761,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
......@@ -771,6 +789,7 @@ def test_layernorm_mlp():
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
for kwargs in kwargs_list:
......
......@@ -1036,7 +1036,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config):
def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False):
reset_rng_states()
inp_hidden_states = torch.randn(
......@@ -1052,12 +1052,18 @@ def _test_granular_accuracy(block, bs, dtype, config):
out = out[0]
loss = out.sum()
loss.backward()
if delay_wgrad_compute:
block.backward_dw()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs
......@@ -1191,6 +1197,54 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation):
config = model_configs[model]
te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
te_linear_ref.weight = Parameter(te_linear.weight.clone())
if bias:
te_linear_ref.bias = Parameter(te_linear.bias.clone())
if fuse_wgrad_accumulation:
weight = getattr(te_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
te_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
......@@ -1376,6 +1430,67 @@ def test_layernorm_linear_accuracy(
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_linear_accuracy_delay_wgrad_compute(
dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
):
config = model_configs[model]
ln_linear_ref = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
ln_linear = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone())
ln_linear_ref.weight = Parameter(ln_linear.weight.clone())
if bias:
ln_linear_ref.bias = Parameter(ln_linear.bias.clone())
if fuse_wgrad_accumulation:
weight = getattr(ln_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
ln_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
......@@ -1452,8 +1567,78 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation
):
config = model_configs[model]
ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
ln_mlp_ref = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
if bias:
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
if fuse_wgrad_accumulation:
ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32)
ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone()
ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32)
ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone()
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_grouped_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
block,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute=False,
):
reset_rng_states()
if fp8:
......@@ -1495,6 +1680,12 @@ def _test_grouped_linear_accuracy(
)
loss = out.sum()
loss.backward()
if delay_wgrad_compute:
if isinstance(block, GroupedLinear):
block.backward_dw()
else:
for i in range(num_gemms):
block[i].backward_dw()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
......@@ -1515,6 +1706,8 @@ def _test_grouped_linear_accuracy(
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy(
dtype,
num_gemms,
......@@ -1523,6 +1716,8 @@ def test_grouped_linear_accuracy(
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
bias,
delay_wgrad_compute,
parallel_mode=None,
):
fp8 = recipe is not None
......@@ -1542,18 +1737,19 @@ def test_grouped_linear_accuracy(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=True,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Linear(
config.hidden_size,
4 * config.hidden_size,
bias=True,
bias=bias,
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
......@@ -1567,17 +1763,34 @@ def test_grouped_linear_accuracy(
with torch.no_grad():
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if bias:
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
sequential_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
outputs = _test_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
grouped_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
)
# Shoule be bit-wise match
......@@ -1596,6 +1809,8 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
recipe=recipe,
fp8_model_params=True,
fuse_wgrad_accumulation=True,
bias=True,
delay_wgrad_compute=False,
)
......
......@@ -9,6 +9,7 @@ from dataclasses import dataclass
from functools import reduce
from operator import mul as multiply_op
import queue
import torch
from .. import cpp_extensions as tex
......@@ -216,3 +217,79 @@ class _ParameterInitMeta:
"""Safeguard reference to the parameter's parent module and initialization function."""
if self.init_fn is None:
self.init_fn = get_default_init_method()
class WeightGradStore:
"""
A class to manage weight gradient storage and computation in Transformer modules.
This class enables split backward propagation for better memory efficiency.
"""
def __init__(self, delay_wgrad_compute=False, ub_bulk_wgrad=False):
"""
Initialize the WeightGradStore.
Args:
delay_wgrad_compute (bool): Whether to delay weight gradient computation
ub_bulk_wgrad (bool): Whether to enable bulk weight gradient computation
"""
if delay_wgrad_compute:
self.context = queue.Queue()
assert (
ub_bulk_wgrad is False
), "ub_bulk_wgrad is not supported when enabling delay_wgrad_compute"
self.enabled = delay_wgrad_compute
else:
self.context = None
self.enabled = False
def delay_wgrad_compute(self):
"""
Get the current split backward propagation status.
Returns:
bool: True if split backward is enabled, False otherwise
"""
return self.enabled
def enable_delay_wgrad_compute(self):
"""Enable split backward propagation."""
self.enabled = True
def disable_delay_wgrad_compute(self):
"""Disable split backward propagation."""
self.enabled = False
def put(self, tensor_list, func):
"""
Store tensors and computation function for later execution.
Args:
tensor_list (list): List of tensors needed for computation
func (callable): Function to be executed with the tensors
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
self.context.put([tensor_list, func])
def pop(self):
"""
Execute the stored computation with the stored tensors.
Raises an exception if the queue is empty.
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
if self.context.qsize() > 0:
tensor_list, func = self.context.get()
return func(*tensor_list), tensor_list
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
raise RuntimeError(f"Pop empty queue. rank {rank}")
raise RuntimeError("Pop empty queue. No distributed environment detected.")
def assert_empty(self):
"""
Assert that the queue is empty.
Used for debugging and ensuring proper cleanup.
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
rank = torch.distributed.get_rank()
assert self.context.empty(), f"Queue is not empty. rank {rank}"
......@@ -19,7 +19,7 @@ import torch.nn.functional as F
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from ._common import _ParameterInitMeta
from ._common import _ParameterInitMeta, noop_cat
from ..fp8 import (
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
......@@ -1140,6 +1140,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights)
if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype)
del grad_bias_
del wgrad
def _validate_name(self):
"""
Validate name passed to the module.
......
......@@ -5,6 +5,7 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import functools
import torch
import transformer_engine_torch as tex
......@@ -17,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..utils import (
divide,
......@@ -64,6 +66,7 @@ class _GroupedLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
......@@ -220,6 +223,7 @@ class _GroupedLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
ctx.wgrad_store = wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
......@@ -328,13 +332,10 @@ class _GroupedLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
for w in weights
]
# WGRAD
_, grad_biases_, _ = general_grouped_gemm(
inputmats,
grad_output,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
grouped_gemm_wgrad = functools.partial(
general_grouped_gemm,
out_dtype=ctx.activation_dtype,
workspaces=get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
......@@ -343,13 +344,19 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad)
else:
_, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor
clear_tensor_data(*inputmats)
# Deallocate input tensor
clear_tensor_data(*inputmats)
def handle_custom_ddp_from_mcore(weight, wgrad):
if ctx.weights_requires_grad:
......@@ -385,7 +392,14 @@ class _GroupedLinear(torch.autograd.Function):
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias:
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
and ctx.wgrad_store.delay_wgrad_compute()
and not ctx.fp8
):
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
......@@ -408,6 +422,7 @@ class _GroupedLinear(torch.autograd.Function):
None,
None,
None,
None,
*wgrad_list,
*grad_biases,
)
......@@ -456,6 +471,8 @@ class GroupedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases.
......@@ -482,6 +499,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
) -> None:
super().__init__()
......@@ -502,6 +520,8 @@ class GroupedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
self._num_fp8_tensors_per_gemm = {
"fwd": 3,
......@@ -707,6 +727,7 @@ class GroupedLinear(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
......@@ -727,6 +748,30 @@ class GroupedLinear(TransformerEngineBaseModule):
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_
del wgrad_list
del tensor_list
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
......
......@@ -9,6 +9,7 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
from torch.nn import init
......@@ -52,7 +53,7 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ..tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
......@@ -91,6 +92,7 @@ class _LayerNormLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
......@@ -438,6 +440,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
ctx.debug = debug
# Row Parallel Linear
......@@ -752,15 +755,14 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total,
grad_output,
get_workspace(),
layout="NT",
grad=True,
general_gemm_wgrad = functools.partial(
general_gemm,
out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator,
......@@ -771,6 +773,20 @@ class _LayerNormLinear(torch.autograd.Function):
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad:
......@@ -779,16 +795,11 @@ class _LayerNormLinear(torch.autograd.Function):
else:
dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True)
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
# Don't return grad bias if not needed
if not ctx.use_bias:
grad_bias = None
# Make sure all tensor-parallel communication is finished
# Synchronize tensor parallel communication
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
......@@ -870,6 +881,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # input_quantizer
None, # weight_quantizer
......@@ -992,6 +1004,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
......@@ -1026,6 +1042,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: str = None,
) -> None:
......@@ -1045,6 +1062,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
......@@ -1423,6 +1441,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
......
......@@ -8,6 +8,7 @@ import warnings
from typing import Callable, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
from torch.nn.parameter import Parameter
......@@ -65,7 +66,7 @@ from ..tensor.float8_tensor import (
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensor,
......@@ -155,6 +156,7 @@ class _LayerNormMLP(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer],
......@@ -587,6 +589,8 @@ class _LayerNormMLP(torch.autograd.Function):
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs:
fc2_out = rs_out
......@@ -820,15 +824,14 @@ class _LayerNormMLP(torch.autograd.Function):
grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm(
act_out,
grad_output,
get_workspace(),
general_gemm_fc2_wgrad = functools.partial(
general_gemm,
out_dtype=(
origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
workspace=get_workspace(),
quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision
layout="NT",
grad=grad_arg,
......@@ -837,13 +840,28 @@ class _LayerNormMLP(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if fc2_bias_grad is None:
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None:
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
clear_tensor_data(act_out)
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad)
fc2_wgrad = None
else:
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
act_out,
grad_output,
)
if fc2_bias_grad is None:
if (
ctx.fp8
and ctx.fp8_recipe.float8_block_scaling()
and fc2_bias is not None
):
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute():
clear_tensor_data(act_out)
# bias computation
fc1_bias_grad = None
......@@ -1017,15 +1035,14 @@ class _LayerNormMLP(torch.autograd.Function):
)
# wgrad GEMM
fc1_wgrad_outputs = general_gemm(
ln_out_total,
dact,
get_workspace(),
general_gemm_fc1_wgrad = functools.partial(
general_gemm,
out_dtype=(
origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
quantization_params=ctx.fc1_grad_weight_quantizer,
grad=fuse_gemm_and_bias_fc1_wgrad,
......@@ -1037,13 +1054,23 @@ class _LayerNormMLP(torch.autograd.Function):
extra_output=fc1_dgrad_rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad)
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None
else:
fc1_wgrad_outputs = general_gemm_fc1_wgrad(
ln_out_total,
dact,
)
clear_tensor_data(ln_out_total, dact)
clear_tensor_data(ln_out_total, dact)
if fuse_gemm_and_bias_fc1_wgrad:
fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs
else:
fc1_wgrad, *_ = fc1_wgrad_outputs
if fuse_gemm_and_bias_fc1_wgrad:
fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs
else:
fc1_wgrad, *_ = fc1_wgrad_outputs
if ctx.ub_bulk_wgrad:
if ub_obj_fc1_wgrad.is_fp8_ubuf():
......@@ -1160,6 +1187,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer,
None, # fc1_weight_quantizer,
......@@ -1296,6 +1324,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
......@@ -1333,6 +1365,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
) -> None:
super().__init__()
......@@ -1365,6 +1398,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -1636,6 +1671,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
......@@ -1835,3 +1871,41 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
if self.use_bias and self.fc1_bias.grad is None:
(fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
else:
(fc1_wgrad, *_), _ = self.wgrad_store.pop()
fc1_bias_grad = None
if self.use_bias:
if self.fc2_bias.grad is None:
if (
self.fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling()
and self.apply_bias
and not self.gemm_bias_unfused_add
):
act_out = tensor_list_fc2[0]
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
self.fc2_bias.grad = fc2_bias_grad_.to(self.fc2_bias.dtype)
if self.fc1_bias.grad is None:
self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype)
if not self.fuse_wgrad_accumulation:
if self.fc2_weight.grad is None:
self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
if self.fc1_weight.grad is None:
self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
del fc2_bias_grad_
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
......@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
import transformer_engine_torch as tex
......@@ -22,7 +23,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import noop_cat, _fix_gathered_fp8_transpose
from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..utils import (
cast_if_needed,
......@@ -85,6 +86,7 @@ class _Linear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer],
......@@ -380,6 +382,7 @@ class _Linear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs_fprop:
......@@ -675,15 +678,14 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, _, rs_out = general_gemm(
inputmat_total,
grad_output,
get_workspace(),
layout="NT",
grad=True,
general_gemm_wgrad = functools.partial(
general_gemm,
out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator,
......@@ -694,6 +696,19 @@ class _Linear(torch.autograd.Function):
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
if ctx.owns_input:
clear_tensor_data(inputmat_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad:
......@@ -702,14 +717,6 @@ class _Linear(torch.autograd.Function):
else:
dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
if ctx.owns_input:
clear_tensor_data(inputmat_total)
# Don't return grad bias if not needed
if not ctx.use_bias:
grad_bias = None
......@@ -761,6 +768,7 @@ class _Linear(torch.autograd.Function):
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # input_quantizer
None, # weight_quantizer
None, # output_quantizer
......@@ -861,6 +869,10 @@ class Linear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations.
......@@ -891,6 +903,7 @@ class Linear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
name: Optional[str] = None,
) -> None:
......@@ -911,6 +924,8 @@ class Linear(TransformerEngineBaseModule):
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if device == "meta":
assert parameters_split is None, "Cannot split module parameters on 'meta' device."
if tp_group is None:
......@@ -1241,6 +1256,7 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment