Unverified Commit e3938b2f authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: update other MoE models deps (#2156)

parent c211e7b6
...@@ -153,12 +153,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -153,12 +153,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int], num_expert_group: Optional[int],
topk_group: Optional[int], topk_group: Optional[int],
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe raise NotImplementedError("The TPU backend currently does not support MoE.")
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
"""Fused MoE kernel.""" """Fused MoE kernel."""
import functools import functools
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -18,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -18,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if torch.cuda.is_available() or torch.hip.is_available(): if torch.cuda.is_available() or torch.hip.is_available():
from .fused_moe import fused_experts from sglang.srt.layers.triton_fused_moe.fused_moe import fused_experts
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
...@@ -512,7 +514,7 @@ class FusedMoE(torch.nn.Module): ...@@ -512,7 +514,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
): ):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from sglang.srt.layers.triton_fused_moe.fused_moe import (
fused_topk, fused_topk,
grouped_topk, grouped_topk,
) )
......
...@@ -24,7 +24,6 @@ from vllm.distributed import ( ...@@ -24,7 +24,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
...@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import ( ...@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
......
...@@ -26,7 +26,6 @@ from vllm.distributed import ( ...@@ -26,7 +26,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -41,6 +40,7 @@ from sglang.srt.layers.linear import ( ...@@ -41,6 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -22,7 +22,6 @@ import torch ...@@ -22,7 +22,6 @@ import torch
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -27,7 +27,6 @@ from vllm.distributed import ( ...@@ -27,7 +27,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -43,6 +42,7 @@ from sglang.srt.layers.layernorm import RMSNorm ...@@ -43,6 +42,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -26,7 +26,6 @@ from vllm.distributed import ( ...@@ -26,7 +26,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -42,6 +41,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -42,6 +41,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -24,7 +24,6 @@ from vllm.distributed import ( ...@@ -24,7 +24,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -957,6 +957,21 @@ def direct_register_custom_op( ...@@ -957,6 +957,21 @@ def direct_register_custom_op(
fake_impl: Optional[Callable] = None, fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None, target_lib: Optional[Library] = None,
): ):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
import torch.library import torch.library
if hasattr(torch.library, "infer_schema"): if hasattr(torch.library, "infer_schema"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment