Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45273722
Commit
45273722
authored
May 14, 2025
by
xiabo
Browse files
add kvint8
parent
17e4dd25
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
361 additions
and
3 deletions
+361
-3
csrc/attention/attention_kernels.cuh
csrc/attention/attention_kernels.cuh
+15
-1
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+15
-0
csrc/attention/dtype_fp8.cuh
csrc/attention/dtype_fp8.cuh
+1
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+16
-0
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+10
-0
csrc/quantization/int8_kvcache/quant_utils.cuh
csrc/quantization/int8_kvcache/quant_utils.cuh
+288
-0
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+2
-1
vllm/config.py
vllm/config.py
+1
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+5
-0
vllm/utils.py
vllm/utils.py
+8
-0
No files found.
csrc/attention/attention_kernels.cuh
View file @
45273722
...
...
@@ -33,6 +33,8 @@ typedef __hip_bfloat16 __nv_bfloat16;
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#include "../quantization/int8_kvcache/quant_utils.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
...
...
@@ -280,7 +282,13 @@ __device__ void paged_attention_kernel(
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
{
}
else
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kInt8
)
{
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
int8
::
scaled_vec_conversion_int8
<
K_vec
,
Quant_vec
>
(
k_vec_quant
,
*
k_scale
);
}
else
{
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
...
...
@@ -410,6 +418,12 @@ __device__ void paged_attention_kernel(
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kInt8
)
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
int8
::
scaled_vec_conversion_int8
<
V_vec
,
V_quant_vec
>
(
v_quant_vec
,
*
v_scale
);
}
else
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
...
...
csrc/attention/attention_kernels_opt.cu
View file @
45273722
...
...
@@ -14,6 +14,8 @@ typedef __hip_bfloat16 __nv_bfloat16;
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#include "../quantization/int8_kvcache/quant_utils.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
...
...
@@ -311,6 +313,12 @@ __device__ void paged_attention_kernel_opt(
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kInt8
)
{
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
int8
::
scaled_vec_conversion_int8
<
K_vec
,
Quant_vec
>
(
k_vec_quant
,
*
k_scale_ptr
);
}
else
{
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
...
...
@@ -478,6 +486,13 @@ __device__ void paged_attention_kernel_opt(
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kInt8
)
{
// printf("======xiabo_kvint8\n");
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
int8
::
scaled_vec_conversion_int8
<
V_vec
,
V_quant_vec
>
(
v_quant_vec
,
*
v_scale_ptr
);
}
else
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
...
...
csrc/attention/dtype_fp8.cuh
View file @
45273722
...
...
@@ -15,6 +15,7 @@ enum class Fp8KVCacheDataType {
kAuto
=
0
,
kFp8E4M3
=
1
,
kFp8E5M2
=
2
,
kInt8
=
3
,
};
// fp8 vector types for quantization of kv cache
...
...
csrc/cache_kernels.cu
View file @
45273722
...
...
@@ -12,6 +12,8 @@
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif
#include "quantization/int8_kvcache/quant_utils.cuh"
#include <algorithm>
#include <cassert>
#include <map>
...
...
@@ -252,6 +254,13 @@ __global__ void reshape_and_cache_kernel(
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
key_cache
[
tgt_key_idx
]
=
tgt_key
;
value_cache
[
tgt_value_idx
]
=
tgt_value
;
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kInt8
)
{
key_cache
[
tgt_key_idx
]
=
int8
::
scaled_vec_conversion_int8
<
cache_t
,
scalar_t
>
(
tgt_key
,
*
k_scale
);
value_cache
[
tgt_value_idx
]
=
int8
::
scaled_vec_conversion_int8
<
cache_t
,
scalar_t
>
(
tgt_value
,
*
v_scale
);
}
else
{
key_cache
[
tgt_key_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
*
k_scale
);
...
...
@@ -296,6 +305,13 @@ __global__ void reshape_and_cache_flash_kernel(
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
key_cache
[
tgt_key_value_idx
]
=
tgt_key
;
value_cache
[
tgt_key_value_idx
]
=
tgt_value
;
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kInt8
)
{
key_cache
[
tgt_key_value_idx
]
=
int8
::
scaled_vec_conversion_int8
<
cache_t
,
scalar_t
>
(
tgt_key
,
*
k_scale
);
value_cache
[
tgt_key_value_idx
]
=
int8
::
scaled_vec_conversion_int8
<
cache_t
,
scalar_t
>
(
tgt_value
,
*
v_scale
);
}
else
{
key_cache
[
tgt_key_value_idx
]
=
fp8
::
scaled_convert
<
cache_t
,
scalar_t
,
kv_dt
>
(
tgt_key
,
*
k_scale
);
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
45273722
...
...
@@ -653,6 +653,16 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "int8") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else { \
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
...
...
csrc/quantization/int8_kvcache/quant_utils.cuh
0 → 100644
View file @
45273722
// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include "../../attention/attention_dtypes.h"
#include <stdio.h>
namespace
vllm
{
namespace
int8
{
// KV-CACHE int8
static
inline
__device__
float
int8_to_float
(
uint8_t
x
,
const
float
scale
)
{
int8_t
a
=
x
-
128
;
float
res
=
a
*
scale
;
return
res
;
}
static
inline
__device__
uint8_t
float_to_int8
(
float
x
,
const
float
scale
)
{
int8_t
fx
=
roundf
(
max
(
-
128.
f
,
min
(
127.
f
,
x
/
scale
)));
uint8_t
res
=
fx
+
128
;
return
res
;
}
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
scaled_vec_conversion_int8
(
const
Tin
&
x
,
const
float
scale
)
{
return
x
;
}
// int8 -> half
// template <>
// __inline__ __device__ uint16_t scaled_vec_conversion_int8<uint16_t, uint8_t>(
// const uint8_t& a, const float scale) {
// float res = int8_to_float(a, scale);
// return float_to_half(res);
// // return half(a);__float2half
// }
// int8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion_int8
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
)
{
union
{
uint8_t
uint8
[
2
];
uint16_t
uint16
;
};
uint16
=
a
;
float2
b
;
b
.
x
=
(
uint8
[
0
]
-
128
)
*
scale
;
b
.
y
=
(
uint8
[
1
]
-
128
)
*
scale
;
union
{
half2
float16
;
uint32_t
uint32
;
};
float16
=
__float22half2_rn
(
b
);
return
uint32
;
}
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
{
return
x
;
}
template
<
>
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
float2
>
(
const
float2
&
a
)
{
union
{
half2
float16
;
uint32_t
uint32
;
};
float16
=
__float22half2_rn
(
a
);
return
uint32
;
}
template
<
>
__inline__
__device__
uint2
vec_conversion
<
uint2
,
Float4_
>
(
const
Float4_
&
a
)
{
uint2
b
;
float2
val
;
val
.
x
=
a
.
x
.
x
;
val
.
y
=
a
.
x
.
y
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
val
);
val
.
x
=
a
.
y
.
x
;
val
.
y
=
a
.
y
.
y
;
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
val
);
return
b
;
}
// int8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion_int8
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
union
{
uint8_t
uint8
[
4
];
uint32_t
uint32
;
};
uint32
=
a
;
Float4_
b
;
b
.
x
.
x
=
(
uint8
[
0
]
-
128
)
*
scale
;
b
.
x
.
y
=
(
uint8
[
1
]
-
128
)
*
scale
;
b
.
y
.
x
=
(
uint8
[
2
]
-
128
)
*
scale
;
b
.
y
.
y
=
(
uint8
[
3
]
-
128
)
*
scale
;
return
vec_conversion
<
uint2
,
Float4_
>
(
b
);
}
inline
__device__
float2
dequant
(
uint16_t
a
,
const
float
scale
)
{
union
{
uint8_t
uint8
[
2
];
uint16_t
uint16
;
};
uint16
=
a
;
float2
b
;
b
.
x
=
(
uint8
[
0
]
-
128
)
*
scale
;
b
.
y
=
(
uint8
[
1
]
-
128
)
*
scale
;
return
b
;
}
// int8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
scaled_vec_conversion_int8
<
uint4
,
uint2
>
(
const
uint2
&
a
,
const
float
scale
)
{
// scaled_vec_conversion_int8<uint4, uint64_t>(const uint64_t& a, const float scale) {
union
{
uint16_t
uint16
[
4
];
uint2
uint64
;
};
uint64
=
a
;
Float8_
b
;
b
.
x
=
dequant
(
uint16
[
0
],
scale
);
b
.
y
=
dequant
(
uint16
[
1
],
scale
);
b
.
z
=
dequant
(
uint16
[
2
],
scale
);
b
.
w
=
dequant
(
uint16
[
3
],
scale
);
uint4
c
;
c
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
b
.
x
);
c
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
b
.
y
);
c
.
z
=
vec_conversion
<
uint32_t
,
float2
>
(
b
.
z
);
c
.
w
=
vec_conversion
<
uint32_t
,
float2
>
(
b
.
w
);
return
c
;
}
// int8 -> __nv_bfloat16
template
<
>
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion_int8
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
const
float
scale
)
{
// Note there is no direct convert function from int8 to bf16.
float
res
=
int8_to_float
(
a
,
scale
);
return
__float2bfloat16
(
res
);
}
// int8x2 -> __nv_bfloat162
template
<
>
__inline__
__device__
__nv_bfloat162
scaled_vec_conversion_int8
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
)
{
__nv_bfloat162
res
;
res
.
x
=
scaled_vec_conversion_int8
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
y
=
scaled_vec_conversion_int8
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
return
res
;
}
// int8x4 -> bf16_4_t
template
<
>
__inline__
__device__
bf16_4_t
scaled_vec_conversion_int8
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
bf16_4_t
res
;
res
.
x
=
scaled_vec_conversion_int8
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
y
=
scaled_vec_conversion_int8
<
__nv_bfloat162
,
uint16_t
>
(
(
uint16_t
)(
a
>>
16U
),
scale
);
return
res
;
}
// int8x8 -> bf16_8_t
template
<
>
__inline__
__device__
bf16_8_t
scaled_vec_conversion_int8
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
const
float
scale
)
{
// scaled_vec_conversion_int8<bf16_8_t, uint64_t>(const uint64_t& a, const float scale) {
// bf16_4_t tmp1, tmp2;
// tmp1 = scaled_vec_conversion_int8<bf16_4_t, uint32_t>(a.x, scale);
// tmp2 = scaled_vec_conversion_int8<bf16_4_t, uint32_t>(a.y, scale);
bf16_8_t
res
;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
return
res
;
}
// int8 -> float
template
<
>
__inline__
__device__
float
scaled_vec_conversion_int8
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
const
float
scale
)
{
float
res
=
int8_to_float
(
a
,
scale
);
return
res
;
}
// int8x2 -> float2
template
<
>
__inline__
__device__
float2
scaled_vec_conversion_int8
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
)
{
// int8x2 -> half2
uint32_t
tmp
=
scaled_vec_conversion_int8
<
uint32_t
,
uint16_t
>
(
a
,
scale
);
// half2 -> float2
return
half2_to_float2
(
tmp
);
}
// int8x4 -> float4
template
<
>
__inline__
__device__
Float4_
scaled_vec_conversion_int8
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
Float4_
res
;
res
.
x
=
scaled_vec_conversion_int8
<
float2
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
y
=
scaled_vec_conversion_int8
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
return
res
;
}
// int8x8 -> float8
template
<
>
__inline__
__device__
Float8_
scaled_vec_conversion_int8
<
Float8_
,
uint64_t
>
(
const
uint64_t
&
a
,
const
float
scale
)
{
// scaled_vec_conversion_int8<Float8_, uint2>(const uint2& a, const float scale) {
// Float4_ tmp1, tmp2;
// tmp1 = scaled_vec_conversion_int8<Float4_, uint32_t>(a.x, scale);
// tmp2 = scaled_vec_conversion_int8<Float4_, uint32_t>(a.y, scale);
Float8_
res
;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
return
res
;
}
// half -> int8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_int8
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
)
{
uint8_t
res
=
float_to_int8
(
half_to_float
(
a
),
scale
);
return
(
uint8_t
)
res
;
// return (uint8_t)(a);
}
// bf16 -> int8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_int8
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
const
float
scale
)
{
uint8_t
res
=
float_to_int8
(
__bfloat162float
(
a
),
scale
);
return
(
uint8_t
)
res
;
}
// float -> int8
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion_int8
<
uint8_t
,
float
>
(
const
float
&
a
,
const
float
scale
)
{
uint8_t
res
=
float_to_int8
(
a
,
scale
);
return
(
uint8_t
)
res
;
// return (uint8_t)(a);
}
// int8x4 -> float4
template
<
>
__inline__
__device__
float4
scaled_vec_conversion_int8
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
Float4_
tmp
=
scaled_vec_conversion_int8
<
Float4_
,
uint32_t
>
(
a
,
scale
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
}
}
// namespace int8
}
// namespace vllm
vllm/attention/ops/paged_attn.py
View file @
45273722
...
...
@@ -131,7 +131,8 @@ class PagedAttention:
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
if
(
kv_cache_dtype
==
"int8"
):
use_tc
=
False
if
use_tc
and
head_size
==
128
:
if
envs
.
VLLM_USE_PA_PRINT_PARAM
:
print
(
"PA V1 SIZE:"
)
...
...
vllm/config.py
View file @
45273722
...
...
@@ -1288,7 +1288,7 @@ class ModelConfig:
BlockSize
=
Literal
[
1
,
8
,
16
,
32
,
64
,
128
]
CacheDType
=
Literal
[
"auto"
,
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
]
CacheDType
=
Literal
[
"auto"
,
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
,
"int8"
]
PrefixCachingHashAlgo
=
Literal
[
"builtin"
,
"sha256"
]
...
...
vllm/engine/arg_utils.py
View file @
45273722
...
...
@@ -1383,6 +1383,11 @@ class EngineArgs:
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
)
supported
=
flash_attn_supports_fp8
()
int8_attention
=
self
.
kv_cache_dtype
.
startswith
(
"int8"
)
if
int8_attention
:
supported
=
True
if
not
supported
:
_raise_or_fallback
(
feature_name
=
"--kv-cache-dtype"
,
recommend_to_remove
=
False
)
...
...
vllm/utils.py
View file @
45273722
...
...
@@ -747,6 +747,8 @@ def get_kv_cache_torch_dtype(
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
cache_dtype
]
elif
cache_dtype
==
"fp8"
:
torch_dtype
=
torch
.
uint8
elif
cache_dtype
==
"int8"
:
torch_dtype
=
torch
.
uint8
else
:
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
elif
isinstance
(
cache_dtype
,
torch
.
dtype
):
...
...
@@ -792,6 +794,8 @@ def create_kv_caches_with_random_flash(
key_value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
key_value_cache
,
-
scale
,
scale
)
elif
cache_dtype
==
'int8'
:
_generate_random_int8
(
value_cache
)
else
:
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
...
...
@@ -833,6 +837,8 @@ def create_kv_caches_with_random(
key_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
key_cache
,
-
scale
,
scale
)
elif
cache_dtype
==
'int8'
:
_generate_random_int8
(
key_value_cache
)
else
:
raise
ValueError
(
f
"Does not support key cache of type
{
cache_dtype
}
"
)
...
...
@@ -848,6 +854,8 @@ def create_kv_caches_with_random(
value_cache
.
uniform_
(
-
scale
,
scale
)
elif
cache_dtype
==
'fp8'
:
_generate_random_fp8
(
value_cache
,
-
scale
,
scale
)
elif
cache_dtype
==
'int8'
:
_generate_random_int8
(
key_cache
)
else
:
raise
ValueError
(
f
"Does not support value cache of type
{
cache_dtype
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment