Commit 2b68e91f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Lint

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/817

Differential Revision: D16762905

Pulled By: myleott

fbshipit-source-id: d920595bec44ed26b72dfc6fbc15c0aa107b4e56
parent 969f4474
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
__all__ = ['pdb'] __all__ = ['pdb']
__version__ = '0.7.2' __version__ = '0.7.2'
import fairseq.criterions import fairseq.criterions # noqa
import fairseq.models import fairseq.models # noqa
import fairseq.modules import fairseq.modules # noqa
import fairseq.optim import fairseq.optim # noqa
import fairseq.optim.lr_scheduler import fairseq.optim.lr_scheduler # noqa
import fairseq.pdb import fairseq.pdb # noqa
import fairseq.tasks import fairseq.tasks # noqa
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse
from collections import OrderedDict from collections import OrderedDict
from typing import Union from typing import Union
import collections import collections
......
...@@ -149,7 +149,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -149,7 +149,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
else: else:
# Hacky as heck, for the specific case of multilingual training with RoundRobin. # Hacky as heck, for the specific case of multilingual training with RoundRobin.
if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple): if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple):
return all(a is None or b is None or a <= b return all(
a is None or b is None or a <= b
for a, b in zip(size_fn(idx).values(), max_positions) for a, b in zip(size_fn(idx).values(), max_positions)
) )
# For MultiCorpusSampledDataset, will generalize it later # For MultiCorpusSampledDataset, will generalize it later
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from fairseq import file_utils from fairseq import file_utils
from fairseq.data.encoders import register_bpe from fairseq.data.encoders import register_bpe
@register_bpe('fastbpe') @register_bpe('fastbpe')
class fastBPE(object): class fastBPE(object):
......
...@@ -7,7 +7,6 @@ Original license: MIT ...@@ -7,7 +7,6 @@ Original license: MIT
from functools import lru_cache from functools import lru_cache
import json import json
import os
@lru_cache() @lru_cache()
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import namedtuple
import os import os
import pickle import pickle
import socket import socket
...@@ -12,7 +11,6 @@ import warnings ...@@ -12,7 +11,6 @@ import warnings
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn
from fairseq import utils from fairseq import utils
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
import argparse import argparse
import copy import copy
import os import os
from typing import List
import torch import torch
from torch import nn from torch import nn
......
...@@ -3,11 +3,6 @@ ...@@ -3,11 +3,6 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
import argparse import argparse
import importlib import importlib
import os import os
...@@ -28,6 +23,12 @@ from .composite_encoder import CompositeEncoder ...@@ -28,6 +23,12 @@ from .composite_encoder import CompositeEncoder
from .distributed_fairseq_model import DistributedFairseqModel from .distributed_fairseq_model import DistributedFairseqModel
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
__all__ = [ __all__ = [
'BaseFairseqModel', 'BaseFairseqModel',
'CompositeEncoder', 'CompositeEncoder',
...@@ -43,7 +44,6 @@ __all__ = [ ...@@ -43,7 +44,6 @@ __all__ = [
] ]
def build_model(args, task): def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task) return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
Base classes for various fairseq models. Base classes for various fairseq models.
""" """
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch
...@@ -259,6 +258,7 @@ class FairseqModel(FairseqEncoderDecoderModel): ...@@ -259,6 +258,7 @@ class FairseqModel(FairseqEncoderDecoderModel):
stacklevel=4, stacklevel=4,
) )
class FairseqMultiModel(BaseFairseqModel): class FairseqMultiModel(BaseFairseqModel):
"""Base class for combining multiple encoder-decoder models.""" """Base class for combining multiple encoder-decoder models."""
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
......
...@@ -20,7 +20,6 @@ from fairseq.models import ( ...@@ -20,7 +20,6 @@ from fairseq.models import (
from fairseq.modules import ( from fairseq.modules import (
AdaptiveSoftmax, AdaptiveSoftmax,
LayerNorm, LayerNorm,
MultiheadAttention,
PositionalEmbedding, PositionalEmbedding,
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
TransformerDecoderLayer, TransformerDecoderLayer,
...@@ -51,6 +50,7 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -51,6 +50,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def hub_models(cls): def hub_models(cls):
# fmt: off
return { return {
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2', 'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
...@@ -64,6 +64,7 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -64,6 +64,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz', 'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz',
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz', 'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz',
} }
# fmt: on
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
......
...@@ -351,7 +351,7 @@ class ConvAggegator(nn.Module): ...@@ -351,7 +351,7 @@ class ConvAggegator(nn.Module):
residual = x residual = x
x = conv(x) x = conv(x)
if self.skip_connections: if self.skip_connections:
if rproj != None: if rproj is not None:
residual = rproj(residual) residual = rproj(residual)
x = (x + residual) * self.residual_scale x = (x + residual) * self.residual_scale
return x return x
......
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
from fairseq import registry from fairseq import registry
from fairseq.optim.fairseq_optimizer import FairseqOptimizer from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from fairseq.optim.bmuf import FairseqBMUF from fairseq.optim.bmuf import FairseqBMUF # noqa
__all__ = [ __all__ = [
......
...@@ -19,7 +19,7 @@ class FairseqAdam(FairseqOptimizer): ...@@ -19,7 +19,7 @@ class FairseqAdam(FairseqOptimizer):
super().__init__(args, params) super().__init__(args, params)
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
from apex.optimizers import FusedAdam as _FusedAdam from apex.optimizers import FusedAdam as _FusedAdam # noqa
self._optimizer = FusedAdam(params, **self.optimizer_config) self._optimizer = FusedAdam(params, **self.optimizer_config)
except ImportError: except ImportError:
self._optimizer = Adam(params, **self.optimizer_config) self._optimizer = Adam(params, **self.optimizer_config)
......
...@@ -3,10 +3,6 @@ ...@@ -3,10 +3,6 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sys
import time
import torch import torch
import torch.distributed as dist import torch.distributed as dist
......
...@@ -350,7 +350,7 @@ def add_optimization_args(parser): ...@@ -350,7 +350,7 @@ def add_optimization_args(parser):
group.add_argument('--min-lr', default=-1, type=float, metavar='LR', group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
help='stop training when the learning rate reaches this minimum') help='stop training when the learning rate reaches this minimum')
group.add_argument('--use-bmuf', default=False, action='store_true', group.add_argument('--use-bmuf', default=False, action='store_true',
help="specify global optimizer for syncing models on different GPUs/Shards") help='specify global optimizer for syncing models on different GPUs/shards')
# fmt: on # fmt: on
return group return group
......
...@@ -11,7 +11,6 @@ from collections import OrderedDict ...@@ -11,7 +11,6 @@ from collections import OrderedDict
import json import json
from numbers import Number from numbers import Number
import os import os
import re
import sys import sys
from fairseq import distributed_utils from fairseq import distributed_utils
...@@ -19,6 +18,7 @@ from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter ...@@ -19,6 +18,7 @@ from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
g_tbmf_wrapper = None g_tbmf_wrapper = None
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'): def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
if args.log_format is None: if args.log_format is None:
args.log_format = no_progress_bar if args.no_progress_bar else default args.log_format = no_progress_bar if args.no_progress_bar else default
......
...@@ -7,7 +7,7 @@ import math ...@@ -7,7 +7,7 @@ import math
import torch import torch
from fairseq import search, utils from fairseq import search
from fairseq.models import FairseqIncrementalDecoder from fairseq.models import FairseqIncrementalDecoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment