export_lm_onnx.py 2.7 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# export_lm_onnx.py
import torch
from espnet2.tasks.lm import LMTask
from omegaconf import OmegaConf
import argparse

def get_parser():
    parser = argparse.ArgumentParser(description='Export ESPNet LM model to ONNX')
    parser.add_argument('--lm_file', type=str, required=True, help='Path to the trained LM .pth file')
    parser.add_argument('--lm_config', type=str, required=True, help='Path to the LM config.yaml file')
    parser.add_argument('--output', type=str, default='lm.onnx', help='Output ONNX file for LM')
    return parser

class ONNXLMLayer(torch.nn.Module):
    def __init__(self, lm_model):
        super(ONNXLMLayer, self).__init__()
        # The core LM network is usually the 'predictor'
        # Check the specific structure of your LM model by printing model.children()/modules()
        self.predictor = lm_model.predictor # This is the Transformer LM part

    def forward(self, input_ids, lengths):
        # input_ids: (B, T) - token IDs
        # lengths: (B,) - actual sequence lengths
        # The LM outputs logits for the next token
        # y shape is typically (B, T, vocab_size)
        y, _ = self.predictor(input_ids, lengths)
        # Return logits for all positions
        return y # Shape: (B, T, vocab_size)

def main():
    parser = get_parser()
    args = parser.parse_args()

    # Load config
    config = OmegaConf.load(args.lm_config)

    # Build model directly using ESPNet's task interface
    model, *_ = LMTask.build_model_from_file(
        args.lm_config,
        args.lm_file,
        device="cpu" # or "cuda"
    )
    model.eval()

    # Wrap the core LM predictor module
    predictor_wrapper = ONNXLMLayer(model)

    # Example input shapes
    # input_ids: (B, T) - token IDs
    # lengths: (B,) - actual lengths
    # Get vocab size from the model's output layer
    vocab_size = model.predictor.embed.out_features if hasattr(model.predictor.embed, 'out_features') else model.predictor.output_layer.out_features
    print(f"Detected LM vocab size: {vocab_size}")

    dummy_input_ids = torch.randint(low=0, high=vocab_size, size=(1, 10)) # Batch=1, Seq Len=10
    dummy_lengths = torch.LongTensor([10])

    torch.onnx.export(
        predictor_wrapper,
        (dummy_input_ids, dummy_lengths),
        args.output,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=['input_ids', 'lengths'],
        output_names=['logits'],
        dynamic_axes={
            'input_ids': {0: 'batch_size', 1: 'sequence_length'},
            'lengths': {0: 'batch_size'},
            'logits': {0: 'batch_size', 1: 'sequence_length', 2: 'vocab_size'}
        }
    )
    print(f"LM exported to {args.output}")


if __name__ == "__main__":
    main()