Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 8
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: image_pretraining
data: /datasets01/imagenet_full_size/061417/
dataset:
num_workers: 6
batch_size: 64
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
optimization:
max_update: 400000
lr: [0.0005]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: cosine
warmup_updates: 10000
model:
_name: data2vec_vision
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: image_pretraining
data: /datasets01/imagenet_full_size/061417
dataset:
num_workers: 6
batch_size: 128
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 2
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: legacy_ddp
criterion:
_name: model
log_keys:
- ema_decay
- target_var
- pred_var
optimization:
max_update: 375300 #300*1251
lr: [0.0005]
clip_norm: 3.0
optimizer:
_name: adam
adam_betas: (0.9,0.999)
adam_eps: 1e-08
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 12510 # it should be 10 epochs
model:
_name: data2vec_vision
attention_dropout: 0.05
ema_decay: 0.999
ema_end_decay: 0.9998
layer_norm_targets: True
average_top_k_layers: 6
loss_beta: 2.0
drop_path: 0.25
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tb
fp16_no_flatten_grads: true
checkpoint:
save_interval: 5
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
task:
_name: mae_image_pretraining
data: /datasets01/imagenet_full_size/061417/
rebuild_batches: true
dataset:
num_workers: 6
batch_size: 64
skip_invalid_size_inputs_valid_test: true
required_batch_size_multiple: 1
disable_validation: true
distributed_training:
distributed_world_size: 16
ddp_backend: c10d
criterion:
_name: model
optimization:
max_update: 375300
lr: [0.0006]
optimizer:
_name: composite
groups:
with_decay:
lr_float: 6e-4
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0.05
lr_scheduler:
_name: cosine
warmup_updates: 50040
no_decay:
lr_float: 6e-4
optimizer:
_name: adam
adam_betas: [0.9,0.95]
weight_decay: 0
lr_scheduler:
_name: cosine
warmup_updates: 50040
lr_scheduler: pass_through
model:
_name: mae
# @package _global_
hydra:
sweep:
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
distributed_training:
distributed_world_size: 1
nprocs_per_node: 1
distributed_port: -1
common:
log_interval: 1
dataset:
num_workers: 0
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 0
nodes: 1
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
- task.local_cache_path
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
- task.local_cache_path
sweep:
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 2
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 80
gpus_per_node: 8
tasks_per_node: 1
mem_gb: 450
nodes: 3
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 450
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: devlab,learnlab,learnfair,scavenge
constraint: volta32gb,ib4
max_num_timeout: 30
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 4
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 6
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30
# @package _global_
hydra:
job:
config:
override_dirname:
kv_sep: ':'
item_sep: '/'
exclude_keys:
- run_config
- distributed_training.distributed_port
- distributed_training.distributed_world_size
- model.pretrained_model_path
- model.target_network_path
- next_script
- task.cache_in_scratch
- task.data
- checkpoint.save_interval_updates
- checkpoint.keep_interval_updates
- checkpoint.save_on_overflow
- common.log_interval
- common.user_dir
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
subdir: ''
launcher:
submitit_folder: ${hydra.sweep.dir}
timeout_min: 4320
cpus_per_task: 10
gpus_per_node: 8
tasks_per_node: 8
mem_gb: 0
nodes: 8
name: ${env:PREFIX}_${hydra.job.config_name}
partition: wav2vec,learnlab,learnfair
max_num_timeout: 30
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .image_dataset import ImageDataset
from .path_dataset import PathDataset
from .mae_image_dataset import MaeImageDataset
from .mae_finetuning_image_dataset import MaeFinetuningImageDataset
__all__ = [
"ImageDataset",
"MaeImageDataset",
"MaeFinetuningImageDataset",
"PathDataset",
]
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.data import BaseWrapperDataset, data_utils
class AddClassTargetDataset(BaseWrapperDataset):
def __init__(
self,
dataset,
labels,
multi_class,
num_classes=None,
label_indices=None,
add_to_input=True,
):
super().__init__(dataset)
self.label_indices = label_indices
self.labels = labels
self.multi_class = multi_class
self.add_to_input = add_to_input
if num_classes is None and multi_class:
assert self.label_indices is not None
num_classes = len(self.label_indices)
self.num_classes = num_classes
def __getitem__(self, index):
item = self.dataset[index]
item_labels = self.labels[index]
if self.multi_class:
item["label"] = torch.zeros(self.num_classes)
for il in item_labels:
if self.label_indices is not None:
il = self.label_indices[il]
item["label"][il] = 1.0
else:
item["label"] = torch.tensor(
self.labels[index]
if self.label_indices is None
else self.label_indices[self.labels[index]]
)
return item
def collater(self, samples):
collated = self.dataset.collater(samples)
if len(collated) == 0:
return collated
indices = set(collated["id"].tolist())
target = [s["label"] for s in samples if s["id"] in indices]
collated["label"] = torch.stack(target, dim=0)
if self.add_to_input:
collated["net_input"]["label"] = collated["label"]
return collated
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import os
from typing import Optional, Callable, Set
import torch
from torchvision.datasets.vision import VisionDataset
from torchvision.transforms import ToTensor
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class ImageDataset(FairseqDataset, VisionDataset):
def __init__(
self,
root: str,
extensions: Set[str],
load_classes: bool,
transform: Optional[Callable] = None,
shuffle=True,
):
FairseqDataset.__init__(self)
VisionDataset.__init__(self, root=root, transform=transform)
self.shuffle = shuffle
self.tensor_transform = ToTensor()
self.classes = None
self.labels = None
if load_classes:
classes = [d.name for d in os.scandir(root) if d.is_dir()]
classes.sort()
self.classes = {cls_name: i for i, cls_name in enumerate(classes)}
logger.info(f"loaded {len(self.classes)} classes")
self.labels = []
def walk_path(root_path):
for root, _, fnames in sorted(os.walk(root_path, followlinks=True)):
for fname in sorted(fnames):
fname_ext = os.path.splitext(fname)
if fname_ext[-1].lower() not in extensions:
continue
path = os.path.join(root, fname)
yield path
logger.info(f"finding images in {root}")
if self.classes is not None:
self.files = []
self.labels = []
for c, i in self.classes.items():
for f in walk_path(os.path.join(root, c)):
self.files.append(f)
self.labels.append(i)
else:
self.files = [f for f in walk_path(root)]
logger.info(f"loaded {len(self.files)} examples")
def __getitem__(self, index):
from PIL import Image
fpath = self.files[index]
with open(fpath, "rb") as f:
img = Image.open(f).convert("RGB")
if self.transform is None:
img = self.tensor_transform(img)
else:
img = self.transform(img)
assert torch.is_tensor(img)
res = {"id": index, "img": img}
if self.labels is not None:
res["label"] = self.labels[index]
return res
def __len__(self):
return len(self.files)
def collater(self, samples):
if len(samples) == 0:
return {}
collated_img = torch.stack([s["img"] for s in samples], dim=0)
res = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": {
"img": collated_img,
},
}
if "label" in samples[0]:
res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples])
return res
def num_tokens(self, index):
return 1
def size(self, index):
return 1
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
return order[0]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import os
import torch
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import PIL
from fairseq.data import FairseqDataset
from .mae_image_dataset import caching_loader
logger = logging.getLogger(__name__)
def build_transform(is_train, input_size, color_jitter, aa, reprob, remode, recount):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=input_size,
is_training=True,
color_jitter=color_jitter,
auto_augment=aa,
interpolation="bicubic",
re_prob=reprob,
re_mode=remode,
re_count=recount,
mean=mean,
std=std,
)
return transform
# eval transform
t = []
if input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(input_size / crop_pct)
t.append(
transforms.Resize(
size, interpolation=PIL.Image.BICUBIC
), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
class MaeFinetuningImageDataset(FairseqDataset):
def __init__(
self,
root: str,
split: str,
is_train: bool,
input_size,
color_jitter=None,
aa="rand-m9-mstd0.5-inc1",
reprob=0.25,
remode="pixel",
recount=1,
local_cache_path=None,
shuffle=True,
):
FairseqDataset.__init__(self)
self.shuffle = shuffle
transform = build_transform(
is_train, input_size, color_jitter, aa, reprob, remode, recount
)
path = os.path.join(root, split)
loader = caching_loader(local_cache_path, datasets.folder.default_loader)
self.dataset = datasets.ImageFolder(path, loader=loader, transform=transform)
logger.info(f"loaded {len(self.dataset)} examples")
def __getitem__(self, index):
img, label = self.dataset[index]
return {"id": index, "img": img, "label": label}
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if len(samples) == 0:
return {}
collated_img = torch.stack([s["img"] for s in samples], dim=0)
res = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": {
"imgs": collated_img,
},
}
if "label" in samples[0]:
res["net_input"]["labels"] = torch.LongTensor([s["label"] for s in samples])
return res
def num_tokens(self, index):
return 1
def size(self, index):
return 1
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
return order[0]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import logging
import math
import random
import time
import numpy as np
import os
import torch
from torchvision import datasets, transforms
from .path_dataset import PathDataset
from fairseq.data import FairseqDataset
from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d
from shutil import copyfile
logger = logging.getLogger(__name__)
def load(path, loader, cache):
if hasattr(caching_loader, "cache_root"):
cache = caching_loader.cache_root
cached_path = cache + path
num_tries = 3
for curr_try in range(num_tries):
try:
if curr_try == 2:
return loader(path)
if not os.path.exists(cached_path) or curr_try > 0:
os.makedirs(os.path.dirname(cached_path), exist_ok=True)
copyfile(path, cached_path)
os.chmod(cached_path, 0o777)
return loader(cached_path)
except Exception as e:
logger.warning(str(e))
if "Errno 13" in str(e):
caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}"
logger.warning(f"setting cache root to {caching_loader.cache_root}")
cached_path = caching_loader.cache_root + path
if curr_try == (num_tries - 1):
raise
time.sleep(2)
def caching_loader(cache_root: str, loader):
if cache_root is None:
return loader
if cache_root == "slurm_tmpdir":
cache_root = os.environ["SLURM_TMPDIR"]
assert len(cache_root) > 0
if not cache_root.endswith("/"):
cache_root += "/"
return partial(load, loader=loader, cache=cache_root)
class RandomResizedCropAndInterpolationWithTwoPic:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(
self,
size,
second_size=None,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation="bilinear",
second_interpolation="lanczos",
):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if second_size is not None:
if isinstance(second_size, tuple):
self.second_size = second_size
else:
self.second_size = (second_size, second_size)
else:
self.second_size = None
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
logger.warning("range should be of kind (min, max)")
if interpolation == "random":
from PIL import Image
self.interpolation = (Image.BILINEAR, Image.BICUBIC)
else:
self.interpolation = self._pil_interp(interpolation)
self.second_interpolation = (
self._pil_interp(second_interpolation)
if second_interpolation is not None
else None
)
self.scale = scale
self.ratio = ratio
def _pil_interp(self, method):
from PIL import Image
if method == "bicubic":
return Image.BICUBIC
elif method == "lanczos":
return Image.LANCZOS
elif method == "hamming":
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img):
import torchvision.transforms.functional as F
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
if self.second_size is None:
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
else:
return F.resized_crop(
img, i, j, h, w, self.size, interpolation
), F.resized_crop(
img, i, j, h, w, self.second_size, self.second_interpolation
)
class MaeImageDataset(FairseqDataset):
def __init__(
self,
root: str,
split: str,
input_size,
local_cache_path=None,
shuffle=True,
key="imgs",
beit_transforms=False,
target_transform=False,
no_transform=False,
compute_mask=False,
patch_size: int = 16,
mask_prob: float = 0.75,
mask_prob_adjust: float = 0,
mask_length: int = 1,
inverse_mask: bool = False,
expand_adjacent: bool = False,
mask_dropout: float = 0,
non_overlapping: bool = False,
require_same_masks: bool = True,
clone_batch: int = 1,
dataset_type: str = "imagefolder",
):
FairseqDataset.__init__(self)
self.shuffle = shuffle
self.key = key
loader = caching_loader(local_cache_path, datasets.folder.default_loader)
self.transform_source = None
self.transform_target = None
if target_transform:
self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4)
self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4)
if no_transform:
if input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(input_size / crop_pct)
self.transform_train = transforms.Compose(
[
transforms.Resize(size, interpolation=3),
transforms.CenterCrop(input_size),
]
)
self.transform_train = transforms.Resize((input_size, input_size))
elif beit_transforms:
beit_transform_list = []
if not target_transform:
beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4))
beit_transform_list.extend(
[
transforms.RandomHorizontalFlip(p=0.5),
RandomResizedCropAndInterpolationWithTwoPic(
size=input_size,
second_size=None,
interpolation="bicubic",
second_interpolation=None,
),
]
)
self.transform_train = transforms.Compose(beit_transform_list)
else:
self.transform_train = transforms.Compose(
[
transforms.RandomResizedCrop(
input_size, scale=(0.2, 1.0), interpolation=3
), # 3 is bicubic
transforms.RandomHorizontalFlip(),
]
)
self.final_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
if dataset_type == "imagefolder":
self.dataset = datasets.ImageFolder(
os.path.join(root, split), loader=loader
)
elif dataset_type == "path":
self.dataset = PathDataset(
root,
loader,
None,
None,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
else:
raise Exception(f"invalid dataset type {dataset_type}")
logger.info(
f"initial transform: {self.transform_train}, "
f"source transform: {self.transform_source}, "
f"target transform: {self.transform_target}, "
f"final transform: {self.final_transform}"
)
logger.info(f"loaded {len(self.dataset)} examples")
self.is_compute_mask = compute_mask
self.patches = (input_size // patch_size) ** 2
self.mask_prob = mask_prob
self.mask_prob_adjust = mask_prob_adjust
self.mask_length = mask_length
self.inverse_mask = inverse_mask
self.expand_adjacent = expand_adjacent
self.mask_dropout = mask_dropout
self.non_overlapping = non_overlapping
self.require_same_masks = require_same_masks
self.clone_batch = clone_batch
def __getitem__(self, index):
img, _ = self.dataset[index]
img = self.transform_train(img)
source = None
target = None
if self.transform_source is not None:
source = self.final_transform(self.transform_source(img))
if self.transform_target is not None:
target = self.final_transform(self.transform_target(img))
if source is None:
img = self.final_transform(img)
v = {"id": index, self.key: source if source is not None else img}
if target is not None:
v["target"] = target
if self.is_compute_mask:
if self.mask_length == 1:
mask = compute_block_mask_1d(
shape=(self.clone_batch, self.patches),
mask_prob=self.mask_prob,
mask_length=self.mask_length,
mask_prob_adjust=self.mask_prob_adjust,
inverse_mask=self.inverse_mask,
require_same_masks=True,
)
else:
mask = compute_block_mask_2d(
shape=(self.clone_batch, self.patches),
mask_prob=self.mask_prob,
mask_length=self.mask_length,
mask_prob_adjust=self.mask_prob_adjust,
inverse_mask=self.inverse_mask,
require_same_masks=True,
expand_adjcent=self.expand_adjacent,
mask_dropout=self.mask_dropout,
non_overlapping=self.non_overlapping,
)
v["precomputed_mask"] = mask
return v
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if len(samples) == 0:
return {}
collated_img = torch.stack([s[self.key] for s in samples], dim=0)
res = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": {
self.key: collated_img,
},
}
if "target" in samples[0]:
collated_target = torch.stack([s["target"] for s in samples], dim=0)
res["net_input"]["target"] = collated_target
if "precomputed_mask" in samples[0]:
collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0)
res["net_input"]["precomputed_mask"] = collated_mask
return res
def num_tokens(self, index):
return 1
def size(self, index):
return 1
@property
def sizes(self):
return np.full((len(self),), 1)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
return order[0]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from enum import Enum, auto
class Modality(Enum):
AUDIO = auto()
IMAGE = auto()
TEXT = auto()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment