"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7fd1d42a0139bb8dc133b872313f5dd49c694b87"
Unverified Commit 029b0d95 authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

add GPT-J ONNX config to Transformers (#16274)



* add GPT-J ONNX config to Transformers

* remove token-classification features mapping
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>

* add question-answering features mapping
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>

* add GPT2 config init to GPT2 config + copie shebang for fix-copies
Co-authored-by: default avatarChainYo <t.chaigneau.tc@gmail.com>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent aff9bc40
...@@ -54,6 +54,7 @@ Ready-made configurations include the following architectures: ...@@ -54,6 +54,7 @@ Ready-made configurations include the following architectures:
- ELECTRA - ELECTRA
- FlauBERT - FlauBERT
- GPT Neo - GPT Neo
- GPT-J
- I-BERT - I-BERT
- LayoutLM - LayoutLM
- M2M100 - M2M100
......
...@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available ...@@ -21,7 +21,7 @@ from ...utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = { _import_structure = {
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"], "configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig", "GPTJOnnxConfig"],
} }
if is_torch_available(): if is_torch_available():
...@@ -43,7 +43,7 @@ if is_flax_available(): ...@@ -43,7 +43,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig, GPTJOnnxConfig
if is_torch_available(): if is_torch_available():
from .modeling_gptj import ( from .modeling_gptj import (
......
...@@ -13,8 +13,12 @@ ...@@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" GPT-J model configuration""" """ GPT-J model configuration"""
from collections import OrderedDict
from typing import Any, List, Mapping, Optional
from ... import PreTrainedTokenizer, TensorType, is_torch_available
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging from ...utils import logging
...@@ -135,3 +139,84 @@ class GPTJConfig(PretrainedConfig): ...@@ -135,3 +139,84 @@ class GPTJConfig(PretrainedConfig):
super().__init__( super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
) )
# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
class GPTJOnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def num_layers(self) -> int:
return self._config.n_layer
@property
def num_attention_heads(self) -> int:
return self._config.n_head
def generate_dummy_inputs(
self,
tokenizer: PreTrainedTokenizer,
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
self.num_attention_heads,
past_key_values_length,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
)
return ordered_inputs
@property
def default_onnx_opset(self) -> int:
return 13
...@@ -11,6 +11,7 @@ from ..models.electra import ElectraOnnxConfig ...@@ -11,6 +11,7 @@ from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig from ..models.flaubert import FlaubertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.gptj import GPTJOnnxConfig
from ..models.ibert import IBertOnnxConfig from ..models.ibert import IBertOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.m2m_100 import M2M100OnnxConfig from ..models.m2m_100 import M2M100OnnxConfig
...@@ -233,6 +234,15 @@ class FeaturesManager: ...@@ -233,6 +234,15 @@ class FeaturesManager:
"token-classification", "token-classification",
onnx_config_cls=GPT2OnnxConfig, onnx_config_cls=GPT2OnnxConfig,
), ),
"gpt-j": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"question-answering",
"sequence-classification",
onnx_config_cls=GPTJOnnxConfig,
),
"gpt-neo": supported_features_mapping( "gpt-neo": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment