Commit 1ad55bb4 authored by mashun1's avatar mashun1
Browse files

i2vgen-xl

parents
Pipeline #819 canceled with stages
import os, yaml
from copy import deepcopy, copy
# def get prior and ldm config
def assign_prior_mudule_cfg(cfg):
'''
'''
#
prior_cfg = deepcopy(cfg)
vldm_cfg = deepcopy(cfg)
with open(cfg.prior_cfg, 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
prior_cfg[k].update(v)
else:
prior_cfg[k] = v
with open(cfg.vldm_cfg, 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
vldm_cfg[k].update(v)
else:
vldm_cfg[k] = v
return prior_cfg, vldm_cfg
# def get prior and ldm config
def assign_vldm_vsr_mudule_cfg(cfg):
'''
'''
#
vldm_cfg = deepcopy(cfg)
vsr_cfg = deepcopy(cfg)
with open(cfg.vldm_cfg, 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
vldm_cfg[k].update(v)
else:
vldm_cfg[k] = v
with open(cfg.vsr_cfg, 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
vsr_cfg[k].update(v)
else:
vsr_cfg[k] = v
return vldm_cfg, vsr_cfg
# def get prior and ldm config
def assign_signle_cfg(cfg, _cfg_update, tname):
'''
'''
#
vldm_cfg = deepcopy(cfg)
if os.path.exists(_cfg_update[tname]):
with open(_cfg_update[tname], 'r') as f:
_cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader)
# _cfg_update = _cfg_update.cfg_dict
for k, v in _cfg_update.items():
if isinstance(v, dict) and k in cfg:
vldm_cfg[k].update(v)
else:
vldm_cfg[k] = v
return vldm_cfg
\ No newline at end of file
import os
import yaml
import json
import copy
import argparse
import utils.logging as logging
logger = logging.get_logger(__name__)
class Config(object):
def __init__(self, load=True, cfg_dict=None, cfg_level=None):
self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "")
if load:
self.args = self._parse_args()
logger.info("Loading config from {}.".format(self.args.cfg_file))
self.need_initialization = True
cfg_base = self._initialize_cfg()
cfg_dict = self._load_yaml(self.args)
cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
cfg_dict = self._update_from_args(cfg_dict)
self.cfg_dict = cfg_dict
self._update_dict(cfg_dict)
def _parse_args(self):
parser = argparse.ArgumentParser(
description="Argparser for configuring [code base name to think of] codebase"
)
parser.add_argument(
"--cfg",
dest="cfg_file",
help="Path to the configuration file",
default='configs/i2vgen_xl_infer.yaml'
)
parser.add_argument(
"--init_method",
help="Initialization method, includes TCP or shared file-system",
default="tcp://localhost:9999",
type=str,
)
parser.add_argument(
'--debug',
action='store_true',
default=False,
help='Into debug information'
)
parser.add_argument(
"opts",
help="other configurations",
default=None,
nargs=argparse.REMAINDER)
return parser.parse_args()
def _path_join(self, path_list):
path = ""
for p in path_list:
path+= p + '/'
return path[:-1]
def _update_from_args(self, cfg_dict):
args = self.args
for var in vars(args):
cfg_dict[var] = getattr(args, var)
return cfg_dict
def _initialize_cfg(self):
if self.need_initialization:
self.need_initialization = False
if os.path.exists('./configs/base.yaml'):
with open("./configs/base.yaml", 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
else:
with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
return cfg
def _load_yaml(self, args, file_name=""):
assert args.cfg_file is not None
if not file_name == "": # reading from base file
with open(file_name, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
else:
if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]:
args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./")
with open(args.cfg_file, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
file_name = args.cfg_file
if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys():
# return cfg if the base file is being accessed
cfg = self._merge_cfg_from_command_update(args, cfg)
return cfg
if "_BASE" in cfg.keys():
if cfg["_BASE"][1] == '.':
prev_count = cfg["_BASE"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg)
else:
if "_BASE_RUN" in cfg.keys():
if cfg["_BASE_RUN"][1] == '.':
prev_count = cfg["_BASE_RUN"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE_RUN"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True)
if "_BASE_MODEL" in cfg.keys():
if cfg["_BASE_MODEL"][1] == '.':
prev_count = cfg["_BASE_MODEL"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE_MODEL"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg)
cfg = self._merge_cfg_from_command(args, cfg)
return cfg
def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
for k,v in cfg_new.items():
if k in cfg_base.keys():
if isinstance(v, dict):
self._merge_cfg_from_base(cfg_base[k], v)
else:
cfg_base[k] = v
else:
if "BASE" not in k or preserve_base:
cfg_base[k] = v
return cfg_base
def _merge_cfg_from_command_update(self, args, cfg):
if len(args.opts) == 0:
return cfg
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
args.opts, len(args.opts)
)
keys = args.opts[0::2]
vals = args.opts[1::2]
for key, val in zip(keys, vals):
cfg[key] = val
return cfg
def _merge_cfg_from_command(self, args, cfg):
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
args.opts, len(args.opts)
)
keys = args.opts[0::2]
vals = args.opts[1::2]
# maximum supported depth 3
for idx, key in enumerate(keys):
key_split = key.split('.')
assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
len(key_split)
)
assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
key_split[0]
)
if len(key_split) == 2:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
elif len(key_split) == 3:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
key
)
elif len(key_split) == 4:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format(
key
)
if len(key_split) == 1:
cfg[key_split[0]] = vals[idx]
elif len(key_split) == 2:
cfg[key_split[0]][key_split[1]] = vals[idx]
elif len(key_split) == 3:
cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
elif len(key_split) == 4:
cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx]
return cfg
def _update_dict(self, cfg_dict):
def recur(key, elem):
if type(elem) is dict:
return key, Config(load=False, cfg_dict=elem, cfg_level=key)
else:
if type(elem) is str and elem[1:3]=="e-":
elem = float(elem)
return key, elem
dic = dict(recur(k, v) for k, v in cfg_dict.items())
self.__dict__.update(dic)
def get_args(self):
return self.args
def __repr__(self):
return "{}\n".format(self.dump())
def dump(self):
return json.dumps(self.cfg_dict, indent=2)
def deep_copy(self):
return copy.deepcopy(self)
if __name__ == '__main__':
# debug
cfg = Config(load=True)
print(cfg.DATA)
\ No newline at end of file
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
import torch.nn.functional as F
import torch.distributed as dist
import functools
import pickle
import numpy as np
from collections import OrderedDict
from torch.autograd import Function
__all__ = ['is_dist_initialized',
'get_world_size',
'get_rank',
'new_group',
'destroy_process_group',
'barrier',
'broadcast',
'all_reduce',
'reduce',
'gather',
'all_gather',
'reduce_dict',
'get_global_gloo_group',
'generalized_all_gather',
'generalized_gather',
'scatter',
'reduce_scatter',
'send',
'recv',
'isend',
'irecv',
'shared_random_seed',
'diff_all_gather',
'diff_all_reduce',
'diff_scatter',
'diff_copy',
'spherical_kmeans',
'sinkhorn']
#-------------------------------- Distributed operations --------------------------------#
def is_dist_initialized():
return dist.is_available() and dist.is_initialized()
def get_world_size(group=None):
return dist.get_world_size(group) if is_dist_initialized() else 1
def get_rank(group=None):
return dist.get_rank(group) if is_dist_initialized() else 0
def new_group(ranks=None, **kwargs):
if is_dist_initialized():
return dist.new_group(ranks, **kwargs)
return None
def destroy_process_group():
if is_dist_initialized():
dist.destroy_process_group()
def barrier(group=None, **kwargs):
if get_world_size(group) > 1:
dist.barrier(group, **kwargs)
def broadcast(tensor, src, group=None, **kwargs):
if get_world_size(group) > 1:
return dist.broadcast(tensor, src, group, **kwargs)
def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs):
if get_world_size(group) > 1:
return dist.all_reduce(tensor, op, group, **kwargs)
def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs):
if get_world_size(group) > 1:
return dist.reduce(tensor, dst, op, group, **kwargs)
def gather(tensor, dst=0, group=None, **kwargs):
rank = get_rank() # global rank
world_size = get_world_size(group)
if world_size == 1:
return [tensor]
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None
dist.gather(tensor, tensor_list, dst, group, **kwargs)
return tensor_list
def all_gather(tensor, uniform_size=True, group=None, **kwargs):
world_size = get_world_size(group)
if world_size == 1:
return [tensor]
assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()'
if uniform_size:
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group, **kwargs)
return tensor_list
else:
# collect tensor shapes across GPUs
shape = tuple(tensor.shape)
shape_list = generalized_all_gather(shape, group)
# flatten the tensor
tensor = tensor.reshape(-1)
size = int(np.prod(shape))
size_list = [int(np.prod(u)) for u in shape_list]
max_size = max(size_list)
# pad to maximum size
if size != max_size:
padding = tensor.new_zeros(max_size - size)
tensor = torch.cat([tensor, padding], dim=0)
# all_gather
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group, **kwargs)
# reshape tensors
tensor_list = [t[:n].view(s) for t, n, s in zip(
tensor_list, size_list, shape_list)]
return tensor_list
@torch.no_grad()
def reduce_dict(input_dict, group=None, reduction='mean', **kwargs):
assert reduction in ['mean', 'sum']
world_size = get_world_size(group)
if world_size == 1:
return input_dict
# ensure that the orders of keys are consistent across processes
if isinstance(input_dict, OrderedDict):
keys = list(input_dict.keys)
else:
keys = sorted(input_dict.keys())
vals = [input_dict[key] for key in keys]
vals = torch.stack(vals, dim=0)
dist.reduce(vals, dst=0, group=group, **kwargs)
if dist.get_rank(group) == 0 and reduction == 'mean':
vals /= world_size
dist.broadcast(vals, src=0, group=group, **kwargs)
reduced_dict = type(input_dict)([
(key, val) for key, val in zip(keys, vals)])
return reduced_dict
@functools.lru_cache()
def get_global_gloo_group():
backend = dist.get_backend()
assert backend in ['gloo', 'nccl']
if backend == 'nccl':
return dist.new_group(backend='gloo')
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
backend = dist.get_backend(group)
assert backend in ['gloo', 'nccl']
device = torch.device('cpu' if backend == 'gloo' else 'cuda')
buffer = pickle.dumps(data)
if len(buffer) > 1024 ** 3:
logger = logging.getLogger(__name__)
logger.warning(
'Rank {} trying to all-gather {:.2f} GB of data on device'
'{}'.format(get_rank(), len(buffer) / (1024 ** 3), device))
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
world_size = dist.get_world_size(group=group)
assert world_size >= 1, \
'gather/all_gather must be called from ranks within' \
'the give group!'
local_size = torch.tensor(
[tensor.numel()], dtype=torch.int64, device=tensor.device)
size_list = [torch.zeros(
[1], dtype=torch.int64, device=tensor.device)
for _ in range(world_size)]
# gather tensors and compute the maximum size
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# pad tensors to the same size
if local_size != max_size:
padding = torch.zeros(
(max_size - local_size, ),
dtype=torch.uint8, device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def generalized_all_gather(data, group=None):
if get_world_size(group) == 1:
return [data]
if group is None:
group = get_global_gloo_group()
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving tensors from all ranks
tensor_list = [torch.empty(
(max_size, ), dtype=torch.uint8, device=tensor.device)
for _ in size_list]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def generalized_gather(data, dst=0, group=None):
world_size = get_world_size(group)
if world_size == 1:
return [data]
if group is None:
group = get_global_gloo_group()
rank = dist.get_rank() # global rank
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
# receiving tensors from all ranks to dst
if rank == dst:
max_size = max(size_list)
tensor_list = [torch.empty(
(max_size, ), dtype=torch.uint8, device=tensor.device)
for _ in size_list]
dist.gather(tensor, tensor_list, dst=dst, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
else:
dist.gather(tensor, [], dst=dst, group=group)
return []
def scatter(data, scatter_list=None, src=0, group=None, **kwargs):
r"""NOTE: only supports CPU tensor communication.
"""
if get_world_size(group) > 1:
return dist.scatter(data, scatter_list, src, group, **kwargs)
def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs):
if get_world_size(group) > 1:
return dist.reduce_scatter(output, input_list, op, group, **kwargs)
def send(tensor, dst, group=None, **kwargs):
if get_world_size(group) > 1:
assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()'
return dist.send(tensor, dst, group, **kwargs)
def recv(tensor, src=None, group=None, **kwargs):
if get_world_size(group) > 1:
assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()'
return dist.recv(tensor, src, group, **kwargs)
def isend(tensor, dst, group=None, **kwargs):
if get_world_size(group) > 1:
assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()'
return dist.isend(tensor, dst, group, **kwargs)
def irecv(tensor, src=None, group=None, **kwargs):
if get_world_size(group) > 1:
assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()'
return dist.irecv(tensor, src, group, **kwargs)
def shared_random_seed(group=None):
seed = np.random.randint(2 ** 31)
all_seeds = generalized_all_gather(seed, group)
return all_seeds[0]
#-------------------------------- Differentiable operations --------------------------------#
def _all_gather(x):
if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
return x
rank = dist.get_rank()
world_size = dist.get_world_size()
tensors = [torch.empty_like(x) for _ in range(world_size)]
tensors[rank] = x
dist.all_gather(tensors, x)
return torch.cat(tensors, dim=0).contiguous()
def _all_reduce(x):
if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
return x
dist.all_reduce(x)
return x
def _split(x):
if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
return x
rank = dist.get_rank()
world_size = dist.get_world_size()
return x.chunk(world_size, dim=0)[rank].contiguous()
class DiffAllGather(Function):
r"""Differentiable all-gather.
"""
@staticmethod
def symbolic(graph, input):
return _all_gather(input)
@staticmethod
def forward(ctx, input):
return _all_gather(input)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
class DiffAllReduce(Function):
r"""Differentiable all-reducd.
"""
@staticmethod
def symbolic(graph, input):
return _all_reduce(input)
@staticmethod
def forward(ctx, input):
return _all_reduce(input)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class DiffScatter(Function):
r"""Differentiable scatter.
"""
@staticmethod
def symbolic(graph, input):
return _split(input)
@staticmethod
def symbolic(ctx, input):
return _split(input)
@staticmethod
def backward(ctx, grad_output):
return _all_gather(grad_output)
class DiffCopy(Function):
r"""Differentiable copy that reduces all gradients during backward.
"""
@staticmethod
def symbolic(graph, input):
return input
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
return _all_reduce(grad_output)
diff_all_gather = DiffAllGather.apply
diff_all_reduce = DiffAllReduce.apply
diff_scatter = DiffScatter.apply
diff_copy = DiffCopy.apply
#-------------------------------- Distributed algorithms --------------------------------#
@torch.no_grad()
def spherical_kmeans(feats, num_clusters, num_iters=10):
k, n, c = num_clusters, *feats.size()
ones = feats.new_ones(n, dtype=torch.long)
# distributed settings
rank = get_rank()
world_size = get_world_size()
# init clusters
rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))]
clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k]
# variables
new_clusters = feats.new_zeros(k, c)
counts = feats.new_zeros(k, dtype=torch.long)
# iterative Expectation-Maximization
for step in range(num_iters + 1):
# Expectation step
simmat = torch.mm(feats, clusters.t())
scores, assigns = simmat.max(dim=1)
if step == num_iters:
break
# Maximization step
new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats)
all_reduce(new_clusters)
counts.zero_()
counts.index_add_(0, assigns, ones)
all_reduce(counts)
mask = (counts > 0)
clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1)
clusters = F.normalize(clusters, p=2, dim=1)
return clusters, assigns, scores
@torch.no_grad()
def sinkhorn(Q, eps=0.5, num_iters=3):
# normalize Q
Q = torch.exp(Q / eps).t()
sum_Q = Q.sum()
all_reduce(sum_Q)
Q /= sum_Q
# variables
n, m = Q.size()
u = Q.new_zeros(n)
r = Q.new_ones(n) / n
c = Q.new_ones(m) / (m * get_world_size())
# iterative update
cur_sum = Q.sum(dim=1)
all_reduce(cur_sum)
for i in range(num_iters):
u = cur_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / Q.sum(dim=0)).unsqueeze(0)
cur_sum = Q.sum(dim=1)
all_reduce(cur_sum)
return (Q / Q.sum(dim=0, keepdim=True)).t().float()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Logging."""
import builtins
import decimal
import functools
import logging
import os
import sys
import simplejson
# from fvcore.common.file_io import PathManager
import utils.distributed as du
def _suppress_print():
"""
Suppresses printing from the current process.
"""
def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
pass
builtins.print = print_pass
# @functools.lru_cache(maxsize=None)
# def _cached_log_stream(filename):
# return PathManager.open(filename, "a")
def setup_logging(cfg, log_file):
"""
Sets up the logging for multiple processes. Only enable the logging for the
master process, and suppress logging for the non-master processes.
"""
if du.is_master_proc():
# Enable logging for the master process.
logging.root.handlers = []
else:
# Suppress logging for non-master processes.
_suppress_print()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.propagate = False
plain_formatter = logging.Formatter(
"[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
datefmt="%m/%d %H:%M:%S",
)
if du.is_master_proc():
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
ch.setFormatter(plain_formatter)
logger.addHandler(ch)
if log_file is not None and du.is_master_proc(du.get_world_size()):
filename = os.path.join(cfg.OUTPUT_DIR, log_file)
fh = logging.FileHandler(filename)
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
def get_logger(name):
"""
Retrieve the logger with the specified name or, if name is None, return a
logger which is the root logger of the hierarchy.
Args:
name (string): name of the logger.
"""
return logging.getLogger(name)
def log_json_stats(stats):
"""
Logs json stats.
Args:
stats (dict): a dictionary of statistical information to log.
"""
stats = {
k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v
for k, v in stats.items()
}
json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
logger = get_logger(__name__)
logger.info("{:s}".format(json_stats))
import socket
from contextlib import closing
def find_free_port():
""" https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return str(s.getsockname()[1])
\ No newline at end of file
from .lr_scheduler import *
from .adafactor import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment