Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6512937d
Unverified
Commit
6512937d
authored
Jul 31, 2024
by
HandH1998
Committed by
GitHub
Jul 31, 2024
Browse files
Support W4A8 quantization for vllm (#5218)
parent
c0644cf9
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1963 additions
and
84 deletions
+1963
-84
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-0
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/ops.h
csrc/ops.h
+7
-0
csrc/quantization/marlin/dense/common/base.h
csrc/quantization/marlin/dense/common/base.h
+32
-0
csrc/quantization/marlin/dense/common/mem.h
csrc/quantization/marlin/dense/common/mem.h
+89
-0
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+6
-84
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
+1243
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+4
-0
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+66
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+9
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
vllm/model_executor/layers/quantization/qqq.py
vllm/model_executor/layers/quantization/qqq.py
+285
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py
...ecutor/layers/quantization/utils/marlin_utils_test_qqq.py
+125
-0
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+82
-0
No files found.
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
0 → 100644
View file @
6512937d
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
model_name
:
"
HandH1998/QQQ-Llama-3-8b-g128"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.409
-
name
:
"
exact_match,flexible-extract"
value
:
0.406
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
6512937d
...
...
@@ -7,3 +7,4 @@ Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml
CMakeLists.txt
View file @
6512937d
...
...
@@ -170,6 +170,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
...
...
csrc/ops.h
View file @
6512937d
...
...
@@ -115,6 +115,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
torch
::
Tensor
marlin_qqq_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_q_weight
,
torch
::
Tensor
const
&
s_tok
,
torch
::
Tensor
const
&
s_ch
,
torch
::
Tensor
const
&
s_group
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/quantization/marlin/dense/common/base.h
0 → 100644
View file @
6512937d
/*
* Modified by HandH1998
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
csrc/quantization/marlin/dense/common/mem.h
0 → 100644
View file @
6512937d
/*
* Modified by HandH1998
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Asynchronous global->shared copy
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Async copy fence.
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
// Wait until at most `n` async copy stages are still pending.
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
6512937d
...
...
@@ -25,6 +25,12 @@
#include <iostream>
#include "common/base.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h"
#endif
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
...
...
@@ -32,23 +38,9 @@ inline std::string str(T x) {
namespace
marlin_dense
{
constexpr
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
// for instance as inputs to tensor core operations. Consequently, all
// corresponding index accesses must be compile-time constants, which is why we
// extensively use `#pragma unroll` throughout the kernel code to guarantee
// this.
template
<
typename
T
,
int
n
>
struct
Vec
{
T
elems
[
n
];
__device__
T
&
operator
[](
int
i
)
{
return
elems
[
i
];
}
};
using
I4
=
Vec
<
int
,
4
>
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
...
...
@@ -57,43 +49,6 @@ using FragB = Vec<half2, 2>;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__
inline
void
cp_async4_pred
(
void
*
smem_ptr
,
const
void
*
glob_ptr
,
bool
pred
=
true
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" .reg .pred p;
\n
"
" setp.ne.b32 p, %0, 0;
\n
"
" @p cp.async.cg.shared.global [%1], [%2], %3;
\n
"
"}
\n
"
::
"r"
((
int
)
pred
),
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Asynchronous global->shared copy
__device__
inline
void
cp_async4
(
void
*
smem_ptr
,
const
void
*
glob_ptr
)
{
const
int
BYTES
=
16
;
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"{
\n
"
" cp.async.cg.shared.global [%0], [%1], %2;
\n
"
"}
\n
"
::
"r"
(
smem
),
"l"
(
glob_ptr
),
"n"
(
BYTES
));
}
// Async copy fence.
__device__
inline
void
cp_async_fence
()
{
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
}
// Wait until at most `n` async copy stages are still pending.
template
<
int
n
>
__device__
inline
void
cp_async_wait
()
{
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
n
));
}
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__
inline
void
mma
(
const
FragA
&
a_frag
,
const
FragB
&
frag_b
,
...
...
@@ -164,39 +119,6 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__
inline
void
barrier_acquire
(
int
*
lock
,
int
count
)
{
if
(
threadIdx
.
x
==
0
)
{
int
state
=
-
1
;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
while
(
state
!=
count
);
}
__syncthreads
();
}
// Release barrier and increment visitation count.
__device__
inline
void
barrier_release
(
int
*
lock
,
bool
reset
=
false
)
{
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
reset
)
{
lock
[
0
]
=
0
;
return
;
}
int
val
=
1
;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm
volatile
(
"fence.acq_rel.gpu;
\n
"
);
asm
volatile
(
"red.relaxed.gpu.global.add.s32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
val
));
}
}
template
<
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
...
...
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
0 → 100644
View file @
6512937d
This diff is collapsed.
Click to expand it.
csrc/torch_bindings.cpp
View file @
6512937d
...
...
@@ -149,6 +149,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
ops
.
impl
(
"fp8_marlin_gemm"
,
torch
::
kCUDA
,
&
fp8_marlin_gemm
);
// marlin_qqq_gemm for QQQ.
ops
.
def
(
"marlin_qqq_gemm"
,
&
marlin_qqq_gemm
);
ops
.
impl
(
"marlin_qqq_gemm"
,
torch
::
kCUDA
,
&
marlin_qqq_gemm
);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops
.
def
(
...
...
tests/kernels/test_marlin_gemm.py
View file @
6512937d
...
...
@@ -10,6 +10,9 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
,
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.qqq
import
(
MARLIN_QQQ_MAX_PARALLEL
,
MARLIN_QQQ_MIN_THREAD_N
,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
,
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
MARLIN_SUPPORTED_NUM_BITS
,
...
...
@@ -21,6 +24,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_weights
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_24
import
(
marlin_24_quantize
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq
import
(
# noqa: E501
marlin_qqq_quantize
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
awq_pack
,
gptq_pack
,
quantize_weights
,
quantize_weights_with_zp
,
sort_weights
)
...
...
@@ -425,3 +430,64 @@ def test_awq_marlin_gemm(
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"qqq"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_marlin_qqq_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
):
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
print
(
f
"MNK =
{
size_m
}
{
size_n
}
{
size_k
}
"
)
print
(
f
"groupsize =
{
group_size
}
"
)
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
# Quantize activations
s_a
=
a_input
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
].
div
(
int8_traits
.
max
).
to
(
torch
.
float
)
q_a
=
(
a_input
/
s_a
).
round
().
clamp
(
int8_traits
.
min
,
int8_traits
.
max
).
to
(
torch
.
int8
)
# Quantize weights
w_ref
,
marlin_qqq_q_w
,
marlin_qqq_s_group
,
marlin_qqq_s_channel
=
\
marlin_qqq_quantize
(
b_weight
,
num_bits
,
group_size
)
workspace
=
MarlinWorkspace
(
size_n
,
MARLIN_QQQ_MIN_THREAD_N
,
MARLIN_QQQ_MAX_PARALLEL
)
output
=
ops
.
marlin_qqq_gemm
(
q_a
,
marlin_qqq_q_w
,
s_a
,
marlin_qqq_s_channel
,
marlin_qqq_s_group
,
workspace
.
scratch
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
q_a
.
half
()
*
s_a
.
half
(),
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
print
(
"max_diff = {}"
.
format
(
max_diff
))
assert
max_diff
<
0.04
vllm/_custom_ops.py
View file @
6512937d
...
...
@@ -389,6 +389,15 @@ def scaled_int8_quant(
return
output
,
input_scales
# qqq ops
def
marlin_qqq_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
s_tok
:
torch
.
Tensor
,
s_ch
:
torch
.
Tensor
,
s_group
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_C
.
marlin_qqq_gemm
(
a
,
b_q_weight
,
s_tok
,
s_ch
,
s_group
,
workspace
,
size_m
,
size_n
,
size_k
)
# moe
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
6512937d
...
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
...
@@ -37,6 +38,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
}
...
...
vllm/model_executor/layers/quantization/qqq.py
0 → 100644
View file @
6512937d
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
MARLIN_QQQ_TILE
=
16
MARLIN_QQQ_MIN_THREAD_N
=
64
MARLIN_QQQ_MIN_THREAD_K
=
128
MARLIN_QQQ_MAX_PARALLEL
=
16
MARLIN_QQQ_SUPPORTED_NUM_BITS
=
[
4
]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
MARLIN_QQQ_SUPPORTED_SYM
=
[
True
]
class
QQQConfig
(
QuantizationConfig
):
"""Config class for QQQ
Reference: https://arxiv.org/pdf/2406.09904
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
=
True
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
MARLIN_QQQ_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"QQQ does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
MARLIN_QQQ_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"QQQ does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
MARLIN_QQQ_SUPPORTED_SYM
:
raise
ValueError
(
f
"QQQ does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
MARLIN_QQQ_SUPPORTED_SYM
}
are supported."
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight_bits
# Tile size used by QQQ kernels.
self
.
tile_size
=
MARLIN_QQQ_TILE
# Min out_features dim
self
.
min_n_threads
=
MARLIN_QQQ_MIN_THREAD_N
# Min in_features dim
self
.
min_k_threads
=
MARLIN_QQQ_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self
.
max_parallel
=
MARLIN_QQQ_MAX_PARALLEL
# Permutation length used by the QQQ kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
"QQQConfig(weight_bits={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"qqq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
"""List of filenames to search for in the model directory."""
return
[
"quant_config.json"
,
"quantize_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"QQQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
weight_bits
,
group_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QQQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
QQQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
QQQLinearMethod
(
LinearMethodBase
):
"""Linear method for QQQ.
Args:
quant_config: The QQQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
QQQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
(
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
s_channel
=
Parameter
(
torch
.
empty
(
1
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
float
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
s_channel
,
{
"input_dim"
:
None
,
"output_dim"
:
1
,
},
)
if
self
.
quant_config
.
group_size
==
-
1
:
s_group
=
Parameter
(
torch
.
tensor
(
[],
device
=
"cuda"
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
)
else
:
s_group
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
s_group
,
{
"input_dim"
:
None
if
self
.
quant_config
.
group_size
==
-
1
else
0
,
"output_dim"
:
None
if
self
.
quant_config
.
group_size
==
-
1
else
1
,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
layer
.
register_parameter
(
"B"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s_channel"
,
s_channel
)
set_weight_attrs
(
s_channel
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s_group"
,
s_group
)
set_weight_attrs
(
s_group
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
B
s_ch
=
layer
.
s_channel
s_group
=
layer
.
s_group
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
s_ch
.
shape
[
1
]
x_int8
,
s_tok
=
ops
.
scaled_int8_quant
(
x_2d
)
output_2d
=
ops
.
marlin_qqq_gemm
(
x_int8
,
qweight
,
s_tok
,
s_ch
,
s_group
,
workspace
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py
0 → 100644
View file @
6512937d
from
typing
import
List
import
numpy
import
torch
from
.marlin_utils_test
import
marlin_permute_weights
from
.quant_utils
import
get_pack_factor
,
qqq_quantize_weights
def
marlin_qqq_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
,
group_size
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
if
group_size
==
size_k
:
for
i
in
range
(
pack_factor
):
q_packed
|=
(
q_w
[:,
i
::
pack_factor
]
&
0xF
)
<<
num_bits
*
i
else
:
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
get_qqq_scale_perms
():
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
def
get_qqq_weight_perm
(
num_bits
:
int
,
quant_type
:
str
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
4
*
(
i
%
4
),
4
*
(
i
%
4
)
+
1
,
4
*
(
i
%
4
)
+
2
,
4
*
(
i
%
4
)
+
3
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
assert
quant_type
in
[
"per-channel"
,
"per-group"
],
"not supported quantization type"
if
num_bits
==
4
:
if
quant_type
==
"per-channel"
:
interleave
=
numpy
.
array
([
4
,
0
,
5
,
1
,
6
,
2
,
7
,
3
])
else
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
else
:
raise
Exception
(
"num_bits must be 4, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_qqq_permute_scales
(
s_group
,
s_channel
,
size_k
,
size_n
,
group_size
):
scale_perm
,
scale_perm_single
=
get_qqq_scale_perms
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
s_group
=
s_group
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
s_channel
=
s_channel
.
reshape
(
(
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s_group
=
s_group
.
reshape
((
-
1
,
size_n
)).
contiguous
()
else
:
s_channel
=
s_channel
.
reshape
(
(
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s_channel
=
s_channel
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s_group
,
s_channel
def
marlin_qqq_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
quant_type
=
"per-channel"
if
group_size
==
size_k
else
"per-group"
# Quantize
w_ref
,
q_w
,
s_group
,
s_channel
=
qqq_quantize_weights
(
w
,
num_bits
,
group_size
)
# Reformat to marlin_qqq
weight_perm
=
get_qqq_weight_perm
(
num_bits
,
quant_type
)
marlin_qqq_q_w
=
marlin_qqq_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
,
group_size
)
marlin_qqq_s_group
,
marlin_qqq_s_channel
=
marlin_qqq_permute_scales
(
s_group
,
s_channel
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_ref
,
marlin_qqq_q_w
,
marlin_qqq_s_group
,
marlin_qqq_s_channel
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
6512937d
...
...
@@ -205,6 +205,88 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
)
# QQQ employs different quant schemes for per-group and
# per-channel quantization.
def
qqq_quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
if
group_size
<
size_k
:
# Reshape to [groupsize, -1]
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Compute scale for each group
s_group
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s_group
*=
2
/
max_q_val
# 2 => symmetric
# Quantize
q_w
=
torch
.
round
(
w
/
s_group
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s_group
# Restore original shapes
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
# Compute int8 quantization scale for each channel
s_channel
=
torch
.
max
(
torch
.
abs
(
w_ref
),
0
,
keepdim
=
True
)[
0
]
s_channel
/=
127.0
t_int8
=
(
w_ref
/
s_channel
).
round
().
clamp
(
-
128
,
127
).
to
(
torch
.
int8
)
w_ref
=
t_int8
.
half
()
*
s_channel
s_channel
=
s_channel
.
reshape
(
1
,
-
1
).
to
(
dtype
=
torch
.
float
)
# Fuse scales
s_group
=
(
s_group
.
reshape
(
-
1
,
size_n
).
contiguous
()
/
s_channel
).
to
(
dtype
=
torch
.
half
)
else
:
max_q_val
=
2
**
(
num_bits
-
1
)
-
1
# Compute scale for each channel
s_channel
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s_channel
/=
max_q_val
# Quantize
q_w
=
torch
.
round
(
w
/
s_channel
).
int
()
q_w
=
torch
.
clamp
(
q_w
,
-
max_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
q_w
.
half
()
*
s_channel
s_group
=
torch
.
tensor
([],
dtype
=
torch
.
half
)
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
s_channel
/=
(
2
**
(
8
-
num_bits
))
s_channel
=
s_channel
.
reshape
(
-
1
,
size_n
).
contiguous
().
to
(
torch
.
float
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s_group
.
to
(
device
=
orig_device
),
s_channel
.
to
(
device
=
orig_device
),
)
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment