Unverified Commit 8435b2e0 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[ModelBash][DSV3] Add TRTLLM DSV3 Router GEMM kernel (6% B1 Speedup) (#34302)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent b1b5e045
...@@ -1101,6 +1101,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -1101,6 +1101,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Not building Marlin MOE kernels as no compatible archs found" message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
" in CUDA target architectures") " in CUDA target architectures")
endif() endif()
# DeepSeek V3 router GEMM kernel - requires SM90+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(DSV3_ROUTER_GEMM_ARCHS "9.0a;10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND DSV3_ROUTER_GEMM_ARCHS)
set(DSV3_ROUTER_GEMM_SRC
"csrc/moe/dsv3_router_gemm_entry.cu"
"csrc/moe/dsv3_router_gemm_float_out.cu"
"csrc/moe/dsv3_router_gemm_bf16_out.cu")
set_gencode_flags_for_srcs(
SRCS "${DSV3_ROUTER_GEMM_SRC}"
CUDA_ARCHS "${DSV3_ROUTER_GEMM_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${DSV3_ROUTER_GEMM_SRC}")
message(STATUS "Building DSV3 router GEMM kernel for archs: ${DSV3_ROUTER_GEMM_ARCHS}")
else()
message(STATUS "Not building DSV3 router GEMM kernel as no compatible archs found"
" (requires SM90+ and CUDA >= 12.0)")
endif()
endif() endif()
message(STATUS "Enabling moe extension.") message(STATUS "Enabling moe extension.")
......
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "dsv3_router_gemm_utils.h"
// Custom FMA implementation using PTX assembly instructions
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b,
float2 const& c) {
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
}
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec,
float* dst) {
__nv_bfloat16* bf16_ptr =
reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));
#pragma unroll
for (int i = 0; i < VPT; i++) {
dst[i] = __bfloat162float(bf16_ptr[i]);
}
}
template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts,
int kHiddenDim>
__global__ __launch_bounds__(128, 1) void router_gemm_kernel_bf16_output(
__nv_bfloat16* out, T const* mat_a, T const* mat_b) {
// Each block handles one expert column
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
// Constants for this kernel
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations =
kHiddenDim / k_elems_per_k_iteration; // Total K iterations
// Initialize accumulators for all M rows
float acc[kNumTokens] = {};
// Shared memory for warp-level reduction
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps
// B matrix is in column-major order, so we can directly load a column for the
// n_idx expert
T const* b_col = mat_b + n_idx * kHiddenDim;
// Pre-compute k_base values for each iteration to help compiler optimize
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
// Process the GEMM in chunks
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];
// Load B matrix values using vector load (8 bf16 values)
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);
// Convert B values to float
float b_float[VPT];
bf16_uint4_to_float8<VPT>(b_vec, b_float);
// Process each token
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
// Load both rows of A matrix using vector loads
uint4 a_vec = *reinterpret_cast<uint4 const*>(
mat_a + (m_idx * kHiddenDim) + k_base);
// Convert A values to float
float a_float[VPT];
bf16_uint4_to_float8<VPT>(a_vec, a_float);
// Process elements in this chunk
#pragma unroll
for (int k = 0; k < VPT; k++) {
float a = a_float[k];
float b = b_float[k];
acc[m_idx] += a * b;
}
}
}
// Perform warp-level reduction
int const warpSize = 32;
int const warpId = tid / warpSize;
int const laneId = tid % warpSize;
// Register for warp-level reduction results
float warp_result[kNumTokens];
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
warp_result[m_idx] = acc[m_idx];
}
// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = warp_result[m];
// Butterfly reduction pattern
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);
// Only the first thread in each warp stores to shared memory
if (laneId == 0) {
sm_reduction[m][warpId] = sum;
}
}
__syncthreads();
// Final reduction across warps (only first thread)
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;
// Sum across the kNumWarps
#pragma unroll
for (int w = 0; w < kNumWarps; w++) {
final_sum += sm_reduction[m][w];
}
// Write final result
out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a,
T const* mat_b, cudaStream_t stream) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config,
router_gemm_kernel_bf16_output<T, kBlockSize, VPT, kNumTokens,
kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
}
// Template instantiations for DEFAULT_NUM_EXPERTS experts
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "dsv3_router_gemm_utils.h"
static constexpr int DEFAULT_NUM_EXPERTS = 256;
static constexpr int KIMI_K2_NUM_EXPERTS = 384;
static constexpr int DEFAULT_HIDDEN_DIM = 7168;
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b,
cudaStream_t stream);
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a,
T const* mat_b, cudaStream_t stream);
template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller {
static void unroll_float_output(int num_tokens, float* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeRouterGemmFloatOutput<__nv_bfloat16, kBegin, kNumExperts,
kHiddenDim>(output, input, weights, stream);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts,
kHiddenDim>::unroll_float_output(num_tokens, output, input,
weights, stream);
}
}
static void unroll_bf16_output(int num_tokens, __nv_bfloat16* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeRouterGemmBf16Output<__nv_bfloat16, kBegin, kNumExperts,
kHiddenDim>(output, input, weights, stream);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts,
kHiddenDim>::unroll_bf16_output(num_tokens, output, input,
weights, stream);
}
}
};
template <int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void unroll_float_output(int num_tokens, float* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemmFloatOutput<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(
output, input, weights, stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
static void unroll_bf16_output(int num_tokens, __nv_bfloat16* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemmBf16Output<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(
output, input, weights, stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};
void dsv3_router_gemm(at::Tensor& output, // [num_tokens, num_experts]
const at::Tensor& mat_a, // [num_tokens, hidden_dim]
const at::Tensor& mat_b // [num_experts, hidden_dim]
) {
TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
const int num_tokens = mat_a.size(0);
const int num_experts = mat_b.size(0);
const int hidden_dim = mat_a.size(1);
TORCH_CHECK(mat_a.size(1) == mat_b.size(1),
"mat_a and mat_b must have the same hidden_dim");
TORCH_CHECK(hidden_dim == DEFAULT_HIDDEN_DIM,
"Expected hidden_dim=", DEFAULT_HIDDEN_DIM,
", but got hidden_dim=", hidden_dim);
TORCH_CHECK(
num_experts == DEFAULT_NUM_EXPERTS || num_experts == KIMI_K2_NUM_EXPERTS,
"Expected num_experts=", DEFAULT_NUM_EXPERTS,
" or num_experts=", KIMI_K2_NUM_EXPERTS,
", but got num_experts=", num_experts);
TORCH_CHECK(num_tokens >= 1 && num_tokens <= 16,
"currently num_tokens must be less than or equal to 16 for "
"router_gemm");
TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "mat_a must be bf16");
TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "mat_b must be bf16");
TORCH_CHECK(output.dtype() == at::kFloat || output.dtype() == at::kBFloat16,
"output must be float32 or bf16");
auto const sm = getSMVersion();
TORCH_CHECK(sm >= 90 && sm <= 103, "required SM_103 >= CUDA ARCH >= SM_90");
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (output.dtype() == at::kFloat) {
if (num_experts == DEFAULT_NUM_EXPERTS) {
LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
unroll_float_output(
num_tokens, reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
} else if (num_experts == KIMI_K2_NUM_EXPERTS) {
LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
unroll_float_output(
num_tokens, reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
}
} else if (output.dtype() == at::kBFloat16) {
if (num_experts == DEFAULT_NUM_EXPERTS) {
LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
unroll_bf16_output(
num_tokens,
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
} else if (num_experts == KIMI_K2_NUM_EXPERTS) {
LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::
unroll_bf16_output(
num_tokens,
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream);
}
}
}
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include "dsv3_router_gemm_utils.h"
// Custom FMA implementation using PTX assembly instructions
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b,
float2 const& c) {
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
}
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec,
float* dst) {
__nv_bfloat16* bf16_ptr =
reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));
#pragma unroll
for (int i = 0; i < VPT; i++) {
dst[i] = __bfloat162float(bf16_ptr[i]);
}
}
template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts,
int kHiddenDim>
__global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output(
float* out, T const* mat_a, T const* mat_b) {
// Each block handles one expert column
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
// Constants for this kernel
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations =
kHiddenDim / k_elems_per_k_iteration; // Total K iterations
// Initialize accumulators for all M rows
float acc[kNumTokens] = {};
// Shared memory for warp-level reduction
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps
// B matrix is in column-major order, so we can directly load a column for the
// n_idx expert
T const* b_col = mat_b + n_idx * kHiddenDim;
// Pre-compute k_base values for each iteration to help compiler optimize
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
// Process the GEMM in chunks
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];
// Load B matrix values using vector load (8 bf16 values)
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);
// Convert B values to float
float b_float[VPT];
bf16_uint4_to_float8<VPT>(b_vec, b_float);
// Process each token
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
// Load both rows of A matrix using vector loads
uint4 a_vec = *reinterpret_cast<uint4 const*>(
mat_a + (m_idx * kHiddenDim) + k_base);
// Convert A values to float
float a_float[VPT];
bf16_uint4_to_float8<VPT>(a_vec, a_float);
// Process elements in this chunk
#pragma unroll
for (int k = 0; k < VPT; k++) {
float a = a_float[k];
float b = b_float[k];
acc[m_idx] += a * b;
}
}
}
// Perform warp-level reduction
int const warpSize = 32;
int const warpId = tid / warpSize;
int const laneId = tid % warpSize;
// Register for warp-level reduction results
float warp_result[kNumTokens];
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
warp_result[m_idx] = acc[m_idx];
}
// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = warp_result[m];
// Butterfly reduction pattern
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);
// Only the first thread in each warp stores to shared memory
if (laneId == 0) {
sm_reduction[m][warpId] = sum;
}
}
__syncthreads();
// Final reduction across warps (only first thread)
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;
// Sum across the kNumWarps
#pragma unroll
for (int w = 0; w < kNumWarps; w++) {
final_sum += sm_reduction[m][w];
}
// Write final result
out[m * kNumExperts + n_idx] = final_sum;
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b,
cudaStream_t stream) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config,
router_gemm_kernel_float_output<T, kBlockSize, VPT, kNumTokens,
kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
}
// Template instantiations for DEFAULT_NUM_EXPERTS experts
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
/*
* Adapted from SGLang's sgl-kernel implementation, which was adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cstdlib>
#include <mutex>
inline int getSMVersion() {
auto* props = at::cuda::getCurrentDeviceProperties();
return props->major * 10 + props->minor;
}
inline bool getEnvEnablePDL() {
static std::once_flag flag;
static bool enablePDL = false;
std::call_once(flag, [&]() {
if (getSMVersion() >= 90) {
const char* env = std::getenv("TRTLLM_ENABLE_PDL");
enablePDL = env && env[0] == '1' && env[1] == '\0';
}
});
return enablePDL;
}
...@@ -55,4 +55,15 @@ bool moe_permute_unpermute_supported(); ...@@ -55,4 +55,15 @@ bool moe_permute_unpermute_supported();
void shuffle_rows(const torch::Tensor& input_tensor, void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map, const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor); torch::Tensor& output_tensor);
\ No newline at end of file
#ifndef USE_ROCM
// DeepSeek V3 optimized router GEMM kernel for SM90+
// Computes output = mat_a @ mat_b.T where:
// mat_a: [num_tokens, hidden_dim] in bf16
// mat_b: [num_experts, hidden_dim] in bf16
// output: [num_tokens, num_experts] in bf16 or fp32
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a,
const torch::Tensor& mat_b);
#endif
...@@ -124,6 +124,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -124,6 +124,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, " "routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
"Tensor)"); "Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk); m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
#endif #endif
} }
......
...@@ -2190,6 +2190,21 @@ def moe_wna16_gemm( ...@@ -2190,6 +2190,21 @@ def moe_wna16_gemm(
) )
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weight: torch.Tensor,
output_dtype: torch.dtype,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weight.shape[0],
device=hidden_states.device,
dtype=output_dtype,
)
torch.ops._moe_C.dsv3_router_gemm(output, hidden_states, router_weight)
return output
def topk_softmax( def topk_softmax(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
......
...@@ -221,6 +221,73 @@ class DeepseekV2MLP(nn.Module): ...@@ -221,6 +221,73 @@ class DeepseekV2MLP(nn.Module):
return x return x
class DeepSeekV2Gate(ReplicatedLinear):
def __init__(
self,
hidden_size: int,
n_experts: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
assert quant_config is None
super().__init__(
hidden_size,
n_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)
# Unquantized only, will be called "weight".
assert hasattr(self, "weight")
is_hopper_or_blackwell = current_platform.is_device_capability(
(9, 0)
) or current_platform.is_device_capability_family(100)
SUPPORTED_NUM_EXPERTS = [256, 384]
SUPPORTED_HIDDEN_SIZES = [7168]
self.allow_dsv3_router_gemm = (
current_platform.is_cuda()
and is_hopper_or_blackwell
and n_experts in SUPPORTED_NUM_EXPERTS
and hidden_size in SUPPORTED_HIDDEN_SIZES
)
self._out_dtype: torch.dtype | None = None
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
"""
Set out dtype for the router logits. This is needed after
__init__, b/c we need to check if the trtllm kernel is
selected before we decide between bf16 and fp32.
"""
if self._out_dtype is not None:
raise ValueError("out_dtype has already been set")
else:
self._out_dtype = out_dtype
@property
def out_dtype(self) -> torch.dtype:
if self._out_dtype is None:
raise ValueError("out_dtype has not been set yet")
return self._out_dtype
def forward(
self,
x: torch.Tensor,
) -> tuple[torch.Tensor, None]:
"""
Use specialized GEMM for low batch size for DSV3 and KIMI.
"""
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
return ops.dsv3_router_gemm(
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
), None
else:
return super().forward(x)
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
self, self,
...@@ -249,10 +316,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -249,10 +316,9 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now." "Only silu is supported for now."
) )
self.gate = ReplicatedLinear( self.gate = DeepSeekV2Gate(
config.hidden_size, config.hidden_size,
config.n_routed_experts, config.n_routed_experts,
bias=False,
quant_config=None, quant_config=None,
prefix=f"{prefix}.gate", prefix=f"{prefix}.gate",
) )
...@@ -325,6 +391,13 @@ class DeepseekV2MoE(nn.Module): ...@@ -325,6 +391,13 @@ class DeepseekV2MoE(nn.Module):
else None, else None,
) )
# NOTE(rob): this is a hack until we finish off the PR for
# merging TRTLLM kernels into the MK framework. Then we can
# query the MonolithicMK for the expected router logits.
self.gate.set_out_dtype(
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment