Unverified Commit 5d01ef21 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] GroupedDense v.2 without dynamic shape (#1721)



* Implemented GroupedDense and TestGroupedDense for BF16, FP16, and FP8 
* Fix GroupedGemmFFI cuBLAS workspace alignment bug
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c293d3a8
......@@ -40,10 +40,11 @@ from transformer_engine.jax.quantize import (
ScalingMode,
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
GEMM_CASES = [
......@@ -1204,24 +1205,6 @@ class TestFusedDense:
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def _quantize_gemm_pair(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer):
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return lhs_q, rhs_q
# E5M2 * E5M2 is not supported
fwd_bwd_dtypes = [
[jnp.float8_e4m3fn, jnp.float8_e4m3fn],
......@@ -1229,219 +1212,194 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn],
]
"""
@pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
GROUPED_DENSE_INPUT_SHAPES = [
# (n_groups, m, n, k), the actual m will be multiplied by 32
(5, 32, 128, 64), # Test the case where n_groups is not a multiple of 4
(8, 64, 32, 128),
(8, 64, 128, 256),
]
@pytest_parametrize_wrapper("input_shape", GROUPED_DENSE_INPUT_SHAPES)
class TestGroupedDense:
def _ref_grouped_gemm_with_jnp_dot(self, lhs_list, rhs_list, contracting_dims_list):
ref_out_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
dim_nums = (contracting_dims, ((), ()))
ref_out_list.append(jax.lax.dot_general(lhs, rhs, dim_nums))
return ref_out_list
def _generate_grouped_gemm_input(self, dtype, shape_list, layout_list):
def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims):
lhs_contract_dim, _ = contracting_dims
assert len(lhs_contract_dim) == 1 and lhs.ndim == 2 and rhs.ndim == 3
if bias is None:
bias = jnp.zeros((rhs.shape[0], rhs.shape[2]), dtype=lhs.dtype)
else:
assert bias.ndim == 2 and bias.shape == (rhs.shape[0], rhs.shape[2])
remaining_axis = (set(range(lhs.ndim)) - set(lhs_contract_dim)).pop()
lhs = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=remaining_axis)
rhs = jnp.split(rhs, rhs.shape[0], axis=0)
bias = jnp.split(bias, bias.shape[0], axis=0)
ref_out = []
dim_num = (contracting_dims, ((), ()))
for lhs_i, rhs_i, bias_i in zip(lhs, rhs, bias):
out_i = jax.lax.dot_general(lhs_i, rhs_i, dim_num) + jnp.expand_dims(bias_i, axis=0)
ref_out.append(jnp.squeeze(out_i))
return ref_out
def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, len(shape_list) * 2)
subkeys = jax.random.split(key, 4)
n_groups, m, n, k = input_shape
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
assert group_sizes.sum() == m
# *32 to make sure that input shape works for MXFP8
group_sizes = group_sizes * 32
m = m * 32
lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m)
rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k)
bias_shape = (n_groups, n)
lhs = jax.random.uniform(subkeys[1], lhs_shape, dtype=dtype)
rhs = jax.random.uniform(subkeys[2], rhs_shape, dtype=dtype)
bias = jax.random.uniform(subkeys[3], bias_shape, dtype=dtype) if with_bias else None
lhs_list, rhs_list, contracting_dims_list = [], [], []
for i, ((m, n, k), data_layout) in enumerate(zip(shape_list, layout_list)):
lhs = jax.random.uniform(
subkeys[2 * i],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=dtype,
)
rhs = jax.random.uniform(
subkeys[2 * i + 1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=dtype,
)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
rhs_contracting_dim = (1,) if data_layout[1] == "N" else (2,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
lhs_list.append(lhs)
rhs_list.append(rhs)
contracting_dims_list.append(contracting_dims)
return lhs, rhs, group_sizes, contracting_dims, bias
return lhs_list, rhs_list, contracting_dims_list
def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype):
assert out.dtype == ref_list[0].dtype
out_list = jnp.split(out, jnp.cumulative_sum(group_sizes)[:-1], axis=0)
for i in range(len(ref_list)):
assert_allclose(out_list[i], ref_list[i], dtype=dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp16(self, dtype, shape_list, layout_list):
lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
dtype, shape_list, layout_list
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
dtype, input_shape, layout
)
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(lhs_list, rhs_list, contracting_dims_list)
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=dtype)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout_list", [["NN", "TN", "NT", "TT"]])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list, layout_list):
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=False
scaling_mode=scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=False,
n_groups=input_shape[0],
)
# quantizer_set.{x, kernel} has fwd_dtype, while quantizer_set.grad has bwd_dtype
# We want to test E4M3 * E5M2, manually set the quantizer_set.kernel.q_dtype to bwd_dtype
quantizer_set.kernel.q_dtype = bwd_dtype
for quantizer in quantizer_set.kernel.quantizers:
quantizer.q_dtype = bwd_dtype
out_dtype = jnp.bfloat16
lhs_list, rhs_list, contracting_dims_list = self._generate_grouped_gemm_input(
out_dtype, shape_list, layout_list
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
out_dtype, input_shape, layout
)
q_lhs_list = []
q_rhs_list = []
for lhs, rhs, contracting_dims in zip(lhs_list, rhs_list, contracting_dims_list):
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
# test the case where lhs and rhs have different q_dtypes
q_lhs, q_rhs = _quantize_gemm_pair(
lhs, rhs, contracting_dims, quantizer_set.x, quantizer_set.dgrad
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
prim_out = tex.grouped_gemm(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
)
q_lhs_list.append(q_lhs)
q_rhs_list.append(q_rhs)
ref_out = self._ref_grouped_gemm_with_jnp_dot(lhs_list, rhs_list, contracting_dims_list)
primitive_out = tex.grouped_gemm(q_lhs_list, q_rhs_list, contracting_dims_list)
allclose_dtype = jnp.float8_e4m3fn
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
if jnp.float8_e5m2 in fwd_bwd_dtype:
allclose_dtype = jnp.float8_e5m2
for i in range(len(shape_list)):
assert_allclose(primitive_out[i], ref_out[i], dtype=allclose_dtype)
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, shape_list):
group_size = len(shape_list)
layout_list = ["NN" for _ in range(group_size)]
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, allclose_dtype)
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
dtype, shape_list, layout_list
)
bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
)
)
def _ref_sum_grouped_dense(self, x, kernel, bias, group_sizes, contracting_dims):
out_list = self._ref_grouped_dense(x, kernel, bias, group_sizes, contracting_dims)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func(x_list, kernel_list, bias_list, contracting_dims_list):
out_list = grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list)
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def _primitive_sum_grouped_dense(
self, x, kernel, bias, group_sizes, contracting_dims, quantizer_set=noop_quantizer_set
):
out = grouped_dense(
x, kernel, group_sizes, contracting_dims, bias=bias, quantizer_set=quantizer_set
)
return jnp.sum(jnp.asarray(out))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
@pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16])
def test_grouped_dense_grad_fp16(self, dtype, input_shape):
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
dtype,
input_shape,
with_bias=True,
)
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x, kernel, bias, group_sizes, contracting_dims
)
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
value_n_grad_primitive_func(x_list, kernel_list, bias_list, contracting_dims_list)
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
x, kernel, bias, group_sizes, contracting_dims
)
assert_allclose(primitive_out_mean, ref_out_mean, dtype=dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=dtype)
assert_allclose(prim_out_sum, ref_out_sum, dtype=dtype)
assert_allclose(prim_dgrad, ref_dgrad, dtype=dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, shape_list):
group_size = len(shape_list)
layout_list = ["NN" for _ in range(group_size)]
fwd_dtype, bwd_dtype = fwd_bwd_dtype
if fwd_dtype == jnp.float8_e5m2:
pytest.skip("We never use E5M2 for fwd_dtype in training")
# Question: should we use different quantizers for different groups?
ref_quantizer_set_list = []
quantizer_set_list = []
for _ in range(group_size):
ref_quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
)
ref_quantizer_set_list.append(ref_quantizer_set)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=fwd_dtype, bwd_dtype=bwd_dtype, is_2x2x=True
@pytest.mark.parametrize(
"fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
)
quantizer_set_list.append(quantizer_set)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_dense yet")
out_dtype = jnp.bfloat16
x_list, kernel_list, contracting_dims_list = self._generate_grouped_gemm_input(
out_dtype, shape_list, layout_list
)
bias_list = []
key = jax.random.PRNGKey(1)
for shape in shape_list:
n = shape[1]
bias = jax.random.uniform(key, n, dtype=out_dtype)
bias_list.append(bias)
def ref_func(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
out_list = []
for i in range(len(x_list)):
out_list.append(
dense(
x_list[i],
kernel_list[i],
bias_list[i],
contracting_dims=contracting_dims_list[i],
quantizer_set=quantizer_set_list[i],
)
fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
dtype,
input_shape,
with_bias=True,
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
def primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
):
out_list = grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
is_2x2x=True,
n_groups=group_sizes.size,
)
out_sum_list = [jnp.sum(out) for out in out_list]
return jnp.sum(jnp.asarray(out_sum_list))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2))
value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2))
ref_out_mean, (ref_dgrad_list, ref_wgrad_list, ref_dbias_list) = value_n_grad_ref_func(
x_list, kernel_list, bias_list, contracting_dims_list, ref_quantizer_set_list
)
primitive_out_mean, (primitive_dgrad_list, primitive_wgrad_list, primitive_dbias_list) = (
value_n_grad_primitive_func(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func(
x,
kernel,
bias,
group_sizes,
contracting_dims,
)
prim_out_sum, (prim_dgrad, prim_wgrad, prim_dbias) = value_n_grad_prim_func(
x, kernel, bias, group_sizes, contracting_dims, quantizer_set=quantizer_set
)
allclose_dtype = jnp.float8_e4m3fn
if fwd_dtype == jnp.float8_e5m2 or bwd_dtype == jnp.float8_e5m2:
allclose_dtype = jnp.float8_e5m2
assert_allclose(primitive_out_mean, ref_out_mean, dtype=allclose_dtype)
for i in range(group_size):
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
assert_allclose(prim_out_sum, ref_out_sum, dtype=fwd_dtype)
assert_allclose(prim_dgrad, ref_dgrad, dtype=bwd_dtype)
assert_allclose(prim_wgrad, ref_wgrad, dtype=bwd_dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)
......@@ -525,6 +525,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
const auto workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
......@@ -533,6 +534,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
NVTE_CHECK(workspace_alignment % 256 == 0,
"cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment);
const auto status =
cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
......
......@@ -6,22 +6,28 @@
from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce
import operator
import math
import jax
import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize
from ..quantize import (
ScaledTensor,
GroupedScaledTensor1x,
ScalingMode,
Quantizer,
GroupedQuantizer,
QuantizeConfig,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
)
__all__ = ["gemm"]
__all__ = ["gemm", "grouped_gemm", "is_gemm_with_all_layouts_supported"]
num_cublas_streams = 4
......@@ -34,6 +40,11 @@ def get_cublas_workspace_size_bytes() -> None:
return 4_194_304
def is_gemm_with_all_layouts_supported() -> False:
"""Return True if using blackwell, False otherwise."""
return get_device_compute_capability(0) >= 100
class GroupedGemmPrimitive(BasePrimitive):
"""
Primitive for grouped GEMM
......@@ -41,73 +52,139 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = ()
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias):
def abstract(
lhs_data_aval,
lhs_scale_inv_aval,
rhs_data_aval,
rhs_scale_inv_aval,
bias_aval,
group_sizes_aval,
group_offset_aval,
*,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
"""
Grouped GEMM operation.
Args:
*args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
args[ 0 : num_gemms] are the lhs tensors,
args[ num_gemms : 2*num_gemms] are the rhs tensors,
args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
num_gemms: Number of GEMM operations to perform.
scaling_mode: Scaling mode for the GEMM operations.
out_dtype: Data type of the output tensors.
has_bias: Boolean indicating if bias tensors are provided.
lhs_data: Left-hand side input matrix data, 1D flattened array
lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array
rhs_data: Right-hand side input matrix data, 1D flattened array
rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array
bias: Bias matrix of shape (G, N)
group_sizes: 1D array containing the sizes of each group
group_offset: 1D array containing offsets for each group (not yet implemented)
M: Number of rows in the output matrix
N: Number of columns in the output matrix
K: Number of columns in the left-hand side matrix
lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed
rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed
scaling_mode: Scaling mode for the GEMM operations
out_dtype: Data type of the output tensors
has_bias: Boolean indicating if bias tensors are provided
is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation
where both lhs and rhs are 2D matrices and output is (G, M, N)
Returns:
A tuple of ShapedArray objects of size num_gemms+1:
ret[0 : num_gemms]: GEMM output tensors,
ret[num_gemms]:workspace tensor.
A jnp.ndarray containing the result of the grouped GEMM operation
"""
del scaling_mode
expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
assert (
len(args) == expected_num_args
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
A_list = args[0:num_gemms]
B_list = args[num_gemms : 2 * num_gemms]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval = tuple(
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
for A, B in zip(A_list, B_list)
)
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias
# TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_size += lhs_scale_inv_aval.size + rhs_scale_inv_aval.size
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return (*out_list_aval, workspace_aval)
out_shape = (M, N)
if is_grouped_dense_wgrad:
out_shape = (group_sizes_aval.size, M, N)
out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
return (out_aval, workspace_aval)
@staticmethod
def outer_abstract(*args, **kwargs):
(out_aval, _) = GroupedGemmPrimitive.abstract(*args, **kwargs)
return out_aval
return (out_aval,)
@staticmethod
def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias):
def lowering(
ctx,
*args,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx,
*args,
num_gemms=num_gemms,
scaling_mode=int(scaling_mode),
M=M,
N=N,
K=K,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode.value,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
)
@staticmethod
def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias):
def impl(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M,
N,
K,
lhs_is_trans,
rhs_is_trans,
scaling_mode,
out_dtype,
has_bias,
is_grouped_dense_wgrad,
):
assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind(
*args,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
(out, _) = GroupedGemmPrimitive.inner_primitive.bind(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M=M,
N=N,
K=K,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode,
out_dtype=out_dtype,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
)
return out[:-1] # out is [out_list, wkspace], only return out_list
return (out,)
register_primitive(GroupedGemmPrimitive)
......@@ -285,7 +362,7 @@ def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""General matrix multiplication with optional quantization.
......@@ -310,130 +387,190 @@ def gemm(
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
"""
def swizzled_scale(scales):
# Swizzle the scale tensor for FP8 GEMM
assert scales.ndim == 2
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
scales = scales.reshape(rows, cols)
return scales
def grouped_gemm(
lhs: Union[jnp.ndarray, GroupedScaledTensor1x],
rhs: Union[jnp.ndarray, GroupedScaledTensor1x],
group_sizes: jnp.ndarray,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)),
bias: jnp.ndarray = None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""
Grouped GEMM operation.
Args:
lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x
group_sizes: 1D array containing the sizes of each group
contracting_dims: Tuple of two sequences representing the contracting dimensions
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
def grouped_gemm(
lhs_list: List[Union[jnp.ndarray, ScaledTensor]],
rhs_list: List[Union[jnp.ndarray, ScaledTensor]],
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
bias_list: List[jnp.ndarray] = None,
) -> List[jnp.ndarray]:
# Grouped GEMM for multiple pairs of tensors.
assert (
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
num_gemms = len(lhs_list)
lhs_list_ = []
rhs_list_ = []
lhs_sinv_list_ = []
rhs_sinv_list_ = []
bias_list_ = []
for i in range(num_gemms):
lhs = lhs_list[i]
rhs = rhs_list[i]
contracting_dims = contracting_dims_list[i]
dim_nums = (contracting_dims, ((), ()))
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
scaling_mode = lhs.scaling_mode
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
Returns:
A jnp.ndarray containing the result of the grouped GEMM operation
Note:
Tested shapes:
lhs: [M, K] or [K, N]
rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K]
"""
# TODO(Phuong): implement the group_offset
group_offset = group_offset or jnp.zeros((1,), jnp.int32)
# TODO(Phuong): implement the precision
del precision
if isinstance(lhs, jnp.ndarray):
assert isinstance(rhs, jnp.ndarray)
out_dtype = lhs.dtype
lhs_shape = lhs.shape
rhs_shape = rhs.shape
lhs_data = lhs
rhs_data = rhs
lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32)
scaling_mode = ScalingMode.NO_SCALING
elif isinstance(lhs, GroupedScaledTensor1x):
assert isinstance(rhs, GroupedScaledTensor1x)
out_dtype = lhs.dq_dtype
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode.is_tensor_scaling():
lhs_shape = lhs.original_shape
rhs_shape = rhs.original_shape
lhs_data = lhs.data
rhs_data = rhs.data
lhs_scale_inv = lhs.scale_inv
rhs_scale_inv = rhs.scale_inv
assert lhs.scaling_mode == rhs.scaling_mode
scaling_mode = lhs.scaling_mode
else:
raise TypeError("Unsupported lhs type object!")
out_dtype = preferred_element_type or out_dtype
lhs_contract_dim, rhs_contract_dim = contracting_dims
lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1
lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1)
# rhs_shape [G, K, N]
rhs_is_trans = rhs_contract_dim[0] != 1
rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim)
is_grouped_dense_wgrad = False
if len(rhs_shape) == 2:
rhs_is_trans = rhs_contract_dim[0] != 0
is_grouped_dense_wgrad = True
# TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this?
if (
is_grouped_dense_wgrad
and not isinstance(lhs, ScaledTensor)
and not isinstance(rhs, ScaledTensor)
):
lhs_is_trans = True
rhs_is_trans = False
lhs_flatten_axis = 1
rhs_flatten_axis = 1
if (
not isinstance(lhs, ScaledTensor)
and not isinstance(rhs, ScaledTensor)
and quantizer_set != noop_quantizer_set
):
assert isinstance(quantizer_set.x, GroupedQuantizer)
assert type(quantizer_set.x) is type(quantizer_set.kernel)
scaling_mode = quantizer_set.x.scaling_mode
if (
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
# scaling_mode.is_tensor_scaling()
# and is_gemm_with_all_layouts_supported()
scaling_mode.is_1d_block_scaling()
):
lhs_is_rowwise = rhs_is_rowwise = True
else:
lhs_is_rowwise = not lhs_is_trans
rhs_is_rowwise = lhs_is_trans
quantizer_set.x.q_layout = (
QuantizeLayout.ROWWISE if lhs_is_rowwise else QuantizeLayout.COLWISE
)
quantizer_set.kernel.q_layout = (
QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE
)
lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis)
rhs_q = grouped_quantize(
rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis
)
lhs_data = lhs_q.data
rhs_data = rhs_q.data
lhs_scale_inv = lhs_q.scale_inv
rhs_scale_inv = rhs_q.scale_inv
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else:
# For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NO_SCALING
lhs_shape = lhs.shape
rhs_shape = rhs.shape
out_dtype = lhs.dtype
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
# Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode.is_tensor_scaling():
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
# swizzled_scale requires a matrix
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
# Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
# thus additional transpose is required
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
if scaling_mode.is_tensor_scaling(): # and not is_gemm_with_all_layouts_supported():
lhs_is_trans = False
rhs_is_trans = True
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor):
lhs_layout_is_T = lhs.data_layout == "T"
rhs_layout_is_T = rhs.data_layout == "T"
else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
# x.shape = [D1, D2]
# contracting_dims = (1, ) --> output.shape = [1, D1, D2]
# contracting_dims = (0, ) --> output.shape = [1, D2, D1]
bm = lhs_remain_shape[0]
bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1]
assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}"
if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0):
print("grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(f"m = {bm}, n = {bn}, k = {kl}; ")
print("cuBLAS requires the problem shapes being multiples of 16")
assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0)
lhs_list_.append(lhs_3d)
rhs_list_.append(rhs_3d)
if scaling_mode == ScalingMode.NO_SCALING:
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode.is_tensor_scaling():
lhs_sinv_list_.append(lhs.scale_inv)
rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_sinv_list_.append(lhs_scale_inv)
rhs_sinv_list_.append(rhs_scale_inv)
if bias_list is not None:
bias_list_.append(bias_list[i])
out_list = GroupedGemmPrimitive.outer_primitive.bind(
*lhs_list_,
*rhs_list_,
*lhs_sinv_list_,
*rhs_sinv_list_,
*bias_list_,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
lhs_layout_is_T = lhs_q.data_layout == "T"
rhs_layout_is_T = rhs_q.data_layout == "T"
lhs_ndim = len(lhs_shape)
rhs_ndim = len(rhs_shape)
if lhs_layout_is_T:
lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim)
if rhs_layout_is_T:
rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim)
lhs_data = _shape_normalization(lhs_data, (lhs_contract_dim, ()), not lhs_layout_is_T)
rhs_data = _shape_normalization(rhs_data, (rhs_contract_dim, ()), rhs_layout_is_T)
# Calling GroupedGEMM Custom Call
K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim)
K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim)
assert K_lhs == K_rhs
M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim))
N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G
if is_grouped_dense_wgrad:
N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim))
else:
assert group_sizes.size == rhs_shape[0]
assert group_offset.size == 1
has_bias = bias is not None
assert not has_bias or bias.shape == (group_sizes.size, N)
bias = jnp.empty((), jnp.float32) if bias is None else bias
# TODO(Phuong): support MXFP8_1D_SCALING
assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported"
(out,) = GroupedGemmPrimitive.outer_primitive.bind(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
M=M,
N=N,
K=K_lhs,
lhs_is_trans=lhs_is_trans,
rhs_is_trans=rhs_is_trans,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
has_bias=1 if bias_list is not None else 0,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
)
return out_list
"""
return out
......@@ -47,7 +47,7 @@ else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["quantize", "quantize_dbias", "grouped_quantize"]
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
class BaseDBiasQuantizePrimitive(BasePrimitive):
......@@ -1032,3 +1032,24 @@ def grouped_quantize(
group_axis=group_axis,
)
return out
def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
"""
Compute the grouped bias gradient.
Args:
grad: jnp.ndarray of shape (M, N)
group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M
Returns:
dbias: jnp.ndarray of shape (num_groups, N)
"""
assert grad.ndim == 2, "Input grad must be a 2D tensor."
assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."
segment_ids = jnp.repeat(jnp.arange(group_sizes.shape[0]), group_sizes)
grad_fp32 = grad.astype(jnp.float32)
dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
dbias = dbias_fp32.astype(grad.dtype)
return dbias
......@@ -13,43 +13,127 @@
#include "transformer_engine/multi_stream.h"
#include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32
namespace transformer_engine {
namespace jax {
Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
Variadic_Result_Type output_list, int64_t num_gemms,
JAXX_Scaling_Mode scaling_mode, int64_t has_bias) {
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) {
// Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// A: row-major [m, k] for N - [k, m] for T
// B: row-major [k, n] for N - [n, k] for T
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// cuBLAS uses column-major data_layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// A: column-major with size [k, m] for T - [m, k] for N
// B: column-major with size [n, k] for T - [k, n] for N
//
// If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if (num_gemms <= 0) {
return ffi_with_cuda_error_check();
int num_streams = nvte_get_num_compute_streams();
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_data.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_data.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type());
auto bias_ptr = has_bias ? reinterpret_cast<uint8_t *>(bias.untyped_data()) : nullptr;
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());
NVTE_CHECK(group_sizes.dimensions().size() == 1);
size_t num_gemms = group_sizes.dimensions()[0];
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto workspace_total_size = product(workspace->dimensions());
auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions());
auto workspace_size = (workspace_total_size - lhs_sinv_size - rhs_sinv_size) / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t expected_lhs_size = m * k;
size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n);
size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n);
size_t actual_lhs_size = product(lhs_data.dimensions());
size_t actual_rhs_size = product(rhs_data.dimensions());
size_t actual_out_size = product(output->dimensions());
NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ",
expected_lhs_size, ", got ", actual_lhs_size);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(expected_rhs_size == actual_rhs_size,
"Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k,
" = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m,
" * ", n, " = ", expected_out_size, ", got ", actual_out_size);
} else {
NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k,
" * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size);
NVTE_CHECK(expected_out_size == actual_out_size,
"Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n,
" = ", expected_out_size, ", got ", actual_out_size);
}
size_t dim_list_bytes = sizeof(int32_t) * num_gemms;
std::vector<int32_t> dim_list_host(num_gemms);
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
", got sum(group_sizes)=", sum_group_sizes);
} else {
NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k,
", got sum(group_sizes)=", sum_group_sizes);
}
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t expected_output_size = num_gemms + 1;
size_t actual_input_size = input_list.size();
size_t actual_output_size = output_list.size();
NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
expected_input_size, actual_input_size);
NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
expected_output_size, actual_output_size);
bool trans_lhs = true;
bool trans_rhs = false;
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
bool grad = false;
bool accumulate = false;
bool use_split_accumulator = false;
auto bias_shape = std::vector<size_t>{has_bias ? n : 0};
const int arch = cuda::sm_arch();
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
if (arch < 100 && is_fp8_gemm) {
NVTE_CHECK(!lhs_is_trans && rhs_is_trans,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ",
"got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans);
}
// These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> lhs_wrapper_list;
......@@ -67,96 +151,83 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list;
int lhs_list_offset = 0;
int rhs_list_offset = num_gemms;
int lhs_sinv_list_offset = 2 * num_gemms;
int rhs_sinv_list_offset = 3 * num_gemms;
int bias_list_offset = 4 * num_gemms;
int out_list_offset = 0;
for (int i = 0; i < num_gemms; i++) {
Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value();
Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value();
Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value();
Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value();
Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value();
DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type());
DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
void *lhs_ptr = lhs_i.untyped_data();
void *rhs_ptr = rhs_i.untyped_data();
void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *out_ptr = out_i->untyped_data();
// Placeholder for bias since it can be empty
DType bias_dtype = DType::kFloat32;
void *bias_ptr = nullptr;
auto lhs_shape_ = lhs_i.dimensions();
auto rhs_shape_ = rhs_i.dimensions();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t m = lhs_shape_[1];
size_t n = rhs_shape_[1];
size_t k = lhs_shape_[2];
auto lhs_shape = std::vector<size_t>{m, k};
auto rhs_shape = std::vector<size_t>{n, k};
auto out_shape = std::vector<size_t>{n, m};
auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
float *amax_dptr = nullptr;
float *scale_dptr = nullptr;
auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Note: the scale_inv array should have been swizzled in Python before lowering
auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
for (int i = 0; i < 2; i++) {
lhs_sinv_shape[i] = lhs_sinv_shape_[i];
rhs_sinv_shape[i] = rhs_sinv_shape_[i];
for (size_t i = 0; i < num_gemms; i++) {
// Matrix data shapes
size_t m_i = dim_list_host[i];
auto lhs_shape = std::vector<size_t>{m_i, k};
auto rhs_shape = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n};
auto out_shape = std::vector<size_t>{m_i, n};
if (is_grouped_dense_wgrad) {
size_t k_i = dim_list_host[i];
lhs_shape[0] = lhs_is_trans ? k_i : m;
lhs_shape[1] = lhs_is_trans ? m : k_i;
rhs_shape[0] = rhs_is_trans ? n : k_i;
rhs_shape[1] = rhs_is_trans ? k_i : n;
out_shape[0] = m;
out_shape[1] = n;
}
NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode);
TensorWrapper lhs_i_(nvte_scaling_mode);
TensorWrapper rhs_i_(nvte_scaling_mode);
lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape);
// Set matrix data pointers
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
void *lhs_vptr = static_cast<void *>(lhs_ptr);
void *rhs_vptr = static_cast<void *>(rhs_ptr);
if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape);
else
rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape);
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape);
else
lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else {
NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
// Scale_inv shapes
auto lhs_sinv_size = std::vector<size_t>{1};
auto rhs_sinv_size = std::vector<size_t>{1};
if (is_mxfp8_scaling) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t scale_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_size[0] = m_i * scale_k;
rhs_sinv_size[0] = n * scale_k;
// Need to add swizzle here
}
auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype);
void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0};
auto pre_gelu_shape = std::vector<size_t>{0};
if (has_bias) {
auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i);
Buffer_Type bias_i = bias_i_get.value();
bias_ptr = bias_i.untyped_data();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type());
bias_shape[0] = n;
// Set scale_inv pointers
void *rhs_sinv_vptr = static_cast<void *>(rhs_sinv_ptr);
void *lhs_sinv_vptr = static_cast<void *>(lhs_sinv_ptr);
if (is_fp8_gemm) {
if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size);
else
rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size);
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size);
else
lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size);
} else {
NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
}
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
auto pre_gelu_i = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype);
// Update pointer for the next GEMM pair
lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes;
rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes;
out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes;
if (is_fp8_gemm) {
lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes;
}
if (has_bias) bias_ptr += n * bias_dtype_bytes;
out_wrapper_list.push_back(std::move(out_i_));
// Move objects to the lists to keep them alive
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
out_wrapper_list.push_back(std::move(out_i));
bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
......@@ -167,11 +238,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
out_list.push_back(out_wrapper_list.back().data());
}
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
auto num_streams = nvte_get_num_compute_streams();
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) {
auto workspace_i =
......@@ -182,7 +248,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
}
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_gemms, trans_lhs, trans_rhs, grad,
pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad,
workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream);
......@@ -192,11 +258,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.RemainingArgs() // input list
.RemainingRets() // output list
.Attr<int64_t>("num_gemms")
.Arg<Buffer_Type>() // lhs_data
.Arg<Buffer_Type>() // lhs_sinv
.Arg<Buffer_Type>() // rhs_data
.Arg<Buffer_Type>() // rhs_sinv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // group_sizes
.Arg<Buffer_Type>() // group_offset
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // workspace
.Attr<int64_t>("M")
.Attr<int64_t>("N")
.Attr<int64_t>("K")
.Attr<bool>("lhs_is_trans")
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("has_bias"),
.Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"),
FFI_CudaGraph_Traits);
} // namespace jax
......
......@@ -153,28 +153,28 @@ def _dense_bwd_rule(
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple(
g_contracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
)
# k_non_contracting_dims
k_constracting_dim = tuple(
k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim),
(g_contracting_dim, k_contracting_dim),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN
# x_non_contracting_dims
g_constracting_dim = x_constracting_dim = tuple(
g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims))
)
wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim)
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim)
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......@@ -184,135 +184,240 @@ def _dense_bwd_rule(
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
"""
def grouped_dense(
x_list,
kernel_list,
bias_list,
contracting_dims_list,
quantizer_set_list=None,
x: jnp.ndarray,
kernel: jnp.ndarray,
group_sizes: jnp.ndarray,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
bias: jnp.ndarray = None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
# Perform grouped_dense layer transformation with optional quantization.
"""
Perform grouped dense (linear) layer transformation with optional quantization.
output_list = _grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
Args:
x: Input tensor of shape (M, K)
kernel: Weight matrix of shape (G, K, N)
group_sizes: 1D array of shape (G,) specifying the size of each group
contracting_dims: Tuple of sequences specifying which dimensions to contract
(currently only supports ((1,), (1,)))
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
Returns:
A jnp.ndarray containing the result of the grouped linear operation
"""
output = _grouped_dense(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
)
return output_list
return output
@partial(jax.custom_vjp, nondiff_argnums=(3,))
def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list):
output_list, _ = _grouped_dense_fwd_rule(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7))
def _grouped_dense(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
):
output, _ = _grouped_dense_fwd_rule(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
)
return output_list
return output
def _grouped_dense_fwd_rule(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
):
use_bias = bias_list is not None
output_list = []
x_rowwise_list = []
x_colwise_list = []
kernel_colwise_list = []
kernel_rowwise_list = []
x_shape_list = []
kernel_shape_list = []
if quantizer_set_list is None:
x_rowwise_list = x_list
x_colwise_list = x_list
kernel_colwise_list = kernel_list
kernel_rowwise_list = kernel_list
x_shape_list = [x.shape for x in x_list]
kernel_shape_list = [kernel.shape for kernel in kernel_list]
use_bias = bias is not None
is_noop_quantizer_set = quantizer_set == noop_quantizer_set
if is_noop_quantizer_set:
grouped_gemm_x = x
grouped_gemm_kernel = kernel
ctx_x = x
ctx_kernel = kernel
flatten_axis_k = None
else:
for i in range(len(x_list)): # pylint: disable=consider-using-enumerate
q_x = tex.quantize(x_list[i], quantizer_set_list[i].x)
q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel)
x_rowwise_list.append(q_x.get_rowwise_tensor())
x_colwise_list.append(q_x.get_colwise_tensor())
kernel_colwise_list.append(q_kernel.get_colwise_tensor())
kernel_rowwise_list.append(q_kernel.get_rowwise_tensor())
x_shape_list.append(x_rowwise_list[-1].data.shape)
kernel_shape_list.append(kernel_rowwise_list[-1].data.shape)
output_list = tex.grouped_gemm(
x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list
x_contracting_dims, k_contracting_dims = contracting_dims
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis
assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
# Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
"grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
"and k_contracting_dims=(1,) for now, "
f"got {x_contracting_dims=} and {k_contracting_dims=}"
)
k_contracting_dims = (0,)
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
)
casted_kernel = tex.grouped_quantize(
kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k
)
contracting_dims = (x_contracting_dims, k_contracting_dims)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor()
grouped_gemm_kernel = casted_kernel.get_colwise_tensor()
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()?
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
output = tex.grouped_gemm(
grouped_gemm_x,
grouped_gemm_kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
)
ctx = (
x_colwise_list,
kernel_rowwise_list,
x_shape_list,
kernel_shape_list,
group_sizes,
ctx_x,
ctx_kernel,
x.shape,
kernel.shape,
use_bias,
quantizer_set_list,
is_noop_quantizer_set,
quantizer_set,
flatten_axis_k,
)
return output_list, ctx
return output, ctx
def _grouped_dense_bwd_rule(
contracting_dims, precision, preferred_element_type, group_offset, ctx, grad
):
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
(
colwise_x_list,
rowwise_kernel_list,
x_shape_list,
kernel_shape_list,
group_sizes,
ctx_x,
ctx_kernel,
x_shape,
kernel_shape,
use_bias,
quantizer_set_list,
is_noop_quantizer_set,
quantizer_set,
flatten_axis_k,
) = ctx
group_size = len(grad_list)
dbias_list = []
grad_rowwise_list = []
grad_colwise_list = []
dgrad_contracting_dims_list = []
wgrad_contracting_dims_list = []
for i in range(group_size):
grad = grad_list[i]
x_shape = x_shape_list[i]
kernel_shape = kernel_shape_list[i]
fwd_contracting_dims = contracting_dims_list[i]
if quantizer_set_list is None:
casted_grad = grad
dbias = tex.quantization._jax_dbias(grad)
grad_rowwise_list.append(grad)
grad_colwise_list.append(grad)
else:
quantizer_set = quantizer_set_list[i]
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad
)
grad_rowwise_list.append(casted_grad.get_rowwise_tensor())
grad_colwise_list.append(casted_grad.get_colwise_tensor())
dbias_list.append(dbias)
# GEMM NT
fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims
if is_noop_quantizer_set:
# The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
# g_contracting_dim = (1, )
# k_contracting_dim = (2, )
g_contracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
)
k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_contracting_dims_list.append(dgrad_contracting_dims)
dgrad_grad = grad
dgrad_kernel_T = ctx_kernel
# GEMM TN
# g_contracting_dim = (0, )
# x_contracting_dim = (0, )
g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims))
)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_contracting_dims_list.append(wgrad_contracting_dims)
wgrad_x_T = ctx_x
wgrad_grad = grad
else:
casted_grad = tex.grouped_quantize(
grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
# g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
# extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (1,)
k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor()
dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,)
x_contracting_dim = (1,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor()
dgrad = tex.grouped_gemm(
dgrad_grad,
dgrad_kernel_T,
group_sizes,
dgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
dgrad_list = tex.grouped_gemm(
grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list
wgrad = tex.grouped_gemm(
wgrad_x_T,
wgrad_grad,
group_sizes,
wgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list)
return dgrad_list, wgrad_list, dbias_list, quantizer_set_list
group_sizes_grad = None
dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
......@@ -127,14 +127,16 @@ class BlockScaleDequantizer(Dequantizer):
def dequantize(scaled_tensor):
"""Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args:
scaled_tensor: The quantized tensor to dequantize
data: The quantized tensor data
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
The dequantized tensor in the specified data type
The dequantized tensor
"""
return BlockScaleDequantizer._dequantize_func(
scaled_tensor.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment