Commit 8fcdb9b7 authored by Myle Ott's avatar Myle Ott
Browse files

Fix Flake8

parent 2a84f46b
......@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import pickle
import torch.distributed
......
......@@ -9,8 +9,6 @@
Train a network on multiple GPUs.
"""
import math
import torch
from fairseq import optim
......
......@@ -357,6 +357,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x, avg_attn_scores
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if encoder_out is not None:
encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
......@@ -401,23 +414,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if encoder_out is not None:
def update_enc_out(enc_out):
return enc_out.index_select(0, new_order)
encoder_out = tuple([update_enc_out(eo) for eo in encoder_out])
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(0, 0.1)
......
......@@ -373,6 +373,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return x, attn_scores
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is None:
return
......@@ -382,16 +383,17 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
if not isinstance(new_order, Variable):
new_order = Variable(new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order) for eo in encoder_out_dict['encoder_out'])
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self):
......
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -15,7 +16,6 @@ from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
)
from fairseq import utils
from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
......@@ -220,6 +220,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return x, attn
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
......@@ -233,11 +239,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return state_dict
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
......@@ -312,22 +313,20 @@ class TransformerDecoderLayer(nn.Module):
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
mask_future_timesteps=True,
incremental_state=incremental_state,
need_weights=False,
)
query=x,
key=x,
value=x,
mask_future_timesteps=True,
incremental_state=incremental_state,
need_weights=False,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
......@@ -336,7 +335,6 @@ class TransformerDecoderLayer(nn.Module):
incremental_state=incremental_state,
static_kv=True,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
......
......@@ -10,8 +10,6 @@ import re
import torch
from fairseq import dictionary
SPACE_NORMALIZER = re.compile("\s+")
......
......@@ -11,7 +11,7 @@ Train a network across multiple GPUs.
from collections import defaultdict, OrderedDict
from itertools import chain
import math
import torch
from fairseq import distributed_utils, optim, utils
......
......@@ -7,7 +7,6 @@
# can be found in the PATENTS file in the same directory.
import collections
import itertools
import os
import math
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment