Unverified Commit 0b1bdac6 authored by wangxiyuan's avatar wangxiyuan Committed by GitHub
Browse files

[Platform] Custom ops support for FusedMoe (#22509)


Signed-off-by: default avatarwangxiyuan <wangxiyuan1007@gmail.com>
parent d94e3026
...@@ -682,7 +682,8 @@ def determine_expert_map( ...@@ -682,7 +682,8 @@ def determine_expert_map(
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
class FusedMoE(torch.nn.Module): @CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj / This layer contains both MergedColumnParallel weights (gate_up_proj /
......
...@@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -16,6 +16,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
...@@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -226,7 +227,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
class LinearBase(torch.nn.Module): class LinearBase(CustomOp):
"""Base linear layer. """Base linear layer.
Args: Args:
...@@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module): ...@@ -269,12 +270,8 @@ class LinearBase(torch.nn.Module):
prefix=prefix) prefix=prefix)
self.return_bias = return_bias self.return_bias = return_bias
def forward(
self, x: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
raise NotImplementedError
@CustomOp.register("replicated_linear")
class ReplicatedLinear(LinearBase): class ReplicatedLinear(LinearBase):
"""Replicated linear layer. """Replicated linear layer.
...@@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear): ...@@ -443,6 +440,7 @@ class MergedReplicatedLinear(ReplicatedLinear):
param[shard_offset:shard_offset + shard_size] = loaded_weight param[shard_offset:shard_offset + shard_size] = loaded_weight
@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase): class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1229,6 +1227,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
@CustomOp.register("row_parallel_linear")
class RowParallelLinear(LinearBase): class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
...@@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase): ...@@ -1405,6 +1404,7 @@ class RowParallelLinear(LinearBase):
return s return s
@CustomOp.register("qkv_cross_parallel_linear")
class QKVCrossParallelLinear(LinearBase): class QKVCrossParallelLinear(LinearBase):
"""Linear layers for efficient cross-attention's QKV transformation. """Linear layers for efficient cross-attention's QKV transformation.
......
...@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter ...@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
...@@ -159,7 +160,8 @@ def get_masked_input_and_mask( ...@@ -159,7 +160,8 @@ def get_masked_input_and_mask(
return input_, ~vocab_mask return input_, ~vocab_mask
class VocabParallelEmbedding(torch.nn.Module): @CustomOp.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment