Unverified Commit aca9778e authored by Aman Gupta's avatar Aman Gupta Committed by GitHub
Browse files

Make minor improvements to optimizer.py (#1687)

parent fd2949ab
......@@ -64,9 +64,9 @@ class GlobalOptimManager:
parameters (`torch.Tensor` or `list(torch.Tensors)`):
The input parameters.
key (`str`):
The hyperparamter to override.
The hyperparameter to override.
value:
The hyperparameter values.
The hyperparameter value.
key_value_dict (`dict`):
A dictionary with multiple key-values to override.
......@@ -115,7 +115,7 @@ class Optimizer8bit(torch.optim.Optimizer):
Base 8-bit optimizer class.
Arguments:
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
......@@ -291,7 +291,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
if self.is_paged:
# all paged operation are asynchronous, we need
# all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
......@@ -371,7 +371,7 @@ class Optimizer2State(Optimizer8bit):
Arguments:
optimizer_name (`str`):
The name of the optimizer.
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
......@@ -428,7 +428,6 @@ class Optimizer2State(Optimizer8bit):
if args is None:
args = {}
args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
......@@ -613,7 +612,7 @@ class Optimizer1State(Optimizer8bit):
Arguments:
optimizer_name (`str`):
The name of the optimizer.
params (`torch.tensor`):
params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
......@@ -655,7 +654,6 @@ class Optimizer1State(Optimizer8bit):
if args is None:
args = {}
args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment