Unverified Commit f83c4d65 authored by runluo's avatar runluo Committed by GitHub
Browse files

[NFC] polish colossalai/nn/layer/wrapper/pipeline_wrapper.py code style (#1303)

parent 7696cead
......@@ -6,6 +6,7 @@ from colossalai.core import global_context as gpc
class PipelineSharedModuleWrapper:
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
self.pipeline_ranks = pipeline_ranks
......@@ -22,10 +23,7 @@ class PipelineSharedModuleWrapper:
num_pp_stages = num_dp_groups // pp_size
for i in range(dp_size):
for j in range(num_pp_stages):
pipeline_ranks = list(
range(i * num_dp_groups + j,
(i + 1) * num_dp_groups,
num_pp_stages))
pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages))
sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks]
group = dist.new_group(sub_ranks)
if rank in sub_ranks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment