Commit 2bde9d2b authored by oahzxl's avatar oahzxl
Browse files

code format

parent 8a634af2
......@@ -220,7 +220,9 @@ if CODEGEN_AVAILABLE:
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem)
self.chunk_region_search = ChunkRegionSearch(
meta_graph, max_memory, print_mem
)
self.chunk_infos = self.chunk_region_search.search_region()
def _gen_python_code(
......
import copy
from .chunk_selector import ChunkSelector
from .index_tracer import IndexTracer
from .memory_estiamtor import MemoryEstimator
from .chunk_selector import ChunkSelector
import copy
from .utils import is_non_compute_node, is_non_compute_node_except_placeholder, get_node_shape
from .utils import (
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
class ChunkRegionSearch(object):
......@@ -11,7 +16,7 @@ class ChunkRegionSearch(object):
self.print_mem = print_mem
self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer)
self.memory_estimator = MemoryEstimator()
self.chunk_selector = ChunkSelector(
self.index_tracer, self.memory_estimator, max_memory=max_memory
)
......@@ -211,4 +216,3 @@ class ChunkRegionSearch(object):
self.index_tracer.node_list, chunk_infos, print_mem=True
)
return chunk_infos
......@@ -16,7 +16,7 @@ from .utils import (
class MemoryEstimator(object):
def __init__(self, index_tracer: IndexTracer) -> None:
def __init__(self) -> None:
pass
def _get_meta_node_size(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment