Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66b809cc
Commit
66b809cc
authored
Feb 08, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.2' into v0.7.2-dev
parents
37b63c24
0408efc6
Changes
1000
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
314 additions
and
88 deletions
+314
-88
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+36
-11
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+2
-0
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+2
-0
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+2
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+2
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-0
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+2
-0
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+2
-0
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+2
-0
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+2
-0
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+4
-1
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+50
-28
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
...ation/compressed_tensors/schemes/compressed_tensors_24.py
+192
-48
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+2
-0
No files found.
Too many changes to show.
To preserve performance only
1000 of 1000+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/logits_processor.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
"""A layer that compute logits from hidden_stats."""
import
inspect
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Optional
import
torch
...
...
@@ -14,6 +16,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.platforms
import
current_platform
_logits_processor_threadpool
:
Optional
[
ThreadPoolExecutor
]
=
None
if
envs
.
VLLM_LOGITS_PROCESSOR_THREADS
is
not
None
:
_logits_processor_threadpool
=
ThreadPoolExecutor
(
envs
.
VLLM_LOGITS_PROCESSOR_THREADS
)
class
LogitsProcessor
(
nn
.
Module
):
"""Process logits and apply logits processors from sampling metadata.
...
...
@@ -134,6 +141,7 @@ def _apply_logits_processors(
)
->
torch
.
Tensor
:
found_logits_processors
=
False
logits_processed
=
0
logits_row_ids_and_logits_row_futures
=
[]
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
...
...
@@ -147,22 +155,39 @@ def _apply_logits_processors(
past_tokens_ids
=
seq_group
.
seq_data
[
seq_id
].
output_token_ids
prompt_tokens_ids
=
seq_group
.
seq_data
[
seq_id
].
prompt_token_ids
for
logits_processor
in
logits_processors
:
parameters
=
inspect
.
signature
(
logits_processor
).
parameters
if
len
(
parameters
)
==
3
:
logits_row
=
logits_processor
(
prompt_tokens_ids
,
past_tokens_ids
,
logits_row
)
else
:
logits_row
=
logits_processor
(
past_tokens_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
if
_logits_processor_threadpool
is
not
None
:
logits_row_ids_and_logits_row_futures
.
append
(
(
logits_row_idx
,
_logits_processor_threadpool
.
submit
(
_apply_logits_processors_single_seq
,
logits_row
,
logits_processors
,
past_tokens_ids
,
prompt_tokens_ids
)))
else
:
logits
[
logits_row_idx
]
=
\
_apply_logits_processors_single_seq
(
logits_row
,
logits_processors
,
past_tokens_ids
,
prompt_tokens_ids
)
logits_processed
+=
len
(
seq_group
.
sample_indices
)
+
len
(
seq_group
.
prompt_logprob_indices
)
for
logits_row_idx
,
future
in
logits_row_ids_and_logits_row_futures
:
logits
[
logits_row_idx
]
=
future
.
result
()
if
found_logits_processors
:
# verifies that no rows in logits were missed unexpectedly
assert
logits_processed
==
logits
.
shape
[
0
]
return
logits
def
_apply_logits_processors_single_seq
(
logits_row
,
logits_processors
,
past_tokens_ids
,
prompt_tokens_ids
)
->
torch
.
Tensor
:
for
logits_processor
in
logits_processors
:
parameters
=
inspect
.
signature
(
logits_processor
).
parameters
if
len
(
parameters
)
==
3
:
logits_row
=
logits_processor
(
prompt_tokens_ids
,
past_tokens_ids
,
logits_row
)
else
:
logits_row
=
logits_processor
(
past_tokens_ids
,
logits_row
)
return
logits_row
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
torch
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
...
...
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
...
...
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
...
...
vllm/model_executor/layers/pooler.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
enum
import
IntEnum
from
typing
import
List
,
Optional
,
Union
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Type
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
# and https://arxiv.org/pdf/2401.06118.pdf
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
...
...
vllm/model_executor/layers/quantization/awq_triton.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
torch
import
triton
import
triton.language
as
tl
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
inspect
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Type
import
torch
from
torch
import
nn
...
...
@@ -57,6 +59,7 @@ def method_has_implemented_embedding(
class
QuantizationConfig
(
ABC
):
"""Base class for quantization configs."""
packed_modules_mapping
:
Mapping
[
str
,
List
[
str
]]
=
dict
()
@
abstractmethod
def
get_name
(
self
)
->
str
:
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
contextlib
import
suppress
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
cast
...
...
@@ -81,7 +83,9 @@ class CompressedTensorsConfig(QuantizationConfig):
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
):
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
UnquantizedLinearMethod
()
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
...
...
@@ -377,34 +381,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# Will be empty for models with only sparsity
weight_quant
=
input_quant
=
None
sparsity_scheme
:
Optional
[
SparsityCompressionConfig
]
=
None
if
self
.
target_scheme_map
:
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
target_scheme_map
.
keys
())
targets
=
self
.
target_scheme_map
.
keys
(),
fused_mapping
=
self
.
packed_modules_mapping
)
scheme_dict
=
self
.
target_scheme_map
[
matched_target
]
weight_quant
=
scheme_dict
.
get
(
"weights"
)
input_quant
=
scheme_dict
.
get
(
"input_activations"
)
if
self
.
sparsity_scheme_map
:
is_ignored
=
False
with
suppress
(
ValueError
):
is_ignored
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
sparsity_ignore_list
)
# if the layer is in the sparsity ignore list,
# we should not apply any sparsity scheme
if
not
is_ignored
:
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
sparsity_scheme_map
.
keys
())
sparsity_scheme
=
self
.
sparsity_scheme_map
.
get
(
matched_target
)
# Find the sparsity scheme of the layer
# assume that fused layers inerhit first component's sparsity scheme
sparsity_targets
=
(
self
.
sparsity_scheme_map
.
keys
()
-
set
(
self
.
sparsity_ignore_list
))
sparsity_scheme
:
Optional
[
SparsityCompressionConfig
]
=
None
with
suppress
(
ValueError
):
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
sparsity_targets
,
fused_mapping
=
self
.
packed_modules_mapping
)
sparsity_scheme
=
self
.
sparsity_scheme_map
[
matched_target
]
if
self
.
supports_cutlass_24
(
weight_quant
=
weight_quant
,
input_quant
=
input_quant
,
...
...
@@ -418,10 +417,22 @@ class CompressedTensorsConfig(QuantizationConfig):
return
None
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
scheme
=
CompressedTensors24
(
quantized
=
weight_quant
is
not
None
or
input_quant
is
not
None
,
weight_quant
=
weight_quant
,
input_quant
=
input_quant
)
model_compression_config
=
(
None
if
sparsity_scheme
is
None
or
sparsity_scheme
.
format
==
"dense"
else
self
.
config
)
scheme
=
CompressedTensors24
(
quantized
=
weight_quant
is
not
None
or
input_quant
is
not
None
,
weight_quant
=
weight_quant
,
input_quant
=
input_quant
,
model_compression_config
=
model_compression_config
,
)
elif
weight_quant
is
None
:
logger
.
warning_once
(
"Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod"
)
return
None
else
:
# Find the quant_scheme
scheme
=
self
.
_get_scheme_from_parts
(
# type: ignore
...
...
@@ -471,10 +482,21 @@ class CompressedTensorsConfig(QuantizationConfig):
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
is_valid_sparsity
=
(
sparsity_scheme
is
not
None
and
sparsity_scheme
.
sparsity_structure
==
SparsityStructure
.
TWO_FOUR
.
value
and
sparsity_scheme
.
format
==
"dense"
)
if
sparsity_scheme
is
None
:
return
False
is_valid_sparsity_structure
:
bool
=
(
sparsity_scheme
.
sparsity_structure
==
SparsityStructure
.
TWO_FOUR
.
value
)
valid_compressors
=
{
CompressionFormat
.
dense
.
value
,
CompressionFormat
.
sparse_24_bitmask
.
value
}
is_valid_sparsity
=
(
is_valid_sparsity_structure
and
sparsity_scheme
.
format
in
valid_compressors
)
if
not
is_valid_sparsity
:
return
False
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
import
enum
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
View file @
66b809cc
from
typing
import
Callable
,
List
,
Optional
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
compressed_tensors
import
CompressionFormat
,
ModelCompressor
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
)
from
compressed_tensors.utils
import
combine_shards
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
@@ -20,26 +26,39 @@ __all__ = ["CompressedTensors24"]
class
CompressedTensors24
(
CompressedTensorsScheme
):
def
__init__
(
self
,
quantized
:
bool
=
False
,
weight_quant
:
Optional
[
QuantizationArgs
]
=
None
,
input_quant
:
Optional
[
QuantizationArgs
]
=
None
):
def
__init__
(
self
,
quantized
:
bool
=
False
,
weight_quant
:
Optional
[
QuantizationArgs
]
=
None
,
input_quant
:
Optional
[
QuantizationArgs
]
=
None
,
model_compression_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
self
.
quantized
=
quantized
self
.
weight_quant
=
weight_quant
self
.
input_quant
=
input_quant
self
.
model_compressor
=
(
ModelCompressor
.
from_compression_config
(
model_compression_config
)
if
model_compression_config
is
not
None
else
None
)
self
.
do_sparse_decompress
=
(
self
.
model_compressor
is
not
None
and
self
.
model_compressor
.
sparsity_config
.
format
==
CompressionFormat
.
sparse_24_bitmask
.
value
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# Only cutlass 3.x kernels are implemented so far
return
90
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
,
):
if
not
sparse_cutlass_supported
():
raise
ValueError
(
"Sparse CUTLASS not supported. vLLM must be built with "
...
...
@@ -47,16 +66,56 @@ class CompressedTensors24(CompressedTensorsScheme):
self
.
output_dtype
=
params_dtype
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size
=
input_size
layer
.
input_size_per_partition
=
input_size_per_partition
self
.
weights_dtype
:
torch
.
dtype
=
self
.
_get_params_dtype
(
params_dtype
)
# parameter to store uncompressed weight
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
self
.
weights_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
self
.
weights_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
if
self
.
do_sparse_decompress
:
assert
all
(
partition_size
%
8
==
0
for
partition_size
in
output_partition_sizes
),
"All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
1
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
,
)
compressed_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
//
2
,
dtype
=
self
.
weights_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
bitmask
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
//
8
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"shape"
,
shape
)
layer
.
register_parameter
(
"compressed"
,
compressed_weight
)
layer
.
register_parameter
(
"bitmask"
,
bitmask
)
# Check if quantized, not just 2:4 Sparse
if
self
.
quantized
:
...
...
@@ -66,14 +125,16 @@ class CompressedTensors24(CompressedTensorsScheme):
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
,
)
else
:
assert
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
...
...
@@ -82,9 +143,10 @@ class CompressedTensors24(CompressedTensorsScheme):
# register input quant scale
assert
(
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
...
...
@@ -105,13 +167,25 @@ class CompressedTensors24(CompressedTensorsScheme):
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
if
self
.
do_sparse_decompress
:
layer
.
weight
.
data
=
self
.
_decompress_bitmask_compressed_weight
(
compressed
=
layer
.
compressed
,
bitmask
=
layer
.
bitmask
,
layer
=
layer
,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del
layer
.
compressed
del
layer
.
bitmask
# torch.compile workaround
if
hasattr
(
layer
,
"input_scale"
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
...
...
@@ -119,10 +193,13 @@ class CompressedTensors24(CompressedTensorsScheme):
if
self
.
weight_quant
:
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
convert_to_channelwise
(
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
),
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
convert_to_channelwise
(
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
),
requires_grad
=
False
,
)
else
:
# torch.compile workaround
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
...
...
@@ -132,20 +209,22 @@ class CompressedTensors24(CompressedTensorsScheme):
layer
.
weight
=
torch
.
nn
.
Parameter
(
w_compressed
,
requires_grad
=
False
)
layer
.
meta
=
torch
.
nn
.
Parameter
(
meta
,
requires_grad
=
False
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Returns the output tensor for the layer with 2:4
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
:return: The output tensor of the layer
"""
if
self
.
quantized
:
scale
=
None
...
...
@@ -169,13 +248,15 @@ class CompressedTensors24(CompressedTensorsScheme):
input_scale
=
layer
.
input_scale
q_input
=
x
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
=
q_input
,
bt_nzs
=
layer
.
weight
,
bt_meta
=
layer
.
meta
,
scale_a
=
input_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
self
.
output_dtype
,
bias
=
bias
)
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
=
q_input
,
bt_nzs
=
layer
.
weight
,
bt_meta
=
layer
.
meta
,
scale_a
=
input_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
self
.
output_dtype
,
bias
=
bias
,
)
assert
out
.
is_contiguous
()
return
out
...
...
@@ -201,8 +282,71 @@ class CompressedTensors24(CompressedTensorsScheme):
raise
ValueError
(
"Quantization type not supported by Cutlass"
)
def
_decompress_bitmask_compressed_weight
(
self
,
compressed
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
,
)
->
torch
.
Tensor
:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor compressed using the
sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
def
check_24
(
tensor
):
new_tensor
=
tensor
.
view
(
-
1
,
4
)
zero_counts
=
(
new_tensor
==
0
).
sum
(
dim
=
1
)
return
(
zero_counts
>=
2
).
all
().
item
()
sparsity_compressor
=
self
.
model_compressor
.
sparsity_compressor
def
_process_split
(
bitmask_compressed_weight
:
torch
.
Tensor
,
shape
,
bitmask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
weight_data
=
dict
(
compressed
=
bitmask_compressed_weight
,
shape
=
shape
,
bitmask
=
bitmask
,
)
return
sparsity_compressor
.
decompress_weight
(
weight_data
)
split_weights
:
List
[
torch
.
Tensor
]
=
[]
split_bitmask
:
List
[
torch
.
Tensor
]
=
[]
split_shape
:
List
[
Tuple
[
int
,
int
]]
=
[]
if
isinstance
(
layer
,
(
QKVParallelLinear
,
MergedColumnParallelLinear
)):
split_weights
=
torch
.
split
(
compressed
,
layer
.
logical_widths
)
split_bitmask
=
torch
.
split
(
bitmask
,
layer
.
logical_widths
)
split_shape
=
[(
out
,
layer
.
input_size_per_partition
)
for
out
in
layer
.
logical_widths
]
if
split_weights
:
decompressed_shards
=
[
_process_split
(
compressed_weight
,
shape
,
bitmask
)
for
compressed_weight
,
shape
,
bitmask
in
zip
(
split_weights
,
split_shape
,
split_bitmask
)
]
decompressed
=
combine_shards
(
decompressed_shards
)
else
:
decompressed
=
sparsity_compressor
.
decompress_weight
(
dict
(
compressed
=
compressed
,
shape
=
(
layer
.
logical_widths
[
0
],
layer
.
input_size_per_partition
,
),
bitmask
=
bitmask
,
))
return
decompressed
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
,
List
,
Optional
import
torch
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
,
List
,
Optional
import
torch
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
66b809cc
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
,
List
,
Optional
import
torch
...
...
Prev
1
…
32
33
34
35
36
37
38
39
40
…
50
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment