Commit 6a10c7bf authored by unknown's avatar unknown
Browse files

提交Swin-Transformer代码

parents
DATA:
DATASET: imagenet22K
IMG_SIZE: 192
MODEL:
TYPE: swin_moe
NAME: swin_moe_small_patch4_window12_192_8expert_32gpu_22k
DROP_PATH_RATE: 0.2
SWIN_MOE:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 12
MLP_FC2_BIAS: False
INIT_STD: 0.005
MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
NUM_LOCAL_EXPERTS: -4
TOP_VALUE: 1
CAPACITY_FACTOR: 1.25
IS_GSHARD_LOSS: False
MOE_DROP: 0.1
AUX_LOSS_WEIGHT: 0.01
TRAIN:
EPOCHS: 90
WARMUP_EPOCHS: 10
WEIGHT_DECAY: 0.1
BASE_LR: 1.25e-4 # 4096 batch-size
WARMUP_LR: 1.25e-7
MIN_LR: 1.25e-6
CLIP_GRAD: 3.0
TEST:
SHUFFLE: True
\ No newline at end of file
DATA:
DATASET: imagenet22K
IMG_SIZE: 192
MODEL:
TYPE: swin_moe
NAME: swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k
DROP_PATH_RATE: 0.2
SWIN_MOE:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 12
MLP_FC2_BIAS: False
INIT_STD: 0.005
MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
NUM_LOCAL_EXPERTS: 1
TOP_VALUE: 1
CAPACITY_FACTOR: 1.25
COSINE_ROUTER: True
IS_GSHARD_LOSS: False
MOE_DROP: 0.1
AUX_LOSS_WEIGHT: 0.01
TRAIN:
EPOCHS: 90
WARMUP_EPOCHS: 10
WEIGHT_DECAY: 0.1
BASE_LR: 1.25e-4 # 4096 batch-size
WARMUP_LR: 1.25e-7
MIN_LR: 1.25e-6
CLIP_GRAD: 3.0
TEST:
SHUFFLE: True
\ No newline at end of file
DATA:
DATASET: imagenet22K
IMG_SIZE: 192
MODEL:
TYPE: swin_moe
NAME: swin_moe_small_patch4_window12_192_densebaseline_22k
DROP_PATH_RATE: 0.2
SWIN_MOE:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 12
MLP_FC2_BIAS: False
MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ]
TRAIN:
EPOCHS: 90
WARMUP_EPOCHS: 10
WEIGHT_DECAY: 0.1
BASE_LR: 1.25e-4 # 4096 batch-size
WARMUP_LR: 1.25e-7
MIN_LR: 1.25e-6
CLIP_GRAD: 3.0
MOE:
SAVE_MASTER: True
TEST:
SHUFFLE: True
\ No newline at end of file
DATA:
DATASET: imagenet22K
IMG_SIZE: 192
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window12_192_22k
DROP_PATH_RATE: 0.2
SWINV2:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 12
TRAIN:
EPOCHS: 90
WARMUP_EPOCHS: 5
WEIGHT_DECAY: 0.1
BASE_LR: 1.25e-4 # 4096 batch-size
WARMUP_LR: 1.25e-7
MIN_LR: 1.25e-6
\ No newline at end of file
DATA:
IMG_SIZE: 256
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
DROP_PATH_RATE: 0.2
SWINV2:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 16
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
TRAIN:
EPOCHS: 30
WARMUP_EPOCHS: 5
WEIGHT_DECAY: 1e-8
BASE_LR: 2e-05
WARMUP_LR: 2e-08
MIN_LR: 2e-07
\ No newline at end of file
DATA:
IMG_SIZE: 384
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window12to24_192to384_22kto1k_ft
DROP_PATH_RATE: 0.2
SWINV2:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 24
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
TRAIN:
EPOCHS: 30
WARMUP_EPOCHS: 5
WEIGHT_DECAY: 1e-8
BASE_LR: 2e-05
WARMUP_LR: 2e-08
MIN_LR: 2e-07
TEST:
CROP: False
\ No newline at end of file
DATA:
IMG_SIZE: 256
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window16_256
DROP_PATH_RATE: 0.5
SWINV2:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 16
\ No newline at end of file
DATA:
IMG_SIZE: 256
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window8_256
DROP_PATH_RATE: 0.5
SWINV2:
EMBED_DIM: 128
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 4, 8, 16, 32 ]
WINDOW_SIZE: 8
\ No newline at end of file
DATA:
DATASET: imagenet22K
IMG_SIZE: 192
MODEL:
TYPE: swinv2
NAME: swinv2_large_patch4_window12_192_22k
DROP_PATH_RATE: 0.2
SWINV2:
EMBED_DIM: 192
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 6, 12, 24, 48 ]
WINDOW_SIZE: 12
TRAIN:
EPOCHS: 90
WARMUP_EPOCHS: 5
WEIGHT_DECAY: 0.1
BASE_LR: 1.25e-4 # 4096 batch-size
WARMUP_LR: 1.25e-7
MIN_LR: 1.25e-6
\ No newline at end of file
DATA:
IMG_SIZE: 256
MODEL:
TYPE: swinv2
NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
DROP_PATH_RATE: 0.2
SWINV2:
EMBED_DIM: 192
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 6, 12, 24, 48 ]
WINDOW_SIZE: 16
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
TRAIN:
EPOCHS: 30
WARMUP_EPOCHS: 5
WEIGHT_DECAY: 1e-8
BASE_LR: 2e-05
WARMUP_LR: 2e-08
MIN_LR: 2e-07
\ No newline at end of file
DATA:
IMG_SIZE: 384
MODEL:
TYPE: swinv2
NAME: swinv2_large_patch4_window12to24_192to384_22kto1k_ft
DROP_PATH_RATE: 0.2
SWINV2:
EMBED_DIM: 192
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 6, 12, 24, 48 ]
WINDOW_SIZE: 24
PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
TRAIN:
EPOCHS: 30
WARMUP_EPOCHS: 5
WEIGHT_DECAY: 1e-8
BASE_LR: 2e-05
WARMUP_LR: 2e-08
MIN_LR: 2e-07
TEST:
CROP: False
\ No newline at end of file
DATA:
IMG_SIZE: 256
MODEL:
TYPE: swinv2
NAME: swinv2_small_patch4_window16_256
DROP_PATH_RATE: 0.3
SWINV2:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 16
\ No newline at end of file
DATA:
IMG_SIZE: 256
MODEL:
TYPE: swinv2
NAME: swinv2_small_patch4_window8_256
DROP_PATH_RATE: 0.3
SWINV2:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 18, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 8
\ No newline at end of file
DATA:
IMG_SIZE: 256
MODEL:
TYPE: swinv2
NAME: swinv2_tiny_patch4_window16_256
DROP_PATH_RATE: 0.2
SWINV2:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 6, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 16
\ No newline at end of file
DATA:
IMG_SIZE: 256
MODEL:
TYPE: swinv2
NAME: swinv2_tiny_patch4_window8_256
DROP_PATH_RATE: 0.2
SWINV2:
EMBED_DIM: 96
DEPTHS: [ 2, 2, 6, 2 ]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 8
\ No newline at end of file
from .build import build_loader as _build_loader
from .data_simmim_pt import build_loader_simmim
from .data_simmim_ft import build_loader_finetune
def build_loader(config, simmim=False, is_pretrain=False):
if not simmim:
return _build_loader(config)
if is_pretrain:
return build_loader_simmim(config)
else:
return build_loader_finetune(config)
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import torch
import numpy as np
import torch.distributed as dist
from torchvision import datasets, transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
from timm.data import create_transform
from .cached_image_folder import CachedImageFolder
from .imagenet22k_dataset import IN22KDATASET
from .samplers import SubsetRandomSampler
try:
from torchvision.transforms import InterpolationMode
def _pil_interp(method):
if method == 'bicubic':
return InterpolationMode.BICUBIC
elif method == 'lanczos':
return InterpolationMode.LANCZOS
elif method == 'hamming':
return InterpolationMode.HAMMING
else:
# default bilinear, do we want to allow nearest?
return InterpolationMode.BILINEAR
import timm.data.transforms as timm_transforms
timm_transforms._pil_interp = _pil_interp
except:
from timm.data.transforms import _pil_interp
def build_loader(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
config.freeze()
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
dataset_val, _ = build_dataset(is_train=False, config=config)
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
sampler_train = SubsetRandomSampler(indices)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
if config.TEST.SEQUENTIAL:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_val = torch.utils.data.distributed.DistributedSampler(
dataset_val, shuffle=config.TEST.SHUFFLE
)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
shuffle=False,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False
)
# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
if config.DATA.DATASET == 'imagenet':
prefix = 'train' if is_train else 'val'
if config.DATA.ZIP_MODE:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
else:
root = os.path.join(config.DATA.DATA_PATH, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif config.DATA.DATASET == 'imagenet22K':
prefix = 'ILSVRC2011fall_whole'
if is_train:
ann_file = prefix + "_map_train.txt"
else:
ann_file = prefix + "_map_val.txt"
dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)
nb_classes = 21841
else:
raise NotImplementedError("We only support ImageNet Now.")
return dataset, nb_classes
def build_transform(is_train, config):
resize_im = config.DATA.IMG_SIZE > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=config.DATA.IMG_SIZE,
is_training=True,
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
re_prob=config.AUG.REPROB,
re_mode=config.AUG.REMODE,
re_count=config.AUG.RECOUNT,
interpolation=config.DATA.INTERPOLATION,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
return transform
t = []
if resize_im:
if config.TEST.CROP:
size = int((256 / 224) * config.DATA.IMG_SIZE)
t.append(
transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
else:
t.append(
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
interpolation=_pil_interp(config.DATA.INTERPOLATION))
)
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import io
import os
import time
import torch.distributed as dist
import torch.utils.data as data
from PIL import Image
from .zipreader import is_zip_path, ZipReader
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
filename_lower = filename.lower()
return any(filename_lower.endswith(ext) for ext in extensions)
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_dataset(dir, class_to_idx, extensions):
images = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if has_file_allowed_extension(fname, extensions):
path = os.path.join(root, fname)
item = (path, class_to_idx[target])
images.append(item)
return images
def make_dataset_with_ann(ann_file, img_prefix, extensions):
images = []
with open(ann_file, "r") as f:
contents = f.readlines()
for line_str in contents:
path_contents = [c for c in line_str.split('\t')]
im_file_name = path_contents[0]
class_index = int(path_contents[1])
assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
item = (os.path.join(img_prefix, im_file_name), class_index)
images.append(item)
return images
class DatasetFolder(data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (list[string]): A list of allowed extensions.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
Attributes:
samples (list): List of (sample path, class_index) tuples
"""
def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
cache_mode="no"):
# image folder mode
if ann_file == '':
_, class_to_idx = find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
# zip mode
else:
samples = make_dataset_with_ann(os.path.join(root, ann_file),
os.path.join(root, img_prefix),
extensions)
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
self.samples = samples
self.labels = [y_1k for _, y_1k in samples]
self.classes = list(set(self.labels))
self.transform = transform
self.target_transform = target_transform
self.cache_mode = cache_mode
if self.cache_mode != "no":
self.init_cache()
def init_cache(self):
assert self.cache_mode in ["part", "full"]
n_sample = len(self.samples)
global_rank = dist.get_rank()
world_size = dist.get_world_size()
samples_bytes = [None for _ in range(n_sample)]
start_time = time.time()
for index in range(n_sample):
if index % (n_sample // 10) == 0:
t = time.time() - start_time
print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
start_time = time.time()
path, target = self.samples[index]
if self.cache_mode == "full":
samples_bytes[index] = (ZipReader.read(path), target)
elif self.cache_mode == "part" and index % world_size == global_rank:
samples_bytes[index] = (ZipReader.read(path), target)
else:
samples_bytes[index] = (path, target)
self.samples = samples_bytes
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
if isinstance(path, bytes):
img = Image.open(io.BytesIO(path))
elif is_zip_path(path):
data = ZipReader.read(path)
img = Image.open(io.BytesIO(data))
else:
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_img_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
class CachedImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
loader=default_img_loader, cache_mode="no"):
super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
ann_file=ann_file, img_prefix=img_prefix,
transform=transform, target_transform=target_transform,
cache_mode=cache_mode)
self.imgs = self.samples
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
image = self.loader(path)
if self.transform is not None:
img = self.transform(image)
else:
img = image
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
# --------------------------------------------------------
# SimMIM
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Zhenda Xie
# --------------------------------------------------------
import os
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import Mixup
from timm.data import create_transform
from timm.data.transforms import _pil_interp
def build_loader_finetune(config):
config.defrost()
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
config.freeze()
dataset_val, _ = build_dataset(is_train=False, config=config)
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
sampler_train = DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
sampler_val = DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False
)
data_loader_train = DataLoader(
dataset_train, sampler=sampler_train,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=True,
)
data_loader_val = DataLoader(
dataset_val, sampler=sampler_val,
batch_size=config.DATA.BATCH_SIZE,
num_workers=config.DATA.NUM_WORKERS,
pin_memory=config.DATA.PIN_MEMORY,
drop_last=False,
)
# setup mixup / cutmix
mixup_fn = None
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
if config.DATA.DATASET == 'imagenet':
prefix = 'train' if is_train else 'val'
root = os.path.join(config.DATA.DATA_PATH, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
else:
raise NotImplementedError("We only support ImageNet Now.")
return dataset, nb_classes
def build_transform(is_train, config):
resize_im = config.DATA.IMG_SIZE > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=config.DATA.IMG_SIZE,
is_training=True,
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
re_prob=config.AUG.REPROB,
re_mode=config.AUG.REMODE,
re_count=config.AUG.RECOUNT,
interpolation=config.DATA.INTERPOLATION,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
return transform
t = []
if resize_im:
if config.TEST.CROP:
size = int((256 / 224) * config.DATA.IMG_SIZE)
t.append(
transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
# to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
else:
t.append(
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
interpolation=_pil_interp(config.DATA.INTERPOLATION))
)
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
# --------------------------------------------------------
# SimMIM
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Zhenda Xie
# --------------------------------------------------------
import math
import random
import numpy as np
import torch
import torch.distributed as dist
import torchvision.transforms as T
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data._utils.collate import default_collate
from torchvision.datasets import ImageFolder
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
class MaskGenerator:
def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
self.input_size = input_size
self.mask_patch_size = mask_patch_size
self.model_patch_size = model_patch_size
self.mask_ratio = mask_ratio
assert self.input_size % self.mask_patch_size == 0
assert self.mask_patch_size % self.model_patch_size == 0
self.rand_size = self.input_size // self.mask_patch_size
self.scale = self.mask_patch_size // self.model_patch_size
self.token_count = self.rand_size ** 2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
def __call__(self):
mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
mask = np.zeros(self.token_count, dtype=int)
mask[mask_idx] = 1
mask = mask.reshape((self.rand_size, self.rand_size))
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
return mask
class SimMIMTransform:
def __init__(self, config):
self.transform_img = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
])
if config.MODEL.TYPE in ['swin', 'swinv2']:
model_patch_size=config.MODEL.SWIN.PATCH_SIZE
else:
raise NotImplementedError
self.mask_generator = MaskGenerator(
input_size=config.DATA.IMG_SIZE,
mask_patch_size=config.DATA.MASK_PATCH_SIZE,
model_patch_size=model_patch_size,
mask_ratio=config.DATA.MASK_RATIO,
)
def __call__(self, img):
img = self.transform_img(img)
mask = self.mask_generator()
return img, mask
def collate_fn(batch):
if not isinstance(batch[0][0], tuple):
return default_collate(batch)
else:
batch_num = len(batch)
ret = []
for item_idx in range(len(batch[0][0])):
if batch[0][0][item_idx] is None:
ret.append(None)
else:
ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
return ret
def build_loader_simmim(config):
transform = SimMIMTransform(config)
dataset = ImageFolder(config.DATA.DATA_PATH, transform)
sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
return dataloader
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment