Unverified Commit 7022d50f authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Quantizer as API (#2039)



* Introduce QuantizerBase
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Expose as a first-class API
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Undo QuantizerBase
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Make Quantizer a base class without implementations
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Support CustomRecipe and CustomRecipeState
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Resolving comments: quantize impl, num_quantizers, defaults
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Quantizer factories
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Add tests
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* QuantizedTensorBase _get_quantizer() + quantize_()
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Experimental note + LayerNormMLP fix
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tensor._internal -> tensor.base
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Expose
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor import fix
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Single quantizer factory with roles
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* More context for qfactory, fwd/bwd_roles
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Rename *Base -> *Storage quantized tensors
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* make_quantizers() will take roles from the operation
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* Improve tests and fix missing imports
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Merge main followup
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
Signed-off-by: default avatarEvgeny Tsykunov <etsykunov@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ce18bee7
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import check_fp8_support, fp8_autocast
from transformer_engine.pytorch import Linear
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear
from transformer_engine.pytorch.module.layernorm_mlp import LayerNormMLP
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.module.grouped_linear import GroupedLinear
@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"])
def test_custom_recipe_sanity(module_type):
available, reason = check_fp8_support()
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(0)
# Simple linear layer with dims divisible by 16
in_features = 64
out_features = 64
batch = 32
if module_type == "Linear":
model = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
elif module_type == "LayerNormLinear":
model = LayerNormLinear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
elif module_type == "LayerNormMLP":
# hidden_size == in_features == out_features for simplicity
model = LayerNormMLP(
hidden_size=in_features, ffn_hidden_size=out_features, params_dtype=torch.bfloat16
).cuda()
else:
# OpsLinear path
model = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
# Single factory: map roles to quantizers
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
# Execute with custom recipe
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe):
out = model(inp)
loss = out.float().sum()
loss.backward()
# Basic sanity: gradients exist
assert inp.grad is not None
def test_custom_recipe_grouped_linear_sanity():
available, reason = check_fp8_support()
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(0)
num_gemms = 3
in_features = 64
out_features = 64
batch = 32
base = batch // num_gemms
rem = batch % num_gemms
m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)]
model = GroupedLinear(num_gemms, in_features, out_features, params_dtype=torch.bfloat16).cuda()
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe):
out = model(inp, m_splits)
loss = out.float().sum()
loss.backward()
assert inp.grad is not None
def test_custom_recipe_matches_current_scaling():
available, reason = check_fp8_support()
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(123)
in_features = 64
out_features = 64
batch = 32
# Create two identical models
model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
model_custom = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda()
model_custom.load_state_dict(model_ref.state_dict())
# Identical inputs for both paths
base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16)
inp_ref = base_inp.clone().detach().requires_grad_(True)
inp_custom = base_inp.clone().detach().requires_grad_(True)
# Reference: use Float8CurrentScaling recipe
ref_recipe = recipe.Float8CurrentScaling()
with fp8_autocast(enabled=True, fp8_recipe=ref_recipe):
out_ref = model_ref(inp_ref)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3
assert ref_bwd_go.dtype == tex.DType.kFloat8E5M2
assert ref_bwd_gi.dtype == tex.DType.kFloat8E5M2
# Stress dynamic range in grad_output
scale = torch.ones(out_features, device="cuda", dtype=torch.float32)
scale[0] = 1e8
scale[1] = 1e-8
loss_ref = (out_ref.float() * scale.view(1, -1)).sum()
loss_ref.backward()
# Custom: single factory returning quantizers per role to match Float8CurrentScaling
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom_recipe):
out_custom = model_custom(inp_custom)
# Assert dtypes for custom quantizers match reference mapping
cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3
assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2
assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2
loss_custom = (out_custom.float() * scale.view(1, -1)).sum()
loss_custom.backward()
# Compare forward outputs (exact match expected)
assert torch.allclose(out_ref, out_custom, rtol=0.0, atol=0.0)
# Compare input gradients
assert inp_ref.grad is not None and inp_custom.grad is not None
assert torch.allclose(inp_ref.grad, inp_custom.grad, rtol=0.0, atol=0.0)
# Compare parameter gradients (weights and bias if present)
ref_params = dict(model_ref.named_parameters())
custom_params = dict(model_custom.named_parameters())
for name, p_ref in ref_params.items():
p_cus = custom_params[name]
assert p_ref.grad is not None and p_cus.grad is not None
assert torch.allclose(p_ref.grad, p_cus.grad, rtol=0.0, atol=0.0)
def test_custom_recipe_ops_linear_2_1_layout():
available, reason = check_fp8_support()
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(7)
in_features = 64
out_features = 64
batch = 16
# Use ops.Linear which consumes 2 forward quantizers and 1 backward quantizer
op = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
def quantizer_factory(role):
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda")
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda")
custom = recipe.CustomRecipe(qfactory=quantizer_factory)
with fp8_autocast(enabled=True, fp8_recipe=custom):
out = op(inp)
loss = out.float().sum()
loss.backward()
assert inp.grad is not None
def test_custom_recipe_factory_invocation_counts_and_cycling():
available, reason = check_fp8_support()
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
torch.manual_seed(13)
in_features = 64
out_features = 64
batch = 8
op = Linear(in_features, out_features, params_dtype=torch.bfloat16)
inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True)
# Counters per role
counts = {
"linear_input": 0,
"linear_weight": 0,
"linear_output": 0,
"linear_grad_output": 0,
"linear_grad_input": 0,
}
def quantizer_factory(role):
if role in counts:
counts[role] += 1
if role in ("linear_input", "linear_weight", "linear_output"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
if role in ("linear_grad_output", "linear_grad_input"):
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda"))
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
custom = recipe.CustomRecipe(qfactory=quantizer_factory)
# Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory),
# and backward to build 2 quantizers (cycled from 1 factory).
with fp8_autocast(enabled=True, fp8_recipe=custom):
out = op(inp)
loss = out.float().sum()
loss.backward()
# Single GEMM: forward should request input, weight, output; backward grad_output, grad_input
assert counts["linear_input"] == 1
assert counts["linear_weight"] == 1
assert counts["linear_output"] == 1
assert counts["linear_grad_output"] == 1
assert counts["linear_grad_input"] == 1
def test_factories_return_distinct_instances_and_buffers():
available, reason = check_fp8_support()
if not torch.cuda.is_available() or not available:
pytest.skip(f"FP8 unsupported on this device: {reason}")
# Two calls should produce distinct quantizer objects and distinct tensor buffers
def factory():
return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda"))
q1 = factory()
q2 = factory()
assert q1 is not q2
assert q1.scale.data_ptr() != q2.scale.data_ptr()
assert q1.amax.data_ptr() != q2.amax.data_ptr()
# Mutating one should not affect the other
q1.scale.fill_(123.0)
assert not torch.equal(q1.scale, q2.scale)
......@@ -6,7 +6,8 @@
from __future__ import annotations
import os
from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple
from typing import Any, Literal, Optional, Union, Callable, NamedTuple
from dataclasses import field
from pydantic.dataclasses import dataclass
......@@ -111,6 +112,10 @@ class Recipe:
"""Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling)
def custom(self):
"""Whether the given recipe is custom."""
return isinstance(self, CustomRecipe)
@dataclass()
class DelayedScaling(Recipe):
......@@ -377,7 +382,6 @@ class Float8BlockScaling(Recipe):
)
@dataclass()
class NVFP4BlockScaling(Recipe):
"""
Use the NVFP4 scaling strategy.
......@@ -456,3 +460,37 @@ class NVFP4BlockScaling(Recipe):
f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, "
f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, "
)
@dataclass()
class CustomRecipe(Recipe):
"""
Custom recipe that allows users to provide quantizer factories.
.. warning::
**EXPERIMENTAL**: Custom recipe is experimental, still under active development,
and the API is subject to change without notice. Use at your own risk.
Parameters
----------
qfactory : Callable
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as:
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
qfactory: Callable[..., Any]
fp8_dpa: bool = False
fp8_mha: bool = False
def __repr__(self) -> str:
return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}"
......@@ -15,8 +15,8 @@ import nvdlfw_inspect.api as debug_api
from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage
from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params
......@@ -123,7 +123,7 @@ class LogTensorStats(BaseLogTensorStats):
"""API call used to collect the data about the tensor before process_tensor()/quantization."""
assert (
type(tensor) not in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase]
type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage]
and tensor.dtype != torch.uint8
), (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using"
......
......@@ -18,7 +18,7 @@ from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
QuantizedTensorBase,
QuantizedTensorStorage,
prepare_for_saving,
restore_from_saved,
)
......@@ -557,7 +557,7 @@ class DebugQuantizer(Quantizer):
self._update_parent_quantizer_usage()
class DebugQuantizedTensor(QuantizedTensorBase):
class DebugQuantizedTensor(QuantizedTensorStorage):
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
......
......@@ -56,6 +56,21 @@ from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor import Float8Quantizer
from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor import QuantizedTensorStorage
from transformer_engine.pytorch.tensor import Float8TensorStorage
from transformer_engine.pytorch.tensor import MXFP8TensorStorage
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor import Float8Tensor
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import prepare_for_saving
from transformer_engine.pytorch.tensor import restore_from_saved
try:
torch._dynamo.config.error_on_nested_jit_trace = False
......
......@@ -12,7 +12,7 @@ from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_experimental
from ..experimental.gemm import experimental_gemm
from ...debug.pytorch.debug_quantization import DebugQuantizer
......@@ -107,9 +107,9 @@ def general_gemm(
# Use bfloat16 as default bias_dtype
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase):
if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage):
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
# implementation for Float8BlockwiseQTensorStorage GEMM
use_split_accumulator = True
# Check that data format is supported
......
......@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from .tensor.quantized_tensor import QuantizedTensorBase
from .tensor.quantized_tensor import QuantizedTensorStorage
from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"]
......@@ -34,7 +34,7 @@ def mark_activation_offload(*tensors):
if tensor is not None:
tensor.activation_offloading = True
# This is a hack to force clear the tensor after it is offloaded.
# It is needed, because .*TensorBase classes are saved in the ctx,
# It is needed, because .*TensorStorage classes are saved in the ctx,
# and they contain the reference to their data tensors.
tensor.needs_force_clear = True
......@@ -362,7 +362,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
),
)
is_quantized_tensor = isinstance(tensor, QuantizedTensorBase)
is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage)
if not torch_stray_tensor:
......@@ -514,7 +514,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
if tensor_tag[0] == self.offloaded_group_count:
if hasattr(tensor_buf, "needs_force_clear"):
# Need to clear activation tensor - sometimes references persist in the code.
# This is the case for example with the Float8TensorBase class,
# This is the case for example with the Float8TensorStorage class,
# which is saved directly inside the ctx while its internal tensors are
# saved inside save_for_backward.
tensor_buf.data = torch.Tensor()
......
......@@ -314,7 +314,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
// Construct FP8 block-wise tensors
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject *>(Float8BlockwiseQTensorBasePythonClass));
reinterpret_cast<PyObject *>(Float8BlockwiseQTensorStoragePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
......@@ -461,7 +461,7 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
}
// Construct mxfp8 tensors
py::handle MXFP8TensorClass(reinterpret_cast<PyObject *>(MXFP8TensorBasePythonClass));
py::handle MXFP8TensorClass(reinterpret_cast<PyObject *>(MXFP8TensorStoragePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
......
......@@ -23,17 +23,17 @@
namespace transformer_engine::pytorch {
PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *Float8TensorBasePythonClass = nullptr;
PyTypeObject *Float8TensorStoragePythonClass = nullptr;
PyTypeObject *Float8QuantizerClass = nullptr;
PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr;
PyTypeObject *MXFP8TensorStoragePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorStoragePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorBasePythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
void init_float8_extension() {
......@@ -46,9 +46,9 @@ void init_float8_extension() {
Float8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor"));
auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base");
Float8TensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase"));
py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage");
Float8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage"));
NVTE_CHECK(Float8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch Float8 extension.");
}
......@@ -61,29 +61,29 @@ void init_mxfp8_extension() {
MXFP8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor"));
auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base");
MXFP8TensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase"));
py::module_::import("transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage");
MXFP8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorStorage"));
NVTE_CHECK(MXFP8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch MXFP8 extension.");
}
void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorBasePythonClass) return;
if (Float8BlockwiseQTensorStoragePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
"transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base");
"transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage");
Float8BlockwiseQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer"));
Float8BlockwiseQTensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase"));
Float8BlockwiseQTensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorStorage"));
Float8BlockwiseQTensorPythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor"));
NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr,
NVTE_CHECK(Float8BlockwiseQTensorStoragePythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
......@@ -97,9 +97,9 @@ void init_nvfp4_extensions() {
NVFP4TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor"));
auto nvfp4_base_module =
py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base");
NVFP4TensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorBase"));
py::module_::import("transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage");
NVFP4TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorStorage"));
NVTE_CHECK(NVFP4TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch NVFP4 extension.");
}
......
......@@ -31,17 +31,17 @@ namespace transformer_engine::pytorch {
} while (false);
extern PyTypeObject *Float8TensorPythonClass;
extern PyTypeObject *Float8TensorBasePythonClass;
extern PyTypeObject *Float8TensorStoragePythonClass;
extern PyTypeObject *Float8QuantizerClass;
extern PyTypeObject *Float8CurrentScalingQuantizerClass;
extern PyTypeObject *MXFP8TensorPythonClass;
extern PyTypeObject *MXFP8TensorBasePythonClass;
extern PyTypeObject *MXFP8TensorStoragePythonClass;
extern PyTypeObject *MXFP8QuantizerClass;
extern PyTypeObject *Float8BlockwiseQTensorPythonClass;
extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass;
extern PyTypeObject *Float8BlockwiseQTensorStoragePythonClass;
extern PyTypeObject *Float8BlockwiseQuantizerClass;
extern PyTypeObject *NVFP4TensorPythonClass;
extern PyTypeObject *NVFP4TensorBasePythonClass;
extern PyTypeObject *NVFP4TensorStoragePythonClass;
extern PyTypeObject *NVFP4QuantizerClass;
void init_extension();
......@@ -55,13 +55,13 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) {
}
inline bool IsFloat8Tensor(PyObject *obj) {
return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass;
return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorStoragePythonClass;
}
inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; }
inline bool IsMXFP8Tensor(PyObject *obj) {
return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass;
return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorStoragePythonClass;
}
inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) {
......@@ -72,11 +72,11 @@ inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4Quant
inline bool IsFloat8BlockwiseQTensor(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass ||
Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass;
Py_TYPE(obj) == Float8BlockwiseQTensorStoragePythonClass;
}
inline bool IsNVFP4Tensor(PyObject *obj) {
return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorBasePythonClass;
return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorStoragePythonClass;
}
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer);
......
......@@ -152,7 +152,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
// Construct Python FP8 tensor
py::object out_py;
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass));
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
......@@ -357,7 +357,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
py::object data_py = with_data ? py::cast(data_tensor) : py::none();
py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none();
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass));
out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py,
"quantizer"_a = this->quantizer);
......@@ -630,7 +630,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
py::object ret;
if (internal) {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorStoragePythonClass));
ret = Float8BlockwiseQTensorClass(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
......@@ -950,7 +950,7 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
// Construct Python MXFP8 tensor
py::object out_py;
if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorBasePythonClass));
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass));
out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
......@@ -1230,7 +1230,7 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
// Construct Python NVFP4 tensor
py::object out_py;
if (internal) {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorBasePythonClass));
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass));
out_py = NVFP4TensorClass(
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
......
......@@ -41,11 +41,11 @@ from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentSc
from .tensor.mxfp8_tensor import MXFP8Quantizer
from .tensor.nvfp4_tensor import NVFP4Quantizer
from .tensor.float8_blockwise_tensor import Float8BlockQuantizer
from .tensor.quantized_tensor import QuantizedTensorBase, QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.nvfp4_tensor_base import NVFP4TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer
from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .triton.pad import pad_columnwise_scale_inv
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
......@@ -907,7 +907,7 @@ def _all_gather_fp8(
async_op: bool = False,
quantizer: Optional[Quantizer] = None,
out_shape: Optional[list[int]] = None,
) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]:
) -> tuple[Float8TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather FP8 tensor along first dimension."""
world_size = get_distributed_world_size(process_group)
......@@ -925,7 +925,7 @@ def _all_gather_fp8(
# Cast input tensor to FP8 if needed
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase):
if not isinstance(inp, Float8TensorStorage):
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
......@@ -940,7 +940,7 @@ def _all_gather_fp8(
)
# Construct output tensor
out: Float8TensorBase
out: Float8TensorStorage
if quantizer is not None:
dtype = torch.float32
device = "cuda"
......@@ -958,7 +958,7 @@ def _all_gather_fp8(
out._transpose = None
out._transpose_invalid = True
else:
raise RuntimeError("FP8TensorBase is not supported yet without Quantizer")
raise RuntimeError("Float8TensorStorage is not supported yet without Quantizer")
# Assume scaling factors are identical across ranks
out._scale_inv = inp._scale_inv
......@@ -1003,10 +1003,10 @@ def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorBase,
out: Float8BlockwiseQTensorStorage,
quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorBase:
) -> Float8BlockwiseQTensorStorage:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
......@@ -1040,7 +1040,7 @@ def _post_process_fp8_blockwise_gather(
class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorBase
tensor: Float8BlockwiseQTensorStorage
quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work
_synchronized: bool = False
......@@ -1078,18 +1078,18 @@ def _all_gather_fp8_blockwise(
if isinstance(inp, torch.Tensor):
device = inp.device
dtype = inp.dtype
elif isinstance(inp, Float8BlockwiseQTensorBase):
elif isinstance(inp, Float8BlockwiseQTensorStorage):
if inp._rowwise_data is not None:
device = inp._rowwise_data.device
elif inp._columnwise_data is not None:
device = inp._columnwise_data.device
else:
raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data")
raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data")
dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant.
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, "
f"found {inp.__class__.__name__})"
"Invalid type for input tensor (expected torch.Tensor or"
f" Float8BlockwiseQTensorStorage, found {inp.__class__.__name__})"
)
world_size = get_distributed_world_size(process_group)
......@@ -1106,7 +1106,7 @@ def _all_gather_fp8_blockwise(
# Doing BF16 gather for now as baseline because it's simpler
if (
not isinstance(inp, Float8BlockwiseQTensorBase)
not isinstance(inp, Float8BlockwiseQTensorStorage)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
......@@ -1131,7 +1131,7 @@ def _all_gather_fp8_blockwise(
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True
if not isinstance(inp, Float8BlockwiseQTensorBase):
if not isinstance(inp, Float8BlockwiseQTensorStorage):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
......@@ -1228,12 +1228,12 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int):
def _post_process_nvfp4_gather(
out: NVFP4TensorBase,
out: NVFP4TensorStorage,
columnwise_data_interleaved: torch.Tensor,
columnwise_scale_inv_interleaved: torch.Tensor,
world_size: int,
handle: Optional[torch.distributed.Work] = None,
) -> NVFP4TensorBase:
) -> NVFP4TensorStorage:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
......@@ -1251,7 +1251,7 @@ def _post_process_nvfp4_gather(
class _NVFP4AllGatherAsyncHandle:
"""Handle for asynchronous NVFP4 all-gather."""
output: NVFP4TensorBase
output: NVFP4TensorStorage
columnwise_data_interleaved: torch.Tensor
columnwise_scale_inv_interleaved: torch.Tensor
world_size: int
......@@ -1279,7 +1279,7 @@ def _all_gather_nvfp4(
async_op: bool = False,
quantizer: NVFP4Quantizer,
out_shape: Optional[list[int]] = None,
) -> tuple[NVFP4TensorBase, Optional[torch.distributed.Work]]:
) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather NVFP4 tensor along first dimension."""
# Input tensor attributes
......@@ -1289,7 +1289,7 @@ def _all_gather_nvfp4(
dtype: torch.dtype
# Construct packed shapes for input and input_t.
if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorBase):
if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorStorage):
# High-precision tensor.
in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size())
in_shape_t = NVFP4Quantizer.convert_shape_for_fp4(
......@@ -1297,7 +1297,7 @@ def _all_gather_nvfp4(
)
device = inp.device
dtype = inp.dtype
elif isinstance(inp, NVFP4TensorBase):
elif isinstance(inp, NVFP4TensorStorage):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
......@@ -1307,7 +1307,7 @@ def _all_gather_nvfp4(
dtype = torch.bfloat16
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or NVFP4TensorBase, "
"Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, "
f"found {inp.__class__.__name__})"
)
......@@ -1321,7 +1321,7 @@ def _all_gather_nvfp4(
# For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to NVFP4.
if (
not isinstance(inp, NVFP4TensorBase)
not isinstance(inp, NVFP4TensorStorage)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
......@@ -1336,7 +1336,7 @@ def _all_gather_nvfp4(
return out, None
# Cast input tensor to NVFP4 with required data
if not isinstance(inp, NVFP4TensorBase):
if not isinstance(inp, NVFP4TensorStorage):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
......@@ -1453,7 +1453,7 @@ def _all_gather_mxfp8(
async_op: bool = False,
quantizer: MXFP8Quantizer,
out_shape: Optional[list[int]] = None,
) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]:
) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]:
"""All-gather MXFP8 tensor along first dimension."""
# Input tensor attributes
......@@ -1464,7 +1464,7 @@ def _all_gather_mxfp8(
in_shape = inp.size()
device = inp.device
dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase):
elif isinstance(inp, MXFP8TensorStorage):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
......@@ -1476,7 +1476,7 @@ def _all_gather_mxfp8(
dtype = torch.bfloat16 # Guess high-precision dtype.
else:
raise ValueError(
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, "
"Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, "
f"found {inp.__class__.__name__})"
)
......@@ -1488,7 +1488,7 @@ def _all_gather_mxfp8(
# For cases where inp has dimensions that cannot be quantized,
# we gather in high precision followed by a cast to FP8.
if (
not isinstance(inp, MXFP8TensorBase)
not isinstance(inp, MXFP8TensorStorage)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
......@@ -1503,7 +1503,7 @@ def _all_gather_mxfp8(
return out, None
# Cast input tensor to MXFP8 with required data
if not isinstance(inp, MXFP8TensorBase):
if not isinstance(inp, MXFP8TensorStorage):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
......@@ -1587,7 +1587,7 @@ def gather_along_first_dim(
# Return immediately if no communication is required
world_size = get_distributed_world_size(process_group)
if world_size == 1:
if quantizer is not None and not isinstance(inp, QuantizedTensorBase):
if quantizer is not None and not isinstance(inp, QuantizedTensorStorage):
inp = quantizer(inp)
return inp, None
......@@ -1634,7 +1634,7 @@ def gather_along_first_dim(
out_shape[0] *= world_size
# FP8 case: delayed scaling or current scaling
if isinstance(inp, Float8TensorBase) or isinstance(
if isinstance(inp, Float8TensorStorage) or isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
return _all_gather_fp8(
......@@ -1646,7 +1646,9 @@ def gather_along_first_dim(
)
# FP8 block scaling case, block length = 128
if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer):
if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance(
quantizer, Float8BlockQuantizer
):
return _all_gather_fp8_blockwise(
inp,
process_group,
......@@ -1656,7 +1658,7 @@ def gather_along_first_dim(
)
# MXFP8 case
if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer):
if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer):
assert isinstance(quantizer, MXFP8Quantizer)
return _all_gather_mxfp8(
inp,
......@@ -1667,7 +1669,7 @@ def gather_along_first_dim(
)
# NVFP4 case
if isinstance(inp, NVFP4TensorBase) or isinstance(quantizer, NVFP4Quantizer):
if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer):
assert isinstance(quantizer, NVFP4Quantizer)
return _all_gather_nvfp4(
inp,
......@@ -1683,7 +1685,7 @@ def gather_along_first_dim(
"Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather."
)
if isinstance(inp, QuantizedTensorBase):
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
......@@ -1701,7 +1703,7 @@ def gather_along_first_dim(
return out, None
# Dequantize quantized tensor if not supported
if isinstance(inp, QuantizedTensorBase):
if isinstance(inp, QuantizedTensorStorage):
warnings.warn(
"Attempting to all-gather an unsupported quantized tensor. "
"Falling back to high-precision all-gather."
......
......@@ -13,7 +13,7 @@ from typing import Iterable, Optional, Tuple, Union
import torch
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer
from transformer_engine.pytorch.experimental import utils
......@@ -36,7 +36,7 @@ class MMParams:
@dataclasses.dataclass
class ExperimentalQuantizedTensor(QuantizedTensorBase):
class ExperimentalQuantizedTensor(QuantizedTensorStorage):
"""Base class for experimental quantized tensor containers.
An experimental container to hold quantization result, including quantized tensor, optional
......@@ -187,7 +187,7 @@ class ExperimentalQuantizer(Quantizer):
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> QuantizedTensorBase:
) -> QuantizedTensorStorage:
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement make_empty function"
)
......
......@@ -22,6 +22,7 @@ from transformer_engine.common.recipe import (
Float8CurrentScaling,
Float8BlockScaling,
NVFP4BlockScaling,
CustomRecipe,
)
from .constants import dist_group_type
......@@ -866,6 +867,8 @@ class RecipeState(abc.ABC):
cls = Float8BlockScalingRecipeState
elif recipe.nvfp4():
cls = NVFP4BlockScalingRecipeState
elif recipe.custom():
cls = CustomRecipeState
else:
raise ValueError(f"{recipe.__class__.__name__} is not supported")
return cls(
......@@ -1191,3 +1194,56 @@ class NVFP4BlockScalingRecipeState(RecipeState):
]
raise RuntimeError(f"Unexpected recipe mode ({self.mode})")
class CustomRecipeState(RecipeState):
"""State for CustomRecipe: produce quantizers per tensor."""
recipe: CustomRecipe
mode: str
num_quantizers: int
device: Optional[torch.device]
def __init__(
self,
recipe: CustomRecipe,
*,
mode: str,
num_quantizers: int = 1,
device: Optional[torch.device] = None,
) -> None:
self.recipe = recipe
self.mode = mode
self.num_quantizers = num_quantizers
if device is None:
device = torch.device("cuda")
self.device = device
if getattr(recipe, "qfactory", None) is None:
raise ValueError("CustomRecipe requires `qfactory`.")
def make_quantizers(self) -> list:
qfactory = self.recipe.qfactory
out = []
# TODO(negvet): make_quantizers() should take roles from the operation
# Hardcode linear-specific roles for now
roles: List[str]
if self.mode == "forward":
roles = [
("linear_input", "linear_weight", "linear_output")[i % 3]
for i in range(self.num_quantizers)
]
elif self.mode == "backward":
roles = [
("linear_grad_output", "linear_grad_input")[i % 2]
for i in range(self.num_quantizers)
]
else:
roles = ["unknown"] * self.num_quantizers
for i in range(self.num_quantizers):
# Get quantizer from the user defined factory
quantizer = qfactory(roles[i])
out.append(quantizer)
return out
......@@ -38,15 +38,15 @@ from ..distributed import (
_fsdp_gather_tensors,
)
from ..constants import dist_group_type
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
......@@ -505,7 +505,7 @@ def fill_userbuffers_buffer_for_all_gather(
local_tensor: torch.Tensor,
quantizer: Optional[Quantizer],
process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
) -> tuple[torch.Tensor | QuantizedTensorStorage, torch.Tensor | QuantizedTensorStorage]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
......@@ -529,7 +529,7 @@ def fill_userbuffers_buffer_for_all_gather(
# Unquantized data
if quantizer is None:
if isinstance(local_tensor, QuantizedTensorBase):
if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf():
raise RuntimeError(
......@@ -542,8 +542,8 @@ def fill_userbuffers_buffer_for_all_gather(
# FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
if not isinstance(local_tensor, Float8TensorStorage):
if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor)
......@@ -554,7 +554,7 @@ def fill_userbuffers_buffer_for_all_gather(
)
comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorBase(
global_tensor = Float8TensorStorage(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
......@@ -566,8 +566,8 @@ def fill_userbuffers_buffer_for_all_gather(
if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
if not isinstance(local_tensor, MXFP8TensorStorage):
if isinstance(local_tensor, QuantizedTensorStorage):
local_tensor.dequantize()
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
......@@ -622,7 +622,7 @@ def fill_userbuffers_buffer_for_all_gather(
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorBase(
global_tensor = MXFP8TensorStorage(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
......@@ -786,10 +786,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
f"({len(weight_quantizers)}) must match"
)
for weight, quantizer in zip(weight_tensors, weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_quantizer(quantizer)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
......@@ -1038,6 +1038,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
# Set FP8_MAX per tensor according to recipe
if hasattr(self.fp8_meta["recipe"], "fp8_format"):
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
......@@ -1170,9 +1171,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output,
(
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
Float8TensorStorage,
MXFP8TensorStorage,
Float8BlockwiseQTensorStorage,
),
):
grad_output = quantizer(grad_output)
......@@ -1201,9 +1202,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_.get_tensor(True),
(
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
Float8TensorStorage,
MXFP8TensorStorage,
Float8BlockwiseQTensorStorage,
),
)
and ctx.use_bias
......@@ -1219,7 +1220,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ctx.use_bias:
if isinstance(
grad_output,
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
(
QuantizedTensor,
Float8TensorStorage,
MXFP8TensorStorage,
Float8BlockwiseQTensorStorage,
),
):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
else:
......@@ -1229,7 +1235,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(grad_output, QuantizedTensorBase):
if not isinstance(grad_output, QuantizedTensorStorage):
grad_output = quantizer(grad_output)
return grad_output, grad_bias
......@@ -1383,14 +1389,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Reset cache if workspace is invalid
if out is not None and quantizer is not None:
reset_cache = False
if isinstance(out, Float8TensorBase):
if isinstance(out, Float8TensorStorage):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and out._transpose is None
):
reset_cache = True
elif isinstance(out, MXFP8TensorBase):
elif isinstance(out, MXFP8TensorStorage):
if quantizer.rowwise_usage and out._rowwise_data is None:
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
......@@ -1581,7 +1587,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
recipe = self.fp8_meta["recipe"]
weight_tensors = [getattr(self, name) for name in self.weight_names]
for i, tensor in enumerate(weight_tensors):
if isinstance(tensor, QuantizedTensorBase):
if isinstance(tensor, QuantizedTensorStorage):
quantizer = tensor._get_quantizer()
if quantizer is None:
continue
......
......@@ -44,7 +44,7 @@ from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.quantized_tensor import (
QuantizedTensorBase,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -200,13 +200,13 @@ class _GroupedLinear(torch.autograd.Function):
inputmats[0] = inp
else:
for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensorBase):
if isinstance(inputmat, QuantizedTensorStorage):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
inputmats = [None] * num_gemms
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensorBase):
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
tensors_to_save, tensor_objects = prepare_for_saving(
......@@ -338,7 +338,7 @@ class _GroupedLinear(torch.autograd.Function):
)
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
......@@ -734,7 +734,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert not isinstance(
inp, QuantizedTensorBase
inp, QuantizedTensorStorage
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
......@@ -868,16 +868,17 @@ class GroupedLinear(TransformerEngineBaseModule):
self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"]
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors
w.dequantize() if isinstance(w, QuantizedTensorStorage) else w
for w in weight_tensors
]
return weight_tensors
......
......@@ -58,7 +58,7 @@ from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -66,8 +66,8 @@ from ..tensor.quantized_tensor import (
from ...debug.pytorch.debug_state import TEDebugState
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
......@@ -200,7 +200,7 @@ class _LayerNormLinear(torch.autograd.Function):
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not experimental
and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom()
)
# Apply normalization
......@@ -278,7 +278,7 @@ class _LayerNormLinear(torch.autograd.Function):
weightmat = weight
quantized_weight = False
if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorBase)
quantized_weight = not isinstance(weight, QuantizedTensorStorage)
# Configure quantizer
if weight_quantizer is not None:
......@@ -403,18 +403,18 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
if isinstance(ln_out, QuantizedTensorBase):
if isinstance(ln_out, QuantizedTensorStorage):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if (
isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
isinstance(ln_out, (MXFP8TensorStorage, Float8BlockwiseQTensorStorage))
or not ctx.ln_out_needs_gather
):
ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensorBase):
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
......@@ -685,9 +685,9 @@ class _LayerNormLinear(torch.autograd.Function):
# --------------------------------------------------
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase):
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator
......@@ -806,14 +806,14 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase):
if isinstance(ln_out_total, QuantizedTensorStorage):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -999,7 +999,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
# if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
......@@ -1790,7 +1790,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
......
......@@ -71,7 +71,7 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensorBase,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -116,9 +116,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"swiglu": (tex.swiglu, tex.dswiglu, None),
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling: []
# Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
# TODO(ksivaman): Fuse nvfp4 act once kernel is available.
if recipe.float8_current_scaling() or recipe.float8_block_scaling() or recipe.nvfp4():
if (
recipe.float8_current_scaling()
or recipe.float8_block_scaling()
or recipe.nvfp4()
or recipe.custom()
):
return {
"gelu": (tex.gelu, tex.dgelu, None),
"geglu": (tex.geglu, tex.dgeglu, None),
......@@ -448,10 +453,18 @@ class _LayerNormMLP(torch.autograd.Function):
act_out = fc2_input_quantizer(act_out)
else:
fc1_out, *_ = fc1_outputs
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise.
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise
act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer)
elif recipe.custom():
# tex.quantize does not support custom quantizers
act_out = activation_func(fc1_out, None)
act_out = fc2_input_quantizer(act_out)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
else:
if fp8_calibration:
act_out = activation_func(fc1_out, None)
......@@ -522,9 +535,9 @@ class _LayerNormMLP(torch.autograd.Function):
if is_grad_enabled:
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorBase):
if isinstance(fc1_weight_final, QuantizedTensorStorage):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensorBase):
if isinstance(fc2_weight_final, QuantizedTensorStorage):
fc2_weight_final.update_usage(columnwise_usage=True)
if cpu_offloading:
......@@ -823,10 +836,10 @@ class _LayerNormMLP(torch.autograd.Function):
)
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True)
if ctx.fc2_weight_quantizer is not None and isinstance(
ctx.fc2_weight, QuantizedTensorBase
ctx.fc2_weight, QuantizedTensorStorage
):
ctx.fc2_weight.update_usage(columnwise_usage=True)
......@@ -908,14 +921,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase):
if isinstance(act_out, QuantizedTensorStorage):
act_out.update_usage(columnwise_usage=True)
else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -1023,10 +1036,13 @@ class _LayerNormMLP(torch.autograd.Function):
) # activation in high precision
if ctx.fp8:
# TODO float8 blockwise current scaling has no bgrad fusion for now
# TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now
# TODO(ksivaman): Re-add fusion once kernel is available.
if isinstance(
if (
isinstance(
ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer)
)
or ctx.fp8_recipe.custom()
):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
......@@ -1072,7 +1088,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Make sure required data is available
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensorBase
ctx.fc1_weight_quantizer, QuantizedTensorStorage
):
ctx.fc1_weight.update_usage(columnwise_usage=True)
......@@ -1143,7 +1159,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase):
if isinstance(ln_out_total, QuantizedTensorStorage):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -1153,7 +1169,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(dact, QuantizedTensorBase):
if isinstance(dact, QuantizedTensorStorage):
dact.update_usage(columnwise_usage=True)
else:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -2153,7 +2169,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
return [self.fc1_weight, self.fc2_weight]
......
......@@ -59,7 +59,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
QuantizedTensorStorage,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -178,7 +178,7 @@ class _Linear(torch.autograd.Function):
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not isinstance(inputmat, QuantizedTensorBase) and not experimental:
if not isinstance(inputmat, QuantizedTensorStorage) and not experimental:
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
......@@ -216,7 +216,7 @@ class _Linear(torch.autograd.Function):
else: # Do not all-gather input tensor
if fp8 or debug:
if isinstance(inputmat, QuantizedTensorBase):
if isinstance(inputmat, QuantizedTensorStorage):
inputmat.update_usage(rowwise_usage=True)
else:
if input_quantizer is None:
......@@ -372,7 +372,7 @@ class _Linear(torch.autograd.Function):
if (
backward_needs_input
and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase)
and isinstance(inputmat, QuantizedTensorStorage)
):
if (
ctx.backward_input_needs_gather
......@@ -391,7 +391,7 @@ class _Linear(torch.autograd.Function):
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensorBase):
if isinstance(weightmat, QuantizedTensorStorage):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None:
......@@ -404,7 +404,7 @@ class _Linear(torch.autograd.Function):
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
saved_inputmat,
weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None,
weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
......@@ -613,7 +613,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work = None
if ctx.requires_wgrad:
if ctx.fp8 or ctx.debug:
if isinstance(inputmat, QuantizedTensorBase):
if isinstance(inputmat, QuantizedTensorStorage):
# Input tensor is already quantized
pass
elif ctx.debug or ctx.experimental:
......@@ -632,7 +632,7 @@ class _Linear(torch.autograd.Function):
quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat)
else:
if isinstance(inputmat, QuantizedTensorBase):
if isinstance(inputmat, QuantizedTensorStorage):
inputmat = inputmat.dequantize(dtype=ctx.activation_dtype)
else:
inputmat = cast_if_needed(inputmat, ctx.activation_dtype)
......@@ -677,9 +677,11 @@ class _Linear(torch.autograd.Function):
if ctx.requires_dgrad:
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
if ctx.weight_quantizer is not None and isinstance(
weight_fp8, QuantizedTensorStorage
):
weight_fp8.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator
......@@ -763,7 +765,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work.wait()
inputmat_total_work = None
if ctx.fp8 or ctx.debug:
if isinstance(inputmat_total, QuantizedTensorBase):
if isinstance(inputmat_total, QuantizedTensorStorage):
inputmat_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -805,7 +807,7 @@ class _Linear(torch.autograd.Function):
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
if isinstance(grad_output, QuantizedTensorStorage):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
......@@ -958,7 +960,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage):
_fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
wgrad,
......@@ -1524,7 +1526,7 @@ class Linear(TransformerEngineBaseModule):
for name, q in zip(names, original_quantizers)
)
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
"""Get the weight tensors of the module."""
unfused_weights = [getattr(self, name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
......
......@@ -13,17 +13,17 @@ from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..fp8 import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.quantized_tensor import QuantizedTensorBase
from ..tensor.quantized_tensor import QuantizedTensorStorage
from ..utils import canonicalize_dtype
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool:
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool:
"""Check if tensor is a quantized tensor"""
return isinstance(tensor, QuantizedTensorBase)
return isinstance(tensor, QuantizedTensorStorage)
def maybe_dequantize(
tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None
tensor: torch.Tensor | QuantizedTensorStorage, dtype: torch.dtype | None = None
) -> torch.Tensor:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if is_quantized_tensor(tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment