Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
...@@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module): ...@@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module):
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling.""" """More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if k is None and p is None: if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is # We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require # not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does. # CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
if generators: if generators:
logger.warning("FlashInfer 0.2.3+ does not support " logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to " "per-request generators. Falling back to "
"PyTorch-native implementation.") "PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p) return self.forward_native(logits, generators, k, p)
return flashinfer_sample(probs, k, p, generators) return flashinfer_sample(logits, k, p, generators)
def forward_tpu( def forward_tpu(
self, self,
...@@ -254,12 +254,12 @@ def random_sample( ...@@ -254,12 +254,12 @@ def random_sample(
def flashinfer_sample( def flashinfer_sample(
probs: torch.Tensor, logits: torch.Tensor,
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
generators: dict[int, torch.Generator], generators: dict[int, torch.Generator],
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer. """Sample from the logits using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function. Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor However, this function is faster because it avoids sorting the logits tensor
...@@ -274,18 +274,19 @@ def flashinfer_sample( ...@@ -274,18 +274,19 @@ def flashinfer_sample(
the synchronization overhead. the synchronization overhead.
""" """
assert not (k is None and p is None) assert not (k is None and p is None)
if k is None: if k is None:
# Top-p only. # Top-p only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True) probs, p, deterministic=True)
elif p is None: elif p is None:
# Top-k only. # Top-k only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True) probs, k, deterministic=True)
else: else:
# Both top-k and top-p. # Both top-k and top-p.
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs( next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
probs, k, p, deterministic=True)) logits, k, p, deterministic=True)
return next_token_ids.view(-1) return next_token_ids.view(-1)
...@@ -4,17 +4,17 @@ import torch.nn as nn ...@@ -4,17 +4,17 @@ import torch.nn as nn
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata FlashAttentionMetadata)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -27,12 +27,15 @@ class EagleProposer: ...@@ -27,12 +27,15 @@ class EagleProposer:
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
runner=None,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.draft_model_config = self.speculative_config.draft_model_config self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method self.method = self.speculative_config.method
self.runner = runner
self.dtype = vllm_config.model_config.dtype self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
...@@ -108,9 +111,11 @@ class EagleProposer: ...@@ -108,9 +111,11 @@ class EagleProposer:
# FA requires seq_len to have dtype int32. # FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int() seq_lens = (target_positions[last_token_indices] + 1).int()
if self.method in ["eagle", "eagle3"]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize. # FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item() max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() max_num_tokens = (cu_num_tokens[1:] -
cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=max_num_tokens, max_query_len=max_num_tokens,
...@@ -126,6 +131,31 @@ class EagleProposer: ...@@ -126,6 +131,31 @@ class EagleProposer:
prefix_kv_lens=None, prefix_kv_lens=None,
suffix_kv_lens=None, suffix_kv_lens=None,
) )
elif self.method == "deepseek_mtp":
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise ValueError(f"Unsupported method: {self.method}")
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \ if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]: num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
...@@ -135,14 +165,18 @@ class EagleProposer: ...@@ -135,14 +165,18 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
with set_forward_context(attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model( ret_hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens], self.hidden_states[:num_input_tokens],
) )
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
...@@ -152,6 +186,10 @@ class EagleProposer: ...@@ -152,6 +186,10 @@ class EagleProposer:
# [batch_size, 1] # [batch_size, 1]
return draft_token_ids.view(-1, 1) return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
...@@ -213,13 +251,13 @@ class EagleProposer: ...@@ -213,13 +251,13 @@ class EagleProposer:
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:batch_size] = hidden_states
# Run the model. # Run the model.
with set_forward_context(attn_metadata, with set_forward_context(per_layer_attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size): num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size], self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size], self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size], self.hidden_states[:input_batch_size],
) )
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size],
...@@ -239,6 +277,7 @@ class EagleProposer: ...@@ -239,6 +277,7 @@ class EagleProposer:
cu_target_query_lens: torch.Tensor, cu_target_query_lens: torch.Tensor,
# [batch_size] # [batch_size]
num_rejected_tokens: torch.Tensor, num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c] # cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3] # num_rejected_tokens: [n1, n2, n3]
...@@ -256,21 +295,16 @@ class EagleProposer: ...@@ -256,21 +295,16 @@ class EagleProposer:
# [a - n1, b - n2, c - n3] -> # [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.empty_like(cu_target_query_lens) cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty( token_indices = torch.empty(
num_tokens, num_tokens,
dtype=torch.int32, dtype=torch.int32,
device=cu_num_tokens.device, device=cu_target_query_lens.device,
) )
batch_size = num_rejected_tokens.shape[0] batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
prepare_input_kernel[(batch_size, )]( prepare_eagle_input_kernel[(batch_size, )](
token_indices, token_indices,
cu_target_query_lens, cu_target_query_lens,
cu_num_tokens, cu_num_tokens,
...@@ -279,48 +313,28 @@ class EagleProposer: ...@@ -279,48 +313,28 @@ class EagleProposer:
return cu_num_tokens, token_indices return cu_num_tokens, token_indices
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
loader = get_model_loader(self.vllm_config.load_config) draft_model_config = \
target_layer_num = self.vllm_config.model_config.get_num_layers( self.vllm_config.speculative_config.draft_model_config
self.vllm_config.parallel_config)
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys()) get_layers_from_vllm_config(self.vllm_config, Attention).keys())
draft_model_config = \ self.model = get_model(vllm_config=self.vllm_config,
self.vllm_config.speculative_config.draft_model_config model_config=draft_model_config)
# FIXME(lily): This does not handle with distributed inference.
target_device = self.vllm_config.device_config.device
# We need to set the vllm_config here to register attention
# layers in the forward context.
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
draft_model_config.architectures)
self.model = draft_model_cls(
vllm_config=self.vllm_config,
start_layer_id=target_layer_num).to(target_device)
draft_attn_layer_names = ( draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() - get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names) target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = next(iter(draft_attn_layer_names)) self.attn_layer_names = list(draft_attn_layer_names)
loaded_weights = self.model.load_weights(
loader.get_all_weights(draft_model_config, self.model))
# share embed_tokens with the target model if needed # share embed_tokens with the target model if needed
if get_pp_group().world_size == 1: if get_pp_group().world_size == 1:
assert "model.embed_tokens.weight" not in loaded_weights, \
"For PP = 1, Eagle draft should share embed with target model"
logger.info( logger.info(
"The EAGLE head shares the same vocab embedding" \ "The EAGLE head shares the same vocab embedding" \
" with the target model." " with the target model."
) )
self.model.model.embed_tokens = target_model.model.embed_tokens self.model.model.embed_tokens = target_model.model.embed_tokens
else: else:
assert "model.embed_tokens.weight" in loaded_weights, \
"For PP > 1, Eagle draft checkpoint should its own copy of "
" the model.embed_tokens.weight"
logger.info( logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \ "Since PP > 1, the EAGLE head loaded its own vocab embedding" \
" weights instead of sharing them with the target model." " weights instead of sharing them with the target model."
...@@ -342,11 +356,30 @@ class EagleProposer: ...@@ -342,11 +356,30 @@ class EagleProposer:
with set_forward_context(None, self.vllm_config, with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
self.model( self.model(
input_ids=self.input_ids[:num_tokens], self.input_ids[:num_tokens],
positions=self.positions[:num_tokens], self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens], self.hidden_states[:num_tokens],
) )
def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
Need this assumption to ensure all eagle layers can use the
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
"""
kv_cache_groups: dict[str, int] = {}
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
kv_cache_groups[layer_name] = id
assert len(
set([
kv_cache_groups[layer_name]
for layer_name in self.attn_layer_names
])
) == 1, "All eagle layers should belong to the same kv cache group"
# NOTE(woosuk): Currently, the below code is not used and we always use argmax # NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage # to sample the draft tokens. We will use this after we find a way to manage
...@@ -389,29 +422,3 @@ def compute_probs_and_sample_next_token( ...@@ -389,29 +422,3 @@ def compute_probs_and_sample_next_token(
next_token_ids, next_token_ids,
) )
return next_token_ids, probs return next_token_ids, probs
@triton.jit
def prepare_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)
...@@ -3,12 +3,10 @@ ...@@ -3,12 +3,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.medusa import Medusa
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger # Initialize logger
...@@ -49,20 +47,9 @@ class MedusaProposer: ...@@ -49,20 +47,9 @@ class MedusaProposer:
return [list(row) for row in zip(*draft_tokens)] return [list(row) for row in zip(*draft_tokens)]
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
# Get model loader and config self.model = get_model(vllm_config=self.vllm_config,
loader = get_model_loader(self.vllm_config.load_config) model_config=self.vllm_config.
draft_config = self.vllm_config.speculative_config.draft_model_config speculative_config.draft_model_config)
# Load model with proper dtype and config
with set_default_torch_dtype(draft_config.dtype), \
set_current_vllm_config(self.vllm_config):
self.model = Medusa(
vllm_config=self.vllm_config.speculative_config).to(
self.device)
# Load model weights
weights = loader.get_all_weights(draft_config, self.model)
self.model.load_weights(weights)
@torch.inference_mode() @torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None: def dummy_run(self, num_tokens: int) -> None:
......
...@@ -134,17 +134,17 @@ class SpecDecodingProm: ...@@ -134,17 +134,17 @@ class SpecDecodingProm:
self.counter_spec_decode_num_drafts = \ self.counter_spec_decode_num_drafts = \
self._counter_cls( self._counter_cls(
name="vllm:spec_decode_num_drafts_total", name="vllm:spec_decode_num_drafts",
documentation="Number of spec decoding drafts.", documentation="Number of spec decoding drafts.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.counter_spec_decode_num_draft_tokens = \ self.counter_spec_decode_num_draft_tokens = \
self._counter_cls( self._counter_cls(
name="vllm:spec_decode_num_draft_tokens_total", name="vllm:spec_decode_num_draft_tokens",
documentation="Number of draft tokens.", documentation="Number of draft tokens.",
labelnames=labelnames,).labels(*labelvalues) labelnames=labelnames,).labels(*labelvalues)
self.counter_spec_decode_num_accepted_tokens = \ self.counter_spec_decode_num_accepted_tokens = \
self._counter_cls( self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total", name="vllm:spec_decode_num_accepted_tokens",
documentation="Number of accepted tokens.", documentation="Number of accepted tokens.",
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
...@@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: ...@@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
return False return False
return True return True
@triton.jit
def prepare_eagle_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import re import regex as re
def grammar_is_likely_lark(grammar_str: str) -> bool: def grammar_is_likely_lark(grammar_str: str) -> bool:
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -105,15 +104,10 @@ class MultiGroupBlockTable: ...@@ -105,15 +104,10 @@ class MultiGroupBlockTable:
def __init__(self, max_num_reqs: int, max_model_len: int, def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool, max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, kv_cache_config: KVCacheConfig) -> None: device: torch.device, block_size: int) -> None:
max_num_blocks_per_req = [
cdiv(max_model_len, g.kv_cache_spec.block_size)
for g in kv_cache_config.kv_cache_groups
]
self.block_tables = [ self.block_tables = [
BlockTable(max_num_reqs, max_num_blocks_per_req[i], BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
max_num_batched_tokens, pin_memory, device) max_num_batched_tokens, pin_memory, device)
for i in range(len(kv_cache_config.kv_cache_groups))
] ]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None: def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
......
...@@ -11,7 +11,6 @@ from vllm.lora.request import LoRARequest ...@@ -11,7 +11,6 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import swap_dict_values
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice from vllm.v1.utils import copy_slice
...@@ -63,7 +62,7 @@ class InputBatch: ...@@ -63,7 +62,7 @@ class InputBatch:
device: torch.device, device: torch.device,
pin_memory: bool, pin_memory: bool,
vocab_size: int, vocab_size: int,
kv_cache_config: KVCacheConfig, block_size: int,
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
...@@ -105,7 +104,7 @@ class InputBatch: ...@@ -105,7 +104,7 @@ class InputBatch:
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory, pin_memory=pin_memory,
device=device, device=device,
kv_cache_config=kv_cache_config, block_size=block_size,
) )
# Sampling-related. # Sampling-related.
......
...@@ -27,15 +27,15 @@ from vllm.distributed.parallel_state import ( ...@@ -27,15 +27,15 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import TensorizerLoader, get_model
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
is_pin_memory_available) check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
...@@ -63,6 +63,7 @@ from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, ...@@ -63,6 +63,7 @@ from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
else: else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
...@@ -150,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -150,12 +151,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
if self.speculative_config: if self.speculative_config:
self.use_spec_decode = True self.use_spec_decode = True
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config) self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.drafter = EagleProposer(self.vllm_config, self.device,
self.device) # type: ignore self) # type: ignore
if self.speculative_config.method == "eagle3": if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa": elif self.speculative_config.method == "medusa":
...@@ -170,6 +175,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -170,6 +175,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE == CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager) and not self.model_config.enforce_eager)
...@@ -914,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -914,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = [] encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list: for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, batched_mm_inputs = MultiModalKwargs.as_kwargs(
device=self.device) batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
...@@ -1348,7 +1366,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1348,7 +1366,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_token_ids = torch.tensor(next_token_ids, next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] # At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_names[0]]
# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
block_table = eagle_attn_metadata.block_table
else:
block_table = None
if spec_decode_metadata is None: if spec_decode_metadata is None:
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
...@@ -1369,14 +1396,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1369,14 +1396,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens) for i, n in enumerate(num_draft_tokens)
] ]
num_rejected_tokens = torch.tensor( num_rejected_tokens_tensor = async_tensor_h2d(
num_rejected_tokens, num_rejected_tokens,
dtype=torch.int32, dtype=torch.int32,
device=self.device, target_device=self.device,
) pin_memory=True)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs( cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc, eagle_attn_metadata.query_start_loc,
num_rejected_tokens, num_rejected_tokens_tensor,
num_tokens,
) )
target_token_ids = self.input_ids[token_indices] target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices] target_positions = positions[token_indices]
...@@ -1387,7 +1416,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1387,7 +1416,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[ target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices] token_indices]
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
...@@ -1395,7 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1395,7 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_slot_mapping=target_slot_mapping, target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens, cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_table, block_table=block_table,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
...@@ -1523,6 +1551,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1523,6 +1551,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
time_after_load - time_before_load) time_after_load - time_before_load)
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
TensorizerLoader.save_model(
self.model,
tensorizer_config=tensorizer_config,
)
def _get_prompt_logprobs_dict( def _get_prompt_logprobs_dict(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1703,8 +1740,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1703,8 +1740,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
hidden_states = outputs hidden_states = outputs
if self.use_spec_decode and \ if self.use_spec_decode and self.speculative_config.use_eagle():
self.speculative_config.method in ('eagle', 'eagle3'):
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens)
...@@ -1716,6 +1752,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1716,6 +1752,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# The dummy hidden states may contain special values,
# like `inf` or `nan`.
# To avoid breaking the sampler, we use a random tensor here instead.
hidden_states = torch.rand_like(hidden_states)
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)
num_reqs = logits.size(0) num_reqs = logits.size(0)
...@@ -1837,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1837,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batched_dummy_mm_inputs = MultiModalKwargs.batch( batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items) [dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device) batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run multimodal encoder. # Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings( dummy_encoder_outputs = self.model.get_multimodal_embeddings(
...@@ -1947,16 +1990,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1947,16 +1990,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer cache size of each layer
""" """
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config,
)
self.initialize_attn_backend(kv_cache_config) self.initialize_attn_backend(kv_cache_config)
kv_caches: dict[str, torch.Tensor] = {} kv_caches: dict[str, torch.Tensor] = {}
...@@ -1988,6 +2026,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1988,6 +2026,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV cache specs. # KV cache specs.
raise ValueError("Unknown KV cache spec type.") raise ValueError("Unknown KV cache spec type.")
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)
bind_kv_cache( bind_kv_cache(
kv_caches, kv_caches,
self.vllm_config.compilation_config.static_forward_context, self.vllm_config.compilation_config.static_forward_context,
......
...@@ -31,6 +31,7 @@ from vllm.v1.worker.worker_base import WorkerBase ...@@ -31,6 +31,7 @@ from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -171,10 +172,9 @@ class Worker(WorkerBase): ...@@ -171,10 +172,9 @@ class Worker(WorkerBase):
Then, it calculate the free memory that can be used for KV cache in Then, it calculate the free memory that can be used for KV cache in
bytes. bytes.
:::{tip} Tip:
You may limit the usage of GPU memory You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter. by adjusting the `gpu_memory_utilization` parameter.
:::
""" """
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
...@@ -326,6 +326,13 @@ class Worker(WorkerBase): ...@@ -326,6 +326,13 @@ class Worker(WorkerBase):
max_size=max_size, max_size=max_size,
) )
def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
) -> None:
self.model_runner.save_tensorized_model(
tensorizer_config=tensorizer_config, )
def init_worker_distributed_environment( def init_worker_distributed_environment(
vllm_config: VllmConfig, vllm_config: VllmConfig,
...@@ -341,8 +348,7 @@ def init_worker_distributed_environment( ...@@ -341,8 +348,7 @@ def init_worker_distributed_environment(
distributed_init_method, local_rank) distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
......
...@@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = [] encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list: for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, batched_mm_inputs = MultiModalKwargs.as_kwargs(
device=self.device) batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
...@@ -1261,7 +1264,8 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1261,7 +1264,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config, block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
block_size,
) )
assert self.block_table_cpu.dtype == self.input_batch.block_table[ assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype 0].get_cpu_tensor().dtype
...@@ -1434,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1434,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
batch_size) batch_size)
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, return MultiModalKwargs.as_kwargs(
device=self.device) batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
......
...@@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment( ...@@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment(
backend="gloo", backend="gloo",
) )
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
try: try:
......
...@@ -10,7 +10,7 @@ def sanity_check_mm_encoder_outputs( ...@@ -10,7 +10,7 @@ def sanity_check_mm_encoder_outputs(
) -> None: ) -> None:
""" """
Perform sanity checks for the result of Perform sanity checks for the result of
{meth}`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`. [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
""" """
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), ( assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, " "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
...@@ -39,7 +39,7 @@ def scatter_mm_placeholders( ...@@ -39,7 +39,7 @@ def scatter_mm_placeholders(
Scatter the multimodal embeddings into a contiguous tensor that represents Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens. the placeholder tokens.
{class}`vllm.multimodal.processing.PromptUpdateDetails.is_embed`. [`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
Args: Args:
embeds: The multimodal embeddings. embeds: The multimodal embeddings.
...@@ -66,7 +66,7 @@ def gather_mm_placeholders( ...@@ -66,7 +66,7 @@ def gather_mm_placeholders(
""" """
Reconstructs the embeddings from the placeholder tokens. Reconstructs the embeddings from the placeholder tokens.
This is the operation of {func}`scatter_mm_placeholders`. This is the operation of [scatter_mm_placeholders][].
""" """
if is_embed is None: if is_embed is None:
return placeholders return placeholders
......
...@@ -297,8 +297,11 @@ class CPUEncoderDecoderModelRunner( ...@@ -297,8 +297,11 @@ class CPUEncoderDecoderModelRunner(
model_input.encoder_input_tokens, model_input.encoder_input_tokens,
"encoder_positions": "encoder_positions":
model_input.encoder_input_positions, model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(
device=self.device), model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
"intermediate_tensors": "intermediate_tensors":
intermediate_tensors, intermediate_tensors,
} }
......
...@@ -628,7 +628,10 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): ...@@ -628,7 +628,10 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
multimodal_kwargs = {} multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None: if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs( multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device) model_input.multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
)
execute_model_kwargs = {} execute_model_kwargs = {}
if previous_hidden_states is not None: if previous_hidden_states is not None:
execute_model_kwargs.update( execute_model_kwargs.update(
......
...@@ -50,8 +50,11 @@ class CPUPoolingModelRunner( ...@@ -50,8 +50,11 @@ class CPUPoolingModelRunner(
model_input.input_tokens, model_input.input_tokens,
"positions": "positions":
model_input.input_positions, model_input.input_positions,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(
device=self.device), model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs, **cross_enc_kwargs,
"intermediate_tensors": "intermediate_tensors":
intermediate_tensors, intermediate_tensors,
......
...@@ -390,8 +390,7 @@ class CPUWorker(LocalOrDistributedWorkerBase): ...@@ -390,8 +390,7 @@ class CPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block. """Return the size in bytes of a single KV cache block.
......
...@@ -202,9 +202,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -202,9 +202,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
encoder_input_ids=model_input.encoder_input_tokens, encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions, encoder_positions=model_input.encoder_input_positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(
device=self.device), multi_modal_kwargs,
**seqlen_agnostic_kwargs) dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
)
logits = self.model.compute_logits(hidden_or_intermediate_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
......
...@@ -201,10 +201,9 @@ class HPUWorker(LocalOrDistributedWorkerBase): ...@@ -201,10 +201,9 @@ class HPUWorker(LocalOrDistributedWorkerBase):
Then, it calculate the maximum possible number of GPU and CPU blocks Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory. that can be allocated with the remaining free memory.
:::{tip} Tip:
You may limit the usage of GPU memory You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter. by adjusting the `gpu_memory_utilization` parameter.
:::
""" """
# Profile the memory usage of the model and get the maximum number of # Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory. # cache blocks that can be allocated with the remaining free memory.
...@@ -416,8 +415,7 @@ def init_worker_distributed_environment( ...@@ -416,8 +415,7 @@ def init_worker_distributed_environment(
backend='hccl') backend='hccl')
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size() torch_world_size = torch.distributed.get_world_size()
...@@ -443,8 +441,7 @@ def init_worker_distributed_environment( ...@@ -443,8 +441,7 @@ def init_worker_distributed_environment(
torch.distributed.all_reduce(dummy_tensor_hpu) torch.distributed.all_reduce(dummy_tensor_hpu)
assert dummy_tensor_hpu.item() == parallel_config.world_size assert dummy_tensor_hpu.item() == parallel_config.world_size
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size, parallel_config.pipeline_parallel_size)
parallel_config.enable_expert_parallel)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len, def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len,
......
...@@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState ...@@ -23,7 +23,7 @@ from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
...@@ -729,7 +729,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -729,7 +729,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata, seq_group_metadata,
range(positions[0], positions[0] + len(positions))) range(positions[0], positions[0] + len(positions)))
if not mm_kwargs:
# M-RoPE requires mrope_positions even for plain text; return early
# when mm_kwargs is empty only if inter_data.is_prompt is False.
if not mm_kwargs and not inter_data.is_prompt:
return return
inter_data.multi_modal_kwargs = mm_kwargs inter_data.multi_modal_kwargs = mm_kwargs
...@@ -741,12 +744,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -741,12 +744,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
video_grid_thw = mm_kwargs.get("video_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None)
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
None) None)
assert (
image_grid_thw is not None or video_grid_thw is not None
or audio_feature_lengths is not None), (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw' or "
"'audio_feature_lengths'.")
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
...@@ -872,7 +869,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -872,7 +869,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = list[int]() input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]() inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]() token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
...@@ -880,15 +877,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -880,15 +877,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for cur_token_types in inter_data.token_types: for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types) token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None: if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append( inputs_embeds_list.append(
inter_data.inputs_embeds.to( inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device)) device=self.runner.device))
inputs_embeds: Optional[torch.Tensor] inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0: if len(inputs_embeds_list) == 0:
inputs_embeds = None inputs_embeds = None
else: else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype, dtype=self.runner.model_config.dtype,
device=self.runner.device) device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens) assert len(inputs_embeds) == len(input_tokens)
...@@ -1848,8 +1845,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1848,8 +1845,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
inputs_embeds=model_input.inputs_embeds, inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions, positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(
device=self.device), multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs, **seqlen_agnostic_kwargs,
**model_kwargs, **model_kwargs,
) )
...@@ -1893,15 +1893,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1893,15 +1893,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
logits = self.model.compute_logits(hidden_or_intermediate_states, logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
if not self.is_driver_worker: if self.is_driver_worker:
return []
if model_input.async_callback is not None: if model_input.async_callback is not None:
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
assert isinstance(self.sampler, Sampler) assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None: if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True self.sampler.include_gpu_probs_tensor = True
...@@ -1919,24 +1917,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1919,24 +1917,36 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if intermediate_tensors is not None: if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get( orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item() "model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency # If there are multiple workers, we are still tracking the
# from the start time of the driver worker to the end time of the # latency from the start time of the driver worker to the end
# driver worker. The model forward time will then end up covering # time of the driver worker. The model forward time will then
# the communication time as well. # end up covering the communication time as well.
output.model_forward_time = (orig_model_forward_time + output.model_forward_time = (orig_model_forward_time +
model_forward_time) model_forward_time)
if model_input.inputs_embeds is not None: if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \ self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor orig_include_gpu_probs
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings( output.sampled_token_embeds = sampled_token_embeds
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip( for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs): output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1 assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed sequence_group_output.samples[
0].output_embed = token_embed
if not self.is_driver_worker:
return []
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment