Unverified Commit 282eb59f authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Add bf16 output option for dsv3_router_gemm kernel (#7999)

parent 4540a466
......@@ -222,7 +222,9 @@ set(SOURCES
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/dsv3_router_gemm.cu"
"csrc/gemm/dsv3_router_gemm_bf16_out.cu"
"csrc/gemm/dsv3_router_gemm_entry.cu"
"csrc/gemm/dsv3_router_gemm_float_out.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
......
......@@ -7,6 +7,48 @@ import triton.testing
from sgl_kernel import dsv3_router_gemm
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch", "dsv3_router_gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={},
)
)
def benchmark_bf16_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K, N = num_tokens, 7168, 256
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8]
if impl == "torch":
def runner():
F.linear(mat_a, mat_b)
elif impl == "sgl-kernel":
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
......@@ -21,7 +63,7 @@ from sgl_kernel import dsv3_router_gemm
args={},
)
)
def benchmark(num_tokens, impl):
def benchmark_float_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K, N = num_tokens, 7168, 256
......@@ -38,7 +80,7 @@ def benchmark(num_tokens, impl):
elif impl == "sgl-kernel":
def runner():
dsv3_router_gemm(mat_a, mat_b)
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
......@@ -53,4 +95,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm")
benchmark_bf16_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_float_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cuda_bf16.h"
#include "cuda_runtime.h"
#include "utils.h"
// Custom FMA implementation using PTX assembly instructions
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) {
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
}
// Convert 8 bfloat16 values from a uint4 to float array - optimized conversion
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) {
__nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));
#pragma unroll
for (int i = 0; i < VPT; i++) {
dst[i] = __bfloat162float(bf16_ptr[i]);
}
}
template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts, int kHiddenDim>
__global__
__launch_bounds__(128, 1) void router_gemm_kernel_bf16_output(__nv_bfloat16* out, T const* mat_a, T const* mat_b) {
// Each block handles one expert column
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
constexpr int kWarpSize = 32;
constexpr int kNumWarps = kBlockSize / kWarpSize;
// Constants for this kernel
constexpr int k_elems_per_k_iteration = VPT * kBlockSize;
constexpr int k_iterations = kHiddenDim / k_elems_per_k_iteration; // Total K iterations
// Initialize accumulators for all M rows
float acc[kNumTokens] = {};
// Shared memory for warp-level reduction
__shared__ float sm_reduction[kNumTokens][kNumWarps]; // kNumWarps
// B matrix is in column-major order, so we can directly load a column for the n_idx expert
T const* b_col = mat_b + n_idx * kHiddenDim;
// Pre-compute k_base values for each iteration to help compiler optimize
// int k_bases[k_iterations];
int k_bases[k_iterations];
#pragma unroll
for (int ki = 0; ki < k_iterations; ki++) {
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
// Process the GEMM in chunks
for (int ki = 0; ki < k_iterations; ki++) {
int const k_base = k_bases[ki];
// Load B matrix values using vector load (8 bf16 values)
uint4 b_vec = *reinterpret_cast<uint4 const*>(b_col + k_base);
// Convert B values to float
float b_float[VPT];
bf16_uint4_to_float8<VPT>(b_vec, b_float);
// Process each token
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
// Load both rows of A matrix using vector loads
uint4 a_vec = *reinterpret_cast<uint4 const*>(mat_a + (m_idx * kHiddenDim) + k_base);
// Convert A values to float
float a_float[VPT];
bf16_uint4_to_float8<VPT>(a_vec, a_float);
// Process elements in this chunk
#pragma unroll
for (int k = 0; k < VPT; k++) {
float a = a_float[k];
float b = b_float[k];
acc[m_idx] += a * b;
}
}
}
// Perform warp-level reduction
int const warpSize = 32;
int const warpId = tid / warpSize;
int const laneId = tid % warpSize;
// Register for warp-level reduction results
float warp_result[kNumTokens];
#pragma unroll
for (int m_idx = 0; m_idx < kNumTokens; m_idx++) {
warp_result[m_idx] = acc[m_idx];
}
// Perform warp-level reduction using optimized butterfly pattern
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float sum = warp_result[m];
// Butterfly reduction pattern
sum += __shfl_xor_sync(0xffffffff, sum, 16);
sum += __shfl_xor_sync(0xffffffff, sum, 8);
sum += __shfl_xor_sync(0xffffffff, sum, 4);
sum += __shfl_xor_sync(0xffffffff, sum, 2);
sum += __shfl_xor_sync(0xffffffff, sum, 1);
// Only the first thread in each warp stores to shared memory
if (laneId == 0) {
sm_reduction[m][warpId] = sum;
}
}
__syncthreads();
// Final reduction across warps (only first thread)
if (tid == 0) {
#pragma unroll
for (int m = 0; m < kNumTokens; m++) {
float final_sum = 0.0f;
// Sum across the kNumWarps
#pragma unroll
for (int w = 0; w < kNumWarps; w++) {
final_sum += sm_reduction[m][w];
}
// Write final result
out[m * kNumExperts + n_idx] = __float2bfloat16(final_sum);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b, cudaStream_t stream) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config,
router_gemm_kernel_bf16_output<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>,
output,
mat_a,
mat_b);
}
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3RouterGemm.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/thop/dsv3RouterGemmOp.cpp
*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cuda_bf16.h"
#include "cuda_runtime.h"
#include "utils.h"
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream);
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* mat_b, cudaStream_t stream);
template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller {
static void unroll_float_output(
int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeRouterGemmFloatOutput<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll_float_output(
num_tokens, output, input, weights, stream);
}
}
static void unroll_bf16_output(
int num_tokens,
__nv_bfloat16* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeRouterGemmBf16Output<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll_bf16_output(
num_tokens, output, input, weights, stream);
}
}
};
template <int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void unroll_float_output(
int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemmFloatOutput<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
static void unroll_bf16_output(
int num_tokens,
__nv_bfloat16* output,
__nv_bfloat16 const* input,
__nv_bfloat16 const* weights,
cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemmBf16Output<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};
void dsv3_router_gemm(
torch::Tensor& output, // [num_tokens, num_experts]
const torch::Tensor& mat_a, // [num_tokens, hidden_dim]
const torch::Tensor& mat_b // [num_experts, hidden_dim]
) {
TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
const int num_tokens = mat_a.size(0);
constexpr int num_experts = 256;
constexpr int hidden_dim = 7168;
TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim");
TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168");
TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256");
TORCH_CHECK(
num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm");
TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16");
TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16");
TORCH_CHECK(
output.dtype() == torch::kFloat32 || output.dtype() == torch::kBFloat16, "output must be float32 or bf16");
auto const sm = getSMVersion();
TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90");
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (output.dtype() == torch::kFloat32) {
LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_float_output(
num_tokens,
reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
stream);
} else if (output.dtype() == torch::kBFloat16) {
LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_bf16_output(
num_tokens,
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
stream);
}
}
......@@ -46,7 +46,7 @@ __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* ds
}
template <typename T, int kBlockSize, int VPT, int kNumTokens, int kNumExperts, int kHiddenDim>
__global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const* mat_a, T const* mat_b) {
__global__ __launch_bounds__(128, 1) void router_gemm_kernel_float_output(float* out, T const* mat_a, T const* mat_b) {
// Each block handles one expert column
int const n_idx = blockIdx.x;
int const tid = threadIdx.x;
......@@ -163,7 +163,7 @@ __global__ __launch_bounds__(128, 1) void router_gemm_kernel(float* out, T const
}
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) {
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
......@@ -177,110 +177,57 @@ void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config, router_gemm_kernel<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>, output, mat_a, mat_b);
&config,
router_gemm_kernel_float_output<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>,
output,
mat_a,
mat_b);
}
template void
invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void
invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller {
static void
unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) {
if (num_tokens == kBegin) {
invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, stream);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll(num_tokens, output, input, weights, stream);
}
}
};
template <int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void
unroll(int num_tokens, float* output, __nv_bfloat16 const* input, __nv_bfloat16 const* weights, cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};
void dsv3_router_gemm(
torch::Tensor& output, // [num_tokens, num_experts]
const torch::Tensor& mat_a, // [num_tokens, hidden_dim]
const torch::Tensor& mat_b // [num_experts, hidden_dim]
) {
TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
const int num_tokens = mat_a.size(0);
constexpr int num_experts = 256;
constexpr int hidden_dim = 7168;
TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim");
TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168");
TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256");
TORCH_CHECK(
num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm");
TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16");
TORCH_CHECK(mat_b.dtype() == torch::kBFloat16, "mat_b must be bf16");
TORCH_CHECK(output.dtype() == torch::kFloat32, "output must be float32");
auto const sm = getSMVersion();
TORCH_CHECK(sm >= 90, "required CUDA ARCH >= SM_90");
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
LoopUnroller<1, 16, num_experts, hidden_dim>::unroll(
num_tokens,
reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
stream);
}
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
......@@ -262,12 +262,13 @@ def qserve_w4a8_per_group_gemm(
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weights: torch.Tensor,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weights.shape[0],
device=hidden_states.device,
dtype=torch.float32,
dtype=out_dtype,
)
torch.ops.sgl_kernel.dsv3_router_gemm(
output,
......
......@@ -15,17 +15,20 @@ def test_dsv3_router_gemm(num_tokens):
mat_b = torch.randn(
(num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda"
).contiguous()
output = torch.empty(
(num_tokens, num_experts), dtype=torch.float32, device="cuda"
).contiguous()
ref = F.linear(mat_a, mat_b).to(torch.float32)
bf16_ref = F.linear(mat_a, mat_b)
float_ref = bf16_ref.to(torch.float32)
bf16_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
float_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
output = dsv3_router_gemm(mat_a, mat_b)
assert torch.allclose(
bf16_output, bf16_ref, rtol=1e-2, atol=1e-3
), "Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference"
assert torch.allclose(
output, ref, rtol=1e-2, atol=1e-3
), "Router GEMM output mismatch with torch.nn.functional.linear reference"
float_output, float_ref, rtol=1e-2, atol=1e-3
), "Router GEMM output in float32 dtype mismatch with torch.nn.functional.linear reference"
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment