Unverified Commit b121bc03 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify batch result resolution (#1735)

parent e12358dc
...@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors. It contains low-level tensor data. Most of the data consists of GPU tensors.
""" """
import dataclasses
import logging import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason): ...@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
} }
@dataclass @dataclasses.dataclass
class ImageInputs: class ImageInputs:
"""The image related inputs.""" """The image related inputs."""
...@@ -407,7 +407,7 @@ class Req: ...@@ -407,7 +407,7 @@ class Req:
bid = 0 bid = 0
@dataclass @dataclasses.dataclass
class ScheduleBatch: class ScheduleBatch:
"""Store all inforamtion of a batch.""" """Store all inforamtion of a batch."""
...@@ -902,7 +902,7 @@ class ScheduleBatch: ...@@ -902,7 +902,7 @@ class ScheduleBatch:
) )
@dataclass @dataclasses.dataclass
class ModelWorkerBatch: class ModelWorkerBatch:
# The batch id # The batch id
bid: int bid: int
...@@ -942,24 +942,7 @@ class ModelWorkerBatch: ...@@ -942,24 +942,7 @@ class ModelWorkerBatch:
mrope_positions_delta: List[List[int]] mrope_positions_delta: List[List[int]]
def copy(self): def copy(self):
return ModelWorkerBatch( return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
extend_prefix_lens=self.extend_prefix_lens,
extend_logprob_start_lens=self.extend_logprob_start_lens,
image_inputs=self.image_inputs,
lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta,
)
def to(self, device: str): def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True) self.input_ids = self.input_ids.to(device, non_blocking=True)
......
...@@ -149,12 +149,8 @@ class Scheduler: ...@@ -149,12 +149,8 @@ class Scheduler:
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.enable_overlap: if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else: else:
TpWorkerClass = TpModelWorker TpWorkerClass = TpModelWorker
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.tp_worker = TpWorkerClass( self.tp_worker = TpWorkerClass(
server_args=server_args, server_args=server_args,
...@@ -756,9 +752,12 @@ class Scheduler: ...@@ -756,9 +752,12 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation: if self.is_generation:
logits_output, next_token_ids, bid = result logits_output, next_token_ids, bid = result
if batch.return_logprob:
# Move logprobs to cpu if self.enable_overlap:
if logits_output.next_token_logprobs is not None: logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[ logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device), torch.arange(len(next_token_ids), device=self.device),
...@@ -771,8 +770,7 @@ class Scheduler: ...@@ -771,8 +770,7 @@ class Scheduler:
logits_output.normalized_prompt_logprobs = ( logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist() logits_output.normalized_prompt_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist()
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
...@@ -825,14 +823,16 @@ class Scheduler: ...@@ -825,14 +823,16 @@ class Scheduler:
logits_output, next_token_ids, bid = result logits_output, next_token_ids, bid = result
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu if self.enable_overlap:
if batch.return_logprob: logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs[ else:
torch.arange(len(next_token_ids), device=self.device), # Move next_token_ids and logprobs to cpu
next_token_ids, if batch.return_logprob:
].tolist() next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids) next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
self.token_to_kv_pool.free_group_begin() self.token_to_kv_pool.free_group_begin()
......
...@@ -48,19 +48,16 @@ class TpModelWorkerClient: ...@@ -48,19 +48,16 @@ class TpModelWorkerClient:
self.max_running_requests = self.worker.max_running_requests self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device self.device = self.worker.device
# Create future mappings # Init future mappings
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
self.future_token_ids_ct = 0 self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty( self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
) )
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()
# Launch a thread # Launch a thread
self.future_event_map = dict() self.input_queue = Queue()
self.forward_queue = Queue() self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream() self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread( self.forward_thread = threading.Thread(
target=self.forward_thread_func, target=self.forward_thread_func,
...@@ -90,9 +87,7 @@ class TpModelWorkerClient: ...@@ -90,9 +87,7 @@ class TpModelWorkerClient:
def forward_thread_func_(self): def forward_thread_func_(self):
while True: while True:
tic1 = time.time() tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = ( model_worker_batch, future_token_ids_ct = self.input_queue.get()
self.forward_queue.get()
)
# Resolve future tokens in the input # Resolve future tokens in the input
tic2 = time.time() tic2 = time.time()
...@@ -107,17 +102,22 @@ class TpModelWorkerClient: ...@@ -107,17 +102,22 @@ class TpModelWorkerClient:
model_worker_batch model_worker_batch
) )
# Set future values # Update the future token ids map
if model_worker_batch.return_logprob: bs = len(model_worker_batch.seq_lens)
self.future_logits_output_dict[future_logits_output] = logits_output future_next_token_ids = torch.arange(
-(future_token_ids_ct + bs),
-(future_token_ids_ct),
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to( self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32 torch.int32
) )
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist() # Set the result
) next_token_ids = next_token_ids.tolist()
self.future_event_map[model_worker_batch.bid].set() assert logits_output.next_token_logprobs is None, "Not supported"
self.output_queue.put((None, next_token_ids))
if False: if False:
tic3 = time.time() tic3 = time.time()
...@@ -128,38 +128,26 @@ class TpModelWorkerClient: ...@@ -128,38 +128,26 @@ class TpModelWorkerClient:
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}" f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
) )
def resolve_future_token_ids(self, bid: int): def resulve_batch_result(self, bid: int):
self.future_event_map[bid].wait() logits_output, next_token_ids = self.output_queue.get()
ret = self.future_token_ids_output[bid] return logits_output, next_token_ids
del self.future_event_map[bid]
return ret
def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Allocate output future objects # Push a new batch to the queue
future_logits_output = self.future_logits_output_ct self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
self.future_logits_output_ct += 1
# Allocate output future objects
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream): future_next_token_ids = torch.arange(
future_next_token_ids = -torch.arange( -(self.future_token_ids_ct + bs),
self.future_token_ids_ct + 1, -(self.future_token_ids_ct),
self.future_token_ids_ct + 1 + bs, dtype=torch.int32,
dtype=torch.int32, device=self.device,
device=self.device, )
)
self.future_token_ids_ct = ( self.future_token_ids_ct = (
self.future_token_ids_ct + bs self.future_token_ids_ct + bs
) % self.future_token_ids_limit ) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids return None, future_next_token_ids
self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
......
...@@ -120,7 +120,7 @@ class ModelRunner: ...@@ -120,7 +120,7 @@ class ModelRunner:
) )
if self.is_multimodal_model: if self.is_multimodal_model:
logger.info( logger.warning(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
) )
server_args.chunked_prefill_size = None server_args.chunked_prefill_size = None
...@@ -131,13 +131,6 @@ class ModelRunner: ...@@ -131,13 +131,6 @@ class ModelRunner:
]: ]:
server_args.disable_cuda_graph = True server_args.disable_cuda_graph = True
if self.server_args.enable_overlap_schedule:
logger.warning(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars # Global vars
if server_args.show_time_cost: if server_args.show_time_cost:
enable_show_time_cost() enable_show_time_cost()
......
...@@ -177,6 +177,16 @@ class ServerArgs: ...@@ -177,6 +177,16 @@ class ServerArgs:
if self.sampling_backend is None: if self.sampling_backend is None:
self.sampling_backend = "flashinfer" self.sampling_backend = "flashinfer"
if self.enable_overlap_schedule:
logger.warning(
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
"The NaN detection is also disabled."
)
self.disable_penalizer = True
self.disable_nan_detection = True
# Model-specific patches # Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info( logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment