Commit 6ce55e4b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Small fixes

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/835

Differential Revision: D16904038

Pulled By: myleott

fbshipit-source-id: 2c9d0b913f8d688297ac80fcabd905bd1397f66a
parent 2eb53b8e
...@@ -200,8 +200,10 @@ def main(parsed_args): ...@@ -200,8 +200,10 @@ def main(parsed_args):
is_bpe = False is_bpe = False
w = '' w = ''
if args.output_word_probs: if args.output_word_probs:
print(str(int(sample_id)) + " " + print(
('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))) str(int(sample_id)) + " "
+ ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
)
wps_meter.update(sample['ntokens']) wps_meter.update(sample['ntokens'])
t.log({'wps': round(wps_meter.avg)}) t.log({'wps': round(wps_meter.avg)})
......
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
__version__ = '0.7.2' __version__ = '0.8.0'
import examples.noisychannel # noqa import examples.noisychannel # noqa
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
from .rerank_options import * from .rerank_options import * # noqa
...@@ -77,9 +77,11 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write ...@@ -77,9 +77,11 @@ def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write
for key in range(len(gen_keys)): for key in range(len(gen_keys)):
if args.prefix_len is None: if args.prefix_len is None:
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \ assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
("pred and rescore hypo mismatch: i: " + str(key) + ", " + str(hypo_lst[key]) + str(gen_keys[key]) + "pred and rescore hypo mismatch: i: " + str(key) + ", "
str(gen_output.no_bpe_hypo[key])) + str(hypo_lst[key]) + str(gen_keys[key])
+ str(gen_output.no_bpe_hypo[key])
)
sys_tok = dict.encode_line(hypo_lst[key]) sys_tok = dict.encode_line(hypo_lst[key])
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
scorer.add(ref_tok, sys_tok) scorer.add(ref_tok, sys_tok)
......
#!/usr/bin/env python3 -u #!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
import rerank_utils """
Generate n-best translations using a trained model.
"""
from contextlib import redirect_stdout
import os import os
import subprocess import subprocess
import rerank_utils
from examples.noisychannel import rerank_options from examples.noisychannel import rerank_options
from fairseq import options from fairseq import options
import generate import generate
import preprocess import preprocess
from contextlib import redirect_stdout
"""
Generate n-best translations using a trained model.
"""
def gen_and_reprocess_nbest(args): def gen_and_reprocess_nbest(args):
if args.score_dict_dir is None: if args.score_dict_dir is None:
......
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
from fairseq import options from fairseq import options
......
...@@ -27,12 +27,14 @@ def random_search(args): ...@@ -27,12 +27,14 @@ def random_search(args):
param_values += initial_params param_values += initial_params
random.seed(args.seed) random.seed(args.seed)
random_params = np.array([[random.uniform(args.lower_bound[i], args.upper_bound[i]) random_params = np.array([
for i in range(len(args.tune_param))] [random.uniform(args.lower_bound[i], args.upper_bound[i]) for i in range(len(args.tune_param))]
for k in range(args.num_trials)]) for k in range(args.num_trials)
set_params = np.array([[initial_params[i][0] ])
for i in range(len(tuneable_parameters))] set_params = np.array([
for k in range(args.num_trials)]) [initial_params[i][0] for i in range(len(tuneable_parameters))]
for k in range(args.num_trials)
])
random_params = np.concatenate((random_params, set_params), 1) random_params = np.concatenate((random_params, set_params), 1)
rerank_args = vars(args).copy() rerank_args = vars(args).copy()
......
...@@ -128,8 +128,8 @@ def write_reprocessed(sources, hypos, targets, source_outfile, ...@@ -128,8 +128,8 @@ def write_reprocessed(sources, hypos, targets, source_outfile,
"in writing reprocessed, only one type of prefix may be used" "in writing reprocessed, only one type of prefix may be used"
with open(source_outfile, 'w') as source_file, \ with open(source_outfile, 'w') as source_file, \
open(hypo_outfile, 'w') as hypo_file, \ open(hypo_outfile, 'w') as hypo_file, \
open(target_outfile, 'w') as target_file: open(target_outfile, 'w') as target_file:
assert len(sources) == len(hypos), "sources and hypos list length mismatch" assert len(sources) == len(hypos), "sources and hypos list length mismatch"
if right_to_left: if right_to_left:
......
...@@ -270,6 +270,7 @@ class WinograndeTask(WSCTask): ...@@ -270,6 +270,7 @@ class WinograndeTask(WSCTask):
Task for WinoGrande dataset. Efficient implementation for Winograd schema Task for WinoGrande dataset. Efficient implementation for Winograd schema
tasks with exactly two candidates, one of which is correct. tasks with exactly two candidates, one of which is correct.
""" """
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
assert args.criterion == 'winogrande', 'Must set --criterion=winogrande' assert args.criterion == 'winogrande', 'Must set --criterion=winogrande'
...@@ -280,7 +281,6 @@ class WinograndeTask(WSCTask): ...@@ -280,7 +281,6 @@ class WinograndeTask(WSCTask):
return cls(args, vocab) return cls(args, vocab)
def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs): def load_dataset(self, split, epoch=0, combine=False, data_path=None, return_only=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
...@@ -299,7 +299,7 @@ class WinograndeTask(WSCTask): ...@@ -299,7 +299,7 @@ class WinograndeTask(WSCTask):
candidate_masks = [] candidate_masks = []
candidate_lengths = [] candidate_lengths = []
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=split=='test') itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == 'test'))
for sample in itr: for sample in itr:
sentence, pronoun_span, query, cand_text = sample sentence, pronoun_span, query, cand_text = sample
......
...@@ -13,7 +13,7 @@ from fairseq import utils ...@@ -13,7 +13,7 @@ from fairseq import utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqModel, FairseqEncoderDecoderModel,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
...@@ -23,7 +23,7 @@ from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VG ...@@ -23,7 +23,7 @@ from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VG
@register_model("asr_vggtransformer") @register_model("asr_vggtransformer")
class VGGTransformerModel(FairseqModel): class VGGTransformerModel(FairseqEncoderDecoderModel):
""" """
Transformers with convolutional context for ASR Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660 https://arxiv.org/abs/1904.11660
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
import sys
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -174,6 +173,7 @@ class LightConvModel(FairseqEncoderDecoderModel): ...@@ -174,6 +173,7 @@ class LightConvModel(FairseqEncoderDecoderModel):
decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens) decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens)
return LightConvModel(encoder, decoder) return LightConvModel(encoder, decoder)
class LightConvEncoder(FairseqEncoder): class LightConvEncoder(FairseqEncoder):
""" """
LightConv encoder consisting of *args.encoder_layers* layers. Each layer LightConv encoder consisting of *args.encoder_layers* layers. Each layer
......
...@@ -10,14 +10,12 @@ from .character_token_embedder import CharacterTokenEmbedder ...@@ -10,14 +10,12 @@ from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv, DynamicConv1dTBC from .dynamic_convolution import DynamicConv, DynamicConv1dTBC
#from .dynamicconv_layer import DynamicconvLayer
from .gelu import gelu, gelu_accurate from .gelu import gelu, gelu_accurate
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
from .highway import Highway from .highway import Highway
from .layer_norm import LayerNorm from .layer_norm import LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .lightweight_convolution import LightweightConv, LightweightConv1dTBC from .lightweight_convolution import LightweightConv, LightweightConv1dTBC
#from .lightconv_layer import LightconvLayer
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
from .logsumexp_moe import LogSumExpMoE from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork from .mean_pool_gating_network import MeanPoolGatingNetwork
...@@ -38,7 +36,6 @@ __all__ = [ ...@@ -38,7 +36,6 @@ __all__ = [
'CharacterTokenEmbedder', 'CharacterTokenEmbedder',
'ConvTBC', 'ConvTBC',
'DownsampledMultiHeadAttention', 'DownsampledMultiHeadAttention',
# 'DyamicconvLayer',
'DynamicConv1dTBC', 'DynamicConv1dTBC',
'DynamicConv', 'DynamicConv',
'gelu', 'gelu',
...@@ -47,7 +44,6 @@ __all__ = [ ...@@ -47,7 +44,6 @@ __all__ = [
'Highway', 'Highway',
'LayerNorm', 'LayerNorm',
'LearnedPositionalEmbedding', 'LearnedPositionalEmbedding',
# 'LightconvLayer',
'LightweightConv1dTBC', 'LightweightConv1dTBC',
'LightweightConv', 'LightweightConv',
'LinearizedConvolution', 'LinearizedConvolution',
......
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved. *
* * This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
......
...@@ -10,6 +10,7 @@ import torch.nn.functional as F ...@@ -10,6 +10,7 @@ import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from .unfold import unfold1d from .unfold import unfold1d
def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1, def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
weight_dropout=0., weight_softmax=False, weight_dropout=0., weight_softmax=False,
renorm_padding=False, bias=False, conv_bias=False, renorm_padding=False, bias=False, conv_bias=False,
...@@ -28,6 +29,7 @@ def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1, ...@@ -28,6 +29,7 @@ def DynamicConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
weight_dropout=weight_dropout, weight_dropout=weight_dropout,
weight_softmax=weight_softmax, bias=bias) weight_softmax=weight_softmax, bias=bias)
def Linear(in_features, out_features, bias=True): def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias) m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
...@@ -209,7 +211,7 @@ class DynamicConv1dTBC(nn.Module): ...@@ -209,7 +211,7 @@ class DynamicConv1dTBC(nn.Module):
# turn the convolution filters into band matrices # turn the convolution filters into band matrices
weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False)
weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T
output = torch.bmm(weight_expanded, x) output = torch.bmm(weight_expanded, x)
output = output.transpose(0, 1).contiguous().view(T, B, C) output = output.transpose(0, 1).contiguous().view(T, B, C)
return output return output
......
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
from .dynamicconv_layer import DynamicconvLayer from .dynamicconv_layer import DynamicconvLayer # noqa
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
def gen_forward(): def gen_forward():
...@@ -13,9 +11,10 @@ def gen_forward(): ...@@ -13,9 +11,10 @@ def gen_forward():
head = """ head = """
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* *
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
#include "dynamicconv_cuda.cuh" #include "dynamicconv_cuda.cuh"
...@@ -103,9 +102,10 @@ def gen_backward(): ...@@ -103,9 +102,10 @@ def gen_backward():
head = """ head = """
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* *
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
#include "dynamicconv_cuda.cuh" #include "dynamicconv_cuda.cuh"
......
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
......
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved. *
* * This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
......
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved. *
* * This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
#include "dynamicconv_cuda.cuh" #include "dynamicconv_cuda.cuh"
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
import torch.nn.functional as F import torch.nn.functional as F
import dynamicconv_cuda import dynamicconv_cuda
from fairseq import utils from fairseq import utils
from fairseq.modules.unfold import unfold1d
class dynamicconvFunction(Function): class dynamicconvFunction(Function):
...@@ -68,7 +75,7 @@ class DynamicconvLayer(nn.Module): ...@@ -68,7 +75,7 @@ class DynamicconvLayer(nn.Module):
T, B, C = x.size() T, B, C = x.size()
K, H = self.kernel_size, self.num_heads K, H = self.kernel_size, self.num_heads
R = C // H # R = C // H
# during inference time, incremental BMM is faster # during inference time, incremental BMM is faster
if incremental_state is not None: if incremental_state is not None:
...@@ -199,7 +206,7 @@ class DynamicconvLayer(nn.Module): ...@@ -199,7 +206,7 @@ class DynamicconvLayer(nn.Module):
# turn the convolution filters into band matrices # turn the convolution filters into band matrices
weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False) weight_expanded = weight.new_zeros(B*H, T, T+K-1, requires_grad=False)
weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight) weight_expanded.as_strided((B*H, T, K), (T*(T+K-1), T+K, 1)).copy_(weight)
weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T
output = torch.bmm(weight_expanded, x) output = torch.bmm(weight_expanded, x)
output = output.transpose(0, 1).contiguous().view(T, B, C) output = output.transpose(0, 1).contiguous().view(T, B, C)
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment