Unverified Commit c7d57d5b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix CI and style (#12658)

parent 80802c4c
......@@ -2,7 +2,6 @@
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
class KimiLinearConfig(PretrainedConfig):
......@@ -150,6 +149,8 @@ class KimiLinearConfig(PretrainedConfig):
@property
def mamba2_cache_params(self) -> KimiLinearCacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size
shape = KimiLinearStateShape.create(
tp_world_size=get_attention_tp_size(),
num_heads=self.linear_attn_config["num_heads"],
......
......@@ -156,6 +156,7 @@ class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
enable_memory_saver=enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
self.enable_memory_saver = enable_memory_saver
self._init_mamba_pool(
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
)
......
......@@ -3,9 +3,14 @@ import tempfile
from contextlib import nullcontext
import torch
import torch.utils.cpp_extension
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.server_args import get_global_server_args
after_2_8_0 = version.parse(torch.__version__) >= version.parse("2.8.0")
nccl_allocator_source = """
......@@ -60,9 +65,6 @@ _cur_device = None
def is_symmetric_memory_enabled():
# Import here to avoid circular import
from sglang.srt.server_args import get_global_server_args
return get_global_server_args().enable_symm_mem
......@@ -123,7 +125,12 @@ class SymmetricMemoryContext:
_graph_pool_id is not None
), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
if after_2_8_0:
torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id)
else:
torch._C._cuda_endAllocateCurrentStreamToPool(
_cur_device, _graph_pool_id
)
self._mem_pool_ctx.__enter__()
......@@ -137,7 +144,12 @@ class SymmetricMemoryContext:
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id)
if after_2_8_0:
torch._C._cuda_beginAllocateCurrentThreadToPool(
_cur_device, _graph_pool_id
)
else:
torch._C._cuda_beginAllocateToPool(_cur_device, _graph_pool_id)
def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False):
......
......@@ -31,8 +31,6 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -67,6 +65,7 @@ from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
from sglang.srt.utils import (
MultiprocessingSerializer,
assert_pkg_version,
......@@ -513,6 +512,21 @@ class Engine(EngineBase):
self.tokenizer_manager.update_weights_from_disk(obj, None)
)
def update_weights_from_ipc(
self,
zmq_handles: Dict[str, str],
flush_cache: bool = True,
):
"""Update weights from IPC for checkpoint-engine integration."""
obj = UpdateWeightsFromIPCReqInput(
zmq_handles=zmq_handles,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
......@@ -658,21 +672,6 @@ class Engine(EngineBase):
request=None,
)
def update_weights_from_ipc(
self,
zmq_handles: Dict[str, str],
flush_cache: bool = True,
):
"""Update weights from IPC for checkpoint-engine integration."""
obj = UpdateWeightsFromIPCReqInput(
zmq_handles=zmq_handles,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
......@@ -881,14 +880,14 @@ def _launch_subprocesses(
detoken_proc.start()
# Init tokenizer manager first, as the bootstrap server is initialized here
if server_args.tokenizer_worker_num > 1:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
template_manager = None
else:
if server_args.tokenizer_worker_num == 1:
tokenizer_manager, template_manager = _init_tokenizer_manager(
server_args, port_args
)
else:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
template_manager = None
# Wait for the model to finish loading
scheduler_infos = []
......@@ -911,7 +910,6 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info, port_args
......@@ -162,6 +162,7 @@ class LinearBase(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.quant_config = quant_config
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
else:
......
......@@ -269,10 +269,11 @@ class Scheduler(
server_args.speculative_algorithm
)
self.gpu_id = gpu_id
self.page_size = server_args.page_size
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
self.page_size = server_args.page_size
# Distributed rank info
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
......@@ -298,22 +299,12 @@ class Scheduler(
# Init moe config
self.init_moe_config()
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser(
model_type=self.server_args.reasoning_parser, stream_reasoning=False
)
self.tokenizer.think_end_id = self.tokenizer.encode(
reasoning_parser.detector.think_end_token, add_special_tokens=False
)[0]
# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
# Launch a tensor parallel worker
from sglang.srt.managers.tp_worker import TpModelWorker
self.tp_worker = TpModelWorker(
......@@ -327,7 +318,6 @@ class Scheduler(
)
# Launch a draft worker for speculative decoding
draft_worker_kwargs = dict(
gpu_id=gpu_id,
tp_rank=tp_rank,
......@@ -481,10 +471,6 @@ class Scheduler(
)
# Enable preemption for priority scheduling.
self.try_preemption = self.enable_priority_scheduling
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.init_new_token_ratio = min(
envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
* server_args.schedule_conservativeness,
......@@ -511,7 +497,6 @@ class Scheduler(
)
self.offload_tags = set()
self.init_profiler()
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
self.input_blocker = (
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
......@@ -519,18 +504,15 @@ class Scheduler(
else None
)
# Init disaggregation
self.init_disaggregation()
# Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank)
if self.enable_kv_cache_events:
self.init_kv_events(server_args.kv_events_config)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
if envs.SGLANG_LOG_GC.get():
configure_gc_logger()
......@@ -695,6 +677,15 @@ class Scheduler(
revision=server_args.revision,
)
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser(
model_type=self.server_args.reasoning_parser, stream_reasoning=False
)
self.tokenizer.think_end_id = self.tokenizer.encode(
reasoning_parser.detector.think_end_token, add_special_tokens=False
)[0]
def init_memory_pool_and_cache(self):
server_args = self.server_args
......@@ -835,6 +826,9 @@ class Scheduler(
init_embedding_cache(embedding_cache_size * 1024 * 1024)
def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
......
......@@ -858,7 +858,6 @@ class SchedulerOutputProcessorMixin:
prompt_tokens.append(len(req.origin_input_ids))
completion_tokens.append(len(output_ids_))
cached_tokens.append(req.cached_tokens)
retraction_counts.append(req.retraction_count)
if not self.spec_algorithm.is_none():
......
......@@ -196,9 +196,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else server_args.speculative_num_draft_tokens
)
# Initialize tokenizer and processor
set_global_server_args_for_tokenizer(server_args)
# Initialize tokenizer and processor
if self.model_config.is_multimodal:
import_processors("sglang.srt.multimodal.processors")
try:
......@@ -370,6 +370,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if self.server_args.gc_warning_threshold_secs > 0.0:
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
# Dispatcher and communicators
self._result_dispatcher = TypeBasedDispatcher(
[
(
......@@ -387,15 +388,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
(
FreezeGCReq,
lambda x: None,
),
(FreezeGCReq, lambda x: None),
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
(HealthCheckOutput, lambda x: None),
]
)
self.init_communicators(server_args)
async def generate_request(
......@@ -407,8 +404,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if request:
if "trace_context" in request.headers:
if request and "trace_context" in request.headers:
trace_set_remote_propagate_context(request.headers["trace_context"])
if self.server_args.tokenizer_worker_num > 1:
......
......@@ -58,6 +58,7 @@ from sglang.srt.utils.common import (
json_list_type,
nullable_str,
parse_connector_type,
wait_port_available,
xpu_has_xmx_support,
)
from sglang.srt.utils.hf_transformers_utils import check_gguf_file, get_config
......@@ -3763,6 +3764,10 @@ class ServerArgs:
"Please set --chunked-prefill-size -1 when using --multi-item-scoring-delimiter."
)
assert (
self.schedule_conservativeness >= 0
), "schedule_conservativeness must be non-negative"
def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
......@@ -3956,9 +3961,7 @@ def set_global_server_args_for_scheduler(server_args: ServerArgs):
_global_server_args = server_args
def set_global_server_args_for_tokenizer(server_args: ServerArgs):
global _global_server_args
_global_server_args = server_args
set_global_server_args_for_tokenizer = set_global_server_args_for_scheduler
def get_global_server_args() -> ServerArgs:
......@@ -4082,7 +4085,8 @@ class PortArgs:
), "please provide --dist-init-addr as host:port of head node"
dist_init_host, dist_init_port = dist_init_addr
port_base = int(dist_init_port) + 1
dist_init_port = int(dist_init_port)
port_base = dist_init_port + 1
detokenizer_port = port_base + 1
rpc_port = port_base + 2
metrics_ipc_name = port_base + 3
......@@ -4092,6 +4096,25 @@ class PortArgs:
else:
assert worker_ports is not None
scheduler_input_port = worker_ports[dp_rank]
try:
if dp_rank is None:
wait_port_available(dist_init_port, "dist_init_port")
wait_port_available(port_base, "port_base")
wait_port_available(detokenizer_port, "detokenizer_port")
wait_port_available(nccl_port, "nccl_port")
wait_port_available(rpc_port, "rpc_port")
wait_port_available(metrics_ipc_name, "metrics_ipc_name")
# Check scheduler_input_port only for dp.
# Skip check when using worker_ports since the port is already bound by our ZMQ socket
if dp_rank is None or worker_ports is None:
wait_port_available(scheduler_input_port, "scheduler_input_port")
except ValueError as e:
logger.exception(
f"Port is already in use. {dist_init_port=} {port_base=} {detokenizer_port=} {nccl_port=} {scheduler_input_port=}"
)
raise
return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
......
......@@ -1557,7 +1557,7 @@ def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
"max_new_tokens": 500,
},
},
)
......@@ -1584,7 +1584,7 @@ async def send_concurrent_generate_requests(
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
"max_new_tokens": 500,
},
},
) as response:
......@@ -1608,7 +1608,7 @@ async def send_concurrent_generate_requests_with_custom_params(
""",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
"max_new_tokens": 500,
},
}
......
......@@ -2,7 +2,6 @@ import asyncio
import os
import re
import unittest
from concurrent.futures import ThreadPoolExecutor
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
......@@ -37,6 +36,8 @@ class TestMaxQueuedRequests(CustomTestCase):
"1",
"--max-queued-requests", # Enforce max queued request number is 1
"1",
"--attention-backend",
"triton",
),
return_stdout_stderr=(cls.stdout, cls.stderr),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment