Unverified Commit 2ff767b5 authored by Adrian Abeyta's avatar Adrian Abeyta Committed by GitHub
Browse files

Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)


Co-authored-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: default avatarHaiShaw <hixiao@gmail.com>
Co-authored-by: default avatarAdrianAbeyta <Adrian.Abeyta@amd.com>
Co-authored-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: default avatarroot <root@gt-pla-u18-08.pla.dcgpu>
Co-authored-by: default avatarmawong-amd <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: default avatarttbachyinsda <ttbachyinsda@outlook.com>
Co-authored-by: default avatarguofangze <guofangze@kuaishou.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarjacobthebanana <50071502+jacobthebanana@users.noreply.github.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 3dcb3e8b
......@@ -181,6 +181,7 @@ _build/
# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h
# Benchmark dataset
*.json
......@@ -19,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
#
# Supported/expected torch versions for CUDA/ROCm.
......
......@@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
enable_chunked_prefill=args.enable_chunked_prefill,
......@@ -127,10 +128,23 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8_e5m2'],
choices=['auto', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type.')
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
'--profile',
action='store_true',
......
......@@ -72,6 +72,7 @@ def run_vllm(
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9,
......@@ -89,6 +90,7 @@ def run_vllm(
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir)
......@@ -217,7 +219,8 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device,
args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching,
args.gpu_memory_utilization, args.download_dir)
elif args.backend == "hf":
......@@ -306,10 +309,23 @@ if __name__ == "__main__":
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
choices=["auto", "fp8"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
"--device",
type=str,
......
......@@ -97,6 +97,9 @@ def main(
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()
# Using default kv_scale
kv_scale = 1.0
for _ in range(num_iters):
if version == "v1":
ops.paged_attention_v1(
......@@ -112,6 +115,7 @@ def main(
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
elif version == "v2":
ops.paged_attention_v2(
......@@ -130,6 +134,7 @@ def main(
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
raise ValueError(f"Invalid version: {version}")
......@@ -179,11 +184,13 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
choices=["auto", "fp8"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args()
print(args)
......
......@@ -117,6 +117,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8_E4M3"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")
......
......@@ -4,4 +4,4 @@
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_fp8_e5m2.cuh"
#include "dtype_fp8.cuh"
......@@ -22,12 +22,26 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef ENABLE_FP8_E5M2
#if defined(ENABLE_FP8_E5M2)
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include <algorithm>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
......@@ -78,7 +92,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE,
bool IS_FP8_KV_CACHE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
......@@ -95,7 +109,8 @@ __device__ void paged_attention_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
const int kv_head_stride,
const float kv_scale) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
......@@ -142,7 +157,7 @@ __device__ void paged_attention_kernel(
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
#ifdef ENABLE_FP8_E5M2
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
#endif
......@@ -208,11 +223,16 @@ __device__ void paged_attention_kernel(
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
if constexpr (IS_FP8_KV_CACHE) {
#if defined(ENABLE_FP8_E5M2)
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec.
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
#elif defined(ENABLE_FP8_E4M3)
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
// cache vec to k vec in higher precision (FP16, BFloat16, etc.)
k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
#else
assert(false);
#endif
......@@ -292,7 +312,7 @@ __device__ void paged_attention_kernel(
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
#ifdef ENABLE_FP8_E5M2
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
#endif
using Float_L_vec = typename FloatVec<L_vec>::Type;
......@@ -328,11 +348,16 @@ __device__ void paged_attention_kernel(
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
if constexpr (IS_FP8_KV_CACHE) {
#if defined(ENABLE_FP8_E5M2)
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
#elif defined(ENABLE_FP8_E4M3)
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
// FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
#else
assert(false);
#endif
......@@ -423,7 +448,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE>
bool IS_FP8_KV_CACHE>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
......@@ -437,11 +462,12 @@ __global__ void paged_attention_v1_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
const int kv_head_stride,
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
......@@ -451,7 +477,7 @@ template<
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_E5M2_KV_CACHE,
bool IS_FP8_KV_CACHE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
......@@ -468,11 +494,12 @@ __global__ void paged_attention_v2_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
const int kv_head_stride,
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride);
q_stride, kv_block_stride, kv_head_stride, kv_scale);
}
// Grid: (num_heads, num_seqs).
......@@ -579,9 +606,9 @@ __global__ void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
IS_FP8_KV_CACHE>), shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
......@@ -594,14 +621,15 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
kv_head_stride, \
kv_scale);
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_E5M2_KV_CACHE,
bool IS_FP8_KV_CACHE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out,
......@@ -613,7 +641,8 @@ void paged_attention_v1_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
......@@ -677,8 +706,8 @@ void paged_attention_v1_launcher(
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
out, \
query, \
key_cache, \
......@@ -688,20 +717,21 @@ void paged_attention_v1_launcher(
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
......@@ -720,7 +750,8 @@ void paged_attention_v1(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype) {
const std::string& kv_cache_dtype,
float kv_scale) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
......@@ -731,7 +762,7 @@ void paged_attention_v1(
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8_e5m2") {
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
......@@ -748,7 +779,7 @@ void paged_attention_v1(
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
IS_FP8_KV_CACHE, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \
max_logits_ptr, \
......@@ -764,7 +795,8 @@ void paged_attention_v1(
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride); \
kv_head_stride, \
kv_scale); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
......@@ -778,7 +810,7 @@ template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_E5M2_KV_CACHE,
bool IS_FP8_KV_CACHE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
......@@ -794,7 +826,8 @@ void paged_attention_v2_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
......@@ -864,8 +897,8 @@ void paged_attention_v2_launcher(
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
out, \
exp_sums, \
max_logits, \
......@@ -878,20 +911,21 @@ void paged_attention_v2_launcher(
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
break; \
case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
break; \
case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
......@@ -913,7 +947,8 @@ void paged_attention_v2(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype) {
const std::string& kv_cache_dtype,
float kv_scale) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
......@@ -924,7 +959,7 @@ void paged_attention_v2(
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8_e5m2") {
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
......
......@@ -8,7 +8,7 @@
#endif
namespace vllm {
#ifdef ENABLE_FP8_E5M2
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
template<>
......
......@@ -21,9 +21,10 @@ void reshape_and_cache(
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
const std::string& kv_cache_dtype,
const float kv_scale);
// Just for unittest
void convert_fp8_e5m2(
void convert_fp8(
torch::Tensor& src_cache,
torch::Tensor& dst_cache);
......@@ -4,8 +4,10 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#ifdef ENABLE_FP8_E5M2
#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include <algorithm>
......@@ -151,7 +153,7 @@ void copy_blocks(
namespace vllm {
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
......@@ -163,7 +165,8 @@ __global__ void reshape_and_cache_kernel(
const int num_heads,
const int head_size,
const int block_size,
const int x) {
const int x,
const float kv_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
......@@ -195,10 +198,13 @@ __global__ void reshape_and_cache_kernel(
+ block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_e5m2_kv_cache) {
#ifdef ENABLE_FP8_E5M2
if constexpr (is_fp8_kv_cache) {
#if defined(ENABLE_FP8_E5M2)
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3)
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
assert(false);
#endif
......@@ -211,8 +217,8 @@ __global__ void reshape_and_cache_kernel(
} // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
......@@ -223,7 +229,8 @@ __global__ void reshape_and_cache_kernel(
num_heads, \
head_size, \
block_size, \
x);
x, \
kv_scale);
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
......@@ -231,7 +238,8 @@ void reshape_and_cache(
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
const std::string& kv_cache_dtype,
const float kv_scale)
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
......@@ -254,7 +262,7 @@ void reshape_and_cache(
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
}
} else if (kv_cache_dtype == "fp8_e5m2") {
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
......@@ -270,15 +278,17 @@ void reshape_and_cache(
namespace vllm {
template<typename Tout, typename Tin>
__global__ void convert_fp8_e5m2_kernel(
__global__ void convert_fp8_kernel(
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
#ifdef ENABLE_FP8_E5M2
#if defined(ENABLE_FP8_E5M2)
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#elif defined(ENABLE_FP8_E4M3)
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
......@@ -287,16 +297,25 @@ __global__ void convert_fp8_e5m2_kernel(
} // namespace vllm
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
#define CALL_CONVERT_FP8(Tout, Tin) \
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
block_stride);
void convert_fp8_e5m2(
void convert_fp8(
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
{
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK(
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device);
int64_t num_blocks = src_cache.size(0);
int64_t block_stride = src_cache.stride(0);
......@@ -305,16 +324,16 @@ void convert_fp8_e5m2(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8_E5M2(uint8_t, float);
CALL_CONVERT_FP8(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
CALL_CONVERT_FP8(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8_E5M2(float, uint8_t);
CALL_CONVERT_FP8(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
CALL_CONVERT_FP8(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
}
}
......@@ -14,7 +14,8 @@ void paged_attention_v1(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);
const std::string& kv_cache_dtype,
float kv_scale);
void paged_attention_v2(
torch::Tensor& out,
......@@ -31,7 +32,8 @@ void paged_attention_v2(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);
const std::string& kv_cache_dtype,
float kv_scale);
void rms_norm(
torch::Tensor& out,
......
......@@ -91,9 +91,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"convert_fp8_e5m2",
&convert_fp8_e5m2,
"Convert the key and value cache to fp8_e5m2 data type");
"convert_fp8",
&convert_fp8,
"Convert the key and value cache to fp8 data type");
// Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
......
#pragma once
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#else
#include <type_traits>
#include <stdint.h>
#include <math.h>
#include <iostream>
#endif
#include "hip_float8_impl.h"
struct alignas(1) hip_fp8
{
struct from_bits_t
{
};
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
uint8_t data;
hip_fp8() = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
: data(v)
{
}
#ifdef __HIP__MI300__
// NOTE: ON-DEVICE... always optimal bias
explicit HIP_FP8_DEVICE hip_fp8(float v)
: data(hip_fp8_impl::to_fp8_from_fp32(v))
{
}
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
: hip_fp8(static_cast<float>(v))
{
}
// Host only implementation using s/w simulation
explicit HIP_FP8_HOST
#else // __HIP__MI300__
// both Host and DEVICE for non-MI300 using s/w simulation
explicit HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__
hip_fp8(float v)
{
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
}
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
: hip_fp8(static_cast<float>(v))
{
}
#ifdef __HIP__MI300__
// upcast using device specific intrinsic
explicit inline HIP_FP8_DEVICE operator float() const
{
float fval;
uint32_t i32val = static_cast<uint32_t>(data);
// upcast
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
}
explicit inline HIP_FP8_HOST operator float() const
#else // __HIP__MI300__
explicit inline HIP_FP8_HOST_DEVICE operator float() const
#endif // __HIP__MI300__
{
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
}
};
namespace std
{
inline hip_fp8 sin(hip_fp8 a)
{
return hip_fp8(sinf(float(a)));
}
inline hip_fp8 cos(hip_fp8 a)
{
return hip_fp8(cosf(float(a)));
}
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
{
return a;
}
} // namespace std
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
{
return os << float(f8);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
{
return (fa + float(b));
}
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
{
return (float(a) + fb);
}
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
{
return hip_fp8(float(a) + float(b));
}
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
{
return a = hip_fp8(float(a) + float(b));
}
// overloading multiplication, always returns float,
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
{
return float(a) * float(b);
}
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
{
return (a * float(b));
}
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
{
return (float(a) * b);
}
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
{
return ((float)a * float(b));
}
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
{
return ((float)a * float(b));
}
// overloading for compare
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
{
return (a.data == b.data);
}
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
{
return (a.data != b.data);
}
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
{
return static_cast<float>(a) >= static_cast<float>(b);
}
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
{
return static_cast<float>(a) > static_cast<float>(b);
}
#pragma once
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#endif
#ifdef __HIPCC__
#define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__
#else
#define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST
#define HIP_FP8_DEVICE
#endif
namespace hip_fp8_impl
{
#ifdef __HIP__MI300__
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
{
uint8_t i8data;
union {
float fval;
uint32_t i32val;
uint8_t i8val[4]; // NOTE: not endian independent
} val;
uint32_t ival = 0;
val.fval = v;
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
}
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false); // false -> WORD0
val.i32val = ival;
i8data = val.i8val[0];
return i8data;
}
#endif // __HIP__MI300__
HIP_FP8_HOST inline int clz(uint32_t x)
{
return __builtin_clz(x);
}
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
HIP_FP8_DEVICE inline int clz(uint32_t x)
{
return __clz(x);
}
#endif
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
{
#ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value;
#else
constexpr bool is_half = false;
#endif
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x;
if (sizeof(T) == 4) {
x = reinterpret_cast<uint32_t&>(_x);
} else {
x = reinterpret_cast<uint16_t&>(_x);
}
uint32_t head, mantissa;
int exponent, bias;
uint32_t sign;
if (sizeof(T) == 4) {
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
} else {
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
// Deal with inf and NaNs
if (negative_zero_nan) {
if (sizeof(T) == 4) {
if ((x & 0x7F800000) == 0x7F800000) {
return 0x80;
}
} else {
// if(__hisinf(x) || __hisnan(x))
if ((x & 0x7C00) == 0x7C00) {
return 0x80;
}
}
} else {
if (sizeof(T) == 4) {
if ((x & 0x7F800000) == 0x7F800000) {
return signed_inf + (mantissa != 0 ? 1 : 0);
}
} else {
if ((x & 0x7C00) == 0x7C00) {
return signed_inf + (mantissa != 0 ? 1 : 0);
}
}
}
if (x == 0) {
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, f8_exponent, exponent_diff;
if (exponent == 0) { // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
} else { // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if (act_exponent <= f8_denormal_act_exponent) {
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent;
} else { // both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
// for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part
and make something not midpoint look like midpoint. For example, the fp16
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
shift right by 4 bits, it would look like midpoint.
*/
if (exponent_diff > 0) {
mantissa >>= exponent_diff;
} else if (exponent_diff == -1) {
mantissa <<= -exponent_diff;
}
bool implicit_one = mantissa & (1 << mfmt);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
// is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if (f8_exponent == 0) {
if ((1 << mfmt) & mantissa) {
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
} else {
if ((1 << (mfmt + 1)) & mantissa) {
mantissa >>= 1;
f8_exponent++;
}
}
mantissa >>= (mfmt - wm);
// above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
if (f8_exponent > max_exp) {
if (clip) {
mantissa = (1 << wm) - 1;
f8_exponent = max_exp;
} else {
return signed_inf;
}
}
if (f8_exponent == 0 && mantissa == 0) {
return negative_zero_nan ? 0 : (sign << 7);
}
mantissa &= (1 << wm) - 1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
}
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
{
#ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value;
#else
constexpr bool is_half = false;
#endif
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0;
#ifdef __HIPCC__
if (is_half) {
const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const _Float16&>(ihInf);
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
} else
#endif
if (is_float) {
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
if (x == 0) {
return 0;
}
uint32_t sign = x >> 7;
uint32_t mantissa = x & ((1 << wm) - 1);
int exponent = (x & 0x7F) >> wm;
if (negative_zero_nan) {
if (x == 0x80) {
return fNaN;
}
} else {
if (x == 0x80) {
return fNeg0;
}
if (exponent == ((1 << we) - 1)) {
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
}
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
if (we == 5 && is_half && !negative_zero_nan) {
retval = x << 8;
return reinterpret_cast<const T&>(retval);
}
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
// subnormal input
if (exponent == 0) {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << wm) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if (exponent <= 0) {
mantissa |= 1 << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
if (sizeof(T) == 2) {
retval = (sign << 15) | (exponent << 10) | mantissa;
} else {
retval = (sign << 31) | (exponent << 23) | mantissa;
}
return reinterpret_cast<const T&>(retval);
}
} // namespace hip_fp8_impl
#pragma once
#include "hip_float8.h"
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"
namespace vllm
{
namespace fp8_e4m3 {
template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
return x;
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
{
return x;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
{
hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res;
res.data = static_cast<float>(f8);
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
{
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r.x.data = f2[0];
tmp.h2r.y.data = f2[1];
return tmp.ui32;
#else
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
return tmp.u32;
#endif
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
{
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
{
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2;
}
using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
{
hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8};
return __float2bfloat16(f);
}
using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
{
__nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
{
bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
{
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
{
hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8);
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
{
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0];
res.y = f2[1];
return res;
#else
float2 res;
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
return res;
#endif
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
{
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
{
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
{
__half_raw tmp;
tmp.x = a;
hip_fp8 f8{static_cast<float>(tmp.data)};
return f8.data;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
{
hip_fp8 res{__bfloat162float(a)};
return res.data;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
{
hip_fp8 f8(a);
return f8.data;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
{
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
// float2 -> half2
template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
{
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
// Float4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
{
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
// Float4 -> float4
template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
{
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
// Float8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
{
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
// float2 -> bfloat162
template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
{
__nv_bfloat162 b = __float22bfloat162_rn(a);
return b;
}
// Float4 -> bfloat162x2
template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
{
bf16_4_t b;
b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y);
return b;
}
// Float8 -> bfloat162x4
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
{
bf16_8_t b;
b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y);
b.z = __float22bfloat162_rn(a.z);
b.w = __float22bfloat162_rn(a.w);
return b;
}
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
s.t.
Quantize(HP / scale) => FP8
Dequant(FP8) * scale => HP
*/
// fp8 -> half
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
{
hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res;
res.data = static_cast<float>(f8) * scale;
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
{
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r.x.data = f2[0] * scale;
tmp.h2r.y.data = f2[1] * scale;
return tmp.ui32;
#else
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
return tmp.u32;
#endif
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
{
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
{
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
return tmp.u64x2;
}
using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
{
hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8};
return __float2bfloat16(f * scale);
}
using __nv_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
{
__nv_bfloat162 res;
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
{
bf16_4_t res;
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
{
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
{
hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
{
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0] * scale;
res.y = f2[1] * scale;
return res;
#else
float2 res;
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
return res;
#endif
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
{
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
{
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
/* Quantize(HP / scale) => FP8 */
// TODO(Hai): vectorized to add
// half -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
{
__half_raw tmp;
tmp.x = a;
hip_fp8 f8{static_cast<float>(tmp.data)/scale};
return f8.data;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
{
hip_fp8 res{__bfloat162float(a)/scale};
return res.data;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
{
hip_fp8 f8(a/scale);
return f8.data;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
{
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
}
} // namespace vllm
......@@ -91,7 +91,8 @@ Documentation
:caption: Quantization
quantization/auto_awq
quantization/fp8_e5m2_kv_cache
quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache
.. toctree::
:maxdepth: 2
......
.. _fp8_e4m3_kvcache:
FP8 E4M3 KV Cache
==================
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache,
improving throughput. OCP (Open Compute Project www.opencompute.org) specifies two common 8-bit floating point data formats: E5M2
(5 exponent bits and 2 mantissa bits) and E4M3FN (4 exponent bits and 3 mantissa bits), often shortened as E4M3. One benefit of
the E4M3 format over E5M2 is that floating point numbers are represented in higher precision. However, the small dynamic range of
FP8 E4M3 (±240.0 can be represented) typically necessitates the use of a higher-precision (typically FP32) scaling factor alongside
each quantized tensor. For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling
factors of a finer granularity (e.g. per-channel).
These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If
this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an
unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO).
To install AMMO (AlgorithMic Model Optimization):
.. code-block:: console
$ pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent silicon
offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to and from fp32, fp16, bf16, etc.
Thus, LLM inference is greatly accelerated with minimal accuracy loss.
Here is an example of how to enable this feature:
.. code-block:: python
# two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to
# https://github.com/vllm-project/vllm/blob/main/examples/fp8/README.md to generate kv_cache_scales.json of your own.
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=1.3, top_p=0.8)
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
kv_cache_dtype="fp8",
quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.
.. _fp8_e5m2_kv_cache:
.. _fp8_kv_cache:
FP8 E5M2 KV Cache
==================
......@@ -21,7 +21,7 @@ Here is an example of how to enable this feature:
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2")
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
......@@ -31,3 +31,6 @@ Here is an example of how to enable this feature:
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Note, current prefix caching doesn't work with FP8 KV cache enabled, forward_prefix kernel should handle different KV and cache type.
# FP8 KV Cache
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms.
## Prerequisites
- Python 3.x
- PyTorch
- NumPy
- Hugging Face Transformers
- Hugging Face Hub
- AMMO
Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps:
1. Install all necessary prerequisites and dependencies.
2. Convert HF model into a quantized HF model.
3. Extract KV Cache Scaling Factors from quantized HF model.
4. Load KV Cache Scaling Factors into VLLM.
### 2. Convert HF model into a quantized HF model.
Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md).
`quantize.py` (examples/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format).
The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/fp8/quantizer/README.md`.
### 3. Extract KV Cache Scaling Factors from quantized HF model.
`extract_scales.py` (examples/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following:
1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename.
2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM.
3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks.
```python
# prerequisites:
# - Quantized HF LLaMa 2 model
python3 examples/fp8/extract_scales.py --help
Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE]
KV Scale Extraction Example
optional arguments:
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU).
Optional arguments:
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
--revision: Specify the model's revision number. (Default: None)
--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None)
--output_name: Specify the output filename. (Default: kv_cache_scales.json)
--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None)
```
```python
Example:
python3 examples/fp8/extract_scales.py --quantized_model <QUANTIZED_MODEL_DIR> --tp_size <TENSOR_PARALLEL_SIZE> --output_dir <PATH_TO_OUTPUT_DIR>
```
### 4. Load KV Cache Scaling Factors into VLLM.
This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8.
```python
# prerequisites:
# - LLaMa 2 kv_cache_scales.json file
python3 benchmarks/benchmark_throughput.py --help
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
[--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
[--quantization-param-path KV_CACHE_quantization_param_path]
Benchmark Throughput Example
optional arguments:
-h, --help show this help message and exit
--backend {vllm,hf,mii}
--dataset DATASET Path to the dataset.
--input-len INPUT_LEN Input prompt length for each request
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
--model MODEL
--tokenizer TOKENIZER
--quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None}
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
--n N Number of generated sequences per prompt.
--use-beam-search
--num-prompts NUM_PROMPTS Number of prompts to process.
--seed SEED
--hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend.
--trust-remote-code trust remote code from huggingface
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
--enforce-eager enforce eager execution
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria.
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
```
```
Example:
python3 benchmarks/benchmark_throughput.py --input-len <INPUT_LEN> --output-len <OUTPUT_LEN> -tp <TENSOR_PARALLEL_SIZE> --kv-cache-dtype fp8 --quantization-param-path <path/to/kv_cache_scales.json> --model <path-to-llama2>
```python
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment