Unverified Commit ff884e20 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flatten_axis for quantization and Sharding propagation fixes (#1644)



* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout

* add fatten_axis option

* added gated act to test encoder

* sharding constraint fixes

* fix padding when flattening first dim needs to be padded

* update test sizes so that padding is tested

* rm output sharding as it can be done in the flax module

* sharding scale_inv for mxfp8

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent be1f647c
......@@ -57,13 +57,14 @@ class Net(nn.Module):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
# Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
......@@ -459,7 +460,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -467,7 +468,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -475,14 +476,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
......@@ -491,7 +492,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
......@@ -500,7 +501,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
if __name__ == "__main__":
......
......@@ -29,7 +29,7 @@ from transformer_engine.jax.quantize import (
ScaledTensor,
ScalingMode,
QuantizerFactory,
QuantizeAxis,
QuantizeLayout,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
......@@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
if isinstance(a, ScaledTensor1x):
if a.layout == "T":
b_transpose = jnp.transpose(b, (-1, *range(b.ndim - 1)))
if a.data_layout == "T":
flatten_axis = a.data.ndim - a.flatten_axis
b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
else:
assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
......@@ -141,7 +142,8 @@ class TestActivation:
def test_act_grad(self, shape, activation_type):
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, jnp.float32)
x = jnp.repeat(x, len(activation_type), axis=-1)
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
......@@ -159,7 +161,8 @@ class TestActivation:
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_act_grad_with_delayed_scaling_fp8(self, random_inputs, activation_type, output_type):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1)
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type
value_n_grad_primitive_func = jit(
......@@ -169,7 +172,7 @@ class TestActivation:
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
q_dtype=output_type,
q_axis=QuantizeAxis.ROWWISE,
q_layout=QuantizeLayout.ROWWISE,
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
......@@ -182,19 +185,22 @@ class TestActivation:
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_delayed_scaling_fp8(
self, random_inputs, activation_type, output_type, q_axis
self, random_inputs, activation_type, output_type, q_layout
):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1)
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
q_dtype=output_type,
q_axis=q_axis,
q_layout=q_layout,
)
te_output = tex.act_lu(x, activation_type, te_quantizer)
......@@ -203,19 +209,21 @@ class TestActivation:
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest_parametrize_wrapper("shape", [(128, 128)])
@pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_block_scaling_fp8(
self, random_inputs, activation_type, output_type, q_axis
self, random_inputs, activation_type, output_type, q_layout
):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-1)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_axis=q_axis
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)
output = tex.act_lu(x, activation_type, quantizer)
......@@ -324,9 +332,11 @@ class TestNorm:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
# No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_grad_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
):
"""
Test transformer_engine.jax.layernorm.layernorm
......@@ -335,7 +345,9 @@ class TestNorm:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, q_dtype=out_dtype, q_axis=q_axis
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
q_dtype=out_dtype,
q_layout=q_layout,
)
self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
......@@ -351,7 +363,7 @@ class TestNorm:
inp_dtype,
out_dtype,
scaling_mode,
q_axis,
q_layout,
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
......@@ -363,7 +375,7 @@ class TestNorm:
gamma = jnp.asarray(gamma, inp_dtype)
quantizer, ref_quantizer = QuantizerFactory.create(
n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_axis=q_axis
n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
)
if norm_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
......@@ -391,9 +403,11 @@ class TestNorm:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
# No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_norm_forward_with_delayed_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_axis
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype, q_layout
):
if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
......@@ -407,7 +421,7 @@ class TestNorm:
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
q_axis=q_axis,
q_layout=q_layout,
)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
......@@ -424,7 +438,7 @@ class TestNorm:
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING,
q_axis=QuantizeAxis.ROWWISE_COLWISE,
q_layout=QuantizeLayout.ROWWISE_COLWISE,
)
......@@ -434,14 +448,14 @@ QUANTIZE_OUTPUT_DTYPES = {
}
ALL_QUANTIZE_TEST_SHAPES = [
(128, 128),
(4, 256, 512),
(32, 64),
(2, 64, 32),
]
QUANTIZE_TEST_SHAPES = {
"L0": [
(256, 128),
(64, 16, 2, 256),
(32, 256, 128),
(64, 32, 32, 256),
],
"L2": ALL_QUANTIZE_TEST_SHAPES,
}
......@@ -457,48 +471,52 @@ QUANTIZATION_INPUT_DTYPE = {
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape", ALL_QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
@pytest_parametrize_wrapper(
"q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE]
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
class TestQuantize:
"""
Purely quantization related tests that will always test on a wider set of types and shapes
"""
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis):
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
key = jax.random.PRNGKey(0)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode,
q_dtype=q_dtype,
q_axis=q_axis,
q_layout=q_layout,
)
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
n_iterations = 3 if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype)
scaled_tensor = quantizer.quantize(x)
scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
assert_dequantized_scaled_tensor(scaled_tensor, x)
def test_quantize_bitwise(self, in_dtype, input_shape, q_dtype, scaling_mode, q_axis):
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
input_shape
def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
key = jax.random.PRNGKey(0)
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
)
jax_output = _jax_quantize(input, quantizer=jax_quantizer)
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = tex.quantize(input, quantizer=te_quantizer)
assert_bitwise_scaled_tensors(jax_output, te_output)
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
......@@ -508,9 +526,13 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape", QUANTIZE_TEST_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.ROWWISE, QuantizeAxis.ROWWISE_COLWISE])
def test_quantize_dbias(self, in_dtype, input_shape, out_dtype, scaling_mode, q_axis):
transpose_axis = -1
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
@pytest_parametrize_wrapper("flatten_axis", [-1, -2])
def test_quantize_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, q_layout, flatten_axis
):
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING and not is_shape_supported_by_mxfp8(
input_shape
):
......@@ -520,35 +542,37 @@ class TestFusedQuantize:
input = jax.random.uniform(key, input_shape, in_dtype)
jax_quantizer, te_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
)
te_output, te_dbias = jit(lambda input: tex.quantize_dbias(input, quantizer=te_quantizer))(
input
te_output, te_dbias = jit(
lambda input: tex.quantize_dbias(
input, quantizer=te_quantizer, flatten_axis=flatten_axis
)
)(input)
jax_output, jax_dbias = jit(
lambda input: _jax_quantize_dbias(
input,
quantizer=jax_quantizer,
input, quantizer=jax_quantizer, flatten_axis=flatten_axis
)
)(input)
assert_bitwise_scaled_tensors(jax_output, te_output)
assert_bitwise_scaled_tensors(te_output, jax_output)
assert_allclose(jax_dbias, te_dbias)
assert_allclose(te_dbias, jax_dbias)
def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_axis
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
x = jnp.repeat(x, len(activation_type), axis=-1)
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
jax_quantizer, te_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_axis=q_axis
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
)
is_casted_output = te_quantizer is not None
......@@ -573,12 +597,12 @@ class TestFusedQuantize:
)(dz, x)
if is_casted_output:
assert_bitwise_scaled_tensors(jax_output, te_output)
assert_bitwise_scaled_tensors(te_output, jax_output)
else:
assert_allclose(jax_output, te_output)
assert_allclose(te_output, jax_output)
if is_dbias:
assert_allclose(jax_dbias, te_dbias)
assert_allclose(te_dbias, jax_dbias)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
......@@ -597,7 +621,7 @@ class TestFusedQuantize:
scaling_mode=ScalingMode.NVTE_NO_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_axis=QuantizeAxis.ROWWISE,
q_layout=QuantizeLayout.ROWWISE,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -605,9 +629,11 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_delayed_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
):
self._test_quantize_dact_dbias(
in_dtype=in_dtype,
......@@ -616,7 +642,7 @@ class TestFusedQuantize:
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_axis=q_axis,
q_layout=q_layout,
)
@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
......@@ -626,9 +652,11 @@ class TestFusedQuantize:
)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper("q_axis", [QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_quantize_dact_dbias_mxfp8_scaling(
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_axis
self, in_dtype, input_shape, out_dtype, activation_type, is_dbias, q_layout
):
if reduce(operator.mul, input_shape[:-1]) % 128 != 0 or input_shape[-1] % 128 != 0:
# TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
......@@ -645,75 +673,75 @@ class TestFusedQuantize:
scaling_mode=ScalingMode.NVTE_MXFP8_1D_SCALING,
activation_type=activation_type,
is_dbias=is_dbias,
q_axis=q_axis,
q_layout=q_layout,
)
class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, layout):
if layout[0] == "T":
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
a = jnp.swapaxes(a, -1, -2)
if layout[1] == "T":
if data_layout[1] == "T":
b = jnp.swapaxes(b, -1, -2)
return jnp.dot(a, b)
def _generate_gemm_input(self, m, n, k, layout):
def _generate_gemm_input(self, m, n, k, data_layout):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(
subkeys[0],
(m if layout[0] == "N" else k, k if layout[0] == "N" else m),
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=jnp.bfloat16,
) / jnp.sqrt(k)
w = jax.random.uniform(
subkeys[1],
(k if layout[1] == "N" else n, n if layout[1] == "N" else k),
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=jnp.bfloat16,
) / jnp.sqrt(n)
lhs_contracting_dim = (1,) if layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if layout[1] == "N" else (1,)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return (x, w, contracting_dims)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)])
@pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"])
def test_gemm_bf16(self, m, n, k, layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_bf16(self, m, n, k, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
primitive_out = tex.gemm(x, w, contracting_dims)
ref_out = self._ref_gemm_with_jnp_dot(x, w, layout)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)])
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
)
primitive_out = tex.gemm(
x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, layout)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=q_dtype)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)])
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k):
layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout)
data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
def primitive_func(x, w, contracting_dims):
primitive_out = dense(x, w, contracting_dims=contracting_dims)
return jnp.mean(primitive_out)
def ref_func(x, w, layout):
return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, layout))
def ref_func(x, w, data_layout):
return jnp.mean(self._ref_gemm_with_jnp_dot(x, w, data_layout))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
......@@ -722,19 +750,19 @@ class TestDense:
primitive_out, (primitive_x_grad, primitive_w_grad) = value_n_grad_primitive_func(
x, w, contracting_dims
)
ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, layout)
ref_out, (ref_x_grad, ref_w_grad) = value_n_grad_ref_func(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("m,n,k", [(512, 128, 256)])
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, layout)
data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
key = jax.random.PRNGKey(1)
bias = jax.random.uniform(key, n, dtype=jnp.bfloat16)
......@@ -745,9 +773,9 @@ class TestDense:
)
return jnp.mean(primitive_out)
def ref_func(x, w, bias, layout):
def ref_func(x, w, bias, data_layout):
return jnp.mean(
self._ref_gemm_with_jnp_dot(x, w, layout) + jnp.expand_dims(bias, axis=0)
self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0)
)
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
......@@ -763,7 +791,9 @@ class TestDense:
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
)
ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(x, w, bias, layout)
ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
x, w, bias, data_layout
)
assert_allclose(primitive_out, ref_out, dtype=q_dtype)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
......@@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", [(512, 128, 128)])
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
......@@ -873,7 +903,7 @@ class TestFusedDense:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("m,n,k", [(512, 128, 256)])
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
......@@ -898,13 +928,13 @@ class TestFusedDense:
x = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
kernel_1 = jax.random.normal(
subkeys[1], (k, len(activation_type) * n), jnp.bfloat16
subkeys[1], (k, len(activation_type), n), jnp.bfloat16
) / jnp.sqrt(k)
kernel_2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) / jnp.sqrt(n)
gamma = jax.random.normal(subkeys[5], (k,), jnp.bfloat16)
beta = None # was tested in TestNorm
if use_bias:
bias_1 = jax.random.normal(subkeys[3], (len(activation_type) * n), jnp.bfloat16)
bias_1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16)
bias_2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16)
else:
bias_1 = None
......@@ -1039,19 +1069,19 @@ class TestGroupedDense:
subkeys = jax.random.split(key, len(shape_list) * 2)
lhs_list, rhs_list, contracting_dims_list = [], [], []
for i, ((m, n, k), layout) in enumerate(zip(shape_list, layout_list)):
for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
lhs = jax.random.uniform(
subkeys[2 * i],
(m if layout[0] == "N" else k, k if layout[0] == "N" else m),
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype,
)
rhs = jax.random.uniform(
subkeys[2 * i + 1],
(k if layout[1] == "N" else n, n if layout[1] == "N" else k),
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=dtype,
)
lhs_contracting_dim = (1,) if layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if layout[1] == "N" else (1,)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
lhs_list.append(lhs)
......
......@@ -45,11 +45,17 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in]
INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64
......@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
......@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype
subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype)
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
else:
b1 = None
......@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
layernorm_input_axes = LAYERNORM_INPUT_AXES
dot_1_input_axes = DOT_1_INPUT_AXES
dot_2_input_axes = DOT_2_INPUT_AXES
kernel_1_axes = KERNEL_1_AXES
kernel_2_axes = KERNEL_2_AXES
else:
layernorm_input_axes = None
dot_1_input_axes = None
dot_2_input_axes = None
dot_1_input_axes = dot_2_input_axes = None
kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
......@@ -130,6 +137,8 @@ class TestDistributedLayernormMLP:
norm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
kernel_1_axes=kernel_1_axes,
kernel_2_axes=kernel_2_axes,
activation_type=activation_type,
quantizer_sets=quantizer_sets,
)
......@@ -142,7 +151,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_primitive(
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
......@@ -168,12 +177,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp"))
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec("tp"))
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
......@@ -267,16 +276,7 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
......@@ -295,13 +295,13 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
scale_axes=LN_SCALE_AXES,
ln_bias_axes=LN_BIAS_AXES,
kernel_axes_1=KERNEL_1_AXES,
kernel_axes_2=KERNEL_2_AXES,
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
bias_axes_1=BIAS_1_AXES,
bias_axes_2=BIAS_2_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
......@@ -334,7 +334,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_layer(
def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
self._test_layernorm_mlp(
......
......@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
......
......@@ -26,12 +26,12 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding,
)
from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias
from .quantization import _jax_dbias, _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
)
......@@ -110,38 +110,28 @@ class ActLuPrimitive(BasePrimitive):
"""
te_act_lu_p abstract
"""
del act_enum, act_len, scale_shapes
del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
out_shape = (
*x_aval.shape[:-2],
1,
x_aval.shape[-1],
assert x_aval.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x_aval.shape} and act_len {act_len}"
)
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer)
if len(rowwise_scale_inv_shape) > 1:
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if len(colwise_scale_inv_shape) > 1:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
if is_2x:
).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
if not is_2x:
out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
......@@ -211,15 +201,8 @@ class ActLuPrimitive(BasePrimitive):
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if is_2x:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
# Slice out padding for MXFP8, noop for DelayedScaling
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
......@@ -227,6 +210,7 @@ class ActLuPrimitive(BasePrimitive):
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod
......@@ -292,11 +276,14 @@ class ActLuPrimitive(BasePrimitive):
is_outer,
) # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_spec = (*x_spec[:-2], None, x_spec[-2])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-2], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec)
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else:
colwise_out_spec = out_spec
else:
......@@ -304,18 +291,24 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
)
return (
out_sharding,
colwise_out_sharding,
......@@ -340,14 +333,14 @@ class ActLuPrimitive(BasePrimitive):
):
del result_infos, is_outer # Unused.
x_spec = get_padded_spec(arg_infos[0])
out_spec = (*x_spec[:-1], x_spec[-1])
if act_len == 2 and x_spec[-1] is None:
# Ensure last axis is partitioned and not the gating axis
out_spec = (*x_spec[:-2], None, x_spec[-2])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-2], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec)
colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else:
colwise_out_spec = out_spec
else:
......@@ -355,21 +348,25 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
)
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec))
arg_shardings = tuple(arg_shardings)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -413,6 +410,7 @@ class ActLuPrimitive(BasePrimitive):
register_primitive(ActLuPrimitive)
# TODO(Jeremy): replace is_2x with q_layout
class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
DActLu DBias Cast Transpose Primitive
......@@ -445,42 +443,41 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p abstract
"""
del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype
assert x_aval.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x_aval.shape} and act_len {act_len}"
)
assert scale_aval.dtype == jnp.float32
ir_hidden_size = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1]
gi_hidden_size = act_len * x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size
out_shape = x_aval.shape
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x:
# Don't transpose output for MXFP8
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
t_shape = out_shape
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else:
colwise_out_shape = out_shape
else:
t_shape = multidim_transpose(out_shape)
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
if is_dbias:
dbias_shape = gi_hidden_size
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
dbias_shape = (act_len, ir_hidden_size)
(wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
......@@ -489,9 +486,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
else:
dbias_shape = (1,)
wkspace_shape = (1,)
wkspace_dtype = jnp.float32
dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
return (
out_aval,
......@@ -587,8 +589,8 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
# Slice out padding for MXFP8, noop for DelayedScaling
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
......@@ -596,14 +598,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) # Exclude wkspace
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
@staticmethod
def batcher(
......@@ -670,15 +665,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
result_infos,
):
del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, is_dbias, act_len, is_outer
del scale_dtype, scale_shapes, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else:
colwise_x_spec = x_spec
else:
......@@ -687,23 +683,32 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding = NamedSharding(
dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(*colwise_scale_inv_spec),
desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
)
return (
out_sharding,
......@@ -711,7 +716,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_shaprding,
dbias_sharding,
)
@staticmethod
......@@ -731,10 +736,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out")
scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec)
colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else:
colwise_x_spec = x_spec
else:
......@@ -743,38 +753,39 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding = NamedSharding(
dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias",
)
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = (
arg_shardings[1],
arg_shardings[1],
*arg_shardings[2:],
) # dz and x are the same
out_shardings = (
out_sharding,
colwise_out_sharding,
scale_inv_sharding,
colwise_scale_inv_sharding,
amax_sharding,
dbias_shaprding,
dbias_sharding,
)
def sharded_impl(dz, x, scale):
......@@ -816,14 +827,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
"""
JAX native activation implementation
"""
x = jnp.split(inputs, len(activation_type), axis=-1)
act_len = len(activation_type)
assert inputs.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {inputs.shape} and act_len {act_len}"
)
x = jnp.split(inputs, act_len, axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
if quantizer:
return quantizer.quantize(x)
return quantizer.quantize(x, flatten_axis=-1)
return x
......@@ -837,6 +855,12 @@ def _jax_quantize_dact_dbias(
"""
JAX implementation of dact_lu and dbias with optional quantization
"""
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
_, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
)
......@@ -844,10 +868,10 @@ def _jax_quantize_dact_dbias(
dbias = None
if is_dbias:
dbias = _jax_dbias(dx).astype(x.dtype)
dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
if quantizer is not None:
dx = quantizer.quantize(dx, dq_dtype=x.dtype)
dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
else:
dx = dx.astype(x.dtype)
......@@ -863,6 +887,7 @@ def act_lu(
Args:
x: Input tensor to be processed.
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
......@@ -873,12 +898,17 @@ def act_lu(
A ScaledTensor containing the quantized activated input.
"""
act_type_id = ActivationEnum[activation_type].value
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support 2x quantization for DelayedScaling yet
......@@ -889,16 +919,15 @@ def act_lu(
return war_output
scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type))
output_shape = (*x.shape[:-2], x.shape[-1])
if quantizer is None:
x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type)))
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
x,
scale,
out_dtype=x.dtype,
act_enum=act_type_id,
act_len=len(activation_type),
act_len=act_len,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
......@@ -911,7 +940,6 @@ def act_lu(
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type)))
(
rowwise_casted_output,
colwise_casted_output,
......@@ -923,25 +951,15 @@ def act_lu(
scale,
out_dtype=quantizer.q_dtype,
act_enum=act_type_id,
act_len=len(activation_type),
act_len=act_len,
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(output_shape),
# output does not have act axis
scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
is_outer=True,
)
rowwise_casted_output = rowwise_casted_output.reshape(output_shape)
if len(rowwise_scale_inv.shape) > 1:
rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2) # Remove act axis
if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE):
colwise_output_shape = output_shape
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_output_shape = multidim_transpose(output_shape)
colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape)
if len(colwise_scale_inv.shape) > 1:
colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2) # Remove act axis
quantizer.update(updated_amax)
return ScaledTensorFactory.create(
......@@ -951,8 +969,8 @@ def act_lu(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
)
......@@ -968,7 +986,7 @@ def quantize_dact_dbias(
Args:
dz: Gradient of the output with respect to the activation output.
x: Input tensor that was processed by the forward pass.
Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
......@@ -979,21 +997,25 @@ def quantize_dact_dbias(
- The gradient of the activation with respect to the bias.
"""
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
if not DActLuDBiasQuantizePrimitive.enabled():
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = quantize_dact_dbias(
dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None
)
return quantize_dbias(out, is_dbias=True, quantizer=quantizer)
out = dact_lu(dz, x, activation_type, quantizer=None)
return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
is_gated = len(activation_type) == 2
is_gated = act_len == 2
# TE/common does not support DelayedScaling2x for gated-act yet
if is_gated:
war_output = try_apply_delayed_scaling_2x_war(
......@@ -1003,6 +1025,7 @@ def quantize_dact_dbias(
activation_type=activation_type,
is_dbias=is_dbias,
quantizer=quantizer,
flatten_axis=-2,
)
if war_output is not None:
return war_output
......@@ -1025,12 +1048,12 @@ def quantize_dact_dbias(
scale_shapes=((), ()), # unused
is_dbias=False,
act_enum=act_type_id,
act_len=len(activation_type),
act_len=act_len,
is_outer=True,
)
dbias = None
if is_dbias:
dbias = _jax_dbias(output).astype(x.dtype)
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias
if isinstance(quantizer, DelayedScaleQuantizer):
......@@ -1041,15 +1064,8 @@ def quantize_dact_dbias(
dgated = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
)
# TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype)
else:
out, dbias = quantize_dbias(
dgated,
quantizer=quantizer,
is_dbias=True,
dq_dtype=x.dtype,
out, dbias = _quantize_dbias_impl(
dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
)
return out, dbias
......@@ -1070,10 +1086,11 @@ def quantize_dact_dbias(
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(out_shape),
# output has act axis
scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
is_dbias=is_dbias,
act_enum=act_type_id,
act_len=len(activation_type),
act_len=act_len,
is_outer=True,
)
......@@ -1090,8 +1107,9 @@ def quantize_dact_dbias(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=-2, # as output has act axis
)
return out, dbias
......
......@@ -6,9 +6,9 @@
from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce
import operator
from transformer_engine_jax import get_device_compute_capability
import jax
import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive
......@@ -183,10 +183,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T")
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
......@@ -203,9 +202,9 @@ def _jax_gemm_delayed_scaling_fp8(
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T":
if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T":
if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch)
......@@ -403,19 +402,19 @@ def grouped_gemm(
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T":
if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T":
if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN
# For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
......@@ -432,8 +431,8 @@ def grouped_gemm(
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T")
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
......
......@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout
TEDType = transformer_engine_jax.DType
......@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1):
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
"""
te_cast_transpose_p multi-dims transpose
static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
involved into transpose, -1 means all axes involve into transpose.
transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary
transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
X in shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis_boundary == 2
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 2
static_axis_boundary == 0, transpose_axis == 2
Xt = (dim0, dim2, dim3, dim4, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 3
static_axis_boundary == 0, transpose_axis == 3
Xt = (dim0, dim3, dim4, dim1. dim2)
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose.
transpose_start_idx = static_axis_boundary + 1
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape))
assert transpose_start_idx < transpose_axis_boundary
transpose_axis = normalize_axis_boundary(transpose_axis, len(shape))
assert transpose_start_idx < transpose_axis
return (
*shape[:transpose_start_idx],
*shape[transpose_axis_boundary:],
*shape[transpose_start_idx:transpose_axis_boundary],
*shape[transpose_axis:],
*shape[transpose_start_idx:transpose_axis],
)
......@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break
return (
quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE
and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100
and is_dbias
)
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs):
"""
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
......@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
# 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE
quantizer.q_layout = QuantizeLayout.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None
if isinstance(rowwise, tuple):
other_outputs = rowwise[1:]
rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1)))
quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
if flatten_axis < 0:
flatten_axis += rowwise.data.ndim
assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
colwise_data = jnp.transpose(
rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis))
)
output_2x = ScaledTensorFactory.create(
data=rowwise.data,
scale_inv=rowwise.scale_inv,
......@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE,
layout=quantizer.get_layout(),
q_layout=QuantizeLayout.ROWWISE_COLWISE,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
)
if other_outputs is not None:
return (output_2x,) + other_outputs
......
......@@ -30,7 +30,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
)
......@@ -277,13 +277,13 @@ class NormFwdPrimitive(BasePrimitive):
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x.shape, is_padded=False
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv = scale_inv.flatten()[
: reduce(operator.mul, rowwise_scale_inv_shape)
].reshape(rowwise_scale_inv_shape)
# slice out padding for mxfp8, noop for DelayedScaling
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape)
: reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape)
return (
out,
......@@ -816,7 +816,7 @@ def layernorm_fwd(
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -900,8 +900,8 @@ def layernorm_fwd(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
)
return scaled_tensor, mu, rsigma
......@@ -997,7 +997,7 @@ def rmsnorm_fwd(
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -1082,8 +1082,8 @@ def rmsnorm_fwd(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
)
return scaled_tensor, rsigma
......
......@@ -2,6 +2,8 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
import operator
from functools import reduce
from typing import Tuple, Optional
from packaging import version
......@@ -24,7 +26,7 @@ from .misc import (
)
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
......@@ -50,7 +52,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
6,
7,
8,
) # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer
9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive = None
outer_primitive = None
......@@ -61,7 +64,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*,
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -73,49 +77,52 @@ class DBiasQuantizePrimitive(BasePrimitive):
del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape
assert scale_aval is None or scale_aval.dtype == jnp.float32
rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
rowwise_out_shape = out_shape
else:
rowwise_out_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
t_shape = multidim_transpose(x_aval.shape)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
# Don't transpose output for MXFP8
t_shape = x_aval.shape
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
if is_dbias:
gi_hidden_size = x_aval.shape[-1]
dbias_shape = (gi_hidden_size,)
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
dbias_shape = x_aval.shape[flatten_axis:]
gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
(wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size,
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
wkspace_shape = wkspace_info[0]
wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
else:
dbias_shape = (1,)
wkspace_shape = (1,)
wkspace_dtype = jnp.float32
dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
return (
rowwise_out_aval,
......@@ -151,7 +158,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*,
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -169,7 +177,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
x,
scale,
scaling_mode=scaling_mode,
q_axis=q_axis,
q_layout=q_layout,
flatten_axis=flatten_axis,
is_dbias=is_dbias,
)
......@@ -179,7 +188,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale,
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -203,7 +213,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
......@@ -211,13 +222,11 @@ class DBiasQuantizePrimitive(BasePrimitive):
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
......@@ -237,7 +246,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*,
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -260,7 +270,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
......@@ -272,7 +283,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def infer_sharding_from_operands(
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -281,16 +293,17 @@ class DBiasQuantizePrimitive(BasePrimitive):
arg_infos,
result_infos,
):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer) # Unused.
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]),
PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec)
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
else:
......@@ -300,26 +313,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
)
scale_inv_sharding = NamedSharding(
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])),
desc="DBiasQuantizePrimitive.scale_inv",
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding"
PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding",
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
)
dbias_sharding = NamedSharding(
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DBiasQuantizePrimitive.dbias_sharding",
PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv",
)
return (
out_sharding,
colwise_out_sharding,
......@@ -333,7 +355,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def partition(
out_dtype,
scaling_mode,
q_axis,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
......@@ -344,14 +367,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]),
PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec)
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
else:
......@@ -361,26 +385,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding",
)
scale_inv_sharding = NamedSharding(
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])),
desc="DBiasQuantizePrimitive.scale_inv",
)
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding"
PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.dbias_sharding",
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
amax_sharding = NamedSharding(
mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
)
dbias_sharding = NamedSharding(
colwise_scale_inv_sharding = NamedSharding(
mesh,
PartitionSpec(x_spec[-1]),
desc="DBiasQuantizePrimitive.dbias_sharding",
PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.colwise_scale_inv",
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
out_sharding,
......@@ -404,7 +437,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_axis=q_axis,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
......@@ -436,49 +470,45 @@ class DBiasQuantizePrimitive(BasePrimitive):
register_primitive(DBiasQuantizePrimitive)
def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None):
def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
if quantizer is None:
return x
return quantizer.quantize(x, dq_dtype=dq_dtype)
return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def _jax_dbias(dx: jnp.ndarray):
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
assert flatten_axis < 0
dtype = dtype or dx.dtype
dbias = jnp.sum(
dx,
axis=tuple(range(dx.ndim - 1)),
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim + flatten_axis)),
keepdims=False,
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias
return dbias.astype(dtype)
def _jax_quantize_dbias(
x,
quantizer: Quantizer = None,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
):
if quantizer is None:
return x, None
return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x)
def _jax_dbias(
dx: jnp.ndarray,
):
dbias = jnp.sum(
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim - 1)),
keepdims=False,
return (
quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
_jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
)
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias.astype(dx.dtype)
def _quantize_impl(
def _quantize_dbias_impl(
x: jnp.ndarray,
quantizer: Quantizer,
is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -488,40 +518,51 @@ def _quantize_impl(
quantizer is not None
), "quantizer must be provided if dq_dtype is provided"
dq_dtype = dq_dtype or x.dtype
if not DBiasQuantizePrimitive.enabled():
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
# TE/common doesn't support colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_impl(
out, _ = _quantize_dbias_impl(
x=x,
is_dbias=False,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
dbias = _jax_dbias(x)
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
if quantizer is None:
if is_dbias:
return x, _jax_dbias(x)
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None
if isinstance(quantizer, DelayedScaleQuantizer):
......@@ -539,9 +580,10 @@ def _quantize_impl(
scale,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_axis=quantizer.q_axis.value,
q_layout=quantizer.q_layout.value,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
is_dbias=is_dbias,
is_outer=True,
)
......@@ -557,18 +599,18 @@ def _quantize_impl(
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=dq_dtype if dq_dtype is not None else x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
dq_dtype=dq_dtype,
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
)
return out, dbias
return out, dbias.astype(dq_dtype)
# TODO(Phuong): do not expose dq_dtype to users
def quantize(
x: jnp.ndarray,
quantizer: Quantizer,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -576,26 +618,25 @@ def quantize(
x: Input tensor to be quantized.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
Returns:
A ScaledTensor containing the quantized input tensor.
"""
out, _ = _quantize_impl(
out, _ = _quantize_dbias_impl(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
return out
# TODO(Phuong): do not expose dq_dtype to users
def quantize_dbias(
dz: jnp.ndarray,
quantizer: Quantizer,
is_dbias: bool = True,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -604,8 +645,8 @@ def quantize_dbias(
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
is_dbias: If True, compute bias gradient. Defaults to True.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
Returns:
A tuple containing:
......@@ -614,9 +655,6 @@ def quantize_dbias(
- The bias gradient tensor.
Shape: (K,) or empty if is_dbias is False.
"""
return _quantize_impl(
dz,
quantizer=quantizer,
is_dbias=is_dbias,
dq_dtype=dq_dtype,
return _quantize_dbias_impl(
dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
)
......@@ -11,14 +11,6 @@
#include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
namespace {
bool is_gated(NVTE_Activation_Type act_type) {
return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU ||
act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU ||
act_type == NVTE_Activation_Type::SREGLU;
}
} // namespace
namespace transformer_engine {
namespace jax {
......@@ -44,38 +36,56 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, act_len * n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
if (is_2x) {
output_tensor.set_columnwise_data(colwise_output, static_cast<DType>(out_dtype), output_shape);
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
}
switch (act_type) {
......@@ -162,8 +172,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
if (is_2x) {
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype,
output_trans_shape);
auto &tmp_shape = scaling_mode == static_cast<int>(NVTE_DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
......@@ -190,9 +202,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
bool is_dbias, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
......@@ -201,11 +213,15 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
void *workspace = workspace_buf->untyped_data();
......@@ -213,17 +229,18 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto act_input_dims = act_input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
auto input_ranks = input_dims.size();
auto act_input_ranks = act_input_dims.size();
auto m = product(act_input_dims, 0, act_input_dims.size() - 1);
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated
auto n = act_input_dims.back();
auto input_shape = std::vector<size_t>{m, input_dims.back()};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{m, n};
auto dbias_shape = std::vector<size_t>{n};
// n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto act_len = act_input_dims[act_input_dims.size() - 2];
NVTE_CHECK(act_input_dims.back() == input_dims.back(),
"Shape mismatch between activation input and gradient input");
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto output_trans_shape = std::vector<size_t>{n * act_len, m};
auto dbias_shape = std::vector<size_t>{n * act_len};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
......@@ -231,49 +248,55 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
if (is_2x) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x &&
is_gated(act_type)),
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(
!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) {
......
......@@ -44,12 +44,12 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
cudaStreamSynchronize(stream);
// Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair:
// Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair:
// cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be:
......
......@@ -34,7 +34,7 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
enum class QuantizeAxis {
enum class QuantizeLayout {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
......
......@@ -144,11 +144,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeAxis>(m, "QuantizeAxis",
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE)
.value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values();
}
......
......@@ -42,10 +42,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum,
int64_t quantize_axis_enum, bool is_dbias) {
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias,
int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -55,7 +55,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data();
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto const quantize_axis = static_cast<QuantizeAxis>(quantize_axis_enum);
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
......@@ -63,9 +63,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
void *workspace = workspace_buf->untyped_data();
auto input_dims = input_buf.dimensions();
int64_t input_ndim = input_dims.size();
if (flatten_axis < 0) flatten_axis += input_ndim;
NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");
auto workspace_dims = workspace_buf->dimensions();
auto m = product(input_dims, 0, input_dims.size() - 1);
auto n = input_dims.back();
auto m = product(input_dims, 0, flatten_axis);
auto n = product(input_dims, flatten_axis, input_ndim);
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
......@@ -75,37 +79,54 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
}
if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
......@@ -133,8 +154,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode")
.Attr<int64_t>("q_axis")
.Attr<bool>("is_dbias"),
.Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
......
......@@ -15,7 +15,11 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import QuantizerSet, noop_quantizer_set
from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
)
def dense(
......@@ -23,6 +27,8 @@ def dense(
kernel: jnp.ndarray,
bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -48,12 +54,12 @@ def dense(
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense(x, kernel, bias, contracting_dims, quantizer_set)
output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def _dense(x, kernel, bias, contracting_dims, quantizer_set):
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
......@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set)
output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
)
return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Tuple of (output, context) for backward pass
"""
x_contracting_dims, k_contracting_dims = contracting_dims
casted_x = tex.quantize(x, quantizer_set.x)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN
output = tex.gemm(
......@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
use_bias = bias is not None
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
......@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
kernel.shape,
use_bias,
quantizer_set,
flatten_axis_k,
)
return output, ctx
def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns:
Tuple of gradients with respect to inputs
"""
......@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
kernel_shape,
use_bias,
quantizer_set,
flatten_axis_k,
) = ctx
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
)
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN
# x_non_contracting_dims
......@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set
......
......@@ -28,6 +28,7 @@ from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -406,6 +407,10 @@ class DenseGeneral(TransformerEngineBase):
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
-----------------------
......@@ -429,6 +434,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
def __post_init__(self):
if self.kernel_init is None:
......@@ -460,29 +466,35 @@ class DenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
if self.kernel_axes:
assert len(kernel_shape) == len(self.kernel_axes), (
"Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
)
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
else:
bias = None
quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis)))
y = dense(
inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set
inputs,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation:
......@@ -491,20 +503,14 @@ class DenseGeneral(TransformerEngineBase):
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_init_shape = (
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
......@@ -527,7 +533,6 @@ class DenseGeneral(TransformerEngineBase):
y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype
y = y.reshape(*inputs.shape[: self.axis], *features)
return y
......@@ -678,6 +683,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
assert self.axis == -1, "Only support axis = =-1 at this moment"
input_dtype = inputs.dtype
ln_output = None
......@@ -692,10 +698,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type,
(features,),
......@@ -731,17 +734,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim)
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
contract_ind = tuple(range(0, len(axis)))
......@@ -756,11 +754,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set)
z = dense(
y,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (
......@@ -768,20 +774,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_init_shape = (
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
......@@ -803,8 +803,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
......@@ -814,7 +813,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
z = z / self.depth_scaling
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
z = z.reshape(*inputs.shape[: self.axis], *features)
# z = z.reshape(*inputs.shape[: self.axis], *features)
return z, ln_output # dense_output, layer_norm_output
......@@ -989,6 +988,8 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1")
......@@ -1027,7 +1028,6 @@ class LayerNormMLP(TransformerEngineBase):
)
# LayerNorm
if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1]
......@@ -1071,7 +1071,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes(
"wi_kernel",
kernel_1_init,
......@@ -1081,17 +1081,10 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
axes=self.kernel_axes_1,
)
kernel_1_compute_shape = (
reduce(operator.mul, [y.shape[ax] for ax in axis], 1),
num_activations * self.intermediate_dim,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)
if self.kernel_axes_1 is not None:
kernel_1 = with_sharding_constraint_by_logical_axes(
kernel_1, self.kernel_axes_1[:-2] + self.kernel_axes_1[-1:]
)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
......@@ -1102,27 +1095,20 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
axes=self.kernel_axes_2,
)
kernel_2_compute_shape = (
self.intermediate_dim,
reduce(operator.mul, hidden_size_tuple, 1),
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
if self.kernel_axes_2 is not None:
kernel_2 = with_sharding_constraint_by_logical_axes(kernel_2, self.kernel_axes_2)
contract_ind = tuple(range(0, len(axis)))
if self.use_bias:
bias_1_shape = num_activations * self.intermediate_dim
bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes(
"wi_bias",
self.bias_init,
bias_1_shape,
self.dtype,
axes=self.bias_axes_1,
)
bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
......@@ -1131,8 +1117,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2_shape,
self.dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
else:
bias_1 = None
bias_2 = None
......@@ -1141,8 +1126,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_mlp(
y,
scale,
......@@ -1155,6 +1138,8 @@ class LayerNormMLP(TransformerEngineBase):
norm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name,
activation_type=normalized_acts,
......@@ -1175,6 +1160,7 @@ class LayerNormMLP(TransformerEngineBase):
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
)
else:
......@@ -1183,35 +1169,31 @@ class LayerNormMLP(TransformerEngineBase):
y,
kernel_1,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
)
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = (
kernel_1_compute_shape[0],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_shape = (
kernel_1_each_shape[0],
num_activations,
wi_lora_a_kernel_each_shape = (
kernel_1_each_shape[: len(axis)],
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_each_shape = (
kernel_1_each_shape[0],
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
wi_lora_a_kernel = nn_partitioning.param_with_axes(
"wi_lora_a_kernel",
kernel_1_init,
num_activations,
-1,
wi_lora_a_kernel_init_each_shape,
-2,
wi_lora_a_kernel_each_shape,
self.dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = (
......@@ -1232,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation(
y,
axis,
num_activations * self.intermediate_dim,
(num_activations, self.intermediate_dim),
wi_lora_a_kernel,
wi_lora_b_kernel,
self.low_rank_adaptation_alpha,
......@@ -1246,11 +1228,12 @@ class LayerNormMLP(TransformerEngineBase):
z = activation(x, normalized_acts)
else:
activations = []
x = jnp.split(x, num_activations, axis=-1)
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = reduce(operator.mul, activations)
z = jnp.squeeze(z, axis=-2)
z = z.astype(input_dtype)
z = nn.Dropout(
......@@ -1264,7 +1247,12 @@ class LayerNormMLP(TransformerEngineBase):
# DenseGeneral 2
out = dense(
z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set
z,
kernel_2,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set,
)
if self.enable_low_rank_adaptation:
......
......@@ -33,10 +33,9 @@ def layernorm_dense(
norm_type: str = "layernorm",
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes: Tuple[str, ...] = None,
# The logic axes of sharding constraint to the dot input.
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
......@@ -56,6 +55,7 @@ def layernorm_dense(
epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: Set of quantizers for different tensor types
Returns:
......@@ -78,6 +78,7 @@ def layernorm_dense(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
)
return output
......@@ -91,6 +92,7 @@ def layernorm_dense(
7,
8,
9,
10,
),
)
def _layernorm_dense(
......@@ -104,6 +106,7 @@ def _layernorm_dense(
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
......@@ -139,6 +142,7 @@ def _layernorm_dense(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
)
return output
......@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
......@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
......@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
norm_type,
quantizer_set.x,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
......@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims,
use_bias,
quantizer_set,
flatten_axis,
)
return output, ctx
......@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
kernel_axes,
ctx,
grad,
):
......@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
k_contracting_dims_in_fwd,
use_bias,
quantizer_set,
flatten_axis,
) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes)
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple(
......@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
(x_constracting_dim, g_constracting_dim),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad,
x,
......
......@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .sharding import get_non_contracting_logical_axes
def layernorm_mlp(
......@@ -37,6 +38,8 @@ def layernorm_mlp(
norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
kernel_1_axes: Tuple[str, ...] = None,
kernel_2_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
......@@ -66,6 +69,8 @@ def layernorm_mlp(
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
kernel_1_axes: Logical axes for sharding the first weight matrix
kernel_2_axes: Logical axes for sharding the second weight matrix
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
......@@ -109,6 +114,8 @@ def layernorm_mlp(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -117,7 +124,7 @@ def layernorm_mlp(
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -132,6 +139,8 @@ def _layernorm_mlp(
norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
kernel_1_axes: Tuple[str, ...],
kernel_2_axes: Tuple[str, ...],
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
......@@ -179,6 +188,8 @@ def _layernorm_mlp(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
Returns:
Tuple of (output, context) for automatic differentiation
"""
del kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len * intermediate)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in)
assert len(kernel_1.shape) == 2
assert len(kernel_1.shape) == 3
assert len(kernel_2.shape) == 2
assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type)
assert kernel_1.shape[-2] == len(activation_type)
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None
......@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
norm_type,
quantizer=ffn1_quantizer_set.x,
)
casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm(
......@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_1.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1:
bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
......@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims),
)
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2:
bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
......@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
ctx,
grad,
......@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
(
x,
mu,
......@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_2 = tuple(
g_contracting_dims_2 = tuple(
range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
)
# k_non_contracting_dims
k_constracting_dim_2 = tuple(
k_contracting_dims_2 = tuple(
dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
)
......@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2,
(g_constracting_dim_2, k_constracting_dim_2),
(g_contracting_dims_2, k_contracting_dims_2),
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
x_constracting_dim = g_constracting_dim = tuple(
x_contracting_dims = g_contracting_dims = tuple(
range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
)
......@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
wgrad_2 = tex.gemm(
colwise_casted_act_out,
casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
(x_contracting_dims, g_contracting_dims),
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
dgrad_2,
......@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1 = tuple(
range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim)
dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
)
# k_non_contracting_dims
k_constracting_dim_1 = tuple(
k_contracting_dims_1 = tuple(
dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
)
......@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1,
(g_constracting_dim_1, k_constracting_dim_1),
(g_contracting_dims_1, k_contracting_dims_1),
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm(
colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
(x_contracting_dims, g_contracting_dims),
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad_1,
x,
......
......@@ -57,18 +57,27 @@ class Dequantizer:
data = scaled_tensor.data.astype(jnp.float32)
data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
flatten_axis = scaled_tensor.flatten_axis
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False
data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
data = data.reshape(
*data_shape[:-2],
scale_shape[-2],
int(data_shape[-2] / scale_shape[-2]),
*data_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale = jnp.expand_dims(scale, axis=(-1, -3))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape
......
......@@ -14,7 +14,7 @@ from typing import Union, Optional
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
......@@ -24,7 +24,7 @@ from .helper import (
)
__all__ = [
"QuantizeAxis",
"QuantizeLayout",
"Quantizer",
"QuantizerSet",
"DelayedScaleQuantizer",
......@@ -45,12 +45,12 @@ class Quantizer(ABC):
Attributes:
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both)
q_layout: The quantization axis (row-wise, column-wise, or both)
"""
q_dtype: jnp.dtype
scaling_mode: ScalingMode
q_axis: QuantizeAxis
q_layout: QuantizeLayout
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
......@@ -59,7 +59,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations
"""
children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
@classmethod
......@@ -85,30 +85,31 @@ class Quantizer(ABC):
Returns:
True if using both row-wise and column-wise quantization
"""
return self.q_axis == QuantizeAxis.ROWWISE_COLWISE
return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
@abstractmethod
def get_layout(self) -> str:
"""Get the data layout.
def get_data_layout(self) -> str:
"""Get the data data_layout.
Returns:
Data layout in string format
Data data_layout in string format
"""
@abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Core quantization function to be implemented by subclasses.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None):
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1):
"""Quantize a tensor using the internal _quantize_func().
Args:
......@@ -116,21 +117,26 @@ class Quantizer(ABC):
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise:
return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
return self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return self._quantize_func(x, dq_dtype=dq_dtype)
return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def get_scale_shapes(self, data_shape, is_padded=True):
def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1):
"""Get shapes for scale tensors.
Args:
......@@ -140,7 +146,7 @@ class Quantizer(ABC):
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded)
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
def get_scale_dtype(self):
"""Get the data type for scale tensors.
......@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
......@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
def get_layout(self) -> str:
"""Get the data layout string.
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data layout in string format
Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
layout = "NT"
if self.q_axis == QuantizeAxis.ROWWISE_COLWISE:
return layout
if self.q_axis == QuantizeAxis.ROWWISE:
return layout[0]
if self.q_axis == QuantizeAxis.COLWISE:
return layout[1]
raise ValueError(f"Invalid q_axis: {self.q_axis}")
def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
data_layout = "NT"
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return data_layout
if self.q_layout == QuantizeLayout.ROWWISE:
return data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE:
return data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}")
def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
......@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None):
def quantize(
self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
):
"""Quantize a tensor using the internal _quantize_func().
Args:
......@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x())
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x())
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = None
if is_colwise:
colwise_tensor = ScaledTensorFactory.create_1x(
data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))),
data=jnp.transpose(
rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
),
scale_inv=rowwise_tensor.scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
is_colwise=True,
layout="T",
data_layout="T",
flatten_axis=flatten_axis,
)
if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
......@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
q_layout: Quantization axis (default: ROWWISE_COLWISE)
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_layout(self) -> str:
"""Get the data layout string.
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data layout in string format
Data data_layout in string format
"""
if self.is_2x2x():
return "NN"
return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
assert (
0 <= flatten_axis < x.ndim
), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False)
scale_shape = self.scaling_mode.get_scale_shape(
x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape(
*x_shape[:-2],
scale_shape[-2],
int(x_shape[-2] / scale_shape[-2]),
*x_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*x_shape[flatten_axis:-1],
scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]),
)
amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True)
amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
scales = amax.astype(jnp.float32) / MAX
......@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
self.scaling_mode,
is_colwise=is_colwise,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
def _cast_to_e8m0_with_rounding_up(self, scales):
......@@ -509,7 +539,7 @@ class QuantizerFactory:
n_quantizers: int = 1,
scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None,
q_axis: QuantizeAxis = None,
q_layout: QuantizeLayout = None,
**kwargs,
) -> Quantizer:
"""Create one or more quantizers with specified parameters.
......@@ -518,7 +548,8 @@ class QuantizerFactory:
n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use
q_dtype: Quantization data type
q_axis: Quantization axis
q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
**kwargs: Additional arguments for quantizer initialization
Returns:
......@@ -534,7 +565,7 @@ class QuantizerFactory:
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append(
quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs
q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
)
)
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
......@@ -554,11 +585,11 @@ class QuantizerFactory:
A QuantizerSet instance
"""
if is_2x2x:
q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else:
q_axis_x = QuantizeAxis.ROWWISE
q_axis_kernel = QuantizeAxis.COLWISE
q_axis_dgrad = None
q_layout_x = QuantizeLayout.ROWWISE
q_layout_kernel = QuantizeLayout.COLWISE
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set")
......@@ -577,9 +608,11 @@ class QuantizerFactory:
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x)
q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad)
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x)
q_kernel = QuantizerFactory.create(
1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel
)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment