Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
4d26a67c
Unverified
Commit
4d26a67c
authored
Apr 12, 2024
by
Amy Yang
Committed by
GitHub
Apr 12, 2024
Browse files
sync fbcode cp pg initialize (#1177)
Co-authored-by:
amyyang
<
amyyang@meta.com
>
parent
8fb39b2a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
6 deletions
+21
-6
fairscale/nn/model_parallel/initialize.py
fairscale/nn/model_parallel/initialize.py
+21
-6
No files found.
fairscale/nn/model_parallel/initialize.py
View file @
4d26a67c
...
...
@@ -20,8 +20,6 @@
# limitations under the License.
"""Model and data parallel groups."""
from
typing
import
List
,
Optional
import
torch
...
...
@@ -41,13 +39,13 @@ _CONTEXT_PARALLEL_GROUP_RANKS = None
def
initialize_model_parallel
(
model_parallel_size
:
int
,
context_parallel_size
:
int
=
1
,
model_parallel_size_
:
int
,
pipeline_length
:
int
=
1
,
context_parallel_size
:
int
=
1
,
*
,
model_parallel_backend
:
Optional
[
str
]
=
None
,
cp_backend
:
Optional
[
str
]
=
None
,
pipeline_backend
:
Optional
[
str
]
=
None
,
cp_backend
:
Optional
[
str
]
=
None
,
ddp_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
...
...
@@ -67,11 +65,28 @@ def initialize_model_parallel(
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
process groups initialized in the order of MP, CP, PP, DP.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups
and 8 data-parallel groups as:
when alternate_pp_config = False,
8 data_parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 context-parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 pipeline model-parallel groups:
[g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
"""
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
model_parallel_size
=
int
(
min
(
model_parallel_size
,
world_size
))
model_parallel_size
=
int
(
min
(
model_parallel_size
_
,
world_size
))
ensure_divisibility
(
world_size
,
model_parallel_size
)
ensure_divisibility
(
world_size
,
context_parallel_size
)
ensure_divisibility
(
world_size
,
model_parallel_size
*
pipeline_length
*
context_parallel_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment