Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
0fbfc4b8
"vscode:/vscode.git/clone" did not exist on "6534efd6ca5955316f60fce7563e3bcd97cca583"
Unverified
Commit
0fbfc4b8
authored
Dec 15, 2023
by
CHU Tianxiang
Committed by
GitHub
Dec 15, 2023
Browse files
Add GPTQ support (#916)
parent
c06170cc
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1683 additions
and
51 deletions
+1683
-51
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+1
-1
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+1
-1
csrc/ops.h
csrc/ops.h
+12
-0
csrc/pybind.cpp
csrc/pybind.cpp
+2
-2
csrc/quantization/gptq/compat.cuh
csrc/quantization/gptq/compat.cuh
+64
-0
csrc/quantization/gptq/matrix_view.cuh
csrc/quantization/gptq/matrix_view.cuh
+151
-0
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+859
-0
csrc/quantization/gptq/qdq_4.cuh
csrc/quantization/gptq/qdq_4.cuh
+235
-0
csrc/quantization/gptq/qdq_util.cuh
csrc/quantization/gptq/qdq_util.cuh
+60
-0
setup.py
setup.py
+1
-0
vllm/config.py
vllm/config.py
+1
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+3
-2
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+37
-23
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-1
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+13
-11
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+215
-0
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+8
-6
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+8
-1
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+8
-1
No files found.
benchmarks/benchmark_latency.py
View file @
0fbfc4b8
...
...
@@ -84,7 +84,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
'-q'
,
choices
=
[
'awq'
,
'squeezellm'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
default
=
None
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
...
...
benchmarks/benchmark_throughput.py
View file @
0fbfc4b8
...
...
@@ -244,7 +244,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
'-q'
,
choices
=
[
'awq'
,
'squeezellm'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
...
...
csrc/ops.h
View file @
0fbfc4b8
...
...
@@ -77,3 +77,15 @@ void squeezellm_gemm(
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
);
csrc/pybind.cpp
View file @
0fbfc4b8
...
...
@@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Quantization ops
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
// Cache ops
...
...
csrc/quantization/gptq/compat.cuh
0 → 100644
View file @
0fbfc4b8
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _compat_cuh
#define _compat_cuh
namespace
vllm
{
namespace
gptq
{
// atomicAdd for half types, to support CC < 7.x
__device__
__forceinline__
void
atomicAdd_half
(
half
*
address
,
half
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
((
char
*
)
address
-
((
size_t
)
address
&
2
));
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
__half_raw
hsum
;
hsum
.
x
=
(
size_t
)
address
&
2
?
(
old
>>
16
)
:
(
old
&
0xffff
);
half
tmpres
=
__hadd
(
hsum
,
val
);
hsum
=
__half_raw
(
tmpres
);
old
=
(
size_t
)
address
&
2
?
(
old
&
0xffff
)
|
(
hsum
.
x
<<
16
)
:
(
old
&
0xffff0000
)
|
hsum
.
x
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
old
);
}
while
(
assumed
!=
old
);
}
// atomicAdd for half2 types
__device__
__forceinline__
void
atomicAdd_half2
(
half2
*
address
,
half2
val
)
{
unsigned
int
*
address_as_ui
=
(
unsigned
int
*
)
address
;
unsigned
int
old
=
*
address_as_ui
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
half2
old_val
=
*
((
half2
*
)
&
old
);
half2
new_val
=
__hadd2
(
old_val
,
val
);
old
=
atomicCAS
(
address_as_ui
,
assumed
,
*
((
unsigned
int
*
)
&
new_val
));
}
while
(
assumed
!=
old
);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
atomicAdd_half
(
address
,
val
);
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
atomicAdd_half2
(
address
,
val
);
}
#endif
#endif
#endif
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/matrix_view.cuh
0 → 100644
View file @
0fbfc4b8
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
class
MatrixView_half
{
public:
const
half
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_half
(
const
half
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
half
item
(
int
row
,
int
column
)
const
{
return
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
half2
item_half2
(
int
row
,
int
column
)
const
{
return
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
];
}
__device__
__forceinline__
half2
item_half2half2
(
int
row
,
int
column
)
const
{
return
__half2half2
(
data
[
row
*
width
+
column
]);
}
__device__
__forceinline__
const
half
*
item_ptr
(
int
row
,
int
column
)
const
{
return
&
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
void
item4
(
half
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__low2half
(
i01
);
items
[
1
]
=
__high2half
(
i01
);
items
[
2
]
=
__low2half
(
i23
);
items
[
3
]
=
__high2half
(
i23
);
}
__device__
__forceinline__
void
item4_f
(
float
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__half2float
(
__low2half
(
i01
));
items
[
1
]
=
__half2float
(
__high2half
(
i01
));
items
[
2
]
=
__half2float
(
__low2half
(
i23
));
items
[
3
]
=
__half2float
(
__high2half
(
i23
));
}
__device__
__forceinline__
void
item4_h2
(
half2
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
half2
i01
=
ptr
[
0
];
half2
i23
=
ptr
[
1
];
items
[
0
]
=
__half2half2
(
__low2half
(
i01
));
items
[
1
]
=
__half2half2
(
__high2half
(
i01
));
items
[
2
]
=
__half2half2
(
__low2half
(
i23
));
items
[
3
]
=
__half2half2
(
__high2half
(
i23
));
}
};
class
MatrixView_half_rw
{
public:
half
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_half_rw
(
half
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
half
item
(
int
row
,
int
column
)
const
{
return
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
half2
item_half2
(
int
row
,
int
column
)
const
{
return
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
];
}
__device__
__forceinline__
half
*
item_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
*
width
+
column
];
}
__device__
__forceinline__
void
set
(
int
row
,
int
column
,
half
value
)
{
data
[
row
*
width
+
column
]
=
value
;
}
__device__
__forceinline__
void
set_half2
(
int
row
,
int
column
,
half2
value
)
{
((
half2
*
)
data
)[(
row
*
width
+
column
)
/
2
]
=
value
;
}
__device__
__forceinline__
void
set4
(
int
row
,
int
column
,
half
v0
,
half
v1
,
half
v2
,
half
v3
)
{
half2
v01
=
__halves2half2
(
v0
,
v1
);
half2
v23
=
__halves2half2
(
v2
,
v3
);
half2
*
ptr
=
(
half2
*
)
item_ptr
(
row
,
column
);
ptr
[
0
]
=
v01
;
ptr
[
1
]
=
v23
;
}
};
class
MatrixView_q4_row
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q4_row
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
return
(
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
)
&
0x0f
;
}
__device__
__forceinline__
void
item2
(
int
(
&
items
)[
2
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
uint32_t
d
=
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
;
items
[
0
]
=
d
&
0x0f
;
items
[
1
]
=
(
d
>>
4
)
&
0x0f
;
}
__device__
__forceinline__
void
item4
(
int
(
&
items
)[
4
],
int
row
,
int
column
)
const
{
int
shift
=
(
column
&
0x07
)
*
4
;
uint32_t
d
=
data
[
row
*
width
/
8
+
column
/
8
]
>>
shift
;
items
[
0
]
=
d
&
0x0f
;
items
[
1
]
=
(
d
>>
4
)
&
0x0f
;
items
[
2
]
=
(
d
>>
8
)
&
0x0f
;
items
[
3
]
=
(
d
>>
12
)
&
0x0f
;
}
};
class
MatrixView_q4_column
{
public:
const
uint32_t
*
data
;
const
int
height
;
const
int
width
;
__device__
__forceinline__
MatrixView_q4_column
(
const
uint32_t
*
data
,
const
int
height
,
const
int
width
)
:
data
(
data
),
height
(
height
),
width
(
width
)
{
}
__device__
__forceinline__
int
item
(
int
row
,
int
column
)
const
{
int
shift
=
(
row
&
0x07
)
*
4
;
return
(
data
[
row
/
8
*
width
+
column
]
>>
shift
)
&
0x0f
;
}
__device__
__forceinline__
uint32_t
item_uint32_t
(
int
row
,
int
column
)
{
return
data
[
row
/
8
*
width
+
column
];
}
__device__
__forceinline__
const
uint32_t
*
item_uint32_ptr
(
int
row
,
int
column
)
{
return
&
data
[
row
/
8
*
width
+
column
];
}
};
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/q_gemm.cu
0 → 100644
View file @
0fbfc4b8
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa
*/
#include <cstdint>
#include <cstdio>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "compat.cuh"
#include "matrix_view.cuh"
#include "qdq_4.cuh"
namespace
vllm
{
namespace
gptq
{
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
#define MAX_Q_GEMM_ROWS 50
#define MAX_ALT_GEMM_ROWS 8
#define THREADS_X 32
#define THREADS_Y 32
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
#if defined(USE_ROCM)
__host__
__forceinline__
hipblasStatus_t
__compat_hipblasHgemm
(
hipblasHandle_t
handle
,
hipblasOperation_t
transA
,
hipblasOperation_t
transB
,
int
m
,
int
n
,
int
k
,
const
half
*
alpha
,
const
half
*
AP
,
int
lda
,
const
half
*
BP
,
int
ldb
,
const
half
*
beta
,
half
*
CP
,
int
ldc
)
{
return
hipblasHgemm
(
handle
,
transA
,
transB
,
m
,
n
,
k
,
reinterpret_cast
<
const
hipblasHalf
*>
(
alpha
),
reinterpret_cast
<
const
hipblasHalf
*>
(
AP
),
lda
,
reinterpret_cast
<
const
hipblasHalf
*>
(
BP
),
ldb
,
reinterpret_cast
<
const
hipblasHalf
*>
(
beta
),
reinterpret_cast
<
hipblasHalf
*>
(
CP
),
ldc
);
}
#define hipblasHgemm __compat_hipblasHgemm
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
#define rocblas_operation_none HIPBLAS_OP_N
#define rocblas_hgemm __compat_hipblasHgemm
#endif
__forceinline__
__device__
half2
dot22_8
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
,
const
half2
g_result
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__hadd2
(
result
,
g_result
);
}
__forceinline__
__device__
float
dot22_8_f
(
half2
(
&
dq
)[
4
],
const
half
*
a_ptr
)
{
half2
result
=
{};
const
half2
*
a2_ptr
=
(
const
half2
*
)
a_ptr
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
result
=
__hfma2
(
dq
[
i
],
*
a2_ptr
++
,
result
);
return
__half2float
(
__low2half
(
result
))
+
__half2float
(
__high2half
(
result
));
}
typedef
void
(
*
fp_gemm_half_q_half_gptq_kernel
)
(
const
half
*
,
const
uint32_t
*
,
const
uint32_t
*
,
const
half
*
,
half
*
,
const
int
,
const
int
,
const
int
,
const
int
,
const
int
*
);
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_gptq_kernel
(
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_n
,
const
int
size_k
,
const
int
groups
,
const
int
*
__restrict__
b_q_perm
)
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q4_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
// Block
int
offset_n
=
blockIdx
.
x
*
BLOCK_KN_SIZE
*
4
;
int
offset_m
=
blockIdx
.
y
*
m_count
;
int
offset_k
=
blockIdx
.
z
*
BLOCK_KN_SIZE
;
int
end_n
=
min
(
offset_n
+
BLOCK_KN_SIZE
*
4
,
size_n
);
int
end_m
=
min
(
offset_m
+
m_count
,
size_m
);
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
int
n
=
offset_n
+
t
*
4
;
// Preload block_a
__shared__
half
block_a
[
m_count
][
BLOCK_KN_SIZE
];
if
(
offset_k
+
t
<
end_k
)
{
for
(
int
m
=
0
;
m
<
m_count
;
++
m
)
{
const
half
*
a_ptr
=
a_
.
item_ptr
(
offset_m
+
m
,
0
);
half
*
block_a_ptr
=
block_a
[
m
];
half
a0
;
if
(
b_q_perm
)
a0
=
a_ptr
[
b_q_perm
[
offset_k
+
t
]];
else
a0
=
a_ptr
[
offset_k
+
t
];
block_a_ptr
[
t
]
=
a0
;
}
}
// Zero output
if
(
n
>=
size_n
)
return
;
if
(
blockIdx
.
z
==
0
)
{
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
*
((
uint64_t
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
))
=
0
;
}
__syncthreads
();
// Find initial group
int
groupsize
=
size_k
/
groups
;
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// a, b offset
int
qk
=
offset_k
/
(
32
/
4
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
const
half
*
a_ptr
=
&
block_a
[
0
][
0
];
int
a_stride
=
BLOCK_KN_SIZE
;
// Initial group
int
zeros
[
4
];
float
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_f
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
3
]
+
1
,
z1z16
[
3
],
y1y16
[
3
]);
// Column result
float
block_c
[
m_count
][
4
]
=
{};
// Dequantize and multiply
int
k
=
offset_k
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_f
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
3
]
+
1
,
z1z16
[
3
],
y1y16
[
3
]);
}
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
half2
dq
[
4
][
4
];
dequant_4bit_8_gptq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
#pragma unroll
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
block_c
[
m
][
0
]
=
fma
(
dot22_8_f
(
dq
[
0
],
a_ptr
+
m
*
a_stride
),
scales
[
0
],
block_c
[
m
][
0
]);
block_c
[
m
][
1
]
=
fma
(
dot22_8_f
(
dq
[
1
],
a_ptr
+
m
*
a_stride
),
scales
[
1
],
block_c
[
m
][
1
]);
block_c
[
m
][
2
]
=
fma
(
dot22_8_f
(
dq
[
2
],
a_ptr
+
m
*
a_stride
),
scales
[
2
],
block_c
[
m
][
2
]);
block_c
[
m
][
3
]
=
fma
(
dot22_8_f
(
dq
[
3
],
a_ptr
+
m
*
a_stride
),
scales
[
3
],
block_c
[
m
][
3
]);
}
b_ptr
+=
size_n
;
a_ptr
+=
8
;
}
k
+=
32
;
}
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
{
half2
*
out
=
(
half2
*
)
c_
.
item_ptr
(
offset_m
+
m
,
n
);
half2
result01
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
0
]),
__float2half_rn
(
block_c
[
m
][
1
]));
half2
result23
=
__halves2half2
(
__float2half_rn
(
block_c
[
m
][
2
]),
__float2half_rn
(
block_c
[
m
][
3
]));
atomicAdd
(
out
,
result01
);
atomicAdd
(
out
+
1
,
result23
);
}
}
fp_gemm_half_q_half_gptq_kernel
pick_gemm_half_q_half_gptq_kernel
(
bool
first_block
,
const
int
m_count
)
{
#if BLOCK_M_SIZE_MAX >= 1
if
(
m_count
==
1
)
return
gemm_half_q_half_gptq_kernel
<
true
,
1
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 2
if
(
m_count
==
2
)
return
gemm_half_q_half_gptq_kernel
<
true
,
2
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 3
if
(
m_count
==
3
)
return
gemm_half_q_half_gptq_kernel
<
true
,
3
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 4
if
(
m_count
==
4
)
return
gemm_half_q_half_gptq_kernel
<
true
,
4
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 5
if
(
m_count
==
5
)
return
gemm_half_q_half_gptq_kernel
<
true
,
5
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 6
if
(
m_count
==
6
)
return
gemm_half_q_half_gptq_kernel
<
true
,
6
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 7
if
(
m_count
==
7
)
return
gemm_half_q_half_gptq_kernel
<
true
,
7
>
;
#endif
#if BLOCK_M_SIZE_MAX >= 8
if
(
m_count
==
8
)
return
gemm_half_q_half_gptq_kernel
<
true
,
8
>
;
#endif
return
NULL
;
}
void
gemm_half_q_half_cuda_part
(
const
half
*
a
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_gptq_qzeros
,
const
half
*
b_gptq_scales
,
const
int
*
b_q_perm
,
half
*
c
,
int
size_m
,
int
size_n
,
int
size_k
,
int
m_count
,
int
groups
)
{
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
blockDim
.
z
=
1
;
gridDim
.
x
=
DIVIDE
(
size_n
,
BLOCK_KN_SIZE
*
4
);
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
fp_gemm_half_q_half_gptq_kernel
kernel
=
pick_gemm_half_q_half_gptq_kernel
(
true
,
m_count
);
kernel
<<<
gridDim
,
blockDim
>>>
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
c
,
size_m
,
size_n
,
size_k
,
groups
,
b_q_perm
);
}
__global__
void
reconstruct_exllama_kernel
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_gptq_qzeros
,
const
half
*
__restrict__
b_gptq_scales
,
const
int
size_k
,
const
int
size_n
,
const
int
groups
,
half
*
__restrict__
b
)
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q4_row
b_gptq_qzeros_
(
b_gptq_qzeros
,
groups
,
size_n
);
MatrixView_half
b_gptq_scales_
(
b_gptq_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
end_k
=
min
(
offset_k
+
BLOCK_KN_SIZE
,
size_k
);
// Preload remapping table
__shared__
int
perm
[
BLOCK_KN_SIZE
];
int
t
=
threadIdx
.
x
;
if
(
b_q_perm
)
{
if
(
offset_k
+
t
<
size_k
)
perm
[
t
]
=
b_q_perm
[
offset_k
+
t
];
}
// Column
int
n
=
offset_n
+
t
*
4
;
if
(
n
>=
size_n
)
return
;
// Find initial group
int
groupsize
=
size_k
/
groups
;
int
group
=
offset_k
/
groupsize
;
int
nextgroup
=
offset_k
+
groupsize
;
// b offset
int
qk
=
offset_k
/
(
32
/
4
);
const
uint32_t
*
b_ptr
=
b_q_weight
+
qk
*
size_n
+
n
;
// Initial zeros/scale
int
zeros
[
4
];
half2
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
3
]
+
1
,
z1z16
[
3
],
y1y16
[
3
]);
__syncthreads
();
int
k
=
offset_k
;
int
lk
=
0
;
while
(
k
<
end_k
)
{
if
(
k
==
nextgroup
)
{
group
++
;
nextgroup
+=
groupsize
;
b_gptq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_gptq_scales_
.
item4_h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
3
]
+
1
,
z1z16
[
3
],
y1y16
[
3
]);
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
{
half2
dq
[
4
][
4
];
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
dequant_4bit_8_gptq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_gptq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
b_ptr
+=
size_n
;
//half* dqh = (half*)dq;
if
(
b_q_perm
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
perm
[
lk
++
],
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
perm
[
lk
++
],
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
else
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
v
=
0
;
v
<
4
;
v
++
)
dq
[
v
][
j
]
=
__hmul2
(
scales
[
v
],
dq
[
v
][
j
]);
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__low2half
(
dq
[
0
][
j
]),
__low2half
(
dq
[
1
][
j
]),
__low2half
(
dq
[
2
][
j
]),
__low2half
(
dq
[
3
][
j
]));
b_
.
set4
(
offset_k
+
lk
++
,
n
,
__high2half
(
dq
[
0
][
j
]),
__high2half
(
dq
[
1
][
j
]),
__high2half
(
dq
[
2
][
j
]),
__high2half
(
dq
[
3
][
j
]));
}
}
}
k
+=
32
;
}
}
void
reconstruct_exllama
(
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_gptq_qzeros
,
const
half
*
b_gptq_scales
,
const
int
*
b_q_perm
,
half
*
out
,
int
height
,
int
width
,
int
groups
)
{
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
gridDim
.
y
=
DIVIDE
(
height
,
BLOCK_KN_SIZE
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
reconstruct_exllama_kernel
<<<
gridDim
,
blockDim
>>>
(
b_q_weight
,
b_q_perm
,
b_gptq_qzeros
,
b_gptq_scales
,
height
,
width
,
groups
,
out
);
}
__global__
void
gemm_half_q_half_alt_kernel
(
const
half2
*
__restrict__
vec
,
const
uint32_t
*
__restrict__
mat
,
half
*
__restrict__
mul
,
const
half
*
__restrict__
scales
,
const
uint32_t
*
__restrict__
zeros
,
const
int
*
__restrict__
g_idx
,
int
batch
,
int
height
,
int
width
)
{
int
zero_width
=
width
/
8
;
int
vec_height
=
height
*
4
;
const
int
blockwidth2
=
BLOCK_KN_SIZE
/
2
;
int
b
=
blockIdx
.
y
*
BLOCK_M_SIZE_MAX
;
int
b_end
=
min
(
BLOCK_M_SIZE_MAX
,
batch
-
b
);
int
h
=
BLOCK_KN_SIZE
*
blockIdx
.
z
/
8
;
int
h_end
=
min
(
BLOCK_KN_SIZE
/
8
,
height
-
h
)
*
4
;
int
w
=
BLOCK_KN_SIZE
*
blockIdx
.
x
+
threadIdx
.
x
;
__shared__
half2
blockvec
[
BLOCK_M_SIZE_MAX
][
blockwidth2
];
if
(
threadIdx
.
x
<
h_end
)
{
for
(
int
m
=
0
;
m
<
b_end
;
++
m
)
{
blockvec
[
m
][
threadIdx
.
x
]
=
vec
[(
m
+
b
)
*
vec_height
+
blockIdx
.
z
*
BLOCK_KN_SIZE
/
2
+
threadIdx
.
x
];
}
}
__shared__
half2
deq2
[
256
][
8
];
int
val
=
threadIdx
.
x
/
8
;
int
off
=
threadIdx
.
x
%
8
;
for
(;
val
<
256
;
val
+=
BLOCK_KN_SIZE
/
8
)
{
deq2
[
val
][
off
]
=
__halves2half2
(
__int2half_rn
(
val
&
0xF
),
__int2half_rn
(
val
>>
4
)
);
}
if
(
blockIdx
.
z
==
0
)
{
for
(
int
m
=
0
;
m
<
b_end
;
m
++
)
mul
[(
b
+
m
)
*
width
+
w
]
=
__int2half_rn
(
0
);
}
__syncthreads
();
int
i
=
width
*
h
+
w
;
int
g_h
=
h
*
8
;
int
k
=
0
;
int
z_w
=
w
/
8
;
int
z_mod
=
(
w
%
8
)
*
4
;
half2
res2
;
half
res
[
BLOCK_M_SIZE_MAX
]
=
{};
unsigned
int
tmp
;
while
(
k
<
h_end
)
{
tmp
=
mat
[
i
];
half2
scales_tmp
[
4
];
half2
zeros_tmp
[
4
];
for
(
int
tmp_k
=
0
;
tmp_k
<
4
;
tmp_k
++
)
{
int
g
=
g_idx
[
g_h
+
(
k
+
tmp_k
)
*
2
];
int
g2
=
g_idx
[
g_h
+
(
k
+
tmp_k
)
*
2
+
1
];
half
scale_f
=
scales
[
g
*
width
+
w
];
half
scale_f2
=
scales
[
g2
*
width
+
w
];
half2
scale
=
__halves2half2
(
scale_f
,
scale_f2
);
half2
zero
=
__halves2half2
(
__hmul
(
scale_f
,
__int2half_rn
(
-
((
zeros
[
g
*
zero_width
+
z_w
]
>>
z_mod
)
&
0xF
)
-
1
)),
__hmul
(
scale_f2
,
__int2half_rn
(
-
((
zeros
[
g2
*
zero_width
+
z_w
]
>>
z_mod
)
&
0xF
)
-
1
))
);
scales_tmp
[
tmp_k
]
=
scale
;
zeros_tmp
[
tmp_k
]
=
zero
;
}
for
(
int
m
=
0
;
m
<
b_end
;
m
++
)
{
res2
=
{};
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
0
)
&
0xff
][
off
],
scales_tmp
[
0
],
zeros_tmp
[
0
]),
blockvec
[
m
][
k
+
0
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
8
)
&
0xff
][
off
],
scales_tmp
[
1
],
zeros_tmp
[
1
]),
blockvec
[
m
][
k
+
1
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
16
)
&
0xff
][
off
],
scales_tmp
[
2
],
zeros_tmp
[
2
]),
blockvec
[
m
][
k
+
2
],
res2
);
res2
=
__hfma2
(
__hfma2
(
deq2
[(
tmp
>>
24
)
&
0xff
][
off
],
scales_tmp
[
3
],
zeros_tmp
[
3
]),
blockvec
[
m
][
k
+
3
],
res2
);
res
[
m
]
=
__hadd
(
res
[
m
],
__hadd
(
res2
.
x
,
res2
.
y
));
}
i
+=
width
;
k
+=
4
;
}
for
(
int
m
=
0
;
m
<
b_end
;
m
++
)
{
atomicAdd
(
&
mul
[(
b
+
m
)
*
width
+
w
],
res
[
m
]);
}
}
void
gemm_half_q_half_alt
(
const
half
*
a
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_gptq_qzeros
,
const
half
*
b_gptq_scales
,
const
int
*
b_g_idx
,
half
*
c
,
int
size_m
,
int
size_n
,
int
size_k
)
{
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
blockDim
.
z
=
1
;
gridDim
.
x
=
DIVIDE
(
size_n
,
BLOCK_KN_SIZE
);
gridDim
.
y
=
DIVIDE
(
size_m
,
BLOCK_M_SIZE_MAX
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
gemm_half_q_half_alt_kernel
<<<
gridDim
,
blockDim
>>>
(
(
const
half2
*
)
a
,
b_q_weight
,
c
,
b_gptq_scales
,
b_gptq_qzeros
,
b_g_idx
,
size_m
,
size_k
/
8
,
size_n
);
}
__global__
void
reconstruct_gptq_kernel
(
const
uint32_t
*
__restrict__
w
,
const
half
*
__restrict__
w_scales
,
const
uint32_t
*
__restrict__
w_zeros
,
const
int
*
__restrict__
g_idx
,
const
int
height
,
const
int
width
,
const
int
group
,
half
*
__restrict__
out
)
{
// Start of block
int
column
=
BLOCK_KN_SIZE
*
blockIdx
.
x
+
threadIdx
.
x
;
int
row
=
blockIdx
.
y
*
8
;
if
(
column
>=
width
)
return
;
// Views
MatrixView_q4_column
w_
(
w
,
height
,
width
);
MatrixView_half_rw
out_
(
out
,
height
,
width
);
MatrixView_half
w_scales_
(
w_scales
,
group
,
width
);
MatrixView_q4_row
w_zeros_
(
w_zeros
,
group
,
width
);
uint32_t
w_read
=
w_
.
item_uint32_t
(
row
,
column
);
half
*
out_ptr
=
out_
.
item_ptr
(
row
,
column
);
#pragma unroll
for
(
int
s
=
0
;
s
<
32
;
s
+=
4
)
{
int
group
=
g_idx
[
row
+
s
/
4
];
half
w_scale
=
w_scales_
.
item
(
group
,
column
);
uint32_t
w_zero
=
w_zeros_
.
item
(
group
,
column
)
+
1
;
half
w_item
=
__hmul
(
__int2half_rn
((
int
)((
w_read
>>
s
)
&
0x0f
)
-
w_zero
),
w_scale
);
*
out_ptr
=
w_item
;
out_ptr
+=
out_
.
width
;
}
}
void
reconstruct_gptq
(
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_gptq_qzeros
,
const
half
*
b_gptq_scales
,
const
int
*
b_g_idx
,
half
*
out
,
int
height
,
int
width
,
int
groups
)
{
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
BLOCK_KN_SIZE
;
blockDim
.
y
=
1
;
gridDim
.
y
=
DIVIDE
(
height
,
8
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
reconstruct_gptq_kernel
<<<
gridDim
,
blockDim
>>>
(
b_q_weight
,
b_gptq_scales
,
b_gptq_qzeros
,
b_g_idx
,
height
,
width
,
groups
,
out
);
}
void
gemm_half_q_half_cuda
(
cublasHandle_t
cublas_handle
,
const
half
*
a
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_gptq_qzeros
,
const
half
*
b_gptq_scales
,
const
int
*
b_g_idx
,
half
*
c
,
half
*
temp_dq
,
int
size_m
,
int
size_n
,
int
size_k
,
int
groups
,
bool
use_exllama
)
{
if
((
use_exllama
&&
size_m
>
MAX_Q_GEMM_ROWS
)
||
(
!
use_exllama
&&
size_m
>
MAX_ALT_GEMM_ROWS
))
{
// Reconstruct FP16 matrix, then cuBLAS
if
(
use_exllama
)
{
reconstruct_exllama
(
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
temp_dq
,
size_k
,
size_n
,
groups
);
}
else
{
reconstruct_gptq
(
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
temp_dq
,
size_k
,
size_n
,
groups
);
}
const
half
alpha
=
__float2half
(
1.0
f
);
const
half
beta
=
__float2half
(
0.0
f
);
cublasHgemm
(
cublas_handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
size_n
,
size_m
,
size_k
,
&
alpha
,
temp_dq
,
size_n
,
a
,
size_k
,
&
beta
,
c
,
size_n
);
}
else
if
(
use_exllama
)
{
// Quantized matmul
int
max_chunks
=
size_m
/
BLOCK_M_SIZE_MAX
;
int
last_chunk
=
max_chunks
*
BLOCK_M_SIZE_MAX
;
int
last_chunk_size
=
size_m
-
last_chunk
;
if
(
max_chunks
)
{
gemm_half_q_half_cuda_part
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
c
,
last_chunk
,
size_n
,
size_k
,
BLOCK_M_SIZE_MAX
,
groups
);
}
if
(
last_chunk_size
)
{
gemm_half_q_half_cuda_part
(
a
+
last_chunk
*
size_k
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
c
+
last_chunk
*
size_n
,
last_chunk_size
,
size_n
,
size_k
,
last_chunk_size
,
groups
);
}
}
else
{
gemm_half_q_half_alt
(
a
,
b_q_weight
,
b_gptq_qzeros
,
b_gptq_scales
,
b_g_idx
,
c
,
size_m
,
size_n
,
size_k
);
}
}
__global__
void
shuffle_kernel
(
uint32_t
*
__restrict__
b_q_weight
,
const
int
size_k
,
const
int
size_n
)
{
int
n
=
blockIdx
.
x
*
THREADS_X
+
threadIdx
.
x
;
if
(
n
>=
size_n
)
return
;
int
k
=
0
;
uint32_t
*
b_ptr
=
b_q_weight
+
n
;
while
(
k
<
size_k
)
{
shuffle_4bit_8
(
b_ptr
,
size_n
);
b_ptr
+=
1
*
size_n
;
k
+=
8
;
}
}
__global__
void
make_sequential_kernel
(
const
uint32_t
*
__restrict__
w
,
uint32_t
*
__restrict__
w_new
,
const
int
*
__restrict__
q_perm
,
const
int
w_height
,
const
int
w_width
)
{
const
uint64_t
*
w2
=
(
uint64_t
*
)
w
;
uint64_t
*
w_new2
=
(
uint64_t
*
)
w_new
;
int
w2_stride
=
w_width
>>
1
;
int
w2_column
=
THREADS_X
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
w2_column
>=
w2_stride
)
return
;
int
w_new2_row
=
blockIdx
.
y
;
int
q_perm_idx
=
w_new2_row
<<
3
;
uint64_t
dst
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
int
source_row
=
q_perm
[
q_perm_idx
++
];
int
w2_row
=
source_row
>>
3
;
int
w2_subrow
=
source_row
&
0x07
;
int
w2_row_shift
=
w2_subrow
<<
2
;
int
wnew2_row_shift
=
i
<<
2
;
uint64_t
src
=
w2
[
w2_row
*
w2_stride
+
w2_column
];
src
>>=
w2_row_shift
;
src
&=
0x0000000f0000000f
;
src
<<=
wnew2_row_shift
;
dst
|=
src
;
}
w_new2
[
w_new2_row
*
w2_stride
+
w2_column
]
=
dst
;
}
void
shuffle_exllama_weight
(
uint32_t
*
q_weight
,
int
*
q_perm
,
int
height
,
int
width
)
{
if
(
q_perm
)
{
uint32_t
*
new_qweight
=
NULL
;
cudaMalloc
(
&
new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
));
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
THREADS_X
;
blockDim
.
y
=
1
;
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
height
/
8
;
make_sequential_kernel
<<<
gridDim
,
blockDim
>>>
(
q_weight
,
new_qweight
,
q_perm
,
height
/
8
,
width
);
// Replace qweights
cudaMemcpyAsync
(
q_weight
,
new_qweight
,
height
/
8
*
width
*
sizeof
(
uint32_t
),
cudaMemcpyDeviceToDevice
);
// Cleanup
cudaDeviceSynchronize
();
cudaFree
(
new_qweight
);
}
dim3
blockDim
,
gridDim
;
blockDim
.
x
=
THREADS_X
;
blockDim
.
y
=
1
;
gridDim
.
x
=
DIVIDE
(
width
,
THREADS_X
);
gridDim
.
y
=
1
;
shuffle_kernel
<<<
gridDim
,
blockDim
>>>
(
q_weight
,
height
,
width
);
}
}
// namespace gptq
}
// namespace vllm
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
a
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
at
::
Tensor
c
=
torch
::
empty
({
a
.
size
(
0
),
b_q_weight
.
size
(
1
)},
options
);
at
::
Tensor
temp_dq
=
torch
::
empty
({
b_q_weight
.
size
(
0
)
*
8
,
b_q_weight
.
size
(
1
)},
options
);
vllm
::
gptq
::
gemm_half_q_half_cuda
(
at
::
cuda
::
getCurrentCUDABlasHandle
(),
(
const
half
*
)
a
.
data_ptr
(),
(
const
uint32_t
*
)
b_q_weight
.
data_ptr
(),
(
const
uint32_t
*
)
b_gptq_qzeros
.
data_ptr
(),
(
const
half
*
)
b_gptq_scales
.
data_ptr
(),
b_g_idx
.
device
().
is_meta
()
?
NULL
:
(
const
int
*
)
b_g_idx
.
data_ptr
(),
(
half
*
)
c
.
data_ptr
(),
(
half
*
)
temp_dq
.
data_ptr
(),
c
.
size
(
0
),
// m
c
.
size
(
1
),
// n
a
.
size
(
1
),
// k
b_gptq_qzeros
.
size
(
0
),
// group number
use_exllama
);
return
c
;
}
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
vllm
::
gptq
::
shuffle_exllama_weight
(
(
uint32_t
*
)
q_weight
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
?
NULL
:
(
int
*
)
q_perm
.
data_ptr
(),
q_weight
.
size
(
0
)
*
8
,
q_weight
.
size
(
1
)
);
}
csrc/quantization/gptq/qdq_4.cuh
0 → 100644
View file @
0fbfc4b8
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
namespace
vllm
{
namespace
gptq
{
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
uint32_t
qa
=
q
[
0
];
uint32_t
qb
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
uint32_t
qa0
=
qa
&
0x0f
;
uint32_t
qa1
=
(
qa
&
0xf0
)
>>
4
;
qa
>>=
8
;
qb
|=
(
qa1
<<
(
i
*
4
+
16
));
qb
|=
(
qa0
<<
(
i
*
4
));
}
q
[
0
]
=
qb
;
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
const
uint32_t
c0
=
0x64006400
;
const
half
y16_
=
__float2half_rn
(
1.0
f
/
16.0
f
);
const
half2
y16
=
__halves2half2
(
y16_
,
y16_
);
const
half
z1_
=
__float2half_rn
(
-
1024.0
f
-
8.0
f
);
const
half
z16_
=
__float2half_rn
(
-
1024.0
f
/
16.0
f
-
8.0
f
);
const
half2
z1
=
__halves2half2
(
z1_
,
z1_
);
const
half2
z16
=
__halves2half2
(
z16_
,
z16_
);
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 0], q[ 1]) + 1024
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 2], q[ 3]) * 16 + 1024
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2(q[ 4], q[ 5]) + 1024
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2(q[ 6], q[ 7]) * 16 + 1024
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1
);
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y16
,
z16
);
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1
);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y16
,
z16
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
half2
scale2
=
__half2half2
(
scale
);
z1z16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
z1
.
as_half
));
z1z16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
z16
));
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__hmul2
(
scale2
,
__half2half2
(
y1
));
y1y16
[
1
]
=
__hmul2
(
scale2
,
__half2half2
(
y16
));
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
]
)
{
half_uint16
z1
(
0xe400
|
zero
);
// half(-1024.0f - zero);
half
z16
=
__hsub
(
__int2half_rn
(
-
64
),
__int2half_rn
(
zero
));
z1z16
[
0
]
=
__half2half2
(
z1
.
as_half
);
z1z16
[
1
]
=
__half2half2
(
z16
);
const
half
y1
=
__float2half_rn
(
1.0
f
);
const
half
y16
=
__float2half_rn
(
1.0
f
/
16.0
f
);
y1y16
[
0
]
=
__half2half2
(
y1
);
y1y16
[
1
]
=
__half2half2
(
y16
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1z16
)[
2
],
half2
(
&
y1y16
)[
2
],
int
stride
,
bool
scaled
)
{
const
uint32_t
c0
=
0x64006400
;
uint32_t
qa
=
q_0
;
half2_uint32
q0
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[0] + 1024, q[1] + 1024 )
half2_uint32
q1
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa
>>=
8
;
half2_uint32
q2
((
qa
&
0x000f000f
)
|
c0
);
// half2( q[4] + 1024, q[5] + 1024 )
half2_uint32
q3
((
qa
&
0x00f000f0
)
|
c0
);
// half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
q0
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
// half2( q[0] * s - z * s, q[1] * s - z * s)
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] * s - z * s, q[3] * s - z * s)
dq
[
2
]
=
__hfma2
(
q2
.
as_half2
,
y1y16
[
0
],
z1z16
[
0
]);
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
}
else
{
dq
[
0
]
=
__hadd2
(
q0
.
as_half2
,
z1z16
[
0
]);
// half2( q[0] - z, q[1] - z )
dq
[
1
]
=
__hfma2
(
q1
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[2] - z, q[3] - z )
dq
[
2
]
=
__hadd2
(
q2
.
as_half2
,
z1z16
[
0
]);
// half2( q[4] - z, q[5] - z )
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[6] - z, q[7] - z )
}
}
}
// namespace gptq
}
// namespace vllm
#else
namespace
vllm
{
namespace
gptq
{
__forceinline__
__device__
void
shuffle_4bit_8
(
uint32_t
*
q
,
int
stride
)
{
}
__forceinline__
__device__
void
dequant_4bit_8
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
int
stride
)
{
half
dqh
[
8
];
for
(
int
i
=
0
;
i
<
8
;
i
++
)
dqh
[
i
]
=
dq_ns
(
exb
(
q_0
,
i
*
4
,
0x0f
),
8
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero_scale
(
const
uint32_t
zero
,
const
half
scale
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z
=
__hmul
(
z
,
scale
);
z1
[
0
]
=
__half2half2
(
z
);
y1
[
0
]
=
__half2half2
(
scale
);
}
__forceinline__
__device__
void
dequant_4bit_8_prep_zero
(
const
uint32_t
zero
,
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
]
)
{
half
z
=
__int2half_rn
(
-
((
int
)
zero
));
z1
[
0
]
=
__half2half2
(
z
);
}
__forceinline__
__device__
void
dequant_4bit_8_gptq
(
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
z1
)[
2
],
half2
(
&
y1
)[
2
],
int
stride
,
bool
scaled
)
{
half2
dqh2
[
8
];
uint32_t
qa
=
q_0
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
half
d0
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
half
d1
=
__int2half_rn
(
qa
&
0x0f
);
qa
>>=
4
;
dqh2
[
i
]
=
__halves2half2
(
d0
,
d1
);
}
if
(
scaled
)
{
dq
[
0
]
=
__hfma2
(
dqh2
[
0
],
y1
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hfma2
(
dqh2
[
1
],
y1
[
0
],
z1
[
0
]);
dq
[
2
]
=
__hfma2
(
dqh2
[
2
],
y1
[
0
],
z1
[
0
]);
dq
[
3
]
=
__hfma2
(
dqh2
[
3
],
y1
[
0
],
z1
[
0
]);
}
else
{
dq
[
0
]
=
__hadd2
(
dqh2
[
0
],
z1
[
0
]);
dq
[
1
]
=
__hadd2
(
dqh2
[
1
],
z1
[
0
]);
dq
[
2
]
=
__hadd2
(
dqh2
[
2
],
z1
[
0
]);
dq
[
3
]
=
__hadd2
(
dqh2
[
3
],
z1
[
0
]);
}
}
}
// namespace gptq
}
// namespace vllm
#endif
csrc/quantization/gptq/qdq_util.cuh
0 → 100644
View file @
0fbfc4b8
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace
vllm
{
namespace
gptq
{
union
half2_uint32
{
uint32_t
as_uint32
;
half2
as_half2
;
__device__
half2_uint32
(
uint32_t
val
)
:
as_uint32
(
val
)
{}
__device__
half2_uint32
(
half2
val
)
:
as_half2
(
val
)
{}
};
union
half_uint16
{
uint16_t
as_uint16
;
half
as_half
;
__device__
half_uint16
(
uint16_t
val
)
:
as_uint16
(
val
)
{}
__device__
half_uint16
(
half
val
)
:
as_half
(
val
)
{}
};
// Max_scale premultiplied by 1/256
__forceinline__
__device__
half
dq_scale
(
const
int
qs
,
const
half
max_scale
)
{
int
qs_i
=
qs
+
1
;
half
qs_h
=
__int2half_rn
(
qs_i
*
qs_i
);
qs_h
=
__hmul
(
qs_h
,
max_scale
);
return
qs_h
;
}
__forceinline__
__device__
half
dq
(
const
int
q
,
const
int
qzero
,
const
half
scale
)
{
return
__hmul
(
__int2half_rn
(
q
-
qzero
),
scale
);
}
__forceinline__
__device__
half
dq_ns
(
const
int
q
,
const
int
qzero
)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return
__int2half_rn
(
q
-
qzero
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)((
q
>>
shift
)
&
mask
);
}
__forceinline__
__device__
int
exb
(
const
uint32_t
q1
,
const
uint32_t
q0
,
const
int
shift
,
const
int
mask
)
{
return
(
int
)(
__funnelshift_rc
(
q0
,
q1
,
shift
)
&
mask
);
}
}
// namespace gptq
}
// namespace vllm
#endif
setup.py
View file @
0fbfc4b8
...
...
@@ -219,6 +219,7 @@ vllm_extension_sources = [
"csrc/activation_kernels.cu"
,
"csrc/layernorm_kernels.cu"
,
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
,
"csrc/quantization/gptq/q_gemm.cu"
,
"csrc/cuda_utils_kernels.cu"
,
"csrc/pybind.cpp"
,
]
...
...
vllm/config.py
View file @
0fbfc4b8
...
...
@@ -142,7 +142,7 @@ class ModelConfig:
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
"awq"
,
"squeezellm"
]
supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
]
rocm_not_supported_quantization
=
[
"awq"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
vllm/engine/arg_utils.py
View file @
0fbfc4b8
...
...
@@ -179,7 +179,7 @@ class EngineArgs:
parser
.
add_argument
(
'--quantization'
,
'-q'
,
type
=
str
,
choices
=
[
'awq'
,
'squeezellm'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
default
=
None
,
help
=
'Method used to quantize the weights'
)
return
parser
...
...
vllm/entrypoints/llm.py
View file @
0fbfc4b8
...
...
@@ -38,8 +38,9 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
we support "awq", "gptq" and "squeezellm". If None, we assume the
model weights are not quantized and use `dtype` to determine the
data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
...
...
vllm/model_executor/layers/linear.py
View file @
0fbfc4b8
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
...
...
@@ -21,8 +21,10 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
def
create_weights
(
self
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
torch
.
Tensor
]:
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
"""Create weights for a linear layer."""
raise
NotImplementedError
...
...
@@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
def
create_weights
(
self
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
torch
.
Tensor
]:
weight
=
Parameter
(
torch
.
empty
(
output_size
,
input_size
,
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
),
requires_grad
=
False
)
...
...
@@ -102,9 +106,11 @@ class ReplicatedLinear(torch.nn.Module):
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
self
.
input_size
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
linear_weights
.
items
():
self
.
register_parameter
(
name
,
weight
)
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
...
...
@@ -168,10 +174,12 @@ class ColumnParallelLinear(torch.nn.Module):
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size
,
self
.
output_size_per_partition
,
self
.
params_dtype
)
self
.
input_size
,
self
.
output_size_per_partition
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
linear_weights
.
items
():
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
...
...
@@ -295,10 +303,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
else
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -418,10 +428,12 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
else
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
logger
.
warning
(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -481,10 +493,12 @@ class RowParallelLinear(torch.nn.Module):
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size_per_partition
,
self
.
output_size
,
self
.
params_dtype
)
self
.
input_size_per_partition
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
linear_weights
.
items
():
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
0fbfc4b8
from
typing
import
Type
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
_QUANTIZATION_CONFIG_REGISTRY
=
{
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
}
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
0fbfc4b8
...
...
@@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
input_size
%
self
.
quant_config
.
group_size
!=
0
:
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
if
output_size
%
self
.
quant_config
.
pack_factor
!=
0
:
if
output_size
_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
...
...
@@ -92,8 +94,8 @@ class AWQLinearMethod(LinearMethodBase):
qweight
=
Parameter
(
torch
.
empty
(
input_size
,
output_size
//
self
.
quant_config
.
pack_factor
,
input_size
_per_partition
,
output_size
_per_partition
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
...
...
@@ -108,8 +110,8 @@ class AWQLinearMethod(LinearMethodBase):
})
qzeros
=
Parameter
(
torch
.
empty
(
input_size
//
self
.
quant_config
.
group_size
,
output_size
//
self
.
quant_config
.
pack_factor
,
input_size
_per_partition
//
self
.
quant_config
.
group_size
,
output_size
_per_partition
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
...
...
@@ -124,8 +126,8 @@ class AWQLinearMethod(LinearMethodBase):
})
scales
=
Parameter
(
torch
.
empty
(
input_size
//
self
.
quant_config
.
group_size
,
output_size
,
input_size
_per_partition
//
self
.
quant_config
.
group_size
,
output_size
_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
...
...
@@ -142,7 +144,7 @@ class AWQLinearMethod(LinearMethodBase):
}
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
torch
.
Tensor
],
weights
:
Dict
[
str
,
Any
],
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"qweight"
]
...
...
vllm/model_executor/layers/quantization/gptq.py
0 → 100644
View file @
0fbfc4b8
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm._C
import
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
class
GPTQConfig
(
QuantizationConfig
):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
pack_factor
=
32
//
self
.
weight_bits
# exllama kernel v1 only supports 4 bit
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
"Currently, only 4-bit weight quantization is supported for "
f
"GPTQ, but got
{
self
.
weight_bits
}
bits."
)
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
60
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
)
def
get_linear_method
(
self
)
->
"GPTQLinearMethod"
:
return
GPTQLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ExllamaState
(
Enum
):
UNUSED
=
enum
.
auto
()
UNINITIALIZED
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQLinearMethod
(
LinearMethodBase
):
"""Linear method for GPTQ.
Args:
quant_config: The GPTQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
)
->
Dict
[
str
,
Any
]:
del
output_size
# Unused.
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
exllama_state
=
ExllamaState
.
UNINITIALIZED
scale_and_zero_size
=
input_size
//
group_size
scale_and_zero_input_dim
=
None
if
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
:
# For act-order models, we cannot use Exllama for row parallel layer
if
self
.
quant_config
.
desc_act
:
exllama_state
=
ExllamaState
.
UNUSED
else
:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size
=
input_size_per_partition
//
group_size
scale_and_zero_input_dim
=
0
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
g_idx
=
Parameter
(
torch
.
tensor
(
[
i
//
self
.
quant_config
.
group_size
for
i
in
range
(
input_size_per_partition
)
],
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
"input_dim"
:
0
,
"ignore_warning"
:
True
})
qzeros
=
Parameter
(
torch
.
empty
(
scale_and_zero_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
"input_dim"
:
scale_and_zero_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
scale_and_zero_size
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
scale_and_zero_input_dim
,
"output_dim"
:
1
,
})
return
{
"qweight"
:
qweight
,
"g_idx"
:
g_idx
,
"qzeros"
:
qzeros
,
"scales"
:
scales
,
"exllama_state"
:
exllama_state
,
}
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
],
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"qweight"
]
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if
weights
[
"exllama_state"
]
==
ExllamaState
.
UNINITIALIZED
:
if
self
.
quant_config
.
desc_act
:
weights
[
"g_idx"
]
=
torch
.
argsort
(
weights
[
"g_idx"
]).
to
(
torch
.
int
)
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
])
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/squeezellm.py
View file @
0fbfc4b8
...
...
@@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
SqueezeLLMConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
input_size
%
self
.
quant_config
.
pack_factor
!=
0
:
def
create_weights
(
self
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
if
input_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
qweight
=
Parameter
(
torch
.
empty
(
input_size
//
self
.
quant_config
.
pack_factor
,
output_size
,
input_size
_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size
_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
...
...
@@ -108,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
}
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
torch
.
Tensor
],
weights
:
Dict
[
str
,
Any
],
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"qweight"
]
...
...
vllm/model_executor/models/aquila.py
View file @
0fbfc4b8
...
...
@@ -332,11 +332,18 @@ class AquilaForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/baichuan.py
View file @
0fbfc4b8
...
...
@@ -355,11 +355,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment