Unverified Commit f5a2faf2 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Introduce `FutureMap` (#10715)

parent 1c82d9db
import torch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.utils import get_compiler_backend
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class FutureMap:
def __init__(
self,
max_running_requests: int,
device: torch.device,
):
self.future_ct = 0
# A factor of 3 is used to avoid collision in the circular buffer.
self.future_limit = max_running_requests * 3
# A factor of 5 is used to ensure the buffer is large enough.
self.future_buffer_len = max_running_requests * 5
self.device = device
self.token_ids_buf = torch.empty(
(self.future_buffer_len,), dtype=torch.int64, device=self.device
)
def update_ct(self, bs: int) -> int:
"""Update the circular buffer pointer and return the current pointer."""
cur_future_ct = self.future_ct
self.future_ct = (cur_future_ct + bs) % self.future_limit
return cur_future_ct
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
input_ids = model_worker_batch.input_ids
_resolve_future_token_ids(input_ids, self.token_ids_buf)
def update_next_future(self, future_ct: int, bs: int):
return torch.arange(
-(future_ct + 1),
-(future_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
)
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
......@@ -36,10 +36,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
from sglang.srt.utils import DynamicGradMode
from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
......@@ -48,15 +49,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class TpModelWorkerClient:
"""A tensor parallel model worker."""
......@@ -79,11 +71,7 @@ class TpModelWorkerClient:
self.gpu_id = gpu_id
# Init future mappings
self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
)
self.future_map = FutureMap(self.max_running_requests, self.device)
# Launch threads
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
......@@ -153,7 +141,7 @@ class TpModelWorkerClient:
batch_lists: List = [None] * 2
while True:
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
if not model_worker_batch:
break
......@@ -169,8 +157,7 @@ class TpModelWorkerClient:
copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map)
self.future_map.resolve_future(model_worker_batch)
# Run forward
logits_output, next_token_ids, can_run_cuda_graph = (
......@@ -187,9 +174,9 @@ class TpModelWorkerClient:
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long)
self.future_token_ids_map[
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
] = next_token_ids
# store the future indices into future map
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
# Copy results to the CPU
if model_worker_batch.return_logprob:
......@@ -255,20 +242,14 @@ class TpModelWorkerClient:
sync_event.record(self.scheduler_stream)
# Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange(
-(self.future_token_ids_ct + 1),
-(self.future_token_ids_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
cur_future_map_ct = self.future_map.update_ct(bs)
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
# get this forward batch's future token ids
future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
return None, future_next_token_ids, False
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment