_meta_registrations.py 2.32 KB
Newer Older
1
2
import functools

Edward Z. Yang's avatar
Edward Z. Yang committed
3
import torch
4
import torch._custom_ops
Edward Z. Yang's avatar
Edward Z. Yang committed
5
6
7
8
9
10
import torch.library

# Ensure that torch.ops.torchvision is visible
import torchvision.extension  # noqa: F401


11
12
13
@functools.lru_cache(None)
def get_meta_lib():
    return torch.library.Library("torchvision", "IMPL", "Meta")
Edward Z. Yang's avatar
Edward Z. Yang committed
14
15


16
def register_meta(op_name, overload_name="default"):
Edward Z. Yang's avatar
Edward Z. Yang committed
17
    def wrapper(fn):
18
19
        if torchvision.extension._has_ops():
            get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
Edward Z. Yang's avatar
Edward Z. Yang committed
20
21
22
23
24
        return fn

    return wrapper


25
@register_meta("roi_align")
Edward Z. Yang's avatar
Edward Z. Yang committed
26
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
27
28
    torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
    torch._check(
Edward Z. Yang's avatar
Edward Z. Yang committed
29
30
31
32
33
34
35
36
37
38
39
        input.dtype == rois.dtype,
        lambda: (
            "Expected tensor for input to have the same type as tensor for rois; "
            f"but type {input.dtype} does not equal {rois.dtype}"
        ),
    )
    num_rois = rois.size(0)
    _, channels, height, width = input.size()
    return input.new_empty((num_rois, channels, pooled_height, pooled_width))


40
@register_meta("_roi_align_backward")
Edward Z. Yang's avatar
Edward Z. Yang committed
41
42
43
def meta_roi_align_backward(
    grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
44
    torch._check(
Edward Z. Yang's avatar
Edward Z. Yang committed
45
46
47
48
49
50
51
        grad.dtype == rois.dtype,
        lambda: (
            "Expected tensor for grad to have the same type as tensor for rois; "
            f"but type {grad.dtype} does not equal {rois.dtype}"
        ),
    )
    return grad.new_empty((batch_size, channels, height, width))
52
53
54
55
56
57
58
59
60
61
62
63
64
65


@torch._custom_ops.impl_abstract("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
    torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
    torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
    torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
    torch._check(
        dets.size(0) == scores.size(0),
        lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
    )
    ctx = torch._custom_ops.get_ctx()
    num_to_keep = ctx.create_unbacked_symint()
    return dets.new_empty(num_to_keep, dtype=torch.long)