Unverified Commit 1258bbe0 authored by Hongbin Liu's avatar Hongbin Liu Committed by GitHub
Browse files

Manually launch wgrad accumulation and reduce in backward_dw() instead of backward() (#1976)



* disable wgrad accumulation and reduce in backward() And manually launch it in backward_dw()
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* refactor
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* refactor
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* set skip_backward_post_hook to True only if delay_wgrad_compute is True
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8e2d37e9
......@@ -582,6 +582,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fsdp_group = None
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
self.wgrad_accumulation_and_reduce_hooks = []
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
......@@ -1383,6 +1384,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook):
"""
This method is used to manually control the weight gradient accumulation and reduce.
This method should be called before the backward() method.
Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation
and reduce in backward();
And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method.
"""
self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
......@@ -1393,14 +1404,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights)
weight_tensor = noop_cat(self._get_weight_tensors())
if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad
del bgrad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def _validate_name(self):
"""
......
......@@ -662,6 +662,12 @@ class GroupedLinear(TransformerEngineBaseModule):
self.reset_parameters(defer_init=device == "meta")
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
for i in range(self.num_gemms):
if name in (f"weight{i}", f"bias{i}"):
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
......@@ -819,19 +825,21 @@ class GroupedLinear(TransformerEngineBaseModule):
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if weight_params[i].grad is None:
weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
if bias_params[i].grad is None:
bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype)
del grad_biases_
del wgrad_list
del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
......
......@@ -1382,6 +1382,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
......
......@@ -1642,6 +1642,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
warmup_jit_bias_gelu_all_dtypes(
self.size_per_partition, seq_length, micro_batch_size
)
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in ["fc1_weight", "fc2_weight", "fc1_bias", "fc2_bias"]:
param.skip_backward_post_hook = True
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
......@@ -2152,3 +2156,5 @@ class LayerNormMLP(TransformerEngineBaseModule):
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
......@@ -1270,6 +1270,11 @@ class Linear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment