Unverified Commit ff884e20 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flatten_axis for quantization and Sharding propagation fixes (#1644)



* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout

* add fatten_axis option

* added gated act to test encoder

* sharding constraint fixes

* fix padding when flattening first dim needs to be padded

* update test sizes so that padding is tested

* rm output sharding as it can be done in the flax module

* sharding scale_inv for mxfp8

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent be1f647c
...@@ -57,13 +57,14 @@ class Net(nn.Module): ...@@ -57,13 +57,14 @@ class Net(nn.Module):
self_attn_mask_type="padding", self_attn_mask_type="padding",
enable_relative_embedding=False, enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral, enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
) )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral: if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device. # Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint( x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
) )
...@@ -459,7 +460,7 @@ class TestEncoder(unittest.TestCase): ...@@ -459,7 +460,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -467,7 +468,7 @@ class TestEncoder(unittest.TestCase): ...@@ -467,7 +468,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -475,14 +476,14 @@ class TestEncoder(unittest.TestCase): ...@@ -475,14 +476,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self): def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self): def test_te_delayed_scaling_fp8_with_sp(self):
...@@ -491,7 +492,7 @@ class TestEncoder(unittest.TestCase): ...@@ -491,7 +492,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self): def test_te_mxfp8_with_sp(self):
...@@ -500,7 +501,7 @@ class TestEncoder(unittest.TestCase): ...@@ -500,7 +501,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -29,7 +29,7 @@ from transformer_engine.jax.quantize import ( ...@@ -29,7 +29,7 @@ from transformer_engine.jax.quantize import (
ScaledTensor, ScaledTensor,
ScalingMode, ScalingMode,
QuantizerFactory, QuantizerFactory,
QuantizeAxis, QuantizeLayout,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
...@@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): ...@@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray): def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
if isinstance(a, ScaledTensor1x): if isinstance(a, ScaledTensor1x):
if a.layout == "T": if a.data_layout == "T":
b_transpose = jnp.transpose(b, (-1, *range(b.ndim - 1))) flatten_axis = a.data.ndim - a.flatten_axis
b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype) assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
else: else:
assert_allclose(a.dequantize(), b, dtype=a.data.dtype) assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
...@@ -141,7 +142,8 @@ class TestActivation: ...@@ -141,7 +142,8 @@ class TestActivation:
def test_act_grad(self, shape, activation_type): def test_act_grad(self, shape, activation_type):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, jnp.float32) x = jax.random.uniform(key, shape, jnp.float32)
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
...@@ -159,7 +161,8 @@ class TestActivation: ...@@ -159,7 +161,8 @@ class TestActivation:
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type): def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type self.activation_type = activation_type
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
...@@ -169,7 +172,7 @@ class TestActivation: ...@@ -169,7 +172,7 @@ class TestActivation:
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
q_dtype=output_type, q_dtype=output_type,
q_axis=QuantizeAxis.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
...@@ -182,19 +185,22 @@ class TestActivation: ...@@ -182,19 +185,22 @@ class TestActivation:
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_delayed_scaling_fp8( def test_act_forward_with_delayed_scaling_fp8(
self, random_inputs, activation_type, output_type, q_axis self, random_inputs, activation_type, output_type, q_layout
): ):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type self.activation_type = activation_type
te_quantizer, jax_quantizer = QuantizerFactory.create( te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, n_quantizers=2,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
q_dtype=output_type, q_dtype=output_type,
q_axis=q_axis, q_layout=q_layout,
) )
te_output = tex.act_lu(x, activation_type, te_quantizer) te_output = tex.act_lu(x, activation_type, te_quantizer)
...@@ -203,19 +209,21 @@ class TestActivation: ...@@ -203,19 +209,21 @@ class TestActivation:
assert_bitwise_scaled_tensors(te_output, jax_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest_parametrize_wrapper("shape", [(128, 128)]) @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_block_scaling_fp8( def test_act_forward_with_block_scaling_fp8(
self, random_inputs, activation_type, output_type, q_axis self, random_inputs, activation_type, output_type, q_layout
): ):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type self.activation_type = activation_type
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_axis=q_axis scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
) )
output = tex.act_lu(x, activation_type, quantizer) output = tex.act_lu(x, activation_type, quantizer)
...@@ -324,9 +332,11 @@ class TestNorm: ...@@ -324,9 +332,11 @@ class TestNorm:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_grad_with_delayed_scaling_fp8( def test_norm_grad_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
): ):
""" """
Test transformer_engine.jax.layernorm.layernorm Test transformer_engine.jax.layernorm.layernorm
...@@ -335,7 +345,9 @@ class TestNorm: ...@@ -335,7 +345,9 @@ class TestNorm:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!") pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_axis=q_axis scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
q_dtype=out_dtype,
q_layout=q_layout,
) )
self._test_norm_grad( self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
...@@ -351,7 +363,7 @@ class TestNorm: ...@@ -351,7 +363,7 @@ class TestNorm:
inp_dtype, inp_dtype,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
): ):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3) subkeys = jax.random.split(key, 3)
...@@ -363,7 +375,7 @@ class TestNorm: ...@@ -363,7 +375,7 @@ class TestNorm:
gamma = jnp.asarray(gamma, inp_dtype) gamma = jnp.asarray(gamma, inp_dtype)
quantizer, ref_quantizer = QuantizerFactory.create( quantizer, ref_quantizer = QuantizerFactory.create(
n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_axis=q_axis n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
) )
if norm_type == "layernorm": if norm_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
...@@ -391,9 +403,11 @@ class TestNorm: ...@@ -391,9 +403,11 @@ class TestNorm:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# No Norm FWD E5M2 in TE backend # No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_forward_with_delayed_scaling_fp8( def test_norm_forward_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
): ):
if norm_type == "rmsnorm" and zero_centered_gamma is True: if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!") pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
...@@ -407,7 +421,7 @@ class TestNorm: ...@@ -407,7 +421,7 @@ class TestNorm:
inp_dtype=inp_dtype, inp_dtype=inp_dtype,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
q_axis=q_axis, q_layout=q_layout,
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
...@@ -424,7 +438,7 @@ class TestNorm: ...@@ -424,7 +438,7 @@ class TestNorm:
inp_dtype=inp_dtype, inp_dtype=inp_dtype,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING,
q_axis=QuantizeAxis.ROWWISE_COLWISE, q_layout=QuantizeLayout.ROWWISE_COLWISE,
) )
...@@ -434,14 +448,14 @@ QUANTIZE_OUTPUT_DTYPES = { ...@@ -434,14 +448,14 @@ QUANTIZE_OUTPUT_DTYPES = {
} }
ALL_QUANTIZE_TEST_SHAPES = [ ALL_QUANTIZE_TEST_SHAPES = [
(128, 128), (32, 64),
(4, 256, 512), (2, 64, 32),
] ]
QUANTIZE_TEST_SHAPES = { QUANTIZE_TEST_SHAPES = {
"L0": [ "L0": [
(256, 128), (32, 256, 128),
(64, 16, 2, 256), (64, 32, 32, 256),
], ],
"L2": ALL_QUANTIZE_TEST_SHAPES, "L2": ALL_QUANTIZE_TEST_SHAPES,
} }
...@@ -457,48 +471,52 @@ QUANTIZATION_INPUT_DTYPE = { ...@@ -457,48 +471,52 @@ QUANTIZATION_INPUT_DTYPE = {
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE] "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
) )
class TestQuantize: class TestQuantize:
""" """
Purely quantization related tests that will always test on a wider set of types and shapes Purely quantization related tests that will always test on a wider set of types and shapes
""" """
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis): def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling) # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_dtype=q_dtype, q_dtype=q_dtype,
q_axis=q_axis, q_layout=q_layout,
) )
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations): for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype) x = jax.random.uniform(key, input_shape, in_dtype)
scaled_tensor = quantizer.quantize(x) scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
assert_dequantized_scaled_tensor(scaled_tensor, x) assert_dequantized_scaled_tensor(scaled_tensor, x)
def test_quantize_bitwise(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis): def test_quantize_bitwise(
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
input_shape ):
):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create( te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
) )
jax_output = _jax_quantize(input, quantizer=jax_quantizer) jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = tex.quantize(input, quantizer=te_quantizer) te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(jax_output, te_output) assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
...@@ -508,9 +526,13 @@ class TestFusedQuantize: ...@@ -508,9 +526,13 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES) @pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
def test_quantize_dbias(self, in_dtype, input_shape, out_dtype, scaling_mode, q_axis): "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
transpose_axis = -1 )
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
):
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8( if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
input_shape input_shape
): ):
...@@ -520,35 +542,37 @@ class TestFusedQuantize: ...@@ -520,35 +542,37 @@ class TestFusedQuantize:
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
jax_quantizer, te_quantizer = QuantizerFactory.create( jax_quantizer, te_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
) )
te_output, te_dbias = jit(lambda input: tex.quantize_dbias(input, quantizer=te_quantizer))( te_output, te_dbias = jit(
input lambda input: tex.quantize_dbias(
) input, quantizer=te_quantizer, flatten_axis=flatten_axis
)
)(input)
jax_output, jax_dbias = jit( jax_output, jax_dbias = jit(
lambda input: _jax_quantize_dbias( lambda input: _jax_quantize_dbias(
input, input, quantizer=jax_quantizer, flatten_axis=flatten_axis
quantizer=jax_quantizer,
) )
)(input) )(input)
assert_bitwise_scaled_tensors(jax_output, te_output) assert_bitwise_scaled_tensors(te_output, jax_output)
assert_allclose(jax_dbias, te_dbias) assert_allclose(te_dbias, jax_dbias)
def _test_quantize_dact_dbias( def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_axis self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
): ):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
x = jnp.repeat(x, len(activation_type), axis=-1) x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1) dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
jax_quantizer, te_quantizer = QuantizerFactory.create( jax_quantizer, te_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
) )
is_casted_output = te_quantizer is not None is_casted_output = te_quantizer is not None
...@@ -573,12 +597,12 @@ class TestFusedQuantize: ...@@ -573,12 +597,12 @@ class TestFusedQuantize:
)(dz, x) )(dz, x)
if is_casted_output: if is_casted_output:
assert_bitwise_scaled_tensors(jax_output, te_output) assert_bitwise_scaled_tensors(te_output, jax_output)
else: else:
assert_allclose(jax_output, te_output) assert_allclose(te_output, jax_output)
if is_dbias: if is_dbias:
assert_allclose(jax_dbias, te_dbias) assert_allclose(te_dbias, jax_dbias)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
...@@ -597,7 +621,7 @@ class TestFusedQuantize: ...@@ -597,7 +621,7 @@ class TestFusedQuantize:
scaling_mode=ScalingMode.NVTE_NO_SCALING, scaling_mode=ScalingMode.NVTE_NO_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_axis=QuantizeAxis.ROWWISE, q_layout=QuantizeLayout.ROWWISE,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -605,9 +629,11 @@ class TestFusedQuantize: ...@@ -605,9 +629,11 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_delayed_scaling( def test_quantize_dact_dbias_delayed_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
): ):
self._test_quantize_dact_dbias( self._test_quantize_dact_dbias(
in_dtype=in_dtype, in_dtype=in_dtype,
...@@ -616,7 +642,7 @@ class TestFusedQuantize: ...@@ -616,7 +642,7 @@ class TestFusedQuantize:
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_axis=q_axis, q_layout=q_layout,
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
...@@ -626,9 +652,11 @@ class TestFusedQuantize: ...@@ -626,9 +652,11 @@ class TestFusedQuantize:
) )
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]) @pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_mxfp8_scaling( def test_quantize_dact_dbias_mxfp8_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
): ):
if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0: if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
# TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes. # TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
...@@ -645,75 +673,75 @@ class TestFusedQuantize: ...@@ -645,75 +673,75 @@ class TestFusedQuantize:
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING,
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
q_axis=q_axis, q_layout=q_layout,
) )
class TestDense: class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, layout): def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if layout[0] == "T": if data_layout[0] == "T":
a = jnp.swapaxes(a, -1, -2) a = jnp.swapaxes(a, -1, -2)
if layout[1] == "T": if data_layout[1] == "T":
b = jnp.swapaxes(b, -1, -2) b = jnp.swapaxes(b, -1, -2)
return jnp.dot(a, b) return jnp.dot(a, b)
def _generate_gemm_input(self, m, n, k, layout): def _generate_gemm_input(self, m, n, k, data_layout):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
x = jax.random.uniform( x = jax.random.uniform(
subkeys[0], subkeys[0],
(m if layout[0] == "N" else k, k if layout[0] == "N" else m), (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=jnp.bfloat16, dtype=jnp.bfloat16,
) / jnp.sqrt(k) ) / jnp.sqrt(k)
w = jax.random.uniform( w = jax.random.uniform(
subkeys[1], subkeys[1],
(k if layout[1] == "N" else n, n if layout[1] == "N" else k), (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=jnp.bfloat16, dtype=jnp.bfloat16,
) / jnp.sqrt(n) ) / jnp.sqrt(n)
lhs_contracting_dim = (1,) if layout[0] == "N" else (0,) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if layout[1] == "N" else (1,) rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return (x, w, contracting_dims) return (x, w, contracting_dims)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_bf16(self, m, n, k, layout): def test_gemm_bf16(self, m, n, k, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
primitive_out = tex.gemm(x, w, contracting_dims) primitive_out = tex.gemm(x, w, contracting_dims)
ref_out = self._ref_gemm_with_jnp_dot(x, w, layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, layout): def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
) )
primitive_out = tex.gemm( primitive_out = tex.gemm(
x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
) )
ref_out = self._ref_gemm_with_jnp_dot(x, w, layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=q_dtype)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k): def test_dense_grad_bf16(self, m, n, k):
layout = "NN" data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
def primitive_func(x, w, contracting_dims): def primitive_func(x, w, contracting_dims):
primitive_out = dense(x, w, contracting_dims=contracting_dims) primitive_out = dense(x, w, contracting_dims=contracting_dims)
return jnp.mean(primitive_out) return jnp.mean(primitive_out)
def ref_func(x, w, layout): def ref_func(x, w, data_layout):
return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, layout)) return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
...@@ -722,19 +750,19 @@ class TestDense: ...@@ -722,19 +750,19 @@ class TestDense:
primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func( primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
x, w, contracting_dims x, w, contracting_dims
) )
ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, layout) ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
layout = "NN" data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
key = jax.random.PRNGKey(1) key = jax.random.PRNGKey(1)
bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)
...@@ -745,9 +773,9 @@ class TestDense: ...@@ -745,9 +773,9 @@ class TestDense:
) )
return jnp.mean(primitive_out) return jnp.mean(primitive_out)
def ref_func(x, w, bias, layout): def ref_func(x, w, bias, data_layout):
return jnp.mean( return jnp.mean(
self._ref_gemm_with_jnp_dot(x, w, layout) + jnp.expand_dims(bias, axis=0) self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
) )
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
...@@ -763,7 +791,9 @@ class TestDense: ...@@ -763,7 +791,9 @@ class TestDense:
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
) )
ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(x, w, bias, layout) ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
x, w, bias, data_layout
)
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=q_dtype)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
...@@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ...@@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense: class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", [(512, 128, 128)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
...@@ -873,7 +903,7 @@ class TestFusedDense: ...@@ -873,7 +903,7 @@ class TestFusedDense:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", [(512, 128, 256)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
...@@ -898,13 +928,13 @@ class TestFusedDense: ...@@ -898,13 +928,13 @@ class TestFusedDense:
x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
kernel_1 = jax.random.normal( kernel_1 = jax.random.normal(
subkeys[1], (k, len(activation_type) * n), jnp.bfloat16 subkeys[1], (k, len(activation_type), n), jnp.bfloat16
) / jnp.sqrt(k) ) / jnp.sqrt(k)
kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n) kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
beta = None # was tested in TestNorm beta = None # was tested in TestNorm
if use_bias: if use_bias:
bias_1 = jax.random.normal(subkeys[3], (len(activation_type) * n), jnp.bfloat16) bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else: else:
bias_1 = None bias_1 = None
...@@ -1039,19 +1069,19 @@ class TestGroupedDense: ...@@ -1039,19 +1069,19 @@ class TestGroupedDense:
subkeys = jax.random.split(key, len(shape_list) * 2) subkeys = jax.random.split(key, len(shape_list) * 2)
lhs_list, rhs_list, contracting_dims_list = [], [], [] lhs_list, rhs_list, contracting_dims_list = [], [], []
for i, ((m, n, k), layout) in enumerate(zip(shape_list, layout_list)): for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
lhs = jax.random.uniform( lhs = jax.random.uniform(
subkeys[2 * i], subkeys[2 * i],
(m if layout[0] == "N" else k, k if layout[0] == "N" else m), (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype, dtype=dtype,
) )
rhs = jax.random.uniform( rhs = jax.random.uniform(
subkeys[2 * i + 1], subkeys[2 * i + 1],
(k if layout[1] == "N" else n, n if layout[1] == "N" else k), (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=dtype, dtype=dtype,
) )
lhs_contracting_dim = (1,) if layout[0] == "N" else (0,) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if layout[1] == "N" else (1,) rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
lhs_list.append(lhs) lhs_list.append(lhs)
......
...@@ -45,11 +45,17 @@ if is_mxfp8_supported: ...@@ -45,11 +45,17 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in] INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES) DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64 INTERMEDIATE = 64
...@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs(): ...@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
configs.append( configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
) )
if is_devices_enough(4): if is_devices_enough(4):
configs.append( configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
...@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP: ...@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal( k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in) ) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt( k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE INTERMEDIATE
) )
if use_bias: if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype) b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
else: else:
b1 = None b1 = None
...@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP: ...@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
layernorm_input_axes = LAYERNORM_INPUT_AXES layernorm_input_axes = LAYERNORM_INPUT_AXES
dot_1_input_axes = DOT_1_INPUT_AXES dot_1_input_axes = DOT_1_INPUT_AXES
dot_2_input_axes = DOT_2_INPUT_AXES dot_2_input_axes = DOT_2_INPUT_AXES
kernel_1_axes = KERNEL_1_AXES
kernel_2_axes = KERNEL_2_AXES
else: else:
layernorm_input_axes = None layernorm_input_axes = None
dot_1_input_axes = None dot_1_input_axes = dot_2_input_axes = None
dot_2_input_axes = None kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
...@@ -130,6 +137,8 @@ class TestDistributedLayernormMLP: ...@@ -130,6 +137,8 @@ class TestDistributedLayernormMLP:
norm_input_axes=layernorm_input_axes, norm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes, dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes, dot_2_input_axes=dot_2_input_axes,
kernel_1_axes=kernel_1_axes,
kernel_2_axes=kernel_2_axes,
activation_type=activation_type, activation_type=activation_type,
quantizer_sets=quantizer_sets, quantizer_sets=quantizer_sets,
) )
...@@ -142,7 +151,7 @@ class TestDistributedLayernormMLP: ...@@ -142,7 +151,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_primitive( def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
): ):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
...@@ -168,12 +177,12 @@ class TestDistributedLayernormMLP: ...@@ -168,12 +177,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp")) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding) k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding) k2_ = jax.device_put(k2, k2_sharding)
if use_bias: if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec("tp")) b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding) b1_ = jax.device_put(b1, b1_sharding)
else: else:
b1_sharding = b1_ = None b1_sharding = b1_ = None
...@@ -267,16 +276,7 @@ class TestDistributedLayernormMLP: ...@@ -267,16 +276,7 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, # input: [batch, seqlen, hidden] transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=use_bias, use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
) )
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply( mlp_out_single, ln_out_single = ln_mlp_single.apply(
...@@ -295,13 +295,13 @@ class TestDistributedLayernormMLP: ...@@ -295,13 +295,13 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=LN_SCALE_AXES,
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=LN_BIAS_AXES,
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes_1=KERNEL_1_AXES,
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), kernel_axes_2=KERNEL_2_AXES,
use_bias=use_bias, use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES), bias_axes_1=BIAS_1_AXES,
bias_axes_2=(W_NO_SHARD_AXES,), bias_axes_2=BIAS_2_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES,
...@@ -334,7 +334,7 @@ class TestDistributedLayernormMLP: ...@@ -334,7 +334,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_layer( def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
): ):
self._test_layernorm_mlp( self._test_layernorm_mlp(
......
...@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g): ...@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx (x, _) = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type) dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None) return (dx, None)
......
...@@ -26,12 +26,12 @@ from .misc import ( ...@@ -26,12 +26,12 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100, should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding, NamedSharding,
) )
from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias from .quantization import _jax_dbias, _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeAxis, QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
) )
...@@ -110,41 +110,31 @@ class ActLuPrimitive(BasePrimitive): ...@@ -110,41 +110,31 @@ class ActLuPrimitive(BasePrimitive):
""" """
te_act_lu_p abstract te_act_lu_p abstract
""" """
del act_enum, act_len, scale_shapes del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
assert x_aval.shape[-2] == act_len, (
out_shape = ( "activation input should be replicated by act_len in the -2 axis, got input shape"
*x_aval.shape[:-2], f" {x_aval.shape} and act_len {act_len}"
1,
x_aval.shape[-1],
) )
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer) ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
if not is_2x:
if len(rowwise_scale_inv_shape) > 1: out_shape = (1,)
rowwise_scale_inv_shape = ( colwise_scale_inv_shape = (1,)
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:] colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
)
if len(colwise_scale_inv_shape) > 1:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) shape=colwise_scale_inv_shape, dtype=scale_dtype
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) )
if is_2x:
colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
...@@ -211,15 +201,8 @@ class ActLuPrimitive(BasePrimitive): ...@@ -211,15 +201,8 @@ class ActLuPrimitive(BasePrimitive):
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False) ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: # Slice out padding for MXFP8, noop for DelayedScaling
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if is_2x:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv = jax.lax.slice( scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
) )
...@@ -227,6 +210,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -227,6 +210,7 @@ class ActLuPrimitive(BasePrimitive):
colwise_scale_inv = jax.lax.slice( colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
) )
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod @staticmethod
...@@ -292,11 +276,14 @@ class ActLuPrimitive(BasePrimitive): ...@@ -292,11 +276,14 @@ class ActLuPrimitive(BasePrimitive):
is_outer, is_outer,
) # Unused. ) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
out_spec = (*x_spec[:-2], None, x_spec[-2]) scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-2], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec) colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else: else:
colwise_out_spec = out_spec colwise_out_spec = out_spec
else: else:
...@@ -304,18 +291,24 @@ class ActLuPrimitive(BasePrimitive): ...@@ -304,18 +291,24 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
) )
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax") amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
) )
return ( return (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -340,14 +333,14 @@ class ActLuPrimitive(BasePrimitive): ...@@ -340,14 +333,14 @@ class ActLuPrimitive(BasePrimitive):
): ):
del result_infos, is_outer # Unused. del result_infos, is_outer # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
out_spec = (*x_spec[:-1], x_spec[-1]) scale_spec = get_padded_spec(arg_infos[1])
if act_len == 2 and x_spec[-1] is None:
# Ensure last axis is partitioned and not the gating axis out_spec = (*x_spec[:-2], x_spec[-1])
out_spec = (*x_spec[:-2], None, x_spec[-2])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec) colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else: else:
colwise_out_spec = out_spec colwise_out_spec = out_spec
else: else:
...@@ -355,21 +348,25 @@ class ActLuPrimitive(BasePrimitive): ...@@ -355,21 +348,25 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
) )
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax") amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
) )
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec)) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -413,6 +410,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -413,6 +410,7 @@ class ActLuPrimitive(BasePrimitive):
register_primitive(ActLuPrimitive) register_primitive(ActLuPrimitive)
# TODO(Jeremy): replace is_2x with q_layout
class DActLuDBiasQuantizePrimitive(BasePrimitive): class DActLuDBiasQuantizePrimitive(BasePrimitive):
""" """
DActLu DBias Cast Transpose Primitive DActLu DBias Cast Transpose Primitive
...@@ -445,42 +443,41 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -445,42 +443,41 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p abstract te_dact_dbias_quantize_p abstract
""" """
del act_enum, scale_shapes del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(dz_aval.dtype) dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype assert x_aval.dtype == dz_dtype
assert x_aval.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x_aval.shape} and act_len {act_len}"
)
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
ir_hidden_size = dz_aval.shape[-1] ir_hidden_size = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1] gi_hidden_size = act_len * x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size assert act_len * ir_hidden_size == gi_hidden_size
out_shape = x_aval.shape out_shape = x_aval.shape
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if is_2x: if is_2x:
# Don't transpose output for MXFP8 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
t_shape = out_shape
else: else:
t_shape = multidim_transpose(out_shape) colwise_out_shape = out_shape
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype) else:
colwise_scale_inv_aval = jax.core.ShapedArray( colwise_out_shape = (1,)
shape=colwise_scale_inv_shape, dtype=scale_dtype colwise_scale_inv_shape = (1,)
) colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
if is_dbias: if is_dbias:
dbias_shape = gi_hidden_size dbias_shape = (act_len, ir_hidden_size)
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
(wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size, x_aval.size // gi_hidden_size,
gi_hidden_size, gi_hidden_size,
...@@ -489,9 +486,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -489,9 +486,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
) )
wkspace_aval = x_aval.update( wkspace_shape = wkspace_info[0]
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
) else:
dbias_shape = (1,)
wkspace_shape = (1,)
wkspace_dtype = jnp.float32
dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
return ( return (
out_aval, out_aval,
...@@ -587,23 +589,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -587,23 +589,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False) ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: # Slice out padding for MXFP8, noop for DelayedScaling
scale_inv = jax.lax.slice( scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
) )
if is_2x: return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) # Exclude wkspace
@staticmethod @staticmethod
def batcher( def batcher(
...@@ -670,15 +665,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -670,15 +665,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
result_infos, result_infos,
): ):
del out_dtype, result_infos, act_enum del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, is_dbias, act_len, is_outer del scale_dtype, scale_shapes, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
) )
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec) colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else: else:
colwise_x_spec = x_spec colwise_x_spec = x_spec
else: else:
...@@ -687,23 +683,32 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -687,23 +683,32 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
) )
dbias_shaprding = NamedSharding( dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(x_spec[-1]), PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias", desc="DActLuDBiasQuantizePrimitive.dbias",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
) )
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax" mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: colwise_scale_inv_sharding = NamedSharding(
scale_inv_sharding = NamedSharding( mesh,
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" PartitionSpec(*colwise_scale_inv_spec),
) desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
) )
return ( return (
out_sharding, out_sharding,
...@@ -711,7 +716,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -711,7 +716,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scale_inv_sharding, scale_inv_sharding,
colwise_scale_inv_sharding, colwise_scale_inv_sharding,
amax_sharding, amax_sharding,
dbias_shaprding, dbias_sharding,
) )
@staticmethod @staticmethod
...@@ -731,10 +736,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -731,10 +736,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
): ):
del result_infos, is_outer del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out") scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec) colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else: else:
colwise_x_spec = x_spec colwise_x_spec = x_spec
else: else:
...@@ -743,38 +753,39 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -743,38 +753,39 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
) )
dbias_shaprding = NamedSharding( dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(x_spec[-1]), PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias", desc="DActLuDBiasQuantizePrimitive.dbias",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
) )
amax_sharding = NamedSharding( amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax" colwise_scale_inv_sharding = NamedSharding(
) mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
) )
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = (
arg_shardings[1],
arg_shardings[1],
*arg_shardings[2:],
) # dz and x are the same
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
scale_inv_sharding, scale_inv_sharding,
colwise_scale_inv_sharding, colwise_scale_inv_sharding,
amax_sharding, amax_sharding,
dbias_shaprding, dbias_sharding,
) )
def sharded_impl(dz, x, scale): def sharded_impl(dz, x, scale):
...@@ -816,14 +827,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S ...@@ -816,14 +827,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
""" """
JAX native activation implementation JAX native activation implementation
""" """
x = jnp.split(inputs, len(activation_type), axis=-1) act_len = len(activation_type)
assert inputs.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {inputs.shape} and act_len {act_len}"
)
x = jnp.split(inputs, act_len, axis=-2)
acts = [] acts = []
for idx, act_fn in enumerate(activation_type): for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i) acts.append(x_i)
x = reduce(operator.mul, acts) x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
if quantizer: if quantizer:
return quantizer.quantize(x) return quantizer.quantize(x, flatten_axis=-1)
return x return x
...@@ -837,6 +855,12 @@ def _jax_quantize_dact_dbias( ...@@ -837,6 +855,12 @@ def _jax_quantize_dact_dbias(
""" """
JAX implementation of dact_lu and dbias with optional quantization JAX implementation of dact_lu and dbias with optional quantization
""" """
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
) )
...@@ -844,10 +868,10 @@ def _jax_quantize_dact_dbias( ...@@ -844,10 +868,10 @@ def _jax_quantize_dact_dbias(
dbias = None dbias = None
if is_dbias: if is_dbias:
dbias = _jax_dbias(dx).astype(x.dtype) dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
if quantizer is not None: if quantizer is not None:
dx = quantizer.quantize(dx, dq_dtype=x.dtype) dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
else: else:
dx = dx.astype(x.dtype) dx = dx.astype(x.dtype)
...@@ -863,6 +887,7 @@ def act_lu( ...@@ -863,6 +887,7 @@ def act_lu(
Args: Args:
x: Input tensor to be processed. x: Input tensor to be processed.
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
...@@ -873,12 +898,17 @@ def act_lu( ...@@ -873,12 +898,17 @@ def act_lu(
A ScaledTensor containing the quantized activated input. A ScaledTensor containing the quantized activated input.
""" """
act_type_id = ActivationEnum[activation_type].value act_type_id = ActivationEnum[activation_type].value
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
if not ActLuPrimitive.enabled(): if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer) return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support colwise-only quantization yet # TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_act_lu(x, activation_type, quantizer) return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support 2x quantization for DelayedScaling yet # TE/common does not support 2x quantization for DelayedScaling yet
...@@ -889,16 +919,15 @@ def act_lu( ...@@ -889,16 +919,15 @@ def act_lu(
return war_output return war_output
scale = jnp.empty((1,), jnp.float32) scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type)) output_shape = (*x.shape[:-2], x.shape[-1])
if quantizer is None: if quantizer is None:
x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type)))
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
x, x,
scale, scale,
out_dtype=x.dtype, out_dtype=x.dtype,
act_enum=act_type_id, act_enum=act_type_id,
act_len=len(activation_type), act_len=act_len,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
...@@ -911,7 +940,6 @@ def act_lu( ...@@ -911,7 +940,6 @@ def act_lu(
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type)))
( (
rowwise_casted_output, rowwise_casted_output,
colwise_casted_output, colwise_casted_output,
...@@ -923,25 +951,15 @@ def act_lu( ...@@ -923,25 +951,15 @@ def act_lu(
scale, scale,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
act_enum=act_type_id, act_enum=act_type_id,
act_len=len(activation_type), act_len=act_len,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(output_shape), # output does not have act axis
scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
is_outer=True, is_outer=True,
) )
rowwise_casted_output = rowwise_casted_output.reshape(output_shape)
if len(rowwise_scale_inv.shape) > 1:
rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2) # Remove act axis
if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE):
colwise_output_shape = output_shape
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_output_shape = multidim_transpose(output_shape)
colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape)
if len(colwise_scale_inv.shape) > 1:
colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2) # Remove act axis
quantizer.update(updated_amax) quantizer.update(updated_amax)
return ScaledTensorFactory.create( return ScaledTensorFactory.create(
...@@ -951,8 +969,8 @@ def act_lu( ...@@ -951,8 +969,8 @@ def act_lu(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
) )
...@@ -968,7 +986,7 @@ def quantize_dact_dbias( ...@@ -968,7 +986,7 @@ def quantize_dact_dbias(
Args: Args:
dz: Gradient of the output with respect to the activation output. dz: Gradient of the output with respect to the activation output.
x: Input tensor that was processed by the forward pass. x: Input tensor that was processed by the forward pass.
Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
...@@ -979,21 +997,25 @@ def quantize_dact_dbias( ...@@ -979,21 +997,25 @@ def quantize_dact_dbias(
- The gradient of the activation with respect to the bias. - The gradient of the activation with respect to the bias.
""" """
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
if not DActLuDBiasQuantizePrimitive.enabled(): if not DActLuDBiasQuantizePrimitive.enabled():
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet # TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = quantize_dact_dbias( out = dact_lu(dz, x, activation_type, quantizer=None)
dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
)
return quantize_dbias(out, is_dbias=True, quantizer=quantizer)
is_gated = len(activation_type) == 2 is_gated = act_len == 2
# TE/common does not support DelayedScaling2x for gated-act yet # TE/common does not support DelayedScaling2x for gated-act yet
if is_gated: if is_gated:
war_output = try_apply_delayed_scaling_2x_war( war_output = try_apply_delayed_scaling_2x_war(
...@@ -1003,6 +1025,7 @@ def quantize_dact_dbias( ...@@ -1003,6 +1025,7 @@ def quantize_dact_dbias(
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=-2,
) )
if war_output is not None: if war_output is not None:
return war_output return war_output
...@@ -1025,12 +1048,12 @@ def quantize_dact_dbias( ...@@ -1025,12 +1048,12 @@ def quantize_dact_dbias(
scale_shapes=((), ()), # unused scale_shapes=((), ()), # unused
is_dbias=False, is_dbias=False,
act_enum=act_type_id, act_enum=act_type_id,
act_len=len(activation_type), act_len=act_len,
is_outer=True, is_outer=True,
) )
dbias = None dbias = None
if is_dbias: if is_dbias:
dbias = _jax_dbias(output).astype(x.dtype) dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias return output.astype(x.dtype), dbias
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
...@@ -1041,16 +1064,9 @@ def quantize_dact_dbias( ...@@ -1041,16 +1064,9 @@ def quantize_dact_dbias(
dgated = dact_lu( dgated = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
) )
# TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests out, dbias = _quantize_dbias_impl(
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype) )
else:
out, dbias = quantize_dbias(
dgated,
quantizer=quantizer,
is_dbias=True,
dq_dtype=x.dtype,
)
return out, dbias return out, dbias
out_shape = x.shape out_shape = x.shape
...@@ -1070,10 +1086,11 @@ def quantize_dact_dbias( ...@@ -1070,10 +1086,11 @@ def quantize_dact_dbias(
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(out_shape), # output has act axis
scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_type_id, act_enum=act_type_id,
act_len=len(activation_type), act_len=act_len,
is_outer=True, is_outer=True,
) )
...@@ -1090,8 +1107,9 @@ def quantize_dact_dbias( ...@@ -1090,8 +1107,9 @@ def quantize_dact_dbias(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=-2, # as output has act axis
) )
return out, dbias return out, dbias
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
from typing import Tuple, Sequence, Union, Dict, List from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce from functools import partial, reduce
import operator import operator
from transformer_engine_jax import get_device_compute_capability
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
...@@ -183,10 +183,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): ...@@ -183,10 +183,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose # Reshape + Transpose
# [..., M, K] -> [B, M, K] # [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K] # [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N") lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T") rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,))) dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general( out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
...@@ -203,9 +202,9 @@ def _jax_gemm_delayed_scaling_fp8( ...@@ -203,9 +202,9 @@ def _jax_gemm_delayed_scaling_fp8(
), "rhs does not have delayed tensor scaling mode" ), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T": if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract) lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T": if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract) rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch) lhs_dn = (lhs_contract, lhs_batch)
...@@ -403,19 +402,19 @@ def grouped_gemm( ...@@ -403,19 +402,19 @@ def grouped_gemm(
lhs_shape = lhs.data.shape lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout # For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
assert not ( assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2" ), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T": if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T": if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else: else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN # For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING scaling_mode = ScalingMode.NVTE_NO_SCALING
lhs_shape = lhs.shape lhs_shape = lhs.shape
rhs_shape = rhs.shape rhs_shape = rhs.shape
...@@ -432,8 +431,8 @@ def grouped_gemm( ...@@ -432,8 +431,8 @@ def grouped_gemm(
lhs_3d = _shape_normalization(lhs, lhs_dn) lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N") lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn) lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn)
......
...@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type ...@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import transformer_engine_jax import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout
TEDType = transformer_engine_jax.DType TEDType = transformer_engine_jax.DType
...@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim): ...@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1): def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
""" """
te_cast_transpose_p multi-dims transpose te_cast_transpose_p multi-dims transpose
static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
involved into transpose, -1 means all axes involve into transpose. involved into transpose, -1 means all axes involve into transpose.
transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary transpose. Note, transpose_axis should be greater than static_axis_boundary
examples: examples:
X in shape (dim0, dim1, dim2, dim3, dim4) X in shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis_boundary == 2 static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1) Xt = (dim2, dim3, dim4, dim0, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 2 static_axis_boundary == 0, transpose_axis == 2
Xt = (dim0, dim2, dim3, dim4, dim1) Xt = (dim0, dim2, dim3, dim4, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 3 static_axis_boundary == 0, transpose_axis == 3
Xt = (dim0, dim3, dim4, dim1. dim2) Xt = (dim0, dim3, dim4, dim1. dim2)
""" """
if static_axis_boundary < 0: if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes static_axis_boundary = -1 # means no static axes
assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose.
transpose_start_idx = static_axis_boundary + 1 transpose_start_idx = static_axis_boundary + 1
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape)) transpose_axis = normalize_axis_boundary(transpose_axis, len(shape))
assert transpose_start_idx < transpose_axis_boundary assert transpose_start_idx < transpose_axis
return ( return (
*shape[:transpose_start_idx], *shape[:transpose_start_idx],
*shape[transpose_axis_boundary:], *shape[transpose_axis:],
*shape[transpose_start_idx:transpose_axis_boundary], *shape[transpose_start_idx:transpose_axis],
) )
...@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant ...@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break break
return ( return (
quantizer is not None quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100 and arch_l_100
and is_dbias and is_dbias
) )
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs):
""" """
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling. Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result. It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
...@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): ...@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
# 2x is not supported by TE kernels for delayed scaling # 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX # so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE quantizer.q_layout = QuantizeLayout.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer) rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None other_outputs = None
if isinstance(rowwise, tuple): if isinstance(rowwise, tuple):
other_outputs = rowwise[1:] other_outputs = rowwise[1:]
rowwise = rowwise[0] rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1))) if flatten_axis < 0:
flatten_axis += rowwise.data.ndim
assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
colwise_data = jnp.transpose(
rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis))
)
output_2x = ScaledTensorFactory.create( output_2x = ScaledTensorFactory.create(
data=rowwise.data, data=rowwise.data,
scale_inv=rowwise.scale_inv, scale_inv=rowwise.scale_inv,
...@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): ...@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
colwise_scale_inv=rowwise.scale_inv, colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype, dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE, q_layout=QuantizeLayout.ROWWISE_COLWISE,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
) )
if other_outputs is not None: if other_outputs is not None:
return (output_2x,) + other_outputs return (output_2x,) + other_outputs
......
...@@ -30,7 +30,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a ...@@ -30,7 +30,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeAxis, QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
) )
...@@ -277,14 +277,14 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -277,14 +277,14 @@ class NormFwdPrimitive(BasePrimitive):
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x.shape, is_padded=False x.shape, is_padded=False
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: # slice out padding for mxfp8, noop for DelayedScaling
scale_inv = scale_inv.flatten()[ scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
: reduce(operator.mul, rowwise_scale_inv_shape) rowwise_scale_inv_shape
].reshape(rowwise_scale_inv_shape) )
if is_2x: if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[ colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape) : reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape) ].reshape(colwise_scale_inv_shape)
return ( return (
out, out,
colwise_out, colwise_out,
...@@ -816,7 +816,7 @@ def layernorm_fwd( ...@@ -816,7 +816,7 @@ def layernorm_fwd(
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -900,8 +900,8 @@ def layernorm_fwd( ...@@ -900,8 +900,8 @@ def layernorm_fwd(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
) )
return scaled_tensor, mu, rsigma return scaled_tensor, mu, rsigma
...@@ -997,7 +997,7 @@ def rmsnorm_fwd( ...@@ -997,7 +997,7 @@ def rmsnorm_fwd(
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -1082,8 +1082,8 @@ def rmsnorm_fwd( ...@@ -1082,8 +1082,8 @@ def rmsnorm_fwd(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
) )
return scaled_tensor, rsigma return scaled_tensor, rsigma
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for quantization""" """JAX/TE custom ops for quantization"""
import operator
from functools import reduce
from typing import Tuple, Optional from typing import Tuple, Optional
from packaging import version from packaging import version
...@@ -24,7 +26,7 @@ from .misc import ( ...@@ -24,7 +26,7 @@ from .misc import (
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports from jax import ffi # pylint: disable=ungrouped-imports
...@@ -50,7 +52,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -50,7 +52,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
6, 6,
7, 7,
8, 8,
) # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer 9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -61,7 +64,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -61,7 +64,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -73,49 +77,52 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -73,49 +77,52 @@ class DBiasQuantizePrimitive(BasePrimitive):
del scale_shapes del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
rowwise_out_shape = out_shape
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): else:
rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) rowwise_out_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) shape=colwise_scale_inv_shape, dtype=scale_dtype
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) )
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
t_shape = multidim_transpose(x_aval.shape)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
# Don't transpose output for MXFP8
t_shape = x_aval.shape
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
if is_dbias: if is_dbias:
gi_hidden_size = x_aval.shape[-1] dbias_shape = x_aval.shape[flatten_axis:]
dbias_shape = (gi_hidden_size,) gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
(wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size, x_aval.size // gi_hidden_size,
gi_hidden_size, gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
) )
wkspace_aval = x_aval.update( wkspace_shape = wkspace_info[0]
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
) else:
dbias_shape = (1,)
wkspace_shape = (1,)
wkspace_dtype = jnp.float32
dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
return ( return (
rowwise_out_aval, rowwise_out_aval,
...@@ -151,7 +158,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -151,7 +158,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -169,7 +177,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -169,7 +177,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
x, x,
scale, scale,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_axis=q_axis, q_layout=q_layout,
flatten_axis=flatten_axis,
is_dbias=is_dbias, is_dbias=is_dbias,
) )
...@@ -179,7 +188,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -179,7 +188,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale, scale,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -203,7 +213,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -203,7 +213,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_axis=q_axis, q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
...@@ -211,16 +222,14 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -211,16 +222,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False) ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: scale_inv = jax.lax.slice(
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
scale_inv = jax.lax.slice( )
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
) colwise_scale_inv = jax.lax.slice(
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
colwise_scale_inv = jax.lax.slice( )
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return ( return (
out, out,
colwise_out, colwise_out,
...@@ -237,7 +246,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -237,7 +246,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -260,7 +270,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -260,7 +270,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_axis=q_axis, q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
...@@ -272,7 +283,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -272,7 +283,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def infer_sharding_from_operands( def infer_sharding_from_operands(
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -281,16 +293,17 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -281,16 +293,17 @@ class DBiasQuantizePrimitive(BasePrimitive):
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer) # Unused. del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]), PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding", desc="DBiasQuantizePrimitive.out_sharding",
) )
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
colwise_out_spec = x_spec colwise_out_spec = x_spec
else: else:
...@@ -300,26 +313,35 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -300,26 +313,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*colwise_out_spec), PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding", desc="DBiasQuantizePrimitive.colwise_out_sharding",
) )
scale_inv_sharding = NamedSharding(
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])), PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.scale_inv", desc="DBiasQuantizePrimitive.dbias_sharding",
) )
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding" scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: amax_sharding = NamedSharding(
scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
) )
dbias_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(x_spec[-1]), PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.dbias_sharding", desc="DBiasQuantizePrimitive.colwise_scale_inv",
) )
return ( return (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -333,7 +355,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -333,7 +355,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def partition( def partition(
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -344,14 +367,15 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -344,14 +367,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
): ):
del result_infos, is_outer del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]), PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding", desc="DBiasQuantizePrimitive.out_sharding",
) )
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
colwise_out_spec = x_spec colwise_out_spec = x_spec
else: else:
...@@ -361,26 +385,35 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -361,26 +385,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*colwise_out_spec), PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding", desc="DBiasQuantizePrimitive.colwise_out_sharding",
) )
scale_inv_sharding = NamedSharding(
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])), PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.scale_inv", desc="DBiasQuantizePrimitive.dbias_sharding",
) )
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding" scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: amax_sharding = NamedSharding(
scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
) )
dbias_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(x_spec[-1]), PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.dbias_sharding", desc="DBiasQuantizePrimitive.colwise_scale_inv",
) )
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
...@@ -404,7 +437,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -404,7 +437,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_axis=q_axis, q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
...@@ -436,49 +470,45 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -436,49 +470,45 @@ class DBiasQuantizePrimitive(BasePrimitive):
register_primitive(DBiasQuantizePrimitive) register_primitive(DBiasQuantizePrimitive)
def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None): def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
if quantizer is None: if quantizer is None:
return x return x
return quantizer.quantize(x, dq_dtype=dq_dtype) return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def _jax_dbias(dx: jnp.ndarray): def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
assert flatten_axis < 0
dtype = dtype or dx.dtype
dbias = jnp.sum( dbias = jnp.sum(
dx, dx.astype(jnp.float32),
axis=tuple(range(dx.ndim - 1)), axis=tuple(range(dx.ndim + flatten_axis)),
keepdims=False, keepdims=False,
) )
dbias = dbias.ravel() # C++ function returns an 1D array for dbias return dbias.astype(dtype)
return dbias
def _jax_quantize_dbias( def _jax_quantize_dbias(
x, x,
quantizer: Quantizer = None, quantizer: Quantizer = None,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
): ):
if quantizer is None: if quantizer is None:
return x, None return x, None
return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x) return (
quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
_jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
def _jax_dbias(
dx: jnp.ndarray,
):
dbias = jnp.sum(
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim - 1)),
keepdims=False,
) )
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias.astype(dx.dtype)
def _quantize_impl( def _quantize_dbias_impl(
x: jnp.ndarray, x: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = False, is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -488,40 +518,51 @@ def _quantize_impl( ...@@ -488,40 +518,51 @@ def _quantize_impl(
quantizer is not None quantizer is not None
), "quantizer must be provided if dq_dtype is provided" ), "quantizer must be provided if dq_dtype is provided"
dq_dtype = dq_dtype or x.dtype
if not DBiasQuantizePrimitive.enabled(): if not DBiasQuantizePrimitive.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( return _jax_quantize_dbias(
x, x,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
# TE/common doesn't support colwise only quantization yet # TE/common doesn't support colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if is_dbias: if is_dbias:
return _jax_quantize_dbias( return _jax_quantize_dbias(
x, x,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
scale = jnp.empty((), jnp.float32) scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100 # TE/common dbias_quantize does not support 1x on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_impl( out, _ = _quantize_dbias_impl(
x=x, x=x,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
dbias = _jax_dbias(x) dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
if quantizer is None: if quantizer is None:
if is_dbias: if is_dbias:
return x, _jax_dbias(x) return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None return x, None
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
...@@ -539,9 +580,10 @@ def _quantize_impl( ...@@ -539,9 +580,10 @@ def _quantize_impl(
scale, scale,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
q_axis=quantizer.q_axis.value, q_layout=quantizer.q_layout.value,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape), scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
is_dbias=is_dbias, is_dbias=is_dbias,
is_outer=True, is_outer=True,
) )
...@@ -557,18 +599,18 @@ def _quantize_impl( ...@@ -557,18 +599,18 @@ def _quantize_impl(
colwise_data=colwise_casted_output, colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=dq_dtype if dq_dtype is not None else x.dtype, dq_dtype=dq_dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
) )
return out, dbias return out, dbias.astype(dq_dtype)
# TODO(Phuong): do not expose dq_dtype to users
def quantize( def quantize(
x: jnp.ndarray, x: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -576,26 +618,25 @@ def quantize( ...@@ -576,26 +618,25 @@ def quantize(
x: Input tensor to be quantized. x: Input tensor to be quantized.
Shape: (..., K) where K is the hidden size. Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output. quantizer: Quantizer for FP8 quantization of the output.
dq_dtype: Optional dtype for dequantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
If None, uses the same dtype as the input tensor. Defaults to -1.
Returns: Returns:
A ScaledTensor containing the quantized input tensor. A ScaledTensor containing the quantized input tensor.
""" """
out, _ = _quantize_impl( out, _ = _quantize_dbias_impl(
x, x,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, flatten_axis=flatten_axis,
) )
return out return out
# TODO(Phuong): do not expose dq_dtype to users
def quantize_dbias( def quantize_dbias(
dz: jnp.ndarray, dz: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = True, is_dbias: bool = True,
dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -604,8 +645,8 @@ def quantize_dbias( ...@@ -604,8 +645,8 @@ def quantize_dbias(
Shape: (..., K) where K is the hidden size. Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output. quantizer: Quantizer for FP8 quantization of the output.
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
dq_dtype: Optional dtype for dequantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
If None, uses the same dtype as the input tensor. Defaults to -1.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -614,9 +655,6 @@ def quantize_dbias( ...@@ -614,9 +655,6 @@ def quantize_dbias(
- The bias gradient tensor. - The bias gradient tensor.
Shape: (K,) or empty if is_dbias is False. Shape: (K,) or empty if is_dbias is False.
""" """
return _quantize_impl( return _quantize_dbias_impl(
dz, dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
quantizer=quantizer,
is_dbias=is_dbias,
dq_dtype=dq_dtype,
) )
...@@ -11,14 +11,6 @@ ...@@ -11,14 +11,6 @@
#include "transformer_engine/cast.h" #include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
namespace {
bool is_gated(NVTE_Activation_Type act_type) {
return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU ||
act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU ||
act_type == NVTE_Activation_Type::SREGLU;
}
} // namespace
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -44,38 +36,56 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal ...@@ -44,38 +36,56 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto act_len = input_dims[input_dims.size() - 2]; auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum); auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int); auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, act_len * n}; auto input_shape = std::vector<size_t>{m, act_len * n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape); output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv( if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
scale_inv_buf->untyped_data(), NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
std::vector<size_t>{ cudaMemsetAsync(amax, 0, sizeof(float), stream);
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
scale_inv_buf->dimensions().back()}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
} output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); } else {
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); output_tensor.set_rowwise_scale_inv(
cudaMemsetAsync(amax, 0, sizeof(float), stream); scale_inv_buf->untyped_data(),
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1}); std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(colwise_output, static_cast<DType>(out_dtype), output_shape); auto &tmp_shape =
output_tensor.set_columnwise_scale_inv( (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
colwise_scale_inv_buf->untyped_data(), output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, if (is_fp8_dtype(out_dtype)) {
colwise_scale_inv_buf->dimensions().size() - 1), // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
colwise_scale_inv_buf->dimensions().back()}); auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
} }
switch (act_type) { switch (act_type) {
...@@ -162,8 +172,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -162,8 +172,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, auto &tmp_shape = scaling_mode == static_cast<int>(NVTE_DELAYED_TENSOR_SCALING)
output_trans_shape); ? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter // Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
...@@ -190,9 +202,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid ...@@ -190,9 +202,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x, Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
bool is_dbias, int64_t act_enum) { bool is_dbias, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
...@@ -201,11 +213,15 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -201,11 +213,15 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum); auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data();
void *workspace = workspace_buf->untyped_data(); void *workspace = workspace_buf->untyped_data();
...@@ -213,17 +229,18 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -213,17 +229,18 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto act_input_dims = act_input_buf.dimensions(); auto act_input_dims = act_input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions(); auto workspace_dims = workspace_buf->dimensions();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims // n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto input_ranks = input_dims.size(); auto act_len = act_input_dims[act_input_dims.size() - 2];
auto act_input_ranks = act_input_dims.size(); NVTE_CHECK(act_input_dims.back() == input_dims.back(),
auto m = product(act_input_dims, 0, act_input_dims.size() - 1); "Shape mismatch between activation input and gradient input");
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = act_input_dims.back(); auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, input_dims.back()};
auto act_input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_trans_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n * act_len};
auto dbias_shape = std::vector<size_t>{n}; auto output_trans_shape = std::vector<size_t>{n * act_len, m};
auto dbias_shape = std::vector<size_t>{n * act_len};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end()); std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
...@@ -231,50 +248,56 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, ...@@ -231,50 +248,56 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
} }
} }
if (is_2x) { if (is_2x) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) { if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf = auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
output_tensor.set_columnwise_scale_inv( if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
colwise_scale_inv_buf->untyped_data(), output_tensor.set_columnwise_scale_inv(
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, std::vector<size_t>{1});
colwise_scale_inv_buf->dimensions().size() - 1), } else {
colwise_scale_inv_buf->dimensions().back()}); output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
} }
} }
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!"); NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && NVTE_CHECK(
is_gated(act_type)), !(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations."); "TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) { if (is_dbias) {
switch (act_type) { switch (act_type) {
......
...@@ -44,12 +44,12 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -44,12 +44,12 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k], // A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose, // B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect: // on exiting this function, JAX expect:
// C: row-major with size [m, n]. // C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair: // cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose, // A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n]. // B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be: // If we call cuBLAS GEMM for A * B, the output will be:
......
...@@ -34,7 +34,7 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -34,7 +34,7 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret; return ret;
} }
enum class QuantizeAxis { enum class QuantizeLayout {
ROWWISE, ROWWISE,
COLWISE, COLWISE,
ROWWISE_COLWISE, ROWWISE_COLWISE,
......
...@@ -144,11 +144,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -144,11 +144,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING) .value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
.export_values(); .export_values();
pybind11::enum_<transformer_engine::jax::QuantizeAxis>(m, "QuantizeAxis", pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
pybind11::module_local()) pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE) .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE) .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values(); .export_values();
} }
......
...@@ -42,10 +42,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ ...@@ -42,10 +42,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf, Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias,
int64_t quantize_axis_enum, bool is_dbias) { int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
...@@ -55,7 +55,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -55,7 +55,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data(); auto *input = input_buf.untyped_data();
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum); auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto const quantize_axis = static_cast<QuantizeAxis>(quantize_axis_enum); auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data(); auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data();
...@@ -63,9 +63,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -63,9 +63,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
void *workspace = workspace_buf->untyped_data(); void *workspace = workspace_buf->untyped_data();
auto input_dims = input_buf.dimensions(); auto input_dims = input_buf.dimensions();
int64_t input_ndim = input_dims.size();
if (flatten_axis < 0) flatten_axis += input_ndim;
NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");
auto workspace_dims = workspace_buf->dimensions(); auto workspace_dims = workspace_buf->dimensions();
auto m = product(input_dims, 0, input_dims.size() - 1); auto m = product(input_dims, 0, flatten_axis);
auto n = input_dims.back(); auto n = product(input_dims, flatten_axis, input_ndim);
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m}; auto output_trans_shape = std::vector<size_t>{n, m};
...@@ -75,37 +79,54 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -75,37 +79,54 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode); auto output_tensor = TensorWrapper(scaling_mode);
if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape); output_tensor.set_rowwise_data(output, out_dtype, output_shape);
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (is_fp8_dtype(out_dtype)) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data()); if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data()); float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling"); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1}); cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
} }
if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) { if (quantize_layout == QuantizeLayout::COLWISE ||
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf = auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf; (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(), if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), output_tensor.set_columnwise_scale_inv(
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0, tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
colwise_scale_inv_buf->dimensions().size() - 1), std::vector<size_t>{1});
colwise_scale_inv_buf->dimensions().back()}); } else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
} }
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
...@@ -133,8 +154,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, ...@@ -133,8 +154,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // dbias .Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace .Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode") .Attr<int64_t>("scaling_mode")
.Attr<int64_t>("q_axis") .Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias"), .Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
......
...@@ -15,7 +15,11 @@ import jax ...@@ -15,7 +15,11 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize import QuantizerSet, noop_quantizer_set from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
)
def dense( def dense(
...@@ -23,6 +27,8 @@ def dense( ...@@ -23,6 +27,8 @@ def dense(
kernel: jnp.ndarray, kernel: jnp.ndarray,
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -48,12 +54,12 @@ def dense( ...@@ -48,12 +54,12 @@ def dense(
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
else: else:
output = _dense(x, kernel, bias, contracting_dims, quantizer_set) output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3,)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, quantizer_set): def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support This function implements the core dense layer transformation logic with support
...@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
kernel: Weight matrix kernel: Weight matrix
bias: Optional bias tensor bias: Optional bias tensor
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set) output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
)
return output return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
x_contracting_dims, k_contracting_dims = contracting_dims x_contracting_dims, k_contracting_dims = contracting_dims
casted_x = tex.quantize(x, quantizer_set.x) flatten_axis_x = -len(x_contracting_dims)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN # GEMM NN
output = tex.gemm( output = tex.gemm(
...@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
casted_kernel.get_colwise_tensor(), casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
use_bias = bias is not None use_bias = bias is not None
if use_bias: if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
...@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set): ...@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
kernel.shape, kernel.shape,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k,
) )
return output, ctx return output, ctx
def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns: Returns:
Tuple of gradients with respect to inputs Tuple of gradients with respect to inputs
""" """
...@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
kernel_shape, kernel_shape,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k,
) = ctx ) = ctx
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
)
# GEMM NT # GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
rowwise_casted_kernel, rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim), (g_constracting_dim, k_constracting_dim),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
...@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu ...@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set return dgrad, wgrad, dbias, quantizer_set
......
...@@ -28,6 +28,7 @@ from ..softmax import softmax, SoftmaxType ...@@ -28,6 +28,7 @@ from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available from ..cpp_extensions import is_softmax_kernel_available
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -406,6 +407,10 @@ class DenseGeneral(TransformerEngineBase): ...@@ -406,6 +407,10 @@ class DenseGeneral(TransformerEngineBase):
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling. :math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -429,6 +434,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -429,6 +434,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
...@@ -460,29 +466,35 @@ class DenseGeneral(TransformerEngineBase): ...@@ -460,29 +466,35 @@ class DenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, inputs.ndim) axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
if self.kernel_axes:
assert len(kernel_shape) == len(self.kernel_axes), (
"Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
)
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) ).astype(input_dtype)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
else: else:
bias = None bias = None
quantizer_set = self.generate_quantizer_set() quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
y = dense( y = dense(
inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set inputs,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -491,20 +503,14 @@ class DenseGeneral(TransformerEngineBase): ...@@ -491,20 +503,14 @@ class DenseGeneral(TransformerEngineBase):
*features[:-1], *features[:-1],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_init_shape = ( lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes( lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_shape,
self.dtype, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
...@@ -527,7 +533,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -527,7 +533,6 @@ class DenseGeneral(TransformerEngineBase):
y += jnp.reshape(bias, bias_shape) y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype assert y.dtype == input_dtype
y = y.reshape(*inputs.shape[: self.axis], *features)
return y return y
...@@ -678,6 +683,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -678,6 +683,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
assert self.axis == -1, "Only support axis = =-1 at this moment"
input_dtype = inputs.dtype input_dtype = inputs.dtype
ln_output = None ln_output = None
...@@ -692,10 +698,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -692,10 +698,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.enable_layernorm: if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1] features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters( scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type, self.layernorm_type,
(features,), (features,),
...@@ -731,17 +734,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -731,17 +734,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -756,11 +754,19 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -756,11 +754,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes, dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
) )
else: else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set) z = dense(
y,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
lora_a_kernel_shape = ( lora_a_kernel_shape = (
...@@ -768,20 +774,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -768,20 +774,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
*features[:-1], *features[:-1],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_init_shape = ( lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes( lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_shape,
self.dtype, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
...@@ -803,8 +803,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -803,8 +803,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) ).astype(input_dtype)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
if bias is not None: if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
...@@ -814,7 +813,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -814,7 +813,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
z = z / self.depth_scaling z = z / self.depth_scaling
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
z = z.reshape(*inputs.shape[: self.axis], *features) # z = z.reshape(*inputs.shape[: self.axis], *features)
return z, ln_output # dense_output, layer_norm_output return z, ln_output # dense_output, layer_norm_output
...@@ -989,6 +988,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -989,6 +988,8 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization. The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0") ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1") ffn2_quantizer_set = self.generate_quantizer_set("_1")
...@@ -1027,7 +1028,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1027,7 +1028,6 @@ class LayerNormMLP(TransformerEngineBase):
) )
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1] features = inputs.shape[-1]
...@@ -1071,7 +1071,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1071,7 +1071,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations = len(normalized_acts) num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim) kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes( kernel_1 = nn_partitioning.param_with_axes(
"wi_kernel", "wi_kernel",
kernel_1_init, kernel_1_init,
...@@ -1081,17 +1081,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1081,17 +1081,10 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
axes=self.kernel_axes_1, axes=self.kernel_axes_1,
) )
kernel_1_compute_shape = (
reduce(operator.mul, [y.shape[ax] for ax in axis], 1),
num_activations * self.intermediate_dim,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape)
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype) kernel_1 = kernel_1.astype(input_dtype)
if self.kernel_axes_1 is not None:
kernel_1 = with_sharding_constraint_by_logical_axes(
kernel_1, self.kernel_axes_1[:-2] + self.kernel_axes_1[-1:]
)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size) hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
...@@ -1102,27 +1095,20 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1102,27 +1095,20 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype, self.dtype,
axes=self.kernel_axes_2, axes=self.kernel_axes_2,
) )
kernel_2_compute_shape = (
self.intermediate_dim,
reduce(operator.mul, hidden_size_tuple, 1),
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape)
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
if self.kernel_axes_2 is not None:
kernel_2 = with_sharding_constraint_by_logical_axes(kernel_2, self.kernel_axes_2)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
if self.use_bias: if self.use_bias:
bias_1_shape = num_activations * self.intermediate_dim bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes( bias_1 = nn_partitioning.param_with_axes(
"wi_bias", "wi_bias",
self.bias_init, self.bias_init,
bias_1_shape, bias_1_shape,
self.dtype, self.dtype,
axes=self.bias_axes_1, axes=self.bias_axes_1,
) ).astype(input_dtype)
bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype)
bias_2_shape = (hidden_size,) bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes( bias_2 = nn_partitioning.param_with_axes(
...@@ -1131,8 +1117,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1131,8 +1117,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2_shape, bias_2_shape,
self.dtype, self.dtype,
axes=self.bias_axes_2, axes=self.bias_axes_2,
) ).astype(input_dtype)
bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype)
else: else:
bias_1 = None bias_1 = None
bias_2 = None bias_2 = None
...@@ -1141,8 +1126,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1141,8 +1126,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name = "ffn2" ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp: if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_mlp( out = layernorm_mlp(
y, y,
scale, scale,
...@@ -1155,6 +1138,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1155,6 +1138,8 @@ class LayerNormMLP(TransformerEngineBase):
norm_input_axes=self.layernorm_input_axes, norm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes, dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes, dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name, ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name, ffn2_ckpt_name=ffn2_ckpt_name,
activation_type=normalized_acts, activation_type=normalized_acts,
...@@ -1175,6 +1160,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1175,6 +1160,7 @@ class LayerNormMLP(TransformerEngineBase):
epsilon=self.epsilon, epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes, layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes, dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
else: else:
...@@ -1183,35 +1169,31 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1183,35 +1169,31 @@ class LayerNormMLP(TransformerEngineBase):
y, y,
kernel_1, kernel_1,
contracting_dims=(axis, contract_ind), contracting_dims=(axis, contract_ind),
input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = ( wi_lora_a_kernel_each_shape = (
kernel_1_compute_shape[0], kernel_1_each_shape[: len(axis)],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_shape = (
kernel_1_each_shape[0],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_each_shape = (
kernel_1_each_shape[0],
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape) wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
wi_lora_a_kernel = nn_partitioning.param_with_axes( wi_lora_a_kernel = nn_partitioning.param_with_axes(
"wi_lora_a_kernel", "wi_lora_a_kernel",
kernel_1_init, kernel_1_init,
num_activations, num_activations,
-1, -2,
wi_lora_a_kernel_init_each_shape, wi_lora_a_kernel_each_shape,
self.dtype, self.dtype,
axes=wi_lora_a_kernel_axes, axes=wi_lora_a_kernel_axes,
) )
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype) wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = ( wi_lora_b_kernel_shape = (
...@@ -1232,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1232,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation( x += _apply_low_rank_adaptation(
y, y,
axis, axis,
num_activations * self.intermediate_dim, (num_activations, self.intermediate_dim),
wi_lora_a_kernel, wi_lora_a_kernel,
wi_lora_b_kernel, wi_lora_b_kernel,
self.low_rank_adaptation_alpha, self.low_rank_adaptation_alpha,
...@@ -1246,11 +1228,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1246,11 +1228,12 @@ class LayerNormMLP(TransformerEngineBase):
z = activation(x, normalized_acts) z = activation(x, normalized_acts)
else: else:
activations = [] activations = []
x = jnp.split(x, num_activations, axis=-1) x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(normalized_acts): for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = reduce(operator.mul, activations) z = reduce(operator.mul, activations)
z = jnp.squeeze(z, axis=-2)
z = z.astype(input_dtype) z = z.astype(input_dtype)
z = nn.Dropout( z = nn.Dropout(
...@@ -1264,7 +1247,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1264,7 +1247,12 @@ class LayerNormMLP(TransformerEngineBase):
# DenseGeneral 2 # DenseGeneral 2
out = dense( out = dense(
z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set z,
kernel_2,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
......
...@@ -33,10 +33,9 @@ def layernorm_dense( ...@@ -33,10 +33,9 @@ def layernorm_dense(
norm_type: str = "layernorm", norm_type: str = "layernorm",
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
# The logic axes of sharding constraint to the dot input.
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation. """Apply layer normalization followed by dense layer transformation.
...@@ -56,6 +55,7 @@ def layernorm_dense( ...@@ -56,6 +55,7 @@ def layernorm_dense(
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: Set of quantizers for different tensor types quantizer_set: Set of quantizers for different tensor types
Returns: Returns:
...@@ -78,6 +78,7 @@ def layernorm_dense( ...@@ -78,6 +78,7 @@ def layernorm_dense(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -91,6 +92,7 @@ def layernorm_dense( ...@@ -91,6 +92,7 @@ def layernorm_dense(
7, 7,
8, 8,
9, 9,
10,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -104,6 +106,7 @@ def _layernorm_dense( ...@@ -104,6 +106,7 @@ def _layernorm_dense(
epsilon: float, epsilon: float,
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
quantizer_set, quantizer_set,
): ):
"""Internal implementation of layernorm_dense with custom VJP. """Internal implementation of layernorm_dense with custom VJP.
...@@ -139,6 +142,7 @@ def _layernorm_dense( ...@@ -139,6 +142,7 @@ def _layernorm_dense(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule( ...@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for layernorm_dense. """Forward pass rule for layernorm_dense.
...@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule( ...@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
...@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule( ...@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
norm_type, norm_type,
quantizer_set.x, quantizer_set.x,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...) # Kernel in (hidden_in, hidden_out...)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel) flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
...@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule( ...@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims, k_contracting_dims,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis,
) )
return output, ctx return output, ctx
...@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule( ...@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument dot_input_axes, # pylint: disable=unused-argument
kernel_axes,
ctx, ctx,
grad, grad,
): ):
...@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule( ...@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
k_contracting_dims_in_fwd, k_contracting_dims_in_fwd,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis,
) = ctx ) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes) casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple( g_constracting_dim = tuple(
...@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule( ...@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
(x_constracting_dim, g_constracting_dim), (x_constracting_dim, g_constracting_dim),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
dx, dgamma, dbeta = tex.normalization_bwd( dx, dgamma, dbeta = tex.normalization_bwd(
dgrad, dgrad,
x, x,
......
...@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .sharding import get_non_contracting_logical_axes
def layernorm_mlp( def layernorm_mlp(
...@@ -37,6 +38,8 @@ def layernorm_mlp( ...@@ -37,6 +38,8 @@ def layernorm_mlp(
norm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
kernel_1_axes: Tuple[str, ...] = None,
kernel_2_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
...@@ -66,6 +69,8 @@ def layernorm_mlp( ...@@ -66,6 +69,8 @@ def layernorm_mlp(
norm_input_axes: Logical axes for sharding the layernorm input norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication
kernel_1_axes: Logical axes for sharding the first weight matrix
kernel_2_axes: Logical axes for sharding the second weight matrix
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
...@@ -109,6 +114,8 @@ def layernorm_mlp( ...@@ -109,6 +114,8 @@ def layernorm_mlp(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -117,7 +124,7 @@ def layernorm_mlp( ...@@ -117,7 +124,7 @@ def layernorm_mlp(
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -132,6 +139,8 @@ def _layernorm_mlp( ...@@ -132,6 +139,8 @@ def _layernorm_mlp(
norm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
kernel_1_axes: Tuple[str, ...],
kernel_2_axes: Tuple[str, ...],
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
...@@ -179,6 +188,8 @@ def _layernorm_mlp( ...@@ -179,6 +188,8 @@ def _layernorm_mlp(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
...@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule( ...@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
Returns: Returns:
Tuple of (output, context) for automatic differentiation Tuple of (output, context) for automatic differentiation
""" """
del kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len * intermediate) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in) # Kernel_2 should be in shape of (intermediate, hidden_in)
assert len(kernel_1.shape) == 2 assert len(kernel_1.shape) == 3
assert len(kernel_2.shape) == 2 assert len(kernel_2.shape) == 2
assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type) assert kernel_1.shape[-2] == len(activation_type)
x_contracting_dims = (len(x.shape) - 1,) x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
use_bias_1 = bias_1 is not None use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None use_bias_2 = bias_1 is not None
...@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
quantizer=ffn1_quantizer_set.x, quantizer=ffn1_quantizer_set.x,
) )
casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out) # (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm( dot_1_output = tex.gemm(
...@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule( ...@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_1.get_colwise_tensor(), casted_kernel_1.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1: if use_bias_1:
bias_1_shape = bias_1.shape bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
...@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule( ...@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2: if use_bias_2:
bias_2_shape = bias_2.shape bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
...@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument kernel_1_axes,
ffn2_ckpt_name, # pylint: disable=unused-argument kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type, activation_type,
ctx, ctx,
grad, grad,
...@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
( (
x, x,
mu, mu,
...@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule( ...@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_2 = tuple( g_contracting_dims_2 = tuple(
range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim) range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_constracting_dim_2 = tuple( k_contracting_dims_2 = tuple(
dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
) )
...@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule( ...@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
dgrad_2 = tex.gemm( dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2, rowwise_casted_kernel_2,
(g_constracting_dim_2, k_constracting_dim_2), (g_contracting_dims_2, k_contracting_dims_2),
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
x_constracting_dim = g_constracting_dim = tuple( x_contracting_dims = g_contracting_dims = tuple(
range(0, len(x.shape) - len(x_contracting_dims_in_fwd)) range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
) )
...@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule( ...@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
wgrad_2 = tex.gemm( wgrad_2 = tex.gemm(
colwise_casted_act_out, colwise_casted_act_out,
casted_grad.get_colwise_tensor(), casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim), (x_contracting_dims, g_contracting_dims),
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
casted_dact_out, dbias_1 = tex.quantize_dact_dbias( casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
dgrad_2, dgrad_2,
...@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule( ...@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1 = tuple( dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim) g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_constracting_dim_1 = tuple( k_contracting_dims_1 = tuple(
dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
) )
...@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule( ...@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
dgrad_1 = tex.gemm( dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(), casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1, rowwise_casted_kernel_1,
(g_constracting_dim_1, k_constracting_dim_1), (g_contracting_dims_1, k_contracting_dims_1),
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
# TN GEMM # TN GEMM
# (hidden, batch...) x (hidden, batch...) # (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm( wgrad_1 = tex.gemm(
colwise_casted_ln_out, colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(), casted_dact_out.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim), (x_contracting_dims, g_contracting_dims),
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
dx, dgamma, dbeta = tex.normalization_bwd( dx, dgamma, dbeta = tex.normalization_bwd(
dgrad_1, dgrad_1,
x, x,
......
...@@ -57,18 +57,27 @@ class Dequantizer: ...@@ -57,18 +57,27 @@ class Dequantizer:
data = scaled_tensor.data.astype(jnp.float32) data = scaled_tensor.data.astype(jnp.float32)
data_shape = data.shape data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32) scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
flatten_axis = scaled_tensor.flatten_axis
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaled_tensor.scaling_mode.get_scale_shape( scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
data = data.reshape( data = data.reshape(
*data_shape[:-2], *data_shape[: flatten_axis - 1],
scale_shape[-2], scale_shape[flatten_axis - 1],
int(data_shape[-2] / scale_shape[-2]), int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1], scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]), int(data_shape[-1] / scale_shape[-1]),
) )
scale = jnp.expand_dims(scale, axis=(-1, -3))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape( return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape data_shape
......
...@@ -14,7 +14,7 @@ from typing import Union, Optional ...@@ -14,7 +14,7 @@ from typing import Union, Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
...@@ -24,7 +24,7 @@ from .helper import ( ...@@ -24,7 +24,7 @@ from .helper import (
) )
__all__ = [ __all__ = [
"QuantizeAxis", "QuantizeLayout",
"Quantizer", "Quantizer",
"QuantizerSet", "QuantizerSet",
"DelayedScaleQuantizer", "DelayedScaleQuantizer",
...@@ -45,12 +45,12 @@ class Quantizer(ABC): ...@@ -45,12 +45,12 @@ class Quantizer(ABC):
Attributes: Attributes:
q_dtype: The data type for quantized values q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both) q_layout: The quantization axis (row-wise, column-wise, or both)
""" """
q_dtype: jnp.dtype q_dtype: jnp.dtype
scaling_mode: ScalingMode scaling_mode: ScalingMode
q_axis: QuantizeAxis q_layout: QuantizeLayout
def tree_flatten(self): def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations. """Flatten the quantizer for JAX tree operations.
...@@ -59,7 +59,7 @@ class Quantizer(ABC): ...@@ -59,7 +59,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations Tuple of (children, aux_data) for tree operations
""" """
children = () children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data) return (children, aux_data)
@classmethod @classmethod
...@@ -85,30 +85,31 @@ class Quantizer(ABC): ...@@ -85,30 +85,31 @@ class Quantizer(ABC):
Returns: Returns:
True if using both row-wise and column-wise quantization True if using both row-wise and column-wise quantization
""" """
return self.q_axis == QuantizeAxis.ROWWISE_COLWISE return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
@abstractmethod @abstractmethod
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout. """Get the data data_layout.
Returns: Returns:
Data layout in string format Data data_layout in string format
""" """
@abstractmethod @abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Core quantization function to be implemented by subclasses. """Core quantization function to be implemented by subclasses.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype dq_dtype: Data type for dequantized values, default is x.dtype
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None): def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1):
"""Quantize a tensor using the internal _quantize_func(). """Quantize a tensor using the internal _quantize_func().
Args: Args:
...@@ -116,21 +117,26 @@ class Quantizer(ABC): ...@@ -116,21 +117,26 @@ class Quantizer(ABC):
is_rowwise: Whether to use row-wise quantization is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
if (is_rowwise and is_colwise) or self.is_2x2x(): if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) colwise_tensor = self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise: if is_colwise:
return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype) return self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return self._quantize_func(x, dq_dtype=dq_dtype) return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def get_scale_shapes(self, data_shape, is_padded=True): def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1):
"""Get shapes for scale tensors. """Get shapes for scale tensors.
Args: Args:
...@@ -140,7 +146,7 @@ class Quantizer(ABC): ...@@ -140,7 +146,7 @@ class Quantizer(ABC):
Returns: Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape) Tuple of (rowwise_scale_shape, colwise_scale_shape)
""" """
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded) return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
def get_scale_dtype(self): def get_scale_dtype(self):
"""Get the data type for scale tensors. """Get the data type for scale tensors.
...@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor scale: Current scaling factor
amax_history: History of maximum absolute values amax_history: History of maximum absolute values
""" """
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field( amax_history: jnp.ndarray = field(
...@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations Tuple of (children, aux_data) for tree operations
""" """
children = (self.scale, self.amax_history) children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis) aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data) return (children, aux_data)
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout string. """Get the data data_layout string.
Returns: Returns:
Data layout in string format Data data_layout in string format
Raises: Raises:
ValueError: If quantization axis is invalid ValueError: If quantization axis is invalid
""" """
layout = "NT" data_layout = "NT"
if self.q_axis == QuantizeAxis.ROWWISE_COLWISE: if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return layout return data_layout
if self.q_axis == QuantizeAxis.ROWWISE: if self.q_layout == QuantizeLayout.ROWWISE:
return layout[0] return data_layout[0]
if self.q_axis == QuantizeAxis.COLWISE: if self.q_layout == QuantizeLayout.COLWISE:
return layout[1] return data_layout[1]
raise ValueError(f"Invalid q_axis: {self.q_axis}") raise ValueError(f"Invalid q_layout: {self.q_layout}")
def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8. """Quantize function helper for delayed scaling FP8.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
...@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
scale_inv=scale_inv, scale_inv=scale_inv,
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None): def quantize(
self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
):
"""Quantize a tensor using the internal _quantize_func(). """Quantize a tensor using the internal _quantize_func().
Args: Args:
...@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer): ...@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
is_rowwise: Whether to use row-wise quantization is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = ( is_rowwise = (
is_rowwise is_rowwise
if is_rowwise is not None if is_rowwise is not None
else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x()) else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
) )
is_colwise = ( is_colwise = (
is_colwise is_colwise
if is_colwise is not None if is_colwise is not None
else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x()) else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
) )
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = None colwise_tensor = None
if is_colwise: if is_colwise:
colwise_tensor = ScaledTensorFactory.create_1x( colwise_tensor = ScaledTensorFactory.create_1x(
data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))), data=jnp.transpose(
rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
),
scale_inv=rowwise_tensor.scale_inv, scale_inv=rowwise_tensor.scale_inv,
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
is_colwise=True, is_colwise=True,
layout="T", data_layout="T",
flatten_axis=flatten_axis,
) )
if is_colwise and is_rowwise: if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
...@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer): ...@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
Attributes: Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE) q_layout: Quantization axis (default: ROWWISE_COLWISE)
""" """
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data layout string. """Get the data data_layout string.
Returns: Returns:
Data layout in string format Data data_layout in string format
""" """
if self.is_2x2x(): if self.is_2x2x():
return "NN" return "NN"
return "N" return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x: def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8. """Quantize function helper for block scaling FP8.
Args: Args:
x: Input tensor to quantize x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns: Returns:
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
# TODO(Phuong): use quantize_func from JAX # TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
assert (
0 <= flatten_axis < x.ndim
), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False) scale_shape = self.scaling_mode.get_scale_shape(
x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale_dtype = self.scaling_mode.get_scale_dtype() scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape( x = x.reshape(
*x_shape[:-2], *x_shape[: flatten_axis - 1],
scale_shape[-2], scale_shape[flatten_axis - 1],
int(x_shape[-2] / scale_shape[-2]), int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*x_shape[flatten_axis:-1],
scale_shape[-1], scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]), int(x_shape[-1] / scale_shape[-1]),
) )
amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True) amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32) MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
scales = amax.astype(jnp.float32) / MAX scales = amax.astype(jnp.float32) / MAX
...@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer): ...@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
self.scaling_mode, self.scaling_mode,
is_colwise=is_colwise, is_colwise=is_colwise,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
def _cast_to_e8m0_with_rounding_up(self, scales): def _cast_to_e8m0_with_rounding_up(self, scales):
...@@ -509,7 +539,7 @@ class QuantizerFactory: ...@@ -509,7 +539,7 @@ class QuantizerFactory:
n_quantizers: int = 1, n_quantizers: int = 1,
scaling_mode: ScalingMode = None, scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None, q_dtype: jnp.dtype = None,
q_axis: QuantizeAxis = None, q_layout: QuantizeLayout = None,
**kwargs, **kwargs,
) -> Quantizer: ) -> Quantizer:
"""Create one or more quantizers with specified parameters. """Create one or more quantizers with specified parameters.
...@@ -518,7 +548,8 @@ class QuantizerFactory: ...@@ -518,7 +548,8 @@ class QuantizerFactory:
n_quantizers: Number of quantizers to create n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use scaling_mode: Scaling mode to use
q_dtype: Quantization data type q_dtype: Quantization data type
q_axis: Quantization axis q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
Returns: Returns:
...@@ -534,7 +565,7 @@ class QuantizerFactory: ...@@ -534,7 +565,7 @@ class QuantizerFactory:
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode) quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append( quantizers.append(
quantizer_type( quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
) )
) )
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
...@@ -554,11 +585,11 @@ class QuantizerFactory: ...@@ -554,11 +585,11 @@ class QuantizerFactory:
A QuantizerSet instance A QuantizerSet instance
""" """
if is_2x2x: if is_2x2x:
q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else: else:
q_axis_x = QuantizeAxis.ROWWISE q_layout_x = QuantizeLayout.ROWWISE
q_axis_kernel = QuantizeAxis.COLWISE q_layout_kernel = QuantizeLayout.COLWISE
q_axis_dgrad = None q_layout_dgrad = None
if "quantize_meta_set" in kwargs: if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set") quantize_meta_set = kwargs.get("quantize_meta_set")
...@@ -577,9 +608,11 @@ class QuantizerFactory: ...@@ -577,9 +608,11 @@ class QuantizerFactory:
else: else:
args_x = args_kernel = args_grad = {} args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x) q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x)
q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel) q_kernel = QuantizerFactory.create(
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad) 1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel
)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment