Unverified Commit be0124bd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename triton_fused_moe -> fused_moe_triton (#2163)

parent fe5d3e81
...@@ -376,7 +376,7 @@ def try_get_optimal_moe_config( ...@@ -376,7 +376,7 @@ def try_get_optimal_moe_config(
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
): ):
from sglang.srt.layers.triton_fused_moe import get_config from sglang.srt.layers.fused_moe_triton import get_config
override_config = get_config() override_config = get_config()
if override_config: if override_config:
......
...@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if torch.cuda.is_available() or torch.hip.is_available(): if torch.cuda.is_available() or torch.hip.is_available():
from sglang.srt.layers.triton_fused_moe.fused_moe import fused_experts from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
...@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module): ...@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
): ):
from sglang.srt.layers.triton_fused_moe.fused_moe import ( from sglang.srt.layers.fused_moe_triton.fused_moe import (
fused_topk, fused_topk,
grouped_topk, grouped_topk,
) )
......
...@@ -68,7 +68,7 @@ def fp8_get_quant_method(self, layer, prefix): ...@@ -68,7 +68,7 @@ def fp8_get_quant_method(self, layer, prefix):
is_layer_skipped, is_layer_skipped,
) )
from sglang.srt.layers.triton_fused_moe.layer import FusedMoE from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
......
...@@ -28,6 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -28,6 +28,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
...@@ -36,7 +37,6 @@ from sglang.srt.layers.linear import ( ...@@ -36,7 +37,6 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
......
...@@ -30,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -30,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -40,7 +41,6 @@ from sglang.srt.layers.linear import ( ...@@ -40,7 +41,6 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -31,6 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -31,6 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -41,7 +42,6 @@ from sglang.srt.layers.linear import ( ...@@ -41,7 +42,6 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -31,7 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -31,7 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe import FusedMoE from sglang.srt.layers.fused_moe_grok import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
......
...@@ -25,6 +25,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size ...@@ -25,6 +25,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
...@@ -35,7 +36,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -35,7 +36,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -38,11 +38,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -38,11 +38,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -30,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -30,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
...@@ -41,7 +42,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -41,7 +42,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -34,10 +34,10 @@ from vllm.model_executor.layers.linear import ( ...@@ -34,10 +34,10 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import fused_moe
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment