download_whisper.py 3.61 KB
Newer Older
Attila Dusnoki's avatar
Attila Dusnoki committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################

from optimum.exporters.onnx import main_export
from optimum.exporters.onnx.model_configs import WhisperOnnxConfig
from transformers import AutoConfig

from optimum.exporters.onnx.base import ConfigBehavior
from typing import Dict


class CustomWhisperOnnxConfig(WhisperOnnxConfig):
    @property
    def inputs(self) -> Dict[str, Dict[int, str]]:
        common_inputs = {}

        if self._behavior is ConfigBehavior.ENCODER:
            common_inputs["input_features"] = {
                0: "batch_size",
                1: "feature_size",
                2: "encoder_sequence_length"
            }

        if self._behavior is ConfigBehavior.DECODER:
            common_inputs["decoder_input_ids"] = {
                0: "batch_size",
                1: "decoder_sequence_length"
            }
            common_inputs["decoder_attention_mask"] = {
                0: "batch_size",
                1: "decoder_sequence_length"
            }
            common_inputs["encoder_outputs"] = {
                0: "batch_size",
                1: "encoder_sequence_length"
            }

        return common_inputs

    @property
    def torch_to_onnx_input_map(self) -> Dict[str, str]:
        if self._behavior is ConfigBehavior.DECODER:
            return {
                "decoder_input_ids": "input_ids",
                "decoder_attention_mask": "attention_mask",
                "encoder_outputs": "encoder_hidden_states",
            }
        return {}


def export():
    model_id = "openai/whisper-tiny.en"
    config = AutoConfig.from_pretrained(model_id)

    custom_whisper_onnx_config = CustomWhisperOnnxConfig(
        config=config,
        task="automatic-speech-recognition",
    )

    encoder_config = custom_whisper_onnx_config.with_behavior("encoder")
    decoder_config = custom_whisper_onnx_config.with_behavior("decoder",
                                                              use_past=False)

    custom_onnx_configs = {
        "encoder_model": encoder_config,
        "decoder_model": decoder_config,
    }

    output = "models/whisper-tiny.en_modified"
    main_export(model_id,
                output=output,
                no_post_process=True,
                do_validation=False,
                custom_onnx_configs=custom_onnx_configs)

    print(f"Done. Check {output}")


if __name__ == "__main__":
    export()