"git@developer.sourcefind.cn:change/sglang.git" did not exist on "54b9a2de0a457709607d6df917d5e6ac5004f72b"
Commit fce888f4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Introduce the abstract class for dataloaders. load() is required.

PiperOrigin-RevId: 324659665
parent f4558be1
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An abstraction that NLP models define input pipelines."""
import abc
from typing import Optional
import tensorflow as tf
class DataLoader(metaclass=abc.ABCMeta):
"""An abstract class defining the APIs for tf.data input pipeline."""
@abc.abstractmethod
def load(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Implements DataLoader load method.
Builds the entire input pipeline inside the load method. Users can define
states inside the DataLoader class and returns a tf.data dataset
object.
Args:
input_context: This is a context class that is passed to the user's input
function and contains information about the compute replicas and input
pipelines. This object is used for multi-host inputs and passed by
the distribution strategy.
Returns:
A per-host tf.data dataset. Note that, we usually create the distributed
dataset through the load method, so we should not directly return a
distributed dataset here.
"""
pass
...@@ -21,6 +21,7 @@ import tensorflow as tf ...@@ -21,6 +21,7 @@ import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -37,7 +38,7 @@ class BertPretrainDataConfig(cfg.DataConfig): ...@@ -37,7 +38,7 @@ class BertPretrainDataConfig(cfg.DataConfig):
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig) @data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
class BertPretrainDataLoader: class BertPretrainDataLoader(data_loader.DataLoader):
"""A class to load dataset for bert pretraining task.""" """A class to load dataset for bert pretraining task."""
def __init__(self, params): def __init__(self, params):
......
...@@ -20,6 +20,7 @@ import tensorflow as tf ...@@ -20,6 +20,7 @@ import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -42,7 +43,7 @@ class QADataConfig(cfg.DataConfig): ...@@ -42,7 +43,7 @@ class QADataConfig(cfg.DataConfig):
@data_loader_factory.register_data_loader_cls(QADataConfig) @data_loader_factory.register_data_loader_cls(QADataConfig)
class QuestionAnsweringDataLoader: class QuestionAnsweringDataLoader(data_loader.DataLoader):
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params): def __init__(self, params):
......
...@@ -20,6 +20,7 @@ import tensorflow as tf ...@@ -20,6 +20,7 @@ import tensorflow as tf
from official.core import input_reader from official.core import input_reader
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.data import data_loader
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
...@@ -37,7 +38,7 @@ class SentencePredictionDataConfig(cfg.DataConfig): ...@@ -37,7 +38,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig) @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
class SentencePredictionDataLoader: class SentencePredictionDataLoader(data_loader.DataLoader):
"""A class to load dataset for sentence prediction (classification) task.""" """A class to load dataset for sentence prediction (classification) task."""
def __init__(self, params): def __init__(self, params):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment