Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
862f2ef8
Unverified
Commit
862f2ef8
authored
Sep 03, 2025
by
Chaojun Zhang
Committed by
GitHub
Sep 03, 2025
Browse files
[XPU] Fix the bug of LoRA logits on the XPU platform (#24081)
Signed-off-by:
chzhang
<
chaojun.zhang@intel.com
>
parent
2fd1a40a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
5 deletions
+15
-5
vllm/lora/layers.py
vllm/lora/layers.py
+1
-1
vllm/lora/punica_wrapper/punica_xpu.py
vllm/lora/punica_wrapper/punica_xpu.py
+10
-3
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+4
-1
No files found.
vllm/lora/layers.py
View file @
862f2ef8
...
...
@@ -1151,7 +1151,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_logits
=
lora_logits
.
mT
indices_padded
=
self
.
punica_wrapper
.
sampler_indices_padded
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
()
or
current_platform
.
is_xpu
()
:
indices_padded
=
indices_padded
[:
logits
.
size
(
0
)]
lora_logits
=
(
lora_logits
.
reshape
(
...
...
vllm/lora/punica_wrapper/punica_xpu.py
View file @
862f2ef8
...
...
@@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase):
add_inputs
=
True
,
**
kwargs
)
@
property
def
sampler_indices_padded
(
self
)
->
torch
.
Tensor
:
"""
This property provides access to padded sampler indices.
"""
return
self
.
_sampler_indices_padded
[:]
def
add_lora_logits
(
self
,
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
@@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase):
buffer
=
torch
.
zeros
((
x
.
size
(
0
),
r
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
bgmv_shrink
(
x
,
lora_a_stacked
,
buffer
,
self
.
sampler_indices
,
scale
)
sampler_indices
=
torch
.
narrow
(
self
.
_sampler_indices
,
0
,
0
,
x
.
size
(
0
))
bgmv_shrink
(
x
,
lora_a_stacked
,
buffer
,
sampler_indices
,
scale
)
bgmv_expand
(
buffer
,
lora_b_stacked
,
y
,
self
.
sampler_indices
,
sampler_indices
,
add_inputs
=
True
)
return
y
.
view_as
(
y_org
)
vllm/platforms/xpu.py
View file @
862f2ef8
...
...
@@ -91,7 +91,7 @@ class XPUPlatform(Platform):
cache_config
.
block_size
=
64
# lazy import to avoid circular import
from
vllm.config
import
CUDAGraphMode
from
vllm.config
import
CompilationLevel
,
CUDAGraphMode
compilation_config
=
vllm_config
.
compilation_config
if
compilation_config
.
cudagraph_mode
is
None
or
\
compilation_config
.
cudagraph_mode
.
max_cudagraph_mode
()
\
...
...
@@ -100,6 +100,9 @@ class XPUPlatform(Platform):
"cudagraphs. Fallback to cudagraph_mode=NONE"
)
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
if
vllm_config
.
lora_config
is
not
None
:
compilation_config
.
level
=
CompilationLevel
.
NO_COMPILATION
# check and update parallel config
parallel_config
=
vllm_config
.
parallel_config
parallel_config
.
worker_cls
=
"vllm.v1.worker.xpu_worker.XPUWorker"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment