Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
...@@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict, ...@@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
...@@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, execute_model_req: Optional[ExecuteModelRequest] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list) num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None assert execute_model_req is not None
assert blocks_to_swap_out is not None blocks_to_copy = execute_model_req.blocks_to_copy
assert blocks_to_copy is not None assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0
assert len(blocks_to_swap_out) == 0
data: Dict[str, Any] = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy, "blocks_to_copy": execute_model_req.blocks_to_copy,
} }
broadcast_tensor_dict(data, src=0) broadcast_tensor_dict(data, src=0)
else: else:
...@@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups = data["num_seq_groups"] num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy) self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
......
...@@ -9,6 +9,7 @@ import torch.nn as nn ...@@ -9,6 +9,7 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend) get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
...@@ -20,12 +21,11 @@ from vllm.lora.request import LoRARequest ...@@ -20,12 +21,11 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip, from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad, is_pin_memory_available, make_tensor_with_pad)
maybe_expand_dim)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -43,8 +43,8 @@ class PreparePromptMetadata(NamedTuple): ...@@ -43,8 +43,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens: List[int] input_tokens: List[int]
input_positions: List[int] input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage] attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int] seq_lens: List[int]
subquery_lens: List[int] query_lens: List[int]
lora_index_mapping: List[int] lora_index_mapping: List[int]
lora_prompt_mapping: List[int] lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest] lora_requests: Set[LoRARequest]
...@@ -57,8 +57,8 @@ class PreparePromptMetadata(NamedTuple): ...@@ -57,8 +57,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens=[], input_tokens=[],
input_positions=[], input_positions=[],
attn_metadata=None, attn_metadata=None,
prompt_lens=[], seq_lens=[],
subquery_lens=[], query_lens=[],
lora_index_mapping=[], lora_index_mapping=[],
lora_prompt_mapping=[], lora_prompt_mapping=[],
lora_requests=set(), lora_requests=set(),
...@@ -135,9 +135,8 @@ class ModelRunner: ...@@ -135,9 +135,8 @@ class ModelRunner:
self.graph_memory_pool: Optional[Tuple[ self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture. int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = ( self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
self.model_config.max_context_len_to_capture if self.model_config is not None else 0)
if self.model_config is not None else 0)
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
...@@ -150,13 +149,16 @@ class ModelRunner: ...@@ -150,13 +149,16 @@ class ModelRunner:
self.model: torch.nn.Module # Set after load_model self.model: torch.nn.Module # Set after load_model
self.block_size: int # Set after initial profiling. self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to # When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in # max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table # Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration. # in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be # The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size). # (max batch size to capture, max context len to capture / block size).
self.graph_block_tables: torch.Tensor # Set after initial profiling. self.graph_block_tables: torch.Tensor # Set after initial profiling.
# Set if the backend is flashinfer.
self.flashinfer_workspace_buffer: torch.Tensor
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model( self.model = get_model(
...@@ -170,8 +172,8 @@ class ModelRunner: ...@@ -170,8 +172,8 @@ class ModelRunner:
) )
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took " logger.info("Loading model weights took %.4f GB",
f"{self.model_memory_usage / float(2**30):.4f} GB") self.model_memory_usage / float(2**30))
if self.lora_config: if self.lora_config:
assert hasattr(self.model, "supported_lora_modules" assert hasattr(self.model, "supported_lora_modules"
...@@ -196,18 +198,19 @@ class ModelRunner: ...@@ -196,18 +198,19 @@ class ModelRunner:
self.model.load_kv_cache_scales( self.model.load_kv_cache_scales(
self.model_config.quantization_param_path) self.model_config.quantization_param_path)
else: else:
raise RuntimeError("Using FP8 KV cache and scaling " raise RuntimeError(
"factors provided but model " "Using FP8 KV cache and scaling factors provided but "
f"{self.model.__class__} does not " "model %s does not support loading scaling factors.",
"support loading scaling factors.") self.model.__class__)
else: else:
logger.warn("Using FP8 KV cache but no scaling factors " logger.warning(
"provided. Defaulting to scaling factors of 1.0. " "Using FP8 KV cache but no scaling factors "
"This may lead to less accurate results!") "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None: elif self.model_config.quantization_param_path is not None:
logger.warn("KV cache scaling factors provided, " logger.warning("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. " "but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.") "KV cache scaling factors will not be used.")
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
self.block_size = block_size self.block_size = block_size
...@@ -218,7 +221,7 @@ class ModelRunner: ...@@ -218,7 +221,7 @@ class ModelRunner:
def get_max_block_per_batch(self) -> int: def get_max_block_per_batch(self) -> int:
block_size = self.block_size block_size = self.block_size
return (self.max_context_len_to_capture + block_size - 1) // block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_prompt( def _prepare_prompt(
self, self,
...@@ -231,9 +234,9 @@ class ModelRunner: ...@@ -231,9 +234,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = [] seq_lens: List[int] = []
context_lens: List[int] = [] context_lens: List[int] = []
subquery_lens: List[int] = [] query_lens: List[int] = []
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
...@@ -257,21 +260,19 @@ class ModelRunner: ...@@ -257,21 +260,19 @@ class ModelRunner:
token_chunk_size = seq_group_metadata.token_chunk_size token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
computed_len = seq_data.get_num_computed_tokens() context_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption # We should use get_len here because in case of preemption
# it contains output tokens. # it contains output tokens.
prefill_end = min(seq_data.get_len(), seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
computed_len + token_chunk_size) prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] seq_lens.append(seq_len)
prompt_len = prefill_end
prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
if computed_block_nums is not None and len( if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None: computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window # Prefix is not supported with sliding_window
computed_len = len(computed_block_nums) * self.block_size context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:] prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums) prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled: elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None: if seq_group_metadata.block_tables is not None:
...@@ -285,25 +286,25 @@ class ModelRunner: ...@@ -285,25 +286,25 @@ class ModelRunner:
prefix_block_tables.append([]) prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this # Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced. # assumption can be changed once chunked prefill is introduced.
assert computed_len == 0 assert context_len == 0
# actual prompt lens # actual prompt lens
context_lens.append(computed_len) context_lens.append(context_len)
subquery_lens.append(prompt_len - computed_len) query_lens.append(seq_len - context_len)
input_tokens.extend(prompt_tokens) input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prefill_end))) input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id lora_id = seq_group_metadata.lora_int_id
if lora_id > 0: if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request) lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (prompt_len - computed_len) lora_index_mapping += [lora_id] * (seq_len - context_len)
lora_prompt_mapping.extend( lora_prompt_mapping.extend(
[lora_id] * [lora_id] *
(prompt_len - computed_len (seq_len - context_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.multi_modal_data: if seq_group_metadata.multi_modal_data:
...@@ -313,24 +314,25 @@ class ModelRunner: ...@@ -313,24 +314,25 @@ class ModelRunner:
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue continue
# Compute the slot mapping. # Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window). # where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and # For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot # block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0 start_idx = 0
if self.sliding_window is not None: if self.sliding_window is not None:
assert computed_len == 0, ( assert context_len == 0, (
"Prefix caching is currently not supported with " "Prefix caching is currently not supported with "
"sliding window attention") "sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window) start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prefill_end): for i in range(context_len, seq_len):
if i < start_idx: if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
continue continue
...@@ -340,9 +342,9 @@ class ModelRunner: ...@@ -340,9 +342,9 @@ class ModelRunner:
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append(slot) slot_mapping.append(slot)
max_subquery_len = max(subquery_lens) max_query_len = max(query_lens)
max_prompt_len = max(prompt_lens) max_seq_len = max(seq_lens)
assert max_subquery_len > 0 assert max_query_len > 0
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
...@@ -369,50 +371,57 @@ class ModelRunner: ...@@ -369,50 +371,57 @@ class ModelRunner:
# Query length can be shorter than key (i.e., prompt) when prefill # Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached. # is chunked or prefix cached.
subquery_lens_tensor = torch.tensor(subquery_lens, query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long, dtype=torch.int,
device=self.device) device=self.device)
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
torch.cumsum(subquery_lens_tensor, torch.cumsum(query_lens_tensor,
dim=0, dim=0,
dtype=subquery_start_loc.dtype, dtype=subquery_start_loc.dtype,
out=subquery_start_loc[1:]) out=subquery_start_loc[1:])
torch.cumsum(prompt_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
out=seq_start_loc[1:]) out=seq_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata( if self.attn_backend is FlashInferBackend:
is_prompt=True, attn_metadata = self.attn_backend.make_metadata(
prompt_lens=prompt_lens, is_prompt=True,
prompt_lens_tensor=prompt_lens_tensor, use_cuda_graph=False,
max_subquery_len=max_subquery_len, seq_start_loc=seq_start_loc,
max_context_len=None, max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len, block_tables=block_tables)
subquery_start_loc=subquery_start_loc, else:
seq_start_loc=seq_start_loc, attn_metadata = self.attn_backend.make_metadata(
context_lens=context_lens_tensor, is_prompt=True,
block_tables=block_tables, seq_lens=seq_lens,
use_cuda_graph=False, seq_lens_tensor=seq_lens_tensor,
) max_query_len=max_query_len,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
return PreparePromptMetadata( return PreparePromptMetadata(
input_tokens=input_tokens, input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
prompt_lens=prompt_lens, seq_lens=seq_lens,
subquery_lens=subquery_lens, query_lens=query_lens,
lora_index_mapping=lora_index_mapping, lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping, lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests, lora_requests=lora_requests,
...@@ -427,12 +436,30 @@ class ModelRunner: ...@@ -427,12 +436,30 @@ class ModelRunner:
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
context_lens: List[int] = [] seq_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
lora_index_mapping: List[int] = [] lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0: if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty() return PrepareDecodeMetadata.empty()
...@@ -455,9 +482,9 @@ class ModelRunner: ...@@ -455,9 +482,9 @@ class ModelRunner:
position = seq_len - 1 position = seq_len - 1
input_positions.append(position) input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min( seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window) seq_len, self.sliding_window)
context_lens.append(context_len) seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
...@@ -473,15 +500,21 @@ class ModelRunner: ...@@ -473,15 +500,21 @@ class ModelRunner:
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table) block_tables.append(block_table)
paged_kv_indices.extend(block_table)
paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
last_page_len = seq_data.get_len() % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
# vLLM uses cuda graph only for decoding requests. # vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details. # See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens. # For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens) batch_size = len(input_tokens)
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
use_captured_graph = ( use_captured_graph = (not self.model_config.enforce_eager
not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_seq_len <= self.max_seq_len_to_capture)
and max_context_len <= self.max_context_len_to_capture)
if use_captured_graph: if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size) graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size assert graph_batch_size >= batch_size
...@@ -489,21 +522,21 @@ class ModelRunner: ...@@ -489,21 +522,21 @@ class ModelRunner:
input_tokens.append(0) input_tokens.append(0)
input_positions.append(0) input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
context_lens.append(1) seq_lens.append(1)
block_tables.append([]) block_tables.append([])
lora_index_mapping.append(0) lora_index_mapping.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
context_lens_tensor = torch.tensor(context_lens, seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
if use_captured_graph: if use_captured_graph:
# When using cuda-graph all these tensors should be # When using cuda-graph all these tensors should be
# padded. # padded.
assert context_lens_tensor.shape[0] == len(input_tokens) assert seq_lens_tensor.shape[0] == len(input_tokens)
assert context_lens_tensor.shape[0] == len(input_positions) assert seq_lens_tensor.shape[0] == len(input_positions)
assert context_lens_tensor.shape[0] == len(slot_mapping) assert seq_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
...@@ -523,19 +556,51 @@ class ModelRunner: ...@@ -523,19 +556,51 @@ class ModelRunner:
device=self.device, device=self.device,
) )
attn_metadata = self.attn_backend.make_metadata( if self.attn_backend is FlashInferBackend:
is_prompt=False, if not hasattr(self, "flashinfer_workspace_buffer"):
prompt_lens=None, # Allocate 16MB workspace buffer
prompt_lens_tensor=None, # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
max_subquery_len=None, self.flashinfer_workspace_buffer = torch.empty(
max_context_len=max_context_len, 16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
max_prompt_len=None, paged_kv_indptr = torch.tensor(paged_kv_indptr,
subquery_start_loc=None, dtype=torch.int,
seq_start_loc=None, device=self.device)
context_lens=context_lens_tensor, paged_kv_indices = torch.tensor(paged_kv_indices,
block_tables=block_tables, dtype=torch.int,
use_cuda_graph=use_captured_graph, device=self.device)
) paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
dtype=torch.int,
device=self.device)
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
use_cuda_graph=False,
workspace_buffer=self.flashinfer_workspace_buffer,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.model_config.get_num_attention_heads(
self.parallel_config),
num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config),
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
data_type=kv_cache_dtype)
else:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_query_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
return PrepareDecodeMetadata( return PrepareDecodeMetadata(
input_tokens=input_tokens, input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,
...@@ -546,108 +611,6 @@ class ModelRunner: ...@@ -546,108 +611,6 @@ class ModelRunner:
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
) )
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert subquery_lens is not None
subquery_len = subquery_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
list(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx
+ num_seqs))))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
...@@ -667,8 +630,8 @@ class ModelRunner: ...@@ -667,8 +630,8 @@ class ModelRunner:
input_tokens, input_tokens,
input_positions, input_positions,
prefill_attn_metadata, prefill_attn_metadata,
prompt_lens, seq_lens,
subquery_lens, query_lens,
lora_index_mapping, lora_index_mapping,
lora_prompt_mapping, lora_prompt_mapping,
lora_requests, lora_requests,
...@@ -684,14 +647,14 @@ class ModelRunner: ...@@ -684,14 +647,14 @@ class ModelRunner:
decode_lora_requests, decode_lora_requests,
decode_slot_mapping, decode_slot_mapping,
) = self._prepare_decode(decode_reqs) ) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
prompt_lens, seq_group_metadata_list, seq_lens, query_lens, self.device,
subquery_lens) self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled: if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0 assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens) num_prefills = len(seq_lens)
num_prefill_tokens = len(input_tokens) num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens) num_decode_tokens = len(decode_input_tokens)
...@@ -787,12 +750,9 @@ class ModelRunner: ...@@ -787,12 +750,9 @@ class ModelRunner:
**metadata_dict) **metadata_dict)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=selected_token_indices, selected_token_indices=selected_token_indices,
categorized_sample_indices=None, categorized_sample_indices=None,
generators=None, num_prompts=0,
perform_sampling=False,
) )
# if it is a mixed batch, decode attn_metadata is broadcasted # if it is a mixed batch, decode attn_metadata is broadcasted
...@@ -851,7 +811,7 @@ class ModelRunner: ...@@ -851,7 +811,7 @@ class ModelRunner:
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling: if not self.is_driver_worker:
return None return None
# Sample the next token. # Sample the next token.
...@@ -859,6 +819,7 @@ class ModelRunner: ...@@ -859,6 +819,7 @@ class ModelRunner:
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
return output return output
@torch.inference_mode() @torch.inference_mode()
...@@ -928,10 +889,10 @@ class ModelRunner: ...@@ -928,10 +889,10 @@ class ModelRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
return return
def remove_all_loras(self) -> bool: def remove_all_loras(self):
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras() self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: Set[LoRARequest], def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None: lora_mapping: LoRAMapping) -> None:
...@@ -990,7 +951,7 @@ class ModelRunner: ...@@ -990,7 +951,7 @@ class ModelRunner:
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID) slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
graph_batch_size = _get_graph_batch_size( graph_batch_size = _get_graph_batch_size(
...@@ -1012,14 +973,13 @@ class ModelRunner: ...@@ -1012,14 +973,13 @@ class ModelRunner:
# Create dummy attn_metadata. # Create dummy attn_metadata.
decode_metadata = self.attn_backend.make_metadata( decode_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
prompt_lens=None, seq_lens=None,
prompt_lens_tensor=None, seq_lens_tensor=seq_lens[:batch_size],
max_subquery_len=None, max_query_len=None,
max_context_len=self.max_context_len_to_capture, max_seq_len=self.max_seq_len_to_capture,
max_prompt_len=None,
subquery_start_loc=None, subquery_start_loc=None,
seq_start_loc=None, seq_start_loc=None,
context_lens=context_lens[:batch_size], context_lens_tensor=None,
block_tables=block_tables[:batch_size], block_tables=block_tables[:batch_size],
use_cuda_graph=True, use_cuda_graph=True,
) )
...@@ -1054,7 +1014,7 @@ class ModelRunner: ...@@ -1054,7 +1014,7 @@ class ModelRunner:
end_time = time.perf_counter() end_time = time.perf_counter()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
# This usually takes < 10 seconds. # This usually takes < 10 seconds.
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
def __del__(self) -> None: def __del__(self) -> None:
# Delete the CUDA graphs before deleting the pynccl communicator. # Delete the CUDA graphs before deleting the pynccl communicator.
...@@ -1129,7 +1089,7 @@ class CUDAGraphRunner: ...@@ -1129,7 +1089,7 @@ class CUDAGraphRunner:
"positions": positions, "positions": positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.decode_metadata.context_lens, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
...@@ -1151,8 +1111,8 @@ class CUDAGraphRunner: ...@@ -1151,8 +1111,8 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) non_blocking=True)
self.input_buffers["context_lens"].copy_( self.input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.context_lens, non_blocking=True) attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_( self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph. # Run the graph.
......
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, ...@@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
make_tensor_with_pad, maybe_expand_dim)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,7 +52,7 @@ class NeuronModelRunner: ...@@ -54,7 +52,7 @@ class NeuronModelRunner:
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
input_block_ids: List[int] = [] input_block_ids: List[int] = []
prompt_lens: List[int] = [] seq_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
...@@ -63,26 +61,26 @@ class NeuronModelRunner: ...@@ -63,26 +61,26 @@ class NeuronModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids() prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens) seq_len = len(prompt_tokens)
prompt_lens.append(prompt_len) seq_lens.append(seq_len)
input_tokens.append(prompt_tokens) input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len))) input_positions.append(list(range(seq_len)))
assert seq_group_metadata.block_tables is not None assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1 assert len(block_table) == 1
input_block_ids.append(block_table[0]) input_block_ids.append(block_table[0])
max_prompt_len = max(prompt_lens) max_seq_len = max(seq_lens)
assert max_prompt_len > 0 assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens, input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len, max_seq_len,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = make_tensor_with_pad(input_positions, input_positions = make_tensor_with_pad(input_positions,
max_prompt_len, max_seq_len,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
...@@ -90,7 +88,7 @@ class NeuronModelRunner: ...@@ -90,7 +88,7 @@ class NeuronModelRunner:
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
return input_tokens, input_positions, input_block_ids, prompt_lens return input_tokens, input_positions, input_block_ids, seq_lens
def _prepare_decode( def _prepare_decode(
self, self,
...@@ -141,106 +139,6 @@ class NeuronModelRunner: ...@@ -141,106 +139,6 @@ class NeuronModelRunner:
return input_tokens, input_positions, input_block_ids return input_tokens, input_positions, input_block_ids
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert prompt_lens is not None
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
...@@ -251,13 +149,20 @@ class NeuronModelRunner: ...@@ -251,13 +149,20 @@ class NeuronModelRunner:
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, input_block_ids, (input_tokens, input_positions, input_block_ids,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list) seq_lens) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, (input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list) input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] seq_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = SamplingMetadata.prepare(
prompt_lens) seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory)
return (input_tokens, input_positions, input_block_ids, return (input_tokens, input_positions, input_block_ids,
sampling_metadata) sampling_metadata)
......
...@@ -11,13 +11,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -11,13 +11,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
get_tensor_model_parallel_cpu_group,
init_distributed_environment) init_distributed_environment)
from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import ( from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar) init_custom_ar)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
...@@ -210,19 +211,21 @@ class Worker(WorkerBase): ...@@ -210,19 +211,21 @@ class Worker(WorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, execute_model_req: Optional[ExecuteModelRequest] = None
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
num_lookahead_slots: int = 0,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
assert execute_model_req is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None blocks_to_swap_in = execute_model_req.blocks_to_swap_in
assert blocks_to_swap_out is not None blocks_to_swap_out = execute_model_req.blocks_to_swap_out
assert blocks_to_copy is not None blocks_to_copy = execute_model_req.blocks_to_copy
data: Dict[str, Any] = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_in": blocks_to_swap_in,
...@@ -237,9 +240,6 @@ class Worker(WorkerBase): ...@@ -237,9 +240,6 @@ class Worker(WorkerBase):
blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
...@@ -288,6 +288,9 @@ def init_worker_distributed_environment( ...@@ -288,6 +288,9 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank) distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
if pynccl_utils.is_initialized(): if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size() pynccl_world_size = pynccl_utils.get_world_size()
if pynccl_world_size != parallel_config.world_size: if pynccl_world_size != parallel_config.world_size:
...@@ -298,12 +301,9 @@ def init_worker_distributed_environment( ...@@ -298,12 +301,9 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1: elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size # NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1. # is 1.
# NOTE(kaichao): By default, pynccl will use information inside # NOTE(kaichao): By default, pynccl is initialized for tp group.
# `parallel_state` for initialization. pynccl_utils.init_process_group(
pynccl_utils.init_process_group() group=get_tensor_model_parallel_cpu_group())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce: if not parallel_config.disable_custom_all_reduce:
......
import datetime
import importlib import importlib
import os import os
import tempfile
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import get_vllm_instance_id, update_environment_variables from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -50,10 +48,8 @@ class WorkerBase(ABC): ...@@ -50,10 +48,8 @@ class WorkerBase(ABC):
@abstractmethod @abstractmethod
def execute_model( def execute_model(
self, seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
int],
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no """Executes at least one model step on the given sequences, unless no
sequences are provided.""" sequences are provided."""
raise NotImplementedError raise NotImplementedError
...@@ -128,15 +124,7 @@ class WorkerWrapperBase: ...@@ -128,15 +124,7 @@ class WorkerWrapperBase:
function tracing if required. function tracing if required.
Arguments are passed to the worker class constructor. Arguments are passed to the worker class constructor.
""" """
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): enable_trace_function_call_for_thread()
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
mod = importlib.import_module(self.worker_module_name) mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name) worker_class = getattr(mod, self.worker_class_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment