Unverified Commit ef725fea authored by wangxiyuan's avatar wangxiyuan Committed by GitHub
Browse files

[platform] support pytorch custom op pluggable (#11328)


Signed-off-by: default avatarwangxiyuan <wangxiyuan1007@gmail.com>
parent d907be7d
...@@ -57,6 +57,11 @@ class CustomOp(nn.Module): ...@@ -57,6 +57,11 @@ class CustomOp(nn.Module):
# PyTorch-native implementation. # PyTorch-native implementation.
return self.forward_native(*args, **kwargs) return self.forward_native(*args, **kwargs)
def forward_oot(self, *args, **kwargs):
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self): def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one # NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching. # specific backend. Currently, we do not support dynamic dispatching.
...@@ -81,6 +86,8 @@ class CustomOp(nn.Module): ...@@ -81,6 +86,8 @@ class CustomOp(nn.Module):
return self.forward_tpu return self.forward_tpu
elif current_platform.is_xpu(): elif current_platform.is_xpu():
return self.forward_xpu return self.forward_xpu
elif current_platform.is_out_of_tree():
return self.forward_oot
else: else:
return self.forward_cuda return self.forward_cuda
......
...@@ -45,6 +45,7 @@ class PlatformEnum(enum.Enum): ...@@ -45,6 +45,7 @@ class PlatformEnum(enum.Enum):
CPU = enum.auto() CPU = enum.auto()
NEURON = enum.auto() NEURON = enum.auto()
OPENVINO = enum.auto() OPENVINO = enum.auto()
OOT = enum.auto()
UNSPECIFIED = enum.auto() UNSPECIFIED = enum.auto()
...@@ -107,6 +108,9 @@ class Platform: ...@@ -107,6 +108,9 @@ class Platform:
def is_openvino(self) -> bool: def is_openvino(self) -> bool:
return self._enum == PlatformEnum.OPENVINO return self._enum == PlatformEnum.OPENVINO
def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT
def is_cuda_alike(self) -> bool: def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`.""" """Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment