"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1438c487df5ce38a7b2ae30877b3074b96a423dd"
Unverified Commit fbf1397b authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Turn on eval mode when exporting to ONNX (#12758)

* Set model in eval mode when exporting to ONNX.

* Disable t5 for now.

* Disable T5 with past too.

* Style.
parent 8ef3f365
...@@ -87,6 +87,7 @@ def export( ...@@ -87,6 +87,7 @@ def export(
logger.info(f"Using framework PyTorch: {torch.__version__}") logger.info(f"Using framework PyTorch: {torch.__version__}")
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
model.config.return_dict = True model.config.return_dict = True
model.eval()
# Check if we need to override certain configuration item # Check if we need to override certain configuration item
if config.values_override is not None: if config.values_override is not None:
......
...@@ -3,14 +3,13 @@ from tempfile import NamedTemporaryFile ...@@ -3,14 +3,13 @@ from tempfile import NamedTemporaryFile
from unittest import TestCase from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
from transformers import ( # LongformerConfig, from transformers import ( # LongformerConfig,; T5Config,
AlbertConfig, AlbertConfig,
AutoTokenizer, AutoTokenizer,
BartConfig, BartConfig,
DistilBertConfig, DistilBertConfig,
GPT2Config, GPT2Config,
RobertaConfig, RobertaConfig,
T5Config,
XLMRobertaConfig, XLMRobertaConfig,
is_torch_available, is_torch_available,
) )
...@@ -22,7 +21,8 @@ from transformers.models.distilbert import DistilBertOnnxConfig ...@@ -22,7 +21,8 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig # from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.roberta import RobertaOnnxConfig from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
# from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
...@@ -122,7 +122,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -122,7 +122,11 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX) Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
""" """
SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)} SUPPORTED_WITH_PAST_CONFIGS = {
("BART", BartConfig),
("GPT2", GPT2Config),
# ("T5", T5Config)
}
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
def test_use_past(self): def test_use_past(self):
...@@ -165,14 +169,13 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -165,14 +169,13 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
if is_torch_available(): if is_torch_available():
from transformers import ( from transformers import ( # T5Model,
AlbertModel, AlbertModel,
BartModel, BartModel,
BertModel, BertModel,
DistilBertModel, DistilBertModel,
GPT2Model, GPT2Model,
RobertaModel, RobertaModel,
T5Model,
XLMRobertaModel, XLMRobertaModel,
) )
...@@ -185,7 +188,7 @@ if is_torch_available(): ...@@ -185,7 +188,7 @@ if is_torch_available():
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment