Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de94289a
Unverified
Commit
de94289a
authored
Sep 24, 2025
by
Kyle Sayers
Committed by
GitHub
Sep 23, 2025
Browse files
[Core] Support weight_loader_v2 for `UnquantizedLinearMethod` (#23036)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
19836092
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
12 deletions
+70
-12
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+37
-6
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+11
-5
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+22
-1
No files found.
vllm/compilation/decorators.py
View file @
de94289a
...
...
@@ -8,6 +8,7 @@ from unittest.mock import patch
import
torch
import
torch.nn
as
nn
from
packaging
import
version
from
torch._dynamo.symbolic_convert
import
InliningInstructionTranslator
from
vllm.compilation.counter
import
compilation_counter
...
...
@@ -300,13 +301,13 @@ def _support_torch_compile(
logger
.
debug
(
"enable_cpp_symbolic_shape_guards config not available"
)
with
patch
.
object
(
InliningInstructionTranslator
,
'inline_call'
,
with
patch
.
object
(
InliningInstructionTranslator
,
"inline_call"
,
patched_inline_call
),
torch
.
_dynamo
.
config
.
patch
(
**
dynamo_config_patches
),
maybe_use_cudagraph_partition_wrapper
(
self
.
vllm_config
):
self
.
vllm_config
)
,
_torch27_patch_tensor_subclasses
()
:
output
=
self
.
compiled_callable
(
*
args
,
**
kwargs
)
return
output
# usually, capturing the model once is enough, and then we can
...
...
@@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
if
(
compilation_config
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
and
compilation_config
.
use_inductor_graph_partition
):
torch
.
_inductor
.
utils
.
set_customized_partition_wrappers
(
None
)
@
contextlib
.
contextmanager
def
_torch27_patch_tensor_subclasses
():
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
`BasevLLMParameters` without having to replace them with regular tensors
before `torch.compile`-time.
"""
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ModelWeightParameter
,
RowvLLMParameter
,
_ColumnvLLMParameter
)
def
return_false
(
*
args
,
**
kwargs
):
return
False
if
version
.
parse
(
"2.7"
)
<=
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"2.8"
):
yield
return
with
(
torch
.
_dynamo
.
config
.
patch
(
"traceable_tensor_subclasses"
,
[
BasevLLMParameter
,
ModelWeightParameter
,
_ColumnvLLMParameter
,
RowvLLMParameter
]),
patch
(
"torch._dynamo.variables.torch.can_dispatch_torch_function"
,
return_false
)):
yield
vllm/model_executor/layers/linear.py
View file @
de94289a
...
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
# yapf: disable
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
BlockQuantScaleParameter
,
ModelWeightParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
...
...
@@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
"UnquantizedLinearMethod"
,
"CompressedTensorsLinearMethod"
,
"CompressedTensorsLinearTransformMethod"
,
"BitBLASLinearMethod"
,
...
...
@@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
# The amount of memory allocated for the weights is
# sum(output_partition_sizes) * input_size_per_partition.
try
:
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
weight_loader
=
extra_weight_attrs
.
pop
(
"weight_loader"
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
except
torch
.
cuda
.
OutOfMemoryError
as
e
:
logger
.
error
(
"Failed to create unquantized linear weights: %s"
,
e
)
if
torch
.
cuda
.
is_available
():
...
...
@@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
"Failed to create unquantized linear weights. "
"This may be caused by insufficient memory to allocate "
"the weight."
)
from
e
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
...
...
vllm/model_executor/parameter.py
View file @
de94289a
...
...
@@ -61,9 +61,24 @@ class BasevLLMParameter(Parameter):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
@
property
def
weight_loader
(
self
):
def
weight_loader
(
self
)
->
Callable
:
# NOTE(@ksayers) some models such as mamba_mixer2 override the
# weight loader to support custom loading. In the future, model-specific
# weight loading should be implemented via Model.load_weights. In the
# meantime, support deleting and overriding `weight_loader`` attribute
if
self
.
_weight_loader
is
None
:
raise
AttributeError
(
f
"
{
self
.
__class__
.
__name__
}
weight_loader "
"attribute has been deleted"
)
return
self
.
_weight_loader
@
weight_loader
.
setter
def
weight_loader
(
self
,
value
:
Callable
):
self
.
_weight_loader
=
value
@
weight_loader
.
deleter
def
weight_loader
(
self
):
self
.
_weight_loader
=
None
# type: ignore[assignment]
def
_is_1d_and_scalar
(
self
,
loaded_weight
:
torch
.
Tensor
):
cond1
=
self
.
data
.
ndim
==
1
and
self
.
data
.
numel
()
==
1
cond2
=
loaded_weight
.
ndim
==
0
and
loaded_weight
.
numel
()
==
1
...
...
@@ -97,6 +112,12 @@ class BasevLLMParameter(Parameter):
assert
shard_id
in
qkv_idxs
return
qkv_idxs
[
shard_id
]
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
if
kwargs
is
None
:
kwargs
=
{}
return
super
().
__torch_function__
(
func
,
types
,
args
,
kwargs
)
class
_ColumnvLLMParameter
(
BasevLLMParameter
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment