Unverified Commit 9166d4df authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Call `pre_(first_)forward` only when global state changes (#1917)



* Change pre_forward to pre_first_forward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix passing invalid recipe with fp8 disabled
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent d26cc3a0
......@@ -593,7 +593,7 @@ def save_fp8_tensors(
m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation):
m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe)
m.pre_first_forward(recipe=fp8_recipe)
module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors)
return fp8_tensors
......
......@@ -19,7 +19,7 @@ from ...distributed import (
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from ...fp8 import FP8GlobalStateManager
from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
......@@ -303,8 +303,12 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
super().pre_first_forward(recipe=recipe)
# Initialize weights if needed
weight = self.weight
......@@ -313,23 +317,17 @@ class BasicLinear(BasicOperation):
weight = self.weight
# Configure quantizers
if FP8GlobalStateManager.is_fp8_enabled():
if recipe is not None:
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
# Specify required tensor formats
is_grad_enabled = torch.is_grad_enabled()
weight_requires_grad = is_grad_enabled and weight.requires_grad
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
input_quantizer.internal = True
weight_quantizer.internal = True
grad_output_quantizer.internal = True
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
......
......@@ -112,8 +112,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias)
self.bias = bias
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
if self.bias.device.type == "meta":
self.reset_parameters()
......
......@@ -167,8 +167,8 @@ class LayerNorm(BasicOperation):
self.weight = weight
self.bias = bias
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
if self.weight.device.type == "meta" or self.bias.device.type == "meta":
self.reset_parameters()
......
......@@ -150,8 +150,8 @@ class RMSNorm(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
if self.weight.device.type == "meta":
self.reset_parameters()
......
......@@ -344,6 +344,9 @@ class OperationFuser:
self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)]
self._backward_ops = list(reversed(self._forward_ops))
# Flag for checking if this is the first iteration
self._is_first_forward = True
# Fuse ops if needed
if fuse_ops:
self.fuse_ops()
......@@ -391,8 +394,12 @@ class OperationFuser:
)
# Initialization before forward pass
for op in self._basic_ops:
op.pre_forward()
if self._is_first_forward:
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
for op in self._basic_ops:
op.pre_first_forward(recipe=recipe)
self._is_first_forward = False
# Canonicalize op kwargs
if basic_op_kwargs is None:
......
......@@ -65,7 +65,11 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def is_fused_op(self) -> bool:
"""Whether this op is the fusion of one or more basic ops"""
def pre_forward(self) -> None:
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
def get_input_quantizer(self) -> Optional[Quantizer]:
......@@ -223,14 +227,10 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def _reset_quantization_recipe_state(
self,
*,
recipe: Optional[Recipe] = None,
recipe: Recipe,
) -> None:
"""Construct state for quantization recipe"""
# Quantization recipe
if recipe is None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
# Quantization recipe state for forward and backward pass
self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []}
......@@ -265,14 +265,10 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def _update_quantization_recipe_state(
self,
*,
recipe: Optional[Recipe] = None,
recipe: Recipe,
) -> None:
"""Make sure quantizer state matches quantization recipe"""
# Quantization recipe
if recipe is None:
recipe = FP8GlobalStateManager.get_fp8_recipe()
# Reset quantization state if needed
if self._fp8_metas is None or self._quantizers is None:
self._reset_quantization_recipe_state(recipe=recipe)
......@@ -346,7 +342,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
"""
if self._quantizers is None:
self._reset_quantization_recipe_state()
self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
return self._quantizers[mode][index]
@torch.no_grad()
......@@ -397,19 +393,16 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale)
self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history)
def pre_forward(
def pre_first_forward(
self,
*,
fp8_enabled: Optional[bool] = None,
fp8_recipe: Optional[Recipe] = None,
recipe: Optional[Recipe],
) -> None:
"""Preprocessing before forward pass"""
# Initialize FP8 metadata if needed
if fp8_enabled is None:
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
if fp8_enabled:
self._update_quantization_recipe_state(recipe=fp8_recipe)
if recipe is not None:
self._update_quantization_recipe_state(recipe=recipe)
if not FP8GlobalStateManager.fp8_graph_capturing():
if self.num_quantizers("forward"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
......@@ -647,7 +640,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self._reset_quantization_recipe_state()
self._reset_quantization_recipe_state(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode]
# Load extra items
......@@ -728,10 +721,10 @@ class FusedOperation(FusibleOperation):
def get_grad_input_quantizer(self) -> Optional[Quantizer]:
return self.basic_ops[-1].get_grad_input_quantizer()
def pre_forward(self) -> None:
def pre_first_forward(self, *args, **kwargs) -> None:
"""Preprocessing before forward pass"""
for op in self.basic_ops:
op.pre_forward()
op.pre_first_forward(*args, **kwargs)
def forward(
self,
......
......@@ -10,6 +10,7 @@ from typing import Optional
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.fuser import OperationFuser
......@@ -37,6 +38,9 @@ class Sequential(torch.nn.Module):
self._module_groups: Optional[list[OperationFuser | torch.nn.Module]]
self._module_groups = None
# Global state of last iteration
self._last_global_state = None
# Add modules
if len(args) == 1 and isinstance(args[0], dict):
for key, module in args[0].items():
......@@ -185,6 +189,16 @@ class Sequential(torch.nn.Module):
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass"""
# Get current global state
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_enabled else None
global_state = (fp8_enabled, type(fp8_recipe))
# Reset module groups is global state changed
if self._last_global_state != global_state:
self._module_groups = None
self._last_global_state = global_state
# Create module groups if needed
if self._module_groups is None:
self._module_groups = self._make_module_groups(self._modules.values())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment