Commit 6ce55e4b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Small fixes

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/835

Differential Revision: D16904038

Pulled By: myleott

fbshipit-source-id: 2c9d0b913f8d688297ac80fcabd905bd1397f66a
parent 2eb53b8e
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .lightconv_layer import LightconvLayer
from .lightconv_layer import LightconvLayer # noqa
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
def gen_forward():
......@@ -13,9 +11,10 @@ def gen_forward():
head = """
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "lightconv_cuda.cuh"
......@@ -118,9 +117,10 @@ def gen_backward():
head = """
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "lightconv_cuda.cuh"
......
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <vector>
......
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
......
/**
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "lightconv_cuda.cuh"
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F
import time
import lightconv_cuda
from fairseq import utils
class lightconvFunction(Function):
@staticmethod
......@@ -26,6 +31,7 @@ class lightconvFunction(Function):
grad_input, grad_weights = outputs
return grad_input, grad_weights, None
class LightconvLayer(nn.Module):
def __init__(
self,
......@@ -82,7 +88,7 @@ class LightconvLayer(nn.Module):
weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1)
weight = F.dropout(weight, self.weight_dropout, training=self.training)
output = torch.bmm(x_unfold, weight) # T*B*H x R x 1
output = torch.bmm(x_unfold, weight) # T*B*H x R x 1
output = output.view(T, B, C)
return output
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
......
......@@ -10,6 +10,7 @@ import torch.nn.functional as F
from fairseq import utils
from fairseq.modules.unfold import unfold1d
def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
weight_dropout=0., weight_softmax=False, bias=False):
if torch.cuda.is_available():
......@@ -26,6 +27,7 @@ def LightweightConv(input_size, kernel_size=1, padding_l=None, num_heads=1,
weight_dropout=weight_dropout,
weight_softmax=weight_softmax, bias=bias)
class LightweightConv1d(nn.Module):
'''Lightweight Convolution assuming the input is BxCxT
This is just an example that explains LightConv clearer than the TBC version.
......
......@@ -5,6 +5,7 @@
import torch.nn.functional as F
def unfold1d(x, kernel_size, padding_l, pad_value=0):
'''unfold T x B x C to T x B x C x K'''
if kernel_size > 1:
......
......@@ -121,9 +121,9 @@ def main():
num = args.num_epoch_checkpoints
assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
'--checkpoint-upper-bound requires --num-epoch-checkpoints'
'--checkpoint-upper-bound requires --num-epoch-checkpoints'
assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
if num is not None:
args.inputs = last_n_checkpoints(
......
#!/usr/bin/env python
"""Helper script to compare two argparse.Namespace objects."""
from argparse import Namespace
from argparse import Namespace # noqa
def main():
......
......@@ -10,7 +10,6 @@ document in a large file. Documents should be separated by a single empty line.
import argparse
import gzip
import random
import sys
import numpy as np
......
......@@ -10,8 +10,6 @@ should be separated by a single empty line.
import argparse
import contextlib
import random
import sys
def main():
......
......@@ -19,6 +19,8 @@ def main():
parser.add_argument('sample_output', help='train output file')
parser.add_argument('remainder_output', help='valid output file')
parser.add_argument('-k', type=int, help="remainder size")
parser.add_argument('--lines', action='store_true',
help='split lines instead of docs')
args = parser.parse_args()
assert args.k is not None
......@@ -48,6 +50,8 @@ def main():
update_sample(doc)
else:
doc.append(line)
if args.lines:
update_sample(doc)
if i % 1000000 == 0:
print(i, file=sys.stderr, end="", flush=True)
elif i % 100000 == 0:
......@@ -61,7 +65,7 @@ def main():
with open(args.sample_output, 'w', encoding='utf-8') as out:
first = True
for doc in sample:
if not first:
if not first and not args.lines:
out.write("\n")
first = False
for line in doc:
......@@ -70,7 +74,7 @@ def main():
with open(args.remainder_output, 'w', encoding='utf-8') as out:
first = True
for doc in remainder:
if not first:
if not first and not args.lines:
out.write("\n")
first = False
for line in doc:
......
......@@ -30,7 +30,7 @@ def main():
args = parser.parse_args()
assert len(args.inputs) == len(args.outputs), \
"number of input and output paths should match"
"number of input and output paths should match"
sp = spm.SentencePieceProcessor()
sp.Load(args.model)
......
""" Helper script to pre-compute embeddings for a wav2letter++ dataset
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Helper script to pre-compute embeddings for a wav2letter++ dataset
"""
import glob, os
import tqdm
import argparse
import glob
import os
from shutil import copy
import soundfile as sf
import h5py
import soundfile as sf
import numpy as np
import torch
from torch import nn
import tqdm
from fairseq.models.wav2vec import Wav2VecModel
import argparse
def read_audio(fname):
""" Load an audio file and return PCM along with the sample rate """
......@@ -228,4 +233,4 @@ if __name__ == "__main__":
if not args.no_copy_labels:
print("Copying label data...")
writer.copy_labels()
print("Done.")
\ No newline at end of file
print("Done.")
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion
from .asr_test_base import CrossEntropyCriterionTestBase
......
......@@ -14,7 +14,6 @@ import torch
from torch import nn
from scripts.average_checkpoints import average_checkpoints
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment