Commit 7eb87f51 authored by Liang Bowen's avatar Liang Bowen Committed by Frank Lee
Browse files

flake8 style (#352)

parent 54ee8d12
......@@ -38,7 +38,7 @@ class CheckpointModule(nn.Module):
def divide(numerator, denominator):
"""Only allow exact division
:param numerator: Numerator of the division
:param denominator: Denominator of the division
"""
......
......@@ -101,7 +101,7 @@ class WrappedDropPath(nn.Module):
@LAYERS.register_module
class VanillaPatchEmbedding(nn.Module):
"""
"""
2D Image to Patch Embedding
:param img_size: image size
......
......@@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
self.ranks_in_group = sub_ranks
def register_module(self, module: nn.Module):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]]
for p in module.parameters():
setattr(p, 'pipeline_shared_module_pg', self.group)
dist.broadcast(p, src, group=self.group)
def register_parameter(self, param: nn.Parameter):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]]
setattr(param, 'pipeline_shared_module_pg', self.group)
dist.broadcast(param, src, group=self.group)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment