Unverified Commit 01a504c4 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] Grouped GEMM & Dense support MXFP8 and handle empty matrices (#1871)



* Support MXFP8 and handle empty matrices
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
parent a69692ac
......@@ -1250,6 +1250,9 @@ class TestGroupedDense:
group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m))
group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])])
group_sizes = jnp.diff(group_sizes)
# Make one empty input lhs to test empty GEMM handling
group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1])
group_sizes = group_sizes.at[1].set(0)
assert group_sizes.sum() == m
# *32 to make sure that input shape works for MXFP8
......@@ -1301,9 +1304,6 @@ class TestGroupedDense:
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_gemm yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
......@@ -1388,9 +1388,6 @@ class TestGroupedDense:
)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
pytest.skip("MXFP8 is not supported in grouped_dense yet")
fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16
x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input(
......
......@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H
#define TRANSFORMER_ENGINE_MULTI_STREAM_H
#include "cuda_runtime.h"
#ifdef __cplusplus
extern "C" {
#endif
......@@ -18,6 +20,26 @@ extern "C" {
/*! \brief Number of CUDA streams to use in multi-stream operations */
int nvte_get_num_compute_streams();
/*! \brief Get a CUDA stream for compute operations.
*
* \param[in] idx Index of the stream to retrieve.Add commentMore actions
* \return A cudaStream_t.
*
* This function returns a CUDA stream that can be used for compute operations.
* The index should be in the range [0, nvte_get_num_compute_streams() - 1].
*/
cudaStream_t nvte_get_compute_stream(const int idx);
/*! \brief Get a CUDA event for compute operations.
*
* \param[in] idx Index of the event to retrieve.
* \return A cudaEvent_t.
*
* This function returns a CUDA event that can be used to synchronize compute operations.
* The index should be in the range [0, nvte_get_num_compute_streams() - 1].
*/
cudaEvent_t nvte_get_compute_stream_event(const int idx);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -58,4 +58,12 @@ int get_num_compute_streams() {
int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); }
cudaStream_t nvte_get_compute_stream(const int idx) {
return transformer_engine::detail::get_compute_stream(idx);
}
cudaEvent_t nvte_get_compute_stream_event(const int idx) {
return transformer_engine::detail::get_compute_stream_event(idx);
}
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
......@@ -103,14 +103,15 @@ class GroupedGemmPrimitive(BasePrimitive):
"""
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, scaling_mode, has_bias
del lhs_scale_inv_aval, rhs_scale_inv_aval
# TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
# JAX buffer pointers are 128-aligned
# 255 is added to the workspace size to ensure workspace ptr is 256-aligned
workspace_size += 255
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
# We also pad scale_inv swizzle buffers size for 256 bytes alignment.
workspace_size += 256
workspace_size += lhs_scale_inv_aval.size + 256
workspace_size += rhs_scale_inv_aval.size + 256
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
# TODO(phuong): We should make separate tmp buffers for swizzled scales to avoid unaligned-by-256 workspace ptr issue
out_shape = (M, N)
if is_grouped_dense_wgrad:
......@@ -495,7 +496,8 @@ def grouped_gemm(
# and is_gemm_with_all_layouts_supported()
scaling_mode.is_1d_block_scaling()
):
lhs_is_rowwise = rhs_is_rowwise = True
lhs_is_rowwise = True
rhs_is_rowwise = False
else:
lhs_is_rowwise = not lhs_is_trans
rhs_is_rowwise = lhs_is_trans
......@@ -557,9 +559,6 @@ def grouped_gemm(
assert not has_bias or bias.shape == (group_sizes.size, N)
bias = jnp.empty((), jnp.float32) if bias is None else bias
# TODO(Phuong): support MXFP8_1D_SCALING
assert scaling_mode != ScalingMode.MXFP8_1D_SCALING, "MXFP8_1D_SCALING is not yet supported"
(out,) = GroupedGemmPrimitive.outer_primitive.bind(
lhs_data,
lhs_scale_inv,
......
......@@ -10,6 +10,8 @@
#include "../extensions.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "transformer_engine/multi_stream.h"
#include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32
......@@ -17,6 +19,12 @@
namespace transformer_engine {
namespace jax {
static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
// Move the pointer to the next 256B aligned address
return reinterpret_cast<uint8_t *>((reinterpret_cast<uintptr_t>(ptr) + 255) &
~static_cast<uintptr_t>(255));
}
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
......@@ -58,11 +66,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto out_ptr = reinterpret_cast<uint8_t *>(output->untyped_data());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
// Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned
auto workspace_ptr =
reinterpret_cast<uint8_t *>((reinterpret_cast<uintptr_t>(workspace->untyped_data()) + 255) &
~static_cast<uintptr_t>(255));
auto workspace_total_size = product(workspace->dimensions()) - 255;
auto workspace_size = workspace_total_size / num_streams;
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
auto workspace_total_size = product(workspace->dimensions());
auto lhs_sinv_size = product(lhs_sinv.dimensions());
auto rhs_sinv_size = product(rhs_sinv.dimensions());
auto workspace_size =
(workspace_total_size - lhs_sinv_size - rhs_sinv_size - 3 * 256) / num_streams;
auto swizzled_lhs_sinv_ptr = workspace_ptr + workspace_size * num_streams;
swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr);
auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size;
swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr);
size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype);
size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype);
......@@ -122,6 +137,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// It is weird that TE/Common GEMM only use colwise for MXFP8
const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype);
const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans;
const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans;
......@@ -135,6 +152,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// These lists are to keep the TensorWrapper objects alive
std::vector<TensorWrapper> lhs_wrapper_list;
std::vector<TensorWrapper> rhs_wrapper_list;
std::vector<TensorWrapper> lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling
std::vector<TensorWrapper> rhs_swizzle_wrapper_list;
std::vector<TensorWrapper> bias_wrapper_list;
std::vector<TensorWrapper> pre_gelu_wrapper_list;
std::vector<TensorWrapper> out_wrapper_list;
......@@ -143,66 +162,119 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
std::vector<NVTETensor> lhs_list;
std::vector<NVTETensor> rhs_list;
std::vector<NVTETensor> lhs_swizzle_list;
std::vector<NVTETensor> rhs_swizzle_list;
std::vector<NVTETensor> bias_list;
std::vector<NVTETensor> pre_gelu_list;
std::vector<NVTETensor> out_list;
std::vector<NVTETensor> workspace_list;
size_t lhs_sinv_total_size = 0;
size_t rhs_sinv_total_size = 0;
std::vector<void *> zero_out_dptr_list;
std::vector<size_t> zero_out_size_list;
for (size_t i = 0; i < num_gemms; i++) {
// Matrix data shapes
size_t m_i = dim_list_host[i];
auto lhs_shape = std::vector<size_t>{m_i, k};
auto rhs_shape = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n};
auto out_shape = std::vector<size_t>{m_i, n};
auto lhs_shape_i = std::vector<size_t>{m_i, k};
auto rhs_shape_i = std::vector<size_t>{rhs_is_trans ? n : k, rhs_is_trans ? k : n};
auto out_shape_i = std::vector<size_t>{m_i, n};
if (is_grouped_dense_wgrad) {
size_t k_i = dim_list_host[i];
lhs_shape[0] = lhs_is_trans ? k_i : m;
lhs_shape[1] = lhs_is_trans ? m : k_i;
rhs_shape[0] = rhs_is_trans ? n : k_i;
rhs_shape[1] = rhs_is_trans ? k_i : n;
out_shape[0] = m;
out_shape[1] = n;
lhs_shape_i[0] = lhs_is_trans ? k_i : m;
lhs_shape_i[1] = lhs_is_trans ? m : k_i;
rhs_shape_i[0] = rhs_is_trans ? n : k_i;
rhs_shape_i[1] = rhs_is_trans ? k_i : n;
out_shape_i[0] = m;
out_shape_i[1] = n;
}
size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1];
size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1];
size_t out_size = out_shape_i[0] * out_shape_i[1];
bool is_empty_gemm = lhs_size == 0 || rhs_size == 0;
if (is_empty_gemm && out_size > 0) {
zero_out_dptr_list.push_back(out_ptr);
zero_out_size_list.push_back(out_size * out_dtype_bytes);
}
// Set matrix data pointers
auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape, out_dtype);
auto out_i = TensorWrapper(static_cast<void *>(out_ptr), out_shape_i, out_dtype);
void *lhs_vptr = static_cast<void *>(lhs_ptr);
void *rhs_vptr = static_cast<void *>(rhs_ptr);
if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape);
rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
else
rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape);
rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape);
lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i);
else
lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape);
lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i);
// Scale_inv shapes
auto lhs_sinv_size = std::vector<size_t>{1};
auto rhs_sinv_size = std::vector<size_t>{1};
if (is_mxfp8_scaling) {
NVTE_CHECK(k % MXFP8_BLOCK_SIZE == 0, "MXFP8 K-dim being divisble by %d (got %d)",
MXFP8_BLOCK_SIZE, k);
size_t scale_k = k / MXFP8_BLOCK_SIZE;
lhs_sinv_size[0] = m_i * scale_k;
rhs_sinv_size[0] = n * scale_k;
// Need to add swizzle here
}
// Set scale_inv pointers
// Set scale_inv shapes and pointers
void *rhs_sinv_vptr = static_cast<void *>(rhs_sinv_ptr);
void *lhs_sinv_vptr = static_cast<void *>(lhs_sinv_ptr);
if (is_fp8_gemm) {
size_t lhs_sinv_size_i = 0;
size_t rhs_sinv_size_i = 0;
if (is_tensor_scaling) {
auto tensor_scaling_sinv_shape = std::vector<size_t>{1};
// If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers
if (!is_empty_gemm) {
lhs_sinv_size_i = 1;
rhs_sinv_size_i = 1;
}
if (rhs_use_colwise) // MatA to enter cuBLAS
rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size);
rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape);
else
rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_size);
rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape);
if (lhs_use_colwise) // MatB to enter cuBLAS
lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size);
lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape);
else
lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_size);
lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape);
} else if (is_mxfp8_scaling) {
auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
void *swizzled_lhs_sinv_vptr = static_cast<void *>(swizzled_lhs_sinv_ptr);
void *swizzled_rhs_sinv_vptr = static_cast<void *>(swizzled_rhs_sinv_ptr);
// {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i
// point to swizzled scale_inv data (store on workspace, only used for GEMM).
// Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers
auto lhs_sinv_shape_i =
get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise);
auto rhs_sinv_shape_i =
get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise);
lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1];
rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1];
if (lhs_use_colwise) {
lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i);
lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
} else {
lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i);
lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i);
}
if (rhs_use_colwise) {
rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
} else {
rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i);
rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i);
}
if (!is_empty_gemm) {
lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i));
rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i));
lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data());
rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data());
}
} else {
NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING,
"Unsupported scaling mode: ", static_cast<int>(scaling_mode));
......@@ -212,16 +284,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto pre_gelu_i = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype);
// Update pointer for the next GEMM pair
lhs_ptr += lhs_shape[0] * lhs_shape[1] * lhs_dtype_bytes;
rhs_ptr += rhs_shape[0] * rhs_shape[1] * rhs_dtype_bytes;
out_ptr += out_shape[0] * out_shape[1] * out_dtype_bytes;
lhs_ptr += lhs_size * lhs_dtype_bytes;
rhs_ptr += rhs_size * rhs_dtype_bytes;
out_ptr += out_size * out_dtype_bytes;
if (is_fp8_gemm) {
lhs_sinv_ptr += lhs_sinv_size[0] * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_size[0] * rhs_sinv_dtype_bytes;
lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes;
rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes;
lhs_sinv_total_size += lhs_sinv_size_i;
rhs_sinv_total_size += rhs_sinv_size_i;
if (is_mxfp8_scaling) {
swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes;
swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes;
}
}
if (has_bias) bias_ptr += n * bias_dtype_bytes;
// Move objects to the lists to keep them alive
if (is_empty_gemm) continue;
lhs_wrapper_list.push_back(std::move(lhs_i));
rhs_wrapper_list.push_back(std::move(rhs_i));
out_wrapper_list.push_back(std::move(out_i));
......@@ -244,10 +323,41 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
workspace_ptr += workspace_size;
}
if (is_fp8_gemm) {
NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ",
lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size);
NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ",
rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size);
}
size_t num_non_empty_gemms = lhs_list.size();
if (is_mxfp8_scaling) {
for (int i = 0; i < num_non_empty_gemms; i++) {
// The i-th GEMM will use the (i % num_streams)-th stream to compute,
// use the same stream to swizzle the scaling factors to make sure that
// the swizzling is done before the GEMM computation starts.
int stream_id = i % num_streams;
cudaStream_t stream_i = nvte_get_compute_stream(stream_id);
nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i);
nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i);
}
}
// Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM
size_t num_zero_outs = zero_out_dptr_list.size();
for (int i = 0; i < num_zero_outs; i++) {
int stream_id = i % num_streams;
cudaStream_t stream_i = nvte_get_compute_stream(stream_id);
void *dptr = zero_out_dptr_list[i];
size_t count = zero_out_size_list[i];
NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i));
}
nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(),
pre_gelu_list.data(), num_gemms, rhs_is_trans, lhs_is_trans, grad,
workspace_list.data(), accumulate, use_split_accumulator,
num_math_sm, stream);
pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans,
lhs_is_trans, grad, workspace_list.data(), accumulate,
use_split_accumulator, num_math_sm, stream);
return ffi_with_cuda_error_check();
}
......
......@@ -287,7 +287,13 @@ def _grouped_dense_fwd_rule(
"and k_contracting_dims=(1,) for now, "
f"got {x_contracting_dims=} and {k_contracting_dims=}"
)
scaling_mode = quantizer_set.x.scaling_mode
if scaling_mode.is_tensor_scaling():
k_contracting_dims = (0,)
elif scaling_mode.is_1d_block_scaling():
k_contracting_dims = (1,)
else:
raise ValueError(f"Unsupported scaling mode {scaling_mode.value} for grouped_dense")
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
......@@ -385,7 +391,7 @@ def _grouped_dense_bwd_rule(
dgrad_grad = casted_grad.get_rowwise_tensor()
dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (0,) to make it work
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment