Unverified Commit ff884e20 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flatten_axis for quantization and Sharding propagation fixes (#1644)



* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout

* add fatten_axis option

* added gated act to test encoder

* sharding constraint fixes

* fix padding when flattening first dim needs to be padded

* update test sizes so that padding is tested

* rm output sharding as it can be done in the flax module

* sharding scale_inv for mxfp8

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent be1f647c
......@@ -57,13 +57,14 @@ class Net(nn.Module):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
# Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
......@@ -459,7 +460,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -467,7 +468,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -475,14 +476,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
......@@ -491,7 +492,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
......@@ -500,7 +501,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
if __name__ == "__main__":
......
This diff is collapsed.
......@@ -45,11 +45,17 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in]
INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
KERNEL_1_AXES = (W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES)
KERNEL_2_AXES = (W_TP_AXES, W_FSDP_AXES)
LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64
......@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
configs.append(
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
......@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype
subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype)
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
else:
b1 = None
......@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
layernorm_input_axes = LAYERNORM_INPUT_AXES
dot_1_input_axes = DOT_1_INPUT_AXES
dot_2_input_axes = DOT_2_INPUT_AXES
kernel_1_axes = KERNEL_1_AXES
kernel_2_axes = KERNEL_2_AXES
else:
layernorm_input_axes = None
dot_1_input_axes = None
dot_2_input_axes = None
dot_1_input_axes = dot_2_input_axes = None
kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
......@@ -130,6 +137,8 @@ class TestDistributedLayernormMLP:
norm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
kernel_1_axes=kernel_1_axes,
kernel_2_axes=kernel_2_axes,
activation_type=activation_type,
quantizer_sets=quantizer_sets,
)
......@@ -142,7 +151,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_primitive(
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
......@@ -168,12 +177,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp"))
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec("tp"))
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
......@@ -267,16 +276,7 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
......@@ -295,13 +295,13 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
scale_axes=LN_SCALE_AXES,
ln_bias_axes=LN_BIAS_AXES,
kernel_axes_1=KERNEL_1_AXES,
kernel_axes_2=KERNEL_2_AXES,
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
bias_axes_1=BIAS_1_AXES,
bias_axes_2=BIAS_2_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
......@@ -334,7 +334,7 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_layer(
def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
self._test_layernorm_mlp(
......
......@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None)
......
......@@ -6,9 +6,9 @@
from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce
import operator
from transformer_engine_jax import get_device_compute_capability
import jax
import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive
......@@ -183,10 +183,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T")
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
......@@ -203,9 +202,9 @@ def _jax_gemm_delayed_scaling_fp8(
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T":
if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T":
if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch)
......@@ -403,19 +402,19 @@ def grouped_gemm(
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T":
if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T":
if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN
# For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
......@@ -432,8 +431,8 @@ def grouped_gemm(
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T")
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
......
......@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout
TEDType = transformer_engine_jax.DType
......@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1):
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
"""
te_cast_transpose_p multi-dims transpose
static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
involved into transpose, -1 means all axes involve into transpose.
transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary
transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
X in shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis_boundary == 2
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 2
static_axis_boundary == 0, transpose_axis == 2
Xt = (dim0, dim2, dim3, dim4, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 3
static_axis_boundary == 0, transpose_axis == 3
Xt = (dim0, dim3, dim4, dim1. dim2)
"""
if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes
assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose.
transpose_start_idx = static_axis_boundary + 1
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape))
assert transpose_start_idx < transpose_axis_boundary
transpose_axis = normalize_axis_boundary(transpose_axis, len(shape))
assert transpose_start_idx < transpose_axis
return (
*shape[:transpose_start_idx],
*shape[transpose_axis_boundary:],
*shape[transpose_start_idx:transpose_axis_boundary],
*shape[transpose_axis:],
*shape[transpose_start_idx:transpose_axis],
)
......@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break
return (
quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE
and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100
and is_dbias
)
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs):
"""
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
......@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
# 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE
quantizer.q_layout = QuantizeLayout.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None
if isinstance(rowwise, tuple):
other_outputs = rowwise[1:]
rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1)))
quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
if flatten_axis < 0:
flatten_axis += rowwise.data.ndim
assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
colwise_data = jnp.transpose(
rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis))
)
output_2x = ScaledTensorFactory.create(
data=rowwise.data,
scale_inv=rowwise.scale_inv,
......@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE,
layout=quantizer.get_layout(),
q_layout=QuantizeLayout.ROWWISE_COLWISE,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
)
if other_outputs is not None:
return (output_2x,) + other_outputs
......
......@@ -30,7 +30,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
Quantizer,
QuantizeAxis,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
)
......@@ -277,14 +277,14 @@ class NormFwdPrimitive(BasePrimitive):
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x.shape, is_padded=False
)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv = scale_inv.flatten()[
: reduce(operator.mul, rowwise_scale_inv_shape)
].reshape(rowwise_scale_inv_shape)
if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape)
].reshape(colwise_scale_inv_shape)
# slice out padding for mxfp8, noop for DelayedScaling
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape)
return (
out,
colwise_out,
......@@ -816,7 +816,7 @@ def layernorm_fwd(
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -900,8 +900,8 @@ def layernorm_fwd(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
)
return scaled_tensor, mu, rsigma
......@@ -997,7 +997,7 @@ def rmsnorm_fwd(
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = (
......@@ -1082,8 +1082,8 @@ def rmsnorm_fwd(
colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype,
q_axis=quantizer.q_axis,
layout=quantizer.get_layout(),
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
)
return scaled_tensor, rsigma
......
......@@ -11,14 +11,6 @@
#include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
namespace {
bool is_gated(NVTE_Activation_Type act_type) {
return act_type == NVTE_Activation_Type::GEGLU || act_type == NVTE_Activation_Type::SWIGLU ||
act_type == NVTE_Activation_Type::REGLU || act_type == NVTE_Activation_Type::QGEGLU ||
act_type == NVTE_Activation_Type::SREGLU;
}
} // namespace
namespace transformer_engine {
namespace jax {
......@@ -44,38 +36,56 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto act_len = input_dims[input_dims.size() - 2];
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto is_2x = static_cast<bool>(is_2x_int);
auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis
auto input_shape = std::vector<size_t>{m, act_len * n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
if (is_2x) {
output_tensor.set_columnwise_data(colwise_output, static_cast<DType>(out_dtype), output_shape);
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
}
switch (act_type) {
......@@ -162,8 +172,10 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
if (is_2x) {
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype,
output_trans_shape);
auto &tmp_shape = scaling_mode == static_cast<int>(NVTE_DELAYED_TENSOR_SCALING)
? output_trans_shape
: output_shape;
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if (is_fp8_dtype(out_dtype)) {
......@@ -190,9 +202,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum, bool is_2x,
bool is_dbias, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
......@@ -201,11 +213,15 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
void *workspace = workspace_buf->untyped_data();
......@@ -213,17 +229,18 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto act_input_dims = act_input_buf.dimensions();
auto workspace_dims = workspace_buf->dimensions();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
auto input_ranks = input_dims.size();
auto act_input_ranks = act_input_dims.size();
auto m = product(act_input_dims, 0, act_input_dims.size() - 1);
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated
auto n = act_input_dims.back();
auto input_shape = std::vector<size_t>{m, input_dims.back()};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{m, n};
auto dbias_shape = std::vector<size_t>{n};
// n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto act_len = act_input_dims[act_input_dims.size() - 2];
NVTE_CHECK(act_input_dims.back() == input_dims.back(),
"Shape mismatch between activation input and gradient input");
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = input_dims.back();
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto output_trans_shape = std::vector<size_t>{n * act_len, m};
auto dbias_shape = std::vector<size_t>{n * act_len};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
......@@ -231,50 +248,56 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto output_tensor = TensorWrapper(scaling_mode);
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
if (is_fp8_dtype(out_dtype)) {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
if (is_2x) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(colwise_output, out_dtype, tmp_shape);
if (is_fp8_dtype(out_dtype)) {
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK(!(is_gated(act_type) && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x &&
is_gated(act_type)),
"TE/common does not support delayed scaling for 2x with gated activations.");
NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!");
NVTE_CHECK(
!(scaling_mode == NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING && is_2x && act_len == 2),
"TE/common does not support delayed scaling for 2x with gated activations.");
if (is_dbias) {
switch (act_type) {
......
......@@ -44,12 +44,12 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
cudaStreamSynchronize(stream);
// Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair:
// Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair:
// cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be:
......
......@@ -34,7 +34,7 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
enum class QuantizeAxis {
enum class QuantizeLayout {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
......
......@@ -144,11 +144,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("NVTE_INVALID_SCALING", NVTEScalingMode::NVTE_MXFP8_1D_SCALING)
.export_values();
pybind11::enum_<transformer_engine::jax::QuantizeAxis>(m, "QuantizeAxis",
pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeAxis::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeAxis::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeAxis::ROWWISE_COLWISE)
pybind11::enum_<transformer_engine::jax::QuantizeLayout>(m, "QuantizeLayout",
pybind11::module_local())
.value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE)
.value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE)
.value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE)
.export_values();
}
......
......@@ -42,10 +42,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type output_trans_buf,
Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
Result_Type amax_out_buf, Result_Type dbias_buf,
Result_Type workspace_buf, int64_t scaling_mode_enum,
int64_t quantize_axis_enum, bool is_dbias) {
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
int64_t scaling_mode_enum, int64_t quantize_layout_enum, bool is_dbias,
int64_t flatten_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
......@@ -55,7 +55,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto *input = input_buf.untyped_data();
auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
auto const quantize_axis = static_cast<QuantizeAxis>(quantize_axis_enum);
auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
auto *output = output_buf->untyped_data();
auto *output_trans = output_trans_buf->untyped_data();
......@@ -63,9 +63,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
void *workspace = workspace_buf->untyped_data();
auto input_dims = input_buf.dimensions();
int64_t input_ndim = input_dims.size();
if (flatten_axis < 0) flatten_axis += input_ndim;
NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");
auto workspace_dims = workspace_buf->dimensions();
auto m = product(input_dims, 0, input_dims.size() - 1);
auto n = input_dims.back();
auto m = product(input_dims, 0, flatten_axis);
auto n = product(input_dims, flatten_axis, input_ndim);
auto input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
auto output_trans_shape = std::vector<size_t>{n, m};
......@@ -75,37 +79,54 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto output_tensor = TensorWrapper(scaling_mode);
if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
if (quantize_layout == QuantizeLayout::ROWWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{
product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
scale_inv_buf->dimensions().back()});
}
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
cudaMemsetAsync(amax, 0, sizeof(float), stream);
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_rowwise_scale_inv(
scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
product(scale_inv_buf->dimensions(), flatten_axis,
scale_inv_buf->dimensions().size())});
}
}
}
if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
if (quantize_layout == QuantizeLayout::COLWISE ||
quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
auto &tmp_shape =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape;
output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto &colwise_scale_inv_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
output_tensor.set_columnwise_scale_inv(
colwise_scale_inv_buf->untyped_data(),
convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
colwise_scale_inv_buf->dimensions().size() - 1),
colwise_scale_inv_buf->dimensions().back()});
auto &tmp_buf =
(scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : colwise_scale_inv_buf;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{1});
} else {
output_tensor.set_columnwise_scale_inv(
tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
std::vector<size_t>{
product(tmp_buf->dimensions(), 0, flatten_axis),
product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
}
}
auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
......@@ -133,8 +154,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("scaling_mode")
.Attr<int64_t>("q_axis")
.Attr<bool>("is_dbias"),
.Attr<int64_t>("q_layout")
.Attr<bool>("is_dbias")
.Attr<int64_t>("flatten_axis"),
FFI_CudaGraph_Traits);
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
......
......@@ -15,7 +15,11 @@ import jax
import jax.numpy as jnp
from . import cpp_extensions as tex
from .quantize import QuantizerSet, noop_quantizer_set
from .quantize import (
QuantizerSet,
noop_quantizer_set,
with_sharding_constraint_by_logical_axes,
)
def dense(
......@@ -23,6 +27,8 @@ def dense(
kernel: jnp.ndarray,
bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -48,12 +54,12 @@ def dense(
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense(x, kernel, bias, contracting_dims, quantizer_set)
output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def _dense(x, kernel, bias, contracting_dims, quantizer_set):
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
......@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
output, _ = _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set)
output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
)
return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
"""Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Tuple of (output, context) for backward pass
"""
x_contracting_dims, k_contracting_dims = contracting_dims
casted_x = tex.quantize(x, quantizer_set.x)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN
output = tex.gemm(
......@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
casted_kernel.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
use_bias = bias is not None
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
......@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
kernel.shape,
use_bias,
quantizer_set,
flatten_axis_k,
)
return output, ctx
def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argument
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns:
Tuple of gradients with respect to inputs
"""
......@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
kernel_shape,
use_bias,
quantizer_set,
flatten_axis_k,
) = ctx
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
)
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN
# x_non_contracting_dims
......@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set
......
......@@ -28,6 +28,7 @@ from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
from ..cpp_extensions import is_softmax_kernel_available
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -406,6 +407,10 @@ class DenseGeneral(TransformerEngineBase):
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
-----------------------
......@@ -429,6 +434,7 @@ class DenseGeneral(TransformerEngineBase):
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
def __post_init__(self):
if self.kernel_init is None:
......@@ -460,29 +466,35 @@ class DenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
if self.kernel_axes:
assert len(kernel_shape) == len(self.kernel_axes), (
"Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
)
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
else:
bias = None
quantizer_set = self.generate_quantizer_set()
contract_ind = tuple(range(0, len(axis)))
y = dense(
inputs, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set
inputs,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation:
......@@ -491,20 +503,14 @@ class DenseGeneral(TransformerEngineBase):
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_init_shape = (
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
......@@ -527,7 +533,6 @@ class DenseGeneral(TransformerEngineBase):
y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype
y = y.reshape(*inputs.shape[: self.axis], *features)
return y
......@@ -678,6 +683,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
assert self.axis == -1, "Only support axis = =-1 at this moment"
input_dtype = inputs.dtype
ln_output = None
......@@ -692,10 +698,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
assert self.axis == -1 # Only support axis = =-1 at this moment
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self.layernorm_type,
(features,),
......@@ -731,17 +734,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim)
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
kernel_compute_shape = (
reduce(operator.mul, [inputs.shape[ax] for ax in axis], 1),
reduce(operator.mul, features, 1),
)
kernel = jnp.reshape(kernel, kernel_compute_shape)
contract_ind = tuple(range(0, len(axis)))
......@@ -756,11 +754,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
else:
y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes)
z = dense(y, kernel, contracting_dims=(axis, contract_ind), quantizer_set=quantizer_set)
z = dense(
y,
kernel,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
)
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (
......@@ -768,20 +774,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_init_shape = (
kernel_compute_shape[0],
*features[:-1],
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
......@@ -803,8 +803,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.reshape(kernel_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
......@@ -814,7 +813,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
z = z / self.depth_scaling
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
z = z.reshape(*inputs.shape[: self.axis], *features)
# z = z.reshape(*inputs.shape[: self.axis], *features)
return z, ln_output # dense_output, layer_norm_output
......@@ -989,6 +988,8 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
assert self.axis == -1, "Only support axis == -1 at this moment"
ffn1_quantizer_set = self.generate_quantizer_set("_0")
ffn2_quantizer_set = self.generate_quantizer_set("_1")
......@@ -1027,7 +1028,6 @@ class LayerNormMLP(TransformerEngineBase):
)
# LayerNorm
if self.enable_layernorm:
assert self.axis == -1 # Only support axis == -1 at this moment
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1]
......@@ -1071,7 +1071,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations = len(normalized_acts)
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([y.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes(
"wi_kernel",
kernel_1_init,
......@@ -1081,17 +1081,10 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
axes=self.kernel_axes_1,
)
kernel_1_compute_shape = (
reduce(operator.mul, [y.shape[ax] for ax in axis], 1),
num_activations * self.intermediate_dim,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)
if self.kernel_axes_1 is not None:
kernel_1 = with_sharding_constraint_by_logical_axes(
kernel_1, self.kernel_axes_1[:-2] + self.kernel_axes_1[-1:]
)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
......@@ -1102,27 +1095,20 @@ class LayerNormMLP(TransformerEngineBase):
self.dtype,
axes=self.kernel_axes_2,
)
kernel_2_compute_shape = (
self.intermediate_dim,
reduce(operator.mul, hidden_size_tuple, 1),
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_compute_shape)
if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
if self.kernel_axes_2 is not None:
kernel_2 = with_sharding_constraint_by_logical_axes(kernel_2, self.kernel_axes_2)
contract_ind = tuple(range(0, len(axis)))
if self.use_bias:
bias_1_shape = num_activations * self.intermediate_dim
bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes(
"wi_bias",
self.bias_init,
bias_1_shape,
self.dtype,
axes=self.bias_axes_1,
)
bias_1 = bias_1.reshape(kernel_1_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
......@@ -1131,8 +1117,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2_shape,
self.dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.reshape(kernel_2_compute_shape[-1]).astype(input_dtype)
).astype(input_dtype)
else:
bias_1 = None
bias_2 = None
......@@ -1141,8 +1126,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name = "ffn2"
if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
out = layernorm_mlp(
y,
scale,
......@@ -1155,6 +1138,8 @@ class LayerNormMLP(TransformerEngineBase):
norm_input_axes=self.layernorm_input_axes,
dot_1_input_axes=self.dot_1_input_axes,
dot_2_input_axes=self.dot_2_input_axes,
kernel_1_axes=self.kernel_axes_1,
kernel_2_axes=self.kernel_axes_2,
ffn1_ckpt_name=ffn1_ckpt_name,
ffn2_ckpt_name=ffn2_ckpt_name,
activation_type=normalized_acts,
......@@ -1175,6 +1160,7 @@ class LayerNormMLP(TransformerEngineBase):
epsilon=self.epsilon,
layernorm_input_axes=self.layernorm_input_axes,
dot_input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
)
else:
......@@ -1183,35 +1169,31 @@ class LayerNormMLP(TransformerEngineBase):
y,
kernel_1,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_1_input_axes,
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
)
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = (
kernel_1_compute_shape[0],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_shape = (
kernel_1_each_shape[0],
num_activations,
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_init_each_shape = (
kernel_1_each_shape[0],
wi_lora_a_kernel_each_shape = (
kernel_1_each_shape[: len(axis)],
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
wi_lora_a_kernel = nn_partitioning.param_with_axes(
"wi_lora_a_kernel",
kernel_1_init,
num_activations,
-1,
wi_lora_a_kernel_init_each_shape,
-2,
wi_lora_a_kernel_each_shape,
self.dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = (
......@@ -1232,7 +1214,7 @@ class LayerNormMLP(TransformerEngineBase):
x += _apply_low_rank_adaptation(
y,
axis,
num_activations * self.intermediate_dim,
(num_activations, self.intermediate_dim),
wi_lora_a_kernel,
wi_lora_b_kernel,
self.low_rank_adaptation_alpha,
......@@ -1246,11 +1228,12 @@ class LayerNormMLP(TransformerEngineBase):
z = activation(x, normalized_acts)
else:
activations = []
x = jnp.split(x, num_activations, axis=-1)
x = jnp.split(x, num_activations, axis=-2)
for idx, act_fn in enumerate(normalized_acts):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = reduce(operator.mul, activations)
z = jnp.squeeze(z, axis=-2)
z = z.astype(input_dtype)
z = nn.Dropout(
......@@ -1264,7 +1247,12 @@ class LayerNormMLP(TransformerEngineBase):
# DenseGeneral 2
out = dense(
z, kernel_2, contracting_dims=(axis, contract_ind), quantizer_set=ffn2_quantizer_set
z,
kernel_2,
contracting_dims=(axis, contract_ind),
input_axes=self.dot_2_input_axes,
kernel_axes=self.kernel_axes_2,
quantizer_set=ffn2_quantizer_set,
)
if self.enable_low_rank_adaptation:
......
......@@ -33,10 +33,9 @@ def layernorm_dense(
norm_type: str = "layernorm",
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes: Tuple[str, ...] = None,
# The logic axes of sharding constraint to the dot input.
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
......@@ -56,6 +55,7 @@ def layernorm_dense(
epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: Set of quantizers for different tensor types
Returns:
......@@ -78,6 +78,7 @@ def layernorm_dense(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
)
return output
......@@ -91,6 +92,7 @@ def layernorm_dense(
7,
8,
9,
10,
),
)
def _layernorm_dense(
......@@ -104,6 +106,7 @@ def _layernorm_dense(
epsilon: float,
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
......@@ -139,6 +142,7 @@ def _layernorm_dense(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
)
return output
......@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
epsilon,
layernorm_input_axes,
dot_input_axes,
kernel_axes,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
......@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
assert len(kernel.shape) == 2 # Otherwise need to merge dims in quantize
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
......@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
norm_type,
quantizer_set.x,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...)
casted_kernel = tex.quantize(kernel, quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
......@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims,
use_bias,
quantizer_set,
flatten_axis,
)
return output, ctx
......@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
kernel_axes,
ctx,
grad,
):
......@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
k_contracting_dims_in_fwd,
use_bias,
quantizer_set,
flatten_axis,
) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, dot_input_axes)
casted_grad, dbias = tex.quantize_dbias(grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad)
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple(
......@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
(x_constracting_dim, g_constracting_dim),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad,
x,
......
......@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
from . import cpp_extensions as tex
from .layernorm import canonicalize_norm_type
from .quantize import with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set
from .sharding import get_non_contracting_logical_axes
def layernorm_mlp(
......@@ -37,6 +38,8 @@ def layernorm_mlp(
norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
kernel_1_axes: Tuple[str, ...] = None,
kernel_2_axes: Tuple[str, ...] = None,
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
......@@ -66,6 +69,8 @@ def layernorm_mlp(
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
kernel_1_axes: Logical axes for sharding the first weight matrix
kernel_2_axes: Logical axes for sharding the second weight matrix
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
......@@ -109,6 +114,8 @@ def layernorm_mlp(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -117,7 +124,7 @@ def layernorm_mlp(
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -132,6 +139,8 @@ def _layernorm_mlp(
norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
kernel_1_axes: Tuple[str, ...],
kernel_2_axes: Tuple[str, ...],
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
......@@ -179,6 +188,8 @@ def _layernorm_mlp(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
......@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
Returns:
Tuple of (output, context) for automatic differentiation
"""
del kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len * intermediate)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in)
assert len(kernel_1.shape) == 2
assert len(kernel_1.shape) == 3
assert len(kernel_2.shape) == 2
assert kernel_1.shape[1] == kernel_2.shape[0] * len(activation_type)
assert kernel_1.shape[-2] == len(activation_type)
x_contracting_dims = (len(x.shape) - 1,)
k_contracting_dims = (0,)
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
assert kernel_1.shape[1] == len(activation_type) * kernel_2.shape[0]
use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None
......@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
norm_type,
quantizer=ffn1_quantizer_set.x,
)
casted_kernel_1 = tex.quantize(kernel_1, quantizer=ffn1_quantizer_set.kernel)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes)
casted_kernel_1 = tex.quantize(kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output = tex.gemm(
......@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_1.get_colwise_tensor(),
(x_contracting_dims, k_contracting_dims),
)
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1:
bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
......@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims),
)
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2:
bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
......@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
ffn1_ckpt_name, # pylint: disable=unused-argument
ffn2_ckpt_name, # pylint: disable=unused-argument
kernel_1_axes,
kernel_2_axes,
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
ctx,
grad,
......@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
(
x,
mu,
......@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_2 = tuple(
g_contracting_dims_2 = tuple(
range(grad.ndim - len(kernel_2_shape) + len(k_contracting_dims_in_fwd), grad.ndim)
)
# k_non_contracting_dims
k_constracting_dim_2 = tuple(
k_contracting_dims_2 = tuple(
dim for dim in range(len(kernel_2_shape)) if dim not in k_contracting_dims_in_fwd
)
......@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
dgrad_2 = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel_2,
(g_constracting_dim_2, k_constracting_dim_2),
(g_contracting_dims_2, k_contracting_dims_2),
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
x_constracting_dim = g_constracting_dim = tuple(
x_contracting_dims = g_contracting_dims = tuple(
range(0, len(x.shape) - len(x_contracting_dims_in_fwd))
)
......@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
wgrad_2 = tex.gemm(
colwise_casted_act_out,
casted_grad.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
(x_contracting_dims, g_contracting_dims),
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
casted_dact_out, dbias_1 = tex.quantize_dact_dbias(
dgrad_2,
......@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1 = tuple(
range(dgrad_2.ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dgrad_2.ndim)
dact_out_ndim = casted_dact_out.get_rowwise_tensor().data.ndim
g_contracting_dims_1 = tuple(
range(dact_out_ndim - len(kernel_1_shape) + len(k_contracting_dims_in_fwd), dact_out_ndim)
)
# k_non_contracting_dims
k_constracting_dim_1 = tuple(
k_contracting_dims_1 = tuple(
dim for dim in range(len(kernel_1_shape)) if dim not in k_contracting_dims_in_fwd
)
......@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
dgrad_1 = tex.gemm(
casted_dact_out.get_rowwise_tensor(),
rowwise_casted_kernel_1,
(g_constracting_dim_1, k_constracting_dim_1),
(g_contracting_dims_1, k_contracting_dims_1),
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, norm_input_axes)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1 = tex.gemm(
colwise_casted_ln_out,
casted_dact_out.get_colwise_tensor(),
(x_constracting_dim, g_constracting_dim),
(x_contracting_dims, g_contracting_dims),
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
dx, dgamma, dbeta = tex.normalization_bwd(
dgrad_1,
x,
......
......@@ -57,18 +57,27 @@ class Dequantizer:
data = scaled_tensor.data.astype(jnp.float32)
data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
flatten_axis = scaled_tensor.flatten_axis
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaled_tensor.scaling_mode.get_scale_shape(
scaled_tensor.data.shape, scaled_tensor.is_colwise, is_padded=False
data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding
data = data.reshape(
*data_shape[:-2],
scale_shape[-2],
int(data_shape[-2] / scale_shape[-2]),
*data_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*data_shape[flatten_axis:-1],
scale_shape[-1],
int(data_shape[-1] / scale_shape[-1]),
)
scale = jnp.expand_dims(scale, axis=(-1, -3))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape
......
......@@ -14,7 +14,7 @@ from typing import Union, Optional
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeAxis
from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
......@@ -24,7 +24,7 @@ from .helper import (
)
__all__ = [
"QuantizeAxis",
"QuantizeLayout",
"Quantizer",
"QuantizerSet",
"DelayedScaleQuantizer",
......@@ -45,12 +45,12 @@ class Quantizer(ABC):
Attributes:
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
q_axis: The quantization axis (row-wise, column-wise, or both)
q_layout: The quantization axis (row-wise, column-wise, or both)
"""
q_dtype: jnp.dtype
scaling_mode: ScalingMode
q_axis: QuantizeAxis
q_layout: QuantizeLayout
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
......@@ -59,7 +59,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations
"""
children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
@classmethod
......@@ -85,30 +85,31 @@ class Quantizer(ABC):
Returns:
True if using both row-wise and column-wise quantization
"""
return self.q_axis == QuantizeAxis.ROWWISE_COLWISE
return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
@abstractmethod
def get_layout(self) -> str:
"""Get the data layout.
def get_data_layout(self) -> str:
"""Get the data data_layout.
Returns:
Data layout in string format
Data data_layout in string format
"""
@abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Core quantization function to be implemented by subclasses.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None):
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1):
"""Quantize a tensor using the internal _quantize_func().
Args:
......@@ -116,21 +117,26 @@ class Quantizer(ABC):
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
colwise_tensor = self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise:
return self._quantize_func(x, is_colwise=True, dq_dtype=dq_dtype)
return self._quantize_func(
x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
)
return self._quantize_func(x, dq_dtype=dq_dtype)
return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def get_scale_shapes(self, data_shape, is_padded=True):
def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1):
"""Get shapes for scale tensors.
Args:
......@@ -140,7 +146,7 @@ class Quantizer(ABC):
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded)
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
def get_scale_dtype(self):
"""Get the data type for scale tensors.
......@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_DELAYED_TENSOR_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
amax_history: jnp.ndarray = field(
......@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations
"""
children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_axis)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
return (children, aux_data)
def get_layout(self) -> str:
"""Get the data layout string.
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data layout in string format
Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
layout = "NT"
if self.q_axis == QuantizeAxis.ROWWISE_COLWISE:
return layout
if self.q_axis == QuantizeAxis.ROWWISE:
return layout[0]
if self.q_axis == QuantizeAxis.COLWISE:
return layout[1]
raise ValueError(f"Invalid q_axis: {self.q_axis}")
def _quantize_func(self, x: jnp.ndarray, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
data_layout = "NT"
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return data_layout
if self.q_layout == QuantizeLayout.ROWWISE:
return data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE:
return data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}")
def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
) -> ScaledTensor1x:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
......@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
scale_inv=scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
def quantize(self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None):
def quantize(
self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
):
"""Quantize a tensor using the internal _quantize_func().
Args:
......@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_axis == QuantizeAxis.ROWWISE or self.is_2x2x())
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_axis == QuantizeAxis.COLWISE or self.is_2x2x())
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype)
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = None
if is_colwise:
colwise_tensor = ScaledTensorFactory.create_1x(
data=jnp.transpose(rowwise_tensor.data, (-1, *range(rowwise_tensor.data.ndim - 1))),
data=jnp.transpose(
rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
),
scale_inv=rowwise_tensor.scale_inv,
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
is_colwise=True,
layout="T",
data_layout="T",
flatten_axis=flatten_axis,
)
if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
......@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_axis: Quantization axis (default: ROWWISE_COLWISE)
q_layout: Quantization axis (default: ROWWISE_COLWISE)
"""
scaling_mode: ScalingMode = ScalingMode.NVTE_MXFP8_1D_SCALING
q_axis: QuantizeAxis = QuantizeAxis.ROWWISE_COLWISE
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
def get_layout(self) -> str:
"""Get the data layout string.
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data layout in string format
Data data_layout in string format
"""
if self.is_2x2x():
return "NN"
return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None) -> ScaledTensor1x:
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
if flatten_axis < 0:
flatten_axis = x.ndim + flatten_axis
assert (
0 <= flatten_axis < x.ndim
), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
x_shape = x.shape
scale_shape = self.scaling_mode.get_scale_shape(x_shape, is_colwise, is_padded=False)
scale_shape = self.scaling_mode.get_scale_shape(
x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
)
scale_dtype = self.scaling_mode.get_scale_dtype()
x = x.reshape(
*x_shape[:-2],
scale_shape[-2],
int(x_shape[-2] / scale_shape[-2]),
*x_shape[: flatten_axis - 1],
scale_shape[flatten_axis - 1],
int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
*x_shape[flatten_axis:-1],
scale_shape[-1],
int(x_shape[-1] / scale_shape[-1]),
)
amax = jnp.max(jnp.abs(x), axis=(-3, -1), keepdims=True)
amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
scales = amax.astype(jnp.float32) / MAX
......@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
self.scaling_mode,
is_colwise=is_colwise,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
)
def _cast_to_e8m0_with_rounding_up(self, scales):
......@@ -509,7 +539,7 @@ class QuantizerFactory:
n_quantizers: int = 1,
scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None,
q_axis: QuantizeAxis = None,
q_layout: QuantizeLayout = None,
**kwargs,
) -> Quantizer:
"""Create one or more quantizers with specified parameters.
......@@ -518,7 +548,8 @@ class QuantizerFactory:
n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use
q_dtype: Quantization data type
q_axis: Quantization axis
q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
**kwargs: Additional arguments for quantizer initialization
Returns:
......@@ -534,7 +565,7 @@ class QuantizerFactory:
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append(
quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_axis=q_axis, **kwargs
q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
)
)
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
......@@ -554,11 +585,11 @@ class QuantizerFactory:
A QuantizerSet instance
"""
if is_2x2x:
q_axis_x = q_axis_kernel = q_axis_dgrad = QuantizeAxis.ROWWISE_COLWISE
q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
else:
q_axis_x = QuantizeAxis.ROWWISE
q_axis_kernel = QuantizeAxis.COLWISE
q_axis_dgrad = None
q_layout_x = QuantizeLayout.ROWWISE
q_layout_kernel = QuantizeLayout.COLWISE
q_layout_dgrad = None
if "quantize_meta_set" in kwargs:
quantize_meta_set = kwargs.get("quantize_meta_set")
......@@ -577,9 +608,11 @@ class QuantizerFactory:
else:
args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_x, **args_x)
q_kernel = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_axis_kernel, **args_kernel)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_axis_dgrad, **args_grad)
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x)
q_kernel = QuantizerFactory.create(
1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel
)
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment