Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fab1acce
Commit
fab1acce
authored
Apr 28, 2026
by
zhuwenwen
Browse files
[Feature] Support vllm v0.20.0
parent
88d34c64
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
676 additions
and
444 deletions
+676
-444
CMakeLists.txt
CMakeLists.txt
+2
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+18
-2
csrc/cuda_vec_utils.cuh
csrc/cuda_vec_utils.cuh
+2
-0
csrc/fused_qknorm_rope_kernel.cu
csrc/fused_qknorm_rope_kernel.cu
+9
-9
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+2
-0
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
+523
-416
csrc/quantization/w8a8/fp8/common.cuh
csrc/quantization/w8a8/fp8/common.cuh
+7
-3
requirements/rocm.txt
requirements/rocm.txt
+6
-1
setup.py
setup.py
+103
-7
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+4
-4
No files found.
CMakeLists.txt
View file @
fab1acce
...
...
@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set
(
PYTHON_SUPPORTED_VERSIONS
"3.10"
"3.11"
"3.12"
"3.13"
"3.14"
)
# Supported AMD GPU architectures.
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201"
)
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1152;gfx1153;gfx1200;gfx1201
;gfx928;gfx936;gfx938
"
)
# ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake.
...
...
@@ -1240,7 +1240,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
endif
()
# For CUDA and HIP builds also build the triton_kernels external package.
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
OR VLLM_GPU_LANG STREQUAL
"HIP"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
include
(
cmake/external_projects/triton_kernels.cmake
)
endif
()
...
...
csrc/cache_kernels.cu
View file @
fab1acce
...
...
@@ -931,6 +931,22 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
}
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
uint8_t
,
float
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint8_t
,
uint16_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
src_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
uint8_t
,
__nv_bfloat16
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_CONVERT_FP8
(
float
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
CALL_CONVERT_FP8
(
uint16_t
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
else
if
(
dst_cache
.
dtype
()
==
at
::
ScalarType
::
BFloat16
)
{
CALL_CONVERT_FP8
(
__nv_bfloat16
,
uint8_t
,
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
);
}
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type: "
,
kv_cache_dtype
);
}
...
...
@@ -1156,9 +1172,9 @@ __global__ void cp_gather_and_upconvert_fp8_kv_cache(
const
uint2
fp8_hi
=
make_uint2
(
fp8_data
.
z
,
fp8_data
.
w
);
#ifdef USE_ROCM
const
bf16_8_t
bf16_lo
=
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_lo
,
scale
);
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_lo
,
scale
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
const
bf16_8_t
bf16_hi
=
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_hi
,
scale
);
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_hi
,
scale
,
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
);
#else
const
bf16_8_t
bf16_lo
=
fp8
::
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
fp8_lo
,
scale
,
__NV_E4M3
);
...
...
csrc/cuda_vec_utils.cuh
View file @
fab1acce
...
...
@@ -8,6 +8,8 @@
#include <cassert>
#ifdef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <hip/hip_runtime.h>
#else
#include <cuda_bf16.h>
...
...
csrc/fused_qknorm_rope_kernel.cu
View file @
fab1acce
...
...
@@ -40,15 +40,15 @@
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__
inline
void
__syncwarp
()
{
__builtin_amdgcn_fence
(
__ATOMIC_RELEASE
,
"wavefront"
);
__builtin_amdgcn_wave_barrier
();
__builtin_amdgcn_fence
(
__ATOMIC_ACQUIRE
,
"wavefront"
);
}
#endif
//
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
//
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
//
// implementation is copy/pasted from the implementation in ROCm 7.0
//
__device__ inline void __syncwarp() {
//
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
//
__builtin_amdgcn_wave_barrier();
//
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
//
}
//
#endif
#else
#define FINAL_MASK 0xffffffff
#endif
...
...
csrc/quantization/gptq/q_gemm.cu
View file @
fab1acce
...
...
@@ -12,7 +12,9 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#ifndef USE_ROCM
#include "compat.cuh"
#endif
#include "matrix_view.cuh"
#include "qdq_2.cuh"
#include "qdq_3.cuh"
...
...
csrc/quantization/w8a8/fp8/amd/quant_utils.cuh
View file @
fab1acce
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
...
...
@@ -11,318 +13,348 @@ namespace vllm {
#ifdef USE_ROCM
namespace
fp8
{
#ifdef ENABLE_FP8
//
#ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm
template
<
typename
fp8_type
>
__device__
__forceinline__
fp8_type
cvt_c10
(
float
const
r
)
{
return
{};
// KV-CACHE int8
static
inline
__device__
float
fp8_to_float
(
uint8_t
input
)
{
const
uint32_t
w
=
(
uint32_t
)
input
<<
24
;
const
uint32_t
sign
=
w
&
UINT32_C
(
0x80000000
);
const
uint32_t
nonsign
=
w
&
UINT32_C
(
0x7FFFFFFF
);
uint32_t
renorm_shift
=
__clz
(
nonsign
);
renorm_shift
=
renorm_shift
>
4
?
renorm_shift
-
4
:
0
;
uint32_t
result
=
sign
|
((
nonsign
<<
renorm_shift
>>
4
)
+
((
0x78
-
renorm_shift
)
<<
23
));
return
c10
::
detail
::
fp32_from_bits
(
result
);
}
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
// the new HW cvt with something reasonable that doesn't rely on the
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
template
<
>
__device__
__forceinline__
c10
::
Float8_e4m3fn
cvt_c10
(
float
const
r
)
{
#if HIP_FP8_TYPE_OCP
return
c10
::
Float8_e4m3fn
(
__hip_cvt_float_to_fp8
(
r
,
__hip_fp8_e4m3
::
__default_saturation
,
__hip_fp8_e4m3
::
__default_interpret
),
c10
::
Float8_e4m3fn
::
from_bits
());
#else
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#endif
}
// float -> fp8
static
inline
__device__
uint8_t
float_to_fp8_e4m3
(
float
f
)
{
constexpr
uint32_t
fp8_max
=
UINT32_C
(
1087
)
<<
20
;
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
141
)
<<
23
;
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
uint8_t
result
=
0u
;
const
uint32_t
sign
=
f_bits
&
UINT32_C
(
0x80000000
);
f_bits
^=
sign
;
if
(
f_bits
>=
fp8_max
)
{
result
=
0x7f
;
}
else
{
if
(
f_bits
<
(
UINT32_C
(
121
)
<<
23
))
{
f_bits
=
c10
::
detail
::
fp32_to_bits
(
c10
::
detail
::
fp32_from_bits
(
f_bits
)
+
c10
::
detail
::
fp32_from_bits
(
denorm_mask
));
result
=
static_cast
<
uint8_t
>
(
f_bits
-
denorm_mask
);
}
else
{
uint8_t
mant_odd
=
(
f_bits
>>
20
)
&
1
;
f_bits
+=
((
uint32_t
)(
7
-
127
)
<<
23
)
+
0x7FFFF
;
f_bits
+=
mant_odd
;
result
=
static_cast
<
uint8_t
>
(
f_bits
>>
20
);
}
}
template
<
>
__device__
__forceinline__
c10
::
Float8_e4m3fnuz
cvt_c10
(
float
const
r
)
{
return
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
r
,
__hip_fp8_e4m3_fnuz
::
__default_saturation
,
__hip_fp8_e4m3_fnuz
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
result
|=
static_cast
<
uint8_t
>
(
sign
>>
24
);
return
result
;
}
static
inline
__device__
uint8_t
float_to_fp8_e5m2
(
float
f
)
{
constexpr
uint32_t
fp32_inf
=
UINT32_C
(
255
)
<<
23
;
constexpr
uint32_t
fp8_max
=
UINT32_C
(
143
)
<<
23
;
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
134
)
<<
23
;
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
uint8_t
result
=
0u
;
const
uint32_t
sign
=
f_bits
&
UINT32_C
(
0x80000000
);
f_bits
^=
sign
;
if
(
f_bits
>=
fp8_max
)
{
result
=
f_bits
>
fp32_inf
?
UINT8_C
(
0x7F
)
:
UINT8_C
(
0x7C
);
}
else
{
if
(
f_bits
<
(
UINT32_C
(
113
)
<<
23
))
{
f_bits
=
c10
::
detail
::
fp32_to_bits
(
c10
::
detail
::
fp32_from_bits
(
f_bits
)
+
c10
::
detail
::
fp32_from_bits
(
denorm_mask
));
result
=
static_cast
<
uint8_t
>
(
f_bits
-
denorm_mask
);
}
else
{
uint32_t
mant_odd
=
(
f_bits
>>
21
)
&
1
;
f_bits
+=
((
uint32_t
)(
15
-
127
)
<<
23
)
+
0xFFFFF
;
f_bits
+=
mant_odd
;
result
=
static_cast
<
uint8_t
>
(
f_bits
>>
21
);
}
}
result
|=
static_cast
<
uint8_t
>
(
sign
>>
24
);
return
result
;
}
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
{
return
x
;
}
//
template <typename Tout, typename Tin>
//
__inline__ __device__ Tout vec_conversion(const Tin& x) {
//
return x;
//
}
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
const
float
scale
)
{
const
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
return
x
;
}
#if HIP_FP8_TYPE_OCP
using
fp8_type
=
__hip_fp8_e4m3
;
using
fp8x2_type
=
__hip_fp8x2_e4m3
;
#else
using
fp8_type
=
__hip_fp8_e4m3_fnuz
;
using
fp8x2_type
=
__hip_fp8x2_e4m3_fnuz
;
#endif
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
)
{
return
__hip_cvt_fp8_to_halfraw
(
a
,
fp8_type
::
__default_interpret
).
x
;
}
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
union
{
__half2_raw
h2r
;
uint32_t
ui32
;
}
tmp
;
tmp
.
h2r
=
__hip_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
::
__default_interpret
);
return
tmp
.
ui32
;
}
// fp8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
)
{
union
{
uint2
u32x2
;
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
);
tmp
.
u32
[
1
]
=
vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
tmp
.
u32x2
;
}
// fp8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
)
{
union
{
uint4
u64x2
;
uint2
u64
[
2
];
}
tmp
;
tmp
.
u64
[
0
]
=
vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
);
tmp
.
u64
[
1
]
=
vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
);
return
tmp
.
u64x2
;
}
using
__nv_bfloat16
=
__hip_bfloat16
;
// fp8 -> __nv_bfloat16
template
<
>
__inline__
__device__
__nv_bfloat16
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
)
{
fp8_type
f8
;
f8
.
__x
=
a
;
return
__float2bfloat16
(
static_cast
<
float
>
(
f8
));
}
using
__nv_bfloat162
=
__hip_bfloat162
;
// fp8x2 -> __nv_bfloat162
template
<
>
__inline__
__device__
__nv_bfloat162
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
)
{
__nv_bfloat162
res
;
res
.
x
=
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
);
res
.
y
=
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
));
return
res
;
}
// fp8x4 -> bf16_4_t
template
<
>
__inline__
__device__
bf16_4_t
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
)
{
bf16_4_t
res
;
res
.
x
=
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
);
res
.
y
=
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
res
;
}
// fp8x8 -> bf16_8_t
template
<
>
__inline__
__device__
bf16_8_t
vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
)
{
bf16_4_t
tmp1
,
tmp2
;
tmp1
=
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
);
tmp2
=
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
);
bf16_8_t
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
z
=
tmp2
.
x
;
res
.
w
=
tmp2
.
y
;
return
res
;
}
// fp8 -> float
template
<
>
__inline__
__device__
float
vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
)
{
fp8_type
f8
;
f8
.
__x
=
a
;
return
static_cast
<
float
>
(
f8
);
}
// fp8x2 -> float2
template
<
>
__inline__
__device__
float2
vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
)
{
fp8x2_type
f8x2
;
f8x2
.
__x
=
a
;
return
static_cast
<
float2
>
(
f8x2
);
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
Float4_
vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
)
{
Float4_
res
;
res
.
x
=
vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
);
res
.
y
=
vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
res
;
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
)
{
Float4_
tmp
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
}
// fp8x8 -> float8
template
<
>
__inline__
__device__
Float8_
vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
)
{
Float4_
tmp1
,
tmp2
;
tmp1
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
);
tmp2
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
);
Float8_
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
z
=
tmp2
.
x
;
res
.
w
=
tmp2
.
y
;
return
res
;
}
// half -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
__half_raw
tmp
;
tmp
.
x
=
a
;
return
__hip_cvt_halfraw_to_fp8
(
tmp
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
template
<
>
__inline__
__device__
uint16_t
vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
)
{
union
{
uint32_t
ui32
;
__half2_raw
h2r
;
}
tmp
;
tmp
.
ui32
=
a
;
return
__hip_cvt_halfraw2_to_fp8x2
(
tmp
.
h2r
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
// bf16 -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
)
{
return
__hip_cvt_float_to_fp8
(
__bfloat162float
(
a
),
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
)
{
return
__hip_cvt_float_to_fp8
(
a
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
// float2 -> half2
template
<
>
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
float2
>
(
const
float2
&
a
)
{
union
{
half2
float16
;
uint32_t
uint32
;
};
float16
=
__float22half2_rn
(
a
);
return
uint32
;
}
// Float4 -> half2x2
template
<
>
__inline__
__device__
uint2
vec_conversion
<
uint2
,
Float4_
>
(
const
Float4_
&
a
)
{
uint2
b
;
float2
val
;
val
.
x
=
a
.
x
.
x
;
val
.
y
=
a
.
x
.
y
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
val
);
val
.
x
=
a
.
y
.
x
;
val
.
y
=
a
.
y
.
y
;
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
val
);
return
b
;
}
// Float4 -> float4
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
Float4_
>
(
const
Float4_
&
a
)
{
float4
b
;
b
.
x
=
a
.
x
.
x
;
b
.
y
=
a
.
x
.
y
;
b
.
z
=
a
.
y
.
x
;
b
.
w
=
a
.
y
.
y
;
return
b
;
}
// Float8 -> half2x4
template
<
>
__inline__
__device__
uint4
vec_conversion
<
uint4
,
Float8_
>
(
const
Float8_
&
a
)
{
uint4
b
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
x
);
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
y
);
b
.
z
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
z
);
b
.
w
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
w
);
return
b
;
}
// float2 -> bfloat162
template
<
>
__inline__
__device__
__nv_bfloat162
vec_conversion
<
__nv_bfloat162
,
float2
>
(
const
float2
&
a
)
{
__nv_bfloat162
b
=
__float22bfloat162_rn
(
a
);
return
b
;
}
// Float4 -> bfloat162x2
template
<
>
__inline__
__device__
bf16_4_t
vec_conversion
<
bf16_4_t
,
Float4_
>
(
const
Float4_
&
a
)
{
bf16_4_t
b
;
b
.
x
=
__float22bfloat162_rn
(
a
.
x
);
b
.
y
=
__float22bfloat162_rn
(
a
.
y
);
return
b
;
}
// Float8 -> bfloat162x4
template
<
>
__inline__
__device__
bf16_8_t
vec_conversion
<
bf16_8_t
,
Float8_
>
(
const
Float8_
&
a
)
{
bf16_8_t
b
;
b
.
x
=
__float22bfloat162_rn
(
a
.
x
);
b
.
y
=
__float22bfloat162_rn
(
a
.
y
);
b
.
z
=
__float22bfloat162_rn
(
a
.
z
);
b
.
w
=
__float22bfloat162_rn
(
a
.
w
);
return
b
;
}
//
#if HIP_FP8_TYPE_OCP
//
using fp8_type = __hip_fp8_e4m3;
//
using fp8x2_type = __hip_fp8x2_e4m3;
//
#else
//
using fp8_type = __hip_fp8_e4m3_fnuz;
//
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
//
#endif
//
// fp8 -> half
//
template <>
//
__inline__ __device__ uint16_t
//
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
//
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
//
}
//
// fp8x2 -> half2
//
template <>
//
__inline__ __device__ uint32_t
//
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
//
union {
//
__half2_raw h2r;
//
uint32_t ui32;
//
} tmp;
//
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
//
return tmp.ui32;
//
}
//
// fp8x4 -> half2x2
//
template <>
//
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
//
union {
//
uint2 u32x2;
//
uint32_t u32[2];
//
} tmp;
//
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
//
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
//
return tmp.u32x2;
//
}
//
// fp8x8 -> half2x4
//
template <>
//
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
//
union {
//
uint4 u64x2;
//
uint2 u64[2];
//
} tmp;
//
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
//
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
//
return tmp.u64x2;
//
}
//
using __nv_bfloat16 = __hip_bfloat16;
//
// fp8 -> __nv_bfloat16
//
template <>
//
__inline__ __device__ __nv_bfloat16
//
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
//
fp8_type f8;
//
f8.__x = a;
//
return __float2bfloat16(static_cast<float>(f8));
//
}
//
using __nv_bfloat162 = __hip_bfloat162;
//
// fp8x2 -> __nv_bfloat162
//
template <>
//
__inline__ __device__ __nv_bfloat162
//
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
//
__nv_bfloat162 res;
//
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
//
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
//
return res;
//
}
//
// fp8x4 -> bf16_4_t
//
template <>
//
__inline__ __device__ bf16_4_t
//
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
//
bf16_4_t res;
//
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
//
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
//
return res;
//
}
//
// fp8x8 -> bf16_8_t
//
template <>
//
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
//
bf16_4_t tmp1, tmp2;
//
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
//
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
//
bf16_8_t res;
//
res.x = tmp1.x;
//
res.y = tmp1.y;
//
res.z = tmp2.x;
//
res.w = tmp2.y;
//
return res;
//
}
//
// fp8 -> float
//
template <>
//
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
//
fp8_type f8;
//
f8.__x = a;
//
return static_cast<float>(f8);
//
}
//
// fp8x2 -> float2
//
template <>
//
__inline__ __device__ float2
//
vec_conversion<float2, uint16_t>(const uint16_t& a) {
//
fp8x2_type f8x2;
//
f8x2.__x = a;
//
return static_cast<float2>(f8x2);
//
}
//
// fp8x4 -> float4
//
template <>
//
__inline__ __device__ Float4_
//
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
//
Float4_ res;
//
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
//
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
//
return res;
//
}
//
// fp8x4 -> float4
//
template <>
//
__inline__ __device__ float4
//
vec_conversion<float4, uint32_t>(const uint32_t& a) {
//
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
//
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
//
return res;
//
}
//
// fp8x8 -> float8
//
template <>
//
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
//
Float4_ tmp1, tmp2;
//
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
//
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
//
Float8_ res;
//
res.x = tmp1.x;
//
res.y = tmp1.y;
//
res.z = tmp2.x;
//
res.w = tmp2.y;
//
return res;
//
}
//
// half -> fp8
//
template <>
//
__inline__ __device__ uint8_t
//
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
//
__half_raw tmp;
//
tmp.x = a;
//
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
//
fp8_type::__default_interpret);
//
}
//
template <>
//
__inline__ __device__ uint16_t
//
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
//
union {
//
uint32_t ui32;
//
__half2_raw h2r;
//
} tmp;
//
tmp.ui32 = a;
//
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
//
fp8_type::__default_interpret);
//
}
//
// bf16 -> fp8
//
template <>
//
__inline__ __device__ uint8_t
//
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
//
return __hip_cvt_float_to_fp8(__bfloat162float(a),
//
fp8_type::__default_saturation,
//
fp8_type::__default_interpret);
//
}
//
// float -> fp8
//
template <>
//
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
//
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
//
fp8_type::__default_interpret);
//
}
//
// float2 -> half2
//
template <>
//
__inline__ __device__ uint32_t
//
vec_conversion<uint32_t, float2>(const float2& a) {
//
union {
//
half2 float16;
//
uint32_t uint32;
//
};
//
float16 = __float22half2_rn(a);
//
return uint32;
//
}
//
// Float4 -> half2x2
//
template <>
//
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
//
uint2 b;
//
float2 val;
//
val.x = a.x.x;
//
val.y = a.x.y;
//
b.x = vec_conversion<uint32_t, float2>(val);
//
val.x = a.y.x;
//
val.y = a.y.y;
//
b.y = vec_conversion<uint32_t, float2>(val);
//
return b;
//
}
//
// Float4 -> float4
//
template <>
//
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
//
float4 b;
//
b.x = a.x.x;
//
b.y = a.x.y;
//
b.z = a.y.x;
//
b.w = a.y.y;
//
return b;
//
}
//
// Float8 -> half2x4
//
template <>
//
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
//
uint4 b;
//
b.x = vec_conversion<uint32_t, float2>(a.x);
//
b.y = vec_conversion<uint32_t, float2>(a.y);
//
b.z = vec_conversion<uint32_t, float2>(a.z);
//
b.w = vec_conversion<uint32_t, float2>(a.w);
//
return b;
//
}
//
// float2 -> bfloat162
//
template <>
//
__inline__ __device__ __nv_bfloat162
//
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
//
__nv_bfloat162 b = __float22bfloat162_rn(a);
//
return b;
//
}
//
// Float4 -> bfloat162x2
//
template <>
//
__inline__ __device__ bf16_4_t
//
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
//
bf16_4_t b;
//
b.x = __float22bfloat162_rn(a.x);
//
b.y = __float22bfloat162_rn(a.y);
//
return b;
//
}
//
// Float8 -> bfloat162x4
//
template <>
//
__inline__ __device__ bf16_8_t
//
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
//
bf16_8_t b;
//
b.x = __float22bfloat162_rn(a.x);
//
b.y = __float22bfloat162_rn(a.y);
//
b.z = __float22bfloat162_rn(a.z);
//
b.w = __float22bfloat162_rn(a.w);
//
return b;
//
}
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
...
...
@@ -338,42 +370,47 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template
<
>
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
fp8_type
f8
;
f8
.
__x
=
a
;
return
__float2bfloat16
(
static_cast
<
float
>
(
f8
)
*
scale
);
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
return
__float2bfloat16
(
fp8_to_float
(
a
)
*
scale
);
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8) * scale);
}
// fp8x2 -> __nv_bfloat162
template
<
>
__inline__
__device__
__nv_bfloat162
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
__nv_bfloat162
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
res
;
}
// fp8x4 -> bf16_4_t
template
<
>
__inline__
__device__
bf16_4_t
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
bf16_4_t
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
scale
,
kv_type
);
return
res
;
}
// fp8x8 -> bf16_8_t
template
<
>
__inline__
__device__
bf16_8_t
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
bf16_4_t
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
);
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
bf16_8_t
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
...
...
@@ -385,46 +422,55 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float
template
<
>
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
fp8_type
f8
;
f8
.
__x
=
a
;
return
static_cast
<
float
>
(
f8
)
*
scale
;
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
return
fp8_to_float
(
a
)
*
scale
;
// fp8_type f8;
// f8.__x = a;
// return static_cast<float>(f8) * scale;
}
// fp8x2 -> float2
template
<
>
__inline__
__device__
float2
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
fp8x2_type
f8x2
;
f8x2
.
__x
=
a
;
return
static_cast
<
float2
>
(
f8x2
)
*
scale
;
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float2
f2r
;
f2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
f2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
f2r
;
// [[maybe_unused]]
// fp8x2_type f8x2;
// f8x2.__x = a;
// return static_cast<float2>(f8x2) * scale;
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
Float4_
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
Float4_
res
;
res
.
x
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
y
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
res
.
x
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
,
kv_type
);
return
res
;
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
Float4_
res
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
);
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
Float4_
res
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
,
kv_type
);
return
{
res
.
x
.
x
,
res
.
x
.
y
,
res
.
y
.
x
,
res
.
y
.
y
};
}
// fp8x8 -> float8
template
<
>
__inline__
__device__
Float8_
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
Float4_
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
);
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
Float8_
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
...
...
@@ -436,200 +482,249 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
__half_raw
res
;
res
.
data
=
scaled_vec_conversion
<
float
,
uint8_t
>
(
a
,
scale
);
return
res
.
x
;
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
float
res
=
fp8_to_float
(
a
)
*
scale
;
return
float_to_half
(
res
);
// __half_raw res;
// res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
// return res.x;
}
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
__half2_raw
h2r
;
uint32_t
ui32
;
}
tmp
;
tmp
.
h2r
=
__hip_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
::
__default_interpret
);
tmp
.
h2r
.
x
.
data
*=
scale
;
tmp
.
h2r
.
y
.
data
*=
scale
;
return
tmp
.
ui32
;
uint16_t
u16
[
2
];
uint32_t
u32
;
}
res
;
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
res
.
u32
;
// [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// tmp.h2r.x.data *= scale;
// tmp.h2r.y.data *= scale;
// return tmp.ui32;
}
// fp8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint2
u32x2
;
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
kv_type
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
,
kv_type
);
return
tmp
.
u32x2
;
}
// fp8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
scaled_vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint4
u64x2
;
uint2
u64
[
2
];
}
tmp
;
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
);
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
u64x2
;
}
// half -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
__half_raw
tmp
;
tmp
.
x
=
a
;
tmp
.
data
/=
scale
;
return
__hip_cvt_halfraw_to_fp8
(
tmp
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
res_f
);
}
else
{
return
float_to_fp8_e5m2
(
res_f
);
}
// __half_raw tmp;
// tmp.x = a;
// tmp.data /= scale;
// return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// halfx2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint
32
_t
ui
32
;
__half2_raw
h2r
;
uint
8
_t
ui
8
[
2
]
;
uint16_t
ui16
;
}
tmp
;
tmp
.
ui32
=
a
;
tmp
.
h2r
.
x
.
data
/=
scale
;
tmp
.
h2r
.
y
.
data
/=
scale
;
return
__hip_cvt_halfraw2_to_fp8x2
(
tmp
.
h2r
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
union
{
uint32_t
ui32
;
half2
h2r
;
}
tmp_a
;
tmp_a
.
ui32
=
a
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
,
kv_type
);
return
tmp
.
ui16
;
// union {
// uint32_t ui32;
// __half2_raw h2r;
// } tmp;
// tmp.ui32 = a;
// tmp.h2r.x.data /= scale;
// tmp.h2r.y.data /= scale;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// half2x2 -> fp8x4
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
y
,
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui32
;
}
// half2x4 -> fp8x8
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint4
>
(
const
uint4
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint2
ui2
[
2
];
uint4
ui4
;
}
tmp
;
tmp
.
ui4
=
a
;
uint2
res
;
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
0
],
scale
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
1
],
scale
);
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
0
],
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
1
],
scale
,
kv_type
);
return
res
;
}
// bf16 -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
float
scale
)
{
return
__hip_cvt_float_to_fp8
(
__bfloat162float
(
a
)
/
scale
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
const
__nv_bfloat16
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
res_f
);
}
else
{
return
float_to_fp8_e5m2
(
res_f
);
}
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
// fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// bf16x2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
const
__nv_bfloat162
&
a
,
float
scale
)
{
const
__nv_bfloat162
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint8_t
ui8
[
2
];
uint16_t
ui16
;
}
tmp
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
y
,
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui16
;
}
// bf16x4 -> fp8x4
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
(
const
bf16_4_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
(
const
bf16_4_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
x
,
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
y
,
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui32
;
}
// bf16x8 -> fp8x8
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
bf16_8_t
>
(
const
bf16_8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint2
,
bf16_8_t
>
(
const
bf16_8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
uint2
res
;
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
x
,
a
.
y
},
scale
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
z
,
a
.
w
},
scale
);
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
x
,
a
.
y
},
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
z
,
a
.
w
},
scale
,
kv_type
);
return
res
;
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
)
{
return
__hip_cvt_float_to_fp8
(
a
/
scale
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
a
/
scale
);
}
else
{
return
float_to_fp8_e5m2
(
a
/
scale
);
}
// return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// floatx2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
float2
>
(
const
float2
&
a
,
float
scale
)
{
return
__hip_cvt_float2_to_fp8x2
(
a
/
scale
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
scaled_vec_conversion
<
uint16_t
,
float2
>
(
const
float2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint8_t
ui8
[
2
];
uint16_t
ui16
;
}
tmp
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui16
;
// return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
}
// floatx4 -> fp8x4
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
float4
>
(
const
float4
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
float4
>
(
const
float4
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
x
,
a
.
y
},
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
x
,
a
.
y
},
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
,
kv_type
);
return
tmp
.
ui32
;
}
#endif // ENABLE_FP8
//
#endif // ENABLE_FP8
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
convert
(
const
Tin
&
x
)
{
#ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
vec_conversion
<
Tout
,
Tin
>
(
x
);
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
//
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
//
__inline__ __device__ Tout convert(const Tin& x) {
//
#ifdef ENABLE_FP8
//
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
//
return vec_conversion<Tout, Tin>(x);
//
}
//
#endif
//
assert(false);
//
return {}; // Squash missing return statement warning
//
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
#ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
//
#ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
||
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
,
kv_dt
);
}
#endif
//
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
...
...
@@ -652,19 +747,31 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E4M3) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_CACHE_DTYPE == vllm::Fp8KVCacheDataType::kFp8E5M2) { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
}
}
}
// namespace fp8
#endif // USE_ROCM
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/quantization/w8a8/fp8/common.cuh
View file @
fab1acce
...
...
@@ -47,15 +47,19 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
x
=
val
/
scale
;
}
float
r
=
fmaxf
(
-
quant_type_max_v
<
fp8_type
>
,
fminf
(
x
,
quant_type_max_v
<
fp8_type
>
));
//
float r =
//
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
#ifndef USE_ROCM
// Use hardware cvt instruction for fp8 on nvidia
// Currently only support fp8_type = c10::Float8_e4m3fn
return
fp8
::
vec_conversion
<
fp8_type
,
float
>
(
r
);
#else
fp8_type
*
test
;
uint8_t
test_uint8
=
fp8
::
float_to_fp8_e4m3
(
x
);
test
=
(
fp8_type
*
)(
&
test_uint8
);
return
*
test
;
// Use hardware cvt instruction for fp8 on rocm
return
fp8
::
cvt_c10
<
fp8_type
>
(
r
);
//
return fp8::cvt_c10<fp8_type>(r);
#endif
}
...
...
requirements/rocm.txt
View file @
fab1acce
...
...
@@ -16,8 +16,13 @@ packaging>=24.2
setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
runai-model-streamer[s3,gcs,azure]==0.15.7
conch-triton-kernels==1.2.1
#
conch-triton-kernels==1.2.1
timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
# Other necessary dependencies
torch == 2.10.0
torchvision == 0.25.0
flash_attn == 2.8.3
setup.py
View file @
fab1acce
...
...
@@ -20,6 +20,12 @@ from setuptools import Extension, setup
from
setuptools.command.build_ext
import
build_ext
from
setuptools_scm
import
get_version
from
torch.utils.cpp_extension
import
CUDA_HOME
,
ROCM_HOME
from
typing
import
Optional
,
Union
pwd
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
add_git_version
=
False
if
int
(
os
.
environ
.
get
(
'ADD_GIT_VERSION'
,
'0'
))
==
1
:
add_git_version
=
True
def
load_module_from_path
(
module_name
,
path
):
...
...
@@ -365,7 +371,7 @@ class cmake_build_ext(build_ext):
os
.
makedirs
(
os
.
path
.
dirname
(
dst_file
),
exist_ok
=
True
)
self
.
copy_file
(
file
,
dst_file
)
if
_is_cuda
()
or
_is_hip
()
:
if
_is_cuda
():
# copy vllm/third_party/triton_kernels/**/*.py from self.build_lib
# to current directory so that they can be included in the editable
# build
...
...
@@ -895,6 +901,94 @@ def get_nvcc_cuda_version() -> Version:
return
nvcc_cuda_version
def
get_sha
(
root
:
Union
[
str
,
Path
])
->
str
:
try
:
return
subprocess
.
check_output
([
'git'
,
'rev-parse'
,
'HEAD'
],
cwd
=
root
).
decode
(
'ascii'
).
strip
()
except
Exception
:
return
'Unknown'
def
get_version_add
(
sha
:
Optional
[
str
]
=
None
)
->
str
:
command
=
"git config --global --add safe.directory "
+
pwd
subprocess
.
run
(
command
,
shell
=
True
,
capture_output
=
False
,
text
=
True
)
vllm_root
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
add_version_path
=
os
.
path
.
join
(
os
.
path
.
join
(
vllm_root
,
"vllm"
),
"version.py"
)
major
,
minor
,
_
=
torch
.
__version__
.
split
(
'.'
)
if
add_git_version
:
if
sha
!=
'Unknown'
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
version
=
'das.'
+
sha
[:
7
]
else
:
version
=
'das'
# dtk version
if
os
.
getenv
(
"ROCM_PATH"
):
rocm_path
=
os
.
getenv
(
'ROCM_PATH'
,
""
)
rocm_version_path
=
os
.
path
.
join
(
rocm_path
,
'.info'
,
"rocm_version"
)
with
open
(
rocm_version_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
lines
=
file
.
readlines
()
rocm_version
=
lines
[
0
].
replace
(
"."
,
""
)
version
+=
".dtk"
+
rocm_version
new_version_content
=
f
"""
try:
__version__ = "0.20.0"
__version_tuple__ = (0, 20, 0)
__hcu_version__ = f'0.20.0+
{
version
}
'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:
\\
n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
For example - return True if the current version if 0.7.4 and the
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
def _prev_minor_version():
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
"""
with
open
(
add_version_path
,
encoding
=
"utf-8"
,
mode
=
"w"
)
as
file
:
file
.
write
(
new_version_content
)
file
.
close
()
def
get_version
():
get_version_add
()
version_file
=
'vllm/version.py'
with
open
(
version_file
,
encoding
=
'utf-8'
)
as
f
:
exec
(
compile
(
f
.
read
(),
version_file
,
'exec'
))
return
locals
()[
'__hcu_version__'
]
def
get_vllm_version
()
->
str
:
# Allow overriding the version. This is useful to build platform-specific
# wheels (e.g. CPU, TPU) without modifying the source.
...
...
@@ -903,8 +997,9 @@ def get_vllm_version() -> str:
os
.
environ
[
"SETUPTOOLS_SCM_PRETEND_VERSION"
]
=
env_version
return
get_version
(
write_to
=
"vllm/_version.py"
)
version
=
get_version
(
write_to
=
"vllm/_version.py"
)
sep
=
"+"
if
"+"
not
in
version
else
"."
# dev versions might contain +
if
not
_is_hip
():
version
=
get_version
(
write_to
=
"vllm/_version.py"
)
sep
=
"+"
if
"+"
not
in
version
else
"."
# dev versions might contain +
if
_no_device
():
if
envs
.
VLLM_TARGET_DEVICE
==
"empty"
:
...
...
@@ -921,9 +1016,10 @@ def get_vllm_version() -> str:
version
+=
f
"
{
sep
}
cu
{
cuda_version_str
}
"
elif
_is_hip
():
# Get the Rocm Version
rocm_version
=
get_rocm_version
()
or
torch
.
version
.
hip
if
rocm_version
and
rocm_version
!=
envs
.
VLLM_MAIN_CUDA_VERSION
:
version
+=
f
"
{
sep
}
rocm
{
rocm_version
.
replace
(
'.'
,
''
)[:
3
]
}
"
# rocm_version = get_rocm_version() or torch.version.hip
# if rocm_version and rocm_version != envs.VLLM_MAIN_CUDA_VERSION:
# version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
version
=
get_version
()
elif
_is_tpu
():
version
+=
f
"
{
sep
}
tpu"
elif
_is_cpu
():
...
...
@@ -991,7 +1087,7 @@ if _is_cuda() or _is_hip():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.cumem_allocator"
))
# Optional since this doesn't get built (produce an .so file). This is just
# copying the relevant .py files from the source repository.
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.triton_kernels"
,
optional
=
True
))
#
ext_modules.append(CMakeExtension(name="vllm.triton_kernels", optional=True))
if
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._rocm_C"
))
...
...
vllm/platforms/rocm.py
View file @
fab1acce
...
...
@@ -44,10 +44,10 @@ except ImportError as e:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
# import custom ops, trigger op registration
try
:
import
vllm._rocm_C
# noqa: F401
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._rocm_C with %r"
,
e
)
#
try:
#
import vllm._rocm_C # noqa: F401
#
except ImportError as e:
#
logger.warning("Failed to import from vllm._rocm_C with %r", e)
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS
:
list
[
str
]
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment