"docs/vscode:/vscode.git/clone" did not exist on "7ceff67e1a75dc4060e30dcd1f09052d6ea697a8"
Unverified Commit 0875b250 authored by Matt's avatar Matt Committed by GitHub
Browse files

Allow passing kwargs through to TFBertTokenizer (#24324)

parent cfc838dd
......@@ -48,7 +48,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention_mask.
use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`):
If set to false will use standard TF Text BertTokenizer, making it servable by TF Serving.
If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer
class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to
TFLite.
"""
def __init__(
......@@ -65,11 +67,12 @@ class TFBertTokenizer(tf.keras.layers.Layer):
return_token_type_ids: bool = True,
return_attention_mask: bool = True,
use_fast_bert_tokenizer: bool = True,
**tokenizer_kwargs,
):
super().__init__()
if use_fast_bert_tokenizer:
self.tf_tokenizer = FastBertTokenizer(
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs
)
else:
lookup_table = tf.lookup.StaticVocabularyTable(
......@@ -81,7 +84,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
),
num_oov_buckets=1,
)
self.tf_tokenizer = BertTokenizerLayer(lookup_table, token_out_type=tf.int64, lower_case=do_lower_case)
self.tf_tokenizer = BertTokenizerLayer(
lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs
)
self.vocab_list = vocab_list
self.do_lower_case = do_lower_case
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment