Commit 00707082 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 390025168
parent 7bd3c937
......@@ -24,10 +24,17 @@ from typing import Any, Mapping
from absl import logging
import six
import tensorflow as tf
from tensorflow.python.eager import monitoring
from official.modeling.fast_training.progressive import utils
from official.modeling.hyperparams import base_config
_progressive_policy_creation_counter = monitoring.Counter(
'/tensorflow/training/fast_training/progressive_policy_creation',
'Counter for the number of ProgressivePolicy creations.')
@dataclasses.dataclass
class ProgressiveConfig(base_config.Config):
pass
......@@ -69,6 +76,8 @@ class ProgressivePolicy:
optimizer=self.get_optimizer(stage_id),
model=self.get_model(stage_id, old_model=None))
_progressive_policy_creation_counter.get_cell().increase_by(1)
def compute_stage_id(self, global_step: int) -> int:
for stage_id in range(self.num_stages()):
global_step -= self.num_steps(stage_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment