eval.py 2.83 KB
Newer Older
Pingchuan Ma's avatar
Pingchuan Ma committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import logging
from argparse import ArgumentParser

import sentencepiece as spm
import torch
import torchaudio
from transforms import get_data_module


logger = logging.getLogger(__name__)


def compute_word_level_distance(seq1, seq2):
    return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())


def get_lightning_module(args):
    sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
Pingchuan Ma's avatar
Pingchuan Ma committed
19
    if args.modality == "audiovisual":
Pingchuan Ma's avatar
Pingchuan Ma committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        from lightning_av import AVConformerRNNTModule

        model = AVConformerRNNTModule(args, sp_model)
    else:
        from lightning import ConformerRNNTModule

        model = ConformerRNNTModule(args, sp_model)
    ckpt = torch.load(args.checkpoint_path, map_location=lambda storage, loc: storage)["state_dict"]
    model.load_state_dict(ckpt)
    model.eval()
    return model


def run_eval(model, data_module):
    total_edit_distance = 0
    total_length = 0
    dataloader = data_module.test_dataloader()
    with torch.no_grad():
        for idx, (batch, sample) in enumerate(dataloader):
            actual = sample[0][-1]
            predicted = model(batch)
            total_edit_distance += compute_word_level_distance(actual, predicted)
            total_length += len(actual.split())
            if idx % 100 == 0:
                logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
    logger.warning(f"Final WER: {total_edit_distance / total_length}")
    return total_edit_distance / total_length


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
Pingchuan Ma's avatar
Pingchuan Ma committed
52
        "--modality",
Pingchuan Ma's avatar
Pingchuan Ma committed
53
54
55
56
57
58
59
60
61
62
63
        type=str,
        help="Modality",
        required=True,
    )
    parser.add_argument(
        "--mode",
        type=str,
        help="Perform online or offline recognition.",
        required=True,
    )
    parser.add_argument(
64
        "--root-dir",
Pingchuan Ma's avatar
Pingchuan Ma committed
65
        type=str,
66
        help="Root directory to LRS3 audio-visual datasets.",
Pingchuan Ma's avatar
Pingchuan Ma committed
67
68
69
70
71
        required=True,
    )
    parser.add_argument(
        "--sp-model-path",
        type=str,
Pingchuan Ma's avatar
Pingchuan Ma committed
72
        help="Path to sentencepiece model.",
Pingchuan Ma's avatar
Pingchuan Ma committed
73
74
75
76
77
        required=True,
    )
    parser.add_argument(
        "--checkpoint-path",
        type=str,
Pingchuan Ma's avatar
Pingchuan Ma committed
78
        help="Path to a checkpoint model.",
Pingchuan Ma's avatar
Pingchuan Ma committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        required=True,
    )
    parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
    return parser.parse_args()


def init_logger(debug):
    fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")


def cli_main():
    args = parse_args()
    init_logger(args.debug)
    model = get_lightning_module(args)
    data_module = get_data_module(args, str(args.sp_model_path))
    run_eval(model, data_module)


if __name__ == "__main__":
    cli_main()