Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b1169d7b
Unverified
Commit
b1169d7b
authored
Mar 18, 2026
by
Xin Yang
Committed by
GitHub
Mar 18, 2026
Browse files
[Kernel] Add gpt-oss Router GEMM kernel (#37205)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
17808394
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
875 additions
and
13 deletions
+875
-13
CMakeLists.txt
CMakeLists.txt
+1
-0
benchmarks/kernels/benchmark_router_gemm.py
benchmarks/kernels/benchmark_router_gemm.py
+134
-0
csrc/moe/gpt_oss_router_gemm.cu
csrc/moe/gpt_oss_router_gemm.cu
+144
-0
csrc/moe/gpt_oss_router_gemm.cuh
csrc/moe/gpt_oss_router_gemm.cuh
+447
-0
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+4
-0
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+6
-0
tests/kernels/moe/test_router_gemm.py
tests/kernels/moe/test_router_gemm.py
+37
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+13
-0
vllm/lora/layers/__init__.py
vllm/lora/layers/__init__.py
+2
-0
vllm/lora/layers/gate_linear.py
vllm/lora/layers/gate_linear.py
+30
-0
vllm/lora/utils.py
vllm/lora/utils.py
+2
-0
vllm/model_executor/layers/fused_moe/router/gate_linear.py
vllm/model_executor/layers/fused_moe/router/gate_linear.py
+52
-6
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+3
-7
No files found.
CMakeLists.txt
View file @
b1169d7b
...
@@ -999,6 +999,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -999,6 +999,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list
(
APPEND VLLM_MOE_EXT_SRC
list
(
APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/gpt_oss_router_gemm.cu"
"csrc/moe/router_gemm.cu"
)
"csrc/moe/router_gemm.cu"
)
endif
()
endif
()
...
...
benchmarks/kernels/benchmark_router_gemm.py
0 → 100644
View file @
b1169d7b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS
=
[
256
,
384
]
DSV3_SUPPORTED_HIDDEN_SIZES
=
[
7168
]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS
=
[
32
,
128
]
GPT_OSS_SUPPORTED_HIDDEN_SIZES
=
[
2880
]
def
get_batch_size_range
(
max_batch_size
):
return
[
2
**
x
for
x
in
range
(
14
)
if
2
**
x
<=
max_batch_size
]
def
get_model_params
(
config
):
if
config
.
architectures
[
0
]
in
(
"DeepseekV2ForCausalLM"
,
"DeepseekV3ForCausalLM"
,
"DeepseekV32ForCausalLM"
,
):
num_experts
=
config
.
n_routed_experts
hidden_size
=
config
.
hidden_size
elif
config
.
architectures
[
0
]
in
(
"GptOssForCausalLM"
,):
num_experts
=
config
.
num_local_experts
hidden_size
=
config
.
hidden_size
else
:
raise
ValueError
(
f
"Unsupported architecture:
{
config
.
architectures
}
"
)
return
num_experts
,
hidden_size
def
get_benchmark
(
model
,
max_batch_size
,
trust_remote_code
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
get_batch_size_range
(
max_batch_size
),
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"vllm"
,
],
line_names
=
[
"PyTorch"
,
"vLLM"
],
styles
=
([(
"blue"
,
"-"
),
(
"red"
,
"-"
)]),
ylabel
=
"TFLOPs"
,
plot_name
=
f
"
{
model
}
router gemm throughput"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
):
config
=
get_config
(
model
=
model
,
trust_remote_code
=
trust_remote_code
)
num_experts
,
hidden_size
=
get_model_params
(
config
)
mat_a
=
torch
.
randn
(
(
batch_size
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
(
(
num_experts
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
bias
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
is_hopper_or_blackwell
=
current_platform
.
is_device_capability
(
90
)
or
current_platform
.
is_device_capability_family
(
100
)
allow_dsv3_router_gemm
=
(
is_hopper_or_blackwell
and
num_experts
in
DSV3_SUPPORTED_NUM_EXPERTS
and
hidden_size
in
DSV3_SUPPORTED_HIDDEN_SIZES
)
allow_gpt_oss_router_gemm
=
(
is_hopper_or_blackwell
and
num_experts
in
GPT_OSS_SUPPORTED_NUM_EXPERTS
and
hidden_size
in
GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
has_bias
=
False
if
allow_gpt_oss_router_gemm
:
has_bias
=
True
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch"
:
def
runner
():
if
has_bias
:
F
.
linear
(
mat_a
,
mat_b
,
bias
)
else
:
F
.
linear
(
mat_a
,
mat_b
)
elif
provider
==
"vllm"
:
def
runner
():
if
allow_dsv3_router_gemm
:
ops
.
dsv3_router_gemm
(
mat_a
,
mat_b
,
torch
.
bfloat16
)
elif
allow_gpt_oss_router_gemm
:
ops
.
gpt_oss_router_gemm
(
mat_a
,
mat_b
,
bias
)
else
:
raise
ValueError
(
"Unsupported router gemm"
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
runner
,
quantiles
=
quantiles
)
def
tflops
(
t_ms
):
flops
=
2
*
batch_size
*
hidden_size
*
num_experts
return
flops
/
(
t_ms
*
1e-3
)
/
1e12
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
return
benchmark
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"openai/gpt-oss-20b"
)
parser
.
add_argument
(
"--max-batch-size"
,
default
=
16
,
type
=
int
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
# Get the benchmark function
benchmark
=
get_benchmark
(
args
.
model
,
args
.
max_batch_size
,
args
.
trust_remote_code
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
)
csrc/moe/gpt_oss_router_gemm.cu
0 → 100644
View file @
b1169d7b
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_cuda.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include "gpt_oss_router_gemm.cuh"
void
launch_gpt_oss_router_gemm
(
__nv_bfloat16
*
gA
,
__nv_bfloat16
*
gB
,
__nv_bfloat16
*
gC
,
__nv_bfloat16
*
bias
,
int
batch_size
,
int
output_features
,
int
input_features
,
cudaStream_t
stream
)
{
static
int
const
WARP_TILE_M
=
16
;
static
int
const
TILE_M
=
WARP_TILE_M
;
static
int
const
TILE_N
=
8
;
static
int
const
TILE_K
=
64
;
static
int
const
STAGES
=
16
;
static
int
const
STAGE_UNROLL
=
4
;
static
bool
const
PROFILE
=
false
;
CUtensorMap
weight_map
{};
CUtensorMap
activation_map
{};
constexpr
uint32_t
rank
=
2
;
uint64_t
size
[
rank
]
=
{(
uint64_t
)
input_features
,
(
uint64_t
)
output_features
};
uint64_t
stride
[
rank
-
1
]
=
{
input_features
*
sizeof
(
__nv_bfloat16
)};
uint32_t
box_size
[
rank
]
=
{
TILE_K
,
TILE_M
};
uint32_t
elem_stride
[
rank
]
=
{
1
,
1
};
CUresult
res
=
cuTensorMapEncodeTiled
(
&
weight_map
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
,
rank
,
gB
,
size
,
stride
,
box_size
,
elem_stride
,
CUtensorMapInterleave
::
CU_TENSOR_MAP_INTERLEAVE_NONE
,
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_128B
,
CUtensorMapL2promotion
::
CU_TENSOR_MAP_L2_PROMOTION_NONE
,
CUtensorMapFloatOOBfill
::
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
TORCH_CHECK
(
res
==
CUDA_SUCCESS
,
"cuTensorMapEncodeTiled failed for weight_map, error code="
,
static_cast
<
int
>
(
res
));
size
[
1
]
=
batch_size
;
box_size
[
1
]
=
TILE_N
;
res
=
cuTensorMapEncodeTiled
(
&
activation_map
,
CUtensorMapDataType
::
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
,
rank
,
gA
,
size
,
stride
,
box_size
,
elem_stride
,
CUtensorMapInterleave
::
CU_TENSOR_MAP_INTERLEAVE_NONE
,
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_128B
,
CUtensorMapL2promotion
::
CU_TENSOR_MAP_L2_PROMOTION_NONE
,
CUtensorMapFloatOOBfill
::
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
TORCH_CHECK
(
res
==
CUDA_SUCCESS
,
"cuTensorMapEncodeTiled failed for activation_map, error code="
,
static_cast
<
int
>
(
res
));
int
smem_size
=
STAGES
*
STAGE_UNROLL
*
(
TILE_M
*
TILE_K
*
sizeof
(
__nv_bfloat16
)
+
TILE_N
*
TILE_K
*
sizeof
(
__nv_bfloat16
));
gpuErrChk
(
cudaFuncSetAttribute
(
gpt_oss_router_gemm_kernel
<
WARP_TILE_M
,
TILE_M
,
TILE_N
,
TILE_K
,
STAGES
,
STAGE_UNROLL
,
PROFILE
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
int
tiles_m
=
(
output_features
+
TILE_M
-
1
)
/
TILE_M
;
int
tiles_n
=
(
batch_size
+
TILE_N
-
1
)
/
TILE_N
;
dim3
grid
(
tiles_m
,
tiles_n
);
dim3
block
(
384
);
cudaLaunchConfig_t
config
;
cudaLaunchAttribute
attrs
[
1
];
config
.
gridDim
=
grid
;
config
.
blockDim
=
block
;
config
.
dynamicSmemBytes
=
smem_size
;
config
.
stream
=
stream
;
config
.
attrs
=
attrs
;
attrs
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attrs
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
config
.
numAttrs
=
1
;
cudaLaunchKernelEx
(
&
config
,
&
gpt_oss_router_gemm_kernel
<
WARP_TILE_M
,
TILE_M
,
TILE_N
,
TILE_K
,
STAGES
,
STAGE_UNROLL
,
PROFILE
>
,
gC
,
gA
,
gB
,
bias
,
output_features
,
batch_size
,
input_features
,
weight_map
,
activation_map
,
nullptr
);
}
void
gpt_oss_router_gemm_cuda_forward
(
torch
::
Tensor
&
output
,
torch
::
Tensor
input
,
torch
::
Tensor
weight
,
torch
::
Tensor
bias
)
{
auto
const
batch_size
=
input
.
size
(
0
);
auto
const
input_dim
=
input
.
size
(
1
);
auto
const
output_dim
=
weight
.
size
(
0
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
launch_gpt_oss_router_gemm
((
__nv_bfloat16
*
)
input
.
data_ptr
(),
(
__nv_bfloat16
*
)
weight
.
data_ptr
(),
(
__nv_bfloat16
*
)
output
.
mutable_data_ptr
(),
(
__nv_bfloat16
*
)
bias
.
data_ptr
(),
batch_size
,
output_dim
,
input_dim
,
stream
);
}
else
{
throw
std
::
invalid_argument
(
"Unsupported dtype, only supports bfloat16"
);
}
}
void
gpt_oss_router_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
input
,
torch
::
Tensor
weight
,
torch
::
Tensor
bias
)
{
TORCH_CHECK
(
input
.
dim
()
==
2
,
"input must be 2D"
);
TORCH_CHECK
(
weight
.
dim
()
==
2
,
"weight must be 2D"
);
TORCH_CHECK
(
bias
.
dim
()
==
1
,
"bias must be 1D"
);
TORCH_CHECK
(
input
.
sizes
()[
1
]
==
weight
.
sizes
()[
1
],
"input.size(1) must match weight.size(1)"
);
TORCH_CHECK
(
weight
.
sizes
()[
0
]
==
bias
.
sizes
()[
0
],
"weight.size(0) must match bias.size(0)"
);
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"input tensor must be bfloat16"
);
TORCH_CHECK
(
weight
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"weight tensor must be bfloat16"
);
TORCH_CHECK
(
bias
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"bias tensor must be bfloat16"
);
gpt_oss_router_gemm_cuda_forward
(
output
,
input
,
weight
,
bias
);
}
csrc/moe/gpt_oss_router_gemm.cuh
0 → 100644
View file @
b1169d7b
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc7/cpp/tensorrt_llm/kernels/tinygemm2/tinygemm2_kernel.cuh
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_bf16.h"
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "cuda_pipeline.h"
#include <cuda.h>
#include <cuda/barrier>
#include <cuda/std/utility>
#include <cuda_runtime.h>
using
barrier
=
cuda
::
barrier
<
cuda
::
thread_scope_block
>
;
namespace
cde
=
cuda
::
device
::
experimental
;
namespace
ptx
=
cuda
::
ptx
;
#define gpuErrChk(ans) \
{ \
gpuAssert((ans), __FILE__, __LINE__); \
}
inline
void
gpuAssert
(
cudaError_t
code
,
char
const
*
file
,
int
line
,
bool
abort
=
true
)
{
if
(
code
!=
cudaSuccess
)
{
fprintf
(
stderr
,
"GPUassert: %s %s %d
\n
"
,
cudaGetErrorString
(
code
),
file
,
line
);
if
(
abort
)
{
throw
std
::
runtime_error
(
cudaGetErrorString
(
code
));
}
}
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
__device__
uint64_t
gclock64
()
{
unsigned
long
long
int
rv
;
asm
volatile
(
"mov.u64 %0, %%globaltimer;"
:
"=l"
(
rv
));
return
rv
;
}
__device__
void
ldmatrix
(
__nv_bfloat16
rv
[
2
],
uint32_t
smem_ptr
)
{
int
dst
;
asm
volatile
(
"ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
smem_ptr
));
int
*
rvi
=
reinterpret_cast
<
int
*>
(
&
rv
[
0
]);
rvi
[
0
]
=
dst
;
}
__device__
void
ldmatrix2
(
__nv_bfloat16
rv
[
4
],
uint32_t
smem_ptr
)
{
int
x
,
y
;
asm
volatile
(
"ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
x
),
"=r"
(
y
)
:
"r"
(
smem_ptr
));
int
*
rvi
=
reinterpret_cast
<
int
*>
(
&
rv
[
0
]);
rvi
[
0
]
=
x
;
rvi
[
1
]
=
y
;
}
__device__
void
ldmatrix4
(
__nv_bfloat16
rv
[
8
],
uint32_t
smem_ptr
)
{
int
x
,
y
,
z
,
w
;
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
x
),
"=r"
(
y
),
"=r"
(
z
),
"=r"
(
w
)
:
"r"
(
smem_ptr
));
int
*
rvi
=
reinterpret_cast
<
int
*>
(
&
rv
[
0
]);
rvi
[
0
]
=
x
;
rvi
[
1
]
=
y
;
rvi
[
2
]
=
z
;
rvi
[
3
]
=
w
;
}
__device__
void
HMMA_1688
(
float
d
[
4
],
__nv_bfloat16
a
[
4
],
__nv_bfloat16
b
[
2
],
float
c
[
4
])
{
uint32_t
const
*
A
=
reinterpret_cast
<
uint32_t
const
*>
(
&
a
[
0
]);
uint32_t
const
*
B
=
reinterpret_cast
<
uint32_t
const
*>
(
&
b
[
0
]);
float
const
*
C
=
reinterpret_cast
<
float
const
*>
(
&
c
[
0
]);
float
*
D
=
reinterpret_cast
<
float
*>
(
&
d
[
0
]);
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};
\n
"
:
"=f"
(
D
[
0
]),
"=f"
(
D
[
1
]),
"=f"
(
D
[
2
]),
"=f"
(
D
[
3
])
:
"r"
(
A
[
0
]),
"r"
(
A
[
1
]),
"r"
(
B
[
0
]),
"f"
(
C
[
0
]),
"f"
(
C
[
1
]),
"f"
(
C
[
2
]),
"f"
(
C
[
3
]));
}
__device__
void
HMMA_16816
(
float
d
[
4
],
__nv_bfloat16
a
[
8
],
__nv_bfloat16
b
[
4
],
float
c
[
4
])
{
uint32_t
const
*
A
=
reinterpret_cast
<
uint32_t
const
*>
(
&
a
[
0
]);
uint32_t
const
*
B
=
reinterpret_cast
<
uint32_t
const
*>
(
&
b
[
0
]);
float
const
*
C
=
reinterpret_cast
<
float
const
*>
(
&
c
[
0
]);
float
*
D
=
reinterpret_cast
<
float
*>
(
&
d
[
0
]);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
D
[
0
]),
"=f"
(
D
[
1
]),
"=f"
(
D
[
2
]),
"=f"
(
D
[
3
])
:
"r"
(
A
[
0
]),
"r"
(
A
[
1
]),
"r"
(
A
[
2
]),
"r"
(
A
[
3
]),
"r"
(
B
[
0
]),
"r"
(
B
[
1
]),
"f"
(
C
[
0
]),
"f"
(
C
[
1
]),
"f"
(
C
[
2
]),
"f"
(
C
[
3
]));
}
__device__
void
bar_wait
(
uint32_t
bar_ptr
,
int
phase
)
{
asm
volatile
(
"{
\n
"
".reg .pred P1;
\n
"
"LAB_WAIT:
\n
"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;
\n
"
"@P1 bra.uni DONE;
\n
"
"bra.uni LAB_WAIT;
\n
"
"DONE:
\n
"
"}
\n
"
::
"r"
(
bar_ptr
),
"r"
(
phase
));
}
__device__
bool
bar_try_wait
(
uint32_t
bar_ptr
,
int
phase
)
{
uint32_t
success
;
#ifdef INTERNAL
asm
volatile
(
".pragma
\"
set knob DontInsertYield
\"
;
\n
"
:
:
:
"memory"
);
#endif
asm
volatile
(
"{
\n\t
"
".reg .pred P1;
\n\t
"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2;
\n\t
"
"selp.b32 %0, 1, 0, P1;
\n\t
"
"}"
:
"=r"
(
success
)
:
"r"
(
bar_ptr
),
"r"
(
phase
));
return
success
;
}
__device__
uint32_t
elect_one_sync
()
{
uint32_t
pred
=
0
;
uint32_t
laneid
=
0
;
asm
volatile
(
"{
\n
"
".reg .b32 %%rx;
\n
"
".reg .pred %%px;
\n
"
" elect.sync %%rx|%%px, %2;
\n
"
"@%%px mov.s32 %1, 1;
\n
"
" mov.s32 %0, %%rx;
\n
"
"}
\n
"
:
"+r"
(
laneid
),
"+r"
(
pred
)
:
"r"
(
0xFFFFFFFF
));
return
pred
;
}
#endif
struct
Profile
{
uint64_t
start
;
uint64_t
weight_load_start
;
uint64_t
act_load_start
;
uint64_t
compute_start
;
uint64_t
complete
;
};
template
<
int
WARP_TILE_M
,
int
TILE_M
,
int
TILE_N
,
int
TILE_K
,
int
STAGES
,
int
STAGE_UNROLL
,
bool
PROFILE
>
__global__
__launch_bounds__
(
384
,
1
)
void
gpt_oss_router_gemm_kernel
(
__nv_bfloat16
*
output
,
__nv_bfloat16
*
weights
,
__nv_bfloat16
*
activations
,
__nv_bfloat16
*
bias
,
int
M
,
int
N
,
int
K
,
const
__grid_constant__
CUtensorMap
weight_map
,
const
__grid_constant__
CUtensorMap
activation_map
,
Profile
*
profile
=
nullptr
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if
(
PROFILE
&&
threadIdx
.
x
==
0
&&
blockIdx
.
y
==
0
)
profile
[
blockIdx
.
x
].
start
=
gclock64
();
extern
__shared__
__align__
(
128
)
char
smem
[];
__nv_bfloat16
*
sh_weights
=
(
__nv_bfloat16
*
)
&
smem
[
0
];
__nv_bfloat16
*
sh_activations
=
(
__nv_bfloat16
*
)
&
smem
[
STAGES
*
STAGE_UNROLL
*
TILE_M
*
TILE_K
*
sizeof
(
__nv_bfloat16
)];
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__
barrier
bar_wt_ready
[
STAGES
];
__shared__
barrier
bar_act_ready
[
STAGES
];
__shared__
barrier
bar_data_consumed
[
STAGES
];
__shared__
float4
reduction_buffer
[
128
];
__shared__
nv_bfloat16
sh_bias
[
TILE_M
];
if
(
threadIdx
.
x
==
0
)
{
for
(
int
i
=
0
;
i
<
STAGES
;
i
++
)
{
init
(
&
bar_wt_ready
[
i
],
1
);
init
(
&
bar_act_ready
[
i
],
1
);
init
(
&
bar_data_consumed
[
i
],
32
);
}
ptx
::
fence_proxy_async
(
ptx
::
space_shared
);
asm
volatile
(
"prefetch.tensormap [%0];"
:
:
"l"
(
reinterpret_cast
<
uint64_t
>
(
&
weight_map
))
:
"memory"
);
asm
volatile
(
"prefetch.tensormap [%0];"
:
:
"l"
(
reinterpret_cast
<
uint64_t
>
(
&
activation_map
))
:
"memory"
);
}
__syncthreads
();
int
warp_id
=
threadIdx
.
x
/
32
;
int
lane_id
=
threadIdx
.
x
%
32
;
int
phase
=
0
;
int
mib
=
blockIdx
.
x
*
TILE_M
;
int
ni
=
blockIdx
.
y
*
TILE_N
;
float
accum
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
accum
[
i
]
=
0.
f
;
int
const
K_LOOPS_DMA
=
(
K
+
4
*
TILE_K
*
STAGE_UNROLL
-
1
)
/
(
4
*
(
TILE_K
*
STAGE_UNROLL
));
int
const
K_LOOPS_COMPUTE
=
K_LOOPS_DMA
;
// Data loading thread
if
(
warp_id
>=
4
&&
elect_one_sync
())
{
int
stage
=
warp_id
%
4
;
bool
weight_warp
=
warp_id
<
8
;
if
(
!
weight_warp
)
{
cudaGridDependencySynchronize
();
cudaTriggerProgrammaticLaunchCompletion
();
}
for
(
int
ki
=
0
;
ki
<
K_LOOPS_DMA
;
ki
++
)
{
int
k
=
(
ki
*
4
+
(
warp_id
%
4
))
*
TILE_K
*
STAGE_UNROLL
;
uint64_t
desc_ptr_wt
=
reinterpret_cast
<
uint64_t
>
(
&
weight_map
);
uint64_t
desc_ptr_act
=
reinterpret_cast
<
uint64_t
>
(
&
activation_map
);
uint32_t
bar_ptr_wt
=
__cvta_generic_to_shared
(
&
bar_wt_ready
[
stage
]);
uint32_t
bar_ptr_act
=
__cvta_generic_to_shared
(
&
bar_act_ready
[
stage
]);
int
bytes_wt
=
TILE_M
*
TILE_K
*
sizeof
(
__nv_bfloat16
);
int
bytes_act
=
TILE_N
*
TILE_K
*
sizeof
(
__nv_bfloat16
);
bar_wait
(
__cvta_generic_to_shared
(
&
bar_data_consumed
[
stage
]),
phase
^
1
);
if
(
weight_warp
)
asm
volatile
(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
:
"r"
(
bar_ptr_wt
),
"r"
(
STAGE_UNROLL
*
bytes_wt
));
if
(
!
weight_warp
)
asm
volatile
(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
:
"r"
(
bar_ptr_act
),
"r"
(
STAGE_UNROLL
*
bytes_act
));
if
(
PROFILE
&&
blockIdx
.
y
==
0
&&
ki
==
0
&&
weight_warp
)
profile
[
blockIdx
.
x
].
weight_load_start
=
gclock64
();
if
(
PROFILE
&&
blockIdx
.
y
==
0
&&
ki
==
0
&&
!
weight_warp
)
profile
[
blockIdx
.
x
].
act_load_start
=
gclock64
();
for
(
int
i
=
0
;
i
<
STAGE_UNROLL
;
i
++
)
{
uint32_t
smem_ptr_wt
=
__cvta_generic_to_shared
(
&
sh_weights
[(
stage
*
STAGE_UNROLL
+
i
)
*
TILE_M
*
TILE_K
]);
uint32_t
crd0
=
k
+
i
*
TILE_K
;
uint32_t
crd1
=
mib
;
if
(
weight_warp
)
asm
volatile
(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
:
"r"
(
smem_ptr_wt
),
"l"
(
desc_ptr_wt
),
"r"
(
bar_ptr_wt
),
"r"
(
crd0
),
"r"
(
crd1
)
:
"memory"
);
uint32_t
smem_ptr_act
=
__cvta_generic_to_shared
(
&
sh_activations
[(
stage
*
STAGE_UNROLL
+
i
)
*
TILE_N
*
TILE_K
]);
crd0
=
k
+
i
*
TILE_K
;
crd1
=
ni
;
if
(
!
weight_warp
)
asm
volatile
(
"cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_"
"tx::bytes [%0], [%1, {%3,%4}], "
"[%2];"
:
:
"r"
(
smem_ptr_act
),
"l"
(
desc_ptr_act
),
"r"
(
bar_ptr_act
),
"r"
(
crd0
),
"r"
(
crd1
)
:
"memory"
);
}
stage
+=
4
;
if
(
stage
>=
STAGES
)
{
stage
=
warp_id
%
4
;
phase
^=
1
;
}
}
// Wait for pending loads to be consumed before exiting, to avoid race
for
(
int
i
=
0
;
i
<
(
STAGES
/
4
)
-
1
;
i
++
)
{
bar_wait
(
__cvta_generic_to_shared
(
&
bar_data_consumed
[
stage
]),
phase
^
1
);
stage
+=
4
;
if
(
stage
>=
STAGES
)
{
stage
=
warp_id
%
4
;
phase
^=
1
;
}
}
}
// Compute threads
else
if
(
warp_id
<
4
)
{
// Sneak the bias load into the compute warps since they're just waiting for
// stuff anyway
if
(
threadIdx
.
x
<
TILE_M
)
sh_bias
[
threadIdx
.
x
]
=
bias
[
mib
+
threadIdx
.
x
];
int
stage
=
warp_id
;
int
phase
=
0
;
int
lane_id_div8
=
lane_id
/
8
;
int
lane_id_mod8
=
lane_id
%
8
;
int
lane_row_offset_wt
=
(
lane_id_div8
%
2
)
?
8
:
0
;
int
lane_col_offset_wt
=
(
lane_id_div8
/
2
)
?
1
:
0
;
int
row_wt
=
lane_id_mod8
+
lane_row_offset_wt
;
int
row_act
=
lane_id_mod8
;
int
row_offset_wt
=
(
reinterpret_cast
<
uintptr_t
>
(
sh_weights
)
/
128
)
%
8
;
int
row_offset_act
=
row_offset_wt
;
uint32_t
bar_ptr_wt
=
__cvta_generic_to_shared
(
&
bar_wt_ready
[
stage
]);
uint32_t
bar_ptr_act
=
__cvta_generic_to_shared
(
&
bar_act_ready
[
stage
]);
bool
weight_ready
=
bar_try_wait
(
bar_ptr_wt
,
phase
);
bool
act_ready
=
bar_try_wait
(
bar_ptr_act
,
phase
);
#pragma unroll 2
for
(
int
ki
=
0
;
ki
<
K_LOOPS_COMPUTE
;
ki
++
)
{
int
next_stage
=
stage
+
4
;
int
next_phase
=
phase
;
if
(
next_stage
>=
STAGES
)
{
next_stage
=
warp_id
;
next_phase
^=
1
;
}
while
(
!
weight_ready
||
!
act_ready
)
{
weight_ready
=
bar_try_wait
(
bar_ptr_wt
,
phase
);
act_ready
=
bar_try_wait
(
bar_ptr_act
,
phase
);
}
if
(
PROFILE
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
&&
ki
==
0
)
profile
[
blockIdx
.
x
].
compute_start
=
gclock64
();
if
(
ki
+
1
<
K_LOOPS_COMPUTE
)
{
weight_ready
=
bar_try_wait
(
__cvta_generic_to_shared
(
&
bar_wt_ready
[
next_stage
]),
next_phase
);
act_ready
=
bar_try_wait
(
__cvta_generic_to_shared
(
&
bar_act_ready
[
next_stage
]),
next_phase
);
}
#pragma unroll
for
(
int
su
=
0
;
su
<
STAGE_UNROLL
;
su
++
)
{
__nv_bfloat16
*
ptr_weights
=
&
sh_weights
[(
stage
*
STAGE_UNROLL
+
su
)
*
TILE_M
*
TILE_K
];
__nv_bfloat16
*
ptr_act
=
&
sh_activations
[(
stage
*
STAGE_UNROLL
+
su
)
*
TILE_N
*
TILE_K
];
#pragma unroll
for
(
int
kii
=
0
;
kii
<
TILE_K
/
16
;
kii
++
)
{
__nv_bfloat16
a
[
8
];
__nv_bfloat16
b
[
4
];
int
col
=
2
*
kii
+
lane_col_offset_wt
;
int
col_sw
=
((
row_wt
+
row_offset_wt
)
%
8
)
^
col
;
ldmatrix4
(
a
,
__cvta_generic_to_shared
(
&
ptr_weights
[
row_wt
*
TILE_K
+
col_sw
*
8
]));
col
=
2
*
kii
+
lane_id_div8
;
col_sw
=
((
row_act
+
row_offset_act
)
%
8
)
^
col
;
ldmatrix2
(
b
,
__cvta_generic_to_shared
(
&
ptr_act
[
row_act
*
TILE_K
+
8
*
col_sw
]));
HMMA_16816
(
accum
,
a
,
b
,
accum
);
}
}
uint32_t
bar_c
=
__cvta_generic_to_shared
(
&
bar_data_consumed
[
stage
]);
asm
volatile
(
"mbarrier.arrive.shared::cta.b64 _, [%0];"
:
:
"r"
(
bar_c
));
stage
=
next_stage
;
phase
=
next_phase
;
}
float4
accum4
;
accum4
.
x
=
accum
[
0
];
accum4
.
y
=
accum
[
1
];
accum4
.
z
=
accum
[
2
];
accum4
.
w
=
accum
[
3
];
reduction_buffer
[
threadIdx
.
x
]
=
accum4
;
__syncthreads
();
if
(
warp_id
==
0
)
{
int
mi
=
mib
+
warp_id
*
WARP_TILE_M
;
int
tm
=
mi
+
lane_id
/
4
;
int
tn
=
ni
+
2
*
(
lane_id
%
4
);
float4
accum1
=
reduction_buffer
[
32
+
threadIdx
.
x
];
float4
accum2
=
reduction_buffer
[
64
+
threadIdx
.
x
];
float4
accum3
=
reduction_buffer
[
96
+
threadIdx
.
x
];
accum
[
0
]
=
accum
[
0
]
+
accum1
.
x
+
accum2
.
x
+
accum3
.
x
;
accum
[
1
]
=
accum
[
1
]
+
accum1
.
y
+
accum2
.
y
+
accum3
.
y
;
accum
[
2
]
=
accum
[
2
]
+
accum1
.
z
+
accum2
.
z
+
accum3
.
z
;
accum
[
3
]
=
accum
[
3
]
+
accum1
.
w
+
accum2
.
w
+
accum3
.
w
;
float
bias_lo
=
__bfloat162float
(
sh_bias
[
tm
-
mib
]);
float
bias_hi
=
__bfloat162float
(
sh_bias
[
tm
+
8
-
mib
]);
if
(
tn
<
N
&&
tm
<
M
)
output
[
tn
*
M
+
tm
]
=
__float2bfloat16
(
accum
[
0
]
+
bias_lo
);
if
(
tn
+
1
<
N
&&
tm
<
M
)
output
[(
tn
+
1
)
*
M
+
tm
]
=
__float2bfloat16
(
accum
[
1
]
+
bias_lo
);
if
(
tn
<
N
&&
tm
+
8
<
M
)
output
[
tn
*
M
+
tm
+
8
]
=
__float2bfloat16
(
accum
[
2
]
+
bias_hi
);
if
(
tn
+
1
<
N
&&
tm
+
8
<
M
)
output
[(
tn
+
1
)
*
M
+
tm
+
8
]
=
__float2bfloat16
(
accum
[
3
]
+
bias_hi
);
if
(
PROFILE
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
profile
[
blockIdx
.
x
].
complete
=
gclock64
();
}
}
#endif // end if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
csrc/moe/moe_ops.h
View file @
b1169d7b
...
@@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
...
@@ -70,4 +70,8 @@ torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
// Supports num_tokens in [1, 16], num_experts in {256, 384}, hidden_dim = 7168
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
mat_a
,
void
dsv3_router_gemm
(
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
);
const
torch
::
Tensor
&
mat_b
);
// gpt-oss optimized router GEMM kernel for SM90+
void
gpt_oss_router_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
input
,
torch
::
Tensor
weight
,
torch
::
Tensor
bias
);
#endif
#endif
csrc/moe/torch_bindings.cpp
View file @
b1169d7b
...
@@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -132,6 +132,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// DeepSeek V3 optimized router GEMM for SM90+
// DeepSeek V3 optimized router GEMM for SM90+
m
.
def
(
"dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"
);
m
.
def
(
"dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"
);
// conditionally compiled so impl registration is in source file
// conditionally compiled so impl registration is in source file
// gpt-oss optimized router GEMM kernel for SM90+
m
.
def
(
"gpt_oss_router_gemm(Tensor! output, Tensor input, Tensor weights, "
"Tensor bias) -> ()"
);
m
.
impl
(
"gpt_oss_router_gemm"
,
torch
::
kCUDA
,
&
gpt_oss_router_gemm
);
#endif
#endif
}
}
...
...
tests/kernels/moe/test_router_gemm.py
0 → 100644
View file @
b1169d7b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for optimized router GEMM kernel
Run `pytest tests/kernels/moe/test_router_gemm.py`.
"""
import
pytest
import
torch
import
vllm._custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
@
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_cuda
()
and
(
current_platform
.
is_device_capability
(
90
)
or
current_platform
.
is_device_capability_family
(
100
)
)
),
reason
=
"This test only runs on Hopper or Blackwell GPUs."
,
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"input_dim"
,
[
360
,
720
,
1440
,
2880
])
@
pytest
.
mark
.
parametrize
(
"output_dim"
,
[
32
,
64
,
128
])
def
test_gpt_oss_router_gemm
(
batch_size
,
input_dim
,
output_dim
):
set_random_seed
(
0
)
x
=
torch
.
randn
(
batch_size
,
input_dim
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
weight
=
torch
.
randn
(
output_dim
,
input_dim
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
bias
=
torch
.
randn
(
output_dim
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
output
=
ops
.
gpt_oss_router_gemm
(
x
,
weight
,
bias
)
output_ref
=
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
torch
.
testing
.
assert_close
(
output
,
output_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
vllm/_custom_ops.py
View file @
b1169d7b
...
@@ -2362,6 +2362,19 @@ def dsv3_router_gemm(
...
@@ -2362,6 +2362,19 @@ def dsv3_router_gemm(
return
output
return
output
def
gpt_oss_router_gemm
(
hidden_states
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
torch
.
empty
(
hidden_states
.
shape
[
0
],
weight
.
shape
[
0
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
torch
.
ops
.
_moe_C
.
gpt_oss_router_gemm
(
output
,
hidden_states
,
weight
,
bias
)
return
output
def
topk_softmax
(
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
...
...
vllm/lora/layers/__init__.py
View file @
b1169d7b
...
@@ -13,6 +13,7 @@ from vllm.lora.layers.column_parallel_linear import (
...
@@ -13,6 +13,7 @@ from vllm.lora.layers.column_parallel_linear import (
QKVParallelLinearWithShardedLoRA
,
QKVParallelLinearWithShardedLoRA
,
)
)
from
vllm.lora.layers.fused_moe
import
FusedMoE3DWithLoRA
,
FusedMoEWithLoRA
from
vllm.lora.layers.fused_moe
import
FusedMoE3DWithLoRA
,
FusedMoEWithLoRA
from
vllm.lora.layers.gate_linear
import
GateLinearWithLoRA
from
vllm.lora.layers.logits_processor
import
LogitsProcessorWithLoRA
from
vllm.lora.layers.logits_processor
import
LogitsProcessorWithLoRA
from
vllm.lora.layers.replicated_linear
import
ReplicatedLinearWithLoRA
from
vllm.lora.layers.replicated_linear
import
ReplicatedLinearWithLoRA
from
vllm.lora.layers.row_parallel_linear
import
(
from
vllm.lora.layers.row_parallel_linear
import
(
...
@@ -38,6 +39,7 @@ __all__ = [
...
@@ -38,6 +39,7 @@ __all__ = [
"RowParallelLinearWithLoRA"
,
"RowParallelLinearWithLoRA"
,
"RowParallelLinearWithShardedLoRA"
,
"RowParallelLinearWithShardedLoRA"
,
"ReplicatedLinearWithLoRA"
,
"ReplicatedLinearWithLoRA"
,
"GateLinearWithLoRA"
,
"LoRAMapping"
,
"LoRAMapping"
,
"LoRAMappingType"
,
"LoRAMappingType"
,
"FusedMoEWithLoRA"
,
"FusedMoEWithLoRA"
,
...
...
vllm/lora/layers/gate_linear.py
0 → 100644
View file @
b1169d7b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.model_executor.custom_op
import
maybe_get_oot_by_class
from
vllm.model_executor.layers.fused_moe.router.gate_linear
import
GateLinear
from
.replicated_linear
import
ReplicatedLinearWithLoRA
class
GateLinearWithLoRA
(
ReplicatedLinearWithLoRA
):
def
__init__
(
self
,
base_layer
:
GateLinear
)
->
None
:
super
().
__init__
(
base_layer
,
)
# GateLinearWithLoRA should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
return
type
(
source_layer
)
is
maybe_get_oot_by_class
(
GateLinear
)
vllm/lora/utils.py
View file @
b1169d7b
...
@@ -21,6 +21,7 @@ from vllm.lora.layers import (
...
@@ -21,6 +21,7 @@ from vllm.lora.layers import (
ColumnParallelLinearWithShardedLoRA
,
ColumnParallelLinearWithShardedLoRA
,
FusedMoE3DWithLoRA
,
FusedMoE3DWithLoRA
,
FusedMoEWithLoRA
,
FusedMoEWithLoRA
,
GateLinearWithLoRA
,
LogitsProcessorWithLoRA
,
LogitsProcessorWithLoRA
,
MergedColumnParallelLinearVariableSliceWithLoRA
,
MergedColumnParallelLinearVariableSliceWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
...
@@ -81,6 +82,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
...
@@ -81,6 +82,7 @@ _all_lora_classes: set[type[BaseLayerWithLoRA]] = {
MergedQKVParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLoRA
,
RowParallelLinearWithLoRA
,
RowParallelLinearWithLoRA
,
ReplicatedLinearWithLoRA
,
ReplicatedLinearWithLoRA
,
GateLinearWithLoRA
,
LogitsProcessorWithLoRA
,
LogitsProcessorWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
ColumnParallelLinearWithShardedLoRA
,
QKVParallelLinearWithShardedLoRA
,
QKVParallelLinearWithShardedLoRA
,
...
...
vllm/model_executor/layers/fused_moe/router/gate_linear.py
View file @
b1169d7b
...
@@ -3,9 +3,11 @@
...
@@ -3,9 +3,11 @@
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
import
vllm._custom_ops
as
ops
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
@
PluggableLayer
.
register
(
"gate_linear"
)
@
PluggableLayer
.
register
(
"gate_linear"
)
...
@@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear):
...
@@ -13,8 +15,9 @@ class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
"""MoE gate linear layer with three-tier GEMM dispatch:
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
2. gpt-oss specialized kernel (SM90+, batch<=128, supported dims)
3. F.linear via ReplicatedLinear (ultimate fallback)
3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
4. F.linear via ReplicatedLinear (ultimate fallback)
The ``out_dtype`` attribute is mutable and can be set after init
The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
(e.g. when the required dtype depends on the expert quantization
...
@@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear):
...
@@ -25,6 +28,10 @@ class GateLinear(ReplicatedLinear):
DSV3_SUPPORTED_NUM_EXPERTS
=
[
256
,
384
]
DSV3_SUPPORTED_NUM_EXPERTS
=
[
256
,
384
]
DSV3_SUPPORTED_HIDDEN_SIZES
=
[
7168
]
DSV3_SUPPORTED_HIDDEN_SIZES
=
[
7168
]
# Dimensions supported by the gpt-oss specialized kernel
GPT_OSS_SUPPORTED_NUM_EXPERTS
=
[
32
,
128
]
GPT_OSS_SUPPORTED_HIDDEN_SIZES
=
[
2880
]
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
...
@@ -65,6 +72,15 @@ class GateLinear(ReplicatedLinear):
...
@@ -65,6 +72,15 @@ class GateLinear(ReplicatedLinear):
and
input_size
in
self
.
DSV3_SUPPORTED_HIDDEN_SIZES
and
input_size
in
self
.
DSV3_SUPPORTED_HIDDEN_SIZES
)
)
# gpt-oss specialized kernel eligibility (SM90+, exact dims)
self
.
allow_gpt_oss_router_gemm
=
(
self
.
weight
.
dtype
==
torch
.
bfloat16
and
current_platform
.
is_cuda
()
and
is_hopper_or_blackwell
and
output_size
in
self
.
GPT_OSS_SUPPORTED_NUM_EXPERTS
and
input_size
in
self
.
GPT_OSS_SUPPORTED_HIDDEN_SIZES
)
# cuBLAS bf16→fp32 eligibility
# cuBLAS bf16→fp32 eligibility
self
.
allow_cublas_router_gemm
=
(
self
.
allow_cublas_router_gemm
=
(
self
.
allow_specialized_router_gemm
self
.
allow_specialized_router_gemm
...
@@ -92,8 +108,6 @@ class GateLinear(ReplicatedLinear):
...
@@ -92,8 +108,6 @@ class GateLinear(ReplicatedLinear):
def
forward
(
def
forward
(
self
,
x
:
torch
.
Tensor
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
import
vllm._custom_ops
as
ops
# Tier 1: DSV3 specialized kernel
# Tier 1: DSV3 specialized kernel
if
self
.
allow_dsv3_router_gemm
and
x
.
shape
[
0
]
<=
16
:
if
self
.
allow_dsv3_router_gemm
and
x
.
shape
[
0
]
<=
16
:
output
=
ops
.
dsv3_router_gemm
(
output
=
ops
.
dsv3_router_gemm
(
...
@@ -103,15 +117,47 @@ class GateLinear(ReplicatedLinear):
...
@@ -103,15 +117,47 @@ class GateLinear(ReplicatedLinear):
)
)
return
output
,
None
return
output
,
None
# Tier 2: cuBLAS bf16→fp32
# Tier 2: gpt-oss specialized kernel
if
self
.
allow_gpt_oss_router_gemm
:
output
=
torch
.
ops
.
vllm
.
gpt_oss_router_gemm
(
x
,
self
.
weight
,
self
.
bias
)
return
output
,
None
# Tier 3: cuBLAS bf16→fp32
if
self
.
allow_cublas_router_gemm
and
x
.
dtype
==
torch
.
bfloat16
:
if
self
.
allow_cublas_router_gemm
and
x
.
dtype
==
torch
.
bfloat16
:
output
=
ops
.
router_gemm_bf16_fp32
(
x
,
self
.
weight
)
output
=
ops
.
router_gemm_bf16_fp32
(
x
,
self
.
weight
)
return
output
,
None
return
output
,
None
# Tier
3
: F.linear (ReplicatedLinear)
# Tier
4
: F.linear (ReplicatedLinear)
if
self
.
out_dtype
is
not
None
and
x
.
dtype
!=
self
.
weight
.
dtype
:
if
self
.
out_dtype
is
not
None
and
x
.
dtype
!=
self
.
weight
.
dtype
:
x
=
x
.
to
(
self
.
weight
.
dtype
)
x
=
x
.
to
(
self
.
weight
.
dtype
)
output
,
output_bias
=
super
().
forward
(
x
)
output
,
output_bias
=
super
().
forward
(
x
)
if
self
.
out_dtype
is
not
None
and
output
.
dtype
!=
self
.
out_dtype
:
if
self
.
out_dtype
is
not
None
and
output
.
dtype
!=
self
.
out_dtype
:
output
=
output
.
to
(
self
.
out_dtype
)
output
=
output
.
to
(
self
.
out_dtype
)
return
output
,
output_bias
return
output
,
output_bias
def
gpt_oss_router_gemm_impl
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Dynamically run min-latency gemm if num_tokens <= 128.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
if
x
.
shape
[
0
]
<=
128
:
return
ops
.
gpt_oss_router_gemm
(
x
,
weight
,
bias
)
else
:
return
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
def
gpt_oss_router_gemm_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
x
.
new_empty
((
x
.
shape
[
0
],
weight
.
shape
[
0
]))
direct_register_custom_op
(
op_name
=
"gpt_oss_router_gemm"
,
op_func
=
gpt_oss_router_gemm_impl
,
fake_impl
=
gpt_oss_router_gemm_fake
,
)
vllm/model_executor/models/gpt_oss.py
View file @
b1169d7b
...
@@ -20,12 +20,11 @@ from vllm.distributed import (
...
@@ -20,12 +20,11 @@ from vllm.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
GateLinear
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEParallelConfig
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEParallelConfig
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
@@ -175,13 +174,11 @@ class MLPBlock(torch.nn.Module):
...
@@ -175,13 +174,11 @@ class MLPBlock(torch.nn.Module):
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
experts_per_token
=
config
.
num_experts_per_tok
self
.
experts_per_token
=
config
.
num_experts_per_tok
self
.
world_size
=
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
1
self
.
world_size
=
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
1
self
.
router
=
Replic
ate
d
Linear
(
self
.
router
=
G
ateLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
num_local_experts
,
config
.
num_local_experts
,
bias
=
True
,
bias
=
True
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.router"
,
prefix
=
f
"
{
prefix
}
.router"
,
return_bias
=
False
,
)
)
assert
config
.
intermediate_size
%
self
.
world_size
==
0
assert
config
.
intermediate_size
%
self
.
world_size
==
0
self
.
experts
=
FusedMoE
(
self
.
experts
=
FusedMoE
(
...
@@ -209,7 +206,7 @@ class MLPBlock(torch.nn.Module):
...
@@ -209,7 +206,7 @@ class MLPBlock(torch.nn.Module):
self
,
x
[:,
:
self
.
hidden_size
],
self
.
router
.
weight
,
self
.
router
.
bias
self
,
x
[:,
:
self
.
hidden_size
],
self
.
router
.
weight
,
self
.
router
.
bias
)
)
else
:
else
:
g
=
self
.
router
(
x
)
g
,
_
=
self
.
router
(
x
)
x
=
self
.
experts
(
hidden_states
=
x
,
router_logits
=
g
)[:,
:
self
.
hidden_size
]
x
=
self
.
experts
(
hidden_states
=
x
,
router_logits
=
g
)[:,
:
self
.
hidden_size
]
if
self
.
is_sequence_parallel
:
if
self
.
is_sequence_parallel
:
...
@@ -273,7 +270,6 @@ class GptOssModel(nn.Module, EagleModelMixin):
...
@@ -273,7 +270,6 @@ class GptOssModel(nn.Module, EagleModelMixin):
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
quant_config
=
vllm_config
.
quant_config
self
.
quant_config
=
vllm_config
.
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
config
.
hidden_size
=
self
.
config
.
hidden_size
self
.
embedding
=
VocabParallelEmbedding
(
self
.
embedding
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
self
.
config
.
hidden_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment