Unverified Commit 62797440 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Lint] Add `python/sglang` to ruff F401 checks and remove unused imports in files (#11685)

parent 2614adf9
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -547,7 +547,7 @@ class Indexer(CustomOp): ...@@ -547,7 +547,7 @@ class Indexer(CustomOp):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
layer_id: int, layer_id: int,
) -> torch.Tensor: ) -> torch.Tensor:
import custom_ops import custom_ops # noqa: F401
import torch_npu import torch_npu
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
......
from __future__ import annotations from __future__ import annotations
import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
...@@ -34,18 +33,18 @@ _is_hip = is_hip() ...@@ -34,18 +33,18 @@ _is_hip = is_hip()
if _is_hip: if _is_hip:
try: try:
from aiter import ( from aiter import ( # noqa: F401
flash_attn_varlen_func, flash_attn_varlen_func,
mha_batch_prefill_func, mha_batch_prefill_func,
paged_attention_ragged, paged_attention_ragged,
) )
from aiter.mla import mla_decode_fwd, mla_prefill_fwd from aiter.mla import mla_decode_fwd, mla_prefill_fwd # noqa: F401
except ImportError: except ImportError:
print( print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
) )
else: else:
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from sgl_kernel.flash_attn import flash_attn_with_kvcache
@dataclass(frozen=True) @dataclass(frozen=True)
......
...@@ -372,4 +372,4 @@ if not ( ...@@ -372,4 +372,4 @@ if not (
logger.info( logger.info(
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries." "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
) )
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm # noqa: F401
...@@ -116,8 +116,6 @@ def cutlass_fused_experts_fp8( ...@@ -116,8 +116,6 @@ def cutlass_fused_experts_fp8(
if is_cuda: if is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_group_transpose,
per_token_group_quant_fp8_hopper_moe_mn_major,
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Cutlass W4A8 MoE kernel.""" """Cutlass W4A8 MoE kernel."""
import logging
from typing import Optional from typing import Optional
import torch import torch
......
import logging import logging
from typing import List, Optional
import torch import torch
import triton import triton
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import ceil_div, is_cuda
from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
from typing import Any, Dict, Optional, Union from typing import Optional, Union
import torch import torch
from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
......
...@@ -43,13 +43,7 @@ from sglang.srt.utils import ( ...@@ -43,13 +43,7 @@ from sglang.srt.utils import (
) )
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import ( from flashinfer import RoutingMethodType, fp4_quantize
RoutingMethodType,
fp4_quantize,
reorder_rows_for_gated_act_gemm,
shuffle_matrix_a,
shuffle_matrix_sf_a,
)
_is_hip = is_hip() _is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
......
...@@ -51,7 +51,9 @@ elif _is_hip: ...@@ -51,7 +51,9 @@ elif _is_hip:
if _is_cuda or _is_hip: if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from sgl_kernel import ( # noqa: F401
moe_align_block_size as sgl_moe_align_block_size,
)
@dataclass @dataclass
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum from enum import IntEnum
from functools import cache
from typing import Optional from typing import Optional
import torch import torch
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from contextlib import nullcontext from contextlib import nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe.token_dispatcher.base import ( from sglang.srt.layers.moe.token_dispatcher.base import (
......
...@@ -22,7 +22,7 @@ try: ...@@ -22,7 +22,7 @@ try:
except ImportError: except ImportError:
use_mooncake_ep = False use_mooncake_ep = False
from enum import Enum, IntEnum, auto from enum import Enum, auto
import torch import torch
import torch.distributed as dist import torch.distributed as dist
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
import warnings import warnings
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
......
...@@ -3,7 +3,6 @@ from __future__ import annotations ...@@ -3,7 +3,6 @@ from __future__ import annotations
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch import torch
......
...@@ -5,7 +5,7 @@ from __future__ import annotations ...@@ -5,7 +5,7 @@ from __future__ import annotations
import enum import enum
import logging import logging
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
...@@ -21,14 +21,7 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -21,14 +21,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize, per_tensor_dequantize,
replace_parameter, replace_parameter,
) )
from sglang.srt.utils import ( from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
get_bool_env_var,
is_cpu,
is_cuda,
is_hip,
is_npu,
set_weight_attrs,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...@@ -49,7 +42,7 @@ if _use_aiter: ...@@ -49,7 +42,7 @@ if _use_aiter:
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1 from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
try: try:
import vllm import vllm # noqa: F401
VLLM_AVAILABLE = True VLLM_AVAILABLE = True
except ImportError: except ImportError:
......
...@@ -12,7 +12,7 @@ def _compute_enable_deep_gemm(): ...@@ -12,7 +12,7 @@ def _compute_enable_deep_gemm():
return False return False
try: try:
import deep_gemm import deep_gemm # noqa: F401
except ImportError: except ImportError:
return False return False
......
...@@ -5,7 +5,7 @@ from typing import Tuple ...@@ -5,7 +5,7 @@ from typing import Tuple
import torch import torch
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( # noqa: F401
DEEPGEMM_BLACKWELL, DEEPGEMM_BLACKWELL,
DEEPGEMM_SCALE_UE8M0, DEEPGEMM_SCALE_UE8M0,
ENABLE_JIT_DEEPGEMM, ENABLE_JIT_DEEPGEMM,
...@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) ...@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
if ENABLE_JIT_DEEPGEMM: if ENABLE_JIT_DEEPGEMM:
import deep_gemm import deep_gemm
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK") _SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
......
...@@ -67,7 +67,7 @@ if _is_hip: ...@@ -67,7 +67,7 @@ if _is_hip:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
else: else:
try: try:
import vllm._C import vllm._C # noqa: F401
except ImportError: except ImportError:
raise ImportError("vllm is required when SGLANG_USE_AITER is set to False") raise ImportError("vllm is required when SGLANG_USE_AITER is set to False")
......
...@@ -11,7 +11,6 @@ from torch.nn.parameter import Parameter ...@@ -11,7 +11,6 @@ from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
LinearMethodBase, LinearMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -28,7 +27,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import ( ...@@ -28,7 +27,7 @@ from sglang.srt.layers.quantization.marlin_utils_fp8 import (
prepare_fp8_layer_for_marlin, prepare_fp8_layer_for_marlin,
) )
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import get_bool_env_var, is_cuda from sglang.srt.utils import get_bool_env_var, is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
......
...@@ -199,7 +199,6 @@ class GPTQConfig(QuantizationConfig): ...@@ -199,7 +199,6 @@ class GPTQConfig(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional[LinearMethodBase]: ) -> Optional[LinearMethodBase]:
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment