Unverified Commit e0204fbb authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Refactor `te.ops` (#1951)



* Refactor _OperationFuserAutogradFunction.forward to use less parameters
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit f8f59b1bb184e89468058521df4cfff029ad909c)

* Rename `BackwardBiasActivation` to `BackwardActivationBias`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 397c58fc296f801fe4ad600aadc2daff3b78be45)

* Use forward operation order in backward fused operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 2d37a9385069b066e6cdeff3eb9173c2079cb791)

* Rename `prev_op_grad_input_quantizer` to `prev_op_grad_output_quantizer`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit d7ab5dfb23e216866f7f4fc4d7a99f625d329f1e)

* Make OperationFuser persistent
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 77984d9715d31e87519dc6ea1e02c483a81355a7)

* Distribute extra inputs to and collect extra outputs from multiple module groups in Sequential
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
(cherry picked from commit 0716aaad542e59f2c1ac4620167965a0334bbf71)

* Take requires_grad into account when fusing operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Change get_quantizer to return None if no quantization recipe is used
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Refactor pre_first_forward
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix for failing `test_make_graphed_callables[fp8_recipe0-*-True-*-linear_op]`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix linting errors
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix fp8 meta tensors in CUDA Graph capture
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix failing distributed userbuffers tests
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cb504cda
...@@ -20,7 +20,7 @@ import transformer_engine.pytorch as te ...@@ -20,7 +20,7 @@ import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
BackwardBiasActivation, BackwardActivationBias,
BackwardLinearAdd, BackwardLinearAdd,
ForwardLinearBiasActivation, ForwardLinearBiasActivation,
ForwardLinearBiasAdd, ForwardLinearBiasAdd,
...@@ -262,6 +262,65 @@ class TestSequentialContainer: ...@@ -262,6 +262,65 @@ class TestSequentialContainer:
model(torch.zeros(1)) model(torch.zeros(1))
assert len(model._module_groups) == 6 assert len(model._module_groups) == 6
def test_extra_tensors(self, size: int = 16) -> None:
"""Check that extra inputs are distributed properly between module groups
and that extra outputs are properly collected"""
# Construct sequential container
bias = te_ops.Bias(size=size, device="cpu")
with torch.no_grad():
bias.bias.copy_(torch.rand((size,)))
model = te_ops.Sequential( # | Inputs | Outputs
torch.nn.Identity(), # | x1 | x1
te_ops.MakeExtraOutput(), # | x1 | x1 [x1]
bias, # | x1 | h1 (= x1 + b)
te_ops.MakeExtraOutput(), # | h1 | h1 [h1]
te_ops.AddInPlace(), # | h1 [x2] | x2 (= x2 + h1)
te_ops.MakeExtraOutput(), # | x2 | x2 [x2]
torch.nn.Identity(), # | x2 | x2
bias, # | x2 | h2 (= x2 + b)
te_ops.AddInPlace(), # | h2 [x3] | x3 (= x3 + h2)
te_ops.MakeExtraOutput(), # | x3 | x3 [x3]
te_ops.AddInPlace(), # | x3 [x4] | x4 (= x4 + x3)
torch.nn.Identity(), # | x4 | x4
te_ops.Identity(), # | x4 | x4
te_ops.MakeExtraOutput(), # | x4 | x4 [x4]
te_ops.Identity(), # | x4 | x4
)
# Create input tensors
x1 = torch.rand((size,))
x2 = torch.rand((size,))
x3 = torch.rand((size,))
x4 = torch.rand((size,))
# Save original input tensor values
x1_orig = x1.clone()
x2_orig = x2.clone()
x3_orig = x3.clone()
x4_orig = x4.clone()
# Run forward
ys = model(x1, x2, x3, x4)
# Check whether outputs match (x4, x1, h1, x2, x3, x4)
assert len(ys) == 6
assert ys[0].data_ptr() == x4.data_ptr()
assert ys[1].data_ptr() == x1.data_ptr()
assert ys[2].data_ptr() not in [x.data_ptr() for x in (x1, x2, x3, x4)]
assert ys[3].data_ptr() == x2.data_ptr()
assert ys[4].data_ptr() == x3.data_ptr()
assert ys[5].data_ptr() == x4.data_ptr()
# Check whether tensors have correct values
b = bias.bias
h1 = ys[2]
torch.testing.assert_close(x1, x1_orig)
torch.testing.assert_close(h1, x1_orig + b)
torch.testing.assert_close(x2, x2_orig + h1)
torch.testing.assert_close(x3, x3_orig + x2 + b)
torch.testing.assert_close(x4, x4_orig + x3)
class TestFuser: class TestFuser:
"""Tests for operation fusion infrastructure""" """Tests for operation fusion infrastructure"""
...@@ -1870,7 +1929,7 @@ class TestFusedOps: ...@@ -1870,7 +1929,7 @@ class TestFusedOps:
@pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32))) @pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_bias_activation( def test_backward_activation_bias(
self, self,
*, *,
activation: str, activation: str,
...@@ -1879,7 +1938,7 @@ class TestFusedOps: ...@@ -1879,7 +1938,7 @@ class TestFusedOps:
device: torch.device = "cuda", device: torch.device = "cuda",
quantization: Optional[str], quantization: Optional[str],
) -> None: ) -> None:
"""Backward dbias + dact + quantize""" """Backward dact + dbias + quantize"""
# Tensor dimensions # Tensor dimensions
in_shape = list(out_shape) in_shape = list(out_shape)
...@@ -1938,7 +1997,7 @@ class TestFusedOps: ...@@ -1938,7 +1997,7 @@ class TestFusedOps:
backward_ops = model._module_groups[0]._backward_ops backward_ops = model._module_groups[0]._backward_ops
if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]: if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]:
assert len(backward_ops) == 2 assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardBiasActivation) assert isinstance(backward_ops[0][0], BackwardActivationBias)
assert isinstance(backward_ops[1][0], te_ops.Quantize) assert isinstance(backward_ops[1][0], te_ops.Quantize)
else: else:
assert len(backward_ops) == 3 assert len(backward_ops) == 3
...@@ -2185,6 +2244,7 @@ class TestSequentialModules: ...@@ -2185,6 +2244,7 @@ class TestSequentialModules:
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm")) @pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm"))
@pytest.mark.parametrize("quantized_compute", (False, True)) @pytest.mark.parametrize("quantized_compute", (False, True))
...@@ -2194,6 +2254,7 @@ class TestSequentialModules: ...@@ -2194,6 +2254,7 @@ class TestSequentialModules:
def test_layernorm_mlp( def test_layernorm_mlp(
self, self,
*, *,
requires_grad: bool,
bias: bool, bias: bool,
normalization: str, normalization: str,
quantized_compute: bool, quantized_compute: bool,
...@@ -2234,6 +2295,7 @@ class TestSequentialModules: ...@@ -2234,6 +2295,7 @@ class TestSequentialModules:
quantization=quantization, quantization=quantization,
test_dtype=dtype, test_dtype=dtype,
test_device=device, test_device=device,
requires_grad=requires_grad,
) )
_, dy_test = make_reference_and_test_tensors( _, dy_test = make_reference_and_test_tensors(
in_shape, in_shape,
......
...@@ -191,12 +191,6 @@ class TestFP8Recipe: ...@@ -191,12 +191,6 @@ class TestFP8Recipe:
amax_compute_algo=amax_compute_algo, amax_compute_algo=amax_compute_algo,
) )
# Get FP8 meta tensors
with te.fp8_autocast(fp8_recipe=recipe):
x_fp8_meta = op.get_quantizer("forward", 0)
w_fp8_meta = op.get_quantizer("forward", 1)
dy_fp8_meta = op.get_quantizer("backward", 0)
# Perform training steps # Perform training steps
x_history = [] x_history = []
w_history = [] w_history = []
...@@ -228,19 +222,30 @@ class TestFP8Recipe: ...@@ -228,19 +222,30 @@ class TestFP8Recipe:
y = op(x) y = op(x)
y.backward(dy) y.backward(dy)
def check_amax_history( def check_metas(
fp8_meta: dict, test_scale: float,
ref_amax_history: Iterable[float], test_amax_history: torch.Tensor,
) -> None: ref_amax_history_list: list[float],
"""Check that amax history matches expected values""" stage: str,
if len(ref_amax_history) > amax_history_len: ):
ref_amax_history = ref_amax_history[-amax_history_len:] """Check that meta tensors match expected values"""
# Compute amax
if len(ref_amax_history_list) > amax_history_len:
ref_amax_history_list = ref_amax_history_list[-(amax_history_len + 1) :]
ref_amax_history = torch.tensor( ref_amax_history = torch.tensor(
ref_amax_history, ref_amax_history_list,
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
) )
test_amax_history = fp8_meta.amax_history[:, 0] if amax_compute_algo == "max":
ref_amax = max(ref_amax_history_list)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history_list[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compare amax history
tols = dict(rtol=0, atol=0) tols = dict(rtol=0, atol=0)
torch.testing.assert_close( torch.testing.assert_close(
test_amax_history[-(step + 1) :], test_amax_history[-(step + 1) :],
...@@ -248,23 +253,6 @@ class TestFP8Recipe: ...@@ -248,23 +253,6 @@ class TestFP8Recipe:
**tols, **tols,
) )
def check_scale(
quantizer: Float8Quantizer,
ref_amax_history: Iterable[float],
stage: str,
):
"""Check that scale and scale reciprocal match expected values"""
# Compute amax
if len(ref_amax_history) > amax_history_len:
ref_amax_history = ref_amax_history[-(amax_history_len + 1) :]
if amax_compute_algo == "max":
ref_amax = max(ref_amax_history)
elif amax_compute_algo == "most_recent":
ref_amax = ref_amax_history[-1]
else:
raise RuntimeError(f"{amax_compute_algo=} is not supported")
# Compute scale # Compute scale
max_val = { max_val = {
"forward": 448.0, "forward": 448.0,
...@@ -272,16 +260,26 @@ class TestFP8Recipe: ...@@ -272,16 +260,26 @@ class TestFP8Recipe:
}[stage] }[stage]
ref_scale = (max_val / ref_amax) / (2**margin) ref_scale = (max_val / ref_amax) / (2**margin)
# Check values in FP8 meta tensors # Compare scale
torch.testing.assert_close( torch.testing.assert_close(
quantizer.scale.item(), test_scale,
ref_scale, ref_scale,
) )
# Get scaling factors
x_test_scale = op.get_quantizer("forward", 0).scale.item()
w_test_scale = op.get_quantizer("forward", 1).scale.item()
dy_test_scale = op.get_quantizer("backward", 0).scale.item()
# Get amax histories
x_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 0]
w_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 1]
dy_test_history = op._fp8_metas["backward"][backward_key].amax_history[:, 0]
# Check that results match expected values # Check that results match expected values
check_scale(x_fp8_meta, x_history, "forward") check_metas(x_test_scale, x_test_history, x_history, "forward")
check_scale(w_fp8_meta, w_history, "forward") check_metas(w_test_scale, w_test_history, w_history, "forward")
check_scale(dy_fp8_meta, dy_history, "backward") check_metas(dy_test_scale, dy_test_history, dy_history, "backward")
@pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
@pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
......
...@@ -21,6 +21,8 @@ from .fp8 import ( ...@@ -21,6 +21,8 @@ from .fp8 import (
from .distributed import get_all_rng_states, graph_safe_rng_available from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .ops.op import BasicOperation from .ops.op import BasicOperation
from .ops import Sequential
from .ops.fuser import OperationFuser
from .utils import make_weak_ref from .utils import make_weak_ref
__all__ = ["make_graphed_callables"] __all__ = ["make_graphed_callables"]
...@@ -44,7 +46,7 @@ def set_capture_end() -> None: ...@@ -44,7 +46,7 @@ def set_capture_end() -> None:
_IS_GRAPH_CAPTURING = False _IS_GRAPH_CAPTURING = False
def is_graph_capturing() -> None: def is_graph_capturing() -> bool:
"""Return whether within `make_graphed_callables`.""" """Return whether within `make_graphed_callables`."""
return _IS_GRAPH_CAPTURING return _IS_GRAPH_CAPTURING
...@@ -338,6 +340,16 @@ def _make_graphed_callables( ...@@ -338,6 +340,16 @@ def _make_graphed_callables(
def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule): if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module) visited_te_modules.add(module)
# If forward is called on a BasicOperation directly the hook will run
elif isinstance(module, BasicOperation):
visited_te_modules.add(module)
# If forward is called on a te.ops.Sequential it is not called on its constituent ops
elif isinstance(module, Sequential):
assert module._module_groups is not None, "Should have been initialized by warmup"
for module_group in module._module_groups:
if isinstance(module_group, OperationFuser):
for basic_op in module_group._basic_ops:
visited_te_modules.add(basic_op)
# Run warmup and do the above filtering. # Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()): with torch.cuda.stream(torch.cuda.Stream()):
...@@ -674,15 +686,13 @@ def _make_graphed_callables( ...@@ -674,15 +686,13 @@ def _make_graphed_callables(
# run the graph, otherwise run the original forward method # run the graph, otherwise run the original forward method
if func.training == graph_training_state: if func.training == graph_training_state:
# Set the FP8 group from global amax reduction. # Set the FP8 group from global amax reduction.
if FP8GlobalStateManager.is_fp8_enabled():
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
for m in func.modules(): for m in func.modules():
if (
isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()
):
if m not in visited_te_modules: if m not in visited_te_modules:
# Only Set the FP8 meta for the modules included by forward # Only Set the FP8 meta for the modules included by forward
continue continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if isinstance(m, TransformerEngineBaseModule):
from transformer_engine.pytorch.attention.dot_product_attention import ( from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention, DotProductAttention,
) )
...@@ -699,6 +709,12 @@ def _make_graphed_callables( ...@@ -699,6 +709,12 @@ def _make_graphed_callables(
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, m.fp8_meta,
) )
elif isinstance(m, BasicOperation):
for mode in ("forward", "backward"):
if m.num_quantizers(mode):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m._fp8_metas[mode],
)
return graphed(*user_args, **user_kwargs) return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs)
...@@ -721,7 +737,7 @@ def _make_graphed_callables( ...@@ -721,7 +737,7 @@ def _make_graphed_callables(
def save_fp8_tensors( def save_fp8_tensors(
modules: Iterable[torch.nn.Module], modules: Iterable[torch.nn.Module],
fp8_recipe: Recipe, fp8_recipe: Optional[Recipe],
) -> Optional[List[Any]]: ) -> Optional[List[Any]]:
""" """
Returns the FP8 tensors for all modules Returns the FP8 tensors for all modules
...@@ -740,7 +756,7 @@ def save_fp8_tensors( ...@@ -740,7 +756,7 @@ def save_fp8_tensors(
m.adjust_amax_history_length(fp8_recipe.amax_history_len) m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors() module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation): elif isinstance(m, BasicOperation):
m.pre_first_forward(recipe=fp8_recipe) m.reset_recipe_type(recipe=fp8_recipe)
module_tensors = m._save_fp8_metas() module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors) fp8_tensors.append(module_tensors)
return fp8_tensors return fp8_tensors
...@@ -777,7 +793,7 @@ def make_graphed_callables( ...@@ -777,7 +793,7 @@ def make_graphed_callables(
sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
fp8_enabled: bool = False, fp8_enabled: bool = False,
fp8_calibrating: bool = False, fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None, fp8_recipe: Optional[Recipe] = None,
fp8_group: Optional[dist_group_type] = None, fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False, fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None, _order: Optional[List[int]] = None,
...@@ -828,7 +844,7 @@ def make_graphed_callables( ...@@ -828,7 +844,7 @@ def make_graphed_callables(
data of fp8 tensors even when executing without fp8 enabled. This is data of fp8 tensors even when executing without fp8 enabled. This is
useful for saving an inference ready fp8 checkpoint while training useful for saving an inference ready fp8 checkpoint while training
using a higher precision. using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None` fp8_recipe: Recipe, default = `None`
recipe used for FP8 training. recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors distributed group over which amaxes for the fp8 tensors
...@@ -844,7 +860,10 @@ def make_graphed_callables( ...@@ -844,7 +860,10 @@ def make_graphed_callables(
""" """
set_capture_start() set_capture_start()
fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe if fp8_enabled and fp8_recipe is None:
fp8_recipe = get_default_fp8_recipe()
elif not fp8_enabled:
fp8_recipe = None
# Handle single module. # Handle single module.
just_one_callable = False just_one_callable = False
......
...@@ -11,7 +11,6 @@ from typing import Optional ...@@ -11,7 +11,6 @@ from typing import Optional
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
...@@ -71,7 +70,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -71,7 +70,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -87,14 +86,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -87,14 +86,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check input tensor # Check input tensor
x = maybe_dequantize(input_.contiguous(), dtype) x = maybe_dequantize(input_.contiguous(), dtype)
# Check if quantized compute is enabled
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None
if with_quantized_compute:
quantizer = next_op_input_quantizer
# Launch kernel # Launch kernel
y = self._activation_forward_impl(x, quantizer) y = self._activation_forward_impl(x, next_op_input_quantizer)
# Quantize input to FP8 before caching if needed # Quantize input to FP8 before caching if needed
if self.cache_quantized_input: if self.cache_quantized_input:
...@@ -103,10 +96,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -103,10 +96,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = input_quantizer(x) x = input_quantizer(x)
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad:
ctx.save_for_backward(x) ctx.save_for_backward(x)
ctx.with_quantized_compute = with_quantized_compute
ctx.dtype = dtype ctx.dtype = dtype
ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
return y return y
...@@ -125,13 +118,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -125,13 +118,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check grad output tensor # Check grad output tensor
dy = maybe_dequantize(grad_output.contiguous(), x.dtype) dy = maybe_dequantize(grad_output.contiguous(), x.dtype)
# Check if quantized compute is enabled
quantizer = None
if ctx.with_quantized_compute:
quantizer = ctx.prev_op_grad_input_quantizer
# Launch kernel # Launch kernel
dx = self._activation_backward_impl(dy, x, quantizer) dx = self._activation_backward_impl(dy, x, ctx.prev_op_grad_output_quantizer)
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x) clear_tensor_data(x)
......
...@@ -59,7 +59,7 @@ class AddInPlace(BasicOperation): ...@@ -59,7 +59,7 @@ class AddInPlace(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......
...@@ -40,7 +40,7 @@ class AllGather(BasicOperation): ...@@ -40,7 +40,7 @@ class AllGather(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
out: torch.Tensor out: torch.Tensor
......
...@@ -42,7 +42,7 @@ class AllReduce(BasicOperation): ...@@ -42,7 +42,7 @@ class AllReduce(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -22,7 +22,6 @@ from ...distributed import ( ...@@ -22,7 +22,6 @@ from ...distributed import (
from ...fp8 import FP8GlobalStateManager, Recipe from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
...@@ -291,6 +290,14 @@ class BasicLinear(BasicOperation): ...@@ -291,6 +290,14 @@ class BasicLinear(BasicOperation):
# Quantize if needed # Quantize if needed
if self._with_quantized_weight: if self._with_quantized_weight:
quantizer = self.get_quantizer("forward", 1) quantizer = self.get_quantizer("forward", 1)
if quantizer is None:
raise RuntimeError(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because fp8_model_init was called "
"with enabled=True and recipe=None, instead of providing "
"a recipe to use for quantization."
)
quantizer.set_usage( quantizer.set_usage(
rowwise=True, rowwise=True,
columnwise=torch.is_grad_enabled(), columnwise=torch.is_grad_enabled(),
...@@ -303,62 +310,19 @@ class BasicLinear(BasicOperation): ...@@ -303,62 +310,19 @@ class BasicLinear(BasicOperation):
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
self.weight = weight self.weight = weight
def pre_first_forward( def pre_first_fuser_forward(self) -> None:
self, super().pre_first_fuser_forward()
*, if self.weight.device.type == "meta":
recipe: Optional[Recipe],
) -> None:
super().pre_first_forward(recipe=recipe)
# Initialize weights if needed
weight = self.weight
if weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
weight = self.weight
# Configure quantizers def reset_recipe_type(self, *, recipe: Optional[Recipe]) -> None:
if recipe is not None: super().reset_recipe_type(recipe=recipe)
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
grad_output_quantizer = self.get_quantizer("backward", 0)
# Specify required tensor formats if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters():
input_quantizer.internal = True # Make quantizers use internal tensors
weight_quantizer.internal = True self.get_input_quantizer().internal = True
grad_output_quantizer.internal = True self.get_grad_output_quantizer().internal = True
self.get_quantizer("forward", 1).internal = True
# Recipe-specific configuration
if recipe.float8_current_scaling():
if any(
not isinstance(q, Float8CurrentScalingQuantizer)
for q in (input_quantizer, weight_quantizer, grad_output_quantizer)
):
raise RuntimeError(
"FP8 current-scaling recipe is enabled, "
f"but input quantizer is {input_quantizer.__class__.__name__}, "
f"weight quantizer is {weight_quantizer.__class__.__name__}, "
f"grad output quantizer is {grad_output_quantizer.__class__.__name__}"
)
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Make sure weight tensor has correct quantizer
# Note: Quantizer might have changed if quantization
# recipe changed
if isinstance(
weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
) and isinstance(weight, Float8TensorBase):
weight._quantizer = weight_quantizer
@staticmethod @staticmethod
def _functional_forward( def _functional_forward(
...@@ -894,7 +858,7 @@ class BasicLinear(BasicOperation): ...@@ -894,7 +858,7 @@ class BasicLinear(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -903,27 +867,34 @@ class BasicLinear(BasicOperation): ...@@ -903,27 +867,34 @@ class BasicLinear(BasicOperation):
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata # FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
output_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
# Get quantizers
input_quantizer = self.get_quantizer("forward", 0) input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1) weight_quantizer = self.get_quantizer("forward", 1)
output_quantizer = next_op_input_quantizer output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0) grad_output_quantizer = self.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer grad_input_quantizer = prev_op_grad_output_quantizer
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
# Configure quantizers # Configure quantizers
# Note: We cache the quantized input for backward pass, # Note: We cache the quantized input for backward pass,
# but discard the quantized weights. # but discard the quantized weights.
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=False) weight_quantizer.set_usage(rowwise=True, columnwise=False)
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
if self.sequence_parallel and self.tensor_parallel_mode == "column":
input_quantizer.with_amax_reduction = True
input_quantizer.amax_reduction_group = self.tensor_parallel_group
if self.sequence_parallel and self.tensor_parallel_mode == "row":
grad_output_quantizer.with_amax_reduction = True
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
# Get autocast dtype if needed # Get autocast dtype if needed
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
dtype = torch.get_autocast_dtype("cuda") dtype = torch.get_autocast_dtype("cuda")
...@@ -947,6 +918,7 @@ class BasicLinear(BasicOperation): ...@@ -947,6 +918,7 @@ class BasicLinear(BasicOperation):
) )
# Save state for backward pass # Save state for backward pass
if ctx.requires_grad:
ctx.save_for_backward(x_local, w) ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
......
...@@ -18,7 +18,6 @@ from ...utils import ( ...@@ -18,7 +18,6 @@ from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
) )
from ...fp8 import FP8GlobalStateManager
from ...tensor import Quantizer from ...tensor import Quantizer
...@@ -114,8 +113,8 @@ class Bias(BasicOperation): ...@@ -114,8 +113,8 @@ class Bias(BasicOperation):
bias = torch.nn.Parameter(bias) bias = torch.nn.Parameter(bias)
self.bias = bias self.bias = bias
def pre_first_forward(self, *args, **kwargs) -> None: def pre_first_fuser_forward(self) -> None:
super().pre_first_forward(*args, **kwargs) super().pre_first_fuser_forward()
if self.bias.device.type == "meta": if self.bias.device.type == "meta":
self.reset_parameters() self.reset_parameters()
...@@ -123,24 +122,14 @@ class Bias(BasicOperation): ...@@ -123,24 +122,14 @@ class Bias(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
x = input_ x = input_
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
# Check if backward pass is needed if ctx.requires_grad:
requires_grad = ctx.requires_grad ctx.grad_input_quantizer = prev_op_grad_output_quantizer
# Check if previous op quantizes its output's gradient
grad_input_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
grad_input_quantizer = prev_op_grad_input_quantizer
if requires_grad:
ctx.with_quantized_compute = with_quantized_compute
ctx.grad_input_quantizer = grad_input_quantizer
return x + b return x + b
...@@ -152,7 +141,7 @@ class Bias(BasicOperation): ...@@ -152,7 +141,7 @@ class Bias(BasicOperation):
dy = grad_output dy = grad_output
if dy.dim() > 1: if dy.dim() > 1:
quantizer = ctx.grad_input_quantizer quantizer = ctx.grad_input_quantizer
if ctx.with_quantized_compute and quantizer is not None: if quantizer is not None:
db, dy = tex.bgrad_quantize(dy, quantizer) db, dy = tex.bgrad_quantize(dy, quantizer)
else: else:
db = dy.sum(tuple(range(dy.dim() - 1))) db = dy.sum(tuple(range(dy.dim() - 1)))
......
...@@ -23,7 +23,7 @@ class Identity(BasicOperation): ...@@ -23,7 +23,7 @@ class Identity(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
return input_ return input_
......
...@@ -74,7 +74,7 @@ class L2Normalization(BasicOperation): ...@@ -74,7 +74,7 @@ class L2Normalization(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors # Use input directly - torch.compile can handle multi-dimensional tensors
......
...@@ -13,7 +13,6 @@ from typing import Optional ...@@ -13,7 +13,6 @@ from typing import Optional
import torch import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType from ...constants import TE_DType
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
...@@ -168,8 +167,8 @@ class LayerNorm(BasicOperation): ...@@ -168,8 +167,8 @@ class LayerNorm(BasicOperation):
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
def pre_first_forward(self, *args, **kwargs) -> None: def pre_first_fuser_forward(self) -> None:
super().pre_first_forward(*args, **kwargs) super().pre_first_fuser_forward()
if self.weight.device.type == "meta" or self.bias.device.type == "meta": if self.weight.device.type == "meta" or self.bias.device.type == "meta":
self.reset_parameters() self.reset_parameters()
...@@ -177,7 +176,7 @@ class LayerNorm(BasicOperation): ...@@ -177,7 +176,7 @@ class LayerNorm(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -200,31 +199,22 @@ class LayerNorm(BasicOperation): ...@@ -200,31 +199,22 @@ class LayerNorm(BasicOperation):
w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
b = maybe_dequantize(self.bias, dtype).view((inner_dim,)) b = maybe_dequantize(self.bias, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute layer norm # Compute layer norm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"] sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"]
y, means, rstdevs = layernorm_fwd( y, means, rstdevs = layernorm_fwd(
x, x,
w, w,
b, b,
self.eps, self.eps,
None, None,
output_quantizer, next_op_input_quantizer,
TE_DType[dtype], TE_DType[dtype],
sm_margin, sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
) )
# Save state for backward pass # Save state for backward pass
if requires_grad: if ctx.requires_grad:
ctx.save_for_backward(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
......
...@@ -59,7 +59,7 @@ class MakeExtraOutput(BasicOperation): ...@@ -59,7 +59,7 @@ class MakeExtraOutput(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......
...@@ -50,7 +50,7 @@ class Quantize(BasicOperation): ...@@ -50,7 +50,7 @@ class Quantize(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -64,6 +64,7 @@ class Quantize(BasicOperation): ...@@ -64,6 +64,7 @@ class Quantize(BasicOperation):
if quantize_forward and not is_quantized_tensor(out): if quantize_forward and not is_quantized_tensor(out):
out = self.get_quantizer("forward", 0)(out) out = self.get_quantizer("forward", 0)(out)
if ctx.requires_grad:
ctx.quantize_backward = quantize_backward ctx.quantize_backward = quantize_backward
return out return out
......
...@@ -40,7 +40,7 @@ class ReduceScatter(BasicOperation): ...@@ -40,7 +40,7 @@ class ReduceScatter(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -38,9 +38,10 @@ class Reshape(BasicOperation): ...@@ -38,9 +38,10 @@ class Reshape(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
if ctx.requires_grad:
ctx.input_shape = input_.size() ctx.input_shape = input_.size()
return input_.reshape(*self._shape) return input_.reshape(*self._shape)
......
...@@ -13,7 +13,6 @@ from typing import Optional ...@@ -13,7 +13,6 @@ from typing import Optional
import torch import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType from ...constants import TE_DType
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
...@@ -151,8 +150,8 @@ class RMSNorm(BasicOperation): ...@@ -151,8 +150,8 @@ class RMSNorm(BasicOperation):
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
self.weight = weight self.weight = weight
def pre_first_forward(self, *args, **kwargs) -> None: def pre_first_fuser_forward(self) -> None:
super().pre_first_forward(*args, **kwargs) super().pre_first_fuser_forward()
if self.weight.device.type == "meta": if self.weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
...@@ -160,7 +159,7 @@ class RMSNorm(BasicOperation): ...@@ -160,7 +159,7 @@ class RMSNorm(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op_grad_input_quantizer: Optional[Quantizer], prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -182,30 +181,21 @@ class RMSNorm(BasicOperation): ...@@ -182,30 +181,21 @@ class RMSNorm(BasicOperation):
x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim)) x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim))
w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if output is quantized
output_quantizer = None
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute RMSNorm # Compute RMSNorm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"] sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"]
y, _, rstdevs = rmsnorm_fwd( y, _, rstdevs = rmsnorm_fwd(
x, x,
w, w,
self.eps, self.eps,
None, None,
output_quantizer, next_op_input_quantizer,
TE_DType[dtype], TE_DType[dtype],
sm_margin, sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
) )
# Save state for backward pass # Save state for backward pass
if requires_grad: if ctx.requires_grad:
ctx.save_for_backward(x, rstdevs) ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
"""Compound tensor operation supported by the operation fuser.""" """Compound tensor operation supported by the operation fuser."""
from .backward_bias_activation import ( from .backward_activation_bias import (
BackwardBiasActivation, BackwardActivationBias,
fuse_backward_bias_activation, fuse_backward_activation_bias,
) )
from .backward_linear_add import ( from .backward_linear_add import (
BackwardLinearAdd, BackwardLinearAdd,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Fused backward dbias + dact + quantize.""" """Fused backward dact + dbias + quantize."""
from __future__ import annotations from __future__ import annotations
from typing import Optional from typing import Optional
...@@ -29,8 +29,8 @@ _fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu} ...@@ -29,8 +29,8 @@ _fused_activations = {GELU: tex.dbias_dgelu, ReLU: tex.dbias_drelu}
_fusible_activations = tuple(_fused_activations.keys()) _fusible_activations = tuple(_fused_activations.keys())
class BackwardBiasActivation(FusedOperation): class BackwardActivationBias(FusedOperation):
"""Fused backward dbias + dact + quantize """Fused backward dact + dbias + quantize
Uses the next operation's input quantizer. Uses the next operation's input quantizer.
...@@ -66,15 +66,10 @@ class BackwardBiasActivation(FusedOperation): ...@@ -66,15 +66,10 @@ class BackwardBiasActivation(FusedOperation):
dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype) dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype)
# Get previous op quantizer # Get previous op quantizer
if not bias_op_ctx.with_quantized_compute:
raise RuntimeError(
"BackwardBiasActivation requires quantized compute, "
"but Bias context has it disabled"
)
quantizer = bias_op_ctx.grad_input_quantizer quantizer = bias_op_ctx.grad_input_quantizer
if quantizer is None: if quantizer is None:
raise RuntimeError( raise RuntimeError(
"BackwardBiasActivation requires previous op's grad output quantizer, " "BackwardActivationBias requires previous op's grad output quantizer, "
"but Bias context has no quantizer" "but Bias context has no quantizer"
) )
...@@ -87,11 +82,11 @@ class BackwardBiasActivation(FusedOperation): ...@@ -87,11 +82,11 @@ class BackwardBiasActivation(FusedOperation):
return dx, [(), (db,)], [(), ()] return dx, [(), (db,)], [(), ()]
def fuse_backward_bias_activation( def fuse_backward_activation_bias(
ops: list[tuple[FusibleOperation, list[int]]], ops: list[tuple[FusibleOperation, list[int]]],
recipe: Optional[Recipe], recipe: Optional[Recipe],
) -> list[tuple[FusibleOperation, list[int]]]: ) -> list[tuple[FusibleOperation, list[int]]]:
"""Fused backward dbias + dact + quantize """Fused backward dact + dbias + quantize
Parameters Parameters
---------- ----------
...@@ -138,7 +133,7 @@ def fuse_backward_bias_activation( ...@@ -138,7 +133,7 @@ def fuse_backward_bias_activation(
ops = ops[1:] ops = ops[1:]
# Replace window with fused op # Replace window with fused op
op = BackwardBiasActivation( op = BackwardActivationBias(
activation=window[0][0], activation=window[0][0],
bias=window[1][0], bias=window[1][0],
) )
......
...@@ -29,10 +29,10 @@ class BackwardLinearAdd(FusedOperation): ...@@ -29,10 +29,10 @@ class BackwardLinearAdd(FusedOperation):
def __init__( def __init__(
self, self,
*, *,
linear: BasicLinear,
backward_add: MakeExtraOutput, backward_add: MakeExtraOutput,
linear: BasicLinear,
) -> None: ) -> None:
super().__init__((linear, backward_add)) super().__init__((backward_add, linear))
def fuser_backward( def fuser_backward(
self, self,
...@@ -47,7 +47,7 @@ class BackwardLinearAdd(FusedOperation): ...@@ -47,7 +47,7 @@ class BackwardLinearAdd(FusedOperation):
]: ]:
# Get basic operations # Get basic operations
linear_op = self.basic_ops[0] linear_op = self.basic_ops[1]
linear_op_ctx = basic_op_ctxs[0] linear_op_ctx = basic_op_ctxs[0]
# Saved tensors from forward pass # Saved tensors from forward pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment