Unverified Commit de51c96b authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[NVFP4][MOE] Bug Fix for NVFP4 Grouped Quant (#2564)



* fix
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* resolve review comments
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>

* Comment tweaks
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarZhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 702fc5ee
......@@ -1125,8 +1125,9 @@ template <bool kEnableStochasticRounding, bool kEnableRHTColQuant, bool kEnableR
void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_size, TA const *A,
TB const *B, TQA *QA, TSFA *SFA,
MultiAmaxHadamardCastFusionArgs &args,
const size_t *rng_state, uint32_t sm_count,
cudaStream_t stream, int k_tile_size = 1024) {
const size_t *rng_state, uint32_t *tile_scheduler_workspace,
uint32_t sm_count, cudaStream_t stream,
int k_tile_size = 1024) {
using namespace cute;
static int constexpr SFVecSize = 16;
static int constexpr RhtTensorSize = 16;
......@@ -1295,10 +1296,9 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Allocate workspace and set to zero
void *tile_scheduler_workspace = nullptr;
NVTE_CHECK_CUDA(cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream));
NVTE_CHECK_CUDA(cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream));
// Set workspace and set to zero
NVTE_CHECK_CUDA(cudaMemsetAsync(reinterpret_cast<void *>(tile_scheduler_workspace), 0,
sizeof(uint32_t), stream));
// Launch kernel
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream};
......@@ -1308,8 +1308,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
tile_scheduler_workspace, mma, rng_state);
NVTE_CHECK_CUDA(cudaGetLastError());
NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed.");
NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream));
}
} // namespace
......@@ -1318,7 +1316,8 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tensor *> &output_list,
const size_t *split_sections, size_t num_tensors,
const Tensor &hadamard_matrix_,
QuantizationConfig &quant_config, cudaStream_t stream) {
QuantizationConfig &quant_config, Tensor &quant_workspace,
cudaStream_t stream) {
NVTE_API_CALL(group_hadamard_transform_cast_fusion);
using transformer_engine::detail::kMaxTensorsPerKernel;
......@@ -1399,6 +1398,12 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}
uint32_t *tile_scheduler_workspace = nullptr;
NVTE_CHECK(quant_workspace.data.dptr != nullptr, "Quantization workspace must be provided.");
NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t),
"Quantization workspace must be at least 4 bytes.");
tile_scheduler_workspace = reinterpret_cast<uint32_t *>(quant_workspace.data.dptr);
// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
......@@ -1461,7 +1466,9 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens
/*QA=*/reinterpret_cast<TQA *>(rowwise_data_base_ptr),
/*SFA=*/reinterpret_cast<TSFA *>(rowwise_scale_inv_base_ptr),
/*args=*/kernel_args,
/*rng_state=*/rng_state, /*sm_count=*/sm_count,
/*rng_state=*/rng_state,
/*tile_scheduler_workspace=*/tile_scheduler_workspace,
/*sm_count=*/sm_count,
/*stream=*/stream, /*k_tile_size=*/k_tile_size);
} else {
NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=",
......@@ -1478,7 +1485,7 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
const size_t *split_sections,
const size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTETensor quant_workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion);
using namespace transformer_engine;
NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0.");
......@@ -1489,6 +1496,8 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
output_list[i] = convertNVTETensorCheck(outputs[i]);
}
Tensor *quant_workspace_tensor = convertNVTETensorCheck(quant_workspace);
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
......@@ -1497,5 +1506,5 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
// Call the multi-tensor Hadamard transform amax implementation.
group_hadamard_transform_cast_fusion(*input_tensor, output_list, split_sections, num_tensors,
*convertNVTETensorCheck(hadamard_matrix), quant_config_cpp,
stream);
*quant_workspace_tensor, stream);
}
......@@ -115,13 +115,14 @@ void nvte_group_hadamard_transform_cast_fusion_columnwise(
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs,
const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
NVTETensor quant_workspace, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -872,10 +872,16 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix);
if (all_aligned_token_dim) {
// allocate a tile scheduler workspace
auto tile_scheduler_workspace_torch =
at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32));
auto nvte_tile_scheduler_workspace =
makeTransformerEngineTensor(tile_scheduler_workspace_torch);
// call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose
nvte_group_hadamard_transform_cast_fusion(
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0], stream);
rht_matrix_nvte.data(), split_sections.data(), num_tensors, quant_config_list[0],
nvte_tile_scheduler_workspace.data(), stream);
} else {
// Separate quantization for rowwise usage and columnwise usage
// Rowwise quantization fusion with grouped version
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment