Unverified Commit 5d01ef21 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] GroupedDense v.2 without dynamic shape (#1721)



* Implemented GroupedDense and TestGroupedDense for BF16, FP16, and FP8 
* Fix GroupedGemmFFI cuBLAS workspace alignment bug
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c293d3a8
...@@ -40,10 +40,11 @@ from transformer_engine.jax.quantize import ( ...@@ -40,10 +40,11 @@ from transformer_engine.jax.quantize import (
ScalingMode, ScalingMode,
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.jax.layernorm_dense import layernorm_dense
GEMM_CASES = [ GEMM_CASES = [
...@@ -1204,24 +1205,6 @@ class TestFusedDense: ...@@ -1204,24 +1205,6 @@ class TestFusedDense:
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer):
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return lhs_q, rhs_q
# E5M2 * E5M2 is not supported # E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [ fwd_bwd_dtypes = [
[jnp.float8_e4m3fn, jnp.float8_e4m3fn], [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
...@@ -1229,219 +1212,194 @@ fwd_bwd_dtypes = [ ...@@ -1229,219 +1212,194 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn], [jnp.float8_e5m2, jnp.float8_e4m3fn],
] ]
""" GROUPED_DENSE_INPUT_SHAPES = [
@pytest_parametrize_wrapper( # (n_groups, m, n, k), the actual m will be multiplied by 32
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] (5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4
) (8, 64, 32, 128),
(8, 64, 128, 256),
]
@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
class TestGroupedDense: class TestGroupedDense:
def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list): def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
ref_out_list = [] lhs_contract_dim, _ = contracting_dims
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
dim_nums = (contracting_dims, ((), ())) if bias is None:
ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums)) bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
return ref_out_list else:
assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list): remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
rhs = jnp.split(rhs, rhs.shape[0], axis=0)
bias = jnp.split(bias, bias.shape[0], axis=0)
ref_out = []
dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i))
return ref_out
def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, len(shape_list) * 2) subkeys = jax.random.split(key, 4)
n_groups, m, n, k = input_shape
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
assert group_sizes.sum() == m
# *32 to make sure that input shape works for MXFP8
group_sizes = group_sizes * 32
m = m * 32
lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
bias_shape = (n_groups, n)
lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None
lhs_list, rhs_list, contracting_dims_list = [], [], []
for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
lhs = jax.random.uniform(
subkeys[2 * i],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype,
)
rhs = jax.random.uniform(
subkeys[2 * i + 1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=dtype,
)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
lhs_list.append(lhs) return lhs, rhs, group_sizes, contracting_dims, bias
rhs_list.append(rhs)
contracting_dims_list.append(contracting_dims)
return lhs_list, rhs_list, contracting_dims_list def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
assert out.dtype == ref_list[0].dtype
out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
for i in range(len(ref_list)):
assert_allclose(out_list[i], ref_list[i], dtype=dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) @pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list): def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
dtype, shape_list, layout_list dtype, input_shape, layout
) )
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list) prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
for i in range(len(shape_list)): self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
assert_allclose(primitive_out[i], ref_out[i], dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]]) @pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list): def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False scaling_mode=scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=False,
n_groups=input_shape[0],
) )
# quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
# We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
quantizer_set.kernel.q_dtype = bwd_dtype
for quantizer in quantizer_set.kernel.quantizers:
quantizer.q_dtype = bwd_dtype
out_dtype = jnp.bfloat16 out_dtype = jnp.bfloat16
lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input( lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
out_dtype, shape_list, layout_list out_dtype, input_shape, layout
) )
q_lhs_list = [] ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
q_rhs_list = [] prim_out = tex.grouped_gemm(
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list): lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
# test the case where lhs and rhs have different q_dtypes
q_lhs, q_rhs = _quantize_gemm_pair(
lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
) )
q_lhs_list.append(q_lhs)
q_rhs_list.append(q_rhs)
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list)
allclose_dtype = jnp.float8_e4m3fn allclose_dtype = jnp.float8_e4m3fn
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: if jnp.float8_e5m2 in fwd_bwd_dtype:
allclose_dtype = jnp.float8_e5m2 allclose_dtype = jnp.float8_e5m2
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
def test_grouped_dense_grad_fp16(self, dtype, shape_list):
group_size = len(shape_list)
layout_list = ["NN" for _ in range(group_size)]
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
dtype, shape_list, layout_list out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
)
bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger # Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero # and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list] out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list)) return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list): def _primitive_sum_grouped_dense(
out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list) self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
out_sum_list = [jnp.sum(out) for out in out_list] ):
return jnp.sum(jnp.asarray(out_sum_list)) out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
)
return jnp.sum(jnp.asarray(out))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) def test_grouped_dense_grad_fp16(self, dtype, input_shape):
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
dtype,
input_shape,
with_bias=True,
)
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list x, kernel, bias, group_sizes, contracting_dims
) )
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list) x, kernel, bias, group_sizes, contracting_dims
) )
assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype) assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
for i in range(group_size): assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype) assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest.mark.parametrize(
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) "fwd_bwd_dtype",
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list): [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
group_size = len(shape_list)
layout_list = ["NN" for _ in range(group_size)]
fwd_dtype, bwd_dtype = fwd_bwd_dtype
if fwd_dtype == jnp.float8_e5m2:
pytest.skip("We never use E5M2 for fwd_dtype in training")
# Question: should we use different quantizers for different groups?
ref_quantizer_set_list = []
quantizer_set_list = []
for _ in range(group_size):
ref_quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
ref_quantizer_set_list.append(ref_quantizer_set)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
) )
quantizer_set_list.append(quantizer_set) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_dense yet")
out_dtype = jnp.bfloat16 fwd_dtype, bwd_dtype = fwd_bwd_dtype
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input( dtype = jnp.bfloat16
out_dtype, shape_list, layout_list x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
) dtype,
bias_list = [] input_shape,
key = jax.random.PRNGKey(1) with_bias=True,
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=out_dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
quantizer_set=quantizer_set_list[i],
)
) )
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func( quantizer_set = QuantizerFactory.create_set(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list scaling_mode=scaling_mode,
): fwd_dtype=fwd_dtype,
out_list = grouped_dense( bwd_dtype=bwd_dtype,
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list is_2x2x=True,
n_groups=group_sizes.size,
) )
out_sum_list = [jnp.sum(out) for out in out_list] value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
return jnp.sum(jnp.asarray(out_sum_list)) value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func( ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list x,
) kernel,
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = ( bias,
value_n_grad_primitive_func( group_sizes,
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list contracting_dims,
) )
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
) )
allclose_dtype = jnp.float8_e4m3fn assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2: assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
allclose_dtype = jnp.float8_e5m2 assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
...@@ -525,6 +525,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -525,6 +525,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B)); const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C)); const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D)); const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
const auto workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
...@@ -533,6 +534,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -533,6 +534,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
NVTE_CHECK(workspace_alignment % 256 == 0,
"cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment);
const auto status = const auto status =
cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
......
...@@ -6,22 +6,28 @@ ...@@ -6,22 +6,28 @@
from typing import Tuple, Sequence, Union, Dict from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import math
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
from ..quantize import ( from ..quantize import (
ScaledTensor, ScaledTensor,
GroupedScaledTensor1x,
ScalingMode, ScalingMode,
Quantizer, Quantizer,
GroupedQuantizer,
QuantizeConfig, QuantizeConfig,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
) )
__all__ = ["gemm"] __all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"]
num_cublas_streams = 4 num_cublas_streams = 4
...@@ -34,6 +40,11 @@ def get_cublas_workspace_size_bytes() -> None: ...@@ -34,6 +40,11 @@ def get_cublas_workspace_size_bytes() -> None:
return 4_194_304 return 4_194_304
def is_gemm_with_all_layouts_supported() -> False:
"""Return True if using blackwell, False otherwise."""
return get_device_compute_capability(0) >= 100
class GroupedGemmPrimitive(BasePrimitive): class GroupedGemmPrimitive(BasePrimitive):
""" """
Primitive for grouped GEMM Primitive for grouped GEMM
...@@ -41,73 +52,139 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -41,73 +52,139 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi" name = "te_grouped_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = () impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias): def abstract(
lhs_data_aval,
lhs_scale_inv_aval,
rhs_data_aval,
rhs_scale_inv_aval,
bias_aval,
group_sizes_aval,
group_offset_aval,
*,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
""" """
Grouped GEMM operation.
Args: Args:
*args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias: lhs_data: Left-hand side input matrix data, 1D flattened array
args[ 0 : num_gemms] are the lhs tensors, lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array
args[ num_gemms : 2*num_gemms] are the rhs tensors, rhs_data: Right-hand side input matrix data, 1D flattened array
args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors, rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array
args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors, bias: Bias matrix of shape (G, N)
args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True. group_sizes: 1D array containing the sizes of each group
num_gemms: Number of GEMM operations to perform. group_offset: 1D array containing offsets for each group (not yet implemented)
scaling_mode: Scaling mode for the GEMM operations. M: Number of rows in the output matrix
out_dtype: Data type of the output tensors. N: Number of columns in the output matrix
has_bias: Boolean indicating if bias tensors are provided. K: Number of columns in the left-hand side matrix
lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
scaling_mode: Scaling mode for the GEMM operations
out_dtype: Data type of the output tensors
has_bias: Boolean indicating if bias tensors are provided
is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation
where both lhs and rhs are 2D matrices and output is (G, M, N)
Returns: Returns:
A tuple of ShapedArray objects of size num_gemms+1: A jnp.ndarray containing the result of the grouped GEMM operation
ret[0 : num_gemms]: GEMM output tensors,
ret[num_gemms]:workspace tensor.
""" """
del scaling_mode del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias
assert ( # TODO(Phuong): move some shape checks from Cpp to here
len(args) == expected_num_args
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
A_list = args[0:num_gemms]
B_list = args[num_gemms : 2 * num_gemms]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval = tuple(
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
for A, B in zip(A_list, B_list)
)
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return (*out_list_aval, workspace_aval) out_shape = (M, N)
if is_grouped_dense_wgrad:
out_shape = (group_sizes_aval.size, M, N)
out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
return (out_aval, workspace_aval)
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
(out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs) (out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
return out_aval return (out_aval,)
@staticmethod @staticmethod
def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias): def lowering(
ctx,
*args,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
del out_dtype del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx, ctx,
*args, *args,
num_gemms=num_gemms, M=M,
scaling_mode=int(scaling_mode), N=N,
K=K,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode.value,
has_bias=has_bias, has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
) )
@staticmethod @staticmethod
def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias): def impl(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
assert GroupedGemmPrimitive.inner_primitive is not None assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind( (out, _) = GroupedGemmPrimitive.inner_primitive.bind(
*args, lhs_data,
num_gemms=num_gemms, lhs_scale_inv,
scaling_mode=scaling_mode.value, rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M=M,
N=N,
K=K,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode,
out_dtype=out_dtype, out_dtype=out_dtype,
has_bias=has_bias, has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
) )
return out[:-1] # out is [out_list, wkspace], only return out_list return (out,)
register_primitive(GroupedGemmPrimitive) register_primitive(GroupedGemmPrimitive)
...@@ -285,7 +362,7 @@ def gemm( ...@@ -285,7 +362,7 @@ def gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""General matrix multiplication with optional quantization. """General matrix multiplication with optional quantization.
...@@ -310,130 +387,190 @@ def gemm( ...@@ -310,130 +387,190 @@ def gemm(
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
""" def grouped_gemm(
def swizzled_scale(scales): lhs: Union[jnp.ndarray, GroupedScaledTensor1x],
# Swizzle the scale tensor for FP8 GEMM rhs: Union[jnp.ndarray, GroupedScaledTensor1x],
assert scales.ndim == 2 group_sizes: jnp.ndarray,
rows, cols = scales.shape contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)),
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) bias: jnp.ndarray = None,
scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
scales = scales.reshape(rows, cols) preferred_element_type: jnp.dtype = None,
return scales group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""
Grouped GEMM operation.
Args:
lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
group_sizes: 1D array containing the sizes of each group
contracting_dims: Tuple of two sequences representing the contracting dimensions
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
def grouped_gemm( Returns:
lhs_list: List[Union[jnp.ndarray, ScaledTensor]], A jnp.ndarray containing the result of the grouped GEMM operation
rhs_list: List[Union[jnp.ndarray, ScaledTensor]],
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], Note:
bias_list: List[jnp.ndarray] = None, Tested shapes:
) -> List[jnp.ndarray]: lhs: [M, K] or [K, N]
# Grouped GEMM for multiple pairs of tensors. rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K]
assert ( """
len(lhs_list) == len(rhs_list) == len(contracting_dims_list) # TODO(Phuong): implement the group_offset
), "lhs_list, rhs_list, contracting_dims_list must have the same length" group_offset = group_offset or jnp.zeros((1,), jnp.int32)
num_gemms = len(lhs_list) # TODO(Phuong): implement the precision
lhs_list_ = [] del precision
rhs_list_ = []
lhs_sinv_list_ = [] if isinstance(lhs, jnp.ndarray):
rhs_sinv_list_ = [] assert isinstance(rhs, jnp.ndarray)
bias_list_ = [] out_dtype = lhs.dtype
for i in range(num_gemms): lhs_shape = lhs.shape
lhs = lhs_list[i] rhs_shape = rhs.shape
rhs = rhs_list[i] lhs_data = lhs
contracting_dims = contracting_dims_list[i] rhs_data = rhs
dim_nums = (contracting_dims, ((), ())) lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32)
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): scaling_mode = ScalingMode.NO_SCALING
scaling_mode = lhs.scaling_mode elif isinstance(lhs, GroupedScaledTensor1x):
lhs_shape = lhs.data.shape assert isinstance(rhs, GroupedScaledTensor1x)
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype out_dtype = lhs.dq_dtype
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout lhs_shape = lhs.original_shape
if lhs.scaling_mode.is_tensor_scaling(): rhs_shape = rhs.original_shape
lhs_data = lhs.data
rhs_data = rhs.data
lhs_scale_inv = lhs.scale_inv
rhs_scale_inv = rhs.scale_inv
assert lhs.scaling_mode == rhs.scaling_mode
scaling_mode = lhs.scaling_mode
else:
raise TypeError("Unsupported lhs type object!")
out_dtype = preferred_element_type or out_dtype
lhs_contract_dim, rhs_contract_dim = contracting_dims
lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1
lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1)
# rhs_shape [G, K, N]
rhs_is_trans = rhs_contract_dim[0] != 1
rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim)
is_grouped_dense_wgrad = False
if len(rhs_shape) == 2:
rhs_is_trans = rhs_contract_dim[0] != 0
is_grouped_dense_wgrad = True
# TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this?
if (
is_grouped_dense_wgrad
and not isinstance(lhs, ScaledTensor)
and not isinstance(rhs, ScaledTensor)
):
lhs_is_trans = True
rhs_is_trans = False
lhs_flatten_axis = 1
rhs_flatten_axis = 1
if (
not isinstance(lhs, ScaledTensor)
and not isinstance(rhs, ScaledTensor)
and quantizer_set != noop_quantizer_set
):
assert isinstance(quantizer_set.x, GroupedQuantizer)
assert type(quantizer_set.x) is type(quantizer_set.kernel)
scaling_mode = quantizer_set.x.scaling_mode
if (
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
# scaling_mode.is_tensor_scaling()
# and is_gemm_with_all_layouts_supported()
scaling_mode.is_1d_block_scaling()
):
lhs_is_rowwise = rhs_is_rowwise = True
else:
lhs_is_rowwise = not lhs_is_trans
rhs_is_rowwise = lhs_is_trans
quantizer_set.x.q_layout = (
QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
)
quantizer_set.kernel.q_layout = (
QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE
)
lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis)
rhs_q = grouped_quantize(
rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis
)
lhs_data = lhs_q.data
rhs_data = rhs_q.data
lhs_scale_inv = lhs_q.scale_inv
rhs_scale_inv = rhs_q.scale_inv
assert not ( assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2" ), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
out_dtype = lhs.dtype
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums # Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
lhs_dn = (lhs_contract, lhs_batch) # thus additional transpose is required
rhs_dn = (rhs_contract, rhs_batch) # TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported():
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) lhs_is_trans = False
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) rhs_is_trans = True
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
# Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy lhs_layout_is_T = lhs.data_layout == "T"
if scaling_mode == ScalingMode.NO_SCALING: rhs_layout_is_T = rhs.data_layout == "T"
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode.is_tensor_scaling():
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
# swizzled_scale requires a matrix
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else: else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") lhs_layout_is_T = lhs_q.data_layout == "T"
rhs_layout_is_T = rhs_q.data_layout == "T"
# Note: already_transposed doesn't matter for the output shape lhs_ndim = len(lhs_shape)
# x.shape = [B, D1, D2] rhs_ndim = len(rhs_shape)
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] if lhs_layout_is_T:
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
# x.shape = [D1, D2] if rhs_layout_is_T:
# contracting_dims = (1, ) --> output.shape = [1, D1, D2] rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim)
# contracting_dims = (0, ) --> output.shape = [1, D2, D1] lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T)
bm = lhs_remain_shape[0] rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T)
bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1] # Calling GroupedGEMM Custom Call
kr = rhs_3d.shape[-1] K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}" K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim)
if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0): assert K_lhs == K_rhs
print("grouped_gemm input pair {i} has invalid problem shape for lowering: ") M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim))
print(f"m = {bm}, n = {bn}, k = {kl}; ") N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G
print("cuBLAS requires the problem shapes being multiples of 16")
assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0) if is_grouped_dense_wgrad:
N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim))
lhs_list_.append(lhs_3d) else:
rhs_list_.append(rhs_3d) assert group_sizes.size == rhs_shape[0]
if scaling_mode == ScalingMode.NO_SCALING:
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32)) assert group_offset.size == 1
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode.is_tensor_scaling(): has_bias = bias is not None
lhs_sinv_list_.append(lhs.scale_inv) assert not has_bias or bias.shape == (group_sizes.size, N)
rhs_sinv_list_.append(rhs.scale_inv) bias = jnp.empty((), jnp.float32) if bias is None else bias
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_sinv_list_.append(lhs_scale_inv) # TODO(Phuong): support MXFP8_1D_SCALING
rhs_sinv_list_.append(rhs_scale_inv) assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported"
if bias_list is not None:
bias_list_.append(bias_list[i]) (out,) = GroupedGemmPrimitive.outer_primitive.bind(
lhs_data,
out_list = GroupedGemmPrimitive.outer_primitive.bind( lhs_scale_inv,
*lhs_list_, rhs_data,
*rhs_list_, rhs_scale_inv,
*lhs_sinv_list_, bias,
*rhs_sinv_list_, group_sizes,
*bias_list_, group_offset,
num_gemms=num_gemms, M=M,
scaling_mode=scaling_mode, N=N,
K=K_lhs,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype, out_dtype=out_dtype,
has_bias=1 if bias_list is not None else 0, has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
) )
return out
return out_list
"""
...@@ -47,7 +47,7 @@ else: ...@@ -47,7 +47,7 @@ else:
from jax.extend import ffi # pylint: disable=ungrouped-imports from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["quantize", "quantize_dbias", "grouped_quantize"] __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
class BaseDBiasQuantizePrimitive(BasePrimitive): class BaseDBiasQuantizePrimitive(BasePrimitive):
...@@ -1032,3 +1032,24 @@ def grouped_quantize( ...@@ -1032,3 +1032,24 @@ def grouped_quantize(
group_axis=group_axis, group_axis=group_axis,
) )
return out return out
def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
"""
Compute the grouped bias gradient.
Args:
grad: jnp.ndarray of shape (M, N)
group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M
Returns:
dbias: jnp.ndarray of shape (num_groups, N)
"""
assert grad.ndim == 2, "Input grad must be a 2D tensor."
assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."
segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes)
grad_fp32 = grad.astype(jnp.float32)
dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
dbias = dbias_fp32.astype(grad.dtype)
return dbias
...@@ -13,43 +13,127 @@ ...@@ -13,43 +13,127 @@
#include "transformer_engine/multi_stream.h" #include "transformer_engine/multi_stream.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Variadic_Result_Type output_list, int64_t num_gemms, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
JAXX_Scaling_Mode scaling_mode, int64_t has_bias) { Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) {
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k], // A: row-major [m, k] for N - [k, m] for T
// B: row-major with size [n, k], needs transpose, // B: row-major [k, n] for N - [n, k] for T
// on exiting this function, JAX expect: // on exiting this function, JAX expect:
// C: row-major with size [m, n]. // C: row-major with size [m, n].
// cuBLAS uses column-major data_layout, in this view, each input matrix pair: // cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose, // A: column-major with size [k, m] for T - [m, k] for N
// B: column-major with size [k, n]. // B: column-major with size [n, k] for T - [k, n] for N
//
// If we call cuBLAS GEMM for A * B, the output will be: // If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m]. // C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if (num_gemms <= 0) { int num_streams = nvte_get_num_compute_streams();
return ffi_with_cuda_error_check();
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_data.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_data.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type());
auto bias_ptr = has_bias ? reinterpret_cast<uint8_t *>(bias.untyped_data()) : nullptr;
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());
NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0];
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto workspace_total_size = product(workspace->dimensions());
auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions());
auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t expected_lhs_size = m * k;
size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n);
size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n);
size_t actual_lhs_size = product(lhs_data.dimensions());
size_t actual_rhs_size = product(rhs_data.dimensions());
size_t actual_out_size = product(output->dimensions());
NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ",
expected_lhs_size, ", got ", actual_lhs_size);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(expected_rhs_size == actual_rhs_size,
"Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k,
" = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m,
" * ", n, " = ", expected_out_size, ", got ", actual_out_size);
} else {
NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k,
" * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size,
"Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n,
" = ", expected_out_size, ", got ", actual_out_size);
}
size_t dim_list_bytes = sizeof(int32_t) * num_gemms;
std::vector<int32_t> dim_list_host(num_gemms);
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
", got sum(group_sizes)=", sum_group_sizes);
} else {
NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k,
", got sum(group_sizes)=", sum_group_sizes);
} }
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t expected_output_size = num_gemms + 1;
size_t actual_input_size = input_list.size();
size_t actual_output_size = output_list.size();
NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
expected_input_size, actual_input_size);
NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
expected_output_size, actual_output_size);
bool trans_lhs = true;
bool trans_rhs = false;
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
bool grad = false; bool grad = false;
bool accumulate = false; bool accumulate = false;
bool use_split_accumulator = false; bool use_split_accumulator = false;
auto bias_shape = std::vector<size_t>{has_bias ? n : 0};
const int arch = cuda::sm_arch();
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
if (arch < 100 && is_fp8_gemm) {
NVTE_CHECK(!lhs_is_trans && rhs_is_trans,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ",
"got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans);
}
// These lists are to keep the TensorWrapper objects alive // These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> lhs_wrapper_list; std::vector<TensorWrapper> lhs_wrapper_list;
...@@ -67,96 +151,83 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, ...@@ -67,96 +151,83 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
std::vector<NVTETensor> out_list; std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list; std::vector<NVTETensor> workspace_list;
int lhs_list_offset = 0; for (size_t i = 0; i < num_gemms; i++) {
int rhs_list_offset = num_gemms; // Matrix data shapes
int lhs_sinv_list_offset = 2 * num_gemms; size_t m_i = dim_list_host[i];
int rhs_sinv_list_offset = 3 * num_gemms; auto lhs_shape = std::vector<size_t>{m_i, k};
int bias_list_offset = 4 * num_gemms; auto rhs_shape = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n};
int out_list_offset = 0; auto out_shape = std::vector<size_t>{m_i, n};
for (int i = 0; i < num_gemms; i++) { if (is_grouped_dense_wgrad) {
Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value(); size_t k_i = dim_list_host[i];
Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value(); lhs_shape[0] = lhs_is_trans ? k_i : m;
Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value(); lhs_shape[1] = lhs_is_trans ? m : k_i;
Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value(); rhs_shape[0] = rhs_is_trans ? n : k_i;
Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value(); rhs_shape[1] = rhs_is_trans ? k_i : n;
out_shape[0] = m;
DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type()); out_shape[1] = n;
DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
void *lhs_ptr = lhs_i.untyped_data();
void *rhs_ptr = rhs_i.untyped_data();
void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *out_ptr = out_i->untyped_data();
// Placeholder for bias since it can be empty
DType bias_dtype = DType::kFloat32;
void *bias_ptr = nullptr;
auto lhs_shape_ = lhs_i.dimensions();
auto rhs_shape_ = rhs_i.dimensions();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t m = lhs_shape_[1];
size_t n = rhs_shape_[1];
size_t k = lhs_shape_[2];
auto lhs_shape = std::vector<size_t>{m, k};
auto rhs_shape = std::vector<size_t>{n, k};
auto out_shape = std::vector<size_t>{n, m};
auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
float *amax_dptr = nullptr;
float *scale_dptr = nullptr;
auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Note: the scale_inv array should have been swizzled in Python before lowering
auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
for (int i = 0; i < 2; i++) {
lhs_sinv_shape[i] = lhs_sinv_shape_[i];
rhs_sinv_shape[i] = rhs_sinv_shape_[i];
} }
NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode); // Set matrix data pointers
TensorWrapper lhs_i_(nvte_scaling_mode); auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
TensorWrapper rhs_i_(nvte_scaling_mode); auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape); auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape); void *lhs_vptr = static_cast<void *>(lhs_ptr);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape); void *rhs_vptr = static_cast<void *>(rhs_ptr);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape); if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape);
else
rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape);
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape);
else
lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_)); // Scale_inv shapes
rhs_wrapper_list.push_back(std::move(rhs_i_)); auto lhs_sinv_size = std::vector<size_t>{1};
} else { auto rhs_sinv_size = std::vector<size_t>{1};
NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode)); if (is_mxfp8_scaling) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t scale_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_size[0] = m_i * scale_k;
rhs_sinv_size[0] = n * scale_k;
// Need to add swizzle here
} }
auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype); // Set scale_inv pointers
void *pre_gelu_ptr = nullptr; void *rhs_sinv_vptr = static_cast<void *>(rhs_sinv_ptr);
auto bias_shape = std::vector<size_t>{0}; void *lhs_sinv_vptr = static_cast<void *>(lhs_sinv_ptr);
auto pre_gelu_shape = std::vector<size_t>{0}; if (is_fp8_gemm) {
if (has_bias) { if (rhs_use_colwise) // MatA to enter cuBLAS
auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i); rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size);
Buffer_Type bias_i = bias_i_get.value(); else
bias_ptr = bias_i.untyped_data(); rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size);
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type()); if (lhs_use_colwise) // MatB to enter cuBLAS
bias_shape[0] = n; lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size);
else
lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size);
} else {
NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
} }
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); auto pre_gelu_i = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype);
// Update pointer for the next GEMM pair
lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes;
rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes;
out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes;
if (is_fp8_gemm) {
lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes;
}
if (has_bias) bias_ptr += n * bias_dtype_bytes;
out_wrapper_list.push_back(std::move(out_i_)); // Move objects to the lists to keep them alive
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
out_wrapper_list.push_back(std::move(out_i));
bias_wrapper_list.push_back(std::move(bias_i)); bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
...@@ -167,11 +238,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, ...@@ -167,11 +238,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
out_list.push_back(out_wrapper_list.back().data()); out_list.push_back(out_wrapper_list.back().data());
} }
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto num_streams = nvte_get_num_compute_streams();
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size}; auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) { for (int i = 0; i < num_streams; i++) {
auto workspace_i = auto workspace_i =
...@@ -182,7 +248,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, ...@@ -182,7 +248,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
} }
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad, pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad,
workspace_list.data(), accumulate, use_split_accumulator, workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream); num_math_sm, stream);
...@@ -192,11 +258,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list, ...@@ -192,11 +258,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind() FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.RemainingArgs() // input list .Arg<Buffer_Type>() // lhs_data
.RemainingRets() // output list .Arg<Buffer_Type>() // lhs_sinv
.Attr<int64_t>("num_gemms") .Arg<Buffer_Type>() // rhs_data
.Arg<Buffer_Type>() // rhs_sinv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // group_sizes
.Arg<Buffer_Type>() // group_offset
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("M")
.Attr<int64_t>("N")
.Attr<int64_t>("K")
.Attr<bool>("lhs_is_trans")
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("has_bias"), .Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
......
...@@ -153,28 +153,28 @@ def _dense_bwd_rule( ...@@ -153,28 +153,28 @@ def _dense_bwd_rule(
# GEMM NT # GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple( g_contracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_constracting_dim = tuple( k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel, rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim), (g_contracting_dim, k_contracting_dim),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
g_constracting_dim = x_constracting_dim = tuple( g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims)) range(0, len(x_shape) - len(fwd_x_contracting_dims))
) )
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim)
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
...@@ -184,135 +184,240 @@ def _dense_bwd_rule( ...@@ -184,135 +184,240 @@ def _dense_bwd_rule(
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
"""
def grouped_dense( def grouped_dense(
x_list, x: jnp.ndarray,
kernel_list, kernel: jnp.ndarray,
bias_list, group_sizes: jnp.ndarray,
contracting_dims_list, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
quantizer_set_list=None, bias: jnp.ndarray = None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
# Perform grouped_dense layer transformation with optional quantization. """
Perform grouped dense (linear) layer transformation with optional quantization.
output_list = _grouped_dense( Args:
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list x: Input tensor of shape (M, K)
kernel: Weight matrix of shape (G, K, N)
group_sizes: 1D array of shape (G,) specifying the size of each group
contracting_dims: Tuple of sequences specifying which dimensions to contract
(currently only supports ((1,), (1,)))
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
Returns:
A jnp.ndarray containing the result of the grouped linear operation
"""
output = _grouped_dense(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
) )
return output_list return output
@partial(jax.custom_vjp, nondiff_argnums=(3,)) @partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7))
def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): def _grouped_dense(
output_list, _ = _grouped_dense_fwd_rule( x,
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
):
output, _ = _grouped_dense_fwd_rule(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
) )
return output_list return output
def _grouped_dense_fwd_rule( def _grouped_dense_fwd_rule(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
): ):
use_bias = bias_list is not None use_bias = bias is not None
output_list = [] is_noop_quantizer_set = quantizer_set == noop_quantizer_set
x_rowwise_list = []
x_colwise_list = [] if is_noop_quantizer_set:
kernel_colwise_list = [] grouped_gemm_x = x
kernel_rowwise_list = [] grouped_gemm_kernel = kernel
x_shape_list = [] ctx_x = x
kernel_shape_list = [] ctx_kernel = kernel
if quantizer_set_list is None: flatten_axis_k = None
x_rowwise_list = x_list
x_colwise_list = x_list
kernel_colwise_list = kernel_list
kernel_rowwise_list = kernel_list
x_shape_list = [x.shape for x in x_list]
kernel_shape_list = [kernel.shape for kernel in kernel_list]
else: else:
for i in range(len(x_list)): # pylint: disable=consider-using-enumerate x_contracting_dims, k_contracting_dims = contracting_dims
q_x = tex.quantize(x_list[i], quantizer_set_list[i].x) flatten_axis_x = -len(x_contracting_dims)
q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis
x_rowwise_list.append(q_x.get_rowwise_tensor())
x_colwise_list.append(q_x.get_colwise_tensor()) assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
kernel_colwise_list.append(q_kernel.get_colwise_tensor()) assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
kernel_rowwise_list.append(q_kernel.get_rowwise_tensor()) # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
x_shape_list.append(x_rowwise_list[-1].data.shape) # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
kernel_shape_list.append(kernel_rowwise_list[-1].data.shape) assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
"grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
output_list = tex.grouped_gemm( "and k_contracting_dims=(1,) for now, "
x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list f"got {x_contracting_dims=} and {k_contracting_dims=}"
)
k_contracting_dims = (0,)
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
)
casted_kernel = tex.grouped_quantize(
kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k
)
contracting_dims = (x_contracting_dims, k_contracting_dims)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor()
grouped_gemm_kernel = casted_kernel.get_colwise_tensor()
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()?
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
output = tex.grouped_gemm(
grouped_gemm_x,
grouped_gemm_kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
) )
ctx = ( ctx = (
x_colwise_list, group_sizes,
kernel_rowwise_list, ctx_x,
x_shape_list, ctx_kernel,
kernel_shape_list, x.shape,
kernel.shape,
use_bias, use_bias,
quantizer_set_list, is_noop_quantizer_set,
quantizer_set,
flatten_axis_k,
) )
return output_list, ctx return output, ctx
def _grouped_dense_bwd_rule(
contracting_dims, precision, preferred_element_type, group_offset, ctx, grad
):
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
( (
colwise_x_list, group_sizes,
rowwise_kernel_list, ctx_x,
x_shape_list, ctx_kernel,
kernel_shape_list, x_shape,
kernel_shape,
use_bias, use_bias,
quantizer_set_list, is_noop_quantizer_set,
quantizer_set,
flatten_axis_k,
) = ctx ) = ctx
group_size = len(grad_list) if is_noop_quantizer_set:
dbias_list = [] # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
grad_rowwise_list = [] # g_contracting_dim = (1, )
grad_colwise_list = [] # k_contracting_dim = (2, )
dgrad_contracting_dims_list = []
wgrad_contracting_dims_list = []
for i in range(group_size):
grad = grad_list[i]
x_shape = x_shape_list[i]
kernel_shape = kernel_shape_list[i]
fwd_contracting_dims = contracting_dims_list[i]
if quantizer_set_list is None:
casted_grad = grad
dbias = tex.quantization._jax_dbias(grad)
grad_rowwise_list.append(grad)
grad_colwise_list.append(grad)
else:
quantizer_set = quantizer_set_list[i]
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad
)
grad_rowwise_list.append(casted_grad.get_rowwise_tensor())
grad_colwise_list.append(casted_grad.get_colwise_tensor())
dbias_list.append(dbias)
# GEMM NT
fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims
g_contracting_dim = tuple( g_contracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
) )
k_contracting_dim = tuple( k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_contracting_dims_list.append(dgrad_contracting_dims) dgrad_grad = grad
dgrad_kernel_T = ctx_kernel
# GEMM TN # g_contracting_dim = (0, )
# x_contracting_dim = (0, )
g_contracting_dim = x_contracting_dim = tuple( g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims)) range(0, len(x_shape) - len(fwd_x_contracting_dims))
) )
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_contracting_dims_list.append(wgrad_contracting_dims) wgrad_x_T = ctx_x
wgrad_grad = grad
else:
casted_grad = tex.grouped_quantize(
grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
# g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
# extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (1,)
k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor()
dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,)
x_contracting_dim = (1,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor()
dgrad = tex.grouped_gemm(
dgrad_grad,
dgrad_kernel_T,
group_sizes,
dgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
dgrad_list = tex.grouped_gemm( wgrad = tex.grouped_gemm(
grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list wgrad_x_T,
wgrad_grad,
group_sizes,
wgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
) )
wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list)
return dgrad_list, wgrad_list, dbias_list, quantizer_set_list group_sizes_grad = None
dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
...@@ -127,14 +127,16 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -127,14 +127,16 @@ class BlockScaleDequantizer(Dequantizer):
def dequantize(scaled_tensor): def dequantize(scaled_tensor):
"""Dequantize a tensor using block scaling. """Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args: Args:
scaled_tensor: The quantized tensor to dequantize data: The quantized tensor data
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns: Returns:
The dequantized tensor in the specified data type The dequantized tensor
""" """
return BlockScaleDequantizer._dequantize_func( return BlockScaleDequantizer._dequantize_func(
scaled_tensor.data, scaled_tensor.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment