Unverified Commit 98b4c0d9 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] grouped_gemm() uses variadic arguments (#1658)



* New GroupedGemmPrimitive using variadic args

* Remove squeeze() to reduce D2D memcpy

* Revert to the list append fashion to simplify code

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c8e7cc02
...@@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi" name = "te_grouped_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = (6, 7, 8, 9) impl_static_args = ()
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract( def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias):
lhs_contig_aval, """
lhs_scale_contig_aval, Args:
rhs_contig_aval, *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
rhs_scale_contig_aval, args[ 0 : num_gemms] are the lhs tensors,
bias_contig_aval, args[ num_gemms : 2*num_gemms] are the rhs tensors,
dim_list_aval, args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
*, args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
num_gemms, args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
scaling_mode, num_gemms: Number of GEMM operations to perform.
out_dtype, scaling_mode: Scaling mode for the GEMM operations.
out_flat_size, out_dtype: Data type of the output tensors.
): has_bias: Boolean indicating if bias tensors are provided.
del lhs_contig_aval, lhs_scale_contig_aval
del rhs_contig_aval, rhs_scale_contig_aval Returns:
del bias_contig_aval, dim_list_aval A tuple of ShapedArray objects of size num_gemms+1:
del num_gemms, scaling_mode ret[0 : num_gemms]: GEMM output tensors,
out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype) ret[num_gemms]:workspace tensor.
wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams """
wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8) del scaling_mode
return (out_flat_aval, wkspace_aval) expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
assert (
len(args) == expected_num_args
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
A_list = args[0:num_gemms]
B_list = args[num_gemms : 2 * num_gemms]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval = tuple(
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
for A, B in zip(A_list, B_list)
)
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return (*out_list_aval, workspace_aval)
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
...@@ -74,60 +87,27 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -74,60 +87,27 @@ class GroupedGemmPrimitive(BasePrimitive):
return out_aval return out_aval
@staticmethod @staticmethod
def lowering( def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias):
ctx, del out_dtype
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
del out_dtype, out_flat_size
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx, ctx,
lhs_contig, *args,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=scaling_mode.value, scaling_mode=int(scaling_mode),
has_bias=has_bias,
) )
@staticmethod @staticmethod
def impl( def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias):
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
assert GroupedGemmPrimitive.inner_primitive is not None assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind( out = GroupedGemmPrimitive.inner_primitive.bind(
lhs_contig, *args,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=scaling_mode, scaling_mode=scaling_mode.value,
out_dtype=out_dtype, out_dtype=out_dtype,
out_flat_size=out_flat_size, has_bias=has_bias,
) )
return out[0] # out is [out_flat, wkspace], only return out_flat return out[:-1] # out is [out_list, wkspace], only return out_list
register_primitive(GroupedGemmPrimitive) register_primitive(GroupedGemmPrimitive)
...@@ -366,6 +346,7 @@ def swizzled_scale(scales): ...@@ -366,6 +346,7 @@ def swizzled_scale(scales):
rows, cols = scales.shape rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
scales = scales.reshape(rows, cols)
return scales return scales
...@@ -380,18 +361,12 @@ def grouped_gemm( ...@@ -380,18 +361,12 @@ def grouped_gemm(
len(lhs_list) == len(rhs_list) == len(contracting_dims_list) len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length" ), "lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms = len(lhs_list)
out_flat_size = 0
dims = []
lhs_contig_ = []
rhs_contig_ = []
lhs_scale_inv_contig_ = []
rhs_scale_inv_contig_ = []
bias_contig_ = []
out_offsets = []
remain_shape_list = []
num_gemms = len(lhs_list) num_gemms = len(lhs_list)
lhs_list_ = []
rhs_list_ = []
lhs_sinv_list_ = []
rhs_sinv_list_ = []
bias_list_ = []
for i in range(num_gemms): for i in range(num_gemms):
lhs = lhs_list[i] lhs = lhs_list[i]
rhs = rhs_list[i] rhs = rhs_list[i]
...@@ -402,7 +377,7 @@ def grouped_gemm( ...@@ -402,7 +377,7 @@ def grouped_gemm(
lhs_shape = lhs.data.shape lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
assert not ( assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
...@@ -427,6 +402,7 @@ def grouped_gemm( ...@@ -427,6 +402,7 @@ def grouped_gemm(
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
# Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
if scaling_mode == ScalingMode.NO_SCALING: if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn) lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn)
...@@ -438,13 +414,13 @@ def grouped_gemm( ...@@ -438,13 +414,13 @@ def grouped_gemm(
rhs_3d = _shape_normalization(rhs.data, rhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
# swizzled_scale requires a matrix
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else: else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: if _shape_normalization() is updated to support non-TN, need to update here # Note: already_transposed doesn't matter for the output shape
# already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2] # x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
...@@ -455,66 +431,37 @@ def grouped_gemm( ...@@ -455,66 +431,37 @@ def grouped_gemm(
bn = rhs_remain_shape[0] bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1] kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1] kr = rhs_3d.shape[-1]
remain_shape_list.append(((bm,), (bn,))) assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}"
assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})" if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0):
k = kl print("grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(f"m = {bm}, n = {bn}, k = {kl}; ")
if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0): print("cuBLAS requires the problem shapes being multiples of 16")
print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0)
print(
f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples" lhs_list_.append(lhs_3d)
" of 16" rhs_list_.append(rhs_3d)
)
assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0
dims.append((bm, bn, k))
lhs_contig_.append(lhs_3d.reshape(-1))
rhs_contig_.append(rhs_3d.reshape(-1))
if scaling_mode == ScalingMode.NO_SCALING: if scaling_mode == ScalingMode.NO_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1)) lhs_sinv_list_.append(lhs.scale_inv)
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1)) rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING: if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1)) lhs_sinv_list_.append(lhs_scale_inv)
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1)) rhs_sinv_list_.append(rhs_scale_inv)
if bias_list is not None: if bias_list is not None:
bias_contig_.append(bias_list[i].reshape(-1)) bias_list_.append(bias_list[i])
out_flat_size += bm * bn
out_offsets.append(out_flat_size) out_list = GroupedGemmPrimitive.outer_primitive.bind(
*lhs_list_,
lhs_contig = jnp.concatenate(lhs_contig_) *rhs_list_,
rhs_contig = jnp.concatenate(rhs_contig_) *lhs_sinv_list_,
lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_) *rhs_sinv_list_,
rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_) *bias_list_,
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32)
# TE/common does not support NVTE_NO_SCALING yet
# It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16
if scaling_mode == ScalingMode.NO_SCALING:
scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode,
out_dtype=out_dtype, out_dtype=out_dtype,
out_flat_size=out_flat_size, has_bias=1 if bias_list is not None else 0,
) )
# Split the output back into tensors return out_list
out_offsets = jnp.array(out_offsets)
out_flat_list = jnp.split(out_contig, out_offsets[:-1])
out_tensors = []
for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list):
out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape))
return out_tensors
...@@ -15,34 +15,9 @@ ...@@ -15,34 +15,9 @@
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
constexpr static size_t MXFP8_BLOCK_SIZE = 32; Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
Variadic_Result_Type output_list, int64_t num_gemms,
// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX) JAXX_Scaling_Mode scaling_mode, int64_t has_bias) {
Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr,
const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype,
uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr,
const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype,
uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms,
int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode,
cudaStream_t stream) {
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms;
std::unique_ptr<int32_t[]> dim_list_host = std::make_unique<int32_t[]>(3 * num_gemms);
cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k], // A: row-major with size [m, k],
...@@ -56,6 +31,18 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -56,6 +31,18 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
// C: column-major with size [m, n] --> row-major with size [n, m]. // C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if (num_gemms <= 0) {
return ffi_with_cuda_error_check();
}
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t expected_output_size = num_gemms + 1;
size_t actual_input_size = input_list.size();
size_t actual_output_size = output_list.size();
NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
expected_input_size, actual_input_size);
NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
expected_output_size, actual_output_size);
bool trans_lhs = true; bool trans_lhs = true;
bool trans_rhs = false; bool trans_rhs = false;
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
...@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
std::vector<NVTETensor> out_list; std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list; std::vector<NVTETensor> workspace_list;
int lhs_list_offset = 0;
int rhs_list_offset = num_gemms;
int lhs_sinv_list_offset = 2 * num_gemms;
int rhs_sinv_list_offset = 3 * num_gemms;
int bias_list_offset = 4 * num_gemms;
int out_list_offset = 0;
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
size_t m = dim_list_host[i * 3]; Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value();
size_t n = dim_list_host[i * 3 + 1]; Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value();
size_t k = dim_list_host[i * 3 + 2]; Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value();
Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value();
Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value();
DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type());
DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
void *lhs_ptr = lhs_i.untyped_data();
void *rhs_ptr = rhs_i.untyped_data();
void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *out_ptr = out_i->untyped_data();
// Placeholder for bias since it can be empty
DType bias_dtype = DType::kFloat32;
void *bias_ptr = nullptr;
auto lhs_shape_ = lhs_i.dimensions();
auto rhs_shape_ = rhs_i.dimensions();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t m = lhs_shape_[1];
size_t n = rhs_shape_[1];
size_t k = lhs_shape_[2];
auto lhs_shape = std::vector<size_t>{m, k}; auto lhs_shape = std::vector<size_t>{m, k};
auto rhs_shape = std::vector<size_t>{n, k}; auto rhs_shape = std::vector<size_t>{n, k};
...@@ -90,52 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -90,52 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto lhs_sinv_shape = std::vector<size_t>{1, 1}; auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1}; auto rhs_sinv_shape = std::vector<size_t>{1, 1};
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape); float *amax_dptr = nullptr;
rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape); float *scale_dptr = nullptr;
auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { reinterpret_cast<float *>(lhs_sinv_ptr));
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat32, auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
std::vector<size_t>{1}); reinterpret_cast<float *>(rhs_sinv_ptr));
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat32, lhs_wrapper_list.push_back(std::move(lhs_i_));
std::vector<size_t>{1}); rhs_wrapper_list.push_back(std::move(rhs_i_));
} else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { } else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t sinv_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_shape[0] = m;
lhs_sinv_shape[1] = sinv_k;
rhs_sinv_shape[0] = n;
rhs_sinv_shape[1] = sinv_k;
// Note: the scale_inv array should have been swizzled in Python before lowering // Note: the scale_inv array should have been swizzled in Python before lowering
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0, auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
lhs_sinv_shape); auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0, for (int i = 0; i < 2; i++) {
rhs_sinv_shape); lhs_sinv_shape[i] = lhs_sinv_shape_[i];
rhs_sinv_shape[i] = rhs_sinv_shape_[i];
}
NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode);
TensorWrapper lhs_i_(nvte_scaling_mode);
TensorWrapper rhs_i_(nvte_scaling_mode);
lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else { } else {
NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode)); NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
} }
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
lhs_ptr += m * k * lhs_dtype_bytes;
rhs_ptr += n * k * rhs_dtype_bytes;
out_ptr += m * n * out_dtype_bytes;
lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes;
auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype);
void *pre_gelu_ptr = nullptr; void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0}; auto bias_shape = std::vector<size_t>{0};
auto pre_gelu_shape = std::vector<size_t>{0}; auto pre_gelu_shape = std::vector<size_t>{0};
if (bias_ptr != nullptr) bias_shape[0] = n; if (has_bias) {
auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i);
Buffer_Type bias_i = bias_i_get.value();
bias_ptr = bias_i.untyped_data();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type());
bias_shape[0] = n;
}
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes;
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype); auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
out_wrapper_list.push_back(std::move(out_i)); out_wrapper_list.push_back(std::move(out_i_));
bias_wrapper_list.push_back(std::move(bias_i)); bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
...@@ -146,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -146,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
out_list.push_back(out_wrapper_list.back().data()); out_list.push_back(out_wrapper_list.back().data());
} }
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size}; auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) { for (int i = 0; i < num_streams; i++) {
auto workspace_i = auto workspace_i =
...@@ -163,50 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh ...@@ -163,50 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten,
Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten,
Buffer_Type dim_list, Result_Type out_flatten,
Result_Type workspace_flatten, int64_t num_gemms,
JAXX_Scaling_Mode scaling_mode) {
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv_flatten.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv_flatten.untyped_data());
auto bias_ptr = reinterpret_cast<uint8_t *>(bias_flatten.untyped_data());
auto dim_list_ptr = reinterpret_cast<int32_t *>(dim_list.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type());
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type());
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(out_flatten->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace_flatten->untyped_data());
auto workspace_size = workspace_flatten->dimensions().back() / num_streams;
return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype,
rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype,
workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode,
stream);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind() FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream .Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_flatten .RemainingArgs() // input list
.Arg<Buffer_Type>() // lhs_sinv_flatten .RemainingRets() // output list
.Arg<Buffer_Type>() // rhs_flatten
.Arg<Buffer_Type>() // rhs_sinv_flatten
.Arg<Buffer_Type>() // bias_flatten
.Arg<Buffer_Type>() // dim_list
.Ret<Buffer_Type>() // out_flatten
.Ret<Buffer_Type>() // workspace_flatten
.Attr<int64_t>("num_gemms") .Attr<int64_t>("num_gemms")
.Attr<JAXX_Scaling_Mode>("scaling_mode"), .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("has_bias"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
} // namespace jax } // namespace jax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment