Unverified Commit 1258bbe0 authored by Hongbin Liu's avatar Hongbin Liu Committed by GitHub
Browse files

Manually launch wgrad accumulation and reduce in backward_dw() instead of backward() (#1976)



* disable wgrad accumulation and reduce in backward() And manually launch it in backward_dw()
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* refactor
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* refactor
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* set skip_backward_post_hook to True only if delay_wgrad_compute is True
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 8e2d37e9
...@@ -582,6 +582,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -582,6 +582,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fsdp_group = None self.fsdp_group = None
self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None self.activation_dtype: Optional[torch.dtype] = None
self.wgrad_accumulation_and_reduce_hooks = []
if not TEDebugState.debug_enabled: if not TEDebugState.debug_enabled:
TEDebugState.initialize() TEDebugState.initialize()
...@@ -1383,6 +1384,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1383,6 +1384,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
) )
def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook):
"""
This method is used to manually control the weight gradient accumulation and reduce.
This method should be called before the backward() method.
Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation
and reduce in backward();
And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method.
"""
self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
...@@ -1393,14 +1404,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1393,14 +1404,17 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, bgrad), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names] weight_tensor = noop_cat(self._get_weight_tensors())
weight_tensor = noop_cat(unfused_weights)
if weight_tensor.grad is None: if weight_tensor.grad is None:
weight_tensor.grad = wgrad.to(weight_tensor.dtype) weight_tensor.grad = wgrad.to(weight_tensor.dtype)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None: if bias_tensor.grad is None:
bias_tensor.grad = bgrad.to(bias_tensor.dtype) bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad
del bgrad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def _validate_name(self): def _validate_name(self):
""" """
......
...@@ -662,6 +662,12 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -662,6 +662,12 @@ class GroupedLinear(TransformerEngineBaseModule):
self.reset_parameters(defer_init=device == "meta") self.reset_parameters(defer_init=device == "meta")
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
for i in range(self.num_gemms):
if name in (f"weight{i}", f"bias{i}"):
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
...@@ -819,19 +825,21 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -819,19 +825,21 @@ class GroupedLinear(TransformerEngineBaseModule):
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop() (_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2] wgrad_list = tensor_list[2]
weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms): for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}") if weight_params[i].grad is None:
if weight_param.grad is None: weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype)
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias: if self.use_bias:
for i in range(self.num_gemms): for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}") if bias_params[i].grad is None:
if bias_param.grad is None: bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype)
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_ del grad_biases_
del wgrad_list del wgrad_list
del tensor_list del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear.""" """Customize quantizers based on current scaling recipe + linear."""
......
...@@ -1382,6 +1382,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1382,6 +1382,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
......
...@@ -1642,6 +1642,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1642,6 +1642,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
warmup_jit_bias_gelu_all_dtypes( warmup_jit_bias_gelu_all_dtypes(
self.size_per_partition, seq_length, micro_batch_size self.size_per_partition, seq_length, micro_batch_size
) )
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in ["fc1_weight", "fc2_weight", "fc1_bias", "fc2_bias"]:
param.skip_backward_post_hook = True
# These many SMs are subtracted from the total SM count when calling forward # These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
...@@ -2152,3 +2156,5 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2152,3 +2156,5 @@ class LayerNormMLP(TransformerEngineBaseModule):
del fc2_wgrad del fc2_wgrad
del fc1_wgrad del fc1_wgrad
del fc1_bias_grad del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
...@@ -1270,6 +1270,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -1270,6 +1270,11 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
if name in self.weight_names or name in self.bias_names:
param.skip_backward_post_hook = True
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment