Unverified Commit 0ff6d266 authored by Pingchuan Ma's avatar Pingchuan Ma Committed by GitHub
Browse files

replace avsr model used in the tutorial (#3602)

* replace model used in the tutorial

* Upload a torchscript model;Remove model components

* Update download url

* Switch from download_url_to_file to download_asset
parent 94aafd83
...@@ -27,18 +27,11 @@ attention due to its robustness against noise. ...@@ -27,18 +27,11 @@ attention due to its robustness against noise.
.. note:: .. note::
We do not have any pre-trained models available at this time. The To run this tutorial, please make sure you are in the `tutorial` folder.
following recipe uses placedholders for the sentencepiece model path
``spm_model_path`` and the pretrained model path ``avsr_model_path``.
If you are interested in the training recipe for real-time AV-ASR
models (AV-ASR), it can be found at `real-time
AV-ASR <https://github.com/pytorch/audio/tree/main/examples/avsr>`__
recipe.
.. note:: .. note::
To run this tutorial, please make sure you are in the `tutorial` folder. We tested the tutorial on torchaudio version 2.0.2 on Macbook Pro (M1 Pro).
""" """
...@@ -48,6 +41,7 @@ import torch ...@@ -48,6 +41,7 @@ import torch
import torchaudio import torchaudio
import torchvision import torchvision
###################################################################### ######################################################################
# Overview # Overview
# -------- # --------
...@@ -172,7 +166,6 @@ class Preprocessing(torch.nn.Module): ...@@ -172,7 +166,6 @@ class Preprocessing(torch.nn.Module):
), ),
FunctionalModule(lambda x: torch.stack(x)), FunctionalModule(lambda x: torch.stack(x)),
torchvision.transforms.Normalize(0.0, 255.0), torchvision.transforms.Normalize(0.0, 255.0),
torchvision.transforms.CenterCrop(44),
torchvision.transforms.Grayscale(), torchvision.transforms.Grayscale(),
torchvision.transforms.Normalize(0.421, 0.165), torchvision.transforms.Normalize(0.421, 0.165),
) )
...@@ -203,30 +196,6 @@ class Preprocessing(torch.nn.Module): ...@@ -203,30 +196,6 @@ class Preprocessing(torch.nn.Module):
# .. image:: https://download.pytorch.org/torchaudio/doc-assets/avsr/architecture.png # .. image:: https://download.pytorch.org/torchaudio/doc-assets/avsr/architecture.png
# #
from avsr.models.fusion import fusion_module
from avsr.models.resnet import video_resnet
from avsr.models.resnet1d import audio_resnet
class AVSR(torch.nn.Module):
def __init__(
self,
audio_frontend,
video_frontend,
fusion,
model,
):
super().__init__()
self.audio_frontend = audio_frontend
self.video_frontend = video_frontend
self.fusion = fusion
self.model = model
def forward(self, audio, video):
audio_features = self.audio_frontend(audio)
video_features = self.video_frontend(video)
return self.fusion(torch.cat([video_features, audio_features], dim=-1))
class SentencePieceTokenProcessor: class SentencePieceTokenProcessor:
def __init__(self, sp_model): def __init__(self, sp_model):
...@@ -269,20 +238,8 @@ class InferencePipeline(torch.nn.Module): ...@@ -269,20 +238,8 @@ class InferencePipeline(torch.nn.Module):
return transcript return transcript
def _get_inference_pipeline(avsr_model_config, avsr_model_path, spm_model_path): def _get_inference_pipeline(model_path, spm_model_path):
model = AVSR( model = torch.jit.load(model_path)
audio_frontend=audio_resnet(),
video_frontend=video_resnet(),
fusion=fusion_module(
1024,
avsr_model_config["transformer_ffn_dim"],
avsr_model_config["input_dim"],
avsr_model_config["transformer_dropout"],
),
model=torchaudio.models.emformer_rnnt_model(**avsr_model_config),
)
ckpt = torch.load(avsr_model_path, map_location=lambda storage, loc: storage)["state_dict"]
model.load_state_dict(ckpt)
model.eval() model.eval()
sp_model = spm.SentencePieceProcessor(model_file=spm_model_path) sp_model = spm.SentencePieceProcessor(model_file=spm_model_path)
...@@ -310,35 +267,15 @@ def _get_inference_pipeline(avsr_model_config, avsr_model_path, spm_model_path): ...@@ -310,35 +267,15 @@ def _get_inference_pipeline(avsr_model_config, avsr_model_path, spm_model_path):
# 4. Clean up # 4. Clean up
# #
from torchaudio.utils import download_asset
def main(device, src, option=None): def main(device, src, option=None):
print("Building pipeline...") print("Building pipeline...")
spm_model_path = "../avsr/spm_unigram_1023.model" model_path = download_asset("tutorial-assets/device_avsr_model.pt")
avsr_model_path = "../avsr/online_avsr_model.pth" spm_model_path = download_asset("tutorial-assets/spm_unigram_1023.model")
avsr_model_config = {
"input_dim": 512, pipeline = _get_inference_pipeline(model_path, spm_model_path)
"encoding_dim": 1024,
"segment_length": 32,
"right_context_length": 4,
"time_reduction_input_dim": 768,
"time_reduction_stride": 1,
"transformer_num_heads": 12,
"transformer_ffn_dim": 3072,
"transformer_num_layers": 20,
"transformer_dropout": 0.1,
"transformer_activation": "gelu",
"transformer_left_context_length": 30,
"transformer_max_memory_size": 0,
"transformer_weight_init_scale_strategy": "depthwise",
"transformer_tanh_on_mem": True,
"symbol_embedding_dim": 512,
"num_lstm_layers": 3,
"lstm_layer_norm": True,
"lstm_layer_norm_epsilon": 0.001,
"lstm_dropout": 0.3,
"num_symbols": 1024,
}
pipeline = _get_inference_pipeline(avsr_model_config, avsr_model_path, spm_model_path)
BUFFER_SIZE = 32 BUFFER_SIZE = 32
segment_length = 8 segment_length = 8
...@@ -367,9 +304,9 @@ def main(device, src, option=None): ...@@ -367,9 +304,9 @@ def main(device, src, option=None):
video = torch.cat(video_chunks) video = torch.cat(video_chunks)
audio = torch.cat(audio_chunks) audio = torch.cat(audio_chunks)
video, audio = cacher(video, audio) video, audio = cacher(video, audio)
pipeline.state, pipeline.hypothesis = None, None pipeline.state, pipeline.hypotheses = None, None
transcript = pipeline(audio, video.float()) transcript = pipeline(audio, video.float())
print(transcript, end="\r", flush=True) print(transcript, end="", flush=True)
num_video_frames = 0 num_video_frames = 0
video_chunks = [] video_chunks = []
audio_chunks = [] audio_chunks = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment