Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu
#include <torch/all.h>
#include "cutlass_mxfp8_grouped_mm_launcher.cuh"
void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& sfa,
const torch::Tensor& sfb, torch::Tensor& d,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
"blockscale_offsets must be int32");
TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)");
TORCH_CHECK(b.dim() == 3,
"b must be a 3D tensor of shape (num_experts, k, n)");
TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0,
"k should align 128");
TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128");
TORCH_CHECK(a.strides()[1] == 1, "a must be row major");
TORCH_CHECK(b.strides()[1] == 1, "b must be column major");
auto stream = at::cuda::getCurrentCUDAStream();
if (d.dtype() == torch::kBFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else if (d.dtype() == torch::kFloat16) {
expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype<
cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false,
"No implemented cutlass_mxfp8_grouped_mm for "
"current device");
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm);
}
\ No newline at end of file
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh
#pragma once
#include <cuda.h>
#include "cute/tensor.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
namespace expert_specialization {
using namespace cute;
template <typename GemmTraits>
struct CutlassMxfp8GroupedMmOffsetFunctor {
using Gemm = typename GemmTraits::Gemm;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementSF = typename GemmTraits::ElementSF;
using ElementD = typename GemmTraits::ElementOutput;
// Input
int* expert_offsets{nullptr};
int* blockscale_offsets{nullptr};
// Output
ElementA* a_base{nullptr};
ElementB* b_base{nullptr};
ElementSF* sfa_base{nullptr};
ElementSF* sfb_base{nullptr};
ElementD* d_base{nullptr};
ElementA** a_offsets{nullptr};
ElementB** b_offsets{nullptr};
ElementSF** sfa_offsets{nullptr};
ElementSF** sfb_offsets{nullptr};
ElementD** d_offsets{nullptr};
CutlassMxfp8GroupedMmOffsetFunctor() = default;
CutlassMxfp8GroupedMmOffsetFunctor(
int* _expert_offsets, int* _blockscale_offsets, ElementA* _a_base,
ElementB* _b_base, ElementSF* _sfa_base, ElementSF* _sfb_base,
ElementD* _d_base, ElementA** _a_offsets, ElementB** _b_offsets,
ElementSF** _sfa_offsets, ElementSF** _sfb_offsets, ElementD** _d_offsets)
: expert_offsets{_expert_offsets},
blockscale_offsets{_blockscale_offsets},
a_base(_a_base),
b_base(_b_base),
sfa_base(_sfa_base),
sfb_base(_sfb_base),
d_base(_d_base),
a_offsets(_a_offsets),
b_offsets(_b_offsets),
sfa_offsets(_sfa_offsets),
sfb_offsets(_sfb_offsets),
d_offsets(_d_offsets) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
int64_t blockscale_offset =
static_cast<int64_t>(blockscale_offsets[expert_id]);
int64_t a_stride = expert_offset * k;
int64_t b_stride = expert_id * k * n;
int64_t d_stride = expert_offset * n;
int64_t sfa_stride = blockscale_offset * (k / 32);
int64_t sfb_stride = expert_id * n * (k / 32);
a_offsets[expert_id] = a_base + a_stride;
b_offsets[expert_id] = b_base + b_stride;
sfa_offsets[expert_id] = sfa_base + sfa_stride;
sfb_offsets[expert_id] = sfb_base + sfb_stride;
d_offsets[expert_id] = d_base + d_stride;
}
};
template <typename GemmTraits>
struct CutlassMxfp8GroupedMmLayoutFunctor {
using Sm1xxBlkScaledConfig = typename GemmTraits::Sm1xxBlkScaledConfig;
using LayoutSFA = typename GemmTraits::LayoutSFA;
using LayoutSFB = typename GemmTraits::LayoutSFB;
LayoutSFA* layout_sfa_base{nullptr};
LayoutSFB* layout_sfb_base{nullptr};
CutlassMxfp8GroupedMmLayoutFunctor() = default;
CutlassMxfp8GroupedMmLayoutFunctor(LayoutSFA* _layout_sfa_base,
LayoutSFB* _layout_sfb_base)
: layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id;
LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id;
*layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
cute::make_shape(m, n, k, 1));
*layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
cute::make_shape(m, n, k, 1));
}
};
template <typename GemmTraits>
struct CutlassMxfp8GroupedMmStrideFunctor {
using StrideA = typename GemmTraits::StrideA;
using StrideB = typename GemmTraits::StrideB;
using StrideD = typename GemmTraits::StrideD;
StrideA* stride_A_base{nullptr};
StrideB* stride_B_base{nullptr};
StrideD* stride_D_base{nullptr};
CutlassMxfp8GroupedMmStrideFunctor() = default;
CutlassMxfp8GroupedMmStrideFunctor(StrideA* _stride_A_base,
StrideB* _stride_B_base,
StrideD* _stride_D_base)
: stride_A_base(_stride_A_base),
stride_B_base(_stride_B_base),
stride_D_base(_stride_D_base) {}
void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) {
StrideA* stride_A = stride_A_base + expert_id;
StrideB* stride_B = stride_B_base + expert_id;
StrideD* stride_D = stride_D_base + expert_id;
*stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
*stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
*stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
}
};
template <typename OffsetFunctor, typename LayoutFunctor,
typename StrideFunctor>
__global__ void cutlassMxfp8GroupedMmPreComputeKernel(
int* problem_sizes, OffsetFunctor offset_functor,
LayoutFunctor layout_functor, StrideFunctor stride_functor) {
int64_t expert_id = static_cast<int64_t>(threadIdx.x);
int m = problem_sizes[expert_id * 3 + 0];
int n = problem_sizes[expert_id * 3 + 1];
int k = problem_sizes[expert_id * 3 + 2];
offset_functor(expert_id, m, n, k);
layout_functor(expert_id, m, n, k);
stride_functor(expert_id, m, n, k);
}
} // namespace expert_specialization
\ No newline at end of file
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <cassert>
#include <iostream>
#include <string>
#include "cute/tensor.hpp"
#include "cutlass_mxfp8_grouped_mm_functor.cuh"
#include "cutlass_mxfp8_grouped_mm_traits.cuh"
namespace expert_specialization {
template <typename GemmTraits>
void cutlass_mxfp8_grouped_mm_pre_compute(
torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs,
torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a,
torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa,
torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor<GemmTraits>;
using ElementA = typename OffsetFunctor::ElementA;
using ElementB = typename OffsetFunctor::ElementB;
using ElementSF = typename OffsetFunctor::ElementSF;
using ElementD = typename OffsetFunctor::ElementD;
using LayoutFunctor = CutlassMxfp8GroupedMmLayoutFunctor<GemmTraits>;
using LayoutSFA = typename LayoutFunctor::LayoutSFA;
using LayoutSFB = typename LayoutFunctor::LayoutSFB;
using StrideFunctor = CutlassMxfp8GroupedMmStrideFunctor<GemmTraits>;
using StrideA = typename StrideFunctor::StrideA;
using StrideB = typename StrideFunctor::StrideB;
using StrideD = typename StrideFunctor::StrideD;
int num_experts = (int)expert_offsets.size(0);
TORCH_CHECK(num_experts <= 1024,
"Number of experts cannot exceed 1024, the maximum number of "
"threads per block.");
OffsetFunctor offset_functor(
reinterpret_cast<int*>(expert_offsets.data_ptr()),
reinterpret_cast<int*>(blockscale_offsets.data_ptr()),
reinterpret_cast<ElementA*>(a.data_ptr()),
reinterpret_cast<ElementB*>(b.data_ptr()),
reinterpret_cast<ElementSF*>(sfa.data_ptr()),
reinterpret_cast<ElementSF*>(sfb.data_ptr()),
reinterpret_cast<ElementD*>(d.data_ptr()),
reinterpret_cast<ElementA**>(a_ptrs.data_ptr()),
reinterpret_cast<ElementB**>(b_ptrs.data_ptr()),
reinterpret_cast<ElementSF**>(sfa_ptrs.data_ptr()),
reinterpret_cast<ElementSF**>(sfb_ptrs.data_ptr()),
reinterpret_cast<ElementD**>(d_ptrs.data_ptr()));
LayoutFunctor layout_functor(
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()));
StrideFunctor stride_functor(reinterpret_cast<StrideA*>(stride_a.data_ptr()),
reinterpret_cast<StrideB*>(stride_b.data_ptr()),
reinterpret_cast<StrideD*>(stride_d.data_ptr()));
cutlassMxfp8GroupedMmPreComputeKernel<<<1, num_experts, 0, stream>>>(
static_cast<int*>(problem_sizes.data_ptr()), offset_functor,
layout_functor, stride_functor);
}
template <typename GemmTraits>
void cutlass_mxfp8_grouped_mm(
const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs,
const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs,
const torch::Tensor& d_ptrs, const torch::Tensor& stride_a,
const torch::Tensor& stride_b, const torch::Tensor& stride_d,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, cudaStream_t stream) {
using Gemm = typename GemmTraits::Gemm;
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementSF = typename GemmTraits::ElementSF;
using ElementD = typename GemmTraits::ElementOutput;
using StrideA = typename GemmTraits::StrideA;
using StrideB = typename GemmTraits::StrideB;
using StrideD = typename GemmTraits::StrideD;
using LayoutSFA = typename GemmTraits::LayoutSFA;
using LayoutSFB = typename GemmTraits::LayoutSFB;
using UnderlyingProblemShape =
typename GemmTraits::ProblemShape::UnderlyingProblemShape;
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = c10::cuda::current_device();
hw_info.sm_count =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster;
hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster;
int num_experts = (int)problem_sizes.size(0);
UnderlyingProblemShape* underlying_problem_shape =
reinterpret_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
typename Gemm::Arguments arguments = {
cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, underlying_problem_shape, nullptr},
{reinterpret_cast<const ElementA**>(a_ptrs.data_ptr()),
reinterpret_cast<StrideA*>(stride_a.data_ptr()),
reinterpret_cast<const ElementB**>(b_ptrs.data_ptr()),
reinterpret_cast<StrideB*>(stride_b.data_ptr()),
reinterpret_cast<const ElementSF**>(sfa_ptrs.data_ptr()),
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
reinterpret_cast<const ElementSF**>(sfb_ptrs.data_ptr()),
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())},
{{},
nullptr,
nullptr,
reinterpret_cast<ElementD**>(d_ptrs.data_ptr()),
reinterpret_cast<StrideD*>(stride_d.data_ptr())},
hw_info,
{} // Scheduler
};
Gemm gemm;
auto can_implement_status = gemm.can_implement(arguments);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM");
torch::TensorOptions options_uint8 =
torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device());
size_t workspace_size = gemm.get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
auto status = gemm.initialize(arguments, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm.run(stream, nullptr, true); // Enable PDL
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void cutlass_mxfp8_grouped_mm_dispatch_out_dtype(
const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa,
const torch::Tensor& sfb, torch::Tensor& d,
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets, cudaStream_t stream) {
int num_experts = (int)problem_sizes.size(0);
torch::TensorOptions options_int64 =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::TensorOptions options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor d_ptrs = torch::empty(num_experts, options_int64);
torch::Tensor stride_a = torch::empty(num_experts, options_int64);
torch::Tensor stride_b = torch::empty(num_experts, options_int64);
torch::Tensor stride_d = torch::empty(num_experts, options_int64);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32);
using GemmTraits = CutlassMxfp8GroupedMmGemmTraits<MMA1SMConfig, OutType>;
cutlass_mxfp8_grouped_mm_pre_compute<GemmTraits>(
a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d,
layout_sfa, layout_sfb, a, b, sfa, sfb, d, problem_sizes, expert_offsets,
blockscale_offsets, stream);
cutlass_mxfp8_grouped_mm<GemmTraits>(
a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d,
layout_sfa, layout_sfb, problem_sizes, stream);
}
} // namespace expert_specialization
\ No newline at end of file
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_traits.cuh
#pragma once
// Misc
#include "cute/tensor.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/cutlass.h"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_size.h"
// Collective Builder
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
// Integration
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
namespace expert_specialization {
using namespace cute;
// Different configs for 1SM and 2SM MMA kernel
struct MMA1SMConfig {
using MmaTileShape = Shape<_128, _128, _128>;
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
const static dim3 preferred_cluster;
const static dim3 fallback_cluster;
};
const dim3 MMA1SMConfig::preferred_cluster(1, 4, 1);
const dim3 MMA1SMConfig::fallback_cluster(1, 2, 1);
template <typename _MMAConfig, typename OutputDtype>
struct CutlassMxfp8GroupedMmGemmTraits {
using MMAConfig = _MMAConfig;
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutputDtype;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
// A matrix configuration
using ElementA = cutlass::mx_float8_t<ElementInput>;
using LayoutA = cutlass::layout::RowMajor;
constexpr static int AlignmentA = 32;
// B matrix configuration
using ElementB = cutlass::mx_float8_t<ElementInput>;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr static int AlignmentB = 32;
// C/D matrix configuration
using ElementC = void;
using ElementD = ElementOutput;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;
constexpr static int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
constexpr static int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using ElementAccumulator = float;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using CustomEVTIdentity = // acc
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator,
RoundStyle>,
cutlass::epilogue::fusion::Sm90AccFetch>;
// Core kernel configurations
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
// Runtime Cluster Shape
using ClusterShape = Shape<int32_t, int32_t, _1>;
// Define Epilogue
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, typename MMAConfig::MmaTileShape,
ClusterShape, Shape<_64, _64>, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC*, AlignmentC, ElementD, LayoutD*, AlignmentD,
typename MMAConfig::EpilogueSchedule,
CustomEVTIdentity>::CollectiveOp;
// Define Mainloop
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
LayoutB*, AlignmentB, ElementAccumulator,
typename MMAConfig::MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename MMAConfig::KernelSchedule>::CollectiveOp;
// Define GemmKernel
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using ElementSF = typename Gemm::GemmKernel::ElementSF;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
using LayoutSFA =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
using LayoutSFB =
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
using Sm1xxBlkScaledConfig =
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
};
} // namespace expert_specialization
\ No newline at end of file
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu
#include <torch/all.h>
#include "mxfp8_experts_quant.cuh"
void mxfp8_experts_quant(const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
TORCH_CHECK(input.dim() == 2, "input must be 2D tensor");
TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128");
TORCH_CHECK(input.strides()[1] == 1, "input must be row major");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dtype() == torch::kInt32,
"expert_offsets must be int32");
TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32,
"blockscale_offsets must be int32");
auto groups = problem_sizes.size(0);
TORCH_CHECK(
expert_offsets.dim() == 1 && expert_offsets.size(0) == groups,
"expert_offsets must be 1D and have size equal to the number of groups");
TORCH_CHECK(
blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups,
"blockscale_offsets must be 1D and have size equal to the number of "
"groups");
auto stream = at::cuda::getCurrentCUDAStream();
if (input.dtype() == torch::kBFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else if (input.dtype() == torch::kFloat16) {
expert_specialization::launch_mxfp8_experts_quant<__half>(
input, problem_sizes, expert_offsets, blockscale_offsets, quant_output,
scale_factor);
} else {
TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16");
}
#else
TORCH_CHECK(false,
"No implemented mxfp8_experts_quant for "
"current device");
#endif
}
#include "core/registration.h"
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("mxfp8_experts_quant", mxfp8_experts_quant);
}
\ No newline at end of file
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// Adapted from SGLang:
// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/all.h>
#include <cuda/ptx>
#include "cute/tensor.hpp"
namespace expert_specialization {
using namespace cute;
constexpr uint32_t THREAD_BLOCK_SIZE = 128;
constexpr uint32_t WARP_SIZE = 32;
constexpr int BLOCK_M = 128;
constexpr int BLOCK_K = 128;
using ThrLayout = Layout<Shape<_16, _8>, Stride<_8, _1>>;
using ValLayout = Layout<Shape<_1, _16>>;
using SfR2SThrLayout = Layout<Shape<_16, _4>, Stride<_4, _1>>;
using SfR2SValLayout = Layout<Shape<_1, _1>>;
using ScaleFactorTileLayout =
Layout<Shape<Shape<_32, _4>, _4>, Stride<Stride<_16, _4>, _1>>;
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
// Some code references TRT-LLM:
// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/quantization.cuh
template <typename FragmentS, typename FragmentD>
__inline__ __device__ uint8_t cvt_warp_fp16_to_mxfp8(FragmentS& fragment_s,
FragmentD& fragment_d) {
using FragmentSLayout = typename FragmentS::layout_type;
using FragmentDLayout = typename FragmentD::layout_type;
FragmentSLayout fragment_s_layout;
FragmentDLayout fragment_d_layout;
static_assert(is_static<FragmentSLayout>::value &&
size(fragment_s_layout) == 16);
static_assert(is_static<FragmentDLayout>::value &&
size(fragment_d_layout) == 16);
constexpr int eles_per_thr = 16;
using ValType = typename FragmentS::element_type;
using VecType = std::conditional_t<std::is_same_v<ValType, __nv_bfloat16>,
__nv_bfloat162, __half2>;
VecType vec[8];
// Assign vals
vec[0].x = fragment_s(Int<0>{});
vec[0].y = fragment_s(Int<1>{});
vec[1].x = fragment_s(Int<2>{});
vec[1].y = fragment_s(Int<3>{});
vec[2].x = fragment_s(Int<4>{});
vec[2].y = fragment_s(Int<5>{});
vec[3].x = fragment_s(Int<6>{});
vec[3].y = fragment_s(Int<7>{});
vec[4].x = fragment_s(Int<8>{});
vec[4].y = fragment_s(Int<9>{});
vec[5].x = fragment_s(Int<10>{});
vec[5].y = fragment_s(Int<11>{});
vec[6].x = fragment_s(Int<12>{});
vec[6].y = fragment_s(Int<13>{});
vec[7].x = fragment_s(Int<14>{});
vec[7].y = fragment_s(Int<15>{});
auto local_max = __habs2(vec[0]);
for (int i = 1; i < eles_per_thr / 2; i++) {
local_max = __hmax2(__habs2(vec[i]), local_max);
}
local_max = __hmax2(__shfl_xor_sync(uint32_t(-1), local_max, 1), local_max);
// Get the final absolute maximum values.
float block_max(0.0f);
if constexpr (std::is_same_v<ValType, __nv_bfloat16>) {
block_max = __bfloat162float(__hmax(local_max.x, local_max.y));
} else {
block_max = __half2float(__hmax(local_max.x, local_max.y));
}
// Get the SF (max value of the vector / max value of mxfp8).
float sf_val = block_max * reciprocal_approximate_ftz(448.0f);
// 8 bits representation of the SF.
uint8_t fp8_sf_val;
__nv_fp8_e8m0 tmp_sf_val;
tmp_sf_val.__x =
__nv_cvt_float_to_e8m0(sf_val, __NV_SATFINITE, cudaRoundPosInf);
sf_val = static_cast<float>(tmp_sf_val);
fp8_sf_val = tmp_sf_val.__x;
// Get the output scale (reciprocal of the SFValue).
float output_scale =
block_max != 0.f ? reciprocal_approximate_ftz(sf_val) : 0.0f;
// Convert the input to float.
float2 fp2_vals[eles_per_thr / 2];
#pragma unroll
for (int i = 0; i < eles_per_thr / 2; i++) {
if constexpr (std::is_same_v<ValType, __half>) {
fp2_vals[i] = __half22float2(vec[i]);
} else {
fp2_vals[i] = __bfloat1622float2(vec[i]);
}
fp2_vals[i].x *= output_scale;
fp2_vals[i].y *= output_scale;
}
union {
uint8_t bytes[16];
__nv_fp8x2_e4m3 elts[8];
} u;
u.elts[0] = __nv_fp8x2_e4m3(fp2_vals[0]);
u.elts[1] = __nv_fp8x2_e4m3(fp2_vals[1]);
u.elts[2] = __nv_fp8x2_e4m3(fp2_vals[2]);
u.elts[3] = __nv_fp8x2_e4m3(fp2_vals[3]);
u.elts[4] = __nv_fp8x2_e4m3(fp2_vals[4]);
u.elts[5] = __nv_fp8x2_e4m3(fp2_vals[5]);
u.elts[6] = __nv_fp8x2_e4m3(fp2_vals[6]);
u.elts[7] = __nv_fp8x2_e4m3(fp2_vals[7]);
fragment_d(Int<0>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[0]);
fragment_d(Int<1>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[1]);
fragment_d(Int<2>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[2]);
fragment_d(Int<3>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[3]);
fragment_d(Int<4>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[4]);
fragment_d(Int<5>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[5]);
fragment_d(Int<6>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[6]);
fragment_d(Int<7>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[7]);
fragment_d(Int<8>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[8]);
fragment_d(Int<9>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[9]);
fragment_d(Int<10>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[10]);
fragment_d(Int<11>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[11]);
fragment_d(Int<12>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[12]);
fragment_d(Int<13>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[13]);
fragment_d(Int<14>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[14]);
fragment_d(Int<15>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[15]);
return fp8_sf_val;
}
template <typename TensorS, typename TensorP, typename TensorD,
typename TensorSharedSF, typename TensorSF, typename TiledCopyG2R,
typename TiledCopyR2G, typename TiledCopyR2S>
__inline__ __device__ void mxfp8_experts_quant_tile(
TensorS& tensor_s, TensorP& tensor_p, TensorD& tensor_d,
TensorSharedSF& tensor_shared_sf, TensorSF& tensor_sf, int m,
TiledCopyG2R& tiled_copy_g2r, TiledCopyR2G& tiled_copy_r2g,
TiledCopyR2S& tiled_copy_r2s) {
static_assert(size(get<0>(typename TensorS::layout_type{})) == 128 &&
size(get<1>(typename TensorS::layout_type{})) == 128 &&
stride(get<1>(typename TensorS::layout_type{})) == 1);
static_assert(size(get<0>(typename TensorD::layout_type{})) == 128 &&
size(get<1>(typename TensorD::layout_type{})) == 128 &&
stride(get<1>(typename TensorD::layout_type{})) == 1);
static_assert(size(get<0>(typename TensorP::layout_type{})) == 128 &&
size(get<1>(typename TensorP::layout_type{})) == 128);
static_assert(size(get<0>(typename TensorSharedSF::layout_type{})) == 128 &&
size(get<1>(typename TensorSharedSF::layout_type{})) == 4);
static_assert(size(get<0>(typename TensorSF::layout_type{})) == 128 &&
size(get<1>(typename TensorSF::layout_type{})) == 4);
using Tiler_MN = typename TiledCopyG2R::Tiler_MN;
auto tiler_mn = Tiler_MN{};
static_assert(size<0>(tiler_mn) == 16 && size<1>(tiler_mn) == 128);
auto tiled_tensor_s = tiled_divide(tensor_s, tiler_mn);
auto tiled_tensor_p = tiled_divide(tensor_p, tiler_mn);
auto tiled_tensor_d = tiled_divide(tensor_d, tiler_mn);
static_assert(size<2>(tiled_tensor_s) == 1);
static_assert(size<2>(tiled_tensor_p) == 1);
static_assert(size<2>(tiled_tensor_d) == 1);
auto squeeze_tiled_tensor_s = take<0, 2>(tiled_tensor_s);
auto squeeze_tiled_tensor_p = take<0, 2>(tiled_tensor_p);
auto squeeze_tiled_tensor_d = take<0, 2>(tiled_tensor_d);
using SF_Tiler_MN = typename TiledCopyR2S::Tiler_MN;
auto sf_tiler_mn = SF_Tiler_MN{};
static_assert(size<0>(sf_tiler_mn) == 16 && size<1>(sf_tiler_mn) == 4);
auto tiled_tensor_sf = tiled_divide(tensor_sf, sf_tiler_mn);
auto tiled_tensor_shared_sf = tiled_divide(tensor_shared_sf, sf_tiler_mn);
auto squeeze_tiled_tensor_sf = take<0, 2>(tiled_tensor_sf);
auto squeeze_tiled_tensor_shared_sf = take<0, 2>(tiled_tensor_shared_sf);
constexpr int tile_loop_count = size<1>(tiled_tensor_s);
constexpr int rows_in_tile = 16;
// We don't need to clear shared memory
// clear(squeeze_tiled_tensor_shared_sf);
#pragma unroll 4
for (int t = 0; t < tile_loop_count; t++) {
if (t * rows_in_tile >= m) {
break;
}
auto current_copy_tile_s = tensor<0>(squeeze_tiled_tensor_s(_, t));
auto current_copy_tile_p = tensor<0>(squeeze_tiled_tensor_p(_, t));
auto current_copy_tile_d = tensor<0>(squeeze_tiled_tensor_d(_, t));
auto current_copy_tile_sf = tensor<0>(squeeze_tiled_tensor_sf(_, t));
auto current_copy_tile_shared_sf =
tensor<0>(squeeze_tiled_tensor_shared_sf(_, t));
// Global to Register copy
auto thr_copy_g2r = tiled_copy_g2r.get_thread_slice(threadIdx.x);
auto thr_tile_g2r_s = thr_copy_g2r.partition_S(current_copy_tile_s);
auto thr_tile_g2r_p = thr_copy_g2r.partition_S(current_copy_tile_p);
auto input_fragment = make_fragment_like(thr_tile_g2r_s);
// Register to Global copy
auto thr_copy_r2g = tiled_copy_r2g.get_thread_slice(threadIdx.x);
auto thr_tile_r2g_d = thr_copy_r2g.partition_D(current_copy_tile_d);
auto thr_tile_r2g_p = thr_copy_r2g.partition_D(current_copy_tile_p);
auto output_fragment = make_fragment_like(thr_tile_r2g_d);
// Register to Shared copy
auto thr_copy_r2s = tiled_copy_r2s.get_thread_slice(threadIdx.x / 2);
auto thr_tile_r2s_shared_sf =
thr_copy_r2s.partition_D(current_copy_tile_shared_sf);
auto shared_sf_fragment = make_fragment_like(thr_tile_r2s_shared_sf);
// CopyG2R & convert & CopyR2G
copy_if(tiled_copy_g2r, thr_tile_g2r_p, thr_tile_g2r_s, input_fragment);
uint8_t fp8_sf_val =
cvt_warp_fp16_to_mxfp8(input_fragment, output_fragment);
copy_if(tiled_copy_r2g, thr_tile_r2g_p, output_fragment, thr_tile_r2g_d);
shared_sf_fragment[0] = fp8_sf_val;
// Before first copy r2s, clear shared memory and wait previous group
if (t == 0 && threadIdx.x == 0) {
// Wait for the group to have completed reading from shared memory.
cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t<0>());
}
__syncthreads();
if (threadIdx.x % 2 == 0) {
copy(tiled_copy_r2s, shared_sf_fragment, thr_tile_r2s_shared_sf);
}
__syncthreads();
}
// Wait for shared memory writes to be visible to TMA engine.
cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); // b)
__syncthreads();
if (threadIdx.x == 0) {
cuda::ptx::cp_async_bulk(cuda::ptx::space_global, cuda::ptx::space_shared,
squeeze_tiled_tensor_sf.data().get(),
squeeze_tiled_tensor_shared_sf.data().get(), 512);
// Wait for TMA transfer to have finished reading shared memory.
// Create a "bulk async-group" out of the previous bulk copy operation.
cuda::ptx::cp_async_bulk_commit_group();
}
__syncthreads();
}
template <typename T_IN, typename TiledCopyG2R, typename TiledCopyR2G,
typename TiledCopyR2S>
__global__ void mxfp8_experts_quant_kernel(
const T_IN* input, const int* problem_sizes, const int* expert_offsets,
const int* blockscale_offsets, cutlass::float_e4m3_t* quant_output,
uint8_t* scale_factor, int groups, TiledCopyG2R tiled_copy_g2r,
TiledCopyR2G tiled_copy_r2g, TiledCopyR2S tiled_copy_r2s) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
__shared__ __align__(512) uint8_t shared_memory[512];
ScaleFactorTileLayout scale_factor_tile_layout{};
auto scale_factor_shared =
make_tensor(make_smem_ptr(shared_memory),
scale_factor_tile_layout); // ((_32,_4), _4):((_16,_4), _1)
// TODO: Transform Groupwise Schedule into a more efficient Schedule
for (int g = 0; g < groups; g++) {
int m = problem_sizes[g * 3 + 0];
int k = problem_sizes[g * 3 + 2];
int64_t expert_offset = static_cast<int64_t>(expert_offsets[g]);
int64_t blockscale_offset = static_cast<int64_t>(blockscale_offsets[g]);
auto input_tensor = make_tensor(
make_gmem_ptr(input + expert_offset * k),
make_layout(make_shape(m, k),
LayoutRight{})); // (M, K):(K, 1) half_t/bfloat16_t
auto quant_output_tensor = make_tensor(
make_gmem_ptr(quant_output + expert_offset * k),
make_layout(make_shape(m, k),
LayoutRight{})); // (M, K):(K, 1) cutlass::float_e4m3_t
auto scale_factor_shape = make_shape(ceil_div(m, 128) * 128, k / 32);
auto scale_factor_layout = tile_to_shape(scale_factor_tile_layout,
scale_factor_shape, LayoutRight{});
// layout<0>(layout<0>(scale_factor_layout)) (_32,_4):(_16,_4) -- static
// layout<1>(layout<0>(scale_factor_layout)) M_align_128 / 128 -- dynamic
// shape dynamic stride layout<0>(layout<1>(scale_factor_layout)) _4:_1 --
// static layout<1>(layout<1>(scale_factor_layout)) (K / 32) / 4 : _512 --
// dynamic shape static stride
// Reshape to zipped layout for 1D indexing
auto zipped_scale_factor_layout = make_layout(
make_layout(layout<0>(layout<0>(scale_factor_layout)),
layout<0>(layout<1>(scale_factor_layout))),
make_layout(
layout<1>(layout<0>(scale_factor_layout)),
layout<1>(layout<1>(
scale_factor_layout)))); // (((_32,_4),_4),(M_align_128 /
// 128,(K / 32) /
// 4)):(((_16,_4),_1),(?,_512))
auto scale_factor_tensor =
make_tensor(make_gmem_ptr(scale_factor + blockscale_offset * (k / 32)),
zipped_scale_factor_layout);
// Used for cases where M is not divisible by 128 (most scenarios).
auto input_shape = shape(input_tensor); // (M, K):(K, 1)
auto identity_tensor = make_identity_tensor(input_shape);
auto predict_tensor = cute::lazy::transform(
identity_tensor, [&](auto c) { return elem_less(c, input_shape); });
// (_128, _128)
auto tiler = make_shape(Int<BLOCK_M>{}, Int<BLOCK_K>{});
auto tiled_input_tensor = zipped_divide(
input_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
auto tiled_quant_output_tensor =
zipped_divide(quant_output_tensor,
tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
auto tiled_predict_tensor = zipped_divide(
predict_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128)))
auto total_tiles =
size<1>(tiled_input_tensor); // cdiv(M, 128) * cdiv(K, 128)
decltype(total_tiles) blk_offset = blockIdx.x;
while (blk_offset < total_tiles) {
auto current_input_tile = tensor<0>(tiled_input_tensor(_, blk_offset));
auto current_quant_output_tile =
tensor<0>(tiled_quant_output_tensor(_, blk_offset));
auto current_predict_tile =
tensor<0>(tiled_predict_tensor(_, blk_offset));
auto current_scale_factor_tile =
tensor<0>(scale_factor_tensor(_, blk_offset));
mxfp8_experts_quant_tile<
decltype(current_input_tile), decltype(current_predict_tile),
decltype(current_quant_output_tile), decltype(scale_factor_shared),
decltype(current_scale_factor_tile), TiledCopyG2R, TiledCopyR2G,
TiledCopyR2S>(current_input_tile, current_predict_tile,
current_quant_output_tile, scale_factor_shared,
current_scale_factor_tile, m, tiled_copy_g2r,
tiled_copy_r2g, tiled_copy_r2s);
blk_offset += gridDim.x;
}
}
#endif
}
template <typename T_IN>
void launch_mxfp8_experts_quant(const torch::Tensor& input,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets,
const torch::Tensor& blockscale_offsets,
torch::Tensor& quant_output,
torch::Tensor& scale_factor) {
ThrLayout thr_layout{};
ValLayout val_layout{};
SfR2SThrLayout r2s_thr_layout{};
SfR2SValLayout r2s_val_layout{};
using CopyOpG2R =
UniversalCopy<cutlass::AlignedArray<T_IN, size(val_layout)>>;
using CopyAtomG2R = cute::Copy_Atom<CopyOpG2R, T_IN>;
auto tiled_copy_g2r = cute::make_tiled_copy(
CopyAtomG2R{}, thr_layout, val_layout); // Tiler_MN: (16, 128)
using CopyOpR2G = UniversalCopy<
cutlass::AlignedArray<cutlass::float_e4m3_t, size(val_layout)>>;
using CopyAtomR2G = cute::Copy_Atom<CopyOpR2G, cutlass::float_e4m3_t>;
auto tiled_copy_r2g = cute::make_tiled_copy(
CopyAtomR2G{}, thr_layout, val_layout); // Tiler_MN: (16, 128)
using CopyOpR2S =
UniversalCopy<cutlass::AlignedArray<uint8_t, size(r2s_val_layout)>>;
using CopyAtomR2S = cute::Copy_Atom<CopyOpR2S, uint8_t>;
auto tiled_copy_r2s = cute::make_tiled_copy(
CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4)
int max_active_blocks_per_sm = -1;
AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_per_sm,
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
decltype(tiled_copy_r2g),
decltype(tiled_copy_r2s)>,
THREAD_BLOCK_SIZE, 0));
dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount *
max_active_blocks_per_sm,
1, 1);
dim3 block(THREAD_BLOCK_SIZE, 1, 1);
int num_experts = (int)problem_sizes.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
mxfp8_experts_quant_kernel<T_IN, decltype(tiled_copy_g2r),
decltype(tiled_copy_r2g), decltype(tiled_copy_r2s)>
<<<grid, block, 0, stream>>>(
reinterpret_cast<const T_IN*>(input.data_ptr()),
reinterpret_cast<const int*>(problem_sizes.data_ptr()),
reinterpret_cast<const int*>(expert_offsets.data_ptr()),
reinterpret_cast<const int*>(blockscale_offsets.data_ptr()),
reinterpret_cast<cutlass::float_e4m3_t*>(quant_output.data_ptr()),
reinterpret_cast<uint8_t*>(scale_factor.data_ptr()), num_experts,
tiled_copy_g2r, tiled_copy_r2g, tiled_copy_r2s);
}
} // namespace expert_specialization
\ No newline at end of file
...@@ -57,7 +57,7 @@ void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows, ...@@ -57,7 +57,7 @@ void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows,
template <typename T> template <typename T>
void expandInputRowsKernelLauncher( void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts, T const* unpermuted_input, T* permuted_output,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const* expert_first_token_offset, int64_t const num_rows,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
template <typename T, bool CHECK_SKIPPED> template <typename T, bool CHECK_SKIPPED>
__global__ void expandInputRowsKernel( __global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, int* sorted_experts, T const* unpermuted_input, T* permuted_output,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const* expert_first_token_offset, int64_t const num_rows,
...@@ -16,7 +16,6 @@ __global__ void expandInputRowsKernel( ...@@ -16,7 +16,6 @@ __global__ void expandInputRowsKernel(
int64_t expanded_dest_row = blockIdx.x; int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row = int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row]; expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX); assert(expanded_dest_row <= INT32_MAX);
...@@ -54,7 +53,7 @@ __global__ void expandInputRowsKernel( ...@@ -54,7 +53,7 @@ __global__ void expandInputRowsKernel(
template <typename T> template <typename T>
void expandInputRowsKernelLauncher( void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts, T const* unpermuted_input, T* permuted_output,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const* expert_first_token_offset, int64_t const num_rows,
...@@ -70,12 +69,12 @@ void expandInputRowsKernelLauncher( ...@@ -70,12 +69,12 @@ void expandInputRowsKernelLauncher(
bool is_check_skip = num_valid_tokens_ptr != nullptr; bool is_check_skip = num_valid_tokens_ptr != nullptr;
auto func = func_map[is_check_skip]; auto func = func_map[is_check_skip];
func<<<blocks, threads, 0, stream>>>( func<<<blocks, threads, 0, stream>>>(unpermuted_input, permuted_output,
unpermuted_input, permuted_output, sorted_experts, expanded_dest_row_to_expanded_source_row,
expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row,
expanded_source_row_to_expanded_dest_row, permuted_idx, permuted_idx, expert_first_token_offset,
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k, num_rows, num_valid_tokens_ptr, cols, k,
num_local_experts); num_local_experts);
} }
template <class T, class U> template <class T, class U>
......
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// bf16 x bf16 -> fp32 router GEMM via cuBLAS.
// Uses CUBLAS_COMPUTE_32F so bf16 operands accumulate into fp32,
// matching TRT-LLM's cuBLAS fallback behaviour in dsv3RouterGemmOp.
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cublas_v2.h>
// cuBLAS column-major math for row-major PyTorch tensors:
// weight[N,K]_row lda=K -> cuBLAS sees (K,N) col-major; CUBLAS_OP_T ->
// (N,K) input[M,K]_row ldb=K -> cuBLAS sees (K,M) col-major; CUBLAS_OP_N
// -> (K,M) out[M,N]_row ldc=N -> cuBLAS sees (N,M) col-major (written as
// output^T)
// cuBLAS: C(N,M) = weight(N,K) @ input(K,M) => C^T = output[M,N]
// params: m=N, n=M, k=K, lda=K (weight), ldb=K (input), ldc=N (output)
torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
torch::Tensor const& weight) {
TORCH_CHECK(input.dtype() == torch::kBFloat16,
"router_gemm_bf16_fp32: input must be bfloat16");
TORCH_CHECK(weight.dtype() == torch::kBFloat16,
"router_gemm_bf16_fp32: weight must be bfloat16");
TORCH_CHECK(input.dim() == 2 && weight.dim() == 2,
"router_gemm_bf16_fp32: input and weight must be 2-D");
TORCH_CHECK(input.size(1) == weight.size(1),
"router_gemm_bf16_fp32: inner dimensions must match");
int64_t const M = input.size(0);
int64_t const N = weight.size(0);
int64_t const K = input.size(1);
auto out = torch::empty({M, N}, input.options().dtype(torch::kFloat32));
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(
cublasSetStream(handle, at::cuda::getCurrentCUDAStream()));
float const alpha = 1.0f;
float const beta = 0.0f;
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle, CUBLAS_OP_T, CUBLAS_OP_N, static_cast<int>(N),
static_cast<int>(M), static_cast<int>(K), &alpha, weight.data_ptr(),
CUDA_R_16BF, static_cast<int>(K), input.data_ptr(), CUDA_R_16BF,
static_cast<int>(K), &beta, out.data_ptr(), CUDA_R_32F,
static_cast<int>(N), CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT));
return out;
}
...@@ -124,6 +124,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -124,6 +124,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, " "routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
"Tensor)"); "Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk); m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
m.def("router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor");
m.impl("router_gemm_bf16_fp32", torch::kCUDA, &router_gemm_bf16_fp32);
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file
#endif #endif
} }
......
...@@ -114,6 +114,10 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, ...@@ -114,6 +114,10 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
int64_t numRows, int64_t stride0, int64_t stride1, int64_t numRows, int64_t stride0, int64_t stride1,
int64_t topK); int64_t topK);
void large_context_topk(const torch::Tensor& score, torch::Tensor& indices,
const torch::Tensor& lengths,
std::optional<torch::Tensor> row_starts_opt);
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& weight, torch::Tensor& scale, // torch::Tensor& weight, torch::Tensor& scale,
// double epsilon); // double epsilon);
...@@ -265,13 +269,13 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( ...@@ -265,13 +269,13 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab); const int64_t n, const int64_t k, const bool swap_ab);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens, const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t num_local_experts,
const int64_t padded_m, const int64_t n, const int64_t padded_m, const int64_t n,
const int64_t k); const int64_t k);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
...@@ -291,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, ...@@ -291,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a); std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_scale,
torch::Tensor const& input_scale, bool is_sf_swizzled_layout);
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant( void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,
...@@ -311,7 +319,9 @@ void silu_and_mul_scaled_fp4_experts_quant( ...@@ -311,7 +319,9 @@ void silu_and_mul_scaled_fp4_experts_quant(
void per_token_group_quant_fp8(const torch::Tensor& input, void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s, torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min, int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0); double fp8_max, bool scale_ue8m0,
bool dummy_is_scale_transposed,
bool dummy_is_tma_aligned);
void per_token_group_quant_int8(const torch::Tensor& input, void per_token_group_quant_int8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_q,
...@@ -365,7 +375,9 @@ void selective_scan_fwd( ...@@ -365,7 +375,9 @@ void selective_scan_fwd(
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size, const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token, const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token, const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx); const std::optional<torch::Tensor>& initial_state_idx,
const std::optional<torch::Tensor>& cu_chunk_seqlen,
const std::optional<torch::Tensor>& last_chunk_indices);
torch::Tensor dynamic_4bit_int_moe_cpu( torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
...@@ -404,3 +416,8 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, ...@@ -404,3 +416,8 @@ void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false); int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size(); int64_t qr_max_size();
#endif #endif
#ifndef USE_ROCM
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a,
torch::Tensor const& mat_b);
#endif
\ No newline at end of file
...@@ -542,7 +542,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel( ...@@ -542,7 +542,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
if (!lane_id) { if (!lane_id) {
// Store scales. // Store scales.
if constexpr (std::is_same<scale_t, uint8_t>::value) { if constexpr (std::is_same<scale_t, uint8_t>::value) {
// Packed UE8MO format. Remove Mantissa. // Packed UE8M0 format. Remove Mantissa.
*y_s_ptr = reinterpret_cast<int16_t&>(y_s) >> 7; *y_s_ptr = reinterpret_cast<int16_t&>(y_s) >> 7;
bool const jump_pack = (current_group_id + 1) % 4 == 0; bool const jump_pack = (current_group_id + 1) % 4 == 0;
......
...@@ -39,12 +39,12 @@ namespace vllm { ...@@ -39,12 +39,12 @@ namespace vllm {
template <class Type, bool UE8M0_SF = false> template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols,
int32_t num_padded_cols, int32_t num_packed_cols,
Type const* __restrict__ in, Type const* __restrict__ in,
float const* __restrict__ SFScale, float const* __restrict__ SFScale,
uint32_t* __restrict__ out, uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) { uint32_t* __restrict__ SFout) {
using PackedVec = vllm::PackedVec<Type>; using PackedVec = vllm::PackedVec<Type, CVT_FP4_PACK16>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
...@@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// Input tensor row/col loops. // Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
if (colIdx < num_padded_cols) { if (colIdx < num_packed_cols) {
PackedVec in_vec; PackedVec in_vec;
PackedVec in_vec2; PackedVec in_vec2;
int64_t inOffset = int64_t inOffset =
...@@ -73,19 +73,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -73,19 +73,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
bool valid = (rowIdx < numRows) && (elem_idx < numCols); bool valid = (rowIdx < numRows) && (elem_idx < numCols);
if constexpr (CVT_FP4_PACK16) { if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>( ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8], &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid); valid);
ld256_or_zero_cg_u32<Type>( ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec2),
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 8], &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 8],
valid); valid);
} else { } else {
ld128_or_zero_cg_u32<Type>( ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4], &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid); valid);
ld128_or_zero_cg_u32<Type>( ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec2),
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 4], &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 4],
valid); valid);
} }
// Compute silu and mul // Compute silu and mul
...@@ -107,7 +107,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -107,7 +107,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64; reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
} else { } else {
out[inOffset] = out_val; int64_t outOffset =
rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
out[outOffset] = out_val;
} }
} }
} }
...@@ -140,9 +142,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] ...@@ -140,9 +142,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
int const numBlocksPerSM = int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x)); vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); int num_packed_cols = int(n / CVT_FP4_ELTS_PER_THREAD);
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x)); int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
int grid_x = std::min( int grid_x = std::min(
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
...@@ -152,7 +154,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] ...@@ -152,7 +154,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>( vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
m, n, sf_n_unpadded, input_ptr, input_sf_ptr, m, n, num_packed_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr), reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out)); reinterpret_cast<uint32_t*>(sf_out));
}); });
......
...@@ -43,7 +43,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -43,7 +43,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
uint32_t* input_offset_by_experts, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts, uint32_t* output_scale_offset_by_experts, int n_experts,
bool low_latency) { bool low_latency) {
using PackedVec = PackedVec<Type>; using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
...@@ -155,7 +155,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) ...@@ -155,7 +155,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
float const* SFScale, uint32_t* out, uint32_t* SFout, float const* SFScale, uint32_t* out, uint32_t* SFout,
uint32_t* input_offset_by_experts, uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts, int n_experts) { uint32_t* output_scale_offset_by_experts, int n_experts) {
using PackedVec = PackedVec<Type>; using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <torch/all.h> #include <torch/all.h>
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
...@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( ...@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor const& output_scale_offset_by_experts); torch::Tensor const& output_scale_offset_by_experts);
#endif #endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) { bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf, return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
...@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, ...@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
} }
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::empty(
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
torch::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::empty(
{sf_m, sf_n},
torch::TensorOptions().device(device).dtype(torch::kInt32));
} else {
output_sf = torch::empty(
{m, n / CVT_FP4_SF_VEC_SIZE},
torch::TensorOptions().device(device).dtype(torch::kUInt8));
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant( void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale,
......
...@@ -42,7 +42,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -42,7 +42,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
Type const* __restrict__ in, Type const* __restrict__ in,
float const* __restrict__ SFScale, float const* __restrict__ SFScale,
uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) {
using PackedVec = vllm::PackedVec<Type>; using PackedVec = vllm::PackedVec<Type, CVT_FP4_PACK16>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
...@@ -71,13 +71,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -71,13 +71,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// If we are outside valid rows OR outside valid columns -> Use Zeros // If we are outside valid rows OR outside valid columns -> Use Zeros
bool valid = (rowIdx < numRows) && (elem_idx < numCols); bool valid = (rowIdx < numRows) && (elem_idx < numCols);
if constexpr (CVT_FP4_PACK16) { if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>( ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8], &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid); valid);
} else { } else {
ld128_or_zero_cg_u32<Type>( ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4], &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid); valid);
} }
auto sf_out = auto sf_out =
...@@ -109,11 +109,12 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -109,11 +109,12 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
template <class Type, bool UE8M0_SF = false> template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols,
int32_t sf_n_unpadded, Type const* __restrict__ in, int32_t sf_n_unpadded, int32_t num_packed_cols,
Type const* __restrict__ in,
float const* __restrict__ SFScale, float const* __restrict__ SFScale,
uint32_t* __restrict__ out, uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) { uint32_t* __restrict__ SFout) {
using PackedVec = PackedVec<Type>; using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
...@@ -131,20 +132,20 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -131,20 +132,20 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// Iterate over all rows and cols including padded ones - // Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it. // ensures we visit every single scale factor address to initialize it.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
if (colIdx < sf_n_unpadded) { if (colIdx < num_packed_cols) {
PackedVec in_vec; PackedVec in_vec;
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
// If we are outside valid rows OR outside valid columns -> Use Zeros // If we are outside valid rows OR outside valid columns -> Use Zeros
bool valid = (rowIdx < numRows) && (elem_idx < numCols); bool valid = (rowIdx < numRows) && (elem_idx < numCols);
if constexpr (CVT_FP4_PACK16) { if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>( ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8], &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid); valid);
} else { } else {
ld128_or_zero_cg_u32<Type>( ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4], &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid); valid);
} }
auto sf_out = auto sf_out =
...@@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, ...@@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
reinterpret_cast<uint32_t*>(sf_out)); reinterpret_cast<uint32_t*>(sf_out));
}); });
} else { } else {
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x)); int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
int grid_x = std::min( int grid_x = std::min(
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
...@@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, ...@@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment. // NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false> vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, input_ptr, <<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
input_sf_ptr, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr), reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out)); reinterpret_cast<uint32_t*>(sf_out));
}); });
......
...@@ -18,9 +18,12 @@ ...@@ -18,9 +18,12 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <utility>
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ #include "../../cuda_vec_utils.cuh"
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090
#define ELTS_PER_THREAD 16 #define ELTS_PER_THREAD 16
constexpr int CVT_FP4_ELTS_PER_THREAD = 16; constexpr int CVT_FP4_ELTS_PER_THREAD = 16;
constexpr bool CVT_FP4_PACK16 = true; constexpr bool CVT_FP4_PACK16 = true;
...@@ -34,68 +37,6 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16; ...@@ -34,68 +37,6 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16;
namespace vllm { namespace vllm {
// Convert PyTorch cpp type to CUDA type
template <typename T>
struct CUDATypeConverter {
using Type = T;
};
template <>
struct CUDATypeConverter<at::Half> {
using Type = half;
};
template <>
struct CUDATypeConverter<at::BFloat16> {
using Type = __nv_bfloat16;
};
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
// Define a 32 bytes packed data type.
template <class Type>
struct alignas(32) PackedVec {
typename TypeConverter<Type>::Type elts[8];
};
#else
// Define a 16 bytes packed data type.
template <class Type>
struct alignas(16) PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
#endif
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
template <typename Int> template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) { __host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>, static_assert(std::is_integral_v<Int>,
...@@ -114,6 +55,18 @@ inline int computeEffectiveRows(int m) { ...@@ -114,6 +55,18 @@ inline int computeEffectiveRows(int m) {
return round_up(m, ROW_TILE); return round_up(m, ROW_TILE);
} }
// Compute the shape of the swizzled SF output tensor.
// Returns (rounded_m, rounded_n / 4) where:
// rounded_m = round_up(m, 128)
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
inline std::pair<int64_t, int64_t> computeSwizzledSFShape(int64_t m,
int64_t n) {
int64_t rounded_m = round_up(m, static_cast<int64_t>(128));
int64_t scale_n = n / CVT_FP4_SF_VEC_SIZE;
int64_t rounded_n = round_up(scale_n, static_cast<int64_t>(4));
return {rounded_m, rounded_n / 4};
}
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) { inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
uint32_t val; uint32_t val;
...@@ -208,56 +161,6 @@ __device__ __forceinline__ float reciprocal_approximate_ftz(float a) { ...@@ -208,56 +161,6 @@ __device__ __forceinline__ float reciprocal_approximate_ftz(float a) {
return b; return b;
} }
template <class Type>
__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec<Type>& out,
const void* ptr,
bool pred) {
uint32_t r0, r1, r2, r3;
asm volatile(
"{\n"
" .reg .pred pr;\n"
" setp.ne.u32 pr, %4, 0;\n"
" mov.u32 %0, 0;\n"
" mov.u32 %1, 0;\n"
" mov.u32 %2, 0;\n"
" mov.u32 %3, 0;\n"
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n"
"}\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
: "r"((int)pred), "l"(ptr));
*reinterpret_cast<uint4*>(&out) = uint4{r0, r1, r2, r3};
}
template <class Type>
__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec<Type>& out,
const void* ptr,
bool pred) {
uint32_t r0, r1, r2, r3, r4, r5, r6, r7;
asm volatile(
"{\n"
" .reg .pred pr;\n"
" setp.ne.u32 pr, %8, 0;\n"
" mov.u32 %0, 0;\n"
" mov.u32 %1, 0;\n"
" mov.u32 %2, 0;\n"
" mov.u32 %3, 0;\n"
" mov.u32 %4, 0;\n"
" mov.u32 %5, 0;\n"
" mov.u32 %6, 0;\n"
" mov.u32 %7, 0;\n"
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n"
"}\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6),
"=r"(r7)
: "r"((int)pred), "l"(ptr));
reinterpret_cast<uint4*>(&out)[0] = uint4{r0, r1, r2, r3};
reinterpret_cast<uint4*>(&out)[1] = uint4{r4, r5, r6, r7};
}
// Compute SF output offset for swizzled tensor core layout. // Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4] // SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64 // Caller must precompute: numKTiles = (numCols + 63) / 64
...@@ -315,8 +218,8 @@ __device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack, ...@@ -315,8 +218,8 @@ __device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack,
// Quantizes the provided PackedVec into the uint32_t output // Quantizes the provided PackedVec into the uint32_t output
template <class Type, int CVT_FP4_NUM_THREADS_PER_SF, bool UE8M0_SF = false> template <class Type, int CVT_FP4_NUM_THREADS_PER_SF, bool UE8M0_SF = false>
__device__ __forceinline__ fp4_packed_t __device__ __forceinline__ fp4_packed_t cvt_warp_fp16_to_fp4(
cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) { PackedVec<Type, CVT_FP4_PACK16>& vec, float SFScaleVal, uint8_t* SFout) {
// Get absolute maximum values among the local 8 values. // Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]); auto localMax = __habs2(vec.elts[0]);
...@@ -372,11 +275,7 @@ cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) { ...@@ -372,11 +275,7 @@ cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
#pragma unroll #pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) { fp2Vals[i] = cast_to_float2(vec.elts[i]);
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale; fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale; fp2Vals[i].y *= outputScale;
} }
...@@ -395,22 +294,19 @@ __device__ __forceinline__ float2 silu2(float2 x) { ...@@ -395,22 +294,19 @@ __device__ __forceinline__ float2 silu2(float2 x) {
} }
template <class Type> template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu_mul( __inline__ __device__ PackedVec<Type, CVT_FP4_PACK16> compute_silu_mul(
const PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) { const PackedVec<Type, CVT_FP4_PACK16>& x_vec,
PackedVec<Type> result; const PackedVec<Type, CVT_FP4_PACK16>& y_vec) {
PackedVec<Type, CVT_FP4_PACK16> result;
#pragma unroll #pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
// silu_mul in float32 // silu_mul in float32
if constexpr (std::is_same_v<Type, half>) { using packed_t = typename PackedTypeConverter<Type>::Type;
float2 silu_vec = silu2(__half22float2(x_vec.elts[i])); float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i]));
result.elts[i] = __float22half2_rn( float2 y_f2 = cast_to_float2(y_vec.elts[i]);
__fmul2_rn(silu_vec, __half22float2(y_vec.elts[i]))); result.elts[i] = cast_to_packed<packed_t>(
} else { make_float2(silu_vec.x * y_f2.x, silu_vec.y * y_f2.y));
float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i]));
result.elts[i] = __float22bfloat162_rn(
__fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i])));
}
} }
return result; return result;
} }
......
...@@ -29,31 +29,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( ...@@ -29,31 +29,33 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size] scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) { int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) {
float rms = 0.0f; float rms = 0.0f;
float token_scale = 0.0f; float token_scale = 0.0f;
// Compute rms // Compute rms
vllm::vectorized::compute_rms<scalar_t, has_residual>( vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual); &rms, input, hidden_size, input_stride, var_epsilon, residual);
// Compute scale // Compute scale
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
has_residual>( has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, hidden_size, &token_scale, scales, input, weight, rms, scale_ub, hidden_size,
residual); input_stride, residual);
// RMS Norm + Quant // RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) { if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale; token_scale = 1.0f / token_scale;
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true, vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>( has_residual>(out, input, weight, rms,
out, input, weight, rms, &token_scale, hidden_size, residual); &token_scale, hidden_size,
input_stride, residual);
} else { } else {
// FP8 - Do not invert token_scale for exact match with FBGemm // FP8 - Do not invert token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false, vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>( has_residual>(out, input, weight, rms,
out, input, weight, rms, &token_scale, hidden_size, residual); &token_scale, hidden_size,
input_stride, residual);
} }
} }
...@@ -65,38 +67,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( ...@@ -65,38 +67,40 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size] scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) { int32_t const input_stride, scalar_t* __restrict__ residual = nullptr) {
// For vectorization, token_input and token_output pointers need to be // For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively. // aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0; bool const can_vectorize = hidden_size % 4 == 0 and input_stride % 4 == 0;
if (can_vectorize) { if (can_vectorize) {
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t, return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
has_residual>( has_residual>(
out, scales, input, weight, scale_ub, var_epsilon, hidden_size, out, scales, input, weight, scale_ub, var_epsilon, hidden_size,
residual); input_stride, residual);
} }
float rms = 0.0f; float rms = 0.0f;
float token_scale = 0.0f; float token_scale = 0.0f;
// Compute RMS // Compute RMS
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size, vllm::compute_rms<scalar_t, has_residual>(
var_epsilon, residual); &rms, input, hidden_size, input_stride, var_epsilon, residual);
// Compute Scale // Compute Scale
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>( vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, hidden_size, &token_scale, scales, input, weight, rms, scale_ub, hidden_size,
residual); input_stride, residual);
// RMS Norm + Quant // RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) { if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale; token_scale = 1.0f / token_scale;
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>( vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, &token_scale, hidden_size, residual); out, input, weight, rms, &token_scale, hidden_size, input_stride,
residual);
} else { } else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm // FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>( vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, &token_scale, hidden_size, residual); out, input, weight, rms, &token_scale, hidden_size, input_stride,
residual);
} }
} }
...@@ -111,18 +115,20 @@ __global__ void rms_norm_per_block_quant_kernel( ...@@ -111,18 +115,20 @@ __global__ void rms_norm_per_block_quant_kernel(
scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size] scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon, int32_t const hidden_size, float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) { int32_t const input_stride, scalar_t* __restrict__ residual = nullptr,
int64_t outer_scale_stride = 1) {
float rms; float rms;
// Compute RMS // Compute RMS
// Always able to vectorize due to constraints on hidden_size // Always able to vectorize due to constraints on hidden_size
vllm::vectorized::compute_rms<scalar_t, has_residual>( vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual); &rms, input, hidden_size, input_stride, var_epsilon, residual);
// Compute Scale // Compute Scale
// Always able to vectorize due to constraints on hidden_size and group_size // Always able to vectorize due to constraints on hidden_size and group_size
vllm::vectorized::compute_dynamic_per_token_scales< vllm::vectorized::compute_dynamic_per_token_scales<
scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>( scalar_t, scalar_out_t, has_residual, is_scale_transposed, group_size>(
nullptr, scales, input, weight, rms, scale_ub, hidden_size, residual); nullptr, scales, input, weight, rms, scale_ub, hidden_size, input_stride,
residual, outer_scale_stride);
// RMS Norm + Quant // RMS Norm + Quant
// Always able to vectorize due to constraints on hidden_size // Always able to vectorize due to constraints on hidden_size
...@@ -133,7 +139,8 @@ __global__ void rms_norm_per_block_quant_kernel( ...@@ -133,7 +139,8 @@ __global__ void rms_norm_per_block_quant_kernel(
vllm::vectorized::norm_and_quant< vllm::vectorized::norm_and_quant<
scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>, scalar_t, scalar_out_t, std::is_same_v<scalar_out_t, int8_t>,
has_residual, is_scale_transposed, group_size>( has_residual, is_scale_transposed, group_size>(
out, input, weight, rms, scales, hidden_size, residual); out, input, weight, rms, scales, hidden_size, input_stride, residual,
outer_scale_stride);
} }
} // namespace vllm } // namespace vllm
...@@ -149,6 +156,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( ...@@ -149,6 +156,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
std::optional<at::Tensor> const& scale_ub, std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual) { std::optional<at::Tensor>& residual) {
int32_t hidden_size = input.size(-1); int32_t hidden_size = input.size(-1);
int32_t input_stride = input.view({-1, hidden_size}).stride(0);
auto num_tokens = input.numel() / hidden_size; auto num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
...@@ -165,7 +173,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( ...@@ -165,7 +173,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(), out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(), input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr, scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size, var_epsilon, hidden_size, input_stride,
has_residual ? residual->data_ptr<scalar_in_t>() : nullptr); has_residual ? residual->data_ptr<scalar_in_t>() : nullptr);
}); });
}); });
...@@ -182,7 +190,9 @@ void rms_norm_dynamic_per_token_quant( ...@@ -182,7 +190,9 @@ void rms_norm_dynamic_per_token_quant(
? c10::ScalarType::Float8_e4m3fn ? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz; : c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1,
"Input must be contiguous in the last dimension");
if (scale_ub.has_value()) { if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type); TORCH_CHECK(out.dtype() == kFp8Type);
...@@ -191,6 +201,7 @@ void rms_norm_dynamic_per_token_quant( ...@@ -191,6 +201,7 @@ void rms_norm_dynamic_per_token_quant(
TORCH_CHECK(scales.dtype() == torch::kFloat32); TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) { if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type()); TORCH_CHECK(residual->scalar_type() == input.scalar_type());
TORCH_CHECK(residual->is_contiguous());
} }
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
...@@ -212,6 +223,15 @@ void rms_norm_per_block_quant_dispatch( ...@@ -212,6 +223,15 @@ void rms_norm_per_block_quant_dispatch(
std::optional<at::Tensor> const& scale_ub, std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual, bool is_scale_transposed) { std::optional<at::Tensor>& residual, bool is_scale_transposed) {
int32_t hidden_size = input.size(-1); int32_t hidden_size = input.size(-1);
int32_t input_stride = input.view({-1, hidden_size}).stride(0);
TORCH_CHECK(hidden_size % 4 == 0,
"Hidden size must be divisible by 4 for vectorized access");
TORCH_CHECK(input_stride % 4 == 0,
"Input stride must be divisible by 4 for vectorized access");
TORCH_CHECK(group_size % 4 == 0,
"Group size must be divisible by 4 for vectorized access");
auto num_tokens = input.numel() / hidden_size; auto num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 grid(num_tokens);
...@@ -237,9 +257,10 @@ void rms_norm_per_block_quant_dispatch( ...@@ -237,9 +257,10 @@ void rms_norm_per_block_quant_dispatch(
weight.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr, : nullptr,
var_epsilon, hidden_size, var_epsilon, hidden_size, input_stride,
has_residual ? residual->data_ptr<scalar_in_t>() has_residual ? residual->data_ptr<scalar_in_t>()
: nullptr); : nullptr,
scales.stride(1));
}); });
}); });
}); });
...@@ -257,7 +278,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -257,7 +278,9 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
? c10::ScalarType::Float8_e4m3fn ? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz; : c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1,
"Input must be contiguous in the last dimension");
if (scale_ub.has_value()) { if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type); TORCH_CHECK(out.dtype() == kFp8Type);
...@@ -266,11 +289,17 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input, ...@@ -266,11 +289,17 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
TORCH_CHECK(scales.dtype() == torch::kFloat32); TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) { if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type()); TORCH_CHECK(residual->scalar_type() == input.scalar_type());
TORCH_CHECK(residual->is_contiguous());
} }
TORCH_CHECK(group_size == 128 || group_size == 64, TORCH_CHECK(group_size == 128 || group_size == 64,
"Unsupported group size: ", group_size); "Unsupported group size: ", group_size);
if (scales.stride(1) > 1) {
TORCH_CHECK(is_scale_transposed,
"Outer scale stride must be 1 when scales are not transposed");
}
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size, rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
var_epsilon, scale_ub, residual, var_epsilon, scale_ub, residual,
is_scale_transposed); is_scale_transposed);
......
...@@ -16,14 +16,17 @@ namespace vllm { ...@@ -16,14 +16,17 @@ namespace vllm {
// has_residual must be true, if residual is not a nullptr // has_residual must be true, if residual is not a nullptr
template <typename scalar_t, bool has_residual = false> template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
int32_t const hidden_size, float const epsilon, int32_t const hidden_size,
int32_t const input_stride, float const epsilon,
scalar_t const* __restrict__ residual = nullptr) { scalar_t const* __restrict__ residual = nullptr) {
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// sum of squares // sum of squares
float ss = 0.0f; float ss = 0.0f;
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]); float x = static_cast<float>(input[input_token_offset + i]);
if constexpr (has_residual) { if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]); x += static_cast<float>(residual[token_offset + i]);
} }
...@@ -73,15 +76,20 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -73,15 +76,20 @@ __device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales, float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub, float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr, int32_t const hidden_size, int32_t const input_stride,
int32_t const group_size = 0) { scalar_t const* __restrict__ residual = nullptr,
int32_t const group_size = 0, int64_t outer_scale_stride = 1) {
float block_absmax_val_maybe = 0.0f; float block_absmax_val_maybe = 0.0f;
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>}; constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
__syncthreads(); __syncthreads();
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
if (group_size > 0) { if (group_size > 0) {
__shared__ float s_max_vals[1024];
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
int64_t num_groups = hidden_size / group_size; int64_t num_groups = hidden_size / group_size;
__shared__ float s_max_vals[1024];
int64_t const threads_per_group = blockDim.x / num_groups; int64_t const threads_per_group = blockDim.x / num_groups;
int64_t const thread_in_group = threadIdx.x % threads_per_group; int64_t const thread_in_group = threadIdx.x % threads_per_group;
int64_t const group_offset = threadIdx.x / threads_per_group * group_size; int64_t const group_offset = threadIdx.x / threads_per_group * group_size;
...@@ -89,7 +97,7 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -89,7 +97,7 @@ __device__ void compute_dynamic_per_token_scales(
int64_t const thread_end = int64_t const thread_end =
min(group_offset + group_size, static_cast<int64_t>(hidden_size)); min(group_offset + group_size, static_cast<int64_t>(hidden_size));
for (auto i = thread_offset; i < thread_end; i += threads_per_group) { for (auto i = thread_offset; i < thread_end; i += threads_per_group) {
float x = static_cast<float>(input[token_offset + i]); float x = static_cast<float>(input[input_token_offset + i]);
if constexpr (has_residual) { if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]); x += static_cast<float>(residual[token_offset + i]);
} }
...@@ -133,7 +141,9 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -133,7 +141,9 @@ __device__ void compute_dynamic_per_token_scales(
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val()); scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
// Global output store // Global output store
if constexpr (is_scale_transposed) { if constexpr (is_scale_transposed) {
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
blockIdx.x] = scale; blockIdx.x] = scale;
} else { } else {
all_token_scales[blockIdx.x * num_groups + all_token_scales[blockIdx.x * num_groups +
...@@ -142,10 +152,8 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -142,10 +152,8 @@ __device__ void compute_dynamic_per_token_scales(
} }
__syncthreads(); __syncthreads();
} else { } else {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]); float x = static_cast<float>(input[input_token_offset + i]);
if constexpr (has_residual) { if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]); x += static_cast<float>(residual[token_offset + i]);
} }
...@@ -180,17 +188,18 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -180,17 +188,18 @@ __device__ void compute_dynamic_per_token_scales(
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted, template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false, bool is_scale_transposed = false> bool has_residual = false, bool is_scale_transposed = false>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output, __device__ void norm_and_quant(
scalar_t const* __restrict__ input, scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight, scalar_t const* __restrict__ weight, float const rms, float* const scale,
float const rms, float* const scale, int32_t const hidden_size, int32_t const input_stride,
int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0,
scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
int32_t const group_size = 0) { int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]); float x = static_cast<float>(input[input_token_offset + i]);
if constexpr (has_residual) { if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]); x += static_cast<float>(residual[token_offset + i]);
residual[token_offset + i] = static_cast<scalar_t>(x); residual[token_offset + i] = static_cast<scalar_t>(x);
...@@ -202,7 +211,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, ...@@ -202,7 +211,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
int64_t scale_idx = 0; int64_t scale_idx = 0;
if (group_size > 0) { if (group_size > 0) {
if constexpr (is_scale_transposed) { if constexpr (is_scale_transposed) {
scale_idx = (i / group_size) * gridDim.x + blockIdx.x; int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
scale_idx = (i / group_size) * scale_rows + blockIdx.x;
} else { } else {
scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size; scale_idx = blockIdx.x * (hidden_size / group_size) + i / group_size;
} }
...@@ -222,13 +233,16 @@ namespace vectorized { ...@@ -222,13 +233,16 @@ namespace vectorized {
// hidden_size must be a multiple of 4 // hidden_size must be a multiple of 4
template <typename scalar_t, bool has_residual = false> template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
int32_t const hidden_size, float const epsilon, int32_t const hidden_size,
int32_t const input_stride, float const epsilon,
scalar_t const* __restrict__ residual = nullptr) { scalar_t const* __restrict__ residual = nullptr) {
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// Vectorized input/output to better utilize memory bandwidth. // Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input = vec4_t<scalar_t> const* vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]); reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
vec4_t<scalar_t> const* vec_residual = nullptr; vec4_t<scalar_t> const* vec_residual = nullptr;
if constexpr (has_residual) { if constexpr (has_residual) {
vec_residual = vec_residual =
...@@ -286,8 +300,9 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -286,8 +300,9 @@ __device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales, float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub, float const rms, float const* __restrict__ scale_ub,
int32_t const hidden_size, int32_t const hidden_size, int32_t const input_stride,
scalar_t const* __restrict__ residual = nullptr) { scalar_t const* __restrict__ residual = nullptr,
int64_t outer_scale_stride = 1) {
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>}; constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
const int VEC_SIZE = 4; const int VEC_SIZE = 4;
...@@ -298,10 +313,13 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -298,10 +313,13 @@ __device__ void compute_dynamic_per_token_scales(
vec4_t<scalar_t> const* vec_weight = nullptr; vec4_t<scalar_t> const* vec_weight = nullptr;
vec4_t<scalar_t> const* vec_residual = nullptr; vec4_t<scalar_t> const* vec_residual = nullptr;
int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
if constexpr (group_size > 0) { if constexpr (group_size > 0) {
__shared__ float s_max_vals[1024]; __shared__ float s_max_vals[1024];
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
int64_t const num_groups = hidden_size / group_size; int64_t const num_groups = hidden_size / group_size;
int64_t const threads_per_group = blockDim.x / num_groups; int64_t const threads_per_group = blockDim.x / num_groups;
int64_t const thread_in_group = threadIdx.x % threads_per_group; int64_t const thread_in_group = threadIdx.x % threads_per_group;
...@@ -310,7 +328,8 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -310,7 +328,8 @@ __device__ void compute_dynamic_per_token_scales(
int64_t const thread_offset = group_offset + thread_in_group; int64_t const thread_offset = group_offset + thread_in_group;
int64_t const thread_end = min(group_offset + (group_size >> 2), int64_t const thread_end = min(group_offset + (group_size >> 2),
static_cast<int64_t>(hidden_size >> 2)); static_cast<int64_t>(hidden_size >> 2));
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]); vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight); vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
if constexpr (has_residual) { if constexpr (has_residual) {
vec_residual = vec_residual =
...@@ -382,7 +401,9 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -382,7 +401,9 @@ __device__ void compute_dynamic_per_token_scales(
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val()); scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
// Global output store // Global output store
if constexpr (is_scale_transposed) { if constexpr (is_scale_transposed) {
all_token_scales[(threadIdx.x / threads_per_group) * gridDim.x + int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
all_token_scales[(threadIdx.x / threads_per_group) * scale_rows +
blockIdx.x] = scale; blockIdx.x] = scale;
} else { } else {
all_token_scales[blockIdx.x * num_groups + all_token_scales[blockIdx.x * num_groups +
...@@ -392,8 +413,8 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -392,8 +413,8 @@ __device__ void compute_dynamic_per_token_scales(
__syncthreads(); __syncthreads();
} else { } else {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); vec_input =
vec_input = reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]); reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight); vec_weight = reinterpret_cast<vec4_t<scalar_t> const*>(weight);
if constexpr (has_residual) { if constexpr (has_residual) {
vec_residual = vec_residual =
...@@ -458,17 +479,18 @@ __device__ void compute_dynamic_per_token_scales( ...@@ -458,17 +479,18 @@ __device__ void compute_dynamic_per_token_scales(
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted, template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false, bool is_scale_transposed = false, bool has_residual = false, bool is_scale_transposed = false,
int32_t group_size = 0> int32_t group_size = 0>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output, __device__ void norm_and_quant(
scalar_t const* __restrict__ input, scalar_out_t* __restrict__ output, scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight, scalar_t const* __restrict__ weight, float const rms, float* const scale,
float const rms, float* const scale, int32_t const hidden_size, int32_t const input_stride,
int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr, int64_t outer_scale_stride = 1) {
scalar_t* __restrict__ residual = nullptr) { int64_t const input_token_offset =
blockIdx.x * static_cast<int64_t>(input_stride);
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size); int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// Vectorized input/output/weight/residual to better utilize memory bandwidth. // Vectorized input/output/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input = vec4_t<scalar_t> const* vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]); reinterpret_cast<vec4_t<scalar_t> const*>(&input[input_token_offset]);
vec4_t<scalar_t> const* vec_weight = vec4_t<scalar_t> const* vec_weight =
reinterpret_cast<vec4_t<scalar_t> const*>(weight); reinterpret_cast<vec4_t<scalar_t> const*>(weight);
q8x4_t<scalar_out_t>* vec_output = q8x4_t<scalar_out_t>* vec_output =
...@@ -516,7 +538,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output, ...@@ -516,7 +538,9 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
int64_t const num_groups = hidden_size / group_size; int64_t const num_groups = hidden_size / group_size;
int64_t scale_idx = 0; int64_t scale_idx = 0;
if constexpr (is_scale_transposed) { if constexpr (is_scale_transposed) {
scale_idx = (i * VEC_SIZE / group_size) * gridDim.x + blockIdx.x; int64_t const scale_rows = (gridDim.x + outer_scale_stride - 1) /
outer_scale_stride * outer_scale_stride;
scale_idx = (i * VEC_SIZE / group_size) * scale_rows + blockIdx.x;
} else { } else {
scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size; scale_idx = blockIdx.x * num_groups + i * VEC_SIZE / group_size;
} }
......
...@@ -12,6 +12,68 @@ namespace vllm { ...@@ -12,6 +12,68 @@ namespace vllm {
using c3x::cutlass_gemm_caller; using c3x::cutlass_gemm_caller;
// Custom wrapper to allow specifying EpilogueTile for small M
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule, typename EpilogueTile>
struct cutlass_3x_gemm_sm120_custom {
using ElementAB = ElementAB_;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementAB>::value;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementAB>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementD_>::value;
using ElementD = ElementD_;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
// MMA type
using ElementAccumulator = float;
// Epilogue types
using ElementBias = cutlass::half_t;
using ElementCompute = float;
using ElementAux = ElementD;
using LayoutAux = LayoutD;
using ElementAmax = float;
using EVTCompute = typename Epilogue::EVTCompute;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, EpilogueTile, // Use custom EpilogueTile
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
EVTCompute>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule, void>::CollectiveOp;
using GemmKernel = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>>;
};
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm120_fp8_config_default { struct sm120_fp8_config_default {
...@@ -25,6 +87,54 @@ struct sm120_fp8_config_default { ...@@ -25,6 +87,54 @@ struct sm120_fp8_config_default {
KernelSchedule, EpilogueSchedule>; KernelSchedule, EpilogueSchedule>;
}; };
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm120_fp8_config_M64 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
// SM120 Cooperative kernel requires Tile M >= 128.
// For M=64 tile, we use Pingpong schedule which is more flexible with small
// tiles.
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_64, _64, _128>;
// CUTLASS 3.x on SM120 currently restricts programmatic multicast (Cluster >
// 1) for certain schedules/types. Reverting to 1x1x1 to ensure compilation.
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm_sm120<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm120_fp8_config_M32 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_32, _64, _128>;
using ClusterShape = Shape<_1, _1, _1>;
// Use custom gemm to specify EpilogueTile M=32
using Cutlass3xGemm =
cutlass_3x_gemm_sm120_custom<InType, OutType, Epilogue, TileShape,
ClusterShape, KernelSchedule,
EpilogueSchedule, Shape<_32, _32>>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm120_fp8_config_M16 {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileShape = Shape<_16, _64, _128>;
using ClusterShape = Shape<_1, _1, _1>;
// Use custom gemm to specify EpilogueTile M=16
using Cutlass3xGemm =
cutlass_3x_gemm_sm120_custom<InType, OutType, Epilogue, TileShape,
ClusterShape, KernelSchedule,
EpilogueSchedule, Shape<_16, _32>>;
};
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
...@@ -36,6 +146,28 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, ...@@ -36,6 +146,28 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
int M = a.size(0);
if (M <= 16) {
using Cutlass3xGemmM16 =
typename sm120_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
return cutlass_gemm_caller<Cutlass3xGemmM16>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
if (M <= 32) {
using Cutlass3xGemmM32 =
typename sm120_fp8_config_M32<InType, OutType, Epilogue>::Cutlass3xGemm;
return cutlass_gemm_caller<Cutlass3xGemmM32>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
if (M <= 256) {
using Cutlass3xGemmM64 =
typename sm120_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm120_fp8_config_default<InType, OutType, typename sm120_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm; Epilogue>::Cutlass3xGemm;
...@@ -64,4 +196,4 @@ void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out, ...@@ -64,4 +196,4 @@ void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
} }
} }
} // namespace vllm } // namespace vllm
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment