Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a4c23314
Unverified
Commit
a4c23314
authored
Jul 08, 2025
by
Yan Ma
Committed by
GitHub
Jul 08, 2025
Browse files
[xpu]feat: support multi-lora on xpu (#20616)
Signed-off-by:
yan
<
yan.ma@intel.com
>
parent
b942c094
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
4 deletions
+28
-4
vllm/lora/ops/triton_ops/lora_expand_op.py
vllm/lora/ops/triton_ops/lora_expand_op.py
+2
-0
vllm/lora/ops/triton_ops/lora_shrink_op.py
vllm/lora/ops/triton_ops/lora_shrink_op.py
+2
-0
vllm/lora/ops/triton_ops/utils.py
vllm/lora/ops/triton_ops/utils.py
+9
-3
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+4
-1
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+11
-0
No files found.
vllm/lora/ops/triton_ops/lora_expand_op.py
View file @
a4c23314
...
...
@@ -13,6 +13,7 @@ import triton.language as tl
from
vllm.lora.ops.triton_ops.kernel_utils
import
do_expand_kernel
from
vllm.lora.ops.triton_ops.utils
import
_get_lora_b_ptr
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -283,6 +284,7 @@ try:
op_func
=
_lora_expand
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
_lora_expand_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
lora_expand
=
torch
.
ops
.
vllm
.
lora_expand
...
...
vllm/lora/ops/triton_ops/lora_shrink_op.py
View file @
a4c23314
...
...
@@ -13,6 +13,7 @@ import triton.language as tl
from
vllm.lora.ops.triton_ops.kernel_utils
import
do_shrink_kernel
from
vllm.lora.ops.triton_ops.utils
import
_get_lora_a_ptr
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -237,6 +238,7 @@ try:
op_func
=
_lora_shrink
,
mutates_args
=
[
"output_tensor"
],
fake_impl
=
_lora_shrink_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
lora_shrink
=
torch
.
ops
.
vllm
.
lora_shrink
...
...
vllm/lora/ops/triton_ops/utils.py
View file @
a4c23314
...
...
@@ -35,7 +35,9 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
lora_strides_d1
.
append
(
lora_a_weight
.
stride
(
1
))
lora_strides_d2
.
append
(
lora_a_weight
.
stride
(
2
))
if
len
(
lora_a_weights
)
>
1
:
lora_ptr_tensor
=
torch
.
tensor
(
tensor_ptrs
,
device
=
device
)
lora_ptr_tensor
=
torch
.
tensor
(
tensor_ptrs
,
device
=
device
,
dtype
=
torch
.
uint64
)
else
:
lora_ptr_tensor
=
lora_a_weights
[
0
]
...
...
@@ -89,8 +91,12 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
if
len
(
lora_weights
)
>
1
:
# note these are device tensors
lora_ptr_tensor
=
torch
.
tensor
(
tensor_ptrs
,
device
=
device
)
slice_start_tensor
=
torch
.
tensor
(
slice_offset_lst
,
device
=
device
)
lora_ptr_tensor
=
torch
.
tensor
(
tensor_ptrs
,
device
=
device
,
dtype
=
torch
.
uint64
)
slice_start_tensor
=
torch
.
tensor
(
slice_offset_lst
,
device
=
device
,
dtype
=
torch
.
uint64
)
else
:
slice_start_tensor
=
slice_offset_lst
[
0
]
lora_ptr_tensor
=
lora_b_weight
[
0
]
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
a4c23314
...
...
@@ -27,6 +27,7 @@ from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
,
PlaceholderModule
if
TYPE_CHECKING
:
...
...
@@ -513,7 +514,9 @@ def deserialize_tensorizer_model(model: nn.Module,
**
tensorizer_args
.
stream_kwargs
)
as
stream
,
TensorDeserializer
(
stream
,
dtype
=
tensorizer_config
.
dtype
,
device
=
torch
.
device
(
"cuda"
,
torch
.
cuda
.
current_device
()),
device
=
f
'xpu:
{
torch
.
xpu
.
current_device
()
}
'
if
current_platform
.
is_xpu
()
else
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
,
**
tensorizer_args
.
deserialization_kwargs
)
as
deserializer
:
deserializer
.
load_into_module
(
model
)
end
=
time
.
perf_counter
()
...
...
vllm/platforms/xpu.py
View file @
a4c23314
...
...
@@ -58,6 +58,10 @@ class XPUPlatform(Platform):
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
return
torch
.
xpu
.
get_device_name
(
device_id
)
@
classmethod
def
get_punica_wrapper
(
cls
)
->
str
:
return
"vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
device_props
=
torch
.
xpu
.
get_device_properties
(
device_id
)
...
...
@@ -78,6 +82,13 @@ class XPUPlatform(Platform):
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
64
# FIXME: Temporarily forcing eager mode
# remove after t.compile support stabilizes.
if
(
envs
.
VLLM_USE_V1
and
vllm_config
.
model_config
is
not
None
and
not
vllm_config
.
model_config
.
enforce_eager
):
from
vllm.config
import
CompilationLevel
vllm_config
.
compilation_config
.
level
=
CompilationLevel
.
NO_COMPILATION
# noqa: E501
# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment