Commit 2e9e0478 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 382240268
parent 1624201f
......@@ -78,6 +78,7 @@ class DataConfig(cfg.DataConfig):
parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True
@dataclasses.dataclass
......@@ -215,7 +216,8 @@ class Losses(hyperparams.Config):
class MaskRCNNTask(cfg.TaskConfig):
model: MaskRCNN = MaskRCNN()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
validation_data: DataConfig = DataConfig(is_training=False,
drop_remainder=False)
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
......@@ -260,7 +262,8 @@ def fasterrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size)),
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=22500,
validation_steps=coco_val_samples // eval_batch_size,
......@@ -324,7 +327,8 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size)),
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=22500,
validation_steps=coco_val_samples // eval_batch_size,
......@@ -401,7 +405,8 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size)),
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=steps_per_epoch * 350,
validation_steps=coco_val_samples // eval_batch_size,
......@@ -486,7 +491,8 @@ def cascadercnn_spinenet_coco() -> cfg.ExperimentConfig:
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size)),
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=steps_per_epoch * 500,
validation_steps=coco_val_samples // eval_batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment