# export_asr_onnx.py import torch import numpy as np import soundfile as sf from espnet2.bin.asr_inference import Speech2Text # Import the main inference class # from espnet2.utils.fileio import read_kaldi_ascii_vec # Remove this unused import import argparse def get_parser(): parser = argparse.ArgumentParser(description='Export ESPNet ASR model to ONNX') parser.add_argument('--asr_model_file', type=str, required=True, help='Path to the trained ASR model .pth file') parser.add_argument('--asr_config', type=str, required=True, help='Path to the ASR model config.yaml file') parser.add_argument('--output_encoder', type=str, default='asr_encoder.onnx', help='Output ONNX file for encoder') parser.add_argument('--output_decoder', type=str, default='asr_decoder.onnx', help='Output ONNX file for decoder') return parser class ONNXEncoder(torch.nn.Module): def __init__(self, encoder): super(ONNXEncoder, self).__init__() self.encoder = encoder def forward(self, speech, speech_lengths): # The encoder might return more than just (encoder_out, encoder_out_lens) # Capture all returned values encoder_outputs = self.encoder(speech, speech_lengths) # Typically, the first two are encoder_out and encoder_out_lens encoder_out, encoder_out_lens = encoder_outputs[0], encoder_outputs[1] # If there are more outputs (like cache or intermediate results), ignore them for ONNX return encoder_out, encoder_out_lens class ONNXDecoder(torch.nn.Module): def __init__(self, decoder, decoder_output_layer, ctc=None): super(ONNXDecoder, self).__init__() self.decoder = decoder self.decoder_output_layer = decoder_output_layer self.ctc = ctc def forward(self, encoder_out, encoder_out_lens, hyp_seq, hyp_len): """ encoder_out: (B, T_enc, D_enc) encoder_out_lens: (B,) hyp_seq: (B, U) - Padded hypothesis sequence IDs (e.g., , id1, id2, ...) hyp_len: (B,) - Actual lengths of hyp_seq Returns: ctc_logprobs: (B, T_enc, VocabSize) if CTC exists att_logprobs: (B, VocabSize) for the *next* token prediction based on the whole hyp_seq """ # Prepare masks batch_size, T_enc = encoder_out.size(0), encoder_out.size(1) U = hyp_seq.size(1) # src_mask (mask for encoder output based on encoder_out_lens) src_mask = (~make_pad_mask(encoder_out_lens, T_enc)).unsqueeze(1).unsqueeze(2).to(encoder_out.device) # (B, 1, 1, T_enc) # tgt_mask (mask for decoder input based on hyp_len) tgt_mask = (~make_pad_mask(hyp_len, U)).unsqueeze(1).unsqueeze(2).to(encoder_out.device) # (B, 1, U, U) # Ensure upper triangular part (future info) is masked in tgt_mask future_mask = torch.triu(torch.ones(U, U, device=tgt_mask.device), diagonal=1).bool() tgt_mask = tgt_mask & (~future_mask.unsqueeze(0)) # Forward through decoder dec_out, _ = self.decoder(hyp_seq, tgt_mask, encoder_out, src_mask) # dec_out: (B, U, D_dec) # Calculate attention-based log probabilities for the *last* token position in the sequence last_token_states = dec_out[torch.arange(batch_size), hyp_len - 1] # (B, D_dec) # Apply the output layer to get logits, then log_softmax att_logits = self.decoder_output_layer(last_token_states) # (B, VocabSize) att_logprobs = att_logits.log_softmax(dim=-1) # (B, VocabSize) # Also return CTC log probs if available ctc_logprobs = self.ctc.log_softmax(encoder_out) if self.ctc else torch.empty(batch_size, T_enc, 0, device=encoder_out.device, dtype=encoder_out.dtype) # (B, T_enc, VocabSize) return ctc_logprobs, att_logprobs def make_pad_mask(lengths, max_len=None): """Create padding mask.""" batch_size = lengths.size(0) if max_len is None: max_len = lengths.max().item() seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = lengths.unsqueeze(-1).expand_as(seq_range_expand) mask = seq_range_expand >= seq_length_expand return mask def get_nested_attr(obj, attr_path): """Get a nested attribute using a dot-separated string path.""" attrs = attr_path.split('.') current_obj = obj for attr in attrs: current_obj = getattr(current_obj, attr) return current_obj def main(): parser = get_parser() args = parser.parse_args() # Load the speech2text inference object speech2text = Speech2Text( asr_train_config=args.asr_config, asr_model_file=args.asr_model_file, device="cpu", # or "cuda" maxlenratio=0.0, minlenratio=0.0, batch_size=1, dtype="float32", beam_size=1, # Use beam_size=1 for ONNX export to avoid dynamic control flow ctc_weight=0.5, # or the value from your config (0.3) lm_weight=0.0, # Disable LM for this export penalty=0.0, nbest=1 ) # --- Access the internal ASR model --- checkpoint = torch.load(args.asr_model_file, map_location="cpu") from omegaconf import OmegaConf from espnet2.tasks.asr import ASRTask config = OmegaConf.load(args.asr_config) asr_model, *_ = ASRTask.build_model_from_file( args.asr_config, args.asr_model_file, device="cpu" # or "cuda" ) asr_model.eval() print("Loaded internal ASR model successfully.") # print("Model attributes:", [attr for attr in dir(asr_model) if not attr.startswith('_')]) # Commented out as it was verbose # --- Find the correct output layer name --- vocab_size = len(config.token_list) # Get vocab size from config print(f"Expected vocab size: {vocab_size}") found_layer = False decoder_output_layer_path = "" # Store the full path for name, module in asr_model.named_modules(): if isinstance(module, torch.nn.Linear): if module.out_features == vocab_size: print(f"Found potential output layer: {name} -> {module}") decoder_output_layer_path = name # e.g., 'decoder.output_layer' found_layer = True break if not found_layer: print("Could not find a Linear layer matching the vocab size. Please inspect the model manually.") # You might need to print the full model structure: print(asr_model) # Or iterate through children/modules more carefully. return # Exit if not found print(f"Using '{decoder_output_layer_path}' as the decoder output layer path.") # --- Export Encoder --- encoder_wrapper = ONNXEncoder(asr_model.encoder) dummy_speech = torch.randn(1, 200, 80, dtype=torch.float32) # (B, T, F) - Adjust T and F! dummy_speech_lengths = torch.LongTensor([200]) # Actual length in frames torch.onnx.export( encoder_wrapper, (dummy_speech, dummy_speech_lengths), args.output_encoder, export_params=True, opset_version=14, do_constant_folding=True, input_names=['speech', 'speech_lengths'], output_names=['encoder_out', 'encoder_out_lens'], dynamic_axes={ 'speech': {0: 'batch_size', 1: 'time'}, 'speech_lengths': {0: 'batch_size'}, 'encoder_out': {0: 'batch_size', 1: 'time_enc'}, 'encoder_out_lens': {0: 'batch_size'} } ) print(f"Encoder exported to {args.output_encoder}") # --- Export Decoder (Scoring Part) --- # Use the helper function to get the nested attribute target_layer = get_nested_attr(asr_model, decoder_output_layer_path) decoder_wrapper = ONNXDecoder(asr_model.decoder, target_layer, getattr(asr_model, 'ctc', None)) dummy_encoder_out = torch.randn(1, 100, asr_model.encoder.output_size(), dtype=torch.float32) # (B, T_enc, D_enc) dummy_encoder_out_lens = torch.LongTensor([100]) sos_id_val = asr_model.sos dummy_hyp_seq = torch.full((1, 5), 0, dtype=torch.long) # Initialize with padding ID (often 0) dummy_hyp_seq[0, 0] = sos_id_val # dummy_hyp_seq[0, 1] = 100 # Some dummy token ID dummy_hyp_seq[0, 2] = 200 # Another dummy token ID dummy_hyp_len = torch.LongTensor([3]) # Length of the actual sequence (excluding padding) dyn_axes_dec = { 'encoder_out': {0: 'batch_size', 1: 'time_enc'}, 'encoder_out_lens': {0: 'batch_size'}, 'hyp_seq': {0: 'batch_size', 1: 'time_dec'}, 'hyp_len': {0: 'batch_size'}, 'ctc_logprobs': {0: 'batch_size', 1: 'time_enc'}, # Only dynamic if CTC is used 'att_logprobs': {0: 'batch_size'} } torch.onnx.export( decoder_wrapper, (dummy_encoder_out, dummy_encoder_out_lens, dummy_hyp_seq, dummy_hyp_len), args.output_decoder, export_params=True, opset_version=14, do_constant_folding=True, input_names=['encoder_out', 'encoder_out_lens', 'hyp_seq', 'hyp_len'], output_names=['ctc_logprobs', 'att_logprobs'], # Outputs for scoring dynamic_axes=dyn_axes_dec ) print(f"Decoder (scoring part) exported to {args.output_decoder}") if __name__ == "__main__": main()