Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
from . import BaseWrapperDataset
class LRUCacheDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
@lru_cache(maxsize=16)
def __getitem__(self, index):
return self.dataset[index]
@lru_cache(maxsize=16)
def collater(self, samples):
return self.dataset.collater(samples)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import numpy as np
import torch
from unicore.data import Dictionary, data_utils
from . import BaseWrapperDataset, LRUCacheDataset
class MaskTokensDataset(BaseWrapperDataset):
@classmethod
def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
"""Return the source and target datasets for masked LM training."""
dataset = LRUCacheDataset(dataset)
return (
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
)
def __init__(
self,
dataset: torch.utils.data.Dataset,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
return_masked_tokens: bool = False,
seed: int = 1,
mask_prob: float = 0.15,
leave_unmasked_prob: float = 0.1,
random_token_prob: float = 0.1,
):
assert 0.0 < mask_prob < 1.0
assert 0.0 <= random_token_prob <= 1.0
assert 0.0 <= leave_unmasked_prob <= 1.0
assert random_token_prob + leave_unmasked_prob <= 1.0
self.dataset = dataset
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.return_masked_tokens = return_masked_tokens
self.seed = seed
self.mask_prob = mask_prob
self.leave_unmasked_prob = leave_unmasked_prob
self.random_token_prob = random_token_prob
if random_token_prob > 0.0:
weights = np.ones(len(self.vocab))
weights[vocab.special_index()] = 0
self.weights = weights / weights.sum()
self.epoch = None
@property
def can_reuse_epoch_itr_across_epochs(self):
return True # only the noise changes, not item sizes
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.epoch = epoch
def __getitem__(self, index: int):
return self.__getitem_cached__(self.epoch, index)
@lru_cache(maxsize=16)
def __getitem_cached__(self, epoch: int, index: int):
with data_utils.numpy_seed(self.seed, epoch, index):
item = self.dataset[index]
sz = len(item)
# don't allow empty sequence
assert sz > 2
assert (
self.mask_idx not in item
), "Dataset contains mask_idx (={}), this is not expected!".format(
self.mask_idx,
)
# decide elements to mask
mask = np.full(sz, False)
num_mask = int(
# add a random number for probabilistic rounding
self.mask_prob * (sz - 2) + np.random.rand()
)
# don't mask first and last position
mask_idc = np.random.choice(sz - 2, num_mask, replace=False) + 1
mask[mask_idc] = True
if self.return_masked_tokens:
new_item = np.full(len(mask), self.pad_idx)
new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
return torch.from_numpy(new_item)
# decide unmasking and random replacement
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
if rand_or_unmask_prob > 0.0:
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
if self.random_token_prob == 0.0:
unmask = rand_or_unmask
rand_mask = None
elif self.leave_unmasked_prob == 0.0:
unmask = None
rand_mask = rand_or_unmask
else:
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
decision = np.random.rand(sz) < unmask_prob
unmask = rand_or_unmask & decision
rand_mask = rand_or_unmask & (~decision)
else:
unmask = rand_mask = None
if unmask is not None:
mask = mask ^ unmask
new_item = np.copy(item)
new_item[mask] = self.mask_idx
if rand_mask is not None:
num_rand = rand_mask.sum()
if num_rand > 0:
new_item[rand_mask] = np.random.choice(
len(self.vocab),
num_rand,
p=self.weights,
)
return torch.from_numpy(new_item)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import torch
from torch.utils.data.dataloader import default_collate
from . import UnicoreDataset
def _flatten(dico, prefix=None):
"""Flatten a nested dictionary."""
new_dico = OrderedDict()
if isinstance(dico, dict):
prefix = prefix + "." if prefix is not None else ""
for k, v in dico.items():
if v is None:
continue
new_dico.update(_flatten(v, prefix + k))
elif isinstance(dico, list):
for i, v in enumerate(dico):
new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]"))
else:
new_dico = OrderedDict({prefix: dico})
return new_dico
def _unflatten(dico):
"""Unflatten a flattened dictionary into a nested dictionary."""
new_dico = OrderedDict()
for full_k, v in dico.items():
full_k = full_k.split(".")
node = new_dico
for k in full_k[:-1]:
if k.startswith("[") and k.endswith("]"):
k = int(k[1:-1])
if k not in node:
node[k] = OrderedDict()
node = node[k]
node[full_k[-1]] = v
return new_dico
class NestedDictionaryDataset(UnicoreDataset):
def __init__(self, defn):
super().__init__()
self.defn = _flatten(defn)
first = None
for v in self.defn.values():
if not isinstance(
v,
(
UnicoreDataset,
torch.utils.data.Dataset,
),
):
raise ValueError("Expected Dataset but found: {}".format(v.__class__))
first = first or v
if len(v) > 0:
assert len(v) == len(first), "dataset lengths must match"
self._len = len(first)
def __getitem__(self, index):
return OrderedDict((k, ds[index]) for k, ds in self.defn.items())
def __len__(self):
return self._len
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
if len(samples) == 0:
return {}
sample = OrderedDict()
for k, ds in self.defn.items():
try:
sample[k] = ds.collater([s[k] for s in samples])
except NotImplementedError:
sample[k] = default_collate([s[k] for s in samples])
return _unflatten(sample)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return any(ds.supports_prefetch for ds in self.defn.values())
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
for ds in self.defn.values():
if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values())
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.defn.values():
ds.set_epoch(epoch)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import UnicoreDataset
class NumSamplesDataset(UnicoreDataset):
def __getitem__(self, index):
return 1
def __len__(self):
return 0
def collater(self, samples):
return sum(samples)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class NumelDataset(BaseWrapperDataset):
def __init__(self, dataset, reduce=False):
super().__init__(dataset)
self.reduce = reduce
def __getitem__(self, index):
item = self.dataset[index]
if torch.is_tensor(item):
return torch.numel(item)
else:
return np.size(item)
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if self.reduce:
return sum(samples)
else:
return torch.tensor(samples)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from unicore.data import data_utils
from . import BaseWrapperDataset
class PadDataset(BaseWrapperDataset):
def __init__(self, dataset, pad_idx, left_pad):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
def collater(self, samples):
return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
class LeftPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=True)
class RightPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=False)
class RightPadDataset2D(BaseWrapperDataset):
def __init__(self, dataset, pad_idx,left_pad=False):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
def collater(self, samples):
return data_utils.collate_tokens_2d(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from functools import lru_cache
from . import BaseWrapperDataset
class PrependTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
@lru_cache(maxsize=16)
def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([torch.full_like(item[0], self.token).unsqueeze(0), item], dim=0)
return item
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.utils.data.dataloader import default_collate
from functools import lru_cache
from . import UnicoreDataset
class RawLabelDataset(UnicoreDataset):
def __init__(self, labels):
super().__init__()
self.labels = labels
@lru_cache(maxsize=16)
def __getitem__(self, index):
return self.labels[index]
def __len__(self):
return len(self.labels)
def collater(self, samples):
return torch.tensor(samples)
class RawArrayDataset(UnicoreDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
@lru_cache(maxsize=16)
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, 'collater'):
return self.dataset.collater(samples)
else:
return default_collate(samples)
class RawNumpyDataset(UnicoreDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
@lru_cache(maxsize=16)
def __getitem__(self, index):
return torch.from_numpy(self.dataset[index])
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, 'collater'):
return self.dataset.collater(samples)
else:
return default_collate(samples)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from . import BaseWrapperDataset, data_utils
class SortDataset(BaseWrapperDataset):
def __init__(self, dataset, sort_order):
super().__init__(dataset)
if not isinstance(sort_order, (list, tuple)):
sort_order = [sort_order]
self.sort_order = sort_order
assert all(len(so) == len(dataset) for so in sort_order)
def ordered_indices(self):
return np.lexsort(self.sort_order)
class EpochShuffleDataset(BaseWrapperDataset):
def __init__(self, dataset, size, seed):
super().__init__(dataset)
self.size = size
self.seed = seed
self.set_epoch(1)
def set_epoch(self, epoch):
super().set_epoch(epoch)
with data_utils.numpy_seed(self.seed + epoch - 1):
self.sort_order = np.random.permutation(self.size)
def ordered_indices(self):
return self.sort_order
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
\ No newline at end of file
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from functools import lru_cache
import torch
from unicore.data import Dictionary
from functools import lru_cache
from . import BaseWrapperDataset
class TokenizeDataset(BaseWrapperDataset):
def __init__(
self,
dataset: torch.utils.data.Dataset,
dictionary: Dictionary,
max_seq_len: int=512,
):
self.dataset = dataset
self.dictionary = dictionary
self.max_seq_len = max_seq_len
@lru_cache(maxsize=16)
def __getitem__(self, index: int):
raw_data = self.dataset[index]
assert len(raw_data) < self.max_seq_len and len(raw_data) > 0
return torch.from_numpy(self.dictionary.vec_index(raw_data)).long()
\ No newline at end of file
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import numpy as np
import torch.utils.data
logger = logging.getLogger(__name__)
class EpochListening:
"""Mixin for receiving updates whenever the epoch increments."""
@property
def can_reuse_epoch_itr_across_epochs(self):
"""
Whether we can reuse the :class:`unicore.data.EpochBatchIterator` for
this dataset across epochs.
This needs to return ``False`` if the sample sizes can change across
epochs, in which case we may need to regenerate batches at each epoch.
If your dataset relies in ``set_epoch`` then you should consider setting
this to ``False``.
"""
return True
def set_epoch(self, epoch):
"""Will receive the updated epoch number at the beginning of the epoch."""
pass
class UnicoreDataset(torch.utils.data.Dataset, EpochListening):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
raise NotImplementedError
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self), dtype=np.int64)
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False
def attr(self, attr: str, index: int):
return getattr(self, attr, None)
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError
def batch_by_size(
self,
indices,
batch_size=None,
required_batch_size_multiple=1,
):
"""
Given an ordered set of indices
"""
from unicore.data import data_utils
return data_utils.batch_by_size(
indices,
batch_size=batch_size,
required_batch_size_multiple=required_batch_size_multiple,
)
@property
def supports_fetch_outside_dataloader(self):
"""Whether this dataset supports fetching outside the workers of the dataloader."""
return True
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .module_proxy_wrapper import ModuleProxyWrapper
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
__all__ = [
"ModuleProxyWrapper",
]
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A modified version of the legacy DistributedDataParallel module that uses c10d
communication primitives. This version is simpler than the latest PyTorch
version and is useful for debugging. Notably it does not overlap gradient
communication with the backward pass, which makes it slower but more robust
than the PyTorch version.
This version also supports the *no_sync* context manager, which allows faster
training with `--update-freq`.
"""
from collections import OrderedDict
from contextlib import contextmanager
import torch
from torch import nn
from unicore.distributed import utils
class LegacyDistributedDataParallel(nn.Module):
"""Implements distributed data parallelism at the module level.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
broadcast buffers.
Args:
module (~torch.nn.Module): module to be parallelized
process_group: the c10d process group to be used for distributed data
parallel all-reduction.
buffer_size (int, optional): number of elements to buffer before
performing all-reduce (default: 256M).
"""
def __init__(self, module, process_group, buffer_size=2 ** 28):
super().__init__()
self.module = module
self.process_group = process_group
self.world_size = utils.get_world_size(self.process_group)
# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
self.buffer = None
# We can also forcibly accumulate grads locally and only do the
# all-reduce at some later time
self.accumulate_grads = False
# make per-device lists of parameters
paramlists = OrderedDict()
for param in self.module.parameters():
device = param.device
if paramlists.get(device) is None:
paramlists[device] = []
paramlists[device] += [param]
self.per_device_params = list(paramlists.values())
@contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""
old_accumulate_grads = self.accumulate_grads
self.accumulate_grads = True
yield
self.accumulate_grads = old_accumulate_grads
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_params(self, params):
if self.accumulate_grads:
return
buffer = self.buffer
nonzero_buffer = False
if len(params) > 1:
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
nonzero_buffer = True
else:
buffer[offset : offset + sz].zero_()
offset += sz
else:
# we only have a single grad to all-reduce
p = params[0]
if p.grad is not None:
buffer = p.grad.data
nonzero_buffer = True
elif p.numel() <= self.buffer.numel():
buffer = buffer[: p.numel()]
buffer.zero_()
else:
buffer = torch.zeros_like(p)
if nonzero_buffer:
buffer.div_(self.world_size)
utils.all_reduce(buffer, self.process_group)
# copy all-reduced grads back into their original place
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
else:
p.grad = buffer[offset : offset + sz].view_as(p).clone()
offset += sz
def all_reduce_grads(self):
"""
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
"""
def reduction_fn():
# This function only needs to be called once
if self.accumulate_grads:
return
if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size)
for params in self.per_device_params:
# All-reduce the gradients in buckets
offset = 0
buffered_params = []
for param in params:
if not param.requires_grad:
continue
if param.grad is None:
param.grad = torch.zeros_like(param)
if hasattr(param, "expert"):
# Skip gradient sync for unshared parameters
continue
if param.grad.requires_grad:
raise RuntimeError(
"DistributedDataParallel only works "
"with gradients that don't require "
"grad"
)
sz = param.numel()
if sz > self.buffer.numel():
# all-reduce big params directly
self.all_reduce_params([param])
else:
if offset + sz > self.buffer.numel():
self.all_reduce_params(buffered_params)
offset = 0
buffered_params.clear()
buffered_params.append(param)
offset += sz
if len(buffered_params) > 0:
self.all_reduce_params(buffered_params)
reduction_fn()
\ No newline at end of file
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn
class ModuleProxyWrapper(nn.Module):
"""
Wrap a DistributedDataParallel module and forward requests for missing
attributes to the module wrapped by DDP (the twice-wrapped module).
Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
Usage::
module.xyz = "hello world"
wrapped_module = DistributedDataParallel(module, **ddp_args)
wrapped_module = ModuleProxyWrapper(wrapped_module)
assert wrapped_module.xyz == "hello world"
assert wrapped_module.state_dict().keys() == module.state_dict().keys()
Args:
module (nn.Module): module to wrap
"""
def __init__(self, module: nn.Module):
super().__init__()
assert hasattr(module, "module"), \
"ModuleProxyWrapper expects input to wrap another module"
self.module = module
def __getattr__(self, name):
"""Forward missing attributes to twice-wrapped module."""
try:
# defer to nn.Module's logic
return super().__getattr__(name)
except AttributeError:
try:
# forward to the once-wrapped module
return getattr(self.module, name)
except AttributeError:
# forward to the twice-wrapped module
return getattr(self.module.module, name)
def state_dict(self, *args, **kwargs):
"""Forward to the twice-wrapped module."""
return self.module.module.state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
"""Forward to the twice-wrapped module."""
return self.module.module.load_state_dict(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import io
import logging
import os
import pickle
import random
import socket
import struct
import subprocess
import warnings
from collections import OrderedDict
from typing import Any, Dict, List, Mapping, Optional
from dataclasses import dataclass
import torch
import torch.distributed as dist
logger = logging.getLogger(__name__)
def is_master(args):
return args.distributed_rank == 0
def infer_init_method(args, force_distributed=False):
if args.distributed_init_method is not None:
return
if all(
key in os.environ
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
):
# support torch.distributed.launch
_infer_torch_distributed_launch_init(args)
elif args.distributed_port > 0:
# we can determine the init method automatically for Slurm
_infer_slurm_init(args)
elif args.distributed_world_size > 1 or force_distributed:
# fallback for single node with multiple GPUs
_infer_single_node_init(args)
elif not args.distributed_no_spawn:
args.distributed_num_procs = min(
torch.cuda.device_count(), args.distributed_world_size
)
def _infer_torch_distributed_launch_init(args):
args.distributed_init_method = "env://"
args.distributed_world_size = int(os.environ["WORLD_SIZE"])
args.distributed_rank = int(os.environ["RANK"])
# processes are created by torch.distributed.launch
args.distributed_no_spawn = True
def _infer_slurm_init(args):
node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
node_list = os.environ.get("SLURM_JOB_NODELIST")
if node_list is not None:
try:
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", node_list]
)
args.distributed_init_method = "tcp://{host}:{port}".format(
host=hostnames.split()[0].decode("utf-8"),
port=args.distributed_port,
)
nnodes = int(os.environ.get("SLURM_NNODES"))
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
if ntasks_per_node is not None:
ntasks_per_node = int(ntasks_per_node)
else:
ntasks = int(os.environ.get("SLURM_NTASKS"))
nnodes = int(os.environ.get("SLURM_NNODES"))
assert ntasks % nnodes == 0
ntasks_per_node = int(ntasks / nnodes)
if ntasks_per_node == 1:
gpus_per_node = torch.cuda.device_count()
node_id = int(os.environ.get("SLURM_NODEID"))
args.distributed_rank = node_id * gpus_per_node
args.distributed_world_size = nnodes * gpus_per_node
else:
assert ntasks_per_node == args.distributed_world_size // nnodes
args.distributed_no_spawn = True
args.distributed_rank = int(os.environ.get("SLURM_PROCID"))
args.device_id = int(os.environ.get("SLURM_LOCALID"))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
def _infer_single_node_init(args):
assert (
args.distributed_world_size <= torch.cuda.device_count()
), f"world size is {args.distributed_world_size} but have {torch.cuda.device_count()} available devices"
port = random.randint(10000, 20000)
args.distributed_init_method = "tcp://localhost:{port}".format(port=port)
def distributed_init(args):
if torch.distributed.is_available() and torch.distributed.is_initialized():
warnings.warn(
"Distributed is already initialized, cannot initialize twice!"
)
else:
logger.info(
"distributed init (rank {}): {}".format(
args.distributed_rank,
args.distributed_init_method,
)
)
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
timeout=datetime.timedelta(seconds=30),
)
logger.info(
"initialized host {} as rank {}".format(
socket.gethostname(),
args.distributed_rank,
)
)
# perform a dummy all-reduce to initialize the NCCL communicator
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())
args.distributed_rank = torch.distributed.get_rank()
if is_master(args):
logging.getLogger().setLevel(logging.INFO)
else:
logging.getLogger().setLevel(logging.WARNING)
return args.distributed_rank
def distributed_main(i, main, args, kwargs):
args.device_id = i
if torch.cuda.is_available() and not args.cpu:
torch.cuda.set_device(args.device_id)
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = kwargs.pop("start_rank", 0) + i
args.distributed_rank = distributed_init(args)
after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
if after_distributed_init_fn:
args = after_distributed_init_fn(args)
main(args, **kwargs)
if torch.distributed.is_initialized():
torch.distributed.barrier(get_global_group())
def call_main(args, main, **kwargs):
if args.distributed_init_method is None:
infer_init_method(args)
if args.distributed_init_method is not None:
# distributed training
if not args.distributed_no_spawn:
start_rank = args.distributed_rank
args.distributed_rank = None # assign automatically
kwargs["start_rank"] = start_rank
torch.multiprocessing.spawn(
fn=distributed_main,
args=(main, args, kwargs),
nprocs=min(
torch.cuda.device_count(),
args.distributed_world_size,
),
join=True,
)
else:
distributed_main(args.device_id, main, args, kwargs)
else:
# single GPU main
main(args, **kwargs)
def new_groups(grouped_ranks: List[List[int]]):
groups = [dist.new_group(g) for g in grouped_ranks]
my_group_idx = _find_my_group_index(grouped_ranks)
return groups[my_group_idx]
def _find_my_group_index(grouped_ranks):
my_rank = get_global_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def _find_my_group(grouped_ranks):
index = _find_my_group_index(grouped_ranks)
return grouped_ranks[index]
def get_rank(group):
return dist.get_rank(group=group)
def get_world_size(group):
if torch.distributed.is_initialized():
return dist.get_world_size(group=group)
else:
return 1
def get_global_group():
if torch.distributed.is_initialized():
if not hasattr(get_global_group, "_global_group"):
# ideally we could use torch.distributed.group.WORLD, but it seems
# to cause random NCCL hangs in some cases
get_global_group._global_group = dist.new_group()
return get_global_group._global_group
else:
return None
def get_global_rank():
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def get_global_world_size():
if torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
return get_global_group()
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return get_rank(get_data_parallel_group())
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return get_world_size(get_data_parallel_group())
def all_reduce(tensor, group, op="sum"):
if op == "sum":
op = dist.ReduceOp.SUM
elif op == "max":
op = dist.ReduceOp.MAX
else:
raise NotImplementedError
dist.all_reduce(tensor, op=op, group=group)
return tensor
def broadcast(tensor, src, group):
dist.broadcast(tensor, src=src, group=group)
def all_to_all(tensor, group):
"""Perform an all-to-all operation on a 1D Tensor."""
assert tensor.dim() == 1
split_count = get_world_size(group=group)
assert tensor.numel() % split_count == 0
output = torch.zeros_like(tensor)
dist.all_to_all_single(output, tensor, group=group)
return output
def all_gather(tensor, group, return_tensor=False):
"""Perform an all-gather operation."""
world_size = get_world_size(group=group)
rank = get_rank(group=group)
tensor_list = [
tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
]
dist.all_gather(tensor_list, tensor, group=group)
if return_tensor:
return torch.stack(tensor_list, dim=0)
else:
return tensor_list
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable and any CUDA tensors will be moved
to CPU and returned on CPU as well.
Args:
data (Any): data from the local worker to be gathered on other workers
group: group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
from unicore import utils
if group is None:
group = get_global_group()
rank = get_rank(group=group)
world_size = get_world_size(group=group)
buffer_size = max_size * world_size
if (
not hasattr(all_gather_list, "_buffer")
or all_gather_list._buffer.numel() < buffer_size
):
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
buffer.zero_()
cpu_buffer = all_gather_list._cpu_buffer
data = utils.move_to_cpu(data)
enc = pickle.dumps(data)
enc_size = len(enc)
header_size = 4 # size of header that contains the length of the encoded data
size = header_size + enc_size
if size > max_size:
raise ValueError(
"encoded data size ({}) exceeds max_size ({})".format(size, max_size)
)
header = struct.pack(">I", enc_size)
cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
start = rank * max_size
buffer[start : start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group)
buffer = buffer.cpu()
try:
result = []
for i in range(world_size):
out_buffer = buffer[i * max_size : (i + 1) * max_size]
(enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
if enc_size > 0:
result.append(
pickle.loads(
bytes(out_buffer[header_size : header_size + enc_size].tolist())
)
)
return result
except pickle.UnpicklingError:
raise Exception(
"Unable to unpickle data from other workers. all_gather_list requires all "
"workers to enter the function together, so this error usually indicates "
"that the workers have fallen out of sync somehow. Workers can fall out of "
"sync if one of them runs out of memory, or if there are other conditions "
"in your training script that can cause one worker to finish an epoch "
"while other workers are still iterating over their portions of the data. "
"Try rerunning with --ddp-backend=legacy_ddp and see if that helps."
)
def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]:
"""
AllReduce a dictionary of values across workers. We separately
reduce items that are already on the device and items on CPU for
better performance.
Args:
data (Mapping[str, Any]): dictionary of data to all-reduce, but
cannot be a nested dictionary
device (torch.device): device for the reduction
group: group of the collective
"""
data_keys = list(data.keys())
# We want to separately reduce items that are already on the
# device and items on CPU for performance reasons.
cpu_data = OrderedDict()
device_data = OrderedDict()
for k in data_keys:
t = data[k]
if not torch.is_tensor(t):
cpu_data[k] = torch.tensor(t, dtype=torch.double)
elif t.device.type != device.type:
cpu_data[k] = t.to(dtype=torch.double)
else:
device_data[k] = t.to(dtype=torch.double)
def _all_reduce_dict(data: OrderedDict):
if len(data) == 0:
return data
buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device)
all_reduce(buf, group=group)
split_buf = torch.split(buf, [t.numel() for t in data.values()])
reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())]
return OrderedDict(zip(data.keys(), reduced_data))
cpu_data = _all_reduce_dict(cpu_data)
device_data = _all_reduce_dict(device_data)
def get_from_stack(key):
if key in cpu_data:
return cpu_data[key]
elif key in device_data:
return device_data[key]
raise KeyError
return OrderedDict([(key, get_from_stack(key)) for key in data_keys])
@dataclass
class _TensorPlaceholder:
index: int
def broadcast_tensors(
tensors: Optional[List[torch.Tensor]],
src_rank: int,
group: object,
dist_device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
"""
Broadcasts a list of tensors without other (non-src) ranks needing to know
the dtypes/shapes of the tensors.
"""
if dist_device is None:
if torch.distributed.get_backend(group) == "nccl":
dist_device = torch.device("cuda")
else:
dist_device = torch.device("cpu")
# share metadata first to simplify transfer
is_src_rank = get_rank(group) == src_rank
if is_src_rank:
metadata = [
{"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
]
metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device)
else:
metadata = _broadcast_object_slow(None, src_rank, group, dist_device)
out_tensors = []
for i, meta in enumerate(metadata):
if is_src_rank:
tensor = tensors[i]
broadcast(tensors[i].to(dist_device), src=src_rank, group=group)
else:
tensor = torch.zeros(
[meta["size"].numel()], dtype=meta["dtype"], device=dist_device
)
broadcast(tensor, src=src_rank, group=group)
tensor = tensor.view(meta["size"]).to(meta["device"])
out_tensors.append(tensor)
return out_tensors
def broadcast_object(
obj: Any,
src_rank: int,
group: object,
dist_device: Optional[torch.device] = None,
) -> Any:
"""Broadcast an arbitrary Python object to other workers."""
if dist_device is None:
if torch.distributed.get_backend(group) == "nccl":
dist_device = torch.device("cuda")
else:
dist_device = torch.device("cpu")
if get_rank(group) == src_rank:
# split the tensors from the non-tensors so we can broadcast them
# directly, avoiding unnecessary serialization/deserialization
tensors = []
obj = _split_tensors_from_obj(obj, tensors)
obj = _broadcast_object_slow(obj, src_rank, group, dist_device)
tensors = broadcast_tensors(tensors, src_rank, group, dist_device)
else:
obj = _broadcast_object_slow(None, src_rank, group, dist_device)
tensors = broadcast_tensors(None, src_rank, group, dist_device)
return _put_tensors_in_obj(obj, tensors)
def _broadcast_object_slow(
obj: Any,
src_rank: int,
group: object,
dist_device: torch.device,
) -> Any:
if get_rank(group) == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device)
length = torch.LongTensor([len(buffer)]).to(dist_device)
broadcast(length, src=src_rank, group=group)
broadcast(buffer, src=src_rank, group=group)
else:
# Fetch from the source
length = torch.LongTensor([0]).to(dist_device)
broadcast(length, src=src_rank, group=group)
buffer = torch.ByteTensor(int(length.item())).to(dist_device)
broadcast(buffer, src=src_rank, group=group)
buffer = io.BytesIO(buffer.cpu().numpy())
obj = torch.load(buffer, map_location="cpu")
return obj
def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
if torch.is_tensor(obj):
placeholder = _TensorPlaceholder(index=len(tensors))
tensors.append(obj)
return placeholder
elif isinstance(obj, dict):
return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()}
elif isinstance(obj, list):
return [_split_tensors_from_obj(v, tensors) for v in obj]
elif isinstance(obj, tuple):
return tuple(_split_tensors_from_obj(v, tensors) for v in obj)
elif isinstance(obj, set):
return {_split_tensors_from_obj(v, tensors) for v in obj}
else:
return obj
def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
if isinstance(obj, _TensorPlaceholder):
return tensors[obj.index]
elif isinstance(obj, dict):
return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()}
elif isinstance(obj, list):
return [_put_tensors_in_obj(v, tensors) for v in obj]
elif isinstance(obj, tuple):
return tuple(_put_tensors_in_obj(v, tensors) for v in obj)
elif isinstance(obj, set):
return {_put_tensors_in_obj(v, tensors) for v in obj}
else:
return obj
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import time
from collections import OrderedDict
from typing import Dict, Optional
try:
import torch
def type_as(a, b):
if torch.is_tensor(a) and torch.is_tensor(b):
return a.to(b)
else:
return a
except ImportError:
torch = None
def type_as(a, b):
return a
try:
import numpy as np
except ImportError:
np = None
class Meter(object):
"""Base class for Meters."""
def __init__(self):
pass
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
pass
def reset(self):
raise NotImplementedError
@property
def smoothed_value(self) -> float:
"""Smoothed value used for logging."""
raise NotImplementedError
def safe_round(number, ndigits):
if hasattr(number, "__round__"):
return round(number, ndigits)
elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
return safe_round(number.item(), ndigits)
elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
return safe_round(number.item(), ndigits)
else:
return number
class AverageMeter(Meter):
"""Computes and stores the average and current value"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.reset()
def reset(self):
self.val = None # most recent update
self.sum = 0 # sum from all updates
self.count = 0 # total n from all updates
def update(self, val, n=1):
if val is not None:
self.val = val
if n > 0:
self.sum = type_as(self.sum, val) + (val * n)
self.count = type_as(self.count, n) + n
def state_dict(self):
return {
"val": self.val,
"sum": self.sum,
"count": self.count,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.val = state_dict["val"]
self.sum = state_dict["sum"]
self.count = state_dict["count"]
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.sum / self.count if self.count > 0 else self.val
@property
def smoothed_value(self) -> float:
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class TimeMeter(Meter):
"""Computes the average occurrence of some event per second"""
def __init__(
self,
init: int = 0,
n: int = 0,
round: Optional[int] = None,
):
self.round = round
self.reset(init, n)
def reset(self, init=0, n=0):
self.init = init
self.start = time.perf_counter()
self.n = n
self.i = 0
def update(self, val=1):
self.n = type_as(self.n, val) + val
self.i += 1
def state_dict(self):
return {
"init": self.elapsed_time,
"n": self.n,
"round": self.round,
}
def load_state_dict(self, state_dict):
if "start" in state_dict:
# backwards compatibility for old state_dicts
self.reset(init=state_dict["init"])
else:
self.reset(init=state_dict["init"], n=state_dict["n"])
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return self.init + (time.perf_counter() - self.start)
@property
def smoothed_value(self) -> float:
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class StopwatchMeter(Meter):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.sum = 0
self.n = 0
self.start_time = None
def start(self):
self.start_time = time.perf_counter()
def stop(self, n=1, prehook=None):
if self.start_time is not None:
if prehook is not None:
prehook()
delta = time.perf_counter() - self.start_time
self.sum = self.sum + delta
self.n = type_as(self.n, n) + n
def reset(self):
self.sum = 0 # cumulative time during which stopwatch was active
self.n = 0 # total n across all start/stop
self.start()
def state_dict(self):
return {
"sum": self.sum,
"n": self.n,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.sum = state_dict["sum"]
self.n = state_dict["n"]
self.start_time = None
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.sum / self.n if self.n > 0 else self.sum
@property
def elapsed_time(self):
if self.start_time is None:
return 0.0
return time.perf_counter() - self.start_time
@property
def smoothed_value(self) -> float:
val = self.avg if self.sum > 0 else self.elapsed_time
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class MetersDict(OrderedDict):
"""A sorted dictionary of :class:`Meters`.
Meters are sorted according to a priority that is given when the
meter is first added to the dictionary.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.priorities = []
def __setitem__(self, key, value):
assert key not in self, "MetersDict doesn't support reassignment"
priority, value = value
bisect.insort(self.priorities, (priority, len(self.priorities), key))
super().__setitem__(key, value)
for _, _, key in self.priorities: # reorder dict to match priorities
self.move_to_end(key)
def add_meter(self, key, meter, priority):
self.__setitem__(key, (priority, meter))
def state_dict(self):
return [
(pri, key, self[key].__class__.__name__, self[key].state_dict())
for pri, _, key in self.priorities
# can't serialize DerivedMeter instances
if not isinstance(self[key], MetersDict._DerivedMeter)
]
def load_state_dict(self, state_dict):
self.clear()
self.priorities.clear()
for pri, key, meter_cls, meter_state in state_dict:
meter = globals()[meter_cls]()
meter.load_state_dict(meter_state)
self.add_meter(key, meter, pri)
def get_smoothed_value(self, key: str) -> float:
"""Get a single smoothed value."""
meter = self[key]
if isinstance(meter, MetersDict._DerivedMeter):
return meter.fn(self)
else:
return meter.smoothed_value
def get_smoothed_values(self) -> Dict[str, float]:
"""Get all smoothed values."""
return OrderedDict(
[
(key, self.get_smoothed_value(key))
for key in self.keys()
if not key.startswith("_")
]
)
def reset(self):
"""Reset Meter instances."""
for meter in self.values():
if isinstance(meter, MetersDict._DerivedMeter):
continue
meter.reset()
class _DerivedMeter(Meter):
"""A Meter whose values are derived from other Meters."""
def __init__(self, fn):
self.fn = fn
def reset(self):
pass
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A standalone module for aggregating metrics.
Metrics can be logged from anywhere using the `log_*` functions defined
in this module. The logged values will be aggregated dynamically based
on the aggregation context in which the logging occurs. See the
:func:`aggregate` context manager for more details.
"""
import contextlib
import uuid
from collections import OrderedDict, defaultdict
from typing import Callable, Dict, List, Optional
from .meters import *
# Aggregation contexts are considered "active" when inside the scope
# created by the :func:`aggregate` context manager.
_aggregators = OrderedDict()
_active_aggregators = OrderedDict()
_active_aggregators_cnt = defaultdict(lambda: 0)
def reset() -> None:
"""Reset all metrics aggregators."""
_aggregators.clear()
_active_aggregators.clear()
_active_aggregators_cnt.clear()
# The "default" aggregator observes all logged values.
_aggregators["default"] = MetersDict()
_active_aggregators["default"] = _aggregators["default"]
_active_aggregators_cnt["default"] = 1
reset()
@contextlib.contextmanager
def aggregate(name: Optional[str] = None, new_root: bool = False):
"""Context manager to aggregate metrics under a given name.
Aggregations can be nested. If *new_root* is ``False``, then logged
metrics will be recorded along the entire stack of nested
aggregators, including a global "default" aggregator. If *new_root*
is ``True``, then this aggregator will be the root of a new
aggregation stack, thus bypassing any parent aggregators.
Note that aggregation contexts are uniquely identified by their
*name* (e.g., train, valid). Creating a context with an existing
name will reuse the corresponding :class:`MetersDict` instance.
If no name is given, then a temporary aggregator will be created.
Usage::
with metrics.aggregate("train"):
for step, batch in enumerate(epoch):
with metrics.aggregate("train_inner") as agg:
metrics.log_scalar("loss", get_loss(batch))
if step % log_interval == 0:
print(agg.get_smoothed_value("loss"))
agg.reset()
print(metrics.get_smoothed_values("train")["loss"])
Args:
name (str): name of the aggregation. Defaults to a
random/temporary name if not given explicitly.
new_root (bool): make this aggregation the root of a new
aggregation stack.
"""
if name is None:
# generate a temporary name
name = str(uuid.uuid4())
assert name not in _aggregators
agg = MetersDict()
else:
assert name != "default"
agg = _aggregators.setdefault(name, MetersDict())
if new_root:
backup_aggregators = _active_aggregators.copy()
_active_aggregators.clear()
backup_aggregators_cnt = _active_aggregators_cnt.copy()
_active_aggregators_cnt.clear()
_active_aggregators[name] = agg
_active_aggregators_cnt[name] += 1
yield agg
_active_aggregators_cnt[name] -= 1
if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
del _active_aggregators[name]
if new_root:
_active_aggregators.clear()
_active_aggregators.update(backup_aggregators)
_active_aggregators_cnt.clear()
_active_aggregators_cnt.update(backup_aggregators_cnt)
def get_active_aggregators() -> List[MetersDict]:
return list(_active_aggregators.values())
def log_scalar(
key: str,
value: float,
weight: float = 1,
priority: int = 10,
round: Optional[int] = None,
):
"""Log a scalar value.
Args:
key (str): name of the field to log
value (float): value to log
weight (float): weight that this value contributes to the average.
A weight of 0 will always log the latest value.
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, AverageMeter(round=round), priority)
agg[key].update(value, weight)
def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20):
"""Log a scalar value derived from other meters.
Args:
key (str): name of the field to log
fn (Callable[[MetersDict], float]): function that takes a single
argument *meters* and returns the derived value
priority (int): smaller values are logged earlier in the output
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
def log_speed(
key: str,
value: float,
priority: int = 30,
round: Optional[int] = None,
):
"""Log the rate of some quantity per second.
Args:
key (str): name of the field to log
value (float): value to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, TimeMeter(round=round), priority)
agg[key].reset() # reset meter on the first call
else:
agg[key].update(value)
def log_start_time(key: str, priority: int = 40, round: Optional[int] = None):
"""Log the duration of some event in seconds.
The duration will be computed once :func:`log_stop_time` is called.
Args:
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, StopwatchMeter(round=round), priority)
agg[key].start()
def log_stop_time(key: str, weight: float = 0.0, prehook=None):
"""Log the duration of some event in seconds.
The duration will be computed since :func:`log_start_time` was called.
Set weight > 0 to report the average time instead of the sum.
Args:
key (str): name of the field to log
weight (float): weight that this time contributes to the average
prehook (function, no arguments): will be called before the timer
is stopped. For example, use prehook=torch.cuda.synchronize to
make sure all gpu operations are done before timer is stopped.
"""
for agg in get_active_aggregators():
if key in agg:
agg[key].stop(weight, prehook)
def log_custom(
new_meter_fn: Callable[[], Meter],
key: str,
*args,
priority: int = 50,
**kwargs,
):
"""Log using a custom Meter.
Any extra *args* or *kwargs* will be passed through to the Meter's
*update* method.
Args:
new_meter_fn (Callable[[], Meter]): function that returns a new
Meter instance
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, new_meter_fn(), priority)
agg[key].update(*args, **kwargs)
def reset_meter(name: str, key: str) -> None:
"""Reset Meter instance aggregated under a given *name* and *key*."""
meter = get_meter(name, key)
if meter is not None:
meter.reset()
def reset_meters(name: str) -> None:
"""Reset Meter instances aggregated under a given *name*."""
meters = get_meters(name)
if meters is not None:
meters.reset()
def get_meter(name: str, key: str) -> Meter:
"""Get a single Meter instance aggregated under *name* and *key*.
Returns:
Meter or None if no metrics have been logged under *name* and *key*.
"""
if name not in _aggregators:
return None
return _aggregators[name].get(key, None)
def get_meters(name: str) -> MetersDict:
"""Get Meter instances aggregated under a given *name*.
Returns:
MetersDict or None if no metrics have been logged under *name*.
"""
return _aggregators.get(name, None)
def get_smoothed_value(name: str, key: str) -> float:
"""Get a single smoothed value.
Raises:
KeyError: if no metrics have been logged under *name* and *key*.
"""
return _aggregators[name].get_smoothed_value(key)
def get_smoothed_values(name: str) -> Dict[str, float]:
"""Get smoothed values aggregated under a given *name*.
Raises:
KeyError: if no metrics have been logged under *name*.
"""
return _aggregators[name].get_smoothed_values()
def state_dict():
return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
def load_state_dict(state_dict):
for name, agg_state in state_dict.items():
_aggregators[name] = MetersDict()
_aggregators[name].load_state_dict(agg_state)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
import atexit
import json
import logging
import os
import sys
from collections import OrderedDict
from contextlib import contextmanager
from numbers import Number
from typing import Optional
import torch
from .meters import AverageMeter, StopwatchMeter, TimeMeter
logger = logging.getLogger(__name__)
def progress_bar(
iterator,
log_format: Optional[str] = None,
log_interval: int = 100,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
default_log_format: str = "tqdm",
):
if log_format is None:
log_format = default_log_format
if log_format == "tqdm" and not sys.stderr.isatty():
log_format = "simple"
if log_format == "json":
bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
elif log_format == "none":
bar = NoopProgressBar(iterator, epoch, prefix)
elif log_format == "simple":
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
elif log_format == "tqdm":
bar = TqdmProgressBar(iterator, epoch, prefix)
else:
raise ValueError("Unknown log format: {}".format(log_format))
if tensorboard_logdir:
try:
# [FB only] custom wrapper for TensorBoard
import palaas # noqa
from .fb_tbmf_wrapper import FbTbmfWrapper
bar = FbTbmfWrapper(bar, log_interval)
except ImportError:
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
return bar
def build_progress_bar(
args,
iterator,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
default: str = "tqdm",
no_progress_bar: str = "none",
):
"""Legacy wrapper that takes an argparse.Namespace."""
if getattr(args, "no_progress_bar", False):
default = no_progress_bar
if getattr(args, "distributed_rank", 0) == 0:
tensorboard_logdir = getattr(args, "tensorboard_logdir", None)
else:
tensorboard_logdir = None
return progress_bar(
iterator,
log_format=args.log_format,
log_interval=args.log_interval,
epoch=epoch,
prefix=prefix,
tensorboard_logdir=tensorboard_logdir,
default_log_format=default,
)
def format_stat(stat):
if isinstance(stat, Number):
stat = "{:g}".format(stat)
elif isinstance(stat, AverageMeter):
stat = "{:.3f}".format(stat.avg)
elif isinstance(stat, TimeMeter):
stat = "{:g}".format(round(stat.avg))
elif isinstance(stat, StopwatchMeter):
stat = "{:g}".format(round(stat.sum))
elif torch.is_tensor(stat):
stat = stat.tolist()
return stat
class BaseProgressBar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.n = getattr(iterable, "n", 0)
self.epoch = epoch
self.prefix = ""
if epoch is not None:
self.prefix += "epoch {:03d}".format(epoch)
if prefix is not None:
self.prefix += (" | " if self.prefix != "" else "") + prefix
def __len__(self):
return len(self.iterable)
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def __iter__(self):
raise NotImplementedError
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
raise NotImplementedError
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
raise NotImplementedError
def update_config(self, config):
"""Log latest configuration."""
pass
def _str_commas(self, stats):
return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
def _str_pipes(self, stats):
return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype
for key in postfix.keys():
postfix[key] = str(format_stat(postfix[key]))
return postfix
@contextmanager
def rename_logger(logger, new_name):
old_name = logger.name
if new_name is not None:
logger.name = new_name
yield logger
logger.name = old_name
class JsonProgressBar(BaseProgressBar):
"""Log output in JSON format."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.i = None
self.size = None
def __iter__(self):
self.size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.n):
self.i = i
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
step = step or self.i or 0
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
update = (
self.epoch - 1 + (self.i + 1) / float(self.size)
if self.epoch is not None
else None
)
stats = self._format_stats(stats, epoch=self.epoch, update=update)
with rename_logger(logger, tag):
logger.info(json.dumps(stats))
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self.stats = stats
if tag is not None:
self.stats = OrderedDict(
[(tag + "_" + k, v) for k, v in self.stats.items()]
)
stats = self._format_stats(self.stats, epoch=self.epoch)
with rename_logger(logger, tag):
logger.info(json.dumps(stats))
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
postfix["epoch"] = epoch
if update is not None:
postfix["update"] = round(update, 3)
# Preprocess stats according to datatype
for key in stats.keys():
postfix[key] = format_stat(stats[key])
return postfix
class NoopProgressBar(BaseProgressBar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
pass
class SimpleProgressBar(BaseProgressBar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.i = None
self.size = None
def __iter__(self):
self.size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.n):
self.i = i
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
step = step or self.i or 0
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
stats = self._format_stats(stats)
postfix = self._str_commas(stats)
with rename_logger(logger, tag):
logger.info(
"{}: {:5d} / {:d} {}".format(
self.prefix, self.i + 1, self.size, postfix
)
)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
with rename_logger(logger, tag):
logger.info("{} | {}".format(self.prefix, postfix))
class TqdmProgressBar(BaseProgressBar):
"""Log to tqdm."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
from tqdm import tqdm
self.tqdm = tqdm(
iterable,
self.prefix,
leave=False,
disable=(logger.getEffectiveLevel() > logging.INFO),
)
def __iter__(self):
return iter(self.tqdm)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
with rename_logger(logger, tag):
logger.info("{} | {}".format(self.prefix, postfix))
try:
_tensorboard_writers = {}
from torch.utils.tensorboard import SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
except ImportError:
SummaryWriter = None
def _close_writers():
for w in _tensorboard_writers.values():
w.close()
atexit.register(_close_writers)
class TensorboardProgressBarWrapper(BaseProgressBar):
"""Log to tensorboard."""
def __init__(self, wrapped_bar, tensorboard_logdir):
self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir
if SummaryWriter is None:
logger.warning(
"tensorboard not found, please install with: pip install tensorboard"
)
def _writer(self, key):
if SummaryWriter is None:
return None
_writers = _tensorboard_writers
if key not in _writers:
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
_writers[key].add_text("sys.argv", " ".join(sys.argv))
return _writers[key]
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats to tensorboard."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def update_config(self, config):
"""Log latest configuration."""
# TODO add hparams to Tensorboard
self.wrapped_bar.update_config(config)
def _log_to_tensorboard(self, stats, tag=None, step=None):
writer = self._writer(tag or "")
if writer is None:
return
if step is None:
step = stats["num_updates"]
for key in stats.keys() - {"num_updates"}:
if isinstance(stats[key], AverageMeter):
writer.add_scalar(key, stats[key].val, step)
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)
elif torch.is_tensor(stats[key]) and stats[key].numel() == 1:
writer.add_scalar(key, stats[key].item(), step)
writer.flush()
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import importlib
import os
from unicore import registry
from unicore.losses.unicore_loss import ( # noqa
UnicoreLoss,
)
(
build_loss_,
register_loss,
CRITERION_REGISTRY,
) = registry.setup_registry(
"--loss", base_class=UnicoreLoss, default="cross_entropy"
)
def build_loss(args, task):
return build_loss_(args, task)
# automatically import any Python files in the losses/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("unicore.losses." + file_name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment