Commit 9b87ce51 authored by dongchy920's avatar dongchy920
Browse files

arcface

parents
import logging
import os
import time
from typing import List
import torch
from eval import verification
from utils.utils_logging import AverageMeter
from torch.utils.tensorboard import SummaryWriter
from torch import distributed
class CallBackVerification(object):
def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112), wandb_logger=None):
self.rank: int = distributed.get_rank()
self.highest_acc: float = 0.0
self.highest_acc_list: List[float] = [0.0] * len(val_targets)
self.ver_list: List[object] = []
self.ver_name_list: List[str] = []
if self.rank is 0:
self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
self.summary_writer = summary_writer
self.wandb_logger = wandb_logger
def ver_test(self, backbone: torch.nn.Module, global_step: int):
results = []
for i in range(len(self.ver_list)):
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
self.ver_list[i], backbone, 10, 10)
logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
self.summary_writer: SummaryWriter
self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, )
if self.wandb_logger:
import wandb
self.wandb_logger.log({
f'Acc/val-Acc1 {self.ver_name_list[i]}': acc1,
f'Acc/val-Acc2 {self.ver_name_list[i]}': acc2,
# f'Acc/val-std1 {self.ver_name_list[i]}': std1,
# f'Acc/val-std2 {self.ver_name_list[i]}': acc2,
})
if acc2 > self.highest_acc_list[i]:
self.highest_acc_list[i] = acc2
logging.info(
'[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
results.append(acc2)
def init_dataset(self, val_targets, data_dir, image_size):
for name in val_targets:
path = os.path.join(data_dir, name + ".bin")
if os.path.exists(path):
data_set = verification.load_bin(path, image_size)
self.ver_list.append(data_set)
self.ver_name_list.append(name)
def __call__(self, num_update, backbone: torch.nn.Module):
if self.rank is 0 and num_update > 0:
backbone.eval()
self.ver_test(backbone, num_update)
backbone.train()
class CallBackLogging(object):
def __init__(self, frequent, total_step, batch_size, start_step=0,writer=None):
self.frequent: int = frequent
self.rank: int = distributed.get_rank()
self.world_size: int = distributed.get_world_size()
self.time_start = time.time()
self.total_step: int = total_step
self.start_step: int = start_step
self.batch_size: int = batch_size
self.writer = writer
self.init = False
self.tic = 0
def __call__(self,
global_step: int,
loss: AverageMeter,
epoch: int,
fp16: bool,
learning_rate: float,
grad_scaler: torch.cuda.amp.GradScaler):
if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
if self.init:
try:
speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
speed_total = speed * self.world_size
except ZeroDivisionError:
speed_total = float('inf')
#time_now = (time.time() - self.time_start) / 3600
#time_total = time_now / ((global_step + 1) / self.total_step)
#time_for_end = time_total - time_now
time_now = time.time()
time_sec = int(time_now - self.time_start)
time_sec_avg = time_sec / (global_step - self.start_step + 1)
eta_sec = time_sec_avg * (self.total_step - global_step - 1)
time_for_end = eta_sec/3600
if self.writer is not None:
self.writer.add_scalar('time_for_end', time_for_end, global_step)
self.writer.add_scalar('learning_rate', learning_rate, global_step)
self.writer.add_scalar('loss', loss.avg, global_step)
if fp16:
msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \
"Fp16 Grad Scale: %2.f Required: %1.f hours" % (
speed_total, loss.avg, learning_rate, epoch, global_step,
grad_scaler.get_scale(), time_for_end
)
else:
msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \
"Required: %1.f hours" % (
speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end
)
logging.info(msg)
loss.reset()
self.tic = time.time()
else:
self.init = True
self.tic = time.time()
import importlib
import os.path as osp
def get_config(config_file):
assert config_file.startswith('configs/'), 'config file setting must start with configs/'
temp_config_name = osp.basename(config_file)
temp_module_name = osp.splitext(temp_config_name)[0]
config = importlib.import_module("configs.base")
cfg = config.config
config = importlib.import_module("configs.%s" % temp_module_name)
job_cfg = config.config
cfg.update(job_cfg)
if cfg.output is None:
cfg.output = osp.join('work_dirs', temp_module_name)
return cfg
\ No newline at end of file
import math
import os
import random
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler as _DistributedSampler
def setup_seed(seed, cuda_deterministic=True):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else: # faster, less reproducible
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
def get_dist_info():
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def sync_random_seed(seed=None, device="cuda"):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)
rank, world_size = get_dist_info()
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
class DistributedSampler(_DistributedSampler):
def __init__(
self,
dataset,
num_replicas=None, # world_size
rank=None, # local_rank
shuffle=True,
seed=0,
):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
# in case that indices is shorter than half of total_size
indices = (indices * math.ceil(self.total_size / len(indices)))[
: self.total_size
]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
import logging
import os
import sys
class AverageMeter(object):
"""Computes and stores the average and current value
"""
def __init__(self):
self.val = None
self.avg = None
self.sum = None
self.count = None
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def init_logging(rank, models_root):
if rank == 0:
log_root = logging.getLogger()
log_root.setLevel(logging.INFO)
formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
handler_stream = logging.StreamHandler(sys.stdout)
handler_file.setFormatter(formatter)
handler_stream.setFormatter(formatter)
log_root.addHandler(handler_file)
log_root.addHandler(handler_stream)
log_root.info('rank_id: %d' % rank)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment