Unverified Commit ff884e20 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flatten_axis for quantization and Sharding propagation fixes (#1644)



* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout

* add fatten_axis option

* added gated act to test encoder

* sharding constraint fixes

* fix padding when flattening first dim needs to be padded

* update test sizes so that padding is tested

* rm output sharding as it can be done in the flax module

* sharding scale_inv for mxfp8

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent be1f647c
...@@ -57,13 +57,14 @@ class Net(nn.Module): ...@@ -57,13 +57,14 @@ class Net(nn.Module):
self_attn_mask_type="padding", self_attn_mask_type="padding",
enable_relative_embedding=False, enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral, enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
) )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral: if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device. # Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint( x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
) )
...@@ -459,7 +460,7 @@ class TestEncoder(unittest.TestCase): ...@@ -459,7 +460,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -467,7 +468,7 @@ class TestEncoder(unittest.TestCase): ...@@ -467,7 +468,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -475,14 +476,14 @@ class TestEncoder(unittest.TestCase): ...@@ -475,14 +476,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self): def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self): def test_te_delayed_scaling_fp8_with_sp(self):
...@@ -491,7 +492,7 @@ class TestEncoder(unittest.TestCase): ...@@ -491,7 +492,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self): def test_te_mxfp8_with_sp(self):
...@@ -500,7 +501,7 @@ class TestEncoder(unittest.TestCase): ...@@ -500,7 +501,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
if __name__ == "__main__": if __name__ == "__main__":
......
This diff is collapsed.
...@@ -45,11 +45,17 @@ if is_mxfp8_supported: ...@@ -45,11 +45,17 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in] INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES) DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64 INTERMEDIATE = 64
...@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs(): ...@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
configs.append( configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
) )
if is_devices_enough(4): if is_devices_enough(4):
configs.append( configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
...@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP: ...@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal( k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in) ) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt( k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE INTERMEDIATE
) )
if use_bias: if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype) b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
else: else:
b1 = None b1 = None
...@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP: ...@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
layernorm_input_axes = LAYERNORM_INPUT_AXES layernorm_input_axes = LAYERNORM_INPUT_AXES
dot_1_input_axes = DOT_1_INPUT_AXES dot_1_input_axes = DOT_1_INPUT_AXES
dot_2_input_axes = DOT_2_INPUT_AXES dot_2_input_axes = DOT_2_INPUT_AXES
kernel_1_axes = KERNEL_1_AXES
kernel_2_axes = KERNEL_2_AXES
else: else:
layernorm_input_axes = None layernorm_input_axes = None
dot_1_input_axes = None dot_1_input_axes = dot_2_input_axes = None
dot_2_input_axes = None kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
...@@ -130,6 +137,8 @@ class TestDistributedLayernormMLP: ...@@ -130,6 +137,8 @@ class TestDistributedLayernormMLP:
norm_input_axes=layernorm_input_axes, norm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes, dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes, dot_2_input_axes=dot_2_input_axes,
kernel_1_axes=kernel_1_axes,
kernel_2_axes=kernel_2_axes,
activation_type=activation_type, activation_type=activation_type,
quantizer_sets=quantizer_sets, quantizer_sets=quantizer_sets,
) )
...@@ -142,7 +151,7 @@ class TestDistributedLayernormMLP: ...@@ -142,7 +151,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_primitive( def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
): ):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
...@@ -168,12 +177,12 @@ class TestDistributedLayernormMLP: ...@@ -168,12 +177,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp")) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding) k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding) k2_ = jax.device_put(k2, k2_sharding)
if use_bias: if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec("tp")) b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding) b1_ = jax.device_put(b1, b1_sharding)
else: else:
b1_sharding = b1_ = None b1_sharding = b1_ = None
...@@ -267,16 +276,7 @@ class TestDistributedLayernormMLP: ...@@ -267,16 +276,7 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, # input: [batch, seqlen, hidden] transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=use_bias, use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
) )
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply( mlp_out_single, ln_out_single = ln_mlp_single.apply(
...@@ -295,13 +295,13 @@ class TestDistributedLayernormMLP: ...@@ -295,13 +295,13 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=LN_SCALE_AXES,
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=LN_BIAS_AXES,
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes_1=KERNEL_1_AXES,
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), kernel_axes_2=KERNEL_2_AXES,
use_bias=use_bias, use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES), bias_axes_1=BIAS_1_AXES,
bias_axes_2=(W_NO_SHARD_AXES,), bias_axes_2=BIAS_2_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES,
...@@ -334,7 +334,7 @@ class TestDistributedLayernormMLP: ...@@ -334,7 +334,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_layer( def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
): ):
self._test_layernorm_mlp( self._test_layernorm_mlp(
......
...@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g): ...@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx (x, _) = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type) dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None) return (dx, None)
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
from typing import Tuple, Sequence, Union, Dict, List from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce from functools import partial, reduce
import operator import operator
from transformer_engine_jax import get_device_compute_capability
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
...@@ -183,10 +183,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): ...@@ -183,10 +183,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose # Reshape + Transpose
# [..., M, K] -> [B, M, K] # [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K] # [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N") lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T") rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,))) dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general( out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
...@@ -203,9 +202,9 @@ def _jax_gemm_delayed_scaling_fp8( ...@@ -203,9 +202,9 @@ def _jax_gemm_delayed_scaling_fp8(
), "rhs does not have delayed tensor scaling mode" ), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T": if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract) lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T": if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract) rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch) lhs_dn = (lhs_contract, lhs_batch)
...@@ -403,19 +402,19 @@ def grouped_gemm( ...@@ -403,19 +402,19 @@ def grouped_gemm(
lhs_shape = lhs.data.shape lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
assert not ( assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2" ), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T": if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T": if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else: else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN # For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING scaling_mode = ScalingMode.NVTE_NO_SCALING
lhs_shape = lhs.shape lhs_shape = lhs.shape
rhs_shape = rhs.shape rhs_shape = rhs.shape
...@@ -432,8 +431,8 @@ def grouped_gemm( ...@@ -432,8 +431,8 @@ def grouped_gemm(
lhs_3d = _shape_normalization(lhs, lhs_dn) lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N") lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn) lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn)
......
...@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type ...@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import transformer_engine_jax import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout
TEDType = transformer_engine_jax.DType TEDType = transformer_engine_jax.DType
...@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim): ...@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1): def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
""" """
te_cast_transpose_p multi-dims transpose te_cast_transpose_p multi-dims transpose
static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
involved into transpose, -1 means all axes involve into transpose. involved into transpose, -1 means all axes involve into transpose.
transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary transpose. Note, transpose_axis should be greater than static_axis_boundary
examples: examples:
X in shape (dim0, dim1, dim2, dim3, dim4) X in shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis_boundary == 2 static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1) Xt = (dim2, dim3, dim4, dim0, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 2 static_axis_boundary == 0, transpose_axis == 2
Xt = (dim0, dim2, dim3, dim4, dim1) Xt = (dim0, dim2, dim3, dim4, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 3 static_axis_boundary == 0, transpose_axis == 3
Xt = (dim0, dim3, dim4, dim1. dim2) Xt = (dim0, dim3, dim4, dim1. dim2)
""" """
if static_axis_boundary < 0: if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes static_axis_boundary = -1 # means no static axes
assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose.
transpose_start_idx = static_axis_boundary + 1 transpose_start_idx = static_axis_boundary + 1
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape)) transpose_axis = normalize_axis_boundary(transpose_axis, len(shape))
assert transpose_start_idx < transpose_axis_boundary assert transpose_start_idx < transpose_axis
return ( return (
*shape[:transpose_start_idx], *shape[:transpose_start_idx],
*shape[transpose_axis_boundary:], *shape[transpose_axis:],
*shape[transpose_start_idx:transpose_axis_boundary], *shape[transpose_start_idx:transpose_axis],
) )
...@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant ...@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break break
return ( return (
quantizer is not None quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100 and arch_l_100
and is_dbias and is_dbias
) )
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs):
""" """
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling. Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result. It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
...@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): ...@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
# 2x is not supported by TE kernels for delayed scaling # 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX # so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE quantizer.q_layout = QuantizeLayout.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer) rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None other_outputs = None
if isinstance(rowwise, tuple): if isinstance(rowwise, tuple):
other_outputs = rowwise[1:] other_outputs = rowwise[1:]
rowwise = rowwise[0] rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1))) if flatten_axis < 0:
flatten_axis += rowwise.data.ndim
assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
colwise_data = jnp.transpose(
rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis))
)
output_2x = ScaledTensorFactory.create( output_2x = ScaledTensorFactory.create(
data=rowwise.data, data=rowwise.data,
scale_inv=rowwise.scale_inv, scale_inv=rowwise.scale_inv,
...@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): ...@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
colwise_scale_inv=rowwise.scale_inv, colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype, dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE, q_layout=QuantizeLayout.ROWWISE_COLWISE,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
) )
if other_outputs is not None: if other_outputs is not None:
return (output_2x,) + other_outputs return (output_2x,) + other_outputs
......
...@@ -30,7 +30,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a ...@@ -30,7 +30,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeAxis, QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
) )
...@@ -277,13 +277,13 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -277,13 +277,13 @@ class NormFwdPrimitive(BasePrimitive):
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x.shape, is_padded=False x.shape, is_padded=False
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: # slice out padding for mxfp8, noop for DelayedScaling
scale_inv = scale_inv.flatten()[ scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
: reduce(operator.mul, rowwise_scale_inv_shape) rowwise_scale_inv_shape
].reshape(rowwise_scale_inv_shape) )
if is_2x: if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[ colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape) : reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape) ].reshape(colwise_scale_inv_shape)
return ( return (
out, out,
...@@ -816,7 +816,7 @@ def layernorm_fwd( ...@@ -816,7 +816,7 @@ def layernorm_fwd(
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -900,8 +900,8 @@ def layernorm_fwd( ...@@ -900,8 +900,8 @@ def layernorm_fwd(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
) )
return scaled_tensor, mu, rsigma return scaled_tensor, mu, rsigma
...@@ -997,7 +997,7 @@ def rmsnorm_fwd( ...@@ -997,7 +997,7 @@ def rmsnorm_fwd(
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -1082,8 +1082,8 @@ def rmsnorm_fwd( ...@@ -1082,8 +1082,8 @@ def rmsnorm_fwd(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
) )
return scaled_tensor, rsigma return scaled_tensor, rsigma
......
...@@ -11,14 +11,6 @@ ...@@ -11,14 +11,6 @@
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
namespace {
bool is_gated(NVTE_Activation_Type act_type) {
return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU ||
act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU ||
act_type == NVTE_Activation_Type::SREGLU;
}
} // namespace
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -44,38 +36,56 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -44,38 +36,56 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto act_len = input_dims[input_dims.size() - 2]; auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum); auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int); auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, act_len * n}; auto input_shape = std::vector<size_t>{m, act_len * n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv( if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(colwise_output, static_cast<DType>(out_dtype), output_shape); auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), std::vector<size_t>{
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, product(tmp_buf->dimensions(), 0, flatten_axis),
colwise_scale_inv_buf->dimensions().size() - 1), product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
colwise_scale_inv_buf->dimensions().back()}); }
}
} }
switch (act_type) { switch (act_type) {
...@@ -162,8 +172,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -162,8 +172,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, auto &tmp_shape = scaling_mode == static_cast<int>(NVTE_DELAYED_TENSOR_SCALING)
output_trans_shape); ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -190,9 +202,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -190,9 +202,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
bool is_dbias, int64_t act_enum) { bool is_dbias, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
...@@ -201,11 +213,15 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -201,11 +213,15 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum); auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data();
void *workspace = workspace_buf->untyped_data(); void *workspace = workspace_buf->untyped_data();
...@@ -213,17 +229,18 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -213,17 +229,18 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto act_input_dims = act_input_buf.dimensions(); auto act_input_dims = act_input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions(); auto workspace_dims = workspace_buf->dimensions();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto input_ranks = input_dims.size(); auto act_len = act_input_dims[act_input_dims.size() - 2];
auto act_input_ranks = act_input_dims.size(); NVTE_CHECK(act_input_dims.back() == input_dims.back(),
auto m = product(act_input_dims, 0, act_input_dims.size() - 1); "Shape mismatch between activation input and gradient input");
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = act_input_dims.back(); auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, input_dims.back()};
auto act_input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_trans_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n * act_len};
auto dbias_shape = std::vector<size_t>{n}; auto output_trans_shape = std::vector<size_t>{n * act_len, m};
auto dbias_shape = std::vector<size_t>{n * act_len};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end()); std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
...@@ -231,49 +248,55 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -231,49 +248,55 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
} }
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf = auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), std::vector<size_t>{1});
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, } else {
colwise_scale_inv_buf->dimensions().size() - 1), output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->dimensions().back()}); tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
} }
} }
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!"); NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && NVTE_CHECK(
is_gated(act_type)), !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations."); "TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) { if (is_dbias) {
......
...@@ -44,12 +44,12 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -44,12 +44,12 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k], // A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose, // B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect: // on exiting this function, JAX expect:
// C: row-major with size [m, n]. // C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair: // cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose, // A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n]. // B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be: // If we call cuBLAS GEMM for A * B, the output will be:
......
...@@ -34,7 +34,7 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -34,7 +34,7 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret; return ret;
} }
enum class QuantizeAxis { enum class QuantizeLayout {
ROWWISE, ROWWISE,
COLWISE, COLWISE,
ROWWISE_COLWISE, ROWWISE_COLWISE,
......
...@@ -144,11 +144,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -144,11 +144,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
.export_values(); .export_values();
pybind11::enum_<transformer_engine::jax::QuantizeAxis>(m, "QuantizeAxis", pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
pybind11::module_local()) pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE) .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE) .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values(); .export_values();
} }
......
...@@ -42,10 +42,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -42,10 +42,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf, Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias,
int64_t quantize_axis_enum, bool is_dbias) { int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -55,7 +55,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -55,7 +55,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum); auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto const quantize_axis = static_cast<QuantizeAxis>(quantize_axis_enum); auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data();
...@@ -63,9 +63,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -63,9 +63,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
void *workspace = workspace_buf->untyped_data(); void *workspace = workspace_buf->untyped_data();
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
int64_t input_ndim = input_dims.size();
if (flatten_axis < 0) flatten_axis += input_ndim;
NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");
auto workspace_dims = workspace_buf->dimensions(); auto workspace_dims = workspace_buf->dimensions();
auto m = product(input_dims, 0, input_dims.size() - 1); auto m = product(input_dims, 0, flatten_axis);
auto n = input_dims.back(); auto n = product(input_dims, flatten_axis, input_ndim);
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m}; auto output_trans_shape = std::vector<size_t>{n, m};
...@@ -75,37 +79,54 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -75,37 +79,54 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(scaling_mode);
if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data()); float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax_out, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
} }
if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { if (quantize_layout == QuantizeLayout::COLWISE ||
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf = auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv( output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), std::vector<size_t>{
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, product(tmp_buf->dimensions(), 0, flatten_axis),
colwise_scale_inv_buf->dimensions().size() - 1), product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
colwise_scale_inv_buf->dimensions().back()}); }
} }
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
...@@ -133,8 +154,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -133,8 +154,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode") .Attr<int64_t>("scaling_mode")
.Attr<int64_t>("q_axis") .Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias"), .Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
......
...@@ -15,7 +15,11 @@ import jax ...@@ -15,7 +15,11 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize import QuantizerSet, noop_quantizer_set from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
)
def dense( def dense(
...@@ -23,6 +27,8 @@ def dense( ...@@ -23,6 +27,8 @@ def dense(
kernel: jnp.ndarray, kernel: jnp.ndarray,
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -48,12 +54,12 @@ def dense( ...@@ -48,12 +54,12 @@ def dense(
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
else: else:
output = _dense(x, kernel, bias, contracting_dims, quantizer_set) output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3,)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, quantizer_set): def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support This function implements the core dense layer transformation logic with support
...@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
kernel: Weight matrix kernel: Weight matrix
bias: Optional bias tensor bias: Optional bias tensor
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set) output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
)
return output return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
x_contracting_dims, k_contracting_dims = contracting_dims x_contracting_dims, k_contracting_dims = contracting_dims
casted_x = tex.quantize(x, quantizer_set.x) flatten_axis_x = -len(x_contracting_dims)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN # GEMM NN
output = tex.gemm( output = tex.gemm(
...@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
casted_kernel.get_colwise_tensor(), casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
use_bias = bias is not None use_bias = bias is not None
if use_bias: if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
...@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
kernel.shape, kernel.shape,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k,
) )
return output, ctx return output, ctx
def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns: Returns:
Tuple of gradients with respect to inputs Tuple of gradients with respect to inputs
""" """
...@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
kernel_shape, kernel_shape,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k,
) = ctx ) = ctx
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
)
# GEMM NT # GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
rowwise_casted_kernel, rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim), (g_constracting_dim, k_constracting_dim),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
...@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set return dgrad, wgrad, dbias, quantizer_set
......
...@@ -28,6 +28,7 @@ from ..softmax import softmax, SoftmaxType ...@@ -28,6 +28,7 @@ from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available from ..cpp_extensions import is_softmax_kernel_available
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -406,6 +407,10 @@ class DenseGeneral(TransformerEngineBase): ...@@ -406,6 +407,10 @@ class DenseGeneral(TransformerEngineBase):
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -429,6 +434,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -429,6 +434,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -460,29 +466,35 @@ class DenseGeneral(TransformerEngineBase): ...@@ -460,29 +466,35 @@ class DenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, inputs.ndim) axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
if self.kernel_axes:
assert len(kernel_shape) == len(self.kernel_axes), (
"Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
)
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) ).astype(input_dtype)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
else: else:
bias = None bias = None
quantizer_set = self.generate_quantizer_set() quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
y = dense( y = dense(
inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set inputs,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -491,20 +503,14 @@ class DenseGeneral(TransformerEngineBase): ...@@ -491,20 +503,14 @@ class DenseGeneral(TransformerEngineBase):
*features[:-1], *features[:-1],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_init_shape = ( lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes( lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_shape,
self.dtype, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
...@@ -527,7 +533,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -527,7 +533,6 @@ class DenseGeneral(TransformerEngineBase):
y += jnp.reshape(bias, bias_shape) y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype assert y.dtype == input_dtype
y = y.reshape(*inputs.shape[: self.axis], *features)
return y return y
...@@ -678,6 +683,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -678,6 +683,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
assert self.axis == -1, "Only support axis = =-1 at this moment"
input_dtype = inputs.dtype input_dtype = inputs.dtype
ln_output = None ln_output = None
...@@ -692,10 +698,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -692,10 +698,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.enable_layernorm: if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1] features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters( scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type, self.layernorm_type,
(features,), (features,),
...@@ -731,17 +734,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -731,17 +734,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -756,11 +754,19 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -756,11 +754,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes, dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set) z = dense(
y,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
lora_a_kernel_shape = ( lora_a_kernel_shape = (
...@@ -768,20 +774,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -768,20 +774,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
*features[:-1], *features[:-1],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_init_shape = ( lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes( lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_shape,
self.dtype, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
...@@ -803,8 +803,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -803,8 +803,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) ).astype(input_dtype)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
if bias is not None: if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
...@@ -814,7 +813,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -814,7 +813,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
z = z / self.depth_scaling z = z / self.depth_scaling
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
z = z.reshape(*inputs.shape[: self.axis], *features) # z = z.reshape(*inputs.shape[: self.axis], *features)
return z, ln_output # dense_output, layer_norm_output return z, ln_output # dense_output, layer_norm_output
...@@ -989,6 +988,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -989,6 +988,8 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0") ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1") ffn2_quantizer_set = self.generate_quantizer_set("_1")
...@@ -1027,7 +1028,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1027,7 +1028,6 @@ class LayerNormMLP(TransformerEngineBase):
) )
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1] features = inputs.shape[-1]
...@@ -1071,7 +1071,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1071,7 +1071,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations = len(normalized_acts) num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes( kernel_1 = nn_partitioning.param_with_axes(
"wi_kernel", "wi_kernel",
kernel_1_init, kernel_1_init,
...@@ -1081,17 +1081,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1081,17 +1081,10 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
axes=self.kernel_axes_1, axes=self.kernel_axes_1,
) )
kernel_1_compute_shape = (
reduce(operator.mul, [y.shape[ax] for ax in axis], 1),
num_activations * self.intermediate_dim,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape)
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
if self.kernel_axes_1 is not None:
kernel_1 = with_sharding_constraint_by_logical_axes(
kernel_1, self.kernel_axes_1[:-2] + self.kernel_axes_1[-1:]
)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size) hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
...@@ -1102,27 +1095,20 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1102,27 +1095,20 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
axes=self.kernel_axes_2, axes=self.kernel_axes_2,
) )
kernel_2_compute_shape = (
self.intermediate_dim,
reduce(operator.mul, hidden_size_tuple, 1),
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape)
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
if self.kernel_axes_2 is not None:
kernel_2 = with_sharding_constraint_by_logical_axes(kernel_2, self.kernel_axes_2)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if self.use_bias: if self.use_bias:
bias_1_shape = num_activations * self.intermediate_dim bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes( bias_1 = nn_partitioning.param_with_axes(
"wi_bias", "wi_bias",
self.bias_init, self.bias_init,
bias_1_shape, bias_1_shape,
self.dtype, self.dtype,
axes=self.bias_axes_1, axes=self.bias_axes_1,
) ).astype(input_dtype)
bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype)
bias_2_shape = (hidden_size,) bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes( bias_2 = nn_partitioning.param_with_axes(
...@@ -1131,8 +1117,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1131,8 +1117,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2_shape, bias_2_shape,
self.dtype, self.dtype,
axes=self.bias_axes_2, axes=self.bias_axes_2,
) ).astype(input_dtype)
bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype)
else: else:
bias_1 = None bias_1 = None
bias_2 = None bias_2 = None
...@@ -1141,8 +1126,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1141,8 +1126,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name = "ffn2" ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp: if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_mlp( out = layernorm_mlp(
y, y,
scale, scale,
...@@ -1155,6 +1138,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1155,6 +1138,8 @@ class LayerNormMLP(TransformerEngineBase):
norm_input_axes=self.layernorm_input_axes, norm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes, dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes, dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name, ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
...@@ -1175,6 +1160,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1175,6 +1160,7 @@ class LayerNormMLP(TransformerEngineBase):
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes, dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
else: else:
...@@ -1183,35 +1169,31 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1183,35 +1169,31 @@ class LayerNormMLP(TransformerEngineBase):
y, y,
kernel_1, kernel_1,
contracting_dims=(axis, contract_ind), contracting_dims=(axis, contract_ind),
input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = ( wi_lora_a_kernel_each_shape = (
kernel_1_compute_shape[0], kernel_1_each_shape[: len(axis)],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_shape = (
kernel_1_each_shape[0],
num_activations,
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
wi_lora_a_kernel_init_each_shape = ( wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
kernel_1_each_shape[0],
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
wi_lora_a_kernel = nn_partitioning.param_with_axes( wi_lora_a_kernel = nn_partitioning.param_with_axes(
"wi_lora_a_kernel", "wi_lora_a_kernel",
kernel_1_init, kernel_1_init,
num_activations, num_activations,
-1, -2,
wi_lora_a_kernel_init_each_shape, wi_lora_a_kernel_each_shape,
self.dtype, self.dtype,
axes=wi_lora_a_kernel_axes, axes=wi_lora_a_kernel_axes,
) )
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype) wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = ( wi_lora_b_kernel_shape = (
...@@ -1232,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1232,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation( x += _apply_low_rank_adaptation(
y, y,
axis, axis,
num_activations * self.intermediate_dim, (num_activations, self.intermediate_dim),
wi_lora_a_kernel, wi_lora_a_kernel,
wi_lora_b_kernel, wi_lora_b_kernel,
self.low_rank_adaptation_alpha, self.low_rank_adaptation_alpha,
...@@ -1246,11 +1228,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1246,11 +1228,12 @@ class LayerNormMLP(TransformerEngineBase):
z = activation(x, normalized_acts) z = activation(x, normalized_acts)
else: else:
activations = [] activations = []
x = jnp.split(x, num_activations, axis=-1) x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(normalized_acts): for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = reduce(operator.mul, activations) z = reduce(operator.mul, activations)
z = jnp.squeeze(z, axis=-2)
z = z.astype(input_dtype) z = z.astype(input_dtype)
z = nn.Dropout( z = nn.Dropout(
...@@ -1264,7 +1247,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1264,7 +1247,12 @@ class LayerNormMLP(TransformerEngineBase):
# DenseGeneral 2 # DenseGeneral 2
out = dense( out = dense(
z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set z,
kernel_2,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
......
...@@ -33,10 +33,9 @@ def layernorm_dense( ...@@ -33,10 +33,9 @@ def layernorm_dense(
norm_type: str = "layernorm", norm_type: str = "layernorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
# The logic axes of sharding constraint to the dot input.
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation. """Apply layer normalization followed by dense layer transformation.
...@@ -56,6 +55,7 @@ def layernorm_dense( ...@@ -56,6 +55,7 @@ def layernorm_dense(
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: Set of quantizers for different tensor types quantizer_set: Set of quantizers for different tensor types
Returns: Returns:
...@@ -78,6 +78,7 @@ def layernorm_dense( ...@@ -78,6 +78,7 @@ def layernorm_dense(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -91,6 +92,7 @@ def layernorm_dense( ...@@ -91,6 +92,7 @@ def layernorm_dense(
7, 7,
8, 8,
9, 9,
10,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -104,6 +106,7 @@ def _layernorm_dense( ...@@ -104,6 +106,7 @@ def _layernorm_dense(
epsilon: float, epsilon: float,
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
quantizer_set, quantizer_set,
): ):
"""Internal implementation of layernorm_dense with custom VJP. """Internal implementation of layernorm_dense with custom VJP.
...@@ -139,6 +142,7 @@ def _layernorm_dense( ...@@ -139,6 +142,7 @@ def _layernorm_dense(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule( ...@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for layernorm_dense. """Forward pass rule for layernorm_dense.
...@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule( ...@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
...@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule( ...@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
norm_type, norm_type,
quantizer_set.x, quantizer_set.x,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...) # Kernel in (hidden_in, hidden_out...)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel) flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
...@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule( ...@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims, k_contracting_dims,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis,
) )
return output, ctx return output, ctx
...@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule( ...@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument dot_input_axes, # pylint: disable=unused-argument
kernel_axes,
ctx, ctx,
grad, grad,
): ):
...@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule( ...@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
k_contracting_dims_in_fwd, k_contracting_dims_in_fwd,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis,
) = ctx ) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes) casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple( g_constracting_dim = tuple(
...@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule( ...@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
(x_constracting_dim, g_constracting_dim), (x_constracting_dim, g_constracting_dim),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
dx, dgamma, dbeta = tex.normalization_bwd( dx, dgamma, dbeta = tex.normalization_bwd(
dgrad, dgrad,
x, x,
......
...@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .sharding import get_non_contracting_logical_axes
def layernorm_mlp( def layernorm_mlp(
...@@ -37,6 +38,8 @@ def layernorm_mlp( ...@@ -37,6 +38,8 @@ def layernorm_mlp(
norm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
kernel_1_axes: Tuple[str, ...] = None,
kernel_2_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
...@@ -66,6 +69,8 @@ def layernorm_mlp( ...@@ -66,6 +69,8 @@ def layernorm_mlp(
norm_input_axes: Logical axes for sharding the layernorm input norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication
kernel_1_axes: Logical axes for sharding the first weight matrix
kernel_2_axes: Logical axes for sharding the second weight matrix
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
...@@ -109,6 +114,8 @@ def layernorm_mlp( ...@@ -109,6 +114,8 @@ def layernorm_mlp(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -117,7 +124,7 @@ def layernorm_mlp( ...@@ -117,7 +124,7 @@ def layernorm_mlp(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -132,6 +139,8 @@ def _layernorm_mlp( ...@@ -132,6 +139,8 @@ def _layernorm_mlp(
norm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
kernel_1_axes: Tuple[str, ...],
kernel_2_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
...@@ -179,6 +188,8 @@ def _layernorm_mlp( ...@@ -179,6 +188,8 @@ def _layernorm_mlp(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule( ...@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
Returns: Returns:
Tuple of (output, context) for automatic differentiation Tuple of (output, context) for automatic differentiation
""" """
del kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len * intermediate) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in) # Kernel_2 should be in shape of (intermediate, hidden_in)
assert len(kernel_1.shape) == 2 assert len(kernel_1.shape) == 3
assert len(kernel_2.shape) == 2 assert len(kernel_2.shape) == 2
assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type) assert kernel_1.shape[-2] == len(activation_type)
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
use_bias_1 = bias_1 is not None use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None use_bias_2 = bias_1 is not None
...@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
) )
casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out) # (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm( dot_1_output = tex.gemm(
...@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule( ...@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_1.get_colwise_tensor(), casted_kernel_1.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1: if use_bias_1:
bias_1_shape = bias_1.shape bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
...@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule( ...@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2: if use_bias_2:
bias_2_shape = bias_2.shape bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
...@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument kernel_1_axes,
ffn2_ckpt_name, # pylint: disable=unused-argument kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type, activation_type,
ctx, ctx,
grad, grad,
...@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
( (
x, x,
mu, mu,
...@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule( ...@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_2 = tuple( g_contracting_dims_2 = tuple(
range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_constracting_dim_2 = tuple( k_contracting_dims_2 = tuple(
dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
) )
...@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule( ...@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
dgrad_2 = tex.gemm( dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2, rowwise_casted_kernel_2,
(g_constracting_dim_2, k_constracting_dim_2), (g_contracting_dims_2, k_contracting_dims_2),
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
x_constracting_dim = g_constracting_dim = tuple( x_contracting_dims = g_contracting_dims = tuple(
range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
) )
...@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule( ...@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
wgrad_2 = tex.gemm( wgrad_2 = tex.gemm(
colwise_casted_act_out, colwise_casted_act_out,
casted_grad.get_colwise_tensor(), casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim), (x_contracting_dims, g_contracting_dims),
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
casted_dact_out, dbias_1 = tex.quantize_dact_dbias( casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
dgrad_2, dgrad_2,
...@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule( ...@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1 = tuple( dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim) g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_constracting_dim_1 = tuple( k_contracting_dims_1 = tuple(
dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
) )
...@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule( ...@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
dgrad_1 = tex.gemm( dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(), casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1, rowwise_casted_kernel_1,
(g_constracting_dim_1, k_constracting_dim_1), (g_contracting_dims_1, k_contracting_dims_1),
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
# TN GEMM # TN GEMM
# (hidden, batch...) x (hidden, batch...) # (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm( wgrad_1 = tex.gemm(
colwise_casted_ln_out, colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(), casted_dact_out.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim), (x_contracting_dims, g_contracting_dims),
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
dx, dgamma, dbeta = tex.normalization_bwd( dx, dgamma, dbeta = tex.normalization_bwd(
dgrad_1, dgrad_1,
x, x,
......
...@@ -57,18 +57,27 @@ class Dequantizer: ...@@ -57,18 +57,27 @@ class Dequantizer:
data = scaled_tensor.data.astype(jnp.float32) data = scaled_tensor.data.astype(jnp.float32)
data_shape = data.shape data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32) scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
flatten_axis = scaled_tensor.flatten_axis
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaled_tensor.scaling_mode.get_scale_shape( scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
data = data.reshape( data = data.reshape(
*data_shape[:-2], *data_shape[: flatten_axis - 1],
scale_shape[-2], scale_shape[flatten_axis - 1],
int(data_shape[-2] / scale_shape[-2]), int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1], scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]), int(data_shape[-1] / scale_shape[-1]),
) )
scale = jnp.expand_dims(scale, axis=(-1, -3))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape( return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape data_shape
......
...@@ -14,7 +14,7 @@ from typing import Union, Optional ...@@ -14,7 +14,7 @@ from typing import Union, Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
...@@ -24,7 +24,7 @@ from .helper import ( ...@@ -24,7 +24,7 @@ from .helper import (
) )
__all__ = [ __all__ = [
"QuantizeAxis", "QuantizeLayout",
"Quantizer", "Quantizer",
"QuantizerSet", "QuantizerSet",
"DelayedScaleQuantizer", "DelayedScaleQuantizer",
...@@ -45,12 +45,12 @@ class Quantizer(ABC): ...@@ -45,12 +45,12 @@ class Quantizer(ABC):
Attributes: Attributes:
q_dtype: The data type for quantized values q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both) q_layout: The quantization axis (row-wise, column-wise, or both)
""" """
q_dtype: jnp.dtype q_dtype: jnp.dtype
scaling_mode: ScalingMode scaling_mode: ScalingMode
q_axis: QuantizeAxis q_layout: QuantizeLayout
def tree_flatten(self): def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations. """Flatten the quantizer for JAX tree operations.
...@@ -59,7 +59,7 @@ class Quantizer(ABC): ...@@ -59,7 +59,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations Tuple of (children, aux_data) for tree operations
""" """
children = () children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data) return (children, aux_data)
@classmethod @classmethod
...@@ -85,30 +85,31 @@ class Quantizer(ABC): ...@@ -85,30 +85,31 @@ class Quantizer(ABC):
Returns: Returns:
True if using both row-wise and column-wise quantization True if using both row-wise and column-wise quantization
""" """
return self.q_axis == QuantizeAxis.ROWWISE_COLWISE return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
@abstractmethod @abstractmethod
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout. """Get the data data_layout.
Returns: Returns:
Data layout in string format Data data_layout in string format
""" """
@abstractmethod @abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Core quantization function to be implemented by subclasses. """Core quantization function to be implemented by subclasses.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype dq_dtype: Data type for dequantized values, default is x.dtype
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None): def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1):
"""Quantize a tensor using the internal _quantize_func(). """Quantize a tensor using the internal _quantize_func().
Args: Args:
...@@ -116,21 +117,26 @@ class Quantizer(ABC): ...@@ -116,21 +117,26 @@ class Quantizer(ABC):
is_rowwise: Whether to use row-wise quantization is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
if (is_rowwise and is_colwise) or self.is_2x2x(): if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) colwise_tensor = self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise: if is_colwise:
return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) return self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return self._quantize_func(x, dq_dtype=dq_dtype) return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def get_scale_shapes(self, data_shape, is_padded=True): def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1):
"""Get shapes for scale tensors. """Get shapes for scale tensors.
Args: Args:
...@@ -140,7 +146,7 @@ class Quantizer(ABC): ...@@ -140,7 +146,7 @@ class Quantizer(ABC):
Returns: Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape) Tuple of (rowwise_scale_shape, colwise_scale_shape)
""" """
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded) return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
def get_scale_dtype(self): def get_scale_dtype(self):
"""Get the data type for scale tensors. """Get the data type for scale tensors.
...@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor scale: Current scaling factor
amax_history: History of maximum absolute values amax_history: History of maximum absolute values
""" """
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field( amax_history: jnp.ndarray = field(
...@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations Tuple of (children, aux_data) for tree operations
""" """
children = (self.scale, self.amax_history) children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data) return (children, aux_data)
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout string. """Get the data data_layout string.
Returns: Returns:
Data layout in string format Data data_layout in string format
Raises: Raises:
ValueError: If quantization axis is invalid ValueError: If quantization axis is invalid
""" """
layout = "NT" data_layout = "NT"
if self.q_axis == QuantizeAxis.ROWWISE_COLWISE: if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return layout return data_layout
if self.q_axis == QuantizeAxis.ROWWISE: if self.q_layout == QuantizeLayout.ROWWISE:
return layout[0] return data_layout[0]
if self.q_axis == QuantizeAxis.COLWISE: if self.q_layout == QuantizeLayout.COLWISE:
return layout[1] return data_layout[1]
raise ValueError(f"Invalid q_axis: {self.q_axis}") raise ValueError(f"Invalid q_layout: {self.q_layout}")
def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8. """Quantize function helper for delayed scaling FP8.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
...@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
scale_inv=scale_inv, scale_inv=scale_inv,
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None): def quantize(
self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
):
"""Quantize a tensor using the internal _quantize_func(). """Quantize a tensor using the internal _quantize_func().
Args: Args:
...@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
is_rowwise: Whether to use row-wise quantization is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = ( is_rowwise = (
is_rowwise is_rowwise
if is_rowwise is not None if is_rowwise is not None
else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x()) else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
) )
is_colwise = ( is_colwise = (
is_colwise is_colwise
if is_colwise is not None if is_colwise is not None
else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x()) else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
) )
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = None colwise_tensor = None
if is_colwise: if is_colwise:
colwise_tensor = ScaledTensorFactory.create_1x( colwise_tensor = ScaledTensorFactory.create_1x(
data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))), data=jnp.transpose(
rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
),
scale_inv=rowwise_tensor.scale_inv, scale_inv=rowwise_tensor.scale_inv,
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
is_colwise=True, is_colwise=True,
layout="T", data_layout="T",
flatten_axis=flatten_axis,
) )
if is_colwise and is_rowwise: if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
...@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer): ...@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
""" """
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout string. """Get the data data_layout string.
Returns: Returns:
Data layout in string format Data data_layout in string format
""" """
if self.is_2x2x(): if self.is_2x2x():
return "NN" return "NN"
return "N" return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8. """Quantize function helper for block scaling FP8.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
# TODO(Phuong): use quantize_func from JAX # TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
assert (
0 <= flatten_axis < x.ndim
), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False) scale_shape = self.scaling_mode.get_scale_shape(
x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale_dtype = self.scaling_mode.get_scale_dtype() scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape( x = x.reshape(
*x_shape[:-2], *x_shape[: flatten_axis - 1],
scale_shape[-2], scale_shape[flatten_axis - 1],
int(x_shape[-2] / scale_shape[-2]), int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*x_shape[flatten_axis:-1],
scale_shape[-1], scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]), int(x_shape[-1] / scale_shape[-1]),
) )
amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True) amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32) MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
scales = amax.astype(jnp.float32) / MAX scales = amax.astype(jnp.float32) / MAX
...@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer): ...@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
self.scaling_mode, self.scaling_mode,
is_colwise=is_colwise, is_colwise=is_colwise,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
def _cast_to_e8m0_with_rounding_up(self, scales): def _cast_to_e8m0_with_rounding_up(self, scales):
...@@ -509,7 +539,7 @@ class QuantizerFactory: ...@@ -509,7 +539,7 @@ class QuantizerFactory:
n_quantizers: int = 1, n_quantizers: int = 1,
scaling_mode: ScalingMode = None, scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None, q_dtype: jnp.dtype = None,
q_axis: QuantizeAxis = None, q_layout: QuantizeLayout = None,
**kwargs, **kwargs,
) -> Quantizer: ) -> Quantizer:
"""Create one or more quantizers with specified parameters. """Create one or more quantizers with specified parameters.
...@@ -518,7 +548,8 @@ class QuantizerFactory: ...@@ -518,7 +548,8 @@ class QuantizerFactory:
n_quantizers: Number of quantizers to create n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use scaling_mode: Scaling mode to use
q_dtype: Quantization data type q_dtype: Quantization data type
q_axis: Quantization axis q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
Returns: Returns:
...@@ -534,7 +565,7 @@ class QuantizerFactory: ...@@ -534,7 +565,7 @@ class QuantizerFactory:
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append( quantizers.append(
quantizer_type( quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
) )
) )
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
...@@ -554,11 +585,11 @@ class QuantizerFactory: ...@@ -554,11 +585,11 @@ class QuantizerFactory:
A QuantizerSet instance A QuantizerSet instance
""" """
if is_2x2x: if is_2x2x:
q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else: else:
q_axis_x = QuantizeAxis.ROWWISE q_layout_x = QuantizeLayout.ROWWISE
q_axis_kernel = QuantizeAxis.COLWISE q_layout_kernel = QuantizeLayout.COLWISE
q_axis_dgrad = None q_layout_dgrad = None
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set") quantize_meta_set = kwargs.get("quantize_meta_set")
...@@ -577,9 +608,11 @@ class QuantizerFactory: ...@@ -577,9 +608,11 @@ class QuantizerFactory:
else: else:
args_x = args_kernel = args_grad = {} args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x) q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x)
q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel) q_kernel = QuantizerFactory.create(
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad) 1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel
)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment