Unverified Commit 9f8aaddf authored by Hongbin Liu's avatar Hongbin Liu Committed by GitHub
Browse files

Split wgrad&dgrad from backward() to support a2a overlap (#1653)



* split wgrad for GroupedLinear
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* support wgrad split for linear and ln_linear
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* add comments and fix WeightGradStore
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* support bias and fix unit tests
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* minor fix
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* support fuse_grad_accumulation=false
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add wgrad split for layernorm_mlp
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* minor fix
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix unittest
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add unittest for distributed interface apply Dener's suggestion
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* replace split_bw with delay_wgrad_compute
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/module/layernorm_mlp.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/module/linear.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_linear.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove comments
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1a6a6d7b
...@@ -300,6 +300,11 @@ def _loss_backward(output_single_node, output_distributed): ...@@ -300,6 +300,11 @@ def _loss_backward(output_single_node, output_distributed):
LOSS_FN(output_distributed, target).backward() LOSS_FN(output_distributed, target).backward()
def _loss_backward_dw(model_single_node, model_distributed):
model_single_node.backward_dw()
model_distributed.backward_dw()
def _alloc_main_grad(model_single_node, model_distributed): def _alloc_main_grad(model_single_node, model_distributed):
for model in [model_single_node, model_distributed]: for model in [model_single_node, model_distributed]:
for param in model.parameters(): for param in model.parameters():
...@@ -473,6 +478,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -473,6 +478,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
# Compute loss and backpropagate # Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed) _loss_backward(output_single_node, output_distributed)
# Compute delayed weight gradient
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients # Validate outputs and gradients
_check_outputs(output_single_node, output_distributed) _check_outputs(output_single_node, output_distributed)
...@@ -494,6 +503,7 @@ def test_linear(): ...@@ -494,6 +503,7 @@ def test_linear():
{"fuse_wgrad_accumulation": True}, {"fuse_wgrad_accumulation": True},
{"return_bias": True}, {"return_bias": True},
{"params_dtype": torch.float16}, {"params_dtype": torch.float16},
{"delay_wgrad_compute": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
for parallel_mode in ["column", "row"]: for parallel_mode in ["column", "row"]:
...@@ -645,6 +655,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -645,6 +655,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# Compute loss and backpropagate # Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed) _loss_backward(output_single_node, output_distributed)
# Compute delayed weight gradient
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients # Validate outputs and gradients
_check_outputs(output_single_node, output_distributed) _check_outputs(output_single_node, output_distributed)
...@@ -667,6 +681,7 @@ def test_layernorm_linear(): ...@@ -667,6 +681,7 @@ def test_layernorm_linear():
{"params_dtype": torch.float16}, {"params_dtype": torch.float16},
{"zero_centered_gamma": False}, {"zero_centered_gamma": False},
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
for parallel_mode in ["column"]: for parallel_mode in ["column"]:
...@@ -746,6 +761,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg ...@@ -746,6 +761,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
# Compute loss and backpropagate # Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed) _loss_backward(output_single_node, output_distributed)
if "delay_wgrad_compute" in kwargs:
_loss_backward_dw(model_single_node, model_distributed)
# Validate outputs and gradients # Validate outputs and gradients
_check_outputs(output_single_node, output_distributed) _check_outputs(output_single_node, output_distributed)
...@@ -771,6 +789,7 @@ def test_layernorm_mlp(): ...@@ -771,6 +789,7 @@ def test_layernorm_mlp():
{"fuse_wgrad_accumulation": True}, {"fuse_wgrad_accumulation": True},
{"return_bias": True}, {"return_bias": True},
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True},
] ]
for kwargs in kwargs_list: for kwargs in kwargs_list:
......
...@@ -1036,7 +1036,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): ...@@ -1036,7 +1036,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
def _test_granular_accuracy(block, bs, dtype, config): def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False):
reset_rng_states() reset_rng_states()
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
...@@ -1052,12 +1052,18 @@ def _test_granular_accuracy(block, bs, dtype, config): ...@@ -1052,12 +1052,18 @@ def _test_granular_accuracy(block, bs, dtype, config):
out = out[0] out = out[0]
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
if delay_wgrad_compute:
block.backward_dw()
torch.cuda.synchronize() torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad] outputs = [out, inp_hidden_states.grad]
for p in block.parameters(): for p in block.parameters():
if p.requires_grad: if p.requires_grad:
outputs.append(p.grad) if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs return outputs
...@@ -1191,6 +1197,54 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias): ...@@ -1191,6 +1197,54 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) assert_allclose(te_output, torch_output, tolerance, rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_accumulation):
config = model_configs[model]
te_linear_ref = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=bias,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
te_linear_ref.weight = Parameter(te_linear.weight.clone())
if bias:
te_linear_ref.bias = Parameter(te_linear.bias.clone())
if fuse_wgrad_accumulation:
weight = getattr(te_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
te_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
te_linear_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
...@@ -1376,6 +1430,67 @@ def test_layernorm_linear_accuracy( ...@@ -1376,6 +1430,67 @@ def test_layernorm_linear_accuracy(
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_linear_accuracy_delay_wgrad_compute(
dtype, bs, model, normalization, zero_centered_gamma, bias, fuse_wgrad_accumulation
):
config = model_configs[model]
ln_linear_ref = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
ln_linear = LayerNormLinear(
config.hidden_size,
4 * config.hidden_size,
config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
ln_linear_ref.layer_norm_weight = Parameter(ln_linear.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_linear_ref.layer_norm_bias = Parameter(ln_linear.layer_norm_bias.clone())
ln_linear_ref.weight = Parameter(ln_linear.weight.clone())
if bias:
ln_linear_ref.bias = Parameter(ln_linear.bias.clone())
if fuse_wgrad_accumulation:
weight = getattr(ln_linear, f"weight")
weight.main_grad = torch.rand_like(weight, dtype=torch.float32)
ln_linear_ref.weight.main_grad = weight.main_grad.clone()
te_outputs = _test_granular_accuracy(ln_linear, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
ln_linear_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("model", ["small"])
...@@ -1452,8 +1567,78 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret ...@@ -1452,8 +1567,78 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_layernorm_mlp_accuracy_delay_wgrad_compute(
dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation
):
config = model_configs[model]
ln_mlp = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=True,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
ln_mlp_ref = LayerNormMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=4 * config.hidden_size,
eps=config.eps,
bias=bias,
normalization=normalization,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
# Share params
with torch.no_grad():
ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone())
if normalization != "RMSNorm":
ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone())
ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone())
ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone())
if bias:
ln_mlp_ref.fc1_bias = Parameter(ln_mlp.fc1_bias.clone())
ln_mlp_ref.fc2_bias = Parameter(ln_mlp.fc2_bias.clone())
if fuse_wgrad_accumulation:
ln_mlp.fc1_weight.main_grad = torch.rand_like(ln_mlp.fc1_weight, dtype=torch.float32)
ln_mlp_ref.fc1_weight.main_grad = ln_mlp.fc1_weight.main_grad.clone()
ln_mlp.fc2_weight.main_grad = torch.rand_like(ln_mlp.fc2_weight, dtype=torch.float32)
ln_mlp_ref.fc2_weight.main_grad = ln_mlp.fc2_weight.main_grad.clone()
te_outputs = _test_granular_accuracy(ln_mlp, bs, dtype, config, delay_wgrad_compute=True)
te_outputs_ref = _test_granular_accuracy(
ln_mlp_ref, bs, dtype, config, delay_wgrad_compute=False
)
# Shoule be bit-wise match
for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
def _test_grouped_linear_accuracy( def _test_grouped_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation block,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute=False,
): ):
reset_rng_states() reset_rng_states()
if fp8: if fp8:
...@@ -1495,6 +1680,12 @@ def _test_grouped_linear_accuracy( ...@@ -1495,6 +1680,12 @@ def _test_grouped_linear_accuracy(
) )
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
if delay_wgrad_compute:
if isinstance(block, GroupedLinear):
block.backward_dw()
else:
for i in range(num_gemms):
block[i].backward_dw()
torch.cuda.synchronize() torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad] outputs = [out, inp_hidden_states.grad]
...@@ -1515,6 +1706,8 @@ def _test_grouped_linear_accuracy( ...@@ -1515,6 +1706,8 @@ def _test_grouped_linear_accuracy(
@pytest.mark.parametrize("recipe", fp8_recipes + [None]) @pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy( def test_grouped_linear_accuracy(
dtype, dtype,
num_gemms, num_gemms,
...@@ -1523,6 +1716,8 @@ def test_grouped_linear_accuracy( ...@@ -1523,6 +1716,8 @@ def test_grouped_linear_accuracy(
recipe, recipe,
fp8_model_params, fp8_model_params,
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
bias,
delay_wgrad_compute,
parallel_mode=None, parallel_mode=None,
): ):
fp8 = recipe is not None fp8 = recipe is not None
...@@ -1542,18 +1737,19 @@ def test_grouped_linear_accuracy( ...@@ -1542,18 +1737,19 @@ def test_grouped_linear_accuracy(
num_gemms, num_gemms,
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=bias,
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
Linear( Linear(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=bias,
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
...@@ -1567,17 +1763,34 @@ def test_grouped_linear_accuracy( ...@@ -1567,17 +1763,34 @@ def test_grouped_linear_accuracy(
with torch.no_grad(): with torch.no_grad():
for i in range(num_gemms): for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) if bias:
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}") weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()
outputs_ref = _test_grouped_linear_accuracy( outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation sequential_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
) )
outputs = _test_grouped_linear_accuracy( outputs = _test_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation grouped_linear,
num_gemms,
bs,
dtype,
config,
recipe,
fp8,
fuse_wgrad_accumulation,
delay_wgrad_compute,
) )
# Shoule be bit-wise match # Shoule be bit-wise match
...@@ -1596,6 +1809,8 @@ def test_grouped_linear_accuracy_single_gemm(recipe): ...@@ -1596,6 +1809,8 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
recipe=recipe, recipe=recipe,
fp8_model_params=True, fp8_model_params=True,
fuse_wgrad_accumulation=True, fuse_wgrad_accumulation=True,
bias=True,
delay_wgrad_compute=False,
) )
......
...@@ -9,6 +9,7 @@ from dataclasses import dataclass ...@@ -9,6 +9,7 @@ from dataclasses import dataclass
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import queue
import torch import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
...@@ -216,3 +217,79 @@ class _ParameterInitMeta: ...@@ -216,3 +217,79 @@ class _ParameterInitMeta:
"""Safeguard reference to the parameter's parent module and initialization function.""" """Safeguard reference to the parameter's parent module and initialization function."""
if self.init_fn is None: if self.init_fn is None:
self.init_fn = get_default_init_method() self.init_fn = get_default_init_method()
class WeightGradStore:
"""
A class to manage weight gradient storage and computation in Transformer modules.
This class enables split backward propagation for better memory efficiency.
"""
def __init__(self, delay_wgrad_compute=False, ub_bulk_wgrad=False):
"""
Initialize the WeightGradStore.
Args:
delay_wgrad_compute (bool): Whether to delay weight gradient computation
ub_bulk_wgrad (bool): Whether to enable bulk weight gradient computation
"""
if delay_wgrad_compute:
self.context = queue.Queue()
assert (
ub_bulk_wgrad is False
), "ub_bulk_wgrad is not supported when enabling delay_wgrad_compute"
self.enabled = delay_wgrad_compute
else:
self.context = None
self.enabled = False
def delay_wgrad_compute(self):
"""
Get the current split backward propagation status.
Returns:
bool: True if split backward is enabled, False otherwise
"""
return self.enabled
def enable_delay_wgrad_compute(self):
"""Enable split backward propagation."""
self.enabled = True
def disable_delay_wgrad_compute(self):
"""Disable split backward propagation."""
self.enabled = False
def put(self, tensor_list, func):
"""
Store tensors and computation function for later execution.
Args:
tensor_list (list): List of tensors needed for computation
func (callable): Function to be executed with the tensors
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
self.context.put([tensor_list, func])
def pop(self):
"""
Execute the stored computation with the stored tensors.
Raises an exception if the queue is empty.
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
if self.context.qsize() > 0:
tensor_list, func = self.context.get()
return func(*tensor_list), tensor_list
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
raise RuntimeError(f"Pop empty queue. rank {rank}")
raise RuntimeError("Pop empty queue. No distributed environment detected.")
def assert_empty(self):
"""
Assert that the queue is empty.
Used for debugging and ensuring proper cleanup.
"""
assert self.enabled is True, "delay_wgrad_compute is not enabled"
rank = torch.distributed.get_rank()
assert self.context.empty(), f"Queue is not empty. rank {rank}"
...@@ -19,7 +19,7 @@ import torch.nn.functional as F ...@@ -19,7 +19,7 @@ import torch.nn.functional as F
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from ._common import _ParameterInitMeta from ._common import _ParameterInitMeta, noop_cat
from ..fp8 import ( from ..fp8 import (
MXFP8BlockScalingRecipeState, MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState, DelayedScalingRecipeState,
...@@ -1140,6 +1140,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1140,6 +1140,27 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
) )
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights)
if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype)
del grad_bias_
del wgrad
def _validate_name(self): def _validate_name(self):
""" """
Validate name passed to the module. Validate name passed to the module.
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""GroupedLinear API""" """GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
import functools
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -17,6 +18,7 @@ from .base import ( ...@@ -17,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import WeightGradStore
from ..fp8 import FP8GlobalStateManager from ..fp8 import FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
...@@ -64,6 +66,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -64,6 +66,7 @@ class _GroupedLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizers: List[Quantizer], input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer], weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer], output_quantizers: List[Quantizer],
...@@ -220,6 +223,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -220,6 +223,7 @@ class _GroupedLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module() or FP8GlobalStateManager.is_first_fp8_module()
) )
ctx.wgrad_store = wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1]) return out.view(-1, *inp.shape[1:-1], out.shape[-1])
...@@ -328,13 +332,10 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -328,13 +332,10 @@ class _GroupedLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
for w in weights for w in weights
] ]
# WGRAD grouped_gemm_wgrad = functools.partial(
_, grad_biases_, _ = general_grouped_gemm( general_grouped_gemm,
inputmats, out_dtype=ctx.activation_dtype,
grad_output, workspaces=get_multi_stream_cublas_workspace(),
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
...@@ -343,13 +344,19 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -343,13 +344,19 @@ class _GroupedLinear(torch.autograd.Function):
use_split_accumulator=wgrad_gemm_use_split_accumulator, use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
) )
for i in range(ctx.num_gemms): # WGRAD
if grad_biases[i] is None: if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
grad_biases[i] = grad_biases_[i] ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad)
del grad_biases_ else:
_, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor # Deallocate input tensor
clear_tensor_data(*inputmats) clear_tensor_data(*inputmats)
def handle_custom_ddp_from_mcore(weight, wgrad): def handle_custom_ddp_from_mcore(weight, wgrad):
if ctx.weights_requires_grad: if ctx.weights_requires_grad:
...@@ -385,7 +392,14 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -385,7 +392,14 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
wgrad_list = [None] * ctx.num_gemms wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias: if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
and ctx.wgrad_store.delay_wgrad_compute()
and not ctx.fp8
):
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
...@@ -408,6 +422,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -408,6 +422,7 @@ class _GroupedLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
*wgrad_list, *wgrad_list,
*grad_biases, *grad_biases,
) )
...@@ -456,6 +471,8 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -456,6 +471,8 @@ class GroupedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and Note: GroupedLinear doesn't really handle the TP communications inside. The `tp_size` and
`parallel_mode` are used to determine the shapes of weights and biases. `parallel_mode` are used to determine the shapes of weights and biases.
...@@ -482,6 +499,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -482,6 +499,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -502,6 +520,8 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -502,6 +520,8 @@ class GroupedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
self._num_fp8_tensors_per_gemm = { self._num_fp8_tensors_per_gemm = {
"fwd": 3, "fwd": 3,
...@@ -707,6 +727,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -707,6 +727,7 @@ class GroupedLinear(TransformerEngineBaseModule):
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.wgrad_store,
input_quantizers, input_quantizers,
weight_quantizers, weight_quantizers,
output_quantizers, output_quantizers,
...@@ -727,6 +748,30 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -727,6 +748,30 @@ class GroupedLinear(TransformerEngineBaseModule):
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out return out
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_
del wgrad_list
del tensor_list
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear.""" """Customize quantizers based on current scaling recipe + linear."""
assert ( assert (
......
...@@ -9,6 +9,7 @@ from typing import Callable, Dict, Optional, Tuple, Union ...@@ -9,6 +9,7 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import functools
import torch import torch
from torch.nn import init from torch.nn import init
...@@ -52,7 +53,7 @@ from ..distributed import ( ...@@ -52,7 +53,7 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
...@@ -91,6 +92,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -91,6 +92,7 @@ class _LayerNormLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
input_quantizer: Optional[Quantizer], input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer],
...@@ -438,6 +440,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -438,6 +440,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase(): if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
ctx.debug = debug ctx.debug = debug
# Row Parallel Linear # Row Parallel Linear
...@@ -752,15 +755,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -752,15 +755,14 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM # wgrad GEMM
# Note: Fuse with bgrad computation if needed # Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, *_, rs_out = general_gemm( general_gemm_wgrad = functools.partial(
ln_out_total, general_gemm,
grad_output,
get_workspace(),
layout="NT",
grad=True,
out_dtype=( out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None), bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
...@@ -771,6 +773,20 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -771,6 +773,20 @@ class _LayerNormLinear(torch.autograd.Function):
extra_output=rs_out, extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad, bulk_overlap=ctx.ub_bulk_wgrad,
) )
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
...@@ -779,16 +795,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -779,16 +795,11 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True)
if grad_bias is None: # Don't return grad bias if not needed
grad_bias = grad_bias_ if not ctx.use_bias:
del grad_bias_ grad_bias = None
# Deallocate input tensor
if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
# Make sure all tensor-parallel communication is finished # Synchronize tensor parallel communication
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
...@@ -870,6 +881,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -870,6 +881,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
None, # fp8_calibration None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation None, # fuse_wgrad_accumulation
None, # input_quantizer None, # input_quantizer
None, # weight_quantizer None, # weight_quantizer
...@@ -992,6 +1004,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -992,6 +1004,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass. Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations. This can help in latency bound communication situations.
...@@ -1026,6 +1042,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1026,6 +1042,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
name: str = None, name: str = None,
) -> None: ) -> None:
...@@ -1045,6 +1062,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1045,6 +1062,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name self.name = name
if TEDebugState.debug_enabled: if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers self._turn_off_unsupported_features_in_debug() # turn off userbuffers
...@@ -1423,6 +1441,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1423,6 +1441,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
......
...@@ -8,6 +8,7 @@ import warnings ...@@ -8,6 +8,7 @@ import warnings
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import functools
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -65,7 +66,7 @@ from ..tensor.float8_tensor import ( ...@@ -65,7 +66,7 @@ from ..tensor.float8_tensor import (
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
...@@ -155,6 +156,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -155,6 +156,7 @@ class _LayerNormMLP(torch.autograd.Function):
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
wgrad_store: WeightGradStore,
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
fc1_input_quantizer: Optional[Quantizer], fc1_input_quantizer: Optional[Quantizer],
fc1_weight_quantizer: Optional[Quantizer], fc1_weight_quantizer: Optional[Quantizer],
...@@ -587,6 +589,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -587,6 +589,8 @@ class _LayerNormMLP(torch.autograd.Function):
if in_fp8_activation_recompute_phase(): if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs: if ub_overlap_rs:
fc2_out = rs_out fc2_out = rs_out
...@@ -820,15 +824,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -820,15 +824,14 @@ class _LayerNormMLP(torch.autograd.Function):
grad_arg = True grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False grad_arg = False
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( general_gemm_fc2_wgrad = functools.partial(
act_out, general_gemm,
grad_output,
get_workspace(),
out_dtype=( out_dtype=(
origin_fc2_weight.main_grad.dtype origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
workspace=get_workspace(),
quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision
layout="NT", layout="NT",
grad=grad_arg, grad=grad_arg,
...@@ -837,13 +840,28 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -837,13 +840,28 @@ class _LayerNormMLP(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
if fc2_bias_grad is None: if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None: ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad)
# BGRAD not fused with GEMM for float8 blockwise gemm. fc2_wgrad = None
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) else:
fc2_bias_grad = fc2_bias_grad_ fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
del fc2_bias_grad_ act_out,
clear_tensor_data(act_out) grad_output,
)
if fc2_bias_grad is None:
if (
ctx.fp8
and ctx.fp8_recipe.float8_block_scaling()
and fc2_bias is not None
):
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute():
clear_tensor_data(act_out)
# bias computation # bias computation
fc1_bias_grad = None fc1_bias_grad = None
...@@ -1017,15 +1035,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1017,15 +1035,14 @@ class _LayerNormMLP(torch.autograd.Function):
) )
# wgrad GEMM # wgrad GEMM
fc1_wgrad_outputs = general_gemm( general_gemm_fc1_wgrad = functools.partial(
ln_out_total, general_gemm,
dact,
get_workspace(),
out_dtype=( out_dtype=(
origin_fc1_weight.main_grad.dtype origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
workspace=get_workspace(),
layout="NT", layout="NT",
quantization_params=ctx.fc1_grad_weight_quantizer, quantization_params=ctx.fc1_grad_weight_quantizer,
grad=fuse_gemm_and_bias_fc1_wgrad, grad=fuse_gemm_and_bias_fc1_wgrad,
...@@ -1037,13 +1054,23 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1037,13 +1054,23 @@ class _LayerNormMLP(torch.autograd.Function):
extra_output=fc1_dgrad_rs_out, extra_output=fc1_dgrad_rs_out,
bulk_overlap=ctx.ub_bulk_wgrad, bulk_overlap=ctx.ub_bulk_wgrad,
) )
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad)
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None
else:
fc1_wgrad_outputs = general_gemm_fc1_wgrad(
ln_out_total,
dact,
)
clear_tensor_data(ln_out_total, dact) clear_tensor_data(ln_out_total, dact)
if fuse_gemm_and_bias_fc1_wgrad: if fuse_gemm_and_bias_fc1_wgrad:
fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs
else: else:
fc1_wgrad, *_ = fc1_wgrad_outputs fc1_wgrad, *_ = fc1_wgrad_outputs
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_fc1_wgrad.is_fp8_ubuf(): if ub_obj_fc1_wgrad.is_fp8_ubuf():
...@@ -1160,6 +1187,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1160,6 +1187,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
None, # fp8_calibration None, # fp8_calibration
None, # wgrad_store
None, # fuse_wgrad_accumulation None, # fuse_wgrad_accumulation
None, # fc1_input_quantizer, None, # fc1_input_quantizer,
None, # fc1_weight_quantizer, None, # fc1_weight_quantizer,
...@@ -1296,6 +1324,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1296,6 +1324,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
batch size per training step. Needed for JIT Warmup, a technique where jit batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase. used for forward propogation and activation recompute phase.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass. Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations. This can help in latency bound communication situations.
...@@ -1333,6 +1365,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1333,6 +1365,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
ub_overlap_rs_dgrad: bool = False, ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1365,6 +1398,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1365,6 +1398,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
if TEDebugState.debug_enabled: if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -1636,6 +1671,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1636,6 +1671,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer, fc1_weight_quantizer,
...@@ -1835,3 +1871,41 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1835,3 +1871,41 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"):
(fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop()
if self.use_bias and self.fc1_bias.grad is None:
(fc1_wgrad, fc1_bias_grad, *_), _ = self.wgrad_store.pop()
else:
(fc1_wgrad, *_), _ = self.wgrad_store.pop()
fc1_bias_grad = None
if self.use_bias:
if self.fc2_bias.grad is None:
if (
self.fp8
and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling()
and self.apply_bias
and not self.gemm_bias_unfused_add
):
act_out = tensor_list_fc2[0]
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
self.fc2_bias.grad = fc2_bias_grad_.to(self.fc2_bias.dtype)
if self.fc1_bias.grad is None:
self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype)
if not self.fuse_wgrad_accumulation:
if self.fc2_weight.grad is None:
self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype)
if self.fc1_weight.grad is None:
self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype)
del fc2_bias_grad_
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
...@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple, Union ...@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import functools
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -22,7 +23,7 @@ from .base import ( ...@@ -22,7 +23,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import noop_cat, _fix_gathered_fp8_transpose from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ..fp8 import FP8GlobalStateManager from ..fp8 import FP8GlobalStateManager
from ..utils import ( from ..utils import (
cast_if_needed, cast_if_needed,
...@@ -85,6 +86,7 @@ class _Linear(torch.autograd.Function): ...@@ -85,6 +86,7 @@ class _Linear(torch.autograd.Function):
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
wgrad_store: WeightGradStore,
input_quantizer: Optional[Quantizer], input_quantizer: Optional[Quantizer],
weight_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer],
output_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer],
...@@ -380,6 +382,7 @@ class _Linear(torch.autograd.Function): ...@@ -380,6 +382,7 @@ class _Linear(torch.autograd.Function):
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase(): if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
# Row Parallel Linear # Row Parallel Linear
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
...@@ -675,15 +678,14 @@ class _Linear(torch.autograd.Function): ...@@ -675,15 +678,14 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM # wgrad GEMM
# Note: Fuse with bgrad computation if needed # Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, _, rs_out = general_gemm( general_gemm_wgrad = functools.partial(
inputmat_total, general_gemm,
grad_output,
get_workspace(),
layout="NT",
grad=True,
out_dtype=( out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None), bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
...@@ -694,6 +696,19 @@ class _Linear(torch.autograd.Function): ...@@ -694,6 +696,19 @@ class _Linear(torch.autograd.Function):
extra_output=rs_out, extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad, bulk_overlap=ctx.ub_bulk_wgrad,
) )
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
if ctx.owns_input:
clear_tensor_data(inputmat_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
...@@ -702,14 +717,6 @@ class _Linear(torch.autograd.Function): ...@@ -702,14 +717,6 @@ class _Linear(torch.autograd.Function):
else: else:
dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
if ctx.owns_input:
clear_tensor_data(inputmat_total)
# Don't return grad bias if not needed # Don't return grad bias if not needed
if not ctx.use_bias: if not ctx.use_bias:
grad_bias = None grad_bias = None
...@@ -761,6 +768,7 @@ class _Linear(torch.autograd.Function): ...@@ -761,6 +768,7 @@ class _Linear(torch.autograd.Function):
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
None, # fp8_calibration None, # fp8_calibration
None, # wgrad_store
None, # input_quantizer None, # input_quantizer
None, # weight_quantizer None, # weight_quantizer
None, # output_quantizer None, # output_quantizer
...@@ -861,6 +869,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -861,6 +869,10 @@ class Linear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether or not to delay weight gradient computation. If set to `True`,
it's the user's responsibility to call `module.backward_dw` to compute
weight gradients.
symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None symmetric_ar_type : {None, 'multimem_all_reduce', 'two_shot', 'one_shot'}, default = None
Type of symmetric memory all-reduce to use during the forward pass. Type of symmetric memory all-reduce to use during the forward pass.
This can help in latency bound communication situations. This can help in latency bound communication situations.
...@@ -891,6 +903,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -891,6 +903,7 @@ class Linear(TransformerEngineBaseModule):
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
) -> None: ) -> None:
...@@ -911,6 +924,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -911,6 +924,8 @@ class Linear(TransformerEngineBaseModule):
if TEDebugState.debug_enabled: if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if device == "meta": if device == "meta":
assert parameters_split is None, "Cannot split module parameters on 'meta' device." assert parameters_split is None, "Cannot split module parameters on 'meta' device."
if tp_group is None: if tp_group is None:
...@@ -1241,6 +1256,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1241,6 +1256,7 @@ class Linear(TransformerEngineBaseModule):
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.wgrad_store,
input_quantizer, input_quantizer,
weight_quantizer, weight_quantizer,
output_quantizer, output_quantizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment