Commit 858087a6 authored by xiabo's avatar xiabo
Browse files

kv cache int8的实现

parent 14ad512a
......@@ -333,9 +333,9 @@ class TurboMind:
cfg.w4_weight_layout=2
cfg.w4_pad_size=0
else:
output_format = update_output_format(cfg.model_name,
inferred_model_format,
model_path, output_format)
# output_format = update_output_format(cfg.model_name,
# inferred_model_format,
# model_path, output_format)
data_type = output_format
update_config_weight_type(output_format, cfg)
......
......@@ -180,6 +180,9 @@ inline __device__ void Store(T* dst, const Array<T, N>& src)
else if constexpr (sizeof(Array<T, N>) == sizeof(uint1)) {
*(uint1*)dst = (const uint1&)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(unsigned short)) {
*(unsigned short*)dst = (const unsigned short&)src;
}
else {
printf("=====array_ops.h 184\n");
// static_assert(!std::is_same_v<T, T>);
......@@ -380,7 +383,6 @@ struct ConvertKvCache<Ti, int8_t> {
inline __device__ uint8_t round(float x) const
{
uint32_t y;
printf("======arrat_ops.h 380\n");
// asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
if (x >= 255) {
y = 255;
......@@ -421,15 +423,18 @@ inline __device__ Array<float, 4> fast_i2f_f32_s8(const Array<int8_t, 4>& x)
// (1 + x / 2^15) * 2^(e - 127) -> e - 127 == 15 -> e = 142
// 7 6 5 4
static constexpr uint32_t f32_magic = 0x47000000; // 2^15 = 32768
static constexpr uint32_t m0 = 0x7604;
static constexpr uint32_t m1 = 0x7614;
static constexpr uint32_t m2 = 0x7624;
static constexpr uint32_t m3 = 0x7634;
printf("======arrat_ops.h 417\n");
// asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[0]) : "r"(i8s), "n"(f32_magic), "n"(m0));
// asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[1]) : "r"(i8s), "n"(f32_magic), "n"(m1));
// asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[2]) : "r"(i8s), "n"(f32_magic), "n"(m2));
// asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[3]) : "r"(i8s), "n"(f32_magic), "n"(m3));
uint8_t elt_0 = i8s & 0xFF;
uint8_t elt_1 = (i8s & 0xFF00) >> 8;
uint8_t elt_2 = (i8s & 0xFF0000) >> 16;
uint8_t elt_3 = (i8s & 0xFF000000) >> 24;
uint8_t elt_4 = (f32_magic & 0xFF);
uint8_t elt_6 = (f32_magic & 0xFF0000) >> 16;
uint8_t elt_7 = (f32_magic & 0xFF000000) >> 24;
u32x4[0] = (elt_7 << 24) | (elt_6 << 16) | (elt_0 << 8) | (elt_4);
u32x4[1] = (elt_7 << 24) | (elt_6 << 16) | (elt_1 << 8) | (elt_4);
u32x4[2] = (elt_7 << 24) | (elt_6 << 16) | (elt_2 << 8) | (elt_4);
u32x4[3] = (elt_7 << 24) | (elt_6 << 16) | (elt_3 << 8) | (elt_4);
if (0) { // fused with dequantization
PRAGMA_UNROLL
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment