Unverified Commit c7d57d5b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix CI and style (#12658)

parent 80802c4c
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape from sglang.srt.configs.mamba_utils import KimiLinearCacheParams, KimiLinearStateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
class KimiLinearConfig(PretrainedConfig): class KimiLinearConfig(PretrainedConfig):
...@@ -150,6 +149,8 @@ class KimiLinearConfig(PretrainedConfig): ...@@ -150,6 +149,8 @@ class KimiLinearConfig(PretrainedConfig):
@property @property
def mamba2_cache_params(self) -> KimiLinearCacheParams: def mamba2_cache_params(self) -> KimiLinearCacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size
shape = KimiLinearStateShape.create( shape = KimiLinearStateShape.create(
tp_world_size=get_attention_tp_size(), tp_world_size=get_attention_tp_size(),
num_heads=self.linear_attn_config["num_heads"], num_heads=self.linear_attn_config["num_heads"],
......
...@@ -156,6 +156,7 @@ class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool): ...@@ -156,6 +156,7 @@ class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
enable_memory_saver=enable_memory_saver, enable_memory_saver=enable_memory_saver,
pre_alloc_size=pre_alloc_size, pre_alloc_size=pre_alloc_size,
) )
self.enable_memory_saver = enable_memory_saver
self._init_mamba_pool( self._init_mamba_pool(
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
) )
......
...@@ -3,9 +3,14 @@ import tempfile ...@@ -3,9 +3,14 @@ import tempfile
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import torch.utils.cpp_extension
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.server_args import get_global_server_args
after_2_8_0 = version.parse(torch.__version__) >= version.parse("2.8.0")
nccl_allocator_source = """ nccl_allocator_source = """
...@@ -60,9 +65,6 @@ _cur_device = None ...@@ -60,9 +65,6 @@ _cur_device = None
def is_symmetric_memory_enabled(): def is_symmetric_memory_enabled():
# Import here to avoid circular import
from sglang.srt.server_args import get_global_server_args
return get_global_server_args().enable_symm_mem return get_global_server_args().enable_symm_mem
...@@ -123,7 +125,12 @@ class SymmetricMemoryContext: ...@@ -123,7 +125,12 @@ class SymmetricMemoryContext:
_graph_pool_id is not None _graph_pool_id is not None
), "graph_pool_id is not set under graph capture" ), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph # Pause graph memory pool to use symmetric memory with cuda graph
if after_2_8_0:
torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id) torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id)
else:
torch._C._cuda_endAllocateCurrentStreamToPool(
_cur_device, _graph_pool_id
)
self._mem_pool_ctx.__enter__() self._mem_pool_ctx.__enter__()
...@@ -137,7 +144,12 @@ class SymmetricMemoryContext: ...@@ -137,7 +144,12 @@ class SymmetricMemoryContext:
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
if self.is_graph_capture: if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id) if after_2_8_0:
torch._C._cuda_beginAllocateCurrentThreadToPool(
_cur_device, _graph_pool_id
)
else:
torch._C._cuda_beginAllocateToPool(_cur_device, _graph_pool_id)
def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False): def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False):
......
...@@ -31,8 +31,6 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union ...@@ -31,8 +31,6 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq import zmq
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -67,6 +65,7 @@ from sglang.srt.managers.scheduler import run_scheduler_process ...@@ -67,6 +65,7 @@ from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer, MultiprocessingSerializer,
assert_pkg_version, assert_pkg_version,
...@@ -513,6 +512,21 @@ class Engine(EngineBase): ...@@ -513,6 +512,21 @@ class Engine(EngineBase):
self.tokenizer_manager.update_weights_from_disk(obj, None) self.tokenizer_manager.update_weights_from_disk(obj, None)
) )
def update_weights_from_ipc(
self,
zmq_handles: Dict[str, str],
flush_cache: bool = True,
):
"""Update weights from IPC for checkpoint-engine integration."""
obj = UpdateWeightsFromIPCReqInput(
zmq_handles=zmq_handles,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
def get_weights_by_name(self, name: str, truncate_size: int = 100): def get_weights_by_name(self, name: str, truncate_size: int = 100):
"""Get weights by parameter name.""" """Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
...@@ -658,21 +672,6 @@ class Engine(EngineBase): ...@@ -658,21 +672,6 @@ class Engine(EngineBase):
request=None, request=None,
) )
def update_weights_from_ipc(
self,
zmq_handles: Dict[str, str],
flush_cache: bool = True,
):
"""Update weights from IPC for checkpoint-engine integration."""
obj = UpdateWeightsFromIPCReqInput(
zmq_handles=zmq_handles,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
# Set global environments # Set global environments
...@@ -881,14 +880,14 @@ def _launch_subprocesses( ...@@ -881,14 +880,14 @@ def _launch_subprocesses(
detoken_proc.start() detoken_proc.start()
# Init tokenizer manager first, as the bootstrap server is initialized here # Init tokenizer manager first, as the bootstrap server is initialized here
if server_args.tokenizer_worker_num > 1: if server_args.tokenizer_worker_num == 1:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
template_manager = None
else:
tokenizer_manager, template_manager = _init_tokenizer_manager( tokenizer_manager, template_manager = _init_tokenizer_manager(
server_args, port_args server_args, port_args
) )
else:
# Launch multi-tokenizer router
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
template_manager = None
# Wait for the model to finish loading # Wait for the model to finish loading
scheduler_infos = [] scheduler_infos = []
...@@ -911,7 +910,6 @@ def _launch_subprocesses( ...@@ -911,7 +910,6 @@ def _launch_subprocesses(
# Assume all schedulers have the same scheduler_info # Assume all schedulers have the same scheduler_info
scheduler_info = scheduler_infos[0] scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info, port_args return tokenizer_manager, template_manager, scheduler_info, port_args
...@@ -162,6 +162,7 @@ class LinearBase(torch.nn.Module): ...@@ -162,6 +162,7 @@ class LinearBase(torch.nn.Module):
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.quant_config = quant_config
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
else: else:
......
...@@ -269,10 +269,11 @@ class Scheduler( ...@@ -269,10 +269,11 @@ class Scheduler(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.page_size = server_args.page_size
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.enable_hicache_storage = server_args.hicache_storage_backend is not None self.enable_hicache_storage = server_args.hicache_storage_backend is not None
self.page_size = server_args.page_size
# Distributed rank info
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info( compute_dp_attention_world_info(
server_args.enable_dp_attention, server_args.enable_dp_attention,
...@@ -298,22 +299,12 @@ class Scheduler( ...@@ -298,22 +299,12 @@ class Scheduler(
# Init moe config # Init moe config
self.init_moe_config() self.init_moe_config()
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser(
model_type=self.server_args.reasoning_parser, stream_reasoning=False
)
self.tokenizer.think_end_id = self.tokenizer.encode(
reasoning_parser.detector.think_end_token, add_special_tokens=False
)[0]
# Check whether overlap can be enabled # Check whether overlap can be enabled
if not self.is_generation: if not self.is_generation:
self.enable_overlap = False self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.") logger.info("Overlap scheduler is disabled for embedding models.")
# Launch a tensor parallel worker # Launch a tensor parallel worker
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
self.tp_worker = TpModelWorker( self.tp_worker = TpModelWorker(
...@@ -327,7 +318,6 @@ class Scheduler( ...@@ -327,7 +318,6 @@ class Scheduler(
) )
# Launch a draft worker for speculative decoding # Launch a draft worker for speculative decoding
draft_worker_kwargs = dict( draft_worker_kwargs = dict(
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
...@@ -481,10 +471,6 @@ class Scheduler( ...@@ -481,10 +471,6 @@ class Scheduler(
) )
# Enable preemption for priority scheduling. # Enable preemption for priority scheduling.
self.try_preemption = self.enable_priority_scheduling self.try_preemption = self.enable_priority_scheduling
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.init_new_token_ratio = min( self.init_new_token_ratio = min(
envs.SGLANG_INIT_NEW_TOKEN_RATIO.get() envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
* server_args.schedule_conservativeness, * server_args.schedule_conservativeness,
...@@ -511,7 +497,6 @@ class Scheduler( ...@@ -511,7 +497,6 @@ class Scheduler(
) )
self.offload_tags = set() self.offload_tags = set()
self.init_profiler() self.init_profiler()
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args) self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
self.input_blocker = ( self.input_blocker = (
SchedulerInputBlocker(noop=self.attn_tp_rank != 0) SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
...@@ -519,18 +504,15 @@ class Scheduler( ...@@ -519,18 +504,15 @@ class Scheduler(
else None else None
) )
# Init disaggregation
self.init_disaggregation()
# Init metrics stats # Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_metrics(tp_rank, pp_rank, dp_rank)
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
self.init_kv_events(server_args.kv_events_config) self.init_kv_events(server_args.kv_events_config)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
if envs.SGLANG_LOG_GC.get(): if envs.SGLANG_LOG_GC.get():
configure_gc_logger() configure_gc_logger()
...@@ -695,6 +677,15 @@ class Scheduler( ...@@ -695,6 +677,15 @@ class Scheduler(
revision=server_args.revision, revision=server_args.revision,
) )
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser(
model_type=self.server_args.reasoning_parser, stream_reasoning=False
)
self.tokenizer.think_end_id = self.tokenizer.encode(
reasoning_parser.detector.think_end_token, add_special_tokens=False
)[0]
def init_memory_pool_and_cache(self): def init_memory_pool_and_cache(self):
server_args = self.server_args server_args = self.server_args
...@@ -835,6 +826,9 @@ class Scheduler( ...@@ -835,6 +826,9 @@ class Scheduler(
init_embedding_cache(embedding_cache_size * 1024 * 1024) init_embedding_cache(embedding_cache_size * 1024 * 1024)
def init_disaggregation(self): def init_disaggregation(self):
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend( self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend self.server_args.disaggregation_transfer_backend
) )
......
...@@ -858,7 +858,6 @@ class SchedulerOutputProcessorMixin: ...@@ -858,7 +858,6 @@ class SchedulerOutputProcessorMixin:
prompt_tokens.append(len(req.origin_input_ids)) prompt_tokens.append(len(req.origin_input_ids))
completion_tokens.append(len(output_ids_)) completion_tokens.append(len(output_ids_))
cached_tokens.append(req.cached_tokens) cached_tokens.append(req.cached_tokens)
retraction_counts.append(req.retraction_count) retraction_counts.append(req.retraction_count)
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
......
...@@ -196,9 +196,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -196,9 +196,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else server_args.speculative_num_draft_tokens else server_args.speculative_num_draft_tokens
) )
# Initialize tokenizer and processor
set_global_server_args_for_tokenizer(server_args) set_global_server_args_for_tokenizer(server_args)
# Initialize tokenizer and processor
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
import_processors("sglang.srt.multimodal.processors") import_processors("sglang.srt.multimodal.processors")
try: try:
...@@ -370,6 +370,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -370,6 +370,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
if self.server_args.gc_warning_threshold_secs > 0.0: if self.server_args.gc_warning_threshold_secs > 0.0:
configure_gc_warning(self.server_args.gc_warning_threshold_secs) configure_gc_warning(self.server_args.gc_warning_threshold_secs)
# Dispatcher and communicators
self._result_dispatcher = TypeBasedDispatcher( self._result_dispatcher = TypeBasedDispatcher(
[ [
( (
...@@ -387,15 +388,11 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -387,15 +388,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output, self._handle_update_weights_from_disk_req_output,
), ),
( (FreezeGCReq, lambda x: None),
FreezeGCReq,
lambda x: None,
),
# For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it. # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
(HealthCheckOutput, lambda x: None), (HealthCheckOutput, lambda x: None),
] ]
) )
self.init_communicators(server_args) self.init_communicators(server_args)
async def generate_request( async def generate_request(
...@@ -407,8 +404,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -407,8 +404,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.auto_create_handle_loop() self.auto_create_handle_loop()
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
if request: if request and "trace_context" in request.headers:
if "trace_context" in request.headers:
trace_set_remote_propagate_context(request.headers["trace_context"]) trace_set_remote_propagate_context(request.headers["trace_context"])
if self.server_args.tokenizer_worker_num > 1: if self.server_args.tokenizer_worker_num > 1:
......
...@@ -58,6 +58,7 @@ from sglang.srt.utils.common import ( ...@@ -58,6 +58,7 @@ from sglang.srt.utils.common import (
json_list_type, json_list_type,
nullable_str, nullable_str,
parse_connector_type, parse_connector_type,
wait_port_available,
xpu_has_xmx_support, xpu_has_xmx_support,
) )
from sglang.srt.utils.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.utils.hf_transformers_utils import check_gguf_file, get_config
...@@ -3763,6 +3764,10 @@ class ServerArgs: ...@@ -3763,6 +3764,10 @@ class ServerArgs:
"Please set --chunked-prefill-size -1 when using --multi-item-scoring-delimiter." "Please set --chunked-prefill-size -1 when using --multi-item-scoring-delimiter."
) )
assert (
self.schedule_conservativeness >= 0
), "schedule_conservativeness must be non-negative"
def check_lora_server_args(self): def check_lora_server_args(self):
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
...@@ -3956,9 +3961,7 @@ def set_global_server_args_for_scheduler(server_args: ServerArgs): ...@@ -3956,9 +3961,7 @@ def set_global_server_args_for_scheduler(server_args: ServerArgs):
_global_server_args = server_args _global_server_args = server_args
def set_global_server_args_for_tokenizer(server_args: ServerArgs): set_global_server_args_for_tokenizer = set_global_server_args_for_scheduler
global _global_server_args
_global_server_args = server_args
def get_global_server_args() -> ServerArgs: def get_global_server_args() -> ServerArgs:
...@@ -4082,7 +4085,8 @@ class PortArgs: ...@@ -4082,7 +4085,8 @@ class PortArgs:
), "please provide --dist-init-addr as host:port of head node" ), "please provide --dist-init-addr as host:port of head node"
dist_init_host, dist_init_port = dist_init_addr dist_init_host, dist_init_port = dist_init_addr
port_base = int(dist_init_port) + 1 dist_init_port = int(dist_init_port)
port_base = dist_init_port + 1
detokenizer_port = port_base + 1 detokenizer_port = port_base + 1
rpc_port = port_base + 2 rpc_port = port_base + 2
metrics_ipc_name = port_base + 3 metrics_ipc_name = port_base + 3
...@@ -4092,6 +4096,25 @@ class PortArgs: ...@@ -4092,6 +4096,25 @@ class PortArgs:
else: else:
assert worker_ports is not None assert worker_ports is not None
scheduler_input_port = worker_ports[dp_rank] scheduler_input_port = worker_ports[dp_rank]
try:
if dp_rank is None:
wait_port_available(dist_init_port, "dist_init_port")
wait_port_available(port_base, "port_base")
wait_port_available(detokenizer_port, "detokenizer_port")
wait_port_available(nccl_port, "nccl_port")
wait_port_available(rpc_port, "rpc_port")
wait_port_available(metrics_ipc_name, "metrics_ipc_name")
# Check scheduler_input_port only for dp.
# Skip check when using worker_ports since the port is already bound by our ZMQ socket
if dp_rank is None or worker_ports is None:
wait_port_available(scheduler_input_port, "scheduler_input_port")
except ValueError as e:
logger.exception(
f"Port is already in use. {dist_init_port=} {port_base=} {detokenizer_port=} {nccl_port=} {scheduler_input_port=}"
)
raise
return PortArgs( return PortArgs(
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
......
...@@ -1557,7 +1557,7 @@ def send_generate_requests(base_url: str, num_requests: int) -> List[str]: ...@@ -1557,7 +1557,7 @@ def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
"text": prompt, "text": prompt,
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 50, "max_new_tokens": 500,
}, },
}, },
) )
...@@ -1584,7 +1584,7 @@ async def send_concurrent_generate_requests( ...@@ -1584,7 +1584,7 @@ async def send_concurrent_generate_requests(
"text": prompt, "text": prompt,
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 50, "max_new_tokens": 500,
}, },
}, },
) as response: ) as response:
...@@ -1608,7 +1608,7 @@ async def send_concurrent_generate_requests_with_custom_params( ...@@ -1608,7 +1608,7 @@ async def send_concurrent_generate_requests_with_custom_params(
""", """,
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 50, "max_new_tokens": 500,
}, },
} }
......
...@@ -2,7 +2,6 @@ import asyncio ...@@ -2,7 +2,6 @@ import asyncio
import os import os
import re import re
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -37,6 +36,8 @@ class TestMaxQueuedRequests(CustomTestCase): ...@@ -37,6 +36,8 @@ class TestMaxQueuedRequests(CustomTestCase):
"1", "1",
"--max-queued-requests", # Enforce max queued request number is 1 "--max-queued-requests", # Enforce max queued request number is 1
"1", "1",
"--attention-backend",
"triton",
), ),
return_stdout_stderr=(cls.stdout, cls.stderr), return_stdout_stderr=(cls.stdout, cls.stderr),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment