"src/vscode:/vscode.git/clone" did not exist on "0227ddfb66421164834879619ff7fd8a5c6f8960"
Commit 6ce55e4b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Small fixes

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/835

Differential Revision: D16904038

Pulled By: myleott

fbshipit-source-id: 2c9d0b913f8d688297ac80fcabd905bd1397f66a
parent 2eb53b8e
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension from torch.utils.cpp_extension import CUDAExtension, BuildExtension
......
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
from .lightconv_layer import LightconvLayer from .lightconv_layer import LightconvLayer # noqa
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
def gen_forward(): def gen_forward():
...@@ -13,9 +11,10 @@ def gen_forward(): ...@@ -13,9 +11,10 @@ def gen_forward():
head = """ head = """
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* *
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
#include "lightconv_cuda.cuh" #include "lightconv_cuda.cuh"
...@@ -118,9 +117,10 @@ def gen_backward(): ...@@ -118,9 +117,10 @@ def gen_backward():
head = """ head = """
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* *
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
#include "lightconv_cuda.cuh" #include "lightconv_cuda.cuh"
......
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
......
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved. *
* * This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
#include <ATen/ATen.h> #include <ATen/ATen.h>
......
/** /**
* Copyright (c) 2018-present, Facebook, Inc. * Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved. *
* * This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/ */
#include "lightconv_cuda.cuh" #include "lightconv_cuda.cuh"
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
import torch.nn.functional as F import torch.nn.functional as F
import time
import lightconv_cuda import lightconv_cuda
from fairseq import utils from fairseq import utils
class lightconvFunction(Function): class lightconvFunction(Function):
@staticmethod @staticmethod
...@@ -26,6 +31,7 @@ class lightconvFunction(Function): ...@@ -26,6 +31,7 @@ class lightconvFunction(Function):
grad_input, grad_weights = outputs grad_input, grad_weights = outputs
return grad_input, grad_weights, None return grad_input, grad_weights, None
class LightconvLayer(nn.Module): class LightconvLayer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -82,7 +88,7 @@ class LightconvLayer(nn.Module): ...@@ -82,7 +88,7 @@ class LightconvLayer(nn.Module):
weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1)
weight = F.dropout(weight, self.weight_dropout, training=self.training) weight = F.dropout(weight, self.weight_dropout, training=self.training)
output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 output = torch.bmm(x_unfold, weight) # T*B*H x R x 1
output = output.view(T, B, C) output = output.view(T, B, C)
return output return output
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension from torch.utils.cpp_extension import CUDAExtension, BuildExtension
......
...@@ -10,6 +10,7 @@ import torch.nn.functional as F ...@@ -10,6 +10,7 @@ import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.modules.unfold import unfold1d from fairseq.modules.unfold import unfold1d
def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1, def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
weight_dropout=0., weight_softmax=False, bias=False): weight_dropout=0., weight_softmax=False, bias=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -26,6 +27,7 @@ def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1, ...@@ -26,6 +27,7 @@ def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
weight_dropout=weight_dropout, weight_dropout=weight_dropout,
weight_softmax=weight_softmax, bias=bias) weight_softmax=weight_softmax, bias=bias)
class LightweightConv1d(nn.Module): class LightweightConv1d(nn.Module):
'''Lightweight Convolution assuming the input is BxCxT '''Lightweight Convolution assuming the input is BxCxT
This is just an example that explains LightConv clearer than the TBC version. This is just an example that explains LightConv clearer than the TBC version.
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import torch.nn.functional as F import torch.nn.functional as F
def unfold1d(x, kernel_size, padding_l, pad_value=0): def unfold1d(x, kernel_size, padding_l, pad_value=0):
'''unfold T x B x C to T x B x C x K''' '''unfold T x B x C to T x B x C x K'''
if kernel_size > 1: if kernel_size > 1:
......
...@@ -121,9 +121,9 @@ def main(): ...@@ -121,9 +121,9 @@ def main():
num = args.num_epoch_checkpoints num = args.num_epoch_checkpoints
assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \ assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
'--checkpoint-upper-bound requires --num-epoch-checkpoints' '--checkpoint-upper-bound requires --num-epoch-checkpoints'
assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \ assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints' 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
if num is not None: if num is not None:
args.inputs = last_n_checkpoints( args.inputs = last_n_checkpoints(
......
#!/usr/bin/env python #!/usr/bin/env python
"""Helper script to compare two argparse.Namespace objects.""" """Helper script to compare two argparse.Namespace objects."""
from argparse import Namespace from argparse import Namespace # noqa
def main(): def main():
......
...@@ -10,7 +10,6 @@ document in a large file. Documents should be separated by a single empty line. ...@@ -10,7 +10,6 @@ document in a large file. Documents should be separated by a single empty line.
import argparse import argparse
import gzip import gzip
import random
import sys import sys
import numpy as np import numpy as np
......
...@@ -10,8 +10,6 @@ should be separated by a single empty line. ...@@ -10,8 +10,6 @@ should be separated by a single empty line.
import argparse import argparse
import contextlib import contextlib
import random
import sys
def main(): def main():
......
...@@ -19,6 +19,8 @@ def main(): ...@@ -19,6 +19,8 @@ def main():
parser.add_argument('sample_output', help='train output file') parser.add_argument('sample_output', help='train output file')
parser.add_argument('remainder_output', help='valid output file') parser.add_argument('remainder_output', help='valid output file')
parser.add_argument('-k', type=int, help="remainder size") parser.add_argument('-k', type=int, help="remainder size")
parser.add_argument('--lines', action='store_true',
help='split lines instead of docs')
args = parser.parse_args() args = parser.parse_args()
assert args.k is not None assert args.k is not None
...@@ -48,6 +50,8 @@ def main(): ...@@ -48,6 +50,8 @@ def main():
update_sample(doc) update_sample(doc)
else: else:
doc.append(line) doc.append(line)
if args.lines:
update_sample(doc)
if i % 1000000 == 0: if i % 1000000 == 0:
print(i, file=sys.stderr, end="", flush=True) print(i, file=sys.stderr, end="", flush=True)
elif i % 100000 == 0: elif i % 100000 == 0:
...@@ -61,7 +65,7 @@ def main(): ...@@ -61,7 +65,7 @@ def main():
with open(args.sample_output, 'w', encoding='utf-8') as out: with open(args.sample_output, 'w', encoding='utf-8') as out:
first = True first = True
for doc in sample: for doc in sample:
if not first: if not first and not args.lines:
out.write("\n") out.write("\n")
first = False first = False
for line in doc: for line in doc:
...@@ -70,7 +74,7 @@ def main(): ...@@ -70,7 +74,7 @@ def main():
with open(args.remainder_output, 'w', encoding='utf-8') as out: with open(args.remainder_output, 'w', encoding='utf-8') as out:
first = True first = True
for doc in remainder: for doc in remainder:
if not first: if not first and not args.lines:
out.write("\n") out.write("\n")
first = False first = False
for line in doc: for line in doc:
......
...@@ -30,7 +30,7 @@ def main(): ...@@ -30,7 +30,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
assert len(args.inputs) == len(args.outputs), \ assert len(args.inputs) == len(args.outputs), \
"number of input and output paths should match" "number of input and output paths should match"
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.Load(args.model) sp.Load(args.model)
......
""" Helper script to pre-compute embeddings for a wav2letter++ dataset #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Helper script to pre-compute embeddings for a wav2letter++ dataset
""" """
import glob, os import argparse
import tqdm import glob
import os
from shutil import copy from shutil import copy
import soundfile as sf
import h5py import h5py
import soundfile as sf
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
import tqdm
from fairseq.models.wav2vec import Wav2VecModel from fairseq.models.wav2vec import Wav2VecModel
import argparse
def read_audio(fname): def read_audio(fname):
""" Load an audio file and return PCM along with the sample rate """ """ Load an audio file and return PCM along with the sample rate """
...@@ -228,4 +233,4 @@ if __name__ == "__main__": ...@@ -228,4 +233,4 @@ if __name__ == "__main__":
if not args.no_copy_labels: if not args.no_copy_labels:
print("Copying label data...") print("Copying label data...")
writer.copy_labels() writer.copy_labels()
print("Done.") print("Done.")
\ No newline at end of file
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
import unittest import unittest
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion
from .asr_test_base import CrossEntropyCriterionTestBase from .asr_test_base import CrossEntropyCriterionTestBase
......
...@@ -14,7 +14,6 @@ import torch ...@@ -14,7 +14,6 @@ import torch
from torch import nn from torch import nn
from scripts.average_checkpoints import average_checkpoints from scripts.average_checkpoints import average_checkpoints
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment