"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "1deafd825449008ddb2c0b29469d1c56377dcadb"
Unverified Commit 6bdd2786 authored by Peter Pan's avatar Peter Pan Committed by GitHub
Browse files

[Kimi K2] dsv3_router_gemm supports NUM_EXPERTS == 384 (#8013)

parent 46e9d1c7
...@@ -13,9 +13,14 @@ from sgl_kernel import dsv3_router_gemm ...@@ -13,9 +13,14 @@ from sgl_kernel import dsv3_router_gemm
x_vals=[i + 1 for i in range(16)], x_vals=[i + 1 for i in range(16)],
x_log=False, x_log=False,
line_arg="impl", line_arg="impl",
line_vals=["torch", "sgl-kernel"], line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=["torch", "dsv3_router_gemm"], line_names=[
styles=[("blue", "-"), ("orange", "-")], "torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="TFLOPs", ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput", plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={}, args={},
...@@ -23,19 +28,26 @@ from sgl_kernel import dsv3_router_gemm ...@@ -23,19 +28,26 @@ from sgl_kernel import dsv3_router_gemm
) )
def benchmark_bf16_output(num_tokens, impl): def benchmark_bf16_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts # M: num_tokens, K: hidden_dim, N: num_experts
M, K, N = num_tokens, 7168, 256 M, K = num_tokens, 7168
if impl == "torch-256" or impl == "sgl-kernel-256":
N = 256
elif impl == "torch-384" or impl == "sgl-kernel-384":
N = 384
else:
raise ValueError(f"Unknown impl: {impl}")
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if impl == "torch": if impl == "torch-256" or impl == "torch-384":
def runner(): def runner():
F.linear(mat_a, mat_b) F.linear(mat_a, mat_b)
elif impl == "sgl-kernel": elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
def runner(): def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16) dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
...@@ -55,9 +67,14 @@ def benchmark_bf16_output(num_tokens, impl): ...@@ -55,9 +67,14 @@ def benchmark_bf16_output(num_tokens, impl):
x_vals=[i + 1 for i in range(16)], x_vals=[i + 1 for i in range(16)],
x_log=False, x_log=False,
line_arg="impl", line_arg="impl",
line_vals=["torch", "sgl-kernel"], line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=["torch", "dsv3_router_gemm"], line_names=[
styles=[("blue", "-"), ("orange", "-")], "torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="TFLOPs", ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput", plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={}, args={},
...@@ -65,19 +82,26 @@ def benchmark_bf16_output(num_tokens, impl): ...@@ -65,19 +82,26 @@ def benchmark_bf16_output(num_tokens, impl):
) )
def benchmark_float_output(num_tokens, impl): def benchmark_float_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts # M: num_tokens, K: hidden_dim, N: num_experts
M, K, N = num_tokens, 7168, 256 M, K = num_tokens, 7168
if impl == "torch-256" or impl == "sgl-kernel-256":
N = 256
elif impl == "torch-384" or impl == "sgl-kernel-384":
N = 384
else:
raise ValueError(f"Unknown impl: {impl}")
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous() mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if impl == "torch": if impl == "torch-256" or impl == "torch-384":
def runner(): def runner():
F.linear(mat_a, mat_b).to(torch.float32) F.linear(mat_a, mat_b).to(torch.float32)
elif impl == "sgl-kernel": elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
def runner(): def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32) dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
......
...@@ -185,6 +185,7 @@ void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const* ...@@ -185,6 +185,7 @@ void invokeRouterGemmBf16Output(__nv_bfloat16* output, T const* mat_a, T const*
mat_b); mat_b);
} }
// Template instantiations for DEFAULT_NUM_EXPERTS experts
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>( template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
...@@ -232,3 +233,52 @@ template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>( ...@@ -232,3 +233,52 @@ template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 256, 7168>(
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>( template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 256, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); __nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template void invokeRouterGemmBf16Output<__nv_bfloat16, 1, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 2, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 3, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 4, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 5, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 6, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 7, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 8, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 9, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 10, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 11, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 12, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 13, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 14, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 15, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmBf16Output<__nv_bfloat16, 16, 384, 7168>(
__nv_bfloat16*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
...@@ -25,6 +25,10 @@ ...@@ -25,6 +25,10 @@
#include "cuda_runtime.h" #include "cuda_runtime.h"
#include "utils.h" #include "utils.h"
static constexpr int DEFAULT_NUM_EXPERTS = 256;
static constexpr int KIMI_K2_NUM_EXPERTS = 384;
static constexpr int DEFAULT_HIDDEN_DIM = 7168;
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim> template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream); void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream);
...@@ -91,12 +95,24 @@ void dsv3_router_gemm( ...@@ -91,12 +95,24 @@ void dsv3_router_gemm(
TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2); TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
const int num_tokens = mat_a.size(0); const int num_tokens = mat_a.size(0);
constexpr int num_experts = 256; const int num_experts = mat_b.size(0);
constexpr int hidden_dim = 7168; const int hidden_dim = mat_a.size(1);
TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim"); TORCH_CHECK(mat_a.size(1) == mat_b.size(1), "mat_a and mat_b must have the same hidden_dim");
TORCH_CHECK(mat_a.size(1) == hidden_dim, "currently hidden_dim only supports 7168"); TORCH_CHECK(
TORCH_CHECK(mat_b.size(0) == num_experts, "currently num_experts only supports 256"); hidden_dim == DEFAULT_HIDDEN_DIM,
"Expected hidden_dim=",
DEFAULT_HIDDEN_DIM,
", but got hidden_dim=",
hidden_dim);
TORCH_CHECK(
num_experts == DEFAULT_NUM_EXPERTS || num_experts == KIMI_K2_NUM_EXPERTS,
"Expected num_experts=",
DEFAULT_NUM_EXPERTS,
" or num_experts=",
KIMI_K2_NUM_EXPERTS,
", but got num_experts=",
num_experts);
TORCH_CHECK( TORCH_CHECK(
num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm"); num_tokens >= 1 && num_tokens <= 16, "currently num_tokens must be less than or equal to 16 for router_gemm");
TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16"); TORCH_CHECK(mat_a.dtype() == torch::kBFloat16, "mat_a must be bf16");
...@@ -110,18 +126,36 @@ void dsv3_router_gemm( ...@@ -110,18 +126,36 @@ void dsv3_router_gemm(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (output.dtype() == torch::kFloat32) { if (output.dtype() == torch::kFloat32) {
LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_float_output( if (num_experts == DEFAULT_NUM_EXPERTS) {
num_tokens, LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_float_output(
reinterpret_cast<float*>(output.mutable_data_ptr()), num_tokens,
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
stream); reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
stream);
} else if (num_experts == KIMI_K2_NUM_EXPERTS) {
LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_float_output(
num_tokens,
reinterpret_cast<float*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
stream);
}
} else if (output.dtype() == torch::kBFloat16) { } else if (output.dtype() == torch::kBFloat16) {
LoopUnroller<1, 16, num_experts, hidden_dim>::unroll_bf16_output( if (num_experts == DEFAULT_NUM_EXPERTS) {
num_tokens, LoopUnroller<1, 16, DEFAULT_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_bf16_output(
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()), num_tokens,
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
stream); reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
stream);
} else if (num_experts == KIMI_K2_NUM_EXPERTS) {
LoopUnroller<1, 16, KIMI_K2_NUM_EXPERTS, DEFAULT_HIDDEN_DIM>::unroll_bf16_output(
num_tokens,
reinterpret_cast<__nv_bfloat16*>(output.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()),
stream);
}
} }
} }
...@@ -184,6 +184,7 @@ void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b, ...@@ -184,6 +184,7 @@ void invokeRouterGemmFloatOutput(float* output, T const* mat_a, T const* mat_b,
mat_b); mat_b);
} }
// Template instantiations for DEFAULT_NUM_EXPERTS experts
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>( template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
...@@ -231,3 +232,52 @@ template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>( ...@@ -231,3 +232,52 @@ template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 256, 7168>(
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>( template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 256, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t); float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
// Template instantiations for KIMI_K2_NUM_EXPERTS experts
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 1, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 2, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 3, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 4, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 5, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 6, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 7, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 8, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 9, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 10, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 11, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 12, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 13, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 14, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 15, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
template void invokeRouterGemmFloatOutput<__nv_bfloat16, 16, 384, 7168>(
float*, __nv_bfloat16 const*, __nv_bfloat16 const*, cudaStream_t);
...@@ -5,8 +5,8 @@ from sgl_kernel import dsv3_router_gemm ...@@ -5,8 +5,8 @@ from sgl_kernel import dsv3_router_gemm
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)]) @pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
def test_dsv3_router_gemm(num_tokens): @pytest.mark.parametrize("num_experts", [256, 384])
num_experts = 256 def test_dsv3_router_gemm(num_tokens, num_experts):
hidden_dim = 7168 hidden_dim = 7168
mat_a = torch.randn( mat_a = torch.randn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment