export_asr_onnx.py 9.11 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# export_asr_onnx.py
import torch
import numpy as np
import soundfile as sf
from espnet2.bin.asr_inference import Speech2Text # Import the main inference class
# from espnet2.utils.fileio import read_kaldi_ascii_vec # Remove this unused import
import argparse

def get_parser():
    parser = argparse.ArgumentParser(description='Export ESPNet ASR model to ONNX')
    parser.add_argument('--asr_model_file', type=str, required=True, help='Path to the trained ASR model .pth file')
    parser.add_argument('--asr_config', type=str, required=True, help='Path to the ASR model config.yaml file')
    parser.add_argument('--output_encoder', type=str, default='asr_encoder.onnx', help='Output ONNX file for encoder')
    parser.add_argument('--output_decoder', type=str, default='asr_decoder.onnx', help='Output ONNX file for decoder')
    return parser

class ONNXEncoder(torch.nn.Module):
    def __init__(self, encoder):
        super(ONNXEncoder, self).__init__()
        self.encoder = encoder

    def forward(self, speech, speech_lengths):
        # The encoder might return more than just (encoder_out, encoder_out_lens)
        # Capture all returned values
        encoder_outputs = self.encoder(speech, speech_lengths)
        # Typically, the first two are encoder_out and encoder_out_lens
        encoder_out, encoder_out_lens = encoder_outputs[0], encoder_outputs[1]
        # If there are more outputs (like cache or intermediate results), ignore them for ONNX
        return encoder_out, encoder_out_lens

class ONNXDecoder(torch.nn.Module):
    def __init__(self, decoder, decoder_output_layer, ctc=None):
        super(ONNXDecoder, self).__init__()
        self.decoder = decoder
        self.decoder_output_layer = decoder_output_layer
        self.ctc = ctc

    def forward(self, encoder_out, encoder_out_lens, hyp_seq, hyp_len):
        """
        encoder_out: (B, T_enc, D_enc)
        encoder_out_lens: (B,)
        hyp_seq: (B, U) - Padded hypothesis sequence IDs (e.g., <sos>, id1, id2, ...)
        hyp_len: (B,) - Actual lengths of hyp_seq
        Returns:
            ctc_logprobs: (B, T_enc, VocabSize) if CTC exists
            att_logprobs: (B, VocabSize) for the *next* token prediction based on the whole hyp_seq
        """
        # Prepare masks
        batch_size, T_enc = encoder_out.size(0), encoder_out.size(1)
        U = hyp_seq.size(1)

        # src_mask (mask for encoder output based on encoder_out_lens)
        src_mask = (~make_pad_mask(encoder_out_lens, T_enc)).unsqueeze(1).unsqueeze(2).to(encoder_out.device) # (B, 1, 1, T_enc)

        # tgt_mask (mask for decoder input based on hyp_len)
        tgt_mask = (~make_pad_mask(hyp_len, U)).unsqueeze(1).unsqueeze(2).to(encoder_out.device) # (B, 1, U, U)
        # Ensure upper triangular part (future info) is masked in tgt_mask
        future_mask = torch.triu(torch.ones(U, U, device=tgt_mask.device), diagonal=1).bool()
        tgt_mask = tgt_mask & (~future_mask.unsqueeze(0))

        # Forward through decoder
        dec_out, _ = self.decoder(hyp_seq, tgt_mask, encoder_out, src_mask) # dec_out: (B, U, D_dec)

        # Calculate attention-based log probabilities for the *last* token position in the sequence
        last_token_states = dec_out[torch.arange(batch_size), hyp_len - 1] # (B, D_dec)
        # Apply the output layer to get logits, then log_softmax
        att_logits = self.decoder_output_layer(last_token_states) # (B, VocabSize)
        att_logprobs = att_logits.log_softmax(dim=-1) # (B, VocabSize)

        # Also return CTC log probs if available
        ctc_logprobs = self.ctc.log_softmax(encoder_out) if self.ctc else torch.empty(batch_size, T_enc, 0, device=encoder_out.device, dtype=encoder_out.dtype) # (B, T_enc, VocabSize)

        return ctc_logprobs, att_logprobs

def make_pad_mask(lengths, max_len=None):
    """Create padding mask."""
    batch_size = lengths.size(0)
    if max_len is None:
        max_len = lengths.max().item()
    seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1).expand_as(seq_range_expand)
    mask = seq_range_expand >= seq_length_expand
    return mask

def get_nested_attr(obj, attr_path):
    """Get a nested attribute using a dot-separated string path."""
    attrs = attr_path.split('.')
    current_obj = obj
    for attr in attrs:
        current_obj = getattr(current_obj, attr)
    return current_obj

def main():
    parser = get_parser()
    args = parser.parse_args()

    # Load the speech2text inference object
    speech2text = Speech2Text(
        asr_train_config=args.asr_config,
        asr_model_file=args.asr_model_file,
        device="cpu", # or "cuda"
        maxlenratio=0.0,
        minlenratio=0.0,
        batch_size=1,
        dtype="float32",
        beam_size=1, # Use beam_size=1 for ONNX export to avoid dynamic control flow
        ctc_weight=0.5, # or the value from your config (0.3)
        lm_weight=0.0, # Disable LM for this export
        penalty=0.0,
        nbest=1
    )

    # --- Access the internal ASR model ---
    checkpoint = torch.load(args.asr_model_file, map_location="cpu")
    from omegaconf import OmegaConf
    from espnet2.tasks.asr import ASRTask

    config = OmegaConf.load(args.asr_config)
    asr_model, *_ = ASRTask.build_model_from_file(
        args.asr_config,
        args.asr_model_file,
        device="cpu" # or "cuda"
    )
    asr_model.eval()

    print("Loaded internal ASR model successfully.")
    # print("Model attributes:", [attr for attr in dir(asr_model) if not attr.startswith('_')]) # Commented out as it was verbose

    # --- Find the correct output layer name ---
    vocab_size = len(config.token_list) # Get vocab size from config
    print(f"Expected vocab size: {vocab_size}")
    found_layer = False
    decoder_output_layer_path = "" # Store the full path
    for name, module in asr_model.named_modules():
        if isinstance(module, torch.nn.Linear):
            if module.out_features == vocab_size:
                print(f"Found potential output layer: {name} -> {module}")
                decoder_output_layer_path = name # e.g., 'decoder.output_layer'
                found_layer = True
                break
    if not found_layer:
        print("Could not find a Linear layer matching the vocab size. Please inspect the model manually.")
        # You might need to print the full model structure: print(asr_model)
        # Or iterate through children/modules more carefully.
        return # Exit if not found

    print(f"Using '{decoder_output_layer_path}' as the decoder output layer path.")


    # --- Export Encoder ---
    encoder_wrapper = ONNXEncoder(asr_model.encoder)
    dummy_speech = torch.randn(1, 200, 80, dtype=torch.float32) # (B, T, F) - Adjust T and F!
    dummy_speech_lengths = torch.LongTensor([200]) # Actual length in frames

    torch.onnx.export(
        encoder_wrapper,
        (dummy_speech, dummy_speech_lengths),
        args.output_encoder,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=['speech', 'speech_lengths'],
        output_names=['encoder_out', 'encoder_out_lens'],
        dynamic_axes={
            'speech': {0: 'batch_size', 1: 'time'},
            'speech_lengths': {0: 'batch_size'},
            'encoder_out': {0: 'batch_size', 1: 'time_enc'},
            'encoder_out_lens': {0: 'batch_size'}
        }
    )
    print(f"Encoder exported to {args.output_encoder}")

    # --- Export Decoder (Scoring Part) ---
    # Use the helper function to get the nested attribute
    target_layer = get_nested_attr(asr_model, decoder_output_layer_path)
    decoder_wrapper = ONNXDecoder(asr_model.decoder, target_layer, getattr(asr_model, 'ctc', None))

    dummy_encoder_out = torch.randn(1, 100, asr_model.encoder.output_size(), dtype=torch.float32) # (B, T_enc, D_enc)
    dummy_encoder_out_lens = torch.LongTensor([100])
    sos_id_val = asr_model.sos
    dummy_hyp_seq = torch.full((1, 5), 0, dtype=torch.long) # Initialize with padding ID (often 0)
    dummy_hyp_seq[0, 0] = sos_id_val # <sos>
    dummy_hyp_seq[0, 1] = 100 # Some dummy token ID
    dummy_hyp_seq[0, 2] = 200 # Another dummy token ID
    dummy_hyp_len = torch.LongTensor([3]) # Length of the actual sequence (excluding padding)

    dyn_axes_dec = {
        'encoder_out': {0: 'batch_size', 1: 'time_enc'},
        'encoder_out_lens': {0: 'batch_size'},
        'hyp_seq': {0: 'batch_size', 1: 'time_dec'},
        'hyp_len': {0: 'batch_size'},
        'ctc_logprobs': {0: 'batch_size', 1: 'time_enc'}, # Only dynamic if CTC is used
        'att_logprobs': {0: 'batch_size'}
    }

    torch.onnx.export(
        decoder_wrapper,
        (dummy_encoder_out, dummy_encoder_out_lens, dummy_hyp_seq, dummy_hyp_len),
        args.output_decoder,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=['encoder_out', 'encoder_out_lens', 'hyp_seq', 'hyp_len'],
        output_names=['ctc_logprobs', 'att_logprobs'], # Outputs for scoring
        dynamic_axes=dyn_axes_dec
    )
    print(f"Decoder (scoring part) exported to {args.output_decoder}")


if __name__ == "__main__":
    main()