Unverified Commit 9b8ebb27 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

move more files under srt/utils (#11285)

parent 758b887a
...@@ -292,7 +292,7 @@ jobs: ...@@ -292,7 +292,7 @@ jobs:
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels] needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels]
if: always() && !failure() && !cancelled() && if: always() && !failure() && !cancelled() &&
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true')) ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
runs-on: 4-gpu-runner runs-on: 4-gpu-h100
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
...@@ -614,7 +614,7 @@ jobs: ...@@ -614,7 +614,7 @@ jobs:
needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels] needs: [check-changes, unit-test-backend-2-gpu, sgl-kernel-build-wheels]
if: always() && !failure() && !cancelled() && if: always() && !failure() && !cancelled() &&
((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true')) ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true'))
runs-on: 4-gpu-runner runs-on: 4-gpu-h100
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
......
...@@ -51,8 +51,8 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator ...@@ -51,8 +51,8 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import get_int_env_var, require_mlp_sync from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -68,7 +68,6 @@ from sglang.srt.managers.scheduler import run_scheduler_process ...@@ -68,7 +68,6 @@ from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer, MultiprocessingSerializer,
assert_pkg_version, assert_pkg_version,
...@@ -82,6 +81,7 @@ from sglang.srt.utils import ( ...@@ -82,6 +81,7 @@ from sglang.srt.utils import (
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
set_ulimit, set_ulimit,
) )
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.version import __version__ from sglang.version import __version__
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -35,8 +35,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -35,8 +35,8 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -12,7 +12,10 @@ from sglang.srt.custom_op import CustomOp ...@@ -12,7 +12,10 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
if is_cuda(): if is_cuda():
import deep_gemm try:
import deep_gemm
except ImportError as e:
deep_gemm = e
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.dp_attention import get_attention_tp_group
......
...@@ -30,9 +30,9 @@ from sglang.srt.layers.quantization.modelopt_quant import ( ...@@ -30,9 +30,9 @@ from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptNvFp4FusedMoEMethod, ModelOptNvFp4FusedMoEMethod,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.offloader import get_offloader
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher import (
......
...@@ -11,7 +11,7 @@ _is_hip = is_hip() ...@@ -11,7 +11,7 @@ _is_hip = is_hip()
@triton.jit @triton.jit
def fused_moe_router_kernel( def fused_moe_router_cudacore_kernel(
input_ptr, # input (bs, hidden_dim) input_ptr, # input (bs, hidden_dim)
moe_router_weight_ptr, # input (num_experts, hidden_dim) moe_router_weight_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk) topk_weights_ptr, # output (bs, topk)
...@@ -114,7 +114,7 @@ def fused_moe_router_kernel( ...@@ -114,7 +114,7 @@ def fused_moe_router_kernel(
# assert not moe_renormalize, "moe weight renormalization not implemented" # assert not moe_renormalize, "moe weight renormalization not implemented"
def fused_moe_router_impl( def fused_moe_router_cudacore(
x: torch.Tensor, x: torch.Tensor,
router_weight: torch.Tensor, router_weight: torch.Tensor,
topk: int, topk: int,
...@@ -138,7 +138,7 @@ def fused_moe_router_impl( ...@@ -138,7 +138,7 @@ def fused_moe_router_impl(
), ),
} }
fused_moe_router_kernel[(bs,)]( fused_moe_router_cudacore_kernel[(bs,)](
x, x,
router_weight, router_weight,
topk_weights, topk_weights,
...@@ -157,7 +157,7 @@ def fused_moe_router_impl( ...@@ -157,7 +157,7 @@ def fused_moe_router_impl(
@triton.jit @triton.jit
def fused_moe_router_large_bs_kernel( def fused_moe_router_tensorcore_kernel(
a_ptr, # input (bs, hidden_dim) a_ptr, # input (bs, hidden_dim)
b_ptr, # input (num_experts, hidden_dim) b_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk) topk_weights_ptr, # output (bs, topk)
...@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel( ...@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel(
topk: tl.constexpr, # only support topk <= 2 topk: tl.constexpr, # only support topk <= 2
moe_softcapping: tl.constexpr, moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported moe_renormalize: tl.constexpr, # not supported
correction_bias_ptr,
is_correction_bias: tl.constexpr,
K: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
stride_am: tl.constexpr, stride_am: tl.constexpr,
stride_bn: tl.constexpr, stride_bn: tl.constexpr,
dp_attn_workaround_flag: tl.constexpr,
): ):
# 1. get block id # 1. get block id
...@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel( ...@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel(
exped = tl.exp(2 * logits_scaled) exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# Add bias after softcapping
if is_correction_bias:
bias = tl.load(
correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :],
mask=expert_mask.T,
other=0.0,
)
logits_softcapped = logits_softcapped + bias
if dp_attn_workaround_flag:
logits_softcapped = tl.where(
logits_softcapped != logits_softcapped, -1e9, logits_softcapped
)
# 5. top1 # 5. top1
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :] arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
cond_top1 = arange_block_size_n < num_experts cond_top1 = arange_block_size_n < num_experts
...@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel( ...@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel(
) )
def fused_moe_router_large_bs_impl( def fused_moe_router_tensorcore(
x: torch.Tensor, x: torch.Tensor,
router_weight: torch.Tensor, router_weight: torch.Tensor,
topk: int, topk: int,
...@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl( ...@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl(
BLOCK_SIZE_M: int, BLOCK_SIZE_M: int,
BLOCK_SIZE_N: int, BLOCK_SIZE_N: int,
BLOCK_SIZE_K: int, BLOCK_SIZE_K: int,
correction_bias: Optional[torch.Tensor] = None,
): ):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1] assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape bs, hidden_dim = x.shape
...@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl( ...@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl(
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device) topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
is_correction_bias = correction_bias is not None
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),) grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
fused_moe_router_large_bs_kernel[grid]( # TODO(ch-wan): temporary workaround for dp attention. We should support masked
# router to skip padded tokens.
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
dp_attn_workaround_flag = is_dp_attention_enabled()
fused_moe_router_tensorcore_kernel[grid](
a_ptr=x, a_ptr=x,
b_ptr=router_weight, b_ptr=router_weight,
topk_weights_ptr=topk_weights, topk_weights_ptr=topk_weights,
...@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl( ...@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl(
moe_softcapping=moe_softcapping, moe_softcapping=moe_softcapping,
moe_renormalize=False, moe_renormalize=False,
K=hidden_dim, K=hidden_dim,
correction_bias_ptr=correction_bias,
is_correction_bias=is_correction_bias,
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_K=BLOCK_SIZE_K,
stride_am=hidden_dim, stride_am=hidden_dim,
stride_bn=hidden_dim, stride_bn=hidden_dim,
dp_attn_workaround_flag=dp_attn_workaround_flag,
) )
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -316,6 +344,7 @@ def fused_moe_router_shim( ...@@ -316,6 +344,7 @@ def fused_moe_router_shim(
topk, topk,
renormalize, renormalize,
correction_bias: Optional[torch.Tensor] = None, correction_bias: Optional[torch.Tensor] = None,
enable_deterministic_inference: bool = False,
): ):
assert not renormalize assert not renormalize
assert ( assert (
...@@ -324,16 +353,22 @@ def fused_moe_router_shim( ...@@ -324,16 +353,22 @@ def fused_moe_router_shim(
) )
bs, hidden_dim = hidden_states.shape bs, hidden_dim = hidden_states.shape
num_experts = gating_output.shape[0] num_experts = gating_output.shape[0]
BLOCK_SIZE_M = 32 BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 256 BLOCK_SIZE_N = max(num_experts, 16)
BLOCK_SIZE_K = (
256 if num_experts < 256 else 64
) # if experts are large, need to use smaller k block or shared memory OOM
if ( if (
bs >= 512 (bs >= 512 or num_experts > 8)
and topk <= 2
and num_experts <= BLOCK_SIZE_N
and hidden_dim % BLOCK_SIZE_K == 0 and hidden_dim % BLOCK_SIZE_K == 0
# we keep using single kernel to avoid non-deterministic behavior
and not enable_deterministic_inference
): ):
return fused_moe_router_large_bs_impl( # if large batch size or large expert, use kernel that uses tensorcore in matmul
return fused_moe_router_tensorcore(
x=hidden_states, x=hidden_states,
router_weight=gating_output, router_weight=gating_output,
topk=topk, topk=topk,
...@@ -341,9 +376,11 @@ def fused_moe_router_shim( ...@@ -341,9 +376,11 @@ def fused_moe_router_shim(
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K, BLOCK_SIZE_K=BLOCK_SIZE_K,
correction_bias=correction_bias,
) )
else: else:
return fused_moe_router_impl( # if smaller, use kernel that does not use tensorcore in matmul
return fused_moe_router_cudacore(
x=hidden_states, x=hidden_states,
router_weight=gating_output, router_weight=gating_output,
topk=topk, topk=topk,
...@@ -380,11 +417,10 @@ class FusedMoeRouter: ...@@ -380,11 +417,10 @@ class FusedMoeRouter:
renormalize=False, renormalize=False,
) )
def forward_vllm( def forward_torch(
self, self,
x: torch.Tensor, x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# g, _ = self.router_linear.forward(x)
g = x.float() @ self.router_linear.weight.T.float() g = x.float() @ self.router_linear.weight.T.float()
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
......
...@@ -2,11 +2,10 @@ from typing import Callable, List, Optional, Tuple ...@@ -2,11 +2,10 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from sglang.srt import offloader
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import is_sm100_supported from sglang.srt.utils import is_sm100_supported, offloader
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -29,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -29,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
align, align,
ceil_div,
get_bool_env_var, get_bool_env_var,
get_cuda_version, get_cuda_version,
get_device_capability, get_device_capability,
......
...@@ -18,8 +18,8 @@ from dataclasses import dataclass, field, fields ...@@ -18,8 +18,8 @@ from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from uuid import uuid4 from uuid import uuid4
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.utils import ConcurrentCounter from sglang.srt.utils import ConcurrentCounter
from sglang.srt.utils.aio_rwlock import RWLock
@dataclass(frozen=True) @dataclass(frozen=True)
......
...@@ -37,13 +37,13 @@ from sglang.srt.managers.io_struct import ( ...@@ -37,13 +37,13 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
bind_port, bind_port,
configure_logger, configure_logger,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died, kill_itself_when_parent_died,
) )
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import TypeBasedDispatcher, get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -36,7 +36,6 @@ else: ...@@ -36,7 +36,6 @@ else:
Image = Any Image = Any
# Parameters for a session
@dataclass @dataclass
class BaseReq(ABC): class BaseReq(ABC):
rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True) rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
...@@ -60,9 +59,11 @@ class BaseBatchReq(ABC): ...@@ -60,9 +59,11 @@ class BaseBatchReq(ABC):
return self.rids return self.rids
# Parameters for a session
@dataclass @dataclass
class SessionParams: class SessionParams:
id: Optional[str] = None id: Optional[str] = None
rid: Optional[str] = None
offset: Optional[int] = None offset: Optional[int] = None
replace: Optional[bool] = None replace: Optional[bool] = None
drop_previous_output: Optional[bool] = None drop_previous_output: Optional[bool] = None
......
...@@ -156,7 +156,6 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -156,7 +156,6 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.tracing.trace import ( from sglang.srt.tracing.trace import (
process_tracing_init, process_tracing_init,
trace_set_proc_propagate_context, trace_set_proc_propagate_context,
...@@ -192,6 +191,7 @@ from sglang.srt.utils.hf_transformers_utils import ( ...@@ -192,6 +191,7 @@ from sglang.srt.utils.hf_transformers_utils import (
get_tokenizer, get_tokenizer,
get_tokenizer_from_processor, get_tokenizer_from_processor,
) )
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.utils import TypeBasedDispatcher, get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -40,7 +40,6 @@ import zmq ...@@ -40,7 +40,6 @@ import zmq
import zmq.asyncio import zmq.asyncio
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.lora.lora_registry import LoRARegistry from sglang.srt.lora.lora_registry import LoRARegistry
...@@ -94,6 +93,7 @@ from sglang.srt.utils import ( ...@@ -94,6 +93,7 @@ from sglang.srt.utils import (
get_zmq_socket, get_zmq_socket,
kill_process_tree, kill_process_tree,
) )
from sglang.srt.utils.aio_rwlock import RWLock
from sglang.srt.utils.hf_transformers_utils import ( from sglang.srt.utils.hf_transformers_utils import (
get_processor, get_processor,
get_tokenizer, get_tokenizer,
......
...@@ -20,7 +20,7 @@ from dataclasses import dataclass ...@@ -20,7 +20,7 @@ from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
""" """
Memory pool. Memory pool.
......
...@@ -117,15 +117,9 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( ...@@ -117,15 +117,9 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
) )
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.offloader import (
create_offloader_from_server_args,
get_offloader,
set_offloader,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer, MultiprocessingSerializer,
cpu_has_amx_support, cpu_has_amx_support,
...@@ -148,7 +142,13 @@ from sglang.srt.utils import ( ...@@ -148,7 +142,13 @@ from sglang.srt.utils import (
set_cuda_arch, set_cuda_arch,
slow_rank_detector, slow_rank_detector,
) )
from sglang.srt.utils.offloader import (
create_offloader_from_server_args,
get_offloader,
set_offloader,
)
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.weight_sync.tensor_bucket import ( from sglang.srt.weight_sync.tensor_bucket import (
FlattenedTensorBucket, FlattenedTensorBucket,
FlattenedTensorMetadata, FlattenedTensorMetadata,
......
...@@ -43,10 +43,8 @@ ...@@ -43,10 +43,8 @@
import copy import copy
import logging import logging
import math
from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -56,10 +54,6 @@ from sglang.srt.configs import KimiVLConfig ...@@ -56,10 +54,6 @@ from sglang.srt.configs import KimiVLConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
......
...@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union ...@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers.activations import ACT2FN, GELUTanh from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
try: try:
...@@ -596,6 +596,8 @@ class MoonVitPretrainedModel(PreTrainedModel): ...@@ -596,6 +596,8 @@ class MoonVitPretrainedModel(PreTrainedModel):
_supports_sdpa = True _supports_sdpa = True
def __init__(self, config: MoonViTConfig, *inputs, **kwargs): def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
from transformers.activations import GELUTanh
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
config = deepcopy(config) config = deepcopy(config)
self.merge_kernel_size = config.merge_kernel_size self.merge_kernel_size = config.merge_kernel_size
......
...@@ -238,6 +238,7 @@ class ServerArgs: ...@@ -238,6 +238,7 @@ class ServerArgs:
log_requests: bool = False log_requests: bool = False
log_requests_level: int = 2 log_requests_level: int = 2
crash_dump_folder: Optional[str] = None crash_dump_folder: Optional[str] = None
crash_on_nan: bool = False
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = False enable_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False enable_metrics_for_all_schedulers: bool = False
...@@ -1733,6 +1734,12 @@ class ServerArgs: ...@@ -1733,6 +1734,12 @@ class ServerArgs:
default=ServerArgs.crash_dump_folder, default=ServerArgs.crash_dump_folder,
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.", help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
) )
parser.add_argument(
"--crash-on-nan",
type=str,
default=ServerArgs.crash_on_nan,
help="Crash the server on nan logprobs.",
)
parser.add_argument( parser.add_argument(
"--show-time-cost", "--show-time-cost",
action="store_true", action="store_true",
......
...@@ -133,9 +133,9 @@ class TiktokenTokenizer: ...@@ -133,9 +133,9 @@ class TiktokenTokenizer:
) )
return self.encode(ret) if tokenize else ret return self.encode(ret) if tokenize else ret
def __call__(self, text, **kwargs): def __call__(self, text: List[str], **kwargs):
return { return {
"input_ids": self.encode(text), "input_ids": [self.encode(x) for x in text],
} }
def init_xgrammar(self): def init_xgrammar(self):
......
# Temporarily do this to avoid changing all imports in the repo # Temporarily do this to avoid changing all imports in the repo
from .common import * from sglang.srt.utils.common import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment