"src/graph/vscode:/vscode.git/clone" did not exist on "1c9d2a03023c64380d69b24f6e6bd0393417f69d"
Unverified Commit 9b8ebb27 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

move more files under srt/utils (#11285)

parent 758b887a
......@@ -292,7 +292,7 @@ jobs:
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels]
if: always() && !failure() && !cancelled() &&
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
runs-on: 4-gpu-runner
runs-on: 4-gpu-h100
strategy:
fail-fast: false
matrix:
......@@ -614,7 +614,7 @@ jobs:
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels]
if: always() && !failure() && !cancelled() &&
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
runs-on: 4-gpu-runner
runs-on: 4-gpu-h100
steps:
- name: Checkout code
uses: actions/checkout@v4
......
......@@ -51,8 +51,8 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__)
......
......@@ -68,7 +68,6 @@ from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
assert_pkg_version,
......@@ -82,6 +81,7 @@ from sglang.srt.utils import (
set_prometheus_multiproc_dir,
set_ulimit,
)
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.version import __version__
logger = logging.getLogger(__name__)
......
......@@ -35,8 +35,8 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
......
......@@ -12,7 +12,10 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
if is_cuda():
import deep_gemm
try:
import deep_gemm
except ImportError as e:
deep_gemm = e
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.dp_attention import get_attention_tp_group
......
......@@ -30,9 +30,9 @@ from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptNvFp4FusedMoEMethod,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.offloader import get_offloader
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
......
......@@ -11,7 +11,7 @@ _is_hip = is_hip()
@triton.jit
def fused_moe_router_kernel(
def fused_moe_router_cudacore_kernel(
input_ptr, # input (bs, hidden_dim)
moe_router_weight_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
......@@ -114,7 +114,7 @@ def fused_moe_router_kernel(
# assert not moe_renormalize, "moe weight renormalization not implemented"
def fused_moe_router_impl(
def fused_moe_router_cudacore(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
......@@ -138,7 +138,7 @@ def fused_moe_router_impl(
),
}
fused_moe_router_kernel[(bs,)](
fused_moe_router_cudacore_kernel[(bs,)](
x,
router_weight,
topk_weights,
......@@ -157,7 +157,7 @@ def fused_moe_router_impl(
@triton.jit
def fused_moe_router_large_bs_kernel(
def fused_moe_router_tensorcore_kernel(
a_ptr, # input (bs, hidden_dim)
b_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
......@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel(
topk: tl.constexpr, # only support topk <= 2
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
correction_bias_ptr,
is_correction_bias: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
stride_am: tl.constexpr,
stride_bn: tl.constexpr,
dp_attn_workaround_flag: tl.constexpr,
):
# 1. get block id
......@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel(
exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# Add bias after softcapping
if is_correction_bias:
bias = tl.load(
correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :],
mask=expert_mask.T,
other=0.0,
)
logits_softcapped = logits_softcapped + bias
if dp_attn_workaround_flag:
logits_softcapped = tl.where(
logits_softcapped != logits_softcapped, -1e9, logits_softcapped
)
# 5. top1
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
cond_top1 = arange_block_size_n < num_experts
......@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel(
)
def fused_moe_router_large_bs_impl(
def fused_moe_router_tensorcore(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
......@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl(
BLOCK_SIZE_M: int,
BLOCK_SIZE_N: int,
BLOCK_SIZE_K: int,
correction_bias: Optional[torch.Tensor] = None,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
......@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl(
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
is_correction_bias = correction_bias is not None
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
fused_moe_router_large_bs_kernel[grid](
# TODO(ch-wan): temporary workaround for dp attention. We should support masked
# router to skip padded tokens.
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
dp_attn_workaround_flag = is_dp_attention_enabled()
fused_moe_router_tensorcore_kernel[grid](
a_ptr=x,
b_ptr=router_weight,
topk_weights_ptr=topk_weights,
......@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl(
moe_softcapping=moe_softcapping,
moe_renormalize=False,
K=hidden_dim,
correction_bias_ptr=correction_bias,
is_correction_bias=is_correction_bias,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
stride_am=hidden_dim,
stride_bn=hidden_dim,
dp_attn_workaround_flag=dp_attn_workaround_flag,
)
return topk_weights, topk_ids
......@@ -316,6 +344,7 @@ def fused_moe_router_shim(
topk,
renormalize,
correction_bias: Optional[torch.Tensor] = None,
enable_deterministic_inference: bool = False,
):
assert not renormalize
assert (
......@@ -324,16 +353,22 @@ def fused_moe_router_shim(
)
bs, hidden_dim = hidden_states.shape
num_experts = gating_output.shape[0]
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 256
BLOCK_SIZE_N = max(num_experts, 16)
BLOCK_SIZE_K = (
256 if num_experts < 256 else 64
) # if experts are large, need to use smaller k block or shared memory OOM
if (
bs >= 512
and topk <= 2
and num_experts <= BLOCK_SIZE_N
(bs >= 512 or num_experts > 8)
and hidden_dim % BLOCK_SIZE_K == 0
# we keep using single kernel to avoid non-deterministic behavior
and not enable_deterministic_inference
):
return fused_moe_router_large_bs_impl(
# if large batch size or large expert, use kernel that uses tensorcore in matmul
return fused_moe_router_tensorcore(
x=hidden_states,
router_weight=gating_output,
topk=topk,
......@@ -341,9 +376,11 @@ def fused_moe_router_shim(
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
correction_bias=correction_bias,
)
else:
return fused_moe_router_impl(
# if smaller, use kernel that does not use tensorcore in matmul
return fused_moe_router_cudacore(
x=hidden_states,
router_weight=gating_output,
topk=topk,
......@@ -380,11 +417,10 @@ class FusedMoeRouter:
renormalize=False,
)
def forward_vllm(
def forward_torch(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# g, _ = self.router_linear.forward(x)
g = x.float() @ self.router_linear.weight.T.float()
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
......
......@@ -2,11 +2,10 @@ from typing import Callable, List, Optional, Tuple
import torch
from sglang.srt import offloader
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import is_sm100_supported
from sglang.srt.utils import is_sm100_supported, offloader
try:
from vllm import _custom_ops as ops
......@@ -29,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.utils import (
align,
ceil_div,
get_bool_env_var,
get_cuda_version,
get_device_capability,
......
......@@ -18,8 +18,8 @@ from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.utils import ConcurrentCounter
from sglang.srt.utils.aio_rwlock import RWLock
@dataclass(frozen=True)
......
......@@ -37,13 +37,13 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
bind_port,
configure_logger,
get_zmq_socket,
kill_itself_when_parent_died,
)
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__)
......
......@@ -36,7 +36,6 @@ else:
Image = Any
# Parameters for a session
@dataclass
class BaseReq(ABC):
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
......@@ -60,9 +59,11 @@ class BaseBatchReq(ABC):
return self.rids
# Parameters for a session
@dataclass
class SessionParams:
id: Optional[str] = None
rid: Optional[str] = None
offset: Optional[int] = None
replace: Optional[bool] = None
drop_previous_output: Optional[bool] = None
......
......@@ -156,7 +156,6 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_set_proc_propagate_context,
......@@ -192,6 +191,7 @@ from sglang.srt.utils.hf_transformers_utils import (
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__)
......
......@@ -40,7 +40,6 @@ import zmq
import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.lora.lora_registry import LoRARegistry
......@@ -94,6 +93,7 @@ from sglang.srt.utils import (
get_zmq_socket,
kill_process_tree,
)
from sglang.srt.utils.aio_rwlock import RWLock
from sglang.srt.utils.hf_transformers_utils import (
get_processor,
get_tokenizer,
......
......@@ -20,7 +20,7 @@ from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
Memory pool.
......
......@@ -117,15 +117,9 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.offloader import (
create_offloader_from_server_args,
get_offloader,
set_offloader,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
MultiprocessingSerializer,
cpu_has_amx_support,
......@@ -148,7 +142,13 @@ from sglang.srt.utils import (
set_cuda_arch,
slow_rank_detector,
)
from sglang.srt.utils.offloader import (
create_offloader_from_server_args,
get_offloader,
set_offloader,
)
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.weight_sync.tensor_bucket import (
FlattenedTensorBucket,
FlattenedTensorMetadata,
......
......@@ -43,10 +43,8 @@
import copy
import logging
import math
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
......@@ -56,10 +54,6 @@ from sglang.srt.configs import KimiVLConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
......
......@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN, GELUTanh
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
try:
......@@ -596,6 +596,8 @@ class MoonVitPretrainedModel(PreTrainedModel):
_supports_sdpa = True
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
from transformers.activations import GELUTanh
super().__init__(config, *inputs, **kwargs)
config = deepcopy(config)
self.merge_kernel_size = config.merge_kernel_size
......
......@@ -238,6 +238,7 @@ class ServerArgs:
log_requests: bool = False
log_requests_level: int = 2
crash_dump_folder: Optional[str] = None
crash_on_nan: bool = False
show_time_cost: bool = False
enable_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False
......@@ -1733,6 +1734,12 @@ class ServerArgs:
default=ServerArgs.crash_dump_folder,
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
)
parser.add_argument(
"--crash-on-nan",
type=str,
default=ServerArgs.crash_on_nan,
help="Crash the server on nan logprobs.",
)
parser.add_argument(
"--show-time-cost",
action="store_true",
......
......@@ -133,9 +133,9 @@ class TiktokenTokenizer:
)
return self.encode(ret) if tokenize else ret
def __call__(self, text, **kwargs):
def __call__(self, text: List[str], **kwargs):
return {
"input_ids": self.encode(text),
"input_ids": [self.encode(x) for x in text],
}
def init_xgrammar(self):
......
# Temporarily do this to avoid changing all imports in the repo
from .common import *
from sglang.srt.utils.common import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment