Unverified Commit 3cdfe1f3 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Bugfix] Make torch registration of punica ops optional (#7970)

parent fdd9daaf
......@@ -160,6 +160,9 @@ def _bgmv_expand(
return
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
try:
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand = _bgmv_expand
......@@ -173,6 +173,9 @@ def _bgmv_expand_slice(
return
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
try:
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice
......@@ -142,6 +142,9 @@ def _bgmv_shrink(
return
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
try:
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_shrink = _bgmv_shrink
......@@ -192,6 +192,9 @@ def _sgmv_expand(
return
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
try:
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand = _sgmv_expand
......@@ -205,6 +205,9 @@ def _sgmv_expand_slice(
return
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
try:
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice
......@@ -189,6 +189,9 @@ def _sgmv_shrink(
return
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
try:
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_shrink = _sgmv_shrink
......@@ -10,10 +10,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
import torch
from vllm.triton_utils import HAS_TRITON
from vllm.utils import is_xpu
# FIXME: xpu path doesn't support torch.library.custom_op
if HAS_TRITON and not is_xpu():
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment