Unverified Commit 8594dd80 authored by Yohei Tamura's avatar Yohei Tamura Committed by GitHub
Browse files

BertJapaneseTokenizer accept options for mecab (#3566)

* BertJapaneseTokenizer accept options for mecab

* black

* fix mecab_option to Option[str]
parent 216e167c
...@@ -89,6 +89,7 @@ class BertJapaneseTokenizer(BertTokenizer): ...@@ -89,6 +89,7 @@ class BertJapaneseTokenizer(BertTokenizer):
pad_token="[PAD]", pad_token="[PAD]",
cls_token="[CLS]", cls_token="[CLS]",
mask_token="[MASK]", mask_token="[MASK]",
mecab_kwargs=None,
**kwargs **kwargs
): ):
"""Constructs a MecabBertTokenizer. """Constructs a MecabBertTokenizer.
...@@ -106,6 +107,7 @@ class BertJapaneseTokenizer(BertTokenizer): ...@@ -106,6 +107,7 @@ class BertJapaneseTokenizer(BertTokenizer):
Type of word tokenizer. Type of word tokenizer.
**subword_tokenizer_type**: (`optional`) string (default "wordpiece") **subword_tokenizer_type**: (`optional`) string (default "wordpiece")
Type of subword tokenizer. Type of subword tokenizer.
**mecab_kwargs**: (`optional`) dict passed to `MecabTokenizer` constructor (default None)
""" """
super(BertTokenizer, self).__init__( super(BertTokenizer, self).__init__(
unk_token=unk_token, unk_token=unk_token,
...@@ -134,7 +136,9 @@ class BertJapaneseTokenizer(BertTokenizer): ...@@ -134,7 +136,9 @@ class BertJapaneseTokenizer(BertTokenizer):
do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=False
) )
elif word_tokenizer_type == "mecab": elif word_tokenizer_type == "mecab":
self.word_tokenizer = MecabTokenizer(do_lower_case=do_lower_case, never_split=never_split) self.word_tokenizer = MecabTokenizer(
do_lower_case=do_lower_case, never_split=never_split, **(mecab_kwargs or {})
)
else: else:
raise ValueError("Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type)) raise ValueError("Invalid word_tokenizer_type '{}' is specified.".format(word_tokenizer_type))
...@@ -164,7 +168,7 @@ class BertJapaneseTokenizer(BertTokenizer): ...@@ -164,7 +168,7 @@ class BertJapaneseTokenizer(BertTokenizer):
class MecabTokenizer(object): class MecabTokenizer(object):
"""Runs basic tokenization with MeCab morphological parser.""" """Runs basic tokenization with MeCab morphological parser."""
def __init__(self, do_lower_case=False, never_split=None, normalize_text=True): def __init__(self, do_lower_case=False, never_split=None, normalize_text=True, mecab_option=None):
"""Constructs a MecabTokenizer. """Constructs a MecabTokenizer.
Args: Args:
...@@ -176,6 +180,7 @@ class MecabTokenizer(object): ...@@ -176,6 +180,7 @@ class MecabTokenizer(object):
List of token not to split. List of token not to split.
**normalize_text**: (`optional`) boolean (default True) **normalize_text**: (`optional`) boolean (default True)
Whether to apply unicode normalization to text before tokenization. Whether to apply unicode normalization to text before tokenization.
**mecab_option**: (`optional`) string passed to `MeCab.Tagger` constructor (default "")
""" """
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.never_split = never_split if never_split is not None else [] self.never_split = never_split if never_split is not None else []
...@@ -183,7 +188,7 @@ class MecabTokenizer(object): ...@@ -183,7 +188,7 @@ class MecabTokenizer(object):
import MeCab import MeCab
self.mecab = MeCab.Tagger() self.mecab = MeCab.Tagger(mecab_option) if mecab_option is not None else MeCab.Tagger()
def tokenize(self, text, never_split=None, **kwargs): def tokenize(self, text, never_split=None, **kwargs):
"""Tokenizes a piece of text.""" """Tokenizes a piece of text."""
......
...@@ -91,6 +91,20 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -91,6 +91,20 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
["アップルストア", "で", "iphone", "8", "が", "発売", "さ", "れ", "た", "。"], ["アップルストア", "で", "iphone", "8", "が", "発売", "さ", "れ", "た", "。"],
) )
def test_mecab_tokenizer_with_option(self):
try:
tokenizer = MecabTokenizer(
do_lower_case=True, normalize_text=False, mecab_option="-d /usr/local/lib/mecab/dic/jumandic"
)
except RuntimeError:
# if dict doesn't exist in the system, previous code raises this error.
return
self.assertListEqual(
tokenizer.tokenize(" \tアップルストアでiPhone8 が \n 発売された 。 "),
["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れた", "\u3000", "。"],
)
def test_mecab_tokenizer_no_normalize(self): def test_mecab_tokenizer_no_normalize(self):
tokenizer = MecabTokenizer(normalize_text=False) tokenizer = MecabTokenizer(normalize_text=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment