Unverified Commit 862f2ef8 authored by Chaojun Zhang's avatar Chaojun Zhang Committed by GitHub
Browse files

[XPU] Fix the bug of LoRA logits on the XPU platform (#24081)


Signed-off-by: default avatarchzhang <chaojun.zhang@intel.com>
parent 2fd1a40a
...@@ -1151,7 +1151,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1151,7 +1151,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_logits = lora_logits.mT lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded indices_padded = self.punica_wrapper.sampler_indices_padded
if current_platform.is_tpu(): if current_platform.is_tpu() or current_platform.is_xpu():
indices_padded = indices_padded[:logits.size(0)] indices_padded = indices_padded[:logits.size(0)]
lora_logits = (lora_logits.reshape( lora_logits = (lora_logits.reshape(
......
...@@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase): ...@@ -225,6 +225,13 @@ class PunicaWrapperXPU(PunicaWrapperBase):
add_inputs=True, add_inputs=True,
**kwargs) **kwargs)
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def add_lora_logits(self, def add_lora_logits(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
...@@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase): ...@@ -259,11 +266,11 @@ class PunicaWrapperXPU(PunicaWrapperBase):
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
bgmv_expand(buffer, bgmv_expand(buffer,
lora_b_stacked, lora_b_stacked,
y, y,
self.sampler_indices, sampler_indices,
add_inputs=True) add_inputs=True)
return y.view_as(y_org) return y.view_as(y_org)
...@@ -91,7 +91,7 @@ class XPUPlatform(Platform): ...@@ -91,7 +91,7 @@ class XPUPlatform(Platform):
cache_config.block_size = 64 cache_config.block_size = 64
# lazy import to avoid circular import # lazy import to avoid circular import
from vllm.config import CUDAGraphMode from vllm.config import CompilationLevel, CUDAGraphMode
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
if compilation_config.cudagraph_mode is None or \ if compilation_config.cudagraph_mode is None or \
compilation_config.cudagraph_mode.max_cudagraph_mode() \ compilation_config.cudagraph_mode.max_cudagraph_mode() \
...@@ -100,6 +100,9 @@ class XPUPlatform(Platform): ...@@ -100,6 +100,9 @@ class XPUPlatform(Platform):
"cudagraphs. Fallback to cudagraph_mode=NONE") "cudagraphs. Fallback to cudagraph_mode=NONE")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if vllm_config.lora_config is not None:
compilation_config.level = CompilationLevel.NO_COMPILATION
# check and update parallel config # check and update parallel config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment