Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
......@@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......@@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
assert len(blocks_to_swap_in) == 0
assert len(blocks_to_swap_out) == 0
assert execute_model_req is not None
blocks_to_copy = execute_model_req.blocks_to_copy
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy,
"blocks_to_copy": execute_model_req.blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
......@@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
......
......@@ -9,6 +9,7 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
......@@ -20,12 +21,11 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad)
logger = init_logger(__name__)
......@@ -43,8 +43,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
seq_lens: List[int]
query_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
......@@ -57,8 +57,8 @@ class PreparePromptMetadata(NamedTuple):
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
seq_lens=[],
query_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
......@@ -135,9 +135,8 @@ class ModelRunner:
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
if self.model_config is not None else 0)
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
......@@ -150,13 +149,16 @@ class ModelRunner:
self.model: torch.nn.Module # Set after load_model
self.block_size: int # Set after initial profiling.
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables: torch.Tensor # Set after initial profiling.
# Set if the backend is flashinfer.
self.flashinfer_workspace_buffer: torch.Tensor
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
......@@ -170,8 +172,8 @@ class ModelRunner:
)
self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took "
f"{self.model_memory_usage / float(2**30):.4f} GB")
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
if self.lora_config:
assert hasattr(self.model, "supported_lora_modules"
......@@ -196,18 +198,19 @@ class ModelRunner:
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
else:
raise RuntimeError("Using FP8 KV cache and scaling "
"factors provided but model "
f"{self.model.__class__} does not "
"support loading scaling factors.")
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__)
else:
logger.warn("Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
logger.warn("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
logger.warning("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
def set_block_size(self, block_size: int) -> None:
self.block_size = block_size
......@@ -218,7 +221,7 @@ class ModelRunner:
def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_context_len_to_capture + block_size - 1) // block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_prompt(
self,
......@@ -231,9 +234,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_lens: List[int] = []
seq_lens: List[int] = []
context_lens: List[int] = []
subquery_lens: List[int] = []
query_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
......@@ -257,21 +260,19 @@ class ModelRunner:
token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id]
computed_len = seq_data.get_num_computed_tokens()
context_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption
# it contains output tokens.
prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = prefill_end
prompt_lens.append(prompt_len)
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
# NOTE: This only works for oooooooxxx style attention.
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
......@@ -285,25 +286,25 @@ class ModelRunner:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert computed_len == 0
assert context_len == 0
# actual prompt lens
context_lens.append(computed_len)
subquery_lens.append(prompt_len - computed_len)
context_lens.append(context_len)
query_lens.append(seq_len - context_len)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prefill_end)))
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (prompt_len - computed_len)
lora_index_mapping += [lora_id] * (seq_len - context_len)
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len - computed_len
(seq_len - context_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.multi_modal_data:
......@@ -313,24 +314,25 @@ class ModelRunner:
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert computed_len == 0, (
assert context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window)
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, prefill_end):
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
......@@ -340,9 +342,9 @@ class ModelRunner:
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens)
assert max_subquery_len > 0
max_query_len = max(query_lens)
max_seq_len = max(seq_lens)
assert max_query_len > 0
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
......@@ -369,50 +371,57 @@ class ModelRunner:
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
subquery_lens_tensor = torch.tensor(subquery_lens,
dtype=torch.long,
device=self.device)
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device=self.device)
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(subquery_lens_tensor,
torch.cumsum(query_lens_tensor,
dim=0,
dtype=subquery_start_loc.dtype,
out=subquery_start_loc[1:])
torch.cumsum(prompt_lens_tensor,
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor,
max_subquery_len=max_subquery_len,
max_context_len=None,
max_prompt_len=max_prompt_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
if self.attn_backend is FlashInferBackend:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
use_cuda_graph=False,
seq_start_loc=seq_start_loc,
max_seq_len=max_seq_len,
block_tables=block_tables)
else:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
return PreparePromptMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
prompt_lens=prompt_lens,
subquery_lens=subquery_lens,
seq_lens=seq_lens,
query_lens=query_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
......@@ -427,12 +436,30 @@ class ModelRunner:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
context_lens: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
......@@ -455,9 +482,9 @@ class ModelRunner:
position = seq_len - 1
input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min(
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
context_lens.append(context_len)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
......@@ -473,15 +500,21 @@ class ModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
paged_kv_indices.extend(block_table)
paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
last_page_len = seq_data.get_len() % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens)
max_context_len = max(context_lens)
use_captured_graph = (
not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_context_len <= self.max_context_len_to_capture)
max_seq_len = max(seq_lens)
use_captured_graph = (not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
......@@ -489,21 +522,21 @@ class ModelRunner:
input_tokens.append(0)
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
context_lens.append(1)
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
batch_size = graph_batch_size
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert context_lens_tensor.shape[0] == len(input_tokens)
assert context_lens_tensor.shape[0] == len(input_positions)
assert context_lens_tensor.shape[0] == len(slot_mapping)
assert seq_lens_tensor.shape[0] == len(input_tokens)
assert seq_lens_tensor.shape[0] == len(input_positions)
assert seq_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
......@@ -523,19 +556,51 @@ class ModelRunner:
device=self.device,
)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
prompt_lens=None,
prompt_lens_tensor=None,
max_subquery_len=None,
max_context_len=max_context_len,
max_prompt_len=None,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
if self.attn_backend is FlashInferBackend:
if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
self.flashinfer_workspace_buffer = torch.empty(
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
paged_kv_indptr = torch.tensor(paged_kv_indptr,
dtype=torch.int,
device=self.device)
paged_kv_indices = torch.tensor(paged_kv_indices,
dtype=torch.int,
device=self.device)
paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
dtype=torch.int,
device=self.device)
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
use_cuda_graph=False,
workspace_buffer=self.flashinfer_workspace_buffer,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.model_config.get_num_attention_heads(
self.parallel_config),
num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config),
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
data_type=kv_cache_dtype)
else:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_query_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
return PrepareDecodeMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
......@@ -546,108 +611,6 @@ class ModelRunner:
slot_mapping=slot_mapping,
)
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
subquery_lens: Optional[List[int]],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert subquery_lens is not None
subquery_len = subquery_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
list(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx
+ num_seqs))))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
......@@ -667,8 +630,8 @@ class ModelRunner:
input_tokens,
input_positions,
prefill_attn_metadata,
prompt_lens,
subquery_lens,
seq_lens,
query_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_requests,
......@@ -684,14 +647,14 @@ class ModelRunner:
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefills = len(seq_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
......@@ -787,12 +750,9 @@ class ModelRunner:
**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
perform_sampling=False,
num_prompts=0,
)
# if it is a mixed batch, decode attn_metadata is broadcasted
......@@ -851,7 +811,7 @@ class ModelRunner:
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
if not self.is_driver_worker:
return None
# Sample the next token.
......@@ -859,6 +819,7 @@ class ModelRunner:
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
@torch.inference_mode()
......@@ -928,10 +889,10 @@ class ModelRunner:
torch.cuda.synchronize()
return
def remove_all_loras(self) -> bool:
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras()
self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
......@@ -990,7 +951,7 @@ class ModelRunner:
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
graph_batch_size = _get_graph_batch_size(
......@@ -1012,14 +973,13 @@ class ModelRunner:
# Create dummy attn_metadata.
decode_metadata = self.attn_backend.make_metadata(
is_prompt=False,
prompt_lens=None,
prompt_lens_tensor=None,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture,
max_prompt_len=None,
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_seq_len=self.max_seq_len_to_capture,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens[:batch_size],
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
......@@ -1054,7 +1014,7 @@ class ModelRunner:
end_time = time.perf_counter()
elapsed_time = end_time - start_time
# This usually takes < 10 seconds.
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
def __del__(self) -> None:
# Delete the CUDA graphs before deleting the pynccl communicator.
......@@ -1129,7 +1089,7 @@ class CUDAGraphRunner:
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.decode_metadata.context_lens,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
......@@ -1151,8 +1111,8 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(
attn_metadata.decode_metadata.context_lens, non_blocking=True)
self.input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph.
......
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
from torch import nn
......@@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
make_tensor_with_pad, maybe_expand_dim)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
logger = init_logger(__name__)
......@@ -54,7 +52,7 @@ class NeuronModelRunner:
input_positions: List[List[int]] = []
input_block_ids: List[int] = []
prompt_lens: List[int] = []
seq_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
......@@ -63,26 +61,26 @@ class NeuronModelRunner:
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
seq_len = len(prompt_tokens)
seq_lens.append(seq_len)
input_tokens.append(prompt_tokens)
input_positions.append(list(range(prompt_len)))
input_positions.append(list(range(seq_len)))
assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
assert len(block_table) == 1
input_block_ids.append(block_table[0])
max_prompt_len = max(prompt_lens)
assert max_prompt_len > 0
max_seq_len = max(seq_lens)
assert max_seq_len > 0
input_tokens = make_tensor_with_pad(input_tokens,
max_prompt_len,
max_seq_len,
pad=0,
dtype=torch.long,
device=self.device)
input_positions = make_tensor_with_pad(input_positions,
max_prompt_len,
max_seq_len,
pad=0,
dtype=torch.long,
device=self.device)
......@@ -90,7 +88,7 @@ class NeuronModelRunner:
dtype=torch.long,
device=self.device)
return input_tokens, input_positions, input_block_ids, prompt_lens
return input_tokens, input_positions, input_block_ids, seq_lens
def _prepare_decode(
self,
......@@ -141,106 +139,6 @@ class NeuronModelRunner:
return input_tokens, input_positions, input_block_ids
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices: Dict[SamplingType,
List[Tuple[int, int]]] = {
t: []
for t in SamplingType
}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
assert prompt_lens is not None
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
(categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=self.device).manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=self.pin_memory)
categorized_sample_indices = {
t: maybe_expand_dim(
async_tensor_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=self.pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
......@@ -251,13 +149,20 @@ class NeuronModelRunner:
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_block_ids,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
seq_lens) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since neuron worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory)
return (input_tokens, input_positions, input_block_ids,
sampling_metadata)
......
......@@ -11,13 +11,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
get_tensor_model_parallel_cpu_group,
init_distributed_environment)
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar)
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase
......@@ -210,19 +211,21 @@ class Worker(WorkerBase):
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
num_lookahead_slots: int = 0,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker:
assert seq_group_metadata_list is not None
assert execute_model_req is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
blocks_to_copy = execute_model_req.blocks_to_copy
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
......@@ -237,9 +240,6 @@ class Worker(WorkerBase):
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
......@@ -288,6 +288,9 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size()
if pynccl_world_size != parallel_config.world_size:
......@@ -298,12 +301,9 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# NOTE(kaichao): By default, pynccl will use information inside
# `parallel_state` for initialization.
pynccl_utils.init_process_group()
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# NOTE(kaichao): By default, pynccl is initialized for tp group.
pynccl_utils.init_process_group(
group=get_tensor_model_parallel_cpu_group())
# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
......
import datetime
import importlib
import os
import tempfile
import threading
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple
from vllm.logger import enable_trace_function_call, init_logger
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import get_vllm_instance_id, update_environment_variables
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
logger = init_logger(__name__)
......@@ -50,10 +48,8 @@ class WorkerBase(ABC):
@abstractmethod
def execute_model(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
int],
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise NotImplementedError
......@@ -128,15 +124,7 @@ class WorkerWrapperBase:
function tracing if required.
Arguments are passed to the worker class constructor.
"""
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
enable_trace_function_call_for_thread()
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment