Unverified Commit 6c5af09b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V1] Implement vLLM V1 [1/N] (#9289)

parent 3ddbe255
from dataclasses import dataclass
from typing import List, Optional
import torch
@dataclass
class SamplingMetadata:
temperature: torch.Tensor
all_greedy: bool
all_random: bool
top_p: torch.Tensor
top_k: torch.Tensor
no_top_p: bool
no_top_k: bool
generators: List[Optional[torch.Generator]]
no_generator: bool
max_num_logprobs: int
"""A layer that samples the next tokens from the model's outputs."""
from typing import List, Optional
import torch
import torch.nn as nn
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
logits = self.apply_temperature(logits, sampling_metadata.temperature)
logits = self.apply_top_k_top_p(logits, sampling_metadata)
probs = self.get_probs(logits)
sampled = self.sample(probs, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
if sampling_metadata.max_num_logprobs > 0:
logprobs = self.get_logprobs(logits)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
else:
topk_logprobs = None
topk_indices = None
sampler_output = SamplerOutput(
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
prompt_logprobs=None,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use float32 to apply temperature scaling.
logits = logits.to(torch.float32)
# Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1))
return logits
def apply_top_k_top_p(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return _apply_top_k_top_p(
logits,
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.softmax(logits, dim=-1, dtype=torch.float32)
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
return probs.argmax(dim=-1).view(-1)
def random_sample(
self,
probs: torch.Tensor,
generators: List[Optional[torch.Generator]],
no_generator: bool,
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
q.exponential_()
if not no_generator:
assert len(generators) == probs.shape[0]
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in enumerate(generators):
if generator is not None:
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def sample(
self,
probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
if sampling_metadata.all_greedy:
return self.greedy_sample(probs)
if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators,
sampling_metadata.no_generator)
greedy_sampled = self.greedy_sample(probs)
random_sampled = self.random_sample(probs,
sampling_metadata.generators,
sampling_metadata.no_generator)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
)
return sampled
# TODO(woosuk): Optimize this with a custom kernel.
def _apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
import multiprocessing
from dataclasses import dataclass
from typing import Dict, List, Optional
import msgspec
import zmq
from msgspec import msgpack
from vllm.transformers_utils.detokenizer_utils import (
convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port
class DetokenizerInputs(msgspec.Struct):
# [num_reqs]
req_ids: List[str]
# A request's prompt token ids is sent to the detokenizer only when
# the request is first detokenized. Otherwise, an empty list is sent.
prompt_token_ids: List[List[int]]
new_token_ids: List[List[int]]
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
# [num_free_reqs]
free_req_ids: List[str]
class DetokenizerOutputs(msgspec.Struct):
# [num_reqs]
req_ids: List[str]
detokenized_texts: List[str]
# NOTE(woosuk): The number of the output token ids of each request
# at the time of detokenization. The detokenizer returns this to the engine
# because the request state (including the output token ids) is
# asynchronously updated in the engine, while RequestOutput requires the
# output token ids to be consistent with the detokenized text.
num_output_token_ids: List[int]
class Detokenizer:
def __init__(self, tokenizer_name: str):
# FIXME(woosuk): Currently, the detokenizer is just a hacky prototype.
# For example, it does not terminate properly. We need to improve this.
self.push_port = get_open_port()
self.pull_port = get_open_port()
self.detokenizer = DetokenizerProc(tokenizer_name, self.push_port,
self.pull_port)
self.detokenizer.start()
self.zmq_context = zmq.Context()
self.push_socket = self.zmq_context.socket(zmq.PUSH)
self.push_socket.connect(f"tcp://localhost:{self.push_port}")
self.pull_socket = self.zmq_context.socket(zmq.PULL)
self.pull_socket.connect(f"tcp://localhost:{self.pull_port}")
self.poller = zmq.Poller()
self.poller.register(self.pull_socket, zmq.POLLIN)
self.msgpack_encoder = msgpack.Encoder()
self.msgpack_decoder = msgpack.Decoder(DetokenizerOutputs)
def send(self, inputs: DetokenizerInputs) -> None:
self.push_socket.send(self.msgpack_encoder.encode(inputs),
flags=zmq.NOBLOCK)
def recv(self) -> Optional[DetokenizerOutputs]:
socks = dict(self.poller.poll(timeout=0))
if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN:
msg = self.pull_socket.recv()
return self.msgpack_decoder.decode(msg)
return None
def terminate(self) -> None:
self.push_socket.send(b"", flags=zmq.NOBLOCK)
self.detokenizer.join()
class DetokenizerProc(multiprocessing.Process):
def __init__(
self,
tokenizer_name: str,
pull_port: int,
push_port: int,
):
super().__init__()
self.tokenizer_name = tokenizer_name
# NOTE: The pull_port of the detokenizer should be the same as the
# push_port of the engine. Vice versa.
self.pull_port = pull_port
self.push_port = push_port
def run(self):
# Initialize these objects after the process is forked since they are
# not picklable.
self.msgpack_encoder = msgpack.Encoder()
self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs)
self.tokenizer = get_tokenizer(self.tokenizer_name)
# req_id -> RequestState
self.request_states: Dict[str, RequestState] = {}
self.zmq_context = zmq.Context()
self.pull_socket = self.zmq_context.socket(zmq.PULL)
self.pull_socket.bind(f"tcp://*:{self.pull_port}")
self.push_socket = self.zmq_context.socket(zmq.PUSH)
self.push_socket.bind(f"tcp://*:{self.push_port}")
while True:
message = self.pull_socket.recv()
if message == b"":
# Terminate signal.
break
inputs = self.msgpack_decoder.decode(message)
for req_id in inputs.free_req_ids:
self.free(req_id)
detokenized_texts: List[str] = []
num_output_token_ids: List[int] = []
num_reqs = len(inputs.req_ids)
for i in range(num_reqs):
req_id = inputs.req_ids[i]
if req_id not in self.request_states:
self.add_request(
request_id=req_id,
prompt_token_ids=inputs.prompt_token_ids[i],
skip_special_tokens=inputs.skip_special_tokens[i],
spaces_between_special_tokens=inputs.
spaces_between_special_tokens[i],
)
new_str = self.detokenize(req_id, inputs.new_token_ids[i])
detokenized_texts.append(new_str)
req_state = self.request_states[req_id]
num_output_token_ids.append(
len(req_state.token_ids) - req_state.num_prompt_tokens)
detokenized = DetokenizerOutputs(
req_ids=inputs.req_ids,
detokenized_texts=detokenized_texts,
num_output_token_ids=num_output_token_ids,
)
self.push_socket.send(self.msgpack_encoder.encode(detokenized),
flags=zmq.NOBLOCK)
def add_request(
self,
request_id: str,
prompt_token_ids: List[int],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> None:
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=self.tokenizer,
prompt_ids=prompt_token_ids,
skip_special_tokens=skip_special_tokens,
)
self.request_states[request_id] = RequestState(
req_id=request_id,
token_ids=prompt_token_ids,
tokens=tokens,
num_prompt_tokens=len(prompt_token_ids),
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
def free(self, request_id: str) -> None:
del self.request_states[request_id]
def detokenize(self, request_id: str, new_token_ids: List[int]) -> str:
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
req_state = self.request_states[request_id]
decoded_text = ""
for new_token_id in new_token_ids:
req_state.token_ids.append(new_token_id)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=req_state.token_ids,
prev_tokens=req_state.tokens,
prefix_offset=req_state.prefix_offset,
read_offset=req_state.read_offset,
skip_special_tokens=req_state.skip_special_tokens,
spaces_between_special_tokens=req_state.
spaces_between_special_tokens,
)
req_state.tokens.extend(new_tokens)
req_state.prefix_offset = prefix_offset
req_state.read_offset = read_offset
req_state.output_text += new_decoded_token_text
decoded_text += new_decoded_token_text
return decoded_text
@dataclass
class RequestState:
req_id: str
token_ids: List[int]
tokens: List[str]
num_prompt_tokens: int
prefix_offset: int
read_offset: int
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_text: str = ""
This diff is collapsed.
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
class Worker:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
speculative_config: Optional[SpeculativeConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = GPUModelRunner(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config,
lora_config=lora_config,
)
def initialize(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_worker_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
set_random_seed(self.model_config.seed)
def load_model(self) -> None:
self.model_runner.load_model()
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = _get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
# if self.model_runner.lora_manager:
# self.model_runner.remove_all_loras()
gc.collect()
torch.cuda.empty_cache()
return num_gpu_blocks, 0
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks."""
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = self.cache_config.block_size * num_gpu_blocks
max_model_len = self.model_config.max_model_len
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")
self.model_runner.initialize_kv_cache(num_gpu_blocks)
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
output = self.model_runner.execute_model(scheduler_output)
# TODO(woosuk): Send the output to the engine process.
return output
def init_worker_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not current_platform.has_device_capability(80):
capability = current_platform.get_device_capability()
gpu_name = current_platform.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
def _get_cache_block_size(
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_attention_layers = model_config.get_num_attention_layers(
parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
dtype_size = get_dtype_size(dtype)
return dtype_size * total
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment