Unverified Commit 3a298e6b authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] TensorUsage + FP8 GEMM with all layouts handling on BW (#1844)



* TensorUsage + FP8 GEMM with all layouts handling on BW
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>


---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent ae572af0
......@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
else:
assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert_dequantized_scaled_tensor(a.get_rowwise_tensor(), b)
assert_dequantized_scaled_tensor(a.get_colwise_tensor(), b)
assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_scaled_tensor(a.colwise_tensor, b)
else:
pytest.fail("a must be a ScaledTensor object")
......@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor(
dq_a_i = dq_a_i.reshape(b_i.shape)
assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert isinstance(a.get_rowwise_tensor(), GroupedScaledTensor1x)
assert isinstance(a.get_colwise_tensor(), GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.get_rowwise_tensor(), b)
assert_dequantized_grouped_scaled_tensor(a.get_colwise_tensor(), b)
assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
else:
pytest.fail("a must be a GroupedScaledTensor object")
......
......@@ -24,10 +24,11 @@ from ..quantize import (
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
)
__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"]
__all__ = ["gemm", "grouped_gemm"]
num_cublas_streams = get_num_compute_streams()
......@@ -40,11 +41,6 @@ def get_cublas_workspace_size_bytes() -> None:
return 4_194_304
def is_gemm_with_all_layouts_supported() -> False:
"""Return True if using blackwell, False otherwise."""
return get_device_compute_capability(0) >= 100
class GroupedGemmPrimitive(BasePrimitive):
"""
Primitive for grouped GEMM
......@@ -338,10 +334,15 @@ def _jax_gemm(
if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor):
if quantizer_set != noop_quantizer_set:
assert type(quantizer_set.x) is type(quantizer_set.kernel)
if (
quantizer_set.x.scaling_mode.is_tensor_scaling()
and is_fp8_gemm_with_all_layouts_supported()
):
lhs_is_rowwise = rhs_is_rowwise = True
else:
(((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
# Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm)
lhs_q = quantizer_set.x.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
......@@ -491,16 +492,13 @@ def grouped_gemm(
assert type(quantizer_set.x) is type(quantizer_set.kernel)
scaling_mode = quantizer_set.x.scaling_mode
if (
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
# scaling_mode.is_tensor_scaling()
# and is_gemm_with_all_layouts_supported()
scaling_mode.is_1d_block_scaling()
quantizer_set.x.scaling_mode.is_tensor_scaling()
and is_fp8_gemm_with_all_layouts_supported()
):
lhs_is_rowwise = True
rhs_is_rowwise = False
lhs_is_rowwise = rhs_is_rowwise = True
else:
lhs_is_rowwise = not lhs_is_trans
rhs_is_rowwise = lhs_is_trans
rhs_is_rowwise = rhs_is_trans
quantizer_set.x.q_layout = (
QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
)
......@@ -515,6 +513,8 @@ def grouped_gemm(
rhs_data = rhs_q.data
lhs_scale_inv = lhs_q.scale_inv
rhs_scale_inv = rhs_q.scale_inv
lhs_shape = lhs_q.original_shape
rhs_shape = rhs_q.original_shape
assert not (
lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
......@@ -522,24 +522,35 @@ def grouped_gemm(
# Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
# thus additional transpose is required
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported():
lhs_is_trans = False
rhs_is_trans = True
if scaling_mode.is_tensor_scaling() and not is_fp8_gemm_with_all_layouts_supported():
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
lhs_layout_is_T = lhs.data_layout == "T"
rhs_layout_is_T = rhs.data_layout == "T"
else:
lhs_layout_is_T = lhs_q.data_layout == "T"
rhs_layout_is_T = rhs_q.data_layout == "T"
# we can't apply _shape_normalization on the grouped input
# thus we need to ensure that lhs is in N and rhs is in T
assert (
lhs_is_trans == lhs_layout_is_T
), "lhs input must be transposed before calling grouped_gemm"
assert (
not rhs_is_trans == rhs_layout_is_T
), "rhs input must be transposed before calling grouped_gemm"
lhs_is_trans = False
rhs_is_trans = True
lhs_ndim = len(lhs_shape)
rhs_ndim = len(rhs_shape)
if lhs_layout_is_T:
lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
if rhs_layout_is_T:
# For rhs [G, K, N], need to exclude the G dim from contract_dim
if group_sizes.size == rhs_shape[0]:
rhs_contract_dim = tuple(
(rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim
)
else:
rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim)
lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T)
rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T)
# Calling GroupedGEMM Custom Call
K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
......
......@@ -19,6 +19,7 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
)
......@@ -105,8 +106,8 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes,
# GEMM NN
output = tex.gemm(
casted_x.get_rowwise_tensor(),
casted_kernel.get_colwise_tensor(),
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims),
)
......@@ -116,8 +117,8 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes,
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
x.shape,
kernel.shape,
use_bias,
......@@ -138,8 +139,8 @@ def _dense_bwd_rule(
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
(
colwise_casted_x,
rowwise_casted_kernel,
casted_x_lhs,
casted_kernel_rhs,
x_shape,
kernel_shape,
use_bias,
......@@ -161,8 +162,8 @@ def _dense_bwd_rule(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel,
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
(g_contracting_dim, k_contracting_dim),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......@@ -174,7 +175,9 @@ def _dense_bwd_rule(
)
wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim)
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dim, g_contracting_dim),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......@@ -287,13 +290,6 @@ def _grouped_dense_fwd_rule(
"and k_contracting_dims=(1,) for now, "
f"got {x_contracting_dims=} and {k_contracting_dims=}"
)
scaling_mode = quantizer_set.x.scaling_mode
if scaling_mode.is_tensor_scaling():
k_contracting_dims = (0,)
elif scaling_mode.is_1d_block_scaling():
k_contracting_dims = (1,)
else:
raise ValueError(f"Unsupported scaling mode {scaling_mode.value} for grouped_dense")
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
......@@ -306,11 +302,10 @@ def _grouped_dense_fwd_rule(
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor()
grouped_gemm_kernel = casted_kernel.get_colwise_tensor()
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()?
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS)
output = tex.grouped_gemm(
grouped_gemm_x,
......@@ -388,7 +383,7 @@ def _grouped_dense_bwd_rule(
g_contracting_dim = (1,)
k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor()
dgrad_grad = casted_grad.get_tensor(usage=TensorUsage.LHS)
dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
......@@ -398,7 +393,7 @@ def _grouped_dense_bwd_rule(
x_contracting_dim = (0,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor()
wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS)
dgrad = tex.grouped_gemm(
dgrad_grad,
......
......@@ -21,6 +21,7 @@ from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
TensorUsage,
)
......@@ -198,8 +199,8 @@ def _layernorm_dense_fwd_rule(
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output = tex.gemm(
casted_ln_out.get_rowwise_tensor(),
casted_kernel.get_colwise_tensor(),
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims),
)
......@@ -209,8 +210,8 @@ def _layernorm_dense_fwd_rule(
output += jnp.reshape(bias, bias_new_shape)
ctx = (
casted_ln_out.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None,
casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None,
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel.get_tensor(TensorUsage.RHS_TRANS),
x.shape,
kernel.shape,
mu,
......@@ -250,8 +251,8 @@ def _layernorm_dense_bwd_rule(
Tuple of gradients for all input parameters
"""
(
colwise_casted_ln_out,
rowwise_casted_kernel,
casted_ln_out,
casted_kernel,
x_shape,
kernel_shape,
mu,
......@@ -281,8 +282,8 @@ def _layernorm_dense_bwd_rule(
# NT GEMM
dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel,
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel,
(g_constracting_dim, k_constracting_dim),
)
......@@ -294,8 +295,8 @@ def _layernorm_dense_bwd_rule(
# TN GEMM
wgrad = tex.gemm(
colwise_casted_ln_out,
casted_grad.get_colwise_tensor(),
casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS),
(x_constracting_dim, g_constracting_dim),
)
......
......@@ -22,7 +22,12 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .quantize import (
with_sharding_constraint_by_logical_axes,
QuantizerSet,
noop_quantizer_set,
TensorUsage,
)
from .sharding import get_non_contracting_logical_axes
......@@ -270,8 +275,8 @@ def _layernorm_mlp_fwd_rule(
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm(
casted_ln_out.get_rowwise_tensor(),
casted_kernel_1.get_colwise_tensor(),
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims),
)
......@@ -299,8 +304,8 @@ def _layernorm_mlp_fwd_rule(
# NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = tex.gemm(
casted_act_out.get_rowwise_tensor(),
casted_kernel_2.get_colwise_tensor(),
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims),
)
......@@ -317,11 +322,11 @@ def _layernorm_mlp_fwd_rule(
rsigma,
gamma,
beta,
casted_ln_out.get_colwise_tensor(),
casted_kernel_1.get_rowwise_tensor(),
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
dot_1_output,
casted_act_out.get_colwise_tensor(),
casted_kernel_2.get_rowwise_tensor(),
casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
x_contracting_dims,
k_contracting_dims,
kernel_1.shape,
......@@ -369,11 +374,11 @@ def _layernorm_mlp_bwd_rule(
rsigma,
gamma,
beta,
colwise_casted_ln_out,
rowwise_casted_kernel_1,
casted_ln_out,
casted_kernel_1,
dot_1_output,
colwise_casted_act_out,
rowwise_casted_kernel_2,
casted_act_out,
casted_kernel_2,
x_contracting_dims_in_fwd,
k_contracting_dims_in_fwd,
kernel_1_shape,
......@@ -404,8 +409,8 @@ def _layernorm_mlp_bwd_rule(
# NT GEMM
# (batch..., hidden_out) x (hidden_in, hidden_out)
dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2,
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2,
(g_contracting_dims_2, k_contracting_dims_2),
)
......@@ -418,8 +423,8 @@ def _layernorm_mlp_bwd_rule(
# TN GEMM
# (hidden, batch...,) x (hidden, batch...)
wgrad_2 = tex.gemm(
colwise_casted_act_out,
casted_grad.get_colwise_tensor(),
casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS),
(x_contracting_dims, g_contracting_dims),
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
......@@ -433,7 +438,7 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
dact_out_ndim = casted_dact_out.get_tensor(TensorUsage.LHS).data.ndim
g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
)
......@@ -444,8 +449,8 @@ def _layernorm_mlp_bwd_rule(
# NT GEMM
dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1,
casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1,
(g_contracting_dims_1, k_contracting_dims_1),
)
......@@ -454,8 +459,8 @@ def _layernorm_mlp_bwd_rule(
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm(
colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(),
casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS),
(x_contracting_dims, g_contracting_dims),
)
......
......@@ -15,3 +15,4 @@ from .dequantizer import *
from .scaling_modes import *
from .metadata import *
from .helper import *
from .device_utils import *
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Device utility functions for JAX quantization.
This module provides utility functions for checking device capabilities and compatibility
for quantization operations in JAX.
"""
import functools
import transformer_engine_jax
__all__ = [
"get_device_compute_capability",
"is_fp8_gemm_with_all_layouts_supported",
]
@functools.lru_cache(maxsize=None)
def get_device_compute_capability(gpu_id: int = 0) -> int:
"""
Get the compute capability of the device.
"""
return transformer_engine_jax.get_device_compute_capability(gpu_id)
@functools.lru_cache(maxsize=None)
def is_fp8_gemm_with_all_layouts_supported() -> bool:
"""Return True if using Blackwell architecture, False otherwise."""
compute_capability = get_device_compute_capability()
return 100 <= compute_capability < 120
......@@ -15,17 +15,13 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version
from transformer_engine.common import recipe
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from .scaling_modes import ScalingMode
from .. import cpp_extensions as tex
from .device_utils import get_device_compute_capability
__all__ = [
"QuantizeConfig",
......@@ -203,7 +199,7 @@ class QuantizeConfig:
FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass
FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients
FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients
IF_QUANTIZE_2X: Whether 2x quantization is enabled
INFERENCE_MODE: Whether to enable optimization for inference
SCALING_MODE: Scaling mode
AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling
AMAX_COMPUTE_ALGO: Algorithm for AMAX computation
......@@ -218,7 +214,7 @@ class QuantizeConfig:
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
IF_QUANTIZE_2X: bool = False
INFERENCE_MODE: bool = False
SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING
# DelayedScaling
......@@ -246,7 +242,6 @@ class QuantizeConfig:
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
cls.IF_QUANTIZE_2X = True
@classmethod
def finalize(cls) -> None:
......@@ -260,7 +255,7 @@ class QuantizeConfig:
cls.FP8_2X_ACC_DGRAD = False
cls.FP8_2X_ACC_WGRAD = False
cls.SCALING_MODE = ScalingMode.NO_SCALING
cls.IF_QUANTIZE_2X = False
cls.INFERENCE_MODE = False
# DelayedScaling
cls.AMAX_HISTORY_LEN = 1024
cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX
......
......@@ -23,6 +23,7 @@ from .helper import (
QuantizeConfig,
AmaxComputeAlgo,
)
from .device_utils import is_fp8_gemm_with_all_layouts_supported
__all__ = [
"QuantizeLayout",
......@@ -607,9 +608,10 @@ class GroupedQuantizer(Quantizer):
def __post_init__(self):
if self.quantizers[0] is None:
self.quantizers = QuantizerFactory.create(
quantizers = QuantizerFactory.create(
self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout
)
self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers
self.data_layout = self.quantizers[0].data_layout
def _create_grouped_tensor_from_tensor_list(
......@@ -841,8 +843,10 @@ class QuantizerFactory:
if is_2x2x:
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else:
q_layout_x = QuantizeLayout.ROWWISE
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
if scaling_mode.is_1d_block_scaling():
q_layout_kernel = QuantizeLayout.COLWISE
if QuantizeConfig.INFERENCE_MODE:
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
......@@ -898,7 +902,15 @@ class QuantizerFactory:
scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X
if is_2x2x is None:
if scaling_mode.is_1d_block_scaling():
is_2x2x = True
elif scaling_mode.is_tensor_scaling():
is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
else: # NO_SCALING ignores is_2x2x for now
is_2x2x = False
is_inference_mode = QuantizeConfig.INFERENCE_MODE
assert not is_inference_mode, "Inference mode is not supported yet!"
q_set = []
for _ in range(n_quantizer_sets):
......@@ -911,4 +923,4 @@ class QuantizerFactory:
return q_set[0] if len(q_set) == 1 else tuple(q_set)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING)
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING, is_2x2x=False)
......@@ -13,7 +13,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Dict
from functools import reduce
from functools import reduce, lru_cache
import operator
import numpy as np
......@@ -21,10 +21,44 @@ from jax.experimental.custom_partitioning import CompoundFactor
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
from transformer_engine_jax import JAXX_Scaling_Mode
from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported
__all__ = ["QuantizeShardyRules", "ScalingMode"]
__all__ = [
"QuantizeShardyRules",
"ScalingMode",
"TensorUsage",
]
class TensorUsage(Enum):
"""Enum indicating tensor usage in GEMM operations.
Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form.
The tensor usage can be:
- LHS: A is in the normal form
- LHS_TRANS: A is in the transposed form
- RHS: B is in the normal form
- RHS_TRANS: B is in the transposed form
The tensor usage is used in the ScaledTensor.get_tensor() method.
"""
# LHS: Left-hand side, RHS: Right-hand side
# LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed
LHS = 0
LHS_TRANS = 1
RHS = 2
RHS_TRANS = 3
def __eq__(self, other):
if not isinstance(other, TensorUsage):
return False
return self.value == other.value
def __hash__(self):
return hash(self.value)
def DIVUP(a, b):
......@@ -104,6 +138,18 @@ class ScalingModeMetadataImpl(ABC):
The shape for scale tensors
"""
@lru_cache(maxsize=4)
@abstractmethod
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
@abstractmethod
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis
......@@ -157,6 +203,23 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (0,)
return (1,)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
if is_fp8_gemm_with_all_layouts_supported():
return QuantizeLayout.ROWWISE
if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
return QuantizeLayout.ROWWISE
return QuantizeLayout.COLWISE
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
......@@ -321,6 +384,27 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape)
@lru_cache(maxsize=4)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
# If we need to support 1x1x for inference in the future
# if QuantizeConfig.INFERENCE_MODE:
# assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
# if usage == TensorUsage.LHS:
# return QuantizeLayout.ROWWISE
# return QuantizeLayout.COLWISE
if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
return QuantizeLayout.ROWWISE
return QuantizeLayout.COLWISE
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
......@@ -506,6 +590,17 @@ class ScalingMode(Enum):
"""
return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
"""Get the quantize layout for the tensor usage.
Args:
usage: The usage of the tensor
Returns:
The quantize layout for the tensor usage
"""
return self._get_impl().get_quantize_layout(usage)
def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis=-1
) -> Tuple[Tuple[str]]:
......
......@@ -17,13 +17,14 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .scaling_modes import ScalingMode, TensorUsage
from .dequantizer import ScalingModeToDequantizerMap
from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)
__all__ = [
"TensorUsage",
"ScaledTensor",
"ScaledTensor1x",
"ScaledTensor2x",
......@@ -64,25 +65,15 @@ class ScaledTensor(ABC):
"""
@abstractmethod
def get_rowwise_tensor(self):
"""Returns the row-wise component of the tensor.
def get_tensor(self, usage: TensorUsage):
"""Returns the appropriate tensor based on the tensor usage and the scaling mode.
If the tensor usage is not valid for the scaling mode, an error is raised.
Returns:
The row-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support row-wise access
"""
@abstractmethod
def get_colwise_tensor(self):
"""Returns the column-wise component of the tensor.
Args:
usage: The usage of the tensor
Returns:
The column-wise tensor component
Raises:
ValueError: If called on a tensor that doesn't support column-wise access
The tensor based on the usage
"""
@abstractmethod
......@@ -181,33 +172,19 @@ class ScaledTensor1x(ScaledTensor):
"""
return self._dq_func(self)
def get_rowwise_tensor(self):
"""Returns the tensor if it's row-wise quantized.
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout = self.scaling_mode.get_quantize_layout(usage)
colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise
rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise
Returns:
The row-wise tensor
Raises:
ValueError: If called on a column-wise quantized tensor
"""
if not self.is_colwise:
if colwise_usage_valid or rowwise_usage_valid:
return self
raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!")
def get_colwise_tensor(self):
"""Returns the tensor if it's column-wise quantized.
Returns:
The column-wise tensor
Raises:
ValueError: If called on a row-wise quantized tensor
"""
if self.is_colwise:
return self
raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")
raise ValueError(
f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
f" self.is_colwise={self.is_colwise}!"
)
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
......@@ -378,22 +355,22 @@ class ScaledTensor2x(ScaledTensor):
"""
return self.rowwise_tensor.dequantize()
def get_rowwise_tensor(self):
"""Returns the row-wise quantized component.
def get_tensor(self, usage: TensorUsage):
"""Returns the tensor based on the tensor usage."""
q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
Returns:
The row-wise tensor component
"""
if q_layout_rowwise == QuantizeLayout.ROWWISE:
return self.rowwise_tensor
def get_colwise_tensor(self):
"""Returns the column-wise quantized component.
Returns:
The column-wise tensor component
"""
if q_layout_colwise == QuantizeLayout.COLWISE:
return self.colwise_tensor
raise ValueError(
f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
)
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
"""Applies sharding constraints to a tensor based on logical axis names.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment