Unverified Commit c3b7c2ae authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

Revert "[JAX] GroupedDense v.2 without dynamic shape" (#1874)

Revert "[JAX] GroupedDense v.2 without dynamic shape (#1721)"

This reverts commit 5d01ef21

.
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4d4f1edb
...@@ -40,11 +40,10 @@ from transformer_engine.jax.quantize import ( ...@@ -40,11 +40,10 @@ from transformer_engine.jax.quantize import (
ScalingMode, ScalingMode,
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.dense import dense
from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.jax.layernorm_dense import layernorm_dense
GEMM_CASES = [ GEMM_CASES = [
...@@ -1205,6 +1204,24 @@ class TestFusedDense: ...@@ -1205,6 +1204,24 @@ class TestFusedDense:
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer):
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return lhs_q, rhs_q
# E5M2 * E5M2 is not supported # E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [ fwd_bwd_dtypes = [
[jnp.float8_e4m3fn, jnp.float8_e4m3fn], [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
...@@ -1212,194 +1229,219 @@ fwd_bwd_dtypes = [ ...@@ -1212,194 +1229,219 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn], [jnp.float8_e5m2, jnp.float8_e4m3fn],
] ]
GROUPED_DENSE_INPUT_SHAPES = [ """
# (n_groups, m, n, k), the actual m will be multiplied by 32 @pytest_parametrize_wrapper(
(5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
(8, 64, 32, 128), )
(8, 64, 128, 256),
]
@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
class TestGroupedDense: class TestGroupedDense:
def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list):
lhs_contract_dim, _ = contracting_dims ref_out_list = []
assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
if bias is None: dim_nums = (contracting_dims, ((), ()))
bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums))
else: return ref_out_list
assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list):
lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
rhs = jnp.split(rhs, rhs.shape[0], axis=0)
bias = jnp.split(bias, bias.shape[0], axis=0)
ref_out = []
dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i))
return ref_out
def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4) subkeys = jax.random.split(key, len(shape_list) * 2)
n_groups, m, n, k = input_shape
lhs_list, rhs_list, contracting_dims_list = [], [], []
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) lhs = jax.random.uniform(
group_sizes = jnp.diff(group_sizes) subkeys[2 * i],
assert group_sizes.sum() == m (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype,
# *32 to make sure that input shape works for MXFP8 )
group_sizes = group_sizes * 32 rhs = jax.random.uniform(
m = m * 32 subkeys[2 * i + 1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) dtype=dtype,
rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) )
bias_shape = (n_groups, n) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return lhs, rhs, group_sizes, contracting_dims, bias lhs_list.append(lhs)
rhs_list.append(rhs)
contracting_dims_list.append(contracting_dims)
def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): return lhs_list, rhs_list, contracting_dims_list
assert out.dtype == ref_list[0].dtype
out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
for i in range(len(ref_list)):
assert_allclose(out_list[i], ref_list[i], dtype=dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
@pytest_parametrize_wrapper("layout", ["NN"]) @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp16(self, dtype, input_shape, layout): def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list):
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
dtype, input_shape, layout dtype, shape_list, layout_list
) )
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"]) @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=False,
n_groups=input_shape[0],
) )
# quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
# We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
quantizer_set.kernel.q_dtype = bwd_dtype
for quantizer in quantizer_set.kernel.quantizers:
quantizer.q_dtype = bwd_dtype
out_dtype = jnp.bfloat16 out_dtype = jnp.bfloat16
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
out_dtype, input_shape, layout out_dtype, shape_list, layout_list
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
prim_out = tex.grouped_gemm(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
) )
q_lhs_list = []
q_rhs_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
# test the case where lhs and rhs have different q_dtypes
q_lhs, q_rhs = _quantize_gemm_pair(
lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
)
q_lhs_list.append(q_lhs)
q_rhs_list.append(q_rhs)
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list)
allclose_dtype = jnp.float8_e4m3fn allclose_dtype = jnp.float8_e4m3fn
if jnp.float8_e5m2 in fwd_bwd_dtype: if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
allclose_dtype = jnp.float8_e5m2 allclose_dtype = jnp.float8_e5m2
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, shape_list):
def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): group_size = len(shape_list)
out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) layout_list = ["NN" for _ in range(group_size)]
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def _primitive_sum_grouped_dense( x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set dtype, shape_list, layout_list
):
out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
) )
return jnp.sum(jnp.asarray(out)) bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list):
def test_grouped_dense_grad_fp16(self, dtype, input_shape): out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list)
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( out_sum_list = [jnp.sum(out) for out in out_list]
dtype, return jnp.sum(jnp.asarray(out_sum_list))
input_shape,
with_bias=True,
)
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x, kernel, bias, group_sizes, contracting_dims x_list, kernel_list, bias_list, contracting_dims_list
) )
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
x, kernel, bias, group_sizes, contracting_dims value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list)
) )
assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype)
assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) for i in range(group_size):
assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype) assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize( @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
"fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING: group_size = len(shape_list)
pytest.skip("MXFP8 is not supported in grouped_dense yet") layout_list = ["NN" for _ in range(group_size)]
fwd_dtype, bwd_dtype = fwd_bwd_dtype fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16 if fwd_dtype == jnp.float8_e5m2:
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( pytest.skip("We never use E5M2 for fwd_dtype in training")
dtype,
input_shape, # Question: should we use different quantizers for different groups?
with_bias=True, ref_quantizer_set_list = []
) quantizer_set_list = []
for _ in range(group_size):
ref_quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
ref_quantizer_set_list.append(ref_quantizer_set)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
quantizer_set_list.append(quantizer_set)
quantizer_set = QuantizerFactory.create_set( out_dtype = jnp.bfloat16
scaling_mode=scaling_mode, x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
fwd_dtype=fwd_dtype, out_dtype, shape_list, layout_list
bwd_dtype=bwd_dtype,
is_2x2x=True,
n_groups=group_sizes.size,
) )
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) bias_list = []
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) key = jax.random.PRNGKey(1)
for shape in shape_list:
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( n = shape[1]
x, bias = jax.random.uniform(key, n, dtype=out_dtype)
kernel, bias_list.append(bias)
bias,
group_sizes, def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
contracting_dims, out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
quantizer_set=quantizer_set_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
):
out_list = grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list
) )
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set value_n_grad_primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
) )
assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) allclose_dtype = jnp.float8_e4m3fn
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) allclose_dtype = jnp.float8_e5m2
assert_allclose(prim_dbias, ref_dbias, dtype=dtype) assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
...@@ -525,7 +525,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -525,7 +525,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B)); const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C)); const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D)); const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
const auto workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
...@@ -534,8 +533,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -534,8 +533,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
NVTE_CHECK(workspace_alignment % 256 == 0,
"cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment);
const auto status = const auto status =
cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
......
...@@ -6,28 +6,22 @@ ...@@ -6,28 +6,22 @@
from typing import Tuple, Sequence, Union, Dict from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import math
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
from ..quantize import ( from ..quantize import (
ScaledTensor, ScaledTensor,
GroupedScaledTensor1x,
ScalingMode, ScalingMode,
Quantizer, Quantizer,
GroupedQuantizer,
QuantizeConfig, QuantizeConfig,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
) )
__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"] __all__ = ["gemm"]
num_cublas_streams = 4 num_cublas_streams = 4
...@@ -40,11 +34,6 @@ def get_cublas_workspace_size_bytes() -> None: ...@@ -40,11 +34,6 @@ def get_cublas_workspace_size_bytes() -> None:
return 4_194_304 return 4_194_304
def is_gemm_with_all_layouts_supported() -> False:
"""Return True if using blackwell, False otherwise."""
return get_device_compute_capability(0) >= 100
class GroupedGemmPrimitive(BasePrimitive): class GroupedGemmPrimitive(BasePrimitive):
""" """
Primitive for grouped GEMM Primitive for grouped GEMM
...@@ -52,139 +41,73 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -52,139 +41,73 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi" name = "te_grouped_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) impl_static_args = ()
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract( def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias):
lhs_data_aval,
lhs_scale_inv_aval,
rhs_data_aval,
rhs_scale_inv_aval,
bias_aval,
group_sizes_aval,
group_offset_aval,
*,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
""" """
Grouped GEMM operation.
Args: Args:
lhs_data: Left-hand side input matrix data, 1D flattened array *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array args[ 0 : num_gemms] are the lhs tensors,
rhs_data: Right-hand side input matrix data, 1D flattened array args[ num_gemms : 2*num_gemms] are the rhs tensors,
rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
bias: Bias matrix of shape (G, N) args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
group_sizes: 1D array containing the sizes of each group args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
group_offset: 1D array containing offsets for each group (not yet implemented) num_gemms: Number of GEMM operations to perform.
M: Number of rows in the output matrix scaling_mode: Scaling mode for the GEMM operations.
N: Number of columns in the output matrix out_dtype: Data type of the output tensors.
K: Number of columns in the left-hand side matrix has_bias: Boolean indicating if bias tensors are provided.
lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
scaling_mode: Scaling mode for the GEMM operations
out_dtype: Data type of the output tensors
has_bias: Boolean indicating if bias tensors are provided
is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation
where both lhs and rhs are 2D matrices and output is (G, M, N)
Returns: Returns:
A jnp.ndarray containing the result of the grouped GEMM operation A tuple of ShapedArray objects of size num_gemms+1:
ret[0 : num_gemms]: GEMM output tensors,
ret[num_gemms]:workspace tensor.
""" """
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval del scaling_mode
del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
# TODO(Phuong): move some shape checks from Cpp to here assert (
len(args) == expected_num_args
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
A_list = args[0:num_gemms]
B_list = args[num_gemms : 2 * num_gemms]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval = tuple(
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
for A, B in zip(A_list, B_list)
)
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
out_shape = (M, N) return (*out_list_aval, workspace_aval)
if is_grouped_dense_wgrad:
out_shape = (group_sizes_aval.size, M, N)
out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
return (out_aval, workspace_aval)
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
(out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
return (out_aval,) return out_aval
@staticmethod @staticmethod
def lowering( def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias):
ctx,
*args,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
del out_dtype del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx, ctx,
*args, *args,
M=M, num_gemms=num_gemms,
N=N, scaling_mode=int(scaling_mode),
K=K,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode.value,
has_bias=has_bias, has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
) )
@staticmethod @staticmethod
def impl( def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias):
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
assert GroupedGemmPrimitive.inner_primitive is not None assert GroupedGemmPrimitive.inner_primitive is not None
(out, _) = GroupedGemmPrimitive.inner_primitive.bind( out = GroupedGemmPrimitive.inner_primitive.bind(
lhs_data, *args,
lhs_scale_inv, num_gemms=num_gemms,
rhs_data, scaling_mode=scaling_mode.value,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M=M,
N=N,
K=K,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode,
out_dtype=out_dtype, out_dtype=out_dtype,
has_bias=has_bias, has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
) )
return (out,) return out[:-1] # out is [out_list, wkspace], only return out_list
register_primitive(GroupedGemmPrimitive) register_primitive(GroupedGemmPrimitive)
...@@ -362,7 +285,7 @@ def gemm( ...@@ -362,7 +285,7 @@ def gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""General matrix multiplication with optional quantization. """General matrix multiplication with optional quantization.
...@@ -387,190 +310,130 @@ def gemm( ...@@ -387,190 +310,130 @@ def gemm(
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
def grouped_gemm( """
lhs: Union[jnp.ndarray, GroupedScaledTensor1x], def swizzled_scale(scales):
rhs: Union[jnp.ndarray, GroupedScaledTensor1x], # Swizzle the scale tensor for FP8 GEMM
group_sizes: jnp.ndarray, assert scales.ndim == 2
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), rows, cols = scales.shape
bias: jnp.ndarray = None, scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
preferred_element_type: jnp.dtype = None, scales = scales.reshape(rows, cols)
group_offset: jnp.array = None, return scales
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""
Grouped GEMM operation.
Args:
lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
group_sizes: 1D array containing the sizes of each group
contracting_dims: Tuple of two sequences representing the contracting dimensions
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
Returns:
A jnp.ndarray containing the result of the grouped GEMM operation
Note:
Tested shapes:
lhs: [M, K] or [K, N]
rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K]
"""
# TODO(Phuong): implement the group_offset
group_offset = group_offset or jnp.zeros((1,), jnp.int32)
# TODO(Phuong): implement the precision
del precision
if isinstance(lhs, jnp.ndarray):
assert isinstance(rhs, jnp.ndarray)
out_dtype = lhs.dtype
lhs_shape = lhs.shape
rhs_shape = rhs.shape
lhs_data = lhs
rhs_data = rhs
lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32)
scaling_mode = ScalingMode.NO_SCALING
elif isinstance(lhs, GroupedScaledTensor1x):
assert isinstance(rhs, GroupedScaledTensor1x)
out_dtype = lhs.dq_dtype
lhs_shape = lhs.original_shape
rhs_shape = rhs.original_shape
lhs_data = lhs.data
rhs_data = rhs.data
lhs_scale_inv = lhs.scale_inv
rhs_scale_inv = rhs.scale_inv
assert lhs.scaling_mode == rhs.scaling_mode
scaling_mode = lhs.scaling_mode
else:
raise TypeError("Unsupported lhs type object!")
out_dtype = preferred_element_type or out_dtype
lhs_contract_dim, rhs_contract_dim = contracting_dims
lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1
lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1)
# rhs_shape [G, K, N]
rhs_is_trans = rhs_contract_dim[0] != 1
rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim)
is_grouped_dense_wgrad = False
if len(rhs_shape) == 2:
rhs_is_trans = rhs_contract_dim[0] != 0
is_grouped_dense_wgrad = True
# TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? def grouped_gemm(
if ( lhs_list: List[Union[jnp.ndarray, ScaledTensor]],
is_grouped_dense_wgrad rhs_list: List[Union[jnp.ndarray, ScaledTensor]],
and not isinstance(lhs, ScaledTensor) contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
and not isinstance(rhs, ScaledTensor) bias_list: List[jnp.ndarray] = None,
): ) -> List[jnp.ndarray]:
lhs_is_trans = True # Grouped GEMM for multiple pairs of tensors.
rhs_is_trans = False assert (
lhs_flatten_axis = 1 len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
rhs_flatten_axis = 1 ), "lhs_list, rhs_list, contracting_dims_list must have the same length"
if ( num_gemms = len(lhs_list)
not isinstance(lhs, ScaledTensor) lhs_list_ = []
and not isinstance(rhs, ScaledTensor) rhs_list_ = []
and quantizer_set != noop_quantizer_set lhs_sinv_list_ = []
): rhs_sinv_list_ = []
assert isinstance(quantizer_set.x, GroupedQuantizer) bias_list_ = []
assert type(quantizer_set.x) is type(quantizer_set.kernel) for i in range(num_gemms):
scaling_mode = quantizer_set.x.scaling_mode lhs = lhs_list[i]
if ( rhs = rhs_list[i]
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later contracting_dims = contracting_dims_list[i]
# scaling_mode.is_tensor_scaling() dim_nums = (contracting_dims, ((), ()))
# and is_gemm_with_all_layouts_supported()
scaling_mode.is_1d_block_scaling()
):
lhs_is_rowwise = rhs_is_rowwise = True
else:
lhs_is_rowwise = not lhs_is_trans
rhs_is_rowwise = lhs_is_trans
quantizer_set.x.q_layout = (
QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
)
quantizer_set.kernel.q_layout = (
QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE
)
lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis)
rhs_q = grouped_quantize(
rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis
)
lhs_data = lhs_q.data
rhs_data = rhs_q.data
lhs_scale_inv = lhs_q.scale_inv
rhs_scale_inv = rhs_q.scale_inv
assert not (
lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
# Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
# thus additional transpose is required
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported():
lhs_is_trans = False
rhs_is_trans = True
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
lhs_layout_is_T = lhs.data_layout == "T" scaling_mode = lhs.scaling_mode
rhs_layout_is_T = rhs.data_layout == "T" lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode.is_tensor_scaling():
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else: else:
lhs_layout_is_T = lhs_q.data_layout == "T" # For jnp.ndarray, only consider contracting_dims, data_layout is always NN
rhs_layout_is_T = rhs_q.data_layout == "T" scaling_mode = ScalingMode.NO_SCALING
lhs_ndim = len(lhs_shape) lhs_shape = lhs.shape
rhs_ndim = len(rhs_shape) rhs_shape = rhs.shape
if lhs_layout_is_T: out_dtype = lhs.dtype
lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
if rhs_layout_is_T: (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) lhs_dn = (lhs_contract, lhs_batch)
lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T) rhs_dn = (rhs_contract, rhs_batch)
rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T)
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
# Calling GroupedGEMM Custom Call rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
assert K_lhs == K_rhs if scaling_mode == ScalingMode.NO_SCALING:
M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) lhs_3d = _shape_normalization(lhs, lhs_dn)
N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode.is_tensor_scaling():
if is_grouped_dense_wgrad: lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
else: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
assert group_sizes.size == rhs_shape[0] lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
assert group_offset.size == 1 lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
has_bias = bias is not None # swizzled_scale requires a matrix
assert not has_bias or bias.shape == (group_sizes.size, N) lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
bias = jnp.empty((), jnp.float32) if bias is None else bias rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else:
# TODO(Phuong): support MXFP8_1D_SCALING raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported"
# Note: already_transposed doesn't matter for the output shape
(out,) = GroupedGemmPrimitive.outer_primitive.bind( # x.shape = [B, D1, D2]
lhs_data, # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
lhs_scale_inv, # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
rhs_data, # x.shape = [D1, D2]
rhs_scale_inv, # contracting_dims = (1, ) --> output.shape = [1, D1, D2]
bias, # contracting_dims = (0, ) --> output.shape = [1, D2, D1]
group_sizes, bm = lhs_remain_shape[0]
group_offset, bn = rhs_remain_shape[0]
M=M, kl = lhs_3d.shape[-1]
N=N, kr = rhs_3d.shape[-1]
K=K_lhs, assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}"
lhs_is_trans=lhs_is_trans, if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0):
rhs_is_trans=rhs_is_trans, print("grouped_gemm input pair {i} has invalid problem shape for lowering: ")
scaling_mode=scaling_mode.value, print(f"m = {bm}, n = {bn}, k = {kl}; ")
print("cuBLAS requires the problem shapes being multiples of 16")
assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0)
lhs_list_.append(lhs_3d)
rhs_list_.append(rhs_3d)
if scaling_mode == ScalingMode.NO_SCALING:
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode.is_tensor_scaling():
lhs_sinv_list_.append(lhs.scale_inv)
rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_sinv_list_.append(lhs_scale_inv)
rhs_sinv_list_.append(rhs_scale_inv)
if bias_list is not None:
bias_list_.append(bias_list[i])
out_list = GroupedGemmPrimitive.outer_primitive.bind(
*lhs_list_,
*rhs_list_,
*lhs_sinv_list_,
*rhs_sinv_list_,
*bias_list_,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
out_dtype=out_dtype, out_dtype=out_dtype,
has_bias=has_bias, has_bias=1 if bias_list is not None else 0,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
) )
return out
return out_list
"""
...@@ -47,7 +47,7 @@ else: ...@@ -47,7 +47,7 @@ else:
from jax.extend import ffi # pylint: disable=ungrouped-imports from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] __all__ = ["quantize", "quantize_dbias", "grouped_quantize"]
class BaseDBiasQuantizePrimitive(BasePrimitive): class BaseDBiasQuantizePrimitive(BasePrimitive):
...@@ -1032,24 +1032,3 @@ def grouped_quantize( ...@@ -1032,24 +1032,3 @@ def grouped_quantize(
group_axis=group_axis, group_axis=group_axis,
) )
return out return out
def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
"""
Compute the grouped bias gradient.
Args:
grad: jnp.ndarray of shape (M, N)
group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M
Returns:
dbias: jnp.ndarray of shape (num_groups, N)
"""
assert grad.ndim == 2, "Input grad must be a 2D tensor."
assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."
segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes)
grad_fp32 = grad.astype(jnp.float32)
dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
dbias = dbias_fp32.astype(grad.dtype)
return dbias
...@@ -13,127 +13,43 @@ ...@@ -13,127 +13,43 @@
#include "transformer_engine/multi_stream.h" #include "transformer_engine/multi_stream.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Variadic_Result_Type output_list, int64_t num_gemms,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, JAXX_Scaling_Mode scaling_mode, int64_t has_bias) {
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) {
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major [m, k] for N - [k, m] for T // A: row-major with size [m, k],
// B: row-major [k, n] for N - [n, k] for T // B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect: // on exiting this function, JAX expect:
// C: row-major with size [m, n]. // C: row-major with size [m, n].
// cuBLAS uses column-major data_layout, in this view, each input matrix pair: // cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m] for T - [m, k] for N // A: column-major with size [k, m], needs transpose,
// B: column-major with size [n, k] for T - [k, n] for N // B: column-major with size [k, n].
//
// If we call cuBLAS GEMM for A * B, the output will be: // If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m]. // C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
int num_streams = nvte_get_num_compute_streams(); if (num_gemms <= 0) {
return ffi_with_cuda_error_check();
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_data.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_data.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type());
auto bias_ptr = has_bias ? reinterpret_cast<uint8_t *>(bias.untyped_data()) : nullptr;
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());
NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0];
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto workspace_total_size = product(workspace->dimensions());
auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions());
auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t expected_lhs_size = m * k;
size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n);
size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n);
size_t actual_lhs_size = product(lhs_data.dimensions());
size_t actual_rhs_size = product(rhs_data.dimensions());
size_t actual_out_size = product(output->dimensions());
NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ",
expected_lhs_size, ", got ", actual_lhs_size);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(expected_rhs_size == actual_rhs_size,
"Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k,
" = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m,
" * ", n, " = ", expected_out_size, ", got ", actual_out_size);
} else {
NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k,
" * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size,
"Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n,
" = ", expected_out_size, ", got ", actual_out_size);
} }
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t dim_list_bytes = sizeof(int32_t) * num_gemms; size_t expected_output_size = num_gemms + 1;
std::vector<int32_t> dim_list_host(num_gemms); size_t actual_input_size = input_list.size();
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data()); size_t actual_output_size = output_list.size();
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
stream); expected_input_size, actual_input_size);
// Note: This may break cudaGraph. NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
cudaStreamSynchronize(stream); expected_output_size, actual_output_size);
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) { bool trans_lhs = true;
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, bool trans_rhs = false;
", got sum(group_sizes)=", sum_group_sizes);
} else {
NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k,
", got sum(group_sizes)=", sum_group_sizes);
}
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
bool grad = false; bool grad = false;
bool accumulate = false; bool accumulate = false;
bool use_split_accumulator = false; bool use_split_accumulator = false;
auto bias_shape = std::vector<size_t>{has_bias ? n : 0};
const int arch = cuda::sm_arch();
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
if (arch < 100 && is_fp8_gemm) {
NVTE_CHECK(!lhs_is_trans && rhs_is_trans,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ",
"got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans);
}
// These lists are to keep the TensorWrapper objects alive // These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> lhs_wrapper_list; std::vector<TensorWrapper> lhs_wrapper_list;
...@@ -151,83 +67,96 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -151,83 +67,96 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
std::vector<NVTETensor> out_list; std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list; std::vector<NVTETensor> workspace_list;
for (size_t i = 0; i < num_gemms; i++) { int lhs_list_offset = 0;
// Matrix data shapes int rhs_list_offset = num_gemms;
size_t m_i = dim_list_host[i]; int lhs_sinv_list_offset = 2 * num_gemms;
auto lhs_shape = std::vector<size_t>{m_i, k}; int rhs_sinv_list_offset = 3 * num_gemms;
auto rhs_shape = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; int bias_list_offset = 4 * num_gemms;
auto out_shape = std::vector<size_t>{m_i, n}; int out_list_offset = 0;
if (is_grouped_dense_wgrad) { for (int i = 0; i < num_gemms; i++) {
size_t k_i = dim_list_host[i]; Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value();
lhs_shape[0] = lhs_is_trans ? k_i : m; Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value();
lhs_shape[1] = lhs_is_trans ? m : k_i; Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value();
rhs_shape[0] = rhs_is_trans ? n : k_i; Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value();
rhs_shape[1] = rhs_is_trans ? k_i : n; Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value();
out_shape[0] = m;
out_shape[1] = n; DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type());
} DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
// Set matrix data pointers
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); void *lhs_ptr = lhs_i.untyped_data();
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); void *rhs_ptr = rhs_i.untyped_data();
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype); void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *lhs_vptr = static_cast<void *>(lhs_ptr); void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *rhs_vptr = static_cast<void *>(rhs_ptr); void *out_ptr = out_i->untyped_data();
if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); // Placeholder for bias since it can be empty
else DType bias_dtype = DType::kFloat32;
rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); void *bias_ptr = nullptr;
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); auto lhs_shape_ = lhs_i.dimensions();
else auto rhs_shape_ = rhs_i.dimensions();
lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape);
// lhs and rhs has shape [1, m, k] and [1, n, k]
// Scale_inv shapes size_t m = lhs_shape_[1];
auto lhs_sinv_size = std::vector<size_t>{1}; size_t n = rhs_shape_[1];
auto rhs_sinv_size = std::vector<size_t>{1}; size_t k = lhs_shape_[2];
if (is_mxfp8_scaling) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", auto lhs_shape = std::vector<size_t>{m, k};
MXFP8_BLOCK_SIZE, k); auto rhs_shape = std::vector<size_t>{n, k};
size_t scale_k = k / MXFP8_BLOCK_SIZE; auto out_shape = std::vector<size_t>{n, m};
lhs_sinv_size[0] = m_i * scale_k; auto lhs_sinv_shape = std::vector<size_t>{1, 1};
rhs_sinv_size[0] = n * scale_k; auto rhs_sinv_shape = std::vector<size_t>{1, 1};
// Need to add swizzle here
} if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
// Set scale_inv pointers scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
void *rhs_sinv_vptr = static_cast<void *>(rhs_sinv_ptr); float *amax_dptr = nullptr;
void *lhs_sinv_vptr = static_cast<void *>(lhs_sinv_ptr); float *scale_dptr = nullptr;
if (is_fp8_gemm) { auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
if (rhs_use_colwise) // MatA to enter cuBLAS reinterpret_cast<float *>(lhs_sinv_ptr));
rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
else reinterpret_cast<float *>(rhs_sinv_ptr));
rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); lhs_wrapper_list.push_back(std::move(lhs_i_));
if (lhs_use_colwise) // MatB to enter cuBLAS rhs_wrapper_list.push_back(std::move(rhs_i_));
lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
else // Note: the scale_inv array should have been swizzled in Python before lowering
lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
for (int i = 0; i < 2; i++) {
lhs_sinv_shape[i] = lhs_sinv_shape_[i];
rhs_sinv_shape[i] = rhs_sinv_shape_[i];
}
NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode);
TensorWrapper lhs_i_(nvte_scaling_mode);
TensorWrapper rhs_i_(nvte_scaling_mode);
lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else { } else {
NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
} }
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype);
auto pre_gelu_i = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype); void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0};
// Update pointer for the next GEMM pair auto pre_gelu_shape = std::vector<size_t>{0};
lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; if (has_bias) {
rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i);
out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; Buffer_Type bias_i = bias_i_get.value();
if (is_fp8_gemm) { bias_ptr = bias_i.untyped_data();
lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type());
rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; bias_shape[0] = n;
} }
if (has_bias) bias_ptr += n * bias_dtype_bytes; auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
// Move objects to the lists to keep them alive out_wrapper_list.push_back(std::move(out_i_));
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
out_wrapper_list.push_back(std::move(out_i));
bias_wrapper_list.push_back(std::move(bias_i)); bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
...@@ -238,6 +167,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -238,6 +167,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
out_list.push_back(out_wrapper_list.back().data()); out_list.push_back(out_wrapper_list.back().data());
} }
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto num_streams = nvte_get_num_compute_streams();
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size}; auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) { for (int i = 0; i < num_streams; i++) {
auto workspace_i = auto workspace_i =
...@@ -248,7 +182,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -248,7 +182,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
} }
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad,
workspace_list.data(), accumulate, use_split_accumulator, workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream); num_math_sm, stream);
...@@ -258,23 +192,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -258,23 +192,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind() FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_data .RemainingArgs() // input list
.Arg<Buffer_Type>() // lhs_sinv .RemainingRets() // output list
.Arg<Buffer_Type>() // rhs_data .Attr<int64_t>("num_gemms")
.Arg<Buffer_Type>() // rhs_sinv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // group_sizes
.Arg<Buffer_Type>() // group_offset
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("M")
.Attr<int64_t>("N")
.Attr<int64_t>("K")
.Attr<bool>("lhs_is_trans")
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias") .Attr<int64_t>("has_bias"),
.Attr<bool>("is_grouped_dense_wgrad"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
......
...@@ -153,28 +153,28 @@ def _dense_bwd_rule( ...@@ -153,28 +153,28 @@ def _dense_bwd_rule(
# GEMM NT # GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_contracting_dim = tuple( g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_contracting_dim = tuple( k_constracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel, rowwise_casted_kernel,
(g_contracting_dim, k_contracting_dim), (g_constracting_dim, k_constracting_dim),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
g_contracting_dim = x_contracting_dim = tuple( g_constracting_dim = x_constracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims)) range(0, len(x_shape) - len(fwd_x_contracting_dims))
) )
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
...@@ -184,240 +184,135 @@ def _dense_bwd_rule( ...@@ -184,240 +184,135 @@ def _dense_bwd_rule(
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
"""
def grouped_dense( def grouped_dense(
x: jnp.ndarray, x_list,
kernel: jnp.ndarray, kernel_list,
group_sizes: jnp.ndarray, bias_list,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), contracting_dims_list,
bias: jnp.ndarray = None, quantizer_set_list=None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
""" # Perform grouped_dense layer transformation with optional quantization.
Perform grouped dense (linear) layer transformation with optional quantization.
Args: output_list = _grouped_dense(
x: Input tensor of shape (M, K) x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
kernel: Weight matrix of shape (G, K, N)
group_sizes: 1D array of shape (G,) specifying the size of each group
contracting_dims: Tuple of sequences specifying which dimensions to contract
(currently only supports ((1,), (1,)))
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
Returns:
A jnp.ndarray containing the result of the grouped linear operation
"""
output = _grouped_dense(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
) )
return output return output_list
@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) @partial(jax.custom_vjp, nondiff_argnums=(3,))
def _grouped_dense( def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
x, output_list, _ = _grouped_dense_fwd_rule(
kernel, x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
):
output, _ = _grouped_dense_fwd_rule(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
) )
return output return output_list
def _grouped_dense_fwd_rule( def _grouped_dense_fwd_rule(
x, x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
): ):
use_bias = bias is not None use_bias = bias_list is not None
is_noop_quantizer_set = quantizer_set == noop_quantizer_set output_list = []
x_rowwise_list = []
if is_noop_quantizer_set: x_colwise_list = []
grouped_gemm_x = x kernel_colwise_list = []
grouped_gemm_kernel = kernel kernel_rowwise_list = []
ctx_x = x x_shape_list = []
ctx_kernel = kernel kernel_shape_list = []
flatten_axis_k = None if quantizer_set_list is None:
x_rowwise_list = x_list
x_colwise_list = x_list
kernel_colwise_list = kernel_list
kernel_rowwise_list = kernel_list
x_shape_list = [x.shape for x in x_list]
kernel_shape_list = [kernel.shape for kernel in kernel_list]
else: else:
x_contracting_dims, k_contracting_dims = contracting_dims for i in range(len(x_list)): # pylint: disable=consider-using-enumerate
flatten_axis_x = -len(x_contracting_dims) q_x = tex.quantize(x_list[i], quantizer_set_list[i].x)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel)
x_rowwise_list.append(q_x.get_rowwise_tensor())
assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" x_colwise_list.append(q_x.get_colwise_tensor())
assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" kernel_colwise_list.append(q_kernel.get_colwise_tensor())
# Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose kernel_rowwise_list.append(q_kernel.get_rowwise_tensor())
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? x_shape_list.append(x_rowwise_list[-1].data.shape)
assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( kernel_shape_list.append(kernel_rowwise_list[-1].data.shape)
"grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
"and k_contracting_dims=(1,) for now, " output_list = tex.grouped_gemm(
f"got {x_contracting_dims=} and {k_contracting_dims=}" x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list
)
k_contracting_dims = (0,)
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
)
casted_kernel = tex.grouped_quantize(
kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k
)
contracting_dims = (x_contracting_dims, k_contracting_dims)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor()
grouped_gemm_kernel = casted_kernel.get_colwise_tensor()
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()?
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
output = tex.grouped_gemm(
grouped_gemm_x,
grouped_gemm_kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
) )
ctx = ( ctx = (
group_sizes, x_colwise_list,
ctx_x, kernel_rowwise_list,
ctx_kernel, x_shape_list,
x.shape, kernel_shape_list,
kernel.shape,
use_bias, use_bias,
is_noop_quantizer_set, quantizer_set_list,
quantizer_set,
flatten_axis_k,
) )
return output, ctx return output_list, ctx
def _grouped_dense_bwd_rule( def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
contracting_dims, precision, preferred_element_type, group_offset, ctx, grad
):
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
( (
group_sizes, colwise_x_list,
ctx_x, rowwise_kernel_list,
ctx_kernel, x_shape_list,
x_shape, kernel_shape_list,
kernel_shape,
use_bias, use_bias,
is_noop_quantizer_set, quantizer_set_list,
quantizer_set,
flatten_axis_k,
) = ctx ) = ctx
if is_noop_quantizer_set: group_size = len(grad_list)
# The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) dbias_list = []
# g_contracting_dim = (1, ) grad_rowwise_list = []
# k_contracting_dim = (2, ) grad_colwise_list = []
dgrad_contracting_dims_list = []
wgrad_contracting_dims_list = []
for i in range(group_size):
grad = grad_list[i]
x_shape = x_shape_list[i]
kernel_shape = kernel_shape_list[i]
fwd_contracting_dims = contracting_dims_list[i]
if quantizer_set_list is None:
casted_grad = grad
dbias = tex.quantization._jax_dbias(grad)
grad_rowwise_list.append(grad)
grad_colwise_list.append(grad)
else:
quantizer_set = quantizer_set_list[i]
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad
)
grad_rowwise_list.append(casted_grad.get_rowwise_tensor())
grad_colwise_list.append(casted_grad.get_colwise_tensor())
dbias_list.append(dbias)
# GEMM NT
fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims
g_contracting_dim = tuple( g_contracting_dim = tuple(
range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
) )
k_contracting_dim = tuple( k_contracting_dim = tuple(
dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = grad dgrad_contracting_dims_list.append(dgrad_contracting_dims)
dgrad_kernel_T = ctx_kernel
# g_contracting_dim = (0, ) # GEMM TN
# x_contracting_dim = (0, )
g_contracting_dim = x_contracting_dim = tuple( g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims)) range(0, len(x_shape) - len(fwd_x_contracting_dims))
) )
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x wgrad_contracting_dims_list.append(wgrad_contracting_dims)
wgrad_grad = grad
else:
casted_grad = tex.grouped_quantize(
grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
# g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
# extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (1,)
k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor()
dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,)
x_contracting_dim = (1,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor()
dgrad = tex.grouped_gemm(
dgrad_grad,
dgrad_kernel_T,
group_sizes,
dgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
wgrad = tex.grouped_gemm( dgrad_list = tex.grouped_gemm(
wgrad_x_T, grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list
wgrad_grad,
group_sizes,
wgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
) )
wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list)
group_sizes_grad = None return dgrad_list, wgrad_list, dbias_list, quantizer_set_list
dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
...@@ -127,16 +127,14 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -127,16 +127,14 @@ class BlockScaleDequantizer(Dequantizer):
def dequantize(scaled_tensor): def dequantize(scaled_tensor):
"""Dequantize a tensor using block scaling. """Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args: Args:
data: The quantized tensor data scaled_tensor: The quantized tensor to dequantize
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns: Returns:
The dequantized tensor The dequantized tensor in the specified data type
""" """
return BlockScaleDequantizer._dequantize_func( return BlockScaleDequantizer._dequantize_func(
scaled_tensor.data, scaled_tensor.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment