Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
23f32229
Unverified
Commit
23f32229
authored
Sep 06, 2024
by
Dipika Sikka
Committed by
GitHub
Sep 06, 2024
Browse files
[Misc] Remove `SqueezeLLM` (#8220)
parent
9db52eab
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
6 additions
and
389 deletions
+6
-389
CMakeLists.txt
CMakeLists.txt
+0
-1
csrc/ops.h
csrc/ops.h
+0
-3
csrc/quantization/squeezellm/quant_cuda_kernel.cu
csrc/quantization/squeezellm/quant_cuda_kernel.cu
+0
-216
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-6
docs/source/quantization/supported_hardware.rst
docs/source/quantization/supported_hardware.rst
+0
-11
examples/fp8/README.md
examples/fp8/README.md
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-6
vllm/config.py
vllm/config.py
+2
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-1
vllm/lora/layers.py
vllm/lora/layers.py
+1
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+0
-2
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+0
-138
No files found.
CMakeLists.txt
View file @
23f32229
...
@@ -181,7 +181,6 @@ set(VLLM_EXT_SRC
...
@@ -181,7 +181,6 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fp8/common.cu"
...
...
csrc/ops.h
View file @
23f32229
...
@@ -170,9 +170,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
...
@@ -170,9 +170,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scales
);
torch
::
Tensor
&
scales
);
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
...
...
csrc/quantization/squeezellm/quant_cuda_kernel.cu
deleted
100644 → 0
View file @
9db52eab
#include <torch/all.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
namespace
vllm
{
namespace
squeezellm
{
__device__
inline
unsigned
int
as_unsigned
(
int
i
)
{
return
*
reinterpret_cast
<
unsigned
int
*>
(
&
i
);
}
// 4-bit matvec kernel (LUT-based)
__global__
void
NUQ4MatMulKernel
(
#ifndef USE_ROCM
const
half2
*
__restrict__
vec
,
#else
const
__half2
*
__restrict__
vec
,
#endif
const
int
*
__restrict__
mat
,
#ifndef USE_ROCM
half2
*
__restrict__
mul
,
#else
float2
*
__restrict__
mul
,
#endif
const
__half
*
__restrict__
lookup_table
,
int
height
,
int
width
,
int
batch
,
int
vec_height
)
{
const
int
blockwidth2
=
BLOCKWIDTH
/
2
;
int
row
=
BLOCKHEIGHT4
*
blockIdx
.
x
;
int
col
=
BLOCKWIDTH
*
blockIdx
.
y
+
threadIdx
.
x
;
#ifndef USE_ROCM
__shared__
half2
blockvec
[
blockwidth2
];
#else
__shared__
__half2
blockvec
[
blockwidth2
];
#endif
__shared__
__half
deq2
[
16
][
BLOCKWIDTH
];
int
off
=
threadIdx
.
x
;
int
column_offset
=
col
*
16
;
for
(
int
val
=
0
;
val
<
16
;
val
+=
1
)
{
int
lut_index
=
column_offset
+
val
;
deq2
[
val
][
off
]
=
lookup_table
[
lut_index
];
}
__half
res
;
#ifndef USE_ROCM
half2
res2
;
half2
tmp2
;
#else
__half2
res2
;
__half2
tmp2
;
#endif
int
i
;
int
k
;
unsigned
int
tmp1
;
unsigned
int
lut_index1
,
lut_index2
;
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
i
=
width
*
row
+
col
;
res
=
__int2half_rd
(
0
);
k
=
0
;
__syncthreads
();
if
(
threadIdx
.
x
<
blockwidth2
)
blockvec
[
threadIdx
.
x
]
=
vec
[
b
*
vec_height
/
2
+
(
row
/
BLOCKHEIGHT4
)
*
blockwidth2
+
threadIdx
.
x
];
__syncthreads
();
while
(
k
<
blockwidth2
)
{
tmp1
=
as_unsigned
(
mat
[
i
]);
#ifndef USE_ROCM
res2
=
{};
tmp2
=
{};
#else
res2
.
x
=
__half_as_ushort
(
__float2half
(
0
));
res2
.
y
=
__half_as_ushort
(
__float2half
(
0
));
tmp2
.
x
=
__half_as_ushort
(
__float2half
(
0
));
tmp2
.
y
=
__half_as_ushort
(
__float2half
(
0
));
#endif
lut_index1
=
tmp1
&
0xF
;
lut_index2
=
(
tmp1
>>
4
)
&
0xF
;
#ifndef USE_ROCM
tmp2
.
x
=
deq2
[
lut_index1
][
off
];
tmp2
.
y
=
deq2
[
lut_index2
][
off
];
#else
tmp2
.
x
=
__half_as_ushort
(
deq2
[
lut_index1
][
off
]);
tmp2
.
y
=
__half_as_ushort
(
deq2
[
lut_index2
][
off
]);
#endif
res2
=
__hfma2
(
tmp2
,
blockvec
[
k
+
0
],
res2
);
lut_index1
=
(
tmp1
>>
8
)
&
0xF
;
lut_index2
=
(
tmp1
>>
12
)
&
0xF
;
#ifndef USE_ROCM
tmp2
.
x
=
deq2
[
lut_index1
][
off
];
tmp2
.
y
=
deq2
[
lut_index2
][
off
];
#else
tmp2
.
x
=
__half_as_ushort
(
deq2
[
lut_index1
][
off
]);
tmp2
.
y
=
__half_as_ushort
(
deq2
[
lut_index2
][
off
]);
#endif
res2
=
__hfma2
(
tmp2
,
blockvec
[
k
+
1
],
res2
);
lut_index1
=
(
tmp1
>>
16
)
&
0xF
;
lut_index2
=
(
tmp1
>>
20
)
&
0xF
;
#ifndef USE_ROCM
tmp2
.
x
=
deq2
[
lut_index1
][
off
];
tmp2
.
y
=
deq2
[
lut_index2
][
off
];
#else
tmp2
.
x
=
__half_as_ushort
(
deq2
[
lut_index1
][
off
]);
tmp2
.
y
=
__half_as_ushort
(
deq2
[
lut_index2
][
off
]);
#endif
res2
=
__hfma2
(
tmp2
,
blockvec
[
k
+
2
],
res2
);
lut_index1
=
(
tmp1
>>
24
)
&
0xF
;
lut_index2
=
(
tmp1
>>
28
)
&
0xF
;
#ifndef USE_ROCM
tmp2
.
x
=
deq2
[
lut_index1
][
off
];
tmp2
.
y
=
deq2
[
lut_index2
][
off
];
#else
tmp2
.
x
=
__half_as_ushort
(
deq2
[
lut_index1
][
off
]);
tmp2
.
y
=
__half_as_ushort
(
deq2
[
lut_index2
][
off
]);
#endif
res2
=
__hfma2
(
tmp2
,
blockvec
[
k
+
3
],
res2
);
#ifndef USE_ROCM
res
=
__hadd
(
__hadd
(
res2
.
x
,
res2
.
y
),
res
);
#else
res
=
__hadd
(
__hadd
(
__ushort_as_half
(
res2
.
x
),
__ushort_as_half
(
res2
.
y
)),
res
);
#endif
i
+=
width
;
k
+=
4
;
}
// col%2 -> only set one of the two values
#ifndef USE_ROCM
half2
res3
=
{};
if
(
col
%
2
==
0
)
{
res3
.
x
=
res
;
}
else
{
res3
.
y
=
res
;
}
#else
__half2
res3
;
res3
.
x
=
__half_as_ushort
(
__float2half
(
0
));
res3
.
y
=
__half_as_ushort
(
__float2half
(
0
));
if
(
col
%
2
==
0
)
{
res3
.
x
=
__half_as_ushort
(
res
);
}
else
{
res3
.
y
=
__half_as_ushort
(
res
);
}
#endif
#ifndef USE_ROCM
atomicAdd
(
&
mul
[
b
*
width
/
2
+
col
/
2
],
res3
);
#else
int
tmp_addr
=
b
*
width
/
2
+
col
/
2
;
atomicAdd
(
&
(
mul
[
tmp_addr
].
x
),
__half2float
(
__ushort_as_half
(
res3
.
x
)));
atomicAdd
(
&
(
mul
[
tmp_addr
].
y
),
__half2float
(
__ushort_as_half
(
res3
.
y
)));
#endif
}
}
}
// namespace squeezellm
}
// namespace vllm
// 4-bit matvec kernel (LUT-based)
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
)
{
int
height
=
mat
.
size
(
0
);
int
width
=
mat
.
size
(
1
);
int
batch
=
vec
.
size
(
0
);
int
vec_height
=
vec
.
size
(
1
);
dim3
blocks
((
height
+
BLOCKHEIGHT4
-
1
)
/
BLOCKHEIGHT4
,
(
width
+
BLOCKWIDTH
-
1
)
/
BLOCKWIDTH
);
dim3
threads
(
BLOCKWIDTH
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
vec
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
vllm
::
squeezellm
::
NUQ4MatMulKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
#ifndef USE_ROCM
(
half2
*
)
vec
.
data_ptr
<
at
::
Half
>
(),
#else
(
__half2
*
)
vec
.
data_ptr
<
at
::
Half
>
(),
#endif
mat
.
data_ptr
<
int
>
(),
#ifndef USE_ROCM
(
half2
*
)
mul
.
data_ptr
<
at
::
Half
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
#else
(
float2
*
)
mul
.
data_ptr
<
float
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
#endif
height
,
width
,
batch
,
vec_height
);
}
#undef BLOCKWIDTH
#undef BLOCKHEIGHT4
csrc/torch_bindings.cpp
View file @
23f32229
...
@@ -237,12 +237,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -237,12 +237,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"
);
ops
.
def
(
"gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// Quantized GEMM for SqueezeLLM.
ops
.
def
(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
"lookup_table) -> ()"
);
ops
.
impl
(
"squeezellm_gemm"
,
torch
::
kCUDA
,
&
squeezellm_gemm
);
// Compute FP8 quantized tensor for given scaling factor.
// Compute FP8 quantized tensor for given scaling factor.
ops
.
def
(
ops
.
def
(
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"
);
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"
);
...
...
docs/source/quantization/supported_hardware.rst
View file @
23f32229
...
@@ -119,17 +119,6 @@ The table below shows the compatibility of various quantization implementations
...
@@ -119,17 +119,6 @@ The table below shows the compatibility of various quantization implementations
- ✗
- ✗
- ✗
- ✗
- ✗
- ✗
* - SqueezeLLM
- ✅︎
- ✅︎
- ✅︎
- ✅︎
- ✅︎
- ✗
- ✗
- ✗
- ✗
- ✗
Notes:
Notes:
^^^^^^
^^^^^^
...
...
examples/fp8/README.md
View file @
23f32229
...
@@ -62,7 +62,7 @@ This script evaluates the inference throughput of language models using various
...
@@ -62,7 +62,7 @@ This script evaluates the inference throughput of language models using various
python3
benchmarks
/
benchmark_throughput
.
py
--
help
python3
benchmarks
/
benchmark_throughput
.
py
--
help
usage
:
benchmark_throughput
.
py
[
-
h
]
[
--
backend
{
vllm
,
hf
,
mii
}]
[
--
dataset
DATASET
]
[
--
input
-
len
INPUT_LEN
]
[
--
output
-
len
OUTPUT_LEN
]
[
--
model
MODEL
]
usage
:
benchmark_throughput
.
py
[
-
h
]
[
--
backend
{
vllm
,
hf
,
mii
}]
[
--
dataset
DATASET
]
[
--
input
-
len
INPUT_LEN
]
[
--
output
-
len
OUTPUT_LEN
]
[
--
model
MODEL
]
[
--
tokenizer
TOKENIZER
]
[
--
quantization
{
awq
,
gptq
,
squeezellm
,
None
}]
[
--
tensor
-
parallel
-
size
TENSOR_PARALLEL_SIZE
]
[
--
n
N
]
[
--
tokenizer
TOKENIZER
]
[
--
quantization
{
awq
,
gptq
,
None
}]
[
--
tensor
-
parallel
-
size
TENSOR_PARALLEL_SIZE
]
[
--
n
N
]
[
--
use
-
beam
-
search
]
[
--
num
-
prompts
NUM_PROMPTS
]
[
--
seed
SEED
]
[
--
hf
-
max
-
batch
-
size
HF_MAX_BATCH_SIZE
]
[
--
trust
-
remote
-
code
]
[
--
use
-
beam
-
search
]
[
--
num
-
prompts
NUM_PROMPTS
]
[
--
seed
SEED
]
[
--
hf
-
max
-
batch
-
size
HF_MAX_BATCH_SIZE
]
[
--
trust
-
remote
-
code
]
[
--
max
-
model
-
len
MAX_MODEL_LEN
]
[
--
dtype
{
auto
,
half
,
float16
,
bfloat16
,
float
,
float32
}]
[
--
enforce
-
eager
]
[
--
kv
-
cache
-
dtype
{
auto
,
fp8
}]
[
--
max
-
model
-
len
MAX_MODEL_LEN
]
[
--
dtype
{
auto
,
half
,
float16
,
bfloat16
,
float
,
float32
}]
[
--
enforce
-
eager
]
[
--
kv
-
cache
-
dtype
{
auto
,
fp8
}]
[
--
quantization
-
param
-
path
KV_CACHE_quantization_param_path
]
[
--
quantization
-
param
-
path
KV_CACHE_quantization_param_path
]
...
@@ -76,7 +76,7 @@ optional arguments:
...
@@ -76,7 +76,7 @@ optional arguments:
--
output
-
len
OUTPUT_LEN
Output
length
for
each
request
.
Overrides
the
output
length
from
the
dataset
.
--
output
-
len
OUTPUT_LEN
Output
length
for
each
request
.
Overrides
the
output
length
from
the
dataset
.
--
model
MODEL
--
model
MODEL
--
tokenizer
TOKENIZER
--
tokenizer
TOKENIZER
--
quantization
{
awq
,
gptq
,
squeezellm
,
None
},
-
q
{
awq
,
gptq
,
squeezellm
,
None
}
--
quantization
{
awq
,
gptq
,
None
},
-
q
{
awq
,
gptq
,
None
}
--
tensor
-
parallel
-
size
TENSOR_PARALLEL_SIZE
,
-
tp
TENSOR_PARALLEL_SIZE
--
tensor
-
parallel
-
size
TENSOR_PARALLEL_SIZE
,
-
tp
TENSOR_PARALLEL_SIZE
--
n
N
Number
of
generated
sequences
per
prompt
.
--
n
N
Number
of
generated
sequences
per
prompt
.
--
use
-
beam
-
search
--
use
-
beam
-
search
...
...
vllm/_custom_ops.py
View file @
23f32229
...
@@ -209,12 +209,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
...
@@ -209,12 +209,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
lookup_table
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
squeezellm_gemm
(
vec
,
mat
,
mul
,
lookup_table
)
# marlin
# marlin
def
marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
def
marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
...
...
vllm/config.py
View file @
23f32229
...
@@ -277,7 +277,7 @@ class ModelConfig:
...
@@ -277,7 +277,7 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
,
"fp8"
]
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"fp8"
]
optimized_quantization_methods
=
[
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
...
@@ -1537,7 +1537,7 @@ class LoRAConfig:
...
@@ -1537,7 +1537,7 @@ class LoRAConfig:
if
model_config
.
quantization
and
model_config
.
quantization
not
in
[
if
model_config
.
quantization
and
model_config
.
quantization
not
in
[
"awq"
,
"gptq"
"awq"
,
"gptq"
]:
]:
# TODO support marlin
and squeezellm
# TODO support marlin
logger
.
warning
(
"%s quantization is not tested with LoRA yet."
,
logger
.
warning
(
"%s quantization is not tested with LoRA yet."
,
model_config
.
quantization
)
model_config
.
quantization
)
...
...
vllm/entrypoints/llm.py
View file @
23f32229
...
@@ -55,7 +55,7 @@ class LLM:
...
@@ -55,7 +55,7 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq",
"squeezellm",
and "fp8" (experimental).
we support "awq", "gptq", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the
If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are
model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of
not quantized and use `dtype` to determine the data type of
...
...
vllm/lora/layers.py
View file @
23f32229
...
@@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
...
@@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# unquantizedLinear
# unquantizedLinear
if
hasattr
(
base_layer
,
"weight"
):
if
hasattr
(
base_layer
,
"weight"
):
return
base_layer
.
weight
.
device
return
base_layer
.
weight
.
device
# GPTQ/AWQ
/SqueezeLLM
# GPTQ/AWQ
elif
hasattr
(
base_layer
,
"qweight"
):
elif
hasattr
(
base_layer
,
"qweight"
):
return
base_layer
.
qweight
.
device
return
base_layer
.
qweight
.
device
# marlin
# marlin
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
23f32229
...
@@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
...
@@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
NeuronQuantConfig
)
NeuronQuantConfig
)
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
@@ -43,7 +42,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -43,7 +42,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"qqq"
:
QQQConfig
,
...
...
vllm/model_executor/layers/quantization/squeezellm.py
deleted
100644 → 0
View file @
9db52eab
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
class
SqueezeLLMConfig
(
QuantizationConfig
):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def
__init__
(
self
,
weight_bits
:
int
,
)
->
None
:
self
.
weight_bits
=
weight_bits
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
"Currently, only 4-bit weight quantization is supported for "
f
"SqueezeLLM, but got
{
self
.
weight_bits
}
bits."
)
self
.
pack_factor
=
32
//
self
.
weight_bits
def
__repr__
(
self
)
->
str
:
return
f
"SqueezeLLMConfig(weight_bits=
{
self
.
weight_bits
}
)"
def
get_name
(
self
)
->
str
:
return
"squeezellm"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SqueezeLLMConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
return
cls
(
weight_bits
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SqueezeLLMLinearMethod
(
QuantizeMethodBase
):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def
__init__
(
self
,
quant_config
:
SqueezeLLMConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
input_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
lookup_table
=
Parameter
(
torch
.
empty
(
output_size
,
self
.
quant_config
.
weight_bits
**
2
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
lookup_table
,
{
"output_dim"
:
0
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
lookup_table
=
layer
.
lookup_table
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
is_hip
():
out_f
=
torch
.
zeros
(
out_shape
,
dtype
=
torch
.
float
)
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out_f
,
lookup_table
)
out
=
out_f
.
to
(
dtype
=
torch
.
float16
)
else
:
# NOTE: The output tensor should be zero-initialized.
out
=
torch
.
zeros
(
out_shape
,
dtype
=
torch
.
float16
)
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out
,
lookup_table
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment