Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
0af41aee
Unverified
Commit
0af41aee
authored
Apr 04, 2024
by
Amy Yang
Committed by
GitHub
Apr 04, 2024
Browse files
add context parallel group init to mp init (#1174)
Co-authored-by:
amyyang
<
amyyang@meta.com
>
parent
9a173bf2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
80 additions
and
28 deletions
+80
-28
fairscale/nn/model_parallel/initialize.py
fairscale/nn/model_parallel/initialize.py
+80
-28
No files found.
fairscale/nn/model_parallel/initialize.py
View file @
0af41aee
...
@@ -34,17 +34,21 @@ _MODEL_PARALLEL_GROUP = None
...
@@ -34,17 +34,21 @@ _MODEL_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
# Pipeline parallel group that the current rank belongs to.
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP
=
None
_PIPELINE_PARALLEL_GROUP
=
None
_PIPELINE_PARALLEL_RANKS
=
None
_PIPELINE_PARALLEL_RANKS
=
None
_CONTEXT_PARALLEL_GROUP
=
None
_CONTEXT_PARALLEL_GROUP_RANKS
=
None
def
initialize_model_parallel
(
def
initialize_model_parallel
(
model_parallel_size_
:
int
,
model_parallel_size
:
int
,
context_parallel_size
:
int
=
1
,
pipeline_length
:
int
=
1
,
pipeline_length
:
int
=
1
,
*
,
*
,
model_parallel_backend
:
Optional
[
str
]
=
None
,
model_parallel_backend
:
Optional
[
str
]
=
None
,
cp_backend
:
Optional
[
str
]
=
None
,
pipeline_backend
:
Optional
[
str
]
=
None
,
pipeline_backend
:
Optional
[
str
]
=
None
,
ddp_backend
:
Optional
[
str
]
=
None
ddp_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
Initialize model data parallel groups.
Initialize model data parallel groups.
...
@@ -67,19 +71,21 @@ def initialize_model_parallel(
...
@@ -67,19 +71,21 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
model_parallel_size
=
int
(
min
(
model_parallel_size
_
,
world_size
))
model_parallel_size
=
int
(
min
(
model_parallel_size
,
world_size
))
ensure_divisibility
(
world_size
,
model_parallel_size
)
ensure_divisibility
(
world_size
,
model_parallel_size
)
ensure_divisibility
(
world_size
,
model_parallel_size
*
pipeline_length
)
ensure_divisibility
(
world_size
,
context_parallel_size
)
ensure_divisibility
(
world_size
,
model_parallel_size
*
pipeline_length
*
context_parallel_size
)
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
data_parallel_size
=
int
(
world_size
/
(
model_parallel_size
*
pipeline_length
))
data_parallel_size
=
int
(
world_size
/
(
model_parallel_size
*
pipeline_length
*
context_parallel_size
))
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> initializing model parallel with size {}"
.
format
(
model_parallel_size
_
))
print
(
"> initializing model parallel with size {}"
.
format
(
model_parallel_size
))
print
(
"> initializing
ddp
with size {}"
.
format
(
data
_parallel_size
))
print
(
"> initializing
context parallel
with size {}"
.
format
(
context
_parallel_size
))
print
(
"> initializing pipeline with size {}"
.
format
(
pipeline_length
))
print
(
"> initializing pipeline with size {}"
.
format
(
pipeline_length
))
print
(
"> initializing ddp with size {}"
.
format
(
data_parallel_size
))
groups
=
torch
.
LongTensor
(
range
(
world_size
)).
reshape
(
data_parallel_size
,
pipeline_length
,
model_parallel_size
)
groups
=
torch
.
LongTensor
(
range
(
world_size
)).
reshape
(
data_parallel_size
,
pipeline_length
,
context_parallel_size
,
model_parallel_size
)
found
=
torch
.
where
(
groups
==
rank
)
found
=
torch
.
where
(
groups
==
rank
)
assert
all
(
len
(
x
)
==
1
for
x
in
found
)
assert
all
(
len
(
x
)
==
1
for
x
in
found
)
...
@@ -88,41 +94,81 @@ def initialize_model_parallel(
...
@@ -88,41 +94,81 @@ def initialize_model_parallel(
# Build the data parallel groups.
# Build the data parallel groups.
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
assert
_DATA_PARALLEL_GROUP
is
None
,
"data parallel group is already initialized"
assert
_DATA_PARALLEL_GROUP
is
None
,
"data parallel group is already initialized"
for
j
in
range
(
pipeline_length
):
for
i
in
range
(
pipeline_length
):
for
k
in
range
(
model_parallel_size
):
for
j
in
range
(
context_parallel_size
):
group
=
torch
.
distributed
.
new_group
(
groups
[:,
j
,
k
].
tolist
(),
backend
=
ddp_backend
)
for
k
in
range
(
model_parallel_size
):
if
j
==
found
[
1
]
and
k
==
found
[
2
]:
group
=
torch
.
distributed
.
new_group
(
groups
[:,
i
,
j
,
k
].
tolist
(),
backend
=
ddp_backend
)
_DATA_PARALLEL_GROUP
=
group
if
i
==
found
[
1
]
and
j
==
found
[
2
]
and
k
==
found
[
3
]:
_DATA_PARALLEL_GROUP
=
group
# Build the model parallel groups.
# Build the model parallel groups.
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
"
m
odel parallel group is already initialized"
assert
_MODEL_PARALLEL_GROUP
is
None
,
"
M
odel parallel group is already initialized"
for
i
in
range
(
data_parallel_size
):
for
i
in
range
(
data_parallel_size
):
for
j
in
range
(
pipeline_length
):
for
j
in
range
(
pipeline_length
):
group
=
torch
.
distributed
.
new_group
(
groups
[
i
,
j
,
:].
tolist
(),
backend
=
model_parallel_backend
)
for
k
in
range
(
context_parallel_size
):
if
i
==
found
[
0
]
and
j
==
found
[
1
]:
group
=
torch
.
distributed
.
new_group
(
groups
[
i
,
j
,
k
,
:].
tolist
(),
backend
=
model_parallel_backend
)
_MODEL_PARALLEL_GROUP
=
group
if
i
==
found
[
0
]
and
j
==
found
[
1
]
and
k
==
found
[
2
]:
_MODEL_PARALLEL_GROUP
=
group
# Build the pipeline parallel groups.
global
_PIPELINE_PARALLEL_GROUP
global
_PIPELINE_PARALLEL_GROUP
assert
_PIPELINE_PARALLEL_GROUP
is
None
,
"model parallel group is already initialized"
global
_PIPELINE_PARALLEL_RANKS
global
_PIPELINE_PARALLEL_RANKS
assert
_PIPELINE_PARALLEL_RANKS
is
None
,
"model parallel group is already initialized"
assert
_PIPELINE_PARALLEL_GROUP
is
None
,
"Pipeline parallel group is already initialized"
for
i
in
range
(
data_parallel_size
):
for
j
in
range
(
context_parallel_size
):
for
k
in
range
(
model_parallel_size
):
ranks
=
groups
[
i
,
:,
j
,
k
].
tolist
()
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
pipeline_backend
)
if
i
==
found
[
0
]
and
j
==
found
[
2
]
and
k
==
found
[
3
]:
_PIPELINE_PARALLEL_GROUP
=
group
_PIPELINE_PARALLEL_RANKS
=
ranks
# Build the context parallel groups.
global
_CONTEXT_PARALLEL_GROUP
global
_CONTEXT_PARALLEL_GROUP_RANKS
assert
(
_CONTEXT_PARALLEL_GROUP
is
None
),
"Context parallelism is already initialized."
for
i
in
range
(
data_parallel_size
):
for
i
in
range
(
data_parallel_size
):
for
k
in
range
(
model_parallel_size
):
for
j
in
range
(
pipeline_length
):
ranks
=
groups
[
i
,
:,
k
].
tolist
()
for
k
in
range
(
model_parallel_size
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
pipeline_backend
)
ranks
=
groups
[
i
,
j
,
:,
k
].
tolist
()
if
i
==
found
[
0
]
and
k
==
found
[
2
]:
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
cp_backend
)
_PIPELINE_PARALLEL_GROUP
=
group
if
i
==
found
[
0
]
and
j
==
found
[
1
]
and
k
==
found
[
3
]:
_PIPELINE_PARALLEL_RANKS
=
ranks
_CONTEXT_PARALLEL_GROUP
=
group
_CONTEXT_PARALLEL_GROUP_RANKS
=
ranks
def
model_parallel_is_initialized
()
->
bool
:
def
model_parallel_is_initialized
()
->
bool
:
"""Check if model and data parallel groups are initialized."""
"""Check if model and data parallel groups are initialized."""
if
_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
or
_PIPELINE_PARALLEL_GROUP
is
None
:
if
_MODEL_PARALLEL_GROUP
is
None
or
_DATA_PARALLEL_GROUP
is
None
or
_PIPELINE_PARALLEL_GROUP
is
None
or
_CONTEXT_PARALLEL_GROUP
is
None
:
return
False
return
False
return
True
return
True
def
get_context_parallel_group
()
->
torch
.
distributed
.
ProcessGroup
:
"""Get the context parallel group the caller rank belongs to."""
assert
(
_CONTEXT_PARALLEL_GROUP
is
not
None
),
"context parallel group is not initialized"
return
_CONTEXT_PARALLEL_GROUP
def
get_context_parallel_world_size
()
->
int
:
"""Return world size for the context parallel group."""
return
torch
.
distributed
.
get_world_size
(
group
=
get_context_parallel_group
())
def
get_context_parallel_rank
()
->
int
:
"""Return my rank for the context parallel group."""
return
torch
.
distributed
.
get_rank
(
group
=
get_context_parallel_group
())
def
get_model_parallel_group
()
->
torch
.
distributed
.
ProcessGroup
:
def
get_model_parallel_group
()
->
torch
.
distributed
.
ProcessGroup
:
"""Get the model parallel group the caller rank belongs to."""
"""Get the model parallel group the caller rank belongs to."""
assert
_MODEL_PARALLEL_GROUP
is
not
None
,
"model parallel group is not initialized"
assert
_MODEL_PARALLEL_GROUP
is
not
None
,
"model parallel group is not initialized"
...
@@ -179,10 +225,16 @@ def destroy_model_parallel() -> None:
...
@@ -179,10 +225,16 @@ def destroy_model_parallel() -> None:
"""Set the groups to none."""
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP
=
None
_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
global
_PIPELINE_PARALLEL_GROUP
global
_PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP
=
None
_PIPELINE_PARALLEL_GROUP
=
None
global
_PIPELINE_PARALLEL_RANKS
global
_PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS
=
None
_PIPELINE_PARALLEL_RANKS
=
None
global
_CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP
=
None
global
_CONTEXT_PARALLEL_GROUP_RANKS
_CONTEXT_PARALLEL_GROUP_RANKS
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment