Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7eb87f51
Commit
7eb87f51
authored
Mar 09, 2022
by
Liang Bowen
Committed by
Frank Lee
Mar 11, 2022
Browse files
flake8 style (#352)
parent
54ee8d12
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
4 deletions
+6
-4
colossalai/nn/layer/utils/common.py
colossalai/nn/layer/utils/common.py
+1
-1
colossalai/nn/layer/vanilla/layers.py
colossalai/nn/layer/vanilla/layers.py
+1
-1
colossalai/nn/layer/wrapper/pipeline_wrapper.py
colossalai/nn/layer/wrapper/pipeline_wrapper.py
+4
-2
No files found.
colossalai/nn/layer/utils/common.py
View file @
7eb87f51
...
...
@@ -38,7 +38,7 @@ class CheckpointModule(nn.Module):
def
divide
(
numerator
,
denominator
):
"""Only allow exact division
:param numerator: Numerator of the division
:param denominator: Denominator of the division
"""
...
...
colossalai/nn/layer/vanilla/layers.py
View file @
7eb87f51
...
...
@@ -101,7 +101,7 @@ class WrappedDropPath(nn.Module):
@
LAYERS
.
register_module
class
VanillaPatchEmbedding
(
nn
.
Module
):
"""
"""
2D Image to Patch Embedding
:param img_size: image size
...
...
colossalai/nn/layer/wrapper/pipeline_wrapper.py
View file @
7eb87f51
...
...
@@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
self
.
ranks_in_group
=
sub_ranks
def
register_module
(
self
,
module
:
nn
.
Module
):
assert
self
.
ranks_in_group
is
not
None
,
f
'Rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
is not in pipeline_ranks
{
self
.
pipeline_ranks
}
'
assert
self
.
ranks_in_group
is
not
None
,
\
f
'Rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
is not in pipeline_ranks
{
self
.
pipeline_ranks
}
'
src
=
self
.
ranks_in_group
[
self
.
pipeline_ranks
[
0
]]
for
p
in
module
.
parameters
():
setattr
(
p
,
'pipeline_shared_module_pg'
,
self
.
group
)
dist
.
broadcast
(
p
,
src
,
group
=
self
.
group
)
def
register_parameter
(
self
,
param
:
nn
.
Parameter
):
assert
self
.
ranks_in_group
is
not
None
,
f
'Rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
is not in pipeline_ranks
{
self
.
pipeline_ranks
}
'
assert
self
.
ranks_in_group
is
not
None
,
\
f
'Rank
{
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
}
is not in pipeline_ranks
{
self
.
pipeline_ranks
}
'
src
=
self
.
ranks_in_group
[
self
.
pipeline_ranks
[
0
]]
setattr
(
param
,
'pipeline_shared_module_pg'
,
self
.
group
)
dist
.
broadcast
(
param
,
src
,
group
=
self
.
group
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment