Commit 7ce2e04f authored by Michael Carilli's avatar Michael Carilli
Browse files

Reorganizing fp16_optimizer

parent 31cee8e7
...@@ -138,7 +138,7 @@ class FP16_Optimizer(object): ...@@ -138,7 +138,7 @@ class FP16_Optimizer(object):
the loss scale is not recommended. the loss scale is not recommended.
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
Pytorch DataParallel or DistributedDataParallel, :class:`FP16_Optimizer` should still work as Pytorch DistributedDataParallel, :class:`FP16_Optimizer` should still work as
intended. intended.
""" """
...@@ -198,29 +198,11 @@ class FP16_Optimizer(object): ...@@ -198,29 +198,11 @@ class FP16_Optimizer(object):
self.overflow = False self.overflow = False
self.first_closure_call_this_step = True self.first_closure_call_this_step = True
# Promote optimizer.state, and optimizer.param_groups, to accommodate user code that def __getstate__
# directly manipulates "optimizer.param_groups" (for example, to adjust the learning rate). raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
def __getattribute__(self, name):
# I could condense the two cases by saying
# if name in ['state', 'param_groups']:
# return self.optimizer.__dict__[name],
# but this would bypass self.optimizer's custom getters and setters, if it chose to define any.
# I could also use properties, as for loss_scale, but I don't know if properties bypass
# self.optimizer's custom getters and setters.
if name == 'state':
return self.optimizer.state
elif name == 'param_groups':
return self.optimizer.param_groups
else:
return object.__getattribute__(self, name)
def __setattr__(self, name, value): def __setstate__
if name == 'state': raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
self.optimizer.state = value
elif name == 'param_groups':
self.optimizer.param_groups = value
else:
object.__setattr__(self, name, value)
def zero_grad(self): def zero_grad(self):
""" """
...@@ -250,6 +232,10 @@ class FP16_Optimizer(object): ...@@ -250,6 +232,10 @@ class FP16_Optimizer(object):
def _update_scale(self, has_overflow=False): def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow) self.loss_scaler.update_scale(has_overflow)
def _master_params_to_model_params(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable # To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
def _model_grads_to_master_grads(self): def _model_grads_to_master_grads(self):
...@@ -286,10 +272,6 @@ class FP16_Optimizer(object): ...@@ -286,10 +272,6 @@ class FP16_Optimizer(object):
else: else:
return -1 return -1
def _master_params_to_model_params(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def state_dict(self): def state_dict(self):
""" """
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
...@@ -540,7 +522,8 @@ class FP16_Optimizer(object): ...@@ -540,7 +522,8 @@ class FP16_Optimizer(object):
return None return None
else: else:
return None return None
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self): def _get_loss_scale(self):
return self.loss_scaler.loss_scale return self.loss_scaler.loss_scale
...@@ -548,3 +531,27 @@ class FP16_Optimizer(object): ...@@ -548,3 +531,27 @@ class FP16_Optimizer(object):
self.loss_scaler.cur_scale = value self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale) loss_scale = property(_get_loss_scale, _set_loss_scale)
# Promote optimizer.state, and optimizer.param_groups, to accommodate user code that
# directly manipulates "optimizer.param_groups" (for example, to adjust the learning rate).
def __getattribute__(self, name):
# I could condense the two cases by saying
# if name in ['state', 'param_groups']:
# return self.optimizer.__dict__[name],
# but this would bypass self.optimizer's custom getters and setters, if it chose to define any.
# I could also use properties, as for loss_scale, but I don't know if properties bypass
# self.optimizer's custom getters and setters.
if name == 'state':
return self.optimizer.state
elif name == 'param_groups':
return self.optimizer.param_groups
else:
return object.__getattribute__(self, name)
def __setattr__(self, name, value):
if name == 'state':
self.optimizer.state = value
elif name == 'param_groups':
self.optimizer.param_groups = value
else:
object.__setattr__(self, name, value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment