"packaging/vscode:/vscode.git/clone" did not exist on "6b242c29af317498ea51b465529cf8e68c2c88fd"
Commit 2f59ed25 authored by burchim's avatar burchim
Browse files

Demo Notebook

parent 19482bf0
__pycache__ __pycache__
*.pyc *.pyc
.DS_Store
.vscode .vscode
.DS_Store .DS_Store
env/* env/*
......
{"nbformat":4,"nbformat_minor":0,"metadata":{"accelerator":"GPU","colab":{"name":"EfficientConformer.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"yA9TPERPtBUL"},"source":["#Efficient Conformer Demo\n","A quick intro to using pretrained models and how to train/evaluate models.<br>\n","repo: [https://github.com/burchim/EfficientConformer](https://github.com/burchim/EfficientConformer)"]},{"cell_type":"markdown","metadata":{"id":"-I6v5ThmlRVp"},"source":["# Install"]},{"cell_type":"code","metadata":{"id":"ugmpSZEa3g13"},"source":["!git clone https://github.com/burchim/EfficientConformer.git "],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"iIvKGdgElYih"},"source":["import os\n","os.chdir('EfficientConformer/')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ClzKWI_TFKZK"},"source":["!pip install -r requirements.txt"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"xr-rBVhM7mBP"},"source":["!git clone --recursive https://github.com/parlance/ctcdecode.git\n","!cd ctcdecode && pip install ."],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"61bc416smy2E"},"source":["# Download pretrained models and tokenizer"]},{"cell_type":"code","metadata":{"id":"l3KrXieZqSDm"},"source":["!pip install gdown"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9Wg0BtnPTIpW"},"source":["pretrained_models = {\n"," \"EfficientConformerCTCSmall\": \"1MU49nbRONkOOGzvXHFDNfvWsyFmrrBam\",\n"," \"EfficientConformerCTCMedium\": \"1h5hRG9T_nErslm5eGgVzqx7dWDcOcGDB\",\n"," \"EfficientConformerCTCLarge\": \"1U4iBTKQogX4btE-S4rqCeeFZpj3gcweA\"\n","}"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"okSOe0wTT9zp"},"source":["# Select one of the official pretrained models\n","pretrained_model = \"EfficientConformerCTCSmall\""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ihk81osmsHAz"},"source":["import gdown\n","\n","# Create model callback directory\n","if not os.path.exists(os.path.join(\"callbacks\", pretrained_model)):\n"," os.mkdir(os.path.join(\"callbacks\", pretrained_model))\n","\n","# Download pretrained model checkpoint\n","gdown.download(\"https://drive.google.com/uc?id=\" + pretrained_models[pretrained_model], os.path.join(\"callbacks\", pretrained_model, \"checkpoints_swa-equal-401-450.ckpt\"), quiet=False)\n","\n","# Create tokenizer directory\n","if not os.path.exists(\"datasets/LibriSpeech\"):\n"," os.mkdir(\"datasets/LibriSpeech\")\n","\n","# Download pretrained model tokenizer\n","gdown.download(\"https://drive.google.com/uc?id=1hx2s4ZTDsnOFtx5_h5R_qZ3R6gEFafRx\", \"datasets/LibriSpeech/LibriSpeech_bpe_256.model\", quiet=False)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PdXrPEoaslUq"},"source":["# Test model on LibriSpeech samples"]},{"cell_type":"code","metadata":{"id":"G9TRAOYhGKHH"},"source":["# Download LibriSPeech dev-clean subset\n","!cd datasets && wget https://www.openslr.org/resources/12/dev-clean.tar.gz && tar xzf dev-clean.tar.gz\n","\n","# Download LibriSPeech dev-other subset\n","!cd datasets && wget https://www.openslr.org/resources/12/dev-other.tar.gz && tar xzf dev-other.tar.gz"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NaaXsV62ux3X"},"source":["import json\n","import glob\n","import torch\n","import torchaudio\n","import IPython.display as ipd\n","from functions import create_model\n","import matplotlib.pyplot as plt\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"V0dY7IBquiRC"},"source":["config_file = \"configs/\" + pretrained_model + \".json\"\n","\n","# Load model Config\n","with open(config_file) as json_config:\n"," config = json.load(json_config)\n","\n","# PyTorch Device\n","device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","print(\"Device:\", device)\n","\n","# Create and Load pretrained model\n","model = create_model(config).to(device)\n","model.summary()\n","model.eval()\n","model.load(os.path.join(\"callbacks\", pretrained_model, \"checkpoints_swa-equal-401-450.ckpt\"))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"_EidY8gF2Z0_"},"source":["# Get audio files paths\n","audio_files = glob.glob(\"datasets/LibriSpeech/*/*/*/*.flac\")\n","print(len(audio_files), \"audio files\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9qMnfHxZzqvX"},"source":["# Random indices\n","indices = torch.randint(0, len(audio_files), size=(10,))\n","\n","# Test model\n","for i in indices:\n","\n"," # Load audio file\n"," audio, sr = torchaudio.load(audio_files[i])\n","\n"," # Plot audio\n"," plt.title(audio_files[i].split(\"/\")[-1])\n"," plt.plot(audio[0])\n"," plt.show()\n"," print()\n","\n"," # Display\n"," ipd.display(ipd.Audio(audio, rate=sr))\n"," print()\n","\n"," # Predict sentence\n"," prediction = model.gready_search_decoding(audio.to(device), x_len=torch.tensor([len(audio[0])], device=device))[0]\n"," print(\"model prediction:\", prediction, '\\n')\n","\n"," for i in range(100):\n"," print('*', end='')\n"," print('\\n')\n"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zCc9S6BgQ1M5"},"source":["# Training\n","Download the LibriSpeech dataset using:\n","\n","- `cd datasets && bash ./download_LibriSpeech.sh`\n","\n","Or download LibriSpeech train-clean 100h subset with:\n","\n","- `cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf datasets/train-clean-100.tar.gz`"]},{"cell_type":"code","metadata":{"id":"y7A4c0x0RnJq"},"source":["# Download LibriSPeech train-clean-100 subset\n","!cd datasets && wget https://www.openslr.org/resources/12/train-clean-100.tar.gz && tar xzf train-clean-100.tar.gz"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"HdYQ6GDeca6M"},"source":["Train an Efficient Conformer CTC Small model.<br>\n","The `--prepare_dataset` flag will tokenize text sequences and save samples length before training/evaluation.<br>\n","Use the `--create_tokenizer` flag if you need to create a new sentencepiece tokenizer.<br>\n","Training mode is selected by default."]},{"cell_type":"code","metadata":{"id":"0HTj66OxQ4in"},"source":["# Prepare dataset and train model\n","!python main.py --config_file configs/EfficientConformerCTCSmall.json --prepare_dataset"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Ea0W2NwVbOCe"},"source":["# Evaluation\n","Proceed to a gready search evaluation.\n","Use the `--mode` flag to select an evaluation mode:\n","\n","- `validation-clean` for evaluation on the LibriSpeech dev-clean validation set.\n","- `validation-other` for evaluation on the LibriSpeech dev-other validation set.\n","- `test-clean` for evaluation on the LibriSpeech test-clean test set.\n","- `test-other` for evaluation on the LibriSpeech test-other test set.\n","- `eval_time` to evaluate model inference time on the LibriSpeech dev-clean validation set.\n","\n","Select a model checkpoint to load for evaluation using the `--initial_epoch` flag.<br>\n","For example, `--initial_epoch swa-equal-401-450` will load the pretrained checkpoints_swa-equal-401-450.ckpt file."]},{"cell_type":"code","metadata":{"id":"SXwRILONbbkG"},"source":["!python main.py --config_file configs/EfficientConformerCTCSmall.json --mode validation-clean --initial_epoch swa-equal-401-450 --gready"],"execution_count":null,"outputs":[]}]}
\ No newline at end of file
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
Official implementation of the Efficient Conformer, progressively downsampled Conformer with grouped attention for Automatic Speech Recognition. Official implementation of the Efficient Conformer, progressively downsampled Conformer with grouped attention for Automatic Speech Recognition.
**Efficient Conformer [Paper](https://arxiv.org/abs/2109.01163) | [Demo Notebook](https://colab.research.google.com/github/burchim/EfficientConformer/blob/master/EfficientConformer.ipynb)**
## Efficient Conformer Encoder ## Efficient Conformer Encoder
Inspired from previous works done in Automatic Speech Recognition and Computer Vision, the Efficient Conformer encoder is composed of three encoder stages where each stage comprises a number of Conformer blocks using grouped attention. The encoded sequence is progressively downsampled and projected to wider feature dimensions, lowering the amount of computation while achieving better performance. Grouped multi-head attention reduce attention complexity by grouping neighbouring time elements along the feature dimension before applying scaled dot-product attention. Inspired from previous works done in Automatic Speech Recognition and Computer Vision, the Efficient Conformer encoder is composed of three encoder stages where each stage comprises a number of Conformer blocks using grouped attention. The encoded sequence is progressively downsampled and projected to wider feature dimensions, lowering the amount of computation while achieving better performance. Grouped multi-head attention reduce attention complexity by grouping neighbouring time elements along the feature dimension before applying scaled dot-product attention.
......
...@@ -423,7 +423,7 @@ class MultiHeadSelfAttentionModule(nn.Module): ...@@ -423,7 +423,7 @@ class MultiHeadSelfAttentionModule(nn.Module):
# Pre Norm # Pre Norm
self.norm = nn.LayerNorm(dim_model, eps=1e-6) self.norm = nn.LayerNorm(dim_model, eps=1e-6)
# Efficient Multi-Head Attention # Multi-Head Linear Attention
if linear_att: if linear_att:
self.mhsa = MultiHeadLinearAttention(dim_model, num_heads) self.mhsa = MultiHeadLinearAttention(dim_model, num_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment