Unverified Commit 9d61205d authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[lint] improve ruff check (#11922)


Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
parent 590bc4b7
...@@ -27,7 +27,9 @@ repos: ...@@ -27,7 +27,9 @@ repos:
rev: v0.11.7 rev: v0.11.7
hooks: hooks:
- id: ruff - id: ruff
args: [--select=F401,F821, --fixable=F401] args:
- --select=F401,F821
- --fix
files: ^(benchmark/|docs/|examples/|python/sglang/) files: ^(benchmark/|docs/|examples/|python/sglang/)
exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$ exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
......
...@@ -167,6 +167,7 @@ class MiniMaxText01LightningAttention(nn.Module): ...@@ -167,6 +167,7 @@ class MiniMaxText01LightningAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False, use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None, slope_rate: Optional[torch.Tensor] = None,
do_eval: bool = False,
**kwargs, **kwargs,
): ):
if (not self.training) and (not do_eval): if (not self.training) and (not do_eval):
......
import itertools import itertools
import logging
import math import math
import os import os
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -10,6 +11,8 @@ import triton ...@@ -10,6 +11,8 @@ import triton
import triton.language as tl import triton.language as tl
from einops import rearrange from einops import rearrange
logger = logging.getLogger(__name__)
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py # Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
@triton.jit @triton.jit
...@@ -302,6 +305,7 @@ class MiniMaxText01LightningAttention(nn.Module): ...@@ -302,6 +305,7 @@ class MiniMaxText01LightningAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False, use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None, slope_rate: Optional[torch.Tensor] = None,
do_eval: bool = False,
**kwargs, **kwargs,
): ):
if (not self.training) and (not do_eval): if (not self.training) and (not do_eval):
......
...@@ -16,6 +16,7 @@ import argparse ...@@ -16,6 +16,7 @@ import argparse
import dataclasses import dataclasses
import itertools import itertools
import json import json
import logging
import multiprocessing import multiprocessing
import os import os
import random import random
...@@ -39,6 +40,8 @@ from sglang.srt.server_args import ServerArgs ...@@ -39,6 +40,8 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_blackwell, kill_process_tree from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary from sglang.test.test_utils import is_in_ci, write_github_step_summary
logger = logging.getLogger(__name__)
class ProfileLinks(BaseModel): class ProfileLinks(BaseModel):
"""Pydantic model for profile trace links.""" """Pydantic model for profile trace links."""
......
...@@ -77,8 +77,8 @@ class CommonKVManager(BaseKVManager): ...@@ -77,8 +77,8 @@ class CommonKVManager(BaseKVManager):
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._register_to_bootstrap() self._register_to_bootstrap()
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} self.transfer_infos = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.decode_kv_args_table = {}
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
......
...@@ -9,7 +9,7 @@ import struct ...@@ -9,7 +9,7 @@ import struct
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
......
...@@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional, ...@@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional,
import jinja2 import jinja2
import openai.types.responses as openai_responses_types import openai.types.responses as openai_responses_types
import orjson
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from openai.types.responses import ( from openai.types.responses import (
...@@ -1063,7 +1064,7 @@ class OpenAIServingResponses(OpenAIServingChat): ...@@ -1063,7 +1064,7 @@ class OpenAIServingResponses(OpenAIServingChat):
): ):
function_name = previous_item.recipient[len("browser.") :] function_name = previous_item.recipient[len("browser.") :]
action = None action = None
parsed_args = ororjson.loads(previous_item.content[0].text) parsed_args = orjson.loads(previous_item.content[0].text)
if function_name == "search": if function_name == "search":
action = openai_responses_types.response_function_web_search.ActionSearch( action = openai_responses_types.response_function_web_search.ActionSearch(
type="search", type="search",
......
...@@ -194,7 +194,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -194,7 +194,7 @@ class FlashInferAttnBackend(AttentionBackend):
) )
if init_new_workspace: if init_new_workspace:
self.workspace_buffer = torch.empty( self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size, envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8, dtype=torch.uint8,
device=model_runner.device, device=model_runner.device,
) )
......
...@@ -38,6 +38,9 @@ from sglang.srt.utils import ( ...@@ -38,6 +38,9 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMlaAttnBackend,
)
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
...@@ -66,7 +69,7 @@ global_workspace_buffer = None ...@@ -66,7 +69,7 @@ global_workspace_buffer = None
class FlashInferMhaChunkKVRunner: class FlashInferMhaChunkKVRunner:
def __init__( def __init__(
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend" self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
): ):
# Parse Constants # Parse Constants
self.num_local_heads = ( self.num_local_heads = (
......
...@@ -13,7 +13,8 @@ from triton_kernels.matmul_ogs import ( ...@@ -13,7 +13,8 @@ from triton_kernels.matmul_ogs import (
PrecisionConfig, PrecisionConfig,
matmul_ogs, matmul_ogs,
) )
from triton_kernels.numerics import InFlexData from triton_kernels.numerics import InFlexData, MicroscalingCtx
from triton_kernels.quantization import downcast_to_mxfp
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.swiglu import swiglu_fn from triton_kernels.swiglu import swiglu_fn
...@@ -119,14 +120,14 @@ def triton_kernel_fused_experts( ...@@ -119,14 +120,14 @@ def triton_kernel_fused_experts(
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported" assert per_channel_quant is False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported" assert expert_map is None, "expert_map is not supported"
assert w1_scale == None, "w1_scale is not supported" assert w1_scale is None, "w1_scale is not supported"
assert w2_scale == None, "w2_scale is not supported" assert w2_scale is None, "w2_scale is not supported"
assert a1_scale == None, "a1_scale is not supported" assert a1_scale is None, "a1_scale is not supported"
assert a2_scale == None, "a2_scale is not supported" assert a2_scale is None, "a2_scale is not supported"
assert block_shape == None, "block_shape is not supported" assert block_shape is None, "block_shape is not supported"
# type check # type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
...@@ -143,7 +144,7 @@ def triton_kernel_fused_experts( ...@@ -143,7 +144,7 @@ def triton_kernel_fused_experts(
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}" ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
# feature check # feature check
assert inplace == False, "Inplace is not supported in new triton MoE kernel" assert inplace is False, "Inplace is not supported in new triton MoE kernel"
M, K = hidden_states.shape M, K = hidden_states.shape
E, _, N = w1.shape E, _, N = w1.shape
...@@ -264,14 +265,14 @@ def triton_kernel_fused_experts_with_bias( ...@@ -264,14 +265,14 @@ def triton_kernel_fused_experts_with_bias(
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported" assert per_channel_quant is False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported" assert expert_map is None, "expert_map is not supported"
assert w1_scale == None, "w1_scale is not supported" assert w1_scale is None, "w1_scale is not supported"
assert w2_scale == None, "w2_scale is not supported" assert w2_scale is None, "w2_scale is not supported"
assert a1_scale == None, "a1_scale is not supported" assert a1_scale is None, "a1_scale is not supported"
assert a2_scale == None, "a2_scale is not supported" assert a2_scale is None, "a2_scale is not supported"
assert block_shape == None, "block_shape is not supported" assert block_shape is None, "block_shape is not supported"
# type check # type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16" assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
...@@ -290,7 +291,7 @@ def triton_kernel_fused_experts_with_bias( ...@@ -290,7 +291,7 @@ def triton_kernel_fused_experts_with_bias(
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}" ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
# feature check # feature check
assert inplace == False, "Inplace is not supported in new triton MoE kernel" assert inplace is False, "Inplace is not supported in new triton MoE kernel"
E, _, _ = w1.shape E, _, _ = w1.shape
......
...@@ -44,6 +44,13 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod ...@@ -44,6 +44,13 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
try: try:
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24 import (
CompressedTensors24,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24 import (
W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16, CompressedTensorsWNA16,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Optional from typing import Any, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
......
...@@ -47,6 +47,8 @@ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 ...@@ -47,6 +47,8 @@ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter from sglang.srt.managers.cache_controller import LayerDoneCounter
from sglang.srt.managers.schedule_batch import Req
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -341,7 +343,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -341,7 +343,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
# For chunk prefill req, we do not need to allocate mamba cache, # For chunk prefill req, we do not need to allocate mamba cache,
# We could use allocated mamba cache instead. # We could use allocated mamba cache instead.
def alloc( def alloc(
self, need_size: int, reqs: Optional[List["Req"]] = None self, need_size: int, reqs: Optional[List[Req]] = None
) -> Optional[List[int]]: ) -> Optional[List[int]]:
select_index = super().alloc(need_size) select_index = super().alloc(need_size)
if select_index == None: if select_index == None:
......
...@@ -110,6 +110,9 @@ def convert_bin_to_safetensor_file( ...@@ -110,6 +110,9 @@ def convert_bin_to_safetensor_file(
dirname = os.path.dirname(sf_filename) dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
from safetensors.torch import save_file
save_file(loaded, sf_filename, metadata={"format": "pt"}) save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size # check file size
......
...@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union ...@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -3499,7 +3500,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -3499,7 +3500,7 @@ class DeepseekV2ForCausalLM(nn.Module):
# temporarily only support DeepSeek V3/R1 # temporarily only support DeepSeek V3/R1
weight_block_size = [128, 128] weight_block_size = [128, 128]
for layer_id in trange( for layer_id in tqdm.trange(
self.config.num_hidden_layers + int(is_nextn), self.config.num_hidden_layers + int(is_nextn),
desc="quant attn to fp8 ue8m0", desc="quant attn to fp8 ue8m0",
): ):
......
...@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi ...@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
"""Inference-only OPT model compatible with HuggingFace weights.""" """Inference-only OPT model compatible with HuggingFace weights."""
import logging
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Optional, Union
...@@ -46,6 +47,9 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -46,6 +47,9 @@ from sglang.srt.model_loader.weight_utils import (
kv_cache_scales_loader, kv_cache_scales_loader,
) )
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
def get_activation(name="relu"): def get_activation(name="relu"):
......
...@@ -42,6 +42,7 @@ import tempfile ...@@ -42,6 +42,7 @@ import tempfile
import threading import threading
import time import time
import traceback import traceback
import types
import uuid import uuid
import warnings import warnings
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
...@@ -55,6 +56,7 @@ from json import JSONDecodeError ...@@ -55,6 +56,7 @@ from json import JSONDecodeError
from multiprocessing.reduction import ForkingPickler from multiprocessing.reduction import ForkingPickler
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict, Dict,
...@@ -62,6 +64,7 @@ from typing import ( ...@@ -62,6 +64,7 @@ from typing import (
List, List,
Optional, Optional,
Protocol, Protocol,
Sequence,
Set, Set,
Tuple, Tuple,
TypeVar, TypeVar,
...@@ -91,6 +94,9 @@ from typing_extensions import Literal ...@@ -91,6 +94,9 @@ from typing_extensions import Literal
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.metrics.func_timer import enable_func_timer
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizeMethodBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
show_time_cost = False show_time_cost = False
...@@ -1076,7 +1082,7 @@ def monkey_patch_vllm_gguf_config(): ...@@ -1076,7 +1082,7 @@ def monkey_patch_vllm_gguf_config():
def get_quant_method_with_embedding_replaced( def get_quant_method_with_embedding_replaced(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return GGUFLinearMethod(self) return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding): elif isinstance(layer, VocabParallelEmbedding):
...@@ -1956,7 +1962,9 @@ def direct_register_custom_op( ...@@ -1956,7 +1962,9 @@ def direct_register_custom_op(
if fake_impl is not None: if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl) my_lib._register_fake(op_name, fake_impl)
except RuntimeError as error: except RuntimeError as error:
if "Tried to register an operator" in str(e) and "multiple times" in str(e): if "Tried to register an operator" in str(error) and "multiple times" in str(
error
):
# Silently ignore duplicate registration errors # Silently ignore duplicate registration errors
# This can happen in multi-engine scenarios # This can happen in multi-engine scenarios
pass pass
......
...@@ -3,6 +3,7 @@ import ast ...@@ -3,6 +3,7 @@ import ast
import asyncio import asyncio
import re import re
import time import time
from typing import Optional
import numpy as np import numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment