Commit 20e75ed6 authored by lizhigong's avatar lizhigong Committed by maxiao1@sugon.com
Browse files

add tbo on v1 engine

parent eba84521
...@@ -159,6 +159,7 @@ if TYPE_CHECKING: ...@@ -159,6 +159,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_TBO: bool = False VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS: int = 0 VLLM_TBO_REQ_DELAY_MS: int = 0
VLLM_TBO_DECODE_BS: int = 0 VLLM_TBO_DECODE_BS: int = 0
VLLM_TBO_MIN_TOKENS: int = 200
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
...@@ -1069,6 +1070,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1069,6 +1070,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TBO_DECODE_BS": "VLLM_TBO_DECODE_BS":
lambda: int(os.getenv("VLLM_TBO_DECODE_BS", "0")), lambda: int(os.getenv("VLLM_TBO_DECODE_BS", "0")),
# set the minimum tokens size for each mini-batch to enable TBO on v1, default is 200.
"VLLM_TBO_MIN_TOKENS":
lambda: int(os.getenv("VLLM_TBO_MIN_TOKENS", "200")),
# Enable zero overhead scheduler. # Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD": "VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))), lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
......
...@@ -16,6 +16,7 @@ from vllm.logger import init_logger ...@@ -16,6 +16,7 @@ from vllm.logger import init_logger
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm import envs from vllm import envs
from vllm.utils import weak_ref_tensor from vllm.utils import weak_ref_tensor
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import is_enable_tbo_v1, tbo_all_reduce_v1
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1' tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
...@@ -214,6 +215,8 @@ def init_two_batch_overlap(): ...@@ -214,6 +215,8 @@ def init_two_batch_overlap():
tbo_obj.init_tbo_thread() tbo_obj.init_tbo_thread()
def tbo_all_reduce(obj): def tbo_all_reduce(obj):
if is_enable_tbo_v1():
return tbo_all_reduce_v1(obj)
if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running: if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
tid = threading.get_ident() tid = threading.get_ident()
if not tbo_one_stream: if not tbo_one_stream:
......
from typing import Any, Optional, Union
import numpy as np
import torch
from vllm import envs
from vllm.distributed.kv_transfer.kv_transfer_state import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_model_executable_v1
from vllm.utils import async_tensor_h2d
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
class TBO_GPUModelRunner(GPUModelRunner):
def __init__(self, vllm_config, device):
super().__init__(vllm_config, device)
self.req_ids_left = []
self.req_ids_right = []
self.req_num_left = 0
self.req_num_right = 0
self.scheduler_output_left = None
self.scheduler_output_right = None
def split_scheduler_output(self, scheduler_output:SchedulerOutput):
split_tokens = scheduler_output.total_num_scheduled_tokens // 2
req_ids = self.input_batch.req_ids
tokens_counter = 0
min_idx = -1
min_counter = 0
for i, id in enumerate(req_ids):
tokens_counter += scheduler_output.num_scheduled_tokens[id]
diff = abs(tokens_counter - split_tokens)
if min_idx == -1 or diff < min_counter:
min_idx = i
min_counter = diff
if tokens_counter > split_tokens or diff == 0:
break
self.req_num_left = min_idx + 1
if self.req_num_left == len(req_ids):
self.req_num_left = self.req_num_left - 1
self.req_ids_left = req_ids[:self.req_num_left]
self.req_ids_right = req_ids[self.req_num_left:]
self.req_num_right = len(req_ids) - self.req_num_left
new_req_data_left = []
new_req_data_right = []
cached_reqs_left = []
cached_reqs_right = []
num_scheduled_tokens_left = {}
num_scheduled_tokens_right = {}
total_num_scheduled_tokens_left = 0
total_num_scheduled_tokens_right = 0
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self.req_ids_left:
new_req_data_left.append(new_req)
else:
new_req_data_right.append(new_req)
for cached_req in scheduler_output.scheduled_cached_reqs:
if cached_req.req_id in self.req_ids_left:
cached_reqs_left.append(cached_req)
else:
cached_reqs_right.append(cached_req)
for key, value in scheduler_output.num_scheduled_tokens.items():
if key in self.req_ids_left:
num_scheduled_tokens_left[key] = value
total_num_scheduled_tokens_left += value
else:
num_scheduled_tokens_right[key] = value
total_num_scheduled_tokens_right += value
self.scheduler_output_left = SchedulerOutput(
scheduled_new_reqs=new_req_data_left,
scheduled_cached_reqs=cached_reqs_left,
num_scheduled_tokens=num_scheduled_tokens_left,
total_num_scheduled_tokens=total_num_scheduled_tokens_left,
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=scheduler_output.finished_req_ids,
free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
structured_output_request_ids=scheduler_output.structured_output_request_ids,
grammar_bitmask=scheduler_output.grammar_bitmask,
)
self.scheduler_output_right = SchedulerOutput(
scheduled_new_reqs=new_req_data_right,
scheduled_cached_reqs=cached_reqs_right,
num_scheduled_tokens=num_scheduled_tokens_right,
total_num_scheduled_tokens=total_num_scheduled_tokens_right,
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=scheduler_output.finished_req_ids,
free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
structured_output_request_ids=scheduler_output.structured_output_request_ids,
grammar_bitmask=scheduler_output.grammar_bitmask,
)
def prepare_tbo_atten_metadata(
self,
scheduler_output: "SchedulerOutput",
req_ids,
req_offset
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = len(req_ids)
assert num_reqs > 0
seq_len_offset = req_offset
if req_offset == 0: #left
query_start_offset = 0
else:
query_start_offset = req_offset + 1
# Get the number of scheduled tokens for each request.
# req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens) + req_offset
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Calculate the slot mapping for each KV cache group.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[req_offset : req_offset + num_reqs] +
num_scheduled_tokens)
self.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
if req_offset > 0: #right
self.query_start_loc[query_start_offset + num_reqs + 1:].fill_(
self.query_start_loc_cpu[num_reqs].item())
self.seq_lens[seq_len_offset :seq_len_offset + num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
if req_offset > 0: #right
self.seq_lens[seq_len_offset + num_reqs:].fill_(0)
query_start_loc = self.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1]
seq_lens = self.seq_lens[seq_len_offset : seq_len_offset + num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
attn_metadata: dict[str, Any] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
self.attn_metadata_builders[kv_cache_group_id],
)
if req_offset > 0:
origin_block_table = self.attn_metadata_builders[kv_cache_group_id].block_table.block_table
self.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = self.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping
self.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = \
origin_slot_mapping[self.scheduler_output_left.total_num_scheduled_tokens:]
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata
if req_offset > 0:
self.attn_metadata_builders[kv_cache_group_id].block_table.block_table = origin_block_table
self.attn_metadata_builders[kv_cache_group_id].block_table.slot_mapping = origin_slot_mapping
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
return attn_metadata
def pad_num_input_tokens(self, scheduler_output):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
return num_input_tokens, num_tokens_across_dp
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
assert not self.use_cuda_graph, 'v1 engine with tbo do not support cuda-graph'
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
super()._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
if self.is_multimodal_model and get_pp_group().is_first_rank:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
positions = self.positions[:num_input_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
use_tbo = False
if len(scheduler_output.num_scheduled_tokens) > 1:
self.split_scheduler_output(scheduler_output)
if self.scheduler_output_left.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS and \
self.scheduler_output_right.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
use_tbo = True
if use_tbo:
num_input_tokens_left = self.scheduler_output_left.total_num_scheduled_tokens
num_input_tokens_right = num_input_tokens - num_input_tokens_left
attn_metadata_left = self.prepare_tbo_atten_metadata(self.scheduler_output_left, self.req_ids_left, 0)
attn_metadata_right = self.prepare_tbo_atten_metadata(self.scheduler_output_right, self.req_ids_right, self.req_num_left)
model_output = tbo_model_executable_v1(
self,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
num_tokens_across_dp,
input_ids,
positions,
intermediate_tensors,
inputs_embeds)
finished_sending, finished_recving = None, None
else:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output,
)
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
elif self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
if max_gen_len == 1:
hidden_states = sample_hidden_states
else:
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
valid_sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
indices = torch.tensor(indices,
device=sample_hidden_states.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_names[0]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
block_table = eagle_attn_metadata.block_table
else:
block_table = None
num_rejected_tokens_tuple = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens_tensor = async_tensor_h2d(
num_rejected_tokens,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_rejected_tokens_tensor,
num_tokens,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]
num_rejected_tokens_tuple = (num_rejected_tokens, num_rejected_tokens_tensor)
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=block_table,
sampling_metadata=sampling_metadata,
num_rejected_tokens_tuple=num_rejected_tokens_tuple
)
spec_token_ids = draft_token_ids.tolist()
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
\ No newline at end of file
import os
import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.logger import init_logger
from vllm.profiler.prof import profile
from vllm import envs
logger = init_logger(__name__)
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
def __init__(self):
global tbo_step_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
self.states_right_queue = queue.Queue()
self.left_thread = None
self.right_thread = None
self.left_tid = 0
self.right_tid = 0
self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
self.model_input_left_queue.empty()
self.model_input_right_queue.empty()
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
logger.info('tbo:two batch overlap start')
def finish_thread(self):
self.left_thread.join()
self.left_thread = None
self.right_thread.join()
self.right_thread = None
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
queue.get()
profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid)
if is_left_thread:
attn_metadata = self.attn_metadata_left
num_input_tokens = self.num_input_tokens_left
input_ids = self.input_ids_left
positions = self.positions_left
else:
attn_metadata = self.attn_metadata_right
num_input_tokens = self.num_input_tokens_right
input_ids = self.input_ids_right
positions = self.positions_right
model_output = None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=self.intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(model_output)
else:
self.states_right_queue.put(model_output)
profile.ProfRangePop()
def tbo_thread_synchronize(self, tid):
if tid == self.left_tid:
if not self.left_first:
self.sem_right.release()
self.left_first = False
profile.ProfRangePop()
self.sem_left.acquire()
profile.ProfRangePush('left')
return self.event_left_c2t, self.event_left_t2c
else:
self.sem_left.release()
profile.ProfRangePop()
self.sem_right.acquire()
profile.ProfRangePush('right')
return self.event_right_c2t, self.event_right_t2c
def set_model_input(self,
model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds):
self.model_runner = model_runner
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
self.num_input_tokens_left = num_input_tokens_left
self.num_input_tokens_right = num_input_tokens_right
self.input_ids_left = input_ids_left
self.input_ids_right = input_ids_right
self.positions_left = positions_left
self.positions_right = positions_right
self.num_tokens_across_dp = num_tokens_across_dp
self.intermediate_tensors = intermediate_tensors
self.inputs_embeds = inputs_embeds
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
def get_model_output(self):
states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get()
return states_left, states_right
tbo_obj_v1 = None
def is_enable_tbo_v1():
global tbo_obj_v1
return tbo_obj_v1 != None
def init_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 == None:
tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread()
def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
else:
event_c2t, event_t2c = tbo_obj_v1.event_right_c2t, tbo_obj_v1.event_right_t2c
event_c2t.record()
with torch.cuda.stream(all_reduce_stream):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(obj)
event_t2c.record()
tbo_obj_v1.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c)
return output
return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors):
output_map = {}
for key in states_left.tensors:
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
output = IntermediateTensors(output_map)
else:
output = torch.concat([states_left, states_right], dim=0)
return output
def tbo_model_executable_v1(
model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
num_tokens_across_dp,
input_ids,
positions,
intermediate_tensors,
inputs_embeds
):
init_two_batch_overlap()
tbo_obj_v1.tbo_running = True
tbo_obj_v1.left_first = True
tbo_obj_v1.step_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj_v1.step_event)
tokens_split = [num_input_tokens_left, num_input_tokens_right]
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0)
positions_left, positions_right = torch.split(positions, tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds)
model_output_left, model_output_right = tbo_obj_v1.get_model_output()
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
tbo_obj_v1.tbo_running = False
tbo_obj_v1.step_event.record()
tbo_obj_v1.finish_thread()
current_stream.wait_event(tbo_obj_v1.step_event)
return hidden_or_intermediate_states
\ No newline at end of file
...@@ -22,6 +22,7 @@ from vllm.lora.request import LoRARequest ...@@ -22,6 +22,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.v1.gpu_model_runner import TBO_GPUModelRunner
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
...@@ -162,6 +163,10 @@ class Worker(WorkerBase): ...@@ -162,6 +163,10 @@ class Worker(WorkerBase):
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
# Construct the model runner # Construct the model runner
if envs.VLLM_ENABLE_TBO:
self.model_runner: TBO_GPUModelRunner = TBO_GPUModelRunner(
self.vllm_config, self.device)
else:
self.model_runner: GPUModelRunner = GPUModelRunner( self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device) self.vllm_config, self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment