Unverified Commit 626a0a01 authored by yujun's avatar yujun Committed by GitHub
Browse files

[RoFormer] Fix some issues (#12397)



* add RoFormerTokenizerFast into AutoTokenizer

* fix typo in roformer docs

* make onnx export happy

* update RoFormerConfig embedding_size

* use jieba not rjieba

* fix 12244 and make test_alignement passed

* update ARCHIVE_MAP

* make style & quality & fixup

* update

* make style & quality & fixup

* make style quality fixup

* update

* suggestion from LysandreJik
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* make style

* use rjieba
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent f5b0c1ec
...@@ -56,7 +56,7 @@ RoFormerTokenizer ...@@ -56,7 +56,7 @@ RoFormerTokenizer
create_token_type_ids_from_sequences, save_vocabulary create_token_type_ids_from_sequences, save_vocabulary
RobertaTokenizerFast RoFormerTokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RoFormerTokenizerFast .. autoclass:: transformers.RoFormerTokenizerFast
......
...@@ -315,6 +315,10 @@ def is_datasets_available(): ...@@ -315,6 +315,10 @@ def is_datasets_available():
return _datasets_available return _datasets_available
def is_rjieba_available():
return importlib.util.find_spec("rjieba") is not None
def is_psutil_available(): def is_psutil_available():
return importlib.util.find_spec("psutil") is not None return importlib.util.find_spec("psutil") is not None
......
...@@ -198,6 +198,7 @@ if is_tokenizers_available(): ...@@ -198,6 +198,7 @@ if is_tokenizers_available():
from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast from ..reformer.tokenization_reformer_fast import ReformerTokenizerFast
from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast from ..retribert.tokenization_retribert_fast import RetriBertTokenizerFast
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
from ..roformer.tokenization_roformer_fast import RoFormerTokenizerFast
from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast from ..squeezebert.tokenization_squeezebert_fast import SqueezeBertTokenizerFast
from ..t5.tokenization_t5_fast import T5TokenizerFast from ..t5.tokenization_t5_fast import T5TokenizerFast
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
...@@ -232,6 +233,7 @@ else: ...@@ -232,6 +233,7 @@ else:
ReformerTokenizerFast = None ReformerTokenizerFast = None
RetriBertTokenizerFast = None RetriBertTokenizerFast = None
RobertaTokenizerFast = None RobertaTokenizerFast = None
RoFormerTokenizerFast = None
SqueezeBertTokenizerFast = None SqueezeBertTokenizerFast = None
T5TokenizerFast = None T5TokenizerFast = None
XLMRobertaTokenizerFast = None XLMRobertaTokenizerFast = None
...@@ -245,7 +247,7 @@ logger = logging.get_logger(__name__) ...@@ -245,7 +247,7 @@ logger = logging.get_logger(__name__)
TOKENIZER_MAPPING = OrderedDict( TOKENIZER_MAPPING = OrderedDict(
[ [
(RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)), (RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)),
(RoFormerConfig, (RoFormerTokenizer, None)), (RoFormerConfig, (RoFormerTokenizer, RoFormerTokenizerFast)),
(T5Config, (T5Tokenizer, T5TokenizerFast)), (T5Config, (T5Tokenizer, T5TokenizerFast)),
(MT5Config, (MT5Tokenizer, MT5TokenizerFast)), (MT5Config, (MT5Tokenizer, MT5TokenizerFast)),
(MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)), (MobileBertConfig, (MobileBertTokenizer, MobileBertTokenizerFast)),
......
...@@ -22,7 +22,11 @@ logger = logging.get_logger(__name__) ...@@ -22,7 +22,11 @@ logger = logging.get_logger(__name__)
ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json", "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/config.json",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json" "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/config.json",
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/config.json",
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/config.json",
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/config.json",
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/config.json",
# See all RoFormer models at https://huggingface.co/models?filter=roformer # See all RoFormer models at https://huggingface.co/models?filter=roformer
} }
...@@ -43,8 +47,9 @@ class RoFormerConfig(PretrainedConfig): ...@@ -43,8 +47,9 @@ class RoFormerConfig(PretrainedConfig):
Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by Vocabulary size of the RoFormer model. Defines the number of different tokens that can be represented by
the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or the :obj:`inputs_ids` passed when calling :class:`~transformers.RoFormerModel` or
:class:`~transformers.TFRoFormerModel`. :class:`~transformers.TFRoFormerModel`.
embedding_size (:obj:`int`, `optional`, defaults to 768): embedding_size (:obj:`int`, `optional`, defaults to None):
Dimensionality of the encoder layers and the pooler layer. Dimensionality of the encoder layers and the pooler layer. Defaults to the :obj:`hidden_size` if not
provided.
hidden_size (:obj:`int`, `optional`, defaults to 768): hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimension of the encoder layers and the pooler layer. Dimension of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12): num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
...@@ -96,7 +101,7 @@ class RoFormerConfig(PretrainedConfig): ...@@ -96,7 +101,7 @@ class RoFormerConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=50000, vocab_size=50000,
embedding_size=768, embedding_size=None,
hidden_size=768, hidden_size=768,
num_hidden_layers=12, num_hidden_layers=12,
num_attention_heads=12, num_attention_heads=12,
...@@ -117,7 +122,7 @@ class RoFormerConfig(PretrainedConfig): ...@@ -117,7 +122,7 @@ class RoFormerConfig(PretrainedConfig):
super().__init__(pad_token_id=pad_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding_size = embedding_size self.embedding_size = hidden_size if embedding_size is None else embedding_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
......
...@@ -60,7 +60,11 @@ _TOKENIZER_FOR_DOC = "RoFormerTokenizer" ...@@ -60,7 +60,11 @@ _TOKENIZER_FOR_DOC = "RoFormerTokenizer"
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"junnyu/roformer_chinese_small", "junnyu/roformer_chinese_small",
"junnyu/roformer_chinese_base" "junnyu/roformer_chinese_base",
"junnyu/roformer_chinese_char_small",
"junnyu/roformer_chinese_char_base",
"junnyu/roformer_small_discriminator",
"junnyu/roformer_small_generator"
# See all RoFormer models at https://huggingface.co/models?filter=roformer # See all RoFormer models at https://huggingface.co/models?filter=roformer
] ]
...@@ -327,9 +331,9 @@ class RoFormerSelfAttention(nn.Module): ...@@ -327,9 +331,9 @@ class RoFormerSelfAttention(nn.Module):
# cos [batch_size, num_heads, sequence_length, embed_size_per_head//2] # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
sin, cos = sinusoidal_pos.chunk(2, dim=-1) sin, cos = sinusoidal_pos.chunk(2, dim=-1)
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
sin_pos = torch.repeat_interleave(sin, 2, dim=-1) sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1] # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = torch.repeat_interleave(cos, 2, dim=-1) cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2] # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as( rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
query_layer query_layer
......
...@@ -65,7 +65,11 @@ _TOKENIZER_FOR_DOC = "RoFormerTokenizer" ...@@ -65,7 +65,11 @@ _TOKENIZER_FOR_DOC = "RoFormerTokenizer"
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"junnyu/roformer_chinese_small", "junnyu/roformer_chinese_small",
"junnyu/roformer_chinese_base" "junnyu/roformer_chinese_base",
"junnyu/roformer_chinese_char_small",
"junnyu/roformer_chinese_char_base",
"junnyu/roformer_small_discriminator",
"junnyu/roformer_small_generator"
# See all RoFormer models at https://huggingface.co/models?filter=roformer # See all RoFormer models at https://huggingface.co/models?filter=roformer
] ]
......
...@@ -31,15 +31,30 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -31,15 +31,30 @@ PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt", "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536} PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"junnyu/roformer_chinese_small": 1536,
"junnyu/roformer_chinese_base": 1536,
"junnyu/roformer_chinese_char_small": 512,
"junnyu/roformer_chinese_char_base": 512,
"junnyu/roformer_small_discriminator": 128,
"junnyu/roformer_small_generator": 128,
}
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
"junnyu/roformer_chinese_small": {"do_lower_case": True}, "junnyu/roformer_chinese_small": {"do_lower_case": True},
"junnyu/roformer_chinese_base": {"do_lower_case": True}, "junnyu/roformer_chinese_base": {"do_lower_case": True},
"junnyu/roformer_chinese_char_small": {"do_lower_case": True},
"junnyu/roformer_chinese_char_base": {"do_lower_case": True},
"junnyu/roformer_small_discriminator": {"do_lower_case": True},
"junnyu/roformer_small_generator": {"do_lower_case": True},
} }
...@@ -166,13 +181,8 @@ class RoFormerTokenizer(PreTrainedTokenizer): ...@@ -166,13 +181,8 @@ class RoFormerTokenizer(PreTrainedTokenizer):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__ = d self.__dict__ = d
try: import rjieba
import rjieba
except ImportError:
raise ImportError(
"You need to install rjieba to use RoFormerTokenizer."
"See https://pypi.org/project/rjieba/ for installation."
)
self.jieba = rjieba self.jieba = rjieba
def get_vocab(self): def get_vocab(self):
......
...@@ -33,15 +33,30 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -33,15 +33,30 @@ PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt", "junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", "junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt",
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt",
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt",
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt",
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt",
} }
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536} PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"junnyu/roformer_chinese_small": 1536,
"junnyu/roformer_chinese_base": 1536,
"junnyu/roformer_chinese_char_small": 512,
"junnyu/roformer_chinese_char_base": 512,
"junnyu/roformer_small_discriminator": 128,
"junnyu/roformer_small_generator": 128,
}
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
"junnyu/roformer_chinese_small": {"do_lower_case": True}, "junnyu/roformer_chinese_small": {"do_lower_case": True},
"junnyu/roformer_chinese_base": {"do_lower_case": True}, "junnyu/roformer_chinese_base": {"do_lower_case": True},
"junnyu/roformer_chinese_char_small": {"do_lower_case": True},
"junnyu/roformer_chinese_char_base": {"do_lower_case": True},
"junnyu/roformer_small_discriminator": {"do_lower_case": True},
"junnyu/roformer_small_generator": {"do_lower_case": True},
} }
......
...@@ -41,26 +41,26 @@ class JiebaPreTokenizer: ...@@ -41,26 +41,26 @@ class JiebaPreTokenizer:
splits = [] splits = []
# this code slice normalized_string is too slow (6s) but test_alignement_methods can pass # this code slice normalized_string is too slow (6s) but test_alignement_methods can pass
# for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False): for token, start, end in self.jieba.tokenize(str(normalized_string), hmm=False):
# if token in self.vocab:
# splits.append(normalized_string.slice((start, end)))
# else:
# token_list = self.normalizers.normalize_str(token).split()
# for token in token_list:
# if token:
# end = start + len(token)
# splits.append(normalized_string.slice((start, end)))
# start = end
# this code test_alignement_methods can't pass but fast (300ms)
for token in self.jieba.cut(str(normalized_string), False):
if token in self.vocab: if token in self.vocab:
splits.append(NormalizedString(token)) splits.append(normalized_string[start:end])
else: else:
token_list = self.normalizers.normalize_str(token).split() token_list = self.normalizers.normalize_str(token).split()
for token in token_list: for token in token_list:
if token: if token:
splits.append(NormalizedString(token)) end = start + len(token)
splits.append(normalized_string[start:end])
start = end
# this code test_alignement_methods can't pass but fast (300ms)
# for token in self.jieba.cut(str(normalized_string), False):
# if token in self.vocab:
# splits.append(NormalizedString(token))
# else:
# token_list = self.normalizers.normalize_str(token).split()
# for token in token_list:
# if token:
# splits.append(NormalizedString(token))
return splits return splits
......
...@@ -35,6 +35,7 @@ from .file_utils import ( ...@@ -35,6 +35,7 @@ from .file_utils import (
is_flax_available, is_flax_available,
is_onnx_available, is_onnx_available,
is_pandas_available, is_pandas_available,
is_rjieba_available,
is_scatter_available, is_scatter_available,
is_sentencepiece_available, is_sentencepiece_available,
is_soundfile_availble, is_soundfile_availble,
...@@ -223,6 +224,16 @@ def require_git_lfs(test_case): ...@@ -223,6 +224,16 @@ def require_git_lfs(test_case):
return test_case return test_case
def require_rjieba(test_case):
"""
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
"""
if not is_rjieba_available():
return unittest.skip("test requires rjieba")(test_case)
else:
return test_case
def require_onnx(test_case): def require_onnx(test_case):
if not is_onnx_available(): if not is_onnx_available():
return unittest.skip("test requires ONNX")(test_case) return unittest.skip("test requires ONNX")(test_case)
......
...@@ -13,29 +13,14 @@ ...@@ -13,29 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import unittest import unittest
from transformers import RoFormerTokenizer, RoFormerTokenizerFast from transformers import RoFormerTokenizer, RoFormerTokenizerFast
from transformers.testing_utils import require_tokenizers from transformers.testing_utils import require_rjieba, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
def is_rjieba_available():
return importlib.util.find_spec("rjieba") is not None
def require_rjieba(test_case):
"""
Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed.
"""
if not is_rjieba_available():
return unittest.skip("test requires rjieba")(test_case)
else:
return test_case
@require_rjieba @require_rjieba
@require_tokenizers @require_tokenizers
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...@@ -79,6 +64,10 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -79,6 +64,10 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100] exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens) self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)
# due to custom pre_tokenize , char_to_token may be error # can't train new_tokenizer via Tokenizers lib
def test_alignement_methods(self): def test_training_new_tokenizer(self):
pass
# can't train new_tokenizer via Tokenizers lib
def test_training_new_tokenizer_with_special_tokens_change(self):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment