Commit c4451b7a authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319267378
parent 7b5cb554
...@@ -74,45 +74,6 @@ def instantiate_bertpretrainer_from_cfg( ...@@ -74,45 +74,6 @@ def instantiate_bertpretrainer_from_cfg(
config.cls_heads)) config.cls_heads))
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@dataclasses.dataclass
class SentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = False
seq_length: int = 128
drop_remainder: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class QADataConfig(cfg.DataConfig): class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering).""" """Data config for question answering task (tasks/question_answering)."""
...@@ -137,22 +98,3 @@ class QADevDataConfig(cfg.DataConfig): ...@@ -137,22 +98,3 @@ class QADevDataConfig(cfg.DataConfig):
vocab_file: str = "" vocab_file: str = ""
tokenization: str = "WordPiece" # WordPiece or SentencePiece tokenization: str = "WordPiece" # WordPiece or SentencePiece
do_lower_case: bool = True do_lower_case: bool = True
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class TaggingDevDataConfig(cfg.DataConfig):
"""Dev Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
from official.utils import registry
_REGISTERED_DATA_LOADER_CLS = {}
def register_data_loader_cls(data_config_cls):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
This decorator supports registration of data loaders as follows:
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
def get_data_loader(data_config):
"""Creates a data_loader from data_config."""
return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
data_config)
...@@ -16,11 +16,27 @@ ...@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task.""" """Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader: class BertPretrainDataLoader:
"""A class to load dataset for bert pretraining task.""" """A class to load dataset for bert pretraining task."""
...@@ -91,7 +107,5 @@ class BertPretrainDataLoader: ...@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset.""" """Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader( reader = input_reader.InputReader(
params=self._params, params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
decoder_fn=self._decode,
parser_fn=self._parse)
return reader.read(input_context) return reader.read(input_context)
...@@ -15,11 +15,24 @@ ...@@ -15,11 +15,24 @@
# ============================================================================== # ==============================================================================
"""Loads dataset for the sentence prediction (classification) task.""" """Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ''
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader: class SentencePredictionDataLoader:
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
......
...@@ -15,15 +15,26 @@ ...@@ -15,15 +15,26 @@
# ============================================================================== # ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task.""" """Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional from typing import Mapping, Optional
import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training: bool = True
seq_length: int = 128
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
class TaggingDataLoader: class TaggingDataLoader:
"""A class to load dataset for tagging (e.g., NER and POS) task.""" """A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params): def __init__(self, params: TaggingDataConfig):
self._params = params self._params = params
self._seq_length = params.seq_length self._seq_length = params.seq_length
......
...@@ -20,7 +20,7 @@ import tensorflow as tf ...@@ -20,7 +20,7 @@ import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.data import pretrain_dataloader from official.nlp.data import data_loader_factory
@dataclasses.dataclass @dataclasses.dataclass
...@@ -95,8 +95,7 @@ class MaskedLMTask(base_task.Task): ...@@ -95,8 +95,7 @@ class MaskedLMTask(base_task.Task):
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
return pretrain_dataloader.BertPretrainDataLoader(params).load( return data_loader_factory.get_data_loader(params).load(input_context)
input_context)
def build_metrics(self, training=None): def build_metrics(self, training=None):
del training del training
......
...@@ -19,6 +19,7 @@ import tensorflow as tf ...@@ -19,6 +19,7 @@ import tensorflow as tf
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import masked_lm from official.nlp.tasks import masked_lm
...@@ -33,7 +34,7 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -33,7 +34,7 @@ class MLMTaskTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]), ]),
train_data=bert.BertPretrainDataConfig( train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy", input_path="dummy",
max_predictions_per_seq=20, max_predictions_per_seq=20,
seq_length=128, seq_length=128,
......
...@@ -25,7 +25,7 @@ import tensorflow_hub as hub ...@@ -25,7 +25,7 @@ import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.data import sentence_prediction_dataloader from official.nlp.data import data_loader_factory
from official.nlp.tasks import utils from official.nlp.tasks import utils
...@@ -103,8 +103,7 @@ class SentencePredictionTask(base_task.Task): ...@@ -103,8 +103,7 @@ class SentencePredictionTask(base_task.Task):
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
return sentence_prediction_dataloader.SentencePredictionDataLoader( return data_loader_factory.get_data_loader(params).load(input_context)
params).load(input_context)
def build_metrics(self, training=None): def build_metrics(self, training=None):
del training del training
......
...@@ -24,6 +24,7 @@ from official.nlp.bert import configs ...@@ -24,6 +24,7 @@ from official.nlp.bert import configs
from official.nlp.bert import export_tfhub from official.nlp.bert import export_tfhub
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.tasks import sentence_prediction from official.nlp.tasks import sentence_prediction
...@@ -31,8 +32,9 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -31,8 +32,9 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(SentencePredictionTaskTest, self).setUp() super(SentencePredictionTaskTest, self).setUp()
self._train_data_config = bert.SentencePredictionDataConfig( self._train_data_config = (
input_path="dummy", seq_length=128, global_batch_size=1) sentence_prediction_dataloader.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1))
def get_model_config(self, num_classes): def get_model_config(self, num_classes):
return bert.BertPretrainerConfig( return bert.BertPretrainerConfig(
......
...@@ -27,7 +27,7 @@ import tensorflow_hub as hub ...@@ -27,7 +27,7 @@ import tensorflow_hub as hub
from official.core import base_task from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import tagging_data_loader from official.nlp.data import data_loader_factory
from official.nlp.modeling import models from official.nlp.modeling import models
from official.nlp.tasks import utils from official.nlp.tasks import utils
...@@ -138,8 +138,7 @@ class TaggingTask(base_task.Task): ...@@ -138,8 +138,7 @@ class TaggingTask(base_task.Task):
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
dataset = tagging_data_loader.TaggingDataLoader(params).load(input_context) return data_loader_factory.get_data_loader(params).load(input_context)
return dataset
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step. """Validatation step.
......
...@@ -20,8 +20,8 @@ import tensorflow as tf ...@@ -20,8 +20,8 @@ import tensorflow as tf
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.bert import export_tfhub from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders from official.nlp.configs import encoders
from official.nlp.data import tagging_data_loader
from official.nlp.tasks import tagging from official.nlp.tasks import tagging
...@@ -31,7 +31,7 @@ class TaggingTest(tf.test.TestCase): ...@@ -31,7 +31,7 @@ class TaggingTest(tf.test.TestCase):
super(TaggingTest, self).setUp() super(TaggingTest, self).setUp()
self._encoder_config = encoders.TransformerEncoderConfig( self._encoder_config = encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1) vocab_size=30522, num_layers=1)
self._train_data_config = bert.TaggingDataConfig( self._train_data_config = tagging_data_loader.TaggingDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1) input_path="dummy", seq_length=128, global_batch_size=1)
def _run_task(self, config): def _run_task(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment