"vscode:/vscode.git/clone" did not exist on "5a290a5644d2eaa721d89e27a2b585a8ca95ce1c"
Unverified Commit e3938b2f authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: update other MoE models deps (#2156)

parent c211e7b6
......@@ -153,12 +153,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
raise NotImplementedError("The TPU backend currently does not support MoE.")
class FusedMoE(torch.nn.Module):
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
"""Fused MoE kernel."""
import functools
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
......@@ -18,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.utils import set_weight_attrs
if torch.cuda.is_available() or torch.hip.is_available():
from .fused_moe import fused_experts
from sglang.srt.layers.triton_fused_moe.fused_moe import fused_experts
else:
fused_experts = None # type: ignore
......@@ -512,7 +514,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
):
from vllm.model_executor.layers.fused_moe.fused_moe import (
from sglang.srt.layers.triton_fused_moe.fused_moe import (
fused_topk,
grouped_topk,
)
......
......@@ -24,7 +24,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig
......@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
......
......@@ -26,7 +26,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -41,6 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......
......@@ -22,7 +22,6 @@ import torch
from torch import nn
from transformers import MixtralConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......
......@@ -27,7 +27,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -43,6 +42,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......
......@@ -26,7 +26,6 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -42,6 +41,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......
......@@ -24,7 +24,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -38,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......
......@@ -957,6 +957,21 @@ def direct_register_custom_op(
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
import torch.library
if hasattr(torch.library, "infer_schema"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment