# export_lm_onnx.py import torch from espnet2.tasks.lm import LMTask from omegaconf import OmegaConf import argparse def get_parser(): parser = argparse.ArgumentParser(description='Export ESPNet LM model to ONNX') parser.add_argument('--lm_file', type=str, required=True, help='Path to the trained LM .pth file') parser.add_argument('--lm_config', type=str, required=True, help='Path to the LM config.yaml file') parser.add_argument('--output', type=str, default='lm.onnx', help='Output ONNX file for LM') return parser class ONNXLMLayer(torch.nn.Module): def __init__(self, lm_model): super(ONNXLMLayer, self).__init__() # The core LM network is usually the 'predictor' # Check the specific structure of your LM model by printing model.children()/modules() self.predictor = lm_model.predictor # This is the Transformer LM part def forward(self, input_ids, lengths): # input_ids: (B, T) - token IDs # lengths: (B,) - actual sequence lengths # The LM outputs logits for the next token # y shape is typically (B, T, vocab_size) y, _ = self.predictor(input_ids, lengths) # Return logits for all positions return y # Shape: (B, T, vocab_size) def main(): parser = get_parser() args = parser.parse_args() # Load config config = OmegaConf.load(args.lm_config) # Build model directly using ESPNet's task interface model, *_ = LMTask.build_model_from_file( args.lm_config, args.lm_file, device="cpu" # or "cuda" ) model.eval() # Wrap the core LM predictor module predictor_wrapper = ONNXLMLayer(model) # Example input shapes # input_ids: (B, T) - token IDs # lengths: (B,) - actual lengths # Get vocab size from the model's output layer vocab_size = model.predictor.embed.out_features if hasattr(model.predictor.embed, 'out_features') else model.predictor.output_layer.out_features print(f"Detected LM vocab size: {vocab_size}") dummy_input_ids = torch.randint(low=0, high=vocab_size, size=(1, 10)) # Batch=1, Seq Len=10 dummy_lengths = torch.LongTensor([10]) torch.onnx.export( predictor_wrapper, (dummy_input_ids, dummy_lengths), args.output, export_params=True, opset_version=14, do_constant_folding=True, input_names=['input_ids', 'lengths'], output_names=['logits'], dynamic_axes={ 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, 'lengths': {0: 'batch_size'}, 'logits': {0: 'batch_size', 1: 'sequence_length', 2: 'vocab_size'} } ) print(f"LM exported to {args.output}") if __name__ == "__main__": main()