Commit 460b006c authored by yuguo's avatar yuguo
Browse files

[DCU] surpport delay_wgrad_compute in batchgemm

parent 196a213f
...@@ -171,7 +171,7 @@ def reset_global_fp8_state(): ...@@ -171,7 +171,7 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
def _test_batched_linear_accuracy( def _test_batched_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
): ):
reset_rng_states() reset_rng_states()
if fp8: if fp8:
...@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy( ...@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
) )
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
if delay_wgrad_compute:
if isinstance(block, BatchedLinear):
block.backward_dw()
else:
for i in range(num_gemms):
block[i].backward_dw()
torch.cuda.synchronize() torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad] outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
if isinstance(block, BatchedLinear):
if getattr(p, "main_grad", None) is not None:
for j in range(batch_num):
outputs.append(p.main_grad[p.main_grad.shape[0] // batch_num * j : p.main_grad.shape[0] // batch_num * (j + 1)])
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
for j in range(batch_num):
outputs.append(p.grad[p.grad.shape[0] // batch_num * j : p.grad.shape[0] // batch_num * (j + 1)])
else:
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs return outputs
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy( ...@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_batched_linear_accuracy( def test_batched_linear_accuracy(
dtype, dtype,
num_gemms, num_gemms,
...@@ -224,6 +247,7 @@ def test_batched_linear_accuracy( ...@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
recipe, recipe,
fp8_model_params, fp8_model_params,
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
delay_wgrad_compute,
parallel_mode=None, parallel_mode=None,
): ):
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2")) batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
...@@ -250,6 +274,7 @@ def test_batched_linear_accuracy( ...@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
...@@ -281,10 +306,10 @@ def test_batched_linear_accuracy( ...@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone() sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone()
outputs_ref = _test_batched_linear_accuracy( outputs_ref = _test_batched_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
) )
outputs = _test_batched_linear_accuracy( outputs = _test_batched_linear_accuracy(
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
) )
# Shoule be bit-wise match # Shoule be bit-wise match
...@@ -292,4 +317,4 @@ def test_batched_linear_accuracy( ...@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3) torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3)
if __name__ == "__main__": if __name__ == "__main__":
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True) test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True, True)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import os import os
import logging import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union, List from typing import Any, Callable, Dict, Optional, Tuple, Union, List
import functools
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -18,6 +18,7 @@ from .base import ( ...@@ -18,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import WeightGradStore
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
...@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
wgrad_store: WeightGradStore,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool, cpu_offloading: bool,
...@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.wgrad_store = wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1]) return out.view(-1, *inp.shape[1:-1], out.shape[-1])
...@@ -246,26 +249,30 @@ class _BatchLinear(torch.autograd.Function): ...@@ -246,26 +249,30 @@ class _BatchLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights for w in weights
] ]
# WGRAD batched_gemm_wgrad = functools.partial(
_, grad_biases, _ = batchgemm( batchgemm,
inputmats, dtype=ctx.activation_dtype,
grad_output_mats, workspaces=get_multi_stream_cublas_batchgemm_workspace(),
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
) )
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmats, grad_output_mats, wgrad_list], batched_gemm_wgrad)
else:
_, grad_biases_, _ = batched_gemm_wgrad(inputmats, grad_output_mats, wgrad_list)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor # Deallocate input tensor
clear_tensor_data(*inputmats) clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t) clear_tensor_data(*inputmats_t)
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
def handle_custom_ddp_from_mcore(w, wgrad): def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad: if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
...@@ -293,6 +300,18 @@ class _BatchLinear(torch.autograd.Function): ...@@ -293,6 +300,18 @@ class _BatchLinear(torch.autograd.Function):
wgrad_list = [ wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
] ]
else:
wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
and ctx.wgrad_store.delay_wgrad_compute()
and not ctx.fp8
):
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
...@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
None, # fp8_calibration None, # fp8_calibration
None, # wgrad_store
None, # fp8_meta None, # fp8_meta
None, # fuse_wgrad_accumulation None, # fuse_wgrad_accumulation
None, # cpu_offloading None, # cpu_offloading
...@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
""" """
def __init__( def __init__(
...@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, self.num_gemms, 2 * self.num_gemms _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, self.num_gemms, 2 * self.num_gemms
...@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.wgrad_store,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled, CPUOffloadEnabled,
...@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out return out
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_
del wgrad_list
del tensor_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment