Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -20,8 +20,8 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput ...@@ -20,8 +20,8 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
import outlines_core as oc import outlines_core as oc
import transformers.convert_slow_tokenizer as convert_slow_tokenizer
import transformers.file_utils as file_utils import transformers.file_utils as file_utils
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
import xgrammar as xgr import xgrammar as xgr
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -30,10 +30,8 @@ else: ...@@ -30,10 +30,8 @@ else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
oc = LazyLoader("oc", globals(), "outlines_core") oc = LazyLoader("oc", globals(), "outlines_core")
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils") file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
tokenization_gpt2 = LazyLoader( convert_slow_tokenizer = LazyLoader(
"tokenization_gpt2", "convert_slow_tokenizer", globals(), "transformers.convert_slow_tokenizer"
globals(),
"transformers.models.gpt2.tokenization_gpt2",
) )
TokenizerLike = object TokenizerLike = object
...@@ -204,7 +202,9 @@ def _reduced_vocabulary( ...@@ -204,7 +202,9 @@ def _reduced_vocabulary(
A Dict of token string -> equivalent token ids A Dict of token string -> equivalent token ids
""" """
unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()} unicode_to_bytes = {
v: k for k, v in convert_slow_tokenizer.bytes_to_unicode().items()
}
def convert_token_to_string(token: str) -> str: def convert_token_to_string(token: str) -> str:
string = tokenizer.convert_tokens_to_string([token]) string = tokenizer.convert_tokens_to_string([token])
......
...@@ -11,7 +11,7 @@ from vllm.distributed.parallel_state import get_dp_group ...@@ -11,7 +11,7 @@ from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.worker.ubatch_utils import ( from vllm.v1.worker.ubatch_utils import (
check_ubatch_thresholds, check_ubatch_thresholds,
is_second_ubatch_empty, is_last_ubatch_empty,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -56,7 +56,7 @@ def _run_ar( ...@@ -56,7 +56,7 @@ def _run_ar(
return tensor return tensor
def _post_process_ubatch(tensor: torch.Tensor) -> bool: def _post_process_ubatch(tensor: torch.Tensor, num_ubatches: int) -> bool:
orig_num_tokens_tensor = tensor[0, :] orig_num_tokens_tensor = tensor[0, :]
padded_num_tokens_tensor = tensor[1, :] padded_num_tokens_tensor = tensor[1, :]
...@@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool: ...@@ -68,7 +68,7 @@ def _post_process_ubatch(tensor: torch.Tensor) -> bool:
# there are no "empty" second ubatches # there are no "empty" second ubatches
orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) orig_min_num_tokens = int(orig_num_tokens_tensor.min().item())
padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) padded_max_num_tokens = int(padded_num_tokens_tensor.max().item())
if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): if is_last_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens, num_ubatches):
logger.debug( logger.debug(
"Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens
) )
...@@ -146,7 +146,7 @@ def _synchronize_dp_ranks( ...@@ -146,7 +146,7 @@ def _synchronize_dp_ranks(
assert should_attempt_dp_padding == should_dp_pad assert should_attempt_dp_padding == should_dp_pad
# Check conditions for microbatching # Check conditions for microbatching
should_ubatch = _post_process_ubatch(tensor) should_ubatch = _post_process_ubatch(tensor, parallel_config.num_ubatches)
if should_ubatch and not should_dp_pad: if should_ubatch and not should_dp_pad:
logger.debug_once( logger.debug_once(
......
...@@ -128,7 +128,6 @@ class InputBatch: ...@@ -128,7 +128,6 @@ class InputBatch:
# allocation if max_model_len is big. # allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros( self.num_computed_tokens_cpu_tensor = torch.zeros(
...@@ -340,9 +339,6 @@ class InputBatch: ...@@ -340,9 +339,6 @@ class InputBatch:
self.req_prompt_embeds[req_index] = request.prompt_embeds self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
self.is_token_ids[req_index, start_idx:end_idx] = True self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens. # Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens self.num_tokens_no_spec[req_index] = request.num_tokens
...@@ -522,10 +518,6 @@ class InputBatch: ...@@ -522,10 +518,6 @@ class InputBatch:
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i2],
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i1],
) )
self.num_tokens[i1], self.num_tokens[i2] = (
self.num_tokens[i2],
self.num_tokens[i1],
)
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = ( self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i2],
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i1],
...@@ -661,17 +653,16 @@ class InputBatch: ...@@ -661,17 +653,16 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index self.req_id_to_index[req_id] = empty_index
if last_req_index != empty_index: num_tokens = self.num_tokens_no_spec[last_req_index] + len(
( self.spec_token_ids[last_req_index]
self.spec_token_ids[last_req_index], )
self.spec_token_ids[empty_index],
) = ( (self.spec_token_ids[last_req_index], self.spec_token_ids[empty_index]) = (
self.spec_token_ids[empty_index], self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index], self.spec_token_ids[last_req_index],
) )
self.spec_token_ids[last_req_index].clear() self.spec_token_ids[last_req_index].clear()
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens last_req_index, :num_tokens
] ]
...@@ -682,7 +673,6 @@ class InputBatch: ...@@ -682,7 +673,6 @@ class InputBatch:
self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop( self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop(
last_req_index last_req_index
) )
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index last_req_index
] ]
......
...@@ -923,7 +923,6 @@ class GPUModelRunner( ...@@ -923,7 +923,6 @@ class GPUModelRunner(
self.input_batch.num_prompt_tokens[req_index] self.input_batch.num_prompt_tokens[req_index]
+ num_output_tokens + num_output_tokens
) )
self.input_batch.num_tokens[req_index] = end_idx
self.input_batch.num_tokens_no_spec[req_index] = end_idx self.input_batch.num_tokens_no_spec[req_index] = end_idx
# Update the block IDs. # Update the block IDs.
...@@ -968,7 +967,6 @@ class GPUModelRunner( ...@@ -968,7 +967,6 @@ class GPUModelRunner(
req_index, start_token_index:end_token_index req_index, start_token_index:end_token_index
] = new_token_ids ] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens_no_spec[req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
...@@ -984,8 +982,6 @@ class GPUModelRunner( ...@@ -984,8 +982,6 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index req_index, start_index:end_token_index
] = spec_token_ids ] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] += num_spec_tokens
# When speculative decoding is used with structured output, # When speculative decoding is used with structured output,
# the scheduler can drop draft tokens that do not # the scheduler can drop draft tokens that do not
...@@ -1628,6 +1624,15 @@ class GPUModelRunner( ...@@ -1628,6 +1624,15 @@ class GPUModelRunner(
logits_indices logits_indices
) )
# Cache attention metadata builds across hybrid KV-cache groups
# The only thing that changes between different hybrid KV-cache groups when the
# same metadata builder and KVCacheSpec is the same is the block table, so we
# can cache the attention metadata builds and just update the block table using
# `builder.update_block_table` if the builder supports it.
cached_attn_metadata: dict[
tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata
] = {}
def _build_attn_group_metadata( def _build_attn_group_metadata(
kv_cache_gid: int, kv_cache_gid: int,
attn_gid: int, attn_gid: int,
...@@ -1635,13 +1640,15 @@ class GPUModelRunner( ...@@ -1635,13 +1640,15 @@ class GPUModelRunner(
ubid: int | None = None, ubid: int | None = None,
) -> None: ) -> None:
attn_group = self.attn_groups[kv_cache_gid][attn_gid] attn_group = self.attn_groups[kv_cache_gid][attn_gid]
builder = attn_group.get_metadata_builder(ubid or 0)
cache_key = (kv_cache_groups[kv_cache_gid].kv_cache_spec, type(builder))
cascade_attn_prefix_len = ( cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid] cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens if cascade_attn_prefix_lens
else 0 else 0
) )
builder = attn_group.get_metadata_builder(ubid or 0)
extra_attn_metadata_args = {} extra_attn_metadata_args = {}
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
assert ubid is None, "UBatching not supported with GDN yet" assert ubid is None, "UBatching not supported with GDN yet"
...@@ -1656,12 +1663,23 @@ class GPUModelRunner( ...@@ -1656,12 +1663,23 @@ class GPUModelRunner(
attn_metadata_i = builder.build_for_cudagraph_capture( attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata common_attn_metadata
) )
elif (
cache_key in cached_attn_metadata
and builder.supports_update_block_table
):
attn_metadata_i = builder.update_block_table(
cached_attn_metadata[cache_key],
common_attn_metadata.block_table_tensor,
common_attn_metadata.slot_mapping,
)
else: else:
attn_metadata_i = builder.build( attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len, common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args, **extra_attn_metadata_args,
) )
if builder.supports_update_block_table:
cached_attn_metadata[cache_key] = attn_metadata_i
if ubid is None: if ubid is None:
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
...@@ -2680,7 +2698,6 @@ class GPUModelRunner( ...@@ -2680,7 +2698,6 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = req_ids[req_idx] req_id = req_ids[req_idx]
req_state = self.requests[req_id] req_state = self.requests[req_id]
...@@ -2755,6 +2772,27 @@ class GPUModelRunner( ...@@ -2755,6 +2772,27 @@ class GPUModelRunner(
**model_kwargs, **model_kwargs,
) )
@staticmethod
def _is_uniform_decode(
max_num_scheduled_tokens: int,
uniform_decode_query_len: int,
num_tokens: int,
num_reqs: int,
force_uniform_decode: bool | None = None,
) -> bool:
"""
Checks if it's a decode batch with same amount scheduled tokens
across all requests.
"""
return (
(
(max_num_scheduled_tokens == uniform_decode_query_len)
and (num_tokens == max_num_scheduled_tokens * num_reqs)
)
if force_uniform_decode is None
else force_uniform_decode
)
def _determine_batch_execution_and_padding( def _determine_batch_execution_and_padding(
self, self,
num_tokens: int, num_tokens: int,
...@@ -2776,14 +2814,12 @@ class GPUModelRunner( ...@@ -2776,14 +2814,12 @@ class GPUModelRunner(
torch.Tensor | None, torch.Tensor | None,
CUDAGraphStat | None, CUDAGraphStat | None,
]: ]:
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) uniform_decode = self._is_uniform_decode(
uniform_decode = ( max_num_scheduled_tokens=max_num_scheduled_tokens,
( uniform_decode_query_len=self.uniform_decode_query_len,
(max_num_scheduled_tokens == self.uniform_decode_query_len) num_tokens=num_tokens,
and (num_tokens_padded == max_num_scheduled_tokens * num_reqs) num_reqs=num_reqs,
) force_uniform_decode=force_uniform_decode,
if force_uniform_decode is None
else force_uniform_decode
) )
# Encoder-decoder models only support CG for decoder_step > 0 (no enc_output # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output
# is present). Also, chunked-prefill is disabled, so batch are uniform. # is present). Also, chunked-prefill is disabled, so batch are uniform.
...@@ -2797,6 +2833,7 @@ class GPUModelRunner( ...@@ -2797,6 +2833,7 @@ class GPUModelRunner(
else force_has_lora else force_has_lora
) )
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
dispatch_cudagraph = ( dispatch_cudagraph = (
lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch( lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens, num_tokens=num_tokens,
...@@ -2812,6 +2849,15 @@ class GPUModelRunner( ...@@ -2812,6 +2849,15 @@ class GPUModelRunner(
num_tokens_padded, use_cascade_attn or has_encoder_output num_tokens_padded, use_cascade_attn or has_encoder_output
) )
num_tokens_padded = batch_descriptor.num_tokens num_tokens_padded = batch_descriptor.num_tokens
if self.compilation_config.pass_config.enable_sp:
assert (
batch_descriptor.num_tokens
% self.vllm_config.parallel_config.tensor_parallel_size
== 0
), (
"Sequence parallelism requires num_tokens to be "
"a multiple of tensor parallel size"
)
# Extra coordination when running data-parallel since we need to coordinate # Extra coordination when running data-parallel since we need to coordinate
# across ranks # across ranks
...@@ -2987,7 +3033,7 @@ class GPUModelRunner( ...@@ -2987,7 +3033,7 @@ class GPUModelRunner(
cascade_attn_prefix_lens = None cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO) # Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
# Pre-compute cascade attention prefix lengths # Pre-compute cascade attention prefix lengths
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
num_scheduled_tokens_np, num_scheduled_tokens_np,
...@@ -3028,6 +3074,13 @@ class GPUModelRunner( ...@@ -3028,6 +3074,13 @@ class GPUModelRunner(
num_scheduled_tokens_np, num_scheduled_tokens_np,
num_tokens_padded, num_tokens_padded,
num_reqs_padded, num_reqs_padded,
self.parallel_config.num_ubatches,
)
logger.debug(
"ubatch_slices: %s, ubatch_slices_padded: %s",
ubatch_slices,
ubatch_slices_padded,
) )
pad_attn = cudagraph_mode == CUDAGraphMode.FULL pad_attn = cudagraph_mode == CUDAGraphMode.FULL
...@@ -3340,9 +3393,13 @@ class GPUModelRunner( ...@@ -3340,9 +3393,13 @@ class GPUModelRunner(
return async_output return async_output
def take_draft_token_ids(self) -> DraftTokenIds | None: def take_draft_token_ids(self) -> DraftTokenIds | None:
if self._draft_token_ids is None: if not self.num_spec_tokens:
return None return None
req_ids = self.input_batch.req_ids req_ids = self.input_batch.req_ids
if self._draft_token_ids is None:
return DraftTokenIds(req_ids, [[] for _ in req_ids])
if isinstance(self._draft_token_ids, torch.Tensor): if isinstance(self._draft_token_ids, torch.Tensor):
draft_token_ids = self._draft_token_ids.tolist() draft_token_ids = self._draft_token_ids.tolist()
else: else:
...@@ -3710,11 +3767,14 @@ class GPUModelRunner( ...@@ -3710,11 +3767,14 @@ class GPUModelRunner(
# wrap the model with full cudagraph wrapper if needed. # wrap the model with full cudagraph wrapper if needed.
cudagraph_mode = self.compilation_config.cudagraph_mode cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None assert cudagraph_mode is not None
if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: if (
cudagraph_mode.has_full_cudagraphs()
and not self.parallel_config.use_ubatching
):
self.model = CUDAGraphWrapper( self.model = CUDAGraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
) )
elif self.parallel_config.enable_dbo: elif self.parallel_config.use_ubatching:
if cudagraph_mode.has_full_cudagraphs(): if cudagraph_mode.has_full_cudagraphs():
self.model = UBatchWrapper( self.model = UBatchWrapper(
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
...@@ -4095,7 +4155,16 @@ class GPUModelRunner( ...@@ -4095,7 +4155,16 @@ class GPUModelRunner(
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
) )
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded should_ubatch,
num_scheduled_tokens,
num_tokens_padded,
num_reqs_padded,
self.vllm_config.parallel_config.num_ubatches,
)
logger.debug(
"ubatch_slices: %s, ubatch_slices_padded: %s",
ubatch_slices,
ubatch_slices_padded,
) )
attn_metadata: PerLayerAttnMetadata | None = None attn_metadata: PerLayerAttnMetadata | None = None
...@@ -4621,7 +4690,7 @@ class GPUModelRunner( ...@@ -4621,7 +4690,7 @@ class GPUModelRunner(
# is above the threshold. Otherwise we just capture a non-ubatched # is above the threshold. Otherwise we just capture a non-ubatched
# version of the graph # version of the graph
allow_microbatching = ( allow_microbatching = (
self.parallel_config.enable_dbo self.parallel_config.use_ubatching
and cudagraph_runtime_mode == CUDAGraphMode.FULL and cudagraph_runtime_mode == CUDAGraphMode.FULL
and uniform_decode and uniform_decode
and check_ubatch_thresholds( and check_ubatch_thresholds(
...@@ -4756,8 +4825,8 @@ class GPUModelRunner( ...@@ -4756,8 +4825,8 @@ class GPUModelRunner(
if kv_cache_group_id < len(kernel_block_sizes) if kv_cache_group_id < len(kernel_block_sizes)
else None, else None,
num_metadata_builders=1 num_metadata_builders=1
if not self.parallel_config.enable_dbo if not self.parallel_config.use_ubatching
else 2, else self.parallel_config.num_ubatches,
) )
# Calculate reorder batch threshold (if needed) # Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders, # Note (tdoublep): do this *after* constructing builders,
......
...@@ -103,8 +103,10 @@ class UBatchWrapper: ...@@ -103,8 +103,10 @@ class UBatchWrapper:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.comm_stream = torch.cuda.Stream(device=device) self.comm_stream = torch.cuda.Stream(device=device)
# Two ubatch threads plus the main thread # Ubatch threads plus the main thread
self.ready_barrier = threading.Barrier(3) self.ready_barrier = threading.Barrier(
self.vllm_config.parallel_config.num_ubatches + 1
)
self.cudagraphs: dict[int, CUDAGraphMetaData] = {} self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
...@@ -309,7 +311,7 @@ class UBatchWrapper: ...@@ -309,7 +311,7 @@ class UBatchWrapper:
create_forward_context( create_forward_context(
attn_metadata[i] if attn_metadata is not None else None, attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config, self.vllm_config,
dp_metadata=dp_metadata, dp_metadata=dp_metadata[i],
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
) )
...@@ -417,18 +419,19 @@ class UBatchWrapper: ...@@ -417,18 +419,19 @@ class UBatchWrapper:
# We shouldn't be here unless we are running with multiple DP ranks # We shouldn't be here unless we are running with multiple DP ranks
assert dp_metadata is not None assert dp_metadata is not None
num_tokens_per_ubatch = ( ubatch_dp_metadata = []
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start for ubatch_slice in ubatch_slices:
) dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_size = self.vllm_config.parallel_config.data_parallel_size ubatch_num_tokens_across_dp = torch.tensor(
ubatch_num_tokens_across_dp = torch.tensor( [ubatch_slice.num_tokens] * dp_size, device="cpu", dtype=torch.int32
[num_tokens_per_ubatch] * dp_size, device="cpu", dtype=torch.int32 )
) ubatch_dp_metadata.append(
ubatch_dp_metadata = DPMetadata.make( DPMetadata.make(
self.vllm_config.parallel_config, self.vllm_config.parallel_config,
num_tokens_per_ubatch, ubatch_slice.num_tokens,
ubatch_num_tokens_across_dp, ubatch_num_tokens_across_dp,
) )
)
if ( if (
num_tokens not in self.cudagraphs num_tokens not in self.cudagraphs
...@@ -464,7 +467,7 @@ class UBatchWrapper: ...@@ -464,7 +467,7 @@ class UBatchWrapper:
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
compute_stream=compute_stream, compute_stream=compute_stream,
dp_metadata=dp_metadata, dp_metadata=ubatch_dp_metadata,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
) )
......
...@@ -56,6 +56,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp ...@@ -56,6 +56,8 @@ from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from .utils import request_memory
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -237,22 +239,8 @@ class Worker(WorkerBase): ...@@ -237,22 +239,8 @@ class Worker(WorkerBase):
torch.cuda.empty_cache() torch.cuda.empty_cache()
# take current memory snapshot # take current memory snapshot
self.init_snapshot = MemorySnapshot() self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
self.requested_memory = ( self.requested_memory = request_memory(init_snapshot, self.cache_config)
self.init_snapshot.total_memory
* self.cache_config.gpu_memory_utilization
)
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
else: else:
raise RuntimeError(f"Not support device type: {self.device_config.device}") raise RuntimeError(f"Not support device type: {self.device_config.device}")
......
...@@ -51,7 +51,6 @@ class InputBatch: ...@@ -51,7 +51,6 @@ class InputBatch:
pin_memory=False, pin_memory=False,
) )
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros( self.num_computed_tokens_cpu_tensor = torch.zeros(
...@@ -200,9 +199,6 @@ class InputBatch: ...@@ -200,9 +199,6 @@ class InputBatch:
start_idx = num_prompt_tokens start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids) end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens. # Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens self.num_tokens_no_spec[req_index] = request.num_tokens
...@@ -344,10 +340,6 @@ class InputBatch: ...@@ -344,10 +340,6 @@ class InputBatch:
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i2],
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i1],
) )
self.num_tokens[i1], self.num_tokens[i2] = (
self.num_tokens[i2],
self.num_tokens[i1],
)
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = ( self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = (
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i2],
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i1],
...@@ -448,11 +440,10 @@ class InputBatch: ...@@ -448,11 +440,10 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index] num_tokens = self.num_tokens_no_spec[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens last_req_index, :num_tokens
] ]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index last_req_index
] ]
......
...@@ -1283,7 +1283,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1283,7 +1283,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
token_id = valid_sampled_token_ids[i][0] token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id) req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens[i] += 1 self.input_batch.num_tokens_no_spec[i] += 1
else: else:
valid_mask = selected_token_ids != INVALID_TOKEN_ID valid_mask = selected_token_ids != INVALID_TOKEN_ID
...@@ -1291,7 +1291,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1291,7 +1291,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
valid_sampled_token_ids = [ valid_sampled_token_ids = [
seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens)
] ]
self.input_batch.num_tokens[:num_reqs] += gen_lens self.input_batch.num_tokens_no_spec[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens: for i, req_state, seq_len in request_seq_lens:
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
self.input_batch.token_ids_cpu[i, target_slice] = ( self.input_batch.token_ids_cpu[i, target_slice] = (
......
...@@ -27,14 +27,16 @@ class UBatchSlice: ...@@ -27,14 +27,16 @@ class UBatchSlice:
UBatchSlices: TypeAlias = list[UBatchSlice] UBatchSlices: TypeAlias = list[UBatchSlice]
def is_second_ubatch_empty(orig_num_tokens: int, padded_num_tokens: int) -> bool: def is_last_ubatch_empty(
return (padded_num_tokens // 2) >= orig_num_tokens orig_num_tokens: int, padded_num_tokens: int, num_ubatches: int
) -> bool:
return (padded_num_tokens // num_ubatches) * (num_ubatches - 1) >= orig_num_tokens
def check_ubatch_thresholds( def check_ubatch_thresholds(
config: ParallelConfig, num_tokens: int, uniform_decode: bool config: ParallelConfig, num_tokens: int, uniform_decode: bool
) -> bool: ) -> bool:
if not config.enable_dbo: if not config.use_ubatching:
return False return False
if uniform_decode: if uniform_decode:
return num_tokens >= config.dbo_decode_token_threshold return num_tokens >= config.dbo_decode_token_threshold
...@@ -42,21 +44,17 @@ def check_ubatch_thresholds( ...@@ -42,21 +44,17 @@ def check_ubatch_thresholds(
return num_tokens >= config.dbo_prefill_token_threshold return num_tokens >= config.dbo_prefill_token_threshold
# This just pads the second ubatch slice out to the total number of tokens # This pads the last ubatch slice out to the total number of tokens
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding. # (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
def _pad_out_ubatch_slices( def _pad_out_ubatch_slices(
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
) -> UBatchSlices: ) -> UBatchSlices:
# TODO(lucas): handle empty second ubatch last_slice = ubatch_slices[-1]
padded_second_request_slice = slice( padded_last_request_slice = slice(last_slice.request_slice.start, num_reqs_padded)
ubatch_slices[1].request_slice.start, num_reqs_padded padded_last_token_slice = slice(last_slice.token_slice.start, num_total_tokens)
)
padded_second_token_slice = slice( return ubatch_slices[:-1] + [
ubatch_slices[1].token_slice.start, num_total_tokens UBatchSlice(padded_last_request_slice, padded_last_token_slice)
)
return [
ubatch_slices[0],
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
] ]
...@@ -65,40 +63,45 @@ def maybe_create_ubatch_slices( ...@@ -65,40 +63,45 @@ def maybe_create_ubatch_slices(
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
num_tokens_padded: int, num_tokens_padded: int,
num_reqs_padded: int, num_reqs_padded: int,
split_point: int | None = None, num_ubatches: int,
split_point: list[int] | int | None = None,
) -> tuple[UBatchSlices | None, UBatchSlices | None]: ) -> tuple[UBatchSlices | None, UBatchSlices | None]:
if not should_ubatch: if not should_ubatch:
return None, None return None, None
if split_point is None: if split_point is None:
split_point = int(num_tokens_padded) // 2 split_point = int(num_tokens_padded) // num_ubatches
token_split_points = [split_point * i for i in range(1, num_ubatches)]
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass # TODO(lucas): Refactor the gpu_model_runner.py so we can pass
# in cu_num_tokens directly (i.e. query_start_loc) # in cu_num_tokens directly (i.e. query_start_loc)
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:]) np.cumsum(num_scheduled_tokens, dtype=np.int32, out=cu_num_tokens[1:])
first_ubatch_token_slice = slice(0, split_point) ubatch_slices = []
second_ubatch_token_slice = slice(split_point, cu_num_tokens[-1]) start_token = 0
# Determine request slices using exclusive stop semantics # Add the end point to the split points to make iteration easier
# First ubatch includes requests whose tokens overlap [0, split_point) all_points = token_split_points + [cu_num_tokens[-1]]
first_ubatch_req_stop = int(
np.searchsorted(cu_num_tokens, split_point, side="left")
)
first_ubatch_req_slice = slice(0, first_ubatch_req_stop)
# Second ubatch starts at the request that contains the split_point for end_token in all_points:
# or the request starting exactly at split_point (if on boundary) token_slice = slice(start_token, end_token)
second_ubatch_req_start = int(
np.searchsorted(cu_num_tokens, split_point, side="right") - 1
)
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
ubatch_slices = [ # Determine request slices using exclusive stop semantics
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), # Ubatch includes requests whose tokens overlap [start_token, end_token)
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
] # Start at the request that contains the start_token
# or the request starting exactly at start_token (if on boundary)
req_start = int(np.searchsorted(cu_num_tokens, start_token, side="right") - 1)
# Stop at the request that starts at or after end_token
req_stop = int(np.searchsorted(cu_num_tokens, end_token, side="left"))
req_slice = slice(req_start, req_stop)
ubatch_slices.append(UBatchSlice(req_slice, token_slice))
start_token = end_token
ubatch_slices_padded = _pad_out_ubatch_slices( ubatch_slices_padded = _pad_out_ubatch_slices(
ubatch_slices, num_tokens_padded, num_reqs_padded ubatch_slices, num_tokens_padded, num_reqs_padded
......
...@@ -7,10 +7,15 @@ import torch ...@@ -7,10 +7,15 @@ import torch
from vllm import forward_context from vllm import forward_context
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.utils.torch_utils import current_stream from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__)
_THREAD_ID_TO_CONTEXT: dict = {} _THREAD_ID_TO_CONTEXT: dict = {}
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] # Here we hardcode the number of microbatches to 2 for default.
_NUM_UBATCHES: int = 2
_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = []
class UBatchContext: class UBatchContext:
...@@ -48,6 +53,7 @@ class UBatchContext: ...@@ -48,6 +53,7 @@ class UBatchContext:
global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT global _CURRENT_CONTEXTS, _THREAD_ID_TO_CONTEXT
_THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id _THREAD_ID_TO_CONTEXT[threading.get_ident()] = self.id
_CURRENT_CONTEXTS[self.id] = self _CURRENT_CONTEXTS[self.id] = self
# _NUM_UBATCHES is set in make_ubatch_contexts
self.ready_barrier.wait() self.ready_barrier.wait()
self.cpu_wait_event.wait() self.cpu_wait_event.wait()
...@@ -181,7 +187,7 @@ dbo_switch_to_compute_sync = _register_ubatch_function( ...@@ -181,7 +187,7 @@ dbo_switch_to_compute_sync = _register_ubatch_function(
def dbo_register_recv_hook(recv_hook): def dbo_register_recv_hook(recv_hook):
if len(_THREAD_ID_TO_CONTEXT) > 0: if len(_THREAD_ID_TO_CONTEXT) > 0:
ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()]
next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % 2] next_ctx = _CURRENT_CONTEXTS[(ctx_idx + 1) % _NUM_UBATCHES]
next_ctx.recv_hook = recv_hook next_ctx.recv_hook = recv_hook
...@@ -202,7 +208,14 @@ def make_ubatch_contexts( ...@@ -202,7 +208,14 @@ def make_ubatch_contexts(
ready_barrier: threading.Barrier, ready_barrier: threading.Barrier,
schedule: str = "default", schedule: str = "default",
) -> list[UBatchContext]: ) -> list[UBatchContext]:
assert num_micro_batches == 2, "only been tested with 2 micro-batches" global _NUM_UBATCHES, _CURRENT_CONTEXTS
assert num_micro_batches > 1, "num_micro_batches must be greater than 1"
_NUM_UBATCHES = num_micro_batches
# Ensure the global context list is large enough
if len(_CURRENT_CONTEXTS) < num_micro_batches:
_CURRENT_CONTEXTS.extend([None] * (num_micro_batches - len(_CURRENT_CONTEXTS)))
""" """
Create a context manager for micro-batching synchronization. Create a context manager for micro-batching synchronization.
""" """
...@@ -210,8 +223,6 @@ def make_ubatch_contexts( ...@@ -210,8 +223,6 @@ def make_ubatch_contexts(
gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)] gpu_comm_done_events = [torch.Event() for _ in range(num_micro_batches)]
gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)] gpu_compute_done_events = [torch.Event() for _ in range(num_micro_batches)]
assert len(forward_contexts) == 2
ctxs = [] ctxs = []
for i in range(num_micro_batches): for i in range(num_micro_batches):
ctx = UBatchContext( ctx = UBatchContext(
......
...@@ -8,13 +8,15 @@ from typing_extensions import deprecated ...@@ -8,13 +8,15 @@ from typing_extensions import deprecated
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
...@@ -248,6 +250,28 @@ def gather_mm_placeholders( ...@@ -248,6 +250,28 @@ def gather_mm_placeholders(
return placeholders[is_embed] return placeholders[is_embed]
def request_memory(init_snapshot: MemorySnapshot, cache_config: CacheConfig) -> float:
"""
Calculate the amount of memory required by vLLM, then validate
that the current amount of free memory is sufficient for that.
"""
requested_memory = init_snapshot.total_memory * cache_config.gpu_memory_utilization
if init_snapshot.free_memory < requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device {init_snapshot.device_} "
f"({GiB(init_snapshot.free_memory)}/"
f"{GiB(init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({cache_config.gpu_memory_utilization}, "
f"{GiB(requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
return requested_memory
def add_kv_sharing_layers_to_kv_cache_groups( def add_kv_sharing_layers_to_kv_cache_groups(
shared_kv_cache_layers: dict[str, str], shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list[KVCacheGroupSpec], kv_cache_groups: list[KVCacheGroupSpec],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment