"docs/vscode:/vscode.git/clone" did not exist on "6e10e3f88158f12b7a304d3c2f803d2bbdde0823"
Unverified Commit 1c4f0c49 authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

Only do meta registrations if we have the ops (#7500)


Signed-off-by: default avatarEdward Z. Yang <ezyang@meta.com>
parent 27b84916
import functools
import torch
import torch.library
......@@ -6,20 +8,22 @@ import torchvision.extension # noqa: F401
from torch._prims_common import check
_meta_lib = torch.library.Library("torchvision", "IMPL", "Meta")
vision = torch.ops.torchvision
@functools.lru_cache(None)
def get_meta_lib():
return torch.library.Library("torchvision", "IMPL", "Meta")
def register_meta(op):
def register_meta(op_name, overload_name="default"):
def wrapper(fn):
_meta_lib.impl(op, fn)
if torchvision.extension._has_ops():
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
return fn
return wrapper
@register_meta(vision.roi_align.default)
@register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
check(
......@@ -34,7 +38,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp
return input.new_empty((num_rois, channels, pooled_height, pooled_width))
@register_meta(vision._roi_align_backward.default)
@register_meta("_roi_align_backward")
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment