"docs/vscode:/vscode.git/clone" did not exist on "7c052cead232cba0b321459b988c776f801ee163"
Unverified Commit 9d61205d authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[lint] improve ruff check (#11922)


Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
parent 590bc4b7
......@@ -27,7 +27,9 @@ repos:
rev: v0.11.7
hooks:
- id: ruff
args: [--select=F401,F821, --fixable=F401]
args:
- --select=F401,F821
- --fix
files: ^(benchmark/|docs/|examples/|python/sglang/)
exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
- repo: https://github.com/psf/black
......
......@@ -167,6 +167,7 @@ class MiniMaxText01LightningAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None,
do_eval: bool = False,
**kwargs,
):
if (not self.training) and (not do_eval):
......
import itertools
import logging
import math
import os
from typing import Optional, Tuple
......@@ -10,6 +11,8 @@ import triton
import triton.language as tl
from einops import rearrange
logger = logging.getLogger(__name__)
# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py
@triton.jit
......@@ -302,6 +305,7 @@ class MiniMaxText01LightningAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
slope_rate: Optional[torch.Tensor] = None,
do_eval: bool = False,
**kwargs,
):
if (not self.training) and (not do_eval):
......
......@@ -16,6 +16,7 @@ import argparse
import dataclasses
import itertools
import json
import logging
import multiprocessing
import os
import random
......@@ -39,6 +40,8 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary
logger = logging.getLogger(__name__)
class ProfileLinks(BaseModel):
"""Pydantic model for profile trace links."""
......
......@@ -77,8 +77,8 @@ class CommonKVManager(BaseKVManager):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._register_to_bootstrap()
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.transfer_infos = {}
self.decode_kv_args_table = {}
self.pp_group = get_pp_group()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
......
......@@ -9,7 +9,7 @@ import struct
import threading
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
import numpy.typing as npt
......
......@@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional,
import jinja2
import openai.types.responses as openai_responses_types
import orjson
from fastapi import Request
from fastapi.responses import ORJSONResponse
from openai.types.responses import (
......@@ -1063,7 +1064,7 @@ class OpenAIServingResponses(OpenAIServingChat):
):
function_name = previous_item.recipient[len("browser.") :]
action = None
parsed_args = ororjson.loads(previous_item.content[0].text)
parsed_args = orjson.loads(previous_item.content[0].text)
if function_name == "search":
action = openai_responses_types.response_function_web_search.ActionSearch(
type="search",
......
......@@ -194,7 +194,7 @@ class FlashInferAttnBackend(AttentionBackend):
)
if init_new_workspace:
self.workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
dtype=torch.uint8,
device=model_runner.device,
)
......
......@@ -38,6 +38,9 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMlaAttnBackend,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput
......@@ -66,7 +69,7 @@ global_workspace_buffer = None
class FlashInferMhaChunkKVRunner:
def __init__(
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
):
# Parse Constants
self.num_local_heads = (
......
......@@ -13,7 +13,8 @@ from triton_kernels.matmul_ogs import (
PrecisionConfig,
matmul_ogs,
)
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics import InFlexData, MicroscalingCtx
from triton_kernels.quantization import downcast_to_mxfp
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.swiglu import swiglu_fn
......@@ -119,14 +120,14 @@ def triton_kernel_fused_experts(
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported"
assert w1_scale == None, "w1_scale is not supported"
assert w2_scale == None, "w2_scale is not supported"
assert a1_scale == None, "a1_scale is not supported"
assert a2_scale == None, "a2_scale is not supported"
assert block_shape == None, "block_shape is not supported"
assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
assert per_channel_quant is False, "per_channel_quant is not supported"
assert expert_map is None, "expert_map is not supported"
assert w1_scale is None, "w1_scale is not supported"
assert w2_scale is None, "w2_scale is not supported"
assert a1_scale is None, "a1_scale is not supported"
assert a2_scale is None, "a2_scale is not supported"
assert block_shape is None, "block_shape is not supported"
# type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
......@@ -143,7 +144,7 @@ def triton_kernel_fused_experts(
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
# feature check
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
assert inplace is False, "Inplace is not supported in new triton MoE kernel"
M, K = hidden_states.shape
E, _, N = w1.shape
......@@ -264,14 +265,14 @@ def triton_kernel_fused_experts_with_bias(
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
) -> torch.Tensor:
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported"
assert w1_scale == None, "w1_scale is not supported"
assert w2_scale == None, "w2_scale is not supported"
assert a1_scale == None, "a1_scale is not supported"
assert a2_scale == None, "a2_scale is not supported"
assert block_shape == None, "block_shape is not supported"
assert use_fp8_w8a8 is False, "use_fp8_w8a8 is not supported"
assert per_channel_quant is False, "per_channel_quant is not supported"
assert expert_map is None, "expert_map is not supported"
assert w1_scale is None, "w1_scale is not supported"
assert w2_scale is None, "w2_scale is not supported"
assert a1_scale is None, "a1_scale is not supported"
assert a2_scale is None, "a2_scale is not supported"
assert block_shape is None, "block_shape is not supported"
# type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
......@@ -290,7 +291,7 @@ def triton_kernel_fused_experts_with_bias(
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
# feature check
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
assert inplace is False, "Inplace is not supported in new triton MoE kernel"
E, _, _ = w1.shape
......
......@@ -44,6 +44,13 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
try:
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24 import (
CompressedTensors24,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24 import (
W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16,
......
......@@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import Any, Optional
from typing import Any, List, Optional
import torch
from torch.nn import Module
......
......@@ -47,6 +47,8 @@ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
from sglang.srt.managers.schedule_batch import Req
logger = logging.getLogger(__name__)
......@@ -341,7 +343,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
# For chunk prefill req, we do not need to allocate mamba cache,
# We could use allocated mamba cache instead.
def alloc(
self, need_size: int, reqs: Optional[List["Req"]] = None
self, need_size: int, reqs: Optional[List[Req]] = None
) -> Optional[List[int]]:
select_index = super().alloc(need_size)
if select_index == None:
......
......@@ -110,6 +110,9 @@ def convert_bin_to_safetensor_file(
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
from safetensors.torch import save_file
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
......
......@@ -25,6 +25,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import tqdm
from torch import nn
from transformers import PretrainedConfig
......@@ -3499,7 +3500,7 @@ class DeepseekV2ForCausalLM(nn.Module):
# temporarily only support DeepSeek V3/R1
weight_block_size = [128, 128]
for layer_id in trange(
for layer_id in tqdm.trange(
self.config.num_hidden_layers + int(is_nextn),
desc="quant attn to fp8 ue8m0",
):
......
......@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
......
......@@ -13,6 +13,7 @@
# ==============================================================================
"""Inference-only OPT model compatible with HuggingFace weights."""
import logging
from collections.abc import Iterable
from typing import Optional, Union
......@@ -46,6 +47,9 @@ from sglang.srt.model_loader.weight_utils import (
kv_cache_scales_loader,
)
from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
def get_activation(name="relu"):
......
......@@ -42,6 +42,7 @@ import tempfile
import threading
import time
import traceback
import types
import uuid
import warnings
from collections import OrderedDict, defaultdict
......@@ -55,6 +56,7 @@ from json import JSONDecodeError
from multiprocessing.reduction import ForkingPickler
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
......@@ -62,6 +64,7 @@ from typing import (
List,
Optional,
Protocol,
Sequence,
Set,
Tuple,
TypeVar,
......@@ -91,6 +94,9 @@ from typing_extensions import Literal
from sglang.srt.environ import envs
from sglang.srt.metrics.func_timer import enable_func_timer
if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizeMethodBase
logger = logging.getLogger(__name__)
show_time_cost = False
......@@ -1076,7 +1082,7 @@ def monkey_patch_vllm_gguf_config():
def get_quant_method_with_embedding_replaced(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
......@@ -1956,7 +1962,9 @@ def direct_register_custom_op(
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
except RuntimeError as error:
if "Tried to register an operator" in str(e) and "multiple times" in str(e):
if "Tried to register an operator" in str(error) and "multiple times" in str(
error
):
# Silently ignore duplicate registration errors
# This can happen in multi-engine scenarios
pass
......
......@@ -3,6 +3,7 @@ import ast
import asyncio
import re
import time
from typing import Optional
import numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment