Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e1620dda
Unverified
Commit
e1620dda
authored
Jun 15, 2022
by
Frank Lee
Committed by
GitHub
Jun 15, 2022
Browse files
[fx] added coloproxy (#1115)
parent
6f82ac9b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
0 deletions
+61
-0
colossalai/fx/__init__.py
colossalai/fx/__init__.py
+0
-0
colossalai/fx/proxy.py
colossalai/fx/proxy.py
+61
-0
No files found.
colossalai/fx/__init__.py
0 → 100644
View file @
e1620dda
colossalai/fx/proxy.py
0 → 100644
View file @
e1620dda
import
operator
import
torch
from
torch.fx.proxy
import
Proxy
,
Attribute
__all__
=
[
'ColoProxy'
]
class
ColoProxy
(
Proxy
):
"""
ColoProxy is a proxy class which uses meta tensor to handle data-dependent control flow. The original torch.fx proxy
cannot be used to infer the condition statement, with this proxy, torch.fx can still run even with if statements.
Usage:
proxy = tracer.create_proxy(...)
proxy.meta_tensor = torch.empty(4, 2, device='meta')
print(len(proxy)) # expect output 4
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
meta_tensor
=
None
@
property
def
meta_tensor
(
self
):
return
self
.
meta_tensor
@
meta_tensor
.
setter
def
meta_tensor
(
self
,
tensor
:
torch
.
Tensor
):
assert
tensor
.
is_meta
,
'Expected to receive a meta tensor, but got a non-meta tensor'
self
.
meta_tensor
=
tensor
@
property
def
has_meta_tensor
(
self
):
return
self
.
meta_tensor
is
not
None
def
_assert_has_meta
(
self
):
assert
self
.
has_meta_tensor
,
f
'Meta tensor is not set for
{
self
.
node
.
name
}
'
@
property
def
dtype
(
self
):
self
.
_assert_has_meta
()
return
self
.
meta_tensor
.
dtype
def
__len__
(
self
):
self
.
_assert_has_meta
()
return
len
(
self
.
meta_tensor
)
def
__bool__
(
self
):
self
.
_assert_has_meta
()
return
self
.
meta_tensor
def
__getattr__
(
self
,
k
):
if
k
==
"metadata"
:
return
self
.
meta_tensor
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return
Attribute
(
self
,
k
)
def
__setitem__
(
self
,
indices
,
values
):
return
self
.
tracer
.
create_proxy
(
"call_function"
,
operator
.
setitem
,
(
self
,
indices
,
values
),
{})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment