Commit 12494cf5 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

adapt v3.0.0

parent 8f326c97
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_5_cuh
#define _qdq_5_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_5BIT == 1
// Permutation:
//
// v5555533 33311111 u4444422 22200000 (u, v lsb)
// vbbbbb99 99977777 uaaaaa88 88866666
// vhhhhhff fffddddd ugggggee eeeccccc
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
// vtttttrr rrrppppp usssssqq qqqooooo
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
uint32_t qd = q[3 * stride];
uint32_t qe = q[4 * stride];
// qa: 66555554 44443333 32222211 11100000
// qb: ccccbbbb baaaaa99 99988888 77777666
// qc: jiiiiihh hhhggggg fffffeee eedddddc
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
uint32_t qf = qe >> 22;
qe <<= 8;
qe |= qd >> 24;
qd <<= 6;
qd |= qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: 555554 44443333 32222211 11100000
// qb: bbbbba aaaa9999 98888877 77766666
// qc: hhhhhg ggggffff feeeeedd dddccccc
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
// qf: vv vvvuuuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
uint32_t zd = 0;
uint32_t ze = 0;
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
// za: 5555533 33311111 4444422 22200000
// zb: bbbbb99 99977777 aaaaa88 88866666
// zc: hhhhhff fffddddd gggggee eeeccccc
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
// ze: tttttrr rrrppppp sssssqq qqqooooo
// qf: vv vvvuuuuu
za |= ((qf & 0x001) >> 0) << 15;
zb |= ((qf & 0x002) >> 1) << 15;
zc |= ((qf & 0x004) >> 2) << 15;
zd |= ((qf & 0x008) >> 3) << 15;
ze |= ((qf & 0x010) >> 4) << 15;
za |= ((qf & 0x020) >> 5) << 31;
zb |= ((qf & 0x040) >> 6) << 31;
zc |= ((qf & 0x080) >> 7) << 31;
zd |= ((qf & 0x100) >> 8) << 31;
ze |= ((qf & 0x200) >> 9) << 31;
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
// zb: vbbbbb99 99977777 uaaaaa88 88866666
// zc: vhhhhhff fffddddd ugggggee eeeccccc
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
// ze: vtttttrr rrrppppp usssssqq qqqooooo
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
q[3 * stride] = zd;
q[4 * stride] = ze;
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y32_ = __float2half_rn(1.0f / 32.0f);
const half2 y32 = __halves2half2(y32_, y32_);
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z32 = __halves2half2(z32_, z32_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
uint32_t qd = q_3;
uint32_t qe = q_4;
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
qa >>= 10;
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
qa >>= 5;
qa &= 0x00010001;
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
qb >>= 10;
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
qb >>= 4;
qb &= 0x00020002;
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
qc >>= 10;
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
qc >>= 3;
qc &= 0x00040004;
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
qd >>= 10;
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
qd >>= 2;
qd &= 0x00080008;
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
qe >>= 10;
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
qe >>= 1;
qe &= 0x00100010;
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hadd2( q3.as_half2, z1);
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hadd2( q6.as_half2, z1);
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
dq[ 8] = __hadd2( q8.as_half2, z1);
dq[ 9] = __hadd2( q9.as_half2, z1);
dq[10] = __hfma2(q10.as_half2, y32, z32);
dq[11] = __hadd2(q11.as_half2, z1);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y32, z32);
dq[14] = __hadd2(q14.as_half2, z1);
dq[15] = __hadd2(q15.as_half2, z1);
}
#else
__forceinline__ __device__ void shuffle_5bit_32
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_5bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
const uint32_t q_3,
const uint32_t q_4,
half2 (&dq)[16],
int stride
)
{
half dqh[32];
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_6_cuh
#define _qdq_6_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_6BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_6bit_16
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_6bit_16
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[8],
int stride
)
{
half dqh[16];
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "../../hip/quant/qdq_util.cuh"
#include "../../config.h"
#if QMODE_8BIT == 1
// Not implemented
#else
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
#endif
#endif
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
union half2_uint32
{
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
__device__ half2_uint32() : as_uint32(0) {}
};
union half_uint16
{
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
__device__ half_uint16() : as_uint16(0) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
{
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
{
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
{
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
{
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
#endif
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef _util_cuh
#define _util_cuh
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cstdint>
#include <cstdio>
#include <ATen/hip/HIPContext.h>
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#define DBGS(__x) printf("%s\n", __x)
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
{
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
qs_h = __hmul(qs_h, qs_h);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ float clamp(float x, float a, float b)
{
return fmaxf(a, fminf(b, x));
}
#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }
inline void gpu_assert(hipError_t code, const char *file, int line, bool abort=true)
{
if (code != hipSuccess)
{
fprintf(stderr,"CUDA error: %s %s %d\n", hipGetErrorString(code), file, line);
if (abort) exit(code);
}
}
void print_global_mem(const half* ptr, int rows, int columns, int stride);
#endif
...@@ -72,7 +72,8 @@ if SYSTEM == "cuda": ...@@ -72,7 +72,8 @@ if SYSTEM == "cuda":
return normed_hidden_states, residual return normed_hidden_states, residual
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops #from vllm._C import ops
from vllm import _custom_ops
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
......
...@@ -3,19 +3,19 @@ from text_generation_server.utils.import_utils import SYSTEM ...@@ -3,19 +3,19 @@ from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F from torch.nn import functional as F
import os import os
if SYSTEM == "rocm": # if SYSTEM == "rocm":
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in ( # ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true", # "true",
"1", # "1",
) # )
if ROCM_USE_SKINNY_GEMM: # if ROCM_USE_SKINNY_GEMM:
try: # try:
from vllm import _custom_C # from vllm import _custom_C
except Exception as e: # except Exception as e:
raise ImportError( # raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}" # f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
) # )
class FastLinear(torch.nn.Module): class FastLinear(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment