"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "565f0c88fc5cfbcfd5476da8146e92c793af4834"
Commit d665ab90 authored by Deyu Fu's avatar Deyu Fu
Browse files

improve backward compatibility

parent f06feced
import types
import torch import torch
import fused_adam_cuda import fused_adam_cuda
...@@ -65,7 +66,9 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -65,7 +66,9 @@ class FusedAdam(torch.optim.Optimizer):
if grads is None: if grads is None:
grads_group = [None]*len(self.param_groups) grads_group = [None]*len(self.param_groups)
# backward compatibility # backward compatibility
# assuming a list of parameter means single group # assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0])!=list: elif type(grads[0])!=list:
grads_group = [grads] grads_group = [grads]
else: else:
...@@ -73,6 +76,8 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -73,6 +76,8 @@ class FusedAdam(torch.optim.Optimizer):
if output_params is None: if output_params is None:
output_params_group = [None]*len(self.param_groups) output_params_group = [None]*len(self.param_groups)
elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0])!=list: elif type(output_params[0])!=list:
output_params_group = [output_params] output_params_group = [output_params]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment