"driver/vscode:/vscode.git/clone" did not exist on "08c001404f16b6701ac9ff1436f1b8d75f4d1560"
Unverified Commit 9090bf02 authored by zhaoyang-star's avatar zhaoyang-star Committed by GitHub
Browse files

Support FP8-E5M2 KV Cache (#2279)


Co-authored-by: default avatarzhaoyang <zhao.yang16@zte.com.cn>
Co-authored-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent 7d648418
...@@ -24,6 +24,7 @@ def main(args: argparse.Namespace): ...@@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
trust_remote_code=args.trust_remote_code, trust_remote_code=args.trust_remote_code,
dtype=args.dtype, dtype=args.dtype,
enforce_eager=args.enforce_eager, enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -117,6 +118,13 @@ if __name__ == '__main__': ...@@ -117,6 +118,13 @@ if __name__ == '__main__':
parser.add_argument('--enforce-eager', parser.add_argument('--enforce-eager',
action='store_true', action='store_true',
help='enforce eager mode and disable CUDA graph') help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8_e5m2'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument( parser.add_argument(
'--profile', '--profile',
action='store_true', action='store_true',
......
...@@ -71,6 +71,7 @@ def run_vllm( ...@@ -71,6 +71,7 @@ def run_vllm(
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
enforce_eager: bool, enforce_eager: bool,
kv_cache_dtype: str,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
...@@ -83,6 +84,7 @@ def run_vllm( ...@@ -83,6 +84,7 @@ def run_vllm(
dtype=dtype, dtype=dtype,
max_model_len=max_model_len, max_model_len=max_model_len,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
) )
# Add the requests to the engine. # Add the requests to the engine.
...@@ -206,7 +208,8 @@ def main(args: argparse.Namespace): ...@@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size, args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager) args.max_model_len, args.enforce_eager,
args.kv_cache_dtype)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
...@@ -284,6 +287,13 @@ if __name__ == "__main__": ...@@ -284,6 +287,13 @@ if __name__ == "__main__":
parser.add_argument("--enforce-eager", parser.add_argument("--enforce-eager",
action="store_true", action="store_true",
help="enforce eager execution") help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
......
from typing import Optional
import argparse import argparse
import random import random
import time import time
import torch import torch
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
from vllm._C import ops from vllm._C import ops
NUM_BLOCKS = 1024 NUM_BLOCKS = 1024
...@@ -23,6 +25,7 @@ def main( ...@@ -23,6 +25,7 @@ def main(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
do_profile: bool, do_profile: bool,
kv_cache_dtype: Optional[str] = None,
) -> None: ) -> None:
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
...@@ -59,15 +62,10 @@ def main( ...@@ -59,15 +62,10 @@ def main(
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
# Create the KV cache. # Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size() key_caches, value_caches = create_kv_caches_with_random(
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") dtype)
key_cache.uniform_(-scale, scale) key_cache, value_cache = key_caches[0], value_caches[0]
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device="cuda")
value_cache.uniform_(-scale, scale)
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
...@@ -106,6 +104,7 @@ def main( ...@@ -106,6 +104,7 @@ def main(
block_size, block_size,
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype,
) )
elif version == "v2": elif version == "v2":
ops.paged_attention_v2( ops.paged_attention_v2(
...@@ -123,6 +122,7 @@ def main( ...@@ -123,6 +122,7 @@ def main(
block_size, block_size,
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype,
) )
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
...@@ -168,16 +168,18 @@ if __name__ == '__main__': ...@@ -168,16 +168,18 @@ if __name__ == '__main__':
default="half") default="half")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.num_query_heads % args.num_kv_heads != 0: if args.num_query_heads % args.num_kv_heads != 0:
raise ValueError("num_query_heads must be divisible by num_kv_heads") raise ValueError("num_query_heads must be divisible by num_kv_heads")
dtype_to_torch_dtype = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
main( main(
version=args.version, version=args.version,
num_seqs=args.batch_size, num_seqs=args.batch_size,
...@@ -187,7 +189,8 @@ if __name__ == '__main__': ...@@ -187,7 +189,8 @@ if __name__ == '__main__':
head_size=args.head_size, head_size=args.head_size,
block_size=args.block_size, block_size=args.block_size,
use_alibi=args.use_alibi, use_alibi=args.use_alibi,
dtype=dtype_to_torch_dtype[args.dtype], dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed, seed=args.seed,
do_profile=args.profile, do_profile=args.profile,
kv_cache_dtype=args.kv_cache_dtype,
) )
...@@ -4,3 +4,4 @@ ...@@ -4,3 +4,4 @@
#include "dtype_float16.cuh" #include "dtype_float16.cuh"
#include "dtype_float32.cuh" #include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh" #include "dtype_bfloat16.cuh"
#include "dtype_fp8_e5m2.cuh"
This diff is collapsed.
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
#ifdef ENABLE_FP8_E5M2
#include <cuda_fp8.h>
#endif
namespace vllm {
#ifdef ENABLE_FP8_E5M2
// fp8 vector types for quantization of kv cache
template<>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
};
template<>
struct Vec<uint8_t, 2> {
using Type = uint16_t;
};
template<>
struct Vec<uint8_t, 4> {
using Type = uint32_t;
};
template<>
struct Vec<uint8_t, 8> {
using Type = uint2;
};
#endif // ENABLE_FP8_E5M2
} // namespace vllm
...@@ -20,7 +20,8 @@ void reshape_and_cache( ...@@ -20,7 +20,8 @@ void reshape_and_cache(
torch::Tensor& value, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping); torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
void gather_cached_kv( void gather_cached_kv(
torch::Tensor& key, torch::Tensor& key,
...@@ -28,3 +29,8 @@ void gather_cached_kv( ...@@ -28,3 +29,8 @@ void gather_cached_kv(
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping); torch::Tensor& slot_mapping);
// Just for unittest
void convert_fp8_e5m2(
torch::Tensor& src_cache,
torch::Tensor& dst_cache);
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
...@@ -131,7 +132,7 @@ void copy_blocks( ...@@ -131,7 +132,7 @@ void copy_blocks(
dim3 block(std::min(1024, numel_per_block)); dim3 block(std::min(1024, numel_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device); const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
...@@ -143,12 +144,12 @@ void copy_blocks( ...@@ -143,12 +144,12 @@ void copy_blocks(
namespace vllm { namespace vllm {
template<typename scalar_t> template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int key_stride,
const int value_stride, const int value_stride,
...@@ -185,19 +186,45 @@ __global__ void reshape_and_cache_kernel( ...@@ -185,19 +186,45 @@ __global__ void reshape_and_cache_kernel(
+ head_idx * head_size * block_size + head_idx * head_size * block_size
+ head_offset * block_size + head_offset * block_size
+ block_offset; + block_offset;
key_cache[tgt_key_idx] = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
value_cache[tgt_value_idx] = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_e5m2_kv_cache) {
#ifdef ENABLE_FP8_E5M2
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#else
assert(false);
#endif
} else {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
}
} }
} }
} // namespace vllm } // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
{ {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
...@@ -212,23 +239,25 @@ void reshape_and_cache( ...@@ -212,23 +239,25 @@ void reshape_and_cache(
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( if (kv_cache_dtype == "auto") {
key.scalar_type(), if (key.dtype() == at::ScalarType::Float) {
"reshape_and_cache_kernel", CALL_RESHAPE_AND_CACHE(float, float, false);
[&] { } else if (key.dtype() == at::ScalarType::Half) {
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>( CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
key.data_ptr<scalar_t>(), } else if (key.dtype() == at::ScalarType::BFloat16) {
value.data_ptr<scalar_t>(), CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
key_cache.data_ptr<scalar_t>(), }
value_cache.data_ptr<scalar_t>(), } else if (kv_cache_dtype == "fp8_e5m2") {
slot_mapping.data_ptr<int64_t>(), if (key.dtype() == at::ScalarType::Float) {
key_stride, CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
value_stride, } else if (key.dtype() == at::ScalarType::Half) {
num_heads, CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
head_size, } else if (key.dtype() == at::ScalarType::BFloat16) {
block_size, CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
x); }
}); } else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
} }
namespace vllm { namespace vllm {
...@@ -256,12 +285,12 @@ __global__ void gather_cached_kv_kernel( ...@@ -256,12 +285,12 @@ __global__ void gather_cached_kv_kernel(
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
const int tgt_key_idx = token_idx * key_stride + i; const int tgt_key_idx = token_idx * key_stride + i;
const int tgt_value_idx = token_idx * value_stride + i; const int tgt_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
const int x_offset = head_offset % x; const int x_offset = head_offset % x;
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x + x_idx * block_size * x
...@@ -373,7 +402,7 @@ void gather_cached_kv( ...@@ -373,7 +402,7 @@ void gather_cached_kv(
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key.scalar_type(), key.scalar_type(),
"gather_cached_kv_kernel_optimized", "gather_cached_kv_kernel_optimized",
[&] { [&] {
...@@ -391,3 +420,55 @@ void gather_cached_kv( ...@@ -391,3 +420,55 @@ void gather_cached_kv(
x); x);
}); });
} }
namespace vllm {
template<typename Tout, typename Tin>
__global__ void convert_fp8_e5m2_kernel(
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
#ifdef ENABLE_FP8_E5M2
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
}
}
} // namespace vllm
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
block_stride);
void convert_fp8_e5m2(
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
{
int64_t num_blocks = src_cache.size(0);
int64_t block_stride = src_cache.stride(0);
dim3 grid(num_blocks);
dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8_E5M2(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8_E5M2(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
}
}
...@@ -14,3 +14,13 @@ ...@@ -14,3 +14,13 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
...@@ -13,7 +13,8 @@ void paged_attention_v1( ...@@ -13,7 +13,8 @@ void paged_attention_v1(
torch::Tensor& context_lens, torch::Tensor& context_lens,
int block_size, int block_size,
int max_context_len, int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes); const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, torch::Tensor& out,
...@@ -29,7 +30,8 @@ void paged_attention_v2( ...@@ -29,7 +30,8 @@ void paged_attention_v2(
torch::Tensor& context_lens, torch::Tensor& context_lens,
int block_size, int block_size,
int max_context_len, int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes); const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);
void rms_norm( void rms_norm(
torch::Tensor& out, torch::Tensor& out,
......
...@@ -75,6 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -75,6 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"gather_cached_kv", "gather_cached_kv",
&gather_cached_kv, &gather_cached_kv,
"Gather key and value from the cache into contiguous QKV tensors"); "Gather key and value from the cache into contiguous QKV tensors");
cache_ops.def(
"convert_fp8_e5m2",
&convert_fp8_e5m2,
"Convert the key and value cache to fp8_e5m2 data type");
// Cuda utils // Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
......
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include "../../attention/attention_dtypes.h"
#include "../../attention/dtype_float32.cuh"
#include "../../attention/dtype_float16.cuh"
#include "../../attention/dtype_bfloat16.cuh"
#pragma once
namespace vllm {
#ifdef ENABLE_FP8_E5M2
namespace fp8_e5m2_unscaled {
template<typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
return x;
}
// fp8 -> half
template<>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
{
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
return res.x;
}
// fp8x2 -> half2
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
{
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
tmp.u16[0] = res.x;
tmp.u16[1] = res.y;
return tmp.u32;
}
// fp8x4 -> half2x2
template<>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
{
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2;
}
// fp8x8 -> half2x4
template<>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
{
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2;
}
// fp8 -> __nv_bfloat16
template<>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
{
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
// half -> float -> bf16
float tmp = half_to_float(res.x);
return __float2bfloat16(tmp);
}
// fp8x2 -> __nv_bfloat162
template<>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
{
__nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res;
}
// fp8x4 -> bf16_4_t
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
{
bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> bf16_8_t
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
{
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template<>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
{
// fp8 -> half
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
// half -> float
return half_to_float(tmp);
}
// fp8x2 -> float2
template<>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
{
// fp8x2 -> half2
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
// half2 -> float2
return half2_to_float2(tmp);
}
// fp8x4 -> float4
template<>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
{
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> float8
template<>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
{
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
{
__half_raw tmp;
tmp.x = a;
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
}
// bf16 -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
#endif
}
// float -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
{
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
}
// fp8x4 -> float4
template<>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
{
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
{
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
template<>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
{
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
template<>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
{
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
template<>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
{
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
template<>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
__nv_bfloat162 b;
from_float(b, a);
return b;
}
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
bf16_4_t b;
from_float(b, a);
return b;
}
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
bf16_8_t b;
from_float(b, a);
return b;
}
} // namespace fp8_e5m2_unscaled
#endif // ENABLE_FP8_E5M2
} // namespace vllm
.. _fp8_e5m2_kv_cache:
FP8 E5M2 KV Cache
==================
The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits.
The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other.
Here is an example of how to enable this feature:
.. code-block:: python
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
...@@ -253,6 +253,9 @@ if _is_cuda(): ...@@ -253,6 +253,9 @@ if _is_cuda():
num_threads = min(os.cpu_count(), nvcc_threads) num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)] NVCC_FLAGS += ["--threads", str(num_threads)]
if nvcc_cuda_version >= Version("11.8"):
NVCC_FLAGS += ["-DENABLE_FP8_E5M2"]
# changes for punica kernels # changes for punica kernels
NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS
REMOVE_NVCC_FLAGS = [ REMOVE_NVCC_FLAGS = [
......
from typing import List, Tuple
import pytest import pytest
import torch from vllm.utils import create_kv_caches_with_random
def create_kv_caches(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
key_cache.uniform_(-scale, scale)
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
value_cache.uniform_(-scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
@pytest.fixture() @pytest.fixture()
def kv_cache_factory(): def kv_cache_factory():
return create_kv_caches return create_kv_caches_with_random
...@@ -6,14 +6,16 @@ import torch ...@@ -6,14 +6,16 @@ import torch
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm._C import ops from vllm._C import ops, cache_ops
from vllm.utils import get_max_shared_memory_bytes from vllm.utils import get_max_shared_memory_bytes
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability. # This will change depending on the compute capability.
# - 512 as a buffer # - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 12000 # Arbitrary values for testing # There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512 PARTITION_SIZE = 512
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
...@@ -23,6 +25,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing ...@@ -23,6 +25,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256] HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True] USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
SEEDS = [0] SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
...@@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention( ...@@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_paged_attention( def test_paged_attention(
...@@ -116,6 +120,7 @@ def test_paged_attention( ...@@ -116,6 +120,7 @@ def test_paged_attention(
use_alibi: bool, use_alibi: bool,
block_size: int, block_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str,
seed: int, seed: int,
device: int, device: int,
) -> None: ) -> None:
...@@ -158,8 +163,9 @@ def test_paged_attention( ...@@ -158,8 +163,9 @@ def test_paged_attention(
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size, dtype, num_kv_heads, head_size,
seed, gpu_id) kv_cache_dtype, dtype, seed,
gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Call the paged attention kernel. # Call the paged attention kernel.
...@@ -177,6 +183,7 @@ def test_paged_attention( ...@@ -177,6 +183,7 @@ def test_paged_attention(
block_size, block_size,
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype,
) )
elif version == "v2": elif version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) // num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
...@@ -209,11 +216,30 @@ def test_paged_attention( ...@@ -209,11 +216,30 @@ def test_paged_attention(
block_size, block_size,
max_context_len, max_context_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype,
) )
else: else:
raise AssertionError(f"Unknown version: {version}") raise AssertionError(f"Unknown version: {version}")
# Run the reference implementation. # Run the reference implementation.
if kv_cache_dtype == "fp8_e5m2":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=gpu_id)
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=gpu_id)
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query) ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention( ref_single_query_cached_kv_attention(
ref_output, ref_output,
...@@ -230,7 +256,12 @@ def test_paged_attention( ...@@ -230,7 +256,12 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two # NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two # implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test. # outputs. Thus, we use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8_e5m2":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
......
...@@ -15,6 +15,7 @@ NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing ...@@ -15,6 +15,7 @@ NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
...@@ -26,6 +27,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] ...@@ -26,6 +27,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode() @torch.inference_mode()
def test_copy_blocks( def test_copy_blocks(
kv_cache_factory, kv_cache_factory,
...@@ -38,6 +40,7 @@ def test_copy_blocks( ...@@ -38,6 +40,7 @@ def test_copy_blocks(
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int, device: int,
kv_cache_dtype: str,
) -> None: ) -> None:
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
...@@ -59,7 +62,8 @@ def test_copy_blocks( ...@@ -59,7 +62,8 @@ def test_copy_blocks(
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads, num_layers, num_heads,
head_size, dtype, seed, gpu_id) head_size, kv_cache_dtype,
dtype, seed, gpu_id)
# Clone the KV caches. # Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches] cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
...@@ -124,7 +128,7 @@ def test_reshape_and_cache( ...@@ -124,7 +128,7 @@ def test_reshape_and_cache(
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype, num_heads, head_size, dtype,
seed, gpu_id) None, seed, gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches. # Clone the KV caches.
...@@ -133,7 +137,7 @@ def test_reshape_and_cache( ...@@ -133,7 +137,7 @@ def test_reshape_and_cache(
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping) slot_mapping, "auto")
# Run the reference implementation. # Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
......
from typing import Optional, Union, ClassVar from typing import Optional, Union, ClassVar
from dataclasses import dataclass from dataclasses import dataclass
import os import os
from packaging.version import Version
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -275,6 +276,7 @@ class CacheConfig: ...@@ -275,6 +276,7 @@ class CacheConfig:
gpu_memory_utilization: Fraction of GPU memory to use for the gpu_memory_utilization: Fraction of GPU memory to use for the
vLLM execution. vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB). swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
""" """
def __init__( def __init__(
...@@ -282,13 +284,16 @@ class CacheConfig: ...@@ -282,13 +284,16 @@ class CacheConfig:
block_size: int, block_size: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
swap_space: int, swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window self.sliding_window = sliding_window
self._verify_args() self._verify_args()
self._verify_cache_dtype()
# Will be set after profiling. # Will be set after profiling.
self.num_gpu_blocks = None self.num_gpu_blocks = None
...@@ -300,6 +305,28 @@ class CacheConfig: ...@@ -300,6 +305,28 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got " "GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.") f"{self.gpu_memory_utilization}.")
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8_e5m2":
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version < Version("11.8"):
raise ValueError(
"FP8 is not supported when cuda version is lower than 11.8."
)
device_name = torch.cuda.get_device_name()
if "AMD" in device_name:
raise NotImplementedError(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
logger.info(
"Using fp8_e5m2 data type to store kv cache. It reduces "
"the GPU memory footprint and boosts the performance. "
"But it may cause slight accuracy drop. "
"Currently we only support fp8 without scaling factors and "
"make e5m2 as a default format.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
......
...@@ -17,6 +17,7 @@ class EngineArgs: ...@@ -17,6 +17,7 @@ class EngineArgs:
download_dir: Optional[str] = None download_dir: Optional[str] = None
load_format: str = 'auto' load_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
...@@ -122,6 +123,14 @@ class EngineArgs: ...@@ -122,6 +123,14 @@ class EngineArgs:
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8_e5m2'],
default='auto',
help='Data type for kv cache storage. If "auto", will use model '
'data type. Note FP8 is not supported when cuda version is '
'lower than 11.8.')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=int, type=int,
default=None, default=None,
...@@ -269,7 +278,7 @@ class EngineArgs: ...@@ -269,7 +278,7 @@ class EngineArgs:
self.max_context_len_to_capture) self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.swap_space, self.kv_cache_dtype,
model_config.get_sliding_window()) model_config.get_sliding_window())
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
......
...@@ -85,6 +85,7 @@ class LLMEngine: ...@@ -85,6 +85,7 @@ class LLMEngine:
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
f"quantization={model_config.quantization}, " f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, " f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
...@@ -144,6 +145,7 @@ class LLMEngine: ...@@ -144,6 +145,7 @@ class LLMEngine:
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )
self._run_workers("init_model") self._run_workers("init_model")
...@@ -234,6 +236,7 @@ class LLMEngine: ...@@ -234,6 +236,7 @@ class LLMEngine:
model_config = copy.deepcopy(self.model_config) model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config) parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config) scheduler_config = copy.deepcopy(self.scheduler_config)
cache_config = copy.deepcopy(self.cache_config)
for rank, (worker, (node_id, for rank, (worker, (node_id,
_)) in enumerate(zip(self.workers, _)) in enumerate(zip(self.workers,
...@@ -249,6 +252,7 @@ class LLMEngine: ...@@ -249,6 +252,7 @@ class LLMEngine:
rank, rank,
distributed_init_method, distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
cache_config=cache_config,
)) ))
driver_rank = 0 driver_rank = 0
...@@ -261,6 +265,7 @@ class LLMEngine: ...@@ -261,6 +265,7 @@ class LLMEngine:
driver_rank, driver_rank,
distributed_init_method, distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
cache_config=cache_config,
is_driver_worker=True, is_driver_worker=True,
) )
...@@ -306,6 +311,7 @@ class LLMEngine: ...@@ -306,6 +311,7 @@ class LLMEngine:
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization, gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes, cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
) )
# Since we use a shared centralized controller, we take the minimum # Since we use a shared centralized controller, we take the minimum
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment