Commit 41b18fd8 authored by zhe chen's avatar zhe chen
Browse files

Use pre-commit to reformat code


Use pre-commit to reformat code
parent ff20ea39
......@@ -5,14 +5,15 @@
# --------------------------------------------------------
import os
import torch
import numpy as np
import torch
import torch.distributed as dist
from timm.data import Mixup, create_transform
from torchvision import transforms
from timm.data import Mixup
from timm.data import create_transform
from .cached_image_folder import ImageCephDataset
from .samplers import SubsetRandomSampler, NodeDistributedSampler
from .samplers import NodeDistributedSampler, SubsetRandomSampler
try:
from torchvision.transforms import InterpolationMode
......@@ -50,7 +51,7 @@ class TTA(torch.nn.Module):
return out
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, scale={self.scales})"
return f'{self.__class__.__name__}(size={self.size}, scale={self.scales})'
def build_loader(config):
......@@ -58,16 +59,16 @@ def build_loader(config):
dataset_train, config.MODEL.NUM_CLASSES = build_dataset('train',
config=config)
config.freeze()
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}"
"successfully build train dataset")
print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}'
'successfully build train dataset')
dataset_val, _ = build_dataset('val', config=config)
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}"
"successfully build val dataset")
print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}'
'successfully build val dataset')
dataset_test, _ = build_dataset('test', config=config)
print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}"
"successfully build test dataset")
print(f'local rank {config.LOCAL_RANK} / global rank {dist.get_rank()}'
'successfully build test dataset')
num_tasks = dist.get_world_size()
global_rank = dist.get_rank()
......
......@@ -5,22 +5,24 @@
# --------------------------------------------------------
import io
import json
import logging
import math
import os
import os.path as osp
import re
import time
import json
import math
from abc import abstractmethod
import mmcv
import torch
import logging
import os.path as osp
from PIL import Image
from tqdm import tqdm, trange
from abc import abstractmethod
import torch.utils.data as data
import torch.distributed as dist
import torch.utils.data as data
from mmcv.fileio import FileClient
from .zipreader import is_zip_path, ZipReader
from PIL import Image
from tqdm import tqdm, trange
from .zipreader import ZipReader, is_zip_path
_logger = logging.getLogger(__name__)
......@@ -66,7 +68,7 @@ def make_dataset(dir, class_to_idx, extensions):
def make_dataset_with_ann(ann_file, img_prefix, extensions):
images = []
with open(ann_file, "r") as f:
with open(ann_file, 'r') as f:
contents = f.readlines()
for line_str in contents:
path_contents = [c for c in line_str.split('\t')]
......@@ -108,7 +110,7 @@ class DatasetFolder(data.Dataset):
img_prefix='',
transform=None,
target_transform=None,
cache_mode="no"):
cache_mode='no'):
# image folder mode
if ann_file == '':
_, class_to_idx = find_classes(root)
......@@ -120,9 +122,9 @@ class DatasetFolder(data.Dataset):
extensions)
if len(samples) == 0:
raise (RuntimeError("Found 0 files in subfolders of: " + root +
"\n" + "Supported extensions are: " +
",".join(extensions)))
raise (RuntimeError('Found 0 files in subfolders of: ' + root +
'\n' + 'Supported extensions are: ' +
','.join(extensions)))
self.root = root
self.loader = loader
......@@ -136,11 +138,11 @@ class DatasetFolder(data.Dataset):
self.target_transform = target_transform
self.cache_mode = cache_mode
if self.cache_mode != "no":
if self.cache_mode != 'no':
self.init_cache()
def init_cache(self):
assert self.cache_mode in ["part", "full"]
assert self.cache_mode in ['part', 'full']
n_sample = len(self.samples)
global_rank = dist.get_rank()
world_size = dist.get_world_size()
......@@ -155,9 +157,9 @@ class DatasetFolder(data.Dataset):
)
start_time = time.time()
path, target = self.samples[index]
if self.cache_mode == "full":
if self.cache_mode == 'full':
samples_bytes[index] = (ZipReader.read(path), target)
elif self.cache_mode == "part" and index % world_size == global_rank:
elif self.cache_mode == 'part' and index % world_size == global_rank:
samples_bytes[index] = (ZipReader.read(path), target)
else:
samples_bytes[index] = (path, target)
......@@ -260,7 +262,7 @@ class CachedImageFolder(DatasetFolder):
transform=None,
target_transform=None,
loader=default_img_loader,
cache_mode="no"):
cache_mode='no'):
super(CachedImageFolder,
self).__init__(root,
loader,
......@@ -394,7 +396,7 @@ class ParserCephImage(Parser):
local_size = int(os.environ.get('LOCAL_SIZE', 1))
self.local_rank = local_rank
self.local_size = local_size
self.rank = int(os.environ["RANK"])
self.rank = int(os.environ['RANK'])
self.world_size = int(os.environ['WORLD_SIZE'])
self.num_replicas = int(os.environ['WORLD_SIZE'])
self.num_parts = local_size
......@@ -405,7 +407,7 @@ class ParserCephImage(Parser):
self.load_onto_memory_v2()
def load_onto_memory(self):
print("Loading images onto memory...", self.local_rank,
print('Loading images onto memory...', self.local_rank,
self.local_size)
if self.file_client is None:
self.file_client = FileClient(self.io_backend, **self.kwargs)
......@@ -417,7 +419,7 @@ class ParserCephImage(Parser):
img_bytes = self.file_client.get(path)
self.holder[path] = img_bytes
print("Loading complete!")
print('Loading complete!')
def load_onto_memory_v2(self):
# print("Loading images onto memory...", self.local_rank, self.local_size)
......@@ -446,7 +448,7 @@ class ParserCephImage(Parser):
self.holder[path] = img_bytes
print("Loading complete!")
print('Loading complete!')
def __getitem__(self, index):
if self.file_client is None:
......
......@@ -4,12 +4,13 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch
import os
import math
from torch.utils.data.sampler import Sampler
import torch.distributed as dist
import os
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data.sampler import Sampler
class SubsetRandomSampler(torch.utils.data.Sampler):
......@@ -57,12 +58,12 @@ class NodeDistributedSampler(Sampler):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
'Requires distributed package to be available')
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
'Requires distributed package to be available')
rank = dist.get_rank()
if local_rank is None:
local_rank = int(os.environ.get('LOCAL_RANK', 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment