"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "8c51f520ddb4c56c54595630f05870b4bc9d50df"
Commit 4294c4f6 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add code for mixture of experts (#521)

Summary:
Code for the paper: [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
Pull Request resolved: https://github.com/pytorch/fairseq/pull/521

Differential Revision: D14188021

Pulled By: myleott

fbshipit-source-id: ed5b1ed5ad9a582359bd5215fa2ea26dc76c673e
parent b65c579b
...@@ -5,19 +5,20 @@ developers to train custom models for translation, summarization, language ...@@ -5,19 +5,20 @@ developers to train custom models for translation, summarization, language
modeling and other text generation tasks. It provides reference implementations modeling and other text generation tasks. It provides reference implementations
of various sequence-to-sequence models, including: of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)** - **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083) - [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122) - [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://arxiv.org/abs/1711.04956) - [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833) - [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- **LightConv and DynamicConv models** - **LightConv and DynamicConv models**
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](https://openreview.net/pdf?id=SkVhlh09tX) - **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- **Long Short-Term Memory (LSTM) networks** - **Long Short-Term Memory (LSTM) networks**
- [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025) - [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)
- [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960) - [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960)
- **Transformer (self-attention) networks** - **Transformer (self-attention) networks**
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762) - [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- [Ott et al. (2018): Scaling Neural Machine Translation](https://arxiv.org/abs/1806.00187) - [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](https://arxiv.org/abs/1808.09381) - [Edunov et al. (2018): Understanding Back-Translation at Scale](https://arxiv.org/abs/1808.09381)
- **_New_** [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
Fairseq features: Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines - multi-GPU (distributed) training on one machine or across multiple machines
...@@ -74,6 +75,7 @@ as well as example training and evaluation commands. ...@@ -74,6 +75,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional models are available - [Language Modeling](examples/language_model/README.md): convolutional models are available
We also have more detailed READMEs to reproduce results from specific papers: We also have more detailed READMEs to reproduce results from specific papers:
- [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md) - [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel) - [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md) - [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
......
...@@ -112,22 +112,22 @@ $ bash prepare-wmt14en2de.sh ...@@ -112,22 +112,22 @@ $ bash prepare-wmt14en2de.sh
$ cd ../.. $ cd ../..
# Binarize the dataset: # Binarize the dataset:
$ TEXT=examples/translation/wmt14_en_de $ TEXT=examples/translation/wmt17_en_de
$ fairseq-preprocess --source-lang en --target-lang de \ $ fairseq-preprocess --source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_de --thresholdtgt 0 --thresholdsrc 0 --destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0
# Train the model: # Train the model:
# If it runs out of memory, try to set --max-tokens 1500 instead # If it runs out of memory, try to set --max-tokens 1500 instead
$ mkdir -p checkpoints/fconv_wmt_en_de $ mkdir -p checkpoints/fconv_wmt_en_de
$ fairseq-train data-bin/wmt14_en_de \ $ fairseq-train data-bin/wmt17_en_de \
--lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ --lr 0.5 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler fixed --force-anneal 50 \ --lr-scheduler fixed --force-anneal 50 \
--arch fconv_wmt_en_de --save-dir checkpoints/fconv_wmt_en_de --arch fconv_wmt_en_de --save-dir checkpoints/fconv_wmt_en_de
# Generate: # Generate:
$ fairseq-generate data-bin/wmt14_en_de \ $ fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/fconv_wmt_en_de/checkpoint_best.pt --beam 5 --remove-bpe --path checkpoints/fconv_wmt_en_de/checkpoint_best.pt --beam 5 --remove-bpe
``` ```
......
...@@ -41,6 +41,9 @@ if [ "$1" == "--icml17" ]; then ...@@ -41,6 +41,9 @@ if [ "$1" == "--icml17" ]; then
URLS[2]="http://statmt.org/wmt14/training-parallel-nc-v9.tgz" URLS[2]="http://statmt.org/wmt14/training-parallel-nc-v9.tgz"
FILES[2]="training-parallel-nc-v9.tgz" FILES[2]="training-parallel-nc-v9.tgz"
CORPORA[2]="training/news-commentary-v9.de-en" CORPORA[2]="training/news-commentary-v9.de-en"
OUTDIR=wmt14_en_de
else
OUTDIR=wmt17_en_de
fi fi
if [ ! -d "$SCRIPTS" ]; then if [ ! -d "$SCRIPTS" ]; then
...@@ -51,7 +54,7 @@ fi ...@@ -51,7 +54,7 @@ fi
src=en src=en
tgt=de tgt=de
lang=en-de lang=en-de
prep=wmt14_en_de prep=$OUTDIR
tmp=$prep/tmp tmp=$prep/tmp
orig=orig orig=orig
dev=dev/newstest2013 dev=dev/newstest2013
......
# Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)
This page includes instructions for reproducing results from the paper [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](https://arxiv.org/abs/1902.07816).
## Training a new model on WMT'17 En-De
First, follow the [instructions to download and preprocess the WMT'17 En-De dataset](../translation#prepare-wmt14en2desh).
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
Then we can train a mixture of experts model using the `translation_moe` task.
Use the `--method` option to choose the MoE variant; we support hard mixtures with a learned or uniform prior (`--method hMoElp` and `hMoEup`, respectively) and soft mixures (`--method sMoElp` and `sMoEup`).
To train a hard mixture of experts model with a learned prior (`hMoElp`) on 1 GPU:
```
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/wmt17_en_de \
--max-update 100000 \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--arch transformer_vaswani_wmt_en_de --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
--lr 0.0007 --min-lr 1e-09 \
--dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \
--max-tokens 3584 \
--update-freq 8
```
**Note**: the above command assumes 1 GPU, but accumulates gradients from 8 fwd/bwd passes to simulate training on 8 GPUs.
You can accelerate training on up to 8 GPUs by adjusting the `CUDA_VISIBLE_DEVICES` and `--update-freq` options accordingly.
Once a model is trained, we can generate translations from different experts using the `--gen-expert` option.
For example, to generate from expert 0:
```
$ fairseq-generate data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt
--beam 1 --remove-bpe \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert 0 \
```
You can also use `scripts/score_moe.py` to compute pairwise BLEU and average oracle BLEU.
We'll first download a tokenized version of the multi-reference WMT'14 En-De dataset:
```
$ wget dl.fbaipublicfiles.com/fairseq/data/wmt14-en-de.extra_refs.tok
```
Next apply BPE on the fly and run generation for each expert:
```
$ BPEROOT=examples/translation/subword-nmt/
$ BPE_CODE=examples/translation/wmt17_en_de/code
$ for EXPERT in $(seq 0 2); do \
cat wmt14-en-de.extra_refs.tok | grep ^S | cut -f 2 | \
python $BPEROOT/apply_bpe.py -c $BPE_CODE | \
fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--buffer 500 --max-tokens 6000 ; \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT \
done > wmt14-en-de.extra_refs.tok.gen.3experts
```
Finally compute pairwise BLUE and average oracle BLEU:
```
$ python scripts/score_moe.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
pairwise BLEU: 48.26
avg oracle BLEU: 49.50
#refs covered: 2.11
```
This reproduces row 3 from Table 7 in the paper.
## Citation
```bibtex
@article{shen2019mixture,
title = {Mixture Models for Diverse Machine Translation: Tricks of the Trade},
author = {Tianxiao Shen and Myle Ott and Michael Auli and Marc'Aurelio Ranzato},
journal = {arXiv preprint arXiv:1902.07816},
year = 2019,
}
```
...@@ -28,11 +28,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -28,11 +28,7 @@ class CrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True) loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
...@@ -42,6 +38,14 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -42,6 +38,14 @@ class CrossEntropyCriterion(FairseqCriterion):
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
return loss, loss
@staticmethod @staticmethod
def aggregate_logging_outputs(logging_outputs): def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
......
...@@ -17,6 +17,8 @@ from .highway import Highway ...@@ -17,6 +17,8 @@ from .highway import Highway
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv1dTBC from .lightweight_convolution import LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
...@@ -35,6 +37,8 @@ __all__ = [ ...@@ -35,6 +37,8 @@ __all__ = [
'LearnedPositionalEmbedding', 'LearnedPositionalEmbedding',
'LightweightConv1dTBC', 'LightweightConv1dTBC',
'LinearizedConvolution', 'LinearizedConvolution',
'LogSumExpMoE',
'MeanPoolGatingNetwork',
'MultiheadAttention', 'MultiheadAttention',
'ScalarBias', 'ScalarBias',
'SinusoidalPositionalEmbedding', 'SinusoidalPositionalEmbedding',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
class LogSumExpMoE(torch.autograd.Function):
"""Standard LogSumExp forward pass, but use *posterior* for the backward.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
"""
@staticmethod
def forward(ctx, logp, posterior, dim=-1):
ctx.save_for_backward(posterior)
ctx.dim = dim
return torch.logsumexp(logp, dim=dim)
@staticmethod
def backward(ctx, grad_output):
posterior, = ctx.saved_tensors
grad_logp = grad_output.unsqueeze(ctx.dim) * posterior
return grad_logp, None, None
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F
class MeanPoolGatingNetwork(torch.nn.Module):
"""A simple mean-pooling gating network for selecting experts.
This module applies mean pooling over an encoder's output and returns
reponsibilities for each expert. The encoder format is expected to match
:class:`fairseq.models.transformer.TransformerEncoder`.
"""
def __init__(self, embed_dim, num_experts, dropout=None):
super().__init__()
self.embed_dim = embed_dim
self.num_experts = num_experts
self.fc1 = torch.nn.Linear(embed_dim, embed_dim)
self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None
self.fc2 = torch.nn.Linear(embed_dim, num_experts)
def forward(self, encoder_out):
if not (
isinstance(encoder_out, dict)
and 'encoder_out' in encoder_out
and 'encoder_padding_mask' in encoder_out
and encoder_out['encoder_out'].size(2) == self.embed_dim
):
raise ValueError('Unexpected format for encoder_out')
# mean pooling over time
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0
ntokens = torch.sum(1 - encoder_padding_mask, dim=1, keepdim=True)
x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
else:
x = torch.mean(encoder_out, dim=1)
x = torch.tanh(self.fc1(x))
if self.dropout is not None:
x = self.dropout(x)
x = self.fc2(x)
return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)
...@@ -98,7 +98,14 @@ class SequenceGenerator(object): ...@@ -98,7 +98,14 @@ class SequenceGenerator(object):
self.search = search.BeamSearch(tgt_dict) self.search = search.BeamSearch(tgt_dict)
@torch.no_grad() @torch.no_grad()
def generate(self, models, sample=None, net_input=None, prefix_tokens=None, **kwargs): def generate(
self,
models,
sample,
prefix_tokens=None,
bos_token=None,
**kwargs
):
"""Generate a batch of translations. """Generate a batch of translations.
Args: Args:
...@@ -143,7 +150,7 @@ class SequenceGenerator(object): ...@@ -143,7 +150,7 @@ class SequenceGenerator(object):
scores_buf = scores.clone() scores_buf = scores.clone()
tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad) tokens = src_tokens.new(bsz * beam_size, max_len + 2).fill_(self.pad)
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = self.eos tokens[:, 0] = bos_token or self.eos
attn, attn_buf = None, None attn, attn_buf = None, None
nonpad_idxs = None nonpad_idxs = None
......
# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
import torch
from fairseq import modules, options, utils
from fairseq.data import (
ConcatDataset,
data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
)
from . import register_task
from .translation import TranslationTask
@contextlib.contextmanager
def eval(model):
is_training = model.training
model.eval()
yield
model.train(is_training)
@register_task('translation_moe')
class TranslationMoETask(TranslationTask):
"""
Translation task for Mixture of Experts (MoE) models.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
Args:
src_dict (Dictionary): dictionary for the source language
tgt_dict (Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`fairseq-train`,
:mod:`fairseq-generate` and :mod:`fairseq-interactive`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
TranslationTask.add_args(parser)
parser.add_argument('--method', required=True,
choices=['sMoElp', 'sMoEup', 'hMoElp', 'hMoEup'])
parser.add_argument('--num-experts', type=int, metavar='N', required=True,
help='number of experts')
parser.add_argument('--mean-pool-gating-network', action='store_true',
help='use a simple mean-pooling gating network')
parser.add_argument('--mean-pool-gating-network-dropout', type=float,
help='dropout for mean-pooling gating network')
parser.add_argument('--mean-pool-gating-network-encoder-dim', type=float,
help='encoder output dim for mean-pooling gating network')
parser.add_argument('--gen-expert', type=int, default=0,
help='which expert to use for generation')
# fmt: on
def __init__(self, args, src_dict, tgt_dict):
if args.method == 'sMoElp':
# soft MoE with learned prior
self.uniform_prior = False
self.hard_selection = False
elif args.method == 'sMoEup':
# soft MoE with uniform prior
self.uniform_prior = True
self.hard_selection = False
elif args.method == 'hMoElp':
# hard MoE with learned prior
self.uniform_prior = False
self.hard_selection = True
elif args.method == 'hMoEup':
# hard MoE with uniform prior
self.uniform_prior = True
self.hard_selection = True
# add indicator tokens for each expert
for i in range(args.num_experts):
# add to both dictionaries in case we're sharing embeddings
src_dict.add_symbol('<expert_{}>'.format(i))
tgt_dict.add_symbol('<expert_{}>'.format(i))
super().__init__(args, src_dict, tgt_dict)
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not self.uniform_prior and not hasattr(model, 'gating_network'):
if self.args.mean_pool_gating_network:
if getattr(args, 'mean_pool_gating_network_encoder_dim', None):
encoder_dim = args.mean_pool_gating_network_encoder_dim
elif getattr(args, 'encoder_embed_dim', None):
# assume that encoder_embed_dim is the encoder's output dimension
encoder_dim = args.encoder_embed_dim
else:
raise ValueError('Must specify --mean-pool-gating-network-encoder-dim')
if getattr(args, 'mean_pool_gating_network_dropout', None):
dropout = args.mean_pool_gating_network_dropout
elif getattr(args, 'dropout', None):
dropout = args.dropout
else:
raise ValueError('Must specify --mean-pool-gating-network-dropout')
model.gating_network = modules.MeanPoolGatingNetwork(
encoder_dim, args.num_experts, dropout,
)
else:
raise ValueError(
'translation_moe task with learned prior requires the model to '
'have a gating network; try using --mean-pool-gating-network'
)
return model
def expert_index(self, i):
return i + self.tgt_dict.index('<expert_0>')
def _get_loss(self, sample, model, criterion):
assert hasattr(criterion, 'compute_loss'), \
'translation_moe task requires the criterion to implement the compute_loss() method'
k = self.args.num_experts
bsz = sample['target'].size(0)
def get_lprob_y(encoder_out, prev_output_tokens_k):
net_output = model.decoder(prev_output_tokens_k, encoder_out)
loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False)
loss = loss.view(bsz, -1)
return -loss.sum(dim=1, keepdim=True) # -> B x 1
def get_lprob_yz(winners=None):
encoder_out = model.encoder(sample['net_input']['src_tokens'], sample['net_input']['src_lengths'])
if winners is None:
lprob_y = []
for i in range(k):
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
assert not prev_output_tokens_k.requires_grad
prev_output_tokens_k[:, 0] = self.expert_index(i)
lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k))
lprob_y = torch.cat(lprob_y, dim=1) # -> B x K
else:
prev_output_tokens_k = sample['net_input']['prev_output_tokens'].clone()
prev_output_tokens_k[:, 0] = self.expert_index(winners)
lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B
if self.uniform_prior:
lprob_yz = lprob_y
else:
lprob_z = model.gating_network(encoder_out) # B x K
if winners is not None:
lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1))
lprob_yz = lprob_y + lprob_z.type_as(lprob_y) # B x K
return lprob_yz
# compute responsibilities without dropout
with eval(model): # disable dropout
with torch.no_grad(): # disable autograd
lprob_yz = get_lprob_yz() # B x K
prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
assert not prob_z_xy.requires_grad
# compute loss with dropout
if self.hard_selection:
winners = prob_z_xy.max(dim=1)[1]
loss = -get_lprob_yz(winners)
else:
lprob_yz = get_lprob_yz() # B x K
loss = -modules.LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)
loss = loss.sum()
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data),
'ntokens': sample['ntokens'],
'sample_size': sample_size,
'posterior': prob_z_xy.float().sum(dim=0).cpu(),
}
return loss, sample_size, logging_output
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
model.train()
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
loss, sample_size, logging_output = self._get_loss(sample, model, criterion)
return loss, sample_size, logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
bos_token=self.expert_index(self.args.gen_expert),
)
def aggregate_logging_outputs(self, logging_outputs, criterion):
agg_logging_outputs = criterion._aggregate_logging_outputs(logging_outputs)
agg_logging_outputs['posterior'] = sum(
log['posterior'] for log in logging_outputs if 'posterior' in log
)
return agg_logging_outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment