Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
16302a53
Unverified
Commit
16302a53
authored
Jun 15, 2022
by
Frank Lee
Committed by
GitHub
Jun 15, 2022
Browse files
[fx] added unit test for coloproxy (#1119)
* [fx] added unit test for coloproxy * polish code * polish code
parent
7d14b473
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
4 deletions
+40
-4
colossalai/fx/proxy.py
colossalai/fx/proxy.py
+17
-4
tests/test_fx/test_coloproxy.py
tests/test_fx/test_coloproxy.py
+23
-0
No files found.
colossalai/fx/proxy.py
View file @
16302a53
...
...
@@ -19,16 +19,16 @@ class ColoProxy(Proxy):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
meta_tensor
=
None
self
.
_
meta_tensor
=
None
@
property
def
meta_tensor
(
self
):
return
self
.
meta_tensor
return
self
.
_
meta_tensor
@
meta_tensor
.
setter
def
meta_tensor
(
self
,
tensor
:
torch
.
Tensor
):
assert
tensor
.
is_meta
,
'Expected to receive a meta tensor, but got a non-meta tensor'
self
.
meta_tensor
=
tensor
assert
tensor
is
None
or
tensor
.
is_meta
,
'Expected to receive a meta tensor, but got a non-meta tensor'
self
.
_
meta_tensor
=
tensor
@
property
def
has_meta_tensor
(
self
):
...
...
@@ -42,6 +42,19 @@ class ColoProxy(Proxy):
self
.
_assert_has_meta
()
return
self
.
meta_tensor
.
dtype
@
property
def
shape
(
self
):
self
.
_assert_has_meta
()
return
self
.
meta_tensor
.
shape
def
dim
(
self
):
self
.
_assert_has_meta
()
return
self
.
meta_tensor
.
dim
()
def
size
(
self
,
dim
:
int
=
None
):
self
.
_assert_has_meta
()
return
self
.
meta_tensor
.
size
(
dim
=
dim
)
def
__len__
(
self
):
self
.
_assert_has_meta
()
return
len
(
self
.
meta_tensor
)
...
...
tests/test_fx/test_coloproxy.py
0 → 100644
View file @
16302a53
import
torch
from
colossalai.fx.proxy
import
ColoProxy
def
test_coloproxy
():
# create a dummy node only for testing purpose
model
=
torch
.
nn
.
Linear
(
10
,
10
)
gm
=
torch
.
fx
.
symbolic_trace
(
model
)
node
=
list
(
gm
.
graph
.
nodes
)[
0
]
# create proxy
proxy
=
ColoProxy
(
node
=
node
)
proxy
.
meta_tensor
=
torch
.
empty
(
4
,
2
,
device
=
'meta'
)
assert
len
(
proxy
)
==
4
assert
proxy
.
shape
[
0
]
==
4
and
proxy
.
shape
[
1
]
==
2
assert
proxy
.
dim
()
==
2
assert
proxy
.
dtype
==
torch
.
float32
assert
proxy
.
size
(
0
)
==
4
if
__name__
==
'__main__'
:
test_coloproxy
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment