Unverified Commit 40ca1336 authored by Bruno Alvisio's avatar Bruno Alvisio Committed by GitHub
Browse files

Fix passing kwargs to TFBertTokenizer (#21619)

parent fc28c006
...@@ -114,15 +114,24 @@ class TFBertTokenizer(tf.keras.layers.Layer): ...@@ -114,15 +114,24 @@ class TFBertTokenizer(tf.keras.layers.Layer):
tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer) tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer)
``` ```
""" """
do_lower_case = kwargs.pop("do_lower_case", None)
do_lower_case = tokenizer.do_lower_case if do_lower_case is None else do_lower_case
cls_token_id = kwargs.pop("cls_token_id", None)
cls_token_id = tokenizer.cls_token_id if cls_token_id is None else cls_token_id
sep_token_id = kwargs.pop("sep_token_id", None)
sep_token_id = tokenizer.sep_token_id if sep_token_id is None else sep_token_id
pad_token_id = kwargs.pop("pad_token_id", None)
pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
vocab = sorted([(wordpiece, idx) for wordpiece, idx in vocab.items()], key=lambda x: x[1]) vocab = sorted([(wordpiece, idx) for wordpiece, idx in vocab.items()], key=lambda x: x[1])
vocab_list = [entry[0] for entry in vocab] vocab_list = [entry[0] for entry in vocab]
return cls( return cls(
vocab_list=vocab_list, vocab_list=vocab_list,
do_lower_case=tokenizer.do_lower_case, do_lower_case=do_lower_case,
cls_token_id=tokenizer.cls_token_id, cls_token_id=cls_token_id,
sep_token_id=tokenizer.sep_token_id, sep_token_id=sep_token_id,
pad_token_id=tokenizer.pad_token_id, pad_token_id=pad_token_id,
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment