Unverified Commit da3890e8 authored by SijiaYang's avatar SijiaYang Committed by GitHub
Browse files

[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)


Signed-off-by: default avataryangsijia.614 <yangsijia.614@bytedance.com>
Co-authored-by: default avataryicwang <yichen.wang@bytedance.com>
parent cb432f17
......@@ -249,6 +249,9 @@ set(SOURCES
"csrc/speculative/speculative_sampling.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/common_extension.cc"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
......
......@@ -277,6 +277,25 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"int num_layers) -> ()");
m.impl("transfer_kv_all_layer_mla_direct", torch::kCUDA, &transfer_kv_all_layer_mla_direct);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From FlashInfer
*/
......
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
#include "cutlass/gemm/collective/collective_builder_decl.hpp"
#include "cutlass/gemm/collective/collective_mma_decl.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/pipeline/sm90_pipeline.hpp"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_RS
template <
class ElementA_,
class GmemLayoutATag_,
int AlignmentA,
class ElementB_,
class GmemLayoutBTag_,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType>
struct CollectiveBuilderMixedInput<
arch::Sm90,
arch::OpClassTensorOp,
ElementA_,
GmemLayoutATag_,
AlignmentA,
ElementB_,
GmemLayoutBTag_,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedPingpong>) &&
(detail::is_use_rmem_A<ElementA_, GmemLayoutATag_, ElementB_, GmemLayoutBTag_>() ||
// ConvertAndScale and ConvertAndScaleWithZero
cute::is_tuple<ElementA_>::value || cute::is_tuple<ElementB_>::value ||
// DirectConvert
sizeof_bits<ElementA_>::value != sizeof_bits<ElementB_>::value)>> {
private:
using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementA_>;
using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementB_>;
using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementA_>;
using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementB_>;
static constexpr bool NeitherIsTuple = !cute::is_tuple<ElementA_>::value && !cute::is_tuple<ElementB_>::value;
// Determine if mixed input types.
static constexpr bool IsMixedInput = cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementA_>> !=
cute::sizeof_bits_v<detail::deduce_mixed_width_dtype_t<0, ElementB_>>;
static constexpr bool IsArrayOfPointersGemm = cute::is_any_of_v<
KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>;
static_assert(IsMixedInput || !IsArrayOfPointersGemm, "Only mixed input grouped RS GEMM is supported.");
public:
using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementA_>;
using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementB_>;
static_assert(
!IsMixedInput || (cute::is_tuple<ElementA_>::value ^ cute::is_tuple<ElementB_>::value ||
(NeitherIsTuple && (sizeof_bits<ElementA>::value != sizeof_bits<ElementB>::value))),
"Either A OR B must be a tuple or the widths of A and B must be different.");
static constexpr bool IsANarrow = sizeof_bits<ElementA>::value < sizeof_bits<ElementB>::value;
template <class T>
static auto get_stride(T const& t) {
if constexpr (not cute::is_layout<cute::remove_pointer_t<T>>::value) {
return t;
} else {
if constexpr (cute::is_pointer_v<T>) {
return &cute::stride(*t);
} else {
return cute::stride(t);
}
}
}
using GmemLayoutATag = decltype(get_stride(GmemLayoutATag_{}));
using GmemLayoutBTag = decltype(get_stride(GmemLayoutBTag_{}));
using ElementPairA =
cute::conditional_t<IsMixedInput && IsANarrow && NeitherIsTuple, cute::tuple<ElementA>, ElementA_>;
using ElementPairB =
cute::conditional_t<IsMixedInput && !IsANarrow && NeitherIsTuple, cute::tuple<ElementB>, ElementB_>;
static constexpr bool IsATransformed = cute::is_tuple<ElementPairA>::value;
using ElementScale = cute::conditional_t<IsATransformed, ScaleA, ScaleB>;
using ElementZero = cute::conditional_t<IsATransformed, ZeroA, ZeroB>;
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(
detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A<GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B<GmemLayoutBTag>();
// If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to rmem and we must swap the
// operands.
static constexpr bool SwapAB =
IsMixedInput ? !IsATransformed : detail::is_swapAB<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>();
static constexpr bool IsWarpSpecializedTransposeB =
detail::is_warpspecialized_transpose_B<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag, KernelScheduleType>();
static_assert(!IsMixedInput || !IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B.");
// When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly.
static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB;
// For fp32 types, map to tf32 MMA value type.
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
// Handle mixed dtypes and MMA.
using RealElementA = cute::conditional_t<SwapAB, ElementBMma, ElementAMma>;
using RealElementB = cute::conditional_t<SwapAB, ElementAMma, ElementBMma>;
using RealElementAMma = cute::conditional_t<IsMixedInput, RealElementB, RealElementA>;
// Always the same for element B.
using RealElementBMma = RealElementB;
static_assert(
!IsMixedInput || TiledMmaGmmaMajorB == GMMA::Major::K || sizeof_bits<RealElementB>::value == 16,
"Mixed input GEMM does not support MN major layout except for 16bit");
using AtomLayoutMNK = cute::conditional_t<
cute::is_any_of_v<
KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative>,
Layout<Shape<_2, _1, _1>>,
Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<
RealElementAMma,
RealElementBMma,
ElementAccumulator,
TileShape_MNK,
GMMA::Major::K,
GMMA::Major::K>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::rs_smem_selector<
GmmaMajorA,
ElementAMma,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{})),
IsWarpSpecializedTransposeB>());
using SmemLayoutAtomB = decltype(detail::rs_smem_selector<
GmmaMajorB,
ElementBMma,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{})),
IsWarpSpecializedTransposeB>());
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomA{});
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutAtomB{});
static constexpr int SmemAlignment = static_cast<int>(cute::max(SmemAlignmentA, SmemAlignmentB));
// Handle mixed dtype array GEMM's size of tensor map storage.
static constexpr size_t TensorMapStorage = sizeof(cute::TmaDescriptor) * size_t(IsMixedInput) * 4;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout;
static constexpr int PipelineStages =
IsMixedInput ? (IsArrayOfPointersGemm ? detail::compute_stage_count_or_override_single_affine_transformed_input<
Sm90ReducedSmemCapacityBytes,
RealElementA,
RealElementB,
ElementScale,
ElementZero,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{})
: detail::compute_stage_count_or_override_single_affine_transformed_input<
detail::sm90_smem_capacity_bytes,
RealElementA,
RealElementB,
ElementScale,
ElementZero,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{}))
: detail::compute_stage_count_or_override<
detail::sm90_smem_capacity_bytes,
ElementAMma,
ElementBMma,
TileShape_MNK,
StageCountType::bytes,
SmemAlignment>(StageCountType{});
using DispatchPolicy = cute::conditional_t<
IsMixedInput,
cute::conditional_t<
IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput<PipelineStages, ClusterShape_MNK, KernelScheduleType>>,
MainloopSm90TmaGmmaRmemAWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = cute::conditional_t<SwapAB, void, Copy_Atom<cute::AutoVectorizingCopy, ElementA>>;
using SmemCopyAtomB = cute::conditional_t<SwapAB, Copy_Atom<cute::AutoVectorizingCopy, ElementB>, void>;
// We pack the scale data with the operand that will be optionally scaled and converted before MMA.
using StrideA = cute::conditional_t<
cute::is_layout<cute::remove_pointer_t<GmemLayoutATag_>>::value,
GmemLayoutATag_,
TagToStrideA_t<GmemLayoutATag>>;
using StrideB = cute::conditional_t<
cute::is_layout<cute::remove_pointer_t<GmemLayoutBTag_>>::value,
GmemLayoutBTag_,
TagToStrideB_t<GmemLayoutBTag>>;
using CollectiveOp = CollectiveMmaArrayMixedInput<
DispatchPolicy,
TileShape_MNK,
ElementPairA,
StrideA,
ElementPairB,
StrideB,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity>;
static_assert(
SmemAlignment == static_cast<int>(cute::max(CollectiveOp::SmemAlignmentA, CollectiveOp::SmemAlignmentB)));
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_array_mixed_input.hpp"
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class ArchTag,
class OpClass,
class ElementA,
class GmemLayoutA,
int AlignmentA,
class ElementB,
class GmemLayoutB,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType,
class Enable = void>
struct CollectiveBuilderMixedInput {
static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_mixed_input.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
class DispatchPolicy,
class TileShape,
class ElementA,
class StrideA,
class ElementB,
class StrideB,
class TiledMma,
class GmemTiledCopyA,
class SmemLayoutAtomA,
class SmemCopyAtomA,
class TransformA,
class GmemTiledCopyB,
class SmemLayoutAtomB,
class SmemCopyAtomB,
class TransformB>
struct CollectiveMmaArrayMixedInput {
static_assert(cutlass::detail::dependent_false<ElementA>, "Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, 0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, 0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
void cutlass_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
void get_cutlass_w4a8_moe_mm_data_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
cutlass_w4a8_moe_mm_sm90(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk);
return;
}
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
get_cutlass_w4a8_moe_mm_data_caller(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k);
return;
}
#pragma once
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <torch/all.h>
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
template <typename ElementA, typename ElementB, typename ElementC, typename ElementAccumulator>
__global__ void int4_fp8_get_group_gemm_starts(
int32_t* expert_offsets,
ElementA** a_offsets,
ElementB** b_offsets,
ElementC** out_offsets,
ElementAccumulator** a_scales_offsets,
cutlass::bfloat16_t** b_scales_offsets,
ElementA* a_base_as_int,
ElementB* b_base_as_int,
ElementC* out_base_as_int,
ElementAccumulator* a_scales_base_as_int,
cutlass::bfloat16_t* b_scales_base_as_int,
int64_t n,
int64_t k,
bool per_act_token,
bool per_out_ch) {
int expert_id = threadIdx.x;
int32_t expert_offset = expert_offsets[expert_id];
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
b_offsets[expert_id] = b_base_as_int + expert_id * k * n / 2;
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
a_scales_offsets[expert_id] = a_scales_base_as_int + (per_act_token ? expert_offset : 0);
b_scales_offsets[expert_id] = b_scales_base_as_int + (per_out_ch ? expert_id * n * 4 * k / 512 : expert_id);
}
#define __CALL_W4A8_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
int4_fp8_get_group_gemm_starts<cutlass::float_e4m3_t, cutlass::int8_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::int8_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
static_cast<float**>(a_scales_ptrs.data_ptr()), \
static_cast<cutlass::bfloat16_t**>(b_scales_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
static_cast<cutlass::int8_t*>(b_tensors.data_ptr()), \
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
static_cast<float*>(a_scales.data_ptr()), \
static_cast<cutlass::bfloat16_t*>(b_scales.data_ptr()), \
out_tensors.size(1), \
a_tensors.size(1), \
per_act_token, \
per_out_ch); \
}
namespace {
void run_int4_fp8_get_group_gemm_starts(
torch::Tensor const& expert_offsets,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor& out_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kBFloat16);
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
if (false) {
}
__CALL_W4A8_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_W4A8_GET_STARTS_KERNEL(torch::kFloat16, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "w4a8_grouped_mm_c3x.cuh"
using namespace cute;
namespace {
#define JOIN_STRUCT_NAME(m, n, k, a, b, c) sm90_fp8_config##_##m##_##n##_##k##_##a##_##b##_##c
#define JOIN_STRUCT_NAME_CO(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
#define GENERATE_SM90_W4A8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
#define GENERATE_SM90_W4A8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_NAME_CO(M, N, K, A, B, C) { \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using TileShape = cute::Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = cute::Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
\
using Cutlass3xW4A8Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>; \
};
GENERATE_SM90_W4A8_PP_CONFIG(64, 16, 512, 1, 1, 1)
GENERATE_SM90_W4A8_PP_CONFIG(64, 32, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 16, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 1, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 32, 512, 2, 1, 1)
GENERATE_SM90_W4A8_CO_CONFIG(128, 64, 512, 1, 1, 1)
void dispatch_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
uint32_t const m = a_tensors.size(0) / topk;
uint32_t const n = d_tensors.size(1);
uint32_t const k = a_tensors.size(1);
if (n == 4096 && k == 7168) {
// group gemm 1
if (m <= 4) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 16) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 256) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 1024) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 2, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else if (n == 7168 && k == 2048) {
// group gemm 2
if (m <= 8) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME(64, 16, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else if (m <= 512) {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 64, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
} else {
using Cutlass3xW4A8GemmSelected = typename JOIN_STRUCT_NAME_CO(128, 32, 512, 1, 1, 1)::Cutlass3xW4A8Gemm;
cutlass_w4a8_group_gemm_caller<Cutlass3xW4A8GemmSelected>(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size);
}
}
} // namespace
void cutlass_w4a8_moe_mm_sm90(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk) {
dispatch_w4a8_moe_mm_sm90(
d_tensors,
a_tensors,
b_tensors,
a_scales,
b_scales,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk);
}
#pragma once
/**
* @file w4a8_grouped_mm_c3x.cuh
* @brief Implementation of grouped GEMM operation with int4 and fp8 mixed
* precision
*
* This file implements a grouped GEMM operation that multiplies FP8 matrices
* (A) with quantized INT4 matrices (B), applying per-block scaling factors.
* The implementation is optimized for NVIDIA Hopper GPUs, leveraging Tensor
* Cores for mixed precision arithmetic.
*
* Key features:
* - Supports grouped GEMM operations with multiple experts
* - Uses FP8 (e4m3) for matrix A
* - Uses INT4 quantization for matrix B with per-block scaling
* - Implements preprocessing for INT4 encoding and scale packing
* - Optimized for Hopper architecture with Tensor Core operations
*/
#include <ATen/cuda/CUDAContext.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder_mixed_input.hpp"
#include "w4a8_get_group_starts.cuh"
using namespace cute;
namespace {
// Type definitions
using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
using QuantType = cutlass::int4b_t; // 4-bit integer type
using ElementAccumulator = float; // Accumulator type
using ElementScale = cutlass::bfloat16_t; // Scale type
using ElementScalePacked = cutlass::Array<ElementScale, 4>;
using ElementC = cutlass::half_t; // Default output type (FP16)
using ElementD = ElementC; // Default output type (FP16)
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// Architecture-specific configurations
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
// constexpr int TileShapeK = 512;
// using TileShape = Shape<_128, _32, cute::Int<TileShapeK>>;
// using ClusterShape = Shape<_1, _1, _1>;
// Layout configurations
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
// Transposed layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using LayoutC_Transpose = typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
// Alignments
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<MmaType>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<QuantType>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
struct cutlass_3x_w4a8_group_gemm {
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC,
LayoutC_Transpose*,
AlignmentC,
ElementD,
LayoutD_Transpose*,
AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilderMixedInput<
ArchTag,
OperatorClass,
cute::tuple<QuantType, ElementScalePacked>,
LayoutB_Transpose*,
AlignmentB,
MmaType,
LayoutA_Transpose*,
AlignmentA,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
// Define the final kernel and GEMM operation types
using GemmKernelScaleOnly =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloopScaleOnly, CollectiveEpilogue>;
using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelScaleOnly>;
using StrideA = cute::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutA*>>;
using StrideB = cute::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutB*>>;
using StrideC = typename GemmKernelScaleOnly::InternalStrideC;
using StrideD = typename GemmKernelScaleOnly::InternalStrideD;
using StrideS = typename CollectiveMainloopScaleOnly::StrideScale;
};
/**
* @brief Main function to run int4 * fp8 grouped GEMM from PyTorch
*
* This function performs multiple GEMM operations in parallel where each
* operation multiplies an FP8 matrix (A) with a quantized INT4 matrix (B),
* applying per-channel scaling factors. It's designed for efficient execution
* on NVIDIA Hopper GPUs, leveraging Tensor Cores for optimal performance with
* mixed precision arithmetic.
*
* The function includes preprocessing steps for both INT4 tensors and scale
* factors to ensure optimal performance and correct operation.
*
* @param d_tensors Output tensor D with shape [total_m, total_n]
* @param a_tensors Tensor containing all A matrices (fp8_e4m3) with shape
* [total_m, K]
* @param b_tensors Tensor containing all B matrices (int4 packed as int8) with
* shape [E, N, K/2]
* @param a_scales Tensor containing A matrix scale factors
* @param b_scales Tensor containing B matrix scale factors with shape [E,
* K//512, N*4]
* @param expert_offsets Tensor containing expert offsets for determining group
* boundaries (int32)
* @param problem_sizes Tensor containing problem sizes with shape [num_experts,
* 3] (M, N, K for each group) (int32)
* @param a_strides Stride information for A tensors
* @param b_strides Stride information for B tensors
* @param d_strides Stride information for D tensors
* @param s_strides Stride information for scale tensors
* @param chunk_size Size of each chunk for scales (K / number of scale chunks)
*/
// template <typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule>
template <typename Gemm>
void cutlass_w4a8_group_gemm_caller(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size) {
// using Gemm = cutlass_3x_w4a8_group_gemm<TileShape, ClusterShape, KernelSchedule, EpilogueSchedule>;
using Args = typename Gemm::GemmScaleOnly::Arguments;
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
// Check inputs
TORCH_CHECK(a_tensors.dim() == 2, "A tensor must be 2D");
TORCH_CHECK(b_tensors.dim() == 3, "B tensor must be 3D [E, N, K/2]");
TORCH_CHECK(b_scales.dim() == 3, "Scale tensor must be 3D [E, K//512, N*4]");
TORCH_CHECK(a_scales.dim() == 1, "A Scale tensor must be 1D [1]");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be a 1D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
// Check tensor shapes
TORCH_CHECK(problem_sizes.size(0) == num_experts, "problem_sizes must have num_experts rows");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have 3 columns (N, M, K)");
TORCH_CHECK(b_tensors.size(0) == num_experts, "B tensor first dimension must match number of groups");
TORCH_CHECK(b_scales.size(0) == num_experts, "Scale tensor first dimension must match number of groups");
TORCH_CHECK(b_tensors.size(2) * 2 == a_tensors.size(1), "B tensor K/2 dimension must match A tensor K dimension");
TORCH_CHECK(b_scales.size(1) == a_tensors.size(1) / 512, "Scale tensor second dimension must be K//512");
TORCH_CHECK(b_scales.size(2) == 4 * b_tensors.size(1), "Scale tensor last dimension must be 4*N");
// Check tensor types
TORCH_CHECK(a_tensors.scalar_type() == torch::kFloat8_e4m3fn, "A tensor must be fp8 (float_e4m3_t) type");
TORCH_CHECK(b_tensors.scalar_type() == torch::kInt8, "B tensor must contain packed int4 values (stored as int8)");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "Expert offsets must be int32 type");
TORCH_CHECK(problem_sizes.scalar_type() == torch::kInt32, "Problem sizes must be int32 type");
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int = torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a_tensors.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
Args arguments;
decltype(arguments.epilogue.thread) fusion_args;
fusion_args.alpha = 1.0f;
fusion_args.beta = 0;
fusion_args.alpha_ptr = a_scales.data_ptr<float>();
;
fusion_args.beta_ptr = nullptr;
fusion_args.alpha_ptr_array = nullptr;
fusion_args.beta_ptr_array = nullptr;
fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0};
fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0};
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
static_cast<ProblemShape::UnderlyingProblemShape*>(problem_sizes.data_ptr());
run_int4_fp8_get_group_gemm_starts(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_tensors,
b_tensors,
d_tensors,
a_scales,
b_scales);
arguments = Args{
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr},
{static_cast<const QuantType**>(b_ptrs.data_ptr()),
static_cast<typename Gemm::StrideB*>(b_strides.data_ptr()),
static_cast<const MmaType**>(a_ptrs.data_ptr()),
static_cast<typename Gemm::StrideA*>(a_strides.data_ptr()),
static_cast<const ElementScalePacked**>(b_scales_ptrs.data_ptr()),
static_cast<typename Gemm::StrideS*>(s_strides.data_ptr()),
static_cast<int>(chunk_size)},
{fusion_args,
nullptr,
nullptr,
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<typename Gemm::StrideD*>(d_strides.data_ptr())},
hw_info};
// Instantiate and run GEMM
typename Gemm::GemmScaleOnly gemm;
size_t workspace_size = Gemm::GemmScaleOnly::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
auto workspace = torch::empty(workspace_size, workspace_options);
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM implementation not supported");
}
status = gemm.initialize(arguments, workspace.data_ptr(), stream);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM initialization failed");
}
status = gemm.run(stream);
if (status != cutlass::Status::kSuccess) {
TORCH_CHECK(false, "GEMM execution failed");
}
}
} // namespace
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <iostream>
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes_w4a8(
const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length,
const int n,
const int k) {
int expert_id = blockIdx.x;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
occurrences += (topk_ids[i] == expert_id);
}
atomicAdd(&atomic_buffer[expert_id], occurrences);
__syncthreads();
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n;
}
}
__global__ void compute_expert_offsets_w4a8(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3 + 1];
expert_offsets[i + 1] = tot_offset;
}
}
void get_cutlass_w4a8_moe_mm_data_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes_w4a8<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
topk_ids.numel(),
n,
k);
compute_expert_offsets_w4a8<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
}
......@@ -467,6 +467,35 @@ void transfer_kv_all_layer_mla_direct(
int64_t page_size,
int64_t num_layers);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
/*
* From FlashInfer
*/
......
......@@ -19,6 +19,7 @@ from sgl_kernel.attention import (
merge_state,
merge_state_v2,
)
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
from sgl_kernel.elementwise import (
apply_rope_with_cos_sin_cache_inplace,
fused_add_rmsnorm,
......
import torch
def get_cutlass_w4a8_moe_mm_data(
topk_ids: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor,
output_permutation: torch.Tensor,
num_experts: int,
n: int,
k: int,
):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in topk_ids (token-expert mapping) and uses it to
compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation after the input is sorted with
input_permutation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
- input_permutation: Permutation that must be used to shuffle the input
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def cutlass_w4a8_moe_mm(
d: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
experts_offsets: torch.tensor,
problem_sizes: torch.tensor,
a_strides: torch.tensor,
b_strides: torch.tensor,
d_strides: torch.tensor,
s_strides: torch.tensor,
chunk_size: int = 128,
topk: int = 8,
):
"""
Perform grouped matrix multiplication between int4 weights and fp8 activations.
This function executes multiple GEMM operations in parallel, which is useful for
scenarios like Mixture of Experts (MoE) where different inputs go through different
experts. The implementation leverages NVIDIA Hopper architecture features for
optimal performance with quantized weights.
Args:
d: Output matrices of shape [total_m, total_n]
a: Activation matrices in FP8 (float_e4m3_t) format
Each tensor should be of shape [total_m, K] in row-major layout
b: Weight matrices in packed int4 format
Each tensor should be of shape [E, N, K//2] in column-major layout
where each byte contains two 4-bit integers
a_scales: Scale factors for the inputs
b_scales: Scale factors for the quantized weights
Each tensor should be of shape [E, K//512, N*8]
experts_offsets: Tensor containing expert offsets for determining group boundaries
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
a_strides: Strides information for A matrices
b_strides: Strides information for B matrices
d_strides: Strides information for D matrices
s_strides: Strides information for b_scales matrices
chunk_size: Number of elements each scale value applies to (K//512), default to 128
Requirements:
- All tensors must be on a CUDA device
- Requires an NVIDIA Hopper GPU (H100)
- A tensors must be in float8_e4m3fn format
- B tensors must contain packed int4 values (stored as int8)
Note:
The function computes: D = (A * (B * scales))
for each group of tensors in parallel
"""
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
d,
a,
b,
a_scales,
b_scales,
experts_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk,
)
import pytest
import torch
from sgl_kernel import cutlass_w4a8_moe_mm
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
if int4_values_interleaved.shape[-1] % 2 != 0:
raise ValueError(
"the last dim size of int4_values_interleaved tensor must be even."
)
input_tensor_int8 = int4_values_interleaved.to(torch.int8)
low_nibbles = input_tensor_int8[..., 0::2]
high_nibbles = input_tensor_int8[..., 1::2]
packed_tensor = (high_nibbles << 4) | (low_nibbles & 0x0F)
return packed_tensor.to(torch.int8)
def pack_interleave(num_experts, ref_weight, ref_scale):
n, k = ref_weight.shape[1], ref_weight.shape[2]
weight = pack_int4_values_to_int8(ref_weight.cpu()).cuda()
w_q = weight.view((num_experts, n, k // 2)).view(torch.int8)
w_q = w_q.contiguous()
scale_interleaved = ref_scale.reshape(
ref_scale.shape[0], ref_scale.shape[1], (ref_scale.shape[2] // 4), 4
) # [E, N, K/4, 4]
scale_interleaved = scale_interleaved.permute(0, 2, 1, 3) # [E, K/4, N, 4]
scale_interleaved = scale_interleaved.reshape(
ref_scale.shape[0], ref_scale.shape[2] // 4, ref_scale.shape[1] * 4
) # [E, K/4, N*4]
w_scale = scale_interleaved.contiguous()
return w_q, w_scale
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
def test_int4_fp8_grouped_gemm_single_expert(batch_size):
# Test parameters
num_experts = 1
m = batch_size # batch size
k = 512 # input dimension
n = 1024 # output dimension
torch.manual_seed(0)
dtype = torch.bfloat16
device = "cuda"
debug = False
print(f"\nTesting with batch_size={batch_size}")
# Create input tensors with ones
if debug:
a = torch.ones(m, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
a_scale = torch.ones(1, dtype=torch.float, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else:
a = torch.randn(m, k, dtype=dtype, device=device)
ref_w = torch.randint(
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
)
affine_coeff = 0.005
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff
)
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
# Create expert offsets and problem sizes
expert_offsets = torch.tensor([0, m], dtype=torch.int32, device=device)
problem_sizes = torch.tensor([[n, m, k]], dtype=torch.int32, device=device)
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
b_strides = a_strides
s_strides = c_strides
# Quantize input
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
# Create output tensor
c = torch.empty((m, n), dtype=torch.float16, device=device)
cutlass_w4a8_moe_mm(
c,
a_q,
w,
a_scale,
w_scale,
expert_offsets[:-1],
problem_sizes,
a_strides,
b_strides,
c_strides,
s_strides,
128,
8,
)
c = c.to(dtype)
# Reference implementation
experts_selection_result = torch.full((m,), 0)
c_ref = ref_grouped_gemm(
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
)
# Compare results
try:
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
except AssertionError as e:
# torch.set_printoptions(threshold=10_000)
print(f" FAILURE: tensors are NOT close.")
print(f" Ref tensor: {c_ref.flatten()}")
print(f" Cutlass tensor: {c.flatten()}")
print(
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(f" AssertionError: {e}")
raise
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16])
@pytest.mark.parametrize("k", [512, 1024])
@pytest.mark.parametrize("n", [1024, 2048])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0)
dtype = torch.bfloat16
device = "cuda"
debug = False
print(
f"\nTesting with batch_size={batch_size}, k={k}, n={n}, num_experts={num_experts}"
)
if debug:
a = torch.ones(batch_size, k, dtype=torch.bfloat16, device=device)
ref_w = torch.ones(num_experts, n, k, dtype=torch.int8, device=device)
a_scale = torch.ones(1, dtype=torch.float, device=device)
ref_w_scale = torch.ones(num_experts, n, k // 128, dtype=dtype, device=device)
else:
a = torch.randn(batch_size, k, dtype=dtype, device=device)
ref_w = torch.randint(
-8, 8, (num_experts, n, k), dtype=torch.int8, device=device
)
affine_coeff = 0.005
a_scale = torch.randn(1, dtype=torch.float32).cuda() * 0.02
ref_w_scale = (
torch.randn(num_experts, n, k // 128, dtype=dtype, device=device)
* affine_coeff
)
w, w_scale = pack_interleave(num_experts, ref_w, ref_w_scale)
# random select experts
experts_selection_result = torch.randint(
0, num_experts, (batch_size,), device=device
)
permutation = torch.argsort(experts_selection_result)
expert_token_counts = torch.bincount(
experts_selection_result, minlength=num_experts
)
# Create problem sizes and offsets for active experts
problem_sizes = []
for i in range(num_experts):
problem_sizes.append([n, expert_token_counts[i].item(), k])
problem_sizes = torch.tensor(problem_sizes, dtype=torch.int32, device=device)
expert_offsets = []
offset = 0
for i in range(num_experts):
expert_offsets.append(offset)
offset += problem_sizes[i][1].item()
expert_offsets = torch.tensor(expert_offsets, dtype=torch.int32, device=device)
# Permute input and quantize
a_perm = a[permutation]
a_q_perm = (
torch.clamp((a_perm / a_scale), -448.0, 448.0)
.to(torch.float8_e4m3fn)
.to(device)
)
# Create stride tensors
a_strides = torch.full((num_experts, 3), k, device=device, dtype=torch.int64)
c_strides = torch.full((num_experts, 3), n, device=device, dtype=torch.int64)
b_strides = a_strides
s_strides = c_strides
c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device)
cutlass_w4a8_moe_mm(
c_perm,
a_q_perm,
w,
a_scale,
w_scale,
expert_offsets,
problem_sizes,
a_strides,
b_strides,
c_strides,
s_strides,
128,
8,
)
# Un-permute the result
c = torch.empty_like(c_perm)
c[permutation] = c_perm
c = c.to(dtype)
c_ref = ref_grouped_gemm(
c, a, a_scale, ref_w, ref_w_scale, num_experts, experts_selection_result
)
# Compare results
try:
torch.testing.assert_close(c, c_ref, rtol=1e-2, atol=0.1)
except AssertionError as e:
print(f" FAILURE: tensors are NOT close.")
print(
f" Max absolute difference: {torch.max(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(
f" Mean absolute difference: {torch.mean(torch.abs(c.to(c_ref.dtype) - c_ref))}"
)
print(f" AssertionError: {e}")
raise
def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_result):
dtype = torch.bfloat16
c_ref = torch.zeros_like(c)
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn)
for i in range(num_experts):
token_idx = torch.where(experts_selection_result == i)[0]
if len(token_idx) == 0:
continue
a = a_q[token_idx]
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float)
ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype)
c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale
c = c.to(dtype)
c_ref[token_idx] = c.to(dtype)
return c_ref
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment