Commit ff969b98 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Update type checks in tensor_tree_map

parent f34bef8e
......@@ -92,16 +92,16 @@ def dict_map(fn, dic, leaf_type):
def tree_map(fn, tree, leaf_type):
tree_type = type(tree)
if(tree_type is dict):
if(isinstance(tree, dict)):
return dict_map(fn, tree, leaf_type)
elif(tree_type is list):
elif(isinstance(tree, list)):
return [tree_map(fn, x, leaf_type) for x in tree]
elif(tree_type is tuple):
elif(isinstance(tree, tuple)):
return tuple([tree_map(fn, x, leaf_type) for x in tree])
elif(tree_type is leaf_type):
elif(isinstance(tree, leaf_type)):
return fn(tree)
else:
print(type(tree))
raise ValueError("Not supported")
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment