Commit c0f05c10 authored by hepj's avatar hepj
Browse files

更新transformer代码

parent c056df78
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
from torch.utils.data.dataloader import default_collate
from fairseq.data import ConcatDataset
logger = logging.getLogger(__name__)
class TransformEosConcatLangPairDataset(ConcatDataset):
"""
It is a combination of TransformEosLangPairDataset and ConcatDataset for multiple LangPairDataset datasets.
Assume all datasets share the same src_eos, tgt_bos, left_pad_source and left_pad_target
"""
def __init__(
self,
datasets,
src_eos,
tgt_bos,
new_src_eos=None,
new_tgt_bos=None,
):
super().__init__(datasets)
if new_src_eos is not None:
assert len(new_src_eos) == len(datasets)
else:
new_src_eos = []
if new_tgt_bos is not None:
assert len(new_tgt_bos) == len(datasets)
else:
new_tgt_bos = []
self.src_eos = src_eos
self.tgt_bos = tgt_bos
self.new_src_eos = (
torch.LongTensor(new_src_eos).cpu() if len(new_src_eos) > 0 else []
)
self.new_tgt_bos = (
torch.LongTensor(new_tgt_bos).cpu() if len(new_tgt_bos) > 0 else []
)
self.left_pad_source = self.is_left_pad_source(datasets)
self.left_pad_target = self.is_left_pad_target(datasets)
self.pad_idx = self.src_dict_pad()
def src_dict_pad(self):
if hasattr(self.datasets[0], "src_dict"):
return self.datasets[0].src_dict.pad()
if hasattr(self.datasets[0], "dataset"):
return self.datasets[0].dataset.src_dict.pad()
raise NotImplementedError("No src_dict is found")
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return dataset_idx, self.datasets[dataset_idx][sample_idx]
def is_left_pad_source(self, datasets):
def _left_pad_source(ds):
if hasattr(ds, "left_pad_source"):
return ds.left_pad_source
if hasattr(ds, "dataset"):
return _left_pad_source(ds.dataset)
logger.warn(f"{type(ds)} has no left_pad_source, using default True")
return True
left_pad_source = _left_pad_source(datasets[0])
for ds in datasets:
if left_pad_source != _left_pad_source(ds):
raise ValueError("Different left_pad_source setting detected!")
return left_pad_source
def is_left_pad_target(self, datasets):
def _left_pad_target(ds):
if hasattr(ds, "left_pad_target"):
return ds.left_pad_target
if hasattr(ds, "dataset"):
return _left_pad_target(ds.dataset)
logger.warn(f"{type(ds)} has no left_pad_target, using default False")
return False
left_pad_target = _left_pad_target(datasets[0])
for ds in datasets:
if left_pad_target != _left_pad_target(ds):
raise ValueError("Different left_pad_target setting detected!")
return left_pad_target
def collater(self, samples, **extra_args):
if len(samples) == 0:
return samples
dataset_ids = [s[0] for s in samples]
samples = [s[1] for s in samples]
if hasattr(self.datasets[0], "collater"):
samples = self.datasets[0].collater(samples, **extra_args)
else:
samples = default_collate(samples, **extra_args)
if len(self.new_src_eos) > 0:
if self.left_pad_source:
assert (
samples["net_input"]["src_tokens"][:, -1] != self.src_eos
).sum() == 0
samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos[
dataset_ids
]
else:
eos_idx = samples["net_input"]["src_lengths"] - 1
assert (
samples["net_input"]["src_tokens"][
torch.arange(eos_idx.size(0)), eos_idx
]
!= self.src_eos
).sum() == 0
samples["net_input"]["src_tokens"].scatter_(
1, eos_idx.view(-1, 1), self.new_src_eos[dataset_ids].view(-1, 1)
)
if len(self.new_tgt_bos) > 0 and "prev_output_tokens" in samples["net_input"]:
if self.left_pad_target:
# TODO: support different padding direction on target side
raise NotImplementedError(
"TransformEosLangPairDataset does not implement --left-pad-target True option"
)
else:
assert (
samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos
).sum() == 0
samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos[
dataset_ids
]
return samples
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class TransformEosDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to wrap
eos (int): index of the end-of-sentence symbol
append_eos_to_src (bool, optional): append EOS to the end of src
remove_eos_from_src (bool, optional): remove EOS from the end of src
append_eos_to_tgt (bool, optional): append EOS to the end of tgt
remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt
"""
def __init__(
self,
dataset,
eos,
append_eos_to_src=False,
remove_eos_from_src=False,
append_eos_to_tgt=False,
remove_eos_from_tgt=False,
has_target=True,
):
if not isinstance(dataset, FairseqDataset):
raise ValueError("dataset must be an instance of FairseqDataset")
if append_eos_to_src and remove_eos_from_src:
raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src")
if append_eos_to_tgt and remove_eos_from_tgt:
raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt")
self.dataset = dataset
self.eos = torch.LongTensor([eos])
self.append_eos_to_src = append_eos_to_src
self.remove_eos_from_src = remove_eos_from_src
self.append_eos_to_tgt = append_eos_to_tgt
self.remove_eos_from_tgt = remove_eos_from_tgt
self.has_target = has_target
# precompute how we should adjust the reported sizes
self._src_delta = 0
self._src_delta += 1 if append_eos_to_src else 0
self._src_delta -= 1 if remove_eos_from_src else 0
self._tgt_delta = 0
self._tgt_delta += 1 if append_eos_to_tgt else 0
self._tgt_delta -= 1 if remove_eos_from_tgt else 0
self._checked_src = False
self._checked_tgt = False
def _check_src(self, src, expect_eos):
if not self._checked_src:
assert (src[-1] == self.eos[0]) == expect_eos
self._checked_src = True
def _check_tgt(self, tgt, expect_eos):
if self.has_target and not self._checked_tgt:
assert (tgt[-1] == self.eos[0]) == expect_eos
self._checked_tgt = True
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
def transform(item):
if self.append_eos_to_src:
self.eos = self.eos.to(device=item["source"].device)
self._check_src(item["source"], expect_eos=False)
item["source"] = torch.cat([item["source"], self.eos])
if self.remove_eos_from_src:
self.eos = self.eos.to(device=item["source"].device)
self._check_src(item["source"], expect_eos=True)
item["source"] = item["source"][:-1]
if self.append_eos_to_tgt:
self.eos = self.eos.to(device=item["target"].device)
self._check_tgt(item["target"], expect_eos=False)
item["target"] = torch.cat([item["target"], self.eos])
if self.remove_eos_from_tgt:
self.eos = self.eos.to(device=item["target"].device)
self._check_tgt(item["target"], expect_eos=True)
item["target"] = item["target"][:-1]
return item
samples = list(map(transform, samples))
return self.dataset.collater(samples)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
if self.has_target:
src_len, tgt_len = self.dataset.size(index)
return (src_len + self._src_delta, tgt_len + self._tgt_delta)
else:
return self.dataset.size(index)
def ordered_indices(self):
# NOTE: we assume that the ordering does not change based on the
# addition or removal of eos
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from . import FairseqDataset
class TransformEosLangPairDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
collated samples of language pair dataset.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset that collates sample into
LanguagePairDataset schema
src_eos (int): original source end-of-sentence symbol index to be replaced
new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol
tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced
new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the
beginning of 'prev_output_tokens'
"""
def __init__(
self,
dataset: FairseqDataset,
src_eos: int,
new_src_eos: Optional[int] = None,
tgt_bos: Optional[int] = None,
new_tgt_bos: Optional[int] = None,
):
self.dataset = dataset
self.src_eos = src_eos
self.new_src_eos = new_src_eos
self.tgt_bos = tgt_bos
self.new_tgt_bos = new_tgt_bos
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples, **extra_args):
samples = self.dataset.collater(samples, **extra_args)
if len(samples) == 0:
return samples
if "net_input" not in samples:
return samples
if self.new_src_eos is not None:
if self.dataset.left_pad_source:
assert (
samples["net_input"]["src_tokens"][:, -1] != self.src_eos
).sum() == 0
samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos
else:
eos_idx = samples["net_input"]["src_lengths"] - 1
assert (
samples["net_input"]["src_tokens"][
torch.arange(eos_idx.size(0)), eos_idx
]
!= self.src_eos
).sum() == 0
eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1)
samples["net_input"]["src_tokens"].scatter_(
1, eos_idx, self.new_src_eos
)
if (
self.new_tgt_bos is not None
and "prev_output_tokens" in samples["net_input"]
):
if self.dataset.left_pad_target:
# TODO: support different padding direction on target side
raise NotImplementedError(
"TransformEosLangPairDataset does not implement --left-pad-target True option"
)
else:
assert (
samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos
).sum() == 0
samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos
return samples
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
@property
def sizes(self):
# dataset.sizes can be a dynamically computed sizes:
return self.dataset.sizes
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .configs import FairseqDataclass
from .constants import ChoiceEnum
__all__ = [
"FairseqDataclass",
"ChoiceEnum",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
from dataclasses import _MISSING_TYPE, dataclass, field
from typing import Any, List, Optional
import torch
from omegaconf import II, MISSING
from fairseq.dataclass.constants import (
DATASET_IMPL_CHOICES,
DDP_BACKEND_CHOICES,
DDP_COMM_HOOK_CHOICES,
GENERATION_CONSTRAINTS_CHOICES,
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
PRINT_ALIGNMENT_CHOICES,
ZERO_SHARDING_CHOICES,
)
@dataclass
class FairseqDataclass:
"""fairseq base dataclass that supported fetching attributes and metas"""
_name: Optional[str] = None
@staticmethod
def name():
return None
def _get_all_attributes(self) -> List[str]:
return [k for k in self.__dataclass_fields__.keys()]
def _get_meta(
self, attribute_name: str, meta: str, default: Optional[Any] = None
) -> Any:
return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
def _get_name(self, attribute_name: str) -> str:
return self.__dataclass_fields__[attribute_name].name
def _get_default(self, attribute_name: str) -> Any:
if hasattr(self, attribute_name):
if str(getattr(self, attribute_name)).startswith("${"):
return str(getattr(self, attribute_name))
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
"${"
):
return str(self.__dataclass_fields__[attribute_name].default)
elif (
getattr(self, attribute_name)
!= self.__dataclass_fields__[attribute_name].default
):
return getattr(self, attribute_name)
f = self.__dataclass_fields__[attribute_name]
if not isinstance(f.default_factory, _MISSING_TYPE):
return f.default_factory()
return f.default
def _get_type(self, attribute_name: str) -> Any:
return self.__dataclass_fields__[attribute_name].type
def _get_help(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "help")
def _get_argparse_const(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "argparse_const")
def _get_argparse_alias(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "argparse_alias")
def _get_choices(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "choices")
@classmethod
def from_namespace(cls, args):
if isinstance(args, cls):
return args
else:
config = cls()
for k in config.__dataclass_fields__.keys():
if k.startswith("_"):
# private member, skip
continue
if hasattr(args, k):
setattr(config, k, getattr(args, k))
return config
@dataclass
class CommonConfig(FairseqDataclass):
# This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were
# used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc.
no_progress_bar: bool = field(
default=False, metadata={"help": "disable progress bar"}
)
log_interval: int = field(
default=100,
metadata={
"help": "log progress every N batches (when progress bar is disabled)"
},
)
log_format: Optional[LOG_FORMAT_CHOICES] = field(
default=None, metadata={"help": "log format to use"}
)
log_file: Optional[str] = field(
default=None, metadata={"help": "log file to copy metrics to."}
)
aim_repo: Optional[str] = field(
default=None,
metadata={"help": "path to Aim repository"},
)
aim_run_hash: Optional[str] = field(
default=None,
metadata={
"help": "Aim run hash. If skipped, creates or continues run "
"based on save_dir"
},
)
tensorboard_logdir: Optional[str] = field(
default=None,
metadata={
"help": "path to save logs for tensorboard, should match --logdir "
"of running tensorboard (default: no tensorboard logging)"
},
)
wandb_project: Optional[str] = field(
default=None,
metadata={"help": "Weights and Biases project name to use for logging"},
)
azureml_logging: Optional[bool] = field(
default=False,
metadata={"help": "Log scalars to AzureML context"},
)
seed: int = field(
default=1, metadata={"help": "pseudo random number generator seed"}
)
cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"})
tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"})
bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"})
memory_efficient_bf16: bool = field(
default=False,
metadata={
"help": "use a memory-efficient version of BF16 training; implies --bf16"
},
)
fp16: bool = field(default=False, metadata={"help": "use FP16"})
memory_efficient_fp16: bool = field(
default=False,
metadata={
"help": "use a memory-efficient version of FP16 training; implies --fp16"
},
)
fp16_no_flatten_grads: bool = field(
default=False, metadata={"help": "don't flatten FP16 grads tensor"}
)
fp16_init_scale: int = field(
default=2**7, metadata={"help": "default FP16 loss scale"}
)
fp16_scale_window: Optional[int] = field(
default=None,
metadata={"help": "number of updates before increasing loss scale"},
)
fp16_scale_tolerance: float = field(
default=0.0,
metadata={
"help": "pct of updates that can overflow before decreasing the loss scale"
},
)
on_cpu_convert_precision: bool = field(
default=False,
metadata={
"help": "if set, the floating point conversion to fp16/bf16 runs on CPU. "
"This reduces bus transfer time and GPU memory usage."
},
)
min_loss_scale: float = field(
default=1e-4,
metadata={
"help": "minimum FP16/AMP loss scale, after which training is stopped"
},
)
threshold_loss_scale: Optional[float] = field(
default=None, metadata={"help": "threshold FP16 loss scale from below"}
)
amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"})
amp_batch_retries: int = field(
default=2,
metadata={
"help": "number of retries of same batch after reducing loss scale with AMP"
},
)
amp_init_scale: int = field(
default=2**7, metadata={"help": "default AMP loss scale"}
)
amp_scale_window: Optional[int] = field(
default=None,
metadata={"help": "number of updates before increasing AMP loss scale"},
)
user_dir: Optional[str] = field(
default=None,
metadata={
"help": "path to a python module containing custom extensions (tasks and/or architectures)"
},
)
empty_cache_freq: int = field(
default=0,
metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"},
)
all_gather_list_size: int = field(
default=16384,
metadata={"help": "number of bytes reserved for gathering stats from workers"},
)
model_parallel_size: int = field(
default=1, metadata={"help": "total number of GPUs to parallelize model over"}
)
quantization_config_path: Optional[str] = field(
default=None, metadata={"help": "path to quantization config file"}
)
profile: bool = field(
default=False, metadata={"help": "enable autograd profiler emit_nvtx"}
)
reset_logging: bool = field(
default=False,
metadata={
"help": "when using Hydra, reset the logging at the beginning of training"
},
)
suppress_crashes: bool = field(
default=False,
metadata={
"help": "suppress crashes when training with the hydra_train entry point so that the "
"main method can return a value (useful for sweeps)"
},
)
use_plasma_view: bool = field(
default=False, metadata={"help": "Store indices and sizes in shared memory"}
)
plasma_path: Optional[str] = field(
default="/tmp/plasma",
metadata={
"help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail."
},
)
@dataclass
class DistributedTrainingConfig(FairseqDataclass):
distributed_world_size: int = field(
default=max(1, torch.cuda.device_count()),
metadata={
"help": "total number of GPUs across all nodes (default: all visible GPUs)"
},
)
distributed_num_procs: Optional[int] = field(
default=max(1, torch.cuda.device_count()),
metadata={
"help": "total number of processes to fork (default: all visible GPUs)"
},
)
distributed_rank: Optional[int] = field(
default=0, metadata={"help": "rank of the current worker"}
)
distributed_backend: str = field(
default="nccl", metadata={"help": "distributed backend"}
)
distributed_init_method: Optional[str] = field(
default=None,
metadata={
"help": "typically tcp://hostname:port that will be used to "
"establish initial connetion"
},
)
distributed_port: int = field(
default=-1,
metadata={
"help": "port number (not required if using --distributed-init-method)"
},
)
device_id: int = field(
default=os.getenv("LOCAL_RANK", 0),
metadata={
"help": "which GPU to use (by default looks for $LOCAL_RANK, usually configured automatically)",
"argparse_alias": "--local_rank",
},
)
distributed_no_spawn: bool = field(
default=False,
metadata={
"help": "do not spawn multiple processes even if multiple GPUs are visible"
},
)
ddp_backend: DDP_BACKEND_CHOICES = field(
default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"}
)
ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field(
default="none", metadata={"help": "communication hook"}
)
bucket_cap_mb: int = field(
default=25, metadata={"help": "bucket size for reduction"}
)
fix_batches_to_gpus: bool = field(
default=False,
metadata={
"help": "don't shuffle batches between GPUs; this reduces overall "
"randomness and may affect precision but avoids the cost of re-reading the data"
},
)
find_unused_parameters: bool = field(
default=False,
metadata={
"help": "disable unused parameter detection (not applicable to "
"--ddp-backend=legacy_ddp)"
},
)
gradient_as_bucket_view: bool = field(
default=False,
metadata={
"help": "when set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. "
"--gradient-as-bucket-view=gradient_as_bucket_view)"
},
)
fast_stat_sync: bool = field(
default=False,
metadata={"help": "[deprecated] this is now defined per Criterion"},
)
heartbeat_timeout: int = field(
default=-1,
metadata={
"help": "kill the job if no progress is made in N seconds; "
"set to -1 to disable"
},
)
broadcast_buffers: bool = field(
default=False,
metadata={
"help": "Copy non-trainable parameters between GPUs, such as "
"batchnorm population statistics"
},
)
slowmo_momentum: Optional[float] = field(
default=None,
metadata={
"help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, "
"0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs"
},
)
slowmo_base_algorithm: str = field(
default="localsgd",
metadata={
"help": "Base algorithm. Either 'localsgd' or 'sgp'. Please refer "
"to the documentation of 'slowmo_base_algorithm' parameter in "
"https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html "
"for more details"
},
)
localsgd_frequency: int = field(
default=3, metadata={"help": "Local SGD allreduce frequency"}
)
nprocs_per_node: int = field(
default=max(1, torch.cuda.device_count()),
metadata={
"help": "number of GPUs in each node. An allreduce operation across GPUs in "
"a node is very fast. Hence, we do allreduce across GPUs in a node, "
"and gossip across different nodes"
},
)
pipeline_model_parallel: bool = field(
default=False,
metadata={"help": "if set, use pipeline model parallelism across GPUs"},
)
pipeline_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the model into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_balance) "
"should equal the total number of layers in the model"
},
)
pipeline_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-balance argument"
},
)
pipeline_chunks: Optional[int] = field(
default=0, metadata={"help": "microbatch count for pipeline model parallelism"}
)
pipeline_encoder_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the pipeline parallel encoder into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_encoder_balance) "
"should equal the total number of encoder layers in the model"
},
)
pipeline_encoder_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-encoder-balance argument"
},
)
pipeline_decoder_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the pipeline parallel decoder into N_K pieces, where each piece "
"contains N_i layers. The sum(args.pipeline_decoder_balance) "
"should equal the total number of decoder layers in the model"
},
)
pipeline_decoder_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
"each of the N_K partitions. The length of this list should "
"equal the length of the --pipeline-decoder-balance argument"
},
)
pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field(
default="never",
metadata={"help": "checkpointing mode for pipeline model parallelism"},
)
zero_sharding: ZERO_SHARDING_CHOICES = field(
default="none", metadata={"help": "ZeRO sharding"}
)
fp16: bool = II("common.fp16")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
tpu: bool = II("common.tpu")
# configuration for --ddp-backend=fully_sharded
no_reshard_after_forward: bool = field(
default=False,
metadata={"help": "don't reshard parameters after forward pass"},
)
fp32_reduce_scatter: bool = field(
default=False,
metadata={"help": "reduce-scatter grads in FP32"},
)
cpu_offload: bool = field(
default=False, metadata={"help": "offload FP32 params to CPU"}
)
use_sharded_state: bool = field(
default=False,
metadata={"help": "use sharded checkpoint files"},
)
not_fsdp_flatten_parameters: bool = field(
default=False,
metadata={"help": "not flatten parameter param for fsdp"},
)
@dataclass
class DatasetConfig(FairseqDataclass):
num_workers: int = field(
default=1, metadata={"help": "how many subprocesses to use for data loading"}
)
skip_invalid_size_inputs_valid_test: bool = field(
default=False,
metadata={"help": "ignore too long or too short lines in valid and test set"},
)
max_tokens: Optional[int] = field(
default=None, metadata={"help": "maximum number of tokens in a batch"}
)
batch_size: Optional[int] = field(
default=None,
metadata={
"help": "number of examples in a batch",
"argparse_alias": "--max-sentences",
},
)
required_batch_size_multiple: int = field(
default=8, metadata={"help": "batch size will be a multiplier of this value"}
)
required_seq_len_multiple: int = field(
default=1,
metadata={
"help": "maximum sequence length in batch will be a multiplier of this value"
},
)
dataset_impl: Optional[DATASET_IMPL_CHOICES] = field(
default=None, metadata={"help": "output dataset implementation"}
)
data_buffer_size: int = field(
default=10, metadata={"help": "Number of batches to preload"}
)
train_subset: str = field(
default="train",
metadata={"help": "data subset to use for training (e.g. train, valid, test)"},
)
valid_subset: str = field(
default="valid",
metadata={
"help": "comma separated list of data subsets to use for validation"
" (e.g. train, valid, test)"
},
)
combine_valid_subsets: Optional[bool] = field(
default=None,
metadata={
"help": "comma separated list of data subsets to use for validation"
" (e.g. train, valid, test)",
"argparse_alias": "--combine-val",
},
)
ignore_unused_valid_subsets: Optional[bool] = field(
default=False,
metadata={"help": "do not raise error if valid subsets are ignored"},
)
validate_interval: int = field(
default=1, metadata={"help": "validate every N epochs"}
)
validate_interval_updates: int = field(
default=0, metadata={"help": "validate every N updates"}
)
validate_after_updates: int = field(
default=0, metadata={"help": "dont validate until reaching this many updates"}
)
fixed_validation_seed: Optional[int] = field(
default=None, metadata={"help": "specified random seed for validation"}
)
disable_validation: bool = field(
default=False, metadata={"help": "disable validation"}
)
max_tokens_valid: Optional[int] = field(
default=II("dataset.max_tokens"),
metadata={
"help": "maximum number of tokens in a validation batch"
" (defaults to --max-tokens)"
},
)
batch_size_valid: Optional[int] = field(
default=II("dataset.batch_size"),
metadata={
"help": "batch size of the validation batch (defaults to --batch-size)",
"argparse_alias": "--max-sentences-valid",
},
)
max_valid_steps: Optional[int] = field(
default=None,
metadata={"help": "How many batches to evaluate", "argparse_alias": "--nval"},
)
curriculum: int = field(
default=0, metadata={"help": "don't shuffle batches for first N epochs"}
)
gen_subset: str = field(
default="test",
metadata={"help": "data subset to generate (train, valid, test)"},
)
num_shards: int = field(
default=1, metadata={"help": "shard generation over N shards"}
)
shard_id: int = field(
default=0, metadata={"help": "id of the shard to generate (id < num_shards)"}
)
grouped_shuffling: bool = field(
default=False,
metadata={
"help": "shuffle batches in groups of num_shards to enable similar sequence lengths on each GPU worker when batches are sorted by length",
},
)
update_epoch_batch_itr: bool = field(
default=II("dataset.grouped_shuffling"),
metadata={
"help": "if true then prevents the reuse the epoch batch iterator by setting can_reuse_epoch_itr to false, defaults to --grouped-shuffling )",
},
)
update_ordered_indices_seed: bool = field(
default=False,
metadata={
"help": "if true then increment seed with epoch for getting batch iterators, defautls to False.",
},
)
@dataclass
class OptimizationConfig(FairseqDataclass):
max_epoch: int = field(
default=0, metadata={"help": "force stop training at specified epoch"}
)
max_update: int = field(
default=0, metadata={"help": "force stop training at specified update"}
)
stop_time_hours: float = field(
default=0,
metadata={
"help": "force stop training after specified cumulative time (if >0)"
},
)
clip_norm: float = field(
default=0.0, metadata={"help": "clip threshold of gradients"}
)
sentence_avg: bool = field(
default=False,
metadata={
"help": "normalize gradients by the number of sentences in a batch"
" (default is to normalize by number of tokens)"
},
)
update_freq: List[int] = field(
default_factory=lambda: [1],
metadata={"help": "update parameters every N_i batches, when in epoch i"},
)
lr: List[float] = field(
default_factory=lambda: [0.25],
metadata={
"help": "learning rate for the first N epochs; all epochs >N using LR_N"
" (note: this may be interpreted differently depending on --lr-scheduler)"
},
)
stop_min_lr: float = field(
default=-1.0,
metadata={"help": "stop training when the learning rate reaches this minimum"},
)
use_bmuf: bool = field(
default=False,
metadata={
"help": "specify global optimizer for syncing models on different GPUs/shards"
},
)
skip_remainder_batch: Optional[bool] = field(
default=False,
metadata={
"help": "if set, include the last (partial) batch of each epoch in training"
" (default is to skip it)."
},
)
@dataclass
class CheckpointConfig(FairseqDataclass):
save_dir: str = field(
default="checkpoints", metadata={"help": "path to save checkpoints"}
)
restore_file: str = field(
default="checkpoint_last.pt",
metadata={
"help": "filename from which to load checkpoint "
"(default: <save-dir>/checkpoint_last.pt"
},
)
continue_once: Optional[str] = field(
default=None,
metadata={
"help": "continues from this checkpoint, unless a checkpoint indicated in 'restore_file' option is present"
},
)
finetune_from_model: Optional[str] = field(
default=None,
metadata={
"help": "finetune from a pretrained model; note that meters and lr scheduler will be reset"
},
)
reset_dataloader: bool = field(
default=False,
metadata={
"help": "if set, does not reload dataloader state from the checkpoint"
},
)
reset_lr_scheduler: bool = field(
default=False,
metadata={
"help": "if set, does not load lr scheduler state from the checkpoint"
},
)
reset_meters: bool = field(
default=False,
metadata={"help": "if set, does not load meters from the checkpoint"},
)
reset_optimizer: bool = field(
default=False,
metadata={"help": "if set, does not load optimizer state from the checkpoint"},
)
optimizer_overrides: str = field(
default="{}",
metadata={
"help": "a dictionary used to override optimizer args when loading a checkpoint"
},
)
save_interval: int = field(
default=1, metadata={"help": "save a checkpoint every N epochs"}
)
save_interval_updates: int = field(
default=0, metadata={"help": "save a checkpoint (and validate) every N updates"}
)
keep_interval_updates: int = field(
default=-1,
metadata={
"help": "keep the last N checkpoints saved with --save-interval-updates"
},
)
keep_interval_updates_pattern: int = field(
default=-1,
metadata={
"help": "when used with --keep-interval-updates, skips deleting "
"any checkpoints with update X where "
"X %% keep_interval_updates_pattern == 0"
},
)
keep_last_epochs: int = field(
default=-1, metadata={"help": "keep last N epoch checkpoints"}
)
keep_best_checkpoints: int = field(
default=-1, metadata={"help": "keep best N checkpoints based on scores"}
)
no_save: bool = field(
default=False, metadata={"help": "don't save models or checkpoints"}
)
no_epoch_checkpoints: bool = field(
default=False, metadata={"help": "only store last and best checkpoints"}
)
no_last_checkpoints: bool = field(
default=False, metadata={"help": "don't store last checkpoints"}
)
no_save_optimizer_state: bool = field(
default=False,
metadata={"help": "don't save optimizer-state as part of checkpoint"},
)
best_checkpoint_metric: str = field(
default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'}
)
maximize_best_checkpoint_metric: bool = field(
default=False,
metadata={
"help": 'select the largest metric value for saving "best" checkpoints'
},
)
patience: int = field(
default=-1,
metadata={
"help": (
"early stop training if valid performance doesn't "
"improve for N consecutive validation runs; note "
"that this is influenced by --validate-interval"
)
},
)
checkpoint_suffix: str = field(
default="", metadata={"help": "suffix to add to the checkpoint file name"}
)
checkpoint_shard_count: int = field(
default=1,
metadata={
"help": "Number of shards containing the checkpoint - "
"if the checkpoint is over 300GB, it is preferable "
"to split it into shards to prevent OOM on CPU while loading "
"the checkpoint"
},
)
load_checkpoint_on_all_dp_ranks: bool = field(
default=False,
metadata={
"help": "load checkpoints on all data parallel devices "
"(default: only load on rank 0 and broadcast to other devices)"
},
)
write_checkpoints_asynchronously: bool = field(
default=False,
metadata={
"help": (
"Write checkpoints asynchronously in a separate "
"thread. NOTE: This feature is currently being tested."
),
"argparse_alias": "--save-async",
},
)
model_parallel_size: int = II("common.model_parallel_size")
@dataclass
class FairseqBMUFConfig(FairseqDataclass):
block_lr: float = field(
default=1, metadata={"help": "block learning rate for bmuf"}
)
block_momentum: float = field(
default=0.875, metadata={"help": "block momentum for bmuf"}
)
global_sync_iter: int = field(
default=50, metadata={"help": "Iteration for syncing global model"}
)
warmup_iterations: int = field(
default=500, metadata={"help": "warmup iterations for model to broadcast"}
)
use_nbm: bool = field(
default=False,
metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"},
)
average_sync: bool = field(
default=False,
metadata={
"help": "Specify whether you want to average the local momentum after each sync"
},
)
distributed_world_size: int = II("distributed_training.distributed_world_size")
@dataclass
class GenerationConfig(FairseqDataclass):
beam: int = field(
default=5,
metadata={"help": "beam size"},
)
nbest: int = field(
default=1,
metadata={"help": "number of hypotheses to output"},
)
max_len_a: float = field(
default=0,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
max_len_b: int = field(
default=200,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
min_len: int = field(
default=1,
metadata={"help": "minimum generation length"},
)
match_source_len: bool = field(
default=False,
metadata={"help": "generations should match the source length"},
)
unnormalized: bool = field(
default=False,
metadata={"help": "compare unnormalized hypothesis scores"},
)
no_early_stop: bool = field(
default=False,
metadata={"help": "deprecated"},
)
no_beamable_mm: bool = field(
default=False,
metadata={"help": "don't use BeamableMM in attention layers"},
)
lenpen: float = field(
default=1,
metadata={
"help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
},
)
unkpen: float = field(
default=0,
metadata={
"help": "unknown word penalty: <0 produces more unks, >0 produces fewer"
},
)
replace_unk: Optional[str] = field(
default=None,
metadata={
"help": "perform unknown replacement (optionally with alignment dictionary)",
"argparse_const": "@@ ",
},
)
sacrebleu: bool = field(
default=False,
metadata={"help": "score with sacrebleu"},
)
score_reference: bool = field(
default=False,
metadata={"help": "just score the reference translation"},
)
prefix_size: int = field(
default=0,
metadata={"help": "initialize generation by target prefix of given length"},
)
no_repeat_ngram_size: int = field(
default=0,
metadata={
"help": "ngram blocking such that this size ngram cannot be repeated in the generation"
},
)
sampling: bool = field(
default=False,
metadata={"help": "sample hypotheses instead of using beam search"},
)
sampling_topk: int = field(
default=-1,
metadata={"help": "sample from top K likely next words instead of all words"},
)
sampling_topp: float = field(
default=-1.0,
metadata={
"help": "sample from the smallest set whose cumulative probability mass exceeds p for next words"
},
)
constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field(
default=None,
metadata={
"help": "enables lexically constrained decoding",
"argparse_const": "ordered",
},
)
temperature: float = field(
default=1.0,
metadata={"help": "temperature for generation"},
)
diverse_beam_groups: int = field(
default=-1,
metadata={"help": "number of groups for Diverse Beam Search"},
)
diverse_beam_strength: float = field(
default=0.5,
metadata={"help": "strength of diversity penalty for Diverse Beam Search"},
)
diversity_rate: float = field(
default=-1.0,
metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
)
print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field(
default=None,
metadata={
"help": "if set, uses attention feedback to compute and print alignment to source tokens "
"(valid options are: hard, soft, otherwise treated as hard alignment)",
"argparse_const": "hard",
},
)
print_step: bool = field(
default=False,
metadata={"help": "print steps"},
)
lm_path: Optional[str] = field(
default=None,
metadata={"help": "path to lm checkpoint for lm fusion"},
)
lm_weight: float = field(
default=0.0,
metadata={"help": "weight for lm probs for lm fusion"},
)
# arguments for iterative refinement generator
iter_decode_eos_penalty: float = field(
default=0.0,
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
)
iter_decode_max_iter: int = field(
default=10,
metadata={"help": "maximum iterations for iterative refinement."},
)
iter_decode_force_max_iter: bool = field(
default=False,
metadata={
"help": "if set, run exact the maximum number of iterations without early stop"
},
)
iter_decode_with_beam: int = field(
default=1,
metadata={
"help": "if > 1, model will generate translations varying by the lengths."
},
)
iter_decode_with_external_reranker: bool = field(
default=False,
metadata={
"help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations"
},
)
retain_iter_history: bool = field(
default=False,
metadata={
"help": "if set, decoding returns the whole history of iterative refinement"
},
)
retain_dropout: bool = field(
default=False,
metadata={"help": "Use dropout at inference time"},
)
# temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed
# retain_dropout_modules: Optional[List[str]] = field(
retain_dropout_modules: Any = field(
default=None,
metadata={
"help": "if set, only retain dropout for the specified modules; "
"if not set, then dropout will be retained for all modules"
},
)
# special decoding format for advanced decoding.
decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field(
default=None,
metadata={"help": "special decoding format for advanced decoding."},
)
no_seed_provided: bool = field(
default=False,
metadata={"help": "if set, dont use seed for initializing random generators"},
)
eos_token: Optional[str] = field(
default=None,
metadata={"help": "EOS token"},
)
@dataclass
class CommonEvalConfig(FairseqDataclass):
path: Optional[str] = field(
default=None,
metadata={"help": "path(s) to model file(s), colon separated"},
)
post_process: Optional[str] = field(
default=None,
metadata={
"help": (
"post-process text by removing BPE, letter segmentation, etc. "
"Valid options can be found in fairseq.data.utils.post_process."
),
"argparse_const": "subword_nmt",
"argparse_alias": "--remove-bpe",
},
)
quiet: bool = field(default=False, metadata={"help": "only print final scores"})
model_overrides: str = field(
default="{}",
metadata={
"help": "a dictionary used to override model args at generation that were used during model training"
},
)
results_path: Optional[str] = field(
default=None, metadata={"help": "path to save eval results (optional)"}
)
@dataclass
class EvalLMConfig(FairseqDataclass):
output_word_probs: bool = field(
default=False,
metadata={
"help": "if set, outputs words and their predicted log probabilities to standard output"
},
)
output_word_stats: bool = field(
default=False,
metadata={
"help": "if set, outputs word statistics such as word count, average probability, etc"
},
)
context_window: int = field(
default=0,
metadata={
"help": "ensures that every evaluated token has access to a context of at least this size, if possible"
},
)
softmax_batch: int = field(
default=sys.maxsize,
metadata={
"help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory"
},
)
@dataclass
class InteractiveConfig(FairseqDataclass):
buffer_size: int = field(
default=0,
metadata={
"help": "read this many sentences into a buffer before processing them"
},
)
input: str = field(
default="-",
metadata={"help": "file to read from; use - for stdin"},
)
@dataclass
class EMAConfig(FairseqDataclass):
store_ema: bool = field(
default=False, metadata={help: "store exponential moving average shadow model"}
)
ema_decay: float = field(
default=0.9999, metadata={"help": "decay for exponential moving average model"}
)
ema_start_update: int = field(
default=0, metadata={"help": "start EMA update after this many model updates"}
)
ema_seed_model: Optional[str] = field(
default=None,
metadata={
"help": "Seed to load EMA model from. "
"Used to load EMA model separately from the actual model."
},
)
ema_update_freq: int = field(
default=1, metadata={"help": "Do EMA update every this many model updates"}
)
ema_fp32: bool = field(
default=False,
metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"},
)
@dataclass
class FairseqConfig(FairseqDataclass):
common: CommonConfig = CommonConfig()
common_eval: CommonEvalConfig = CommonEvalConfig()
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
dataset: DatasetConfig = DatasetConfig()
optimization: OptimizationConfig = OptimizationConfig()
checkpoint: CheckpointConfig = CheckpointConfig()
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
generation: GenerationConfig = GenerationConfig()
eval_lm: EvalLMConfig = EvalLMConfig()
interactive: InteractiveConfig = InteractiveConfig()
model: Any = MISSING
task: Any = None
criterion: Any = None
optimizer: Any = None
lr_scheduler: Any = None
scoring: Any = None
bpe: Any = None
tokenizer: Any = None
ema: EMAConfig = EMAConfig()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum, EnumMeta
from typing import List
class StrEnumMeta(EnumMeta):
# this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see
# https://github.com/facebookresearch/hydra/issues/1156
@classmethod
def __instancecheck__(cls, other):
return "enum" in str(type(other))
class StrEnum(Enum, metaclass=StrEnumMeta):
def __str__(self):
return self.value
def __eq__(self, other: str):
return self.value == other
def __repr__(self):
return self.value
def __hash__(self):
return hash(str(self))
def ChoiceEnum(choices: List[str]):
"""return the Enum class used to enforce list of choices"""
return StrEnum("Choices", {k: k for k in choices})
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum(
[
"c10d", # alias for pytorch_ddp
"fully_sharded", # FullyShardedDataParallel from fairscale
"legacy_ddp",
"no_c10d", # alias for legacy_ddp
"pytorch_ddp",
"slowmo",
]
)
DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(
["unigram", "ensemble", "vote", "dp", "bs"]
)
ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])
PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import logging
from hydra.core.config_store import ConfigStore
from fairseq.dataclass.configs import FairseqConfig
from omegaconf import DictConfig, OmegaConf
logger = logging.getLogger(__name__)
def hydra_init(cfg_name="config") -> None:
cs = ConfigStore.instance()
cs.store(name=f"{cfg_name}", node=FairseqConfig)
for k in FairseqConfig.__dataclass_fields__:
v = FairseqConfig.__dataclass_fields__[k].default
try:
cs.store(name=k, node=v)
except BaseException:
logger.error(f"{k} - {v}")
raise
def add_defaults(cfg: DictConfig) -> None:
"""This function adds default values that are stored in dataclasses that hydra doesn't know about"""
from fairseq.registry import REGISTRIES
from fairseq.tasks import TASK_DATACLASS_REGISTRY
from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY
from fairseq.dataclass.utils import merge_with_parent
from typing import Any
OmegaConf.set_struct(cfg, False)
for k, v in FairseqConfig.__dataclass_fields__.items():
field_cfg = cfg.get(k)
if field_cfg is not None and v.type == Any:
dc = None
if isinstance(field_cfg, str):
field_cfg = DictConfig({"_name": field_cfg})
field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"]
name = getattr(field_cfg, "_name", None)
if k == "task":
dc = TASK_DATACLASS_REGISTRY.get(name)
elif k == "model":
name = ARCH_MODEL_NAME_REGISTRY.get(name, name)
dc = MODEL_DATACLASS_REGISTRY.get(name)
elif k in REGISTRIES:
dc = REGISTRIES[k]["dataclass_registry"].get(name)
if dc is not None:
cfg[k] = merge_with_parent(dc, field_cfg)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import inspect
import logging
import os
import re
from argparse import ArgumentError, ArgumentParser, Namespace
from dataclasses import _MISSING_TYPE, MISSING, is_dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import FairseqConfig
from hydra.core.global_hydra import GlobalHydra
from hydra.experimental import compose, initialize
from omegaconf import DictConfig, OmegaConf, open_dict, _utils
logger = logging.getLogger(__name__)
def eval_str_list(x, x_type=float):
if x is None:
return None
if isinstance(x, str):
if len(x) == 0:
return []
x = ast.literal_eval(x)
try:
return list(map(x_type, x))
except TypeError:
return [x_type(x)]
def interpret_dc_type(field_type):
if isinstance(field_type, str):
raise RuntimeError("field should be a type")
if field_type == Any:
return str
typestring = str(field_type)
if re.match(
r"(typing.|^)Union\[(.*), NoneType\]$", typestring
) or typestring.startswith("typing.Optional"):
return field_type.__args__[0]
return field_type
def gen_parser_from_dataclass(
parser: ArgumentParser,
dataclass_instance: FairseqDataclass,
delete_default: bool = False,
with_prefix: Optional[str] = None,
) -> None:
"""
convert a dataclass instance to tailing parser arguments.
If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
building a flat namespace from a structured dataclass (see transformer_config.py for example).
"""
def argparse_name(name: str):
if name == "data" and (with_prefix is None or with_prefix == ""):
# normally data is positional args, so we don't add the -- nor the prefix
return name
if name == "_name":
# private member, skip
return None
full_name = "--" + name.replace("_", "-")
if with_prefix is not None and with_prefix != "":
# if a prefix is specified, construct the prefixed arg name
full_name = with_prefix + "-" + full_name[2:] # strip -- when composing
return full_name
def get_kwargs_from_dc(
dataclass_instance: FairseqDataclass, k: str
) -> Dict[str, Any]:
"""k: dataclass attributes"""
kwargs = {}
field_type = dataclass_instance._get_type(k)
inter_type = interpret_dc_type(field_type)
field_default = dataclass_instance._get_default(k)
if isinstance(inter_type, type) and issubclass(inter_type, Enum):
field_choices = [t.value for t in list(inter_type)]
else:
field_choices = None
field_help = dataclass_instance._get_help(k)
field_const = dataclass_instance._get_argparse_const(k)
if isinstance(field_default, str) and field_default.startswith("${"):
kwargs["default"] = field_default
else:
if field_default is MISSING:
kwargs["required"] = True
if field_choices is not None:
kwargs["choices"] = field_choices
if (
isinstance(inter_type, type)
and (issubclass(inter_type, List) or issubclass(inter_type, Tuple))
) or ("List" in str(inter_type) or "Tuple" in str(inter_type)):
if "int" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, int)
elif "float" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, float)
elif "str" in str(inter_type):
kwargs["type"] = lambda x: eval_str_list(x, str)
else:
raise NotImplementedError(
"parsing of type " + str(inter_type) + " is not implemented"
)
if field_default is not MISSING:
kwargs["default"] = (
",".join(map(str, field_default))
if field_default is not None
else None
)
elif (
isinstance(inter_type, type) and issubclass(inter_type, Enum)
) or "Enum" in str(inter_type):
kwargs["type"] = str
if field_default is not MISSING:
if isinstance(field_default, Enum):
kwargs["default"] = field_default.value
else:
kwargs["default"] = field_default
elif inter_type is bool:
kwargs["action"] = (
"store_false" if field_default is True else "store_true"
)
kwargs["default"] = field_default
else:
kwargs["type"] = inter_type
if field_default is not MISSING:
kwargs["default"] = field_default
# build the help with the hierarchical prefix
if with_prefix is not None and with_prefix != "" and field_help is not None:
field_help = with_prefix[2:] + ": " + field_help
kwargs["help"] = field_help
if field_const is not None:
kwargs["const"] = field_const
kwargs["nargs"] = "?"
return kwargs
for k in dataclass_instance._get_all_attributes():
field_name = argparse_name(dataclass_instance._get_name(k))
field_type = dataclass_instance._get_type(k)
if field_name is None:
continue
elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass):
# for fields that are of type FairseqDataclass, we can recursively
# add their fields to the namespace (so we add the args from model, task, etc. to the root namespace)
prefix = None
if with_prefix is not None:
# if a prefix is specified, then we don't want to copy the subfields directly to the root namespace
# but we prefix them with the name of the current field.
prefix = field_name
gen_parser_from_dataclass(parser, field_type(), delete_default, prefix)
continue
kwargs = get_kwargs_from_dc(dataclass_instance, k)
field_args = [field_name]
alias = dataclass_instance._get_argparse_alias(k)
if alias is not None:
field_args.append(alias)
if "default" in kwargs:
if isinstance(kwargs["default"], str) and kwargs["default"].startswith(
"${"
):
if kwargs["help"] is None:
# this is a field with a name that will be added elsewhere
continue
else:
del kwargs["default"]
if delete_default and "default" in kwargs:
del kwargs["default"]
try:
parser.add_argument(*field_args, **kwargs)
except ArgumentError:
pass
def _set_legacy_defaults(args, cls):
"""Helper to set default arguments based on *add_args*."""
if not hasattr(cls, "add_args"):
return
import argparse
parser = argparse.ArgumentParser(
argument_default=argparse.SUPPRESS, allow_abbrev=False
)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(args, key):
setattr(args, key, default_value)
def _override_attr(
sub_node: str, data_class: Type[FairseqDataclass], args: Namespace
) -> List[str]:
overrides = []
if not inspect.isclass(data_class) or not issubclass(data_class, FairseqDataclass):
return overrides
def get_default(f):
if not isinstance(f.default_factory, _MISSING_TYPE):
return f.default_factory()
return f.default
for k, v in data_class.__dataclass_fields__.items():
if k.startswith("_"):
# private member, skip
continue
val = get_default(v) if not hasattr(args, k) else getattr(args, k)
field_type = interpret_dc_type(v.type)
if (
isinstance(val, str)
and not val.startswith("${") # not interpolation
and field_type != str
and (
not inspect.isclass(field_type) or not issubclass(field_type, Enum)
) # not choices enum
):
# upgrade old models that stored complex parameters as string
val = ast.literal_eval(val)
if isinstance(val, tuple):
val = list(val)
v_type = getattr(v.type, "__origin__", None)
if (
(v_type is List or v_type is list or v_type is Optional)
# skip interpolation
and not (isinstance(val, str) and val.startswith("${"))
):
# if type is int but val is float, then we will crash later - try to convert here
if hasattr(v.type, "__args__"):
t_args = v.type.__args__
if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int):
val = list(map(t_args[0], val))
elif val is not None and (
field_type is int or field_type is bool or field_type is float
):
try:
val = field_type(val)
except:
pass # ignore errors here, they are often from interpolation args
if val is None:
overrides.append("{}.{}=null".format(sub_node, k))
elif val == "":
overrides.append("{}.{}=''".format(sub_node, k))
elif isinstance(val, str):
val = val.replace("'", r"\'")
overrides.append("{}.{}='{}'".format(sub_node, k, val))
elif isinstance(val, FairseqDataclass):
overrides += _override_attr(f"{sub_node}.{k}", type(val), args)
elif isinstance(val, Namespace):
sub_overrides, _ = override_module_args(val)
for so in sub_overrides:
overrides.append(f"{sub_node}.{k}.{so}")
else:
overrides.append("{}.{}={}".format(sub_node, k, val))
return overrides
def migrate_registry(
name, value, registry, args, overrides, deletes, use_name_as_val=False
):
if value in registry:
overrides.append("{}={}".format(name, value))
overrides.append("{}._name={}".format(name, value))
overrides.extend(_override_attr(name, registry[value], args))
elif use_name_as_val and value is not None:
overrides.append("{}={}".format(name, value))
else:
deletes.append(name)
def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
"""use the field in args to overrides those in cfg"""
overrides = []
deletes = []
for k in FairseqConfig.__dataclass_fields__.keys():
overrides.extend(
_override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args)
)
if args is not None:
if hasattr(args, "task"):
from fairseq.tasks import TASK_DATACLASS_REGISTRY
migrate_registry(
"task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes
)
else:
deletes.append("task")
# these options will be set to "None" if they have not yet been migrated
# so we can populate them with the entire flat args
CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"}
from fairseq.registry import REGISTRIES
for k, v in REGISTRIES.items():
if hasattr(args, k):
migrate_registry(
k,
getattr(args, k),
v["dataclass_registry"],
args,
overrides,
deletes,
use_name_as_val=k not in CORE_REGISTRIES,
)
else:
deletes.append(k)
no_dc = True
if hasattr(args, "arch"):
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY
if args.arch in ARCH_MODEL_REGISTRY:
m_cls = ARCH_MODEL_REGISTRY[args.arch]
dc = getattr(m_cls, "__dataclass", None)
if dc is not None:
m_name = ARCH_MODEL_NAME_REGISTRY[args.arch]
overrides.append("model={}".format(m_name))
overrides.append("model._name={}".format(args.arch))
# override model params with those exist in args
overrides.extend(_override_attr("model", dc, args))
no_dc = False
if no_dc:
deletes.append("model")
return overrides, deletes
class omegaconf_no_object_check:
def __init__(self):
self.old_is_primitive = _utils.is_primitive_type
def __enter__(self):
_utils.is_primitive_type = lambda _: True
def __exit__(self, type, value, traceback):
_utils.is_primitive_type = self.old_is_primitive
def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
"""Convert a flat argparse.Namespace to a structured DictConfig."""
# Here we are using field values provided in args to override counterparts inside config object
overrides, deletes = override_module_args(args)
# configs will be in fairseq/config after installation
config_path = os.path.join("..", "config")
GlobalHydra.instance().clear()
with initialize(config_path=config_path):
try:
composed_cfg = compose("config", overrides=overrides, strict=False)
except:
logger.error("Error when composing. Overrides: " + str(overrides))
raise
for k in deletes:
composed_cfg[k] = None
cfg = OmegaConf.create(
OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True)
)
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
# omegaconf version that supports object flags, or when we migrate all existing models
from omegaconf import _utils
with omegaconf_no_object_check():
if cfg.task is None and getattr(args, "task", None):
cfg.task = Namespace(**vars(args))
from fairseq.tasks import TASK_REGISTRY
_set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
cfg.task._name = args.task
if cfg.model is None and getattr(args, "arch", None):
cfg.model = Namespace(**vars(args))
from fairseq.models import ARCH_MODEL_REGISTRY
_set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
cfg.model._name = args.arch
if cfg.optimizer is None and getattr(args, "optimizer", None):
cfg.optimizer = Namespace(**vars(args))
from fairseq.optim import OPTIMIZER_REGISTRY
_set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
cfg.optimizer._name = args.optimizer
if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
cfg.lr_scheduler = Namespace(**vars(args))
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
_set_legacy_defaults(
cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]
)
cfg.lr_scheduler._name = args.lr_scheduler
if cfg.criterion is None and getattr(args, "criterion", None):
cfg.criterion = Namespace(**vars(args))
from fairseq.criterions import CRITERION_REGISTRY
_set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
cfg.criterion._name = args.criterion
OmegaConf.set_struct(cfg, True)
return cfg
def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
# this will be deprecated when we get rid of argparse and model_overrides logic
from fairseq.registry import REGISTRIES
with open_dict(cfg):
for k in cfg.keys():
# "k in cfg" will return false if its a "mandatory value (e.g. ???)"
if k in cfg and isinstance(cfg[k], DictConfig):
if k in overrides and isinstance(overrides[k], dict):
for ok, ov in overrides[k].items():
if isinstance(ov, dict) and cfg[k][ok] is not None:
overwrite_args_by_name(cfg[k][ok], ov)
else:
cfg[k][ok] = ov
else:
overwrite_args_by_name(cfg[k], overrides)
elif k in cfg and isinstance(cfg[k], Namespace):
for override_key, val in overrides.items():
setattr(cfg[k], override_key, val)
elif k in overrides:
if (
k in REGISTRIES
and overrides[k] in REGISTRIES[k]["dataclass_registry"]
):
cfg[k] = DictConfig(
REGISTRIES[k]["dataclass_registry"][overrides[k]]
)
overwrite_args_by_name(cfg[k], overrides)
cfg[k]._name = overrides[k]
else:
cfg[k] = overrides[k]
def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False):
if remove_missing:
if is_dataclass(dc):
target_keys = set(dc.__dataclass_fields__.keys())
else:
target_keys = set(dc.keys())
with open_dict(cfg):
for k in list(cfg.keys()):
if k not in target_keys:
del cfg[k]
merged_cfg = OmegaConf.merge(dc, cfg)
merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"]
OmegaConf.set_struct(merged_cfg, True)
return merged_cfg
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .distributed_timeout_wrapper import DistributedTimeoutWrapper
from .fully_sharded_data_parallel import (
fsdp_enable_wrap,
fsdp_wrap,
FullyShardedDataParallel,
)
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
from .module_proxy_wrapper import ModuleProxyWrapper
from .tpu_distributed_data_parallel import TPUDistributedDataParallel
__all__ = [
"DistributedTimeoutWrapper",
"fsdp_enable_wrap",
"fsdp_wrap",
"FullyShardedDataParallel",
"LegacyDistributedDataParallel",
"ModuleProxyWrapper",
"TPUDistributedDataParallel",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import signal
import threading
from torch import nn
logger = logging.getLogger(__name__)
class DistributedTimeoutWrapper(nn.Module):
"""
A wrapper that kills the process if no progress is made within a given
*timeout*. The timer is reset every time :func:`forward` is called.
Usage::
module = DistributedTimeoutWrapper(module, timeout=30)
x = module(input)
time.sleep(20) # safe
x = module(input)
time.sleep(45) # job will be killed before this returns
Args:
module (nn.Module): module to wrap
timeout (int): number of seconds before killing the process
(set to a value <= 0 to disable the timeout)
signal (Optional): signal to send once timeout is triggered
"""
def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT):
super().__init__()
self.module = module
self.timeout = timeout
self.signal = signal
if timeout > 0:
self._heartbeat = threading.Event()
self._heartbeat_thread = threading.Thread(
target=self._check_heartbeat,
args=(os.getpid(),),
daemon=True,
)
self._heartbeat_thread.start()
self._terminated = False
else:
self._heartbeat = None
self._heartbeat_thread = None
def __del__(self):
self.stop_timeout()
def __getattr__(self, name):
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name)
def stop_timeout(self):
if self._heartbeat_thread is not None:
self._terminated = True
self._heartbeat_thread.join()
def state_dict(self, *args, **kwargs):
return self.module.state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
return self.module.load_state_dict(*args, **kwargs)
def forward(self, *args, **kwargs):
if self._heartbeat is not None:
self._heartbeat.set()
return self.module(*args, **kwargs)
def _check_heartbeat(self, parent_pid):
self._heartbeat.wait() # wait for the first forward pass
while True:
self._heartbeat.clear()
success = self._heartbeat.wait(timeout=self.timeout)
if self._terminated:
break
elif not success:
logger.error(
(
"Killing job for not making progress in {} seconds. "
"Set --heartbeat-timeout=-1 to disable this timeout."
).format(int(self.timeout))
)
os.kill(parent_pid, self.signal)
return
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from typing import Optional
import torch
from fairseq.dataclass.configs import DistributedTrainingConfig
from fairseq.distributed import utils as dist_utils
try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
has_FSDP = True
except ImportError:
FSDP = torch.nn.Module
has_FSDP = False
class FullyShardedDataParallel(FSDP):
"""
A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some
fairseq-specific checkpoint saving/loading logic.
Args:
use_sharded_state (bool): if True, then ``state_dict`` will return
``FSDP.local_state_dict`` and ``load_state_dict`` will call
``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will
return the full model weights on data parallel rank 0 (empty on
other ranks) and ``load_state_dict`` will broadcast model weights
from rank 0 to other ranks.
"""
def __init__(self, *args, use_sharded_state: bool = False, **kwargs):
if not has_FSDP:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
super().__init__(*args, **kwargs)
self.use_sharded_state = use_sharded_state
@property
def unwrapped_module(self) -> torch.nn.Module:
if self.flatten_parameters:
return self.module.module
else:
return self.module
def state_dict(self, destination=None, prefix="", keep_vars=False):
if self.use_sharded_state:
return super().local_state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
else:
if self.rank == 0:
return super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
else:
# We must call state_dict() due to use of communication
# primitives. But we don't use the result.
super().state_dict()
return destination or {}
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
if self.use_sharded_state:
return super().load_local_state_dict(state_dict, strict=strict)
else:
state_dict = dist_utils.broadcast_object(
state_dict, src_rank=0, group=self.process_group
)
return super().load_state_dict(state_dict, strict=strict)
@contextlib.contextmanager
def fsdp_enable_wrap(cfg: DistributedTrainingConfig):
try:
from fairscale.nn import enable_wrap
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
if cfg.memory_efficient_fp16:
assert cfg.fp16 # memory_efficient_fp16 should imply fp16
group = dist_utils.get_data_parallel_group()
if group is None and cfg.distributed_world_size == 1:
from fairscale.utils.testing import DummyProcessGroup
group = DummyProcessGroup(rank=0, size=1)
fsdp_config = {
"process_group": group,
"reshard_after_forward": not cfg.no_reshard_after_forward,
"mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
"fp32_reduce_scatter": cfg.fp32_reduce_scatter,
"flatten_parameters": not cfg.not_fsdp_flatten_parameters,
"cpu_offload": cfg.cpu_offload,
"compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
"bucket_cap_mb": cfg.bucket_cap_mb,
"state_dict_device": torch.device("cpu"), # reduce GPU mem usage
}
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
use_sharded_state=cfg.use_sharded_state,
**fsdp_config,
):
yield
def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
"""
Helper to wrap layers/modules in FSDP. This falls back to a no-op if
fairscale is not available.
Args:
module (nn.Module): module to (maybe) wrap
min_num_params (int, Optional): minimum number of layer params to wrap
"""
try:
from fairscale.nn import wrap
if min_num_params is not None:
num_params = sum(p.numel() for p in module.parameters())
if num_params >= min_num_params:
return wrap(module, **kwargs)
else:
return module
else:
return wrap(module, **kwargs)
except ImportError:
return module
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A modified version of the legacy DistributedDataParallel module that uses c10d
communication primitives. This version is simpler than the latest PyTorch
version and is useful for debugging. Notably it does not overlap gradient
communication with the backward pass, which makes it slower but more robust
than the PyTorch version.
This version also supports the *no_sync* context manager, which allows faster
training with `--update-freq`.
"""
from collections import OrderedDict
from contextlib import contextmanager
import torch
from torch import nn
from fairseq.distributed import utils
class LegacyDistributedDataParallel(nn.Module):
"""Implements distributed data parallelism at the module level.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
broadcast buffers.
Args:
module (~torch.nn.Module): module to be parallelized
process_group: the c10d process group to be used for distributed data
parallel all-reduction.
buffer_size (int, optional): number of elements to buffer before
performing all-reduce (default: 256M).
"""
def __init__(self, module, process_group, buffer_size=2**28):
super().__init__()
self.module = module
self.process_group = process_group
self.world_size = utils.get_world_size(self.process_group)
# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
self.buffer = None
# We can also forcibly accumulate grads locally and only do the
# all-reduce at some later time
self.accumulate_grads = False
# make per-device lists of parameters
paramlists = OrderedDict()
for param in self.module.parameters():
device = param.device
if paramlists.get(device) is None:
paramlists[device] = []
paramlists[device] += [param]
self.per_device_params = list(paramlists.values())
@contextmanager
def no_sync(self):
"""A context manager to disable gradient synchronization."""
old_accumulate_grads = self.accumulate_grads
self.accumulate_grads = True
yield
self.accumulate_grads = old_accumulate_grads
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_grads(self):
"""
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
"""
def all_reduce_params(params):
buffer = self.buffer
nonzero_buffer = False
if len(params) > 1:
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
nonzero_buffer = True
else:
buffer[offset : offset + sz].zero_()
offset += sz
else:
# we only have a single grad to all-reduce
p = params[0]
if p.grad is not None:
buffer = p.grad.data
nonzero_buffer = True
elif p.numel() <= self.buffer.numel():
buffer = buffer[: p.numel()]
buffer.zero_()
else:
buffer = torch.zeros_like(p)
if nonzero_buffer:
buffer.div_(self.world_size)
utils.all_reduce(buffer, self.process_group)
# copy all-reduced grads back into their original place
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
else:
p.grad = buffer[offset : offset + sz].view_as(p).clone()
offset += sz
def reduction_fn():
# This function only needs to be called once
if self.accumulate_grads:
return
if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size)
for params in self.per_device_params:
# All-reduce the gradients in buckets
offset = 0
buffered_params = []
for param in params:
if not param.requires_grad:
continue
if param.grad is None:
param.grad = torch.zeros_like(param)
if hasattr(param, "expert"):
# Skip gradient sync for unshared parameters
continue
if param.grad.requires_grad:
raise RuntimeError(
"DistributedDataParallel only works "
"with gradients that don't require "
"grad"
)
sz = param.numel()
if sz > self.buffer.numel():
# all-reduce big params directly
all_reduce_params([param])
else:
if offset + sz > self.buffer.numel():
all_reduce_params(buffered_params)
offset = 0
buffered_params.clear()
buffered_params.append(param)
offset += sz
if len(buffered_params) > 0:
all_reduce_params(buffered_params)
reduction_fn()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn
class ModuleProxyWrapper(nn.Module):
"""
Wrap a DistributedDataParallel module and forward requests for missing
attributes to the module wrapped by DDP (the twice-wrapped module).
Also forward calls to :func:`state_dict` and :func:`load_state_dict`.
Usage::
module.xyz = "hello world"
wrapped_module = DistributedDataParallel(module, **ddp_args)
wrapped_module = ModuleProxyWrapper(wrapped_module)
assert wrapped_module.xyz == "hello world"
assert wrapped_module.state_dict().keys() == module.state_dict().keys()
Args:
module (nn.Module): module to wrap
"""
def __init__(self, module: nn.Module):
super().__init__()
assert hasattr(
module, "module"
), "ModuleProxyWrapper expects input to wrap another module"
self.module = module
def __getattr__(self, name):
"""Forward missing attributes to twice-wrapped module."""
try:
# defer to nn.Module's logic
return super().__getattr__(name)
except AttributeError:
try:
# forward to the once-wrapped module
return getattr(self.module, name)
except AttributeError:
# forward to the twice-wrapped module
return getattr(self.module.module, name)
def state_dict(self, *args, **kwargs):
"""Forward to the twice-wrapped module."""
return self.module.module.state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
"""Forward to the twice-wrapped module."""
return self.module.module.load_state_dict(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from fairseq.distributed import utils
class TPUDistributedDataParallel(nn.Module):
def __init__(self, module, process_group):
super().__init__()
self.module = module
self.process_group = process_group
self.world_size = utils.get_world_size(self.process_group)
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_grads(self):
gradients = []
for p in self.parameters():
if not p.requires_grad:
continue
if p.grad is None:
p.grad = torch.zeros_like(p)
if p.grad.requires_grad:
raise RuntimeError(
"TPUDistributedDataParallel only works with gradients that don't "
"require grad"
)
gradients.append(p.grad)
import torch_xla.core.xla_model as xm
xm.all_reduce(
"sum",
gradients,
scale=1.0 / self.world_size,
groups=self.process_group[1],
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import io
import logging
import os
import pickle
import random
import socket
import struct
import subprocess
import warnings
from argparse import Namespace
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional
import torch
import torch.distributed as dist
from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig
from omegaconf import open_dict
try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None
# Flag to indicate if we're using Megatron
# NOTE: this is a temporary hack until we move away from Megatron's model parallel init
_USE_MEGATRON = False
# Whether to use XLA ops (e.g., on TPUs) instead of CUDA ops.
_USE_XLA = False
logger = logging.getLogger(__name__)
def is_master(cfg: DistributedTrainingConfig):
return cfg.distributed_rank == 0
def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False):
if cfg.distributed_init_method is not None or cfg.tpu:
return
num_pipelines_per_node = None
if cfg.pipeline_model_parallel:
num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg)
if all(
key in os.environ
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
):
# support torch.distributed.launch
_infer_torch_distributed_launch_init(cfg)
elif cfg.distributed_port > 0:
# we can determine the init method automatically for Slurm
_infer_slurm_init(cfg, num_pipelines_per_node)
elif cfg.distributed_world_size > 1 or force_distributed:
# fallback for single node with multiple GPUs
_infer_single_node_init(cfg)
if cfg.pipeline_model_parallel:
_pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node)
elif not cfg.distributed_no_spawn:
with open_dict(cfg):
cfg.distributed_num_procs = min(
torch.cuda.device_count(), cfg.distributed_world_size
)
def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig):
cfg.distributed_init_method = "env://"
cfg.distributed_world_size = int(os.environ["WORLD_SIZE"])
cfg.distributed_rank = int(os.environ["RANK"])
# processes are created by torch.distributed.launch
cfg.distributed_no_spawn = True
def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node):
node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
node_list = os.environ.get("SLURM_JOB_NODELIST")
if node_list is not None:
try:
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", node_list]
)
cfg.distributed_init_method = "tcp://{host}:{port}".format(
host=hostnames.split()[0].decode("utf-8"),
port=cfg.distributed_port,
)
nnodes = int(os.environ.get("SLURM_NNODES"))
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
if ntasks_per_node is not None:
ntasks_per_node = int(ntasks_per_node)
else:
ntasks = int(os.environ.get("SLURM_NTASKS"))
nnodes = int(os.environ.get("SLURM_NNODES"))
assert ntasks % nnodes == 0
ntasks_per_node = int(ntasks / nnodes)
if ntasks_per_node == 1:
gpus_per_node = torch.cuda.device_count()
node_id = int(os.environ.get("SLURM_NODEID"))
cfg.distributed_rank = node_id * gpus_per_node
cfg.distributed_world_size = nnodes * gpus_per_node
elif cfg.pipeline_model_parallel:
assert ntasks_per_node == num_pipelines_per_node, (
"SLURM --ntasks-per-node must match number of pipelines per "
"node (={})".format(num_pipelines_per_node)
)
cfg.distributed_no_spawn = True
# For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on
# the first node, [1, 2] on the second node, etc. This
# matches torch.distributed.launch.
node_id = int(os.environ.get("SLURM_NODEID"))
local_id = int(os.environ.get("SLURM_LOCALID"))
cfg.distributed_rank = node_id * num_pipelines_per_node + local_id
# In the above example, device_id will always be in [0, 1],
# which also matches torch.distributed.launch.
cfg.device_id = local_id
# We also want to set distributed_world_size to be the total
# number of pipelines across all nodes.
cfg.distributed_world_size = nnodes * num_pipelines_per_node
else:
assert ntasks_per_node == cfg.distributed_world_size // nnodes
cfg.distributed_no_spawn = True
cfg.distributed_rank = int(os.environ.get("SLURM_PROCID"))
cfg.device_id = int(os.environ.get("SLURM_LOCALID"))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
def _infer_single_node_init(cfg: DistributedTrainingConfig):
assert (
cfg.distributed_world_size <= torch.cuda.device_count()
), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices"
port = random.randint(10000, 20000)
cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port)
def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig):
from fairseq import utils
balance_exists = (
cfg.pipeline_balance is not None
or cfg.pipeline_encoder_balance is not None
or cfg.pipeline_decoder_balance is not None
)
devices_exist = (
cfg.pipeline_devices is not None
or cfg.pipeline_encoder_devices is not None
or cfg.pipeline_decoder_devices is not None
)
if not balance_exists:
raise ValueError(
"--pipeline-balance is currently required for pipeline model parallelism"
)
if not devices_exist:
raise ValueError(
"--pipeline-devices is currently required for pipeline model parallelism"
)
cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int)
if cfg.pipeline_devices is not None:
cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int)
num_pipeline_devices = len(set(cfg.pipeline_devices))
else:
cfg.pipeline_encoder_devices = utils.eval_str_list(
cfg.pipeline_encoder_devices, type=int
)
cfg.pipeline_decoder_devices = utils.eval_str_list(
cfg.pipeline_decoder_devices, type=int
)
num_pipeline_devices = len(
set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices)
)
gpus_per_node = torch.cuda.device_count()
assert (
gpus_per_node >= num_pipeline_devices
and gpus_per_node % num_pipeline_devices == 0
), (
"the number of unique device IDs in --pipeline-devices must evenly divide "
"the number of GPUs per node (multi-node pipelining is not yet supported)"
)
num_pipelines_per_node = gpus_per_node // num_pipeline_devices
return num_pipeline_devices, num_pipelines_per_node
def _pipeline_parallel_post_init(
cfg: DistributedTrainingConfig, num_pipeline_devices, num_pipelines_per_node
):
if not cfg.distributed_no_spawn:
# When distributed_no_spawn is False, we expect distributed_rank and
# distributed_world_size to be based on the total number of GPUs, so
# we need to correct them to be based on the number of pipelines.
assert cfg.distributed_world_size % num_pipeline_devices == 0
cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices
# In the case of 4-way MP on nodes with 8 GPUs, we want
# distributed_rank to be the starting GPU index for each pipeline
# i.e., 0, 2, ...
gpus_per_node = torch.cuda.device_count()
assert cfg.distributed_rank % gpus_per_node == 0
assert cfg.distributed_rank % num_pipeline_devices == 0
with open_dict(cfg):
cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices
# launch one process per pipeline
cfg.distributed_num_procs = num_pipelines_per_node
# if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0
# and 4, indicating the starting device IDs for each pipeline
cfg.device_id *= num_pipeline_devices
if cfg.device_id > 0:
# if there's multiple pipelines on a node (e.g., 4-way MP on an 8
# GPU node), we need to adjust pipeline_devices accordingly
logger.debug(
"setting CUDA device={} on rank {}".format(
cfg.device_id, cfg.distributed_rank
)
)
torch.cuda.set_device(cfg.device_id)
with open_dict(cfg):
cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices]
logger.info(
"setting pipeline_devices={} on rank {}".format(
cfg.pipeline_devices, cfg.distributed_rank
)
)
def distributed_init(cfg: FairseqConfig):
if isinstance(cfg, Namespace):
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
cfg = convert_namespace_to_omegaconf(cfg)
if not cfg.common.tpu:
if torch.distributed.is_available() and torch.distributed.is_initialized():
warnings.warn(
"Distributed is already initialized, cannot initialize twice!"
)
else:
logger.info(
"distributed init (rank {}): {}".format(
cfg.distributed_training.distributed_rank,
cfg.distributed_training.distributed_init_method,
)
)
dist.init_process_group(
backend=cfg.distributed_training.distributed_backend,
init_method=cfg.distributed_training.distributed_init_method,
world_size=cfg.distributed_training.distributed_world_size,
rank=cfg.distributed_training.distributed_rank,
)
logger.info(
"initialized host {} as rank {}".format(
socket.gethostname(),
cfg.distributed_training.distributed_rank,
)
)
# perform a dummy all-reduce to initialize the NCCL communicator
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())
cfg.distributed_training.distributed_rank = torch.distributed.get_rank()
else:
assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size
global _USE_XLA
_USE_XLA = True
cfg.distributed_training.device_id = xm.get_local_ordinal()
cfg.distributed_training.distributed_rank = xm.get_ordinal()
xm.rendezvous("distributed_init") # wait for all workers
if is_master(cfg.distributed_training):
logging.getLogger().setLevel(logging.INFO)
else:
logging.getLogger().setLevel(logging.WARNING)
if cfg.common.model_parallel_size > 1:
try:
from fairseq.model_parallel.megatron.mpu import (
initialize_model_parallel,
model_parallel_cuda_manual_seed,
)
except ImportError:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
global _USE_MEGATRON
_USE_MEGATRON = True
initialize_model_parallel(cfg.common.model_parallel_size)
model_parallel_cuda_manual_seed(cfg.common.seed)
model_part_number = get_model_parallel_rank()
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0:
cfg.checkpoint.checkpoint_suffix = (
f"-rank-{cfg.distributed_training.distributed_rank}"
)
return cfg.distributed_training.distributed_rank
def distributed_main(i, main, cfg: FairseqConfig, kwargs):
cfg.distributed_training.device_id = i
if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu:
torch.cuda.set_device(cfg.distributed_training.device_id)
if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn
cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i
cfg.distributed_training.distributed_rank = distributed_init(cfg)
after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
if after_distributed_init_fn:
cfg = after_distributed_init_fn(cfg)
main(cfg, **kwargs)
if torch.distributed.is_initialized():
torch.distributed.barrier(get_global_group())
def call_main(cfg: FairseqConfig, main, **kwargs):
if cfg.distributed_training.distributed_init_method is None:
infer_init_method(cfg.distributed_training)
if cfg.distributed_training.distributed_init_method is not None:
# distributed training
if not cfg.distributed_training.distributed_no_spawn:
start_rank = cfg.distributed_training.distributed_rank
cfg.distributed_training.distributed_rank = None # assign automatically
kwargs["start_rank"] = start_rank
torch.multiprocessing.spawn(
fn=distributed_main,
args=(main, cfg, kwargs),
nprocs=min(
torch.cuda.device_count(),
cfg.distributed_training.distributed_world_size,
),
join=True,
)
else:
distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1:
import torch_xla.distributed.xla_multiprocessing as xmp
torch.multiprocessing.set_sharing_strategy("file_system")
xmp.spawn(
fn=distributed_main,
args=(main, cfg, kwargs),
# tpu-comment:
# 8 devices in one TPU VM, is the max processes to be spawned.
# The rest is driven by xm.distributed.xla_dist
nprocs=min(cfg.distributed_training.distributed_world_size, 8),
)
else:
# single GPU main
main(cfg, **kwargs)
def use_xla():
global _USE_XLA
return _USE_XLA
def new_groups(grouped_ranks: List[List[int]]):
if use_xla():
return ("tpu", grouped_ranks)
else:
groups = [dist.new_group(g) for g in grouped_ranks]
my_group_idx = _find_my_group_index(grouped_ranks)
return groups[my_group_idx]
def _find_my_group_index(grouped_ranks):
my_rank = get_global_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def _find_my_group(grouped_ranks):
index = _find_my_group_index(grouped_ranks)
return grouped_ranks[index]
def get_rank(group):
if use_xla():
assert group[0] == "tpu"
my_group = _find_my_group(group[1])
return my_group.index(get_global_rank())
else:
return dist.get_rank(group=group)
def get_world_size(group):
if use_xla():
assert group[0] == "tpu"
my_group = _find_my_group(group[1])
return len(my_group)
elif torch.distributed.is_initialized():
return dist.get_world_size(group=group)
else:
return 1
def get_global_group():
if use_xla():
return new_groups([list(range(get_global_world_size()))])
elif torch.distributed.is_initialized():
if not hasattr(get_global_group, "_global_group"):
# ideally we could use torch.distributed.group.WORLD, but it seems
# to cause random NCCL hangs in some cases
get_global_group._global_group = dist.new_group()
return get_global_group._global_group
else:
return None
def get_global_rank():
if use_xla():
return xm.get_ordinal()
elif torch.distributed.is_initialized():
return torch.distributed.get_rank()
else:
return 0
def get_global_world_size():
if use_xla():
return xm.xrt_world_size()
elif torch.distributed.is_initialized():
return torch.distributed.get_world_size()
else:
return 1
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
global _USE_MEGATRON
if _USE_MEGATRON:
from fairseq.model_parallel.megatron import mpu
return mpu.get_data_parallel_group()
else:
return get_global_group()
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return get_rank(get_data_parallel_group())
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return get_world_size(get_data_parallel_group())
def get_model_parallel_group():
global _USE_MEGATRON
if _USE_MEGATRON:
from fairseq.model_parallel.megatron import mpu
return mpu.get_model_parallel_group()
else:
return None
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
return get_rank(get_model_parallel_group())
def get_model_parallel_world_size():
"""Return world size for the model parallel group."""
return get_world_size(get_model_parallel_group())
def all_reduce(tensor, group, op="sum"):
if use_xla():
assert isinstance(group, tuple) and group[0] == "tpu"
tensor = [tensor] # wrap in a list to make xm.all_reduce in-place
return xm.all_reduce(op, tensor, groups=group[1])[0]
else:
if op == "sum":
op = dist.ReduceOp.SUM
elif op == "max":
op = dist.ReduceOp.MAX
else:
raise NotImplementedError
dist.all_reduce(tensor, op=op, group=group)
return tensor
def broadcast(tensor, src, group):
if use_xla():
# XLA doesn't support broadcast, hack it with all_reduce
if get_rank(group) != src:
tensor.zero_()
all_reduce(tensor, group)
else:
dist.broadcast(tensor, src=src, group=group)
def all_to_all(tensor, group):
"""Perform an all-to-all operation on a 1D Tensor."""
assert tensor.dim() == 1
split_count = get_world_size(group=group)
assert tensor.numel() % split_count == 0
if use_xla():
assert isinstance(group, tuple) and group[0] == "tpu"
return xm.all_to_all(
tensor,
split_dimension=0,
concat_dimension=0,
split_count=split_count,
groups=group[1],
)
else:
output = torch.zeros_like(tensor)
dist.all_to_all_single(output, tensor, group=group)
return output
def all_gather(tensor, group, return_tensor=False):
"""Perform an all-gather operation."""
if use_xla():
result = xm.all_gather(tensor, groups=group[1])
world_size = get_world_size(group=group)
result = result.view(world_size, *tensor.size())
if return_tensor:
return result
else:
return [result[i] for i in range(world_size)]
else:
world_size = get_world_size(group=group)
rank = get_rank(group=group)
tensor_list = [
tensor if i == rank else torch.empty_like(tensor) for i in range(world_size)
]
dist.all_gather(tensor_list, tensor, group=group)
if return_tensor:
return torch.stack(tensor_list, dim=0)
else:
return tensor_list
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable and any CUDA tensors will be moved
to CPU and returned on CPU as well.
Args:
data (Any): data from the local worker to be gathered on other workers
group: group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
from fairseq import utils
if group is None:
group = get_global_group()
rank = get_rank(group=group)
world_size = get_world_size(group=group)
buffer_size = max_size * world_size
if (
not hasattr(all_gather_list, "_buffer")
or all_gather_list._buffer.numel() < buffer_size
):
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
buffer.zero_()
cpu_buffer = all_gather_list._cpu_buffer
data = utils.move_to_cpu(data)
enc = pickle.dumps(data)
enc_size = len(enc)
header_size = 4 # size of header that contains the length of the encoded data
size = header_size + enc_size
if size > max_size:
raise ValueError(
"encoded data size ({}) exceeds max_size ({})".format(size, max_size)
)
header = struct.pack(">I", enc_size)
cpu_buffer[:size] = torch.ByteTensor(list(header + enc))
start = rank * max_size
buffer[start : start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group)
buffer = buffer.cpu()
try:
result = []
for i in range(world_size):
out_buffer = buffer[i * max_size : (i + 1) * max_size]
(enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist()))
if enc_size > 0:
result.append(
pickle.loads(
bytes(out_buffer[header_size : header_size + enc_size].tolist())
)
)
return result
except pickle.UnpicklingError:
raise Exception(
"Unable to unpickle data from other workers. all_gather_list requires all "
"workers to enter the function together, so this error usually indicates "
"that the workers have fallen out of sync somehow. Workers can fall out of "
"sync if one of them runs out of memory, or if there are other conditions "
"in your training script that can cause one worker to finish an epoch "
"while other workers are still iterating over their portions of the data. "
"Try rerunning with --ddp-backend=legacy_ddp and see if that helps."
)
def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]:
"""
AllReduce a dictionary of values across workers. We separately
reduce items that are already on the device and items on CPU for
better performance.
Args:
data (Mapping[str, Any]): dictionary of data to all-reduce, but
cannot be a nested dictionary
device (torch.device): device for the reduction
group: group of the collective
"""
data_keys = list(data.keys())
# We want to separately reduce items that are already on the
# device and items on CPU for performance reasons.
cpu_data = OrderedDict()
device_data = OrderedDict()
for k in data_keys:
t = data[k]
if not torch.is_tensor(t):
cpu_data[k] = torch.tensor(t, dtype=torch.double)
elif t.device.type != device.type:
cpu_data[k] = t.to(dtype=torch.double)
else:
device_data[k] = t.to(dtype=torch.double)
def _all_reduce_dict(data: OrderedDict):
if len(data) == 0:
return data
buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device)
all_reduce(buf, group=group)
split_buf = torch.split(buf.clone(), [t.numel() for t in data.values()])
reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())]
return OrderedDict(zip(data.keys(), reduced_data))
cpu_data = _all_reduce_dict(cpu_data)
device_data = _all_reduce_dict(device_data)
def get_from_stack(key):
if key in cpu_data:
return cpu_data[key]
elif key in device_data:
return device_data[key]
raise KeyError
return OrderedDict([(key, get_from_stack(key)) for key in data_keys])
def broadcast_tensors(
tensors: Optional[List[torch.Tensor]],
src_rank: int,
group: object,
dist_device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
"""
Broadcasts a list of tensors without other (non-src) ranks needing to know
the dtypes/shapes of the tensors.
"""
if dist_device is None:
if torch.distributed.get_backend(group) == "nccl":
dist_device = torch.device("cuda")
else:
dist_device = torch.device("cpu")
# share metadata first to simplify transfer
is_src_rank = get_rank(group) == src_rank
if is_src_rank:
metadata = [
{"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
]
metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device)
else:
metadata = _broadcast_object_slow(None, src_rank, group, dist_device)
out_tensors = []
for i, meta in enumerate(metadata):
if is_src_rank:
tensor = tensors[i]
broadcast(tensors[i].to(dist_device), src=src_rank, group=group)
else:
tensor = torch.zeros(
[meta["size"].numel()], dtype=meta["dtype"], device=dist_device
)
broadcast(tensor, src=src_rank, group=group)
tensor = tensor.view(meta["size"]).to(meta["device"])
out_tensors.append(tensor)
return out_tensors
def broadcast_object(
obj: Any,
src_rank: int,
group: object,
dist_device: Optional[torch.device] = None,
) -> Any:
"""Broadcast an arbitrary Python object to other workers."""
if dist_device is None:
if torch.distributed.get_backend(group) == "nccl":
dist_device = torch.device("cuda")
else:
dist_device = torch.device("cpu")
if get_rank(group) == src_rank:
# split the tensors from the non-tensors so we can broadcast them
# directly, avoiding unnecessary serialization/deserialization
tensors = []
obj = _split_tensors_from_obj(obj, tensors)
obj = _broadcast_object_slow(obj, src_rank, group, dist_device)
tensors = broadcast_tensors(tensors, src_rank, group, dist_device)
else:
obj = _broadcast_object_slow(None, src_rank, group, dist_device)
tensors = broadcast_tensors(None, src_rank, group, dist_device)
return _put_tensors_in_obj(obj, tensors)
def _broadcast_object_slow(
obj: Any,
src_rank: int,
group: object,
dist_device: torch.device,
) -> Any:
if get_rank(group) == src_rank:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device)
length = torch.LongTensor([len(buffer)]).to(dist_device)
broadcast(length, src=src_rank, group=group)
broadcast(buffer, src=src_rank, group=group)
else:
# Fetch from the source
length = torch.LongTensor([0]).to(dist_device)
broadcast(length, src=src_rank, group=group)
buffer = torch.ByteTensor(int(length.item())).to(dist_device)
broadcast(buffer, src=src_rank, group=group)
buffer = io.BytesIO(buffer.cpu().numpy())
obj = torch.load(buffer, map_location="cpu")
return obj
@dataclass(frozen=True)
class _TensorPlaceholder:
index: int
def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
if torch.is_tensor(obj):
placeholder = _TensorPlaceholder(index=len(tensors))
tensors.append(obj)
return placeholder
elif isinstance(obj, dict):
return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()}
elif isinstance(obj, list):
return [_split_tensors_from_obj(v, tensors) for v in obj]
elif isinstance(obj, tuple):
return tuple(_split_tensors_from_obj(v, tensors) for v in obj)
elif isinstance(obj, set):
return {_split_tensors_from_obj(v, tensors) for v in obj}
else:
return obj
def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any:
if isinstance(obj, _TensorPlaceholder):
return tensors[obj.index]
elif isinstance(obj, dict):
return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()}
elif isinstance(obj, list):
return [_put_tensors_in_obj(v, tensors) for v in obj]
elif isinstance(obj, tuple):
return tuple(_put_tensors_in_obj(v, tensors) for v in obj)
elif isinstance(obj, set):
return {_put_tensors_in_obj(v, tensors) for v in obj}
else:
return obj
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import typing as tp
def _safe_readline(fd) -> str:
pos = fd.tell()
while True:
try:
return fd.readline()
except UnicodeDecodeError:
pos -= 1
fd.seek(pos) # search where this character begins
def find_offsets(filename: str, num_chunks: int) -> tp.List[int]:
"""
given a file and a number of chuncks, find the offsets in the file
to be able to chunk around full lines.
"""
with open(filename, "r", encoding="utf-8") as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
_safe_readline(f)
offsets[i] = f.tell()
offsets[-1] = size
return offsets
class ChunkLineIterator:
"""
Iterator to properly iterate over lines of a file chunck.
"""
def __init__(self, fd, start_offset: int, end_offset: int):
self._fd = fd
self._start_offset = start_offset
self._end_offset = end_offset
def __iter__(self) -> tp.Iterable[str]:
self._fd.seek(self._start_offset)
# next(f) breaks f.tell(), hence readline() must be used
line = _safe_readline(self._fd)
while line:
pos = self._fd.tell()
# f.tell() does not always give the byte position in the file
# sometimes it skips to a very large number
# it is unlikely that through a normal read we go from
# end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
# that the procedure breaks by the undeterministic behavior of
# f.tell()
if (
self._end_offset > 0
and pos > self._end_offset
and pos < self._end_offset + 2**32
):
break
yield line
line = self._fd.readline()
class Chunker:
"""
contextmanager to read a chunck of a file line by line.
"""
def __init__(self, path: str, start_offset: int, end_offset: int):
self.path = path
self.start_offset = start_offset
self.end_offset = end_offset
def __enter__(self) -> ChunkLineIterator:
self.fd = open(self.path, "r", encoding="utf-8")
return ChunkLineIterator(self.fd, self.start_offset, self.end_offset)
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.fd.close()
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import shutil
from typing import List, Optional
logger = logging.getLogger(__file__)
try:
from iopath.common.file_io import g_pathmgr as IOPathManager
try:
# [FB only - for now] AWS PathHandler for PathManager
from .fb_pathhandlers import S3PathHandler
IOPathManager.register_handler(S3PathHandler())
except KeyError:
logging.warning("S3PathHandler already registered.")
except ImportError:
logging.debug(
"S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module."
)
except ImportError:
IOPathManager = None
class PathManager:
"""
Wrapper for insulating OSS I/O (using Python builtin operations) from
iopath's PathManager abstraction (for transparently handling various
internal backends).
"""
@staticmethod
def open(
path: str,
mode: str = "r",
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
):
if IOPathManager:
return IOPathManager.open(
path=path,
mode=mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
)
return open(
path,
mode=mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
)
@staticmethod
def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool:
if IOPathManager:
return IOPathManager.copy(
src_path=src_path, dst_path=dst_path, overwrite=overwrite
)
return shutil.copyfile(src_path, dst_path)
@staticmethod
def get_local_path(path: str, **kwargs) -> str:
if IOPathManager:
return IOPathManager.get_local_path(path, **kwargs)
return path
@staticmethod
def exists(path: str) -> bool:
if IOPathManager:
return IOPathManager.exists(path)
return os.path.exists(path)
@staticmethod
def isfile(path: str) -> bool:
if IOPathManager:
return IOPathManager.isfile(path)
return os.path.isfile(path)
@staticmethod
def ls(path: str) -> List[str]:
if IOPathManager:
return IOPathManager.ls(path)
return os.listdir(path)
@staticmethod
def mkdirs(path: str) -> None:
if IOPathManager:
return IOPathManager.mkdirs(path)
os.makedirs(path, exist_ok=True)
@staticmethod
def rm(path: str) -> None:
if IOPathManager:
return IOPathManager.rm(path)
os.remove(path)
@staticmethod
def chmod(path: str, mode: int) -> None:
if not PathManager.path_requires_pathmanager(path):
os.chmod(path, mode)
@staticmethod
def register_handler(handler) -> None:
if IOPathManager:
return IOPathManager.register_handler(handler=handler)
@staticmethod
def copy_from_local(
local_path: str, dst_path: str, overwrite: bool = False, **kwargs
) -> None:
if IOPathManager:
return IOPathManager.copy_from_local(
local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs
)
return shutil.copyfile(local_path, dst_path)
@staticmethod
def path_requires_pathmanager(path: str) -> bool:
"""Do we require PathManager to access given path?"""
if IOPathManager:
for p in IOPathManager._path_handlers.keys():
if path.startswith(p):
return True
return False
@staticmethod
def supports_rename(path: str) -> bool:
# PathManager doesn't yet support renames
return not PathManager.path_requires_pathmanager(path)
@staticmethod
def rename(src: str, dst: str):
os.rename(src, dst)
"""
ioPath async PathManager methods:
"""
@staticmethod
def opena(
path: str,
mode: str = "r",
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
):
"""
Return file descriptor with asynchronous write operations.
"""
global IOPathManager
if not IOPathManager:
logging.info("ioPath is initializing PathManager.")
try:
from iopath.common.file_io import PathManager
IOPathManager = PathManager()
except Exception:
logging.exception("Failed to initialize ioPath PathManager object.")
return IOPathManager.opena(
path=path,
mode=mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
)
@staticmethod
def async_close() -> bool:
"""
Wait for files to be written and clean up asynchronous PathManager.
NOTE: `PathManager.async_close()` must be called at the end of any
script that uses `PathManager.opena(...)`.
"""
global IOPathManager
if IOPathManager:
return IOPathManager.async_close()
return False
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Utilities for working with the local dataset cache.
This file is adapted from `AllenNLP <https://github.com/allenai/allennlp>`_.
and `huggingface <https://github.com/huggingface>`_.
"""
import fnmatch
import json
import logging
import os
import shutil
import tarfile
import tempfile
from functools import partial, wraps
from hashlib import sha256
from io import open
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv(
"TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")
)
)
default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq")
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path))
except (AttributeError, ImportError):
PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def load_archive_file(archive_file):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=None)
except EnvironmentError:
logger.info(
"Archive name '{}' was not found in archive name list. "
"We assumed '{}' was a path or URL but couldn't find any file "
"associated to this path or URL.".format(
archive_file,
archive_file,
)
)
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
else:
logger.info(
"loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file
)
)
# Extract archive to temp dir and replace .tar.bz2 if necessary
tempdir = None
if not os.path.isdir(resolved_archive_file):
tempdir = tempfile.mkdtemp()
logger.info(
"extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir
)
)
ext = os.path.splitext(archive_file)[1][1:]
with tarfile.open(resolved_archive_file, "r:" + ext) as archive:
top_dir = os.path.commonprefix(archive.getnames())
archive.extractall(tempdir)
os.remove(resolved_archive_file)
shutil.move(os.path.join(tempdir, top_dir), resolved_archive_file)
shutil.rmtree(tempdir)
return resolved_archive_file
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the URL's, delimited
by a period.
"""
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
return url, etag
def cached_path_from_pm(url_or_filename):
"""
Tries to cache the specified URL using PathManager class.
Returns the cached path if success otherwise failure.
"""
try:
from fairseq.file_io import PathManager
local_path = PathManager.get_local_path(url_or_filename)
return local_path
except Exception:
return None
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ("http", "https", "s3"):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
cached_path = cached_path_from_pm(url_or_filename)
if cached_path:
return cached_path
# Something unknown
raise ValueError(
"unable to parse {} as a URL or as a local path".format(url_or_filename)
)
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
from botocore.exceptions import ClientError
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
import boto3
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
import boto3
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def request_wrap_timeout(func, url):
import requests
for attempt, timeout in enumerate([10, 20, 40, 60, 60]):
try:
return func(timeout=timeout)
except requests.exceptions.Timeout as e:
logger.warning(
"Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs",
url,
attempt,
timeout,
exc_info=e,
)
continue
raise RuntimeError(f"Unable to fetch file {url}")
def http_get(url, temp_file):
import requests
from tqdm import tqdm
req = request_wrap_timeout(partial(requests.get, url, stream=True), url)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
try:
import requests
response = request_wrap_timeout(
partial(requests.head, url, allow_redirects=True), url
)
if response.status_code != 200:
etag = None
else:
etag = response.headers.get("ETag")
except RuntimeError:
etag = None
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*")
matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, "wb") as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename):
"""
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
"""
collection = set()
with open(filename, "r", encoding="utf-8") as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
import logging
import os
from typing import Any, Dict, Iterator, List
import torch
from omegaconf import open_dict
from torch import nn
from fairseq import utils
from fairseq.data import encoders
logger = logging.getLogger(__name__)
def from_pretrained(
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
archive_map=None,
**kwargs
):
from fairseq import checkpoint_utils, file_utils
if archive_map is not None:
if model_name_or_path in archive_map:
model_name_or_path = archive_map[model_name_or_path]
if data_name_or_path is not None and data_name_or_path in archive_map:
data_name_or_path = archive_map[data_name_or_path]
# allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
# for each model
if isinstance(model_name_or_path, dict):
for k, v in model_name_or_path.items():
if k == "checkpoint_file":
checkpoint_file = v
elif (
k != "path"
# only set kwargs that don't already have overrides
and k not in kwargs
):
kwargs[k] = v
model_name_or_path = model_name_or_path["path"]
model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive
if data_name_or_path.startswith("."):
kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path))
else:
kwargs["data"] = file_utils.load_archive_file(data_name_or_path)
for file, arg in {
"code": "bpe_codes",
"bpecodes": "bpe_codes",
"sentencepiece.bpe.model": "sentencepiece_model",
"merges.txt": "bpe_merges",
"vocab.json": "bpe_vocab",
}.items():
path = os.path.join(model_path, file)
if os.path.exists(path):
kwargs[arg] = path
if "user_dir" in kwargs:
utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
arg_overrides=kwargs,
)
return {
"args": args,
"task": task,
"models": models,
}
class GeneratorHubInterface(nn.Module):
"""
PyTorch Hub interface for generating sequences from a pre-trained
translation or language model.
"""
def __init__(self, cfg, task, models):
super().__init__()
self.cfg = cfg
self.task = task
self.models = nn.ModuleList(models)
self.src_dict = task.source_dictionary
self.tgt_dict = task.target_dictionary
# optimize model for generation
for model in self.models:
model.prepare_for_inference_(cfg)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(cfg.generation.replace_unk)
self.tokenizer = encoders.build_tokenizer(cfg.tokenizer)
self.bpe = encoders.build_bpe(cfg.bpe)
self.max_positions = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in models]
)
# this is useful for determining the device
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
def translate(
self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
) -> List[str]:
return self.sample(sentences, beam, verbose, **kwargs)
def sample(
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
) -> List[str]:
if isinstance(sentences, str):
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]
def score(
self, sentences: List[str], replace_newline_with_eos: bool = False, **kwargs
):
if isinstance(sentences, str):
return self.score(
[sentences], replace_newline_with_eos=replace_newline_with_eos, **kwargs
)[0]
def encode(sentence):
if replace_newline_with_eos:
return torch.cat([self.encode(line) for line in sentence.splitlines()])
else:
return self.encode(sentence)
# NOTE: this doesn't support translation tasks currently
tokenized_sentences = [encode(sentence) for sentence in sentences]
return [
hypos[0]
for hypos in self.generate(
tokenized_sentences, score_reference=True, **kwargs
)
]
def generate(
self,
tokenized_sentences: List[torch.LongTensor],
beam: int = 5,
verbose: bool = False,
skip_invalid_size_inputs=False,
inference_step_args=None,
prefix_allowed_tokens_fn=None,
**kwargs
) -> List[List[Dict[str, torch.Tensor]]]:
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
return self.generate(
tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
)[0]
# build generator using current args as well as any kwargs
gen_args = copy.deepcopy(self.cfg.generation)
with open_dict(gen_args):
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(
self.models,
gen_args,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
inference_step_args = inference_step_args or {}
results = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
translations = self.task.inference_step(
generator, self.models, batch, **inference_step_args
)
for id, hypos in zip(batch["id"].tolist(), translations):
results.append((id, hypos))
# sort output to match input order
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
if verbose:
def getarg(name, default):
return getattr(gen_args, name, getattr(self.cfg, name, default))
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
src_str_with_unk = self.string(source_tokens)
logger.info("S\t{}".format(src_str_with_unk))
for hypo in target_hypotheses:
hypo_str = self.decode(hypo["tokens"])
logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
logger.info(
"P\t{}".format(
" ".join(
map(
lambda x: "{:.4f}".format(x),
hypo["positional_scores"].tolist(),
)
)
)
)
if hypo["alignment"] is not None and getarg(
"print_alignment", False
):
logger.info(
"A\t{}".format(
" ".join(
[
"{}-{}".format(src_idx, tgt_idx)
for src_idx, tgt_idx in hypo["alignment"]
]
)
)
)
return outputs
def encode(self, sentence: str) -> torch.LongTensor:
sentence = self.tokenize(sentence)
sentence = self.apply_bpe(sentence)
return self.binarize(sentence)
def decode(self, tokens: torch.LongTensor) -> str:
sentence = self.string(tokens)
sentence = self.remove_bpe(sentence)
return self.detokenize(sentence)
def tokenize(self, sentence: str) -> str:
if self.tokenizer is not None:
sentence = self.tokenizer.encode(sentence)
return sentence
def detokenize(self, sentence: str) -> str:
if self.tokenizer is not None:
sentence = self.tokenizer.decode(sentence)
return sentence
def apply_bpe(self, sentence: str) -> str:
if self.bpe is not None:
sentence = self.bpe.encode(sentence)
return sentence
def remove_bpe(self, sentence: str) -> str:
if self.bpe is not None:
sentence = self.bpe.decode(sentence)
return sentence
def binarize(self, sentence: str) -> torch.LongTensor:
return self.src_dict.encode_line(sentence, add_if_not_exist=False).long()
def string(self, tokens: torch.LongTensor) -> str:
return self.tgt_dict.string(tokens)
def _build_batches(
self, tokens: List[List[int]], skip_invalid_size_inputs: bool
) -> Iterator[Dict[str, Any]]:
lengths = torch.LongTensor([t.numel() for t in tokens])
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.build_dataset_for_inference(tokens, lengths),
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=self.max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs,
disable_iterator_cache=True,
).next_epoch_itr(shuffle=False)
return batch_iterator
class BPEHubInterface(object):
"""PyTorch Hub interface for Byte-Pair Encoding (BPE)."""
def __init__(self, bpe, **kwargs):
super().__init__()
args = argparse.Namespace(bpe=bpe, **kwargs)
self.bpe = encoders.build_bpe(args)
assert self.bpe is not None
def encode(self, sentence: str) -> str:
return self.bpe.encode(sentence)
def decode(self, sentence: str) -> str:
return self.bpe.decode(sentence)
class TokenizerHubInterface(object):
"""PyTorch Hub interface for tokenization."""
def __init__(self, tokenizer, **kwargs):
super().__init__()
args = argparse.Namespace(tokenizer=tokenizer, **kwargs)
self.tokenizer = encoders.build_tokenizer(args)
assert self.tokenizer is not None
def encode(self, sentence: str) -> str:
return self.tokenizer.encode(sentence)
def decode(self, sentence: str) -> str:
return self.tokenizer.decode(sentence)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import uuid
from typing import Dict, Optional
from torch import Tensor
class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_incremental_state()
def init_incremental_state(self):
self._incremental_state_id = str(uuid.uuid4())
def _get_full_incremental_state_key(self, key: str) -> str:
return "{}.{}".format(self._incremental_state_id, key)
def get_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
return incremental_state
def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(
b for b in cls.__bases__ if b != FairseqIncrementalState
)
return cls
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment