Commit 44e7092c authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into AXg

parents 431a9ca3 59434199
......@@ -14,27 +14,16 @@
# limitations under the License.
# ==============================================================================
"""TFM continuous finetuning+eval training driver."""
import gc
import os
import time
from typing import Any, Mapping, Optional
from absl import app
from absl import flags
from absl import logging
import gin
import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import config_definitions
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS
......@@ -44,140 +33,15 @@ flags.DEFINE_integer(
help='The number of total training steps for the pretraining job.')
def run_continuous_finetune(
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
pretrain_steps: Optional[int] = None,
) -> Mapping[str, Any]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
othewise, returns {}.
"""
assert mode == 'continuous_train_and_eval', (
'Only continuous_train_and_eval is supported by continuous_finetune. '
'Got mode: {}'.format(mode))
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
retry_times = 0
while not tf.io.gfile.isdir(params.task.init_checkpoint):
# Wait for the init_checkpoint directory to be created.
if retry_times >= 60:
raise ValueError(
'ExperimentConfig.task.init_checkpoint must be a directory for '
'continuous_train_and_eval mode.')
retry_times += 1
time.sleep(60)
summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, 'eval'))
global_step = 0
def timeout_fn():
if pretrain_steps and global_step < pretrain_steps:
# Keeps waiting for another timeout period.
logging.info(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.', global_step, pretrain_steps)
return False
# Quits the loop.
return True
for pretrain_ckpt in tf.train.checkpoints_iterator(
checkpoint_dir=params.task.init_checkpoint,
min_interval_secs=10,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn):
with distribution_strategy.scope():
global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
if params.trainer.best_checkpoint_export_subdir:
best_ckpt_subdir = '{}_{}'.format(
params.trainer.best_checkpoint_export_subdir, global_step)
params_replaced = params.replace(
task={'init_checkpoint': pretrain_ckpt},
trainer={'best_checkpoint_export_subdir': best_ckpt_subdir})
else:
params_replaced = params.replace(task={'init_checkpoint': pretrain_ckpt})
params_replaced.lock()
logging.info('Running finetuning with params: %s', params_replaced)
with distribution_strategy.scope():
task = task_factory.get_task(params_replaced.task, logging_dir=model_dir)
_, eval_metrics = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
# replace params.task.init_checkpoint to make sure that we load
# exactly this pretrain checkpoint.
params=params_replaced,
model_dir=model_dir,
run_post_eval=True,
save_summary=False)
logging.info('Evaluation finished. Pretrain global_step: %d', global_step)
train_utils.write_json_summary(model_dir, global_step, eval_metrics)
if not os.path.basename(model_dir): # if model_dir.endswith('/')
summary_grp = os.path.dirname(model_dir) + '_' + task.name
else:
summary_grp = os.path.basename(model_dir) + '_' + task.name
summaries = {}
for name, value in eval_metrics.items():
summaries[summary_grp + '/' + name] = value
train_utils.write_summary(summary_writer, global_step, summaries)
train_utils.remove_ckpts(model_dir)
# In TF2, the resource life cycle is bound with the python object life
# cycle. Force trigger python garbage collection here so those resources
# can be deallocated in time, so it doesn't cause OOM when allocating new
# objects.
# TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
# if we need gc here.
gc.collect()
if run_post_eval:
return eval_metrics
return {}
def main(_):
# TODO(b/177863554): consolidate to nlp/train.py
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
train_utils.serialize_config(params, model_dir)
run_continuous_finetune(FLAGS.mode, params, model_dir, FLAGS.pretrain_steps)
continuous_finetune_lib.run_continuous_finetune(FLAGS.mode, params, model_dir,
FLAGS.pretrain_steps)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
......
......@@ -14,17 +14,13 @@
# ==============================================================================
"""Binary to generate training/evaluation dataset for NCF model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
# pylint: disable=g-bad-import-order
# Import libraries
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.recommendation import movielens
......
......@@ -14,15 +14,10 @@
# ==============================================================================
"""NCF model input pipeline."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
# pylint: disable=g-bad-import-order
import tensorflow.compat.v2 as tf
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.recommendation import constants as rconst
......
......@@ -18,10 +18,6 @@ The NeuMF model assembles both MF and MLP models under the NCF framework. Check
`neumf_model.py` for more details about the models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
......@@ -30,7 +26,7 @@ import os
from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.common import distribute_utils
......
......@@ -14,14 +14,9 @@
# ==============================================================================
"""Some gradient util functions to help users writing custom training loop."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl import logging
import tensorflow.compat.v2 as tf
import tensorflow as tf
def _filter_grads(grads_and_vars):
......
# Adding Abseil (absl) flags quickstart
**WARNING** This module is deprecated. We no long use it in new models and
your projects should not depend on it. We will remove this module when
all models using it are deprecated which may take time.
## Defining a flag
absl flag definitions are similar to argparse, although they are defined on a global namespace.
......
"""A simple Python callstack sampler."""
import contextlib
import datetime
import signal
import traceback
class CallstackSampler(object):
"""A simple signal-based Python callstack sampler.
"""
def __init__(self, interval=None):
self.stacks = []
self.interval = 0.001 if interval is None else interval
def _sample(self, signum, frame):
"""Samples the current stack."""
del signum
stack = traceback.extract_stack(frame)
formatted_stack = []
formatted_stack.append(datetime.datetime.utcnow())
for filename, lineno, function_name, text in stack:
formatted_frame = '{}:{}({})({})'.format(filename, lineno, function_name,
text)
formatted_stack.append(formatted_frame)
self.stacks.append(formatted_stack)
signal.setitimer(signal.ITIMER_VIRTUAL, self.interval, 0)
@contextlib.contextmanager
def profile(self):
signal.signal(signal.SIGVTALRM, self._sample)
signal.setitimer(signal.ITIMER_VIRTUAL, self.interval, 0)
try:
yield
finally:
signal.setitimer(signal.ITIMER_VIRTUAL, 0)
def save(self, fname):
with open(fname, 'w') as f:
for s in self.stacks:
for l in s:
f.write('%s\n' % l)
f.write('\n')
@contextlib.contextmanager
def callstack_sampling(filename, interval=None):
"""Periodically samples the Python callstack.
Args:
filename: the filename
interval: the sampling interval, in seconds. Defaults to 0.001.
Yields:
nothing
"""
sampler = CallstackSampler(interval=interval)
with sampler.profile():
yield
sampler.save(filename)
......@@ -34,10 +34,10 @@ do_pylint() {
# --incremental Performs check on only the python files changed in the
# last non-merge git commit.
# Use this list to whitelist pylint errors
ERROR_WHITELIST=""
# Use this list to ALLOWLIST pylint errors
ERROR_ALLOWLIST=""
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
echo "ERROR_ALLOWLIST=\"${ERROR_ALLOWLIST}\""
PYLINT_BIN="python3 -m pylint"
......@@ -92,16 +92,16 @@ do_pylint() {
N_ERRORS=0
while read -r LINE; do
IS_WHITELISTED=0
for WL_REGEX in ${ERROR_WHITELIST}; do
IS_ALLOWLISTED=0
for WL_REGEX in ${ERROR_ALLOWLIST}; do
if echo ${LINE} | grep -q "${WL_REGEX}"; then
echo "Found a whitelisted error:"
echo "Found a ALLOWLISTed error:"
echo " ${LINE}"
IS_WHITELISTED=1
IS_ALLOWLISTED=1
fi
done
if [[ ${IS_WHITELISTED} == "0" ]]; then
if [[ ${IS_ALLOWLISTED} == "0" ]]; then
echo "${LINE}" >> ${NONWL_ERRORS_FILE}
echo "" >> ${NONWL_ERRORS_FILE}
((N_ERRORS++))
......@@ -116,7 +116,7 @@ do_pylint() {
cat "${NONWL_ERRORS_FILE}"
return 1
else
echo "PASS: No non-whitelisted pylint errors were found."
echo "PASS: No non-ALLOWLISTed pylint errors were found."
return 0
fi
}
......
# Expect to reach: box mAP: 49.3%, mask mAP: 43.4% on COCO
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: null
train_data:
global_batch_size: 256
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
losses:
l2_weight_decay: 0.00004
model:
anchor:
anchor_size: 4.0
num_scales: 3
min_level: 3
max_level: 7
input_size: [1280, 1280, 3]
backbone:
spinenet:
stochastic_depth_drop_rate: 0.2
model_id: '143'
type: 'spinenet'
decoder:
type: 'identity'
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
detection_generator:
pre_nms_top_k: 1000
trainer:
train_steps: 231000
optimizer_config:
learning_rate:
type: 'stepwise'
stepwise:
boundaries: [219450, 226380]
values: [0.32, 0.032, 0.0032]
warmup:
type: 'linear'
linear:
warmup_steps: 2000
warmup_learning_rate: 0.0067
# Expect to reach: box mAP: 43.2%, mask mAP: 38.3% on COCO
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: null
train_data:
global_batch_size: 256
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
losses:
l2_weight_decay: 0.00004
model:
anchor:
anchor_size: 3.0
num_scales: 3
min_level: 3
max_level: 7
input_size: [640, 640, 3]
backbone:
spinenet:
stochastic_depth_drop_rate: 0.2
model_id: '49'
type: 'spinenet'
decoder:
type: 'identity'
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
detection_generator:
pre_nms_top_k: 1000
trainer:
train_steps: 231000
optimizer_config:
learning_rate:
type: 'stepwise'
stepwise:
boundaries: [219450, 226380]
values: [0.32, 0.032, 0.0032]
warmup:
type: 'linear'
linear:
warmup_steps: 2000
warmup_learning_rate: 0.0067
# Expect to reach: box mAP: 48.1%, mask mAP: 42.4% on COCO
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: null
train_data:
global_batch_size: 256
parser:
aug_rand_hflip: true
aug_scale_min: 0.1
aug_scale_max: 2.0
losses:
l2_weight_decay: 0.00004
model:
anchor:
anchor_size: 3.0
num_scales: 3
min_level: 3
max_level: 7
input_size: [1024, 1024, 3]
backbone:
spinenet:
stochastic_depth_drop_rate: 0.2
model_id: '96'
type: 'spinenet'
decoder:
type: 'identity'
norm_activation:
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
detection_generator:
pre_nms_top_k: 1000
trainer:
train_steps: 231000
optimizer_config:
learning_rate:
type: 'stepwise'
stepwise:
boundaries: [219450, 226380]
values: [0.32, 0.032, 0.0032]
warmup:
type: 'linear'
linear:
warmup_steps: 2000
warmup_learning_rate: 0.0067
# 3D ResNet-50 video classification on Kinetics-400. 75.3% top-1 and 91.2% top-5 accuracy.
# 3D ResNet-50 video classification on Kinetics-400.
#
# --experiment_type=video_classification_kinetics400
# Expected accuracy on TPU 8x8: 75.1%
# Updated: 2020-12-16
# Expected accuracy: 77.0% top-1, 93.0% top-5.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -15,45 +14,63 @@ task:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: true
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 1
- 3
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: true
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 1
- 3
- 1
- 3
- 1
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: true
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 1
- 3
- 1
- 3
- 3
temporal_strides: 1
use_self_gating: true
use_self_gating: false
model_id: 50
stem_conv_temporal_kernel_size: 5
stem_conv_temporal_stride: 2
stem_pool_temporal_stride: 2
stem_pool_temporal_stride: 1
train_data:
name: kinetics400
feature_shape: !!python/tuple
- 32
- 224
- 224
- 3
temporal_stride: 2
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.08
aug_min_aspect_ratio: 0.5
validation_data:
name: kinetics400
global_batch_size: 32
feature_shape: !!python/tuple
- 32
- 256
- 256
- 3
temporal_stride: 2
num_test_clips: 10
num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16'
drop_remainder: false
trainer:
......@@ -61,11 +78,11 @@ trainer:
learning_rate:
cosine:
initial_learning_rate: 0.8
decay_steps: 42000
decay_steps: 42104
warmup:
linear:
warmup_steps: 1050
train_steps: 42000
warmup_steps: 1053
train_steps: 42104
steps_per_loop: 500
summary_interval: 500
validation_interval: 500
# SlowOnly 16x4 video classification on Kinetics-400.
#
# --experiment_type=video_classification_kinetics400
# Expected accuracy: 75.6% top-1, 92.1% top-5.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
dropout_rate: 0.5
norm_activation:
use_sync_bn: false
backbone:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
model_id: 50
stem_conv_temporal_kernel_size: 1
stem_conv_temporal_stride: 1
stem_pool_temporal_stride: 1
train_data:
name: kinetics400
feature_shape: !!python/tuple
- 16
- 224
- 224
- 3
temporal_stride: 4
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.08
aug_min_aspect_ratio: 0.5
validation_data:
name: kinetics400
feature_shape: !!python/tuple
- 16
- 256
- 256
- 3
temporal_stride: 4
num_test_clips: 10
num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16'
drop_remainder: false
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 0.8
decay_steps: 42104
warmup:
linear:
warmup_steps: 1053
train_steps: 42104
steps_per_loop: 500
summary_interval: 500
validation_interval: 500
# SlowOnly video classification on Kinetics-400. Expected performance to be updated.
# SlowOnly 8x8 video classification on Kinetics-400.
#
# --experiment_type=video_classification_kinetics400
# Expected accuracy: 74.1% top-1, 91.4% top-5.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -53,6 +56,10 @@ task:
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.08
aug_min_aspect_ratio: 0.5
validation_data:
name: kinetics400
feature_shape: !!python/tuple
......@@ -61,8 +68,9 @@ task:
- 256
- 3
temporal_stride: 8
num_test_clips: 1
global_batch_size: 32
num_test_clips: 10
num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16'
drop_remainder: false
trainer:
......
# 3D ResNet-50 video classification on Kinetics-600.
#
# --experiment_type=video_classification_kinetics600
# Expected accuracy: 79.5% top-1, 94.8% top-5.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
dropout_rate: 0.5
norm_activation:
use_sync_bn: false
backbone:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
model_id: 50
stem_conv_temporal_kernel_size: 5
stem_conv_temporal_stride: 2
stem_pool_temporal_stride: 1
train_data:
name: kinetics600
feature_shape: !!python/tuple
- 32
- 224
- 224
- 3
temporal_stride: 2
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.08
aug_min_aspect_ratio: 0.5
validation_data:
name: kinetics600
feature_shape: !!python/tuple
- 32
- 256
- 256
- 3
temporal_stride: 2
num_test_clips: 10
num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16'
drop_remainder: false
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 0.8
decay_steps: 71488
warmup:
linear:
warmup_steps: 1787
train_steps: 71488
steps_per_loop: 500
summary_interval: 500
validation_interval: 500
# SlowOnly 8x8 video classification on Kinetics-600.
#
# --experiment_type=video_classification_kinetics600
# Expected accuracy: 77.3% top-1, 93.6% top-5.
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
model:
dropout_rate: 0.5
norm_activation:
use_sync_bn: false
backbone:
resnet_3d:
block_specs: !!python/tuple
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 1
- 1
- 1
- 1
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
- temporal_kernel_sizes: !!python/tuple
- 3
- 3
- 3
temporal_strides: 1
use_self_gating: false
model_id: 50
stem_conv_temporal_kernel_size: 1
stem_conv_temporal_stride: 1
stem_pool_temporal_stride: 1
train_data:
name: kinetics600
feature_shape: !!python/tuple
- 8
- 224
- 224
- 3
temporal_stride: 8
global_batch_size: 1024
dtype: 'bfloat16'
shuffle_buffer_size: 1024
aug_max_area_ratio: 1.0
aug_max_aspect_ratio: 2.0
aug_min_area_ratio: 0.08
aug_min_aspect_ratio: 0.5
validation_data:
name: kinetics600
feature_shape: !!python/tuple
- 8
- 256
- 256
- 3
temporal_stride: 8
num_test_clips: 10
num_test_crops: 3
global_batch_size: 64
dtype: 'bfloat16'
drop_remainder: false
trainer:
optimizer_config:
learning_rate:
cosine:
initial_learning_rate: 0.8
decay_steps: 71488
warmup:
linear:
warmup_steps: 1787
train_steps: 71488
steps_per_loop: 500
summary_interval: 500
validation_interval: 500
......@@ -35,7 +35,7 @@ class DataConfig(cfg.DataConfig):
shuffle_buffer_size: int = 10000
cycle_length: int = 10
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'
file_type: str = 'tfrecord' # tfrecord, or sstable
file_type: str = 'tfrecord'
@dataclasses.dataclass
......
......@@ -31,11 +31,13 @@ from official.vision.beta.configs import decoders
@dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config):
regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = ''
......@@ -73,6 +75,7 @@ class DataConfig(cfg.DataConfig):
decoder: DataDecoder = DataDecoder()
parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
@dataclasses.dataclass
......
......@@ -68,6 +68,7 @@ class DataConfig(cfg.DataConfig):
decoder: DataDecoder = DataDecoder()
parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
@dataclasses.dataclass
......
......@@ -51,7 +51,7 @@ class DataConfig(cfg.DataConfig):
aug_scale_max: float = 1.0
aug_rand_hflip: bool = True
drop_remainder: bool = True
file_type: str = 'tfrecord' # tfrecord, or sstable
file_type: str = 'tfrecord'
@dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment