Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
58ca6632
Unverified
Commit
58ca6632
authored
Jul 18, 2024
by
Robert Shaw
Committed by
GitHub
Jul 18, 2024
Browse files
[ Misc ] Improve Min Capability Checking in `compressed-tensors` (#6522)
parent
4634c872
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
41 additions
and
8 deletions
+41
-8
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+14
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+7
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+4
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
58ca6632
...
@@ -37,7 +37,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -37,7 +37,7 @@ class CompressedTensorsConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
7
5
return
7
0
def
get_name
(
self
)
->
str
:
def
get_name
(
self
)
->
str
:
return
"compressed_tensors"
return
"compressed_tensors"
...
@@ -85,12 +85,13 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -85,12 +85,13 @@ class CompressedTensorsConfig(QuantizationConfig):
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
return
[]
def
_check_
gptq_and_marlin_can_run
(
self
):
def
_check_
scheme_supported
(
self
,
min_capability
:
int
):
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
80
:
if
capability
<
min_capability
:
raise
RuntimeError
(
"The quantization config is not supported for "
,
raise
RuntimeError
(
"the current GPU. Minimum capability: 80. "
,
"Quantization scheme is not supported for "
,
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"Current capability:
{
capability
}
."
)
f
"Current capability:
{
capability
}
."
)
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
...
@@ -171,7 +172,6 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -171,7 +172,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# Detect If Mixed Precision
# Detect If Mixed Precision
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
self
.
_check_gptq_and_marlin_can_run
()
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
return
CompressedTensorsW4A16Sparse24
(
return
CompressedTensorsW4A16Sparse24
(
...
@@ -222,10 +222,16 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -222,10 +222,16 @@ class CompressedTensorsConfig(QuantizationConfig):
raise
ValueError
(
raise
ValueError
(
f
"Could not find quantization details for
{
layer
}
."
)
f
"Could not find quantization details for
{
layer
}
."
)
return
self
.
_get_schema
(
scheme
=
self
.
_get_schema
(
weight_quant
=
layer_quant_details
[
"weights"
],
weight_quant
=
layer_quant_details
[
"weights"
],
input_quant
=
layer_quant_details
[
"input_activations"
])
input_quant
=
layer_quant_details
[
"input_activations"
])
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self
.
_check_scheme_supported
(
scheme
.
get_min_capability
())
return
scheme
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
View file @
58ca6632
...
@@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
...
@@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors.
of different quantization schemes supported by CompressedTensors.
"""
"""
@
abstractmethod
def
get_min_capability
(
self
)
->
int
:
"""
Get minimum device capability.
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
create_weights
(
self
,
*
args
,
**
kwargs
):
def
create_weights
(
self
,
*
args
,
**
kwargs
):
"""
"""
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
View file @
58ca6632
...
@@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
...
@@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
in a linear transformation.
"""
"""
def
get_min_capability
(
self
)
->
int
:
# volta and up
return
70
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
pass
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
58ca6632
...
@@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
...
@@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
raise
ValueError
(
raise
ValueError
(
"group_size must be given when using strategy group"
)
"group_size must be given when using strategy group"
)
def
get_min_capability
(
self
)
->
int
:
# ampere + up
return
80
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
pass
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
58ca6632
...
@@ -33,6 +33,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -33,6 +33,10 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
"Consider quantizing with per tensor scales or upgrading "
"Consider quantizing with per tensor scales or upgrading "
"to Hopper."
)
"to Hopper."
)
def
get_min_capability
(
self
)
->
int
:
# lovelace and up
return
89
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# If per tensor, when we have a fused module (e.g. QKV) with per
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# tensor scales (thus N scales being passed to the kernel),
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
58ca6632
...
@@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
def
get_min_capability
(
self
)
->
int
:
# turing and up
return
75
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# WEIGHT
# WEIGHT
# Cutlass kernels need transposed weight.
# Cutlass kernels need transposed weight.
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
58ca6632
...
@@ -42,6 +42,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -42,6 +42,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
is_sym
=
True
)
is_sym
=
True
)
def
get_min_capability
(
self
)
->
int
:
# ampere and up
return
80
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment