Commit 48d9afbe authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Speed improvements (#531)

Summary:
* Add FusedLayerNorm and FusedAdam
* Softmax and zero grad optimizations
Pull Request resolved: https://github.com/pytorch/fairseq/pull/531

Differential Revision: D14218457

Pulled By: myleott

fbshipit-source-id: 5656b2d0152cd85f77dc21ec0e1439ec04b9fa89
parent a24880bd
...@@ -36,12 +36,12 @@ translation and language modeling datasets. ...@@ -36,12 +36,12 @@ translation and language modeling datasets.
![Model](fairseq.gif) ![Model](fairseq.gif)
# Requirements and Installation # Requirements and Installation
* A [PyTorch installation](http://pytorch.org/)
* [PyTorch](http://pytorch.org/) version >= 1.0.0
* Python version >= 3.6
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6
Currently fairseq requires PyTorch version >= 1.0.0. Please follow the instructions here to install PyTorch: https://github.com/pytorch/pytorch#installation.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with If you use Docker make sure to increase the shared memory size either with
`--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`. `--ipc=host` or `--shm-size` as command line options to `nvidia-docker run`.
...@@ -60,6 +60,12 @@ cd fairseq ...@@ -60,6 +60,12 @@ cd fairseq
pip install --editable . pip install --editable .
``` ```
**Improved training speed**
Training speed can be further improved by installing NVIDIA's
[apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` option.
fairseq will automatically switch to the faster modules provided by apex.
# Getting Started # Getting Started
The [full documentation](https://fairseq.readthedocs.io/) contains instructions The [full documentation](https://fairseq.readthedocs.io/) contains instructions
......
...@@ -122,8 +122,10 @@ def all_gather_list(data, group=None, max_size=16384): ...@@ -122,8 +122,10 @@ def all_gather_list(data, group=None, max_size=16384):
if not hasattr(all_gather_list, '_buffer') or \ if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size: all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer buffer = all_gather_list._buffer
buffer.zero_() buffer.zero_()
cpu_buffer = all_gather_list._cpu_buffer
enc = pickle.dumps(data) enc = pickle.dumps(data)
enc_size = len(enc) enc_size = len(enc)
...@@ -131,10 +133,12 @@ def all_gather_list(data, group=None, max_size=16384): ...@@ -131,10 +133,12 @@ def all_gather_list(data, group=None, max_size=16384):
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256 assert max_size < 255*256
buffer_rank = buffer[rank * max_size : (rank + 1) * max_size] cpu_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k cpu_buffer[1] = enc_size % 255
buffer_rank[1] = enc_size % 255 cpu_buffer[2 : enc_size + 2] = torch.ByteTensor(list(enc))
buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc)) start = rank * max_size
size = enc_size + 2
buffer[start : start + size].copy_(cpu_buffer[:size])
all_reduce(buffer, group=group) all_reduce(buffer, group=group)
...@@ -144,9 +148,7 @@ def all_gather_list(data, group=None, max_size=16384): ...@@ -144,9 +148,7 @@ def all_gather_list(data, group=None, max_size=16384):
out_buffer = buffer[i * max_size : (i + 1) * max_size] out_buffer = buffer[i * max_size : (i + 1) * max_size]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1]) size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
if size > 0: if size > 0:
result.append( result.append(pickle.loads(bytes(out_buffer[2 : size + 2].tolist())))
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result return result
except pickle.UnpicklingError: except pickle.UnpicklingError:
raise Exception( raise Exception(
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
class FairseqDecoder(nn.Module): class FairseqDecoder(nn.Module):
...@@ -15,6 +16,7 @@ class FairseqDecoder(nn.Module): ...@@ -15,6 +16,7 @@ class FairseqDecoder(nn.Module):
def __init__(self, dictionary): def __init__(self, dictionary):
super().__init__() super().__init__()
self.dictionary = dictionary self.dictionary = dictionary
self.onnx_trace = False
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out):
""" """
...@@ -33,6 +35,9 @@ class FairseqDecoder(nn.Module): ...@@ -33,6 +35,9 @@ class FairseqDecoder(nn.Module):
""" """
raise NotImplementedError raise NotImplementedError
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def get_normalized_probs(self, net_output, log_probs, sample): def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
...@@ -45,11 +50,11 @@ class FairseqDecoder(nn.Module): ...@@ -45,11 +50,11 @@ class FairseqDecoder(nn.Module):
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
return out.exp_() if not log_probs else out return out.exp_() if not log_probs else out
logits = net_output[0].float() logits = net_output[0]
if log_probs: if log_probs:
return F.log_softmax(logits, dim=-1) return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
else: else:
return F.softmax(logits, dim=-1) return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the decoder.""" """Maximum input length supported by the decoder."""
......
...@@ -13,8 +13,8 @@ import torch.nn as nn ...@@ -13,8 +13,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules import ( from fairseq.modules import (
DownsampledMultiHeadAttention, GradMultiply, LearnedPositionalEmbedding, DownsampledMultiHeadAttention, GradMultiply, LayerNorm,
LinearizedConvolution, LearnedPositionalEmbedding, LinearizedConvolution,
) )
from fairseq import utils from fairseq import utils
...@@ -351,13 +351,13 @@ class FConvDecoder(FairseqDecoder): ...@@ -351,13 +351,13 @@ class FConvDecoder(FairseqDecoder):
# pretrained and trained models are joined # pretrained and trained models are joined
self.joining = nn.Sequential( self.joining = nn.Sequential(
Linear(out_embed_dim*2, out_embed_dim*2), Linear(out_embed_dim*2, out_embed_dim*2),
nn.LayerNorm(out_embed_dim*2), LayerNorm(out_embed_dim*2),
nn.GLU(), nn.GLU(),
Linear(out_embed_dim, out_embed_dim*2), Linear(out_embed_dim, out_embed_dim*2),
nn.LayerNorm(out_embed_dim*2), LayerNorm(out_embed_dim*2),
nn.GLU(), nn.GLU(),
Linear(out_embed_dim, out_embed_dim), Linear(out_embed_dim, out_embed_dim),
nn.LayerNorm(out_embed_dim) LayerNorm(out_embed_dim)
) )
# pretrained model contains an output layer that is nhid -> vocab size # pretrained model contains an output layer that is nhid -> vocab size
# but the models are combined in their hidden state # but the models are combined in their hidden state
...@@ -470,7 +470,7 @@ class SelfAttention(nn.Module): ...@@ -470,7 +470,7 @@ class SelfAttention(nn.Module):
self.in_proj_q = Linear(out_channels, embed_dim) self.in_proj_q = Linear(out_channels, embed_dim)
self.in_proj_k = Linear(out_channels, embed_dim) self.in_proj_k = Linear(out_channels, embed_dim)
self.in_proj_v = Linear(out_channels, embed_dim) self.in_proj_v = Linear(out_channels, embed_dim)
self.ln = nn.LayerNorm(out_channels) self.ln = LayerNorm(out_channels)
def forward(self, x): def forward(self, x):
residual = x residual = x
......
...@@ -11,17 +11,16 @@ import torch ...@@ -11,17 +11,16 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import options from fairseq import options, utils
from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention, AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm,
SinusoidalPositionalEmbedding, DynamicConv1dTBC, LightweightConv1dTBC LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding,
DynamicConv1dTBC, LightweightConv1dTBC,
) )
from . import ( from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model, FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel,
register_model_architecture, FairseqModel, register_model, register_model_architecture,
) )
...@@ -771,11 +770,6 @@ def Embedding(num_embeddings, embedding_dim, padding_idx): ...@@ -771,11 +770,6 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True): def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias) m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
......
...@@ -11,17 +11,15 @@ import torch ...@@ -11,17 +11,15 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import options from fairseq import options, utils
from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention, AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm,
SinusoidalPositionalEmbedding LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding,
) )
from . import ( from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model, FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel,
register_model_architecture, FairseqModel, register_model, register_model_architecture,
) )
...@@ -766,11 +764,6 @@ def Embedding(num_embeddings, embedding_dim, padding_idx): ...@@ -766,11 +764,6 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True): def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias) m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
......
...@@ -14,6 +14,7 @@ from .downsampled_multihead_attention import DownsampledMultiHeadAttention ...@@ -14,6 +14,7 @@ from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv1dTBC from .dynamic_convolution import DynamicConv1dTBC
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
from .highway import Highway from .highway import Highway
from .layer_norm import LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv1dTBC from .lightweight_convolution import LightweightConv1dTBC
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
...@@ -34,6 +35,7 @@ __all__ = [ ...@@ -34,6 +35,7 @@ __all__ = [
'DynamicConv1dTBC', 'DynamicConv1dTBC',
'GradMultiply', 'GradMultiply',
'Highway', 'Highway',
'LayerNorm',
'LearnedPositionalEmbedding', 'LearnedPositionalEmbedding',
'LightweightConv1dTBC', 'LightweightConv1dTBC',
'LinearizedConvolution', 'LinearizedConvolution',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -121,6 +122,8 @@ class LightweightConv1dTBC(nn.Module): ...@@ -121,6 +122,8 @@ class LightweightConv1dTBC(nn.Module):
self.reset_parameters() self.reset_parameters()
self.onnx_trace = False
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform_(self.weight) nn.init.xavier_uniform_(self.weight)
if self.bias is not None: if self.bias is not None:
...@@ -144,6 +147,9 @@ class LightweightConv1dTBC(nn.Module): ...@@ -144,6 +147,9 @@ class LightweightConv1dTBC(nn.Module):
output = output + self.bias.view(1, 1, -1) output = output + self.bias.view(1, 1, -1)
return output return output
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def _forward_unfolded(self, x, incremental_state): def _forward_unfolded(self, x, incremental_state):
'''The conventional implementation of convolutions. '''The conventional implementation of convolutions.
Unfolding the input by having a window shifting to the right.''' Unfolding the input by having a window shifting to the right.'''
...@@ -167,7 +173,7 @@ class LightweightConv1dTBC(nn.Module): ...@@ -167,7 +173,7 @@ class LightweightConv1dTBC(nn.Module):
x_unfold = x_unfold.view(T*B*H, R, K) x_unfold = x_unfold.view(T*B*H, R, K)
if self.weight_softmax: if self.weight_softmax:
weight = F.softmax(weight.float(), dim=1).type_as(weight) weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
if incremental_state is not None: if incremental_state is not None:
weight = weight[:, -x_unfold.size(2):] weight = weight[:, -x_unfold.size(2):]
...@@ -192,7 +198,7 @@ class LightweightConv1dTBC(nn.Module): ...@@ -192,7 +198,7 @@ class LightweightConv1dTBC(nn.Module):
weight = self.weight.view(H, K) weight = self.weight.view(H, K)
if self.weight_softmax: if self.weight_softmax:
weight = F.softmax(weight.float(), dim=1).type_as(weight) weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as(weight)
weight = weight.view(1, H, K).expand(T*B, H, K).contiguous() weight = weight.view(1, H, K).expand(T*B, H, K).contiguous()
weight = weight.view(T, B*H, K).transpose(0, 1) weight = weight.view(T, B*H, K).transpose(0, 1)
......
...@@ -184,7 +184,9 @@ class MultiheadAttention(nn.Module): ...@@ -184,7 +184,9 @@ class MultiheadAttention(nn.Module):
).type_as(attn_weights) # FP16 support: cast to float and back ).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) attn_weights = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace,
).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v) attn = torch.bmm(attn_weights, v)
......
...@@ -16,7 +16,11 @@ from . import FairseqOptimizer, register_optimizer ...@@ -16,7 +16,11 @@ from . import FairseqOptimizer, register_optimizer
class FairseqAdam(FairseqOptimizer): class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args, params)
self._optimizer = Adam(params, **self.optimizer_config) try:
from apex.optimizers import FusedAdam
self._optimizer = FusedAdam(params, **self.optimizer_config)
except ImportError:
self._optimizer = Adam(params, **self.optimizer_config)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
......
...@@ -92,4 +92,7 @@ class FairseqOptimizer(object): ...@@ -92,4 +92,7 @@ class FairseqOptimizer(object):
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
for group in self.optimizer.param_groups:
for p in group['params']:
p.grad = None
self.optimizer.zero_grad() self.optimizer.zero_grad()
...@@ -205,11 +205,8 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -205,11 +205,8 @@ class FP16Optimizer(optim.FairseqOptimizer):
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params: for p in self.params:
if p.grad is not None: p.grad = None
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False self._needs_sync = False
......
...@@ -4,15 +4,17 @@ ...@@ -4,15 +4,17 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict
import importlib.util import importlib.util
import logging import logging
import os import os
import re import re
import sys import sys
import traceback import traceback
from collections import defaultdict, OrderedDict
import torch import torch
import torch.nn.functional as F
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
...@@ -447,3 +449,17 @@ def import_user_module(args): ...@@ -447,3 +449,17 @@ def import_user_module(args):
sys.path.insert(0, module_parent) sys.path.insert(0, module_parent)
importlib.import_module(module_name) importlib.import_module(module_name)
sys.path.pop(0) sys.path.pop(0)
def softmax(x, dim, onnx_trace=False):
if onnx_trace:
return F.softmax(x.float(), dim=dim)
else:
return F.softmax(x, dim=dim, dtype=torch.float32)
def log_softmax(x, dim, onnx_trace=False):
if onnx_trace:
return F.log_softmax(x.float(), dim=dim)
else:
return F.log_softmax(x, dim=dim, dtype=torch.float32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment