Commit 0a458091 authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Try DeiT, PiT with DETR in d2go

Summary:
Add new backbone
Experimental results are https://fburl.com/7fyecmrc

Reviewed By: bichenwu09

Differential Revision: D26877909

fbshipit-source-id: ba3f97a1e4d84bec22d6a345f1fca06c741010cc
parent b4d9aad9
# code adapt from https://www.internalfb.com/intern/diffusion/FBS/browse/master/fbcode/mobile-vision/experimental/deit/models.py
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from detectron2.modeling import Backbone, BACKBONE_REGISTRY
from detectron2.utils.file_io import PathManager
from aml.multimodal_video.utils.einops.lib import rearrange
from timm.models.vision_transformer import VisionTransformer, PatchEmbed
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
def monkey_patch_forward(self, x):
x = self.proj(x).flatten(2).transpose(1, 2)
return x
PatchEmbed.forward = monkey_patch_forward
class DistilledVisionTransformer(VisionTransformer, Backbone):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
self.head_dist.apply(self._init_weights)
self.norm = None
def _get_pos_embed(self, H, W):
embed_size = self.pos_embed.shape[-1]
# get ride of extra tokens
pos_tokens = self.pos_embed[:, 2:, :]
npatchs = pos_tokens.shape[1]
H0 = W0 = int(math.sqrt(npatchs))
if H0 == H and W0 == W:
return self.pos_embed
# reshape to 2D
pos_tokens = pos_tokens.transpose(1, 2).reshape(-1, embed_size, H0, W0)
# interp
pos_tokens = F.interpolate(
pos_tokens, size=(H, W), mode="bilinear", align_corners=False,
)
# flatten and reshape back
pos_tokens = pos_tokens.reshape(-1, embed_size, H*W).transpose(1, 2)
pos_embed = torch.cat((self.pos_embed[:, :2, :], pos_tokens), dim=1)
return pos_embed
def forward_features(self, x):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
patch_size = self.patch_embed.patch_size[0]
H, W = x.shape[-2:]
H, W = H // patch_size, W // patch_size
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
# pick the spatial embed and do iterp
pos_embed = self._get_pos_embed(H, W)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
#x = self.norm(x)
spatial = rearrange(x[:, 2:], 'b (h w) c -> b c h w', h=H, w=W)
return x[:, 0], x[:, 1], spatial
def forward(self, x):
x, x_dist, x0 = self.forward_features(x)
return x0
# x = self.head(x)
# x_dist = self.head_dist(x_dist)
# if self.training:
# return x, x_dist
# else:
# # during inference, return the average of both classifier predictions
# return (x + x_dist) / 2
def _cfg(input_size=224, url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, input_size, input_size), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
def deit_scalable_distilled(model_config, pretrained=False, **kwargs):
assert not pretrained
model = DistilledVisionTransformer(
img_size=model_config["I"],
patch_size=model_config["p"],
embed_dim=model_config["h"] * model_config["e"],
depth=model_config["d"],
num_heads=model_config["h"],
mlp_ratio=model_config["r"],
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
model.default_cfg = _cfg(input_size=model_config["I"])
print("model arch config: {}".format(model_config))
print("model train config: {}".format(model.default_cfg))
return model
def add_deit_backbone_config(cfg):
cfg.MODEL.DEIT = type(cfg)()
cfg.MODEL.DEIT.MODEL_CONFIG = None
cfg.MODEL.DEIT.WEIGHTS = None
@BACKBONE_REGISTRY.register()
def deit_d2go_model_wrapper(cfg, _):
assert cfg.MODEL.DEIT.MODEL_CONFIG is not None
with PathManager.open(cfg.MODEL.DEIT.MODEL_CONFIG) as f:
model_config = json.load(f)
model = deit_scalable_distilled(
model_config,
num_classes=0, # set num_classes=0 to avoid building cls head
drop_rate=0,
drop_path_rate=0.1,
)
# load weights
if cfg.MODEL.DEIT.WEIGHTS is not None:
with PathManager.open(cfg.MODEL.DEIT.WEIGHTS, "rb") as f:
state_dict = torch.load(f, map_location="cpu")["model"]
rm_keys = [k for k in state_dict if "head" in k]
rm_keys = rm_keys + ["norm.weight", "norm.bias"]
print(rm_keys)
for k in rm_keys:
del state_dict[k]
model.load_state_dict(state_dict)
print(f"loaded weights from {cfg.MODEL.DEIT.WEIGHTS}")
return model
# https://www.internalfb.com/intern/diffusion/FBS/browse/master/fbcode/mobile-vision/experimental/deit/pit_models.py
# PiT
# Copyright 2021-present NAVER Corp.
# Apache License v2.0
import json
import torch
from aml.multimodal_video.utils.einops.lib import rearrange
from torch import nn
import torch.nn.functional as F
import math
from detectron2.modeling import Backbone, BACKBONE_REGISTRY
from detectron2.utils.file_io import PathManager
from functools import partial
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block as transformer_block
from timm.models.registry import register_model
class Transformer(nn.Module):
def __init__(self, base_dim, depth, heads, mlp_ratio,
drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
super(Transformer, self).__init__()
self.layers = nn.ModuleList([])
embed_dim = base_dim * heads
if drop_path_prob is None:
drop_path_prob = [0.0 for _ in range(depth)]
self.blocks = nn.ModuleList([
transformer_block(
dim=embed_dim,
num_heads=heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=drop_path_prob[i],
norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
for i in range(depth)])
def forward(self, x, cls_tokens):
h, w = x.shape[2:4]
x = rearrange(x, 'b c h w -> b (h w) c')
token_length = cls_tokens.shape[1]
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
cls_tokens = x[:, :token_length]
x = x[:, token_length:]
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
return x, cls_tokens
class conv_head_pooling(nn.Module):
def __init__(self, in_feature, out_feature, stride, conv_type,
padding_mode='zeros', dilation=1):
super(conv_head_pooling, self).__init__()
if conv_type=="depthwise":
_groups = in_feature
else:
_groups = 1
print("_groups in conv_head_pooling: ", _groups)
self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=3,
padding=dilation, dilation=dilation, stride=stride,
padding_mode=padding_mode, groups=_groups)
self.fc = nn.Linear(in_feature, out_feature)
def forward(self, x, cls_token):
x = self.conv(x)
cls_token = self.fc(cls_token)
return x, cls_token
class conv_embedding(nn.Module):
def __init__(self, in_channels, out_channels, patch_size,
stride, padding):
super(conv_embedding, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size,
stride=stride, padding=padding, bias=True)
def forward(self, x):
x = self.conv(x)
return x
class PoolingTransformer(Backbone):
def __init__(self, image_size, patch_size, stride, base_dims, depth, heads,
mlp_ratio, conv_type="depthwise", num_classes=1000, in_chans=3,
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0,
dilated=False):
super(PoolingTransformer, self).__init__()
total_block = sum(depth)
padding = 0
block_idx = 0
self.padding = padding
self.stride = stride
width = math.floor(
(image_size + 2 * padding - patch_size) / stride + 1)
self.conv_type = conv_type
self.base_dims = base_dims
self.heads = heads
self.num_classes = num_classes
self.patch_size = patch_size
self.pos_embed = nn.Parameter(
torch.randn(1, base_dims[0] * heads[0], width, width),
requires_grad=True
)
self.patch_embed = conv_embedding(in_chans, base_dims[0] * heads[0],
patch_size, stride, padding)
self.cls_token = nn.Parameter(
torch.randn(1, 1, base_dims[0] * heads[0]),
requires_grad=True
)
self.pos_drop = nn.Dropout(p=drop_rate)
self.transformers = nn.ModuleList([])
self.pools = nn.ModuleList([])
for stage in range(len(depth)):
drop_path_prob = [drop_path_rate * i / total_block
for i in range(block_idx, block_idx + depth[stage])]
block_idx += depth[stage]
self.transformers.append(
Transformer(base_dims[stage], depth[stage], heads[stage],
mlp_ratio,
drop_rate, attn_drop_rate, drop_path_prob)
)
if stage < len(heads) - 1:
if stage == len(heads) - 2 and dilated:
pool_dilation = 2
pool_stride = 1
else:
pool_dilation = 1
pool_stride = 2
self.pools.append(
conv_head_pooling(base_dims[stage] * heads[stage],
base_dims[stage + 1] * heads[stage + 1],
stride=pool_stride, dilation=pool_dilation,
conv_type=self.conv_type
)
)
#self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
self.embed_dim = base_dims[-1] * heads[-1]
# Classifier head
if num_classes > 0:
self.head = nn.Linear(base_dims[-1] * heads[-1], num_classes)
else:
self.head = nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
if num_classes > 0:
self.head = nn.Linear(self.embed_dim, num_classes)
else:
self.head = nn.Identity()
def _get_pos_embed(self, H, W):
H0, W0 = self.pos_embed.shape[-2:]
if H0 == H and W0 == W:
return self.pos_embed
# interp
pos_embed = F.interpolate(
self.pos_embed, size=(H, W), mode="bilinear", align_corners=False,
)
return pos_embed
def forward_features(self, x):
H, W = x.shape[-2:]
x = self.patch_embed(x)
# featuremap size after patch embeding
H = math.floor(
(H + 2 * self.padding - self.patch_size) / self.stride + 1)
W = math.floor(
(W + 2 * self.padding - self.patch_size) / self.stride + 1)
pos_embed = self._get_pos_embed(H, W)
x = self.pos_drop(x + pos_embed)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
for stage in range(len(self.pools)):
x, cls_tokens = self.transformers[stage](x, cls_tokens)
x, cls_tokens = self.pools[stage](x, cls_tokens)
x, cls_tokens = self.transformers[-1](x, cls_tokens)
#cls_tokens = self.norm(cls_tokens) # no gradient for layer norm, which cause failure
return cls_tokens, x
def forward(self, x):
cls_token, _ = self.forward_features(x)
cls_token = self.head(cls_token[:, 0])
return cls_token
class DistilledPoolingTransformer(PoolingTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cls_token = nn.Parameter(
torch.randn(1, 2, self.base_dims[0] * self.heads[0]),
requires_grad=True)
if self.num_classes > 0:
self.head_dist = nn.Linear(self.base_dims[-1] * self.heads[-1],
self.num_classes)
else:
self.head_dist = nn.Identity()
trunc_normal_(self.cls_token, std=.02)
self.head_dist.apply(self._init_weights)
def forward(self, x):
cls_token, x = self.forward_features(x)
return x
#x_cls = self.head(cls_token[:, 0])
#x_dist = self.head_dist(cls_token[:, 1])
#if self.training:
# return x_cls, x_dist
#else:
# return (x_cls + x_dist) / 2
def pit_scalable_distilled(model_config, pretrained=False, print_info=True, **kwargs):
if "conv_type" in model_config:
conv_type = model_config["conv_type"]
else:
conv_type = "depthwise"
model = DistilledPoolingTransformer(
image_size=model_config["I"],
patch_size=model_config["p"],
stride=model_config["s"],
base_dims=model_config["e"],
depth=model_config["d"],
heads=model_config["h"],
mlp_ratio=model_config["r"],
conv_type=conv_type,
**kwargs
)
if print_info:
print("model arch config: {}".format(model_config))
assert pretrained == False, "pretrained must be False"
return model
def add_pit_backbone_config(cfg):
cfg.MODEL.PIT = type(cfg)()
cfg.MODEL.PIT.MODEL_CONFIG = None
cfg.MODEL.PIT.WEIGHTS = None
cfg.MODEL.PIT.DILATED = True
@BACKBONE_REGISTRY.register()
def pit_d2go_model_wrapper(cfg, _):
assert cfg.MODEL.PIT.MODEL_CONFIG is not None
dilated = cfg.MODEL.PIT.DILATED
with PathManager.open(cfg.MODEL.PIT.MODEL_CONFIG) as f:
model_config = json.load(f)
model = pit_scalable_distilled(
model_config,
num_classes=0, # set num_classes=0 to avoid building cls head
drop_rate=0,
drop_path_rate=0.1,
dilated=dilated,
)
# load weights
if cfg.MODEL.PIT.WEIGHTS is not None:
with PathManager.open(cfg.MODEL.PIT.WEIGHTS, "rb") as f:
state_dict = torch.load(f, map_location="cpu")["model"]
rm_keys = [k for k in state_dict if "head" in k]
rm_keys = rm_keys + ["norm.weight", "norm.bias"]
print(rm_keys)
for k in rm_keys:
del state_dict[k]
model.load_state_dict(state_dict)
print(f"loaded weights from {cfg.MODEL.PIT.WEIGHTS}")
return model
......@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from detectron2.config import CfgNode as CN
from d2go.config import CfgNode as CN
def add_detr_config(cfg):
......@@ -11,6 +11,9 @@ def add_detr_config(cfg):
"""
cfg.MODEL.DETR = CN()
cfg.MODEL.DETR.NUM_CLASSES = 80
cfg.MODEL.BACKBONE.SIMPLE = False
cfg.MODEL.BACKBONE.STRIDE = 1
cfg.MODEL.BACKBONE.CHANNEL = 0
# FBNet
cfg.MODEL.FBNET_V2.OUT_FEATURES = ["trunk3"]
......
......@@ -120,4 +120,5 @@ class DetrDatasetMapper:
]
instances = utils.annotations_to_instances(annos, image_shape)
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict
......@@ -76,11 +76,10 @@ class ResNetMaskedBackbone(nn.Module):
return masks
class FBNetMaskedBackbone(nn.Module):
"""This is a thin wrapper around D2's backbone to provide padding masking"""
class FBNetMaskedBackbone(ResNetMaskedBackbone):
""" This is a thin wrapper around D2's backbone to provide padding masking"""
def __init__(self, cfg):
super().__init__()
nn.Module.__init__(self)
self.backbone = build_backbone(cfg)
self.out_features = cfg.MODEL.FBNET_V2.OUT_FEATURES
self.feature_strides = list(self.backbone._out_feature_strides.values())
......@@ -105,22 +104,32 @@ class FBNetMaskedBackbone(nn.Module):
ret_features[k] = NestedTensor(features[k], masks[i])
return ret_features
def mask_out_padding(self, feature_shapes, image_sizes, device):
masks = []
assert len(feature_shapes) == len(self.feature_strides)
for idx, shape in enumerate(feature_shapes):
N, _, H, W = shape
masks_per_feature_level = torch.ones(
(N, H, W), dtype=torch.bool, device=device
class SimpleSingleStageBackbone(ResNetMaskedBackbone):
"""This is a simple wrapper for single stage backbone,
please set the required configs:
cfg.MODEL.BACKBONE.SIMPLE == True,
cfg.MODEL.BACKBONE.STRIDE, cfg.MODEL.BACKBONE.CHANNEL
"""
def __init__(self, cfg):
nn.Module.__init__(self)
self.backbone = build_backbone(cfg)
self.out_features = ['out']
assert cfg.MODEL.BACKBONE.SIMPLE is True
self.feature_strides = [cfg.MODEL.BACKBONE.STRIDE]
self.num_channels = [cfg.MODEL.BACKBONE.CHANNEL]
self.strides = [cfg.MODEL.BACKBONE.STRIDE]
def forward(self, images):
y = self.backbone(images.tensor)
masks = self.mask_out_padding(
[y.shape],
images.image_sizes,
images.tensor.device,
)
for img_idx, (h, w) in enumerate(image_sizes):
masks_per_feature_level[
img_idx,
: int(np.ceil(float(h) / self.feature_strides[idx])),
: int(np.ceil(float(w) / self.feature_strides[idx])),
] = 0
masks.append(masks_per_feature_level)
return masks
assert len(masks) == 1
ret_features = {}
ret_features[self.out_features[0]] = NestedTensor(y, masks[0])
return ret_features
@META_ARCH_REGISTRY.register()
......@@ -158,8 +167,10 @@ class Detr(nn.Module):
N_steps = hidden_dim // 2
if "resnet" in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = ResNetMaskedBackbone(cfg)
elif "fbnet" in cfg.MODEL.BACKBONE.NAME.lower():
elif 'fbnet' in cfg.MODEL.BACKBONE.NAME.lower():
d2_backbone = FBNetMaskedBackbone(cfg)
elif cfg.MODEL.BACKBONE.SIMPLE:
d2_backbone = SimpleSingleStageBackbone(cfg)
else:
raise NotImplementedError
......
......@@ -5,6 +5,8 @@ from d2go.data.dataset_mappers.build import D2GO_DATA_MAPPER_REGISTRY
from d2go.data.dataset_mappers.d2go_dataset_mapper import D2GoDatasetMapper
from d2go.runner import GeneralizedRCNNRunner
from detr.d2 import DetrDatasetMapper, add_detr_config
from detr.backbone.deit import add_deit_backbone_config
from detr.backbone.pit import add_pit_backbone_config
@D2GO_DATA_MAPPER_REGISTRY.register()
......@@ -30,5 +32,6 @@ class DETRRunner(GeneralizedRCNNRunner):
def get_default_cfg(self):
_C = super().get_default_cfg()
add_detr_config(_C)
_C.MODEL.DETR = CN(_C.MODEL.DETR)
add_deit_backbone_config(_C)
add_pit_backbone_config(_C)
return _C
import unittest
from detr.backbone.deit import add_deit_backbone_config
from detr.backbone.pit import add_pit_backbone_config
import torch
from detectron2.utils.file_io import PathManager
from detectron2.checkpoint import DetectionCheckpointer
from d2go.config import CfgNode as CN
from detectron2.modeling import BACKBONE_REGISTRY
import logging
logger = logging.getLogger(__name__)
# avoid testing on sandcastle due to access to manifold
USE_CUDA = torch.cuda.device_count() > 0
class TestTransformerBackbone(unittest.TestCase):
@unittest.skipIf(not USE_CUDA,"avoid testing on sandcastle due to access to manifold")
def test_deit_model(self):
cfg = CN()
cfg.MODEL = CN()
add_deit_backbone_config(cfg)
build_model = BACKBONE_REGISTRY.get("deit_d2go_model_wrapper")
deit_models = {
"8X-7-RM_4": 170,
"DeiT-Tiny": 224,
"DeiT-Small": 224,
"32X-1-RM_2": 221,
"8X-7": 160,
"32X-1": 256,
}
deit_model_weights = {
"8X-7-RM_4": "manifold://mobile_vision_workflows/tree/workflows/kyungminkim/20210511/deit_[model]deit_scaling_distill_[bs]128_[mcfg]8X-7-RM_4_.OIXarYpbZw/checkpoint_best.pth",
"DeiT-Tiny": "manifold://mobile_vision_workflows/tree/workflows/cl114/DeiT-official-ckpt/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
"DeiT-Small": "manifold://mobile_vision_workflows/tree/workflows/cl114/DeiT-official-ckpt/deit_small_distilled_patch16_224-649709d9.pth",
"32X-1-RM_2": "manifold://mobile_vision_workflows/tree/workflows/kyungminkim/20210511/deit_[model]deit_scaling_distill_[bs]64_[mcfg]32X-1-RM_2_.xusuFyNMdD/checkpoint_best.pth",
"8X-7": "manifold://mobile_vision_workflows/tree/workflows/cl114/scaled_best/8X-7.pth",
"32X-1": "manifold://mobile_vision_workflows/tree/workflows/cl114/scaled_best/32X-1.pth",
}
for model_name, org_size in deit_models.items():
print("model_name", model_name)
cfg.MODEL.DEIT.MODEL_CONFIG = f"manifold://mobile_vision_workflows/tree/workflows/wbc/deit/model_cfgs/{model_name}.json"
cfg.MODEL.DEIT.WEIGHTS = deit_model_weights[model_name]
model = build_model(cfg, None)
model.eval()
for input_size_h in [org_size, 192, 224, 256, 320]:
for input_size_w in [org_size, 192, 224, 256, 320]:
x = torch.rand(1, 3, input_size_h, input_size_w)
y = model(x)
print(f"x.shape: {x.shape}, y.shape: {y.shape}")
@unittest.skipIf(not USE_CUDA,"avoid testing on sandcastle due to access to manifold")
def test_pit_model(self):
cfg = CN()
cfg.MODEL = CN()
add_pit_backbone_config(cfg)
build_model = BACKBONE_REGISTRY.get("pit_d2go_model_wrapper")
pit_models = {
"pit_ti_ours": 160,
"pit_ti": 224,
"pit_s_ours_v1": 256,
"pit_s": 224,
}
pit_model_weights = {
"pit_ti_ours": "manifold://mobile_vision_workflows/tree/workflows/kyungminkim/20210515/deit_[model]pit_scalable_distilled_[bs]128_[mcfg]pit_ti_ours_.HImkjNCpJI/checkpoint_best.pth",
"pit_ti": "manifold://mobile_vision_workflows/tree/workflows/kyungminkim/20210515/deit_[model]pit_scalable_distilled_[bs]128_[mcfg]pit_ti_.QJeFNUfYOD/checkpoint_best.pth",
"pit_s_ours_v1": "manifold://mobile_vision_workflows/tree/workflows/kyungminkim/20210515/deit_[model]pit_scalable_distilled_[bs]64_[mcfg]pit_s_ours_v1_.LXdwyBDaNY/checkpoint_best.pth",
"pit_s": "manifold://mobile_vision_workflows/tree/workflows/kyungminkim/20210515/deit_[model]pit_scalable_distilled_[bs]128_[mcfg]pit_s_.zReQLPOuJe/checkpoint_best.pth",
}
for model_name, org_size in pit_models.items():
print("model_name", model_name)
cfg.MODEL.PIT.MODEL_CONFIG = f"manifold://mobile_vision_workflows/tree/workflows/wbc/deit/model_cfgs/{model_name}.json"
cfg.MODEL.PIT.WEIGHTS = pit_model_weights[model_name]
cfg.MODEL.PIT.DILATED = True
model = build_model(cfg, None)
model.eval()
for input_size_h in [org_size, 192, 224, 256, 320]:
for input_size_w in [org_size, 192, 224, 256, 320]:
x = torch.rand(1, 3, input_size_h, input_size_w)
y = model(x)
print(f"x.shape: {x.shape}, y.shape: {y.shape}")
......@@ -6,7 +6,7 @@ import os
import tempfile
import unittest
from detr import runner as oss_runner
import d2go.projects.detr.runner as oss_runner
import d2go.runner.default_runner as default_runner
from d2go.utils.testing.data_loader_helper import create_local_dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment