Commit 13dd0f7f authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Add Dataset module_import functionality to classifier data tools. This allows...

Add Dataset module_import functionality to classifier data tools.  This allows local TF Datasets to be defined and used to generate classification data.

Complementary functionality to module_import in tensorflow_datasets/scripts/download_and_prepare.py

PiperOrigin-RevId: 306801627
parent a622ba7f
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import collections import collections
import csv import csv
import importlib
import os import os
from absl import logging from absl import logging
...@@ -403,6 +404,7 @@ class TfdsProcessor(DataProcessor): ...@@ -403,6 +404,7 @@ class TfdsProcessor(DataProcessor):
(TFDS) for the meaning of individual parameters): (TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number). dataset: Required dataset name (potentially with subset and version number).
data_dir: Optional TFDS source root directory. data_dir: Optional TFDS source root directory.
module_import: Optional Dataset module to import.
train_split: Name of the train split (defaults to `train`). train_split: Name of the train split (defaults to `train`).
dev_split: Name of the dev split (defaults to `validation`). dev_split: Name of the dev split (defaults to `validation`).
test_split: Name of the test split (defaults to `test`). test_split: Name of the test split (defaults to `test`).
...@@ -418,6 +420,9 @@ class TfdsProcessor(DataProcessor): ...@@ -418,6 +420,9 @@ class TfdsProcessor(DataProcessor):
process_text_fn=tokenization.convert_to_unicode): process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn) super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params) self._process_tfds_params_str(tfds_params)
if self.module_import:
importlib.import_module(self.module_import)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir, self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir,
with_info=True) with_info=True)
self._labels = list(range(info.features[self.label_key].num_classes)) self._labels = list(range(info.features[self.label_key].num_classes))
...@@ -428,6 +433,7 @@ class TfdsProcessor(DataProcessor): ...@@ -428,6 +433,7 @@ class TfdsProcessor(DataProcessor):
d = {k.strip(): v.strip() for k, v in tuples} d = {k.strip(): v.strip() for k, v in tuples}
self.dataset_name = d["dataset"] # Required. self.dataset_name = d["dataset"] # Required.
self.data_dir = d.get("data_dir", None) self.data_dir = d.get("data_dir", None)
self.module_import = d.get("module_import", None)
self.train_split = d.get("train_split", "train") self.train_split = d.get("train_split", "train")
self.dev_split = d.get("dev_split", "validation") self.dev_split = d.get("dev_split", "validation")
self.test_split = d.get("test_split", "test") self.test_split = d.get("test_split", "test")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment