"docs/vscode:/vscode.git/clone" did not exist on "1d7998e333fe98941d9858e8554631848cd7a4b2"
Unverified Commit 98b4c0d9 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] grouped_gemm() uses variadic arguments (#1658)



* New GroupedGemmPrimitive using variadic args

* Remove squeeze() to reduce D2D memcpy

* Revert to the list append fashion to simplify code

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent c8e7cc02
......@@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9)
impl_static_args = ()
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
lhs_contig_aval,
lhs_scale_contig_aval,
rhs_contig_aval,
rhs_scale_contig_aval,
bias_contig_aval,
dim_list_aval,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
):
del lhs_contig_aval, lhs_scale_contig_aval
del rhs_contig_aval, rhs_scale_contig_aval
del bias_contig_aval, dim_list_aval
del num_gemms, scaling_mode
out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype)
wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8)
return (out_flat_aval, wkspace_aval)
def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias):
"""
Args:
*args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
args[ 0 : num_gemms] are the lhs tensors,
args[ num_gemms : 2*num_gemms] are the rhs tensors,
args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
num_gemms: Number of GEMM operations to perform.
scaling_mode: Scaling mode for the GEMM operations.
out_dtype: Data type of the output tensors.
has_bias: Boolean indicating if bias tensors are provided.
Returns:
A tuple of ShapedArray objects of size num_gemms+1:
ret[0 : num_gemms]: GEMM output tensors,
ret[num_gemms]:workspace tensor.
"""
del scaling_mode
expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
assert (
len(args) == expected_num_args
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
A_list = args[0:num_gemms]
B_list = args[num_gemms : 2 * num_gemms]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval = tuple(
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
for A, B in zip(A_list, B_list)
)
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return (*out_list_aval, workspace_aval)
@staticmethod
def outer_abstract(*args, **kwargs):
......@@ -74,60 +87,27 @@ class GroupedGemmPrimitive(BasePrimitive):
return out_aval
@staticmethod
def lowering(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
del out_dtype, out_flat_size
def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias):
del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx,
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*args,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
scaling_mode=int(scaling_mode),
has_bias=has_bias,
)
@staticmethod
def impl(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias):
assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*args,
num_gemms=num_gemms,
scaling_mode=scaling_mode,
scaling_mode=scaling_mode.value,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
has_bias=has_bias,
)
return out[0] # out is [out_flat, wkspace], only return out_flat
return out[:-1] # out is [out_list, wkspace], only return out_list
register_primitive(GroupedGemmPrimitive)
......@@ -366,6 +346,7 @@ def swizzled_scale(scales):
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
scales = scales.reshape(rows, cols)
return scales
......@@ -380,18 +361,12 @@ def grouped_gemm(
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms = len(lhs_list)
out_flat_size = 0
dims = []
lhs_contig_ = []
rhs_contig_ = []
lhs_scale_inv_contig_ = []
rhs_scale_inv_contig_ = []
bias_contig_ = []
out_offsets = []
remain_shape_list = []
num_gemms = len(lhs_list)
lhs_list_ = []
rhs_list_ = []
lhs_sinv_list_ = []
rhs_sinv_list_ = []
bias_list_ = []
for i in range(num_gemms):
lhs = lhs_list[i]
rhs = rhs_list[i]
......@@ -402,7 +377,7 @@ def grouped_gemm(
lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal data_layout
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
......@@ -427,6 +402,7 @@ def grouped_gemm(
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
# Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
......@@ -438,13 +414,13 @@ def grouped_gemm(
rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
# swizzled_scale requires a matrix
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: if _shape_normalization() is updated to support non-TN, need to update here
# already_transposed doesn't matter for the output shape
# Note: already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
......@@ -455,66 +431,37 @@ def grouped_gemm(
bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1]
remain_shape_list.append(((bm,), (bn,)))
assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})"
k = kl
if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0):
print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(
f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples"
" of 16"
)
assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0
dims.append((bm, bn, k))
lhs_contig_.append(lhs_3d.reshape(-1))
rhs_contig_.append(rhs_3d.reshape(-1))
assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}"
if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0):
print("grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(f"m = {bm}, n = {bn}, k = {kl}; ")
print("cuBLAS requires the problem shapes being multiples of 16")
assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0)
lhs_list_.append(lhs_3d)
rhs_list_.append(rhs_3d)
if scaling_mode == ScalingMode.NO_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32))
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
lhs_sinv_list_.append(lhs.scale_inv)
rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
lhs_sinv_list_.append(lhs_scale_inv)
rhs_sinv_list_.append(rhs_scale_inv)
if bias_list is not None:
bias_contig_.append(bias_list[i].reshape(-1))
out_flat_size += bm * bn
out_offsets.append(out_flat_size)
lhs_contig = jnp.concatenate(lhs_contig_)
rhs_contig = jnp.concatenate(rhs_contig_)
lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_)
rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_)
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32)
# TE/common does not support NVTE_NO_SCALING yet
# It expects NVTE_DELAYED_TENSOR_SCALING as default for FP32, BF16, FP16
if scaling_mode == ScalingMode.NO_SCALING:
scaling_mode = ScalingMode.DELAYED_TENSOR_SCALING
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
bias_list_.append(bias_list[i])
out_list = GroupedGemmPrimitive.outer_primitive.bind(
*lhs_list_,
*rhs_list_,
*lhs_sinv_list_,
*rhs_sinv_list_,
*bias_list_,
num_gemms=num_gemms,
scaling_mode=scaling_mode.value,
scaling_mode=scaling_mode,
out_dtype=out_dtype,
out_flat_size=out_flat_size,
has_bias=1 if bias_list is not None else 0,
)
# Split the output back into tensors
out_offsets = jnp.array(out_offsets)
out_flat_list = jnp.split(out_contig, out_offsets[:-1])
out_tensors = []
for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list):
out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape))
return out_tensors
return out_list
......@@ -15,34 +15,9 @@
namespace transformer_engine {
namespace jax {
constexpr static size_t MXFP8_BLOCK_SIZE = 32;
// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX)
Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lhs_sinv_ptr,
const DType &lhs_sinv_dtype, uint8_t *rhs_ptr, const DType &rhs_dtype,
uint8_t *rhs_sinv_ptr, const DType &rhs_sinv_dtype, uint8_t *bias_ptr,
const DType &bias_dtype, uint8_t *out_ptr, const DType &out_dtype,
uint8_t *workspace_ptr, const size_t workspace_size, size_t num_gemms,
int32_t *dim_list_ptr, const JAXX_Scaling_Mode scaling_mode,
cudaStream_t stream) {
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype);
size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype);
size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype);
size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)");
NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)");
size_t dim_list_bytes = sizeof(int32_t) * 3 * num_gemms;
std::unique_ptr<int32_t[]> dim_list_host = std::make_unique<int32_t[]>(3 * num_gemms);
cudaMemcpyAsync(dim_list_host.get(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
Variadic_Result_Type output_list, int64_t num_gemms,
JAXX_Scaling_Mode scaling_mode, int64_t has_bias) {
// Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
......@@ -56,6 +31,18 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
// C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if (num_gemms <= 0) {
return ffi_with_cuda_error_check();
}
size_t expected_input_size = has_bias ? 5 * num_gemms : 4 * num_gemms;
size_t expected_output_size = num_gemms + 1;
size_t actual_input_size = input_list.size();
size_t actual_output_size = output_list.size();
NVTE_CHECK(actual_input_size == expected_input_size, "Expected %zu input tensors, got %zu",
expected_input_size, actual_input_size);
NVTE_CHECK(actual_output_size == expected_output_size, "Expected %zu output tensors, got %zu",
expected_output_size, actual_output_size);
bool trans_lhs = true;
bool trans_rhs = false;
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
......@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list;
int lhs_list_offset = 0;
int rhs_list_offset = num_gemms;
int lhs_sinv_list_offset = 2 * num_gemms;
int rhs_sinv_list_offset = 3 * num_gemms;
int bias_list_offset = 4 * num_gemms;
int out_list_offset = 0;
for (int i = 0; i < num_gemms; i++) {
size_t m = dim_list_host[i * 3];
size_t n = dim_list_host[i * 3 + 1];
size_t k = dim_list_host[i * 3 + 2];
Buffer_Type lhs_i = input_list.get<Buffer_Type>(lhs_list_offset + i).value();
Buffer_Type rhs_i = input_list.get<Buffer_Type>(rhs_list_offset + i).value();
Buffer_Type lhs_sinv_i = input_list.get<Buffer_Type>(lhs_sinv_list_offset + i).value();
Buffer_Type rhs_sinv_i = input_list.get<Buffer_Type>(rhs_sinv_list_offset + i).value();
Result_Type out_i = output_list.get<Buffer_Type>(out_list_offset + i).value();
DType lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_i.element_type());
DType rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_i.element_type());
DType out_dtype = convert_ffi_datatype_to_te_dtype(out_i->element_type());
void *lhs_ptr = lhs_i.untyped_data();
void *rhs_ptr = rhs_i.untyped_data();
void *lhs_sinv_ptr = lhs_sinv_i.untyped_data();
void *rhs_sinv_ptr = rhs_sinv_i.untyped_data();
void *out_ptr = out_i->untyped_data();
// Placeholder for bias since it can be empty
DType bias_dtype = DType::kFloat32;
void *bias_ptr = nullptr;
auto lhs_shape_ = lhs_i.dimensions();
auto rhs_shape_ = rhs_i.dimensions();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t m = lhs_shape_[1];
size_t n = rhs_shape_[1];
size_t k = lhs_shape_[2];
auto lhs_shape = std::vector<size_t>{m, k};
auto rhs_shape = std::vector<size_t>{n, k};
......@@ -90,52 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto lhs_sinv_shape = std::vector<size_t>{1, 1};
auto rhs_sinv_shape = std::vector<size_t>{1, 1};
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
lhs_i.set_rowwise_data(static_cast<void *>(lhs_ptr), lhs_dtype, lhs_shape);
rhs_i.set_rowwise_data(static_cast<void *>(rhs_ptr), rhs_dtype, rhs_shape);
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat32,
std::vector<size_t>{1});
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat32,
std::vector<size_t>{1});
if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
float *amax_dptr = nullptr;
float *scale_dptr = nullptr;
auto lhs_i_ = TensorWrapper(lhs_ptr, lhs_shape, lhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(lhs_sinv_ptr));
auto rhs_i_ = TensorWrapper(rhs_ptr, rhs_shape, rhs_dtype, amax_dptr, scale_dptr,
reinterpret_cast<float *>(rhs_sinv_ptr));
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t sinv_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_shape[0] = m;
lhs_sinv_shape[1] = sinv_k;
rhs_sinv_shape[0] = n;
rhs_sinv_shape[1] = sinv_k;
// Note: the scale_inv array should have been swizzled in Python before lowering
lhs_i.set_rowwise_scale_inv(static_cast<void *>(lhs_sinv_ptr), DType::kFloat8E8M0,
lhs_sinv_shape);
rhs_i.set_rowwise_scale_inv(static_cast<void *>(rhs_sinv_ptr), DType::kFloat8E8M0,
rhs_sinv_shape);
auto lhs_sinv_shape_ = lhs_sinv_i.dimensions();
auto rhs_sinv_shape_ = rhs_sinv_i.dimensions();
for (int i = 0; i < 2; i++) {
lhs_sinv_shape[i] = lhs_sinv_shape_[i];
rhs_sinv_shape[i] = rhs_sinv_shape_[i];
}
NVTEScalingMode nvte_scaling_mode = get_nvte_scaling_mode(scaling_mode);
TensorWrapper lhs_i_(nvte_scaling_mode);
TensorWrapper rhs_i_(nvte_scaling_mode);
lhs_i_.set_rowwise_data(lhs_ptr, lhs_dtype, lhs_shape);
rhs_i_.set_rowwise_data(rhs_ptr, rhs_dtype, rhs_shape);
lhs_i_.set_rowwise_scale_inv(lhs_sinv_ptr, DType::kFloat8E8M0, lhs_sinv_shape);
rhs_i_.set_rowwise_scale_inv(rhs_sinv_ptr, DType::kFloat8E8M0, rhs_sinv_shape);
lhs_wrapper_list.push_back(std::move(lhs_i_));
rhs_wrapper_list.push_back(std::move(rhs_i_));
} else {
NVTE_ERROR("Unsupported scaling mode: ", static_cast<int>(scaling_mode));
}
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
lhs_ptr += m * k * lhs_dtype_bytes;
rhs_ptr += n * k * rhs_dtype_bytes;
out_ptr += m * n * out_dtype_bytes;
lhs_sinv_ptr += lhs_sinv_shape[0] * lhs_sinv_shape[1] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_shape[0] * rhs_sinv_shape[1] * rhs_sinv_dtype_bytes;
auto out_i_ = TensorWrapper(out_ptr, out_shape, out_dtype);
void *pre_gelu_ptr = nullptr;
auto bias_shape = std::vector<size_t>{0};
auto pre_gelu_shape = std::vector<size_t>{0};
if (bias_ptr != nullptr) bias_shape[0] = n;
if (has_bias) {
auto bias_i_get = input_list.get<Buffer_Type>(bias_list_offset + i);
Buffer_Type bias_i = bias_i_get.value();
bias_ptr = bias_i.untyped_data();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_i.element_type());
bias_shape[0] = n;
}
auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
if (bias_ptr != nullptr) bias_ptr += n * bias_dtype_bytes;
auto pre_gelu_i = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, out_dtype);
out_wrapper_list.push_back(std::move(out_i));
out_wrapper_list.push_back(std::move(out_i_));
bias_wrapper_list.push_back(std::move(bias_i));
pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i));
......@@ -146,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
out_list.push_back(out_wrapper_list.back().data());
}
auto workspace_get = output_list.get<Buffer_Type>(num_gemms);
Result_Type workspace = workspace_get.value();
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
size_t workspace_size = workspace->dimensions()[0] / num_streams;
auto workspace_shape = std::vector<size_t>{workspace_size};
for (int i = 0; i < num_streams; i++) {
auto workspace_i =
......@@ -163,50 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
return ffi_with_cuda_error_check();
}
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_flatten,
Buffer_Type lhs_sinv_flatten, Buffer_Type rhs_flatten,
Buffer_Type rhs_sinv_flatten, Buffer_Type bias_flatten,
Buffer_Type dim_list, Result_Type out_flatten,
Result_Type workspace_flatten, int64_t num_gemms,
JAXX_Scaling_Mode scaling_mode) {
// Inputs
auto lhs_ptr = reinterpret_cast<uint8_t *>(lhs_flatten.untyped_data());
auto rhs_ptr = reinterpret_cast<uint8_t *>(rhs_flatten.untyped_data());
auto lhs_sinv_ptr = reinterpret_cast<uint8_t *>(lhs_sinv_flatten.untyped_data());
auto rhs_sinv_ptr = reinterpret_cast<uint8_t *>(rhs_sinv_flatten.untyped_data());
auto bias_ptr = reinterpret_cast<uint8_t *>(bias_flatten.untyped_data());
auto dim_list_ptr = reinterpret_cast<int32_t *>(dim_list.untyped_data());
auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_flatten.element_type());
auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_flatten.element_type());
auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv_flatten.element_type());
auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv_flatten.element_type());
auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias_flatten.element_type());
// Outputs
auto out_ptr = reinterpret_cast<uint8_t *>(out_flatten->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(out_flatten->element_type());
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace_flatten->untyped_data());
auto workspace_size = workspace_flatten->dimensions().back() / num_streams;
return GroupedGemmImpl(lhs_ptr, lhs_dtype, lhs_sinv_ptr, lhs_sinv_dtype, rhs_ptr, rhs_dtype,
rhs_sinv_ptr, rhs_sinv_dtype, bias_ptr, bias_dtype, out_ptr, out_dtype,
workspace_ptr, workspace_size, num_gemms, dim_list_ptr, scaling_mode,
stream);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs_flatten
.Arg<Buffer_Type>() // lhs_sinv_flatten
.Arg<Buffer_Type>() // rhs_flatten
.Arg<Buffer_Type>() // rhs_sinv_flatten
.Arg<Buffer_Type>() // bias_flatten
.Arg<Buffer_Type>() // dim_list
.Ret<Buffer_Type>() // out_flatten
.Ret<Buffer_Type>() // workspace_flatten
.RemainingArgs() // input list
.RemainingRets() // output list
.Attr<int64_t>("num_gemms")
.Attr<JAXX_Scaling_Mode>("scaling_mode"),
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("has_bias"),
FFI_CudaGraph_Traits);
} // namespace jax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment