Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8331017
Unverified
Commit
c8331017
authored
May 09, 2024
by
Cody Yu
Committed by
GitHub
May 09, 2024
Browse files
[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535)
parent
379da6dc
Changes
17
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
843 additions
and
558 deletions
+843
-558
.buildkite/check-wheel-size.py
.buildkite/check-wheel-size.py
+1
-1
CMakeLists.txt
CMakeLists.txt
+1
-1
cmake/utils.cmake
cmake/utils.cmake
+2
-2
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+110
-176
csrc/attention/dtype_fp8.cuh
csrc/attention/dtype_fp8.cuh
+11
-5
csrc/cache.h
csrc/cache.h
+3
-1
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+70
-73
csrc/quantization/fp8/amd/hip_float8.h
csrc/quantization/fp8/amd/hip_float8.h
+0
-0
csrc/quantization/fp8/amd/hip_float8_impl.h
csrc/quantization/fp8/amd/hip_float8_impl.h
+0
-0
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+58
-1
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+0
-0
csrc/quantization/fp8/nvidia/quant_utils.cuh
csrc/quantization/fp8/nvidia/quant_utils.cuh
+568
-0
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
+0
-277
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+2
-2
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+11
-16
vllm/_custom_ops.py
vllm/_custom_ops.py
+5
-2
vllm/utils.py
vllm/utils.py
+1
-1
No files found.
.buildkite/check-wheel-size.py
View file @
c8331017
import
os
import
os
import
zipfile
import
zipfile
MAX_SIZE_MB
=
1
0
0
MAX_SIZE_MB
=
1
5
0
def
print_top_10_largest_files
(
zip_file
):
def
print_top_10_largest_files
(
zip_file
):
...
...
CMakeLists.txt
View file @
c8331017
...
@@ -167,7 +167,7 @@ set(VLLM_EXT_SRC
...
@@ -167,7 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/fp8/
fp8_cuda_kernels
.cu"
"csrc/quantization/fp8/
common
.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp"
)
"csrc/pybind.cpp"
)
...
...
cmake/utils.cmake
View file @
c8331017
...
@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
...
@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags"
)
"Failed to determine torch nvcc compiler flags"
)
if
(
CUDA_VERSION VERSION_GREATER_EQUAL 11.8
)
if
(
CUDA_VERSION VERSION_GREATER_EQUAL 11.8
)
list
(
APPEND GPU_FLAGS
"-DENABLE_FP8
_E5M2
"
)
list
(
APPEND GPU_FLAGS
"-DENABLE_FP8"
)
endif
()
endif
()
if
(
CUDA_VERSION VERSION_GREATER_EQUAL 12.0
)
if
(
CUDA_VERSION VERSION_GREATER_EQUAL 12.0
)
list
(
REMOVE_ITEM GPU_FLAGS
list
(
REMOVE_ITEM GPU_FLAGS
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list
(
APPEND GPU_FLAGS
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DUSE_ROCM"
"-DENABLE_FP8
_E4M3
"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
)
"-fno-gpu-rdc"
)
...
...
csrc/attention/attention_kernels.cu
View file @
c8331017
This diff is collapsed.
Click to expand it.
csrc/attention/dtype_fp8.cuh
View file @
c8331017
...
@@ -3,14 +3,21 @@
...
@@ -3,14 +3,21 @@
#include "attention_generic.cuh"
#include "attention_generic.cuh"
#include <stdint.h>
#include <stdint.h>
#ifdef ENABLE_FP8_E5M2
#ifdef ENABLE_FP8
#ifndef USE_ROCM
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#endif
#endif // USE_ROCM
#endif // ENABLE_FP8
namespace
vllm
{
namespace
vllm
{
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
enum
class
Fp8KVCacheDataType
{
kAuto
=
0
,
kFp8E4M3
=
1
,
kFp8E5M2
=
2
,
};
// fp8 vector types for quantization of kv cache
template
<
>
template
<
>
struct
Vec
<
uint8_t
,
1
>
{
struct
Vec
<
uint8_t
,
1
>
{
using
Type
=
uint8_t
;
using
Type
=
uint8_t
;
...
@@ -30,6 +37,5 @@ template<>
...
@@ -30,6 +37,5 @@ template<>
struct
Vec
<
uint8_t
,
8
>
{
struct
Vec
<
uint8_t
,
8
>
{
using
Type
=
uint2
;
using
Type
=
uint2
;
};
};
#endif // ENABLE_FP8_E5M2
}
// namespace vllm
}
// namespace vllm
csrc/cache.h
View file @
c8331017
...
@@ -34,5 +34,7 @@ void reshape_and_cache_flash(
...
@@ -34,5 +34,7 @@ void reshape_and_cache_flash(
// Just for unittest
// Just for unittest
void
convert_fp8
(
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
torch
::
Tensor
&
src_cache
,
torch
::
Tensor
&
dst_cache
);
const
float
scale
,
const
std
::
string
&
kv_cache_dtype
);
csrc/cache_kernels.cu
View file @
c8331017
...
@@ -4,10 +4,11 @@
...
@@ -4,10 +4,11 @@
#include "cuda_compat.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#ifdef USE_ROCM
#elif defined(ENABLE_FP8_E4M3)
#include "quantization/fp8/amd/quant_utils.cuh"
#include "quantization/fp8/amd_detail/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif
#endif
#include <algorithm>
#include <algorithm>
...
@@ -149,7 +150,7 @@ void copy_blocks(
...
@@ -149,7 +150,7 @@ void copy_blocks(
namespace
vllm
{
namespace
vllm
{
template
<
typename
scalar_t
,
typename
cache_t
,
bool
is_fp8_kv_cache
>
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
reshape_and_cache_kernel
(
__global__
void
reshape_and_cache_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
...
@@ -194,19 +195,12 @@ __global__ void reshape_and_cache_kernel(
...
@@ -194,19 +195,12 @@ __global__ void reshape_and_cache_kernel(
+
block_offset
;
+
block_offset
;
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
if
constexpr
(
is_fp8_kv_cache
)
{
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
#if defined(ENABLE_FP8_E5M2)
key_cache
[
tgt_key_idx
]
=
fp8_e5m2_unscaled
::
vec_conversion
<
uint8_t
,
scalar_t
>
(
tgt_key
);
value_cache
[
tgt_value_idx
]
=
fp8_e5m2_unscaled
::
vec_conversion
<
uint8_t
,
scalar_t
>
(
tgt_value
);
#elif defined(ENABLE_FP8_E4M3)
key_cache
[
tgt_key_idx
]
=
fp8_e4m3
::
scaled_vec_conversion
<
uint8_t
,
scalar_t
>
(
tgt_key
,
kv_scale
);
value_cache
[
tgt_value_idx
]
=
fp8_e4m3
::
scaled_vec_conversion
<
uint8_t
,
scalar_t
>
(
tgt_value
,
kv_scale
);
#else
assert
(
false
);
#endif
}
else
{
key_cache
[
tgt_key_idx
]
=
tgt_key
;
key_cache
[
tgt_key_idx
]
=
tgt_key
;
value_cache
[
tgt_value_idx
]
=
tgt_value
;
value_cache
[
tgt_value_idx
]
=
tgt_value
;
}
else
{
key_cache
[
tgt_key_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
kv_scale
);
value_cache
[
tgt_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_value
,
kv_scale
);
}
}
}
}
}
}
...
@@ -248,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel(
...
@@ -248,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel(
}
}
}
// namespace vllm
}
// namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
// KV_T is the stored data type of kv-cache.
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
// CACHE_T is the data type of key and value tensors.
reinterpret_cast<KV_T*>(key.data_ptr()), \
// KV_DTYPE is the real data type of kv-cache.
reinterpret_cast<KV_T*>(value.data_ptr()), \
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
reinterpret_cast<KV_T*>(key.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
key_stride, \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
value_stride, \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
num_heads, \
slot_mapping.data_ptr<int64_t>(), \
head_size, \
key_stride, \
block_size, \
value_stride, \
x, \
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);
kv_scale);
void
reshape_and_cache
(
void
reshape_and_cache
(
...
@@ -285,25 +282,8 @@ void reshape_and_cache(
...
@@ -285,25 +282,8 @@ void reshape_and_cache(
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
kv_cache_dtype
==
"auto"
)
{
if
(
key
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
DISPATCH_BY_KV_CACHE_DTYPE
(
key
.
dtype
(),
kv_cache_dtype
,
CALL_RESHAPE_AND_CACHE
)
CALL_RESHAPE_AND_CACHE
(
float
,
float
,
false
);
}
else
if
(
key
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_RESHAPE_AND_CACHE
(
uint16_t
,
uint16_t
,
false
);
}
else
if
(
key
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_RESHAPE_AND_CACHE
(
__nv_bfloat16
,
__nv_bfloat16
,
false
);
}
}
else
if
(
kv_cache_dtype
==
"fp8"
)
{
if
(
key
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_RESHAPE_AND_CACHE
(
float
,
uint8_t
,
true
);
}
else
if
(
key
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_RESHAPE_AND_CACHE
(
uint16_t
,
uint8_t
,
true
);
}
else
if
(
key
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_RESHAPE_AND_CACHE
(
__nv_bfloat16
,
uint8_t
,
true
);
}
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type of kv cache: "
,
kv_cache_dtype
);
}
}
}
void
reshape_and_cache_flash
(
void
reshape_and_cache_flash
(
...
@@ -353,35 +333,34 @@ void reshape_and_cache_flash(
...
@@ -353,35 +333,34 @@ void reshape_and_cache_flash(
namespace
vllm
{
namespace
vllm
{
template
<
typename
Tout
,
typename
Tin
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
convert_fp8_kernel
(
__global__
void
convert_fp8_kernel
(
const
Tin
*
__restrict__
src_cache
,
const
Tin
*
__restrict__
src_cache
,
Tout
*
__restrict__
dst_cache
,
Tout
*
__restrict__
dst_cache
,
const
float
kv_scale
,
const
int64_t
block_stride
)
{
const
int64_t
block_stride
)
{
const
int64_t
block_idx
=
blockIdx
.
x
;
const
int64_t
block_idx
=
blockIdx
.
x
;
for
(
int
i
=
threadIdx
.
x
;
i
<
block_stride
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
block_stride
;
i
+=
blockDim
.
x
)
{
int64_t
idx
=
block_idx
*
block_stride
+
i
;
int64_t
idx
=
block_idx
*
block_stride
+
i
;
#if defined(ENABLE_FP8_E5M2)
dst_cache
[
idx
]
=
fp8
::
scaled_convert
<
Tout
,
Tin
,
kv_dt
>
(
src_cache
[
idx
],
kv_scale
);
dst_cache
[
idx
]
=
fp8_e5m2_unscaled
::
vec_conversion
<
Tout
,
Tin
>
(
src_cache
[
idx
]);
#elif defined(ENABLE_FP8_E4M3)
dst_cache
[
idx
]
=
fp8_e4m3
::
vec_conversion
<
Tout
,
Tin
>
(
src_cache
[
idx
]);
#else
assert
(
false
);
#endif
}
}
}
}
}
// namespace vllm
}
// namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin) \
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
kv_scale, \
block_stride);
block_stride);
// Only for testing.
void
convert_fp8
(
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
torch
::
Tensor
&
src_cache
,
torch
::
Tensor
&
dst_cache
)
const
float
kv_scale
,
const
std
::
string
&
kv_cache_dtype
)
{
{
torch
::
Device
src_device
=
src_cache
.
device
();
torch
::
Device
src_device
=
src_cache
.
device
();
torch
::
Device
dst_device
=
dst_cache
.
device
();
torch
::
Device
dst_device
=
dst_cache
.
device
();
...
@@ -399,17 +378,35 @@ void convert_fp8(
...
@@ -399,17 +378,35 @@ void convert_fp8(
dim3
block
(
std
::
min
(
block_stride
,
int64_t
(
512
)));
dim3
block
(
std
::
min
(
block_stride
,
int64_t
(
512
)));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
if
(
kv_cache_dtype
==
"auto"
)
{
CALL_CONVERT_FP8
(
uint8_t
,
float
);
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint8_t
,
float
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
CALL_CONVERT_FP8
(
uint8_t
,
uint16_t
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
uint8_t
,
uint16_t
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
CALL_CONVERT_FP8
(
uint8_t
,
__nv_bfloat16
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
uint8_t
,
__nv_bfloat16
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
CALL_CONVERT_FP8
(
float
,
uint8_t
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
float
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
CALL_CONVERT_FP8
(
uint16_t
,
uint8_t
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
uint16_t
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kAuto
);
}
}
else
if
(
kv_cache_dtype
==
"fp8"
||
kv_cache_dtype
==
"fp8_e4m3"
)
{
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
uint8_t
,
float
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint8_t
,
uint16_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
uint8_t
,
__nv_bfloat16
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
float
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint16_t
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
kv_cache_dtype
);
}
}
}
}
csrc/quantization/fp8/amd
_detail
/hip_float8.h
→
csrc/quantization/fp8/amd/hip_float8.h
View file @
c8331017
File moved
csrc/quantization/fp8/amd
_detail
/hip_float8_impl.h
→
csrc/quantization/fp8/amd/hip_float8_impl.h
View file @
c8331017
File moved
csrc/quantization/fp8/amd
_detail
/quant_utils.cuh
→
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
c8331017
...
@@ -5,12 +5,17 @@
...
@@ -5,12 +5,17 @@
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/dtype_fp8.cuh"
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"
#include "../../../attention/dtype_bfloat16.cuh"
namespace
vllm
namespace
vllm
{
{
namespace
fp8_e4m3
{
#ifdef USE_ROCM
namespace
fp8
{
#ifdef ENABLE_FP8
template
<
typename
Tout
,
typename
Tin
>
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
{
{
...
@@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint3
...
@@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint3
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
return
res
;
}
}
#endif // ENABLE_FP8
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
convert
(
const
Tin
&
x
)
{
#ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
vec_conversion
<
Tout
,
Tin
>
(
x
);
}
#endif
assert
(
false
);
}
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
#ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
}
#endif
assert
(
false
);
}
// The following macro is used to dispatch the conversion function based on the
// data type of the key and value cache. The FN is a macro that calls a function
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
}
// fp8
#endif // USE_ROCM
}
// namespace vllm
}
// namespace vllm
csrc/quantization/fp8/
fp8_cuda_kernels
.cu
→
csrc/quantization/fp8/
common
.cu
View file @
c8331017
File moved
csrc/quantization/fp8/nvidia/quant_utils.cuh
0 → 100644
View file @
c8331017
#pragma once
#include "../../../attention/attention_dtypes.h"
#include <assert.h>
#include <float.h>
#include <stdint.h>
#include <type_traits>
namespace
vllm
{
#ifndef USE_ROCM
namespace
fp8
{
#ifdef ENABLE_FP8
#if 0 // Disable the following code to reduce the binary size.
template <typename Tout, typename Tin>
__inline__ __device__ Tout
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
return x;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
tmp.u16[0] = res.x;
tmp.u16[1] = res.y;
return tmp.u32;
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
tmp.u32[1] =
vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
return tmp.u64x2;
}
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
// half -> float -> bf16
float tmp = half_to_float(res.x);
return __float2bfloat16(tmp);
}
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
__nv_bfloat162 res;
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t res;
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
res.y =
vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float
vec_conversion<float, uint8_t>(const uint8_t &a,
const __nv_fp8_interpretation_t fp8_type) {
// fp8 -> half
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
// half -> float
return half_to_float(tmp);
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
// fp8x2 -> half2
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
// half2 -> float2
return half2_to_float2(tmp);
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
__half_raw tmp;
tmp.x = a;
__nv_fp8_storage_t res =
__nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
return (uint8_t)res;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
__nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
return (uint8_t)res;
#endif
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
__nv_fp8_storage_t
res
=
__nv_cvt_float_to_fp8
(
a
,
__NV_SATFINITE
,
fp8_type
);
return
(
uint8_t
)
res
;
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
Float4_
tmp
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
fp8_type
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
}
template
<
>
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
float2
>
(
const
float2
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
union
{
half2
float16
;
uint32_t
uint32
;
};
float16
=
__float22half2_rn
(
a
);
return
uint32
;
}
template
<
>
__inline__
__device__
uint2
vec_conversion
<
uint2
,
Float4_
>
(
const
Float4_
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
uint2
b
;
float2
val
;
val
.
x
=
a
.
x
.
x
;
val
.
y
=
a
.
x
.
y
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
val
,
fp8_type
);
val
.
x
=
a
.
y
.
x
;
val
.
y
=
a
.
y
.
y
;
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
val
,
fp8_type
);
return
b
;
}
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
Float4_
>
(
const
Float4_
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
float4
b
;
b
.
x
=
a
.
x
.
x
;
b
.
y
=
a
.
x
.
y
;
b
.
z
=
a
.
y
.
x
;
b
.
w
=
a
.
y
.
y
;
return
b
;
}
template
<
>
__inline__
__device__
uint4
vec_conversion
<
uint4
,
Float8_
>
(
const
Float8_
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
uint4
b
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
x
,
fp8_type
);
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
y
,
fp8_type
);
b
.
z
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
z
,
fp8_type
);
b
.
w
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
w
,
fp8_type
);
return
b
;
}
template
<
>
__inline__
__device__
__nv_bfloat162
vec_conversion
<
__nv_bfloat162
,
float2
>
(
const
float2
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
__nv_bfloat162
b
;
from_float
(
b
,
a
);
return
b
;
}
template
<
>
__inline__
__device__
bf16_4_t
vec_conversion
<
bf16_4_t
,
Float4_
>
(
const
Float4_
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
bf16_4_t
b
;
from_float
(
b
,
a
);
return
b
;
}
template
<
>
__inline__
__device__
bf16_8_t
vec_conversion
<
bf16_8_t
,
Float8_
>
(
const
Float8_
&
a
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
bf16_8_t
b
;
from_float
(
b
,
a
);
return
b
;
}
#endif
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains Convention of the scale in API, e.g: FP8_data =
Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
Dequant(FP8) * scale => HP
*/
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
return
x
;
}
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
__half_raw
tmp
=
__nv_cvt_fp8_to_halfraw
(
a
,
fp8_type
);
return
float_to_half
(
half_to_float
(
tmp
.
x
)
*
scale
);
}
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
union
{
uint16_t
u16
[
2
];
uint32_t
u32
;
}
tmp
;
__half2_raw
res
=
__nv_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
);
tmp
.
u16
[
0
]
=
float_to_half
(
half_to_float
(
res
.
x
)
*
scale
);
tmp
.
u16
[
1
]
=
float_to_half
(
half_to_float
(
res
.
y
)
*
scale
);
return
tmp
.
u32
;
}
// fp8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
union
{
uint2
u32x2
;
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
fp8_type
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
,
fp8_type
);
return
tmp
.
u32x2
;
}
// fp8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
scaled_vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
union
{
uint4
u64x2
;
uint2
u64
[
2
];
}
tmp
;
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
,
fp8_type
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
,
fp8_type
);
return
tmp
.
u64x2
;
}
// fp8 -> __nv_bfloat16
template
<
>
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw
res
=
__nv_cvt_fp8_to_halfraw
(
a
,
fp8_type
);
// half -> float -> bf16
float
tmp
=
half_to_float
(
res
.
x
);
return
__float2bfloat16
(
tmp
*
scale
);
}
// fp8x2 -> __nv_bfloat162
template
<
>
__inline__
__device__
__nv_bfloat162
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
__nv_bfloat162
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
fp8_type
);
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
fp8_type
);
return
res
;
}
// fp8x4 -> bf16_4_t
template
<
>
__inline__
__device__
bf16_4_t
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
bf16_4_t
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
fp8_type
);
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
,
fp8_type
);
return
res
;
}
// fp8x8 -> bf16_8_t
template
<
>
__inline__
__device__
bf16_8_t
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
bf16_4_t
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
,
fp8_type
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
,
fp8_type
);
bf16_8_t
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
z
=
tmp2
.
x
;
res
.
w
=
tmp2
.
y
;
return
res
;
}
// fp8 -> float
template
<
>
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
// fp8 -> half
__half_raw
res
=
__nv_cvt_fp8_to_halfraw
(
a
,
fp8_type
);
uint16_t
tmp
=
res
.
x
;
// half -> float
return
half_to_float
(
tmp
)
*
scale
;
}
// fp8x2 -> float2
template
<
>
__inline__
__device__
float2
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
// fp8x2 -> half2
uint32_t
tmp
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
a
,
scale
,
fp8_type
);
// half2 -> float2
return
half2_to_float2
(
tmp
);
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
Float4_
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
Float4_
res
;
res
.
x
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
fp8_type
);
res
.
y
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
,
fp8_type
);
return
res
;
}
// fp8x8 -> float8
template
<
>
__inline__
__device__
Float8_
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
Float4_
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
,
fp8_type
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
,
fp8_type
);
Float8_
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
z
=
tmp2
.
x
;
res
.
w
=
tmp2
.
y
;
return
res
;
}
// half -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
__nv_fp8_storage_t
res
=
__nv_cvt_float_to_fp8
(
half_to_float
(
a
)
/
scale
,
__NV_SATFINITE
,
fp8_type
);
return
(
uint8_t
)
res
;
}
// bf16 -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
__nv_fp8_storage_t
res
=
__nv_cvt_float_to_fp8
(
__bfloat162float
(
a
)
/
scale
,
__NV_SATFINITE
,
fp8_type
);
return
(
uint8_t
)
res
;
#endif
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
__nv_fp8_storage_t
res
=
__nv_cvt_float_to_fp8
(
a
/
scale
,
__NV_SATFINITE
,
fp8_type
);
return
(
uint8_t
)
res
;
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
,
const
__nv_fp8_interpretation_t
fp8_type
)
{
Float4_
tmp
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
,
fp8_type
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
}
#endif // ENABLE_FP8
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
convert
(
const
Tin
&
x
)
{
#if 0 // Disable the following code to reduce the binary size.
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
return vec_conversion<Tout, Tin>(x, __NV_E4M3);
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
return vec_conversion<Tout, Tin>(x, __NV_E5M2);
}
#endif
assert
(
false
);
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
#ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
,
__NV_E4M3
);
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
,
__NV_E5M2
);
}
#endif
assert
(
false
);
}
// The following macro is used to dispatch the conversion function based on the
// data type of the key and value cache. The FN is a macro that calls a function
// with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
if (KV_DTYPE == "auto") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_e5m2") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
}
}
// namespace fp8
#endif // not USE_ROCM
}
// namespace vllm
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
deleted
100644 → 0
View file @
379da6dc
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include "../../attention/attention_dtypes.h"
#include "../../attention/dtype_float32.cuh"
#include "../../attention/dtype_float16.cuh"
#include "../../attention/dtype_bfloat16.cuh"
namespace
vllm
{
#ifdef ENABLE_FP8_E5M2
namespace
fp8_e5m2_unscaled
{
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
{
return
x
;
}
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
)
{
__half_raw
res
=
__nv_cvt_fp8_to_halfraw
(
a
,
__NV_E5M2
);
return
res
.
x
;
}
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
union
{
uint16_t
u16
[
2
];
uint32_t
u32
;
}
tmp
;
__half2_raw
res
=
__nv_cvt_fp8x2_to_halfraw2
(
a
,
__NV_E5M2
);
tmp
.
u16
[
0
]
=
res
.
x
;
tmp
.
u16
[
1
]
=
res
.
y
;
return
tmp
.
u32
;
}
// fp8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
)
{
union
{
uint2
u32x2
;
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
);
tmp
.
u32
[
1
]
=
vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
tmp
.
u32x2
;
}
// fp8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
)
{
union
{
uint4
u64x2
;
uint2
u64
[
2
];
}
tmp
;
tmp
.
u64
[
0
]
=
vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
);
tmp
.
u64
[
1
]
=
vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
);
return
tmp
.
u64x2
;
}
// fp8 -> __nv_bfloat16
template
<
>
__inline__
__device__
__nv_bfloat16
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
)
{
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw
res
=
__nv_cvt_fp8_to_halfraw
(
a
,
__NV_E5M2
);
// half -> float -> bf16
float
tmp
=
half_to_float
(
res
.
x
);
return
__float2bfloat16
(
tmp
);
}
// fp8x2 -> __nv_bfloat162
template
<
>
__inline__
__device__
__nv_bfloat162
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
)
{
__nv_bfloat162
res
;
res
.
x
=
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
);
res
.
y
=
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
));
return
res
;
}
// fp8x4 -> bf16_4_t
template
<
>
__inline__
__device__
bf16_4_t
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
)
{
bf16_4_t
res
;
res
.
x
=
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
);
res
.
y
=
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
res
;
}
// fp8x8 -> bf16_8_t
template
<
>
__inline__
__device__
bf16_8_t
vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
)
{
bf16_4_t
tmp1
,
tmp2
;
tmp1
=
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
);
tmp2
=
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
);
bf16_8_t
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
z
=
tmp2
.
x
;
res
.
w
=
tmp2
.
y
;
return
res
;
}
// fp8 -> float
template
<
>
__inline__
__device__
float
vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
)
{
// fp8 -> half
uint16_t
tmp
=
vec_conversion
<
uint16_t
,
uint8_t
>
(
a
);
// half -> float
return
half_to_float
(
tmp
);
}
// fp8x2 -> float2
template
<
>
__inline__
__device__
float2
vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
)
{
// fp8x2 -> half2
uint32_t
tmp
=
vec_conversion
<
uint32_t
,
uint16_t
>
(
a
);
// half2 -> float2
return
half2_to_float2
(
tmp
);
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
Float4_
vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
)
{
Float4_
res
;
res
.
x
=
vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
);
res
.
y
=
vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
res
;
}
// fp8x8 -> float8
template
<
>
__inline__
__device__
Float8_
vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
)
{
Float4_
tmp1
,
tmp2
;
tmp1
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
);
tmp2
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
);
Float8_
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
z
=
tmp2
.
x
;
res
.
w
=
tmp2
.
y
;
return
res
;
}
// half -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
__half_raw
tmp
;
tmp
.
x
=
a
;
__nv_fp8_storage_t
res
=
__nv_cvt_halfraw_to_fp8
(
tmp
,
__NV_SATFINITE
,
__NV_E5M2
);
return
(
uint8_t
)
res
;
}
// bf16 -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert
(
false
);
#else
__nv_fp8_storage_t
res
=
__nv_cvt_bfloat16raw_to_fp8
(
__nv_bfloat16_raw
(
a
),
__NV_SATFINITE
,
__NV_E5M2
);
return
(
uint8_t
)
res
;
#endif
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
)
{
__nv_fp8_storage_t
res
=
__nv_cvt_float_to_fp8
(
a
,
__NV_SATFINITE
,
__NV_E5M2
);
return
(
uint8_t
)
res
;
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
)
{
Float4_
tmp
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
}
template
<
>
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
float2
>
(
const
float2
&
a
)
{
union
{
half2
float16
;
uint32_t
uint32
;
};
float16
=
__float22half2_rn
(
a
);
return
uint32
;
}
template
<
>
__inline__
__device__
uint2
vec_conversion
<
uint2
,
Float4_
>
(
const
Float4_
&
a
)
{
uint2
b
;
float2
val
;
val
.
x
=
a
.
x
.
x
;
val
.
y
=
a
.
x
.
y
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
val
);
val
.
x
=
a
.
y
.
x
;
val
.
y
=
a
.
y
.
y
;
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
val
);
return
b
;
}
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
Float4_
>
(
const
Float4_
&
a
)
{
float4
b
;
b
.
x
=
a
.
x
.
x
;
b
.
y
=
a
.
x
.
y
;
b
.
z
=
a
.
y
.
x
;
b
.
w
=
a
.
y
.
y
;
return
b
;
}
template
<
>
__inline__
__device__
uint4
vec_conversion
<
uint4
,
Float8_
>
(
const
Float8_
&
a
)
{
uint4
b
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
x
);
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
y
);
b
.
z
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
z
);
b
.
w
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
w
);
return
b
;
}
template
<
>
__inline__
__device__
__nv_bfloat162
vec_conversion
<
__nv_bfloat162
,
float2
>
(
const
float2
&
a
)
{
__nv_bfloat162
b
;
from_float
(
b
,
a
);
return
b
;
}
template
<
>
__inline__
__device__
bf16_4_t
vec_conversion
<
bf16_4_t
,
Float4_
>
(
const
Float4_
&
a
)
{
bf16_4_t
b
;
from_float
(
b
,
a
);
return
b
;
}
template
<
>
__inline__
__device__
bf16_8_t
vec_conversion
<
bf16_8_t
,
Float8_
>
(
const
Float8_
&
a
)
{
bf16_8_t
b
;
from_float
(
b
,
a
);
return
b
;
}
}
// namespace fp8_e5m2_unscaled
#endif // ENABLE_FP8_E5M2
}
// namespace vllm
tests/kernels/test_attention.py
View file @
c8331017
...
@@ -236,14 +236,14 @@ def test_paged_attention(
...
@@ -236,14 +236,14 @@ def test_paged_attention(
dequantized_key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dequantized_key_cache
=
torch
.
empty
(
size
=
key_cache_shape
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
)
device
=
device
)
ops
.
convert_fp8
(
key_cache
,
dequantized_key_cache
)
ops
.
convert_fp8
(
dequantized_
key_cache
,
key_cache
)
key_cache
=
dequantized_key_cache
key_cache
=
dequantized_key_cache
value_cache_shape
=
value_cache
.
shape
value_cache_shape
=
value_cache
.
shape
dequantized_value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dequantized_value_cache
=
torch
.
empty
(
size
=
value_cache_shape
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
)
device
=
device
)
ops
.
convert_fp8
(
value_cache
,
dequantized_
value_cache
)
ops
.
convert_fp8
(
dequantized_
value_cache
,
value_cache
)
value_cache
=
dequantized_value_cache
value_cache
=
dequantized_value_cache
ref_output
=
torch
.
empty_like
(
query
)
ref_output
=
torch
.
empty_like
(
query
)
...
...
tests/kernels/test_cache.py
View file @
c8331017
...
@@ -5,8 +5,6 @@ import pytest
...
@@ -5,8 +5,6 @@ import pytest
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm._C
import
cache_ops
from
vllm.utils
import
is_hip
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
@@ -25,6 +23,8 @@ SEEDS = [0]
...
@@ -25,6 +23,8 @@ SEEDS = [0]
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
# We assume fp8 is always enabled for testing.
KV_CACHE_DTYPE
=
[
"auto"
,
"fp8"
]
KV_CACHE_DTYPE
=
[
"auto"
,
"fp8"
]
...
@@ -124,8 +124,6 @@ def test_reshape_and_cache(
...
@@ -124,8 +124,6 @@ def test_reshape_and_cache(
device
:
str
,
device
:
str
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
)
->
None
:
if
not
is_hip
()
and
kv_cache_dtype
==
"fp8"
:
pytest
.
skip
()
# This test is not tuned for e5m2 cuda precision
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -149,9 +147,9 @@ def test_reshape_and_cache(
...
@@ -149,9 +147,9 @@ def test_reshape_and_cache(
# Clone the KV caches.
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
key_cache
,
cloned_
key_cache
)
ops
.
convert_fp8
(
cloned_
key_cache
,
key_cache
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
value_cache
,
cloned_
value_cache
)
ops
.
convert_fp8
(
cloned_
value_cache
,
value_cache
)
else
:
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
...
@@ -165,9 +163,9 @@ def test_reshape_and_cache(
...
@@ -165,9 +163,9 @@ def test_reshape_and_cache(
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
key_cache
,
result_
key_cache
)
ops
.
convert_fp8
(
result_
key_cache
,
key_cache
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
value_cache
,
result_
value_cache
)
ops
.
convert_fp8
(
result_
value_cache
,
value_cache
)
# Run the reference implementation.
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
...
@@ -255,8 +253,8 @@ def test_reshape_and_cache_flash(
...
@@ -255,8 +253,8 @@ def test_reshape_and_cache_flash(
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
cache_
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
slot_mapping
,
kv_cache_dtype
)
# Run the reference implementation.
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
...
@@ -299,8 +297,6 @@ def test_swap_blocks(
...
@@ -299,8 +297,6 @@ def test_swap_blocks(
)
->
None
:
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
"cpu"
in
direction
:
if
kv_cache_dtype
==
"fp8"
and
"cpu"
in
direction
:
pytest
.
skip
()
pytest
.
skip
()
if
not
is_hip
()
and
kv_cache_dtype
==
"fp8"
:
pytest
.
skip
()
# This test is not tuned for e5m2 cuda precision
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -348,7 +344,6 @@ def test_swap_blocks(
...
@@ -348,7 +344,6 @@ def test_swap_blocks(
dist_value_caches
[
0
][
dst
].
cpu
())
dist_value_caches
[
0
][
dst
].
cpu
())
@
pytest
.
mark
.
skipif
(
not
is_hip
(),
reason
=
"FP8 conversion test requires e4m3"
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
...
@@ -357,7 +352,7 @@ def test_swap_blocks(
...
@@ -357,7 +352,7 @@ def test_swap_blocks(
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_fp8_conversion
(
def
test_fp8_
e4m3_
conversion
(
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
block_size
:
int
,
block_size
:
int
,
...
@@ -377,9 +372,9 @@ def test_fp8_conversion(
...
@@ -377,9 +372,9 @@ def test_fp8_conversion(
cache
.
uniform_
(
low
,
high
)
cache
.
uniform_
(
low
,
high
)
cache_fp8
=
torch
.
empty_like
(
cache
,
dtype
=
torch
.
uint8
)
cache_fp8
=
torch
.
empty_like
(
cache
,
dtype
=
torch
.
uint8
)
ops
.
convert_fp8
(
cache
,
cache
_fp8
)
ops
.
convert_fp8
(
cache
_fp8
,
cache
)
converted_cache
=
torch
.
empty_like
(
cache
)
converted_cache
=
torch
.
empty_like
(
cache
)
ops
.
convert_fp8
(
cache_fp8
,
converted_cache
)
ops
.
convert_fp8
(
converted_cache
,
cache_fp8
)
assert
torch
.
allclose
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
allclose
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
vllm/_custom_ops.py
View file @
c8331017
...
@@ -270,8 +270,11 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
...
@@ -270,8 +270,11 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
vllm_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
vllm_cache_ops
.
swap_blocks
(
src
,
dst
,
block_mapping
)
def
convert_fp8
(
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
)
->
None
:
def
convert_fp8
(
output
:
torch
.
Tensor
,
vllm_cache_ops
.
convert_fp8
(
output
,
input
)
input
:
torch
.
Tensor
,
scale
:
float
=
1.0
,
kv_dtype
:
str
=
"fp8"
)
->
None
:
vllm_cache_ops
.
convert_fp8
(
output
,
input
,
scale
,
kv_dtype
)
#TODO: cuda_utils, custom_ar
#TODO: cuda_utils, custom_ar
vllm/utils.py
View file @
c8331017
...
@@ -329,7 +329,7 @@ def _generate_random_fp8(
...
@@ -329,7 +329,7 @@ def _generate_random_fp8(
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
tensor_tmp
=
torch
.
empty_like
(
tensor
,
dtype
=
torch
.
float16
)
tensor_tmp
=
torch
.
empty_like
(
tensor
,
dtype
=
torch
.
float16
)
tensor_tmp
.
uniform_
(
low
,
high
)
tensor_tmp
.
uniform_
(
low
,
high
)
ops
.
convert_fp8
(
tensor
_tmp
,
tensor
)
ops
.
convert_fp8
(
tensor
,
tensor
_tmp
)
del
tensor_tmp
del
tensor_tmp
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment