Unverified Commit aca9778e authored by Aman Gupta's avatar Aman Gupta Committed by GitHub
Browse files

Make minor improvements to optimizer.py (#1687)

parent fd2949ab
...@@ -64,9 +64,9 @@ class GlobalOptimManager: ...@@ -64,9 +64,9 @@ class GlobalOptimManager:
parameters (`torch.Tensor` or `list(torch.Tensors)`): parameters (`torch.Tensor` or `list(torch.Tensors)`):
The input parameters. The input parameters.
key (`str`): key (`str`):
The hyperparamter to override. The hyperparameter to override.
value: value:
The hyperparameter values. The hyperparameter value.
key_value_dict (`dict`): key_value_dict (`dict`):
A dictionary with multiple key-values to override. A dictionary with multiple key-values to override.
...@@ -115,7 +115,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -115,7 +115,7 @@ class Optimizer8bit(torch.optim.Optimizer):
Base 8-bit optimizer class. Base 8-bit optimizer class.
Arguments: Arguments:
params (`torch.tensor`): params (`torch.Tensor`):
The input parameters to optimize. The input parameters to optimize.
optim_bits (`int`, defaults to 32): optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state. The number of bits of the optimizer state.
...@@ -291,7 +291,7 @@ class Optimizer8bit(torch.optim.Optimizer): ...@@ -291,7 +291,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.update_step(group, p, gindex, pindex) self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize() torch.cuda.synchronize()
if self.is_paged: if self.is_paged:
# all paged operation are asynchronous, we need # all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state # to sync to make sure all tensors are in the right state
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -371,7 +371,7 @@ class Optimizer2State(Optimizer8bit): ...@@ -371,7 +371,7 @@ class Optimizer2State(Optimizer8bit):
Arguments: Arguments:
optimizer_name (`str`): optimizer_name (`str`):
The name of the optimizer. The name of the optimizer.
params (`torch.tensor`): params (`torch.Tensor`):
The input parameters to optimize. The input parameters to optimize.
lr (`float`, defaults to 1e-3): lr (`float`, defaults to 1e-3):
The learning rate. The learning rate.
...@@ -428,7 +428,6 @@ class Optimizer2State(Optimizer8bit): ...@@ -428,7 +428,6 @@ class Optimizer2State(Optimizer8bit):
if args is None: if args is None:
args = {} args = {}
args["optim_bits"] = optim_bits args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise args["block_wise"] = block_wise
...@@ -613,7 +612,7 @@ class Optimizer1State(Optimizer8bit): ...@@ -613,7 +612,7 @@ class Optimizer1State(Optimizer8bit):
Arguments: Arguments:
optimizer_name (`str`): optimizer_name (`str`):
The name of the optimizer. The name of the optimizer.
params (`torch.tensor`): params (`torch.Tensor`):
The input parameters to optimize. The input parameters to optimize.
lr (`float`, defaults to 1e-3): lr (`float`, defaults to 1e-3):
The learning rate. The learning rate.
...@@ -655,7 +654,6 @@ class Optimizer1State(Optimizer8bit): ...@@ -655,7 +654,6 @@ class Optimizer1State(Optimizer8bit):
if args is None: if args is None:
args = {} args = {}
args["optim_bits"] = optim_bits args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise args["block_wise"] = block_wise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment