Unverified Commit 1ccc033c authored by Jay Zhang's avatar Jay Zhang Committed by GitHub
Browse files

Update the example of exporting Bart + BeamSearch to ONNX module to resolve comments. (#14310)



* Update code to resolve comments left in previous PR.

* Add README.md file for this example.

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update README.md file to resolve comments.

* Add a section name.

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarGary Miguel <garymm@garymm.org>

* Add more comments for _convert_past_list_to_tuple().

* Change the default file name to a consistent one.

* Fix a format issue.

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarGary Miguel <garymm@garymm.org>

* Update examples/onnx/pytorch/translation/run_onnx_exporter.py
Co-authored-by: default avatarGary Miguel <garymm@garymm.org>

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>

* Change the folder to summarization and address some other coments.

* Update the torch version.
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarGary Miguel <garymm@garymm.org>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent 6cdc3a78
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Bart + Beam Search to ONNX
This folder contains an example of exporting Bart + Beam Search generation (`BartForConditionalGeneration`) to ONNX.
Beam Search contains a for-loop workflow, so we need to make them TorchScript-compatible for exporting to ONNX. This example shows how to make a Bart model be TorchScript-compatible by wrapping up it into a new model. In addition, some changes were made to the `beam_search()` function to make it TorchScript-compatible.
## How to run the example
To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/transformers
cd transformers
pip install .
```
Then cd in this example folder and run
```bash
pip install -r requirements.txt
```
Now you can run the example command below to get the example ONNX file:
```bash
python run_onnx_exporter.py --model_name_or_path facebook/bart-base
```
import copy import copy
import itertools
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -8,23 +9,23 @@ from transformers import BartConfig ...@@ -8,23 +9,23 @@ from transformers import BartConfig
from transformers.generation_utils import GenerationMixin from transformers.generation_utils import GenerationMixin
def flatten_list(past): def _convert_past_list_to_tuple(past_key_values):
values = [] """
if past is not None: In Bart model, the type of past_key_values is tuple(tuple(torch.FloatTensor)) which is not
for i, p in enumerate(past): TorchScript-compatible. To support this, we have to convert it during the export process.
for j, q in enumerate(p): This function will convert past values from a list to tuple(tuple(torch.FloatTensor)) for
values.append(q) the inner decoder.
return values
def list_to_tuple(past): According to the definition of past_key_values, each inner tuple(torch.FloatTensor) has 4 tensors,
so we convert every 4 elements in the list as a tuple(torch.FloatTensor).
"""
count_of_each_inner_tuple = 4
results = () results = ()
temp_result = () temp_result = ()
count_n = len(past) // 4 count_n = len(past_key_values) // count_of_each_inner_tuple
for idx in range(count_n): for idx in range(count_n):
real_idx = idx * 4 real_idx = idx * count_of_each_inner_tuple
temp_result = tuple(past[real_idx : real_idx + 4]) temp_result = tuple(past_key_values[real_idx : real_idx + count_of_each_inner_tuple])
results += ((temp_result),) results += ((temp_result),)
return results return results
...@@ -51,7 +52,7 @@ class DecoderForONNX(torch.nn.Module): ...@@ -51,7 +52,7 @@ class DecoderForONNX(torch.nn.Module):
def forward(self, input_ids, encoder_state, attention_mask, past=None): def forward(self, input_ids, encoder_state, attention_mask, past=None):
all_results = None all_results = None
if past is not None: if past is not None:
all_results = list_to_tuple(past) all_results = _convert_past_list_to_tuple(past)
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
last_hidden_state, past_key_values = self.decoder( last_hidden_state, past_key_values = self.decoder(
...@@ -68,28 +69,33 @@ class DecoderForONNX(torch.nn.Module): ...@@ -68,28 +69,33 @@ class DecoderForONNX(torch.nn.Module):
return last_hidden_state, past_values return last_hidden_state, past_values
def create_traced_encoder(encoder, input_ids, attention_mask): def _create_traced_encoder(encoder, input_ids, attention_mask):
encoder_c = copy.deepcopy(encoder) encoder_c = copy.deepcopy(encoder)
encoder_for_onnx = EncoderForONNX(encoder_c) encoder_for_onnx = EncoderForONNX(encoder_c)
# return torch.jit.trace(encoder, (input_ids, attention_mask))
return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask)) return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask))
def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None): def _create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
decoder_c = copy.deepcopy(decoder) decoder_c = copy.deepcopy(decoder)
decoder_for_onnx = DecoderForONNX(decoder_c) decoder_for_onnx = DecoderForONNX(decoder_c)
past_values = flatten_list(past) past_values = list(itertools.chain.from_iterable(past or ()))
# Do this twice so we got 2 different decoders for further work. # Do this twice so we got 2 different decoders for further work.
if past_values is None or len(past_values) == 0: if past_values:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
else:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values)) return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values))
else:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
class BartConfigTS(BartConfig, torch.nn.Module): class BartConfigTS(BartConfig, torch.nn.Module):
def init_module(self): """
BartConfigTS is a TorchScript-compatible transformers.models.bart.configuration_bart.BartConfig.
TorchScript only supports sub-classes of torch.nn.Module.
"""
def __init__(self, config):
BartConfig.__init__(self, config)
torch.nn.Module.__init__(self) torch.nn.Module.__init__(self)
...@@ -127,7 +133,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin): ...@@ -127,7 +133,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.config = BartConfigTS(model.config) self.config = BartConfigTS(model.config)
self.config.init_module()
self.config.force_bos_token_to_be_generated = False self.config.force_bos_token_to_be_generated = False
self._trace_modules(model) self._trace_modules(model)
self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id) self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id)
...@@ -136,7 +141,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin): ...@@ -136,7 +141,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
self.decoder_layers = model.config.decoder_layers self.decoder_layers = model.config.decoder_layers
def _trace_modules(self, model): def _trace_modules(self, model):
# Be aware of the last one 2 should be kept.
input_ids = torch.tensor( input_ids = torch.tensor(
[ [
[ [
...@@ -200,89 +204,25 @@ class BARTGenerator(torch.nn.Module, GenerationMixin): ...@@ -200,89 +204,25 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
57, 57,
8629, 8629,
5, 5,
2, model.config.eos_token_id,
] ]
], ],
device=model.device, device=model.device,
dtype=torch.long, dtype=torch.long,
) )
attention_mask = torch.tensor( attention_mask = torch.tensor(
[ [[True] * input_ids.shape[-1]],
[
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
],
device=model.device, device=model.device,
dtype=torch.bool, dtype=torch.bool,
) )
self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask) self.encoder = _create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True) encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True)
decoder = model.model.decoder decoder = model.model.decoder
decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None) decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None)
self.decoder_no_past = create_traced_decoder( self.decoder_no_past = _create_traced_decoder(
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask
) )
self.decoder_with_past = create_traced_decoder( self.decoder_with_past = _create_traced_decoder(
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1] model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1]
) )
...@@ -414,8 +354,8 @@ class BeamSearchScorerTS(torch.nn.Module): ...@@ -414,8 +354,8 @@ class BeamSearchScorerTS(torch.nn.Module):
self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long) self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long)
self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9 self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9
self._beam_hyps_max_length: int = self.max_length - 1 self._beam_hyps_max_length: int = self.max_length - 1
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
def is_done(self) -> torch.Tensor: def is_done(self) -> torch.Tensor:
return self._done.all() return self._done.all()
...@@ -474,11 +414,11 @@ class BeamSearchScorerTS(torch.nn.Module): ...@@ -474,11 +414,11 @@ class BeamSearchScorerTS(torch.nn.Module):
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
hyps_count = self.hypo_len(hypo_idx) hyps_count = self.hypo_len(hypo_idx)
if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]: if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]:
# NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx. # NOTE: work around difference of torch.sum(empty_tensor) == 0, while error in onnx.
# Bug: https://msdata.visualstudio.com/Vienna/_workitems/edit/1486599
beam_idx = ( beam_idx = (
torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long) torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long)
) )
# beam_idx = torch.sum(_beam_hyps_count[:hypo_idx])
self._beam_scores.insert(beam_idx, torch.tensor([score])) self._beam_scores.insert(beam_idx, torch.tensor([score]))
self._beam_hyps.insert(beam_idx, hyp) self._beam_hyps.insert(beam_idx, hyp)
if hyps_count + 1 > self.num_beams: if hyps_count + 1 > self.num_beams:
...@@ -605,7 +545,7 @@ class BeamSearchScorerTS(torch.nn.Module): ...@@ -605,7 +545,7 @@ class BeamSearchScorerTS(torch.nn.Module):
self.hypo_add(final_tokens, final_score, batch_idx) self.hypo_add(final_tokens, final_score, batch_idx)
# select the best hypotheses # select the best hypotheses
# NOTE: new is not scriptable # NOTE: torch.Tensor.new_zeros() is not scriptable
sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long) sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long)
best = [] best = []
best_scores = torch.zeros( best_scores = torch.zeros(
...@@ -782,7 +722,6 @@ class BARTBeamSearchGenerator(BARTGenerator): ...@@ -782,7 +722,6 @@ class BARTBeamSearchGenerator(BARTGenerator):
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
) )
# from generation_utils.py
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
length_penalty = self.config.length_penalty length_penalty = self.config.length_penalty
......
"""
Code to remove duplicate initializers to reduce ONNX model size.
"""
import os import os
import numpy import numpy
...@@ -5,7 +9,7 @@ import numpy ...@@ -5,7 +9,7 @@ import numpy
import onnx import onnx
def is_equal_tensor_proto(a, b): def _is_equal_tensor_proto(a, b):
name_a = a.name name_a = a.name
name_b = b.name name_b = b.name
...@@ -20,25 +24,25 @@ def is_equal_tensor_proto(a, b): ...@@ -20,25 +24,25 @@ def is_equal_tensor_proto(a, b):
return res return res
def node_replace_input_with(node_proto, name, new_name): def _node_replace_input_with(node_proto, name, new_name):
for i, input_name in enumerate(node_proto.input): for i, input_name in enumerate(node_proto.input):
if input_name == name: if input_name == name:
node_proto.input.insert(i, new_name) node_proto.input.insert(i, new_name)
node_proto.input.pop(i + 1) node_proto.input.pop(i + 1)
if node_proto.op_type == "If": if node_proto.op_type == "If":
graph_replace_input_with(node_proto.attribute[0].g, name, new_name) _graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
graph_replace_input_with(node_proto.attribute[1].g, name, new_name) _graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
if node_proto.op_type == "Loop": if node_proto.op_type == "Loop":
graph_replace_input_with(node_proto.attribute[0].g, name, new_name) _graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
def graph_replace_input_with(graph_proto, name, new_name): def _graph_replace_input_with(graph_proto, name, new_name):
for n in graph_proto.node: for n in graph_proto.node:
node_replace_input_with(n, name, new_name) _node_replace_input_with(n, name, new_name)
def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace): def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
inits_with_data = [i for i in model.graph.initializer] inits_with_data = [i for i in model.graph.initializer]
inits = [i for i in model_without_ext.graph.initializer] inits = [i for i in model_without_ext.graph.initializer]
for i, ref_i in ind_to_replace: for i, ref_i in ind_to_replace:
...@@ -52,10 +56,15 @@ def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace) ...@@ -52,10 +56,15 @@ def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace)
model_without_ext.graph.initializer.remove(inits[i]) model_without_ext.graph.initializer.remove(inits[i])
# for n in model.graph.node: # for n in model.graph.node:
graph_replace_input_with(model_without_ext.graph, name_i, name_ref) _graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
def remove_dup_initializers(onnx_file_path): def remove_dup_initializers(onnx_file_path):
"""
Removes duplicate initializers from the model to reduce its size.
Writes a new file in the same directory as onnx_file_path and returns the path to that file.
"""
model_file_folder = os.path.dirname(onnx_file_path) model_file_folder = os.path.dirname(onnx_file_path)
model_file_name = os.path.basename(onnx_file_path) model_file_name = os.path.basename(onnx_file_path)
...@@ -76,7 +85,7 @@ def remove_dup_initializers(onnx_file_path): ...@@ -76,7 +85,7 @@ def remove_dup_initializers(onnx_file_path):
for j in range(i + 1, len(inits)): for j in range(i + 1, len(inits)):
if j in dup_set: if j in dup_set:
continue continue
if is_equal_tensor_proto(inits[i], inits[j]): if _is_equal_tensor_proto(inits[i], inits[j]):
dup_set.add(i) dup_set.add(i)
dup_set.add(j) dup_set.add(j)
...@@ -103,8 +112,8 @@ def remove_dup_initializers(onnx_file_path): ...@@ -103,8 +112,8 @@ def remove_dup_initializers(onnx_file_path):
print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB") print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")
ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0]) ind_to_replace = sorted(ind_to_replace)
remove_dup_initializers_from_model(model, model, ind_to_replace) _remove_dup_initializers_from_model(model, model, ind_to_replace)
optimized_model_file_name = "optimized_" + model_file_name optimized_model_file_name = "optimized_" + model_file_name
new_model = os.path.join(model_file_folder, optimized_model_file_name) new_model = os.path.join(model_file_folder, optimized_model_file_name)
......
torch >= 1.10
\ No newline at end of file
...@@ -20,7 +20,6 @@ import argparse ...@@ -20,7 +20,6 @@ import argparse
import logging import logging
import os import os
import sys import sys
from datetime import datetime
import numpy as np import numpy as np
import torch import torch
...@@ -46,7 +45,7 @@ tokenizer_dict = {"facebook/bart-base": BartTokenizer} ...@@ -46,7 +45,7 @@ tokenizer_dict = {"facebook/bart-base": BartTokenizer}
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") parser = argparse.ArgumentParser(description="Export Bart model + Beam Search to ONNX graph.")
parser.add_argument( parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
) )
...@@ -104,13 +103,12 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l ...@@ -104,13 +103,12 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
model.eval() model.eval()
ort_sess = None ort_sess = None
onnx_bart = torch.jit.script(BARTBeamSearchGenerator(model)) bart_script_model = torch.jit.script(BARTBeamSearchGenerator(model))
with torch.no_grad(): with torch.no_grad():
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device) inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device)
# Test export here.
summary_ids = model.generate( summary_ids = model.generate(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -120,53 +118,54 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l ...@@ -120,53 +118,54 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
decoder_start_token_id=model.config.decoder_start_token_id, decoder_start_token_id=model.config.decoder_start_token_id,
) )
if not ort_sess: torch.onnx.export(
torch.onnx.export( bart_script_model,
onnx_bart, (
( inputs["input_ids"],
inputs["input_ids"], inputs["attention_mask"],
inputs["attention_mask"], num_beams,
num_beams, max_length,
max_length, model.config.decoder_start_token_id,
model.config.decoder_start_token_id, ),
), onnx_file_path,
onnx_file_path, opset_version=14,
opset_version=14, input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"],
input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"], output_names=["output_ids"],
output_names=["output_ids"], dynamic_axes={
dynamic_axes={ "input_ids": {0: "batch", 1: "seq"},
"input_ids": {0: "batch", 1: "seq"}, "output_ids": {0: "batch", 1: "seq_out"},
"output_ids": {0: "batch", 1: "seq_out"}, },
}, example_outputs=summary_ids,
verbose=False, )
strip_doc_string=False,
example_outputs=summary_ids, logger.info("Model exported to {}".format(onnx_file_path))
)
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
logger.info("Deduplicated and optimized model written to {}".format(new_onnx_file_path))
ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
ort_out = ort_sess.run( ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
None, ort_out = ort_sess.run(
{ None,
"input_ids": inputs["input_ids"].cpu().numpy(), {
"attention_mask": inputs["attention_mask"].cpu().numpy(), "input_ids": inputs["input_ids"].cpu().numpy(),
"num_beams": np.array(num_beams), "attention_mask": inputs["attention_mask"].cpu().numpy(),
"max_length": np.array(max_length), "num_beams": np.array(num_beams),
"decoder_start_token_id": np.array(model.config.decoder_start_token_id), "max_length": np.array(max_length),
}, "decoder_start_token_id": np.array(model.config.decoder_start_token_id),
) },
)
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
print("========= Pass - Results are matched! =========")
logger.info("Model outputs from torch and ONNX Runtime are similar.")
logger.info("Success.")
def main(): def main():
args = parse_args() args = parse_args()
local_device = None max_length = 5
local_max_length = 5 num_beams = 4
local_num_beams = 4
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
...@@ -175,41 +174,31 @@ def main(): ...@@ -175,41 +174,31 @@ def main():
level=logging.INFO, level=logging.INFO,
) )
logger.setLevel(logging.ERROR) logger.setLevel(logging.INFO)
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
if args.model_name_or_path: device = torch.device(args.device)
model, tokenizer = load_model_tokenizer(args.model_name_or_path, local_device)
else: model, tokenizer = load_model_tokenizer(args.model_name_or_path, device)
raise ValueError("Make sure that model name has been passed")
if model.config.decoder_start_token_id is None: if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
if args.device: model.to(device)
if args.device == "cuda" and not torch.cuda.is_available():
raise ValueError("CUDA is not available in this server.")
local_device = torch.device(args.device)
else:
local_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(local_device)
if args.max_length: if args.max_length:
local_max_length = args.max_length max_length = args.max_length
if args.num_beams: if args.num_beams:
local_num_beams = args.num_beams num_beams = args.num_beams
if args.output_file_path: if args.output_file_path:
output_name = args.output_file_path output_name = args.output_file_path
else: else:
output_name = "onnx_model_{}.onnx".format(datetime.now().utcnow().microsecond) output_name = "BART.onnx"
export_and_validate_model(model, tokenizer, output_name, local_num_beams, local_max_length)
logger.info("***** Running export *****") logger.info("Exporting model to ONNX")
export_and_validate_model(model, tokenizer, output_name, num_beams, max_length)
if __name__ == "__main__": if __name__ == "__main__":
......
torch >= 1.8
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment