Unverified Commit e0204fbb authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Refactor `te.ops` (#1951)



* Refactor _OperationFuserAutogradFunction.forward to use less parameters
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit f8f59b1bb184e89468058521df4cfff029ad909c)

* Rename `BackwardBiasActivation` to `BackwardActivationBias`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 397c58fc296f801fe4ad600aadc2daff3b78be45)

* Use forward operation order in backward fused operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 2d37a9385069b066e6cdeff3eb9173c2079cb791)

* Rename `prev_op_grad_input_quantizer` to `prev_op_grad_output_quantizer`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit d7ab5dfb23e216866f7f4fc4d7a99f625d329f1e)

* Make OperationFuser persistent
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 77984d9715d31e87519dc6ea1e02c483a81355a7)

* Distribute extra inputs to and collect extra outputs from multiple module groups in Sequential
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 0716aaad542e59f2c1ac4620167965a0334bbf71)

* Take requires_grad into account when fusing operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Change get_quantizer to return None if no quantization recipe is used
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Refactor pre_first_forward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix for failing `test_make_graphed_callables[fp8_recipe0-*-True-*-linear_op]`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix linting errors
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix fp8 meta tensors in CUDA Graph capture
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix failing distributed userbuffers tests
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cb504cda
......@@ -20,7 +20,7 @@ import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardBiasActivation,
BackwardActivationBias,
BackwardLinearAdd,
ForwardLinearBiasActivation,
ForwardLinearBiasAdd,
......@@ -262,6 +262,65 @@ class TestSequentialContainer:
model(torch.zeros(1))
assert len(model._module_groups) == 6
def test_extra_tensors(self, size: int = 16) -> None:
"""Check that extra inputs are distributed properly between module groups
and that extra outputs are properly collected"""
# Construct sequential container
bias = te_ops.Bias(size=size, device="cpu")
with torch.no_grad():
bias.bias.copy_(torch.rand((size,)))
model = te_ops.Sequential( # | Inputs | Outputs
torch.nn.Identity(), # | x1 | x1
te_ops.MakeExtraOutput(), # | x1 | x1 [x1]
bias, # | x1 | h1 (= x1 + b)
te_ops.MakeExtraOutput(), # | h1 | h1 [h1]
te_ops.AddInPlace(), # | h1 [x2] | x2 (= x2 + h1)
te_ops.MakeExtraOutput(), # | x2 | x2 [x2]
torch.nn.Identity(), # | x2 | x2
bias, # | x2 | h2 (= x2 + b)
te_ops.AddInPlace(), # | h2 [x3] | x3 (= x3 + h2)
te_ops.MakeExtraOutput(), # | x3 | x3 [x3]
te_ops.AddInPlace(), # | x3 [x4] | x4 (= x4 + x3)
torch.nn.Identity(), # | x4 | x4
te_ops.Identity(), # | x4 | x4
te_ops.MakeExtraOutput(), # | x4 | x4 [x4]
te_ops.Identity(), # | x4 | x4
)
# Create input tensors
x1 = torch.rand((size,))
x2 = torch.rand((size,))
x3 = torch.rand((size,))
x4 = torch.rand((size,))
# Save original input tensor values
x1_orig = x1.clone()
x2_orig = x2.clone()
x3_orig = x3.clone()
x4_orig = x4.clone()
# Run forward
ys = model(x1, x2, x3, x4)
# Check whether outputs match (x4, x1, h1, x2, x3, x4)
assert len(ys) == 6
assert ys[0].data_ptr() == x4.data_ptr()
assert ys[1].data_ptr() == x1.data_ptr()
assert ys[2].data_ptr() not in [x.data_ptr() for x in (x1, x2, x3, x4)]
assert ys[3].data_ptr() == x2.data_ptr()
assert ys[4].data_ptr() == x3.data_ptr()
assert ys[5].data_ptr() == x4.data_ptr()
# Check whether tensors have correct values
b = bias.bias
h1 = ys[2]
torch.testing.assert_close(x1, x1_orig)
torch.testing.assert_close(h1, x1_orig + b)
torch.testing.assert_close(x2, x2_orig + h1)
torch.testing.assert_close(x3, x3_orig + x2 + b)
torch.testing.assert_close(x4, x4_orig + x3)
class TestFuser:
"""Tests for operation fusion infrastructure"""
......@@ -1870,7 +1929,7 @@ class TestFusedOps:
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_bias_activation(
def test_backward_activation_bias(
self,
*,
activation: str,
......@@ -1879,7 +1938,7 @@ class TestFusedOps:
device: torch.device = "cuda",
quantization: Optional[str],
) -> None:
"""Backward dbias + dact + quantize"""
"""Backward dact + dbias + quantize"""
# Tensor dimensions
in_shape = list(out_shape)
......@@ -1938,7 +1997,7 @@ class TestFusedOps:
backward_ops = model._module_groups[0]._backward_ops
if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardBiasActivation)
assert isinstance(backward_ops[0][0], BackwardActivationBias)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
else:
assert len(backward_ops) == 3
......@@ -2185,6 +2244,7 @@ class TestSequentialModules:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True))
......@@ -2194,6 +2254,7 @@ class TestSequentialModules:
def test_layernorm_mlp(
self,
*,
requires_grad: bool,
bias: bool,
normalization: str,
quantized_compute: bool,
......@@ -2234,6 +2295,7 @@ class TestSequentialModules:
quantization=quantization,
test_dtype=dtype,
test_device=device,
requires_grad=requires_grad,
)
_, dy_test = make_reference_and_test_tensors(
in_shape,
......
......@@ -191,12 +191,6 @@ class TestFP8Recipe:
amax_compute_algo=amax_compute_algo,
)
# Get FP8 meta tensors
with te.fp8_autocast(fp8_recipe=recipe):
x_fp8_meta = op.get_quantizer("forward", 0)
w_fp8_meta = op.get_quantizer("forward", 1)
dy_fp8_meta = op.get_quantizer("backward", 0)
# Perform training steps
x_history = []
w_history = []
......@@ -228,19 +222,30 @@ class TestFP8Recipe:
y = op(x)
y.backward(dy)
def check_amax_history(
fp8_meta: dict,
ref_amax_history: Iterable[float],
) -> None:
"""Check that amax history matches expected values"""
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-amax_history_len:]
def check_metas(
test_scale: float,
test_amax_history: torch.Tensor,
ref_amax_history_list: list[float],
stage: str,
):
"""Check that meta tensors match expected values"""
# Compute amax
if len(ref_amax_history_list) > amax_history_len:
ref_amax_history_list = ref_amax_history_list[-(amax_history_len + 1) :]
ref_amax_history = torch.tensor(
ref_amax_history,
ref_amax_history_list,
dtype=torch.float32,
device=device,
)
test_amax_history = fp8_meta.amax_history[:, 0]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history_list)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history_list[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compare amax history
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(
test_amax_history[-(step + 1) :],
......@@ -248,23 +253,6 @@ class TestFP8Recipe:
**tols,
)
def check_scale(
quantizer: Float8Quantizer,
ref_amax_history: Iterable[float],
stage: str,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-(amax_history_len + 1) :]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compute scale
max_val = {
"forward": 448.0,
......@@ -272,16 +260,26 @@ class TestFP8Recipe:
}[stage]
ref_scale = (max_val / ref_amax) / (2**margin)
# Check values in FP8 meta tensors
# Compare scale
torch.testing.assert_close(
quantizer.scale.item(),
test_scale,
ref_scale,
)
# Get scaling factors
x_test_scale = op.get_quantizer("forward", 0).scale.item()
w_test_scale = op.get_quantizer("forward", 1).scale.item()
dy_test_scale = op.get_quantizer("backward", 0).scale.item()
# Get amax histories
x_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 0]
w_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 1]
dy_test_history = op._fp8_metas["backward"][backward_key].amax_history[:, 0]
# Check that results match expected values
check_scale(x_fp8_meta, x_history, "forward")
check_scale(w_fp8_meta, w_history, "forward")
check_scale(dy_fp8_meta, dy_history, "backward")
check_metas(x_test_scale, x_test_history, x_history, "forward")
check_metas(w_test_scale, w_test_history, w_history, "forward")
check_metas(dy_test_scale, dy_test_history, dy_history, "backward")
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
......
......@@ -21,6 +21,8 @@ from .fp8 import (
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation
from .ops import Sequential
from .ops.fuser import OperationFuser
from .utils import make_weak_ref
__all__ = ["make_graphed_callables"]
......@@ -44,7 +46,7 @@ def set_capture_end() -> None:
_IS_GRAPH_CAPTURING = False
def is_graph_capturing() -> None:
def is_graph_capturing() -> bool:
"""Return whether within `make_graphed_callables`."""
return _IS_GRAPH_CAPTURING
......@@ -338,6 +340,16 @@ def _make_graphed_callables(
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
visited_te_modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert module._module_groups is not None, "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
visited_te_modules.add(basic_op)
# Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()):
......@@ -674,15 +686,13 @@ def _make_graphed_callables(
# run the graph, otherwise run the original forward method
if func.training == graph_training_state:
# Set the FP8 group from global amax reduction.
if FP8GlobalStateManager.is_fp8_enabled():
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
for m in func.modules():
if (
isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()
):
if m not in visited_te_modules:
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if isinstance(m, TransformerEngineBaseModule):
from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention,
)
......@@ -699,6 +709,12 @@ def _make_graphed_callables(
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta,
)
elif isinstance(m, BasicOperation):
for mode in ("forward", "backward"):
if m.num_quantizers(mode):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m._fp8_metas[mode],
)
return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs)
......@@ -721,7 +737,7 @@ def _make_graphed_callables(
def save_fp8_tensors(
modules: Iterable[torch.nn.Module],
fp8_recipe: Recipe,
fp8_recipe: Optional[Recipe],
) -> Optional[List[Any]]:
"""
Returns the FP8 tensors for all modules
......@@ -740,7 +756,7 @@ def save_fp8_tensors(
m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation):
m.pre_first_forward(recipe=fp8_recipe)
m.reset_recipe_type(recipe=fp8_recipe)
module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors)
return fp8_tensors
......@@ -777,7 +793,7 @@ def make_graphed_callables(
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
......@@ -828,7 +844,7 @@ def make_graphed_callables(
data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
fp8_recipe: Recipe, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
......@@ -844,7 +860,10 @@ def make_graphed_callables(
"""
set_capture_start()
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
if fp8_enabled and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not fp8_enabled:
fp8_recipe = None
# Handle single module.
just_one_callable = False
......
......@@ -11,7 +11,6 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
......@@ -71,7 +70,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......@@ -87,14 +86,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check input tensor
x = maybe_dequantize(input_.contiguous(), dtype)
# Check if quantized compute is enabled
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None
if with_quantized_compute:
quantizer = next_op_input_quantizer
# Launch kernel
y = self._activation_forward_impl(x, quantizer)
y = self._activation_forward_impl(x, next_op_input_quantizer)
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
......@@ -103,10 +96,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = input_quantizer(x)
# Save state for backward pass
if ctx.requires_grad:
ctx.save_for_backward(x)
ctx.with_quantized_compute = with_quantized_compute
ctx.dtype = dtype
ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return y
......@@ -125,13 +118,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check grad output tensor
dy = maybe_dequantize(grad_output.contiguous(), x.dtype)
# Check if quantized compute is enabled
quantizer = None
if ctx.with_quantized_compute:
quantizer = ctx.prev_op_grad_input_quantizer
# Launch kernel
dx = self._activation_backward_impl(dy, x, quantizer)
dx = self._activation_backward_impl(dy, x, ctx.prev_op_grad_output_quantizer)
# Clear input tensor if possible
clear_tensor_data(x)
......
......@@ -59,7 +59,7 @@ class AddInPlace(BasicOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......
......@@ -40,7 +40,7 @@ class AllGather(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
out: torch.Tensor
......
......@@ -42,7 +42,7 @@ class AllReduce(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......
......@@ -22,7 +22,6 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
......@@ -291,6 +290,14 @@ class BasicLinear(BasicOperation):
# Quantize if needed
if self._with_quantized_weight:
quantizer = self.get_quantizer("forward", 1)
if quantizer is None:
raise RuntimeError(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because fp8_model_init was called "
"with enabled=True and recipe=None, instead of providing "
"a recipe to use for quantization."
)
quantizer.set_usage(
rowwise=True,
columnwise=torch.is_grad_enabled(),
......@@ -303,62 +310,19 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_first_forward(
self,
*,
recipe: Optional[Recipe],
) -> None:
super().pre_first_forward(recipe=recipe)
# Initialize weights if needed
weight = self.weight
if weight.device.type == "meta":
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.weight.device.type == "meta":
self.reset_parameters()
weight = self.weight
# Configure quantizers
if recipe is not None:
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
def reset_recipe_type(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_type(recipe=recipe)
# Specify required tensor formats
input_quantizer.internal = True
weight_quantizer.internal = True
grad_output_quantizer.internal = True
# Recipe-specific configuration
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer
if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters():
# Make quantizers use internal tensors
self.get_input_quantizer().internal = True
self.get_grad_output_quantizer().internal = True
self.get_quantizer("forward", 1).internal = True
@staticmethod
def _functional_forward(
......@@ -894,7 +858,7 @@ class BasicLinear(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......@@ -903,27 +867,34 @@ class BasicLinear(BasicOperation):
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
# Get quantizers
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
# Configure quantizers
# Note: We cache the quantized input for backward pass,
# but discard the quantized weights.
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False)
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Get autocast dtype if needed
if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda")
......@@ -947,6 +918,7 @@ class BasicLinear(BasicOperation):
)
# Save state for backward pass
if ctx.requires_grad:
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer
......
......@@ -18,7 +18,6 @@ from ...utils import (
canonicalize_device,
canonicalize_dtype,
)
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer
......@@ -114,8 +113,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias)
self.bias = bias
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.bias.device.type == "meta":
self.reset_parameters()
......@@ -123,24 +122,14 @@ class Bias(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
x = input_
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if previous op quantizes its output's gradient
grad_input_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
grad_input_quantizer = prev_op_grad_input_quantizer
if requires_grad:
ctx.with_quantized_compute = with_quantized_compute
ctx.grad_input_quantizer = grad_input_quantizer
if ctx.requires_grad:
ctx.grad_input_quantizer = prev_op_grad_output_quantizer
return x + b
......@@ -152,7 +141,7 @@ class Bias(BasicOperation):
dy = grad_output
if dy.dim() > 1:
quantizer = ctx.grad_input_quantizer
if ctx.with_quantized_compute and quantizer is not None:
if quantizer is not None:
db, dy = tex.bgrad_quantize(dy, quantizer)
else:
db = dy.sum(tuple(range(dy.dim() - 1)))
......
......@@ -23,7 +23,7 @@ class Identity(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
return input_
......
......@@ -74,7 +74,7 @@ class L2Normalization(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors
......
......@@ -13,7 +13,6 @@ from typing import Optional
import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType
from ...utils import (
canonicalize_device,
......@@ -168,8 +167,8 @@ class LayerNorm(BasicOperation):
self.weight = weight
self.bias = bias
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.weight.device.type == "meta" or self.bias.device.type == "meta":
self.reset_parameters()
......@@ -177,7 +176,7 @@ class LayerNorm(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
......@@ -200,31 +199,22 @@ class LayerNorm(BasicOperation):
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
b = maybe_dequantize(self.bias, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute layer norm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"]
y, means, rstdevs = layernorm_fwd(
x,
w,
b,
self.eps,
None,
output_quantizer,
next_op_input_quantizer,
TE_DType[dtype],
sm_margin,
self.zero_centered_gamma,
)
# Save state for backward pass
if requires_grad:
if ctx.requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype
......
......@@ -59,7 +59,7 @@ class MakeExtraOutput(BasicOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......
......@@ -50,7 +50,7 @@ class Quantize(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......@@ -64,6 +64,7 @@ class Quantize(BasicOperation):
if quantize_forward and not is_quantized_tensor(out):
out = self.get_quantizer("forward", 0)(out)
if ctx.requires_grad:
ctx.quantize_backward = quantize_backward
return out
......
......@@ -40,7 +40,7 @@ class ReduceScatter(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
......
......@@ -38,9 +38,10 @@ class Reshape(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if ctx.requires_grad:
ctx.input_shape = input_.size()
return input_.reshape(*self._shape)
......
......@@ -13,7 +13,6 @@ from typing import Optional
import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType
from ...utils import (
canonicalize_device,
......@@ -151,8 +150,8 @@ class RMSNorm(BasicOperation):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_first_forward(self, *args, **kwargs) -> None:
super().pre_first_forward(*args, **kwargs)
def pre_first_fuser_forward(self) -> None:
super().pre_first_fuser_forward()
if self.weight.device.type == "meta":
self.reset_parameters()
......@@ -160,7 +159,7 @@ class RMSNorm(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer],
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
if is_in_onnx_export_mode():
......@@ -182,30 +181,21 @@ class RMSNorm(BasicOperation):
x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim))
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute RMSNorm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"]
y, _, rstdevs = rmsnorm_fwd(
x,
w,
self.eps,
None,
output_quantizer,
next_op_input_quantizer,
TE_DType[dtype],
sm_margin,
self.zero_centered_gamma,
)
# Save state for backward pass
if requires_grad:
if ctx.requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype
......
......@@ -4,9 +4,9 @@
"""Compound tensor operation supported by the operation fuser."""
from .backward_bias_activation import (
BackwardBiasActivation,
fuse_backward_bias_activation,
from .backward_activation_bias import (
BackwardActivationBias,
fuse_backward_activation_bias,
)
from .backward_linear_add import (
BackwardLinearAdd,
......
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Fused backward dbias + dact + quantize."""
"""Fused backward dact + dbias + quantize."""
from __future__ import annotations
from typing import Optional
......@@ -29,8 +29,8 @@ _fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu}
_fusible_activations = tuple(_fused_activations.keys())
class BackwardBiasActivation(FusedOperation):
"""Fused backward dbias + dact + quantize
class BackwardActivationBias(FusedOperation):
"""Fused backward dact + dbias + quantize
Uses the next operation's input quantizer.
......@@ -66,15 +66,10 @@ class BackwardBiasActivation(FusedOperation):
dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype)
# Get previous op quantizer
if not bias_op_ctx.with_quantized_compute:
raise RuntimeError(
"BackwardBiasActivation requires quantized compute, "
"but Bias context has it disabled"
)
quantizer = bias_op_ctx.grad_input_quantizer
if quantizer is None:
raise RuntimeError(
"BackwardBiasActivation requires previous op's grad output quantizer, "
"BackwardActivationBias requires previous op's grad output quantizer, "
"but Bias context has no quantizer"
)
......@@ -87,11 +82,11 @@ class BackwardBiasActivation(FusedOperation):
return dx, [(), (db,)], [(), ()]
def fuse_backward_bias_activation(
def fuse_backward_activation_bias(
ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dbias + dact + quantize
"""Fused backward dact + dbias + quantize
Parameters
----------
......@@ -138,7 +133,7 @@ def fuse_backward_bias_activation(
ops = ops[1:]
# Replace window with fused op
op = BackwardBiasActivation(
op = BackwardActivationBias(
activation=window[0][0],
bias=window[1][0],
)
......
......@@ -29,10 +29,10 @@ class BackwardLinearAdd(FusedOperation):
def __init__(
self,
*,
linear: BasicLinear,
backward_add: MakeExtraOutput,
linear: BasicLinear,
) -> None:
super().__init__((linear, backward_add))
super().__init__((backward_add, linear))
def fuser_backward(
self,
......@@ -47,7 +47,7 @@ class BackwardLinearAdd(FusedOperation):
]:
# Get basic operations
linear_op = self.basic_ops[0]
linear_op = self.basic_ops[1]
linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment