Unverified Commit f5a2faf2 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Introduce `FutureMap` (#10715)

parent 1c82d9db
import torch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.utils import get_compiler_backend
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class FutureMap:
def __init__(
self,
max_running_requests: int,
device: torch.device,
):
self.future_ct = 0
# A factor of 3 is used to avoid collision in the circular buffer.
self.future_limit = max_running_requests * 3
# A factor of 5 is used to ensure the buffer is large enough.
self.future_buffer_len = max_running_requests * 5
self.device = device
self.token_ids_buf = torch.empty(
(self.future_buffer_len,), dtype=torch.int64, device=self.device
)
def update_ct(self, bs: int) -> int:
"""Update the circular buffer pointer and return the current pointer."""
cur_future_ct = self.future_ct
self.future_ct = (cur_future_ct + bs) % self.future_limit
return cur_future_ct
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
input_ids = model_worker_batch.input_ids
_resolve_future_token_ids(input_ids, self.token_ids_buf)
def update_next_future(self, future_ct: int, bs: int):
return torch.arange(
-(future_ct + 1),
-(future_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
)
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
...@@ -36,10 +36,11 @@ from sglang.srt.managers.io_struct import ( ...@@ -36,10 +36,11 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode, get_compiler_backend from sglang.srt.utils import DynamicGradMode
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -48,15 +49,6 @@ if TYPE_CHECKING: ...@@ -48,15 +49,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class TpModelWorkerClient: class TpModelWorkerClient:
"""A tensor parallel model worker.""" """A tensor parallel model worker."""
...@@ -79,11 +71,7 @@ class TpModelWorkerClient: ...@@ -79,11 +71,7 @@ class TpModelWorkerClient:
self.gpu_id = gpu_id self.gpu_id = gpu_id
# Init future mappings # Init future mappings
self.future_token_ids_ct = 0 self.future_map = FutureMap(self.max_running_requests, self.device)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
)
# Launch threads # Launch threads
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]() self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
...@@ -153,7 +141,7 @@ class TpModelWorkerClient: ...@@ -153,7 +141,7 @@ class TpModelWorkerClient:
batch_lists: List = [None] * 2 batch_lists: List = [None] * 2
while True: while True:
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get() model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
if not model_worker_batch: if not model_worker_batch:
break break
...@@ -169,8 +157,7 @@ class TpModelWorkerClient: ...@@ -169,8 +157,7 @@ class TpModelWorkerClient:
copy_done = torch.get_device_module(self.device).Event() copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input # Resolve future tokens in the input
input_ids = model_worker_batch.input_ids self.future_map.resolve_future(model_worker_batch)
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward # Run forward
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
...@@ -187,9 +174,9 @@ class TpModelWorkerClient: ...@@ -187,9 +174,9 @@ class TpModelWorkerClient:
if model_worker_batch.is_prefill_only: if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU # For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long) next_token_ids = torch.zeros(bs, dtype=torch.long)
self.future_token_ids_map[
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 # store the future indices into future map
] = next_token_ids self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
# Copy results to the CPU # Copy results to the CPU
if model_worker_batch.return_logprob: if model_worker_batch.return_logprob:
...@@ -255,20 +242,14 @@ class TpModelWorkerClient: ...@@ -255,20 +242,14 @@ class TpModelWorkerClient:
sync_event.record(self.scheduler_stream) sync_event.record(self.scheduler_stream)
# Push a new batch to the queue # Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
# Allocate output future objects
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange( cur_future_map_ct = self.future_map.update_ct(bs)
-(self.future_token_ids_ct + 1), self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
-(self.future_token_ids_ct + 1 + bs),
-1, # get this forward batch's future token ids
dtype=torch.int64, future_next_token_ids = self.future_map.update_next_future(
device=self.device, cur_future_map_ct, bs
) )
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
return None, future_next_token_ids, False return None, future_next_token_ids, False
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment