crop.py 4.04 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import random
from typing import Any, Dict, List

import torch
from torchvision.ops import box_iou

from transforms.v2 import Transform
from transforms.v2 import functional as F
from transforms.v2.utils import query_bounding_box, query_spatial_size
from util import datapoints


class RandomSizeCrop(Transform):
    def __init__(self, min_size: int, max_size: int):
        super().__init__()
        self.min_size = min_size
        self.max_size = max_size
    
    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        orig_h, orig_w = query_spatial_size(flat_inputs)
        crop_h = random.randint(self.min_size, min(orig_h, self.max_size))
        crop_w = random.randint(self.min_size, min(orig_w, self.max_size))
        
        # get crop region
        top = torch.randint(0, orig_h - crop_h + 1, size=(1,)).item()
        left = torch.randint(0, orig_w - crop_w + 1, size=(1,)).item()
        
        return {"left": left, "top": top, "height": crop_h, "width": crop_w}
    
    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        return F.crop(inpt, **params)


class BoxCenteredRandomSizeCrop(Transform):
    def __init__(self, min_size: int, max_size: int, sampler_options=None, trials: int = 40):
        super().__init__()
        self.min_size = min_size
        self.max_size = max_size
        self.trials = trials
        if sampler_options is None:
            sampler_options = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
        self.options = sampler_options
    
    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        orig_h, orig_w = query_spatial_size(flat_inputs)
        bboxes = query_bounding_box(flat_inputs)
        best_iou = 0
        for _ in range(self.trials):
            idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
            min_jaccard_overlap = self.options[idx]
            crop_h = random.randint(self.min_size, min(orig_h, self.max_size))
            crop_w = random.randint(self.min_size, min(orig_w, self.max_size))
            
            # get crop region
            top = torch.randint(0, orig_h - crop_h + 1, size=(1,)).item()
            left = torch.randint(0, orig_w - crop_w + 1, size=(1,)).item()
            right = left + crop_w
            bottom = top + crop_h
            
            # check for any valid boxes with centers within the crop area
            xyxy_bboxes = F.convert_format_bounding_box(
                bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
            )
            cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
            cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
            is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
            if not is_within_crop_area.any():
                continue
            
            xyxy_bboxes = xyxy_bboxes[is_within_crop_area]
            ious = box_iou(
                xyxy_bboxes,
                xyxy_bboxes.new_tensor([[left, top, right, bottom]]),
            )
            cur_region = dict(
                top=top,
                left=left,
                height=crop_h,
                width=crop_w,
                is_within_crop_area=is_within_crop_area,
            )
            
            if ious.max() > best_iou:
                best_region = cur_region
            
            if ious.max() < min_jaccard_overlap:
                continue
            
            return cur_region
        
        return best_region
    
    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
        if len(params) < 1:
            return inpt
        
        output = F.crop(
            inpt,
            top=params["top"],
            left=params["left"],
            height=params["height"],
            width=params["width"],
        )
        
        if isinstance(output, datapoints.BoundingBox):
            # We "mark" the invalid boxes as degenreate, and they can be
            # removed by a later call to SanitizeBoundingBox()
            output[~params["is_within_crop_area"]] = 0
        return output