"vscode:/vscode.git/clone" did not exist on "cf156afa685f8e211cd77e1360bfed215f7ed6a0"
Unverified Commit b7f5cb78 authored by marload's avatar marload Committed by GitHub
Browse files

Early Return Pattern "if return else return" -> "if return return" (#197)

parent 675d73e0
...@@ -271,9 +271,8 @@ def get_lr_from_config(config): ...@@ -271,9 +271,8 @@ def get_lr_from_config(config):
if lr_schedule == LR_RANGE_TEST: if lr_schedule == LR_RANGE_TEST:
return lr_params[LR_RANGE_TEST_MIN_LR], '' return lr_params[LR_RANGE_TEST_MIN_LR], ''
elif lr_schedule == ONE_CYCLE: if lr_schedule == ONE_CYCLE:
return lr_params[CYCLE_MAX_LR], '' return lr_params[CYCLE_MAX_LR], ''
else:
# Warmup LR # Warmup LR
return lr_params[WARMUP_MAX_LR], '' return lr_params[WARMUP_MAX_LR], ''
...@@ -624,7 +623,6 @@ class OneCycle(object): ...@@ -624,7 +623,6 @@ class OneCycle(object):
""" """
if self.last_batch_iteration <= self.total_size: if self.last_batch_iteration <= self.total_size:
return self._get_cycle_lr() return self._get_cycle_lr()
else:
return self._get_decay_lr(self.last_batch_iteration - self.total_size) return self._get_decay_lr(self.last_batch_iteration - self.total_size)
def step(self, batch_iteration=None): def step(self, batch_iteration=None):
...@@ -701,7 +699,6 @@ class WarmupLR(object): ...@@ -701,7 +699,6 @@ class WarmupLR(object):
def _get_gamma(self): def _get_gamma(self):
if self.last_batch_iteration < self.warmup_num_steps: if self.last_batch_iteration < self.warmup_num_steps:
return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1) return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
else:
return 1.0 return 1.0
def _format_param(self, optimizer, param_value, param_name): def _format_param(self, optimizer, param_value, param_name):
...@@ -712,5 +709,4 @@ class WarmupLR(object): ...@@ -712,5 +709,4 @@ class WarmupLR(object):
param_name, param_name,
FileNotFoundError(param_value))) FileNotFoundError(param_value)))
return list(param_value) return list(param_value)
else:
return [param_value] * len(optimizer.param_groups) return [param_value] * len(optimizer.param_groups)
...@@ -23,7 +23,6 @@ import torch ...@@ -23,7 +23,6 @@ import torch
def to_python_float(t): def to_python_float(t):
if hasattr(t, 'item'): if hasattr(t, 'item'):
return t.item() return t.item()
else:
return t[0] return t[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment