Unverified Commit 5b214b50 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[Refactor] move `deep_gemm_wrapper` out of `quantization` (#11784)

parent 13219e1e
......@@ -6,8 +6,8 @@ import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from sgl_kernel.elementwise import silu_and_mul
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
from sglang.srt.layers.quantization import deep_gemm_wrapper
def _test_accuracy_once(E, M, K, input_dtype, device):
......
......@@ -61,7 +61,6 @@ import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.environ import envs
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
......
......@@ -17,10 +17,10 @@ if is_cuda():
except ImportError as e:
deep_gemm = e
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
......
......@@ -8,9 +8,7 @@ import torch
from tqdm import tqdm
from sglang.srt.environ import envs
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ceil_div, get_bool_env_var
......
......@@ -4,8 +4,8 @@ from typing import Tuple
import torch
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( # noqa: F401
from sglang.srt.layers.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
DEEPGEMM_BLACKWELL,
DEEPGEMM_SCALE_UE8M0,
ENABLE_JIT_DEEPGEMM,
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
from sglang.srt import single_batch_overlap
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
......@@ -19,7 +20,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import (
......
......@@ -105,10 +105,10 @@ class DeepGemmRunnerCore(MoeRunnerCore):
running_state: dict,
) -> torch.Tensor:
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale
......
......@@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher,
BaseDispatcherConfig,
......@@ -20,7 +21,6 @@ from sglang.srt.layers.moe.utils import (
get_moe_runner_backend,
is_tbo_enabled,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
......
......@@ -1007,11 +1007,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.utils import (
get_moe_a2a_backend,
get_moe_runner_backend,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
self.moe_runner_config = moe_runner_config
moe_runner_backend = get_moe_runner_backend()
......
......@@ -23,7 +23,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.utils import (
align,
direct_register_custom_op,
......
......@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Tuple
import torch
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import is_sm100_supported, offloader
......
......@@ -64,6 +64,7 @@ from sglang.srt.eplb.expert_location import (
set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.attention.attention_registry import (
ATTENTION_BACKENDS,
attn_backend_wrapper,
......@@ -75,10 +76,7 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization import (
deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer,
)
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
......
......@@ -28,7 +28,6 @@ import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from sglang.srt import single_batch_overlap
from sglang.srt.configs.model_config import (
get_nsa_index_head_dim,
get_nsa_index_n_heads,
......@@ -48,6 +47,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
......@@ -82,7 +82,6 @@ from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
......
......@@ -44,6 +44,7 @@ from sglang.srt.distributed import (
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
......@@ -62,7 +63,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import zero_experts_compute_triton
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import (
......
......@@ -39,6 +39,7 @@ from torch import nn
from sglang.srt.configs import LongcatFlashConfig
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
......@@ -48,7 +49,6 @@ from sglang.srt.layers.dp_attention import (
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import (
......
......@@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var
......
......@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.communicator import (
CommunicateContext,
......@@ -24,7 +25,6 @@ from sglang.srt.layers.moe.token_dispatcher import (
DeepEPDispatcher,
MooncakeEPDispatcher,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment