Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from itertools import chain
from libc.math cimport ceil
cimport cython
cimport numpy as np
from libc.stdint cimport int32_t, int64_t
DTYPE = np.int64
ctypedef int64_t DTYPE_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size):
cdef DTYPE_t total_size = sizes.sum()
cdef DTYPE_t length = <DTYPE_t> ceil(total_size / <double> block_size)
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE)
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
cdef DTYPE_t i
cdef DTYPE_t start
cdef DTYPE_t end
for i in range(length):
start = i * block_size
end = min(start + block_size, total_size)
slice_indices_view[i][0] = start
slice_indices_view[i][1] = end
return slice_indices
cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list):
"""
Faster function to convert DTYPE_t list of list.
Only fast when there are huge number of rows and low number of columns.
"""
cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1)
return flat.reshape((len(list_of_list), -1))
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len):
cdef DTYPE_t tok_idx = 0
cdef DTYPE_t sz_idx = 0
cdef DTYPE_t curr_size = 0
cdef DTYPE_t i = 0
cdef DTYPE_t length
cdef DTYPE_t total_size
cdef DTYPE_t[:] sizes_view = sizes
cdef np.ndarray[DTYPE_t, ndim=2] slice_indices
cdef list slice_indices_list = []
if break_mode is None or break_mode == 'none':
slice_indices = _get_slice_indices_none_mode(sizes, block_size)
elif break_mode == 'complete':
while sz_idx < len(sizes_view):
if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes_view[sz_idx]
sz_idx += 1
else:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
slice_indices = _fast_convert_to_np_array(slice_indices_list)
elif break_mode == 'complete_doc':
while sz_idx < len(sizes_view):
if (
(curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0)
# an empty sentence indicates end-of-document:
and sizes_view[sz_idx] != document_sep_len
):
curr_size += sizes_view[sz_idx]
sz_idx += 1
else:
# Only keep non-empty documents.
if curr_size > 1:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if sizes_view[sz_idx] == document_sep_len:
tok_idx += sizes_view[sz_idx]
sz_idx += 1
if curr_size > 1:
slice_indices_list.append((tok_idx, tok_idx + curr_size))
slice_indices = _fast_convert_to_np_array(slice_indices_list)
elif break_mode == 'eos':
slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE)
cumsum = sizes.cumsum(axis=0)
slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1]
slice_indices[:, 1] = cumsum
else:
raise ValueError('Invalid break_mode: ' + break_mode)
return slice_indices
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices):
cdef DTYPE_t start_ds_idx
cdef DTYPE_t start_offset
cdef DTYPE_t end_ds_idx
cdef DTYPE_t i
cdef DTYPE_t s
cdef DTYPE_t e
cdef DatasetSearcher ds = DatasetSearcher(sizes)
cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE)
cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index
cdef DTYPE_t[:, :] slice_indices_view = slice_indices
cdef Py_ssize_t x_max = slice_indices.shape[0]
for i in range(x_max):
s = slice_indices_view[i][0]
e = slice_indices_view[i][1]
ds.seek(s)
start_ds_idx = ds.current_index
start_offset = ds.current_offset
if e <= s:
end_ds_idx = start_ds_idx
else:
ds.seek(e - 1)
end_ds_idx = ds.current_index
block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset
block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index
block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset
return block_to_dataset_index
cdef class DatasetSearcher(object):
"""Helper for mapping "flat" indices to indices and offsets in an
underlying dataset."""
cdef DTYPE_t current_i
cdef DTYPE_t current_offset
cdef DTYPE_t current_index
cdef DTYPE_t[:] sizes
def __init__(self, DTYPE_t[:] sizes):
self.sizes = sizes
self.reset()
cdef reset(self):
self.current_offset = 0 # offset within current index in underlying dataset
self.current_i = 0 # "flat" index
self.current_index = 0 # index in underlying dataset
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef int step(self, DTYPE_t i):
cdef DTYPE_t to_consume
cdef DTYPE_t remaining
if i < self.current_i:
self.reset()
if i > self.current_i:
to_consume = i - self.current_i
remaining = self.sizes[self.current_index] - self.current_offset
if remaining > to_consume:
self.current_offset += to_consume
self.current_i += to_consume
else:
assert remaining >= 0
self.current_i += remaining
self.current_index += 1
self.current_offset = 0
return 1
return 0
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef seek(self, DTYPE_t i):
cdef int not_done = 1
while not_done == 1:
not_done = self.step(i)
assert self.current_i == i
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class TransformEosDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to wrap
eos (int): index of the end-of-sentence symbol
append_eos_to_src (bool, optional): append EOS to the end of src
remove_eos_from_src (bool, optional): remove EOS from the end of src
append_eos_to_tgt (bool, optional): append EOS to the end of tgt
remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt
"""
def __init__(
self,
dataset,
eos,
append_eos_to_src=False,
remove_eos_from_src=False,
append_eos_to_tgt=False,
remove_eos_from_tgt=False,
has_target=True,
):
if not isinstance(dataset, FairseqDataset):
raise ValueError("dataset must be an instance of FairseqDataset")
if append_eos_to_src and remove_eos_from_src:
raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src")
if append_eos_to_tgt and remove_eos_from_tgt:
raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt")
self.dataset = dataset
self.eos = torch.LongTensor([eos])
self.append_eos_to_src = append_eos_to_src
self.remove_eos_from_src = remove_eos_from_src
self.append_eos_to_tgt = append_eos_to_tgt
self.remove_eos_from_tgt = remove_eos_from_tgt
self.has_target = has_target
# precompute how we should adjust the reported sizes
self._src_delta = 0
self._src_delta += 1 if append_eos_to_src else 0
self._src_delta -= 1 if remove_eos_from_src else 0
self._tgt_delta = 0
self._tgt_delta += 1 if append_eos_to_tgt else 0
self._tgt_delta -= 1 if remove_eos_from_tgt else 0
self._checked_src = False
self._checked_tgt = False
def _check_src(self, src, expect_eos):
if not self._checked_src:
assert (src[-1] == self.eos[0]) == expect_eos
self._checked_src = True
def _check_tgt(self, tgt, expect_eos):
if self.has_target and not self._checked_tgt:
assert (tgt[-1] == self.eos[0]) == expect_eos
self._checked_tgt = True
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
def transform(item):
if self.append_eos_to_src:
self.eos = self.eos.to(device=item["source"].device)
self._check_src(item["source"], expect_eos=False)
item["source"] = torch.cat([item["source"], self.eos])
if self.remove_eos_from_src:
self.eos = self.eos.to(device=item["source"].device)
self._check_src(item["source"], expect_eos=True)
item["source"] = item["source"][:-1]
if self.append_eos_to_tgt:
self.eos = self.eos.to(device=item["target"].device)
self._check_tgt(item["target"], expect_eos=False)
item["target"] = torch.cat([item["target"], self.eos])
if self.remove_eos_from_tgt:
self.eos = self.eos.to(device=item["target"].device)
self._check_tgt(item["target"], expect_eos=True)
item["target"] = item["target"][:-1]
return item
samples = list(map(transform, samples))
return self.dataset.collater(samples)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
if self.has_target:
src_len, tgt_len = self.dataset.size(index)
return (src_len + self._src_delta, tgt_len + self._tgt_delta)
else:
return self.dataset.size(index)
def ordered_indices(self):
# NOTE: we assume that the ordering does not change based on the
# addition or removal of eos
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from . import FairseqDataset
class TransformEosLangPairDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
collated samples of language pair dataset.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset that collates sample into
LanguagePairDataset schema
src_eos (int): original source end-of-sentence symbol index to be replaced
new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol
tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced
new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the
beginning of 'prev_output_tokens'
"""
def __init__(
self,
dataset: FairseqDataset,
src_eos: int,
new_src_eos: Optional[int] = None,
tgt_bos: Optional[int] = None,
new_tgt_bos: Optional[int] = None,
):
self.dataset = dataset
self.src_eos = src_eos
self.new_src_eos = new_src_eos
self.tgt_bos = tgt_bos
self.new_tgt_bos = new_tgt_bos
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples, **extra_args):
samples = self.dataset.collater(samples, **extra_args)
if self.new_src_eos is not None:
if self.dataset.left_pad_source:
assert (
samples["net_input"]["src_tokens"][:, -1] != self.src_eos
).sum() == 0
samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos
else:
eos_idx = samples["net_input"]["src_lengths"] - 1
assert (
samples["net_input"]["src_tokens"][
torch.arange(eos_idx.size(0)), eos_idx
]
!= self.src_eos
).sum() == 0
eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1)
samples["net_input"]["src_tokens"].scatter_(
1, eos_idx, self.new_src_eos
)
if (
self.new_tgt_bos is not None
and "prev_output_tokens" in samples["net_input"]
):
if self.dataset.left_pad_target:
# TODO: support different padding direction on target side
raise NotImplementedError(
"TransformEosLangPairDataset does not implement --left-pad-target True option"
)
else:
assert (
samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos
).sum() == 0
samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos
return samples
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
@property
def sizes(self):
# dataset.sizes can be a dynamically computed sizes:
return self.dataset.sizes
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .utils import ChoiceEnum, FairseqDataclass
__all__ = ["FairseqDataclass", "ChoiceEnum"]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.dataclass.utils import ChoiceEnum
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"])
DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"])
ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from fairseq.criterions import CRITERION_DATACLASS_REGISTRY
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.dataclass.constants import (
DDP_BACKEND_CHOICES,
DISTRIBUTED_WRAPPER_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
ZERO_SHARDING_CHOICES,
)
from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass
from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY
from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY
from fairseq.optim.bmuf import FairseqBMUFConfig
from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY
from fairseq.tasks import TASK_DATACLASS_REGISTRY
from hydra.core.config_store import ConfigStore
@dataclass
class CommonParams(FairseqDataclass):
# This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were
# used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc.
no_progress_bar: bool = field(
default=False, metadata={"help": "disable progress bar"}
)
log_interval: int = field(
default=100,
metadata={
"help": "log progress every N batches (when progress bar is disabled)"
},
)
log_format: Optional[LOG_FORMAT_CHOICES] = field(
default=None, metadata={"help": "log format to use"}
)
tensorboard_logdir: Optional[str] = field(
default=None,
metadata={
"help": "path to save logs for tensorboard, should match --logdir "
"of running tensorboard (default: no tensorboard logging)"
},
)
seed: int = field(
default=1, metadata={"help": "pseudo random number generator seed"}
)
cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"})
tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"})
bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"})
memory_efficient_bf16: bool = field(
default=False,
metadata={
"help": "use a memory-efficient version of BF16 training; implies --bf16"
},
)
fp16: bool = field(default=False, metadata={"help": "use FP16"})
memory_efficient_fp16: bool = field(
default=False,
metadata={
"help": "use a memory-efficient version of FP16 training; implies --fp16"
},
)
fp16_no_flatten_grads: bool = field(
default=False, metadata={"help": "don't flatten FP16 grads tensor"}
)
fp16_init_scale: int = field(
default=2 ** 7, metadata={"help": "default FP16 loss scale"}
)
fp16_scale_window: Optional[int] = field(
default=None,
metadata={"help": "number of updates before increasing loss scale"},
)
fp16_scale_tolerance: float = field(
default=0.0,
metadata={
"help": "pct of updates that can overflow before decreasing the loss scale"
},
)
min_loss_scale: float = field(
default=1e-4,
metadata={"help": "minimum FP16 loss scale, after which training is stopped"},
)
threshold_loss_scale: Optional[float] = field(
default=None, metadata={"help": "threshold FP16 loss scale from below"}
)
user_dir: Optional[str] = field(
default=None,
metadata={
"help": "path to a python module containing custom extensions (tasks and/or architectures)"
},
)
empty_cache_freq: int = field(
default=0,
metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"},
)
all_gather_list_size: int = field(
default=16384,
metadata={"help": "number of bytes reserved for gathering stats from workers"},
)
model_parallel_size: int = field(
default=1, metadata={"help": "total number of GPUs to parallelize model over"}
)
checkpoint_suffix: str = field(
default="", metadata={"help": "suffix to add to the checkpoint file name"}
)
checkpoint_shard_count: int = field(
default=1,
metadata={
"help": "Number of shards containing the checkpoint - "
"if the checkpoint is over 300GB, it is preferable "
"to split it into shards to prevent OOM on CPU while loading "
"the checkpoint"
},
)
quantization_config_path: Optional[str] = field(
default=None, metadata={"help": "path to quantization config file"}
)
profile: bool = field(
default=False, metadata={"help": "enable autograd profiler emit_nvtx"}
)
@dataclass
class DistributedTrainingParams(FairseqDataclass):
distributed_world_size: int = field(
default=max(1, torch.cuda.device_count()),
metadata={
"help": "total number of GPUs across all nodes (default: all visible GPUs)"
},
)
distributed_rank: Optional[int] = field(
default=0, metadata={"help": "rank of the current worker"}
)
distributed_backend: str = field(
default="nccl", metadata={"help": "distributed backend"}
)
distributed_init_method: Optional[str] = field(
default=None,
metadata={
"help": "typically tcp://hostname:port that will be used to "
"establish initial connetion"
},
)
distributed_port: int = field(
default=-1,
metadata={
"help": "port number (not required if using --distributed-init-method)"
},
)
device_id: int = field(
default=0,
metadata={
"help": "which GPU to use (usually configured automatically)",
"argparse_alias": "--local_rank",
},
)
distributed_no_spawn: bool = field(
default=False,
metadata={
"help": "do not spawn multiple processes even if multiple GPUs are visible"
},
)
ddp_backend: DDP_BACKEND_CHOICES = field(
default="c10d", metadata={"help": "DistributedDataParallel backend"}
)
bucket_cap_mb: int = field(
default=25, metadata={"help": "bucket size for reduction"}
)
fix_batches_to_gpus: bool = field(
default=False,
metadata={
"help": "don't shuffle batches between GPUs; this reduces overall "
"randomness and may affect precision but avoids the cost of re-reading the data"
},
)
find_unused_parameters: bool = field(
default=False,
metadata={
"help": "disable unused parameter detection (not applicable to "
"no_c10d ddp-backend"
},
)
fast_stat_sync: bool = field(
default=False,
metadata={"help": "[deprecated] this is now defined per Criterion"},
)
broadcast_buffers: bool = field(
default=False,
metadata={
"help": "Copy non-trainable parameters between GPUs, such as "
"batchnorm population statistics"
},
)
distributed_wrapper: DISTRIBUTED_WRAPPER_CHOICES = field(
default="DDP", metadata={"help": "DistributedDataParallel backend"}
)
slowmo_momentum: Optional[float] = field(
default=None,
metadata={
"help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, "
"0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs"
},
)
slowmo_algorithm: str = field(
default="LocalSGD", metadata={"help": "whether to use LocalSGD or SGP"}
)
localsgd_frequency: int = field(
default=3, metadata={"help": "Local SGD allreduce frequency"}
)
nprocs_per_node: int = field(
default=max(1, torch.cuda.device_count()),
metadata={
"help": "number of GPUs in each node. An allreduce operation across GPUs in "
"a node is very fast. Hence, we do allreduce across GPUs in a node, "
"and gossip across different nodes"
},
)
pipeline_model_parallel: bool = field(
default=False,
metadata={"help": "if set, use pipeline model parallelism across GPUs"},
)
pipeline_balance: str = field(
default=None,
metadata={
"help": "partition the model into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_balance) "
"should equal the total number of layers in the model"
},
)
pipeline_devices: str = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-balance argument"
},
)
pipeline_chunks: int = field(
default=0, metadata={"help": "microbatch count for pipeline model parallelism"}
)
pipeline_encoder_balance: str = field(
default=None,
metadata={
"help": "partition the pipeline parallel encoder into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_encoder_balance) "
"should equal the total number of encoder layers in the model"
},
)
pipeline_encoder_devices: str = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-encoder-balance argument"
},
)
pipeline_decoder_balance: str = field(
default=None,
metadata={
"help": "partition the pipeline parallel decoder into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_decoder_balance) "
"should equal the total number of decoder layers in the model"
},
)
pipeline_decoder_devices: str = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-decoder-balance argument"
},
)
pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field(
default="never",
metadata={"help": "checkpointing mode for pipeline model parallelism"},
)
zero_sharding: ZERO_SHARDING_CHOICES = field(
default="none", metadata={"help": "ZeRO sharding"}
)
@dataclass
class DatasetParams(FairseqDataclass):
num_workers: int = field(
default=1, metadata={"help": "how many subprocesses to use for data loading"}
)
skip_invalid_size_inputs_valid_test: bool = field(
default=False,
metadata={"help": "ignore too long or too short lines in valid and test set"},
)
max_tokens: Optional[int] = field(
default=None, metadata={"help": "maximum number of tokens in a batch"}
)
batch_size: Optional[int] = field(
default=None, metadata={"help": "number of examples in a batch"}
)
required_batch_size_multiple: int = field(
default=8, metadata={"help": "batch size will be a multiplier of this value"}
)
required_seq_len_multiple: int = field(
default=1,
metadata={
"help": "maximum sequence length in batch will be a multiplier of this value"
},
)
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = field(
default=None, metadata={"help": "output dataset implementation"}
)
data_buffer_size: int = field(
default=10, metadata={"help": "Number of batches to preload"}
)
train_subset: str = field(
default="train",
metadata={"help": "data subset to use for training (e.g. train, valid, test)"},
)
valid_subset: str = field(
default="valid",
metadata={
"help": "comma separated list of data subsets to use for validation"
" (e.g. train, valid, test)"
},
)
validate_interval: int = field(
default=1, metadata={"help": "validate every N epochs"}
)
validate_interval_updates: int = field(
default=0, metadata={"help": "validate every N updates"}
)
validate_after_updates: int = field(
default=0, metadata={"help": "dont validate until reaching this many updates"}
)
fixed_validation_seed: Optional[int] = field(
default=None, metadata={"help": "specified random seed for validation"}
)
disable_validation: bool = field(
default=False, metadata={"help": "disable validation"}
)
max_tokens_valid: Optional[int] = field(
default=None,
metadata={
"help": "maximum number of tokens in a validation batch"
" (defaults to --max-tokens)"
},
)
batch_size_valid: Optional[int] = field(
default=None,
metadata={
"help": "batch size of the validation batch" " (defaults to --batch-size)"
},
)
curriculum: int = field(
default=0, metadata={"help": "don't shuffle batches for first N epochs"}
)
gen_subset: str = field(
default="test",
metadata={"help": "data subset to generate (train, valid, test)"},
)
num_shards: int = field(
default=1, metadata={"help": "shard generation over N shards"}
)
shard_id: int = field(
default=0, metadata={"help": "id of the shard to generate (id < num_shards)"}
)
@dataclass
class OptimizationParams(FairseqDataclass):
max_epoch: int = field(
default=0, metadata={"help": "force stop training at specified epoch"}
)
max_update: int = field(
default=0, metadata={"help": "force stop training at specified update"}
)
stop_time_hours: float = field(
default=0,
metadata={
"help": "force stop training after specified cumulative time (if >0)"
},
)
clip_norm: float = field(
default=0.0, metadata={"help": "clip threshold of gradients"}
)
sentence_avg: bool = field(
default=False,
metadata={
"help": "normalize gradients by the number of sentences in a batch"
" (default is to normalize by number of tokens)"
},
)
update_freq: List[int] = field(
default_factory=lambda: [1],
metadata={"help": "update parameters every N_i batches, when in epoch i"},
)
lr: List[float] = field(
default_factory=lambda: [0.25],
metadata={
"help": "learning rate for the first N epochs; all epochs >N using LR_N"
" (note: this may be interpreted differently depending on --lr-scheduler)"
},
)
min_lr: float = field(
default=-1.0,
metadata={"help": "stop training when the learning rate reaches this minimum"},
)
use_bmuf: bool = field(
default=False,
metadata={
"help": "specify global optimizer for syncing models on different GPUs/shards"
},
)
@dataclass
class CheckpointParams(FairseqDataclass):
save_dir: str = field(
default="checkpoints", metadata={"help": "path to save checkpoints"}
)
restore_file: str = field(
default="checkpoint_last.pt",
metadata={
"help": "filename from which to load checkpoint "
"(default: <save-dir>/checkpoint_last.pt"
},
)
finetune_from_model: Optional[str] = field(
default=None,
metadata={
"help": "finetune from a pretrained model; note that meters and lr scheduler will be reset"
},
)
reset_dataloader: bool = field(
default=False,
metadata={
"help": "if set, does not reload dataloader state from the checkpoint"
},
)
reset_lr_scheduler: bool = field(
default=False,
metadata={
"help": "if set, does not load lr scheduler state from the checkpoint"
},
)
reset_meters: bool = field(
default=False,
metadata={"help": "if set, does not load meters from the checkpoint"},
)
reset_optimizer: bool = field(
default=False,
metadata={"help": "if set, does not load optimizer state from the checkpoint"},
)
optimizer_overrides: str = field(
default="{}",
metadata={
"help": "a dictionary used to override optimizer args when loading a checkpoint"
},
)
save_interval: int = field(
default=1, metadata={"help": "save a checkpoint every N epochs"}
)
save_interval_updates: int = field(
default=0, metadata={"help": "save a checkpoint (and validate) every N updates"}
)
keep_interval_updates: int = field(
default=-1,
metadata={
"help": "keep the last N checkpoints saved with --save-interval-updates"
},
)
keep_last_epochs: int = field(
default=-1, metadata={"help": "keep last N epoch checkpoints"}
)
keep_best_checkpoints: int = field(
default=-1, metadata={"help": "keep best N checkpoints based on scores"}
)
no_save: bool = field(
default=False, metadata={"help": "don't save models or checkpoints"}
)
no_epoch_checkpoints: bool = field(
default=False, metadata={"help": "only store last and best checkpoints"}
)
no_last_checkpoints: bool = field(
default=False, metadata={"help": "don't store last checkpoints"}
)
no_save_optimizer_state: bool = field(
default=False,
metadata={"help": "don't save optimizer-state as part of checkpoint"},
)
best_checkpoint_metric: str = field(
default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'}
)
maximize_best_checkpoint_metric: bool = field(
default=False,
metadata={
"help": 'select the largest metric value for saving "best" checkpoints'
},
)
patience: int = field(
default=-1,
metadata={
"help": (
"early stop training if valid performance doesn't "
"improve for N consecutive validation runs; note "
"that this is influenced by --validate-interval"
)
},
)
@dataclass
class CommonEvalParams(FairseqDataclass):
path: Optional[str] = field(
default=None, metadata={"help": "path(s) to model file(s), colon separated"}
)
remove_bpe: Optional[str] = field(
default=None,
metadata={
"help": "remove BPE tokens before scoring (can be set to sentencepiece)",
"argparse_const": "@@ ",
},
)
quiet: bool = field(default=False, metadata={"help": "only print final scores"})
model_overrides: str = field(
default="{}",
metadata={
"help": "a dictionary used to override model args at generation that were used during model training"
},
)
results_path: Optional[str] = field(
default=None, metadata={"help": "path to save eval results (optional)"}
)
@dataclass
class EvalLMParams(FairseqDataclass):
output_word_probs: bool = field(
default=False,
metadata={
"help": "if set, outputs words and their predicted log probabilities to standard output"
},
)
output_word_stats: bool = field(
default=False,
metadata={
"help": "if set, outputs word statistics such as word count, average probability, etc"
},
)
context_window: int = field(
default=0,
metadata={
"help": "ensures that every evaluated token has access to a context of at least this size, if possible"
},
)
softmax_batch: int = field(
default=sys.maxsize,
metadata={
"help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory"
},
)
@dataclass
class TrainingConfig(FairseqDataclass):
"""Config for training, a composition of training params"""
common: CommonParams = CommonParams()
distributed_training: DistributedTrainingParams = DistributedTrainingParams()
dataset: DatasetParams = DatasetParams()
optimization: OptimizationParams = OptimizationParams()
checkpoint: CheckpointParams = CheckpointParams()
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
@dataclass
class EvalLMConfig(FairseqDataclass):
"""Config for eval lm, a composition of eval_lm params"""
common: CommonParams = CommonParams()
distributed_training: DistributedTrainingParams = DistributedTrainingParams()
dataset: DatasetParams = DatasetParams()
optimization: OptimizationParams = OptimizationParams()
checkpoint: CheckpointParams = CheckpointParams()
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
common_eval: CommonEvalParams = CommonEvalParams()
eval_lm: EvalLMParams = EvalLMParams()
def register_params_dataclass(
cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass]
) -> None:
"""register params dataclass in config store"""
node_ = data_class(_name=data_class.name())
cs.store(name=name, group=group, node=node_)
def register_module_dataclass(
cs: ConfigStore, registry: Dict[str, Any], group: str
) -> None:
"""register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc."""
# note that if `group == model`, we register all model archs, not the model name.
for k, v in registry.items():
if v is not None:
node_ = v(_name=k)
cs.store(name=k, group=group, node=node_)
def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
"""cs: config store instance, register common training configs"""
register_params_dataclass(
cs, name="training_params", group="params", data_class=TrainingConfig
)
register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model")
register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion")
register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer")
register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler")
def register_eval_lm_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
"""cs: config store instance, register common training configs"""
register_params_dataclass(
cs, name="eval_lm_params", group="params", data_class=EvalLMConfig
)
register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion")
register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer")
register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler")
def _override_attr(
sub_node: str, data_class: Type[FairseqDataclass], args: Namespace
) -> List[str]:
overrides = []
for k in data_class.__dataclass_fields__.keys():
if k == "_name":
# private member, skip
continue
if not hasattr(args, k):
# print(f"cannot override {sub_node}.{k} since args does not have attribute {k}")
continue
if getattr(args, k) is None:
overrides.append("{}.{}=null".format(sub_node, k))
elif getattr(args, k) == "":
overrides.append("{}.{}=''".format(sub_node, k))
elif isinstance(getattr(args, k), str):
if (
getattr(args, k).startswith("[")
or getattr(args, k).startswith("(")
or getattr(args, k).startswith("{")
or ("," in getattr(args, k))
):
overrides.append("{}.{}='{}'".format(sub_node, k, getattr(args, k)))
else:
overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k)))
else:
overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k)))
return overrides
def override_training_args(args: Namespace) -> Tuple[List[str], List[str]]:
overrides = []
overrides.extend(_override_attr("params.common", CommonParams, args))
overrides.extend(_override_attr("params.dataset", DatasetParams, args))
overrides.extend(
_override_attr("params.distributed_training", DistributedTrainingParams, args)
)
overrides.extend(_override_attr("params.optimization", OptimizationParams, args))
overrides.extend(_override_attr("params.checkpoint", CheckpointParams, args))
overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args))
module_overrides, module_deletes = override_module_args(args)
overrides.extend(module_overrides)
return overrides, module_deletes
def override_eval_lm_args(args: Namespace) -> Tuple[List[str], List[str]]:
overrides = []
overrides.extend(_override_attr("params.common", CommonParams, args))
overrides.extend(_override_attr("params.dataset", DatasetParams, args))
overrides.extend(
_override_attr("params.distributed_training", DistributedTrainingParams, args)
)
overrides.extend(_override_attr("params.common_eval", CommonEvalParams, args))
overrides.extend(_override_attr("params.eval_lm", EvalLMParams, args))
overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args))
module_overrides, module_deletes = override_module_args(args)
overrides.extend(module_overrides)
return overrides, module_deletes
def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
"""use the field in args to overrides those in cfg"""
overrides = []
deletes = []
if args is not None:
assert (
hasattr(args, "task")
and hasattr(args, "criterion")
and hasattr(args, "optimizer")
and hasattr(args, "lr_scheduler")
)
if args.task in TASK_DATACLASS_REGISTRY:
overrides.append("task={}".format(args.task))
overrides.append("task._name={}".format(args.task))
overrides.extend(
_override_attr("task", TASK_DATACLASS_REGISTRY[args.task], args)
)
else:
deletes.append("task")
if args.criterion in CRITERION_DATACLASS_REGISTRY:
overrides.append("criterion={}".format(args.criterion))
overrides.append("criterion._name={}".format(args.criterion))
overrides.extend(
_override_attr(
"criterion", CRITERION_DATACLASS_REGISTRY[args.criterion], args
)
)
else:
deletes.append("criterion")
if args.optimizer in OPTIMIZER_DATACLASS_REGISTRY:
overrides.append("optimizer={}".format(args.optimizer))
overrides.append("optimizer._name={}".format(args.optimizer))
overrides.extend(
_override_attr(
"optimizer", OPTIMIZER_DATACLASS_REGISTRY[args.optimizer], args
)
)
else:
deletes.append("optimizer")
if args.lr_scheduler in LR_SCHEDULER_DATACLASS_REGISTRY:
overrides.append("lr_scheduler={}".format(args.lr_scheduler))
overrides.append("lr_scheduler._name={}".format(args.lr_scheduler))
overrides.extend(
_override_attr(
"lr_scheduler",
LR_SCHEDULER_DATACLASS_REGISTRY[args.lr_scheduler],
args,
)
)
else:
deletes.append("lr_scheduler")
no_dc = True
if hasattr(args, "arch"):
if args.arch in ARCH_MODEL_REGISTRY:
m_cls = ARCH_MODEL_REGISTRY[args.arch]
dc = getattr(m_cls, "__dataclass", None)
if dc is not None:
overrides.append("model={}".format(args.arch))
overrides.append("model._name={}".format(args.arch))
# override model params with those exist in args
overrides.extend(_override_attr("model", dc, args))
no_dc = False
if no_dc:
deletes.append("model")
return overrides, deletes
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import ArgumentParser
from dataclasses import MISSING, dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
def eval_str_list(x, x_type=float):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
try:
return list(map(x_type, x))
except TypeError:
return [x_type(x)]
class StrEnum(Enum):
def __str__(self):
return self.value
def __eq__(self, other: str):
return self.value == other
def __repr__(self):
return self.value
def ChoiceEnum(choices: List[str]):
"""return the Enum class used to enforce list of choices"""
return StrEnum("Choices", {k: k for k in choices})
@dataclass
class FairseqDataclass:
"""fairseq base dataclass that supported fetching attributes and metas"""
_name: Optional[str] = None
@staticmethod
def name():
return None
def _get_all_attributes(self) -> List[str]:
return [k for k in self.__dataclass_fields__.keys()]
def _get_meta(
self, attribute_name: str, meta: str, default: Optional[Any] = None
) -> Any:
return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
def _get_name(self, attribute_name: str) -> str:
return self.__dataclass_fields__[attribute_name].name
def _get_default(self, attribute_name: str) -> Any:
if hasattr(self, attribute_name):
if str(getattr(self, attribute_name)).startswith("${"):
return str(getattr(self, attribute_name))
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
"${"
):
return str(self.__dataclass_fields__[attribute_name].default)
elif (
getattr(self, attribute_name)
!= self.__dataclass_fields__[attribute_name].default
):
return getattr(self, attribute_name)
return self.__dataclass_fields__[attribute_name].default
def _get_default_factory(self, attribute_name: str) -> Any:
if hasattr(self, attribute_name):
if str(getattr(self, attribute_name)).startswith("${"):
return str(getattr(self, attribute_name))
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
"${"
):
return str(self.__dataclass_fields__[attribute_name].default)
elif (
getattr(self, attribute_name)
!= self.__dataclass_fields__[attribute_name].default_factory()
):
return getattr(self, attribute_name)
return self.__dataclass_fields__[attribute_name].default_factory()
def _get_type(self, attribute_name: str) -> Any:
return self.__dataclass_fields__[attribute_name].type
def _get_help(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "help")
def _get_argparse_const(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "argparse_const")
def _get_argparse_alias(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "argparse_alias")
def _get_choices(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "choices")
def gen_parser_from_dataclass(
parser: ArgumentParser,
dataclass_instance: FairseqDataclass,
delete_default: bool = False,
) -> None:
"""convert a dataclass instance to tailing parser arguments"""
import re
def argparse_name(name: str):
if name == "data":
# normally data is positional args
return name
if name == "_name":
# private member, skip
return None
return "--" + name.replace("_", "-")
def interpret_dc_type(field_type):
if isinstance(field_type, str):
raise RuntimeError()
typestring = str(field_type)
if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring):
return field_type.__args__[0]
return field_type
def get_kwargs_from_dc(
dataclass_instance: FairseqDataclass, k: str
) -> Dict[str, Any]:
"""k: dataclass attributes"""
field_type = dataclass_instance._get_type(k)
inter_type = interpret_dc_type(field_type)
if isinstance(inter_type, type) and issubclass(inter_type, List):
field_default = dataclass_instance._get_default_factory(k)
else:
field_default = dataclass_instance._get_default(k)
if isinstance(inter_type, type) and issubclass(inter_type, Enum):
field_choices = [t.value for t in list(inter_type)]
else:
field_choices = None
field_help = dataclass_instance._get_help(k)
field_const = dataclass_instance._get_argparse_const(k)
kwargs = {}
if isinstance(field_default, str) and field_default.startswith("${"):
kwargs["default"] = field_default
else:
if field_default is MISSING:
kwargs["required"] = True
if field_choices is not None:
kwargs["choices"] = field_choices
if (isinstance(inter_type, type) and issubclass(inter_type, List)) or (
"List" in str(inter_type)
):
if "int" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, int)
elif "float" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, float)
elif "str" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, str)
else:
raise NotImplementedError()
if field_default is not MISSING:
kwargs["default"] = ",".join(map(str, field_default))
elif (
isinstance(inter_type, type) and issubclass(inter_type, Enum)
) or "Enum" in str(inter_type):
kwargs["type"] = str
if field_default is not MISSING:
if isinstance(field_default, Enum):
kwargs["default"] = field_default.value
else:
kwargs["default"] = field_default
elif inter_type is bool:
kwargs["action"] = (
"store_false" if field_default is True else "store_true"
)
kwargs["default"] = field_default
else:
kwargs["type"] = inter_type
if field_default is not MISSING:
kwargs["default"] = field_default
kwargs["help"] = field_help
if field_const is not None:
kwargs["const"] = field_const
kwargs["nargs"] = "?"
return kwargs
for k in dataclass_instance._get_all_attributes():
field_name = argparse_name(dataclass_instance._get_name(k))
if field_name is None:
continue
kwargs = get_kwargs_from_dc(dataclass_instance, k)
field_args = [field_name]
alias = dataclass_instance._get_argparse_alias(k)
if alias is not None:
field_args.append(alias)
if "default" in kwargs:
if isinstance(kwargs["default"], str) and kwargs["default"].startswith(
"${"
):
if kwargs["help"] is None:
# this is a field with a name that will be added elsewhere
continue
else:
del kwargs["default"]
if delete_default:
del kwargs["default"]
try:
parser.add_argument(*field_args, **kwargs)
except ArgumentError:
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import pickle
import random
import socket
import struct
import subprocess
import warnings
from collections import OrderedDict
from typing import Any, Dict, Mapping
import torch
import torch.distributed as dist
from fairseq import utils
logger = logging.getLogger(__name__)
def is_master(args):
return args.distributed_rank == 0
def infer_init_method(args, force_distributed=False):
if args.distributed_init_method is not None or getattr(args, "tpu", False):
return
if args.pipeline_model_parallel:
balance_exists = (
args.pipeline_balance is not None
or args.pipeline_encoder_balance is not None
or args.pipeline_decoder_balance is not None
)
devices_exist = (
args.pipeline_devices is not None
or args.pipeline_encoder_devices is not None
or args.pipeline_decoder_devices is not None
)
if not balance_exists:
raise ValueError(
"--pipeline-balance is currently required for pipeline model parallelism"
)
if not devices_exist:
raise ValueError(
"--pipeline-devices is currently required for pipeline model parallelism"
)
args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int)
if args.pipeline_devices is not None:
args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int)
num_pipeline_devices = len(set(args.pipeline_devices))
else:
args.pipeline_encoder_devices = utils.eval_str_list(
args.pipeline_encoder_devices, type=int
)
args.pipeline_decoder_devices = utils.eval_str_list(
args.pipeline_decoder_devices, type=int
)
num_pipeline_devices = len(
set(args.pipeline_encoder_devices + args.pipeline_decoder_devices)
)
gpus_per_node = torch.cuda.device_count()
assert (
gpus_per_node >= num_pipeline_devices
and gpus_per_node % num_pipeline_devices == 0
), (
"the number of unique device IDs in --pipeline-devices must evenly divide "
"the number of GPUs per node (multi-node pipelining is not yet supported)"
)
num_pipelines_per_node = gpus_per_node // num_pipeline_devices
# support torch.distributed.launch
if all(
key in os.environ
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
):
args.distributed_init_method = "env://"
args.distributed_world_size = int(os.environ["WORLD_SIZE"])
args.distributed_rank = int(os.environ["RANK"])
# processes are created by torch.distributed.launch
args.distributed_no_spawn = True
# we can determine the init method automatically for Slurm
elif args.distributed_port > 0:
node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
node_list = os.environ.get("SLURM_JOB_NODELIST")
if node_list is not None:
try:
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", node_list]
)
args.distributed_init_method = "tcp://{host}:{port}".format(
host=hostnames.split()[0].decode("utf-8"),
port=args.distributed_port,
)
nnodes = int(os.environ.get("SLURM_NNODES"))
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
if ntasks_per_node is not None:
ntasks_per_node = int(ntasks_per_node)
else:
ntasks = int(os.environ.get("SLURM_NTASKS"))
nnodes = int(os.environ.get("SLURM_NNODES"))
assert ntasks % nnodes == 0
ntasks_per_node = int(ntasks / nnodes)
if ntasks_per_node == 1:
gpus_per_node = torch.cuda.device_count()
node_id = int(os.environ.get("SLURM_NODEID"))
args.distributed_rank = node_id * gpus_per_node
args.distributed_world_size = nnodes * gpus_per_node
elif args.pipeline_model_parallel:
assert ntasks_per_node == num_pipelines_per_node, (
"SLURM --ntasks-per-node must match number of pipelines per "
"node (={})".format(num_pipelines_per_node)
)
args.distributed_no_spawn = True
# For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on
# the first node, [1, 2] on the second node, etc. This
# matches torch.distributed.launch.
node_id = int(os.environ.get("SLURM_NODEID"))
local_id = int(os.environ.get("SLURM_LOCALID"))
args.distributed_rank = node_id * num_pipelines_per_node + local_id
# In the above example, device_id will always be in [0, 1],
# which also matches torch.distributed.launch.
args.device_id = local_id
# We also want to set distributed_world_size to be the total
# number of pipelines across all nodes.
args.distributed_world_size = nnodes * num_pipelines_per_node
else:
assert ntasks_per_node == args.distributed_world_size // nnodes
args.distributed_no_spawn = True
args.distributed_rank = int(os.environ.get("SLURM_PROCID"))
args.device_id = int(os.environ.get("SLURM_LOCALID"))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
elif args.distributed_world_size > 1 or force_distributed:
# fallback for single node with multiple GPUs
assert args.distributed_world_size <= torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = "tcp://localhost:{port}".format(port=port)
if args.pipeline_model_parallel:
if not args.distributed_no_spawn:
# When distributed_no_spawn is False, we expect distributed_rank and
# distributed_world_size to be based on the total number of GPUs, so
# we need to correct them to be based on the number of pipelines.
assert args.distributed_world_size % num_pipeline_devices == 0
args.distributed_world_size = (
args.distributed_world_size // num_pipeline_devices
)
# In the case of 4-way MP on nodes with 8 GPUs, we want
# distributed_rank to be the starting GPU index for each pipeline
# i.e., 0, 2, ...
assert args.distributed_rank % gpus_per_node == 0
assert args.distributed_rank % num_pipeline_devices == 0
args.distributed_rank = args.distributed_rank // num_pipeline_devices
# launch one process per pipeline
args.distributed_num_procs = num_pipelines_per_node
# if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0
# and 4, indicating the starting device IDs for each pipeline
args.device_id *= num_pipeline_devices
if args.device_id > 0:
# if there's multiple pipelines on a node (e.g., 4-way MP on an 8
# GPU node), we need to adjust pipeline_devices accordingly
logger.debug(
"setting CUDA device={} on rank {}".format(
args.device_id, args.distributed_rank
)
)
torch.cuda.set_device(args.device_id)
args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices]
logger.info(
"setting pipeline_devices={} on rank {}".format(
args.pipeline_devices, args.distributed_rank
),
)
elif not args.distributed_no_spawn:
args.distributed_num_procs = min(
torch.cuda.device_count(),
args.distributed_world_size,
)
def distributed_init(args):
if not getattr(args, "tpu", False):
if torch.distributed.is_initialized():
warnings.warn(
"Distributed is already initialized, cannot initialize twice!"
)
else:
logger.info(
"distributed init (rank {}): {}".format(
args.distributed_rank,
args.distributed_init_method,
)
)
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
logger.info(
"initialized host {} as rank {}".format(
socket.gethostname(),
args.distributed_rank,
)
)
# perform a dummy all-reduce to initialize the NCCL communicator
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())
args.distributed_rank = torch.distributed.get_rank()
else:
import torch_xla.core.xla_model as xm
assert xm.xrt_world_size() == args.distributed_world_size
args.device_id = xm.get_local_ordinal()
args.distributed_rank = xm.get_ordinal()
xm.rendezvous("distributed_init") # wait for all workers
xm.mark_step()
if not is_master(args):
logging.getLogger().setLevel(logging.WARNING)
if args.model_parallel_size > 1:
try:
from fairseq.model_parallel.megatron.mpu import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_cuda_manual_seed,
)
except ImportError:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
initialize_model_parallel(args.model_parallel_size)
model_parallel_cuda_manual_seed(args.seed)
model_part_number = get_model_parallel_rank()
args.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
return args.distributed_rank
def distributed_main(i, main, args, kwargs):
args.device_id = i
if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False):
torch.cuda.set_device(args.device_id)
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = kwargs.pop("start_rank", 0) + i
args.distributed_rank = distributed_init(args)
after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
if after_distributed_init_fn:
args = after_distributed_init_fn(args)
main(args, **kwargs)
def call_main(args, main, **kwargs):
if args.distributed_init_method is None:
infer_init_method(args)
if args.distributed_init_method is not None:
# distributed training
if not args.distributed_no_spawn:
start_rank = args.distributed_rank
args.distributed_rank = None # assign automatically
kwargs["start_rank"] = start_rank
torch.multiprocessing.spawn(
fn=distributed_main,
args=(main, args, kwargs),
nprocs=args.distributed_num_procs,
)
else:
distributed_main(args.device_id, main, args, kwargs)
elif getattr(args, "tpu", False) and args.distributed_world_size > 1:
import torch_xla.distributed.xla_multiprocessing as xmp
torch.multiprocessing.set_sharing_strategy("file_system")
xmp.spawn(
fn=distributed_main,
args=(main, args, kwargs),
nprocs=8, # use all 8 TPU cores
)
else:
# single GPU main
main(args, **kwargs)
def get_rank():
return dist.get_rank()
def get_world_size():
return dist.get_world_size()
def get_default_group():
return dist.group.WORLD
def all_reduce(tensor, group=None):
if isinstance(group, tuple) and group[0] == "tpu":
import torch_xla.core.xla_model as xm
return xm.all_reduce("sum", [tensor], groups=group[1])
else:
if group is None:
group = get_default_group()
return dist.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
rank = get_rank()
world_size = get_world_size()
buffer_size = max_size * world_size
if (
not hasattr(all_gather_list, "_buffer")
or all_gather_list._buffer.numel() < buffer_size
):
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
buffer.zero_()
cpu_buffer = all_gather_list._cpu_buffer
data = utils.move_to_cpu(data)
enc = pickle.dumps(data)
enc_size = len(enc)
header_size = 4 # size of header that contains the length of the encoded data
size = header_size + enc_size
if size > max_size:
raise ValueError(
"encoded data size ({}) exceeds max_size ({})".format(size, max_size)
)
header = struct.pack(">I", enc_size)
cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
start = rank * max_size
buffer[start : start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group)
buffer = buffer.cpu()
try:
result = []
for i in range(world_size):
out_buffer = buffer[i * max_size : (i + 1) * max_size]
(enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
if enc_size > 0:
result.append(
pickle.loads(
bytes(out_buffer[header_size : header_size + enc_size].tolist())
)
)
return result
except pickle.UnpicklingError:
raise Exception(
"Unable to unpickle data from other workers. all_gather_list requires all "
"workers to enter the function together, so this error usually indicates "
"that the workers have fallen out of sync somehow. Workers can fall out of "
"sync if one of them runs out of memory, or if there are other conditions "
"in your training script that can cause one worker to finish an epoch "
"while other workers are still iterating over their portions of the data. "
"Try rerunning with --ddp-backend=no_c10d and see if that helps."
)
def all_reduce_dict(
data: Mapping[str, Any],
device,
group=None,
) -> Dict[str, Any]:
"""
AllReduce a dictionary of values across workers. We separately
reduce items that are already on the device and items on CPU for
better performance.
Args:
data (Mapping[str, Any]): dictionary of data to all-reduce, but
cannot be a nested dictionary
device (torch.device): device for the reduction
group (optional): group of the collective
"""
data_keys = list(data.keys())
# We want to separately reduce items that are already on the
# device and items on CPU for performance reasons.
cpu_data = OrderedDict()
device_data = OrderedDict()
for k in data_keys:
t = data[k]
if not torch.is_tensor(t):
cpu_data[k] = torch.tensor(t, dtype=torch.double)
elif t.device.type != device.type:
cpu_data[k] = t.to(dtype=torch.double)
else:
device_data[k] = t.to(dtype=torch.double)
def _all_reduce_dict(data: OrderedDict):
if len(data) == 0:
return data
buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device)
all_reduce(buf, group=group)
split_buf = torch.split(buf, [t.numel() for t in data.values()])
reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())]
return OrderedDict(zip(data.keys(), reduced_data))
cpu_data = _all_reduce_dict(cpu_data)
device_data = _all_reduce_dict(device_data)
def get_from_stack(key):
if key in cpu_data:
return cpu_data[key]
elif key in device_data:
return device_data[key]
raise KeyError
return OrderedDict([(key, get_from_stack(key)) for key in data_keys])
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
from typing import List, Optional
try:
from fvcore.common.file_io import PathManager as FVCorePathManager
except ImportError:
FVCorePathManager = None
class PathManager:
"""
Wrapper for insulating OSS I/O (using Python builtin operations) from
fvcore's PathManager abstraction (for transparently handling various
internal backends).
"""
@staticmethod
def open(
path: str,
mode: str = "r",
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
):
if FVCorePathManager:
return FVCorePathManager.open(
path=path,
mode=mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
)
return open(
path,
mode=mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
)
@staticmethod
def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool:
if FVCorePathManager:
return FVCorePathManager.copy(
src_path=src_path, dst_path=dst_path, overwrite=overwrite
)
return shutil.copyfile(src_path, dst_path)
@staticmethod
def get_local_path(path: str, **kwargs) -> str:
if FVCorePathManager:
return FVCorePathManager.get_local_path(path, **kwargs)
return path
@staticmethod
def exists(path: str) -> bool:
if FVCorePathManager:
return FVCorePathManager.exists(path)
return os.path.exists(path)
@staticmethod
def isfile(path: str) -> bool:
if FVCorePathManager:
return FVCorePathManager.isfile(path)
return os.path.isfile(path)
@staticmethod
def ls(path: str) -> List[str]:
if FVCorePathManager:
return FVCorePathManager.ls(path)
return os.listdir(path)
@staticmethod
def mkdirs(path: str) -> None:
if FVCorePathManager:
return FVCorePathManager.mkdirs(path)
os.makedirs(path, exist_ok=True)
@staticmethod
def rm(path: str) -> None:
if FVCorePathManager:
return FVCorePathManager.rm(path)
os.remove(path)
@staticmethod
def chmod(path: str, mode: int) -> None:
if "manifold" not in path:
os.chmod(path, mode)
@staticmethod
def register_handler(handler) -> None:
if FVCorePathManager:
return FVCorePathManager.register_handler(handler=handler)
@staticmethod
def copy_from_local(
local_path: str, dst_path: str, overwrite: bool = False, **kwargs
) -> None:
if FVCorePathManager:
return FVCorePathManager.copy_from_local(
local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs
)
return shutil.copyfile(local_path, dst_path)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Utilities for working with the local dataset cache.
This file is adapted from `AllenNLP <https://github.com/allenai/allennlp>`_.
and `huggingface <https://github.com/huggingface>`_.
"""
import fnmatch
import json
import logging
import os
import shutil
import tarfile
import tempfile
from functools import partial, wraps
from hashlib import sha256
from io import open
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv(
"TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")
)
)
default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq")
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path))
except (AttributeError, ImportError):
PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def load_archive_file(archive_file):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=None)
except EnvironmentError:
logger.info(
"Archive name '{}' was not found in archive name list. "
"We assumed '{}' was a path or URL but couldn't find any file "
"associated to this path or URL.".format(
archive_file,
archive_file,
)
)
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
else:
logger.info(
"loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file
)
)
# Extract archive to temp dir and replace .tar.bz2 if necessary
tempdir = None
if not os.path.isdir(resolved_archive_file):
tempdir = tempfile.mkdtemp()
logger.info(
"extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir
)
)
ext = os.path.splitext(archive_file)[1][1:]
with tarfile.open(resolved_archive_file, "r:" + ext) as archive:
top_dir = os.path.commonprefix(archive.getnames())
archive.extractall(tempdir)
os.remove(resolved_archive_file)
shutil.move(os.path.join(tempdir, top_dir), resolved_archive_file)
shutil.rmtree(tempdir)
return resolved_archive_file
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the URL's, delimited
by a period.
"""
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
return url, etag
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ("http", "https", "s3"):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError(
"unable to parse {} as a URL or as a local path".format(url_or_filename)
)
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
from botocore.exceptions import ClientError
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
import boto3
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
import boto3
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def request_wrap_timeout(func, url):
import requests
for attempt, timeout in enumerate([10, 20, 40, 60, 60]):
try:
return func(timeout=timeout)
except requests.exceptions.Timeout as e:
logger.warning(
"Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs",
url,
attempt,
timeout,
exc_info=e,
)
continue
raise RuntimeError(f"Unable to fetch file {url}")
def http_get(url, temp_file):
import requests
from tqdm import tqdm
req = request_wrap_timeout(partial(requests.get, url, stream=True), url)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
try:
import requests
response = request_wrap_timeout(
partial(requests.head, url, allow_redirects=True), url
)
if response.status_code != 200:
etag = None
else:
etag = response.headers.get("ETag")
except EnvironmentError:
etag = None
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*")
matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, "wb") as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename):
"""
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
"""
collection = set()
with open(filename, "r", encoding="utf-8") as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
import logging
import os
from typing import Any, Dict, Iterator, List, Tuple
import torch
from fairseq import utils
from fairseq.data import encoders
from torch import nn
logger = logging.getLogger(__name__)
def from_pretrained(
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
archive_map=None,
**kwargs
):
from fairseq import checkpoint_utils, file_utils
if archive_map is not None:
if model_name_or_path in archive_map:
model_name_or_path = archive_map[model_name_or_path]
if data_name_or_path is not None and data_name_or_path in archive_map:
data_name_or_path = archive_map[data_name_or_path]
# allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
# for each model
if isinstance(model_name_or_path, dict):
for k, v in model_name_or_path.items():
if k == "checkpoint_file":
checkpoint_file = v
elif (
k != "path"
# only set kwargs that don't already have overrides
and k not in kwargs
):
kwargs[k] = v
model_name_or_path = model_name_or_path["path"]
model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive
if data_name_or_path.startswith("."):
kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path))
else:
kwargs["data"] = file_utils.load_archive_file(data_name_or_path)
for file, arg in {
"code": "bpe_codes",
"bpecodes": "bpe_codes",
"sentencepiece.bpe.model": "sentencepiece_model",
}.items():
path = os.path.join(model_path, file)
if os.path.exists(path):
kwargs[arg] = path
if "user_dir" in kwargs:
utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
arg_overrides=kwargs,
)
return {
"args": args,
"task": task,
"models": models,
}
class GeneratorHubInterface(nn.Module):
"""
PyTorch Hub interface for generating sequences from a pre-trained
translation or language model.
"""
def __init__(self, args, task, models):
super().__init__()
self.args = args
self.task = task
self.models = nn.ModuleList(models)
self.src_dict = task.source_dictionary
self.tgt_dict = task.target_dictionary
# optimize model for generation
for model in self.models:
model.prepare_for_inference_(args)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None))
self.tokenizer = encoders.build_tokenizer(args)
self.bpe = encoders.build_bpe(args)
self.max_positions = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in models]
)
# this is useful for determining the device
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
def translate(
self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
) -> List[str]:
return self.sample(sentences, beam, verbose, **kwargs)
def sample(
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
def score(self, sentences: List[str], **kwargs):
if isinstance(sentences, str):
return self.score([sentences], **kwargs)[0]
# NOTE: this doesn't support translation tasks currently
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
return [
hypos[0]
for hypos in self.generate(
tokenized_sentences, score_reference=True, **kwargs
)
]
def generate(
self,
tokenized_sentences: List[torch.LongTensor],
beam: int = 5,
verbose: bool = False,
skip_invalid_size_inputs=False,
inference_step_args=None,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
return self.generate(
tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
)[0]
# build generator using current args as well as any kwargs
gen_args = copy.copy(self.args)
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(self.models, gen_args)
inference_step_args = inference_step_args or {}
results = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
translations = self.task.inference_step(
generator, self.models, batch, **inference_step_args
)
for id, hypos in zip(batch["id"].tolist(), translations):
results.append((id, hypos))
# sort output to match input order
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
if verbose:
def getarg(name, default):
return getattr(gen_args, name, getattr(self.args, name, default))
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
src_str_with_unk = self.string(source_tokens)
logger.info("S\t{}".format(src_str_with_unk))
for hypo in target_hypotheses:
hypo_str = self.decode(hypo["tokens"])
logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
logger.info(
"P\t{}".format(
" ".join(
map(
lambda x: "{:.4f}".format(x),
hypo["positional_scores"].tolist(),
)
)
)
)
if hypo["alignment"] is not None and getarg(
"print_alignment", False
):
logger.info(
"A\t{}".format(
" ".join(
[
"{}-{}".format(src_idx, tgt_idx)
for src_idx, tgt_idx in hypo["alignment"]
]
)
)
)
return outputs
def encode(self, sentence: str) -> torch.LongTensor:
sentence = self.tokenize(sentence)
sentence = self.apply_bpe(sentence)
return self.binarize(sentence)
def decode(self, tokens: torch.LongTensor) -> str:
sentence = self.string(tokens)
sentence = self.remove_bpe(sentence)
return self.detokenize(sentence)
def tokenize(self, sentence: str) -> str:
if self.tokenizer is not None:
sentence = self.tokenizer.encode(sentence)
return sentence
def detokenize(self, sentence: str) -> str:
if self.tokenizer is not None:
sentence = self.tokenizer.decode(sentence)
return sentence
def apply_bpe(self, sentence: str) -> str:
if self.bpe is not None:
sentence = self.bpe.encode(sentence)
return sentence
def remove_bpe(self, sentence: str) -> str:
if self.bpe is not None:
sentence = self.bpe.decode(sentence)
return sentence
def binarize(self, sentence: str) -> torch.LongTensor:
return self.src_dict.encode_line(sentence, add_if_not_exist=False).long()
def string(self, tokens: torch.LongTensor) -> str:
return self.tgt_dict.string(tokens)
def _build_batches(
self, tokens: List[List[int]], skip_invalid_size_inputs: bool
) -> Iterator[Dict[str, Any]]:
lengths = torch.LongTensor([t.numel() for t in tokens])
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.build_dataset_for_inference(tokens, lengths),
max_tokens=self.args.max_tokens,
max_sentences=self.args.batch_size,
max_positions=self.max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs,
disable_iterator_cache=True,
).next_epoch_itr(shuffle=False)
return batch_iterator
class BPEHubInterface(object):
"""PyTorch Hub interface for Byte-Pair Encoding (BPE)."""
def __init__(self, bpe, **kwargs):
super().__init__()
args = argparse.Namespace(bpe=bpe, **kwargs)
self.bpe = encoders.build_bpe(args)
assert self.bpe is not None
def encode(self, sentence: str) -> str:
return self.bpe.encode(sentence)
def decode(self, sentence: str) -> str:
return self.bpe.decode(sentence)
class TokenizerHubInterface(object):
"""PyTorch Hub interface for tokenization."""
def __init__(self, tokenizer, **kwargs):
super().__init__()
args = argparse.Namespace(tokenizer=tokenizer, **kwargs)
self.tokenizer = encoders.build_tokenizer(args)
assert self.tokenizer is not None
def encode(self, sentence: str) -> str:
return self.tokenizer.encode(sentence)
def decode(self, sentence: str) -> str:
return self.tokenizer.decode(sentence)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import uuid
from typing import Dict, Optional
from torch import Tensor
class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_incremental_state()
def init_incremental_state(self):
self._incremental_state_id = str(uuid.uuid4())
def _get_full_incremental_state_key(self, key: str) -> str:
return "{}.{}".format(self._incremental_state_id, key)
def get_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
return incremental_state
def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(
b for b in cls.__bases__ if b != FairseqIncrementalState
)
return cls
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
import numpy as np
import torch
from fairseq import utils
DecoderOut = namedtuple(
"IterativeRefinementDecoderOut",
["output_tokens", "output_scores", "attn", "step", "max_step", "history"],
)
class IterativeRefinementGenerator(object):
def __init__(
self,
tgt_dict,
models=None,
eos_penalty=0.0,
max_iter=10,
max_ratio=2,
beam_size=1,
decoding_format=None,
retain_dropout=False,
adaptive=True,
retain_history=False,
reranking=False,
):
"""
Generates translations based on iterative refinement.
Args:
tgt_dict: target dictionary
eos_penalty: if > 0.0, it penalized early-stopping in decoding
max_iter: maximum number of refinement iterations
max_ratio: generate sequences of maximum length ax, where x is the source length
decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'}
retain_dropout: retaining dropout in the inference
adaptive: decoding with early stop
"""
self.bos = tgt_dict.bos()
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.eos_penalty = eos_penalty
self.max_iter = max_iter
self.max_ratio = max_ratio
self.beam_size = beam_size
self.reranking = reranking
self.decoding_format = decoding_format
self.retain_dropout = retain_dropout
self.retain_history = retain_history
self.adaptive = adaptive
self.models = models
def generate_batched_itr(
self,
data_itr,
maxlen_a=None,
maxlen_b=None,
cuda=False,
timer=None,
prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
for sample in data_itr:
if "net_input" not in sample:
continue
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(
self.models,
sample,
prefix_tokens=sample["target"][:, :prefix_size]
if prefix_size > 0
else None,
)
if timer is not None:
timer.stop(sample["ntokens"])
for i, id in enumerate(sample["id"]):
# remove padding
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad)
ref = utils.strip_pad(sample["target"][i, :], self.pad)
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None, constraints=None):
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the IterativeRefinementGenerator is not supported"
)
# TODO: iterative refinement generator does not support ensemble for now.
if not self.retain_dropout:
for model in models:
model.eval()
model, reranker = models[0], None
if self.reranking:
assert len(models) > 1, "Assuming the last checkpoint is the reranker"
assert (
self.beam_size > 1
), "Reranking requires multiple translation for each example"
reranker = models[-1]
models = models[:-1]
if len(models) > 1 and hasattr(model, "enable_ensemble"):
assert model.allow_ensemble, "{} does not support ensembling".format(
model.__class__.__name__
)
model.enable_ensemble(models)
# TODO: better encoder inputs?
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
bsz, src_len = src_tokens.size()
# initialize
encoder_out = model.forward_encoder([src_tokens, src_lengths])
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
if self.beam_size > 1:
assert (
model.allow_length_beam
), "{} does not support decoding with length beam.".format(
model.__class__.__name__
)
# regenerate data based on length-beam
length_beam_order = (
utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
)
encoder_out = model.encoder.reorder_encoder_out(
encoder_out, length_beam_order
)
prev_decoder_out = model.regenerate_length_beam(
prev_decoder_out, self.beam_size
)
bsz = bsz * self.beam_size
sent_idxs = torch.arange(bsz)
prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.retain_history:
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])
finalized = [[] for _ in range(bsz)]
def is_a_loop(x, y, s, a):
b, l_x, l_y = x.size(0), x.size(1), y.size(1)
if l_x > l_y:
y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1)
s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
if a is not None:
a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
elif l_x < l_y:
x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1)
return (x == y).all(1), y, s, a
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
cutoff = prev_out_token.ne(self.pad)
tokens = prev_out_token[cutoff]
if prev_out_score is None:
scores, score = None, None
else:
scores = prev_out_score[cutoff]
score = scores.mean()
if prev_out_attn is None:
hypo_attn, alignment = None, None
else:
hypo_attn = prev_out_attn[cutoff]
alignment = hypo_attn.max(dim=1)[1]
return {
"steps": step,
"tokens": tokens,
"positional_scores": scores,
"score": score,
"hypo_attn": hypo_attn,
"alignment": alignment,
}
for step in range(self.max_iter + 1):
decoder_options = {
"eos_penalty": self.eos_penalty,
"max_ratio": self.max_ratio,
"decoding_format": self.decoding_format,
}
prev_decoder_out = prev_decoder_out._replace(
step=step,
max_step=self.max_iter + 1,
)
decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options
)
if self.adaptive:
# terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_output_tokens,
decoder_out.output_tokens,
decoder_out.output_scores,
decoder_out.attn,
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
output_scores=out_scores,
attn=out_attn,
)
else:
terminated = decoder_out.output_tokens.new_zeros(
decoder_out.output_tokens.size(0)
).bool()
if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1)
# collect finalized sentences
finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None
if (decoder_out.attn is None or decoder_out.attn.size(0) == 0)
else decoder_out.attn[terminated]
)
if self.retain_history:
finalized_history_tokens = [h[terminated] for h in decoder_out.history]
for i in range(finalized_idxs.size(0)):
finalized[finalized_idxs[i]] = [
finalized_hypos(
step,
finalized_tokens[i],
finalized_scores[i],
None if finalized_attn is None else finalized_attn[i],
)
]
if self.retain_history:
finalized[finalized_idxs[i]][0]["history"] = []
for j in range(len(finalized_history_tokens)):
finalized[finalized_idxs[i]][0]["history"].append(
finalized_hypos(
step, finalized_history_tokens[j][i], None, None
)
)
# check if all terminated
if terminated.sum() == terminated.size(0):
break
# for next step
not_terminated = ~terminated
prev_decoder_out = decoder_out._replace(
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated]
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
else None,
history=[h[not_terminated] for h in decoder_out.history]
if decoder_out.history is not None
else None,
)
encoder_out = model.encoder.reorder_encoder_out(
encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()
)
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.beam_size > 1:
if reranker is not None:
finalized = self.rerank(
reranker, finalized, [src_tokens, src_lengths], self.beam_size
)
# aggregate information from length beam
finalized = [
finalized[
np.argmax(
[
finalized[self.beam_size * i + j][0]["score"]
for j in range(self.beam_size)
]
)
+ self.beam_size * i
]
for i in range(len(finalized) // self.beam_size)
]
return finalized
def rerank(self, reranker, finalized, encoder_input, beam_size):
def rebuild_batch(finalized):
finalized_tokens = [f[0]["tokens"] for f in finalized]
finalized_maxlen = max(f.size(0) for f in finalized_tokens)
final_output_tokens = (
finalized_tokens[0]
.new_zeros(len(finalized_tokens), finalized_maxlen)
.fill_(self.pad)
)
for i, f in enumerate(finalized_tokens):
final_output_tokens[i, : f.size(0)] = f
return final_output_tokens
final_output_tokens = rebuild_batch(finalized)
final_output_tokens[
:, 0
] = self.eos # autoregressive model assumes starting with EOS
reranker_encoder_out = reranker.encoder(*encoder_input)
length_beam_order = (
utils.new_arange(
final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)
)
.t()
.reshape(-1)
)
reranker_encoder_out = reranker.encoder.reorder_encoder_out(
reranker_encoder_out, length_beam_order
)
reranking_scores = reranker.get_normalized_probs(
reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out),
True,
None,
)
reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None])
reranking_masks = final_output_tokens[:, 1:].ne(self.pad)
reranking_scores = (
reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
)
reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(
reranking_scores
)
for i in range(len(finalized)):
finalized[i][0]["score"] = reranking_scores[i]
return finalized
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A modified version of the legacy DistributedDataParallel module that uses c10d
communication primitives. This version is simpler than the latest PyTorch
version and is useful for debugging. Notably it does not overlap gradient
communication with the backward pass, which makes it slower but more robust
than the PyTorch version.
This version also supports the *no_sync* context manager, which allows faster
training with `--update-freq`.
"""
import copy
from collections import OrderedDict
from contextlib import contextmanager
import torch
from torch import nn
from torch.autograd import Variable
from . import distributed_utils
class LegacyDistributedDataParallel(nn.Module):
"""Implements distributed data parallelism at the module level.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
broadcast buffers.
Args:
module (~torch.nn.Module): module to be parallelized
world_size (int): number of parallel workers
process_group (optional): the c10d process group to be used for
distributed data all-reduction. If None, the default process group
will be used.
buffer_size (int, optional): number of elements to buffer before
performing all-reduce (default: 256M).
"""
def __init__(self, module, world_size, process_group=None, buffer_size=2 ** 28):
super().__init__()
self.module = module
self.world_size = world_size
self.process_group = process_group
# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
self.buffer = None
# We can also forcibly accumulate grads locally and only do the
# all-reduce at some later time
self.accumulate_grads = False
# make per-device lists of parameters
paramlists = OrderedDict()
for param in self.module.parameters():
device = param.device
if paramlists.get(device) is None:
paramlists[device] = []
paramlists[device] += [param]
self.per_device_params = list(paramlists.values())
def __getstate__(self):
attrs = copy.copy(self.__dict__)
return attrs
def __setstate__(self, state):
super().__setstate__(state)
@contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""
old_accumulate_grads = self.accumulate_grads
self.accumulate_grads = True
yield
self.accumulate_grads = old_accumulate_grads
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce(self):
"""
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
"""
def all_reduce_params(params):
buffer = self.buffer
nonzero_buffer = False
if len(params) > 1:
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
nonzero_buffer = True
else:
buffer[offset : offset + sz].zero_()
offset += sz
else:
# we only have a single grad to all-reduce
p = params[0]
if p.grad is not None:
buffer = p.grad.data
nonzero_buffer = True
elif p.numel() <= self.buffer.numel():
buffer = buffer[: p.numel()]
buffer.zero_()
else:
buffer = torch.zeros_like(p)
if nonzero_buffer:
buffer.div_(self.world_size)
distributed_utils.all_reduce(buffer, self.process_group)
# copy all-reduced grads back into their original place
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
else:
p.grad = buffer[offset : offset + sz].view_as(p).clone()
offset += sz
def reduction_fn():
# This function only needs to be called once
if self.accumulate_grads:
return
if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size)
for params in self.per_device_params:
# All-reduce the gradients in buckets
offset = 0
buffered_params = []
for param in params:
if not param.requires_grad:
continue
if param.grad is None:
param.grad = torch.zeros_like(param)
if param.grad.requires_grad:
raise RuntimeError(
"DistributedDataParallel only works "
"with gradients that don't require "
"grad"
)
sz = param.numel()
if sz > self.buffer.numel():
# all-reduce big params directly
all_reduce_params([param])
else:
if offset + sz > self.buffer.numel():
all_reduce_params(buffered_params)
offset = 0
buffered_params.clear()
buffered_params.append(param)
offset += sz
if len(buffered_params) > 0:
all_reduce_params(buffered_params)
reduction_fn()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment