Commit 8fcdb9b7 authored by Myle Ott's avatar Myle Ott
Browse files

Fix Flake8

parent 2a84f46b
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math
import pickle import pickle
import torch.distributed import torch.distributed
......
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
Train a network on multiple GPUs. Train a network on multiple GPUs.
""" """
import math
import torch import torch
from fairseq import optim from fairseq import optim
......
...@@ -357,6 +357,19 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -357,6 +357,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x, avg_attn_scores return x, avg_attn_scores
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if encoder_out is not None:
encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -401,23 +414,6 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -401,23 +414,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x return x
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if encoder_out is not None:
def update_enc_out(enc_out):
return enc_out.index_select(0, new_order)
encoder_out = tuple([update_enc_out(eo) for eo in encoder_out])
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1) m.weight.data.normal_(0, 0.1)
......
...@@ -373,6 +373,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -373,6 +373,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return x, attn_scores return x, attn_scores
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is None: if cached_state is None:
return return
...@@ -382,16 +383,17 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -382,16 +383,17 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return [reorder_state(state_i) for state_i in state] return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order) return state.index_select(0, new_order)
if not isinstance(new_order, Variable):
new_order = Variable(new_order)
new_state = tuple(map(reorder_state, cached_state)) new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple( encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order) for eo in encoder_out_dict['encoder_out']) eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None: if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask'].index_select(1, new_order) encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict return encoder_out_dict
def max_positions(self): def max_positions(self):
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -15,7 +16,6 @@ from fairseq.modules import ( ...@@ -15,7 +16,6 @@ from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
) )
from fairseq import utils
from . import ( from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel, FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
...@@ -220,6 +220,12 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -220,6 +220,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return x, attn return x, attn
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -233,11 +239,6 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -233,11 +239,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return state_dict return state_dict
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block.
...@@ -312,7 +313,6 @@ class TransformerDecoderLayer(nn.Module): ...@@ -312,7 +313,6 @@ class TransformerDecoderLayer(nn.Module):
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state): def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x residual = x
x = self.maybe_layer_norm(0, x, before=True) x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
key=x, key=x,
...@@ -327,7 +327,6 @@ class TransformerDecoderLayer(nn.Module): ...@@ -327,7 +327,6 @@ class TransformerDecoderLayer(nn.Module):
residual = x residual = x
x = self.maybe_layer_norm(1, x, before=True) x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn( x, attn = self.encoder_attn(
query=x, query=x,
key=encoder_out, key=encoder_out,
...@@ -336,7 +335,6 @@ class TransformerDecoderLayer(nn.Module): ...@@ -336,7 +335,6 @@ class TransformerDecoderLayer(nn.Module):
incremental_state=incremental_state, incremental_state=incremental_state,
static_kv=True, static_kv=True,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(1, x, after=True) x = self.maybe_layer_norm(1, x, after=True)
......
...@@ -10,8 +10,6 @@ import re ...@@ -10,8 +10,6 @@ import re
import torch import torch
from fairseq import dictionary
SPACE_NORMALIZER = re.compile("\s+") SPACE_NORMALIZER = re.compile("\s+")
......
...@@ -11,7 +11,7 @@ Train a network across multiple GPUs. ...@@ -11,7 +11,7 @@ Train a network across multiple GPUs.
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from itertools import chain from itertools import chain
import math
import torch import torch
from fairseq import distributed_utils, optim, utils from fairseq import distributed_utils, optim, utils
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import collections import collections
import itertools
import os import os
import math import math
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment