Unverified Commit 1ccc033c authored by Jay Zhang's avatar Jay Zhang Committed by GitHub
Browse files

Update the example of exporting Bart + BeamSearch to ONNX module to resolve comments. (#14310)



* Update code to resolve comments left in previous PR.

* Add README.md file for this example.

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update README.md file to resolve comments.

* Add a section name.

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarGary Miguel <garymm@garymm.org>

* Add more comments for _convert_past_list_to_tuple().

* Change the default file name to a consistent one.

* Fix a format issue.

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarGary Miguel <garymm@garymm.org>

* Update examples/onnx/pytorch/translation/run_onnx_exporter.py
Co-authored-by: default avatarGary Miguel <garymm@garymm.org>

* Update examples/onnx/pytorch/translation/README.md
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>

* Change the folder to summarization and address some other coments.

* Update the torch version.
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarGary Miguel <garymm@garymm.org>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent 6cdc3a78
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Bart + Beam Search to ONNX
This folder contains an example of exporting Bart + Beam Search generation (`BartForConditionalGeneration`) to ONNX.
Beam Search contains a for-loop workflow, so we need to make them TorchScript-compatible for exporting to ONNX. This example shows how to make a Bart model be TorchScript-compatible by wrapping up it into a new model. In addition, some changes were made to the `beam_search()` function to make it TorchScript-compatible.
## How to run the example
To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/transformers
cd transformers
pip install .
```
Then cd in this example folder and run
```bash
pip install -r requirements.txt
```
Now you can run the example command below to get the example ONNX file:
```bash
python run_onnx_exporter.py --model_name_or_path facebook/bart-base
```
import copy
import itertools
from typing import List, Optional, Tuple
import torch
......@@ -8,23 +9,23 @@ from transformers import BartConfig
from transformers.generation_utils import GenerationMixin
def flatten_list(past):
values = []
if past is not None:
for i, p in enumerate(past):
for j, q in enumerate(p):
values.append(q)
return values
def _convert_past_list_to_tuple(past_key_values):
"""
In Bart model, the type of past_key_values is tuple(tuple(torch.FloatTensor)) which is not
TorchScript-compatible. To support this, we have to convert it during the export process.
This function will convert past values from a list to tuple(tuple(torch.FloatTensor)) for
the inner decoder.
def list_to_tuple(past):
According to the definition of past_key_values, each inner tuple(torch.FloatTensor) has 4 tensors,
so we convert every 4 elements in the list as a tuple(torch.FloatTensor).
"""
count_of_each_inner_tuple = 4
results = ()
temp_result = ()
count_n = len(past) // 4
count_n = len(past_key_values) // count_of_each_inner_tuple
for idx in range(count_n):
real_idx = idx * 4
temp_result = tuple(past[real_idx : real_idx + 4])
real_idx = idx * count_of_each_inner_tuple
temp_result = tuple(past_key_values[real_idx : real_idx + count_of_each_inner_tuple])
results += ((temp_result),)
return results
......@@ -51,7 +52,7 @@ class DecoderForONNX(torch.nn.Module):
def forward(self, input_ids, encoder_state, attention_mask, past=None):
all_results = None
if past is not None:
all_results = list_to_tuple(past)
all_results = _convert_past_list_to_tuple(past)
input_ids = input_ids[:, -1:]
last_hidden_state, past_key_values = self.decoder(
......@@ -68,28 +69,33 @@ class DecoderForONNX(torch.nn.Module):
return last_hidden_state, past_values
def create_traced_encoder(encoder, input_ids, attention_mask):
def _create_traced_encoder(encoder, input_ids, attention_mask):
encoder_c = copy.deepcopy(encoder)
encoder_for_onnx = EncoderForONNX(encoder_c)
# return torch.jit.trace(encoder, (input_ids, attention_mask))
return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask))
def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
def _create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
decoder_c = copy.deepcopy(decoder)
decoder_for_onnx = DecoderForONNX(decoder_c)
past_values = flatten_list(past)
past_values = list(itertools.chain.from_iterable(past or ()))
# Do this twice so we got 2 different decoders for further work.
if past_values is None or len(past_values) == 0:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
else:
if past_values:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values))
else:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
class BartConfigTS(BartConfig, torch.nn.Module):
def init_module(self):
"""
BartConfigTS is a TorchScript-compatible transformers.models.bart.configuration_bart.BartConfig.
TorchScript only supports sub-classes of torch.nn.Module.
"""
def __init__(self, config):
BartConfig.__init__(self, config)
torch.nn.Module.__init__(self)
......@@ -127,7 +133,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
def __init__(self, model):
super().__init__()
self.config = BartConfigTS(model.config)
self.config.init_module()
self.config.force_bos_token_to_be_generated = False
self._trace_modules(model)
self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id)
......@@ -136,7 +141,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
self.decoder_layers = model.config.decoder_layers
def _trace_modules(self, model):
# Be aware of the last one 2 should be kept.
input_ids = torch.tensor(
[
[
......@@ -200,89 +204,25 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
57,
8629,
5,
2,
model.config.eos_token_id,
]
],
device=model.device,
dtype=torch.long,
)
attention_mask = torch.tensor(
[
[
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
],
[[True] * input_ids.shape[-1]],
device=model.device,
dtype=torch.bool,
)
self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
self.encoder = _create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True)
decoder = model.model.decoder
decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None)
self.decoder_no_past = create_traced_decoder(
self.decoder_no_past = _create_traced_decoder(
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask
)
self.decoder_with_past = create_traced_decoder(
self.decoder_with_past = _create_traced_decoder(
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1]
)
......@@ -414,8 +354,8 @@ class BeamSearchScorerTS(torch.nn.Module):
self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long)
self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9
self._beam_hyps_max_length: int = self.max_length - 1
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
def is_done(self) -> torch.Tensor:
return self._done.all()
......@@ -474,11 +414,11 @@ class BeamSearchScorerTS(torch.nn.Module):
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
hyps_count = self.hypo_len(hypo_idx)
if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]:
# NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx.
# NOTE: work around difference of torch.sum(empty_tensor) == 0, while error in onnx.
# Bug: https://msdata.visualstudio.com/Vienna/_workitems/edit/1486599
beam_idx = (
torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long)
)
# beam_idx = torch.sum(_beam_hyps_count[:hypo_idx])
self._beam_scores.insert(beam_idx, torch.tensor([score]))
self._beam_hyps.insert(beam_idx, hyp)
if hyps_count + 1 > self.num_beams:
......@@ -605,7 +545,7 @@ class BeamSearchScorerTS(torch.nn.Module):
self.hypo_add(final_tokens, final_score, batch_idx)
# select the best hypotheses
# NOTE: new is not scriptable
# NOTE: torch.Tensor.new_zeros() is not scriptable
sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long)
best = []
best_scores = torch.zeros(
......@@ -782,7 +722,6 @@ class BARTBeamSearchGenerator(BARTGenerator):
bos_token_id=bos_token_id,
)
# from generation_utils.py
batch_size = input_ids.shape[0]
length_penalty = self.config.length_penalty
......
"""
Code to remove duplicate initializers to reduce ONNX model size.
"""
import os
import numpy
......@@ -5,7 +9,7 @@ import numpy
import onnx
def is_equal_tensor_proto(a, b):
def _is_equal_tensor_proto(a, b):
name_a = a.name
name_b = b.name
......@@ -20,25 +24,25 @@ def is_equal_tensor_proto(a, b):
return res
def node_replace_input_with(node_proto, name, new_name):
def _node_replace_input_with(node_proto, name, new_name):
for i, input_name in enumerate(node_proto.input):
if input_name == name:
node_proto.input.insert(i, new_name)
node_proto.input.pop(i + 1)
if node_proto.op_type == "If":
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
if node_proto.op_type == "Loop":
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
def graph_replace_input_with(graph_proto, name, new_name):
def _graph_replace_input_with(graph_proto, name, new_name):
for n in graph_proto.node:
node_replace_input_with(n, name, new_name)
_node_replace_input_with(n, name, new_name)
def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
inits_with_data = [i for i in model.graph.initializer]
inits = [i for i in model_without_ext.graph.initializer]
for i, ref_i in ind_to_replace:
......@@ -52,10 +56,15 @@ def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace)
model_without_ext.graph.initializer.remove(inits[i])
# for n in model.graph.node:
graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
_graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
def remove_dup_initializers(onnx_file_path):
"""
Removes duplicate initializers from the model to reduce its size.
Writes a new file in the same directory as onnx_file_path and returns the path to that file.
"""
model_file_folder = os.path.dirname(onnx_file_path)
model_file_name = os.path.basename(onnx_file_path)
......@@ -76,7 +85,7 @@ def remove_dup_initializers(onnx_file_path):
for j in range(i + 1, len(inits)):
if j in dup_set:
continue
if is_equal_tensor_proto(inits[i], inits[j]):
if _is_equal_tensor_proto(inits[i], inits[j]):
dup_set.add(i)
dup_set.add(j)
......@@ -103,8 +112,8 @@ def remove_dup_initializers(onnx_file_path):
print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")
ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0])
remove_dup_initializers_from_model(model, model, ind_to_replace)
ind_to_replace = sorted(ind_to_replace)
_remove_dup_initializers_from_model(model, model, ind_to_replace)
optimized_model_file_name = "optimized_" + model_file_name
new_model = os.path.join(model_file_folder, optimized_model_file_name)
......
torch >= 1.10
\ No newline at end of file
......@@ -20,7 +20,6 @@ import argparse
import logging
import os
import sys
from datetime import datetime
import numpy as np
import torch
......@@ -46,7 +45,7 @@ tokenizer_dict = {"facebook/bart-base": BartTokenizer}
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
parser = argparse.ArgumentParser(description="Export Bart model + Beam Search to ONNX graph.")
parser.add_argument(
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
)
......@@ -104,13 +103,12 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
model.eval()
ort_sess = None
onnx_bart = torch.jit.script(BARTBeamSearchGenerator(model))
bart_script_model = torch.jit.script(BARTBeamSearchGenerator(model))
with torch.no_grad():
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device)
# Test export here.
summary_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
......@@ -120,53 +118,54 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l
decoder_start_token_id=model.config.decoder_start_token_id,
)
if not ort_sess:
torch.onnx.export(
onnx_bart,
(
inputs["input_ids"],
inputs["attention_mask"],
num_beams,
max_length,
model.config.decoder_start_token_id,
),
onnx_file_path,
opset_version=14,
input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"],
output_names=["output_ids"],
dynamic_axes={
"input_ids": {0: "batch", 1: "seq"},
"output_ids": {0: "batch", 1: "seq_out"},
},
verbose=False,
strip_doc_string=False,
example_outputs=summary_ids,
)
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
ort_out = ort_sess.run(
None,
{
"input_ids": inputs["input_ids"].cpu().numpy(),
"attention_mask": inputs["attention_mask"].cpu().numpy(),
"num_beams": np.array(num_beams),
"max_length": np.array(max_length),
"decoder_start_token_id": np.array(model.config.decoder_start_token_id),
},
)
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
print("========= Pass - Results are matched! =========")
torch.onnx.export(
bart_script_model,
(
inputs["input_ids"],
inputs["attention_mask"],
num_beams,
max_length,
model.config.decoder_start_token_id,
),
onnx_file_path,
opset_version=14,
input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"],
output_names=["output_ids"],
dynamic_axes={
"input_ids": {0: "batch", 1: "seq"},
"output_ids": {0: "batch", 1: "seq_out"},
},
example_outputs=summary_ids,
)
logger.info("Model exported to {}".format(onnx_file_path))
new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path))
logger.info("Deduplicated and optimized model written to {}".format(new_onnx_file_path))
ort_sess = onnxruntime.InferenceSession(new_onnx_file_path)
ort_out = ort_sess.run(
None,
{
"input_ids": inputs["input_ids"].cpu().numpy(),
"attention_mask": inputs["attention_mask"].cpu().numpy(),
"num_beams": np.array(num_beams),
"max_length": np.array(max_length),
"decoder_start_token_id": np.array(model.config.decoder_start_token_id),
},
)
np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3)
logger.info("Model outputs from torch and ONNX Runtime are similar.")
logger.info("Success.")
def main():
args = parse_args()
local_device = None
local_max_length = 5
local_num_beams = 4
max_length = 5
num_beams = 4
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
......@@ -175,41 +174,31 @@ def main():
level=logging.INFO,
)
logger.setLevel(logging.ERROR)
logger.setLevel(logging.INFO)
transformers.utils.logging.set_verbosity_error()
if args.model_name_or_path:
model, tokenizer = load_model_tokenizer(args.model_name_or_path, local_device)
else:
raise ValueError("Make sure that model name has been passed")
device = torch.device(args.device)
model, tokenizer = load_model_tokenizer(args.model_name_or_path, device)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
if args.device:
if args.device == "cuda" and not torch.cuda.is_available():
raise ValueError("CUDA is not available in this server.")
local_device = torch.device(args.device)
else:
local_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(local_device)
model.to(device)
if args.max_length:
local_max_length = args.max_length
max_length = args.max_length
if args.num_beams:
local_num_beams = args.num_beams
num_beams = args.num_beams
if args.output_file_path:
output_name = args.output_file_path
else:
output_name = "onnx_model_{}.onnx".format(datetime.now().utcnow().microsecond)
export_and_validate_model(model, tokenizer, output_name, local_num_beams, local_max_length)
output_name = "BART.onnx"
logger.info("***** Running export *****")
logger.info("Exporting model to ONNX")
export_and_validate_model(model, tokenizer, output_name, num_beams, max_length)
if __name__ == "__main__":
......
torch >= 1.8
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment