Commit b6c5f6e6 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

task update

parent c86a93db
......@@ -29,6 +29,7 @@ import dataclasses
MIN_LEVEL = 1
MAX_LEVEL = 7
GLOBAL_SEED = 1000
def _build_dict(min_level, max_level, value):
vals = {str(key): value for key in range(min_level, max_level + 1)}
......@@ -213,12 +214,13 @@ class YoloTask(cfg.TaskConfig):
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
gradient_clip_norm: float = 0.0
seed = GLOBAL_SEED
COCO_INPUT_PATH_BASE = 'coco'
COCO_TRAIN_EXAMPLES = 118287
COCO_VAL_EXAMPLES = 5000
GLOBAL_SEED = 1000
@exp_factory.register_config_factory('yolo')
def yolo() -> cfg.ExperimentConfig:
......@@ -256,7 +258,6 @@ def yolo_darknet() -> cfg.ExperimentConfig:
train_data=DataConfig(
is_training=True,
global_batch_size=train_batch_size,
seed=GLOBAL_SEED,
dtype='float32',
parser=Parser(
letter_box=False,
......@@ -371,7 +372,6 @@ def scaled_yolo() -> cfg.ExperimentConfig:
train_data=DataConfig(
is_training=True,
global_batch_size=train_batch_size,
seed=GLOBAL_SEED,
dtype='float32',
parser=Parser(
aug_rand_saturation = 0.7,
......
......@@ -58,7 +58,7 @@ class YoloTask(base_task.Task):
self._metrics = []
# globally set the random seed
preprocessing_ops.set_random_seeds(seed=params.train_data.seed)
preprocessing_ops.set_random_seeds(seed=params.seed)
return
def build_model(self):
......@@ -109,6 +109,7 @@ class YoloTask(base_task.Task):
anchor_dict, level_limits = model.anchor_boxes.get(backbone.min_level,
backbone.max_level)
params.seed = self.task_config.seed
# set shared patamters between mosaic and yolo_input
base_config = dict(
letter_box=params.parser.letter_box,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment