Unverified Commit c3b7c2ae authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

Revert "[JAX] GroupedDense v.2 without dynamic shape" (#1874)

Revert "[JAX] GroupedDense v.2 without dynamic shape (#1721)"

This reverts commit 5d01ef21

.
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4d4f1edb
...@@ -40,11 +40,10 @@ from transformer_engine.jax.quantize import ( ...@@ -40,11 +40,10 @@ from transformer_engine.jax.quantize import (
ScalingMode, ScalingMode,
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.dense import dense
from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.jax.layernorm_dense import layernorm_dense
GEMM_CASES = [ GEMM_CASES = [
...@@ -1205,6 +1204,24 @@ class TestFusedDense: ...@@ -1205,6 +1204,24 @@ class TestFusedDense:
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer):
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return lhs_q, rhs_q
# E5M2 * E5M2 is not supported # E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [ fwd_bwd_dtypes = [
[jnp.float8_e4m3fn, jnp.float8_e4m3fn], [jnp.float8_e4m3fn, jnp.float8_e4m3fn],
...@@ -1212,194 +1229,219 @@ fwd_bwd_dtypes = [ ...@@ -1212,194 +1229,219 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn], [jnp.float8_e5m2, jnp.float8_e4m3fn],
] ]
GROUPED_DENSE_INPUT_SHAPES = [ """
# (n_groups, m, n, k), the actual m will be multiplied by 32 @pytest_parametrize_wrapper(
(5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4 "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
(8, 64, 32, 128), )
(8, 64, 128, 256),
]
@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
class TestGroupedDense: class TestGroupedDense:
def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list):
lhs_contract_dim, _ = contracting_dims ref_out_list = []
assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3 for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
if bias is None: dim_nums = (contracting_dims, ((), ()))
bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype) ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums))
else: return ref_out_list
assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop() def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list):
lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
rhs = jnp.split(rhs, rhs.shape[0], axis=0)
bias = jnp.split(bias, bias.shape[0], axis=0)
ref_out = []
dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i))
return ref_out
def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4) subkeys = jax.random.split(key, len(shape_list) * 2)
n_groups, m, n, k = input_shape
lhs_list, rhs_list, contracting_dims_list = [], [], []
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) lhs = jax.random.uniform(
group_sizes = jnp.diff(group_sizes) subkeys[2 * i],
assert group_sizes.sum() == m (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype,
# *32 to make sure that input shape works for MXFP8 )
group_sizes = group_sizes * 32 rhs = jax.random.uniform(
m = m * 32 subkeys[2 * i + 1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) dtype=dtype,
rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) )
bias_shape = (n_groups, n) lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype) contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return lhs, rhs, group_sizes, contracting_dims, bias lhs_list.append(lhs)
rhs_list.append(rhs)
contracting_dims_list.append(contracting_dims)
def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): return lhs_list, rhs_list, contracting_dims_list
assert out.dtype == ref_list[0].dtype
out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
for i in range(len(ref_list)):
assert_allclose(out_list[i], ref_list[i], dtype=dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
@pytest_parametrize_wrapper("layout", ["NN"]) @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp16(self, dtype, input_shape, layout): def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list):
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
dtype, input_shape, layout dtype, shape_list, layout_list
) )
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"]) @pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=False,
n_groups=input_shape[0],
) )
# quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
# We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
quantizer_set.kernel.q_dtype = bwd_dtype
for quantizer in quantizer_set.kernel.quantizers:
quantizer.q_dtype = bwd_dtype
out_dtype = jnp.bfloat16 out_dtype = jnp.bfloat16
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
out_dtype, input_shape, layout out_dtype, shape_list, layout_list
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
prim_out = tex.grouped_gemm(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
) )
q_lhs_list = []
q_rhs_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
# test the case where lhs and rhs have different q_dtypes
q_lhs, q_rhs = _quantize_gemm_pair(
lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
)
q_lhs_list.append(q_lhs)
q_rhs_list.append(q_rhs)
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list)
allclose_dtype = jnp.float8_e4m3fn allclose_dtype = jnp.float8_e4m3fn
if jnp.float8_e5m2 in fwd_bwd_dtype: if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
allclose_dtype = jnp.float8_e5m2 allclose_dtype = jnp.float8_e5m2
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, shape_list):
def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims): group_size = len(shape_list)
out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims) layout_list = ["NN" for _ in range(group_size)]
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def _primitive_sum_grouped_dense( x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set dtype, shape_list, layout_list
):
out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
) )
return jnp.sum(jnp.asarray(out)) bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list):
def test_grouped_dense_grad_fp16(self, dtype, input_shape): out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list)
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( out_sum_list = [jnp.sum(out) for out in out_list]
dtype, return jnp.sum(jnp.asarray(out_sum_list))
input_shape,
with_bias=True,
)
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x, kernel, bias, group_sizes, contracting_dims x_list, kernel_list, bias_list, contracting_dims_list
) )
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
x, kernel, bias, group_sizes, contracting_dims value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list)
) )
assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype) assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype)
assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype) for i in range(group_size):
assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype) assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize( @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
"fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING: group_size = len(shape_list)
pytest.skip("MXFP8 is not supported in grouped_dense yet") layout_list = ["NN" for _ in range(group_size)]
fwd_dtype, bwd_dtype = fwd_bwd_dtype fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16 if fwd_dtype == jnp.float8_e5m2:
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( pytest.skip("We never use E5M2 for fwd_dtype in training")
dtype,
input_shape, # Question: should we use different quantizers for different groups?
with_bias=True, ref_quantizer_set_list = []
) quantizer_set_list = []
for _ in range(group_size):
ref_quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
ref_quantizer_set_list.append(ref_quantizer_set)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
quantizer_set_list.append(quantizer_set)
quantizer_set = QuantizerFactory.create_set( out_dtype = jnp.bfloat16
scaling_mode=scaling_mode, x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
fwd_dtype=fwd_dtype, out_dtype, shape_list, layout_list
bwd_dtype=bwd_dtype,
is_2x2x=True,
n_groups=group_sizes.size,
) )
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) bias_list = []
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) key = jax.random.PRNGKey(1)
for shape in shape_list:
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( n = shape[1]
x, bias = jax.random.uniform(key, n, dtype=out_dtype)
kernel, bias_list.append(bias)
bias,
group_sizes, def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
contracting_dims, out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
quantizer_set=quantizer_set_list[i],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
):
out_list = grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list
) )
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func( primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set value_n_grad_primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
) )
assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype) allclose_dtype = jnp.float8_e4m3fn
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype) if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype) allclose_dtype = jnp.float8_e5m2
assert_allclose(prim_dbias, ref_dbias, dtype=dtype) assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
...@@ -525,7 +525,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -525,7 +525,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B)); const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C)); const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D)); const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
const auto workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
...@@ -534,8 +533,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -534,8 +533,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
NVTE_CHECK(workspace_alignment % 256 == 0,
"cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment);
const auto status = const auto status =
cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
......
...@@ -47,7 +47,7 @@ else: ...@@ -47,7 +47,7 @@ else:
from jax.extend import ffi # pylint: disable=ungrouped-imports from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] __all__ = ["quantize", "quantize_dbias", "grouped_quantize"]
class BaseDBiasQuantizePrimitive(BasePrimitive): class BaseDBiasQuantizePrimitive(BasePrimitive):
...@@ -1032,24 +1032,3 @@ def grouped_quantize( ...@@ -1032,24 +1032,3 @@ def grouped_quantize(
group_axis=group_axis, group_axis=group_axis,
) )
return out return out
def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
"""
Compute the grouped bias gradient.
Args:
grad: jnp.ndarray of shape (M, N)
group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M
Returns:
dbias: jnp.ndarray of shape (num_groups, N)
"""
assert grad.ndim == 2, "Input grad must be a 2D tensor."
assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."
segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes)
grad_fp32 = grad.astype(jnp.float32)
dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
dbias = dbias_fp32.astype(grad.dtype)
return dbias
...@@ -13,127 +13,43 @@ ...@@ -13,127 +13,43 @@
#include "transformer_engine/multi_stream.h" #include "transformer_engine/multi_stream.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Variadic_Result_Type output_list, int64_t num_gemms,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, JAXX_Scaling_Mode scaling_mode, int64_t has_bias) {
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) {
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major [m, k] for N - [k, m] for T // A: row-major with size [m, k],
// B: row-major [k, n] for N - [n, k] for T // B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect: // on exiting this function, JAX expect:
// C: row-major with size [m, n]. // C: row-major with size [m, n].
// cuBLAS uses column-major data_layout, in this view, each input matrix pair: // cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m] for T - [m, k] for N // A: column-major with size [k, m], needs transpose,
// B: column-major with size [n, k] for T - [k, n] for N // B: column-major with size [k, n].
//
// If we call cuBLAS GEMM for A * B, the output will be: // If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m]. // C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
int num_streams = nvte_get_num_compute_streams(); if (num_gemms <= 0) {
return ffi_with_cuda_error_check();
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_data.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_data.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type());
auto bias_ptr = has_bias ? reinterpret_cast<uint8_t *>(bias.untyped_data()) : nullptr;
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());
NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0];
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto workspace_total_size = product(workspace->dimensions());
auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions());
auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t expected_lhs_size = m * k;
size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n);
size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n);
size_t actual_lhs_size = product(lhs_data.dimensions());
size_t actual_rhs_size = product(rhs_data.dimensions());
size_t actual_out_size = product(output->dimensions());
NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ",
expected_lhs_size, ", got ", actual_lhs_size);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(expected_rhs_size == actual_rhs_size,
"Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k,
" = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m,
" * ", n, " = ", expected_out_size, ", got ", actual_out_size);
} else {
NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k,
" * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size,
"Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n,
" = ", expected_out_size, ", got ", actual_out_size);
} }
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t dim_list_bytes = sizeof(int32_t) * num_gemms; size_t expected_output_size = num_gemms + 1;
std::vector<int32_t> dim_list_host(num_gemms); size_t actual_input_size = input_list.size();
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data()); size_t actual_output_size = output_list.size();
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
stream); expected_input_size, actual_input_size);
// Note: This may break cudaGraph. NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
cudaStreamSynchronize(stream); expected_output_size, actual_output_size);
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) { bool trans_lhs = true;
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, bool trans_rhs = false;
", got sum(group_sizes)=", sum_group_sizes);
} else {
NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k,
", got sum(group_sizes)=", sum_group_sizes);
}
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
bool grad = false; bool grad = false;
bool accumulate = false; bool accumulate = false;
bool use_split_accumulator = false; bool use_split_accumulator = false;
auto bias_shape = std::vector<size_t>{has_bias ? n : 0};
const int arch = cuda::sm_arch();
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
if (arch < 100 && is_fp8_gemm) {
NVTE_CHECK(!lhs_is_trans && rhs_is_trans,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ",
"got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans);
}
// These lists are to keep the TensorWrapper objects alive // These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> lhs_wrapper_list; std::vector<TensorWrapper> lhs_wrapper_list;
...@@ -151,83 +67,96 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -151,83 +67,96 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
std::vector<NVTETensor> out_list; std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list; std::vector<NVTETensor> workspace_list;
for (size_t i = 0; i < num_gemms; i++) { int lhs_list_offset = 0;
// Matrix data shapes int rhs_list_offset = num_gemms;
size_t m_i = dim_list_host[i]; int lhs_sinv_list_offset = 2 * num_gemms;
auto lhs_shape = std::vector<size_t>{m_i, k}; int rhs_sinv_list_offset = 3 * num_gemms;
auto rhs_shape = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; int bias_list_offset = 4 * num_gemms;
auto out_shape = std::vector<size_t>{m_i, n}; int out_list_offset = 0;
if (is_grouped_dense_wgrad) { for (int i = 0; i < num_gemms; i++) {
size_t k_i = dim_list_host[i]; Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value();
lhs_shape[0] = lhs_is_trans ? k_i : m; Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value();
lhs_shape[1] = lhs_is_trans ? m : k_i; Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value();
rhs_shape[0] = rhs_is_trans ? n : k_i; Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value();
rhs_shape[1] = rhs_is_trans ? k_i : n; Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value();
out_shape[0] = m;
out_shape[1] = n; DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type());
} DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
// Set matrix data pointers
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); void *lhs_ptr = lhs_i.untyped_data();
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); void *rhs_ptr = rhs_i.untyped_data();
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype); void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *lhs_vptr = static_cast<void *>(lhs_ptr); void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *rhs_vptr = static_cast<void *>(rhs_ptr); void *out_ptr = out_i->untyped_data();
if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape); // Placeholder for bias since it can be empty
else DType bias_dtype = DType::kFloat32;
rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape); void *bias_ptr = nullptr;
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape); auto lhs_shape_ = lhs_i.dimensions();
else auto rhs_shape_ = rhs_i.dimensions();
lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape);
// lhs and rhs has shape [1, m, k] and [1, n, k]
// Scale_inv shapes size_t m = lhs_shape_[1];
auto lhs_sinv_size = std::vector<size_t>{1}; size_t n = rhs_shape_[1];
auto rhs_sinv_size = std::vector<size_t>{1}; size_t k = lhs_shape_[2];
if (is_mxfp8_scaling) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)", auto lhs_shape = std::vector<size_t>{m, k};
MXFP8_BLOCK_SIZE, k); auto rhs_shape = std::vector<size_t>{n, k};
size_t scale_k = k / MXFP8_BLOCK_SIZE; auto out_shape = std::vector<size_t>{n, m};
lhs_sinv_size[0] = m_i * scale_k; auto lhs_sinv_shape = std::vector<size_t>{1, 1};
rhs_sinv_size[0] = n * scale_k; auto rhs_sinv_shape = std::vector<size_t>{1, 1};
// Need to add swizzle here
} if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
// Set scale_inv pointers scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
void *rhs_sinv_vptr = static_cast<void *>(rhs_sinv_ptr); float *amax_dptr = nullptr;
void *lhs_sinv_vptr = static_cast<void *>(lhs_sinv_ptr); float *scale_dptr = nullptr;
if (is_fp8_gemm) { auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
if (rhs_use_colwise) // MatA to enter cuBLAS reinterpret_cast<float *>(lhs_sinv_ptr));
rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
else reinterpret_cast<float *>(rhs_sinv_ptr));
rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size); lhs_wrapper_list.push_back(std::move(lhs_i_));
if (lhs_use_colwise) // MatB to enter cuBLAS rhs_wrapper_list.push_back(std::move(rhs_i_));
lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
else // Note: the scale_inv array should have been swizzled in Python before lowering
lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size); auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
for (int i = 0; i < 2; i++) {
lhs_sinv_shape[i] = lhs_sinv_shape_[i];
rhs_sinv_shape[i] = rhs_sinv_shape_[i];
}
NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode);
TensorWrapper lhs_i_(nvte_scaling_mode);
TensorWrapper rhs_i_(nvte_scaling_mode);
lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else { } else {
NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
} }
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype);
auto pre_gelu_i = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype); void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0};
// Update pointer for the next GEMM pair auto pre_gelu_shape = std::vector<size_t>{0};
lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes; if (has_bias) {
rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes; auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i);
out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes; Buffer_Type bias_i = bias_i_get.value();
if (is_fp8_gemm) { bias_ptr = bias_i.untyped_data();
lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes; bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type());
rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes; bias_shape[0] = n;
} }
if (has_bias) bias_ptr += n * bias_dtype_bytes; auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
// Move objects to the lists to keep them alive out_wrapper_list.push_back(std::move(out_i_));
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
out_wrapper_list.push_back(std::move(out_i));
bias_wrapper_list.push_back(std::move(bias_i)); bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
...@@ -238,6 +167,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -238,6 +167,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
out_list.push_back(out_wrapper_list.back().data()); out_list.push_back(out_wrapper_list.back().data());
} }
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto num_streams = nvte_get_num_compute_streams();
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size}; auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) { for (int i = 0; i < num_streams; i++) {
auto workspace_i = auto workspace_i =
...@@ -248,7 +182,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -248,7 +182,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
} }
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad, pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad,
workspace_list.data(), accumulate, use_split_accumulator, workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream); num_math_sm, stream);
...@@ -258,23 +192,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -258,23 +192,11 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind() FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_data .RemainingArgs() // input list
.Arg<Buffer_Type>() // lhs_sinv .RemainingRets() // output list
.Arg<Buffer_Type>() // rhs_data .Attr<int64_t>("num_gemms")
.Arg<Buffer_Type>() // rhs_sinv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // group_sizes
.Arg<Buffer_Type>() // group_offset
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("M")
.Attr<int64_t>("N")
.Attr<int64_t>("K")
.Attr<bool>("lhs_is_trans")
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias") .Attr<int64_t>("has_bias"),
.Attr<bool>("is_grouped_dense_wgrad"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
......
...@@ -153,28 +153,28 @@ def _dense_bwd_rule( ...@@ -153,28 +153,28 @@ def _dense_bwd_rule(
# GEMM NT # GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_contracting_dim = tuple( g_constracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_contracting_dim = tuple( k_constracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel, rowwise_casted_kernel,
(g_contracting_dim, k_contracting_dim), (g_constracting_dim, k_constracting_dim),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
g_contracting_dim = x_contracting_dim = tuple( g_constracting_dim = x_constracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims)) range(0, len(x_shape) - len(fwd_x_contracting_dims))
) )
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim) colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
...@@ -184,240 +184,135 @@ def _dense_bwd_rule( ...@@ -184,240 +184,135 @@ def _dense_bwd_rule(
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
"""
def grouped_dense( def grouped_dense(
x: jnp.ndarray, x_list,
kernel: jnp.ndarray, kernel_list,
group_sizes: jnp.ndarray, bias_list,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), contracting_dims_list,
bias: jnp.ndarray = None, quantizer_set_list=None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
""" # Perform grouped_dense layer transformation with optional quantization.
Perform grouped dense (linear) layer transformation with optional quantization.
Args: output_list = _grouped_dense(
x: Input tensor of shape (M, K) x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
kernel: Weight matrix of shape (G, K, N)
group_sizes: 1D array of shape (G,) specifying the size of each group
contracting_dims: Tuple of sequences specifying which dimensions to contract
(currently only supports ((1,), (1,)))
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
Returns:
A jnp.ndarray containing the result of the grouped linear operation
"""
output = _grouped_dense(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
) )
return output return output_list
@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) @partial(jax.custom_vjp, nondiff_argnums=(3,))
def _grouped_dense( def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
x, output_list, _ = _grouped_dense_fwd_rule(
kernel, x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
):
output, _ = _grouped_dense_fwd_rule(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
) )
return output return output_list
def _grouped_dense_fwd_rule( def _grouped_dense_fwd_rule(
x, x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
): ):
use_bias = bias is not None use_bias = bias_list is not None
is_noop_quantizer_set = quantizer_set == noop_quantizer_set output_list = []
x_rowwise_list = []
if is_noop_quantizer_set: x_colwise_list = []
grouped_gemm_x = x kernel_colwise_list = []
grouped_gemm_kernel = kernel kernel_rowwise_list = []
ctx_x = x x_shape_list = []
ctx_kernel = kernel kernel_shape_list = []
flatten_axis_k = None if quantizer_set_list is None:
x_rowwise_list = x_list
x_colwise_list = x_list
kernel_colwise_list = kernel_list
kernel_rowwise_list = kernel_list
x_shape_list = [x.shape for x in x_list]
kernel_shape_list = [kernel.shape for kernel in kernel_list]
else: else:
x_contracting_dims, k_contracting_dims = contracting_dims for i in range(len(x_list)): # pylint: disable=consider-using-enumerate
flatten_axis_x = -len(x_contracting_dims) q_x = tex.quantize(x_list[i], quantizer_set_list[i].x)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel)
x_rowwise_list.append(q_x.get_rowwise_tensor())
assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)" x_colwise_list.append(q_x.get_colwise_tensor())
assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)" kernel_colwise_list.append(q_kernel.get_colwise_tensor())
# Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose kernel_rowwise_list.append(q_kernel.get_rowwise_tensor())
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()? x_shape_list.append(x_rowwise_list[-1].data.shape)
assert x_contracting_dims == (1,) and k_contracting_dims == (1,), ( kernel_shape_list.append(kernel_rowwise_list[-1].data.shape)
"grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
"and k_contracting_dims=(1,) for now, " output_list = tex.grouped_gemm(
f"got {x_contracting_dims=} and {k_contracting_dims=}" x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list
)
k_contracting_dims = (0,)
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
)
casted_kernel = tex.grouped_quantize(
kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k
)
contracting_dims = (x_contracting_dims, k_contracting_dims)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor()
grouped_gemm_kernel = casted_kernel.get_colwise_tensor()
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()?
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
output = tex.grouped_gemm(
grouped_gemm_x,
grouped_gemm_kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
) )
ctx = ( ctx = (
group_sizes, x_colwise_list,
ctx_x, kernel_rowwise_list,
ctx_kernel, x_shape_list,
x.shape, kernel_shape_list,
kernel.shape,
use_bias, use_bias,
is_noop_quantizer_set, quantizer_set_list,
quantizer_set,
flatten_axis_k,
) )
return output, ctx return output_list, ctx
def _grouped_dense_bwd_rule( def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
contracting_dims, precision, preferred_element_type, group_offset, ctx, grad
):
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
( (
group_sizes, colwise_x_list,
ctx_x, rowwise_kernel_list,
ctx_kernel, x_shape_list,
x_shape, kernel_shape_list,
kernel_shape,
use_bias, use_bias,
is_noop_quantizer_set, quantizer_set_list,
quantizer_set,
flatten_axis_k,
) = ctx ) = ctx
if is_noop_quantizer_set: group_size = len(grad_list)
# The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?) dbias_list = []
# g_contracting_dim = (1, ) grad_rowwise_list = []
# k_contracting_dim = (2, ) grad_colwise_list = []
dgrad_contracting_dims_list = []
wgrad_contracting_dims_list = []
for i in range(group_size):
grad = grad_list[i]
x_shape = x_shape_list[i]
kernel_shape = kernel_shape_list[i]
fwd_contracting_dims = contracting_dims_list[i]
if quantizer_set_list is None:
casted_grad = grad
dbias = tex.quantization._jax_dbias(grad)
grad_rowwise_list.append(grad)
grad_colwise_list.append(grad)
else:
quantizer_set = quantizer_set_list[i]
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad
)
grad_rowwise_list.append(casted_grad.get_rowwise_tensor())
grad_colwise_list.append(casted_grad.get_colwise_tensor())
dbias_list.append(dbias)
# GEMM NT
fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims
g_contracting_dim = tuple( g_contracting_dim = tuple(
range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
) )
k_contracting_dim = tuple( k_contracting_dim = tuple(
dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = grad dgrad_contracting_dims_list.append(dgrad_contracting_dims)
dgrad_kernel_T = ctx_kernel
# g_contracting_dim = (0, ) # GEMM TN
# x_contracting_dim = (0, )
g_contracting_dim = x_contracting_dim = tuple( g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims)) range(0, len(x_shape) - len(fwd_x_contracting_dims))
) )
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x wgrad_contracting_dims_list.append(wgrad_contracting_dims)
wgrad_grad = grad
else:
casted_grad = tex.grouped_quantize(
grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
# g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
# extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (1,)
k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor()
dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,)
x_contracting_dim = (1,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor()
dgrad = tex.grouped_gemm(
dgrad_grad,
dgrad_kernel_T,
group_sizes,
dgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
wgrad = tex.grouped_gemm( dgrad_list = tex.grouped_gemm(
wgrad_x_T, grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list
wgrad_grad,
group_sizes,
wgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
) )
wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list)
group_sizes_grad = None return dgrad_list, wgrad_list, dbias_list, quantizer_set_list
dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
...@@ -127,16 +127,14 @@ class BlockScaleDequantizer(Dequantizer): ...@@ -127,16 +127,14 @@ class BlockScaleDequantizer(Dequantizer):
def dequantize(scaled_tensor): def dequantize(scaled_tensor):
"""Dequantize a tensor using block scaling. """Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args: Args:
data: The quantized tensor data scaled_tensor: The quantized tensor to dequantize
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns: Returns:
The dequantized tensor The dequantized tensor in the specified data type
""" """
return BlockScaleDequantizer._dequantize_func( return BlockScaleDequantizer._dequantize_func(
scaled_tensor.data, scaled_tensor.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment