Unverified Commit f56174ac authored by tanreinama's avatar tanreinama Committed by GitHub
Browse files

add GPTSAN model (reopen) (#21291)

* add GPTSAN-Japanese

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN (update for review)

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* fix typo in comment text

* add GPTSAN

* add GPTSAN

* add GPTSAN

* add GPTSAN

* fix document and comments

* fix class name GPTSAN->GPTSan

* fix import and test for tokenizer
parent c87bbe1f
...@@ -90,6 +90,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -90,6 +90,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("gpt_neox", "GPTNeoXModel"), ("gpt_neox", "GPTNeoXModel"),
("gpt_neox_japanese", "GPTNeoXJapaneseModel"), ("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
("gptj", "GPTJModel"), ("gptj", "GPTJModel"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("graphormer", "GraphormerModel"), ("graphormer", "GraphormerModel"),
("groupvit", "GroupViTModel"), ("groupvit", "GroupViTModel"),
("hubert", "HubertModel"), ("hubert", "HubertModel"),
...@@ -216,6 +217,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -216,6 +217,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("funnel", "FunnelForPreTraining"), ("funnel", "FunnelForPreTraining"),
("gpt-sw3", "GPT2LMHeadModel"), ("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("ibert", "IBertForMaskedLM"), ("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"), ("longformer", "LongformerForMaskedLM"),
...@@ -286,6 +288,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -286,6 +288,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("gpt_neox", "GPTNeoXForCausalLM"), ("gpt_neox", "GPTNeoXForCausalLM"),
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
("gptj", "GPTJForCausalLM"), ("gptj", "GPTJForCausalLM"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("ibert", "IBertForMaskedLM"), ("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"),
("led", "LEDForConditionalGeneration"), ("led", "LEDForConditionalGeneration"),
...@@ -579,6 +582,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -579,6 +582,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"), ("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "EncoderDecoderModel"), ("encoder-decoder", "EncoderDecoderModel"),
("fsmt", "FSMTForConditionalGeneration"), ("fsmt", "FSMTForConditionalGeneration"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("led", "LEDForConditionalGeneration"), ("led", "LEDForConditionalGeneration"),
("longt5", "LongT5ForConditionalGeneration"), ("longt5", "LongT5ForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"),
......
...@@ -154,6 +154,7 @@ else: ...@@ -154,6 +154,7 @@ else:
("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("gpt_neox", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)), ("gpt_neox_japanese", ("GPTNeoXJapaneseTokenizer", None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)), ("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)), ("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)), ("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)), ("hubert", ("Wav2Vec2CTCTokenizer", None)),
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = {
"configuration_gptsan_japanese": ["GPTSAN_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTSanJapaneseConfig"],
"tokenization_gptsan_japanese": ["GPTSanJapaneseTokenizer"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gptsan_japanese"] = [
"GPTSAN_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTSanJapaneseForConditionalGeneration",
"GPTSanJapaneseModel",
"GPTSanJapanesePreTrainedModel",
]
_import_structure["tokenization_gptsan_japanese"] = [
"GPTSanJapaneseTokenizer",
]
if TYPE_CHECKING:
from .configuration_gptsan_japanese import GPTSAN_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTSanJapaneseConfig
from .tokenization_gptsan_japanese import GPTSanJapaneseTokenizer
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gptsan_japanese import (
GPTSAN_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTSanJapaneseForConditionalGeneration,
GPTSanJapaneseModel,
GPTSanJapanesePreTrainedModel,
)
from .tokenization_gptsan_japanese import GPTSanJapaneseTokenizer
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2023, HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" GPTSAN-japanese model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
GPTSAN_JAPANESE_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"tanreinama/GPTSAN-2.8B-spout_is_uniform": (
"https://huggingface.co/tanreinama/GPTSAN-2.8B-spout_is_uniform/resolve/main/config.json"
),
}
class GPTSanJapaneseConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GPTSanJapaneseModel`]. It is used to instantiate
a GPTSANJapanese model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the GPTSANJapanese
[tanreinama/GPTSAN-2.8B-spout_is_uniform](https://huggingface.co/tanreinama/GPTSAN-2.8B-spout_is_uniform)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Arguments:
vocab_size (`int`, *optional*, defaults to 36000):
Vocabulary size of the GPTSANJapanese model. Defines the number of different tokens that can be represented
by the `inputs_ids` passed when calling [`GPTSanJapaneseModel`].
max_position_embeddings (`int`, *optional*, defaults to 1280):
The maximum sequence length that this model might ever be used with. Defaults set this to 1280.
d_model (`int`, *optional*, defaults to 1024):
Size of the encoder layers and the pooler layer.
d_ff (`int`, *optional*, defaults to 8192):
Size of the intermediate feed forward layer in each `SwitchTransformersBlock`.
d_ext (`int`, *optional*, defaults to 4096):
Size of the intermediate feed forward layer in each Extra-layers.
d_spout (`int`, *optional*, defaults to 128):
Size of the `spout` vector.
num_switch_layers (`int`, *optional*, defaults to 10):
Number of layers in the Switch Transformer layer.
num_ext_layers (`int`, *optional*, defaults to 0):
Number of layers in the Extra-layers.
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_experts (`int`, *optional*, defaults to 16):
Number of experts for each SwitchTransformer layer.
expert_capacity (`int`, *optional*, defaults to 128):
Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular
Transformer.
dropout_rate (`float`, *optional*, defaults to 0.0):
The ratio for all dropout layers.
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
router_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the router.
router_jitter_noise (`float`, *optional*, defaults to 0.0):
Amount of noise to add to the router. Set it to 0.0 during prediction or set small value (usually 1e-2)
during training.
router_dtype (`str`, *optional*, default to `"float32"`):
The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
*selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961).
router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`):
Whether to ignore padding tokens when routing.
output_hidden_states (`bool`, *optional*, default to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers.
initializer_factor (`float`, *optional*, defaults to 0.002):
A factor for initializing all weight matrices.
output_router_logits (`bool`, *optional*, default to `False`):
Whether or not to return the router logits of all experts.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models)
"""
model_type = "gptsan-japanese"
keys_to_ignore_at_inference = [
"past_key_values",
]
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
vocab_size=36000,
max_position_embeddings=1280,
d_model=1024,
d_ff=8192,
d_ext=4096,
d_spout=128,
num_switch_layers=10,
num_ext_layers=0,
num_heads=16,
num_experts=16,
expert_capacity=128,
dropout_rate=0.0,
layer_norm_epsilon=1e-5,
router_bias=False,
router_jitter_noise=0.0,
router_dtype="float32",
router_ignore_padding_tokens=False,
output_hidden_states=False,
output_attentions=False,
initializer_factor=0.002,
output_router_logits=False,
use_cache=True,
separator_token_id=35998,
pad_token_id=35995,
eos_token_id=35999,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.d_ff = d_ff
self.d_ext = d_ext
self.d_spout = d_spout
self.num_switch_layers = num_switch_layers
self.num_ext_layers = num_ext_layers
self.num_layers = num_switch_layers + num_ext_layers
self.num_heads = num_heads
self.num_experts = num_experts
self.expert_capacity = expert_capacity
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.router_bias = router_bias
self.router_jitter_noise = router_jitter_noise
self.router_dtype = router_dtype
self.router_ignore_padding_tokens = router_ignore_padding_tokens
self.output_hidden_states = output_hidden_states
self.output_attentions = output_attentions
self.initializer_factor = initializer_factor
self.output_router_logits = output_router_logits
self.use_cache = use_cache
super().__init__(
separator_token_id=separator_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert GPTSANJapanese checkpoints from the original repository to pytorch model."""
import argparse
import json
import os
from collections import OrderedDict
import numpy as np
import tensorflow as tf
import torch
def convert_tf_gptsan_to_pt(args):
parameter_file = os.path.join(args.tf_model_dir, "parameters.json")
params = json.loads(open(parameter_file).read())
if not params:
raise ValueError(
f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file."
)
if not args.output.endswith(".pt"):
args.output = args.output + ".pt"
new_state = OrderedDict()
with tf.device("/CPU:0"):
reader = tf.train.load_checkpoint(args.tf_model_dir)
shapes = reader.get_variable_to_shape_map()
for key_name in shapes.keys():
vnp = reader.get_tensor(key_name).astype(np.float16)
if key_name.endswith("/adam_m") or key_name.endswith("/adam_v"):
continue
if key_name.startswith("pasts/"):
if key_name.startswith("pasts/mlp"):
player = int(key_name[9])
elif key_name.startswith("pasts/out"):
player = 8
name = "model.sqout.%d.weight" % (player * 2) # enter to nn.Sequencial with Tanh, so 2 at a time
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
new_state[name] = torch.tensor(state)
elif key_name.startswith("model/moe"):
player = int(key_name[9:].split("/")[0])
if key_name.endswith("/switch_gating/kernel"):
name = "model.blocks.%d.feed_forward.mlp.router.classifier.weight" % player
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
new_state[name] = torch.tensor(state)
elif key_name.endswith("/softmlp/kernel"):
name = "model.blocks.%d.feed_forward.soft_bypass_mlp.weight" % player
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
new_state[name] = torch.tensor(state)
elif key_name.endswith("/wo/kernel") or key_name.endswith("/wi/kernel"):
nlayer = key_name[-9:-7]
for i in range(16):
name = "model.blocks.%d.feed_forward.mlp.experts.expert_%d.%s.weight" % (player, i, nlayer)
state = (
vnp[i].transpose([1, 0]).copy()
) # In Mesh-Tensorflow, it is one array, so it is divided
new_state[name] = torch.tensor(state)
elif key_name.startswith("model/mlp"):
player = int(key_name[9:].split("/")[0])
if key_name.endswith("/p1/kernel"):
name = "model.blocks.%d.feed_forward.mlp.wi.weight" % player
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
new_state[name] = torch.tensor(state)
elif key_name.endswith("/p1/bias"):
name = "model.blocks.%d.feed_forward.mlp.wi.bias" % player
state = vnp.copy() # same because it is one dimensional
new_state[name] = torch.tensor(state)
elif key_name.endswith("/p2/kernel"):
name = "model.blocks.%d.feed_forward.mlp.wo.weight" % player
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
new_state[name] = torch.tensor(state)
elif key_name.endswith("/p2/bias"):
name = "model.blocks.%d.feed_forward.mlp.wo.bias" % player
state = vnp.copy() # same because it is one dimensional
new_state[name] = torch.tensor(state)
elif key_name.startswith("model/ln"):
player = int(key_name[8:].split("/")[0])
if key_name.endswith("/b"):
name = "model.blocks.%d.feed_forward.norm.bias" % player
state = vnp.copy() # same because it is one dimensional
new_state[name] = torch.tensor(state)
elif key_name.endswith("/g"):
name = "model.blocks.%d.feed_forward.norm.weight" % player
state = vnp.copy() # same because it is one dimensional
new_state[name] = torch.tensor(state)
elif key_name.startswith("model/att"):
player = int(key_name[9:].split("/")[0])
if key_name.endswith("/qkv/kernel"):
state = vnp.copy() # Compute same dimension as Mesh-tensorflow using einsum
state_q = state[:, 0, :, :]
state_k = state[:, 1, :, :]
state_v = state[:, 2, :, :]
state_q = (
state_q.reshape([state_q.shape[0], state_q.shape[1] * state_q.shape[2]])
.transpose([1, 0])
.copy()
) # Mesh-Tensorflow is a diagonal matrix
state_k = (
state_k.reshape([state_k.shape[0], state_k.shape[1] * state_k.shape[2]])
.transpose([1, 0])
.copy()
) # Mesh-Tensorflow is a diagonal matrix
state_v = (
state_v.reshape([state_v.shape[0], state_v.shape[1] * state_v.shape[2]])
.transpose([1, 0])
.copy()
) # Mesh-Tensorflow is a diagonal matrix
name = "model.blocks.%d.self_attn.self_attn.q_proj.weight" % player
new_state[name] = torch.tensor(state_q)
name = "model.blocks.%d.self_attn.self_attn.k_proj.weight" % player
new_state[name] = torch.tensor(state_k)
name = "model.blocks.%d.self_attn.self_attn.v_proj.weight" % player
new_state[name] = torch.tensor(state_v)
elif key_name.endswith("/o/kernel"):
name = "model.blocks.%d.self_attn.self_attn.out_proj.weight" % player
state = (
vnp.reshape([vnp.shape[0] * vnp.shape[1], vnp.shape[2]]).transpose([1, 0]).copy()
) # Mesh-Tensorflow is a diagonal matrix
new_state[name] = torch.tensor(state)
elif key_name.startswith("model/an"):
player = int(key_name[8:].split("/")[0])
if key_name.endswith("/b"):
name = "model.blocks.%d.self_attn.norm.bias" % player
state = vnp.copy() # same because it is one dimensional
new_state[name] = torch.tensor(state)
elif key_name.endswith("/g"):
name = "model.blocks.%d.self_attn.norm.weight" % player
state = vnp.copy() # same because it is one dimensional
new_state[name] = torch.tensor(state)
elif (
key_name.startswith("model/wte")
or key_name.startswith("model/wpe")
or key_name.startswith("model/ete")
):
nlayer = {"wte": "embed_tokens", "wpe": "position_embeddings", "ete": "extra_position_embeddings"}[
key_name[-3:]
]
name = "model.%s.weight" % nlayer
state = vnp.copy() # same in embedded
new_state[name] = torch.tensor(state)
if key_name.startswith("model/wte"):
name = "lm_head.weight"
state = vnp.copy() # same in embedded
new_state[name] = torch.tensor(state)
elif key_name.startswith("model/wob"):
name = "final_logits_bias"
state = vnp.copy() # same in embedded
state = state.reshape((1, -1))
new_state[name] = torch.tensor(state)
elif key_name == "model/dense/kernel":
name = "model.last_project.weight"
state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
new_state[name] = torch.tensor(state)
elif key_name == "model/dense_1/bias":
name = "model.last_project.bias"
state = vnp.copy() # same because it is one dimensional
new_state[name] = torch.tensor(state)
torch.save(new_state, args.output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model")
parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model")
args = parser.parse_args()
convert_tf_gptsan_to_pt(args)
...@@ -3085,6 +3085,30 @@ class GPTJPreTrainedModel(metaclass=DummyObject): ...@@ -3085,6 +3085,30 @@ class GPTJPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
GPTSAN_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST = None
class GPTSanJapaneseForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTSanJapaneseModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GPTSanJapanesePreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
GRAPHORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None GRAPHORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 Toshiyuki Sakamoto(tanreinama) and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import (
GPTSanJapaneseConfig,
GPTSanJapaneseForConditionalGeneration,
GPTSanJapaneseModel,
GPTSanJapaneseTokenizer,
is_torch_available,
)
from transformers.generation import GenerationConfig
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
class GPTSanJapaneseTester:
def __init__(
self,
parent,
vocab_size=36000,
batch_size=13,
num_contexts=7,
# For common tests
is_training=True,
hidden_size=32,
ext_size=42,
num_hidden_layers=5,
num_ext_layers=2,
num_attention_heads=4,
num_experts=2,
d_ff=32,
d_ext=80,
d_spout=33,
dropout_rate=0.0,
layer_norm_epsilon=1e-6,
expert_capacity=100,
router_jitter_noise=0.0,
):
self.vocab_size = vocab_size
self.parent = parent
self.batch_size = batch_size
self.num_contexts = num_contexts
# For common tests
self.seq_length = self.num_contexts
self.is_training = is_training
self.hidden_size = hidden_size
self.num_ext_layers = num_ext_layers
self.ext_size = ext_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_experts = num_experts
self.d_ff = d_ff
self.d_ext = d_ext
self.d_spout = d_spout
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.expert_capacity = expert_capacity
self.router_jitter_noise = router_jitter_noise
def get_large_model_config(self):
return GPTSanJapaneseConfig.from_pretrained("Tanrei/GPTSAN-japanese")
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = self.get_config()
return (config, input_ids)
def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = self.get_config()
return (config, {"input_ids": input_ids})
def get_config(self):
return GPTSanJapaneseConfig(
vocab_size=36000,
num_contexts=self.seq_length,
d_model=self.hidden_size,
d_ff=self.d_ff,
d_ext=self.d_ext,
d_spout=self.d_spout,
num_switch_layers=self.num_hidden_layers - self.num_ext_layers,
num_ext_layers=self.num_ext_layers,
num_heads=self.num_attention_heads,
num_experts=self.num_experts,
expert_capacity=self.expert_capacity,
dropout_rate=self.dropout_rate,
layer_norm_epsilon=self.layer_norm_epsilon,
router_jitter_noise=self.router_jitter_noise,
)
def create_and_check_model(
self,
config,
input_ids,
):
model = GPTSanJapaneseForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids=input_ids,
)
self.parent.assertIsNotNone(result)
@require_torch
class GPTSanJapaneseTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (GPTSanJapaneseModel,) if is_torch_available() else ()
fx_compatible = False
is_encoder_decoder = False
test_pruning = False
test_headmasking = False
test_cpu_offload = False
test_disk_offload = False
test_save_load_fast_init_to_base = False
test_training = False
# The small GPTSAN_JAPANESE model needs higher percentages for CPU/MP tests
model_split_percents = [0.8, 0.9]
def setUp(self):
self.model_tester = GPTSanJapaneseTester(self)
self.config_tester = ConfigTester(self, config_class=GPTSanJapaneseConfig, d_model=37)
def test_config(self):
GPTSanJapaneseConfig()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@require_torch
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (GPTSanJapaneseForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
is_encoder_decoder = False
test_pruning = False
test_headmasking = False
test_cpu_offload = False
test_disk_offload = False
# The small GPTSAN_JAPANESE model needs higher percentages for CPU/MP tests
model_split_percents = [0.8, 0.9]
def setUp(self):
self.model_tester = GPTSanJapaneseTester(self)
self.config_tester = ConfigTester(self, config_class=GPTSanJapaneseConfig, d_model=37)
def test_config(self):
GPTSanJapaneseConfig()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@slow
def test_logits(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
input_ids = tokenizer.encode("武田信玄は", return_tensors="pt")
outputs = model(input_ids)
output_logits = outputs.logits.detach().cpu().numpy()
# Output of original model created with mesh-tensoflow
target = [
# fmt: off
[-12.037839889526367, -12.433061599731445, -14.333840370178223, -12.450345993041992, -11.1661376953125,
-11.930137634277344, -10.659740447998047, -12.909574508666992, -13.241043090820312, -13.398579597473145,
-11.107524871826172, -12.3685941696167, -22.97943115234375, -10.481067657470703, -12.484030723571777,
-12.807360649108887, -14.769700050354004, -12.233579635620117, -13.428145408630371, -22.624177932739258],
[-7.511149883270264, -8.281851768493652, -7.943127155303955, -7.55021333694458, -6.49869966506958,
-7.586796283721924, -6.978085994720459, -7.839145183563232, -8.21964168548584, -8.695091247558594,
-6.706910610198975, -6.6585798263549805, -19.565698623657227, -5.353842735290527, -8.350686073303223,
-8.039388656616211, -10.856569290161133, -7.75154447555542, -8.819022178649902, -19.51532745361328],
[-9.73066234588623, -10.223922729492188, -9.932981491088867, -11.857836723327637, -7.662626266479492,
-11.13529109954834, -7.765097618103027, -11.472923278808594, -9.543149948120117, -11.905633926391602,
-9.366164207458496, -11.5734281539917, -23.699003219604492, -9.429590225219727, -10.42839241027832,
-10.585240364074707, -10.94771957397461, -11.095416069030762, -10.390240669250488, -23.769372940063477],
[-9.728265762329102, -9.859712600708008, -10.09729290008545, -9.678522109985352, -6.879519939422607,
-9.68487548828125, -4.2803425788879395, -10.018914222717285, -9.308445930480957, -10.63394546508789,
-8.083646774291992, -9.06301498413086, -21.904266357421875, -8.90160846710205, -8.841876029968262,
-11.856719970703125, -12.079398155212402, -11.233753204345703, -10.177338600158691, -21.87256622314453],
[-9.669764518737793, -9.614198684692383, -9.814510345458984, -9.996501922607422, -11.375690460205078,
-10.113405227661133, -10.546867370605469, -10.04369068145752, -10.907809257507324, -10.504216194152832,
-11.129199028015137, -10.151124000549316, -21.96586799621582, -9.086349487304688, -11.730339050292969,
-10.460667610168457, -10.298049926757812, -10.784148216247559, -10.840693473815918, -22.03152847290039],
# fmt: on
]
target = np.array(target).flatten()
predict = output_logits[0, :, :20].flatten()
def check(a, b, epsilon=5e-4):
return abs(a - b) < epsilon * max(abs(a), abs(b))
self.assertTrue(np.all([check(target[i], predict[i]) for i in range(len(target))]))
@slow
def test_batch_generation(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
model.to(torch_device)
# set deterministically
generation_config = GenerationConfig.from_pretrained("Tanrei/GPTSAN-japanese")
generation_config.top_k = 1
# use different length sentences to test batching
sentences = [
"甲斐なら武田と言うほど",
"織田信長は、",
]
tokenizer.padding_side = "left"
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
self.assertNotEqual(inputs["attention_mask"][0].numpy().tolist(), inputs["attention_mask"][1].numpy().tolist())
outputs = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
max_new_tokens=3,
generation_config=generation_config,
)
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
output_non_padded = model.generate(
input_ids=inputs_non_padded, max_new_tokens=3, generation_config=generation_config
)
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
output_padded = model.generate(input_ids=inputs_padded, max_new_tokens=3, generation_config=generation_config)
self.assertNotEqual(inputs_non_padded.shape, inputs_padded.shape)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
expected_output_sentence = [
"甲斐なら武田と言うほど甲斐の武田",
"織田信長は、このような",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
@tooslow
def test_sample(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
# Output of original model created with mesh-tensoflow
target = [
("武田信玄は", 35675),
("武田信玄は、", 45),
("武田信玄は、この", 29),
("武田信玄は、このよう", 30642),
("武田信玄は、このような", 35680),
("武田信玄は、このような「", 8640),
("武田信玄は、このような「武田", 31617),
("武田信玄は、このような「武田家", 30646),
("武田信玄は、このような「武田家の", 31617),
("武田信玄は、このような「武田家の家", 31381),
]
for input, output in target:
input_ids = tokenizer.encode(input, return_tensors="pt")
outputs = model(input_ids)
output_logits = outputs.logits.detach().cpu().numpy()[0]
output_id = np.argmax(output_logits[-1])
self.assertEqual(output_id, output)
@slow
def test_spout_generation(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
model.to(torch_device)
# set deterministically
generation_config = GenerationConfig.from_pretrained("Tanrei/GPTSAN-japanese")
generation_config.top_k = 1
input_text = "武田信玄は、"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(torch_device)
input_ids_batch = tokenizer([input_text, input_text], return_tensors="pt").input_ids.to(torch_device)
# spout from uniform and one-hot
spouts = [
# fmt: off
[0.87882208, 0.38426396, 0.33220248, 0.43890406, 0.16562252,
0.04803985, 0.211572 , 0.23188473, 0.37153068, 0.7836377 ,
0.02160172, 0.38761719, 0.75290772, 0.90198857, 0.34365777,
0.64168169, 0.44318471, 0.14575746, 0.92562881, 0.40812148,
0.29019122, 0.88861599, 0.65524846, 0.43563456, 0.38177187,
0.70832965, 0.81527892, 0.68832812, 0.38833192, 0.4561522 ,
0.14828817, 0.47248213, 0.54357335, 0.82009566, 0.1338884 ,
0.02755417, 0.19764677, 0.2422084 , 0.04757674, 0.65409606,
0.0824589 , 0.03304383, 0.94387689, 0.98764509, 0.82433901,
0.27646741, 0.64907493, 0.76009406, 0.30087915, 0.17904689,
0.41601714, 0.67046398, 0.10422822, 0.08447374, 0.07354344,
0.61423565, 0.70284866, 0.7532333 , 0.1972038 , 0.29575659,
0.90583886, 0.29265307, 0.50000175, 0.70407655, 0.889363 ,
0.81904418, 0.66829128, 0.64468815, 0.56563723, 0.85601875,
0.94924672, 0.00166762, 0.25220643, 0.74540219, 0.67993247,
0.1549675 , 0.39385352, 0.92153607, 0.63745931, 0.27759043,
0.84702295, 0.65904271, 0.58676614, 0.8666936 , 0.39607438,
0.79954983, 0.42220697, 0.39650381, 0.7849864 , 0.56150201,
0.15678925, 0.14746032, 0.34542114, 0.47026783, 0.11956489,
0.25421435, 0.33788901, 0.68934842, 0.36424685, 0.71737898,
0.38983449, 0.94393779, 0.39575588, 0.36616553, 0.87104665,
0.64630203, 0.22516905, 0.88270804, 0.15031338, 0.75144345,
0.46459025, 0.85396454, 0.86355643, 0.65139851, 0.70266061,
0.30241389, 0.81056497, 0.88865969, 0.38773807, 0.70635849,
0.90718459, 0.43245789, 0.28000654, 0.45935562, 0.08773519,
0.9552151 , 0.93901511, 0.22489288], # uniform
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.],
# fmt: on
]
output1 = model.generate(
input_ids=input_ids,
spout=spouts[0],
max_new_tokens=20,
generation_config=generation_config,
)
output2 = model.generate(
input_ids=input_ids,
spout=spouts[1],
max_new_tokens=20,
generation_config=generation_config,
)
output3 = model.generate(
input_ids=input_ids_batch,
spout=spouts,
max_new_tokens=20,
generation_config=generation_config,
)
out1_sentence = tokenizer.decode(output1[0])
out2_sentence = tokenizer.decode(output2[0])
batch_out_sentence = tokenizer.batch_decode(output3)
expected_output_sentence = [
"武田信玄は、武田氏の滅亡後、武田氏の居城であった甲斐武田氏の居城である",
"武田信玄は、武田家の滅亡を防ぐため、武田家の家臣である武田信虎を討",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(batch_out_sentence, [out1_sentence, out2_sentence])
@slow
def test_prefix_lm_generation(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
model.to(torch_device)
# set deterministically
generation_config = GenerationConfig.from_pretrained("Tanrei/GPTSAN-japanese")
generation_config.top_k = 1
prefix_text_1 = "武田信玄"
prefix_text_2 = "織田信長"
input_text_1 = "は、"
input_text_2 = "が、"
input_tok_1 = tokenizer(input_text_1, prefix_text=prefix_text_1, return_tensors="pt")
input_tok_2 = tokenizer(input_text_2, prefix_text=prefix_text_2, return_tensors="pt")
input_tok_3 = tokenizer([[prefix_text_1, input_text_1], [prefix_text_2, input_text_2]], return_tensors="pt")
output1 = model.generate(
input_ids=input_tok_1.input_ids.to(torch_device),
token_type_ids=input_tok_1.token_type_ids.to(torch_device),
max_new_tokens=20,
generation_config=generation_config,
)
output2 = model.generate(
input_ids=input_tok_2.input_ids.to(torch_device),
token_type_ids=input_tok_2.token_type_ids.to(torch_device),
max_new_tokens=20,
generation_config=generation_config,
)
output3 = model.generate(
input_ids=input_tok_3.input_ids.to(torch_device),
token_type_ids=input_tok_3.token_type_ids.to(torch_device),
attention_mask=input_tok_3.attention_mask.to(torch_device),
max_new_tokens=20,
generation_config=generation_config,
)
out1_sentence = tokenizer.decode(output1[0])
out2_sentence = tokenizer.decode(output2[0])
batch_out_sentence = tokenizer.batch_decode(output3)
expected_output_sentence = [
"武田信玄は、武田氏の祖である武田信虎を、その子・武田信友を擁して",
"織田信長が、織田信長の妻・お市の方を妻として迎えたという逸話が残",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(batch_out_sentence, [out1_sentence, out2_sentence])
# coding=utf-8
# Copyright 2023 Toshiyuki Sakamoto(tanreinama) and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import unittest
from transformers.models.gptsan_japanese.tokenization_gptsan_japanese import (
VOCAB_FILES_NAMES,
GPTSanJapaneseTokenizer,
)
from transformers.testing_utils import require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = GPTSanJapaneseTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"do_clean_text": False, "add_prefix_space": False}
def setUp(self):
super().setUp()
# fmt: off
vocab_tokens = ["こん", "こんに", "にちは", "ばんは", "世界,㔺界", "、", "。", "<BR>", "<SP>", "<TAB>", "<URL>", "<EMAIL>", "<TEL>", "<DATE>", "<PRICE>", "<BLOCK>", "<KIGOU>", "<U2000U2BFF>", "<|emoji1|>", "<unk>", "<|bagoftoken|>", "<|endoftext|>"]
# fmt: on
emoji_tokens = {"emoji": {"\ud83d\ude00": "<|emoji1|>"}, "emoji_inv": {"<|emoji1|>": "\ud83d\ude00"}} # 😀
self.special_tokens_map = {"unk_token": "<unk>"}
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
self.emoji_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["emoji_file"])
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
with open(self.emoji_file, "w") as emoji_writer:
emoji_writer.write(json.dumps(emoji_tokens))
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return GPTSanJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs)
# Copied from tests.models.gpt_neox_japanese.test_tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizationTest.get_input_output_texts
def get_input_output_texts(self, tokenizer):
input_text = "こんにちは、世界。 \nこんばんは、㔺界。😀"
output_text = "こんにちは、世界。 \nこんばんは、世界。😀"
return input_text, output_text
# Copied from tests.models.gpt_neox_japanese.test_tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizationTest.get_clean_sequence
def get_clean_sequence(self, tokenizer):
input_text, output_text = self.get_input_output_texts(tokenizer)
ids = tokenizer.encode(output_text, add_special_tokens=False)
text = tokenizer.decode(ids, clean_up_tokenization_spaces=False)
return text, ids
# Copied from tests.models.gpt_neox_japanese.test_tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizationTest.test_pretokenized_inputs
def test_pretokenized_inputs(self):
pass # TODO add if relevant
# Copied from tests.models.gpt_neox_japanese.test_tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizationTest.test_maximum_encoding_length_pair_input
def test_maximum_encoding_length_pair_input(self):
pass # TODO add if relevant
# Copied from tests.models.gpt_neox_japanese.test_tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizationTest.test_maximum_encoding_length_single_input
def test_maximum_encoding_length_single_input(self):
pass # TODO add if relevant
# Copied from tests.models.gpt_neox_japanese.test_tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizationTest.test_full_tokenizer
def test_full_tokenizer(self):
tokenizer = self.get_tokenizer()
# Testing tokenization
input_text = "こんにちは、世界。 こんばんは、㔺界。"
expected_token = ["こん", "にちは", "、", "世界", "。", "<SP>", "こん", "ばんは", "、", "㔺界", "。"]
tokens = tokenizer.tokenize(input_text)
self.assertListEqual(tokens, expected_token)
# Testing conversion to ids without special tokens
expected_ids = [0, 2, 5, 4, 6, 8, 0, 3, 5, 4, 6]
input_ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(input_ids, expected_ids)
# Testing conversion to ids with special tokens
input_tokens = tokens + [tokenizer.unk_token]
expected_ids = [0, 2, 5, 4, 6, 8, 0, 3, 5, 4, 6, 19]
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
self.assertListEqual(input_ids, expected_ids)
def test_token_bagging(self):
tokenizer = self.get_tokenizer()
# Testing tokenization
input_text = "こんにちは、<|bagoftoken|>世界。こんばんは、<|bagoftoken|>㔺界。"
expected_text = "こんにちは、、、、世界。こんばんは、、、、世界。"
tokens = tokenizer.encode(input_text)
output_text = tokenizer.decode(tokens)
self.assertEqual(output_text, expected_text)
@slow
def test_prefix_input(self):
tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
# Testing tokenization
prefix_text = "こんにちは、世界。"
input_text = "こんばんは、㔺界。😀"
expected_text = "こんにちは、世界。こんばんは、世界。😀"
tokens_1 = tokenizer.encode(prefix_text + input_text)
tokens_2 = tokenizer.encode("", prefix_text=prefix_text + input_text)
tokens_3 = tokenizer.encode(input_text, prefix_text=prefix_text)
output_text_1 = tokenizer.decode(tokens_1)
output_text_2 = tokenizer.decode(tokens_2)
output_text_3 = tokenizer.decode(tokens_3)
self.assertEqual(output_text_1, expected_text)
self.assertEqual(output_text_2, expected_text)
self.assertEqual(output_text_3, expected_text)
@slow
def test_token_type_ids(self):
tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
# Testing tokenization
prefix_text = "こんにちは、世界。"
input_text = "こんばんは、㔺界。😀"
len_prefix = len(tokenizer.encode(prefix_text)) - 2
len_text = len(tokenizer.encode(input_text)) - 2
expected_mask_1 = [1] + [0] * (len_prefix + len_text + 1)
expected_mask_2 = [1] * (len_prefix + len_text + 1) + [0]
expected_mask_3 = [1] + [1] * (len_prefix) + [0] * (len_text + 1)
type_id_1 = tokenizer(prefix_text + input_text).token_type_ids
type_id_2 = tokenizer("", prefix_text=prefix_text + input_text).token_type_ids
type_id_3 = tokenizer(input_text, prefix_text=prefix_text).token_type_ids
self.assertListEqual(type_id_1, expected_mask_1)
self.assertListEqual(type_id_2, expected_mask_2)
self.assertListEqual(type_id_3, expected_mask_3)
@slow
def test_prefix_tokens(self):
tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
x_token_1 = tokenizer.encode("あンいワ")
x_token_2 = tokenizer.encode("", prefix_text="あンいワ")
x_token_3 = tokenizer.encode("いワ", prefix_text="あン")
self.assertEqual(tokenizer.decode(x_token_1), tokenizer.decode(x_token_2))
self.assertEqual(tokenizer.decode(x_token_1), tokenizer.decode(x_token_3))
self.assertNotEqual(x_token_1, x_token_2)
self.assertNotEqual(x_token_1, x_token_3)
self.assertEqual(x_token_1[1], x_token_2[-1]) # SEG token
self.assertEqual(x_token_1[1], x_token_3[3]) # SEG token
@slow
def test_batch_encode(self):
tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
input_pairs = [["武田信玄", "は、"], ["織田信長", "の配下の、"]]
x_token = tokenizer(input_pairs, padding=True)
x_token_2 = tokenizer.batch_encode_plus(input_pairs, padding=True)
# fmt: off
expected_outputs = [[35993, 8640, 25948, 35998, 30647, 35675, 35999, 35999], [35993, 10382, 9868, 35998, 30646, 9459, 30646, 35675]]
expected_typeids = [[1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0]]
expected_attmask = [[1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1]]
# fmt: on
self.assertListEqual(x_token.input_ids, expected_outputs)
self.assertListEqual(x_token.token_type_ids, expected_typeids)
self.assertListEqual(x_token.attention_mask, expected_attmask)
self.assertListEqual(x_token_2.input_ids, expected_outputs)
self.assertListEqual(x_token_2.token_type_ids, expected_typeids)
self.assertListEqual(x_token_2.attention_mask, expected_attmask)
# Copied from tests.models.gpt_neox_japanese.test_tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizationTest.test_conversion_reversible
def test_conversion_reversible(self):
# Intentionally convert some words to accommodate character fluctuations unique to Japanese
pass
# Copied from tests.models.gpt_neox_japanese.test_tokenization_gpt_neox_japanese.GPTNeoXJapaneseTokenizationTest.test_padding_different_model_input_name
def test_padding_different_model_input_name(self):
# tokenizer has no padding token
pass
...@@ -198,6 +198,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -198,6 +198,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"CLIPSegVisionModel", "CLIPSegVisionModel",
"CLIPSegTextModel", "CLIPSegTextModel",
"EsmForProteinFolding", "EsmForProteinFolding",
"GPTSanJapaneseModel",
"TimeSeriesTransformerForPrediction", "TimeSeriesTransformerForPrediction",
"JukeboxVQVAE", "JukeboxVQVAE",
"JukeboxPrior", "JukeboxPrior",
......
...@@ -92,6 +92,7 @@ src/transformers/models/glpn/modeling_glpn.py ...@@ -92,6 +92,7 @@ src/transformers/models/glpn/modeling_glpn.py
src/transformers/models/gpt2/configuration_gpt2.py src/transformers/models/gpt2/configuration_gpt2.py
src/transformers/models/gpt2/modeling_gpt2.py src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gptj/modeling_gptj.py src/transformers/models/gptj/modeling_gptj.py
src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py
src/transformers/models/gpt_neo/configuration_gpt_neo.py src/transformers/models/gpt_neo/configuration_gpt_neo.py
src/transformers/models/gpt_neox/configuration_gpt_neox.py src/transformers/models/gpt_neox/configuration_gpt_neox.py
src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py src/transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment