Unverified Commit 0b1e0fcf authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

Fix GPT-J onnx conversion (#16780)



* add gptj to TOKENIZER_MAPPING_NAMES

* fix int32 to float to avoid problem in onnx

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarChainYo <t.chaigneau.tc@gmail.com>
Co-authored-by: default avatarlewtun <lewis.c.tunstall@gmail.com>
parent bae9b645
......@@ -136,6 +136,7 @@ else:
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
(
"xlnet",
......
......@@ -60,7 +60,9 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
......
......@@ -250,7 +250,7 @@ class FeaturesManager:
"token-classification",
onnx_config_cls=GPT2OnnxConfig,
),
"gpt-j": supported_features_mapping(
"gptj": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment