Commit ebbc42d9 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Optimize broadcast & Reorg code (#1598)

parent 3ff64113
......@@ -148,6 +148,6 @@ def get_act_fn(
if not is_flashinfer_available():
logger.info(
"FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries."
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
......@@ -234,14 +234,9 @@ class Scheduler:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# Run one step
self.run_step()
# Send results
if self.tp_rank == 0:
for obj in self.out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
self.out_pyobjs = []
self.send_results()
def recv_requests(self):
if self.tp_rank == 0:
......@@ -256,7 +251,8 @@ class Scheduler:
else:
recv_reqs = None
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
if self.tp_size != 1:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs
def process_input_requests(self, recv_reqs: List):
......@@ -366,43 +362,11 @@ class Scheduler:
self.waiting_queue.append(req)
def run_step(self):
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run a new prefill batch
result = self.run_batch(new_batch)
self.process_batch_result(new_batch, result)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge_batch(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
batch = self.get_new_batch_decode()
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
# Print stats
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_decode_stats()
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def send_results(self):
if self.tp_rank == 0:
for obj in self.out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
self.out_pyobjs = []
def print_decode_stats(self):
num_used = self.max_total_num_tokens - (
......@@ -441,6 +405,31 @@ class Scheduler:
)
exit(1) if crash_on_warning else None
def run_step(self):
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run a new prefill batch
result = self.run_batch(new_batch)
self.process_batch_result(new_batch, result)
else:
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
batch = self.get_new_batch_decode()
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
if self.running_batch is None:
break
if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed
if (
......@@ -612,7 +601,6 @@ class Scheduler:
return None
# Update batch tensors
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
return batch
......@@ -723,6 +711,12 @@ class Scheduler:
self.handle_finished_requests(batch)
if not batch.is_empty():
if self.running_batch is None:
self.running_batch = batch
else:
self.running_batch.merge_batch(batch)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
......@@ -762,6 +756,13 @@ class Scheduler:
self.handle_finished_requests(batch)
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_decode_stats()
if self.running_batch.is_empty():
self.running_batch = None
def add_logprob_return_values(
self,
i: int,
......
......@@ -24,6 +24,7 @@ import random
import resource
import socket
import time
import warnings
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
......@@ -333,6 +334,10 @@ def suppress_other_loggers():
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable"
)
def assert_pkg_version(pkg: str, min_version: str, message: str):
try:
......@@ -615,7 +620,9 @@ def broadcast_pyobj(
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data))
tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
)
tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment