Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aea19f09
Unverified
Commit
aea19f09
authored
Jul 12, 2024
by
Robert Shaw
Committed by
GitHub
Jul 12, 2024
Browse files
[ Misc ] Support Models With Bias in `compressed-tensors` integration (#6356)
parent
f7160d94
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
58 additions
and
21 deletions
+58
-21
.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
.../configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
+11
-0
.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml
...configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-0
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
+1
-1
tests/models/test_compressed_tensors.py
tests/models/test_compressed_tensors.py
+3
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+5
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+5
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+7
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+6
-3
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+5
-2
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+2
-4
No files found.
.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
0 → 100644
View file @
aea19f09
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
model_name
:
"
neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.593
-
name
:
"
exact_match,flexible-extract"
value
:
0.588
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml
0 → 100644
View file @
aea19f09
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1
model_name
:
"
nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.595
-
name
:
"
exact_match,flexible-extract"
value
:
0.582
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
aea19f09
...
...
@@ -2,3 +2,4 @@ Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh
View file @
aea19f09
...
...
@@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done
lm_eval
--model
vllm
\
--model_args
pretrained
=
$MODEL
,tensor_parallel_size
=
$TP_SIZE
,add_bos_token
=
true
\
--model_args
pretrained
=
$MODEL
,tensor_parallel_size
=
$TP_SIZE
,add_bos_token
=
true
,distributed_executor_backend
=
"ray"
\
--tasks
gsm8k
--num_fewshot
$FEWSHOT
--limit
$LIMIT
\
--batch_size
$BATCH_SIZE
tests/models/test_compressed_tensors.py
View file @
aea19f09
...
...
@@ -12,7 +12,10 @@ from tests.quantization.utils import is_quant_method_supported
from
.utils
import
check_logprobs_close
MODELS
=
[
# No bias
"nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test"
,
# Bias
"neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
]
MAX_TOKENS
=
32
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
aea19f09
...
...
@@ -267,10 +267,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
"""
if
bias
is
not
None
:
raise
ValueError
(
"bias is not supported for this linear method"
)
scheme
=
layer
.
scheme
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
return
scheme
.
apply_weights
(
layer
,
x
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
View file @
aea19f09
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
torch
...
...
@@ -20,14 +21,16 @@ class CompressedTensorsScheme(ABC):
raise
NotImplementedError
@
abstractmethod
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: toch.nn.Module with the registered weights and
:param layer: to
r
ch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise
NotImplementedError
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
View file @
aea19f09
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
...
...
@@ -37,6 +37,7 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
})
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
weight
=
layer
.
weight
return
F
.
linear
(
x
,
weight
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
aea19f09
...
...
@@ -118,7 +118,9 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
qweight
=
layer
.
weight_packed
meta
=
layer
.
meta
scales
=
layer
.
scale_packed
...
...
@@ -135,4 +137,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
aea19f09
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
Optional
import
torch
from
torch.nn
import
Parameter
...
...
@@ -78,8 +78,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
**
layer_kwargs
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
apply_int8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
)
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
aea19f09
...
...
@@ -148,7 +148,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size
=
layer
.
group_size
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
apply_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight_packed
,
...
...
@@ -159,4 +161,5 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
num_bits
=
self
.
num_bits
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
True
)
is_k_full
=
True
,
bias
=
bias
)
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
aea19f09
...
...
@@ -148,9 +148,6 @@ def apply_int8_linear(
input_scale
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
bias
is
not
None
:
raise
NotImplementedError
(
"W8A8 with int8 does not yet support bias."
)
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
...
...
@@ -160,4 +157,5 @@ def apply_int8_linear(
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
)
out_dtype
=
input
.
dtype
,
bias
=
bias
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment