Unverified Commit 1c4f0c49 authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

Only do meta registrations if we have the ops (#7500)


Signed-off-by: default avatarEdward Z. Yang <ezyang@meta.com>
parent 27b84916
import functools
import torch import torch
import torch.library import torch.library
...@@ -6,20 +8,22 @@ import torchvision.extension # noqa: F401 ...@@ -6,20 +8,22 @@ import torchvision.extension # noqa: F401
from torch._prims_common import check from torch._prims_common import check
_meta_lib = torch.library.Library("torchvision", "IMPL", "Meta")
vision = torch.ops.torchvision @functools.lru_cache(None)
def get_meta_lib():
return torch.library.Library("torchvision", "IMPL", "Meta")
def register_meta(op): def register_meta(op_name, overload_name="default"):
def wrapper(fn): def wrapper(fn):
_meta_lib.impl(op, fn) if torchvision.extension._has_ops():
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
return fn return fn
return wrapper return wrapper
@register_meta(vision.roi_align.default) @register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
check( check(
...@@ -34,7 +38,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp ...@@ -34,7 +38,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp
return input.new_empty((num_rois, channels, pooled_height, pooled_width)) return input.new_empty((num_rois, channels, pooled_height, pooled_width))
@register_meta(vision._roi_align_backward.default) @register_meta("_roi_align_backward")
def meta_roi_align_backward( def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment