Unverified Commit 1a9e2c2d authored by Ziyue Jiang's avatar Ziyue Jiang Committed by GitHub
Browse files

[tensor] fix kwargs in colo_tensor torch_funtion (#825)

parent eb1b8990
...@@ -63,6 +63,6 @@ class ColoTensor(object): ...@@ -63,6 +63,6 @@ class ColoTensor(object):
kwargs = {} kwargs = {}
kwargs = { kwargs = {
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items()
} }
return func(*args, **kwargs) return func(*args, **kwargs)
...@@ -59,14 +59,13 @@ def test_no_wrap_op(): ...@@ -59,14 +59,13 @@ def test_no_wrap_op():
t_ref = torch.randn(3, 5) t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone()) t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.sum(t) == torch.sum(t_ref) assert torch.sum(t) == torch.sum(t_ref)
assert torch.sum(input=t) == torch.sum(input=t_ref)
def test_lazy_init_tensor(): def test_lazy_init_tensor():
lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True) lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor == None assert lazy_t._torch_tensor == None
assert lazy_t.torch_tensor().numel() == 6 assert lazy_t.torch_tensor().numel() == 6
if __name__ == '__main__': if __name__ == '__main__':
test_lazy_init_tensor() test_no_wrap_op()
# test_element_wise() # test_element_wise()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment