Unverified Commit 9090bf02 authored by zhaoyang-star's avatar zhaoyang-star Committed by GitHub
Browse files

Support FP8-E5M2 KV Cache (#2279)


Co-authored-by: default avatarzhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent 7d648418
......@@ -12,6 +12,7 @@ class InputMetadata:
max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
kv_cache_dtype: Data type to store kv cache.
"""
def __init__(
......@@ -25,6 +26,7 @@ class InputMetadata:
context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
kv_cache_dtype: str,
) -> None:
self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
......@@ -35,6 +37,7 @@ class InputMetadata:
self.context_lens = context_lens
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.kv_cache_dtype = kv_cache_dtype
# Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack.
......@@ -47,4 +50,5 @@ class InputMetadata:
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph})")
f"use_cuda_graph={self.use_cuda_graph}, "
f"kv_cache_dtype={self.kv_cache_dtype})")
......@@ -98,6 +98,7 @@ class PagedAttention(nn.Module):
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)
if input_metadata.is_prompt:
......@@ -265,6 +266,7 @@ def _paged_attention(
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
else:
# Run PagedAttention V2.
......@@ -295,5 +297,6 @@ def _paged_attention(
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
return output
import enum
import os
import socket
import subprocess
import uuid
from platform import uname
from typing import List
from typing import List, Tuple, Union
from packaging.version import parse, Version
import psutil
import torch
......@@ -17,7 +19,17 @@ from typing import (
from collections import OrderedDict
from typing import Any, Hashable, Optional
from vllm.logger import init_logger
T = TypeVar("T")
logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8_e5m2": torch.uint8,
}
class Device(enum.Enum):
......@@ -167,3 +179,99 @@ def get_open_port() -> int:
def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def get_nvcc_cuda_version() -> Version:
cuda_home = os.environ.get('CUDA_HOME')
if not cuda_home:
cuda_home = '/usr/local/cuda'
logger.info(
f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
)
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version
def _generate_random_fp8_e5m2(
tensor: torch.tensor,
low: float,
high: float,
) -> None:
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
# it may occur Inf or NaN if we directly use torch.randint
# to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format repesents Inf.
# | E4M3 | E5M2
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
from vllm._C import cache_ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high)
cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
del tensor_tmp
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
else:
raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8_e5m2":
torch_dtype = torch.uint8
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype):
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(key_cache, -scale, scale)
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(value_cache, -scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
......@@ -6,7 +6,7 @@ import torch
from vllm._C import cache_ops
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import in_wsl
from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
......@@ -34,12 +34,16 @@ class CacheEngine:
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.dtype = model_config.dtype
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache()
......@@ -142,6 +146,7 @@ class CacheEngine:
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: str,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
......@@ -152,7 +157,11 @@ class CacheEngine:
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = _get_dtype_size(model_config.dtype)
if cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype_size = _get_dtype_size(dtype)
return dtype_size * total
......
......@@ -36,6 +36,7 @@ class ModelRunner:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
):
self.model_config = model_config
......@@ -68,6 +69,7 @@ class ModelRunner:
self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result
self.in_wsl = in_wsl()
self.kv_cache_dtype = kv_cache_dtype
def load_model(self) -> None:
self.model = get_model(self.model_config, self.lora_config)
......@@ -223,6 +225,7 @@ class ModelRunner:
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
......@@ -350,6 +353,7 @@ class ModelRunner:
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
)
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
......@@ -473,6 +477,7 @@ class ModelRunner:
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
"use_cuda_graph": input_metadata.use_cuda_graph,
"kv_cache_dtype": input_metadata.kv_cache_dtype,
"selected_token_indices":
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
......@@ -495,6 +500,7 @@ class ModelRunner:
context_lens=metadata_dict["context_lens"],
block_tables=metadata_dict["block_tables"],
use_cuda_graph=metadata_dict["use_cuda_graph"],
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
)
sampling_metadata = SamplingMetadata(
seq_groups=None,
......@@ -665,6 +671,7 @@ class ModelRunner:
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
kv_cache_dtype=self.kv_cache_dtype,
)
if self.lora_config:
......
......@@ -37,6 +37,7 @@ class Worker:
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
......@@ -54,6 +55,7 @@ class Worker:
parallel_config,
scheduler_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
......@@ -95,6 +97,7 @@ class Worker:
block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
cache_dtype: str,
) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated.
......@@ -119,7 +122,7 @@ class Worker:
peak_memory = total_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config)
block_size, cache_dtype, self.model_config, self.parallel_config)
num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment