"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "afb239bbf83737655bf6b6baef2e261768d5c60f"
Commit 964a2867 authored by Kai Wang (Victor Kai)'s avatar Kai Wang (Victor Kai) Committed by binmakeswell
Browse files

[NFC] polish initializer_3d.py code style (#3279)

parent 94eec1c5
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import math import math
import torch.distributed as dist import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
...@@ -213,7 +214,8 @@ class Initializer_3D_InputxWeight(ProcessGroupInitializer): ...@@ -213,7 +214,8 @@ class Initializer_3D_InputxWeight(ProcessGroupInitializer):
for h in range(self.num_group): for h in range(self.num_group):
for k in range(self.depth): for k in range(self.depth):
ranks = [ ranks = [
h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth) h * self.depth**3 + i + self.depth * (j + self.depth * k)
for j in range(self.depth)
for i in range(self.depth) for i in range(self.depth)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
...@@ -266,7 +268,8 @@ class Initializer_3D_OutputxWeight(ProcessGroupInitializer): ...@@ -266,7 +268,8 @@ class Initializer_3D_OutputxWeight(ProcessGroupInitializer):
for h in range(self.num_group): for h in range(self.num_group):
for j in range(self.depth): for j in range(self.depth):
ranks = [ ranks = [
h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth) h * self.depth**3 + i + self.depth * (j + self.depth * k)
for k in range(self.depth)
for i in range(self.depth) for i in range(self.depth)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment