testing.py 4.9 KB
Newer Older
wangwei990215's avatar
wangwei990215 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import time
import logging
import argparse

# nn
import torch

try:
    import torch_musa
except ImportError as e:
    print("You should install torch_musa if you want to run on Moore Threads GPU")

from mooer.models import mooer_model
from mooer.utils.utils import get_device
from mooer.utils.config_utils import parse_asr_configs
from mooer.datasets.speech_dataset_shard import SpeechDatasetShard


def parse_args():
    parser = argparse.ArgumentParser(description='DeepSpeed Training Script')
    parser.add_argument('--test_config', type=str, required=True, help='Path to the testing configuration file.')
    parser.add_argument('--test_data_dir', type=str, default='', help='Path to the testing sets.')
    parser.add_argument('--test_sets', type=str, default='', help='test_sets in test_data_dir, e.g, aishell1/aishell2')
    parser.add_argument('--decode_path', type=str, required=True, help='Path to save decode text and compute wer')
    args = parser.parse_args()
    return args


def main():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        filemode='w'
    )

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    args = parse_args()
    device = str(get_device())
    
    configs = parse_asr_configs(args.test_config)
    train_config = configs['TrainConfig']
    model_config = configs['ModelConfig']
    dataset_config = configs['DataConfig']
    # reset test epoch
    logger.info("set epoch_num=1 for testing")
    dataset_config.num_epochs = 1
    
    # update paths
    if os.path.exists(args.test_data_dir):
        dataset_config.test_data_dir = args.test_data_dir
        dataset_config.test_sets = args.test_sets
    
    os.makedirs(args.decode_path, exist_ok=True)
    
    
    model, tokenizer = mooer_model.init_model(
        model_config=model_config,
        train_config=train_config)
    
    model.to(device)
    model.eval()
    
    # dataset_config = generate_dataset_config(train_config, kwargs)
    logger.info("dataset_config: {}".format(dataset_config))
    
    test_data_dir = dataset_config.test_data_dir
    test_sets = dataset_config.test_sets
    decode_path = args.decode_path
    
    for test_set in test_sets.strip().split('/'):
        test_set_path = os.path.join(test_data_dir, test_set, "data.list")
        decode_dir = os.path.join(decode_path, test_set)
        os.makedirs(decode_dir, exist_ok=True)
        logging.info(f"Test for {test_set_path}")
        if dataset_config.get('test_data_type', 'shard') == 'shard':
            logging.info("Use shard for training...")
            dataset_test_items = SpeechDatasetShard(dataset_config=dataset_config,
                                                    tokenizer=tokenizer,
                                                    normalize=dataset_config.normalize,
                                                    mel_size=dataset_config.mel_size)
            dataset_test = dataset_test_items.dataset(
                data_type='shard',
                data_list_file=test_set_path,
                shuffle=False,
                partition=False
            )
            collator = dataset_test_items.collator
        
            test_dataloader = torch.utils.data.DataLoader(
                dataset_test,
                num_workers=train_config.num_workers_dataloader,
                pin_memory=True,
                shuffle=False,
                batch_size=train_config.val_batch_size,
                drop_last=False,
                collate_fn=collator
            )
        
        else:
            raise KeyError
        
        logger.info("=====================================")
        pred_path = os.path.join(decode_dir, 'text')
        ss = time.perf_counter()
        dtype = torch.float32
        if train_config.use_fp16:
            dtype = torch.float16
        elif train_config.use_bf16:
            dtype = torch.bfloat16
        logging.info(f"Input data type: {dtype}")
        with torch.no_grad():
            with open(pred_path, "w") as pred:
                for step, batch in enumerate(test_dataloader):
                    for key in batch.keys():
                        batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
                    with torch.cuda.amp.autocast(dtype=dtype):
                        model_outputs = model.generate(**batch)
                    output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
                                                               skip_special_tokens=True)
                    for key, text, target in zip(batch["keys"], output_text, batch["targets"]):
                        logging.info(f"{key} {text}")
                        pred.write(key + "\t" + text + "\n")
        logging.info(f"Infer {test_set} Cost: {time.perf_counter() - ss}")


if __name__ == "__main__":
    main()