Unverified Commit db0cc57e authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support decode retract and update decode.py (#7196)

parent 349bb2c9
...@@ -31,7 +31,7 @@ import numpy as np ...@@ -31,7 +31,7 @@ import numpy as np
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST, FAKE_BOOTSTRAP_HOST,
DisaggregationMode, DisaggregationMode,
...@@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import ( ...@@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort, prepare_abort,
) )
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import (
KVCache,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -145,7 +153,11 @@ class DecodePreallocQueue: ...@@ -145,7 +153,11 @@ class DecodePreallocQueue:
gloo_group: ProcessGroup, gloo_group: ProcessGroup,
tp_rank: int, tp_rank: int,
tp_size: int, tp_size: int,
dp_size: int,
gpu_id: int,
bootstrap_port: int, bootstrap_port: int,
max_total_num_tokens: int,
prefill_pp_size: int,
transfer_backend: TransferBackend, transfer_backend: TransferBackend,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
...@@ -161,25 +173,35 @@ class DecodePreallocQueue: ...@@ -161,25 +173,35 @@ class DecodePreallocQueue:
self.gloo_group = gloo_group self.gloo_group = gloo_group
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.dp_size = dp_size
self.gpu_id = gpu_id
self.bootstrap_port = bootstrap_port self.bootstrap_port = bootstrap_port
self.max_total_num_tokens = max_total_num_tokens
self.prefill_pp_size = prefill_pp_size
self.num_reserved_decode_tokens = int( self.num_reserved_decode_tokens = int(
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512") os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
) )
self.transfer_backend = transfer_backend
# Queue for requests pending pre-allocation # Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = [] self.queue: List[DecodeRequest] = []
self.transfer_backend = transfer_backend self.retracted_queue: List[Req] = []
self.prefill_pp_size = prefill_pp_size
self.kv_manager = self._init_kv_manager() self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> BaseKVManager: def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs() kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args.engine_rank = self.tp_rank kv_args = kv_args_class()
attn_tp_size = self.tp_size // self.dp_size
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
kv_args.decode_tp_size = attn_tp_size
kv_args.prefill_pp_size = self.prefill_pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = ( kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos() self.token_to_kv_pool.get_contiguous_buf_infos()
) )
if self.draft_token_to_kv_pool is not None: if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = ( draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos() self.draft_token_to_kv_pool.get_contiguous_buf_infos()
) )
...@@ -194,6 +216,7 @@ class DecodePreallocQueue: ...@@ -194,6 +216,7 @@ class DecodePreallocQueue:
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos() self.metadata_buffers.get_buf_infos()
) )
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
...@@ -205,27 +228,83 @@ class DecodePreallocQueue: ...@@ -205,27 +228,83 @@ class DecodePreallocQueue:
) )
return kv_manager return kv_manager
def add(self, req: Req) -> None: def add(self, req: Req, is_retracted: bool = False) -> None:
"""Add a request to the pending queue.""" """Add a request to the pending queue."""
if self._check_if_req_exceed_kv_capacity(req):
return
if is_retracted:
self.retracted_queue.append(req)
else:
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST: if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
# Fake transfer for warmup reqs kv_receiver_class = get_kv_class(
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER) TransferBackend.FAKE, KVClassType.RECEIVER
)
else: else:
kv_receiver_class = get_kv_class( kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER self.transfer_backend, KVClassType.RECEIVER
) )
kv_receiver = kv_receiver_class( kv_receiver = kv_receiver_class(
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room, bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
) )
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
def extend(self, reqs: List[Req]) -> None: self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message)
prepare_abort(req, message)
self.scheduler.stream_output([req], req.return_logprob)
return True
return False
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
"""Add a request to the pending queue.""" """Add a request to the pending queue."""
for req in reqs: for req in reqs:
self.add(req) self.add(req, is_retracted=is_retracted)
def resume_retracted_reqs(self) -> List[Req]:
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
# allocate memory
resumed_reqs = []
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
for i, req in enumerate(self.retracted_queue):
if self.req_to_token_pool.available_size() <= 0:
break
required_tokens_for_request = (
len(req.origin_input_ids)
+ len(req.output_ids)
+ self.num_reserved_decode_tokens
)
if required_tokens_for_request > allocatable_tokens:
break
resumed_reqs.append(req)
indices_to_remove.add(i)
req.is_retracted = False
self._pre_alloc(req)
allocatable_tokens -= required_tokens_for_request
# load from cpu, release the cpu copy
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
self.retracted_queue = [
entry
for i, entry in enumerate(self.retracted_queue)
if i not in indices_to_remove
]
return resumed_reqs
def _update_handshake_waiters(self) -> None: def _update_handshake_waiters(self) -> None:
if not self.queue: if not self.queue:
...@@ -255,6 +334,8 @@ class DecodePreallocQueue: ...@@ -255,6 +334,8 @@ class DecodePreallocQueue:
error_message, error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
) )
else:
raise ValueError(f"Unexpected poll case: {poll}")
def pop_preallocated(self) -> List[DecodeRequest]: def pop_preallocated(self) -> List[DecodeRequest]:
"""Pop the preallocated requests from the pending queue (FIFO).""" """Pop the preallocated requests from the pending queue (FIFO)."""
...@@ -262,8 +343,16 @@ class DecodePreallocQueue: ...@@ -262,8 +343,16 @@ class DecodePreallocQueue:
preallocated_reqs = [] preallocated_reqs = []
indices_to_remove = set() indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens()
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
retractable_tokens = sum(
len(r.origin_input_ids) + len(r.output_ids)
for r in self.scheduler.running_batch.reqs
)
allocatable_tokens = self._allocatable_tokens(
retractable_tokens=retractable_tokens, count_retracted=True
)
# First, remove all failed requests from the queue # First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue): for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT): if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
...@@ -272,6 +361,7 @@ class DecodePreallocQueue: ...@@ -272,6 +361,7 @@ class DecodePreallocQueue:
) )
indices_to_remove.add(i) indices_to_remove.add(i)
# Then, preallocate the remaining requests if possible
for i, decode_req in enumerate(self.queue): for i, decode_req in enumerate(self.queue):
if i in indices_to_remove: if i in indices_to_remove:
continue continue
...@@ -285,10 +375,23 @@ class DecodePreallocQueue: ...@@ -285,10 +375,23 @@ class DecodePreallocQueue:
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0: if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
break break
# Memory estimation: don't add if the projected memory cannot be met
# TODO: add new_token ratio
origin_input_len = len(decode_req.req.origin_input_ids)
required_tokens_for_request = ( required_tokens_for_request = (
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens origin_input_len + self.num_reserved_decode_tokens
) )
if (
max(
required_tokens_for_request,
origin_input_len
+ decode_req.req.sampling_params.max_new_tokens
- retractable_tokens,
)
> allocatable_tokens
):
break
if required_tokens_for_request > allocatable_tokens: if required_tokens_for_request > allocatable_tokens:
break break
...@@ -321,15 +424,35 @@ class DecodePreallocQueue: ...@@ -321,15 +424,35 @@ class DecodePreallocQueue:
return preallocated_reqs return preallocated_reqs
def _allocatable_tokens(self) -> int: def _allocatable_tokens(
allocatable_tokens = ( self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
self.token_to_kv_pool_allocator.available_size() ) -> int:
- self.num_reserved_decode_tokens need_space_for_single_req = (
max(
[
x.sampling_params.max_new_tokens
+ len(x.origin_input_ids)
- retractable_tokens
for x in self.scheduler.running_batch.reqs
]
)
if retractable_tokens is not None
and len(self.scheduler.running_batch.reqs) > 0
else 0
)
available_size = self.token_to_kv_pool_allocator.available_size()
allocatable_tokens = available_size - max(
# preserve some space for future decode
self.num_reserved_decode_tokens
* ( * (
len(self.scheduler.running_batch.reqs) len(self.scheduler.running_batch.reqs)
+ len(self.transfer_queue.queue) + len(self.transfer_queue.queue)
+ len(self.scheduler.waiting_queue) + len(self.scheduler.waiting_queue)
) ),
# make sure each request can finish if reach max_tokens with all other requests retracted
need_space_for_single_req,
) )
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration # Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
...@@ -342,15 +465,27 @@ class DecodePreallocQueue: ...@@ -342,15 +465,27 @@ class DecodePreallocQueue:
self.scheduler.last_batch.reqs self.scheduler.last_batch.reqs
) )
if count_retracted:
allocatable_tokens -= sum(
[
len(req.origin_input_ids)
+ len(req.output_ids)
+ self.num_reserved_decode_tokens
for req in self.retracted_queue
]
)
return allocatable_tokens return allocatable_tokens
def _pre_alloc(self, req: Req) -> torch.Tensor: def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool""" """Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1) req_pool_indices = self.req_to_token_pool.alloc(1)
assert req_pool_indices is not None assert (
req_pool_indices is not None
), "req_pool_indices is full! There is a bug in memory estimation."
req.req_pool_idx = req_pool_indices[0] req.req_pool_idx = req_pool_indices[0]
if self.token_to_kv_pool_allocator.page_size == 1: if self.token_to_kv_pool_allocator.page_size == 1:
kv_loc = self.token_to_kv_pool_allocator.alloc( kv_loc = self.token_to_kv_pool_allocator.alloc(
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
...@@ -375,7 +510,10 @@ class DecodePreallocQueue: ...@@ -375,7 +510,10 @@ class DecodePreallocQueue:
), ),
extend_num_tokens=num_tokens, extend_num_tokens=num_tokens,
) )
assert kv_loc is not None
assert (
kv_loc is not None
), "KV cache is full! There is a bug in memory estimation."
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
...@@ -395,6 +533,7 @@ class DecodeTransferQueue: ...@@ -395,6 +533,7 @@ class DecodeTransferQueue:
self, self,
gloo_group: ProcessGroup, gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
tp_rank: int,
metadata_buffers: MetadataBuffers, metadata_buffers: MetadataBuffers,
scheduler: Scheduler, scheduler: Scheduler,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
...@@ -402,6 +541,7 @@ class DecodeTransferQueue: ...@@ -402,6 +541,7 @@ class DecodeTransferQueue:
self.queue: List[DecodeRequest] = [] self.queue: List[DecodeRequest] = []
self.gloo_group = gloo_group self.gloo_group = gloo_group
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.metadata_buffers = metadata_buffers self.metadata_buffers = metadata_buffers
self.scheduler = scheduler self.scheduler = scheduler
self.tree_cache = tree_cache self.tree_cache = tree_cache
...@@ -412,10 +552,9 @@ class DecodeTransferQueue: ...@@ -412,10 +552,9 @@ class DecodeTransferQueue:
def extend(self, decode_reqs: List[DecodeRequest]) -> None: def extend(self, decode_reqs: List[DecodeRequest]) -> None:
self.queue.extend(decode_reqs) self.queue.extend(decode_reqs)
def pop_transferred(self) -> List[DecodeRequest]: def pop_transferred(self) -> List[Req]:
if not self.queue: if not self.queue:
return [] return []
polls = poll_and_all_reduce( polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
) )
...@@ -424,7 +563,7 @@ class DecodeTransferQueue: ...@@ -424,7 +563,7 @@ class DecodeTransferQueue:
indices_to_remove = set() indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed: if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try: try:
decode_req.kv_receiver.failure_exception() decode_req.kv_receiver.failure_exception()
except Exception as e: except Exception as e:
...@@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin:
batch, _ = self._prepare_idle_batch_and_run(None) batch, _ = self._prepare_idle_batch_and_run(None)
if batch is None and ( if batch is None and (
len(self.disagg_decode_transfer_queue.queue) len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue) + len(self.disagg_decode_prealloc_queue.queue)
== 0 == 0
): ):
...@@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin:
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
if batch is None and ( if batch is None and (
len(self.disagg_decode_transfer_queue.queue) len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue) + len(self.disagg_decode_prealloc_queue.queue)
== 0 == 0
): ):
...@@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin:
return new_batch return new_batch
def process_decode_queue(self: Scheduler): def process_decode_queue(self: Scheduler):
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
self.waiting_queue.extend(resumed_reqs)
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
# if there are still retracted requests, we do not allocate new requests
return
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns) self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = ( alloc_reqs = (
......
...@@ -25,6 +25,7 @@ from collections import deque ...@@ -25,6 +25,7 @@ from collections import deque
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
...@@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin:
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu() .cpu()
.numpy() .numpy()
.astype(np.int64)
) )
req.start_send_idx = end_idx req.start_send_idx = end_idx
if last_chunk: if last_chunk:
......
...@@ -1415,6 +1415,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1415,6 +1415,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req = self.reqs[idx] req = self.reqs[idx]
retracted_reqs.append(req) retracted_reqs.append(req)
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
if isinstance(self.tree_cache, ChunkCache): if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction # ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
...@@ -1446,6 +1451,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1446,6 +1451,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req.reset_for_retract() req.reset_for_retract()
if len(retracted_reqs) == 0:
# Corner case: only one request left
raise ValueError(
"Failed to retract any request. No space left for only one request."
)
self.filter_batch(keep_indices=sorted_indices) self.filter_batch(keep_indices=sorted_indices)
# Reqs in batch are filtered # Reqs in batch are filtered
......
...@@ -628,6 +628,7 @@ class Scheduler( ...@@ -628,6 +628,7 @@ class Scheduler(
self.disagg_decode_transfer_queue = DecodeTransferQueue( self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.attn_tp_cpu_group, gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
tp_rank=self.tp_rank,
metadata_buffers=self.disagg_metadata_buffers, metadata_buffers=self.disagg_metadata_buffers,
scheduler=self, scheduler=self,
tree_cache=self.tree_cache, tree_cache=self.tree_cache,
...@@ -650,7 +651,11 @@ class Scheduler( ...@@ -650,7 +651,11 @@ class Scheduler(
gloo_group=self.attn_tp_cpu_group, gloo_group=self.attn_tp_cpu_group,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size, tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
gpu_id=self.gpu_id,
bootstrap_port=self.server_args.disaggregation_bootstrap_port, bootstrap_port=self.server_args.disaggregation_bootstrap_port,
max_total_num_tokens=self.max_total_num_tokens,
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
transfer_backend=self.transfer_backend, transfer_backend=self.transfer_backend,
) )
...@@ -1124,14 +1129,14 @@ class Scheduler( ...@@ -1124,14 +1129,14 @@ class Scheduler(
else: else:
self.waiting_queue.append(req) self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req]): def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend( self.disagg_prefill_bootstrap_queue.extend(
reqs, self.model_config.num_key_value_heads reqs, self.model_config.num_key_value_heads
) )
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue # If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs) self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
else: else:
self.waiting_queue.extend(reqs) self.waiting_queue.extend(reqs)
...@@ -1274,6 +1279,7 @@ class Scheduler( ...@@ -1274,6 +1279,7 @@ class Scheduler(
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += ( msg += (
f"cuda graph: {can_run_cuda_graph}, " f"cuda graph: {can_run_cuda_graph}, "
...@@ -1575,7 +1581,7 @@ class Scheduler( ...@@ -1575,7 +1581,7 @@ class Scheduler(
f"#retracted_reqs: {len(retracted_reqs)}, " f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
) )
self._extend_requests_to_queue(retracted_reqs) self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
else: else:
self.new_token_ratio = max( self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay, self.new_token_ratio - self.new_token_ratio_decay,
......
...@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator: ...@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator:
self.is_not_in_free_group = True self.is_not_in_free_group = True
self.free_group = [] self.free_group = []
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class MHATokenToKVPool(KVCache): class MHATokenToKVPool(KVCache):
...@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache): ...@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache):
self.head_dim = head_dim self.head_dim = head_dim
self._create_buffers() self._create_buffers()
# used for chunked cpu-offloading
self.chunk_size = 8192
self.layer_transfer_counter = None self.layer_transfer_counter = None
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None self.alt_stream = self.device_module.Stream() if _is_cuda else None
...@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache): ...@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache):
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
for layer_id in range(self.layer_num):
kv_cache_cpu.append([])
for i in range(0, len(indices), self.chunk_size):
chunk_indices = indices[i : i + self.chunk_size]
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
kv_cache_cpu[-1].append([k_cpu, v_cpu])
torch.cuda.synchronize()
return kv_cache_cpu
def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize()
for layer_id in range(self.layer_num):
for i in range(0, len(indices), self.chunk_size):
chunk_indices = indices[i : i + self.chunk_size]
k_cpu, v_cpu = (
kv_cache_cpu[layer_id][i // self.chunk_size][0],
kv_cache_cpu[layer_id][i // self.chunk_size][1],
)
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
self.k_buffer[layer_id][chunk_indices] = k_chunk
self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize()
# Todo: different memory layout # Todo: different memory layout
def get_flat_data(self, indices): def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer # prepare a large chunk of contiguous data for efficient transfer
......
...@@ -469,5 +469,132 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): ...@@ -469,5 +469,132 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.20) self.assertGreater(metrics["accuracy"], 0.20)
class TestDisaggregationSimulatedRetract(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "true"
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang.srt.disaggregation.mini_lb",
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"--disaggregation-ib-device",
"mlx5_roce0",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"1",
"--disaggregation-ib-device",
"mlx5_roce1",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_TEST_RETRACT")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.62)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment