Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0b1bdac6
Unverified
Commit
0b1bdac6
authored
Aug 13, 2025
by
wangxiyuan
Committed by
GitHub
Aug 13, 2025
Browse files
[Platform] Custom ops support for FusedMoe (#22509)
Signed-off-by:
wangxiyuan
<
wangxiyuan1007@gmail.com
>
parent
d94e3026
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
8 deletions
+11
-8
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+6
-6
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+3
-1
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
0b1bdac6
...
@@ -682,7 +682,8 @@ def determine_expert_map(
...
@@ -682,7 +682,8 @@ def determine_expert_map(
return
(
local_num_experts
,
expert_map
)
return
(
local_num_experts
,
expert_map
)
class
FusedMoE
(
torch
.
nn
.
Module
):
@
CustomOp
.
register
(
"fused_moe"
)
class
FusedMoE
(
CustomOp
):
"""FusedMoE layer for MoE models.
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
This layer contains both MergedColumnParallel weights (gate_up_proj /
...
...
vllm/model_executor/layers/linear.py
View file @
0b1bdac6
...
@@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
...
@@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
dispatch_unquantized_gemm
()(
layer
,
x
,
layer
.
weight
,
bias
)
return
dispatch_unquantized_gemm
()(
layer
,
x
,
layer
.
weight
,
bias
)
class
LinearBase
(
torch
.
nn
.
Module
):
class
LinearBase
(
CustomOp
):
"""Base linear layer.
"""Base linear layer.
Args:
Args:
...
@@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
...
@@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
prefix
=
prefix
)
prefix
=
prefix
)
self
.
return_bias
=
return_bias
self
.
return_bias
=
return_bias
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
raise
NotImplementedError
@
CustomOp
.
register
(
"replicated_linear"
)
class
ReplicatedLinear
(
LinearBase
):
class
ReplicatedLinear
(
LinearBase
):
"""Replicated linear layer.
"""Replicated linear layer.
...
@@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
...
@@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
param
[
shard_offset
:
shard_offset
+
shard_size
]
=
loaded_weight
param
[
shard_offset
:
shard_offset
+
shard_size
]
=
loaded_weight
@
CustomOp
.
register
(
"column_parallel_linear"
)
class
ColumnParallelLinear
(
LinearBase
):
class
ColumnParallelLinear
(
LinearBase
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
...
@@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
@
CustomOp
.
register
(
"row_parallel_linear"
)
class
RowParallelLinear
(
LinearBase
):
class
RowParallelLinear
(
LinearBase
):
"""Linear layer with row parallelism.
"""Linear layer with row parallelism.
...
@@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
return
s
return
s
@
CustomOp
.
register
(
"qkv_cross_parallel_linear"
)
class
QKVCrossParallelLinear
(
LinearBase
):
class
QKVCrossParallelLinear
(
LinearBase
):
"""Linear layers for efficient cross-attention's QKV transformation.
"""Linear layers for efficient cross-attention's QKV transformation.
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
0b1bdac6
...
@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
...
@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
...
@@ -159,7 +160,8 @@ def get_masked_input_and_mask(
...
@@ -159,7 +160,8 @@ def get_masked_input_and_mask(
return
input_
,
~
vocab_mask
return
input_
,
~
vocab_mask
class
VocabParallelEmbedding
(
torch
.
nn
.
Module
):
@
CustomOp
.
register
(
"vocab_parallel_embedding"
)
class
VocabParallelEmbedding
(
CustomOp
):
"""Embedding parallelized in the vocabulary dimension.
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment