Unverified Commit 214e2a4a authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] GEMM custom op (#1855)



* added XLA FFI custom op for TE/common nvte_cublas_gemm
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

started GemmPrimitive, abstract done
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

gemm custom op working with BF16, needs testing for FP8/MXFP8
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

new GemmPrimitive passing all Dense tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

import cleanup and reverted code chunk movement
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed unused .transpose() implementations from ScaledTensors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed unused changes to ScaledTensor classes and debug prints
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* minor unit test cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* FP8 tests passing on Blackwell but MXFP8 outputs NaN
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* reverted dense and fuseddense changes, FP8 test passing on Hopper and Blackwell, MXFP8 has issues with E5M2
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 issue traced to scale factor padding with NaNs instead of zeros
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* padding scale with 2^-127 instead of nans
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix bug on rhs_scale_inv usage
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleanup E8M0 type converter use it in gemm.cpp
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* segfault fixed, passing all unittests on Blackwell
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fix for fuseddense tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix workspace alignment
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

all unit tests passing on H100x8 node
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



linting fixes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed batch dimension numbers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed FP8 scale sharding rule when there are no FP8 scales
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

added error message for unsupported Shardy partitioner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed test tolerances for FP8 cases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed shardy test skip cases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* moved reshape of encoder output in encoder examples to make custom partitioning rules work correctly
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added helper functions for padding and unpadding block scales, changed GemmPrimitive to accept unpadded scales and pad them after sharding
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated shardy rules for all custom ops to decouple block scale rules from their tensors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* changed unit test use_jax_gemm option to be a context to preserve external custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed typo in test utils
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added sequence-first input warnings
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed datasets version for JAX examples
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* reverting modification to force_1x_quantization decision
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected gemm function syntax in unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 397c4be6
...@@ -13,6 +13,7 @@ The implementation supports various normalization types, activation functions, ...@@ -13,6 +13,7 @@ The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints. quantization, and distributed training through sharding constraints.
""" """
import warnings
from typing import List, Tuple, Sequence, Union, Callable from typing import List, Tuple, Sequence, Union, Callable
from functools import partial from functools import partial
...@@ -31,6 +32,16 @@ from .quantize import ( ...@@ -31,6 +32,16 @@ from .quantize import (
from .sharding import get_non_contracting_logical_axes from .sharding import get_non_contracting_logical_axes
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_mlp( def layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -48,6 +59,7 @@ def layernorm_mlp( ...@@ -48,6 +59,7 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
batch_first: bool = True,
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block. """Apply layer normalization followed by MLP block.
...@@ -79,6 +91,7 @@ def layernorm_mlp( ...@@ -79,6 +91,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns: Returns:
...@@ -124,12 +137,13 @@ def layernorm_mlp( ...@@ -124,12 +137,13 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -149,6 +163,7 @@ def _layernorm_mlp( ...@@ -149,6 +163,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
batch_first: bool,
quantizer_sets, quantizer_sets,
): ):
"""Internal implementation of layernorm_mlp with custom VJP. """Internal implementation of layernorm_mlp with custom VJP.
...@@ -174,6 +189,7 @@ def _layernorm_mlp( ...@@ -174,6 +189,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s) activation_type: Activation function(s)
batch_first: Assume that X is batched in the first dimension.
quantizer_sets: Tuple of quantizer sets quantizer_sets: Tuple of quantizer sets
Returns: Returns:
...@@ -198,6 +214,7 @@ def _layernorm_mlp( ...@@ -198,6 +214,7 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
) )
return output return output
...@@ -222,6 +239,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -222,6 +239,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
): ):
"""Forward pass rule for layernorm_mlp. """Forward pass rule for layernorm_mlp.
...@@ -254,6 +272,17 @@ def _layernorm_mlp_fwd_rule( ...@@ -254,6 +272,17 @@ def _layernorm_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
use_bias_1 = bias_1 is not None use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None use_bias_2 = bias_1 is not None
...@@ -267,17 +296,23 @@ def _layernorm_mlp_fwd_rule( ...@@ -267,17 +296,23 @@ def _layernorm_mlp_fwd_rule(
epsilon, epsilon,
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
noop_scaled_tensor=True,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel) casted_kernel_1 = tex.quantize(
kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True
)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out) # (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm( dot_1_output = tex.gemm(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS), casted_kernel_1.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
) )
if dot_1_input_axes is not None and kernel_1_axes is not None: if dot_1_input_axes is not None and kernel_1_axes is not None:
...@@ -287,7 +322,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -287,7 +322,7 @@ def _layernorm_mlp_fwd_rule(
) )
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1: if use_bias_1 and tex.gemm_uses_jax_dot():
bias_1_shape = bias_1.shape bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) dot_1_output += jnp.reshape(bias_1, bias_1_new_shape)
...@@ -295,21 +330,28 @@ def _layernorm_mlp_fwd_rule( ...@@ -295,21 +330,28 @@ def _layernorm_mlp_fwd_rule(
dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name)
# (batch..., hidden_in) -> (batch..., hidden) # (batch..., hidden_in) -> (batch..., hidden)
casted_act_out = tex.act_lu(dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x) casted_act_out = tex.act_lu(
dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True
)
casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes)
casted_kernel_2 = tex.quantize(kernel_2, quantizer=ffn2_quantizer_set.kernel) casted_kernel_2 = tex.quantize(
kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True
)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in) # (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output = tex.gemm( dot_2_output = tex.gemm(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
) )
if use_bias_2: if use_bias_2 and tex.gemm_uses_jax_dot():
bias_2_shape = bias_2.shape bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
...@@ -334,6 +376,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -334,6 +376,7 @@ def _layernorm_mlp_fwd_rule(
use_bias_1, use_bias_1,
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim,
) )
return dot_2_output, ctx return dot_2_output, ctx
...@@ -351,6 +394,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -351,6 +394,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
ctx, ctx,
grad, grad,
): ):
...@@ -367,7 +411,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -367,7 +411,7 @@ def _layernorm_mlp_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first
( (
x, x,
mu, mu,
...@@ -386,6 +430,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -386,6 +430,7 @@ def _layernorm_mlp_bwd_rule(
use_bias_1, use_bias_1,
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim,
) = ctx ) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
...@@ -394,7 +439,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -394,7 +439,7 @@ def _layernorm_mlp_bwd_rule(
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
casted_grad, dbias_2 = tex.quantize_dbias( casted_grad, dbias_2 = tex.quantize_dbias(
grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -411,7 +456,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -411,7 +456,8 @@ def _layernorm_mlp_bwd_rule(
dgrad_2 = tex.gemm( dgrad_2 = tex.gemm(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2, casted_kernel_2,
(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
batched_dims=((x_bdim,), ()),
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
...@@ -425,7 +471,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -425,7 +471,8 @@ def _layernorm_mlp_bwd_rule(
wgrad_2 = tex.gemm( wgrad_2 = tex.gemm(
casted_act_out, casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -435,6 +482,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -435,6 +482,7 @@ def _layernorm_mlp_bwd_rule(
activation_type=activation_type, activation_type=activation_type,
is_dbias=use_bias_1, is_dbias=use_bias_1,
quantizer=ffn2_quantizer_set.dgrad, quantizer=ffn2_quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
...@@ -451,7 +499,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -451,7 +499,8 @@ def _layernorm_mlp_bwd_rule(
dgrad_1 = tex.gemm( dgrad_1 = tex.gemm(
casted_dact_out.get_tensor(TensorUsage.LHS), casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1, casted_kernel_1,
(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()),
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
...@@ -461,7 +510,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -461,7 +510,8 @@ def _layernorm_mlp_bwd_rule(
wgrad_1 = tex.gemm( wgrad_1 = tex.gemm(
casted_ln_out, casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS), casted_dact_out.get_tensor(TensorUsage.RHS),
(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
...@@ -36,6 +36,22 @@ class Dequantizer(ABC): ...@@ -36,6 +36,22 @@ class Dequantizer(ABC):
"""Dequantizing given tensor to higher precision.""" """Dequantizing given tensor to higher precision."""
@dataclass
class NoopDequantizer(Dequantizer):
"""No-op Dequantizer Class"""
@staticmethod
def _dequantize_func(data, *args, **kwargs):
"""A no-op dequantize function that returns the data without any changes."""
del args, kwargs
return data
@staticmethod
def dequantize(scaled_tensor):
"""A no-op dequantize function that simply returns the data array in the ScaledTensor."""
return scaled_tensor.data
class TensorScaleDequantizer(Dequantizer): class TensorScaleDequantizer(Dequantizer):
""" """
TensorScaling Dequantizer Class TensorScaling Dequantizer Class
...@@ -152,6 +168,7 @@ ScalingModeToDequantizerMap = { ...@@ -152,6 +168,7 @@ ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
ScalingMode.NO_SCALING: NoopDequantizer,
} }
......
...@@ -9,7 +9,9 @@ in JAX, including support for different scaling modes and datatypes. ...@@ -9,7 +9,9 @@ in JAX, including support for different scaling modes and datatypes.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from typing import Optional, Tuple, Dict, Union from typing import Optional, Tuple, Dict, Union, Sequence
from functools import reduce
import operator
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -29,6 +31,8 @@ __all__ = [ ...@@ -29,6 +31,8 @@ __all__ = [
"is_fp8_available", "is_fp8_available",
"update_collections", "update_collections",
"get_delayed_scaling", "get_delayed_scaling",
"apply_padding_to_scale_inv",
"remove_padding_from_scale_inv",
"NVTE_FP8_COLLECTION_NAME", "NVTE_FP8_COLLECTION_NAME",
] ]
...@@ -471,4 +475,115 @@ def update_collections(new: Collection, original: Collection) -> Collection: ...@@ -471,4 +475,115 @@ def update_collections(new: Collection, original: Collection) -> Collection:
return new_coll return new_coll
def remove_padding_from_scale_inv(
scale_inv: jax.Array,
scaling_mode: ScalingMode,
data_shape: Sequence[int],
is_colwise: bool = False,
flatten_axis: int = -1,
):
"""
Slice padding out of padded inverse scale factors.
Args:
scale_inv: Inverse scale factor.
data_shape: Shape of the quantized data the inverse scale belongs to.
scaling_mode: ScalingMode representing the quantization method.
is_colwise: Whether the data was quantized column-wise.
flatten_axis: The axis along with the data could be flattened to 2D.
Returns:
Inverse scale factor without padding.
"""
# Get expected unpadded scale shape and check if inverse scale already matches
unpadded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == unpadded_scale_shape:
return scale_inv
# Get the padded scale shape and make sure inverse scale matches
padded_scale_shape = scaling_mode.get_scale_shape(
data_shape,
is_colwise=is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
assert scale_inv.shape == padded_scale_shape, (
f"Padded inverse scale factor has wrong shape, expected {padded_scale_shape} but got "
f"{scale_inv.shape} instead."
)
# Reshape scale inverse to 2D in two stages to preserve the flatten axis
padded_scale_shape_2d = (
reduce(operator.mul, padded_scale_shape[:flatten_axis]),
reduce(operator.mul, padded_scale_shape[flatten_axis:]),
)
scale_inv_2d = jnp.reshape(
jnp.reshape(scale_inv, (padded_scale_shape_2d[0], *scale_inv.shape[flatten_axis:])),
padded_scale_shape_2d,
)
# Slice reshaped 2D scale inverse using collapsed 2D unpadded_scale_shape
unpadded_scale_shape_2d = (
reduce(operator.mul, unpadded_scale_shape[:flatten_axis]),
reduce(operator.mul, unpadded_scale_shape[flatten_axis:]),
)
scale_inv_2d_unpadded = jnp.asarray(
scale_inv_2d[: unpadded_scale_shape_2d[0], : unpadded_scale_shape_2d[1]]
)
# Reshape 2D scale inverse back in two stages in order to preserve the flatten axis
scale_inv_unpadded = jnp.reshape(
jnp.reshape(
scale_inv_2d_unpadded,
(*unpadded_scale_shape[:flatten_axis], scale_inv_2d_unpadded.shape[1]),
),
unpadded_scale_shape,
)
return scale_inv_unpadded
def apply_padding_to_scale_inv(
scale_inv: jax.Array,
scaling_mode: ScalingMode,
data_shape: Sequence[int],
is_colwise: bool = False,
flatten_axis: int = -1,
):
"""
Pad the scale inverse with zeros to match the necessary padded shape for this scaling
mode.
Args:
scale_inv: Inverse scale factor.
data_shape: Shape of the quantized data the inverse scale belongs to.
scaling_mode: ScalingMode representing the quantization method.
is_colwise: Whether the data was quantized column-wise.
flatten_axis: The axis along with the data could be flattened to 2D.
Returns:
Padded inverse scale factor.
"""
# Get the expected padded scale shape and check if inverse scale already matches
padded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=True, flatten_axis=flatten_axis
)
if scaling_mode == ScalingMode.NO_SCALING or scale_inv.shape == padded_scale_shape:
return scale_inv
# Get the expected unpadded scale shape and make sure inverse scales match
unpadded_scale_shape = scaling_mode.get_scale_shape(
data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
)
assert scale_inv.shape == unpadded_scale_shape, (
f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} but got "
f"{scale_inv.shape}."
)
# Pad the scales with the lowest representable value (2^-127) and return
pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape))
return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127)
NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME
...@@ -17,7 +17,7 @@ from functools import reduce, lru_cache ...@@ -17,7 +17,7 @@ from functools import reduce, lru_cache
import operator import operator
import numpy as np import numpy as np
from jax.experimental.custom_partitioning import CompoundFactor from jax.experimental.custom_partitioning import BATCHING
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp import jax.numpy as jnp
...@@ -252,8 +252,9 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -252,8 +252,9 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
del flatten_axis del flatten_axis
input_spec = tuple(f"x{i}" for i in range(input_rank)) input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
return QuantizeShardyRules(input_spec, (unique_var,), (unique_var,), {}) scale_var = BATCHING + unique_var + "_scale_inv"
return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
...@@ -488,31 +489,41 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -488,31 +489,41 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns: Returns:
The Shardy rules for the scaling mode The Shardy rules for the scaling mode
""" """
input_spec = [f"x{i}" for i in range(input_rank)] del flatten_axis
input_spec = [f"{unique_var}{i}" for i in range(input_rank)]
# We have to use two different factors in the two CompoundFactors because of Shardy rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)]
# verifier requirements, even though they are the same. colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)]
rowwise_var = unique_var
colwise_var = f"{unique_var}_" # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors.
input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") # Unfortunately, because Shardy rules are applied to the inner primitive, the
input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") # only way to preserve the relationship is to lower unpadded scales to the
# underlying custom call and pad them in C++. Until that's implemented, the
# The rowwise and colwise scale tensors should be sharded the same way as the input. # Shardy rules for block scales have to be completely disconnected from the
# However, we need to adjust the dimensions where the block scaling factor applies. # Shardy rules for the tensor they belong to.
rowwise = input_spec.copy()
rowwise[-1] = rowwise_var # # We have to use two different factors in the two CompoundFactors because of Shardy
# # verifier requirements, even though they are the same.
colwise = input_spec.copy() # rowwise_var = unique_var
colwise[flatten_axis - 1] = colwise_var # colwise_var = f"{unique_var}_"
# input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
# This implementation needs to be updated for different block dims. # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")
assert self._block_dims == (1, 32)
# # The rowwise and colwise scale tensors should be sharded the same way as the input.
# # However, we need to adjust the dimensions where the block scaling factor applies.
# rowwise = input_spec.copy()
# rowwise[-1] = rowwise_var
# colwise = input_spec.copy()
# colwise[flatten_axis - 1] = colwise_var
# # This implementation needs to be updated for different block dims.
# assert self._block_dims == (1, 32)
return QuantizeShardyRules( return QuantizeShardyRules(
tuple(input_spec), tuple(input_spec),
tuple(rowwise), tuple(rowwise),
tuple(colwise), tuple(colwise),
{"block_size_rowwise": 32, "block_size_colwise": 32}, {}, # {"block_size_rowwise": 32, "block_size_colwise": 32},
) )
......
...@@ -17,6 +17,7 @@ from jax.tree_util import register_pytree_node_class ...@@ -17,6 +17,7 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout from transformer_engine_jax import QuantizeLayout
from .helper import apply_padding_to_scale_inv
from .scaling_modes import ScalingMode, TensorUsage from .scaling_modes import ScalingMode, TensorUsage
from .dequantizer import ScalingModeToDequantizerMap from .dequantizer import ScalingModeToDequantizerMap
from ..sharding import ( from ..sharding import (
...@@ -56,6 +57,11 @@ class ScaledTensor(ABC): ...@@ -56,6 +57,11 @@ class ScaledTensor(ABC):
""" """
return cls(*children, *aux_data) return cls(*children, *aux_data)
@property
@abstractmethod
def ndim(self):
"""Number of dimensions of the underlying quantized array."""
@abstractmethod @abstractmethod
def dequantize(self): def dequantize(self):
"""Dequantizes the tensor back to its original precision. """Dequantizes the tensor back to its original precision.
...@@ -127,24 +133,16 @@ class ScaledTensor1x(ScaledTensor): ...@@ -127,24 +133,16 @@ class ScaledTensor1x(ScaledTensor):
0 < self.flatten_axis < len(self.data.shape) 0 < self.flatten_axis < len(self.data.shape)
), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}" ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
expected_scale_shape = self.scaling_mode.get_scale_shape( if self.scaling_mode == ScalingMode.NO_SCALING:
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis self.scale_inv = jnp.empty((0,), dtype=jnp.float32)
)
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( else:
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis self.scale_inv = apply_padding_to_scale_inv(
) self.scale_inv,
if self.scale_inv.shape != expected_scale_shape: self.scaling_mode,
assert self.scale_inv.shape == expected_unpadded_scale_shape, ( self.data.shape,
f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded" is_colwise=self.is_colwise,
f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got" flatten_axis=self.flatten_axis,
f" {self.scale_inv.shape}"
)
pad_width = tuple(
(0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape)
)
# padding with the smallest number it can present
self.scale_inv = jnp.pad(
self.scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127
) )
def tree_flatten(self): def tree_flatten(self):
...@@ -164,6 +162,10 @@ class ScaledTensor1x(ScaledTensor): ...@@ -164,6 +162,10 @@ class ScaledTensor1x(ScaledTensor):
) )
return (children, aux_data) return (children, aux_data)
@property
def ndim(self):
return self.data.ndim
def dequantize(self): def dequantize(self):
"""Dequantizes the tensor using the stored dequantization function. """Dequantizes the tensor using the stored dequantization function.
...@@ -347,6 +349,11 @@ class ScaledTensor2x(ScaledTensor): ...@@ -347,6 +349,11 @@ class ScaledTensor2x(ScaledTensor):
aux_data = () aux_data = ()
return (children, aux_data) return (children, aux_data)
@property
def ndim(self):
"""Number of dimensions of the underlying row-wise tensor."""
return self.rowwise_tensor.ndim
def dequantize(self): def dequantize(self):
"""Dequantizes the tensor using the row-wise component's dequantization. """Dequantizes the tensor using the row-wise component's dequantization.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment