Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
0fbfc4b8
Unverified
Commit
0fbfc4b8
authored
Dec 15, 2023
by
CHU Tianxiang
Committed by
GitHub
Dec 15, 2023
Browse files
Add GPTQ support (#916)
parent
c06170cc
Changes
35
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1683 additions
and
51 deletions
+1683
-51
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+1
-1
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+1
-1
csrc/ops.h
csrc/ops.h
+12
-0
csrc/pybind.cpp
csrc/pybind.cpp
+2
-2
csrc/quantization/gptq/compat.cuh
csrc/quantization/gptq/compat.cuh
+64
-0
csrc/quantization/gptq/matrix_view.cuh
csrc/quantization/gptq/matrix_view.cuh
+151
-0
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+859
-0
csrc/quantization/gptq/qdq_4.cuh
csrc/quantization/gptq/qdq_4.cuh
+235
-0
csrc/quantization/gptq/qdq_util.cuh
csrc/quantization/gptq/qdq_util.cuh
+60
-0
setup.py
setup.py
+1
-0
vllm/config.py
vllm/config.py
+1
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+3
-2
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+37
-23
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-1
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+13
-11
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+215
-0
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+8
-6
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+8
-1
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+8
-1
No files found.
benchmarks/benchmark_latency.py
View file @
0fbfc4b8
...
...
@@ -84,7 +84,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
'-q'
,
choices
=
[
'awq'
,
'squeezellm'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
default
=
None
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
...
...
benchmarks/benchmark_throughput.py
View file @
0fbfc4b8
...
...
@@ -244,7 +244,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
'-q'
,
choices
=
[
'awq'
,
'squeezellm'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
...
...
csrc/ops.h
View file @
0fbfc4b8
...
...
@@ -77,3 +77,15 @@ void squeezellm_gemm(
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
);
csrc/pybind.cpp
View file @
0fbfc4b8
...
...
@@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Quantization ops
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
// Cache ops
...
...
csrc/quantization/gptq/compat.cuh
0 → 100644
View file @
0fbfc4b8
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _compat_cuh
#define _compat_cuh
namespace
vllm
{
namespace
gptq
{
// atomicAdd for half types, to support CC < 7.x
__device__
__forceinline__
void
atomicAdd_half
(
half
*
address
,
half
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
((
char
*
)
address
-
((
size_t
)
address
&
2
));
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
__half_raw
hsum
;
hsum
.
x
=
(
size_t
)
address
&
2
?
(
old
>>
16
)
:
(
old
&
0xffff
);
half
tmpres
=
__hadd
(
hsum
,
val
);
hsum
=
__half_raw
(
tmpres
);
old
=
(
size_t
)
address
&
2
?
(
old
&
0xffff
)
|
(
hsum
.
x
<<
16
)
:
(
old
&
0xffff0000
)
|
hsum
.
x
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
old
);
}
while
(
assumed
!=
old
);
}
// atomicAdd for half2 types
__device__
__forceinline__
void
atomicAdd_half2
(
half2
*
address
,
half2
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
address
;
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
half2
old_val
=
*
((
half2
*
)
&
old
);
half2
new_val
=
__hadd2
(
old_val
,
val
);
old
=
atomicCAS
(
address_as_ui
,
assumed
,
*
((
unsigned
int
*
)
&
new_val
));
}
while
(
assumed
!=
old
);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
#endif
#endif
#endif
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/matrix_view.cuh
0 → 100644
View file @
0fbfc4b8
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
class
MatrixView_half
{
public:
const
half
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_half
(
const
half
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
half
item
(
int
row
,
int
column
)
const
{
return
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
half2
item_half2
(
int
row
,
int
column
)
const
{
return
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
];
}
__device__
__forceinline__
half2
item_half2half2
(
int
row
,
int
column
)
const
{
return
__half2half2
(
data
[
row
*
width
+
column
]);
}
__device__
__forceinline__
const
half
*
item_ptr
(
int
row
,
int
column
)
const
{
return
&
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
void
item4
(
half
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__low2half
(
i01
);
items
[
1
]
=
__high2half
(
i01
);
items
[
2
]
=
__low2half
(
i23
);
items
[
3
]
=
__high2half
(
i23
);
}
__device__
__forceinline__
void
item4_f
(
float
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__half2float
(
__low2half
(
i01
));
items
[
1
]
=
__half2float
(
__high2half
(
i01
));
items
[
2
]
=
__half2float
(
__low2half
(
i23
));
items
[
3
]
=
__half2float
(
__high2half
(
i23
));
}
__device__
__forceinline__
void
item4_h2
(
half2
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__half2half2
(
__low2half
(
i01
));
items
[
1
]
=
__half2half2
(
__high2half
(
i01
));
items
[
2
]
=
__half2half2
(
__low2half
(
i23
));
items
[
3
]
=
__half2half2
(
__high2half
(
i23
));
}
};
class
MatrixView_half_rw
{
public:
half
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_half_rw
(
half
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
half
item
(
int
row
,
int
column
)
const
{
return
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
half2
item_half2
(
int
row
,
int
column
)
const
{
return
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
];
}
__device__
__forceinline__
half
*
item_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
void
set
(
int
row
,
int
column
,
half
value
)
{
data
[
row
*
width
+
column
]
=
value
;
}
__device__
__forceinline__
void
set_half2
(
int
row
,
int
column
,
half2
value
)
{
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
]
=
value
;
}
__device__
__forceinline__
void
set4
(
int
row
,
int
column
,
half
v0
,
half
v1
,
half
v2
,
half
v3
)
{
half2
v01
=
__halves2half2
(
v0
,
v1
);
half2
v23
=
__halves2half2
(
v2
,
v3
);
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
ptr
[
0
]
=
v01
;
ptr
[
1
]
=
v23
;
}
};
class
MatrixView_q4_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q4_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
return
(
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
)
&
0x0f
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
uint32_t
d
=
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
;
items
[
0
]
=
d
&
0x0f
;
items
[
1
]
=
(
d
>>
4
)
&
0x0f
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
uint32_t
d
=
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
;
items
[
0
]
=
d
&
0x0f
;
items
[
1
]
=
(
d
>>
4
)
&
0x0f
;
items
[
2
]
=
(
d
>>
8
)
&
0x0f
;
items
[
3
]
=
(
d
>>
12
)
&
0x0f
;
}
};
class
MatrixView_q4_column
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q4_column
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
row
&
0x07
)
*
4
;
return
(
data
[
row
/
8
*
width
+
column
]
>>
shift
)
&
0x0f
;
}
__device__
__forceinline__
uint32_t
item_uint32_t
(
int
row
,
int
column
)
{
return
data
[
row
/
8
*
width
+
column
];
}
__device__
__forceinline__
const
uint32_t
*
item_uint32_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
/
8
*
width
+
column
];
}
};
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/q_gemm.cu
0 → 100644
View file @
0fbfc4b8
This diff is collapsed.
Click to expand it.
csrc/quantization/gptq/qdq_4.cuh
0 → 100644
View file @
0fbfc4b8
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x0f
;
uint32_t
qa1
=
(
qa
&
0xf0
)
>>
4
;
qa
>>=
8
;
qb
|=
(
qa1
<<
(
i
*
4
+
16
));
qb
|=
(
qa0
<<
(
i
*
4
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
8.0
f
);
const
half
z16_
=
__float2half_rn
(
-
1024.0
f
/
16.0
f
-
8.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z16
=
__halves2half2
(
z16_
,
z16_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 2], q[ 3]) * 16 + 1024
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 16 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y16
,
z16
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y16
,
z16
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
half2
scale2
=
__half2half2
(
scale
);
z1z16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
z1
.
as_half
));
z1z16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
z16
));
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
y1
));
y1y16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
y16
));
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
z1z16
[
0
]
=
__half2half2
(
z1
.
as_half
);
z1z16
[
1
]
=
__half2half2
(
z16
);
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__half2half2
(
y1
);
y1y16
[
1
]
=
__half2half2
(
y16
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
],
int
stride
,
bool
scaled
)
{
const
uint32_t
c0
=
0x64006400
;
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[0] + 1024, q[1] + 1024 )
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[4] + 1024, q[5] + 1024 )
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
q0
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
// half2( q[0] * s - z * s, q[1] * s - z * s)
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] * s - z * s, q[3] * s - z * s)
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
}
else
{
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1z16
[
0
]);
// half2( q[0] - z, q[1] - z )
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] - z, q[3] - z )
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1z16
[
0
]);
// half2( q[4] - z, q[5] - z )
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[6] - z, q[7] - z )
}
}
}
// namespace gptq
}
// namespace vllm
#else
namespace
vllm
{
namespace
gptq
{
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
4
,
0x0f
),
8
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z
=
__hmul
(
z
,
scale
);
z1
[
0
]
=
__half2half2
(
z
);
y1
[
0
]
=
__half2half2
(
scale
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z1
[
0
]
=
__half2half2
(
z
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
],
int
stride
,
bool
scaled
)
{
half2
dqh2
[
8
];
uint32_t
qa
=
q_0
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
half
d0
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
half
d1
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
dqh2
[
i
]
=
__halves2half2
(
d0
,
d1
);
}
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
dqh2
[
0
],
y1
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hfma2
(
dqh2
[
1
],
y1
[
0
],
z1
[
0
]);
dq
[
2
]
=
__hfma2
(
dqh2
[
2
],
y1
[
0
],
z1
[
0
]);
dq
[
3
]
=
__hfma2
(
dqh2
[
3
],
y1
[
0
],
z1
[
0
]);
}
else
{
dq
[
0
]
=
__hadd2
(
dqh2
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hadd2
(
dqh2
[
1
],
z1
[
0
]);
dq
[
2
]
=
__hadd2
(
dqh2
[
2
],
z1
[
0
]);
dq
[
3
]
=
__hadd2
(
dqh2
[
3
],
z1
[
0
]);
}
}
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/qdq_util.cuh
0 → 100644
View file @
0fbfc4b8
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace
vllm
{
namespace
gptq
{
union
half2_uint32
{
uint32_t
as_uint32
;
half2
as_half2
;
__device__
half2_uint32
(
uint32_t
val
)
:
as_uint32
(
val
)
{}
__device__
half2_uint32
(
half2
val
)
:
as_half2
(
val
)
{}
};
union
half_uint16
{
uint16_t
as_uint16
;
half
as_half
;
__device__
half_uint16
(
uint16_t
val
)
:
as_uint16
(
val
)
{}
__device__
half_uint16
(
half
val
)
:
as_half
(
val
)
{}
};
// Max_scale premultiplied by 1/256
__forceinline__
__device__
half
dq_scale
(
const
int
qs
,
const
half
max_scale
)
{
int
qs_i
=
qs
+
1
;
half
qs_h
=
__int2half_rn
(
qs_i
*
qs_i
);
qs_h
=
__hmul
(
qs_h
,
max_scale
);
return
qs_h
;
}
__forceinline__
__device__
half
dq
(
const
int
q
,
const
int
qzero
,
const
half
scale
)
{
return
__hmul
(
__int2half_rn
(
q
-
qzero
),
scale
);
}
__forceinline__
__device__
half
dq_ns
(
const
int
q
,
const
int
qzero
)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return
__int2half_rn
(
q
-
qzero
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)((
q
>>
shift
)
&
mask
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q1
,
const
uint32_t
q0
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)(
__funnelshift_rc
(
q0
,
q1
,
shift
)
&
mask
);
}
}
// namespace gptq
}
// namespace vllm
#endif
setup.py
View file @
0fbfc4b8
...
...
@@ -219,6 +219,7 @@ vllm_extension_sources = [
"csrc/activation_kernels.cu"
,
"csrc/layernorm_kernels.cu"
,
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
,
"csrc/quantization/gptq/q_gemm.cu"
,
"csrc/cuda_utils_kernels.cu"
,
"csrc/pybind.cpp"
,
]
...
...
vllm/config.py
View file @
0fbfc4b8
...
...
@@ -142,7 +142,7 @@ class ModelConfig:
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
"awq"
,
"squeezellm"
]
supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
]
rocm_not_supported_quantization
=
[
"awq"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
vllm/engine/arg_utils.py
View file @
0fbfc4b8
...
...
@@ -179,7 +179,7 @@ class EngineArgs:
parser
.
add_argument
(
'--quantization'
,
'-q'
,
type
=
str
,
choices
=
[
'awq'
,
'squeezellm'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
default
=
None
,
help
=
'Method used to quantize the weights'
)
return
parser
...
...
vllm/entrypoints/llm.py
View file @
0fbfc4b8
...
...
@@ -38,8 +38,9 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
we support "awq", "gptq" and "squeezellm". If None, we assume the
model weights are not quantized and use `dtype` to determine the
data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
...
...
vllm/model_executor/layers/linear.py
View file @
0fbfc4b8
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
...
...
@@ -21,8 +21,10 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
def
create_weights
(
self
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
"""Create weights for a linear layer."""
raise
NotImplementedError
...
...
@@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
def
create_weights
(
self
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
torch
.
Tensor
]:
weight
=
Parameter
(
torch
.
empty
(
output_size
,
input_size
,
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
),
requires_grad
=
False
)
...
...
@@ -102,9 +106,11 @@ class ReplicatedLinear(torch.nn.Module):
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
self
.
input_size
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
linear_weights
.
items
():
self
.
register_parameter
(
name
,
weight
)
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
...
...
@@ -168,10 +174,12 @@ class ColumnParallelLinear(torch.nn.Module):
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size
,
self
.
output_size_per_partition
,
self
.
params_dtype
)
self
.
input_size
,
self
.
output_size_per_partition
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
linear_weights
.
items
():
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
...
...
@@ -295,10 +303,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
else
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -418,10 +428,12 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
else
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -481,10 +493,12 @@ class RowParallelLinear(torch.nn.Module):
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size_per_partition
,
self
.
output_size
,
self
.
params_dtype
)
self
.
input_size_per_partition
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
linear_weights
.
items
():
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
0fbfc4b8
from
typing
import
Type
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
_QUANTIZATION_CONFIG_REGISTRY
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
}
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
0fbfc4b8
...
...
@@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
input_size
%
self
.
quant_config
.
group_size
!=
0
:
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
if
output_size
%
self
.
quant_config
.
pack_factor
!=
0
:
if
output_size
_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
...
...
@@ -92,8 +94,8 @@ class AWQLinearMethod(LinearMethodBase):
qweight
=
Parameter
(
torch
.
empty
(
input_size
,
output_size
//
self
.
quant_config
.
pack_factor
,
input_size
_per_partition
,
output_size
_per_partition
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
...
...
@@ -108,8 +110,8 @@ class AWQLinearMethod(LinearMethodBase):
})
qzeros
=
Parameter
(
torch
.
empty
(
input_size
//
self
.
quant_config
.
group_size
,
output_size
//
self
.
quant_config
.
pack_factor
,
input_size
_per_partition
//
self
.
quant_config
.
group_size
,
output_size
_per_partition
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
...
...
@@ -124,8 +126,8 @@ class AWQLinearMethod(LinearMethodBase):
})
scales
=
Parameter
(
torch
.
empty
(
input_size
//
self
.
quant_config
.
group_size
,
output_size
,
input_size
_per_partition
//
self
.
quant_config
.
group_size
,
output_size
_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
...
...
@@ -142,7 +144,7 @@ class AWQLinearMethod(LinearMethodBase):
}
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
torch
.
Tensor
],
weights
:
Dict
[
str
,
Any
],
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"qweight"
]
...
...
vllm/model_executor/layers/quantization/gptq.py
0 → 100644
View file @
0fbfc4b8
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm._C
import
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
class
GPTQConfig
(
QuantizationConfig
):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
pack_factor
=
32
//
self
.
weight_bits
# exllama kernel v1 only supports 4 bit
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
"Currently, only 4-bit weight quantization is supported for "
f
"GPTQ, but got
{
self
.
weight_bits
}
bits."
)
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
60
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
)
def
get_linear_method
(
self
)
->
"GPTQLinearMethod"
:
return
GPTQLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ExllamaState
(
Enum
):
UNUSED
=
enum
.
auto
()
UNINITIALIZED
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ.
Args:
quant_config: The GPTQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
)
->
Dict
[
str
,
Any
]:
del
output_size
# Unused.
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
exllama_state
=
ExllamaState
.
UNINITIALIZED
scale_and_zero_size
=
input_size
//
group_size
scale_and_zero_input_dim
=
None
if
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
:
# For act-order models, we cannot use Exllama for row parallel layer
if
self
.
quant_config
.
desc_act
:
exllama_state
=
ExllamaState
.
UNUSED
else
:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size
=
input_size_per_partition
//
group_size
scale_and_zero_input_dim
=
0
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
g_idx
=
Parameter
(
torch
.
tensor
(
[
i
//
self
.
quant_config
.
group_size
for
i
in
range
(
input_size_per_partition
)
],
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
"input_dim"
:
0
,
"ignore_warning"
:
True
})
qzeros
=
Parameter
(
torch
.
empty
(
scale_and_zero_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
"input_dim"
:
scale_and_zero_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
scale_and_zero_size
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
scale_and_zero_input_dim
,
"output_dim"
:
1
,
})
return
{
"qweight"
:
qweight
,
"g_idx"
:
g_idx
,
"qzeros"
:
qzeros
,
"scales"
:
scales
,
"exllama_state"
:
exllama_state
,
}
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
],
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"qweight"
]
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if
weights
[
"exllama_state"
]
==
ExllamaState
.
UNINITIALIZED
:
if
self
.
quant_config
.
desc_act
:
weights
[
"g_idx"
]
=
torch
.
argsort
(
weights
[
"g_idx"
]).
to
(
torch
.
int
)
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
])
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/squeezellm.py
View file @
0fbfc4b8
...
...
@@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
SqueezeLLMConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
input_size
%
self
.
quant_config
.
pack_factor
!=
0
:
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
if
input_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
qweight
=
Parameter
(
torch
.
empty
(
input_size
//
self
.
quant_config
.
pack_factor
,
output_size
,
input_size
_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size
_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
...
...
@@ -108,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
}
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
torch
.
Tensor
],
weights
:
Dict
[
str
,
Any
],
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"qweight"
]
...
...
vllm/model_executor/models/aquila.py
View file @
0fbfc4b8
...
...
@@ -332,11 +332,18 @@ class AquilaForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/baichuan.py
View file @
0fbfc4b8
...
...
@@ -355,11 +355,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment