Unverified Commit 6da76b9c authored by Jingya HUANG's avatar Jingya HUANG Committed by GitHub
Browse files

Add onnx export cuda support (#17183)


Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent adc0ff25
...@@ -3099,7 +3099,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): ...@@ -3099,7 +3099,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
# setting lengths logits to `-inf` # setting lengths logits to `-inf`
logits_mask = self.prepare_question_mask(question_lengths, seqlen) logits_mask = self.prepare_question_mask(question_lengths, seqlen)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.ones(logits_mask.size(), dtype=int) - logits_mask token_type_ids = torch.ones(logits_mask.size(), dtype=int, device=logits_mask.device) - logits_mask
logits_mask = logits_mask logits_mask = logits_mask
logits_mask[:, 0] = False logits_mask[:, 0] = False
logits_mask.unsqueeze_(2) logits_mask.unsqueeze_(2)
......
...@@ -86,6 +86,7 @@ def export_pytorch( ...@@ -86,6 +86,7 @@ def export_pytorch(
opset: int, opset: int,
output: Path, output: Path,
tokenizer: "PreTrainedTokenizer" = None, tokenizer: "PreTrainedTokenizer" = None,
device: str = "cpu",
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
""" """
Export a PyTorch model to an ONNX Intermediate Representation (IR) Export a PyTorch model to an ONNX Intermediate Representation (IR)
...@@ -101,6 +102,8 @@ def export_pytorch( ...@@ -101,6 +102,8 @@ def export_pytorch(
The version of the ONNX operator set to use. The version of the ONNX operator set to use.
output (`Path`): output (`Path`):
Directory to store the exported ONNX model. Directory to store the exported ONNX model.
device (`str`, *optional*, defaults to `cpu`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`.
Returns: Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
...@@ -137,6 +140,10 @@ def export_pytorch( ...@@ -137,6 +140,10 @@ def export_pytorch(
# Ensure inputs match # Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True" # TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH) model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items())
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys()) onnx_outputs = list(config.outputs.keys())
...@@ -268,6 +275,7 @@ def export( ...@@ -268,6 +275,7 @@ def export(
opset: int, opset: int,
output: Path, output: Path,
tokenizer: "PreTrainedTokenizer" = None, tokenizer: "PreTrainedTokenizer" = None,
device: str = "cpu",
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
""" """
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR) Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
...@@ -283,6 +291,9 @@ def export( ...@@ -283,6 +291,9 @@ def export(
The version of the ONNX operator set to use. The version of the ONNX operator set to use.
output (`Path`): output (`Path`):
Directory to store the exported ONNX model. Directory to store the exported ONNX model.
device (`str`, *optional*, defaults to `cpu`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
Returns: Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
...@@ -294,6 +305,9 @@ def export( ...@@ -294,6 +305,9 @@ def export(
"Please install torch or tensorflow first." "Please install torch or tensorflow first."
) )
if is_tf_available() and isinstance(model, TFPreTrainedModel) and device == "cuda":
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None: if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.") raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
if tokenizer is not None: if tokenizer is not None:
...@@ -318,7 +332,7 @@ def export( ...@@ -318,7 +332,7 @@ def export(
) )
if is_torch_available() and issubclass(type(model), PreTrainedModel): if is_torch_available() and issubclass(type(model), PreTrainedModel):
return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer) return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer) return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer)
...@@ -359,6 +373,8 @@ def validate_model_outputs( ...@@ -359,6 +373,8 @@ def validate_model_outputs(
session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"]) session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"])
# Compute outputs from the reference model # Compute outputs from the reference model
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
reference_model.to("cpu")
ref_outputs = reference_model(**reference_model_inputs) ref_outputs = reference_model(**reference_model_inputs)
ref_outputs_dict = {} ref_outputs_dict = {}
......
...@@ -242,7 +242,7 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -242,7 +242,7 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported Integration tests ensuring supported models are correctly exported
""" """
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
from transformers.onnx import export from transformers.onnx import export
model_class = FeaturesManager.get_model_class_for_feature(feature) model_class = FeaturesManager.get_model_class_for_feature(feature)
...@@ -273,7 +273,7 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -273,7 +273,7 @@ class OnnxExportTestCaseV2(TestCase):
with NamedTemporaryFile("w") as output: with NamedTemporaryFile("w") as output:
try: try:
onnx_inputs, onnx_outputs = export( onnx_inputs, onnx_outputs = export(
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name) preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name), device=device
) )
validate_model_outputs( validate_model_outputs(
onnx_config, onnx_config,
...@@ -294,6 +294,14 @@ class OnnxExportTestCaseV2(TestCase): ...@@ -294,6 +294,14 @@ class OnnxExportTestCaseV2(TestCase):
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@require_vision
@require_rjieba
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS)) @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow @slow
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment