Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
import enum
from typing import TYPE_CHECKING, List, Optional, Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
if TYPE_CHECKING:
from vllm.inputs import DecoderOnlyInputs
class Request:
def __init__(
self,
request_id: str,
inputs: "DecoderOnlyInputs",
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.request_id = request_id
self.inputs = inputs
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.metrics = RequestMetrics(arrival_time=arrival_time,
last_token_time=arrival_time,
first_scheduled_time=None,
first_token_time=None,
time_in_queue=None)
self.lora_request = lora_request
self.status = RequestStatus.WAITING
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
self.prompt = inputs.get("prompt")
self.prompt_token_ids = inputs["prompt_token_ids"]
self.num_prompt_tokens = len(self.prompt_token_ids)
self.output_token_ids: List[int] = []
self.output_text = ""
self.num_computed_tokens = 0
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
@property
def num_output_tokens(self) -> int:
return len(self.output_token_ids)
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)
def get_finished_reason(self) -> Union[str, None]:
return RequestStatus.get_finished_reason(self.status)
class RequestStatus(enum.IntEnum):
"""Status of a sequence."""
WAITING = 0
RUNNING = 1
PREEMPTED = 2
# Note: anything after PREEMPTED (2) will be considered
# as a finished status.
FINISHED_STOPPED = 3
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
@staticmethod
def is_finished(status: "RequestStatus") -> bool:
return status > RequestStatus.PREEMPTED
@staticmethod
def get_finished_reason(status: "RequestStatus") -> Union[str, None]:
return _FINISHED_REASON_MAP.get(status)
# Mapping of finished statuses to their finish reasons.
# NOTE: The ignored sequences are the sequences whose prompt lengths
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
_FINISHED_REASON_MAP = {
RequestStatus.FINISHED_STOPPED: "stop",
RequestStatus.FINISHED_LENGTH_CAPPED: "length",
RequestStatus.FINISHED_ABORTED: "abort",
RequestStatus.FINISHED_IGNORED: "length",
}
from dataclasses import dataclass
from typing import List, Optional
import torch
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
all_greedy: bool
all_random: bool
top_p: torch.Tensor
top_k: torch.Tensor
no_top_p: bool
no_top_k: bool
generators: List[Optional[torch.Generator]]
no_generator: bool
max_num_logprobs: int
"""A layer that samples the next tokens from the model's outputs."""
from typing import List, Optional
import torch
import torch.nn as nn
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
logits = self.apply_temperature(logits, sampling_metadata.temperature)
logits = self.apply_top_k_top_p(logits, sampling_metadata)
probs = self.get_probs(logits)
sampled = self.sample(probs, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
if sampling_metadata.max_num_logprobs > 0:
logprobs = self.get_logprobs(logits)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
else:
topk_logprobs = None
topk_indices = None
sampler_output = SamplerOutput(
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
prompt_logprobs=None,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use float32 to apply temperature scaling.
logits = logits.to(torch.float32)
# Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1))
return logits
def apply_top_k_top_p(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return _apply_top_k_top_p(
logits,
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.softmax(logits, dim=-1, dtype=torch.float32)
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
return probs.argmax(dim=-1).view(-1)
def random_sample(
self,
probs: torch.Tensor,
generators: List[Optional[torch.Generator]],
no_generator: bool,
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
q.exponential_()
if not no_generator:
assert len(generators) == probs.shape[0]
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in enumerate(generators):
if generator is not None:
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def sample(
self,
probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(probs)
if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators,
sampling_metadata.no_generator)
greedy_sampled = self.greedy_sample(probs)
random_sampled = self.random_sample(probs,
sampling_metadata.generators,
sampling_metadata.no_generator)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
)
return sampled
# TODO(woosuk): Optimize this with a custom kernel.
def _apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
import multiprocessing
from dataclasses import dataclass
from typing import Dict, List, Optional
import msgspec
import zmq
from msgspec import msgpack
from vllm.transformers_utils.detokenizer_utils import (
convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port
class DetokenizerInputs(msgspec.Struct):
# [num_reqs]
req_ids: List[str]
# A request's prompt token ids is sent to the detokenizer only when
# the request is first detokenized. Otherwise, an empty list is sent.
prompt_token_ids: List[List[int]]
new_token_ids: List[List[int]]
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
# [num_free_reqs]
free_req_ids: List[str]
class DetokenizerOutputs(msgspec.Struct):
# [num_reqs]
req_ids: List[str]
detokenized_texts: List[str]
# NOTE(woosuk): The number of the output token ids of each request
# at the time of detokenization. The detokenizer returns this to the engine
# because the request state (including the output token ids) is
# asynchronously updated in the engine, while RequestOutput requires the
# output token ids to be consistent with the detokenized text.
num_output_token_ids: List[int]
class Detokenizer:
def __init__(self, tokenizer_name: str):
# FIXME(woosuk): Currently, the detokenizer is just a hacky prototype.
# For example, it does not terminate properly. We need to improve this.
self.push_port = get_open_port()
self.pull_port = get_open_port()
self.detokenizer = DetokenizerProc(tokenizer_name, self.push_port,
self.pull_port)
self.detokenizer.start()
self.zmq_context = zmq.Context()
self.push_socket = self.zmq_context.socket(zmq.PUSH)
self.push_socket.connect(f"tcp://localhost:{self.push_port}")
self.pull_socket = self.zmq_context.socket(zmq.PULL)
self.pull_socket.connect(f"tcp://localhost:{self.pull_port}")
self.poller = zmq.Poller()
self.poller.register(self.pull_socket, zmq.POLLIN)
self.msgpack_encoder = msgpack.Encoder()
self.msgpack_decoder = msgpack.Decoder(DetokenizerOutputs)
def send(self, inputs: DetokenizerInputs) -> None:
self.push_socket.send(self.msgpack_encoder.encode(inputs),
flags=zmq.NOBLOCK)
def recv(self) -> Optional[DetokenizerOutputs]:
socks = dict(self.poller.poll(timeout=0))
if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN:
msg = self.pull_socket.recv()
return self.msgpack_decoder.decode(msg)
return None
def terminate(self) -> None:
self.push_socket.send(b"", flags=zmq.NOBLOCK)
self.detokenizer.join()
class DetokenizerProc(multiprocessing.Process):
def __init__(
self,
tokenizer_name: str,
pull_port: int,
push_port: int,
):
super().__init__()
self.tokenizer_name = tokenizer_name
# NOTE: The pull_port of the detokenizer should be the same as the
# push_port of the engine. Vice versa.
self.pull_port = pull_port
self.push_port = push_port
def run(self):
# Initialize these objects after the process is forked since they are
# not picklable.
self.msgpack_encoder = msgpack.Encoder()
self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs)
self.tokenizer = get_tokenizer(self.tokenizer_name)
# req_id -> RequestState
self.request_states: Dict[str, RequestState] = {}
self.zmq_context = zmq.Context()
self.pull_socket = self.zmq_context.socket(zmq.PULL)
self.pull_socket.bind(f"tcp://*:{self.pull_port}")
self.push_socket = self.zmq_context.socket(zmq.PUSH)
self.push_socket.bind(f"tcp://*:{self.push_port}")
while True:
message = self.pull_socket.recv()
if message == b"":
# Terminate signal.
break
inputs = self.msgpack_decoder.decode(message)
for req_id in inputs.free_req_ids:
self.free(req_id)
detokenized_texts: List[str] = []
num_output_token_ids: List[int] = []
num_reqs = len(inputs.req_ids)
for i in range(num_reqs):
req_id = inputs.req_ids[i]
if req_id not in self.request_states:
self.add_request(
request_id=req_id,
prompt_token_ids=inputs.prompt_token_ids[i],
skip_special_tokens=inputs.skip_special_tokens[i],
spaces_between_special_tokens=inputs.
spaces_between_special_tokens[i],
)
new_str = self.detokenize(req_id, inputs.new_token_ids[i])
detokenized_texts.append(new_str)
req_state = self.request_states[req_id]
num_output_token_ids.append(
len(req_state.token_ids) - req_state.num_prompt_tokens)
detokenized = DetokenizerOutputs(
req_ids=inputs.req_ids,
detokenized_texts=detokenized_texts,
num_output_token_ids=num_output_token_ids,
)
self.push_socket.send(self.msgpack_encoder.encode(detokenized),
flags=zmq.NOBLOCK)
def add_request(
self,
request_id: str,
prompt_token_ids: List[int],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> None:
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=self.tokenizer,
prompt_ids=prompt_token_ids,
skip_special_tokens=skip_special_tokens,
)
self.request_states[request_id] = RequestState(
req_id=request_id,
token_ids=prompt_token_ids,
tokens=tokens,
num_prompt_tokens=len(prompt_token_ids),
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
def free(self, request_id: str) -> None:
del self.request_states[request_id]
def detokenize(self, request_id: str, new_token_ids: List[int]) -> str:
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
req_state = self.request_states[request_id]
decoded_text = ""
for new_token_id in new_token_ids:
req_state.token_ids.append(new_token_id)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=req_state.token_ids,
prev_tokens=req_state.tokens,
prefix_offset=req_state.prefix_offset,
read_offset=req_state.read_offset,
skip_special_tokens=req_state.skip_special_tokens,
spaces_between_special_tokens=req_state.
spaces_between_special_tokens,
)
req_state.tokens.extend(new_tokens)
req_state.prefix_offset = prefix_offset
req_state.read_offset = read_offset
req_state.output_text += new_decoded_token_text
decoded_text += new_decoded_token_text
return decoded_text
@dataclass
class RequestState:
req_id: str
token_ids: List[int]
tokens: List[str]
num_prompt_tokens: int
prefix_offset: int
read_offset: int
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_text: str = ""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from unittest.mock import patch
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
class GPUModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
# Model-related.
self.num_attn_layers = model_config.get_num_attention_layers(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = []
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.scheduler_config.max_num_seqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
end_index = start_index + num_new_blocks
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table_cpu[
req_index, start_index:end_index] = req_data.new_block_ids
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt,
multi_modal_data=req_data.multi_modal_data,
sampling_params=req_data.sampling_params,
generator=None, # TODO
block_ids=req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens,
output_token_ids=[],
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for req_data in scheduler_output.scheduled_resumed_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = req_data.block_ids
req_state.num_computed_tokens = req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table[:num_reqs].copy_(
self.input_batch.block_table_cpu_tensor[:num_reqs],
non_blocking=True)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
assert max_num_scheduled_tokens > 0
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
indices = np.arange(num_reqs)
req_indices = np.repeat(indices, num_scheduled_tokens)
# Get batched arange.
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange_matrix = np.tile(np.arange(max_num_scheduled_tokens),
(num_reqs, 1))
mask = arange_matrix < num_scheduled_tokens[:, np.newaxis]
arange = arange_matrix[mask]
# Get positions.
positions = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
positions_np = positions.numpy()
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = positions_np + req_indices * self.max_model_len
token_indices = torch.from_numpy(token_indices)
input_ids = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
torch.index_select(torch.from_numpy(
self.input_batch.token_ids_cpu).flatten(),
0,
token_indices,
out=input_ids)
# Calculate the slot mapping.
block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[
token_indices // self.block_size]
block_offsets = token_indices % self.block_size
slot_mapping = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
torch.add(block_numbers * self.block_size,
block_offsets,
out=slot_mapping)
# Prepare the attention metadata.
query_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
query_start_loc_np = query_start_loc.numpy()
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = seq_lens.max()
seq_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
seq_start_loc_np = seq_start_loc.numpy()
seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
input_ids = input_ids.to(self.device, non_blocking=True)
positions = positions.to(self.device, non_blocking=True).long()
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
attn_metadata = FlashAttentionMetadata(
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
slot_mapping=slot_mapping,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return input_ids, positions, attn_metadata, logits_indices
def _prepare_sampling(
self,
scheduler_output: "SchedulerOutput",
) -> SamplingMetadata:
skip_copy = True
if (scheduler_output.finished_req_ids
or scheduler_output.preempted_req_ids):
skip_copy = False
if (scheduler_output.scheduled_new_reqs
or scheduler_output.scheduled_resumed_reqs):
skip_copy = False
# Create the sampling metadata.
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
return sampling_metadata
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
inputs = self._prepare_inputs(scheduler_output)
input_ids, positions, attn_metadata, logits_indices = inputs
with set_forward_context(attn_metadata):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling(scheduler_output)
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
# NOTE: CPU-GPU synchronization happens here.
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
sampled_token_ids_list = sampled_token_ids.tolist()
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
token_id = sampled_token_ids_list[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators[i]
if generator is not None:
offset = generator.get_offset()
generator = generator.set_offset(offset - 1)
self.input_batch.generators[i] = generator
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
if sampler_output.logprobs is None:
logprobs = None
else:
logprobs = sampler_output.logprobs.cpu()
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs],
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids_cpu=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids,
logprobs_cpu=logprobs,
)
return model_runner_output
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
input_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
positions = torch.zeros(num_tokens,
dtype=torch.long,
device=self.device)
kv_caches = [None for _ in range(self.num_attn_layers)]
model(input_ids, positions, kv_caches, attn_metadata=None)
return
@torch.inference_mode()
def profile_run(self) -> None:
self._dummy_run(self.model, self.max_num_tokens)
torch.cuda.synchronize()
return
@torch.inference_mode()
def capture_model(self) -> None:
# TODO: Implement CUDA graph support.
return
def initialize_kv_cache(self, num_blocks: int) -> None:
assert len(self.kv_caches) == 0
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
for _ in range(self.num_attn_layers):
self.kv_caches.append(
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional["MultiModalDataDict"]
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: List[int]
num_computed_tokens: int
output_token_ids: List[int]
@property
def num_tokens(self) -> int:
return len(self.prompt_token_ids) + len(self.output_token_ids)
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.device = device
self.pin_memory = pin_memory
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
self.req_id_to_index: Dict[str, int] = {}
self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
# Attention-related.
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: Set[str] = set()
self.random_reqs: Set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: Set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()
self.generators: List[Optional[torch.Generator]] = [None
] * max_num_reqs
self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
def add_request(
self,
request: "CachedRequestState",
req_index: Optional[int] = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
self.req_ids[req_index] = request.req_id
self.req_id_to_index[request.req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
self.greedy_reqs.add(req_index)
elif sampling_params.sampling_type == SamplingType.RANDOM:
self.random_reqs.add(req_index)
elif sampling_params.sampling_type == SamplingType.RANDOM_SEED:
# TODO(woosuk): Support per-request random seed.
raise NotImplementedError("Per-request seed is not supported yet.")
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_index)
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_index)
self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[request.req_id] = num_logprobs
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_index)
def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.req_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.generators[req_index] = None
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
return req_index
def clear(self) -> None:
self.req_ids = [None] * self.max_num_reqs
self.req_id_to_index.clear()
self.greedy_reqs.clear()
self.random_reqs.clear()
self.top_p_reqs.clear()
self.top_k_reqs.clear()
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
# The batched states are empty.
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = self.num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
if empty_index >= last_req_index:
break
# Swap the states.
req_id = self.req_ids[last_req_index]
self.req_ids[empty_index] = req_id
self.req_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
# TODO(woosuk): Optimize the copy of token_ids_cpu and
# block_table_cpu.
self.token_ids_cpu[empty_index] = self.token_ids_cpu[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.generators[empty_index] = self.generators[last_req_index]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
def make_sampling_metadata(
self,
skip_copy: bool = False,
) -> SamplingMetadata:
if not skip_copy:
self.temperature[:self.num_reqs].copy_(
self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_p[:self.num_reqs].copy_(
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_k[:self.num_reqs].copy_(
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
return SamplingMetadata(
temperature=self.temperature[:self.num_reqs],
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=self.top_p[:self.num_reqs],
top_k=self.top_k[:self.num_reqs],
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators[:self.num_reqs],
no_generator=self.no_generator,
max_num_logprobs=self.max_num_logprobs,
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_generator(self) -> bool:
return len(self.generators) == 0
@property
def max_num_logprobs(self) -> int:
if self.num_logprobs:
return max(self.num_logprobs.values())
else:
return 0
@property
def no_logprob(self) -> bool:
return len(self.num_logprobs) == 0
@property
def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
class Worker:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
speculative_config: Optional[SpeculativeConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = GPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=lora_config,
)
def initialize(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self) -> None:
self.model_runner.load_model()
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = _get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
# if self.model_runner.lora_manager:
# self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, 0
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks."""
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_gpu_blocks
max_model_len = self.model_config.max_model_len
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
self.model_runner.initialize_kv_cache(num_gpu_blocks)
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
output = self.model_runner.execute_model(scheduler_output)
# TODO(woosuk): Send the output to the engine process.
return output
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
def _get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_attention_layers(
parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = get_dtype_size(dtype)
return dtype_size * total
...@@ -53,7 +53,6 @@ class CacheEngine: ...@@ -53,7 +53,6 @@ class CacheEngine:
# Get attention backend. # Get attention backend.
self.attn_backend = get_attn_backend(self.head_size, self.attn_backend = get_attn_backend(self.head_size,
model_config.get_sliding_window(),
model_config.dtype, model_config.dtype,
cache_config.cache_dtype, cache_config.cache_dtype,
self.block_size, self.block_size,
......
...@@ -420,7 +420,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): ...@@ -420,7 +420,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
......
...@@ -57,7 +57,6 @@ class CPUCacheEngine: ...@@ -57,7 +57,6 @@ class CPUCacheEngine:
# Get attention backend. # Get attention backend.
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
cache_config.cache_dtype, cache_config.cache_dtype,
self.block_size, self.block_size,
......
...@@ -828,7 +828,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -828,7 +828,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
cuda_graph_pad_size = self._get_cuda_graph_pad_size( cuda_graph_pad_size = self._get_cuda_graph_pad_size(
num_seqs=len(seq_lens), num_seqs=len(seq_lens),
max_decode_seq_len=max_encoder_seq_len, max_decode_seq_len=max_decode_seq_len,
max_encoder_seq_len=max_encoder_seq_len) max_encoder_seq_len=max_encoder_seq_len)
batch_size = len(input_tokens) batch_size = len(input_tokens)
...@@ -1011,7 +1011,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1011,7 +1011,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
...@@ -1856,7 +1855,7 @@ class CUDAGraphRunner(nn.Module): ...@@ -1856,7 +1855,7 @@ class CUDAGraphRunner(nn.Module):
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
if self.backend_name != "placeholder-attn": if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_( self.input_buffers["slot_mapping"].copy_(
attn_metadata.slot_mapping, non_blocking=True) attn_metadata.slot_mapping, non_blocking=True)
......
...@@ -29,8 +29,8 @@ if TYPE_CHECKING: ...@@ -29,8 +29,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"] MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]
def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
-> List[str]: -> List[str]:
......
...@@ -75,7 +75,6 @@ class OpenVINOModelRunner: ...@@ -75,7 +75,6 @@ class OpenVINOModelRunner:
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
......
...@@ -71,7 +71,6 @@ class OpenVINOCacheEngine: ...@@ -71,7 +71,6 @@ class OpenVINOCacheEngine:
# Get attention backend. # Get attention backend.
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.head_size, self.head_size,
self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.cache_config.cache_dtype, self.cache_config.cache_dtype,
self.block_size, self.block_size,
......
...@@ -114,7 +114,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -114,7 +114,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
dtype=np.int32) dtype=np.int32)
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.cache_config.cache_dtype, self.cache_config.cache_dtype,
self.block_size, self.block_size,
......
...@@ -92,7 +92,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -92,7 +92,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None: if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls ModelRunnerClass = model_runner_cls
elif self._is_embedding_model(): elif model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner ModelRunnerClass = EmbeddingModelRunner
elif self._is_encoder_decoder_model(): elif self._is_encoder_decoder_model():
ModelRunnerClass = EncoderDecoderModelRunner ModelRunnerClass = EncoderDecoderModelRunner
...@@ -147,9 +147,6 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -147,9 +147,6 @@ class Worker(LocalOrDistributedWorkerBase):
def _is_encoder_decoder_model(self): def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model return self.model_config.is_encoder_decoder_model
def _is_embedding_model(self):
return self.model_config.is_embedding_model
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until # torch.distributed.all_reduce does not free the input tensor until
...@@ -217,42 +214,79 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -217,42 +214,79 @@ class Worker(LocalOrDistributedWorkerBase):
# Profile the memory usage of the model and get the maximum number of # Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory. # cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
self.model_runner.profile_run() self.model_runner.profile_run()
torch.cuda.synchronize()
self._assert_memory_footprint_increased_during_profiling()
# Get the peak memory allocation recorded by torch
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.cuda.empty_cache()
torch_allocated_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
# Calculate the number of blocks that can be allocated with the # Calculate the number of blocks that can be allocated with the
# profiled peak memory. # profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes() cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0: if cache_block_size == 0:
num_gpu_blocks = 0 num_gpu_blocks = 0
num_cpu_blocks = 0 num_cpu_blocks = 0
else: else:
num_gpu_blocks = int( num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes // num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size) cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" memory_usage_post_profile=%.2fGib"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3),
total_allocated_bytes / (1024**3),
non_torch_allocations / (1024**3),
available_kv_cache_memory / (1024**3),
self.cache_config.gpu_memory_utilization)
# Final cleanup
if self.model_runner.lora_manager: if self.model_runner.lora_manager:
self.model_runner.remove_all_loras() self.model_runner.remove_all_loras()
gc.collect() gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks
def _assert_memory_footprint_increased_during_profiling(self):
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
free_gpu_memory, _ = torch.cuda.mem_get_info()
assert self.init_gpu_memory - free_gpu_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
def initialize_cache(self, num_gpu_blocks: int, def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None: num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks. """Allocate GPU and CPU KV cache with the specified number of blocks.
......
...@@ -374,7 +374,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -374,7 +374,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
self.attn_backend = get_attn_backend( self.attn_backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype, self.model_config.dtype,
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment