# ***************************************************************************** # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of the NVIDIA CORPORATION nor the # names of its contributors may be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # ***************************************************************************** import numpy as np from scipy.io.wavfile import read import torch import os import argparse import json class ParseFromConfigFile(argparse.Action): def __init__(self, option_strings, type, dest, help=None, required=False): super(ParseFromConfigFile, self).__init__(option_strings=option_strings, type=type, dest=dest, help=help, required=required) def __call__(self, parser, namespace, values, option_string): with open(values, 'r') as f: data = json.load(f) for group in data.keys(): for k,v in data[group].items(): underscore_k = k.replace('-', '_') setattr(namespace, underscore_k, v) def get_mask_from_lengths(lengths): max_len = torch.max(lengths).item() ids = torch.arange(0, max_len, device=lengths.device, dtype=lengths.dtype) mask = (ids < lengths.unsqueeze(1)).byte() mask = torch.le(mask, 0) return mask def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) return torch.FloatTensor(data.astype(np.float32)), sampling_rate def load_filepaths_and_text(dataset_path, filename, split="|"): with open(filename, encoding='utf-8') as f: def split_line(root, line): parts = line.strip().split(split) if len(parts) > 2: raise Exception( "incorrect line format for file: {}".format(filename)) path = os.path.join(root, parts[0]) text = parts[1] return path,text filepaths_and_text = [split_line(dataset_path, line) for line in f] return filepaths_and_text def to_gpu(x): x = x.contiguous() if torch.cuda.is_available(): x = x.cuda(non_blocking=True) return x