Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import time
from collections import OrderedDict
from typing import Dict, Optional
try:
import torch
def type_as(a, b):
if torch.is_tensor(a) and torch.is_tensor(b):
return a.to(b)
else:
return a
except ImportError:
torch = None
def type_as(a, b):
return a
try:
import numpy as np
except ImportError:
np = None
class Meter(object):
"""Base class for Meters."""
def __init__(self):
pass
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
pass
def reset(self):
raise NotImplementedError
@property
def smoothed_value(self) -> float:
"""Smoothed value used for logging."""
raise NotImplementedError
def safe_round(number, ndigits):
if hasattr(number, "__round__"):
return round(number, ndigits)
elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
return safe_round(number.item(), ndigits)
elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
return safe_round(number.item(), ndigits)
else:
return number
class AverageMeter(Meter):
"""Computes and stores the average and current value"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.reset()
def reset(self):
self.val = None # most recent update
self.sum = 0 # sum from all updates
self.count = 0 # total n from all updates
def update(self, val, n=1):
if val is not None:
self.val = val
if n > 0:
self.sum = type_as(self.sum, val) + (val * n)
self.count = type_as(self.count, n) + n
def state_dict(self):
return {
"val": self.val,
"sum": self.sum,
"count": self.count,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.val = state_dict["val"]
self.sum = state_dict["sum"]
self.count = state_dict["count"]
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.sum / self.count if self.count > 0 else self.val
@property
def smoothed_value(self) -> float:
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class TimeMeter(Meter):
"""Computes the average occurrence of some event per second"""
def __init__(
self,
init: int = 0,
n: int = 0,
round: Optional[int] = None,
):
self.round = round
self.reset(init, n)
def reset(self, init=0, n=0):
self.init = init
self.start = time.perf_counter()
self.n = n
self.i = 0
def update(self, val=1):
self.n = type_as(self.n, val) + val
self.i += 1
def state_dict(self):
return {
"init": self.elapsed_time,
"n": self.n,
"round": self.round,
}
def load_state_dict(self, state_dict):
if "start" in state_dict:
# backwards compatibility for old state_dicts
self.reset(init=state_dict["init"])
else:
self.reset(init=state_dict["init"], n=state_dict["n"])
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return self.init + (time.perf_counter() - self.start)
@property
def smoothed_value(self) -> float:
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class StopwatchMeter(Meter):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.sum = 0
self.n = 0
self.start_time = None
def start(self):
self.start_time = time.perf_counter()
def stop(self, n=1, prehook=None):
if self.start_time is not None:
if prehook is not None:
prehook()
delta = time.perf_counter() - self.start_time
self.sum = self.sum + delta
self.n = type_as(self.n, n) + n
def reset(self):
self.sum = 0 # cumulative time during which stopwatch was active
self.n = 0 # total n across all start/stop
self.start()
def state_dict(self):
return {
"sum": self.sum,
"n": self.n,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.sum = state_dict["sum"]
self.n = state_dict["n"]
self.start_time = None
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.sum / self.n if self.n > 0 else self.sum
@property
def elapsed_time(self):
if self.start_time is None:
return 0.0
return time.perf_counter() - self.start_time
@property
def smoothed_value(self) -> float:
val = self.avg if self.sum > 0 else self.elapsed_time
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class MetersDict(OrderedDict):
"""A sorted dictionary of :class:`Meters`.
Meters are sorted according to a priority that is given when the
meter is first added to the dictionary.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.priorities = []
def __setitem__(self, key, value):
assert key not in self, "MetersDict doesn't support reassignment"
priority, value = value
bisect.insort(self.priorities, (priority, len(self.priorities), key))
super().__setitem__(key, value)
for _, _, key in self.priorities: # reorder dict to match priorities
self.move_to_end(key)
def add_meter(self, key, meter, priority):
self.__setitem__(key, (priority, meter))
def state_dict(self):
return [
(pri, key, self[key].__class__.__name__, self[key].state_dict())
for pri, _, key in self.priorities
# can't serialize DerivedMeter instances
if not isinstance(self[key], MetersDict._DerivedMeter)
]
def load_state_dict(self, state_dict):
self.clear()
self.priorities.clear()
for pri, key, meter_cls, meter_state in state_dict:
meter = globals()[meter_cls]()
meter.load_state_dict(meter_state)
self.add_meter(key, meter, pri)
def get_smoothed_value(self, key: str) -> float:
"""Get a single smoothed value."""
meter = self[key]
if isinstance(meter, MetersDict._DerivedMeter):
return meter.fn(self)
else:
return meter.smoothed_value
def get_smoothed_values(self) -> Dict[str, float]:
"""Get all smoothed values."""
return OrderedDict(
[
(key, self.get_smoothed_value(key))
for key in self.keys()
if not key.startswith("_")
]
)
def reset(self):
"""Reset Meter instances."""
for meter in self.values():
if isinstance(meter, MetersDict._DerivedMeter):
continue
meter.reset()
class _DerivedMeter(Meter):
"""A Meter whose values are derived from other Meters."""
def __init__(self, fn):
self.fn = fn
def reset(self):
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A standalone module for aggregating metrics.
Metrics can be logged from anywhere using the `log_*` functions defined
in this module. The logged values will be aggregated dynamically based
on the aggregation context in which the logging occurs. See the
:func:`aggregate` context manager for more details.
"""
import contextlib
import time
import uuid
from collections import OrderedDict, defaultdict
from typing import Callable, Dict, List, Optional
from .meters import *
# Aggregation contexts are considered "active" when inside the scope
# created by the :func:`aggregate` context manager.
_aggregators = OrderedDict()
_active_aggregators = OrderedDict()
_active_aggregators_cnt = defaultdict(lambda: 0)
def reset() -> None:
"""Reset all metrics aggregators."""
_aggregators.clear()
_active_aggregators.clear()
_active_aggregators_cnt.clear()
# The "default" aggregator observes all logged values.
_aggregators["default"] = MetersDict()
_active_aggregators["default"] = _aggregators["default"]
_active_aggregators_cnt["default"] = 1
reset()
@contextlib.contextmanager
def aggregate(name: Optional[str] = None, new_root: bool = False):
"""Context manager to aggregate metrics under a given name.
Aggregations can be nested. If *new_root* is ``False``, then logged
metrics will be recorded along the entire stack of nested
aggregators, including a global "default" aggregator. If *new_root*
is ``True``, then this aggregator will be the root of a new
aggregation stack, thus bypassing any parent aggregators.
Note that aggregation contexts are uniquely identified by their
*name* (e.g., train, valid). Creating a context with an existing
name will reuse the corresponding :class:`MetersDict` instance.
If no name is given, then a temporary aggregator will be created.
Usage::
with metrics.aggregate("train"):
for step, batch in enumerate(epoch):
with metrics.aggregate("train_inner") as agg:
metrics.log_scalar("loss", get_loss(batch))
if step % log_interval == 0:
print(agg.get_smoothed_value("loss"))
agg.reset()
print(metrics.get_smoothed_values("train")["loss"])
Args:
name (str): name of the aggregation. Defaults to a
random/temporary name if not given explicitly.
new_root (bool): make this aggregation the root of a new
aggregation stack.
"""
if name is None:
# generate a temporary name
name = str(uuid.uuid4())
assert name not in _aggregators
agg = MetersDict()
else:
assert name != "default"
agg = _aggregators.setdefault(name, MetersDict())
if new_root:
backup_aggregators = _active_aggregators.copy()
_active_aggregators.clear()
backup_aggregators_cnt = _active_aggregators_cnt.copy()
_active_aggregators_cnt.clear()
_active_aggregators[name] = agg
_active_aggregators_cnt[name] += 1
yield agg
_active_aggregators_cnt[name] -= 1
if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
del _active_aggregators[name]
if new_root:
_active_aggregators.clear()
_active_aggregators.update(backup_aggregators)
_active_aggregators_cnt.clear()
_active_aggregators_cnt.update(backup_aggregators_cnt)
def get_active_aggregators() -> List[MetersDict]:
return list(_active_aggregators.values())
def log_scalar(
key: str,
value: float,
weight: float = 1,
priority: int = 10,
round: Optional[int] = None,
):
"""Log a scalar value.
Args:
key (str): name of the field to log
value (float): value to log
weight (float): weight that this value contributes to the average.
A weight of 0 will always log the latest value.
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, AverageMeter(round=round), priority)
agg[key].update(value, weight)
def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20):
"""Log a scalar value derived from other meters.
Args:
key (str): name of the field to log
fn (Callable[[MetersDict], float]): function that takes a single
argument *meters* and returns the derived value
priority (int): smaller values are logged earlier in the output
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
def log_speed(
key: str,
value: float,
priority: int = 30,
round: Optional[int] = None,
):
"""Log the rate of some quantity per second.
Args:
key (str): name of the field to log
value (float): value to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, TimeMeter(round=round), priority)
agg[key].reset() # reset meter on the first call
else:
agg[key].update(value)
def log_start_time(key: str, priority: int = 40, round: Optional[int] = None):
"""Log the duration of some event in seconds.
The duration will be computed once :func:`log_stop_time` is called.
Args:
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, StopwatchMeter(round=round), priority)
agg[key].start()
def log_stop_time(key: str, weight: float = 0.0, prehook=None):
"""Log the duration of some event in seconds.
The duration will be computed since :func:`log_start_time` was called.
Set weight > 0 to report the average time instead of the sum.
Args:
key (str): name of the field to log
weight (float): weight that this time contributes to the average
prehook (function, no arguments): will be called before the timer
is stopped. For example, use prehook=torch.cuda.synchronize to
make sure all gpu operations are done before timer is stopped.
"""
for agg in get_active_aggregators():
if key in agg:
agg[key].stop(weight, prehook)
def log_custom(
new_meter_fn: Callable[[], Meter],
key: str,
*args,
priority: int = 50,
**kwargs,
):
"""Log using a custom Meter.
Any extra *args* or *kwargs* will be passed through to the Meter's
*update* method.
Args:
new_meter_fn (Callable[[], Meter]): function that returns a new
Meter instance
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, new_meter_fn(), priority)
agg[key].update(*args, **kwargs)
def reset_meter(name: str, key: str) -> None:
"""Reset Meter instance aggregated under a given *name* and *key*."""
meter = get_meter(name, key)
if meter is not None:
meter.reset()
def reset_meters(name: str) -> None:
"""Reset Meter instances aggregated under a given *name*."""
meters = get_meters(name)
if meters is not None:
meters.reset()
def get_meter(name: str, key: str) -> Meter:
"""Get a single Meter instance aggregated under *name* and *key*.
Returns:
Meter or None if no metrics have been logged under *name* and *key*.
"""
if name not in _aggregators:
return None
return _aggregators[name].get(key, None)
def get_meters(name: str) -> MetersDict:
"""Get Meter instances aggregated under a given *name*.
Returns:
MetersDict or None if no metrics have been logged under *name*.
"""
return _aggregators.get(name, None)
def get_smoothed_value(name: str, key: str) -> float:
"""Get a single smoothed value.
Raises:
KeyError: if no metrics have been logged under *name* and *key*.
"""
return _aggregators[name].get_smoothed_value(key)
def get_smoothed_values(name: str) -> Dict[str, float]:
"""Get smoothed values aggregated under a given *name*.
Raises:
KeyError: if no metrics have been logged under *name*.
"""
return _aggregators[name].get_smoothed_values()
def state_dict():
return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
def load_state_dict(state_dict):
for name, agg_state in state_dict.items():
_aggregators[name] = MetersDict()
_aggregators[name].load_state_dict(agg_state)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
import atexit
import json
import logging
import os
import sys
from collections import OrderedDict
from contextlib import contextmanager
from numbers import Number
from typing import Optional
import torch
from .meters import AverageMeter, StopwatchMeter, TimeMeter
logger = logging.getLogger(__name__)
def progress_bar(
iterator,
log_format: Optional[str] = None,
log_interval: int = 100,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
default_log_format: str = "tqdm",
):
if log_format is None:
log_format = default_log_format
if log_format == "tqdm" and not sys.stderr.isatty():
log_format = "simple"
if log_format == "json":
bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
elif log_format == "none":
bar = NoopProgressBar(iterator, epoch, prefix)
elif log_format == "simple":
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
elif log_format == "tqdm":
bar = TqdmProgressBar(iterator, epoch, prefix)
else:
raise ValueError("Unknown log format: {}".format(log_format))
if tensorboard_logdir:
try:
# [FB only] custom wrapper for TensorBoard
import palaas # noqa
from .fb_tbmf_wrapper import FbTbmfWrapper
bar = FbTbmfWrapper(bar, log_interval)
except ImportError:
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
return bar
def build_progress_bar(
args,
iterator,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
default: str = "tqdm",
no_progress_bar: str = "none",
):
"""Legacy wrapper that takes an argparse.Namespace."""
if getattr(args, "no_progress_bar", False):
default = no_progress_bar
if getattr(args, "distributed_rank", 0) == 0:
tensorboard_logdir = getattr(args, "tensorboard_logdir", None)
else:
tensorboard_logdir = None
return progress_bar(
iterator,
log_format=args.log_format,
log_interval=args.log_interval,
epoch=epoch,
prefix=prefix,
tensorboard_logdir=tensorboard_logdir,
default_log_format=default,
)
def format_stat(stat):
if isinstance(stat, Number):
stat = "{:g}".format(stat)
elif isinstance(stat, AverageMeter):
stat = "{:.3f}".format(stat.avg)
elif isinstance(stat, TimeMeter):
stat = "{:g}".format(round(stat.avg))
elif isinstance(stat, StopwatchMeter):
stat = "{:g}".format(round(stat.sum))
elif torch.is_tensor(stat):
stat = stat.tolist()
return stat
class BaseProgressBar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.n = getattr(iterable, "n", 0)
self.epoch = epoch
self.prefix = ""
if epoch is not None:
self.prefix += "epoch {:03d}".format(epoch)
if prefix is not None:
self.prefix += " | {}".format(prefix)
def __len__(self):
return len(self.iterable)
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def __iter__(self):
raise NotImplementedError
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
raise NotImplementedError
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
raise NotImplementedError
def _str_commas(self, stats):
return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
def _str_pipes(self, stats):
return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype
for key in postfix.keys():
postfix[key] = str(format_stat(postfix[key]))
return postfix
@contextmanager
def rename_logger(logger, new_name):
old_name = logger.name
if new_name is not None:
logger.name = new_name
yield logger
logger.name = old_name
class JsonProgressBar(BaseProgressBar):
"""Log output in JSON format."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.i = None
self.size = None
def __iter__(self):
self.size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.n):
self.i = i
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
step = step or self.i or 0
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
update = (
self.epoch - 1 + (self.i + 1) / float(self.size)
if self.epoch is not None
else None
)
stats = self._format_stats(stats, epoch=self.epoch, update=update)
with rename_logger(logger, tag):
logger.info(json.dumps(stats))
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self.stats = stats
if tag is not None:
self.stats = OrderedDict(
[(tag + "_" + k, v) for k, v in self.stats.items()]
)
stats = self._format_stats(self.stats, epoch=self.epoch)
with rename_logger(logger, tag):
logger.info(json.dumps(stats))
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
postfix["epoch"] = epoch
if update is not None:
postfix["update"] = round(update, 3)
# Preprocess stats according to datatype
for key in stats.keys():
postfix[key] = format_stat(stats[key])
return postfix
class NoopProgressBar(BaseProgressBar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
pass
class SimpleProgressBar(BaseProgressBar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.i = None
self.size = None
def __iter__(self):
self.size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.n):
self.i = i
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
step = step or self.i or 0
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
stats = self._format_stats(stats)
postfix = self._str_commas(stats)
with rename_logger(logger, tag):
logger.info(
"{}: {:5d} / {:d} {}".format(
self.prefix, self.i + 1, self.size, postfix
)
)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
with rename_logger(logger, tag):
logger.info("{} | {}".format(self.prefix, postfix))
class TqdmProgressBar(BaseProgressBar):
"""Log to tqdm."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
from tqdm import tqdm
self.tqdm = tqdm(
iterable,
self.prefix,
leave=False,
disable=(logger.getEffectiveLevel() > logging.INFO),
)
def __iter__(self):
return iter(self.tqdm)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
with rename_logger(logger, tag):
logger.info("{} | {}".format(self.prefix, postfix))
try:
_tensorboard_writers = {}
from tensorboardX import SummaryWriter
except ImportError:
SummaryWriter = None
def _close_writers():
for w in _tensorboard_writers.values():
w.close()
atexit.register(_close_writers)
class TensorboardProgressBarWrapper(BaseProgressBar):
"""Log to tensorboard."""
def __init__(self, wrapped_bar, tensorboard_logdir):
self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir
if SummaryWriter is None:
logger.warning(
"tensorboard not found, please install with: pip install tensorboardX"
)
def _writer(self, key):
if SummaryWriter is None:
return None
_writers = _tensorboard_writers
if key not in _writers:
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
_writers[key].add_text("sys.argv", " ".join(sys.argv))
return _writers[key]
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats to tensorboard."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def _log_to_tensorboard(self, stats, tag=None, step=None):
writer = self._writer(tag or "")
if writer is None:
return
if step is None:
step = stats["num_updates"]
for key in stats.keys() - {"num_updates"}:
if isinstance(stats[key], AverageMeter):
writer.add_scalar(key, stats[key].val, step)
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)
writer.flush()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import criterions, models, modules # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("fairseq.model_parallel.criterions." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
try:
from fairseq.model_parallel.megatron.mpu.cross_entropy import (
vocab_parallel_cross_entropy,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@register_criterion("vocab_parallel_cross_entropy")
class VocabParallelCrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
target = sample["target"]
loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
loss = (loss * (target != self.padding_idx)).sum()
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
from fairseq import distributed_utils
from fairseq.trainer import Trainer
try:
from fairseq.model_parallel.megatron.mpu import (
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_model_parallel_group,
get_model_parallel_src_rank,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class MegatronTrainer(Trainer):
"""Main class for model parallel with data parallel training."""
def __init__(self, args, task, model, criterion):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
super().__init__(args, task, model, criterion)
@property
def data_parallel_world_size(self):
return get_data_parallel_world_size()
@property
def data_parallel_process_group(self):
return get_data_parallel_group()
@property
def data_parallel_rank(self):
return get_data_parallel_rank()
@property
def is_data_parallel_master(self):
return get_model_parallel_src_rank() == 0
def clip_grad_norm(self, clip_norm):
def _aggregate_model_parallel_grad_norm(total_norm):
total_norm = total_norm ** 2
distributed_utils.all_reduce(total_norm, group=get_model_parallel_group())
total_norm = total_norm ** 0.5
return total_norm
return self.optimizer.clip_grad_norm(
clip_norm,
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
model_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("fairseq.model_parallel.models." + model_name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment