Unverified Commit b121bc03 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify batch result resolution (#1735)

parent e12358dc
......@@ -29,8 +29,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
import dataclasses
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
......@@ -116,7 +116,7 @@ class FINISH_ABORT(BaseFinishReason):
}
@dataclass
@dataclasses.dataclass
class ImageInputs:
"""The image related inputs."""
......@@ -407,7 +407,7 @@ class Req:
bid = 0
@dataclass
@dataclasses.dataclass
class ScheduleBatch:
"""Store all inforamtion of a batch."""
......@@ -902,7 +902,7 @@ class ScheduleBatch:
)
@dataclass
@dataclasses.dataclass
class ModelWorkerBatch:
# The batch id
bid: int
......@@ -942,24 +942,7 @@ class ModelWorkerBatch:
mrope_positions_delta: List[List[int]]
def copy(self):
return ModelWorkerBatch(
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
req_to_token_pool_records=self.req_to_token_pool_records,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
extend_prefix_lens=self.extend_prefix_lens,
extend_logprob_start_lens=self.extend_logprob_start_lens,
image_inputs=self.image_inputs,
lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta,
)
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True)
......
......@@ -149,12 +149,8 @@ class Scheduler:
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else:
TpWorkerClass = TpModelWorker
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.tp_worker = TpWorkerClass(
server_args=server_args,
......@@ -756,9 +752,12 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids, bid = result
if batch.return_logprob:
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
......@@ -771,8 +770,7 @@ class Scheduler:
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
next_token_ids = next_token_ids.tolist()
# Check finish conditions
logprob_pt = 0
......@@ -825,14 +823,16 @@ class Scheduler:
logits_output, next_token_ids, bid = result
self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
self.token_to_kv_pool.free_group_begin()
......
......@@ -48,19 +48,16 @@ class TpModelWorkerClient:
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
# Create future mappings
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
# Init future mappings
self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()
# Launch a thread
self.future_event_map = dict()
self.forward_queue = Queue()
self.input_queue = Queue()
self.output_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
......@@ -90,9 +87,7 @@ class TpModelWorkerClient:
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = (
self.forward_queue.get()
)
model_worker_batch, future_token_ids_ct = self.input_queue.get()
# Resolve future tokens in the input
tic2 = time.time()
......@@ -107,17 +102,22 @@ class TpModelWorkerClient:
model_worker_batch
)
# Set future values
if model_worker_batch.return_logprob:
self.future_logits_output_dict[future_logits_output] = logits_output
# Update the future token ids map
bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange(
-(future_token_ids_ct + bs),
-(future_token_ids_ct),
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist()
)
self.future_event_map[model_worker_batch.bid].set()
# Set the result
next_token_ids = next_token_ids.tolist()
assert logits_output.next_token_logprobs is None, "Not supported"
self.output_queue.put((None, next_token_ids))
if False:
tic3 = time.time()
......@@ -128,38 +128,26 @@ class TpModelWorkerClient:
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)
def resolve_future_token_ids(self, bid: int):
self.future_event_map[bid].wait()
ret = self.future_token_ids_output[bid]
del self.future_event_map[bid]
return ret
def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)
def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
return logits_output, next_token_ids
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Allocate output future objects
future_logits_output = self.future_logits_output_ct
self.future_logits_output_ct += 1
# Push a new batch to the queue
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream):
future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1,
self.future_token_ids_ct + 1 + bs,
dtype=torch.int32,
device=self.device,
)
future_next_token_ids = torch.arange(
-(self.future_token_ids_ct + bs),
-(self.future_token_ids_ct),
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids
self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret
return None, future_next_token_ids
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
......
......@@ -120,7 +120,7 @@ class ModelRunner:
)
if self.is_multimodal_model:
logger.info(
logger.warning(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
......@@ -131,13 +131,6 @@ class ModelRunner:
]:
server_args.disable_cuda_graph = True
if self.server_args.enable_overlap_schedule:
logger.warning(
"Overlap scheduler is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results."
)
# Global vars
if server_args.show_time_cost:
enable_show_time_cost()
......
......@@ -177,6 +177,16 @@ class ServerArgs:
if self.sampling_backend is None:
self.sampling_backend = "flashinfer"
if self.enable_overlap_schedule:
logger.warning(
"Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. "
"The NaN detection is also disabled."
)
self.disable_penalizer = True
self.disable_nan_detection = True
# Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment