Unverified Commit 0b1e0fcf authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

Fix GPT-J onnx conversion (#16780)



* add gptj to TOKENIZER_MAPPING_NAMES

* fix int32 to float to avoid problem in onnx

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarChainYo <t.chaigneau.tc@gmail.com>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent bae9b645
...@@ -136,6 +136,7 @@ else: ...@@ -136,6 +136,7 @@ else:
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)), ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("transfo-xl", ("TransfoXLTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)),
( (
"xlnet", "xlnet",
......
...@@ -60,7 +60,9 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None): ...@@ -60,7 +60,9 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
if seq_len is None: if seq_len is None:
seq_len = x.shape[seq_dim] seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float() sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
......
...@@ -250,7 +250,7 @@ class FeaturesManager: ...@@ -250,7 +250,7 @@ class FeaturesManager:
"token-classification", "token-classification",
onnx_config_cls=GPT2OnnxConfig, onnx_config_cls=GPT2OnnxConfig,
), ),
"gpt-j": supported_features_mapping( "gptj": supported_features_mapping(
"default", "default",
"default-with-past", "default-with-past",
"causal-lm", "causal-lm",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment