"src/vscode:/vscode.git/clone" did not exist on "2ab4d0e10f0bfd23c0940539debba205801340b4"
Unverified Commit 5b214b50 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[Refactor] move `deep_gemm_wrapper` out of `quantization` (#11784)

parent 13219e1e
...@@ -6,8 +6,8 @@ import triton ...@@ -6,8 +6,8 @@ import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from sgl_kernel.elementwise import silu_and_mul from sgl_kernel.elementwise import silu_and_mul
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
from sglang.srt.layers.quantization import deep_gemm_wrapper
def _test_accuracy_once(E, M, K, input_dtype, device): def _test_accuracy_once(E, M, K, input_dtype, device):
......
...@@ -61,7 +61,6 @@ import torch.distributed as dist ...@@ -61,7 +61,6 @@ import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.environ import envs
from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler from sglang.srt.managers.scheduler import Scheduler
......
...@@ -17,10 +17,10 @@ if is_cuda(): ...@@ -17,10 +17,10 @@ if is_cuda():
except ImportError as e: except ImportError as e:
deep_gemm = e deep_gemm = e
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
......
...@@ -8,9 +8,7 @@ import torch ...@@ -8,9 +8,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ceil_div, get_bool_env_var from sglang.srt.utils import ceil_div, get_bool_env_var
......
...@@ -4,8 +4,8 @@ from typing import Tuple ...@@ -4,8 +4,8 @@ from typing import Tuple
import torch import torch
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils from sglang.srt.layers.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( # noqa: F401 from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
DEEPGEMM_BLACKWELL, DEEPGEMM_BLACKWELL,
DEEPGEMM_SCALE_UE8M0, DEEPGEMM_SCALE_UE8M0,
ENABLE_JIT_DEEPGEMM, ENABLE_JIT_DEEPGEMM,
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch import torch
from sglang.srt import single_batch_overlap from sglang.srt import single_batch_overlap
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
get_deepep_mode, get_deepep_mode,
get_moe_a2a_backend, get_moe_a2a_backend,
...@@ -19,7 +20,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -19,7 +20,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
tma_align_input_scale, tma_align_input_scale,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
......
...@@ -105,10 +105,10 @@ class DeepGemmRunnerCore(MoeRunnerCore): ...@@ -105,10 +105,10 @@ class DeepGemmRunnerCore(MoeRunnerCore):
running_state: dict, running_state: dict,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_masked_post_quant_fwd, silu_and_mul_masked_post_quant_fwd,
) )
from sglang.srt.layers.quantization import deep_gemm_wrapper
hidden_states = runner_input.hidden_states hidden_states = runner_input.hidden_states
hidden_states_scale = runner_input.hidden_states_scale hidden_states_scale = runner_input.hidden_states_scale
......
...@@ -6,6 +6,7 @@ from dataclasses import dataclass ...@@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.token_dispatcher.base import ( from sglang.srt.layers.moe.token_dispatcher.base import (
BaseDispatcher, BaseDispatcher,
BaseDispatcherConfig, BaseDispatcherConfig,
...@@ -20,7 +21,6 @@ from sglang.srt.layers.moe.utils import ( ...@@ -20,7 +21,6 @@ from sglang.srt.layers.moe.utils import (
get_moe_runner_backend, get_moe_runner_backend,
is_tbo_enabled, is_tbo_enabled,
) )
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_int_env_var, get_int_env_var,
......
...@@ -1007,11 +1007,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1007,11 +1007,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
): ):
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.utils import ( from sglang.srt.layers.moe.utils import (
get_moe_a2a_backend, get_moe_a2a_backend,
get_moe_runner_backend, get_moe_runner_backend,
) )
from sglang.srt.layers.quantization import deep_gemm_wrapper
self.moe_runner_config = moe_runner_config self.moe_runner_config = moe_runner_config
moe_runner_backend = get_moe_runner_backend() moe_runner_backend = get_moe_runner_backend()
......
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.utils import ( from sglang.srt.utils import (
align, align,
direct_register_custom_op, direct_register_custom_op,
......
...@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Tuple ...@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import is_sm100_supported, offloader from sglang.srt.utils import is_sm100_supported, offloader
......
...@@ -64,6 +64,7 @@ from sglang.srt.eplb.expert_location import ( ...@@ -64,6 +64,7 @@ from sglang.srt.eplb.expert_location import (
set_global_expert_location_metadata, set_global_expert_location_metadata,
) )
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.attention.attention_registry import ( from sglang.srt.layers.attention.attention_registry import (
ATTENTION_BACKENDS, ATTENTION_BACKENDS,
attn_backend_wrapper, attn_backend_wrapper,
...@@ -75,10 +76,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -75,10 +76,7 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention, initialize_dp_attention,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization import ( from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer,
)
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
......
...@@ -28,7 +28,6 @@ import torch.nn.functional as F ...@@ -28,7 +28,6 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt import single_batch_overlap
from sglang.srt.configs.model_config import ( from sglang.srt.configs.model_config import (
get_nsa_index_head_dim, get_nsa_index_head_dim,
get_nsa_index_n_heads, get_nsa_index_n_heads,
...@@ -48,6 +47,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ...@@ -48,6 +47,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.attention.npu_ops.mla_preprocess import ( from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
...@@ -82,7 +82,6 @@ from sglang.srt.layers.moe import ( ...@@ -82,7 +82,6 @@ from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
......
...@@ -44,6 +44,7 @@ from sglang.srt.distributed import ( ...@@ -44,6 +44,7 @@ from sglang.srt.distributed import (
) )
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
...@@ -62,7 +63,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import zero_experts_compute_triton ...@@ -62,7 +63,6 @@ from sglang.srt.layers.moe.ep_moe.kernels import zero_experts_compute_triton
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
......
...@@ -39,6 +39,7 @@ from torch import nn ...@@ -39,6 +39,7 @@ from torch import nn
from sglang.srt.configs import LongcatFlashConfig from sglang.srt.configs import LongcatFlashConfig
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
...@@ -48,7 +49,6 @@ from sglang.srt.layers.dp_attention import ( ...@@ -48,7 +49,6 @@ from sglang.srt.layers.dp_attention import (
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
......
...@@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Any, Callable, Optional ...@@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Any, Callable, Optional
import torch import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var from sglang.srt.utils import get_int_env_var
......
...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence ...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
import torch import torch
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
CommunicateContext, CommunicateContext,
...@@ -24,7 +25,6 @@ from sglang.srt.layers.moe.token_dispatcher import ( ...@@ -24,7 +25,6 @@ from sglang.srt.layers.moe.token_dispatcher import (
DeepEPDispatcher, DeepEPDispatcher,
MooncakeEPDispatcher, MooncakeEPDispatcher,
) )
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment