Unverified Commit f36c5de4 authored by kurtamohler's avatar kurtamohler Committed by GitHub
Browse files

Avoid `_prims_common.check` in favor of `torch._check` (#7625)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 9d0a93ee
...@@ -6,8 +6,6 @@ import torch.library ...@@ -6,8 +6,6 @@ import torch.library
# Ensure that torch.ops.torchvision is visible # Ensure that torch.ops.torchvision is visible
import torchvision.extension # noqa: F401 import torchvision.extension # noqa: F401
from torch._prims_common import check
@functools.lru_cache(None) @functools.lru_cache(None)
def get_meta_lib(): def get_meta_lib():
...@@ -25,8 +23,8 @@ def register_meta(op_name, overload_name="default"): ...@@ -25,8 +23,8 @@ def register_meta(op_name, overload_name="default"):
@register_meta("roi_align") @register_meta("roi_align")
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
check( torch._check(
input.dtype == rois.dtype, input.dtype == rois.dtype,
lambda: ( lambda: (
"Expected tensor for input to have the same type as tensor for rois; " "Expected tensor for input to have the same type as tensor for rois; "
...@@ -42,7 +40,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp ...@@ -42,7 +40,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp
def meta_roi_align_backward( def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
): ):
check( torch._check(
grad.dtype == rois.dtype, grad.dtype == rois.dtype,
lambda: ( lambda: (
"Expected tensor for grad to have the same type as tensor for rois; " "Expected tensor for grad to have the same type as tensor for rois; "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment