Unverified Commit 9090bf02 authored by zhaoyang-star's avatar zhaoyang-star Committed by GitHub
Browse files

Support FP8-E5M2 KV Cache (#2279)


Co-authored-by: default avatarzhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent 7d648418
...@@ -12,6 +12,7 @@ class InputMetadata: ...@@ -12,6 +12,7 @@ class InputMetadata:
max_context_len: The maximum context length. max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence. context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block) block_tables: The block tables. (Seq id -> list of physical block)
kv_cache_dtype: Data type to store kv cache.
""" """
def __init__( def __init__(
...@@ -25,6 +26,7 @@ class InputMetadata: ...@@ -25,6 +26,7 @@ class InputMetadata:
context_lens: Optional[torch.Tensor], context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor], block_tables: Optional[torch.Tensor],
use_cuda_graph: bool, use_cuda_graph: bool,
kv_cache_dtype: str,
) -> None: ) -> None:
self.is_prompt = is_prompt self.is_prompt = is_prompt
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
...@@ -35,6 +37,7 @@ class InputMetadata: ...@@ -35,6 +37,7 @@ class InputMetadata:
self.context_lens = context_lens self.context_lens = context_lens
self.block_tables = block_tables self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph self.use_cuda_graph = use_cuda_graph
self.kv_cache_dtype = kv_cache_dtype
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
...@@ -47,4 +50,5 @@ class InputMetadata: ...@@ -47,4 +50,5 @@ class InputMetadata:
f"slot_mapping={self.slot_mapping}, " f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, " f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables}, " f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph})") f"use_cuda_graph={self.use_cuda_graph}, "
f"kv_cache_dtype={self.kv_cache_dtype})")
...@@ -98,6 +98,7 @@ class PagedAttention(nn.Module): ...@@ -98,6 +98,7 @@ class PagedAttention(nn.Module):
key_cache, key_cache,
value_cache, value_cache,
input_metadata.slot_mapping.flatten(), input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
) )
if input_metadata.is_prompt: if input_metadata.is_prompt:
...@@ -265,6 +266,7 @@ def _paged_attention( ...@@ -265,6 +266,7 @@ def _paged_attention(
block_size, block_size,
input_metadata.max_context_len, input_metadata.max_context_len,
alibi_slopes, alibi_slopes,
input_metadata.kv_cache_dtype,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -295,5 +297,6 @@ def _paged_attention( ...@@ -295,5 +297,6 @@ def _paged_attention(
block_size, block_size,
input_metadata.max_context_len, input_metadata.max_context_len,
alibi_slopes, alibi_slopes,
input_metadata.kv_cache_dtype,
) )
return output return output
import enum import enum
import os import os
import socket import socket
import subprocess
import uuid import uuid
from platform import uname from platform import uname
from typing import List from typing import List, Tuple, Union
from packaging.version import parse, Version
import psutil import psutil
import torch import torch
...@@ -17,7 +19,17 @@ from typing import ( ...@@ -17,7 +19,17 @@ from typing import (
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Hashable, Optional from typing import Any, Hashable, Optional
from vllm.logger import init_logger
T = TypeVar("T") T = TypeVar("T")
logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8_e5m2": torch.uint8,
}
class Device(enum.Enum): class Device(enum.Enum):
...@@ -167,3 +179,99 @@ def get_open_port() -> int: ...@@ -167,3 +179,99 @@ def get_open_port() -> int:
def set_cuda_visible_devices(device_ids: List[int]) -> None: def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def get_nvcc_cuda_version() -> Version:
cuda_home = os.environ.get('CUDA_HOME')
if not cuda_home:
cuda_home = '/usr/local/cuda'
logger.info(
f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
)
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
return nvcc_cuda_version
def _generate_random_fp8_e5m2(
tensor: torch.tensor,
low: float,
high: float,
) -> None:
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
# it may occur Inf or NaN if we directly use torch.randint
# to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format repesents Inf.
# | E4M3 | E5M2
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
from vllm._C import cache_ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high)
cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
del tensor_tmp
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
else:
raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8_e5m2":
torch_dtype = torch.uint8
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype):
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(key_cache, -scale, scale)
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=torch_dtype,
device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8_e5m2':
_generate_random_fp8_e5m2(value_cache, -scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import in_wsl from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -34,12 +34,16 @@ class CacheEngine: ...@@ -34,12 +34,16 @@ class CacheEngine:
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config) self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config)
self.dtype = model_config.dtype
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Initialize the cache. # Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache() self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache() self.cpu_cache = self.allocate_cpu_cache()
...@@ -142,6 +146,7 @@ class CacheEngine: ...@@ -142,6 +146,7 @@ class CacheEngine:
@staticmethod @staticmethod
def get_cache_block_size( def get_cache_block_size(
block_size: int, block_size: int,
cache_dtype: str,
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
) -> int: ) -> int:
...@@ -152,7 +157,11 @@ class CacheEngine: ...@@ -152,7 +157,11 @@ class CacheEngine:
key_cache_block = block_size * num_heads * head_size key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block) total = num_layers * (key_cache_block + value_cache_block)
dtype_size = _get_dtype_size(model_config.dtype) if cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype_size = _get_dtype_size(dtype)
return dtype_size * total return dtype_size * total
......
...@@ -36,6 +36,7 @@ class ModelRunner: ...@@ -36,6 +36,7 @@ class ModelRunner:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
...@@ -68,6 +69,7 @@ class ModelRunner: ...@@ -68,6 +69,7 @@ class ModelRunner:
self.graph_block_tables = None # Set after initial profiling. self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result # cache in_wsl result
self.in_wsl = in_wsl() self.in_wsl = in_wsl()
self.kv_cache_dtype = kv_cache_dtype
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config, self.lora_config) self.model = get_model(self.model_config, self.lora_config)
...@@ -223,6 +225,7 @@ class ModelRunner: ...@@ -223,6 +225,7 @@ class ModelRunner:
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, input_metadata, prompt_lens, return (input_tokens, input_positions, input_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, subquery_lens, lora_index_mapping, lora_prompt_mapping,
...@@ -350,6 +353,7 @@ class ModelRunner: ...@@ -350,6 +353,7 @@ class ModelRunner:
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
) )
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
...@@ -473,6 +477,7 @@ class ModelRunner: ...@@ -473,6 +477,7 @@ class ModelRunner:
"context_lens": input_metadata.context_lens, "context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables, "block_tables": input_metadata.block_tables,
"use_cuda_graph": input_metadata.use_cuda_graph, "use_cuda_graph": input_metadata.use_cuda_graph,
"kv_cache_dtype": input_metadata.kv_cache_dtype,
"selected_token_indices": "selected_token_indices":
sampling_metadata.selected_token_indices, sampling_metadata.selected_token_indices,
"lora_requests": lora_requests, "lora_requests": lora_requests,
...@@ -495,6 +500,7 @@ class ModelRunner: ...@@ -495,6 +500,7 @@ class ModelRunner:
context_lens=metadata_dict["context_lens"], context_lens=metadata_dict["context_lens"],
block_tables=metadata_dict["block_tables"], block_tables=metadata_dict["block_tables"],
use_cuda_graph=metadata_dict["use_cuda_graph"], use_cuda_graph=metadata_dict["use_cuda_graph"],
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
) )
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
...@@ -665,6 +671,7 @@ class ModelRunner: ...@@ -665,6 +671,7 @@ class ModelRunner:
context_lens=context_lens[:batch_size], context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size], block_tables=block_tables[:batch_size],
use_cuda_graph=True, use_cuda_graph=True,
kv_cache_dtype=self.kv_cache_dtype,
) )
if self.lora_config: if self.lora_config:
......
...@@ -37,6 +37,7 @@ class Worker: ...@@ -37,6 +37,7 @@ class Worker:
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
...@@ -54,6 +55,7 @@ class Worker: ...@@ -54,6 +55,7 @@ class Worker:
parallel_config, parallel_config,
scheduler_config, scheduler_config,
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
...@@ -95,6 +97,7 @@ class Worker: ...@@ -95,6 +97,7 @@ class Worker:
block_size: int, block_size: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
cpu_swap_space: int, cpu_swap_space: int,
cache_dtype: str,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model and returns the maximum """Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated. number of GPU and CPU cache blocks that can be allocated.
...@@ -119,7 +122,7 @@ class Worker: ...@@ -119,7 +122,7 @@ class Worker:
peak_memory = total_gpu_memory - free_gpu_memory peak_memory = total_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size( cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config) block_size, cache_dtype, self.model_config, self.parallel_config)
num_gpu_blocks = int( num_gpu_blocks = int(
(total_gpu_memory * gpu_memory_utilization - peak_memory) // (total_gpu_memory * gpu_memory_utilization - peak_memory) //
cache_block_size) cache_block_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment