"vscode:/vscode.git/clone" did not exist on "e06a6ea4b086eec71fa7d0d3dc732861aa937f0f"
Commit 5983e3d2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388356184
parent 56b5494d
...@@ -44,6 +44,8 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -44,6 +44,8 @@ class SentencePredictionDataConfig(cfg.DataConfig):
# Maps the key in TfExample to feature name. # Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels' # E.g 'label_ids' to 'next_sentence_labels'
label_name: Optional[Tuple[str, str]] = None label_name: Optional[Tuple[str, str]] = None
# Either tfrecord, sstable, or recordio.
file_type: str = 'tfrecord'
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
...@@ -111,7 +113,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader): ...@@ -111,7 +113,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse) dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
params=self._params,
decoder_fn=self._decode,
parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
...@@ -168,7 +173,8 @@ class TextProcessor(tf.Module): ...@@ -168,7 +173,8 @@ class TextProcessor(tf.Module):
vocab_file=vocab_file, lower_case=lower_case) vocab_file=vocab_file, lower_case=lower_case)
elif tokenization == 'SentencePiece': elif tokenization == 'SentencePiece':
self._tokenizer = modeling.layers.SentencepieceTokenizer( self._tokenizer = modeling.layers.SentencepieceTokenizer(
model_file_path=vocab_file, lower_case=lower_case, model_file_path=vocab_file,
lower_case=lower_case,
strip_diacritics=True) # Strip diacritics to follow ALBERT model strip_diacritics=True) # Strip diacritics to follow ALBERT model
else: else:
raise ValueError('Unsupported tokenization: %s' % tokenization) raise ValueError('Unsupported tokenization: %s' % tokenization)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment