Commit 460b006c authored by yuguo's avatar yuguo
Browse files

[DCU] surpport delay_wgrad_compute in batchgemm

parent 196a213f
......@@ -171,7 +171,7 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset()
def _test_batched_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
):
reset_rng_states()
if fp8:
......@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
)
loss = out.sum()
loss.backward()
if delay_wgrad_compute:
if isinstance(block, BatchedLinear):
block.backward_dw()
else:
for i in range(num_gemms):
block[i].backward_dw()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
if isinstance(block, BatchedLinear):
if getattr(p, "main_grad", None) is not None:
for j in range(batch_num):
outputs.append(p.main_grad[p.main_grad.shape[0] // batch_num * j : p.main_grad.shape[0] // batch_num * (j + 1)])
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
for j in range(batch_num):
outputs.append(p.grad[p.grad.shape[0] // batch_num * j : p.grad.shape[0] // batch_num * (j + 1)])
else:
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
......@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_batched_linear_accuracy(
dtype,
num_gemms,
......@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
delay_wgrad_compute,
parallel_mode=None,
):
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
......@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
).eval()
sequential_linear = torch.nn.ModuleList(
[
......@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone()
outputs_ref = _test_batched_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
)
outputs = _test_batched_linear_accuracy(
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
)
# Shoule be bit-wise match
......@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3)
if __name__ == "__main__":
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True)
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True, True)
......@@ -6,7 +6,7 @@
import os
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
import functools
import torch
import transformer_engine_torch as tex
......@@ -18,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import WeightGradStore
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
......@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
......@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.wgrad_store = wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
......@@ -246,53 +249,69 @@ class _BatchLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights
]
# WGRAD
_, grad_biases, _ = batchgemm(
inputmats,
grad_output_mats,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
batched_gemm_wgrad = functools.partial(
batchgemm,
dtype=ctx.activation_dtype,
workspaces=get_multi_stream_cublas_batchgemm_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
)
# Deallocate input tensor
clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t)
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmats, grad_output_mats, wgrad_list], batched_gemm_wgrad)
else:
_, grad_biases_, _ = batched_gemm_wgrad(inputmats, grad_output_mats, wgrad_list)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor
clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t)
def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
wgrad = None
return wgrad
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
else:
wgrad = None
return wgrad
wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
if not ctx.use_bias or (
ctx.wgrad_store is not None
and ctx.wgrad_store.delay_wgrad_compute()
and not ctx.fp8
):
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
......@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fp8_meta
None, # fuse_wgrad_accumulation
None, # cpu_offloading
......@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
"""
def __init__(
......@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
) -> None:
super().__init__()
......@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, self.num_gemms, 2 * self.num_gemms
......@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
......@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_
del wgrad_list
del tensor_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment