Commit 2b68e91f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Lint

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/817

Differential Revision: D16762905

Pulled By: myleott

fbshipit-source-id: d920595bec44ed26b72dfc6fbc15c0aa107b4e56
parent 969f4474
......@@ -6,10 +6,10 @@
__all__ = ['pdb']
__version__ = '0.7.2'
import fairseq.criterions
import fairseq.models
import fairseq.modules
import fairseq.optim
import fairseq.optim.lr_scheduler
import fairseq.pdb
import fairseq.tasks
import fairseq.criterions # noqa
import fairseq.models # noqa
import fairseq.modules # noqa
import fairseq.optim # noqa
import fairseq.optim.lr_scheduler # noqa
import fairseq.pdb # noqa
import fairseq.tasks # noqa
......@@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from collections import OrderedDict
from typing import Union
import collections
......
......@@ -149,7 +149,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
else:
# Hacky as heck, for the specific case of multilingual training with RoundRobin.
if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple):
return all(a is None or b is None or a <= b
return all(
a is None or b is None or a <= b
for a, b in zip(size_fn(idx).values(), max_positions)
)
# For MultiCorpusSampledDataset, will generalize it later
......
......@@ -6,6 +6,7 @@
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
@register_bpe('fastbpe')
class fastBPE(object):
......
......@@ -7,7 +7,6 @@ Original license: MIT
from functools import lru_cache
import json
import os
@lru_cache()
......
......@@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
import os
import pickle
import socket
......@@ -12,7 +11,6 @@ import warnings
import torch
import torch.distributed as dist
from torch import nn
from fairseq import utils
......
......@@ -7,7 +7,6 @@
import argparse
import copy
import os
from typing import List
import torch
from torch import nn
......
......@@ -3,11 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
import argparse
import importlib
import os
......@@ -28,6 +23,12 @@ from .composite_encoder import CompositeEncoder
from .distributed_fairseq_model import DistributedFairseqModel
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
__all__ = [
'BaseFairseqModel',
'CompositeEncoder',
......@@ -43,7 +44,6 @@ __all__ = [
]
def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
......
......@@ -6,7 +6,6 @@
Base classes for various fairseq models.
"""
import os
from typing import Dict, List, Optional
import torch
......@@ -259,6 +258,7 @@ class FairseqModel(FairseqEncoderDecoderModel):
stacklevel=4,
)
class FairseqMultiModel(BaseFairseqModel):
"""Base class for combining multiple encoder-decoder models."""
......
......@@ -3,8 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import numpy as np
import torch
import torch.nn as nn
......
......@@ -20,7 +20,6 @@ from fairseq.models import (
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
MultiheadAttention,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
TransformerDecoderLayer,
......@@ -51,6 +50,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod
def hub_models(cls):
# fmt: off
return {
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
......@@ -64,6 +64,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz',
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz',
}
# fmt: on
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
......
......@@ -351,7 +351,7 @@ class ConvAggegator(nn.Module):
residual = x
x = conv(x)
if self.skip_connections:
if rproj != None:
if rproj is not None:
residual = rproj(residual)
x = (x + residual) * self.residual_scale
return x
......
......@@ -9,7 +9,7 @@ import os
from fairseq import registry
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from fairseq.optim.bmuf import FairseqBMUF
from fairseq.optim.bmuf import FairseqBMUF # noqa
__all__ = [
......
......@@ -19,7 +19,7 @@ class FairseqAdam(FairseqOptimizer):
super().__init__(args, params)
if torch.cuda.is_available():
try:
from apex.optimizers import FusedAdam as _FusedAdam
from apex.optimizers import FusedAdam as _FusedAdam # noqa
self._optimizer = FusedAdam(params, **self.optimizer_config)
except ImportError:
self._optimizer = Adam(params, **self.optimizer_config)
......
......@@ -3,10 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import time
import torch
import torch.distributed as dist
......
......@@ -350,7 +350,7 @@ def add_optimization_args(parser):
group.add_argument('--min-lr', default=-1, type=float, metavar='LR',
help='stop training when the learning rate reaches this minimum')
group.add_argument('--use-bmuf', default=False, action='store_true',
help="specify global optimizer for syncing models on different GPUs/Shards")
help='specify global optimizer for syncing models on different GPUs/shards')
# fmt: on
return group
......
......@@ -11,7 +11,6 @@ from collections import OrderedDict
import json
from numbers import Number
import os
import re
import sys
from fairseq import distributed_utils
......@@ -19,6 +18,7 @@ from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
g_tbmf_wrapper = None
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
if args.log_format is None:
args.log_format = no_progress_bar if args.no_progress_bar else default
......
......@@ -7,7 +7,7 @@ import math
import torch
from fairseq import search, utils
from fairseq import search
from fairseq.models import FairseqIncrementalDecoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment