Commit ebbc42d9 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Optimize broadcast & Reorg code (#1598)

parent 3ff64113
...@@ -148,6 +148,6 @@ def get_act_fn( ...@@ -148,6 +148,6 @@ def get_act_fn(
if not is_flashinfer_available(): if not is_flashinfer_available():
logger.info( logger.info(
"FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries." "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
) )
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
...@@ -234,14 +234,9 @@ class Scheduler: ...@@ -234,14 +234,9 @@ class Scheduler:
recv_reqs = self.recv_requests() recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
# Run one step
self.run_step() self.run_step()
# Send results self.send_results()
if self.tp_rank == 0:
for obj in self.out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
self.out_pyobjs = []
def recv_requests(self): def recv_requests(self):
if self.tp_rank == 0: if self.tp_rank == 0:
...@@ -256,7 +251,8 @@ class Scheduler: ...@@ -256,7 +251,8 @@ class Scheduler:
else: else:
recv_reqs = None recv_reqs = None
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) if self.tp_size != 1:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs return recv_reqs
def process_input_requests(self, recv_reqs: List): def process_input_requests(self, recv_reqs: List):
...@@ -366,43 +362,11 @@ class Scheduler: ...@@ -366,43 +362,11 @@ class Scheduler:
self.waiting_queue.append(req) self.waiting_queue.append(req)
def run_step(self): def send_results(self):
new_batch = self.get_new_batch_prefill() if self.tp_rank == 0:
for obj in self.out_pyobjs:
if new_batch is not None: self.send_to_detokenizer.send_pyobj(obj)
# Run a new prefill batch self.out_pyobjs = []
result = self.run_batch(new_batch)
self.process_batch_result(new_batch, result)
if not new_batch.is_empty():
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge_batch(new_batch)
else:
# Run a decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
batch = self.get_new_batch_decode()
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
# Print stats
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_decode_stats()
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def print_decode_stats(self): def print_decode_stats(self):
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
...@@ -441,6 +405,31 @@ class Scheduler: ...@@ -441,6 +405,31 @@ class Scheduler:
) )
exit(1) if crash_on_warning else None exit(1) if crash_on_warning else None
def run_step(self):
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run a new prefill batch
result = self.run_batch(new_batch)
self.process_batch_result(new_batch, result)
else:
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
batch = self.get_new_batch_decode()
if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
if self.running_batch is None:
break
if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
if ( if (
...@@ -612,7 +601,6 @@ class Scheduler: ...@@ -612,7 +601,6 @@ class Scheduler:
return None return None
# Update batch tensors # Update batch tensors
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode() batch.prepare_for_decode()
return batch return batch
...@@ -723,6 +711,12 @@ class Scheduler: ...@@ -723,6 +711,12 @@ class Scheduler:
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
if not batch.is_empty():
if self.running_batch is None:
self.running_batch = batch
else:
self.running_batch.merge_batch(batch)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result logits_output, next_token_ids = result
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
...@@ -762,6 +756,13 @@ class Scheduler: ...@@ -762,6 +756,13 @@ class Scheduler:
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_decode_stats()
if self.running_batch.is_empty():
self.running_batch = None
def add_logprob_return_values( def add_logprob_return_values(
self, self,
i: int, i: int,
......
...@@ -24,6 +24,7 @@ import random ...@@ -24,6 +24,7 @@ import random
import resource import resource
import socket import socket
import time import time
import warnings
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
...@@ -333,6 +334,10 @@ def suppress_other_loggers(): ...@@ -333,6 +334,10 @@ def suppress_other_loggers():
logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR) logging.getLogger("vllm.utils").setLevel(logging.ERROR)
warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable"
)
def assert_pkg_version(pkg: str, min_version: str, message: str): def assert_pkg_version(pkg: str, min_version: str, message: str):
try: try:
...@@ -615,7 +620,9 @@ def broadcast_pyobj( ...@@ -615,7 +620,9 @@ def broadcast_pyobj(
else: else:
serialized_data = pickle.dumps(data) serialized_data = pickle.dumps(data)
size = len(serialized_data) size = len(serialized_data)
tensor_data = torch.ByteTensor(list(serialized_data)) tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
)
tensor_size = torch.tensor([size], dtype=torch.long) tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_size, src=0, group=dist_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment