eval.py 2.93 KB
Newer Older
Pingchuan Ma's avatar
Pingchuan Ma committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import logging
from argparse import ArgumentParser

import sentencepiece as spm
import torch
import torchaudio
from transforms import get_data_module


logger = logging.getLogger(__name__)


def compute_word_level_distance(seq1, seq2):
    return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())


def get_lightning_module(args):
    sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
    if args.md == "av":
        from lightning_av import AVConformerRNNTModule

        model = AVConformerRNNTModule(args, sp_model)
    else:
        from lightning import ConformerRNNTModule

        model = ConformerRNNTModule(args, sp_model)
    ckpt = torch.load(args.checkpoint_path, map_location=lambda storage, loc: storage)["state_dict"]
    model.load_state_dict(ckpt)
    model.eval()
    return model


def run_eval(model, data_module):
    total_edit_distance = 0
    total_length = 0
    dataloader = data_module.test_dataloader()
    with torch.no_grad():
        for idx, (batch, sample) in enumerate(dataloader):
            actual = sample[0][-1]
            predicted = model(batch)
            total_edit_distance += compute_word_level_distance(actual, predicted)
            total_length += len(actual.split())
            if idx % 100 == 0:
                logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
    logger.warning(f"Final WER: {total_edit_distance / total_length}")
    return total_edit_distance / total_length


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        "--md",
        type=str,
        help="Modality",
        required=True,
    )
    parser.add_argument(
        "--mode",
        type=str,
        help="Perform online or offline recognition.",
        required=True,
    )
    parser.add_argument(
64
        "--root-dir",
Pingchuan Ma's avatar
Pingchuan Ma committed
65
        type=str,
66
        help="Root directory to LRS3 audio-visual datasets.",
Pingchuan Ma's avatar
Pingchuan Ma committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        required=True,
    )
    parser.add_argument(
        "--sp-model-path",
        type=str,
        help="Path to SentencePiece model.",
        required=True,
    )
    parser.add_argument(
        "--checkpoint-path",
        type=str,
        help="Path to checkpoint model.",
        required=True,
    )
    parser.add_argument(
        "--pretrained-model-path",
        type=str,
        help="Path to Pretraned model.",
    )
    parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
    return parser.parse_args()


def init_logger(debug):
    fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")


def cli_main():
    args = parse_args()
    init_logger(args.debug)
    model = get_lightning_module(args)
    data_module = get_data_module(args, str(args.sp_model_path))
    run_eval(model, data_module)


if __name__ == "__main__":
    cli_main()