Unverified Commit 4341f5e8 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[lazyinit] fix clone and deepcopy (#3553)

parent 1c7734bc
......@@ -14,8 +14,8 @@ from colossalai.tensor.d_tensor.layout import Layout
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
"arange",
"empty",
"full",
"empty",
"linspace",
"logspace",
"ones",
......@@ -324,7 +324,9 @@ class LazyTensor(torch.Tensor):
def clone(self) -> "LazyTensor":
def factory_fn():
return self.materialize().clone()
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
return new_tensor.clone()
target = LazyTensor(factory_fn, meta_data=self._meta_data)
......@@ -333,6 +335,26 @@ class LazyTensor(torch.Tensor):
def detach(self) -> Tensor:
return self
def __deepcopy__(self, memo):
if not self.is_leaf:
raise RuntimeError("Only Tensors created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment")
if id(self) in memo:
return memo[id(self)]
def factory_fn():
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
copied = new_tensor.detach().clone()
if new_tensor.requires_grad:
copied.requires_grad_()
return copied
target = LazyTensor(factory_fn, meta_data=self._meta_data)
memo[id(self)] = target
return target
@property
def data(self):
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment