Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
01a5d18a
"src/libtorchaudio/rnnt/gpu/compute.cu" did not exist on "0f603eb9b385f6355a294423841b38d540a46668"
Unverified
Commit
01a5d18a
authored
Feb 29, 2024
by
CHU Tianxiang
Committed by
GitHub
Feb 28, 2024
Browse files
Add Support for 2/3/8-bit GPTQ Quantization Models (#2330)
parent
929b4f29
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1736 additions
and
229 deletions
+1736
-229
csrc/ops.h
csrc/ops.h
+4
-2
csrc/quantization/gptq/matrix_view.cuh
csrc/quantization/gptq/matrix_view.cuh
+123
-0
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+1326
-126
csrc/quantization/gptq/qdq_2.cuh
csrc/quantization/gptq/qdq_2.cuh
+87
-0
csrc/quantization/gptq/qdq_3.cuh
csrc/quantization/gptq/qdq_3.cuh
+141
-0
csrc/quantization/gptq/qdq_4.cuh
csrc/quantization/gptq/qdq_4.cuh
+6
-94
csrc/quantization/gptq/qdq_8.cuh
csrc/quantization/gptq/qdq_8.cuh
+40
-0
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+9
-7
No files found.
csrc/ops.h
View file @
01a5d18a
...
...
@@ -98,11 +98,13 @@ torch::Tensor gptq_gemm(
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
);
bool
use_exllama
,
int
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
);
torch
::
Tensor
q_perm
,
int
bit
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
...
...
csrc/quantization/gptq/matrix_view.cuh
View file @
01a5d18a
...
...
@@ -146,6 +146,129 @@ public:
__device__
__forceinline__
const
uint32_t
*
item_uint32_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
/
8
*
width
+
column
];
}
};
class
MatrixView_q2_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q2_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
return
(
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
)
&
0x03
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
;
items
[
0
]
=
d
&
0x03
;
items
[
1
]
=
(
d
>>
2
)
&
0x03
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x0f
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
16
+
column
/
16
]
>>
shift
;
items
[
0
]
=
d
&
0x03
;
items
[
1
]
=
(
d
>>
2
)
&
0x03
;
items
[
2
]
=
(
d
>>
4
)
&
0x03
;
items
[
3
]
=
(
d
>>
6
)
&
0x03
;
}
};
class
MatrixView_q3_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q3_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
z_w
=
column
*
3
/
32
;
int
z_mod
=
column
&
0x1f
;
if
(
z_mod
==
10
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
30
)
|
((
data
[
row
*
width
*
3
/
32
+
(
z_w
+
1
)]
<<
2
)
&
0x4
);
}
else
if
(
z_mod
==
21
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
31
)
|
((
data
[
row
*
width
*
3
/
32
+
(
z_w
+
1
)]
<<
1
)
&
0x6
);
}
else
if
(
z_mod
<
10
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
))
&
0x07
;
}
else
if
(
z_mod
<
21
)
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
-
32
))
&
0x07
;
}
else
{
return
(
data
[
row
*
width
*
3
/
32
+
z_w
]
>>
(
z_mod
*
3
-
64
))
&
0x07
;
}
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x1f
);
uint32_t
d
;
if
(
shift
<=
4
)
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
);
}
else
if
(
shift
==
8
)
{
d
=
(
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
24
)
|
((
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
+
1
]
&
0x0f
)
<<
8
);
}
else
if
(
shift
<=
16
)
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
-
32
);
}
else
if
(
shift
==
20
)
{
d
=
(
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
28
)
|
((
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
+
1
]
&
0xff
)
<<
4
);
}
else
{
d
=
data
[
row
*
width
/
32
*
3
+
column
*
3
/
32
]
>>
(
shift
*
3
-
64
);
}
items
[
0
]
=
d
&
0x07
;
items
[
1
]
=
(
d
>>
3
)
&
0x07
;
items
[
2
]
=
(
d
>>
6
)
&
0x07
;
items
[
3
]
=
(
d
>>
9
)
&
0x07
;
}
};
class
MatrixView_q8_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q8_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
8
;
return
(
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
)
&
0xff
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
8
;
uint32_t
d
=
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
;
items
[
0
]
=
d
&
0xff
;
items
[
1
]
=
(
d
>>
8
)
&
0xff
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x03
)
*
2
;
uint32_t
d
=
data
[
row
*
width
/
4
+
column
/
4
]
>>
shift
;
items
[
0
]
=
d
&
0xff
;
items
[
1
]
=
(
d
>>
8
)
&
0xff
;
items
[
2
]
=
(
d
>>
16
)
&
0xff
;
items
[
3
]
=
(
d
>>
24
)
&
0xff
;
}
};
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/q_gemm.cu
View file @
01a5d18a
This diff is collapsed.
Click to expand it.
csrc/quantization/gptq/qdq_2.cuh
0 → 100644
View file @
01a5d18a
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__
__device__
void
shuffle_2bit_16
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x03
;
uint32_t
qa1
=
(
qa
&
0x0c
)
>>
2
;
qa
>>=
4
;
qb
|=
(
qa1
<<
(
i
*
2
+
16
));
qb
|=
(
qa0
<<
(
i
*
2
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_2bit_16
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
8
],
int
stride
,
const
uint32_t
zero
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y4_
=
__float2half_rn
(
1.0
f
/
4.0
f
);
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y4
=
__halves2half2
(
y4_
,
y4_
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half_uint16
z1_
(
0xe400
|
zero
);
// half(-1024.0f - zero);
const
half
z4_
=
__hsub
(
__int2half_rn
(
-
256
),
__int2half_rn
(
zero
));
const
half
z16_
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
const
half
z64_
=
__hsub
(
__int2half_rn
(
-
16
),
__int2half_rn
(
zero
));
const
half2
z1
=
__half2half2
(
z1_
.
as_half
);
const
half2
z4
=
__half2half2
(
z4_
);
const
half2
z16
=
__half2half2
(
z16_
);
const
half2
z64
=
__half2half2
(
z64_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32
q2
((
qa
&
0x00300030
)
|
c0
);
// half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32
q3
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 64 + 1024
qa
>>=
8
;
half2_uint32
q4
((
qa
&
0x00030003
)
|
c0
);
// half2(q[ 8], q[ 8]) + 1024
half2_uint32
q5
((
qa
&
0x000c000c
)
|
c0
);
// half2(q[10], q[11]) * 4 + 1024
half2_uint32
q6
((
qa
&
0x00300030
)
|
c0
);
// half2(q[12], q[13]) * 16 + 1024
half2_uint32
q7
((
qa
&
0x00c000c0
)
|
c0
);
// half2(q[14], q[15]) * 64 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y4
,
z4
);
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y16
,
z16
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y64
,
z64
);
dq
[
4
]
=
__hadd2
(
q4
.
as_half2
,
z1
);
dq
[
5
]
=
__hfma2
(
q5
.
as_half2
,
y4
,
z4
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y16
,
z16
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y64
,
z64
);
}
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/qdq_3.cuh
0 → 100644
View file @
01a5d18a
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__
__device__
void
shuffle_3bit_32
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
*
stride
];
uint32_t
qb
=
q
[
1
*
stride
];
uint32_t
qc
=
q
[
2
*
stride
];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t
qd
=
qc
>>
26
;
qc
<<=
4
;
qc
|=
qb
>>
28
;
qb
<<=
2
;
qb
|=
qa
>>
30
;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t
za
=
0
;
uint32_t
zb
=
0
;
uint32_t
zc
=
0
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qa
&
0x07
;
uint32_t
t1
=
(
qa
&
0x38
)
>>
3
;
qa
>>=
6
;
za
|=
(
t0
<<
(
i
*
3
));
za
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qb
&
0x07
;
uint32_t
t1
=
(
qb
&
0x38
)
>>
3
;
qb
>>=
6
;
zb
|=
(
t0
<<
(
i
*
3
));
zb
|=
(
t1
<<
(
i
*
3
+
16
));
}
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
uint32_t
t0
=
qc
&
0x07
;
uint32_t
t1
=
(
qc
&
0x38
)
>>
3
;
qc
>>=
6
;
zc
|=
(
t0
<<
(
i
*
3
));
zc
|=
(
t1
<<
(
i
*
3
+
16
));
}
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za
|=
((
qd
&
0x01
)
>>
0
)
<<
15
;
zb
|=
((
qd
&
0x02
)
>>
1
)
<<
15
;
zc
|=
((
qd
&
0x04
)
>>
2
)
<<
15
;
za
|=
((
qd
&
0x08
)
>>
3
)
<<
31
;
zb
|=
((
qd
&
0x10
)
>>
4
)
<<
31
;
zc
|=
((
qd
&
0x20
)
>>
5
)
<<
31
;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q
[
0
*
stride
]
=
za
;
q
[
1
*
stride
]
=
zb
;
q
[
2
*
stride
]
=
zc
;
}
__forceinline__
__device__
void
dequant_3bit_32
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
const
uint32_t
q_2
,
half2
(
&
dq
)[
16
],
int
stride
,
const
uint32_t
zero
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y8_
=
__float2half_rn
(
1.0
f
/
8.0
f
);
const
half
y64_
=
__float2half_rn
(
1.0
f
/
64.0
f
);
const
half2
y8
=
__halves2half2
(
y8_
,
y8_
);
const
half2
y64
=
__halves2half2
(
y64_
,
y64_
);
const
half_uint16
z1_
(
0xe400
|
zero
);
// half(-1024.0f - zero);
const
half
z8_
=
__hsub
(
__int2half_rn
(
-
128
),
__int2half_rn
(
zero
));
const
half
z64_
=
__hsub
(
__int2half_rn
(
-
16
),
__int2half_rn
(
zero
));
const
half2
z1
=
__halves2half2
(
z1_
.
as_half
,
z1_
.
as_half
);
const
half2
z8
=
__halves2half2
(
z8_
,
z8_
);
const
half2
z64
=
__halves2half2
(
z64_
,
z64_
);
uint32_t
qa
=
q_0
;
uint32_t
qb
=
q_1
;
uint32_t
qc
=
q_2
;
half2_uint32
q0
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 2], q[ 3]) * 8 + 1024
qa
>>=
6
;
half2_uint32
q2
((
qa
&
0x00070007
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00380038
)
|
c0
);
// half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32
q4
((
qa
&
0x01c001c0
)
|
c0
);
// half2(q[ 8], q[ 9]) * 64 + 1024
qa
>>=
9
;
qa
&=
0x00010001
;
half2_uint32
q5
((
qb
&
0x00070007
)
|
c0
);
// half2(q[10], q[11]) + 1024
half2_uint32
q6
((
qb
&
0x00380038
)
|
c0
);
// half2(q[12], q[13]) * 8 + 1024
qb
>>=
6
;
half2_uint32
q7
((
qb
&
0x00070007
)
|
c0
);
// half2(q[14], q[15]) + 1024
half2_uint32
q8
((
qb
&
0x00380038
)
|
c0
);
// half2(q[16], q[17]) * 8 + 1024
half2_uint32
q9
((
qb
&
0x01c001c0
)
|
c0
);
// half2(q[18], q[19]) * 64 + 1024
qb
>>=
8
;
qb
&=
0x00020002
;
half2_uint32
q10
((
qc
&
0x00070007
)
|
c0
);
// half2(q[20], q[21]) + 1024
half2_uint32
q11
((
qc
&
0x00380038
)
|
c0
);
// half2(q[22], q[23]) * 8 + 1024
qc
>>=
6
;
half2_uint32
q12
((
qc
&
0x00070007
)
|
c0
);
// half2(q[24], q[25]) + 1024
half2_uint32
q13
((
qc
&
0x00380038
)
|
c0
);
// half2(q[26], q[27]) * 8 + 1024
half2_uint32
q14
((
qc
&
0x01c001c0
)
|
c0
);
// half2(q[28], q[29]) * 64 + 1024
qc
>>=
7
;
qc
&=
0x00040004
;
half2_uint32
q15
((
qa
|
qb
|
qc
)
|
c0
);
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y8
,
z8
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y8
,
z8
);
dq
[
4
]
=
__hfma2
(
q4
.
as_half2
,
y64
,
z64
);
dq
[
5
]
=
__hadd2
(
q5
.
as_half2
,
z1
);
dq
[
6
]
=
__hfma2
(
q6
.
as_half2
,
y8
,
z8
);
dq
[
7
]
=
__hadd2
(
q7
.
as_half2
,
z1
);
dq
[
8
]
=
__hfma2
(
q8
.
as_half2
,
y8
,
z8
);
dq
[
9
]
=
__hfma2
(
q9
.
as_half2
,
y64
,
z64
);
dq
[
10
]
=
__hadd2
(
q10
.
as_half2
,
z1
);
dq
[
11
]
=
__hfma2
(
q11
.
as_half2
,
y8
,
z8
);
dq
[
12
]
=
__hadd2
(
q12
.
as_half2
,
z1
);
dq
[
13
]
=
__hfma2
(
q13
.
as_half2
,
y8
,
z8
);
dq
[
14
]
=
__hfma2
(
q14
.
as_half2
,
y64
,
z64
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/qdq_4.cuh
View file @
01a5d18a
...
...
@@ -38,16 +38,17 @@ __forceinline__ __device__ void dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
int
stride
,
const
uint32_t
zero
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half
z1_
=
__float2
half
_rn
(
-
1024.0
f
-
8.0
f
);
const
half
z16_
=
__
floa
t2half_rn
(
-
1024.0
f
/
16.0
f
-
8.0
f
);
const
half2
z1
=
__hal
ves
2half2
(
z1_
,
z1_
);
const
half2
z16
=
__hal
ves
2half2
(
z16_
,
z16_
);
const
half
_uint16
z1_
(
0xe400
|
zero
);
//
half(-1024.0f
- zero
);
const
half
z16_
=
__
hsub
(
__in
t2half_rn
(
-
64
),
__int2half_rn
(
zero
)
);
const
half2
z1
=
__hal
f
2half2
(
z1_
.
as_half
);
const
half2
z16
=
__hal
f
2half2
(
z16_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
...
...
@@ -143,93 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
}
// namespace gptq
}
// namespace vllm
#else
namespace
vllm
{
namespace
gptq
{
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
4
,
0x0f
),
8
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z
=
__hmul
(
z
,
scale
);
z1
[
0
]
=
__half2half2
(
z
);
y1
[
0
]
=
__half2half2
(
scale
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z1
[
0
]
=
__half2half2
(
z
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
],
int
stride
,
bool
scaled
)
{
half2
dqh2
[
8
];
uint32_t
qa
=
q_0
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
half
d0
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
half
d1
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
dqh2
[
i
]
=
__halves2half2
(
d0
,
d1
);
}
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
dqh2
[
0
],
y1
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hfma2
(
dqh2
[
1
],
y1
[
0
],
z1
[
0
]);
dq
[
2
]
=
__hfma2
(
dqh2
[
2
],
y1
[
0
],
z1
[
0
]);
dq
[
3
]
=
__hfma2
(
dqh2
[
3
],
y1
[
0
],
z1
[
0
]);
}
else
{
dq
[
0
]
=
__hadd2
(
dqh2
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hadd2
(
dqh2
[
1
],
z1
[
0
]);
dq
[
2
]
=
__hadd2
(
dqh2
[
2
],
z1
[
0
]);
dq
[
3
]
=
__hadd2
(
dqh2
[
3
],
z1
[
0
]);
}
}
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/qdq_8.cuh
0 → 100644
View file @
01a5d18a
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
__forceinline__
__device__
void
shuffle_8bit_4
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_8bit_8
(
const
uint32_t
q_0
,
const
uint32_t
q_1
,
half2
(
&
dq
)[
4
],
int
stride
,
const
uint32_t
zero
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
8
,
0xff
),
zero
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dqh
[
i
+
4
]
=
dq_ns
(
exb
(
q_1
,
i
*
8
,
0xff
),
zero
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
}
// namespace gptq
}
// namespace vllm
#endif
vllm/model_executor/layers/quantization/gptq.py
View file @
01a5d18a
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
fractions
import
Fraction
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
pack_factor
=
32
//
self
.
weight_bits
# exllama kernel v1 only supports 4 bit
if
self
.
weight_bits
!=
4
:
self
.
pack_factor
=
Fraction
(
32
,
self
.
weight_bits
)
if
self
.
weight_bits
not
in
[
2
,
3
,
4
,
8
]:
raise
ValueError
(
"Currently, only
4
-bit weight quantization is supported for "
"Currently, only
2/3/4/8
-bit weight quantization is supported for "
f
"GPTQ, but got
{
self
.
weight_bits
}
bits."
)
def
__repr__
(
self
)
->
str
:
...
...
@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
.
numerator
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
...
...
@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
])
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
self
.
quant_config
.
weight_bits
)
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
)
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
self
.
quant_config
.
weight_bits
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment