"recipes/vscode:/vscode.git/clone" did not exist on "cec19d4db601eb91aa228bb22e6daef340f169e4"
Unverified Commit 898beca5 authored by Liangliang Ma's avatar Liangliang Ma Committed by GitHub
Browse files

[BugFix][XPU] fix lora ops bgmv_expand size not match (#39989)


Signed-off-by: default avatarMa, Liangliang <liangliang.ma@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 629d45ea
...@@ -27,9 +27,42 @@ def bgmv_expand( ...@@ -27,9 +27,42 @@ def bgmv_expand(
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
add_inputs: bool = True, add_inputs: bool = True,
) -> None: ) -> None:
torch.ops._xpu_C.bgmv_expand( weight_out_dim = lora_b_weights.size(-2)
output_tensor, inputs, lora_b_weights, lora_indices_tensor, add_inputs output_dim = output_tensor.size(1)
)
if weight_out_dim == output_dim:
torch.ops._xpu_C.bgmv_expand(
output_tensor,
inputs,
lora_b_weights,
lora_indices_tensor,
add_inputs,
)
elif weight_out_dim < output_dim:
# LoRA weight output dim can be smaller than the output tensor
# (e.g. vocab_size vs padded logits). Use expand_slice to write
# only the matching portion, mirroring torch_ops common_len logic.
torch.ops._xpu_C.bgmv_expand_slice(
output_tensor,
inputs,
lora_b_weights,
lora_indices_tensor,
0,
weight_out_dim,
add_inputs,
)
else:
# Weight output dim larger than output tensor: truncate weights.
lora_b_weights = lora_b_weights[..., :output_dim, :].contiguous()
torch.ops._xpu_C.bgmv_expand_slice(
output_tensor,
inputs,
lora_b_weights,
lora_indices_tensor,
0,
output_dim,
add_inputs,
)
def bgmv_expand_slice( def bgmv_expand_slice(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment