Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
22feac8e
Unverified
Commit
22feac8e
authored
Aug 28, 2025
by
Kyle Sayers
Committed by
GitHub
Aug 28, 2025
Browse files
[Transform] [Quantization] Add transforms to compressed tensors (#22486)
parent
c8851a47
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
661 additions
and
36 deletions
+661
-36
tests/conftest.py
tests/conftest.py
+37
-6
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+22
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+15
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+37
-15
vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py
...ayers/quantization/compressed_tensors/transform/linear.py
+227
-0
vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
...ayers/quantization/compressed_tensors/transform/module.py
+135
-0
vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py
...pressed_tensors/transform/schemes/linear_qutlass_nvfp4.py
+21
-0
vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py
...layers/quantization/compressed_tensors/transform/utils.py
+13
-0
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+154
-14
No files found.
tests/conftest.py
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
json
import
math
import
os
import
os
import
tempfile
import
tempfile
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Callable
,
Optional
,
TypedDict
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypedDict
,
TypeVar
,
Union
,
cast
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -33,6 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
...
@@ -33,6 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
from
vllm.sampling_params
import
BeamSearchParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.utils
import
maybe_model_redirect
from
vllm.transformers_utils.utils
import
maybe_model_redirect
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -602,7 +604,7 @@ class HfRunner:
...
@@ -602,7 +604,7 @@ class HfRunner:
def
_hidden_states_to_logprobs
(
def
_hidden_states_to_logprobs
(
self
,
self
,
hidden_states
:
tuple
[
tuple
[
torch
.
Tensor
,
...],
...],
hidden_states
:
tuple
[
tuple
[
torch
.
Tensor
,
...],
...],
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
)
->
tuple
[
list
[
dict
[
int
,
float
]],
int
]:
)
->
tuple
[
list
[
dict
[
int
,
float
]],
int
]:
seq_logprobs
=
self
.
_hidden_states_to_seq_logprobs
(
hidden_states
)
seq_logprobs
=
self
.
_hidden_states_to_seq_logprobs
(
hidden_states
)
output_len
=
len
(
hidden_states
)
output_len
=
len
(
hidden_states
)
...
@@ -630,7 +632,7 @@ class HfRunner:
...
@@ -630,7 +632,7 @@ class HfRunner:
self
,
self
,
prompts
:
list
[
str
],
prompts
:
list
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
...
@@ -677,7 +679,7 @@ class HfRunner:
...
@@ -677,7 +679,7 @@ class HfRunner:
self
,
self
,
encoder_decoder_prompts
:
list
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
encoder_decoder_prompts
:
list
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
list
[
TokensTextLogprobs
]:
)
->
list
[
TokensTextLogprobs
]:
...
@@ -966,7 +968,7 @@ class VllmRunner:
...
@@ -966,7 +968,7 @@ class VllmRunner:
self
,
self
,
prompts
:
list
[
str
],
prompts
:
list
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
num_prompt_logprobs
:
Optional
[
int
]
=
None
,
num_prompt_logprobs
:
Optional
[
int
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
...
@@ -991,11 +993,40 @@ class VllmRunner:
...
@@ -991,11 +993,40 @@ class VllmRunner:
videos
=
videos
,
videos
=
videos
,
**
kwargs
)
**
kwargs
)
def
generate_prompt_perplexity
(
self
,
prompts
:
list
[
str
])
->
list
[
float
]:
"""
Return the perplexity score associated with generating the prompts
:param prompts: list of prompts to score
:return: perplexity score of each prompt
"""
outputs
=
self
.
generate_greedy_logprobs
(
prompts
,
max_tokens
=
1
,
num_logprobs
=
None
,
num_prompt_logprobs
=
0
)
perplexities
=
[]
for
output
in
outputs
:
output
=
cast
(
TokensTextLogprobsPromptLogprobs
,
output
)
token_datas
=
cast
(
list
[
Optional
[
dict
[
int
,
Logprob
]]],
output
[
3
])
assert
token_datas
[
0
]
is
None
token_log_probs
=
[]
for
token_data
in
token_datas
[
1
:]:
assert
token_data
is
not
None
assert
len
(
token_data
)
==
1
token_log_prob
=
list
(
token_data
.
values
())[
0
].
logprob
token_log_probs
.
append
(
token_log_prob
)
perplexity
=
math
.
exp
(
-
sum
(
token_log_probs
)
/
len
(
token_log_probs
))
perplexities
.
append
(
perplexity
)
return
perplexities
def
generate_encoder_decoder_greedy_logprobs
(
def
generate_encoder_decoder_greedy_logprobs
(
self
,
self
,
encoder_decoder_prompts
:
list
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
encoder_decoder_prompts
:
list
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
Optional
[
int
]
,
num_prompt_logprobs
:
Optional
[
int
]
=
None
,
num_prompt_logprobs
:
Optional
[
int
]
=
None
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
)
->
Union
[
list
[
TokensTextLogprobs
],
)
->
Union
[
list
[
TokensTextLogprobs
],
...
...
tests/quantization/test_compressed_tensors.py
View file @
22feac8e
...
@@ -719,3 +719,25 @@ def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
...
@@ -719,3 +719,25 @@ def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
print
(
output
)
assert
output
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"model,prompt,exp_perplexity"
,
[
(
"nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16"
,
"Flat is better than nested.
\n
Sparse is better than dense."
,
150.0
,
),
(
"nm-testing/Llama-3.2-1B-Instruct-quip-w4a16"
,
"Flat is better than nested.
\n
Sparse is better than dense."
,
150.0
,
),
])
def
test_compressed_tensors_transforms_perplexity
(
vllm_runner
,
model
,
prompt
,
exp_perplexity
):
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
perplexity
=
llm
.
generate_prompt_perplexity
([
prompt
])[
0
]
print
(
perplexity
)
assert
perplexity
<=
exp_perplexity
\ No newline at end of file
vllm/model_executor/layers/linear.py
View file @
22feac8e
...
@@ -35,6 +35,7 @@ logger = init_logger(__name__)
...
@@ -35,6 +35,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED
=
[
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"CompressedTensorsLinearMethod"
,
"CompressedTensorsLinearTransformMethod"
,
"BitBLASLinearMethod"
,
"BitBLASLinearMethod"
,
"GPTQBitBLASLinearMethod"
,
"GPTQBitBLASLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQMarlinLinearMethod"
,
...
@@ -199,6 +200,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -199,6 +200,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
set_weight_attrs
(
weight
,
extra_weight_attrs
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# special postprocessing for CPU SGL
if
current_platform
.
is_cpu
()
and
envs
.
VLLM_CPU_SGL_KERNEL
:
if
current_platform
.
is_cpu
()
and
envs
.
VLLM_CPU_SGL_KERNEL
:
from
vllm.model_executor.layers.utils
import
check_cpu_sgl_kernel
from
vllm.model_executor.layers.utils
import
check_cpu_sgl_kernel
N
,
K
=
layer
.
weight
.
size
()
N
,
K
=
layer
.
weight
.
size
()
...
@@ -1470,7 +1472,7 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1470,7 +1472,7 @@ class QKVCrossParallelLinear(LinearBase):
self
.
bias
=
torch
.
nn
.
Parameter
()
self
.
bias
=
torch
.
nn
.
Parameter
()
set_weight_attrs
(
self
.
bias
,
{
set_weight_attrs
(
self
.
bias
,
{
"output_dim"
:
0
,
"output_dim"
:
0
,
"weight_loader"
:
self
.
weight_loader
,
"weight_loader"
:
self
.
weight_loader
_v1
,
})
})
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -1580,6 +1582,18 @@ class QKVCrossParallelLinear(LinearBase):
...
@@ -1580,6 +1582,18 @@ class QKVCrossParallelLinear(LinearBase):
k
,
v
=
kv_enc
.
split
(
self
.
kv_size
,
dim
=-
1
)
k
,
v
=
kv_enc
.
split
(
self
.
kv_size
,
dim
=-
1
)
return
q
,
k
,
v
return
q
,
k
,
v
def
weight_loader_v1
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_shard_id
:
Optional
[
str
]
=
None
):
# just like all other parameters, does not yet
# support loading bias with weight_loader_v2
layer
=
(
self
.
q_proj_decoder
if
loaded_shard_id
==
"q"
else
self
.
kv_proj_encoder
)
target_param
=
self
.
select_proj_params
(
layer
,
param
)
shard_id_args
=
(
loaded_shard_id
,
)
if
loaded_shard_id
!=
"q"
else
()
layer
.
weight_loader
(
target_param
,
loaded_weight
,
*
shard_id_args
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
22feac8e
...
@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
...
@@ -11,6 +11,7 @@ from compressed_tensors.config import (CompressionFormat,
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationStrategy
,
QuantizationType
)
QuantizationType
)
from
compressed_tensors.transform
import
TransformConfig
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW4A16Fp4
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW4A16Fp4
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.linear
import
(
# noqa: E501
CompressedTensorsLinearTransformMethod
,
get_linear_transform_schemes
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
find_matched_target
,
is_activation_quantization_format
,
find_matched_target
,
is_activation_quantization_format
,
should_ignore_layer
)
should_ignore_layer
)
...
@@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -60,6 +63,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_ignore_list
:
list
[
str
],
sparsity_ignore_list
:
list
[
str
],
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
transform_config
:
Optional
[
TransformConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
ignore
=
ignore
self
.
ignore
=
ignore
...
@@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -71,6 +75,12 @@ class CompressedTensorsConfig(QuantizationConfig):
self
.
sparsity_ignore_list
=
sparsity_ignore_list
self
.
sparsity_ignore_list
=
sparsity_ignore_list
self
.
config
=
config
self
.
config
=
config
if
transform_config
is
not
None
:
self
.
transform_config
=
TransformConfig
.
model_validate
(
transform_config
)
else
:
self
.
transform_config
=
None
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
...
@@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -103,18 +113,27 @@ class CompressedTensorsConfig(QuantizationConfig):
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
UnquantizedLinearMethod
()
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
# collect schemes
if
scheme
is
None
:
quant_scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
return
UnquantizedLinearMethod
()
input_tfms
,
output_tfms
=
get_linear_transform_schemes
(
layer
.
scheme
=
scheme
layer
,
prefix
,
self
.
transform_config
,
return
CompressedTensorsLinearMethod
(
self
)
self
.
packed_modules_mapping
)
# choose quantization method
quant_method
:
LinearMethodBase
=
UnquantizedLinearMethod
()
if
quant_scheme
is
not
None
:
layer
.
scheme
=
quant_scheme
quant_method
=
CompressedTensorsLinearMethod
(
self
)
# choose transform method
if
any
((
input_tfms
,
output_tfms
)):
return
CompressedTensorsLinearTransformMethod
.
from_schemes
(
quant_method
,
input_tfms
,
output_tfms
)
else
:
return
quant_method
if
isinstance
(
layer
,
Attention
):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
if
isinstance
(
layer
,
FusedMoE
):
...
@@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -129,6 +148,7 @@ class CompressedTensorsConfig(QuantizationConfig):
config
=
config
)
config
=
config
)
sparsity_scheme_map
,
sparsity_ignore_list
=
cls
.
_parse_sparsity_config
(
sparsity_scheme_map
,
sparsity_ignore_list
=
cls
.
_parse_sparsity_config
(
config
=
config
)
config
=
config
)
transform_config
=
config
.
get
(
"transform_config"
)
return
cls
(
return
cls
(
target_scheme_map
=
target_scheme_map
,
target_scheme_map
=
target_scheme_map
,
...
@@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -137,6 +157,7 @@ class CompressedTensorsConfig(QuantizationConfig):
sparsity_scheme_map
=
sparsity_scheme_map
,
sparsity_scheme_map
=
sparsity_scheme_map
,
sparsity_ignore_list
=
sparsity_ignore_list
,
sparsity_ignore_list
=
sparsity_ignore_list
,
config
=
config
,
config
=
config
,
transform_config
=
transform_config
,
)
)
@
classmethod
@
classmethod
...
@@ -537,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -537,9 +558,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# TODO (@kylesayrs): support ignore module names with ct matching utils
# so we do not have to re-write these functions
if
should_ignore_layer
(
layer_name
,
# need to make accelerate optional in ct to do this
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
None
# Will be empty for models with only sparsity
# Will be empty for models with only sparsity
weight_quant
=
input_quant
=
None
weight_quant
=
input_quant
=
None
...
@@ -722,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -722,7 +745,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer input. See LinearMethodBase for param details
layer input. See LinearMethodBase for param details
"""
"""
scheme
=
layer
.
scheme
scheme
=
layer
.
scheme
if
scheme
is
None
:
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
raise
ValueError
(
"A scheme must be defined for each layer"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py
0 → 100644
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Generator
from
itertools
import
accumulate
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.transform
import
(
TransformArgs
,
TransformConfig
,
TransformLocation
,
TransformScheme
)
from
compressed_tensors.utils
import
is_match
from
vllm.model_executor.layers.linear
import
(
WEIGHT_LOADER_V2_SUPPORTED
,
LinearMethodBase
,
QKVCrossParallelLinear
)
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.module
import
(
# noqa: E501
HadamardTransform
)
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.utils
import
(
# noqa: E501
TransformTuple
)
class
CompressedTensorsLinearTransformMethod
(
LinearMethodBase
):
"""
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
input and output transforms to either side of the original apply method
"""
@
classmethod
def
from_schemes
(
cls
,
quant_method
:
LinearMethodBase
,
input_tfms
:
dict
[
int
,
TransformTuple
],
output_tfms
:
dict
[
int
,
TransformTuple
]
)
->
"CompressedTensorsLinearTransformMethod"
:
assert
input_tfms
or
output_tfms
# TODO (@ksayers): implement QutlassLinearMethodNvFP4
# hadacore and fwht can be selected by Transform module
return
cls
(
quant_method
,
input_tfms
,
output_tfms
)
def
__init__
(
self
,
quant_method
:
LinearMethodBase
,
input_tfms
:
dict
[
int
,
TransformTuple
],
output_tfms
:
dict
[
int
,
TransformTuple
]):
self
.
quant_method
=
quant_method
self
.
input_tfms
=
input_tfms
self
.
output_tfms
=
output_tfms
self
.
input_transform
:
Optional
[
HadamardTransform
]
=
None
self
.
output_transform
:
Optional
[
HadamardTransform
]
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# get weight loader for transforms
weight_loader
:
Callable
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# type: ignore[assignment]
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
# transforms (specifically SharedWeightParameter) requires
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
# hack around this by getting weight loader v1 so ULM can load correctly
quant_method_name
=
self
.
quant_method
.
__class__
.
__name__
if
quant_method_name
not
in
WEIGHT_LOADER_V2_SUPPORTED
:
if
isinstance
(
layer
,
QKVCrossParallelLinear
):
weight_loader_v1
=
layer
.
weight_loader_v1
else
:
weight_loader_v1
=
layer
.
weight_loader
extra_weight_attrs
[
"weight_loader"
]
=
weight_loader_v1
self
.
quant_method
.
create_weights
(
layer
=
layer
,
input_size_per_partition
=
input_size_per_partition
,
output_partition_sizes
=
output_partition_sizes
,
input_size
=
input_size
,
output_size
=
output_size
,
params_dtype
=
params_dtype
,
**
extra_weight_attrs
)
# validate schemes
num_partitions
=
len
(
output_partition_sizes
)
self
.
_validate_tfm_schemes
(
num_partitions
)
# create submodules for weight loading
if
len
(
self
.
input_tfms
)
>
0
:
scheme_name
=
list
(
self
.
input_tfms
.
values
())[
0
].
scheme_name
location
=
list
(
self
.
input_tfms
.
values
())[
0
].
args
.
location
transform_name
=
f
"
{
scheme_name
}
_
{
location
}
"
transform
=
HadamardTransform
(
self
.
input_tfms
,
layer
,
weight_loader
,
input_size_per_partition
,
output_partition_sizes
)
layer
.
register_module
(
transform_name
,
transform
)
self
.
input_transform
=
transform
if
len
(
self
.
output_tfms
)
>
0
:
scheme_name
=
list
(
self
.
output_tfms
.
values
())[
0
].
scheme_name
location
=
list
(
self
.
output_tfms
.
values
())[
0
].
args
.
location
transform_name
=
f
"
{
scheme_name
}
_
{
location
}
"
transform
=
HadamardTransform
(
self
.
output_tfms
,
layer
,
weight_loader
,
input_size_per_partition
,
output_partition_sizes
)
layer
.
register_module
(
transform_name
,
transform
)
self
.
output_transform
=
transform
# compute partition ranges for slicing activations
starts
=
[
0
]
+
list
(
accumulate
(
output_partition_sizes
))[:
-
1
]
self
.
partition_ranges
=
list
(
zip
(
starts
,
output_partition_sizes
))
def
process_weights_after_loading
(
self
,
layer
):
self
.
quant_method
.
process_weights_after_loading
(
layer
)
for
submodule
in
layer
.
children
():
if
isinstance
(
submodule
,
HadamardTransform
):
submodule
.
process_weights_after_loading
()
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
input_transform
is
not
None
:
x
=
self
.
input_transform
(
x
)
assert
bias
is
None
x
=
self
.
quant_method
.
apply
(
layer
,
x
,
bias
)
# TODO (@ksayers): Write a triton kernel to do this in parallel
if
self
.
output_transform
is
not
None
:
for
part_id
,
(
start
,
length
)
in
enumerate
(
self
.
partition_ranges
):
x
[:,
start
:
start
+
length
]
=
self
.
output_transform
(
x
[:,
start
:
start
+
length
],
part_id
=
part_id
)
return
x
def
_validate_tfm_schemes
(
self
,
num_partitions
:
int
):
if
len
(
self
.
input_tfms
)
>
0
:
if
0
not
in
self
.
input_tfms
:
raise
ValueError
(
"Must have same input"
)
for
part_index
in
range
(
num_partitions
):
if
self
.
input_tfms
[
part_index
]
!=
self
.
input_tfms
[
0
]:
raise
ValueError
(
"Must have same input"
)
if
len
(
self
.
output_tfms
)
>
0
:
scheme_name
=
list
(
self
.
output_tfms
.
values
())[
0
].
scheme_name
location
=
list
(
self
.
output_tfms
.
values
())[
0
].
args
.
location
for
tfm
in
self
.
output_tfms
.
values
():
if
tfm
.
scheme_name
!=
scheme_name
:
raise
ValueError
(
"Must have same scheme name"
)
if
tfm
.
args
.
location
!=
location
:
raise
ValueError
(
"Must have same location"
)
return
self
.
input_tfms
,
self
.
output_tfms
def
get_linear_transform_schemes
(
layer
:
torch
.
nn
.
Module
,
layer_name
:
str
,
transform_config
:
Optional
[
TransformConfig
],
packed_modules_mapping
:
dict
[
str
,
list
[
str
]]
)
->
tuple
[
dict
[
int
,
TransformTuple
],
dict
[
int
,
TransformTuple
]]:
# [input_transform, [output_transform, ...]]
# there can only be one transform input scheme per (fused) module
input_tfms
=
{}
output_tfms
=
{}
partition_names
=
get_layer_partition_names
(
layer_name
,
packed_modules_mapping
)
for
scheme_name
,
scheme
,
args
in
get_schemes_args
(
transform_config
):
for
part_index
,
part_name
in
enumerate
(
partition_names
):
if
is_match
(
part_name
,
layer
,
args
.
targets
,
args
.
ignore
)
and
args
.
is_online
():
if
args
.
location
==
TransformLocation
.
INPUT
:
input_tfms
[
part_index
]
=
TransformTuple
(
scheme_name
,
scheme
,
args
)
elif
args
.
location
==
TransformLocation
.
OUTPUT
:
output_tfms
[
part_index
]
=
TransformTuple
(
scheme_name
,
scheme
,
args
)
else
:
raise
ValueError
(
f
"Cannot apply `
{
args
.
location
}
` "
f
"transform to `
{
layer_name
}
`"
)
return
(
input_tfms
,
output_tfms
)
def
get_schemes_args
(
transform_config
:
Optional
[
TransformConfig
]
)
->
Generator
[
tuple
[
str
,
TransformScheme
,
TransformArgs
]]:
if
transform_config
is
None
:
return
for
scheme_name
,
scheme
in
transform_config
.
config_groups
.
items
():
for
args
in
scheme
.
apply
:
yield
(
scheme_name
,
scheme
,
args
)
def
get_layer_partition_names
(
layer_name
:
str
,
packed_modules_mapping
:
dict
[
str
,
list
[
str
]])
->
list
[
str
]:
"""
Get all partition names associated with this layer.
Names are returned in order of their partition indices.
```python
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
assert get_layer_partition_names(
"mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"]
assert get_layer_partition_names(
"mlp.down_proj", mapping) == ["down_proj"]
"""
for
fused_suffix
,
part_suffixes
in
packed_modules_mapping
.
items
():
if
layer_name
.
endswith
(
fused_suffix
):
return
[
layer_name
.
removesuffix
(
fused_suffix
)
+
part_suffix
for
part_suffix
in
part_suffixes
]
return
[
layer_name
]
vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py
0 → 100644
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections.abc
import
Hashable
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.transform
import
TransformLocation
,
TransformScheme
from
torch
import
Tensor
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.utils
import
(
# noqa: E501
TransformTuple
)
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parameter
import
SharedWeightParameter
class
HadamardTransform
(
torch
.
nn
.
Module
):
"""
Class which handles weight loading, postprocessing, and application of
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
and attention transforms method (not implemented yet)
"""
transforms
:
dict
[
int
,
TransformTuple
]
# info parsed from transforms config
weight
:
SharedWeightParameter
# container for shared tensors
kernel
:
Callable
# function used during application
scales
:
dict
[
int
,
float
]
# hadamard scale, usually sqrt(matrix.size(0))
def
__init__
(
self
,
transforms
:
dict
[
int
,
TransformTuple
],
layer
:
torch
.
nn
.
Module
,
weight_loader
:
Callable
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
kernel
:
Optional
[
Callable
]
=
None
):
super
().
__init__
()
self
.
transforms
=
transforms
self
.
scales
=
{}
if
get_tensor_model_parallel_world_size
()
>
1
:
raise
NotImplementedError
(
"Online transforms with tensor "
"parallelism is not supported"
)
# Similar to row/col parallel params, but tensors are separate
# to allow for loading with shared memory
self
.
weight
=
SharedWeightParameter
(
weight_loader
=
weight_loader
)
# create shared partition data for each partition of the original weight
input_size
=
input_size_per_partition
for
part_index
,
(
_scheme_name
,
scheme
,
args
)
in
self
.
transforms
.
items
():
output_size
=
output_partition_sizes
[
part_index
]
weight_size
=
self
.
_get_weight_size
(
layer
,
args
.
location
,
input_size
,
output_size
)
data_key
=
self
.
_get_data_key
(
scheme
,
weight_size
)
self
.
weight
.
add_partition
(
part_index
,
data_key
,
size
=
(
weight_size
,
weight_size
),
dtype
=
scheme
.
precision
,
)
# validate that shared tensors and schemes are correct
self
.
_validate_input_transforms
()
# select kernel based on transform schemes
self
.
kernel
=
self
.
_infer_kernel
()
if
kernel
is
None
else
kernel
def
process_weights_after_loading
(
self
):
for
part_id
in
self
.
weight
.
partitions
:
data
=
self
.
weight
.
partitions
[
part_id
].
data
# required by torch.compile
self
.
weight
.
process_weights_after_loading
()
# precompute scale as a runtime multiply, not division
# do not fold into weight in order to utilize FWHT
self
.
scales
[
part_id
]
=
1
/
math
.
sqrt
(
data
.
size
(
0
))
# FUTURE: avoid runtime tranpose by processing weights
# prior to apply
def
forward
(
self
,
value
:
Tensor
,
part_id
:
int
=
0
)
->
Tensor
:
if
part_id
not
in
self
.
weight
.
partitions
:
return
value
weight
=
self
.
weight
.
partitions
[
part_id
]
weight
=
weight
if
self
.
transforms
[
part_id
].
args
.
inverse
else
weight
.
T
# linear := x(W.T)
scale
=
self
.
scales
[
part_id
]
return
self
.
kernel
(
self
,
value
.
to
(
weight
.
dtype
),
weight
,
None
).
to
(
value
.
dtype
)
*
scale
def
_get_data_key
(
self
,
scheme
:
TransformScheme
,
weight_size
:
int
)
->
Hashable
:
return
(
id
(
scheme
),
weight_size
)
def
_get_weight_size
(
self
,
layer
:
torch
.
nn
.
Module
,
location
:
TransformLocation
,
input_size
:
int
,
output_size
:
int
)
->
int
:
if
isinstance
(
layer
,
LinearBase
):
if
location
==
TransformLocation
.
INPUT
:
return
input_size
elif
location
==
TransformLocation
.
OUTPUT
:
return
output_size
elif
isinstance
(
layer
,
VocabParallelEmbedding
):
if
location
==
TransformLocation
.
INPUT
:
return
output_size
elif
location
==
TransformLocation
.
OUTPUT
:
return
input_size
raise
ValueError
()
def
_validate_input_transforms
(
self
):
assert
len
(
self
.
transforms
)
>
0
location
=
list
(
self
.
transforms
.
values
())[
0
].
args
.
location
if
location
==
TransformLocation
.
INPUT
:
first_data
=
self
.
weight
.
partitions
[
0
].
data
for
partition
in
self
.
weight
.
partitions
.
values
():
if
partition
.
data
.
data_ptr
()
!=
first_data
.
data_ptr
():
raise
ValueError
(
""
)
def
_infer_kernel
(
self
)
->
Callable
:
# TODO (@ksayers): use fwht, hadacore
return
dispatch_unquantized_gemm
()
vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py
0 → 100644
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
from
vllm.model_executor.layers.quantization.compressed_tensors.transform.linear
import
(
# noqa: E501
CompressedTensorsLinearTransformMethod
)
# Because qutlass fuses hadamard with quantization, it cannot automatically be
# composed with kernels in the way CompressedTensorsLinearTransformMethod does.
# Therefore, a separate scheme must be created for each quantized dtype
class
QutlassLinearMethodNvFP4
(
CompressedTensorsLinearTransformMethod
):
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# fused hadamard quant linear method
raise
NotImplementedError
()
vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py
0 → 100644
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
NamedTuple
from
compressed_tensors.transform
import
TransformArgs
,
TransformScheme
__all__
=
[
"TransformTuple"
]
class
TransformTuple
(
NamedTuple
):
scheme_name
:
str
scheme
:
TransformScheme
args
:
TransformArgs
vllm/model_executor/parameter.py
View file @
22feac8e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Hashable
from
fractions
import
Fraction
from
fractions
import
Fraction
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
Callable
,
Optional
,
Union
from
weakref
import
WeakValueDictionary
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.utils
import
_make_synced_weight_loader
from
vllm.model_executor.utils
import
_make_synced_weight_loader
...
@@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter):
...
@@ -27,7 +30,7 @@ class BasevLLMParameter(Parameter):
into the parameter when the provided weight loader is called.
into the parameter when the provided weight loader is called.
"""
"""
def
__new__
(
cls
,
data
:
torch
.
Tensor
,
**
kwargs
):
def
__new__
(
cls
,
data
:
Optional
[
torch
.
Tensor
]
,
**
kwargs
):
return
super
().
__new__
(
cls
,
data
=
data
,
requires_grad
=
False
)
return
super
().
__new__
(
cls
,
data
=
data
,
requires_grad
=
False
)
...
@@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter):
...
@@ -81,6 +84,17 @@ class BasevLLMParameter(Parameter):
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
self
.
_assert_and_load
(
loaded_weight
)
self
.
_assert_and_load
(
loaded_weight
)
def
_shard_id_as_int
(
self
,
shard_id
:
Union
[
str
,
int
])
->
int
:
if
isinstance
(
shard_id
,
int
):
return
shard_id
# if not int, assume shard_id for qkv
# map to int and return
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
assert
isinstance
(
shard_id
,
str
)
assert
shard_id
in
qkv_idxs
return
qkv_idxs
[
shard_id
]
class
_ColumnvLLMParameter
(
BasevLLMParameter
):
class
_ColumnvLLMParameter
(
BasevLLMParameter
):
"""
"""
...
@@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -113,6 +127,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset
=
kwargs
.
get
(
"shard_offset"
)
shard_offset
=
kwargs
.
get
(
"shard_offset"
)
shard_size
=
kwargs
.
get
(
"shard_size"
)
shard_size
=
kwargs
.
get
(
"shard_size"
)
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if
isinstance
(
if
isinstance
(
self
,
self
,
(
PackedColumnParameter
,
(
PackedColumnParameter
,
...
@@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -137,6 +152,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_id
=
kwargs
.
get
(
"shard_id"
)
shard_id
=
kwargs
.
get
(
"shard_id"
)
num_heads
=
kwargs
.
get
(
"num_heads"
)
num_heads
=
kwargs
.
get
(
"num_heads"
)
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if
isinstance
(
if
isinstance
(
self
,
self
,
(
PackedColumnParameter
,
(
PackedColumnParameter
,
...
@@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
...
@@ -224,19 +240,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
"""
"""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
self
.
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
_shard_id_as_int
(
self
,
shard_id
:
Union
[
str
,
int
])
->
int
:
if
isinstance
(
shard_id
,
int
):
return
shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert
isinstance
(
shard_id
,
str
)
assert
shard_id
in
self
.
qkv_idxs
return
self
.
qkv_idxs
[
shard_id
]
# For row parallel layers, no sharding needed
# For row parallel layers, no sharding needed
# load weight into parameter as is
# load weight into parameter as is
def
load_row_parallel_weight
(
self
,
*
args
,
**
kwargs
):
def
load_row_parallel_weight
(
self
,
*
args
,
**
kwargs
):
...
@@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
...
@@ -373,6 +378,141 @@ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
pass
pass
class
SharedWeightParameter
(
BasevLLMParameter
):
"""
Parameter for weights with many shared tensors across a model
For example, when applying transforms to the "gate" and "up" partitions of
`MergedColumnParallelLinear`, the transform weights must stay separate
tensors in order to allow for tensor memory sharing between layers.
"""
# global registry for sharing tensors based on passed `data_key`
# this dict holds weaksrefs to avoid memory leak after model cleanup
tensors_registry
:
WeakValueDictionary
=
WeakValueDictionary
()
# local container for strong references to shared tensors
# this set compensates for the fact that torch.nn.Parameter
# and Parameter subclasses do not hold reliable references to tensors
local_tensors
:
set
[
torch
.
Tensor
]
# dictionary mapping partition indices to associated parameters
partitions
:
dict
[
int
,
Union
[
ModelWeightParameter
,
Parameter
]]
def
__new__
(
cls
,
**
kwargs
):
return
super
().
__new__
(
cls
,
data
=
None
,
**
kwargs
)
def
__init__
(
self
,
input_dim
:
int
=
1
,
output_dim
:
int
=
0
,
**
kwargs
):
weight_loader
:
Callable
=
kwargs
.
get
(
"weight_loader"
)
# type: ignore[assignment]
super
().
__init__
(
data
=
None
,
weight_loader
=
weight_loader
)
self
.
local_tensors
=
set
()
self
.
partitions
=
{}
self
.
kwargs
=
{
"input_dim"
:
input_dim
,
"output_dim"
:
output_dim
,
"weight_loader"
:
self
.
_fake_weight_loader
}
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
tp_size
>
1
:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
does not "
"currently support tensor parallelism"
)
def
add_partition
(
self
,
index
:
int
,
data_key
:
Hashable
,
*
args
,
**
kwargs
):
"""
Add a partition to the weight parameter. Partitions whose `data_key`
is the same will share tensor data
:param index: index of partition to add
:param data_key: hashable key used to key shared tensors
:param *args: arguments for `torch.empty`
:param **kwargs: keyword arguments for `torch.empty`
"""
# load (shared) tensor using `data_key`
if
data_key
not
in
self
.
tensors_registry
:
data
=
torch
.
empty
(
*
args
,
**
kwargs
)
self
.
tensors_registry
[
data_key
]
=
data
else
:
data
=
self
.
tensors_registry
[
data_key
]
# create associated model parameter
self
.
partitions
[
index
]
=
ModelWeightParameter
(
data
=
data
,
**
self
.
kwargs
)
# type: ignore[arg-type]
# hold local reference, since ModelWeightParameter does not
# see https://github.com/pytorch/pytorch/issues/75932
self
.
local_tensors
.
add
(
data
)
def
load_column_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
assert
len
(
self
.
partitions
)
==
1
and
0
in
self
.
partitions
partition
=
self
.
partitions
[
0
]
ModelWeightParameter
.
load_column_parallel_weight
(
partition
,
loaded_weight
)
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
):
assert
len
(
self
.
partitions
)
==
1
and
0
in
self
.
partitions
partition
=
self
.
partitions
[
0
]
ModelWeightParameter
.
load_row_parallel_weight
(
partition
,
loaded_weight
)
def
load_merged_column_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
partition_id
=
kwargs
.
pop
(
"shard_id"
)
partition_id
=
self
.
_shard_id_as_int
(
partition_id
)
partition
=
self
.
partitions
[
partition_id
]
input_dim
=
self
.
kwargs
.
get
(
"input_dim"
)
shard_size
=
partition
.
data
.
size
(
input_dim
)
//
self
.
tp_size
shard_offset
=
self
.
tp_rank
*
shard_size
ModelWeightParameter
.
load_merged_column_weight
(
partition
,
loaded_weight
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
):
partition_id
=
self
.
_shard_id_as_int
(
kwargs
.
pop
(
"shard_id"
))
partition
=
self
.
partitions
[
partition_id
]
input_dim
=
self
.
kwargs
.
get
(
"input_dim"
)
shard_size
=
partition
.
data
.
size
(
input_dim
)
//
self
.
tp_size
shard_offset
=
self
.
tp_rank
*
shard_size
shard_id
=
"q"
# fake first partition
num_heads
=
kwargs
.
get
(
"num_heads"
)
ModelWeightParameter
.
load_qkv_weight
(
partition
,
loaded_weight
,
shard_offset
=
shard_offset
,
shard_size
=
shard_size
,
shard_id
=
shard_id
,
num_heads
=
num_heads
,
)
def
process_weights_after_loading
(
self
):
for
key
in
self
.
partitions
:
self
.
partitions
[
key
]
=
torch
.
nn
.
Parameter
(
data
=
self
.
partitions
[
key
].
data
,
requires_grad
=
False
)
@
property
def
data
(
self
):
raise
ValueError
(
"Accessing `data` of a "
"`PartitionedModelWeightParameter` is not allowed. "
"Instead, use `get_partition` to get the weight of "
"the particular partition you want to access"
)
def
_fake_weight_loader
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight_shard_id
:
Optional
[
Union
[
str
,
int
]]):
raise
ValueError
(
"When loading partition weights of "
f
"
{
self
.
__class__
.
__name__
}
, use methods provided by "
f
"
{
self
.
__class__
.
__name__
}
, not partition loader"
)
def
permute_param_layout_
(
param
:
BasevLLMParameter
,
input_dim
:
int
,
def
permute_param_layout_
(
param
:
BasevLLMParameter
,
input_dim
:
int
,
output_dim
:
int
,
**
kwargs
)
->
BasevLLMParameter
:
output_dim
:
int
,
**
kwargs
)
->
BasevLLMParameter
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment