Unverified Commit b509db58 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: remove the dependency on FusedMoE (#2153)

parent dbe17293
......@@ -57,12 +57,23 @@ __all__ = [
"QUANTIZATION_METHODS",
]
"""
def fp8_get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
def fp8_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod,
Fp8MoEMethod,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from sglang.srt.layers.triton_fused_moe.layer import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
from sglang.srt.layers.linear import UnquantizedLinearMethod
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
......@@ -71,4 +82,3 @@ def fp8_get_quant_method(
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
"""
from contextlib import contextmanager
from typing import Any, Dict, Optional
import sglang.srt.layers.triton_fused_moe.fused_moe # noqa
from sglang.srt.layers.triton_fused_moe.fused_moe import (
fused_experts,
fused_topk,
get_config_file_name,
grouped_topk,
)
from sglang.srt.layers.triton_fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
_config: Optional[Dict[str, Any]] = None
@contextmanager
def override_config(config):
global _config
old_config = _config
_config = config
yield
_config = old_config
def get_config() -> Optional[Dict[str, Any]]:
return _config
__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"override_config",
"get_config",
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
This diff is collapsed.
This diff is collapsed.
......@@ -27,7 +27,6 @@ from vllm.distributed import (
get_tp_group,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -42,6 +41,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.triton_fused_moe import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......
......@@ -31,7 +31,7 @@ import time
import warnings
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
import numpy as np
import psutil
......@@ -45,6 +45,7 @@ from packaging import version as pkg_version
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import (
FileCacheManager,
......@@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity():
def crash_on_warnings():
# Crash on warning if we are running CI tests
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
def get_device_name(device_id: int = 0) -> str:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return torch.cuda.get_device_name(device_id)
if hasattr(torch, "hip") and torch.hip.is_available():
return torch.hip.get_device_name(device_id)
if hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.xpu.get_device_name(device_id)
if hasattr(torch, "hpu") and torch.hpu.is_available():
return torch.hpu.get_device_name(device_id)
sglang_lib = Library("sglang", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: List[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
):
import torch.library
if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or sglang_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment