Commit 885fda09 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 424187808
parent 159697a2
......@@ -14,7 +14,7 @@
"""Keras Layers for BERT-specific preprocessing."""
# pylint: disable=g-import-not-at-top
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Text, Union
from absl import logging
import tensorflow as tf
......@@ -71,8 +71,9 @@ class BertTokenizer(tf.keras.layers.Layer):
def __init__(self, *,
vocab_file: str,
lower_case: bool,
lower_case: Optional[bool] = None,
tokenize_with_offsets: bool = False,
tokenizer_kwargs: Optional[Mapping[Text, Any]] = None,
**kwargs):
"""Initialize a `BertTokenizer` layer.
......@@ -81,15 +82,18 @@ class BertTokenizer(tf.keras.layers.Layer):
This is a text file with newline-separated wordpiece tokens.
This layer initializes a lookup table from it that gets used with
`text.BertTokenizer`.
lower_case: A Python boolean forwarded to `text.BertTokenizer`.
lower_case: Optional boolean forwarded to `text.BertTokenizer`.
If true, input text is converted to lower case (where applicable)
before tokenization. This must be set to match the way in which
the `vocab_file` was created.
the `vocab_file` was created. If passed, this overrides whatever value
may have been passed in `tokenizer_kwargs`.
tokenize_with_offsets: A Python boolean. If true, this layer calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of
`(tokens, start_offsets, limit_offsets)`
insead of just tokens.
tokenizer_kwargs: Optional mapping with keyword arguments to forward to
`text.BertTokenizer`'s constructor.
**kwargs: Standard arguments to `Layer()`.
Raises:
......@@ -111,8 +115,11 @@ class BertTokenizer(tf.keras.layers.Layer):
self._special_tokens_dict = self._create_special_tokens_dict(
self._vocab_table, vocab_file)
super().__init__(**kwargs)
self._bert_tokenizer = text.BertTokenizer(
self._vocab_table, lower_case=lower_case)
tokenizer_kwargs = dict(tokenizer_kwargs or {})
if lower_case is not None:
tokenizer_kwargs["lower_case"] = lower_case
self._bert_tokenizer = text.BertTokenizer(self._vocab_table,
**tokenizer_kwargs)
@property
def vocab_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment