Commit c7c224ee authored by Ziheng Qin's avatar Ziheng Qin Committed by binmakeswell
Browse files

[NFC] polish colossalai/builder/pipeline.py code style (#638)

parent 10591ecd
import copy import copy
import heapq import heapq
from colossalai.builder import build_model, build_layer from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
...@@ -40,6 +39,7 @@ def _binary_partition(weights, st, ed): ...@@ -40,6 +39,7 @@ def _binary_partition(weights, st, ed):
def _heap_addition(weights, intervals, add_cnt): def _heap_addition(weights, intervals, add_cnt):
""" """
""" """
def _heap_push(heap, st, ed): def _heap_push(heap, st, ed):
value = weights[ed - 1] value = weights[ed - 1]
if st > 0: if st > 0:
...@@ -162,7 +162,10 @@ def count_layer_params(layers): ...@@ -162,7 +162,10 @@ def count_layer_params(layers):
return param_counts return param_counts
def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: str = 'parameter', verbose: bool = False): def build_pipeline_model_from_cfg(config,
num_chunks: int = 1,
partition_method: str = 'parameter',
verbose: bool = False):
"""An initializer to split the model into different stages for pipeline parallelism. """An initializer to split the model into different stages for pipeline parallelism.
An example for the model config is shown below. The class VisionTransformerFromConfig should An example for the model config is shown below. The class VisionTransformerFromConfig should
...@@ -218,7 +221,7 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: ...@@ -218,7 +221,7 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method:
log_str += f'\n===== stage={stage}, layers={num_layers} =====\n' log_str += f'\n===== stage={stage}, layers={num_layers} =====\n'
for st, ed in parts[stage]: for st, ed in parts[stage]:
for idx, layer in enumerate(layers[st: ed]): for idx, layer in enumerate(layers[st:ed]):
log_str += f'\t{idx + st:2d}: {layer}\n' log_str += f'\t{idx + st:2d}: {layer}\n'
logger.info(log_str, ranks=[0]) logger.info(log_str, ranks=[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment