Unverified Commit eac3dc7b authored by Kai Zhang's avatar Kai Zhang Committed by GitHub
Browse files

Simplified usage log API (#5095)



* log API v3

* make torchscript happy

* make torchscript happy

* add missing logs to constructor

* log ops C++ API as well

* fix type hint

* check function with isinstance
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0b02d420
...@@ -310,7 +310,7 @@ class RegNet(nn.Module): ...@@ -310,7 +310,7 @@ class RegNet(nn.Module):
activation: Optional[Callable[..., nn.Module]] = None, activation: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
if stem_type is None: if stem_type is None:
stem_type = SimpleStemIN stem_type = SimpleStemIN
......
...@@ -174,7 +174,7 @@ class ResNet(nn.Module): ...@@ -174,7 +174,7 @@ class ResNet(nn.Module):
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer self._norm_layer = norm_layer
......
...@@ -13,7 +13,7 @@ class _SimpleSegmentationModel(nn.Module): ...@@ -13,7 +13,7 @@ class _SimpleSegmentationModel(nn.Module):
def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None: def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None:
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
self.backbone = backbone self.backbone = backbone
self.classifier = classifier self.classifier = classifier
self.aux_classifier = aux_classifier self.aux_classifier = aux_classifier
......
...@@ -38,7 +38,7 @@ class LRASPP(nn.Module): ...@@ -38,7 +38,7 @@ class LRASPP(nn.Module):
self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128 self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128
) -> None: ) -> None:
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
self.backbone = backbone self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels) self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
......
...@@ -100,7 +100,7 @@ class ShuffleNetV2(nn.Module): ...@@ -100,7 +100,7 @@ class ShuffleNetV2(nn.Module):
inverted_residual: Callable[..., nn.Module] = InvertedResidual, inverted_residual: Callable[..., nn.Module] = InvertedResidual,
) -> None: ) -> None:
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
if len(stages_repeats) != 3: if len(stages_repeats) != 3:
raise ValueError("expected stages_repeats as list of 3 positive ints") raise ValueError("expected stages_repeats as list of 3 positive ints")
......
...@@ -36,7 +36,7 @@ class Fire(nn.Module): ...@@ -36,7 +36,7 @@ class Fire(nn.Module):
class SqueezeNet(nn.Module): class SqueezeNet(nn.Module):
def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None: def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
self.num_classes = num_classes self.num_classes = num_classes
if version == "1_0": if version == "1_0":
self.features = nn.Sequential( self.features = nn.Sequential(
......
...@@ -37,7 +37,7 @@ class VGG(nn.Module): ...@@ -37,7 +37,7 @@ class VGG(nn.Module):
self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
) -> None: ) -> None:
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
self.features = features self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
......
...@@ -209,7 +209,7 @@ class VideoResNet(nn.Module): ...@@ -209,7 +209,7 @@ class VideoResNet(nn.Module):
zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
""" """
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
self.inplanes = 64 self.inplanes = 64
self.stem = stem() self.stem = stem()
......
...@@ -34,7 +34,8 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: ...@@ -34,7 +34,8 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
Tensor: int64 tensor with the indices of the elements that have been kept Tensor: int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores by NMS, sorted in decreasing order of scores
""" """
_log_api_usage_once("ops", "nms") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(nms)
_assert_has_ops() _assert_has_ops()
return torch.ops.torchvision.nms(boxes, scores, iou_threshold) return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
...@@ -63,7 +64,8 @@ def batched_nms( ...@@ -63,7 +64,8 @@ def batched_nms(
Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted
in decreasing order of scores in decreasing order of scores
""" """
_log_api_usage_once("ops", "batched_nms") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(batched_nms)
# Benchmarks that drove the following thresholds are at # Benchmarks that drove the following thresholds are at
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339 # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing(): if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
...@@ -122,7 +124,8 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor: ...@@ -122,7 +124,8 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
Tensor[K]: indices of the boxes that have both sides Tensor[K]: indices of the boxes that have both sides
larger than min_size larger than min_size
""" """
_log_api_usage_once("ops", "remove_small_boxes") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(remove_small_boxes)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size) keep = (ws >= min_size) & (hs >= min_size)
keep = torch.where(keep)[0] keep = torch.where(keep)[0]
...@@ -141,7 +144,8 @@ def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor: ...@@ -141,7 +144,8 @@ def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
Returns: Returns:
Tensor[N, 4]: clipped boxes Tensor[N, 4]: clipped boxes
""" """
_log_api_usage_once("ops", "clip_boxes_to_image") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(clip_boxes_to_image)
dim = boxes.dim() dim = boxes.dim()
boxes_x = boxes[..., 0::2] boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2] boxes_y = boxes[..., 1::2]
...@@ -181,8 +185,8 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: ...@@ -181,8 +185,8 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
Returns: Returns:
Tensor[N, 4]: Boxes into converted format. Tensor[N, 4]: Boxes into converted format.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once("ops", "box_convert") _log_api_usage_once(box_convert)
allowed_fmts = ("xyxy", "xywh", "cxcywh") allowed_fmts = ("xyxy", "xywh", "cxcywh")
if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts: if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt") raise ValueError("Unsupported Bounding Box Conversions for given in_fmt and out_fmt")
...@@ -232,7 +236,8 @@ def box_area(boxes: Tensor) -> Tensor: ...@@ -232,7 +236,8 @@ def box_area(boxes: Tensor) -> Tensor:
Returns: Returns:
Tensor[N]: the area for each box Tensor[N]: the area for each box
""" """
_log_api_usage_once("ops", "box_area") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_area)
boxes = _upcast(boxes) boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
...@@ -268,7 +273,8 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: ...@@ -268,7 +273,8 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
Returns: Returns:
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
""" """
_log_api_usage_once("ops", "box_iou") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_iou)
inter, union = _box_inter_union(boxes1, boxes2) inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union iou = inter / union
return iou return iou
...@@ -290,8 +296,8 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: ...@@ -290,8 +296,8 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
for every element in boxes1 and boxes2 for every element in boxes1 and boxes2
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once("ops", "generalized_box_iou") _log_api_usage_once(generalized_box_iou)
# degenerate boxes gives inf / nan results # degenerate boxes gives inf / nan results
# so do an early check # so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
...@@ -323,7 +329,8 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: ...@@ -323,7 +329,8 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
Returns: Returns:
Tensor[N, 4]: bounding boxes Tensor[N, 4]: bounding boxes
""" """
_log_api_usage_once("ops", "masks_to_boxes") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(masks_to_boxes)
if masks.numel() == 0: if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.float) return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
......
...@@ -60,8 +60,8 @@ def deform_conv2d( ...@@ -60,8 +60,8 @@ def deform_conv2d(
>>> # returns >>> # returns
>>> torch.Size([4, 5, 8, 8]) >>> torch.Size([4, 5, 8, 8])
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once("ops", "deform_conv2d") _log_api_usage_once(deform_conv2d)
_assert_has_ops() _assert_has_ops()
out_channels = weight.shape[0] out_channels = weight.shape[0]
...@@ -124,6 +124,7 @@ class DeformConv2d(nn.Module): ...@@ -124,6 +124,7 @@ class DeformConv2d(nn.Module):
bias: bool = True, bias: bool = True,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self)
if in_channels % groups != 0: if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups") raise ValueError("in_channels must be divisible by groups")
......
...@@ -77,7 +77,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -77,7 +77,7 @@ class FeaturePyramidNetwork(nn.Module):
extra_blocks: Optional[ExtraFPNBlock] = None, extra_blocks: Optional[ExtraFPNBlock] = None,
): ):
super().__init__() super().__init__()
_log_api_usage_once("ops", self.__class__.__name__) _log_api_usage_once(self)
self.inner_blocks = nn.ModuleList() self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList() self.layer_blocks = nn.ModuleList()
for in_channels in in_channels_list: for in_channels in in_channels_list:
......
...@@ -32,7 +32,8 @@ def sigmoid_focal_loss( ...@@ -32,7 +32,8 @@ def sigmoid_focal_loss(
Returns: Returns:
Loss tensor with the reduction option applied. Loss tensor with the reduction option applied.
""" """
_log_api_usage_once("ops", "sigmoid_focal_loss") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(sigmoid_focal_loss)
p = torch.sigmoid(inputs) p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets) p_t = p * targets + (1 - p) * (1 - targets)
......
...@@ -61,7 +61,7 @@ class FrozenBatchNorm2d(torch.nn.Module): ...@@ -61,7 +61,7 @@ class FrozenBatchNorm2d(torch.nn.Module):
warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning) warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning)
num_features = n num_features = n
super().__init__() super().__init__()
_log_api_usage_once("ops", self.__class__.__name__) _log_api_usage_once(self)
self.eps = eps self.eps = eps
self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("bias", torch.zeros(num_features))
...@@ -155,7 +155,7 @@ class ConvNormActivation(torch.nn.Sequential): ...@@ -155,7 +155,7 @@ class ConvNormActivation(torch.nn.Sequential):
if activation_layer is not None: if activation_layer is not None:
layers.append(activation_layer(inplace=inplace)) layers.append(activation_layer(inplace=inplace))
super().__init__(*layers) super().__init__(*layers)
_log_api_usage_once("ops", self.__class__.__name__) _log_api_usage_once(self)
self.out_channels = out_channels self.out_channels = out_channels
...@@ -179,7 +179,7 @@ class SqueezeExcitation(torch.nn.Module): ...@@ -179,7 +179,7 @@ class SqueezeExcitation(torch.nn.Module):
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
) -> None: ) -> None:
super().__init__() super().__init__()
_log_api_usage_once("ops", self.__class__.__name__) _log_api_usage_once(self)
self.avgpool = torch.nn.AdaptiveAvgPool2d(1) self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1) self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1) self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
......
...@@ -276,7 +276,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -276,7 +276,7 @@ class MultiScaleRoIAlign(nn.Module):
canonical_level: int = 4, canonical_level: int = 4,
): ):
super().__init__() super().__init__()
_log_api_usage_once("ops", self.__class__.__name__) _log_api_usage_once(self)
if isinstance(output_size, int): if isinstance(output_size, int):
output_size = (output_size, output_size) output_size = (output_size, output_size)
self.featmap_names = featmap_names self.featmap_names = featmap_names
......
...@@ -43,7 +43,8 @@ def ps_roi_align( ...@@ -43,7 +43,8 @@ def ps_roi_align(
Returns: Returns:
Tensor[K, C / (output_size[0] * output_size[1]), output_size[0], output_size[1]]: The pooled RoIs Tensor[K, C / (output_size[0] * output_size[1]), output_size[0], output_size[1]]: The pooled RoIs
""" """
_log_api_usage_once("ops", "ps_roi_align") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(ps_roi_align)
_assert_has_ops() _assert_has_ops()
check_roi_boxes_shape(boxes) check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
...@@ -68,6 +69,7 @@ class PSRoIAlign(nn.Module): ...@@ -68,6 +69,7 @@ class PSRoIAlign(nn.Module):
sampling_ratio: int, sampling_ratio: int,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.output_size = output_size self.output_size = output_size
self.spatial_scale = spatial_scale self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio self.sampling_ratio = sampling_ratio
......
...@@ -37,7 +37,8 @@ def ps_roi_pool( ...@@ -37,7 +37,8 @@ def ps_roi_pool(
Returns: Returns:
Tensor[K, C / (output_size[0] * output_size[1]), output_size[0], output_size[1]]: The pooled RoIs. Tensor[K, C / (output_size[0] * output_size[1]), output_size[0], output_size[1]]: The pooled RoIs.
""" """
_log_api_usage_once("ops", "ps_roi_pool") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(ps_roi_pool)
_assert_has_ops() _assert_has_ops()
check_roi_boxes_shape(boxes) check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
...@@ -55,6 +56,7 @@ class PSRoIPool(nn.Module): ...@@ -55,6 +56,7 @@ class PSRoIPool(nn.Module):
def __init__(self, output_size: int, spatial_scale: float): def __init__(self, output_size: int, spatial_scale: float):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.output_size = output_size self.output_size = output_size
self.spatial_scale = spatial_scale self.spatial_scale = spatial_scale
......
...@@ -50,7 +50,8 @@ def roi_align( ...@@ -50,7 +50,8 @@ def roi_align(
Returns: Returns:
Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs. Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
""" """
_log_api_usage_once("ops", "roi_align") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_align)
_assert_has_ops() _assert_has_ops()
check_roi_boxes_shape(boxes) check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
...@@ -75,6 +76,7 @@ class RoIAlign(nn.Module): ...@@ -75,6 +76,7 @@ class RoIAlign(nn.Module):
aligned: bool = False, aligned: bool = False,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.output_size = output_size self.output_size = output_size
self.spatial_scale = spatial_scale self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio self.sampling_ratio = sampling_ratio
......
...@@ -39,7 +39,8 @@ def roi_pool( ...@@ -39,7 +39,8 @@ def roi_pool(
Returns: Returns:
Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs. Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
""" """
_log_api_usage_once("ops", "roi_pool") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_pool)
_assert_has_ops() _assert_has_ops()
check_roi_boxes_shape(boxes) check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
...@@ -57,6 +58,7 @@ class RoIPool(nn.Module): ...@@ -57,6 +58,7 @@ class RoIPool(nn.Module):
def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float): def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
super().__init__() super().__init__()
_log_api_usage_once(self)
self.output_size = output_size self.output_size = output_size
self.spatial_scale = spatial_scale self.spatial_scale = spatial_scale
......
...@@ -23,7 +23,8 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) ...@@ -23,7 +23,8 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
Returns: Returns:
Tensor[N, ...]: The randomly zeroed tensor. Tensor[N, ...]: The randomly zeroed tensor.
""" """
_log_api_usage_once("ops", "stochastic_depth") if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(stochastic_depth)
if p < 0.0 or p > 1.0: if p < 0.0 or p > 1.0:
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
if mode not in ["batch", "row"]: if mode not in ["batch", "row"]:
...@@ -53,6 +54,7 @@ class StochasticDepth(nn.Module): ...@@ -53,6 +54,7 @@ class StochasticDepth(nn.Module):
def __init__(self, p: float, mode: str) -> None: def __init__(self, p: float, mode: str) -> None:
super().__init__() super().__init__()
_log_api_usage_once(self)
self.p = p self.p = p
self.mode = mode self.mode = mode
......
...@@ -140,7 +140,7 @@ class VisionTransformer(nn.Module): ...@@ -140,7 +140,7 @@ class VisionTransformer(nn.Module):
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
): ):
super().__init__() super().__init__()
_log_api_usage_once("models", self.__class__.__name__) _log_api_usage_once(self)
torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment