Commit b6c5f6e6 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

task update

parent c86a93db
...@@ -29,6 +29,7 @@ import dataclasses ...@@ -29,6 +29,7 @@ import dataclasses
MIN_LEVEL = 1 MIN_LEVEL = 1
MAX_LEVEL = 7 MAX_LEVEL = 7
GLOBAL_SEED = 1000
def _build_dict(min_level, max_level, value): def _build_dict(min_level, max_level, value):
vals = {str(key): value for key in range(min_level, max_level + 1)} vals = {str(key): value for key in range(min_level, max_level + 1)}
...@@ -213,12 +214,13 @@ class YoloTask(cfg.TaskConfig): ...@@ -213,12 +214,13 @@ class YoloTask(cfg.TaskConfig):
init_checkpoint_modules: Union[ init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder str, List[str]] = 'all' # all, backbone, and/or decoder
gradient_clip_norm: float = 0.0 gradient_clip_norm: float = 0.0
seed = GLOBAL_SEED
COCO_INPUT_PATH_BASE = 'coco' COCO_INPUT_PATH_BASE = 'coco'
COCO_TRAIN_EXAMPLES = 118287 COCO_TRAIN_EXAMPLES = 118287
COCO_VAL_EXAMPLES = 5000 COCO_VAL_EXAMPLES = 5000
GLOBAL_SEED = 1000
@exp_factory.register_config_factory('yolo') @exp_factory.register_config_factory('yolo')
def yolo() -> cfg.ExperimentConfig: def yolo() -> cfg.ExperimentConfig:
...@@ -256,7 +258,6 @@ def yolo_darknet() -> cfg.ExperimentConfig: ...@@ -256,7 +258,6 @@ def yolo_darknet() -> cfg.ExperimentConfig:
train_data=DataConfig( train_data=DataConfig(
is_training=True, is_training=True,
global_batch_size=train_batch_size, global_batch_size=train_batch_size,
seed=GLOBAL_SEED,
dtype='float32', dtype='float32',
parser=Parser( parser=Parser(
letter_box=False, letter_box=False,
...@@ -371,7 +372,6 @@ def scaled_yolo() -> cfg.ExperimentConfig: ...@@ -371,7 +372,6 @@ def scaled_yolo() -> cfg.ExperimentConfig:
train_data=DataConfig( train_data=DataConfig(
is_training=True, is_training=True,
global_batch_size=train_batch_size, global_batch_size=train_batch_size,
seed=GLOBAL_SEED,
dtype='float32', dtype='float32',
parser=Parser( parser=Parser(
aug_rand_saturation = 0.7, aug_rand_saturation = 0.7,
......
...@@ -58,7 +58,7 @@ class YoloTask(base_task.Task): ...@@ -58,7 +58,7 @@ class YoloTask(base_task.Task):
self._metrics = [] self._metrics = []
# globally set the random seed # globally set the random seed
preprocessing_ops.set_random_seeds(seed=params.train_data.seed) preprocessing_ops.set_random_seeds(seed=params.seed)
return return
def build_model(self): def build_model(self):
...@@ -109,6 +109,7 @@ class YoloTask(base_task.Task): ...@@ -109,6 +109,7 @@ class YoloTask(base_task.Task):
anchor_dict, level_limits = model.anchor_boxes.get(backbone.min_level, anchor_dict, level_limits = model.anchor_boxes.get(backbone.min_level,
backbone.max_level) backbone.max_level)
params.seed = self.task_config.seed
# set shared patamters between mosaic and yolo_input # set shared patamters between mosaic and yolo_input
base_config = dict( base_config = dict(
letter_box=params.parser.letter_box, letter_box=params.parser.letter_box,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment