Unverified Commit 16302a53 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[fx] added unit test for coloproxy (#1119)

* [fx] added unit test for coloproxy

* polish code

* polish code
parent 7d14b473
...@@ -19,16 +19,16 @@ class ColoProxy(Proxy): ...@@ -19,16 +19,16 @@ class ColoProxy(Proxy):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.meta_tensor = None self._meta_tensor = None
@property @property
def meta_tensor(self): def meta_tensor(self):
return self.meta_tensor return self._meta_tensor
@meta_tensor.setter @meta_tensor.setter
def meta_tensor(self, tensor: torch.Tensor): def meta_tensor(self, tensor: torch.Tensor):
assert tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor' assert tensor is None or tensor.is_meta, 'Expected to receive a meta tensor, but got a non-meta tensor'
self.meta_tensor = tensor self._meta_tensor = tensor
@property @property
def has_meta_tensor(self): def has_meta_tensor(self):
...@@ -42,6 +42,19 @@ class ColoProxy(Proxy): ...@@ -42,6 +42,19 @@ class ColoProxy(Proxy):
self._assert_has_meta() self._assert_has_meta()
return self.meta_tensor.dtype return self.meta_tensor.dtype
@property
def shape(self):
self._assert_has_meta()
return self.meta_tensor.shape
def dim(self):
self._assert_has_meta()
return self.meta_tensor.dim()
def size(self, dim: int = None):
self._assert_has_meta()
return self.meta_tensor.size(dim=dim)
def __len__(self): def __len__(self):
self._assert_has_meta() self._assert_has_meta()
return len(self.meta_tensor) return len(self.meta_tensor)
......
import torch
from colossalai.fx.proxy import ColoProxy
def test_coloproxy():
# create a dummy node only for testing purpose
model = torch.nn.Linear(10, 10)
gm = torch.fx.symbolic_trace(model)
node = list(gm.graph.nodes)[0]
# create proxy
proxy = ColoProxy(node=node)
proxy.meta_tensor = torch.empty(4, 2, device='meta')
assert len(proxy) == 4
assert proxy.shape[0] == 4 and proxy.shape[1] == 2
assert proxy.dim() == 2
assert proxy.dtype == torch.float32
assert proxy.size(0) == 4
if __name__ == '__main__':
test_coloproxy()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment