Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
22feac8e
Unverified
Commit
22feac8e
authored
Aug 28, 2025
by
Kyle Sayers
Committed by
GitHub
Aug 28, 2025
Browse files
[Transform] [Quantization] Add transforms to compressed tensors (#22486)
parent
c8851a47
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
661 additions
and
36 deletions
+661
-36
tests/conftest.py
tests/conftest.py
+37
-6
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+22
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+15
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+37
-15
vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py
...ayers/quantization/compressed_tensors/transform/linear.py
+227
-0
vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
...ayers/quantization/compressed_tensors/transform/module.py
+135
-0
vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py
...pressed_tensors/transform/schemes/linear_qutlass_nvfp4.py
+21
-0
vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py
...layers/quantization/compressed_tensors/transform/utils.py
+13
-0
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+154
-14
No files found.
tests/conftest.py
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
math
import
os
import
tempfile
from
enum
import
Enum
from
typing
import
Any
,
Callable
,
Optional
,
TypedDict
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypedDict
,
TypeVar
,
Union
,
cast
import
numpy
as
np
import
pytest
...
...
@@ -33,6 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.utils
import
maybe_model_redirect
logger
=
init_logger
(
__name__
)
...
...
@@ -602,7 +604,7 @@ class HfRunner:
def
_hidden_states_to_logprobs
(
self
,
hidden_states
:
tuple
[
tuple
[
torch
.
Tensor
,
...],
...],
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
)
->
tuple
[
list
[
dict
[
int
,
float
]],
int
]:
seq_logprobs
=
self
.
_hidden_states_to_seq_logprobs
(
hidden_states
)
output_len
=
len
(
hidden_states
)
...
...
@@ -630,7 +632,7 @@ class HfRunner:
self
,
prompts
:
list
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
...
...
@@ -677,7 +679,7 @@ class HfRunner:
self
,
encoder_decoder_prompts
:
list
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
images
:
Optional
[
PromptImageInput
]
=
None
,
**
kwargs
:
Any
,
)
->
list
[
TokensTextLogprobs
]:
...
...
@@ -966,7 +968,7 @@ class VllmRunner:
self
,
prompts
:
list
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
num_prompt_logprobs
:
Optional
[
int
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
...
...
@@ -991,11 +993,40 @@ class VllmRunner:
videos
=
videos
,
**
kwargs
)
def
generate_prompt_perplexity
(
self
,
prompts
:
list
[
str
])
->
list
[
float
]:
"""
Return the perplexity score associated with generating the prompts
:param prompts: list of prompts to score
:return: perplexity score of each prompt
"""
outputs
=
self
.
generate_greedy_logprobs
(
prompts
,
max_tokens
=
1
,
num_logprobs
=
None
,
num_prompt_logprobs
=
0
)
perplexities
=
[]
for
output
in
outputs
:
output
=
cast
(
TokensTextLogprobsPromptLogprobs
,
output
)
token_datas
=
cast
(
list
[
Optional
[
dict
[
int
,
Logprob
]]],
output
[
3
])
assert
token_datas
[
0
]
is
None
token_log_probs
=
[]
for
token_data
in
token_datas
[
1
:]:
assert
token_data
is
not
None
assert
len
(
token_data
)
==
1
token_log_prob
=
list
(
token_data
.
values
())[
0
].
logprob
token_log_probs
.
append
(
token_log_prob
)
perplexity
=
math
.
exp
(
-
sum
(
token_log_probs
)
/
len
(
token_log_probs
))
perplexities
.
append
(
perplexity
)
return
perplexities
def
generate_encoder_decoder_greedy_logprobs
(
self
,
encoder_decoder_prompts
:
list
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
num_prompt_logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
)
->
Union
[
list
[
TokensTextLogprobs
],
...
...
tests/quantization/test_compressed_tensors.py
View file @
22feac8e
...
...
@@ -719,3 +719,25 @@ def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"model,prompt,exp_perplexity"
,
[
(
"nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16"
,
"Flat is better than nested.
\n
Sparse is better than dense."
,
150.0
,
),
(
"nm-testing/Llama-3.2-1B-Instruct-quip-w4a16"
,
"Flat is better than nested.
\n
Sparse is better than dense."
,
150.0
,
),
])
def
test_compressed_tensors_transforms_perplexity
(
vllm_runner
,
model
,
prompt
,
exp_perplexity
):
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
perplexity
=
llm
.
generate_prompt_perplexity
([
prompt
])[
0
]
print
(
perplexity
)
assert
perplexity
<=
exp_perplexity
\ No newline at end of file
vllm/model_executor/layers/linear.py
View file @
22feac8e
...
...
@@ -35,6 +35,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"CompressedTensorsLinearTransformMethod"
,
"BitBLASLinearMethod"
,
"GPTQBitBLASLinearMethod"
,
"AWQMarlinLinearMethod"
,
...
...
@@ -199,6 +200,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# special postprocessing for CPU SGL
if
current_platform
.
is_cpu
()
and
envs
.
VLLM_CPU_SGL_KERNEL
:
from
vllm.model_executor.layers.utils
import
check_cpu_sgl_kernel
N
,
K
=
layer
.
weight
.
size
()
...
...
@@ -1470,7 +1472,7 @@ class QKVCrossParallelLinear(LinearBase):
self
.
bias
=
torch
.
nn
.
Parameter
()
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
"weight_loader"
:
self
.
weight_loader
_v1
,
})
else
:
self
.
bias
=
None
...
...
@@ -1580,6 +1582,18 @@ class QKVCrossParallelLinear(LinearBase):
k
,
v
=
kv_enc
.
split
(
self
.
kv_size
,
dim
=-
1
)
return
q
,
k
,
v
def
weight_loader_v1
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
):
# just like all other parameters, does not yet
# support loading bias with weight_loader_v2
layer
=
(
self
.
q_proj_decoder
if
loaded_shard_id
==
"q"
else
self
.
kv_proj_encoder
)
target_param
=
self
.
select_proj_params
(
layer
,
param
)
shard_id_args
=
(
loaded_shard_id
,
)
if
loaded_shard_id
!=
"q"
else
()
layer
.
weight_loader
(
target_param
,
loaded_weight
,
*
shard_id_args
)
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
22feac8e
...
...
@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
)
from
compressed_tensors.transform
import
TransformConfig
from
pydantic
import
BaseModel
import
vllm.envs
as
envs
...
...
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW4A16Fp4
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.linear
import
(
# noqa: E501
CompressedTensorsLinearTransformMethod
,
get_linear_transform_schemes
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
find_matched_target
,
is_activation_quantization_format
,
should_ignore_layer
)
...
...
@@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list
:
list
[
str
],
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
transform_config
:
Optional
[
TransformConfig
]
=
None
,
):
super
().
__init__
()
self
.
ignore
=
ignore
...
...
@@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
self
.
sparsity_ignore_list
=
sparsity_ignore_list
self
.
config
=
config
if
transform_config
is
not
None
:
self
.
transform_config
=
TransformConfig
.
model_validate
(
transform_config
)
else
:
self
.
transform_config
=
None
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
...
...
@@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
UnquantizedLinearMethod
()
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
if
scheme
is
None
:
return
UnquantizedLinearMethod
()
layer
.
scheme
=
scheme
return
CompressedTensorsLinearMethod
(
self
)
# collect schemes
quant_scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
input_tfms
,
output_tfms
=
get_linear_transform_schemes
(
layer
,
prefix
,
self
.
transform_config
,
self
.
packed_modules_mapping
)
# choose quantization method
quant_method
:
LinearMethodBase
=
UnquantizedLinearMethod
()
if
quant_scheme
is
not
None
:
layer
.
scheme
=
quant_scheme
quant_method
=
CompressedTensorsLinearMethod
(
self
)
# choose transform method
if
any
((
input_tfms
,
output_tfms
)):
return
CompressedTensorsLinearTransformMethod
.
from_schemes
(
quant_method
,
input_tfms
,
output_tfms
)
else
:
return
quant_method
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
...
...
@@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
config
=
config
)
sparsity_scheme_map
,
sparsity_ignore_list
=
cls
.
_parse_sparsity_config
(
config
=
config
)
transform_config
=
config
.
get
(
"transform_config"
)
return
cls
(
target_scheme_map
=
target_scheme_map
,
...
...
@@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map
=
sparsity_scheme_map
,
sparsity_ignore_list
=
sparsity_ignore_list
,
config
=
config
,
transform_config
=
transform_config
,
)
@
classmethod
...
...
@@ -537,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
# TODO (@kylesayrs): support ignore module names with ct matching utils
if
should_ignore_layer
(
layer_name
,
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
None
# Will be empty for models with only sparsity
weight_quant
=
input_quant
=
None
...
...
@@ -722,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer input. See LinearMethodBase for param details
"""
scheme
=
layer
.
scheme
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py
0 → 100644
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Generator
from
itertools
import
accumulate
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.transform
import
(
TransformArgs
,
TransformConfig
,
TransformLocation
,
TransformScheme
)
from
compressed_tensors.utils
import
is_match
from
vllm.model_executor.layers.linear
import
(
WEIGHT_LOADER_V2_SUPPORTED
,
LinearMethodBase
,
QKVCrossParallelLinear
)
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.module
import
(
# noqa: E501
HadamardTransform
)
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.utils
import
(
# noqa: E501
TransformTuple
)
class
CompressedTensorsLinearTransformMethod
(
LinearMethodBase
):
"""
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
input and output transforms to either side of the original apply method
"""
@
classmethod
def
from_schemes
(
cls
,
quant_method
:
LinearMethodBase
,
input_tfms
:
dict
[
int
,
TransformTuple
],
output_tfms
:
dict
[
int
,
TransformTuple
]
)
->
"CompressedTensorsLinearTransformMethod"
:
assert
input_tfms
or
output_tfms
# TODO (@ksayers): implement QutlassLinearMethodNvFP4
# hadacore and fwht can be selected by Transform module
return
cls
(
quant_method
,
input_tfms
,
output_tfms
)
def
__init__
(
self
,
quant_method
:
LinearMethodBase
,
input_tfms
:
dict
[
int
,
TransformTuple
],
output_tfms
:
dict
[
int
,
TransformTuple
]):
self
.
quant_method
=
quant_method
self
.
input_tfms
=
input_tfms
self
.
output_tfms
=
output_tfms
self
.
input_transform
:
Optional
[
HadamardTransform
]
=
None
self
.
output_transform
:
Optional
[
HadamardTransform
]
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# get weight loader for transforms
weight_loader
:
Callable
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# type: ignore[assignment]
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
# transforms (specifically SharedWeightParameter) requires
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
# hack around this by getting weight loader v1 so ULM can load correctly
quant_method_name
=
self
.
quant_method
.
__class__
.
__name__
if
quant_method_name
not
in
WEIGHT_LOADER_V2_SUPPORTED
:
if
isinstance
(
layer
,
QKVCrossParallelLinear
):
weight_loader_v1
=
layer
.
weight_loader_v1
else
:
weight_loader_v1
=
layer
.
weight_loader
extra_weight_attrs
[
"weight_loader"
]
=
weight_loader_v1
self
.
quant_method
.
create_weights
(
layer
=
layer
,
input_size_per_partition
=
input_size_per_partition
,
output_partition_sizes
=
output_partition_sizes
,
input_size
=
input_size
,
output_size
=
output_size
,
params_dtype
=
params_dtype
,
**
extra_weight_attrs
)
# validate schemes
num_partitions
=
len
(
output_partition_sizes
)
self
.
_validate_tfm_schemes
(
num_partitions
)
# create submodules for weight loading
if
len
(
self
.
input_tfms
)
>
0
:
scheme_name
=
list
(
self
.
input_tfms
.
values
())[
0
].
scheme_name
location
=
list
(
self
.
input_tfms
.
values
())[
0
].
args
.
location
transform_name
=
f
"
{
scheme_name
}
_
{
location
}
"
transform
=
HadamardTransform
(
self
.
input_tfms
,
layer
,
weight_loader
,
input_size_per_partition
,
output_partition_sizes
)
layer
.
register_module
(
transform_name
,
transform
)
self
.
input_transform
=
transform
if
len
(
self
.
output_tfms
)
>
0
:
scheme_name
=
list
(
self
.
output_tfms
.
values
())[
0
].
scheme_name
location
=
list
(
self
.
output_tfms
.
values
())[
0
].
args
.
location
transform_name
=
f
"
{
scheme_name
}
_
{
location
}
"
transform
=
HadamardTransform
(
self
.
output_tfms
,
layer
,
weight_loader
,
input_size_per_partition
,
output_partition_sizes
)
layer
.
register_module
(
transform_name
,
transform
)
self
.
output_transform
=
transform
# compute partition ranges for slicing activations
starts
=
[
0
]
+
list
(
accumulate
(
output_partition_sizes
))[:
-
1
]
self
.
partition_ranges
=
list
(
zip
(
starts
,
output_partition_sizes
))
def
process_weights_after_loading
(
self
,
layer
):
self
.
quant_method
.
process_weights_after_loading
(
layer
)
for
submodule
in
layer
.
children
():
if
isinstance
(
submodule
,
HadamardTransform
):
submodule
.
process_weights_after_loading
()
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
input_transform
is
not
None
:
x
=
self
.
input_transform
(
x
)
assert
bias
is
None
x
=
self
.
quant_method
.
apply
(
layer
,
x
,
bias
)
# TODO (@ksayers): Write a triton kernel to do this in parallel
if
self
.
output_transform
is
not
None
:
for
part_id
,
(
start
,
length
)
in
enumerate
(
self
.
partition_ranges
):
x
[:,
start
:
start
+
length
]
=
self
.
output_transform
(
x
[:,
start
:
start
+
length
],
part_id
=
part_id
)
return
x
def
_validate_tfm_schemes
(
self
,
num_partitions
:
int
):
if
len
(
self
.
input_tfms
)
>
0
:
if
0
not
in
self
.
input_tfms
:
raise
ValueError
(
"Must have same input"
)
for
part_index
in
range
(
num_partitions
):
if
self
.
input_tfms
[
part_index
]
!=
self
.
input_tfms
[
0
]:
raise
ValueError
(
"Must have same input"
)
if
len
(
self
.
output_tfms
)
>
0
:
scheme_name
=
list
(
self
.
output_tfms
.
values
())[
0
].
scheme_name
location
=
list
(
self
.
output_tfms
.
values
())[
0
].
args
.
location
for
tfm
in
self
.
output_tfms
.
values
():
if
tfm
.
scheme_name
!=
scheme_name
:
raise
ValueError
(
"Must have same scheme name"
)
if
tfm
.
args
.
location
!=
location
:
raise
ValueError
(
"Must have same location"
)
return
self
.
input_tfms
,
self
.
output_tfms
def
get_linear_transform_schemes
(
layer
:
torch
.
nn
.
Module
,
layer_name
:
str
,
transform_config
:
Optional
[
TransformConfig
],
packed_modules_mapping
:
dict
[
str
,
list
[
str
]]
)
->
tuple
[
dict
[
int
,
TransformTuple
],
dict
[
int
,
TransformTuple
]]:
# [input_transform, [output_transform, ...]]
# there can only be one transform input scheme per (fused) module
input_tfms
=
{}
output_tfms
=
{}
partition_names
=
get_layer_partition_names
(
layer_name
,
packed_modules_mapping
)
for
scheme_name
,
scheme
,
args
in
get_schemes_args
(
transform_config
):
for
part_index
,
part_name
in
enumerate
(
partition_names
):
if
is_match
(
part_name
,
layer
,
args
.
targets
,
args
.
ignore
)
and
args
.
is_online
():
if
args
.
location
==
TransformLocation
.
INPUT
:
input_tfms
[
part_index
]
=
TransformTuple
(
scheme_name
,
scheme
,
args
)
elif
args
.
location
==
TransformLocation
.
OUTPUT
:
output_tfms
[
part_index
]
=
TransformTuple
(
scheme_name
,
scheme
,
args
)
else
:
raise
ValueError
(
f
"Cannot apply `
{
args
.
location
}
` "
f
"transform to `
{
layer_name
}
`"
)
return
(
input_tfms
,
output_tfms
)
def
get_schemes_args
(
transform_config
:
Optional
[
TransformConfig
]
)
->
Generator
[
tuple
[
str
,
TransformScheme
,
TransformArgs
]]:
if
transform_config
is
None
:
return
for
scheme_name
,
scheme
in
transform_config
.
config_groups
.
items
():
for
args
in
scheme
.
apply
:
yield
(
scheme_name
,
scheme
,
args
)
def
get_layer_partition_names
(
layer_name
:
str
,
packed_modules_mapping
:
dict
[
str
,
list
[
str
]])
->
list
[
str
]:
"""
Get all partition names associated with this layer.
Names are returned in order of their partition indices.
```python
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
assert get_layer_partition_names(
"mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"]
assert get_layer_partition_names(
"mlp.down_proj", mapping) == ["down_proj"]
"""
for
fused_suffix
,
part_suffixes
in
packed_modules_mapping
.
items
():
if
layer_name
.
endswith
(
fused_suffix
):
return
[
layer_name
.
removesuffix
(
fused_suffix
)
+
part_suffix
for
part_suffix
in
part_suffixes
]
return
[
layer_name
]
vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
0 → 100644
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections.abc
import
Hashable
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.transform
import
TransformLocation
,
TransformScheme
from
torch
import
Tensor
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.utils
import
(
# noqa: E501
TransformTuple
)
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parameter
import
SharedWeightParameter
class
HadamardTransform
(
torch
.
nn
.
Module
):
"""
Class which handles weight loading, postprocessing, and application of
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
and attention transforms method (not implemented yet)
"""
transforms
:
dict
[
int
,
TransformTuple
]
# info parsed from transforms config
weight
:
SharedWeightParameter
# container for shared tensors
kernel
:
Callable
# function used during application
scales
:
dict
[
int
,
float
]
# hadamard scale, usually sqrt(matrix.size(0))
def
__init__
(
self
,
transforms
:
dict
[
int
,
TransformTuple
],
layer
:
torch
.
nn
.
Module
,
weight_loader
:
Callable
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
kernel
:
Optional
[
Callable
]
=
None
):
super
().
__init__
()
self
.
transforms
=
transforms
self
.
scales
=
{}
if
get_tensor_model_parallel_world_size
()
>
1
:
raise
NotImplementedError
(
"Online transforms with tensor "
"parallelism is not supported"
)
# Similar to row/col parallel params, but tensors are separate
# to allow for loading with shared memory
self
.
weight
=
SharedWeightParameter
(
weight_loader
=
weight_loader
)
# create shared partition data for each partition of the original weight
input_size
=
input_size_per_partition
for
part_index
,
(
_scheme_name
,
scheme
,
args
)
in
self
.
transforms
.
items
():
output_size
=
output_partition_sizes
[
part_index
]
weight_size
=
self
.
_get_weight_size
(
layer
,
args
.
location
,
input_size
,
output_size
)
data_key
=
self
.
_get_data_key
(
scheme
,
weight_size
)
self
.
weight
.
add_partition
(
part_index
,
data_key
,
size
=
(
weight_size
,
weight_size
),
dtype
=
scheme
.
precision
,
)
# validate that shared tensors and schemes are correct
self
.
_validate_input_transforms
()
# select kernel based on transform schemes
self
.
kernel
=
self
.
_infer_kernel
()
if
kernel
is
None
else
kernel
def
process_weights_after_loading
(
self
):
for
part_id
in
self
.
weight
.
partitions
:
data
=
self
.
weight
.
partitions
[
part_id
].
data
# required by torch.compile
self
.
weight
.
process_weights_after_loading
()
# precompute scale as a runtime multiply, not division
# do not fold into weight in order to utilize FWHT
self
.
scales
[
part_id
]
=
1
/
math
.
sqrt
(
data
.
size
(
0
))
# FUTURE: avoid runtime tranpose by processing weights
# prior to apply
def
forward
(
self
,
value
:
Tensor
,
part_id
:
int
=
0
)
->
Tensor
:
if
part_id
not
in
self
.
weight
.
partitions
:
return
value
weight
=
self
.
weight
.
partitions
[
part_id
]
weight
=
weight
if
self
.
transforms
[
part_id
].
args
.
inverse
else
weight
.
T
# linear := x(W.T)
scale
=
self
.
scales
[
part_id
]
return
self
.
kernel
(
self
,
value
.
to
(
weight
.
dtype
),
weight
,
None
).
to
(
value
.
dtype
)
*
scale
def
_get_data_key
(
self
,
scheme
:
TransformScheme
,
weight_size
:
int
)
->
Hashable
:
return
(
id
(
scheme
),
weight_size
)
def
_get_weight_size
(
self
,
layer
:
torch
.
nn
.
Module
,
location
:
TransformLocation
,
input_size
:
int
,
output_size
:
int
)
->
int
:
if
isinstance
(
layer
,
LinearBase
):
if
location
==
TransformLocation
.
INPUT
:
return
input_size
elif
location
==
TransformLocation
.
OUTPUT
:
return
output_size
elif
isinstance
(
layer
,
VocabParallelEmbedding
):
if
location
==
TransformLocation
.
INPUT
:
return
output_size
elif
location
==
TransformLocation
.
OUTPUT
:
return
input_size
raise
ValueError
()
def
_validate_input_transforms
(
self
):
assert
len
(
self
.
transforms
)
>
0
location
=
list
(
self
.
transforms
.
values
())[
0
].
args
.
location
if
location
==
TransformLocation
.
INPUT
:
first_data
=
self
.
weight
.
partitions
[
0
].
data
for
partition
in
self
.
weight
.
partitions
.
values
():
if
partition
.
data
.
data_ptr
()
!=
first_data
.
data_ptr
():
raise
ValueError
(
""
)
def
_infer_kernel
(
self
)
->
Callable
:
# TODO (@ksayers): use fwht, hadacore
return
dispatch_unquantized_gemm
()
vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py
0 → 100644
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.linear
import
(
# noqa: E501
CompressedTensorsLinearTransformMethod
)
# Because qutlass fuses hadamard with quantization, it cannot automatically be
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
# Therefore, a separate scheme must be created for each quantized dtype
class
QutlassLinearMethodNvFP4
(
CompressedTensorsLinearTransformMethod
):
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# fused hadamard quant linear method
raise
NotImplementedError
()
vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py
0 → 100644
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
NamedTuple
from
compressed_tensors.transform
import
TransformArgs
,
TransformScheme
__all__
=
[
"TransformTuple"
]
class
TransformTuple
(
NamedTuple
):
scheme_name
:
str
scheme
:
TransformScheme
args
:
TransformArgs
vllm/model_executor/parameter.py
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Hashable
from
fractions
import
Fraction
from
typing
import
Callable
,
Optional
,
Union
from
weakref
import
WeakValueDictionary
import
torch
from
torch.nn
import
Parameter
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.utils
import
_make_synced_weight_loader
...
...
@@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter):
into the parameter when the provided weight loader is called.
"""
def
__new__
(
cls
,
data
:
torch
.
Tensor
,
**
kwargs
):
def
__new__
(
cls
,
data
:
Optional
[
torch
.
Tensor
]
,
**
kwargs
):
return
super
().
__new__
(
cls
,
data
=
data
,
requires_grad
=
False
)
...
...
@@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter):
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
self
.
_assert_and_load
(
loaded_weight
)
def
_shard_id_as_int
(
self
,
shard_id
:
Union
[
str
,
int
])
->
int
:
if
isinstance
(
shard_id
,
int
):
return
shard_id
# if not int, assume shard_id for qkv
# map to int and return
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
assert
isinstance
(
shard_id
,
str
)
assert
shard_id
in
qkv_idxs
return
qkv_idxs
[
shard_id
]
class
_ColumnvLLMParameter
(
BasevLLMParameter
):
"""
...
...
@@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset
=
kwargs
.
get
(
"shard_offset"
)
shard_size
=
kwargs
.
get
(
"shard_size"
)
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if
isinstance
(
self
,
(
PackedColumnParameter
,
...
...
@@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_id
=
kwargs
.
get
(
"shard_id"
)
num_heads
=
kwargs
.
get
(
"num_heads"
)
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if
isinstance
(
self
,
(
PackedColumnParameter
,
...
...
@@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
"""
def
__init__
(
self
,
**
kwargs
):
self
.
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
super
().
__init__
(
**
kwargs
)
def
_shard_id_as_int
(
self
,
shard_id
:
Union
[
str
,
int
])
->
int
:
if
isinstance
(
shard_id
,
int
):
return
shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert
isinstance
(
shard_id
,
str
)
assert
shard_id
in
self
.
qkv_idxs
return
self
.
qkv_idxs
[
shard_id
]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def
load_row_parallel_weight
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
pass
class
SharedWeightParameter
(
BasevLLMParameter
):
"""
Parameter for weights with many shared tensors across a model
For example, when applying transforms to the "gate" and "up" partitions of
`MergedColumnParallelLinear`, the transform weights must stay separate
tensors in order to allow for tensor memory sharing between layers.
"""
# global registry for sharing tensors based on passed `data_key`
# this dict holds weaksrefs to avoid memory leak after model cleanup
tensors_registry
:
WeakValueDictionary
=
WeakValueDictionary
()
# local container for strong references to shared tensors
# this set compensates for the fact that torch.nn.Parameter
# and Parameter subclasses do not hold reliable references to tensors
local_tensors
:
set
[
torch
.
Tensor
]
# dictionary mapping partition indices to associated parameters
partitions
:
dict
[
int
,
Union
[
ModelWeightParameter
,
Parameter
]]
def
__new__
(
cls
,
**
kwargs
):
return
super
().
__new__
(
cls
,
data
=
None
,
**
kwargs
)
def
__init__
(
self
,
input_dim
:
int
=
1
,
output_dim
:
int
=
0
,
**
kwargs
):
weight_loader
:
Callable
=
kwargs
.
get
(
"weight_loader"
)
# type: ignore[assignment]
super
().
__init__
(
data
=
None
,
weight_loader
=
weight_loader
)
self
.
local_tensors
=
set
()
self
.
partitions
=
{}
self
.
kwargs
=
{
"input_dim"
:
input_dim
,
"output_dim"
:
output_dim
,
"weight_loader"
:
self
.
_fake_weight_loader
}
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
>
1
:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
does not "
"currently support tensor parallelism"
)
def
add_partition
(
self
,
index
:
int
,
data_key
:
Hashable
,
*
args
,
**
kwargs
):
"""
Add a partition to the weight parameter. Partitions whose `data_key`
is the same will share tensor data
:param index: index of partition to add
:param data_key: hashable key used to key shared tensors
:param *args: arguments for `torch.empty`
:param **kwargs: keyword arguments for `torch.empty`
"""
# load (shared) tensor using `data_key`
if
data_key
not
in
self
.
tensors_registry
:
data
=
torch
.
empty
(
*
args
,
**
kwargs
)
self
.
tensors_registry
[
data_key
]
=
data
else
:
data
=
self
.
tensors_registry
[
data_key
]
# create associated model parameter
self
.
partitions
[
index
]
=
ModelWeightParameter
(
data
=
data
,
**
self
.
kwargs
)
# type: ignore[arg-type]
# hold local reference, since ModelWeightParameter does not
# see https://github.com/pytorch/pytorch/issues/75932
self
.
local_tensors
.
add
(
data
)
def
load_column_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
assert
len
(
self
.
partitions
)
==
1
and
0
in
self
.
partitions
partition
=
self
.
partitions
[
0
]
ModelWeightParameter
.
load_column_parallel_weight
(
partition
,
loaded_weight
)
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
assert
len
(
self
.
partitions
)
==
1
and
0
in
self
.
partitions
partition
=
self
.
partitions
[
0
]
ModelWeightParameter
.
load_row_parallel_weight
(
partition
,
loaded_weight
)
def
load_merged_column_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
partition_id
=
kwargs
.
pop
(
"shard_id"
)
partition_id
=
self
.
_shard_id_as_int
(
partition_id
)
partition
=
self
.
partitions
[
partition_id
]
input_dim
=
self
.
kwargs
.
get
(
"input_dim"
)
shard_size
=
partition
.
data
.
size
(
input_dim
)
//
self
.
tp_size
shard_offset
=
self
.
tp_rank
*
shard_size
ModelWeightParameter
.
load_merged_column_weight
(
partition
,
loaded_weight
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
partition_id
=
self
.
_shard_id_as_int
(
kwargs
.
pop
(
"shard_id"
))
partition
=
self
.
partitions
[
partition_id
]
input_dim
=
self
.
kwargs
.
get
(
"input_dim"
)
shard_size
=
partition
.
data
.
size
(
input_dim
)
//
self
.
tp_size
shard_offset
=
self
.
tp_rank
*
shard_size
shard_id
=
"q"
# fake first partition
num_heads
=
kwargs
.
get
(
"num_heads"
)
ModelWeightParameter
.
load_qkv_weight
(
partition
,
loaded_weight
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
shard_id
=
shard_id
,
num_heads
=
num_heads
,
)
def
process_weights_after_loading
(
self
):
for
key
in
self
.
partitions
:
self
.
partitions
[
key
]
=
torch
.
nn
.
Parameter
(
data
=
self
.
partitions
[
key
].
data
,
requires_grad
=
False
)
@
property
def
data
(
self
):
raise
ValueError
(
"Accessing `data` of a "
"`PartitionedModelWeightParameter` is not allowed. "
"Instead, use `get_partition` to get the weight of "
"the particular partition you want to access"
)
def
_fake_weight_loader
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight_shard_id
:
Optional
[
Union
[
str
,
int
]]):
raise
ValueError
(
"When loading partition weights of "
f
"
{
self
.
__class__
.
__name__
}
, use methods provided by "
f
"
{
self
.
__class__
.
__name__
}
, not partition loader"
)
def
permute_param_layout_
(
param
:
BasevLLMParameter
,
input_dim
:
int
,
output_dim
:
int
,
**
kwargs
)
->
BasevLLMParameter
:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment