Commit c4451b7a authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 319267378
parent 7b5cb554
......@@ -74,45 +74,6 @@ def instantiate_bertpretrainer_from_cfg(
config.cls_heads))
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@dataclasses.dataclass
class BertPretrainEvalDataConfig(BertPretrainDataConfig):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path: str = ""
global_batch_size: int = 512
is_training: bool = False
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@dataclasses.dataclass
class SentencePredictionDevDataConfig(cfg.DataConfig):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path: str = ""
global_batch_size: int = 32
is_training: bool = False
seq_length: int = 128
drop_remainder: bool = False
@dataclasses.dataclass
class QADataConfig(cfg.DataConfig):
"""Data config for question answering task (tasks/question_answering)."""
......@@ -137,22 +98,3 @@ class QADevDataConfig(cfg.DataConfig):
vocab_file: str = ""
tokenization: str = "WordPiece" # WordPiece or SentencePiece
do_lower_case: bool = True
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = True
seq_length: int = 384
@dataclasses.dataclass
class TaggingDevDataConfig(cfg.DataConfig):
"""Dev Data config for tagging (tasks/tagging)."""
input_path: str = ""
global_batch_size: int = 48
is_training: bool = False
seq_length: int = 384
drop_remainder: bool = False
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
from official.utils import registry
_REGISTERED_DATA_LOADER_CLS = {}
def register_data_loader_cls(data_config_cls):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
This decorator supports registration of data loaders as follows:
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
def get_data_loader(data_config):
"""Creates a data_loader from data_config."""
return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
data_config)
......@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class BertPretrainDataConfig(cfg.DataConfig):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path: str = ''
global_batch_size: int = 512
is_training: bool = True
seq_length: int = 512
max_predictions_per_seq: int = 76
use_next_sentence_label: bool = True
use_position_id: bool = False
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader:
"""A class to load dataset for bert pretraining task."""
......@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params,
decoder_fn=self._decode,
parser_fn=self._parse)
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
......@@ -15,11 +15,24 @@
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class SentencePredictionDataConfig(cfg.DataConfig):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path: str = ''
global_batch_size: int = 32
is_training: bool = True
seq_length: int = 128
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader:
"""A class to load dataset for sentence prediction (classification) task."""
......
......@@ -15,15 +15,26 @@
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional
import dataclasses
import tensorflow as tf
from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
class TaggingDataConfig(cfg.DataConfig):
"""Data config for tagging (tasks/tagging)."""
is_training: bool = True
seq_length: int = 128
@data_loader_factory.register_data_loader_cls(TaggingDataConfig)
class TaggingDataLoader:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params):
def __init__(self, params: TaggingDataConfig):
self._params = params
self._seq_length = params.seq_length
......
......@@ -20,7 +20,7 @@ import tensorflow as tf
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.data import pretrain_dataloader
from official.nlp.data import data_loader_factory
@dataclasses.dataclass
......@@ -95,8 +95,7 @@ class MaskedLMTask(base_task.Task):
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return pretrain_dataloader.BertPretrainDataLoader(params).load(
input_context)
return data_loader_factory.get_data_loader(params).load(input_context)
def build_metrics(self, training=None):
del training
......
......@@ -19,6 +19,7 @@ import tensorflow as tf
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import masked_lm
......@@ -33,7 +34,7 @@ class MLMTaskTest(tf.test.TestCase):
bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence")
]),
train_data=bert.BertPretrainDataConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(
input_path="dummy",
max_predictions_per_seq=20,
seq_length=128,
......
......@@ -25,7 +25,7 @@ import tensorflow_hub as hub
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.data import data_loader_factory
from official.nlp.tasks import utils
......@@ -103,8 +103,7 @@ class SentencePredictionTask(base_task.Task):
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
return sentence_prediction_dataloader.SentencePredictionDataLoader(
params).load(input_context)
return data_loader_factory.get_data_loader(params).load(input_context)
def build_metrics(self, training=None):
del training
......
......@@ -24,6 +24,7 @@ from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import sentence_prediction_dataloader
from official.nlp.tasks import sentence_prediction
......@@ -31,8 +32,9 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(SentencePredictionTaskTest, self).setUp()
self._train_data_config = bert.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
self._train_data_config = (
sentence_prediction_dataloader.SentencePredictionDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1))
def get_model_config(self, num_classes):
return bert.BertPretrainerConfig(
......
......@@ -27,7 +27,7 @@ import tensorflow_hub as hub
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import encoders
from official.nlp.data import tagging_data_loader
from official.nlp.data import data_loader_factory
from official.nlp.modeling import models
from official.nlp.tasks import utils
......@@ -138,8 +138,7 @@ class TaggingTask(base_task.Task):
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
dataset = tagging_data_loader.TaggingDataLoader(params).load(input_context)
return dataset
return data_loader_factory.get_data_loader(params).load(input_context)
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step.
......
......@@ -20,8 +20,8 @@ import tensorflow as tf
from official.nlp.bert import configs
from official.nlp.bert import export_tfhub
from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import tagging_data_loader
from official.nlp.tasks import tagging
......@@ -31,7 +31,7 @@ class TaggingTest(tf.test.TestCase):
super(TaggingTest, self).setUp()
self._encoder_config = encoders.TransformerEncoderConfig(
vocab_size=30522, num_layers=1)
self._train_data_config = bert.TaggingDataConfig(
self._train_data_config = tagging_data_loader.TaggingDataConfig(
input_path="dummy", seq_length=128, global_batch_size=1)
def _run_task(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment