Unverified Commit 0875b250 authored by Matt's avatar Matt Committed by GitHub
Browse files

Allow passing kwargs through to TFBertTokenizer (#24324)

parent cfc838dd
...@@ -48,7 +48,9 @@ class TFBertTokenizer(tf.keras.layers.Layer): ...@@ -48,7 +48,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
return_attention_mask (`bool`, *optional*, defaults to `True`): return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention_mask. Whether to return the attention_mask.
use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`): use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`):
If set to false will use standard TF Text BertTokenizer, making it servable by TF Serving. If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer
class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to
TFLite.
""" """
def __init__( def __init__(
...@@ -65,11 +67,12 @@ class TFBertTokenizer(tf.keras.layers.Layer): ...@@ -65,11 +67,12 @@ class TFBertTokenizer(tf.keras.layers.Layer):
return_token_type_ids: bool = True, return_token_type_ids: bool = True,
return_attention_mask: bool = True, return_attention_mask: bool = True,
use_fast_bert_tokenizer: bool = True, use_fast_bert_tokenizer: bool = True,
**tokenizer_kwargs,
): ):
super().__init__() super().__init__()
if use_fast_bert_tokenizer: if use_fast_bert_tokenizer:
self.tf_tokenizer = FastBertTokenizer( self.tf_tokenizer = FastBertTokenizer(
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs
) )
else: else:
lookup_table = tf.lookup.StaticVocabularyTable( lookup_table = tf.lookup.StaticVocabularyTable(
...@@ -81,7 +84,9 @@ class TFBertTokenizer(tf.keras.layers.Layer): ...@@ -81,7 +84,9 @@ class TFBertTokenizer(tf.keras.layers.Layer):
), ),
num_oov_buckets=1, num_oov_buckets=1,
) )
self.tf_tokenizer = BertTokenizerLayer(lookup_table, token_out_type=tf.int64, lower_case=do_lower_case) self.tf_tokenizer = BertTokenizerLayer(
lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs
)
self.vocab_list = vocab_list self.vocab_list = vocab_list
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment