Unverified Commit 1a9e2c2d authored by Ziyue Jiang's avatar Ziyue Jiang Committed by GitHub
Browse files

[tensor] fix kwargs in colo_tensor torch_funtion (#825)

parent eb1b8990
......@@ -63,6 +63,6 @@ class ColoTensor(object):
kwargs = {}
kwargs = {
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs
k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
}
return func(*args, **kwargs)
......@@ -59,14 +59,13 @@ def test_no_wrap_op():
t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.sum(t) == torch.sum(t_ref)
assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_lazy_init_tensor():
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor == None
assert lazy_t.torch_tensor().numel() == 6
if __name__ == '__main__':
test_lazy_init_tensor()
test_no_wrap_op()
# test_element_wise()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment