"...text-generation-inference.git" did not exist on "6e3220529df5906ae586031873b7865e9923040b"
Commit d665ab90 authored by Deyu Fu's avatar Deyu Fu
Browse files

improve backward compatibility

parent f06feced
import types
import torch
import fused_adam_cuda
......@@ -65,7 +66,9 @@ class FusedAdam(torch.optim.Optimizer):
if grads is None:
grads_group = [None]*len(self.param_groups)
# backward compatibility
# assuming a list of parameter means single group
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0])!=list:
grads_group = [grads]
else:
......@@ -73,6 +76,8 @@ class FusedAdam(torch.optim.Optimizer):
if output_params is None:
output_params_group = [None]*len(self.param_groups)
elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0])!=list:
output_params_group = [output_params]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment