Commit 7143f128 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'hepj-test' into 'main'

更新transformer代码

See merge request dcutoolkit/deeplearing/dlexamples_new!47
parents a30b77fe c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
import numpy as np
import torch
from fairseq import utils
DecoderOut = namedtuple(
"IterativeRefinementDecoderOut",
["output_tokens", "output_scores", "attn", "step", "max_step", "history"],
)
class IterativeRefinementGenerator(object):
def __init__(
self,
tgt_dict,
models=None,
eos_penalty=0.0,
max_iter=10,
max_ratio=2,
beam_size=1,
decoding_format=None,
retain_dropout=False,
adaptive=True,
retain_history=False,
reranking=False,
):
"""
Generates translations based on iterative refinement.
Args:
tgt_dict: target dictionary
eos_penalty: if > 0.0, it penalized early-stopping in decoding
max_iter: maximum number of refinement iterations
max_ratio: generate sequences of maximum length ax, where x is the source length
decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'}
retain_dropout: retaining dropout in the inference
adaptive: decoding with early stop
"""
self.bos = tgt_dict.bos()
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.eos_penalty = eos_penalty
self.max_iter = max_iter
self.max_ratio = max_ratio
self.beam_size = beam_size
self.reranking = reranking
self.decoding_format = decoding_format
self.retain_dropout = retain_dropout
self.retain_history = retain_history
self.adaptive = adaptive
self.models = models
def generate_batched_itr(
self,
data_itr,
maxlen_a=None,
maxlen_b=None,
cuda=False,
timer=None,
prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
for sample in data_itr:
if "net_input" not in sample:
continue
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(
self.models,
sample,
prefix_tokens=sample["target"][:, :prefix_size]
if prefix_size > 0
else None,
)
if timer is not None:
timer.stop(sample["ntokens"])
for i, id in enumerate(sample["id"]):
# remove padding
src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad)
ref = utils.strip_pad(sample["target"][i, :], self.pad)
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None, constraints=None):
if constraints is not None:
raise NotImplementedError(
"Constrained decoding with the IterativeRefinementGenerator is not supported"
)
# TODO: iterative refinement generator does not support ensemble for now.
if not self.retain_dropout:
for model in models:
model.eval()
model, reranker = models[0], None
if self.reranking:
assert len(models) > 1, "Assuming the last checkpoint is the reranker"
assert (
self.beam_size > 1
), "Reranking requires multiple translation for each example"
reranker = models[-1]
models = models[:-1]
if len(models) > 1 and hasattr(model, "enable_ensemble"):
assert model.allow_ensemble, "{} does not support ensembling".format(
model.__class__.__name__
)
model.enable_ensemble(models)
# TODO: better encoder inputs?
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
bsz, src_len = src_tokens.size()
# initialize
encoder_out = model.forward_encoder([src_tokens, src_lengths])
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
if self.beam_size > 1:
assert (
model.allow_length_beam
), "{} does not support decoding with length beam.".format(
model.__class__.__name__
)
# regenerate data based on length-beam
length_beam_order = (
utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1)
)
encoder_out = model.encoder.reorder_encoder_out(
encoder_out, length_beam_order
)
prev_decoder_out = model.regenerate_length_beam(
prev_decoder_out, self.beam_size
)
bsz = bsz * self.beam_size
sent_idxs = torch.arange(bsz)
prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.retain_history:
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])
finalized = [[] for _ in range(bsz)]
def is_a_loop(x, y, s, a):
b, l_x, l_y = x.size(0), x.size(1), y.size(1)
if l_x > l_y:
y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1)
s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1)
if a is not None:
a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1)
elif l_x < l_y:
x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1)
return (x == y).all(1), y, s, a
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
cutoff = prev_out_token.ne(self.pad)
tokens = prev_out_token[cutoff]
if prev_out_score is None:
scores, score = None, None
else:
scores = prev_out_score[cutoff]
score = scores.mean()
if prev_out_attn is None:
hypo_attn, alignment = None, None
else:
hypo_attn = prev_out_attn[cutoff]
alignment = hypo_attn.max(dim=1)[1]
return {
"steps": step,
"tokens": tokens,
"positional_scores": scores,
"score": score,
"hypo_attn": hypo_attn,
"alignment": alignment,
}
for step in range(self.max_iter + 1):
decoder_options = {
"eos_penalty": self.eos_penalty,
"max_ratio": self.max_ratio,
"decoding_format": self.decoding_format,
}
prev_decoder_out = prev_decoder_out._replace(
step=step,
max_step=self.max_iter + 1,
)
decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options
)
if self.adaptive:
# terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_output_tokens,
decoder_out.output_tokens,
decoder_out.output_scores,
decoder_out.attn,
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
output_scores=out_scores,
attn=out_attn,
)
else:
terminated = decoder_out.output_tokens.new_zeros(
decoder_out.output_tokens.size(0)
).bool()
if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1)
# collect finalized sentences
finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None
if (decoder_out.attn is None or decoder_out.attn.size(0) == 0)
else decoder_out.attn[terminated]
)
if self.retain_history:
finalized_history_tokens = [h[terminated] for h in decoder_out.history]
for i in range(finalized_idxs.size(0)):
finalized[finalized_idxs[i]] = [
finalized_hypos(
step,
finalized_tokens[i],
finalized_scores[i],
None if finalized_attn is None else finalized_attn[i],
)
]
if self.retain_history:
finalized[finalized_idxs[i]][0]["history"] = []
for j in range(len(finalized_history_tokens)):
finalized[finalized_idxs[i]][0]["history"].append(
finalized_hypos(
step, finalized_history_tokens[j][i], None, None
)
)
# check if all terminated
if terminated.sum() == terminated.size(0):
break
# for next step
not_terminated = ~terminated
prev_decoder_out = decoder_out._replace(
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated]
if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0)
else None,
history=[h[not_terminated] for h in decoder_out.history]
if decoder_out.history is not None
else None,
)
encoder_out = model.encoder.reorder_encoder_out(
encoder_out, not_terminated.nonzero(as_tuple=False).squeeze()
)
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.beam_size > 1:
if reranker is not None:
finalized = self.rerank(
reranker, finalized, [src_tokens, src_lengths], self.beam_size
)
# aggregate information from length beam
finalized = [
finalized[
np.argmax(
[
finalized[self.beam_size * i + j][0]["score"]
for j in range(self.beam_size)
]
)
+ self.beam_size * i
]
for i in range(len(finalized) // self.beam_size)
]
return finalized
def rerank(self, reranker, finalized, encoder_input, beam_size):
def rebuild_batch(finalized):
finalized_tokens = [f[0]["tokens"] for f in finalized]
finalized_maxlen = max(f.size(0) for f in finalized_tokens)
final_output_tokens = (
finalized_tokens[0]
.new_zeros(len(finalized_tokens), finalized_maxlen)
.fill_(self.pad)
)
for i, f in enumerate(finalized_tokens):
final_output_tokens[i, : f.size(0)] = f
return final_output_tokens
final_output_tokens = rebuild_batch(finalized)
final_output_tokens[
:, 0
] = self.eos # autoregressive model assumes starting with EOS
reranker_encoder_out = reranker.encoder(*encoder_input)
length_beam_order = (
utils.new_arange(
final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1)
)
.t()
.reshape(-1)
)
reranker_encoder_out = reranker.encoder.reorder_encoder_out(
reranker_encoder_out, length_beam_order
)
reranking_scores = reranker.get_normalized_probs(
reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out),
True,
None,
)
reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None])
reranking_masks = final_output_tokens[:, 1:].ne(self.pad)
reranking_scores = (
reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1)
)
reranking_scores = reranking_scores / reranking_masks.sum(1).type_as(
reranking_scores
)
for i in range(len(finalized)):
finalized[i][0]["score"] = reranking_scores[i]
return finalized
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import time
from collections import OrderedDict
from typing import Dict, Optional
try:
import torch
def type_as(a, b):
if torch.is_tensor(a) and torch.is_tensor(b):
return a.to(b)
else:
return a
except ImportError:
torch = None
def type_as(a, b):
return a
try:
import numpy as np
except ImportError:
np = None
class Meter(object):
"""Base class for Meters."""
def __init__(self):
pass
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
pass
def reset(self):
raise NotImplementedError
@property
def smoothed_value(self) -> float:
"""Smoothed value used for logging."""
raise NotImplementedError
def safe_round(number, ndigits):
if hasattr(number, "__round__"):
return round(number, ndigits)
elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
return safe_round(number.item(), ndigits)
elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
return safe_round(number.item(), ndigits)
else:
return number
class AverageMeter(Meter):
"""Computes and stores the average and current value"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.reset()
def reset(self):
self.val = None # most recent update
self.sum = 0 # sum from all updates
self.count = 0 # total n from all updates
def update(self, val, n=1):
if val is not None:
self.val = val
if n > 0:
self.sum = type_as(self.sum, val) + (val * n)
self.count = type_as(self.count, n) + n
def state_dict(self):
return {
"val": self.val,
"sum": self.sum,
"count": self.count,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.val = state_dict["val"]
self.sum = state_dict["sum"]
self.count = state_dict["count"]
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.sum / self.count if self.count > 0 else self.val
@property
def smoothed_value(self) -> float:
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class SumMeter(Meter):
"""Computes and stores the sum"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.reset()
def reset(self):
self.sum = 0 # sum from all updates
def update(self, val):
if val is not None:
self.sum = type_as(self.sum, val) + val
def state_dict(self):
return {
"sum": self.sum,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.sum = state_dict["sum"]
self.round = state_dict.get("round", None)
@property
def smoothed_value(self) -> float:
val = self.sum
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class TimeMeter(Meter):
"""Computes the average occurrence of some event per second"""
def __init__(
self,
init: int = 0,
n: int = 0,
round: Optional[int] = None,
):
self.round = round
self.reset(init, n)
def reset(self, init=0, n=0):
self.init = init
self.start = time.perf_counter()
self.n = n
self.i = 0
def update(self, val=1):
self.n = type_as(self.n, val) + val
self.i += 1
def state_dict(self):
return {
"init": self.elapsed_time,
"n": self.n,
"round": self.round,
}
def load_state_dict(self, state_dict):
if "start" in state_dict:
# backwards compatibility for old state_dicts
self.reset(init=state_dict["init"])
else:
self.reset(init=state_dict["init"], n=state_dict["n"])
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return self.init + (time.perf_counter() - self.start)
@property
def smoothed_value(self) -> float:
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class StopwatchMeter(Meter):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.sum = 0
self.n = 0
self.start_time = None
def start(self):
self.start_time = time.perf_counter()
def stop(self, n=1, prehook=None):
if self.start_time is not None:
if prehook is not None:
prehook()
delta = time.perf_counter() - self.start_time
self.sum = self.sum + delta
self.n = type_as(self.n, n) + n
def reset(self):
self.sum = 0 # cumulative time during which stopwatch was active
self.n = 0 # total n across all start/stop
self.start()
def state_dict(self):
return {
"sum": self.sum,
"n": self.n,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.sum = state_dict["sum"]
self.n = state_dict["n"]
self.start_time = None
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.sum / self.n if self.n > 0 else self.sum
@property
def elapsed_time(self):
if self.start_time is None:
return 0.0
return time.perf_counter() - self.start_time
@property
def smoothed_value(self) -> float:
val = self.avg if self.sum > 0 else self.elapsed_time
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class MetersDict(OrderedDict):
"""A sorted dictionary of :class:`Meters`.
Meters are sorted according to a priority that is given when the
meter is first added to the dictionary.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.priorities = []
def __setitem__(self, key, value):
assert key not in self, "MetersDict doesn't support reassignment"
priority, value = value
bisect.insort(self.priorities, (priority, len(self.priorities), key))
super().__setitem__(key, value)
for _, _, key in self.priorities: # reorder dict to match priorities
self.move_to_end(key)
def add_meter(self, key, meter, priority):
self.__setitem__(key, (priority, meter))
def state_dict(self):
return [
(pri, key, self[key].__class__.__name__, self[key].state_dict())
for pri, _, key in self.priorities
# can't serialize DerivedMeter instances
if not isinstance(self[key], MetersDict._DerivedMeter)
]
def load_state_dict(self, state_dict):
self.clear()
self.priorities.clear()
for pri, key, meter_cls, meter_state in state_dict:
meter = globals()[meter_cls]()
meter.load_state_dict(meter_state)
self.add_meter(key, meter, pri)
def get_smoothed_value(self, key: str) -> float:
"""Get a single smoothed value."""
meter = self[key]
if isinstance(meter, MetersDict._DerivedMeter):
return meter.fn(self)
else:
return meter.smoothed_value
def get_smoothed_values(self) -> Dict[str, float]:
"""Get all smoothed values."""
return OrderedDict(
[
(key, self.get_smoothed_value(key))
for key in self.keys()
if not key.startswith("_")
]
)
def reset(self):
"""Reset Meter instances."""
for meter in self.values():
if isinstance(meter, MetersDict._DerivedMeter):
continue
meter.reset()
class _DerivedMeter(Meter):
"""A Meter whose values are derived from other Meters."""
def __init__(self, fn):
self.fn = fn
def reset(self):
pass
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A standalone module for aggregating metrics.
Metrics can be logged from anywhere using the `log_*` functions defined
in this module. The logged values will be aggregated dynamically based
on the aggregation context in which the logging occurs. See the
:func:`aggregate` context manager for more details.
"""
import contextlib
import uuid
from collections import defaultdict
from typing import Callable, List, Optional
from .meters import *
# Aggregation contexts are considered "active" when inside the scope
# created by the :func:`aggregate` context manager.
_aggregators = OrderedDict()
_active_aggregators = OrderedDict()
_active_aggregators_cnt = defaultdict(lambda: 0)
def reset() -> None:
"""Reset all metrics aggregators."""
_aggregators.clear()
_active_aggregators.clear()
_active_aggregators_cnt.clear()
# The "default" aggregator observes all logged values.
_aggregators["default"] = MetersDict()
_active_aggregators["default"] = _aggregators["default"]
_active_aggregators_cnt["default"] = 1
reset()
@contextlib.contextmanager
def aggregate(name: Optional[str] = None, new_root: bool = False):
"""Context manager to aggregate metrics under a given name.
Aggregations can be nested. If *new_root* is ``False``, then logged
metrics will be recorded along the entire stack of nested
aggregators, including a global "default" aggregator. If *new_root*
is ``True``, then this aggregator will be the root of a new
aggregation stack, thus bypassing any parent aggregators.
Note that aggregation contexts are uniquely identified by their
*name* (e.g., train, valid). Creating a context with an existing
name will reuse the corresponding :class:`MetersDict` instance.
If no name is given, then a temporary aggregator will be created.
Usage::
with metrics.aggregate("train"):
for step, batch in enumerate(epoch):
with metrics.aggregate("train_inner") as agg:
metrics.log_scalar("loss", get_loss(batch))
if step % log_interval == 0:
print(agg.get_smoothed_value("loss"))
agg.reset()
print(metrics.get_smoothed_values("train")["loss"])
Args:
name (str): name of the aggregation. Defaults to a
random/temporary name if not given explicitly.
new_root (bool): make this aggregation the root of a new
aggregation stack.
"""
if name is None:
# generate a temporary name
name = str(uuid.uuid4())
assert name not in _aggregators
agg = MetersDict()
else:
assert name != "default"
agg = _aggregators.setdefault(name, MetersDict())
if new_root:
backup_aggregators = _active_aggregators.copy()
_active_aggregators.clear()
backup_aggregators_cnt = _active_aggregators_cnt.copy()
_active_aggregators_cnt.clear()
_active_aggregators[name] = agg
_active_aggregators_cnt[name] += 1
yield agg
_active_aggregators_cnt[name] -= 1
if _active_aggregators_cnt[name] == 0 and name in _active_aggregators:
del _active_aggregators[name]
if new_root:
_active_aggregators.clear()
_active_aggregators.update(backup_aggregators)
_active_aggregators_cnt.clear()
_active_aggregators_cnt.update(backup_aggregators_cnt)
def get_active_aggregators() -> List[MetersDict]:
return list(_active_aggregators.values())
def log_scalar(
key: str,
value: float,
weight: float = 1,
priority: int = 10,
round: Optional[int] = None,
):
"""Log a scalar value.
Args:
key (str): name of the field to log
value (float): value to log
weight (float): weight that this value contributes to the average.
A weight of 0 will always log the latest value.
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, AverageMeter(round=round), priority)
agg[key].update(value, weight)
def log_scalar_sum(
key: str,
value: float,
priority: int = 10,
round: Optional[int] = None,
):
"""Log a scalar value that is summed for reporting.
Args:
key (str): name of the field to log
value (float): value to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, SumMeter(round=round), priority)
agg[key].update(value)
def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20):
"""Log a scalar value derived from other meters.
Args:
key (str): name of the field to log
fn (Callable[[MetersDict], float]): function that takes a single
argument *meters* and returns the derived value
priority (int): smaller values are logged earlier in the output
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, MetersDict._DerivedMeter(fn), priority)
def log_speed(
key: str,
value: float,
priority: int = 30,
round: Optional[int] = None,
):
"""Log the rate of some quantity per second.
Args:
key (str): name of the field to log
value (float): value to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, TimeMeter(round=round), priority)
agg[key].reset() # reset meter on the first call
else:
agg[key].update(value)
def log_start_time(key: str, priority: int = 40, round: Optional[int] = None):
"""Log the duration of some event in seconds.
The duration will be computed once :func:`log_stop_time` is called.
Args:
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, StopwatchMeter(round=round), priority)
agg[key].start()
def log_stop_time(key: str, weight: float = 0.0, prehook=None):
"""Log the duration of some event in seconds.
The duration will be computed since :func:`log_start_time` was called.
Set weight > 0 to report the average time instead of the sum.
Args:
key (str): name of the field to log
weight (float): weight that this time contributes to the average
prehook (function, no arguments): will be called before the timer
is stopped. For example, use prehook=torch.cuda.synchronize to
make sure all gpu operations are done before timer is stopped.
"""
for agg in get_active_aggregators():
if key in agg:
agg[key].stop(weight, prehook)
def log_custom(
new_meter_fn: Callable[[], Meter],
key: str,
*args,
priority: int = 50,
**kwargs,
):
"""Log using a custom Meter.
Any extra *args* or *kwargs* will be passed through to the Meter's
*update* method.
Args:
new_meter_fn (Callable[[], Meter]): function that returns a new
Meter instance
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
"""
for agg in get_active_aggregators():
if key not in agg:
agg.add_meter(key, new_meter_fn(), priority)
agg[key].update(*args, **kwargs)
def reset_meter(name: str, key: str) -> None:
"""Reset Meter instance aggregated under a given *name* and *key*."""
meter = get_meter(name, key)
if meter is not None:
meter.reset()
def reset_meters(name: str) -> None:
"""Reset Meter instances aggregated under a given *name*."""
meters = get_meters(name)
if meters is not None:
meters.reset()
def get_meter(name: str, key: str) -> Meter:
"""Get a single Meter instance aggregated under *name* and *key*.
Returns:
Meter or None if no metrics have been logged under *name* and *key*.
"""
if name not in _aggregators:
return None
return _aggregators[name].get(key, None)
def get_meters(name: str) -> MetersDict:
"""Get Meter instances aggregated under a given *name*.
Returns:
MetersDict or None if no metrics have been logged under *name*.
"""
return _aggregators.get(name, None)
def get_smoothed_value(name: str, key: str) -> float:
"""Get a single smoothed value.
Raises:
KeyError: if no metrics have been logged under *name* and *key*.
"""
return _aggregators[name].get_smoothed_value(key)
def get_smoothed_values(name: str) -> Dict[str, float]:
"""Get smoothed values aggregated under a given *name*.
Raises:
KeyError: if no metrics have been logged under *name*.
"""
return _aggregators[name].get_smoothed_values()
def state_dict():
return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()])
def load_state_dict(state_dict):
for name, agg_state in state_dict.items():
_aggregators[name] = MetersDict()
_aggregators[name].load_state_dict(agg_state)
def xla_metrics_report():
try:
import torch_xla.debug.metrics as met
print(met.metrics_report())
except ImportError:
return
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
import atexit
import json
import logging
import os
import sys
from collections import OrderedDict
from contextlib import contextmanager
from numbers import Number
from typing import Optional
import torch
from .meters import AverageMeter, StopwatchMeter, TimeMeter
logger = logging.getLogger(__name__)
def progress_bar(
iterator,
log_format: Optional[str] = None,
log_interval: int = 100,
log_file: Optional[str] = None,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
aim_repo: Optional[str] = None,
aim_run_hash: Optional[str] = None,
aim_param_checkpoint_dir: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
default_log_format: str = "tqdm",
wandb_project: Optional[str] = None,
wandb_run_name: Optional[str] = None,
azureml_logging: Optional[bool] = False,
):
if log_format is None:
log_format = default_log_format
if log_file is not None:
handler = logging.FileHandler(filename=log_file)
logger.addHandler(handler)
if log_format == "tqdm" and not sys.stderr.isatty():
log_format = "simple"
if log_format == "json":
bar = JsonProgressBar(iterator, epoch, prefix, log_interval)
elif log_format == "none":
bar = NoopProgressBar(iterator, epoch, prefix)
elif log_format == "simple":
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval)
elif log_format == "tqdm":
bar = TqdmProgressBar(iterator, epoch, prefix)
else:
raise ValueError("Unknown log format: {}".format(log_format))
if aim_repo:
bar = AimProgressBarWrapper(
bar,
aim_repo=aim_repo,
aim_run_hash=aim_run_hash,
aim_param_checkpoint_dir=aim_param_checkpoint_dir,
)
if tensorboard_logdir:
try:
# [FB only] custom wrapper for TensorBoard
import palaas # noqa
from .fb_tbmf_wrapper import FbTbmfWrapper
bar = FbTbmfWrapper(bar, log_interval)
except ImportError:
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
if wandb_project:
bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name)
if azureml_logging:
bar = AzureMLProgressBarWrapper(bar)
return bar
def build_progress_bar(
args,
iterator,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
default: str = "tqdm",
no_progress_bar: str = "none",
):
"""Legacy wrapper that takes an argparse.Namespace."""
if getattr(args, "no_progress_bar", False):
default = no_progress_bar
if getattr(args, "distributed_rank", 0) == 0:
tensorboard_logdir = getattr(args, "tensorboard_logdir", None)
else:
tensorboard_logdir = None
return progress_bar(
iterator,
log_format=args.log_format,
log_interval=args.log_interval,
epoch=epoch,
prefix=prefix,
tensorboard_logdir=tensorboard_logdir,
default_log_format=default,
)
def format_stat(stat):
if isinstance(stat, Number):
stat = "{:g}".format(stat)
elif isinstance(stat, AverageMeter):
stat = "{:.3f}".format(stat.avg)
elif isinstance(stat, TimeMeter):
stat = "{:g}".format(round(stat.avg))
elif isinstance(stat, StopwatchMeter):
stat = "{:g}".format(round(stat.sum))
elif torch.is_tensor(stat):
stat = stat.tolist()
return stat
class BaseProgressBar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.n = getattr(iterable, "n", 0)
self.epoch = epoch
self.prefix = ""
if epoch is not None:
self.prefix += "epoch {:03d}".format(epoch)
if prefix is not None:
self.prefix += (" | " if self.prefix != "" else "") + prefix
def __len__(self):
return len(self.iterable)
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def __iter__(self):
raise NotImplementedError
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
raise NotImplementedError
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
raise NotImplementedError
def update_config(self, config):
"""Log latest configuration."""
pass
def _str_commas(self, stats):
return ", ".join(key + "=" + stats[key].strip() for key in stats.keys())
def _str_pipes(self, stats):
return " | ".join(key + " " + stats[key].strip() for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype
for key in postfix.keys():
postfix[key] = str(format_stat(postfix[key]))
return postfix
@contextmanager
def rename_logger(logger, new_name):
old_name = logger.name
if new_name is not None:
logger.name = new_name
yield logger
logger.name = old_name
class JsonProgressBar(BaseProgressBar):
"""Log output in JSON format."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.i = None
self.size = None
def __iter__(self):
self.size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.n):
self.i = i
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
step = step or self.i or 0
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
update = (
self.epoch - 1 + (self.i + 1) / float(self.size)
if self.epoch is not None
else None
)
stats = self._format_stats(stats, epoch=self.epoch, update=update)
with rename_logger(logger, tag):
logger.info(json.dumps(stats))
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self.stats = stats
if tag is not None:
self.stats = OrderedDict(
[(tag + "_" + k, v) for k, v in self.stats.items()]
)
stats = self._format_stats(self.stats, epoch=self.epoch)
with rename_logger(logger, tag):
logger.info(json.dumps(stats))
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
postfix["epoch"] = epoch
if update is not None:
postfix["update"] = round(update, 3)
# Preprocess stats according to datatype
for key in stats.keys():
postfix[key] = format_stat(stats[key])
return postfix
class NoopProgressBar(BaseProgressBar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
pass
class SimpleProgressBar(BaseProgressBar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.i = None
self.size = None
def __iter__(self):
self.size = len(self.iterable)
for i, obj in enumerate(self.iterable, start=self.n):
self.i = i
yield obj
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
step = step or self.i or 0
if step > 0 and self.log_interval is not None and step % self.log_interval == 0:
stats = self._format_stats(stats)
postfix = self._str_commas(stats)
with rename_logger(logger, tag):
logger.info(
"{}: {:5d} / {:d} {}".format(
self.prefix, self.i + 1, self.size, postfix
)
)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
with rename_logger(logger, tag):
logger.info("{} | {}".format(self.prefix, postfix))
class TqdmProgressBar(BaseProgressBar):
"""Log to tqdm."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
from tqdm import tqdm
self.tqdm = tqdm(
iterable,
self.prefix,
leave=False,
disable=(logger.getEffectiveLevel() > logging.INFO),
)
def __iter__(self):
return iter(self.tqdm)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
with rename_logger(logger, tag):
logger.info("{} | {}".format(self.prefix, postfix))
try:
import functools
from aim import Repo as AimRepo
@functools.lru_cache()
def get_aim_run(repo, run_hash):
from aim import Run
return Run(run_hash=run_hash, repo=repo)
except ImportError:
get_aim_run = None
AimRepo = None
class AimProgressBarWrapper(BaseProgressBar):
"""Log to Aim."""
def __init__(self, wrapped_bar, aim_repo, aim_run_hash, aim_param_checkpoint_dir):
self.wrapped_bar = wrapped_bar
if get_aim_run is None:
self.run = None
logger.warning("Aim not found, please install with: pip install aim")
else:
logger.info(f"Storing logs at Aim repo: {aim_repo}")
if not aim_run_hash:
# Find run based on save_dir parameter
query = f"run.checkpoint.save_dir == '{aim_param_checkpoint_dir}'"
try:
runs_generator = AimRepo(aim_repo).query_runs(query)
run = next(runs_generator.iter_runs())
aim_run_hash = run.run.hash
except Exception:
pass
if aim_run_hash:
logger.info(f"Appending to run: {aim_run_hash}")
self.run = get_aim_run(aim_repo, aim_run_hash)
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats to Aim."""
self._log_to_aim(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self._log_to_aim(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def update_config(self, config):
"""Log latest configuration."""
if self.run is not None:
for key in config:
self.run.set(key, config[key], strict=False)
self.wrapped_bar.update_config(config)
def _log_to_aim(self, stats, tag=None, step=None):
if self.run is None:
return
if step is None:
step = stats["num_updates"]
if "train" in tag:
context = {"tag": tag, "subset": "train"}
elif "val" in tag:
context = {"tag": tag, "subset": "val"}
else:
context = {"tag": tag}
for key in stats.keys() - {"num_updates"}:
self.run.track(stats[key], name=key, step=step, context=context)
try:
_tensorboard_writers = {}
from torch.utils.tensorboard import SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
except ImportError:
SummaryWriter = None
def _close_writers():
for w in _tensorboard_writers.values():
w.close()
atexit.register(_close_writers)
class TensorboardProgressBarWrapper(BaseProgressBar):
"""Log to tensorboard."""
def __init__(self, wrapped_bar, tensorboard_logdir):
self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir
if SummaryWriter is None:
logger.warning(
"tensorboard not found, please install with: pip install tensorboard"
)
def _writer(self, key):
if SummaryWriter is None:
return None
_writers = _tensorboard_writers
if key not in _writers:
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
_writers[key].add_text("sys.argv", " ".join(sys.argv))
return _writers[key]
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats to tensorboard."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def update_config(self, config):
"""Log latest configuration."""
# TODO add hparams to Tensorboard
self.wrapped_bar.update_config(config)
def _log_to_tensorboard(self, stats, tag=None, step=None):
writer = self._writer(tag or "")
if writer is None:
return
if step is None:
step = stats["num_updates"]
for key in stats.keys() - {"num_updates"}:
if isinstance(stats[key], AverageMeter):
writer.add_scalar(key, stats[key].val, step)
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)
elif torch.is_tensor(stats[key]) and stats[key].numel() == 1:
writer.add_scalar(key, stats[key].item(), step)
writer.flush()
try:
import wandb
except ImportError:
wandb = None
class WandBProgressBarWrapper(BaseProgressBar):
"""Log to Weights & Biases."""
def __init__(self, wrapped_bar, wandb_project, run_name=None):
self.wrapped_bar = wrapped_bar
if wandb is None:
logger.warning("wandb not found, pip install wandb")
return
# reinit=False to ensure if wandb.init() is called multiple times
# within one process it still references the same run
wandb.init(project=wandb_project, reinit=False, name=run_name)
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats to tensorboard."""
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats."""
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def update_config(self, config):
"""Log latest configuration."""
if wandb is not None:
wandb.config.update(config)
self.wrapped_bar.update_config(config)
def _log_to_wandb(self, stats, tag=None, step=None):
if wandb is None:
return
if step is None:
step = stats["num_updates"]
prefix = "" if tag is None else tag + "/"
for key in stats.keys() - {"num_updates"}:
if isinstance(stats[key], AverageMeter):
wandb.log({prefix + key: stats[key].val}, step=step)
elif isinstance(stats[key], Number):
wandb.log({prefix + key: stats[key]}, step=step)
try:
from azureml.core import Run
except ImportError:
Run = None
class AzureMLProgressBarWrapper(BaseProgressBar):
"""Log to Azure ML"""
def __init__(self, wrapped_bar):
self.wrapped_bar = wrapped_bar
if Run is None:
logger.warning("azureml.core not found, pip install azureml-core")
return
self.run = Run.get_context()
def __exit__(self, *exc):
if Run is not None:
self.run.complete()
return False
def __iter__(self):
return iter(self.wrapped_bar)
def log(self, stats, tag=None, step=None):
"""Log intermediate stats to AzureML"""
self._log_to_azureml(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)
def print(self, stats, tag=None, step=None):
"""Print end-of-epoch stats"""
self._log_to_azureml(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def update_config(self, config):
"""Log latest configuration."""
self.wrapped_bar.update_config(config)
def _log_to_azureml(self, stats, tag=None, step=None):
if Run is None:
return
if step is None:
step = stats["num_updates"]
prefix = "" if tag is None else tag + "/"
for key in stats.keys() - {"num_updates"}:
name = prefix + key
if isinstance(stats[key], AverageMeter):
self.run.log_row(name=name, **{"step": step, key: stats[key].val})
elif isinstance(stats[key], Number):
self.run.log_row(name=name, **{"step": step, key: stats[key]})
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import criterions, models, modules # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("fairseq.model_parallel.criterions." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
try:
from fairseq.model_parallel.megatron.mpu.cross_entropy import (
vocab_parallel_cross_entropy,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@register_criterion("vocab_parallel_cross_entropy")
class VocabParallelCrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
target = sample["target"]
loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
loss = (loss * (target != self.padding_idx)).sum()
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
from fairseq.dataclass.configs import FairseqConfig
from fairseq.distributed import utils as distributed_utils
from fairseq.trainer import Trainer
try:
from fairseq.model_parallel.megatron.mpu import (
get_data_parallel_rank,
get_data_parallel_world_size,
get_model_parallel_src_rank,
get_cuda_rng_tracker,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class MegatronTrainer(Trainer):
"""Main class for model parallel with data parallel training."""
def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
super().__init__(cfg, task, model, criterion, **kwargs)
def clip_grad_norm(self, clip_norm):
def _aggregate_model_parallel_grad_norm(total_norm):
total_norm = total_norm ** 2
distributed_utils.all_reduce(
total_norm, group=distributed_utils.get_model_parallel_group()
)
total_norm = total_norm ** 0.5
return total_norm
return self.optimizer.clip_grad_norm(
clip_norm,
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
super().save_checkpoint(filename, extra_state)
def load_checkpoint(
self,
filename,
reset_optimizer=False,
reset_lr_scheduler=False,
optimizer_overrides=None,
reset_meters=False,
):
extra_state = super().load_checkpoint(
filename,
reset_optimizer=reset_optimizer,
reset_lr_scheduler=reset_lr_scheduler,
optimizer_overrides=optimizer_overrides,
reset_meters=reset_meters,
)
if extra_state is not None and "rng_tracker_states" in extra_state:
get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
return extra_state
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
model_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("fairseq.model_parallel.models." + model_name)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .model import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
MultiheadAttention,
PositionalEmbedding,
)
EncoderOut = namedtuple(
"TransformerEncoderOut",
[
"encoder_out", # T x B x C
"encoder_padding_mask", # B x T
"encoder_embedding", # B x T x C
"encoder_states", # List[T x B x C]
],
)
class TransformerEncoderEmbedding(nn.Module):
"""Encoder Embedding + Positional Embedding"""
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
if isinstance(embed_tokens, nn.ModuleList):
self.padding_idx = embed_tokens[0].padding_idx
embed_dim = sum(e.embedding_dim for e in embed_tokens)
else:
self.padding_idx = embed_tokens.padding_idx
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
args.max_source_positions,
embed_dim,
self.padding_idx,
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
def forward(self, input):
# embed tokens and positions
src_tokens = input[0]
prev_output_tokens = input[2]
if isinstance(self.embed_tokens, nn.ModuleList):
x_embed_list = []
for embed_tokens_part in self.embed_tokens:
x_embed_list.append(embed_tokens_part(src_tokens))
embedded = torch.cat(x_embed_list, dim=-1)
else:
embedded = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * embedded
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding:
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
return (x, encoder_padding_mask, prev_output_tokens)
class TransformerEncoderLayerNorm(nn.Module):
"""
Layer norm at the the end of all encoder layers if
args.encoder_enormalize_before = True
"""
def __init__(self, args, embed_dim):
super().__init__()
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, input):
x = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
if self.layer_norm:
x = self.layer_norm(x)
# keeping track of the incremental_state is not supported yet
return (x, encoder_padding_mask, prev_output_tokens)
class TransformerDecoderEmbedding(nn.Module):
"""Decoder Embedding + Positional Embedding"""
def __init__(self, args, embed_tokens):
super().__init__()
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = (
sum(e.embedding_dim for e in embed_tokens)
if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.embedding_dim
)
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
padding_idx = (
embed_tokens[0].padding_idx
if isinstance(embed_tokens, nn.ModuleList)
else embed_tokens.padding_idx
)
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.max_target_positions,
embed_dim,
padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
def forward(self, input):
mt_task = False
if isinstance(input, tuple):
if len(input) == 3:
encoder_out = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
incremental_state = None # Hardcoding to avoid passing of None objects
mt_task = True
else:
# HACK for now, need to fix (TODO sidgoyal)
prev_output_tokens = input[0]
# discard "src_lengths"
encoder_out = None
encoder_padding_mask = None
incremental_state = None
else:
prev_output_tokens = input
encoder_out = None
encoder_padding_mask = None
incremental_state = None
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
if isinstance(self.embed_tokens, nn.ModuleList):
x_embed_list = []
for embed_tokens_part in self.embed_tokens:
x_embed_list.append(embed_tokens_part(prev_output_tokens))
x = self.embed_scale * torch.cat(x_embed_list, dim=-1)
else:
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if mt_task:
return (x, encoder_out, encoder_padding_mask)
return x
class TransformerDecoderOutputLayer(nn.Module):
def __init__(self, args, embed_tokens, dictionary):
super().__init__()
self.share_input_output_embed = args.share_decoder_input_output_embed
self.embed_tokens = embed_tokens
self.output_embed_dim = args.decoder_output_dim
embed_dim = args.decoder_embed_dim
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
if args.adaptive_softmax_cutoff is not None:
assert not isinstance(embed_tokens, nn.ModuleList)
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_tokens = nn.Parameter(
torch.Tensor(len(dictionary), self.output_embed_dim)
)
nn.init.normal_(
self.embed_tokens, mean=0, std=self.output_embed_dim**-0.5
)
if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, input, apply_final_proj=True):
if isinstance(input, tuple):
x = input[0]
else:
x = input
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if apply_final_proj:
x = self.output_layer(x)
return x
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
if isinstance(self.embed_tokens, nn.ModuleList):
output = None
for i, emb in enumerate(self.embed_tokens):
sidx = i * emb.embedding_dim
eidx = (i + 1) * emb.embedding_dim
if output is None:
output = F.linear(features[:, :, sidx:eidx], emb.weight)
else:
output += F.linear(features[:, :, sidx:eidx], emb.weight)
return output
else:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_tokens)
else:
return features
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layer_norms.{}.{}".format(name, old, m)
if k in state_dict:
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k]
del state_dict[k]
def forward(self, input):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
input[2] (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing)
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
x = input[0]
encoder_padding_mask = input[1]
prev_output_tokens = input[2]
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn(
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return (x, encoder_padding_mask, prev_output_tokens)
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True,
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
)
self.activation_dropout = getattr(args, "activation_dropout", 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, "relu_dropout", 0)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, "char_inputs", False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(self, input):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
input[2] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
# Note: incremental state is not yet supported
mt_task = False
if isinstance(input, tuple):
x = input[0]
encoder_out = input[1]
encoder_padding_mask = input[2]
incremental_state = None
mt_task = True
else:
x = input
encoder_out = None
encoder_padding_mask = None
incremental_state = None
if incremental_state is None:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
# TODO: add back prev_self_attn_state, prev_attn_state,
# self_attn_padding_mask
prev_self_attn_state = None
prev_attn_state = None
self_attn_padding_mask = None
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if mt_task:
return (x, encoder_out, encoder_padding_mask)
return x
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import (
Embedding,
TransformerDecoderEmbedding,
TransformerDecoderLayer,
TransformerDecoderOutputLayer,
TransformerEncoderEmbedding,
TransformerEncoderLayer,
TransformerEncoderLayerNorm,
)
from fairseq.models import (
BaseFairseqModel,
FairseqDecoder,
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.models.transformer import (
base_architecture,
transformer_iwslt_de_en,
transformer_wmt_en_de_big,
)
from fairseq.modules import SinusoidalPositionalEmbedding
logger = logging.getLogger(__name__)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
TORCH_PIPE = False
RPC_INIT = False
def import_pipe():
global TORCH_PIPE
global RPC_INIT
try:
from torch.distributed.pipeline.sync import Pipe # noqa
global Pipe
from torch.distributed.pipeline.sync.utils import partition_model
global partition_model
from torch.distributed import rpc
import tempfile
TORCH_PIPE = True
# Initialize single process RPC agent since TORCH_PIPE requires
# RRef. RRef depends on RPC being initialized and as a result we initialize
# RPC with a single node.
tmpfile = tempfile.NamedTemporaryFile()
if not RPC_INIT:
rpc.init_rpc(
name="worker",
rank=0,
world_size=1,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
init_method="file://{}".format(tmpfile.name),
),
)
RPC_INIT = True
logger.info("Using torch pipe")
except ImportError:
try:
from fairscale.nn import Pipe # noqa
logger.info("Using fairscale pipe")
except ImportError:
raise ImportError("Please install fairscale with: pip install fairscale")
@register_model("pipeline_parallel_transformer")
class PipelineParallelTransformerModel(BaseFairseqModel):
def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
import_pipe()
super().__init__()
assert isinstance(encoder, FairseqEncoder)
assert isinstance(decoder, FairseqDecoder)
encoder_module_list = (
[encoder.embedding_layer]
+ list(encoder.encoder_layers)
+ [encoder.final_layer_norm]
)
self.num_encoder_modules = len(encoder_module_list)
decoder_module_list = (
[decoder.embedding_layer]
+ list(decoder.decoder_layers)
+ [decoder.decoder_output_layer]
)
self.num_decoder_modules = len(decoder_module_list)
module_list = encoder_module_list + decoder_module_list
self.devices = devices
if TORCH_PIPE:
self.model = Pipe(
partition_model(nn.Sequential(*module_list), balance, devices),
chunks=chunks,
checkpoint=checkpoint,
)
else:
self.model = Pipe(
nn.Sequential(*module_list),
balance=balance,
devices=devices,
chunks=chunks,
checkpoint=checkpoint,
)
self.encoder_max_positions = self.max_positions_helper(
encoder.embedding_layer, "max_source_positions"
)
self.decoder_max_positions = self.max_positions_helper(
decoder.embedding_layer, "max_target_positions"
)
self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
# Note: To be populated during inference
self.encoder = None
self.decoder = None
def forward(self, src_tokens, src_lengths, prev_output_tokens):
if self.training:
input_lst = [src_tokens, src_lengths, prev_output_tokens]
input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst)
if TORCH_PIPE:
return self.model(input).local_value()
else:
return self.model(input)
else:
assert self.encoder is not None and self.decoder is not None, (
"encoder and decoder need to be initialized by "
+ "calling the `prepare_for_inference_()` method"
)
encoder_output_tuple = self.encoder(input)
return self.decoder(encoder_output_tuple)
def prepare_for_inference_(self, cfg):
if self.encoder is not None and self.decoder is not None:
logger.info("Encoder and Decoder already initialized")
return
encoder_module_list = []
decoder_module_list = []
module_count = 0
for partition in self.model.partitions:
for module in partition:
if module_count < self.num_encoder_modules:
encoder_module_list.append(module)
else:
decoder_module_list.append(module)
module_count += 1
self.model = None
self.encoder = TransformerEncoder(
cfg.distributed_training, None, None, encoder_module_list
)
self.decoder = TransformerDecoder(
cfg.distributed_training,
None,
None,
decoder_module_list=decoder_module_list,
)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1,
help='Number of embedding layer chunks (enables more even distribution'
'of optimizer states across data parallel nodes'
'when using optimizer state sharding and'
'a big embedding vocabulary)')
# fmt: on
@classmethod
def build_model_base(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, "max_source_positions"):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, "max_target_positions"):
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1):
assert embed_dim % num_embed_chunks == 0, (
f"Number of embedding chunks = {num_embed_chunks} should be "
+ f"divisible by the embedding dimension = {embed_dim}"
)
assert path is None or num_embed_chunks == 1, (
"Loading embedding from a path with number of embedding chunks > 1"
+ " is not yet supported"
)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
# if provided, load from preloaded dictionaries
if path:
emb = Embedding(num_embeddings, embed_dim, padding_idx)
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
else:
embed_chunk_dim = embed_dim // num_embed_chunks
emb = nn.ModuleList()
for i in range(num_embed_chunks):
emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx))
return emb
num_embed_chunks = args.num_embedding_chunks
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = build_embedding(
src_dict,
args.encoder_embed_dim,
args.encoder_embed_path,
num_embed_chunks,
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
assert args.share_decoder_input_output_embed or num_embed_chunks == 1, (
"Not sharing decoder I/O embeddings is not yet supported with number of "
+ "embedding chunks > 1"
)
encoder_embed_tokens = build_embedding(
src_dict,
args.encoder_embed_dim,
args.encoder_embed_path,
num_embed_chunks,
)
decoder_embed_tokens = build_embedding(
tgt_dict,
args.decoder_embed_dim,
args.decoder_embed_path,
num_embed_chunks,
)
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return (encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(args, tgt_dict, embed_tokens)
@classmethod
def build_model(cls, args, task):
encoder, decoder = cls.build_model_base(args, task)
return PipelineParallelTransformerModel(
encoder=encoder,
decoder=decoder,
balance=utils.eval_str_list(args.pipeline_balance, type=int),
devices=utils.eval_str_list(args.pipeline_devices, type=int),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder_max_positions, self.decoder_max_positions)
def max_positions_helper(
self, embedding_layer, max_positions_field="max_source_positions"
):
"""Maximum input length supported by the encoder or decoder."""
if embedding_layer.embed_positions is None:
return getattr(embedding_layer, max_positions_field)
return min(
getattr(embedding_layer, max_positions_field),
embedding_layer.embed_positions.max_positions,
)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
if sample is not None:
assert "target" in sample
target = sample["target"]
else:
target = None
out = self.adaptive_softmax.get_log_prob(net_output, target=target)
return out.exp_() if not log_probs else out
# A Pipe() module returns a tuple of tensors as the output.
# In this case, the tuple has one element - the output tensor of logits
logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0]
if log_probs:
return utils.log_softmax(logits, dim=-1, onnx_trace=False)
else:
return utils.softmax(logits, dim=-1, onnx_trace=False)
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder_max_positions
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
self.upgrade_state_dict(state_dict)
is_regular_transformer = not any("model.partitions" in k for k in state_dict)
if is_regular_transformer:
state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict)
return super().load_state_dict(state_dict, strict)
def convert_to_pipeline_parallel_state_dict(self, state_dict):
new_state_dict = self.state_dict()
encoder_layer_idx = 0
decoder_layer_idx = 0
encoder_key_suffixes = [
"self_attn.k_proj.weight",
"self_attn.k_proj.bias",
"self_attn.v_proj.weight",
"self_attn.v_proj.bias",
"self_attn.q_proj.weight",
"self_attn.q_proj.bias",
"self_attn.out_proj.weight",
"self_attn.out_proj.bias",
"self_attn_layer_norm.weight",
"self_attn_layer_norm.bias",
"fc1.weight",
"fc1.bias",
"fc2.weight",
"fc2.bias",
"final_layer_norm.weight",
"final_layer_norm.bias",
]
decoder_key_suffixes = [
"self_attn.k_proj.weight",
"self_attn.k_proj.bias",
"self_attn.v_proj.weight",
"self_attn.v_proj.bias",
"self_attn.q_proj.weight",
"self_attn.q_proj.bias",
"self_attn.out_proj.weight",
"self_attn.out_proj.bias",
"self_attn_layer_norm.weight",
"self_attn_layer_norm.bias",
"encoder_attn.k_proj.weight",
"encoder_attn.k_proj.bias",
"encoder_attn.v_proj.weight",
"encoder_attn.v_proj.bias",
"encoder_attn.q_proj.weight",
"encoder_attn.q_proj.bias",
"encoder_attn.out_proj.weight",
"encoder_attn.out_proj.bias",
"encoder_attn_layer_norm.weight",
"encoder_attn_layer_norm.bias",
"fc1.weight",
"fc1.bias",
"fc2.weight",
"fc2.bias",
"final_layer_norm.weight",
"final_layer_norm.bias",
]
for pid, partition in enumerate(self.model.partitions):
logger.info(f"Begin Partition {pid}")
for mid, module in enumerate(partition):
# fmt: off
if isinstance(module, TransformerEncoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
if isinstance(module, TransformerEncoderLayer):
for suffix in encoder_key_suffixes:
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
encoder_layer_idx += 1
if isinstance(module, TransformerDecoderLayer):
for suffix in decoder_key_suffixes:
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}']
decoder_layer_idx += 1
if isinstance(module, TransformerEncoderLayerNorm):
if 'encoder.layer_norm.weight' in state_dict:
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
if isinstance(module, TransformerDecoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
if isinstance(module, TransformerDecoderOutputLayer):
new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
# fmt: on
return new_state_dict
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
import_pipe()
self.use_pipeline = encoder_module_list is not None
if not self.use_pipeline:
self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
self.encoder_layers = nn.Sequential(
*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
)
if isinstance(embed_tokens, nn.ModuleList):
emb_dim = sum(e.embedding_dim for e in embed_tokens)
else:
emb_dim = embed_tokens.embedding_dim
self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
else:
encoder_balance = utils.eval_str_list(
args.pipeline_encoder_balance, type=int
)
encoder_devices = utils.eval_str_list(
args.pipeline_encoder_devices, type=int
)
assert sum(encoder_balance) == len(encoder_module_list), (
f"Sum of encoder_balance={encoder_balance} is not equal "
+ f"to num_encoder_modules={len(encoder_module_list)}"
)
if TORCH_PIPE:
self.model = Pipe(
module=partition_model(
nn.Sequential(*encoder_module_list),
encoder_balance,
encoder_devices,
),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
else:
self.model = Pipe(
module=nn.Sequential(*encoder_module_list),
balance=encoder_balance,
devices=encoder_devices,
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
def forward(self, src_tokens, src_lengths):
"""
Args:
input_tuple(
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
)
Returns:
output_tuple(
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- prev_output_tokens
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
)
"""
dummy_prev_output_tokens = torch.zeros(
1, dtype=src_tokens.dtype, device=src_tokens.device
)
input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens)
if self.use_pipeline:
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
if TORCH_PIPE:
encoder_out = self.model(input_tuple).local_value()
else:
encoder_out = self.model(input_tuple)
else:
encoder_embed_output_tuple = self.embedding_layer(input_tuple)
encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple)
encoder_out = self.final_layer_norm(encoder_layers_output)
# first element is the encoder output
# second element is the encoder padding mask
# the remaining elements of EncoderOut are not computed by
# the PipelineParallelTransformer
return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None)
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if encoder_out.encoder_out is not None:
encoder_out = encoder_out._replace(
encoder_out=encoder_out.encoder_out.index_select(1, new_order)
)
if encoder_out.encoder_padding_mask is not None:
encoder_out = encoder_out._replace(
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(
0, new_order
)
)
if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
encoder_embedding=encoder_out.encoder_embedding.index_select(
0, new_order
)
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embedding_layer.embed_positions is None:
return self.embedding_layer.max_source_positions
return min(
self.embedding_layer.max_source_positions,
self.embedding_layer.embed_positions.max_positions,
)
class TransformerDecoder(FairseqDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
decoder_module_list=None,
):
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
import_pipe()
self.use_pipeline = decoder_module_list is not None
if not self.use_pipeline:
self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
self.decoder_layers = nn.Sequential(
*[
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
self.decoder_output_layer = TransformerDecoderOutputLayer(
args, embed_tokens, dictionary
)
else:
decoder_balance = utils.eval_str_list(
args.pipeline_decoder_balance, type=int
)
decoder_devices = utils.eval_str_list(
args.pipeline_decoder_devices, type=int
)
assert sum(decoder_balance) == len(decoder_module_list), (
f"Sum of decoder_balance={decoder_balance} is not equal "
+ f"to num_decoder_modules={len(decoder_module_list)}"
)
if TORCH_PIPE:
self.model = Pipe(
module=partition_model(
nn.Sequential(*decoder_module_list),
decoder_balance,
decoder_devices,
),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
else:
self.model = Pipe(
module=nn.Sequential(*decoder_module_list),
balance=decoder_balance,
devices=decoder_devices,
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
def forward(
self,
prev_output_tokens,
encoder_out=None,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
input_tuple = (
encoder_out.encoder_out,
encoder_out.encoder_padding_mask,
prev_output_tokens,
)
if self.use_pipeline:
input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple)
if TORCH_PIPE:
return (self.model(input_tuple).local_value(),)
else:
return (self.model(input_tuple),)
else:
embed_layer_output = self.embedding_layer(input_tuple)
state = self.decoder_layers(embed_layer_output)
return (self.decoder_output_layer(state),)
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embedding_layer.embed_positions is None:
return self.embedding_layer.max_target_positions
return min(
self.embedding_layer.max_target_positions,
self.embedding_layer.embed_positions.max_positions,
)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
"0": "self_attn_layer_norm",
"1": "encoder_attn_layer_norm",
"2": "final_layer_norm",
}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
state_dict[
"{}.layers.{}.{}.{}".format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
@register_model_architecture(
"pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel"
)
def transformer_iwslt_de_en_dist(args):
transformer_iwslt_de_en(args)
@register_model_architecture(
"pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel"
)
def transformer_wmt_en_de_big_dist(args):
transformer_wmt_en_de_big(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .model import * # noqa
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder
from fairseq.models import register_model, register_model_architecture
from fairseq.models.roberta import (
roberta_base_architecture,
roberta_prenorm_architecture,
RobertaEncoder,
RobertaModel,
)
from fairseq.modules import LayerNorm
try:
from fairseq.model_parallel.megatron.mpu import (
copy_to_model_parallel_region,
gather_from_model_parallel_region,
ColumnParallelLinear,
VocabParallelEmbedding,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
logger = logging.getLogger(__name__)
@register_model("model_parallel_roberta")
class ModelParallelRobertaModel(RobertaModel):
def __init__(self, args, encoder):
super().__init__(args, encoder)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
RobertaModel.add_args(parser)
parser.add_argument(
"--no-final-layer-norm",
action="store_true",
help=(
"don't add final layernorm (only applicable when "
"--encoder-normalize-before=True"
),
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture(args)
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
if not hasattr(args, "max_positions"):
args.max_positions = args.tokens_per_sample
if getattr(args, "untie_weights_roberta", False):
raise NotImplementedError(
"--untie-weights-roberta is not supported in model parallel mode"
)
encoder = ModelParallelRobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)
def forward(
self,
src_tokens,
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
**kwargs
):
if classification_head_name is not None:
features_only = True
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = ModelParallelRobertaClassificationHead(
self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim,
num_classes,
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
class ModelParallelRobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the unmasked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
x = copy_to_model_parallel_region(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight)
x = gather_from_model_parallel_region(x).contiguous()
x = x + self.bias
return x
class ModelParallelRobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout
):
super().__init__()
self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ModelParallelRobertaEncoder(RobertaEncoder):
"""RoBERTa encoder."""
def __init__(self, args, dictionary):
super().__init__(args, dictionary)
assert not self.args.untie_weights_roberta
def build_embedding(self, vocab_size, embedding_dim, padding_idx):
return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx)
def build_encoder(self, args, dictionary, embed_tokens):
return ModelParallelTransformerEncoder(args, dictionary, embed_tokens)
def build_lm_head(self, embed_dim, output_dim, activation_fn, weight):
return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight)
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta")
def base_architecture(args):
args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False)
# model parallel RoBERTa defaults to "Pre-LN" formulation
roberta_prenorm_architecture(args)
# earlier versions of model parallel RoBERTa removed the final layer norm
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1")
def model_parallel_roberta_v1_architecture(args):
args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True)
base_architecture(args)
@register_model_architecture(
"model_parallel_roberta", "model_parallel_roberta_postnorm"
)
def model_parallel_roberta_postnorm_architecture(args):
# the original BERT/RoBERTa uses the "Post-LN" formulation
roberta_base_architecture(args)
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base")
def model_parallel_roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large")
def model_parallel_roberta_large_architecture(args):
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch.nn as nn
from fairseq.model_parallel.modules import (
ModelParallelTransformerDecoderLayer,
ModelParallelTransformerEncoderLayer,
)
from fairseq.models import register_model
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
)
try:
from fairseq.model_parallel.megatron.mpu import (
VocabParallelEmbedding,
copy_to_model_parallel_region,
gather_from_model_parallel_region,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
logger = logging.getLogger(__name__)
@register_model("model_parallel_transformer")
class ModelParallelTransformerModel(TransformerModel):
"""
Model parallel Transformer model.
"""
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
dictionary.pad_to_multiple_(args.model_parallel_size * 8)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=num_embeddings**-0.5)
nn.init.constant_(tensor[1], 0)
emb = VocabParallelEmbedding(
num_embeddings, embed_dim, padding_idx, init_method=_vocab_init
)
# if provided, load from preloaded dictionaries
if path:
raise NotImplementedError(
"Loading of embedding from path is not supported for model parallel"
)
return emb
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return ModelParallelTransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return ModelParallelTransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, "no_cross_attention", False),
)
class ModelParallelTransformerEncoder(TransformerEncoder):
"""
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerEncoderLayer`.
"""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
if args.no_final_layer_norm:
self.layer_norm = None
def build_encoder_layer(self, args):
return ModelParallelTransformerEncoderLayer(args)
class ModelParallelTransformerDecoder(TransformerDecoder):
"""
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerDecoderLayer`.
"""
def build_decoder_layer(self, args, no_encoder_attn=False):
return ModelParallelTransformerDecoderLayer(args, no_encoder_attn)
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if not self.share_input_output_embed:
raise NotImplementedError(
"Model parallel training currently requires --share-decoder-input-output-embed"
)
features = copy_to_model_parallel_region(features)
# project back to size of vocabulary
x = self.output_projection(features)
if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy":
x = gather_from_model_parallel_region(x).contiguous()
return x
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer_lm import TransformerLanguageModel
try:
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("model_parallel_transformer_lm")
class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
# make sure all arguments are present in older models
base_lm_architecture(args)
task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8)
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
if args.character_embeddings:
raise NotImplementedError(
"Character embeddings is not supported for model parallel"
)
elif args.adaptive_input:
raise NotImplementedError(
"Adaptive input is not supported for model parallel"
)
else:
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_input_dim
)
decoder = ModelParallelTransformerDecoder(
args,
task.target_dictionary,
embed_tokens,
no_encoder_attn=True,
)
return cls(decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs):
nn.init.normal_(tensor, mean=0, std=embed_dim**-0.5)
nn.init.constant_(tensor[1], 0)
embed_tokens = VocabParallelEmbedding(
len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init
)
return embed_tokens
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.character_filters = getattr(
args,
"character_filters",
"[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]",
)
args.character_embedding_dim = getattr(args, "character_embedding_dim", 4)
args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0)
args.add_bos_token = getattr(args, "add_bos_token", False)
@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron")
def transformer_lm_megatron(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4)
args.decoder_layers = getattr(args, "decoder_layers", 72)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
@register_model_architecture(
"model_parallel_transformer_lm", "transformer_lm_megatron_11b"
)
def transformer_lm_megatron_11b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6)
args.decoder_layers = getattr(args, "decoder_layers", 72)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from .multihead_attention import ModelParallelMultiheadAttention
from .transformer_layer import (
ModelParallelTransformerEncoderLayer,
ModelParallelTransformerDecoderLayer,
)
__all__ = [
"ModelParallelMultiheadAttention",
"ModelParallelTransformerEncoderLayer",
"ModelParallelTransformerDecoderLayer",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules.fairseq_dropout import FairseqDropout
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
get_cuda_rng_tracker,
get_model_parallel_world_size,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@with_incremental_state
class ModelParallelMultiheadAttention(nn.Module):
"""Model parallel Multi-headed attention.
This performs the Multi-headed attention over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
self_attention=False,
encoder_decoder_attention=False,
):
super().__init__()
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.model_parallel_size = get_model_parallel_world_size()
self.num_heads_partition = num_heads // self.model_parallel_size
assert (
self.num_heads_partition * self.model_parallel_size == num_heads
), "Number of heads must be divisible by model parallel size"
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert (
not self.self_attention or self.qkv_same_dim
), "Self-attention requires query, key and value to be of the same size"
self.k_proj = ColumnParallelLinear(
self.kdim, embed_dim, bias=bias, gather_output=False
)
self.v_proj = ColumnParallelLinear(
self.vdim, embed_dim, bias=bias, gather_output=False
)
self.q_proj = ColumnParallelLinear(
embed_dim, embed_dim, bias=bias, gather_output=False
)
self.out_proj = RowParallelLinear(
embed_dim, embed_dim, bias=bias, input_is_parallel=True
)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
**unused_kwargs,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
is_tpu = query.device.type == "xla"
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads_partition, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(
bsz * self.num_heads_partition, -1, self.head_dim
)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(
bsz * self.num_heads_partition, -1, self.head_dim
)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = (
ModelParallelMultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
)
saved_state["prev_key"] = k.view(
bsz, self.num_heads_partition, -1, self.head_dim
)
saved_state["prev_value"] = v.view(
bsz, self.num_heads_partition, -1, self.head_dim
)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [
bsz * self.num_heads_partition,
tgt_len,
src_len,
]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(
bsz, self.num_heads_partition, tgt_len, src_len
)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(
bsz * self.num_heads_partition, tgt_len, src_len
)
attn_weights_float = utils.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
with get_cuda_rng_tracker().fork():
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [
bsz * self.num_heads_partition,
tgt_len,
self.head_dim,
]
embed_dim_partition = embed_dim // self.model_parallel_size
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition)
attn = self.out_proj(attn)
# return attn_weights None to keep the return type same as single gpu multihead attention
# This will be deprecated.
attn_weights: Optional[Tensor] = None
return attn, attn_weights
@staticmethod
def _append_prev_key_padding_mask(
key_padding_mask: Optional[Tensor],
prev_key_padding_mask: Optional[Tensor],
batch_size: int,
src_len: int,
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda:
filler = filler.cuda()
new_key_padding_mask = torch.cat(
[prev_key_padding_mask.float(), filler.float()], dim=1
)
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda:
filler = filler.cuda()
new_key_padding_mask = torch.cat(
[filler.float(), key_padding_mask.float()], dim=1
)
else:
new_key_padding_mask = prev_key_padding_mask
return new_key_padding_mask
def reorder_incremental_state(
self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
if input_buffer[k] is not None:
input_buffer[k] = input_buffer[k].index_select(0, new_order)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result
def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
return self.set_incremental_state(incremental_state, "attn_state", buffer)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.model_parallel.modules import ModelParallelMultiheadAttention
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
try:
from fairseq.model_parallel.megatron.mpu import (
ColumnParallelLinear,
RowParallelLinear,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer):
"""Encoder layer block over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer):
"""Decoder layer block.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return ColumnParallelLinear(input_dim, output_dim, gather_output=False)
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size):
if q_noise > 0:
raise NotImplementedError
return RowParallelLinear(input_dim, output_dim, input_is_parallel=True)
def build_self_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim=embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=not getattr(args, "cross_self_attention", False),
)
def build_encoder_attention(self, embed_dim, args, **unused_kwargs):
return ModelParallelMultiheadAttention(
embed_dim=embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment