Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a9d64377
Commit
a9d64377
authored
Dec 06, 2022
by
oahzxl
Browse files
support new op
parent
f24c418b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
2 deletions
+61
-2
chunk_codegen.py
chunk_codegen.py
+61
-2
No files found.
chunk_codegen.py
View file @
a9d64377
...
@@ -200,8 +200,12 @@ class NodeIndexTracer(object):
...
@@ -200,8 +200,12 @@ class NodeIndexTracer(object):
Args:
Args:
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
input_node
,
weight
,
bias
=
node
.
args
if
len
(
node
.
args
)
==
2
:
input_node
,
weight
=
node
.
args
bias
=
None
else
:
input_node
,
weight
,
bias
=
node
.
args
input_node_idx_trace
=
self
.
_find_idx_trace_from_node
(
input_node
)
input_node_idx_trace
=
self
.
_find_idx_trace_from_node
(
input_node
)
weight_idx_trace
=
self
.
_find_idx_trace_from_node
(
weight
)
weight_idx_trace
=
self
.
_find_idx_trace_from_node
(
weight
)
...
@@ -284,6 +288,53 @@ class NodeIndexTracer(object):
...
@@ -284,6 +288,53 @@ class NodeIndexTracer(object):
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_inherit_computation
(
node
.
args
[
0
],
node
)
self
.
_inherit_computation
(
node
.
args
[
0
],
node
)
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
'dim'
]])
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
'dim'
]])
def
_assign_unsqueeze_index
(
self
,
node
,
node_idx
):
"""
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self
.
_assign_index_as_input
(
node
,
node_idx
)
self
.
_inherit_computation
(
node
.
args
[
0
],
node
)
self
.
idx_trace_list
[
node_idx
][
'idx'
].
insert
(
node
.
args
[
1
],
self
.
_add_index
())
def
_assign_dropout_index
(
self
,
node
,
node_idx
):
"""
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
Args:
node (node)
node_idx (int)
"""
self
.
_assign_index_as_input
(
node
,
node_idx
)
def
_assign_ones_like_index
(
self
,
node
,
node_idx
):
"""
Assign index for oneslike op.
1. assign new index for all dim
Args:
node (node)
node_idx (int)
"""
self
.
_assign_all_index
(
node
,
node_idx
)
def
_assign_to_index
(
self
,
node
,
node_idx
):
"""
Assign index for to op.
1. assign new index for all dim
Args:
node (node)
node_idx (int)
"""
self
.
_assign_index_as_input
(
node
,
node_idx
)
def
_assign_view_reshape_index
(
self
,
node
,
node_idx
):
def
_assign_view_reshape_index
(
self
,
node
,
node_idx
):
"""
"""
...
@@ -388,6 +439,10 @@ class NodeIndexTracer(object):
...
@@ -388,6 +439,10 @@ class NodeIndexTracer(object):
self
.
_assign_permute_index
(
node
,
idx
)
self
.
_assign_permute_index
(
node
,
idx
)
elif
'view'
in
node
.
name
or
'reshape'
in
node
.
name
:
elif
'view'
in
node
.
name
or
'reshape'
in
node
.
name
:
self
.
_assign_view_reshape_index
(
node
,
idx
)
self
.
_assign_view_reshape_index
(
node
,
idx
)
elif
'unsqueeze'
in
node
.
name
:
self
.
_assign_unsqueeze_index
(
node
,
idx
)
elif
'to'
in
node
.
name
:
self
.
_assign_to_index
(
node
,
idx
)
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
elif
node
.
op
==
'call_function'
:
elif
node
.
op
==
'call_function'
:
...
@@ -399,6 +454,10 @@ class NodeIndexTracer(object):
...
@@ -399,6 +454,10 @@ class NodeIndexTracer(object):
self
.
_assign_softmax_index
(
node
,
idx
)
self
.
_assign_softmax_index
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
'mul'
,
'add'
,
'sigmoid'
,
'relu'
]):
elif
any
(
n
in
node
.
name
for
n
in
[
'mul'
,
'add'
,
'sigmoid'
,
'relu'
]):
self
.
_assign_elementwise_index
(
node
,
idx
)
self
.
_assign_elementwise_index
(
node
,
idx
)
elif
'ones_like'
in
node
.
name
:
self
.
_assign_ones_like_index
(
node
,
idx
)
elif
'dropout'
in
node
.
name
:
self
.
_assign_dropout_index
(
node
,
idx
)
elif
'getattr'
in
node
.
name
:
elif
'getattr'
in
node
.
name
:
continue
# get attr like shape
continue
# get attr like shape
elif
'getitem'
in
node
.
name
:
elif
'getitem'
in
node
.
name
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment