Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3b7d6712
Commit
3b7d6712
authored
Dec 06, 2022
by
oahzxl
Browse files
finish region search loop
parent
7330d907
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
40 deletions
+116
-40
chunk_codegen.py
chunk_codegen.py
+114
-38
chunk_codegen_run.py
chunk_codegen_run.py
+2
-2
No files found.
chunk_codegen.py
View file @
3b7d6712
...
@@ -21,7 +21,7 @@ class NodeIndexTracer(object):
...
@@ -21,7 +21,7 @@ class NodeIndexTracer(object):
def
__init__
(
self
,
gm
)
->
None
:
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
gm
=
gm
self
.
nodes_list
=
list
(
gm
.
graph
.
nodes
)
self
.
nodes_list
=
list
(
gm
.
graph
.
nodes
)
self
.
idx_trace_list
=
[{
'idx'
:
[],
'compute'
:
[]
}
for
_
in
range
(
len
(
self
.
nodes_list
))]
self
.
idx_trace_list
=
[{
'idx'
:
[],
'compute'
:
{}
}
for
_
in
range
(
len
(
self
.
nodes_list
))]
self
.
idx_trace_equal
=
[]
self
.
idx_trace_equal
=
[]
self
.
idx_view_list
=
[]
self
.
idx_view_list
=
[]
self
.
idx_count
=
-
1
self
.
idx_count
=
-
1
...
@@ -48,9 +48,12 @@ class NodeIndexTracer(object):
...
@@ -48,9 +48,12 @@ class NodeIndexTracer(object):
"""
"""
_
,
compute_from
=
self
.
_find_trace_from_node
(
node_from
)
_
,
compute_from
=
self
.
_find_trace_from_node
(
node_from
)
idx_to
,
compute_to
=
self
.
_find_trace_from_node
(
node_to
)
idx_to
,
compute_to
=
self
.
_find_trace_from_node
(
node_to
)
for
i
in
compute_from
:
for
k
,
v
in
compute_from
.
items
():
if
i
in
idx_to
and
i
not
in
compute_to
:
if
k
in
idx_to
:
compute_to
.
append
(
i
)
if
k
in
compute_to
:
compute_to
[
k
].
extend
(
v
)
else
:
compute_to
[
k
]
=
copy
.
deepcopy
(
v
)
def
_mark_idx_equal
(
self
,
idx1
,
idx2
):
def
_mark_idx_equal
(
self
,
idx1
,
idx2
):
"""
"""
...
@@ -77,7 +80,9 @@ class NodeIndexTracer(object):
...
@@ -77,7 +80,9 @@ class NodeIndexTracer(object):
for
d
in
dim
:
for
d
in
dim
:
cur_idx
=
input_node_idx_trace
[
d
]
cur_idx
=
input_node_idx_trace
[
d
]
if
cur_idx
not
in
self
.
idx_trace_list
[
idx
][
'compute'
]:
if
cur_idx
not
in
self
.
idx_trace_list
[
idx
][
'compute'
]:
self
.
idx_trace_list
[
idx
][
'compute'
].
append
(
cur_idx
)
self
.
idx_trace_list
[
idx
][
'compute'
][
cur_idx
]
=
[
idx
]
else
:
self
.
idx_trace_list
[
idx
][
'compute'
][
cur_idx
].
append
(
idx
)
def
_find_trace_from_node
(
self
,
node
):
def
_find_trace_from_node
(
self
,
node
):
"""
"""
...
@@ -357,6 +362,11 @@ class NodeIndexTracer(object):
...
@@ -357,6 +362,11 @@ class NodeIndexTracer(object):
"dim_to"
:
dim_to
}
"dim_to"
:
dim_to
}
self
.
idx_view_list
.
append
(
view_dict
)
self
.
idx_view_list
.
append
(
view_dict
)
def
_remove_duplicate_compute
(
self
):
for
i
in
self
.
idx_trace_list
:
for
k
,
v
in
i
[
'compute'
].
items
():
i
[
'compute'
][
k
]
=
list
(
set
(
v
))
def
_merge_equal_idx
(
self
):
def
_merge_equal_idx
(
self
):
idx_equal
=
copy
.
deepcopy
(
self
.
idx_trace_equal
)
idx_equal
=
copy
.
deepcopy
(
self
.
idx_trace_equal
)
idx_equal
.
reverse
()
idx_equal
.
reverse
()
...
@@ -406,6 +416,8 @@ class NodeIndexTracer(object):
...
@@ -406,6 +416,8 @@ class NodeIndexTracer(object):
continue
continue
else
:
else
:
raise
NotImplementedError
(
node
.
op
,
"op not implemented yet!"
)
raise
NotImplementedError
(
node
.
op
,
"op not implemented yet!"
)
self
.
_remove_duplicate_compute
()
self
.
_merge_equal_idx
()
self
.
_merge_equal_idx
()
...
@@ -521,6 +533,19 @@ class MemoryEstimator(object):
...
@@ -521,6 +533,19 @@ class MemoryEstimator(object):
print
(
""
)
print
(
""
)
print
(
"
\n
"
)
print
(
"
\n
"
)
def
_print_compute_op_mem_log
(
self
,
log
,
nodes
,
title
=
None
):
if
title
:
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
if
n
.
op
in
[
'placeholder'
,
'get_attr'
,
'output'
]:
continue
if
any
(
i
in
n
.
name
for
i
in
[
'getitem'
,
'getattr'
]):
continue
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
''
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
estimate_chunk_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
=
None
,
end_nodes
=
None
,
chunk_dims
=
None
,
chunk_sizes
=
None
):
def
estimate_chunk_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
=
None
,
end_nodes
=
None
,
chunk_dims
=
None
,
chunk_sizes
=
None
):
act_memory
=
0.0
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_peak_log
=
[]
...
@@ -584,8 +609,10 @@ class MemoryEstimator(object):
...
@@ -584,8 +609,10 @@ class MemoryEstimator(object):
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
print
(
"with chunk"
if
use_chunk
else
"without chunk"
)
print
(
"with chunk"
if
use_chunk
else
"without chunk"
)
self
.
_print_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
self
.
_print_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
self
.
_print_compute_op_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_compute_op_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
# param_memory = parameter_size(gm)
# param_memory = parameter_size(gm)
# all_memory = act_memory + param_memory
# all_memory = act_memory + param_memory
...
@@ -602,7 +629,7 @@ class ChunkRegionSearch(object):
...
@@ -602,7 +629,7 @@ class ChunkRegionSearch(object):
def
_find_peak_node
(
self
,
mem_peak
):
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
max_value
=
max
(
mem_peak
)
max_idx
=
[
mem_peak
.
index
(
max_value
)
]
max_idx
=
mem_peak
.
index
(
max_value
)
return
max_idx
return
max_idx
def
_get_free_var
(
self
):
def
_get_free_var
(
self
):
...
@@ -635,18 +662,35 @@ class ChunkRegionSearch(object):
...
@@ -635,18 +662,35 @@ class ChunkRegionSearch(object):
raise
RuntimeError
()
raise
RuntimeError
()
# from peak_node to len-2
# from peak_node to len-2
chunk_region_end
=
None
chunk_region_end
=
None
for
i
in
range
(
peak_node
,
len
(
active_node
)
-
1
):
for
i
in
range
(
peak_node
,
len
(
active_node
)):
if
len
(
active_node
[
i
])
==
min_var
:
if
len
(
active_node
[
i
])
==
min_var
:
chunk_region_end
=
i
-
1
chunk_region_end
=
i
break
break
if
i
in
free_vars
or
i
==
0
:
if
i
in
free_vars
or
i
==
0
:
raise
RuntimeError
()
raise
RuntimeError
()
return
chunk_region_start
,
chunk_region_end
return
chunk_region_start
,
chunk_region_end
def
_not_compute
(
self
,
trace
,
chunk_range
,
dim_idx
):
if
trace
[
'idx'
][
dim_idx
]
not
in
trace
[
'compute'
]:
return
True
if
trace
[
'idx'
][
dim_idx
]
in
trace
[
'compute'
]
and
\
all
(
i
<
chunk_range
[
0
]
or
i
>
chunk_range
[
1
]
for
i
in
trace
[
'compute'
][
trace
[
'idx'
][
dim_idx
]]):
return
True
return
False
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
possible_chunk_region
=
[]
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
index_tracer
.
idx_trace_list
)
input_trace
=
[]
for
i
,
n
in
enumerate
(
self
.
node_list
):
if
len
(
n
.
args
)
>
0
and
n
.
op
!=
'output'
:
input_idx
=
_find_idx_by_name
(
n
.
args
[
0
].
name
,
self
.
node_list
)
input_trace
.
append
(
output_trace
[
input_idx
])
else
:
input_trace
.
append
(
None
)
for
before_idx
in
range
(
max_chunk_region
[
0
],
peak_node
):
for
before_idx
in
range
(
max_chunk_region
[
0
],
peak_node
):
for
after_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]):
for
after_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
# skip non compute nodes
if
any
(
op
in
[
'placeholder'
,
'get_attr'
,
'output'
]
for
op
in
if
any
(
op
in
[
'placeholder'
,
'get_attr'
,
'output'
]
for
op
in
[
self
.
node_list
[
before_idx
].
op
,
self
.
node_list
[
after_idx
].
op
]):
[
self
.
node_list
[
before_idx
].
op
,
self
.
node_list
[
after_idx
].
op
]):
...
@@ -656,23 +700,59 @@ class ChunkRegionSearch(object):
...
@@ -656,23 +700,59 @@ class ChunkRegionSearch(object):
continue
continue
# select free dim
# select free dim
before_trace
=
self
.
index_tracer
.
idx
_trace
_list
[
before_idx
]
before_trace
=
input
_trace
[
before_idx
]
after_trace
=
self
.
index_tracer
.
idx
_trace
_list
[
after_idx
]
after_trace
=
output
_trace
[
after_idx
]
free_dim
=
[]
free_dim
=
[]
for
i
in
range
(
min
(
len
(
before_trace
[
'idx'
]),
len
(
after_trace
[
'idx'
]))):
for
i
in
range
(
min
(
len
(
before_trace
[
'idx'
]),
len
(
after_trace
[
'idx'
]))):
if
(
before_trace
[
'idx'
][
i
]
==
after_trace
[
'idx'
][
i
]
and
if
(
before_trace
[
'idx'
][
i
]
==
after_trace
[
'idx'
][
i
]
and
before_trace
[
'idx'
][
i
]
not
in
before_trace
[
'compute'
]
and
self
.
_not_compute
(
before_trace
,
(
before_idx
,
after_idx
),
i
)
and
after_trace
[
'idx'
][
i
]
not
in
after_trace
[
'compute'
]):
self
.
_not_compute
(
after_trace
,
(
before_idx
,
after_idx
),
i
)
and
self
.
node_list
[
after_idx
].
meta
[
'tensor_meta'
].
shape
[
i
]
!=
1
):
free_dim
.
append
(
i
)
free_dim
.
append
(
i
)
possible_chunk_region
.
append
({
'region'
:
(
before_idx
,
after_idx
),
'dim'
:
free_dim
})
possible_chunk_region
.
append
({
'region'
:
(
before_idx
,
after_idx
),
'dim'
:
free_dim
})
return
possible_chunk_region
return
possible_chunk_region
def
_search_best_chunk_region
(
self
,
possible_chunk_regions
):
max_region_range
=
0
best_regions
=
None
for
i
in
possible_chunk_regions
:
if
i
[
'region'
][
1
]
-
i
[
'region'
][
0
]
>
max_region_range
:
best_regions
=
i
max_region_range
=
i
[
'region'
][
1
]
-
i
[
'region'
][
0
]
return
best_regions
def
_step_search
(
self
,
peak_node
,
active_node
):
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
)
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
best_chunk_region
=
self
.
_search_best_chunk_region
(
possible_chunk_regions
)
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
sorted_init_mem_peak
=
sorted
(
init_mem_peak
)
if
max
(
mem_peak
)
<
sorted_init_mem_peak
[
int
(
len
(
sorted_init_mem_peak
)
*
0.5
)]:
return
True
return
False
def
search_region
(
self
):
def
search_region
(
self
):
mem_peak
,
mem_after
,
active_node
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
gm
)
chunk_regions
=
[]
peak_nodes
=
self
.
_find_peak_node
(
mem_peak
)
init_mem_peak
,
_
,
active_node
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
gm
)
for
idx
,
peak_node
in
enumerate
(
peak_nodes
):
mem_peak
=
init_mem_peak
max_chunk_region
=
self
.
_search_max_chunk_region
(
active_node
,
peak_node
)
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
)
while
True
:
peak_node
=
self
.
_find_peak_node
(
mem_peak
)
chunk_region
=
self
.
_step_search
(
peak_node
,
active_node
)
if
chunk_region
is
None
or
len
(
chunk_region
[
'dim'
])
==
0
:
break
chunk_regions
.
append
(
chunk_region
)
mem_peak
,
_
,
active_node
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
gm
,
[
i
[
'region'
][
0
]
for
i
in
chunk_regions
],
[
i
[
'region'
][
1
]
for
i
in
chunk_regions
],
[
i
[
'dim'
][
0
]
for
i
in
chunk_regions
],
[
1
]
*
len
(
chunk_regions
))
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
return
chunk_regions
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
...
@@ -696,11 +776,12 @@ def _get_first_non_single_dim(shape):
...
@@ -696,11 +776,12 @@ def _get_first_non_single_dim(shape):
raise
RuntimeError
(
"can not get first non single dim for shape"
,
shape
)
raise
RuntimeError
(
"can not get first non single dim for shape"
,
shape
)
def
_gen_loop_start
(
chunk_input_meta
,
chunk_output
,
chunk_size
=
2
):
def
_gen_loop_start
(
chunk_input_meta
,
chunk_output
,
chunk_dim
,
chunk_size
=
2
):
if
len
(
chunk_input_meta
)
==
1
:
if
len
(
chunk_input_meta
)
==
1
:
node
=
chunk_input_meta
[
0
]
node
=
chunk_input_meta
[
0
]
node_shape
=
node
.
meta
[
'tensor_meta'
].
shape
node_shape
=
node
.
meta
[
'tensor_meta'
].
shape
chunk_dim
=
_get_first_non_single_dim
(
node_shape
)
free_shape
=
[
node_shape
[
i
]
if
i
in
chunk_dim
else
1
for
i
in
range
(
len
(
node_shape
))]
chunk_dim
=
_get_first_non_single_dim
(
free_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"gen_chunk_idx"
,
node_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"gen_chunk_idx"
,
node_shape
)
out_shape
=
str
(
list
(
chunk_output
.
meta
[
'tensor_meta'
].
shape
))
out_shape
=
str
(
list
(
chunk_output
.
meta
[
'tensor_meta'
].
shape
))
...
@@ -713,12 +794,13 @@ def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
...
@@ -713,12 +794,13 @@ def _gen_loop_start(chunk_input_meta, chunk_output, chunk_size=2):
return
context
return
context
def
_gen_loop_end
(
chunk_outputs
,
chunk_inputs
,
node_list
):
def
_gen_loop_end
(
chunk_outputs
,
chunk_inputs
,
node_list
,
chunk_dim
):
chunk_inputs_name
=
chunk_inputs
[
0
].
name
chunk_inputs_name
=
chunk_inputs
[
0
].
name
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_idx
=
_find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_outputs_idx
=
_find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
'tensor_meta'
].
shape
chunk_output_shape
=
chunk_outputs
.
meta
[
'tensor_meta'
].
shape
chunk_dim
=
_get_first_non_single_dim
(
chunk_output_shape
)
free_shape
=
[
chunk_output_shape
[
i
]
if
i
in
chunk_dim
else
1
for
i
in
range
(
len
(
chunk_output_shape
))]
chunk_dim
=
_get_first_non_single_dim
(
free_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"gen_chunk_idx"
,
chunk_output_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"gen_chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s
\n
"
%
(
chunk_slice
,
chunk_outputs_name
)
context
=
" chunk_result%s = %s
\n
"
%
(
chunk_slice
,
chunk_outputs_name
)
...
@@ -780,7 +862,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -780,7 +862,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
"""
"""
# find the offload regions
# find the offload regions
chunk_regions
=
[(
58
,
62
)]
chunk_region_search
=
ChunkRegionSearch
(
meta_graph
)
chunk_search
=
chunk_region_search
.
search_region
()
chunk_regions
=
[
i
[
'region'
]
for
i
in
chunk_search
]
chunk_dims
=
[
i
[
'dim'
]
for
i
in
chunk_search
]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_inputs
=
[]
chunk_inputs
=
[]
...
@@ -789,16 +875,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -789,16 +875,6 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node_list
=
list
(
nodes
)
node_list
=
list
(
nodes
)
memory_estimator
=
MemoryEstimator
()
memory_estimator
.
estimate_chunk_inference_mem
(
meta_graph
,
chunk_starts
,
chunk_ends
,
[
1
],
[
2
])
memory_estimator
.
estimate_chunk_inference_mem
(
meta_graph
)
node_index_tracer
=
NodeIndexTracer
(
meta_graph
)
node_index_tracer
.
trace_node_idx
()
chunk_region_search
=
ChunkRegionSearch
(
meta_graph
)
chunk_region_search
.
search_region
()
# find the input and output var names for each offload region
# find the input and output var names for each offload region
for
idx
,
(
start
,
end
)
in
enumerate
(
chunk_regions
):
for
idx
,
(
start
,
end
)
in
enumerate
(
chunk_regions
):
offload_node_list
=
node_list
[
start
:
end
+
1
]
offload_node_list
=
node_list
[
start
:
end
+
1
]
...
@@ -824,13 +900,13 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -824,13 +900,13 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# add for loop
# add for loop
chunk_input_meta
=
[
meta_nodes
[
i
]
for
i
in
chunk_inputs_idx
[
region_idx
]]
chunk_input_meta
=
[
meta_nodes
[
i
]
for
i
in
chunk_inputs_idx
[
region_idx
]]
body
.
append
(
_gen_loop_start
(
chunk_input_meta
,
node_list
[
chunk_ends
[
region_idx
]]))
body
.
append
(
_gen_loop_start
(
chunk_input_meta
,
node_list
[
chunk_ends
[
region_idx
]]
,
chunk_dims
[
region_idx
]
))
if
within_chunk_region
:
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
emit_node_func
(
node
,
body
)
# replace input var with chunk var
# replace input var with chunk var
if
node_idx
in
chunk_starts
:
if
node_idx
in
chunk_starts
:
body
[
-
1
]
=
body
[
-
1
].
replace
(
"("
+
chunk_inputs
[
region_idx
][
0
].
name
+
")"
,
'
(
chunk_tensor
)
'
)
body
[
-
1
]
=
body
[
-
1
].
replace
(
chunk_inputs
[
region_idx
][
0
].
name
,
'chunk_tensor'
)
body
[
-
1
]
=
' '
+
body
[
-
1
]
body
[
-
1
]
=
' '
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
...
@@ -840,7 +916,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -840,7 +916,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
if
node_idx
in
chunk_ends
:
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_end
(
node
,
chunk_inputs
[
region_idx
],
node_list
))
body
.
append
(
_gen_loop_end
(
node
,
chunk_inputs
[
region_idx
],
node_list
,
chunk_dims
[
region_idx
]
))
within_chunk_region
=
False
within_chunk_region
=
False
region_idx
+=
1
region_idx
+=
1
...
...
chunk_codegen_run.py
View file @
3b7d6712
...
@@ -45,8 +45,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
...
@@ -45,8 +45,8 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
non_fx_out
=
model
(
node
,
pair
)
non_fx_out
=
model
(
node
,
pair
)
fx_out
=
gm
(
node
,
pair
)
fx_out
=
gm
(
node
,
pair
)
assert
torch
.
equal
(
non_fx_out
[
0
],
fx_out
[
0
]),
"fx_out doesn't comply with original output"
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
]
,
atol
=
1e-6
),
"fx_out doesn't comply with original output"
assert
torch
.
equal
(
non_fx_out
[
1
],
fx_out
[
1
]),
"fx_out doesn't comply with original output"
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
]
,
atol
=
1e-6
),
"fx_out doesn't comply with original output"
# test barckward
# test barckward
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment