Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
49e84bec
Commit
49e84bec
authored
Sep 12, 2024
by
nicodafagood
Browse files
add csrc of myq
parent
df94fba9
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
2908 additions
and
2 deletions
+2908
-2
csrc/ops.h
csrc/ops.h
+14
-0
csrc/pybind.cpp
csrc/pybind.cpp
+2
-0
csrc/quantization/myq/compat.cuh
csrc/quantization/myq/compat.cuh
+64
-0
csrc/quantization/myq/matrix_view.cuh
csrc/quantization/myq/matrix_view.cuh
+274
-0
csrc/quantization/myq/q_gemm.cu
csrc/quantization/myq/q_gemm.cu
+2077
-0
csrc/quantization/myq/qdq_2.cuh
csrc/quantization/myq/qdq_2.cuh
+87
-0
csrc/quantization/myq/qdq_3.cuh
csrc/quantization/myq/qdq_3.cuh
+141
-0
csrc/quantization/myq/qdq_4.cuh
csrc/quantization/myq/qdq_4.cuh
+147
-0
csrc/quantization/myq/qdq_8.cuh
csrc/quantization/myq/qdq_8.cuh
+40
-0
csrc/quantization/myq/qdq_util.cuh
csrc/quantization/myq/qdq_util.cuh
+60
-0
vllm/model_executor/layers/quantization/myq.py
vllm/model_executor/layers/quantization/myq.py
+2
-2
No files found.
csrc/ops.h
View file @
49e84bec
...
...
@@ -115,6 +115,20 @@ void gptq_shuffle(
torch
::
Tensor
q_perm
,
int
bit
);
torch
::
Tensor
myq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_myq_qzeros
,
torch
::
Tensor
b_myq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int
bit
);
void
myq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int
bit
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
...
...
csrc/pybind.cpp
View file @
49e84bec
...
...
@@ -61,6 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"myq_gemm"
,
&
myq_gemm
,
"Quantized GEMM for myq"
);
ops
.
def
(
"myq_shuffle"
,
&
myq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"moe_align_block_size"
,
...
...
csrc/quantization/myq/compat.cuh
0 → 100644
View file @
49e84bec
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _compat_cuh
#define _compat_cuh
namespace
vllm
{
namespace
myq
{
// atomicAdd for half types, to support CC < 7.x
__device__
__forceinline__
void
atomicAdd_half
(
half
*
address
,
half
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
((
char
*
)
address
-
((
size_t
)
address
&
2
));
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
__half_raw
hsum
;
hsum
.
x
=
(
size_t
)
address
&
2
?
(
old
>>
16
)
:
(
old
&
0xffff
);
half
tmpres
=
__hadd
(
hsum
,
val
);
hsum
=
__half_raw
(
tmpres
);
old
=
(
size_t
)
address
&
2
?
(
old
&
0xffff
)
|
(
hsum
.
x
<<
16
)
:
(
old
&
0xffff0000
)
|
hsum
.
x
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
old
);
}
while
(
assumed
!=
old
);
}
// atomicAdd for half2 types
__device__
__forceinline__
void
atomicAdd_half2
(
half2
*
address
,
half2
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
address
;
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
half2
old_val
=
*
((
half2
*
)
&
old
);
half2
new_val
=
__hadd2
(
old_val
,
val
);
old
=
atomicCAS
(
address_as_ui
,
assumed
,
*
((
unsigned
int
*
)
&
new_val
));
}
while
(
assumed
!=
old
);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
#endif
#endif
#endif
}
// namespace myq
}
// namespace vllm
#endif
csrc/quantization/myq/matrix_view.cuh
0 → 100644
View file @
49e84bec
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "qdq_util.cuh"
namespace
vllm
{
namespace
myq
{
class
MatrixView_half
{
public:
const
half
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_half
(
const
half
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
half
item
(
int
row
,
int
column
)
const
{
return
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
half2
item_half2
(
int
row
,
int
column
)
const
{
return
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
];
}
__device__
__forceinline__
half2
item_half2half2
(
int
row
,
int
column
)
const
{
return
__half2half2
(
data
[
row
*
width
+
column
]);
}
__device__
__forceinline__
const
half
*
item_ptr
(
int
row
,
int
column
)
const
{
return
&
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
void
item4
(
half
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__low2half
(
i01
);
items
[
1
]
=
__high2half
(
i01
);
items
[
2
]
=
__low2half
(
i23
);
items
[
3
]
=
__high2half
(
i23
);
}
__device__
__forceinline__
void
item4_f
(
float
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__half2float
(
__low2half
(
i01
));
items
[
1
]
=
__half2float
(
__high2half
(
i01
));
items
[
2
]
=
__half2float
(
__low2half
(
i23
));
items
[
3
]
=
__half2float
(
__high2half
(
i23
));
}
__device__
__forceinline__
void
item4_h2
(
half2
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__half2half2
(
__low2half
(
i01
));
items
[
1
]
=
__half2half2
(
__high2half
(
i01
));
items
[
2
]
=
__half2half2
(
__low2half
(
i23
));
items
[
3
]
=
__half2half2
(
__high2half
(
i23
));
}
};
class
MatrixView_half_rw
{
public:
half
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_half_rw
(
half
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
half
item
(
int
row
,
int
column
)
const
{
return
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
half2
item_half2
(
int
row
,
int
column
)
const
{
return
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
];
}
__device__
__forceinline__
half
*
item_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
void
set
(
int
row
,
int
column
,
half
value
)
{
data
[
row
*
width
+
column
]
=
value
;
}
__device__
__forceinline__
void
set_half2
(
int
row
,
int
column
,
half2
value
)
{
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
]
=
value
;
}
__device__
__forceinline__
void
set4
(
int
row
,
int
column
,
half
v0
,
half
v1
,
half
v2
,
half
v3
)
{
half2
v01
=
__halves2half2
(
v0
,
v1
);
half2
v23
=
__halves2half2
(
v2
,
v3
);
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
ptr
[
0
]
=
v01
;
ptr
[
1
]
=
v23
;
}
};
class
MatrixView_q4_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q4_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
return
(
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
)
&
0x0f
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
uint32_t
d
=
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
;
items
[
0
]
=
d
&
0x0f
;
items
[
1
]
=
(
d
>>
4
)
&
0x0f
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
uint32_t
d
=
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
;
items
[
0
]
=
d
&
0x0f
;
items
[
1
]
=
(
d
>>
4
)
&
0x0f
;
items
[
2
]
=
(
d
>>
8
)
&
0x0f
;
items
[
3
]
=
(
d
>>
12
)
&
0x0f
;
}
};
class
MatrixView_q4_column
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q4_column
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
row
&
0x07
)
*
4
;
return
(
data
[
row
/
8
*
width
+
column
]
>>
shift
)
&
0x0f
;
}
__device__
__forceinline__
uint32_t
item_uint32_t
(
int
row
,
int
column
)
{
return
data
[
row
/
8
*
width
+
column
];
}
__device__
__forceinline__
const
uint32_t
*
item_uint32_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
/
8
*
width
+
column
];
}
};
class
MatrixView_q2_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q2_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
return
(
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
)
&
0x03
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
;
items
[
0
]
=
d
&
0x03
;
items
[
1
]
=
(
d
>>
2
)
&
0x03
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
;
items
[
0
]
=
d
&
0x03
;
items
[
1
]
=
(
d
>>
2
)
&
0x03
;
items
[
2
]
=
(
d
>>
4
)
&
0x03
;
items
[
3
]
=
(
d
>>
6
)
&
0x03
;
}
};
class
MatrixView_q3_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q3_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
z_w
=
column
*
3
/
32
;
int
z_mod
=
column
&
0x1f
;
if
(
z_mod
==
10
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
30
)
|
((
data
[
row
*
width
*
3
/
32
+
(
z_w
+
1
)]
<<
2
)
&
0x4
);
}
else
if
(
z_mod
==
21
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
31
)
|
((
data
[
row
*
width
*
3
/
32
+
(
z_w
+
1
)]
<<
1
)
&
0x6
);
}
else
if
(
z_mod
<
10
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
))
&
0x07
;
}
else
if
(
z_mod
<
21
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
-
32
))
&
0x07
;
}
else
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
-
64
))
&
0x07
;
}
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x1f
);
uint32_t
d
;
if
(
shift
<=
4
)
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
);
}
else
if
(
shift
==
8
)
{
d
=
(
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
24
)
|
((
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
+
1
]
&
0x0f
)
<<
8
);
}
else
if
(
shift
<=
16
)
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
-
32
);
}
else
if
(
shift
==
20
)
{
d
=
(
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
28
)
|
((
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
+
1
]
&
0xff
)
<<
4
);
}
else
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
-
64
);
}
items
[
0
]
=
d
&
0x07
;
items
[
1
]
=
(
d
>>
3
)
&
0x07
;
items
[
2
]
=
(
d
>>
6
)
&
0x07
;
items
[
3
]
=
(
d
>>
9
)
&
0x07
;
}
};
class
MatrixView_q8_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q8_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
8
;
return
(
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
)
&
0xff
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
8
;
uint32_t
d
=
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
;
items
[
0
]
=
d
&
0xff
;
items
[
1
]
=
(
d
>>
8
)
&
0xff
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
;
items
[
0
]
=
d
&
0xff
;
items
[
1
]
=
(
d
>>
8
)
&
0xff
;
items
[
2
]
=
(
d
>>
16
)
&
0xff
;
items
[
3
]
=
(
d
>>
24
)
&
0xff
;
}
};
}
// namespace myq
}
// namespace vllm
#endif
csrc/quantization/myq/q_gemm.cu
0 → 100644
View file @
49e84bec
This diff is collapsed.
Click to expand it.
csrc/quantization/myq/qdq_2.cuh
0 → 100644
View file @
49e84bec
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
myq
{
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__
__device__
void
shuffle_2bit_16
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x03
;
uint32_t
qa1
=
(
qa
&
0x0c
)
>>
2
;
qa
>>=
4
;
qb
|=
(
qa1
<<
(
i
*
2
+
16
));
qb
|=
(
qa0
<<
(
i
*
2
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_2bit_16
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
8
],
int
stride
,
const
uint32_t
zero
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y4_
=
__float2half_rn
(
1.0
f
/
4.0
f
);
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y4
=
__halves2half2
(
y4_
,
y4_
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half_uint16
z1_
(
0xe400
|
zero
);
// half(-1024.0f - zero);
const
half
z4_
=
__hsub
(
__int2half_rn
(
-
256
),
__int2half_rn
(
zero
));
const
half
z16_
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
const
half
z64_
=
__hsub
(
__int2half_rn
(
-
16
),
__int2half_rn
(
zero
));
const
half2
z1
=
__half2half2
(
z1_
.
as_half
);
const
half2
z4
=
__half2half2
(
z4_
);
const
half2
z16
=
__half2half2
(
z16_
);
const
half2
z64
=
__half2half2
(
z64_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32
q2
((
qa
&
0x00300030
)
|
c0
);
// half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32
q3
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 64 + 1024
qa
>>=
8
;
half2_uint32
q4
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 8], q[ 8]) + 1024
half2_uint32
q5
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[10], q[11]) * 4 + 1024
half2_uint32
q6
((
qa
&
0x00300030
)
|
c0
);
// half2(q[12], q[13]) * 16 + 1024
half2_uint32
q7
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[14], q[15]) * 64 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y4
,
z4
);
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y16
,
z16
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y64
,
z64
);
dq
[
4
]
=
__hadd2
(
q4
.
as_half2
,
z1
);
dq
[
5
]
=
__hfma2
(
q5
.
as_half2
,
y4
,
z4
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y16
,
z16
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y64
,
z64
);
}
}
// namespace myq
}
// namespace vllm
#endif
csrc/quantization/myq/qdq_3.cuh
0 → 100644
View file @
49e84bec
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
myq
{
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__
__device__
void
shuffle_3bit_32
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
*
stride
];
uint32_t
qb
=
q
[
1
*
stride
];
uint32_t
qc
=
q
[
2
*
stride
];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t
qd
=
qc
>>
26
;
qc
<<=
4
;
qc
|=
qb
>>
28
;
qb
<<=
2
;
qb
|=
qa
>>
30
;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t
za
=
0
;
uint32_t
zb
=
0
;
uint32_t
zc
=
0
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qa
&
0x07
;
uint32_t
t1
=
(
qa
&
0x38
)
>>
3
;
qa
>>=
6
;
za
|=
(
t0
<<
(
i
*
3
));
za
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qb
&
0x07
;
uint32_t
t1
=
(
qb
&
0x38
)
>>
3
;
qb
>>=
6
;
zb
|=
(
t0
<<
(
i
*
3
));
zb
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qc
&
0x07
;
uint32_t
t1
=
(
qc
&
0x38
)
>>
3
;
qc
>>=
6
;
zc
|=
(
t0
<<
(
i
*
3
));
zc
|=
(
t1
<<
(
i
*
3
+
16
));
}
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za
|=
((
qd
&
0x01
)
>>
0
)
<<
15
;
zb
|=
((
qd
&
0x02
)
>>
1
)
<<
15
;
zc
|=
((
qd
&
0x04
)
>>
2
)
<<
15
;
za
|=
((
qd
&
0x08
)
>>
3
)
<<
31
;
zb
|=
((
qd
&
0x10
)
>>
4
)
<<
31
;
zc
|=
((
qd
&
0x20
)
>>
5
)
<<
31
;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q
[
0
*
stride
]
=
za
;
q
[
1
*
stride
]
=
zb
;
q
[
2
*
stride
]
=
zc
;
}
__forceinline__
__device__
void
dequant_3bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
16
],
int
stride
,
const
uint32_t
zero
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y8_
=
__float2half_rn
(
1.0
f
/
8.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y8
=
__halves2half2
(
y8_
,
y8_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half_uint16
z1_
(
0xe400
|
zero
);
// half(-1024.0f - zero);
const
half
z8_
=
__hsub
(
__int2half_rn
(
-
128
),
__int2half_rn
(
zero
));
const
half
z64_
=
__hsub
(
__int2half_rn
(
-
16
),
__int2half_rn
(
zero
));
const
half2
z1
=
__halves2half2
(
z1_
.
as_half
,
z1_
.
as_half
);
const
half2
z8
=
__halves2half2
(
z8_
,
z8_
);
const
half2
z64
=
__halves2half2
(
z64_
,
z64_
);
uint32_t
qa
=
q_0
;
uint32_t
qb
=
q_1
;
uint32_t
qc
=
q_2
;
half2_uint32
q0
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 2], q[ 3]) * 8 + 1024
qa
>>=
6
;
half2_uint32
q2
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32
q4
((
qa
&
0x01c001c0
)
|
c0
);
// half2(q[ 8], q[ 9]) * 64 + 1024
qa
>>=
9
;
qa
&=
0x00010001
;
half2_uint32
q5
((
qb
&
0x00070007
)
|
c0
);
// half2(q[10], q[11]) + 1024
half2_uint32
q6
((
qb
&
0x00380038
)
|
c0
);
// half2(q[12], q[13]) * 8 + 1024
qb
>>=
6
;
half2_uint32
q7
((
qb
&
0x00070007
)
|
c0
);
// half2(q[14], q[15]) + 1024
half2_uint32
q8
((
qb
&
0x00380038
)
|
c0
);
// half2(q[16], q[17]) * 8 + 1024
half2_uint32
q9
((
qb
&
0x01c001c0
)
|
c0
);
// half2(q[18], q[19]) * 64 + 1024
qb
>>=
8
;
qb
&=
0x00020002
;
half2_uint32
q10
((
qc
&
0x00070007
)
|
c0
);
// half2(q[20], q[21]) + 1024
half2_uint32
q11
((
qc
&
0x00380038
)
|
c0
);
// half2(q[22], q[23]) * 8 + 1024
qc
>>=
6
;
half2_uint32
q12
((
qc
&
0x00070007
)
|
c0
);
// half2(q[24], q[25]) + 1024
half2_uint32
q13
((
qc
&
0x00380038
)
|
c0
);
// half2(q[26], q[27]) * 8 + 1024
half2_uint32
q14
((
qc
&
0x01c001c0
)
|
c0
);
// half2(q[28], q[29]) * 64 + 1024
qc
>>=
7
;
qc
&=
0x00040004
;
half2_uint32
q15
((
qa
|
qb
|
qc
)
|
c0
);
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y8
,
z8
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y8
,
z8
);
dq
[
4
]
=
__hfma2
(
q4
.
as_half2
,
y64
,
z64
);
dq
[
5
]
=
__hadd2
(
q5
.
as_half2
,
z1
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y8
,
z8
);
dq
[
7
]
=
__hadd2
(
q7
.
as_half2
,
z1
);
dq
[
8
]
=
__hfma2
(
q8
.
as_half2
,
y8
,
z8
);
dq
[
9
]
=
__hfma2
(
q9
.
as_half2
,
y64
,
z64
);
dq
[
10
]
=
__hadd2
(
q10
.
as_half2
,
z1
);
dq
[
11
]
=
__hfma2
(
q11
.
as_half2
,
y8
,
z8
);
dq
[
12
]
=
__hadd2
(
q12
.
as_half2
,
z1
);
dq
[
13
]
=
__hfma2
(
q13
.
as_half2
,
y8
,
z8
);
dq
[
14
]
=
__hfma2
(
q14
.
as_half2
,
y64
,
z64
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
}
// namespace myq
}
// namespace vllm
#endif
csrc/quantization/myq/qdq_4.cuh
0 → 100644
View file @
49e84bec
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
myq
{
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x0f
;
uint32_t
qa1
=
(
qa
&
0xf0
)
>>
4
;
qa
>>=
8
;
qb
|=
(
qa1
<<
(
i
*
4
+
16
));
qb
|=
(
qa0
<<
(
i
*
4
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
,
const
uint32_t
zero
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half_uint16
z1_
(
0xe400
|
zero
);
// half(-1024.0f - zero);
const
half
z16_
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
const
half2
z1
=
__half2half2
(
z1_
.
as_half
);
const
half2
z16
=
__half2half2
(
z16_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 2], q[ 3]) * 16 + 1024
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 16 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y16
,
z16
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y16
,
z16
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
half2
scale2
=
__half2half2
(
scale
);
z1z16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
z1
.
as_half
));
z1z16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
z16
));
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
y1
));
y1y16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
y16
));
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
z1z16
[
0
]
=
__half2half2
(
z1
.
as_half
);
z1z16
[
1
]
=
__half2half2
(
z16
);
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__half2half2
(
y1
);
y1y16
[
1
]
=
__half2half2
(
y16
);
}
__forceinline__
__device__
void
dequant_4bit_8_myq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
],
int
stride
,
bool
scaled
)
{
const
uint32_t
c0
=
0x64006400
;
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[0] + 1024, q[1] + 1024 )
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[4] + 1024, q[5] + 1024 )
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
q0
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
// half2( q[0] * s - z * s, q[1] * s - z * s)
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] * s - z * s, q[3] * s - z * s)
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
}
else
{
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1z16
[
0
]);
// half2( q[0] - z, q[1] - z )
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] - z, q[3] - z )
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1z16
[
0
]);
// half2( q[4] - z, q[5] - z )
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[6] - z, q[7] - z )
}
}
}
// namespace myq
}
// namespace vllm
#endif
csrc/quantization/myq/qdq_8.cuh
0 → 100644
View file @
49e84bec
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
myq
{
__forceinline__
__device__
void
shuffle_8bit_4
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_8bit_8
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
half2
(
&
dq
)[
4
],
int
stride
,
const
uint32_t
zero
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
8
,
0xff
),
zero
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
+
4
]
=
dq_ns
(
exb
(
q_1
,
i
*
8
,
0xff
),
zero
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
}
// namespace myq
}
// namespace vllm
#endif
csrc/quantization/myq/qdq_util.cuh
0 → 100644
View file @
49e84bec
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace
vllm
{
namespace
myq
{
union
half2_uint32
{
uint32_t
as_uint32
;
half2
as_half2
;
__device__
half2_uint32
(
uint32_t
val
)
:
as_uint32
(
val
)
{}
__device__
half2_uint32
(
half2
val
)
:
as_half2
(
val
)
{}
};
union
half_uint16
{
uint16_t
as_uint16
;
half
as_half
;
__device__
half_uint16
(
uint16_t
val
)
:
as_uint16
(
val
)
{}
__device__
half_uint16
(
half
val
)
:
as_half
(
val
)
{}
};
// Max_scale premultiplied by 1/256
__forceinline__
__device__
half
dq_scale
(
const
int
qs
,
const
half
max_scale
)
{
int
qs_i
=
qs
+
1
;
half
qs_h
=
__int2half_rn
(
qs_i
*
qs_i
);
qs_h
=
__hmul
(
qs_h
,
max_scale
);
return
qs_h
;
}
__forceinline__
__device__
half
dq
(
const
int
q
,
const
int
qzero
,
const
half
scale
)
{
return
__hmul
(
__int2half_rn
(
q
-
qzero
),
scale
);
}
__forceinline__
__device__
half
dq_ns
(
const
int
q
,
const
int
qzero
)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return
__int2half_rn
(
q
-
qzero
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)((
q
>>
shift
)
&
mask
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q1
,
const
uint32_t
q0
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)(
__funnelshift_rc
(
q0
,
q1
,
shift
)
&
mask
);
}
}
// namespace myq
}
// namespace vllm
#endif
vllm/model_executor/layers/quantization/myq.py
View file @
49e84bec
...
...
@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase):
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
gpt
q_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
ops
.
my
q_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
self
.
quant_config
.
weight_bits
)
output
=
ops
.
gpt
q_gemm
(
reshaped_x
,
weights
[
"qweight"
],
output
=
ops
.
my
q_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment