Unverified Commit b7f5cb78 authored by marload's avatar marload Committed by GitHub
Browse files

Early Return Pattern "if return else return" -> "if return return" (#197)

parent 675d73e0
...@@ -271,11 +271,10 @@ def get_lr_from_config(config): ...@@ -271,11 +271,10 @@ def get_lr_from_config(config):
if lr_schedule == LR_RANGE_TEST: if lr_schedule == LR_RANGE_TEST:
return lr_params[LR_RANGE_TEST_MIN_LR], '' return lr_params[LR_RANGE_TEST_MIN_LR], ''
elif lr_schedule == ONE_CYCLE: if lr_schedule == ONE_CYCLE:
return lr_params[CYCLE_MAX_LR], '' return lr_params[CYCLE_MAX_LR], ''
else: # Warmup LR
# Warmup LR return lr_params[WARMUP_MAX_LR], ''
return lr_params[WARMUP_MAX_LR], ''
""" """
...@@ -624,8 +623,7 @@ class OneCycle(object): ...@@ -624,8 +623,7 @@ class OneCycle(object):
""" """
if self.last_batch_iteration <= self.total_size: if self.last_batch_iteration <= self.total_size:
return self._get_cycle_lr() return self._get_cycle_lr()
else: return self._get_decay_lr(self.last_batch_iteration - self.total_size)
return self._get_decay_lr(self.last_batch_iteration - self.total_size)
def step(self, batch_iteration=None): def step(self, batch_iteration=None):
if batch_iteration is None: if batch_iteration is None:
...@@ -701,8 +699,7 @@ class WarmupLR(object): ...@@ -701,8 +699,7 @@ class WarmupLR(object):
def _get_gamma(self): def _get_gamma(self):
if self.last_batch_iteration < self.warmup_num_steps: if self.last_batch_iteration < self.warmup_num_steps:
return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1) return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
else: return 1.0
return 1.0
def _format_param(self, optimizer, param_value, param_name): def _format_param(self, optimizer, param_value, param_name):
if isinstance(param_value, list) or isinstance(param_value, tuple): if isinstance(param_value, list) or isinstance(param_value, tuple):
...@@ -712,5 +709,4 @@ class WarmupLR(object): ...@@ -712,5 +709,4 @@ class WarmupLR(object):
param_name, param_name,
FileNotFoundError(param_value))) FileNotFoundError(param_value)))
return list(param_value) return list(param_value)
else: return [param_value] * len(optimizer.param_groups)
return [param_value] * len(optimizer.param_groups)
...@@ -23,8 +23,7 @@ import torch ...@@ -23,8 +23,7 @@ import torch
def to_python_float(t): def to_python_float(t):
if hasattr(t, 'item'): if hasattr(t, 'item'):
return t.item() return t.item()
else: return t[0]
return t[0]
class LossScaler: class LossScaler:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment