{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", "import numpy as np\n", "import collections\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from matplotlib import cm as cm\n", "from IPython.display import Audio, display, clear_output, Markdown, Image\n", "#import librosa\n", "#import librosa.display\n", "import ipywidgets as widgets\n", "# \n", "from tacotron2.text import text_to_sequence as text_to_sequence_internal\n", "from tacotron2.text.symbols import symbols\n", "# \n", "import tritonhttpclient as thc\n", "\n", "defaults = {\n", " # settings\n", " 'sigma_infer': 0.6, # don't touch this\n", " 'sampling_rate': 22050, # don't touch this\n", " 'stft_hop_length': 256, # don't touch this\n", " 'url': 'localhost:8000', # don't touch this\n", " 'autoplay': True, # autoplay\n", " 'character_limit_min': 4, # don't touch this\n", " 'character_limit_max': 340 # don't touch this\n", "}\n", "\n", "\n", "# create args object\n", "class Struct:\n", " def __init__(self, **entries):\n", " self.__dict__.update(entries)\n", "\n", "args = Struct(**defaults)\n", "\n", "triton_client = thc.InferenceServerClient(args.url)\n", "\n", "def display_sound(signal, title, color):\n", " ''' displays signal '''\n", " clear_output(wait=True)\n", " plt.figure(figsize=(10, 2.5))\n", " plt.title(title)\n", " plt.tick_params(\n", " axis='both',\n", " which='both',\n", " bottom=True,\n", " top=False,\n", " left=False,\n", " right=False,\n", " labelbottom=True,\n", " labelleft=False)\n", " # librosa.display.waveplot(signal, color=color)\n", " sig = signal[0]\n", " hop = args.stft_hop_length\n", " smoothed = []\n", " for i in range(0, len(sig), hop):\n", " smoothed.append(np.average(sig[i:i+hop]))\n", " plt.plot(smoothed, color=color)\n", " plt.show()\n", "\n", "\n", "def display_spectrogram(mel, title):\n", " ''' displays mel spectrogram '''\n", " clear_output(wait=True)\n", " fig = plt.figure(figsize=(10, 2.5))\n", " ax = fig.add_subplot(111)\n", " plt.title(title)\n", " plt.tick_params(\n", " axis='both',\n", " which='both',\n", " bottom=True,\n", " top=False,\n", " left=False,\n", " right=False,\n", " labelbottom=True,\n", " labelleft=False)\n", " plt.xlabel('Time')\n", " cmap = cm.get_cmap('jet', 30)\n", " cax = ax.imshow(mel[0].astype(np.float32), interpolation=\"nearest\", cmap=cmap)\n", " ax.grid(True)\n", " plt.show()\n", "\n", "\n", "def text_to_sequence(text):\n", " ''' preprocessor of tacotron2\n", " ::text:: the input str\n", " ::returns:: sequence, the preprocessed text\n", " '''\n", " sequence = text_to_sequence_internal(text, ['english_cleaners'])\n", " sequence = np.array(sequence, dtype=np.int64)\n", " return sequence\n", "\n", "\n", "def sequence_to_mel(sequence):\n", " ''' calls tacotron2\n", " ::sequence:: int64 numpy array, contains the preprocessed text\n", " ::returns:: (mel, mel_lengths, alignments) tuple\n", " mel is the mel-spectrogram, np.array\n", " mel_lengths contains the length of the unpadded mel, np.array\n", " alignments contains attention weigths, np.array\n", " '''\n", " sequence = np.reshape(sequence, (1, -1))\n", " input_lengths = np.array([[len(sequence[0])]], dtype=np.int64)\n", " # prepare input/output\n", " inputs = []\n", " inputs.append(thc.InferInput('input__0', sequence.shape, 'INT64'))\n", " inputs.append(thc.InferInput('input__1', input_lengths.shape, 'INT64'))\n", " inputs[0].set_data_from_numpy(sequence, binary_data=True)\n", " inputs[1].set_data_from_numpy(input_lengths, binary_data=True)\n", " outputs = []\n", " outputs.append(thc.InferRequestedOutput('output__0', binary_data=True))\n", " outputs.append(thc.InferRequestedOutput('output__1', binary_data=True))\n", " outputs.append(thc.InferRequestedOutput('output__2', binary_data=True))\n", " # call tacotron2\n", " result = triton_client.infer(model_name=\"tacotron2-ts-script\", inputs=inputs, outputs=outputs)\n", " # get results\n", " mel = result.as_numpy('output__0')\n", " mel_lengths = result.as_numpy('output__1')\n", " alignments = result.as_numpy('output__2')\n", " return mel, mel_lengths, alignments\n", "\n", "\n", "def mel_to_signal(mel, mel_lengths):\n", " ''' calls waveglow\n", " ::mel:: mel spectrogram\n", " ::mel_lengths:: original length of mel spectrogram\n", " ::returns:: waveform\n", " '''\n", " # prepare input/output\n", " mel = mel[:,:,:,None]\n", " stride = 256\n", " n_group = 8\n", " z_size = mel.shape[2]*stride//n_group\n", " shape = (1, n_group, z_size, 1)\n", " z = np.random.normal(0.0, 1.0, shape).astype(mel.dtype)\n", " \n", " inputs = []\n", " inputs.append(thc.InferInput('mel', mel.shape, 'FP16'))\n", " inputs.append(thc.InferInput('z', z.shape, 'FP16'))\n", " inputs[0].set_data_from_numpy(mel, binary_data=True)\n", " inputs[1].set_data_from_numpy(z, binary_data=True)\n", " outputs = []\n", " outputs.append(thc.InferRequestedOutput('audio', binary_data=True))\n", " # call waveglow\n", " result = triton_client.infer(model_name=\"waveglow-tensorrt\", inputs=inputs, outputs=outputs)\n", " # get the results\n", " signal = result.as_numpy('audio')\n", " # postprocessing of waveglow: trimming signal to its actual size\n", " trimmed_length = mel.shape[2]*args.stft_hop_length\n", " signal = signal[:trimmed_length] # trim\n", " signal = signal.astype(np.float32)\n", " return signal\n", "\n", "\n", "# widgets\n", "def get_output_widget(width, height):\n", " ''' creates an output widget with default values and returns it '''\n", " layout = widgets.Layout(width=width,\n", " height=height,\n", " object_fit='fill',\n", " object_position = '{center} {center}')\n", " ret = widgets.Output(layout=layout)\n", " return ret\n", "\n", "\n", "text_area = widgets.Textarea(\n", " value='type here',\n", " placeholder='',\n", " description='',\n", " disabled=False,\n", " continuous_update=True,\n", " layout=widgets.Layout(width='550px', height='80px')\n", ")\n", "\n", "\n", "plot_spectrogram = get_output_widget(width='10in',height='2.1in')\n", "plot_signal = get_output_widget(width='10in',height='2.1in')\n", "plot_play = get_output_widget(width='10in',height='1in')\n", "\n", "\n", "def text_area_change(change):\n", " ''' this gets called each time text_area.value changes '''\n", " text = change['new']\n", " text = text.strip(' ')\n", " length = len(text)\n", " if length < args.character_limit_min: # too short text\n", " return\n", " if length > args.character_limit_max: # too long text\n", " text_area.value = text[:args.character_limit_max]\n", " return\n", " # preprocess tacotron2\n", " sequence = text_to_sequence(text)\n", " # run tacotron2\n", " mel, mel_lengths, alignments = sequence_to_mel(sequence)\n", " with plot_spectrogram:\n", " display_spectrogram(mel, change['new'])\n", " # run waveglow\n", " signal = mel_to_signal(mel, mel_lengths)\n", " with plot_signal:\n", " display_sound(signal, change['new'], 'green')\n", " with plot_play:\n", " clear_output(wait=True)\n", " display(Audio(signal, rate=args.sampling_rate, autoplay=args.autoplay))\n", " # related issue: https://github.com/ipython/ipython/issues/11316\n", "\n", "\n", "# setup callback\n", "text_area.observe(text_area_change, names='value')\n", "\n", "# decorative widgets\n", "empty = widgets.VBox([], layout=widgets.Layout(height='1in'))\n", "markdown_4 = Markdown('**tacotron2 input**')\n", "markdown_6 = Markdown('**tacotron2 output / waveglow input**')\n", "markdown_7 = Markdown('**waveglow output**')\n", "markdown_8 = Markdown('**play**')\n", "\n", "# display widgets\n", "display(\n", " empty, \n", " markdown_4, text_area, \n", " markdown_6, plot_spectrogram, \n", " markdown_7, plot_signal, \n", " markdown_8, plot_play, \n", " empty\n", ")\n", "\n", "# default text\n", "text_area.value = \"The forms of printed letters should be beautiful, and that their arrangement on the page should be reasonable and a help to the shapeliness of the letters themselves.\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }