Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a0a1a19
Unverified
Commit
0a0a1a19
authored
Mar 16, 2026
by
Kyuyeun Kim
Committed by
GitHub
Mar 16, 2026
Browse files
Add ability to replace oot ops when using lora (#37181)
Signed-off-by:
Kyuyeun Kim
<
kyuyeunk@google.com
>
parent
6c1cfbad
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
16 additions
and
11 deletions
+16
-11
vllm/lora/layers/column_parallel_linear.py
vllm/lora/layers/column_parallel_linear.py
+4
-3
vllm/lora/layers/replicated_linear.py
vllm/lora/layers/replicated_linear.py
+2
-1
vllm/lora/layers/row_parallel_linear.py
vllm/lora/layers/row_parallel_linear.py
+2
-1
vllm/lora/layers/vocal_parallel_embedding.py
vllm/lora/layers/vocal_parallel_embedding.py
+2
-1
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+3
-2
vllm/model_executor/layers/attention/mm_encoder_attention.py
vllm/model_executor/layers/attention/mm_encoder_attention.py
+3
-3
No files found.
vllm/lora/layers/column_parallel_linear.py
View file @
0a0a1a19
...
...
@@ -9,6 +9,7 @@ from transformers import PretrainedConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.distributed
import
tensor_model_parallel_all_gather
from
vllm.distributed.utils
import
divide
from
vllm.model_executor.custom_op
import
maybe_get_oot_by_class
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
...
...
@@ -155,9 +156,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
if
type
(
source_layer
)
is
ColumnParallelLinear
:
if
type
(
source_layer
)
is
maybe_get_oot_by_class
(
ColumnParallelLinear
)
:
return
True
if
type
(
source_layer
)
is
MergedColumnParallelLinear
:
if
type
(
source_layer
)
is
maybe_get_oot_by_class
(
MergedColumnParallelLinear
)
:
if
len
(
packed_modules_list
)
!=
1
:
return
False
# Exclude layers with 3+ output sizes - those are handled by
...
...
@@ -606,7 +607,7 @@ class MergedColumnParallelLinearVariableSliceWithLoRA(
)
->
bool
:
# Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if
type
(
source_layer
)
is
not
MergedColumnParallelLinear
:
if
type
(
source_layer
)
is
not
maybe_get_oot_by_class
(
MergedColumnParallelLinear
)
:
return
False
# If packed_modules_list has 3+ items, use this class
...
...
vllm/lora/layers/replicated_linear.py
View file @
0a0a1a19
...
...
@@ -7,6 +7,7 @@ import torch.nn as nn
from
transformers
import
PretrainedConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.model_executor.custom_op
import
maybe_get_oot_by_class
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
.base_linear
import
BaseLinearLayerWithLoRA
...
...
@@ -55,7 +56,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
return
type
(
source_layer
)
is
ReplicatedLinear
return
type
(
source_layer
)
is
maybe_get_oot_by_class
(
ReplicatedLinear
)
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
|
list
[
torch
.
Tensor
|
None
]
...
...
vllm/lora/layers/row_parallel_linear.py
View file @
0a0a1a19
...
...
@@ -11,6 +11,7 @@ from vllm.distributed import (
split_tensor_along_last_dim
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.custom_op
import
maybe_get_oot_by_class
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.platforms
import
current_platform
...
...
@@ -89,7 +90,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
return
type
(
source_layer
)
is
RowParallelLinear
return
type
(
source_layer
)
is
maybe_get_oot_by_class
(
RowParallelLinear
)
# The following layer is based on the tensor parallelism strategy given in
...
...
vllm/lora/layers/vocal_parallel_embedding.py
View file @
0a0a1a19
...
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
transformers
import
PretrainedConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.model_executor.custom_op
import
maybe_get_oot_by_class
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.platforms
import
current_platform
...
...
@@ -132,7 +133,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
packed_modules_list
:
list
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
bool
:
return
type
(
source_layer
)
is
VocabParallelEmbedding
return
type
(
source_layer
)
is
maybe_get_oot_by_class
(
VocabParallelEmbedding
)
@
property
def
weight
(
self
):
...
...
vllm/model_executor/custom_op.py
View file @
0a0a1a19
...
...
@@ -22,10 +22,11 @@ op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot
:
dict
[
str
,
type
[
"CustomOp"
]
|
type
[
"PluggableLayer"
]]
=
{}
def
get_oot_class_by_name
(
class_name
:
str
)
->
type
|
None
:
def
maybe_get_oot_by_class
(
class_type
:
type
)
->
type
:
class_name
=
class_type
.
__name__
if
class_name
in
op_registry_oot
:
return
op_registry_oot
[
class_name
]
return
Non
e
return
class_typ
e
class
PluggableLayer
(
nn
.
Module
):
...
...
vllm/model_executor/layers/attention/mm_encoder_attention.py
View file @
0a0a1a19
...
...
@@ -6,7 +6,7 @@ import numpy as np
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
,
get_oot_class
_by_name
from
vllm.model_executor.custom_op
import
CustomOp
,
maybe_
get_oot_
by_
class
from
vllm.model_executor.models.vision
import
get_vit_attn_backend
from
vllm.utils.math_utils
import
round_up
from
vllm.v1.attention.backends.fa_utils
import
get_flash_attn_version
...
...
@@ -125,7 +125,7 @@ class MMEncoderAttention(CustomOp):
cu_seqlens
:
np
.
ndarray
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
|
None
:
if
(
oot_class
:
=
get_oot_class
_by_name
(
cls
.
__name__
))
is
not
None
:
if
(
oot_class
:
=
maybe_
get_oot_
by_
class
(
cls
))
is
not
cls
:
return
oot_class
.
maybe_compute_seq_lens
(
attn_backend
,
cu_seqlens
,
device
)
# type: ignore[attr-defined]
if
attn_backend
!=
AttentionBackendEnum
.
FLASHINFER
:
...
...
@@ -149,7 +149,7 @@ class MMEncoderAttention(CustomOp):
tp_size
:
int
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
if
(
oot_class
:
=
get_oot_class
_by_name
(
cls
.
__name__
))
is
not
None
:
if
(
oot_class
:
=
maybe_
get_oot_
by_
class
(
cls
))
is
not
cls
:
return
oot_class
.
maybe_recompute_cu_seqlens
(
# type: ignore[attr-defined]
attn_backend
,
cu_seqlens
,
hidden_size
,
tp_size
,
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment