Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b2005e1
Unverified
Commit
3b2005e1
authored
Feb 05, 2025
by
Rahul Tuli
Committed by
GitHub
Feb 05, 2025
Browse files
Add: Support for Sparse24Bitmask Compressed Models
parent
af8486de
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
503 additions
and
112 deletions
+503
-112
.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
...l-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
+11
-0
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+276
-56
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+26
-8
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
...ation/compressed_tensors/schemes/compressed_tensors_24.py
+190
-48
No files found.
.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
0 → 100644
View file @
3b2005e1
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name
:
"
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.6353
-
name
:
"
exact_match,flexible-extract"
value
:
0.637
limit
:
null
num_fewshot
:
null
tests/quantization/test_compressed_tensors.py
View file @
3b2005e1
...
...
@@ -3,6 +3,7 @@
Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
from
typing
import
Optional
import
pytest
...
...
@@ -22,12 +23,30 @@ from vllm.platforms import current_platform
@
pytest
.
mark
.
parametrize
(
"model_args"
,
[(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
"tensor"
,
QuantizationType
.
INT
,
2560
,
True
),
(
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"
,
"channel"
,
QuantizationType
.
INT
,
2560
,
True
),
(
"nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama"
,
"tensor"
,
QuantizationType
.
INT
,
2560
,
False
)])
[
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
"tensor"
,
QuantizationType
.
INT
,
2560
,
True
,
),
(
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"
,
"channel"
,
QuantizationType
.
INT
,
2560
,
True
,
),
(
"nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama"
,
"tensor"
,
QuantizationType
.
INT
,
2560
,
False
,
),
],
)
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
...
...
@@ -85,21 +104,31 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert
output
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
,
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym"
,
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym"
,
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
])
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
,
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym"
,
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym"
,
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
,
],
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
def
test_compressed_tensors_w8a8_logprobs
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_path
,
max_tokens
,
num_logprobs
):
def
test_compressed_tensors_w8a8_logprobs
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_path
,
max_tokens
,
num_logprobs
,
):
dtype
=
"bfloat16"
# skip language translation prompt for the static per tensor asym model
if
model_path
==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
:
# noqa: E501
if
(
model_path
==
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
):
# noqa: E501
example_prompts
=
example_prompts
[
0
:
-
1
]
with
hf_runner
(
model_path
,
dtype
=
dtype
)
as
hf_model
:
...
...
@@ -125,13 +154,21 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
assert
output
@
pytest
.
mark
.
parametrize
(
"model_args"
,
[
(
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
,
"tensor"
),
(
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym"
,
"tensor"
),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
,
"channel"
),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym"
,
"channel"
),
])
@
pytest
.
mark
.
parametrize
(
"model_args"
,
[
(
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
,
"tensor"
),
(
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym"
,
"tensor"
),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
,
"channel"
,
),
(
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym"
,
"channel"
,
),
],
)
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
):
model_path
,
strategy
=
model_args
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
...
...
@@ -156,9 +193,12 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
@
pytest
.
mark
.
parametrize
(
"wNa16_args"
,
[(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
"channel"
,
None
,
8
),
(
"nm-testing/tinyllama-oneshot-w4a16-group128-v2"
,
"group"
,
128
,
8
),
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
)])
[
(
"nm-testing/tinyllama-oneshot-w4a16-channel-v2"
,
"channel"
,
None
,
8
),
(
"nm-testing/tinyllama-oneshot-w4a16-group128-v2"
,
"group"
,
128
,
8
),
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
),
],
)
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
with
vllm_runner
(
model
)
as
llm
:
...
...
@@ -218,7 +258,8 @@ def test_compressed_tensors_fp8(vllm_runner):
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
(
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
))
(
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
),
)
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
...
...
@@ -241,9 +282,14 @@ def test_compressed_tensors_kv_cache(vllm_runner):
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
)
def
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
):
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
def
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"dense"
):
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
...
...
@@ -252,22 +298,39 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
assert
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
sparsity_map
.
get
(
"Linear"
).
format
==
format
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing"
,
"channel"
,
"token"
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing"
,
"channel"
,
"tensor"
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing"
,
"tensor"
,
"tensor"
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing"
,
"tensor"
,
"token"
),
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing"
,
"channel"
,
"token"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing"
,
"tensor"
,
"tensor"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing"
,
"tensor"
,
"token"
,
),
],
)
def
test_compressed_tensors_2of4_quant_fp8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
...
...
@@ -286,16 +349,134 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing"
,
"channel"
,
"token"
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing"
,
"tensor"
,
"tensor"
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing"
,
"tensor"
,
"token"
),
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM"
,
"tensor"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM"
,
"tensor"
,
"tensor"
,
),
],
)
def
test_compressed_tensors_2of4_quant_fp8_compressed
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"sparse-24-bitmask"
,
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"cutlass is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM"
,
"tensor"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM"
,
"tensor"
,
"tensor"
,
),
],
)
def
test_compressed_tensors_2of4_quant_int8_compressed
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"sparse-24-bitmask"
,
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing"
,
"tensor"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing"
,
"tensor"
,
"token"
,
),
],
)
def
test_compressed_tensors_2of4_quant_int8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
...
...
@@ -317,10 +498,12 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
@
pytest
.
mark
.
skip
(
reason
=
"2of4 sparse w16a16 CUTLASS produces bad output."
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"2of4 Sparse is not yet supported on this GPU type."
)
reason
=
"2of4 Sparse is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor"
)])
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor"
)],
)
def
test_compressed_tensors_2of4_sparse
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
...
...
@@ -337,7 +520,9 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
sparsity_map
=
(
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
)
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
...
...
@@ -346,3 +531,38 @@ def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Cutlass is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[(
"nm-testing/llama2.c-stories42M-pruned2.4-compressed"
)])
def
test_compressed_tensors_2of4_sparse_compressed
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
(
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
)
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"sparse-24-bitmask"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
3b2005e1
...
...
@@ -417,15 +417,22 @@ class CompressedTensorsConfig(QuantizationConfig):
return
None
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme
=
CompressedTensors24
(
quantized
=
weight_quant
is
not
None
or
input_quant
is
not
None
,
weight_quant
=
weight_quant
,
input_quant
=
input_quant
)
model_compression_config
=
(
None
if
sparsity_scheme
is
None
or
sparsity_scheme
.
format
==
"dense"
else
self
.
config
)
scheme
=
CompressedTensors24
(
quantized
=
weight_quant
is
not
None
or
input_quant
is
not
None
,
weight_quant
=
weight_quant
,
input_quant
=
input_quant
,
model_compression_config
=
model_compression_config
,
)
elif
weight_quant
is
None
:
logger
.
warning_once
(
"Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod"
)
return
None
else
:
# Find the quant_scheme
scheme
=
self
.
_get_scheme_from_parts
(
# type: ignore
...
...
@@ -475,10 +482,21 @@ class CompressedTensorsConfig(QuantizationConfig):
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity
=
(
sparsity_scheme
is
not
None
and
sparsity_scheme
.
sparsity_structure
==
SparsityStructure
.
TWO_FOUR
.
value
and
sparsity_scheme
.
format
==
"dense"
)
if
sparsity_scheme
is
None
:
return
False
is_valid_sparsity_structure
:
bool
=
(
sparsity_scheme
.
sparsity_structure
==
SparsityStructure
.
TWO_FOUR
.
value
)
valid_compressors
=
{
CompressionFormat
.
dense
.
value
,
CompressionFormat
.
sparse_24_bitmask
.
value
}
is_valid_sparsity
=
(
is_valid_sparsity_structure
and
sparsity_scheme
.
format
in
valid_compressors
)
if
not
is_valid_sparsity
:
return
False
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
View file @
3b2005e1
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
compressed_tensors
import
CompressionFormat
,
ModelCompressor
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
)
from
compressed_tensors.utils
import
combine_shards
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
@@ -22,26 +26,39 @@ __all__ = ["CompressedTensors24"]
class
CompressedTensors24
(
CompressedTensorsScheme
):
def
__init__
(
self
,
quantized
:
bool
=
False
,
weight_quant
:
Optional
[
QuantizationArgs
]
=
None
,
input_quant
:
Optional
[
QuantizationArgs
]
=
None
):
def
__init__
(
self
,
quantized
:
bool
=
False
,
weight_quant
:
Optional
[
QuantizationArgs
]
=
None
,
input_quant
:
Optional
[
QuantizationArgs
]
=
None
,
model_compression_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
self
.
quantized
=
quantized
self
.
weight_quant
=
weight_quant
self
.
input_quant
=
input_quant
self
.
model_compressor
=
(
ModelCompressor
.
from_compression_config
(
model_compression_config
)
if
model_compression_config
is
not
None
else
None
)
self
.
do_sparse_decompress
=
(
self
.
model_compressor
is
not
None
and
self
.
model_compressor
.
sparsity_config
.
format
==
CompressionFormat
.
sparse_24_bitmask
.
value
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# Only cutlass 3.x kernels are implemented so far
return
90
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
,
):
if
not
sparse_cutlass_supported
():
raise
ValueError
(
"Sparse CUTLASS not supported. vLLM must be built with "
...
...
@@ -49,16 +66,56 @@ class CompressedTensors24(CompressedTensorsScheme):
self
.
output_dtype
=
params_dtype
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size
=
input_size
layer
.
input_size_per_partition
=
input_size_per_partition
self
.
weights_dtype
:
torch
.
dtype
=
self
.
_get_params_dtype
(
params_dtype
)
# parameter to store uncompressed weight
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
self
.
weights_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
self
.
weights_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
if
self
.
do_sparse_decompress
:
assert
all
(
partition_size
%
8
==
0
for
partition_size
in
output_partition_sizes
),
"All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
1
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
,
)
compressed_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
//
2
,
dtype
=
self
.
weights_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
bitmask
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
//
8
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"shape"
,
shape
)
layer
.
register_parameter
(
"compressed"
,
compressed_weight
)
layer
.
register_parameter
(
"bitmask"
,
bitmask
)
# Check if quantized, not just 2:4 Sparse
if
self
.
quantized
:
...
...
@@ -68,14 +125,16 @@ class CompressedTensors24(CompressedTensorsScheme):
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
,
)
else
:
assert
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
...
...
@@ -84,9 +143,10 @@ class CompressedTensors24(CompressedTensorsScheme):
# register input quant scale
assert
(
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
...
...
@@ -107,13 +167,25 @@ class CompressedTensors24(CompressedTensorsScheme):
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
if
self
.
do_sparse_decompress
:
layer
.
weight
.
data
=
self
.
_decompress_bitmask_compressed_weight
(
compressed
=
layer
.
compressed
,
bitmask
=
layer
.
bitmask
,
layer
=
layer
,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del
layer
.
compressed
del
layer
.
bitmask
# torch.compile workaround
if
hasattr
(
layer
,
"input_scale"
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
...
...
@@ -121,10 +193,13 @@ class CompressedTensors24(CompressedTensorsScheme):
if
self
.
weight_quant
:
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
convert_to_channelwise
(
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
),
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
convert_to_channelwise
(
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
),
requires_grad
=
False
,
)
else
:
# torch.compile workaround
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
...
...
@@ -134,20 +209,22 @@ class CompressedTensors24(CompressedTensorsScheme):
layer
.
weight
=
torch
.
nn
.
Parameter
(
w_compressed
,
requires_grad
=
False
)
layer
.
meta
=
torch
.
nn
.
Parameter
(
meta
,
requires_grad
=
False
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Returns the output tensor for the layer with 2:4
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
:return: The output tensor of the layer
"""
if
self
.
quantized
:
scale
=
None
...
...
@@ -171,13 +248,15 @@ class CompressedTensors24(CompressedTensorsScheme):
input_scale
=
layer
.
input_scale
q_input
=
x
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
=
q_input
,
bt_nzs
=
layer
.
weight
,
bt_meta
=
layer
.
meta
,
scale_a
=
input_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
self
.
output_dtype
,
bias
=
bias
)
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
=
q_input
,
bt_nzs
=
layer
.
weight
,
bt_meta
=
layer
.
meta
,
scale_a
=
input_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
self
.
output_dtype
,
bias
=
bias
,
)
assert
out
.
is_contiguous
()
return
out
...
...
@@ -203,8 +282,71 @@ class CompressedTensors24(CompressedTensorsScheme):
raise
ValueError
(
"Quantization type not supported by Cutlass"
)
def
_decompress_bitmask_compressed_weight
(
self
,
compressed
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
,
)
->
torch
.
Tensor
:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor compressed using the
sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
def
check_24
(
tensor
):
new_tensor
=
tensor
.
view
(
-
1
,
4
)
zero_counts
=
(
new_tensor
==
0
).
sum
(
dim
=
1
)
return
(
zero_counts
>=
2
).
all
().
item
()
sparsity_compressor
=
self
.
model_compressor
.
sparsity_compressor
def
_process_split
(
bitmask_compressed_weight
:
torch
.
Tensor
,
shape
,
bitmask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
weight_data
=
dict
(
compressed
=
bitmask_compressed_weight
,
shape
=
shape
,
bitmask
=
bitmask
,
)
return
sparsity_compressor
.
decompress_weight
(
weight_data
)
split_weights
:
List
[
torch
.
Tensor
]
=
[]
split_bitmask
:
List
[
torch
.
Tensor
]
=
[]
split_shape
:
List
[
Tuple
[
int
,
int
]]
=
[]
if
isinstance
(
layer
,
(
QKVParallelLinear
,
MergedColumnParallelLinear
)):
split_weights
=
torch
.
split
(
compressed
,
layer
.
logical_widths
)
split_bitmask
=
torch
.
split
(
bitmask
,
layer
.
logical_widths
)
split_shape
=
[(
out
,
layer
.
input_size_per_partition
)
for
out
in
layer
.
logical_widths
]
if
split_weights
:
decompressed_shards
=
[
_process_split
(
compressed_weight
,
shape
,
bitmask
)
for
compressed_weight
,
shape
,
bitmask
in
zip
(
split_weights
,
split_shape
,
split_bitmask
)
]
decompressed
=
combine_shards
(
decompressed_shards
)
else
:
decompressed
=
sparsity_compressor
.
decompress_weight
(
dict
(
compressed
=
compressed
,
shape
=
(
layer
.
logical_widths
[
0
],
layer
.
input_size_per_partition
,
),
bitmask
=
bitmask
,
))
return
decompressed
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment