Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2238758c
Unverified
Commit
2238758c
authored
Apr 25, 2022
by
Frank Lee
Committed by
GitHub
Apr 25, 2022
Browse files
[usability] improved error messages in the context module (#856)
parent
9fdebadd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
18 deletions
+14
-18
colossalai/context/parallel_context.py
colossalai/context/parallel_context.py
+14
-12
colossalai/context/process_group_initializer/initializer_2p5d.py
...lai/context/process_group_initializer/initializer_2p5d.py
+0
-6
No files found.
colossalai/context/parallel_context.py
View file @
2238758c
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
random
import
random
import
socket
import
socket
from
collections
import
Counter
from
collections
import
Counter
from
threading
import
local
from
typing
import
Union
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -93,7 +94,8 @@ class ParallelContext(metaclass=SingletonMeta):
...
@@ -93,7 +94,8 @@ class ParallelContext(metaclass=SingletonMeta):
@
staticmethod
@
staticmethod
def
_check_parallel_mode
(
parallel_mode
:
ParallelMode
):
def
_check_parallel_mode
(
parallel_mode
:
ParallelMode
):
assert
isinstance
(
parallel_mode
,
ParallelMode
)
assert
isinstance
(
parallel_mode
,
ParallelMode
),
\
f
'expected the argument parallel_mode to be of enum ParallelMode, but got
{
type
(
parallel_mode
)
}
'
def
get_global_rank
(
self
):
def
get_global_rank
(
self
):
"""Returns the global rank of the current device.
"""Returns the global rank of the current device.
...
@@ -133,7 +135,7 @@ class ParallelContext(metaclass=SingletonMeta):
...
@@ -133,7 +135,7 @@ class ParallelContext(metaclass=SingletonMeta):
self
.
_check_parallel_mode
(
parallel_mode
)
self
.
_check_parallel_mode
(
parallel_mode
)
return
self
.
_local_ranks
[
parallel_mode
]
return
self
.
_local_ranks
[
parallel_mode
]
def
add_local_rank
(
self
,
parallel_mode
:
ParallelMode
,
rank
:
int
):
def
_
add_local_rank
(
self
,
parallel_mode
:
ParallelMode
,
rank
:
int
):
"""Adds the local rank of the current device for `parallel_mode` to the context.
"""Adds the local rank of the current device for `parallel_mode` to the context.
Args:
Args:
...
@@ -257,11 +259,11 @@ class ParallelContext(metaclass=SingletonMeta):
...
@@ -257,11 +259,11 @@ class ParallelContext(metaclass=SingletonMeta):
self
.
_check_parallel_mode
(
parallel_mode
)
self
.
_check_parallel_mode
(
parallel_mode
)
return
self
.
_world_sizes
[
parallel_mode
]
return
self
.
_world_sizes
[
parallel_mode
]
def
add_world_size
(
self
,
parallel_mode
:
ParallelMode
,
world_size
:
int
):
def
_
add_world_size
(
self
,
parallel_mode
:
ParallelMode
,
world_size
:
int
):
"""Adds world size for `parallel_mode`.
"""Adds world size for `parallel_mode`.
Args:
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The
chosen
parallel mode
.
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode
correponding to the process group
world_size (int): The world size to be added
world_size (int): The world size to be added
Raises:
Raises:
...
@@ -287,7 +289,7 @@ class ParallelContext(metaclass=SingletonMeta):
...
@@ -287,7 +289,7 @@ class ParallelContext(metaclass=SingletonMeta):
self
.
_check_parallel_mode
(
parallel_mode
)
self
.
_check_parallel_mode
(
parallel_mode
)
return
self
.
_groups
[
parallel_mode
]
return
self
.
_groups
[
parallel_mode
]
def
add_group
(
self
,
parallel_mode
:
ParallelMode
,
group
:
dist
.
ProcessGroup
):
def
_
add_group
(
self
,
parallel_mode
:
ParallelMode
,
group
:
dist
.
ProcessGroup
):
"""Adds the group of the current device for `parallel_mode`.
"""Adds the group of the current device for `parallel_mode`.
Args:
Args:
...
@@ -314,7 +316,7 @@ class ParallelContext(metaclass=SingletonMeta):
...
@@ -314,7 +316,7 @@ class ParallelContext(metaclass=SingletonMeta):
self
.
_check_parallel_mode
(
parallel_mode
)
self
.
_check_parallel_mode
(
parallel_mode
)
return
self
.
_cpu_groups
[
parallel_mode
]
return
self
.
_cpu_groups
[
parallel_mode
]
def
add_cpu_group
(
self
,
parallel_mode
:
ParallelMode
,
group
:
dist
.
ProcessGroup
):
def
_
add_cpu_group
(
self
,
parallel_mode
:
ParallelMode
,
group
:
dist
.
ProcessGroup
):
"""Adds the Gloo group of the current device for `parallel_mode`.
"""Adds the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:param parallel_mode: The chosen parallel mode
...
@@ -343,7 +345,7 @@ class ParallelContext(metaclass=SingletonMeta):
...
@@ -343,7 +345,7 @@ class ParallelContext(metaclass=SingletonMeta):
self
.
_check_parallel_mode
(
parallel_mode
)
self
.
_check_parallel_mode
(
parallel_mode
)
return
self
.
_ranks_in_group
[
parallel_mode
]
return
self
.
_ranks_in_group
[
parallel_mode
]
def
add_ranks_in_group
(
self
,
parallel_mode
:
ParallelMode
,
ranks
:
list
):
def
_
add_ranks_in_group
(
self
,
parallel_mode
:
ParallelMode
,
ranks
:
list
):
"""Adds the ranks of the current device for `parallel_mode` in the group.
"""Adds the ranks of the current device for `parallel_mode` in the group.
Args:
Args:
...
@@ -378,11 +380,11 @@ class ParallelContext(metaclass=SingletonMeta):
...
@@ -378,11 +380,11 @@ class ParallelContext(metaclass=SingletonMeta):
self
.
add_global_rank
(
ParallelMode
.
GLOBAL
,
rank
)
self
.
add_global_rank
(
ParallelMode
.
GLOBAL
,
rank
)
def
_register_dist
(
self
,
local_rank
,
world_size
,
process_group
,
cpu_group
,
ranks_in_group
,
mode
):
def
_register_dist
(
self
,
local_rank
,
world_size
,
process_group
,
cpu_group
,
ranks_in_group
,
mode
):
self
.
add_local_rank
(
mode
,
local_rank
)
self
.
_
add_local_rank
(
mode
,
local_rank
)
self
.
add_world_size
(
mode
,
world_size
)
self
.
_
add_world_size
(
mode
,
world_size
)
self
.
add_group
(
mode
,
process_group
)
self
.
_
add_group
(
mode
,
process_group
)
self
.
add_cpu_group
(
mode
,
cpu_group
)
self
.
_
add_cpu_group
(
mode
,
cpu_group
)
self
.
add_ranks_in_group
(
mode
,
ranks_in_group
)
self
.
_
add_ranks_in_group
(
mode
,
ranks_in_group
)
def
check_sanity
(
self
):
def
check_sanity
(
self
):
"""Checks sanity of the parallel context.
"""Checks sanity of the parallel context.
...
...
colossalai/context/process_group_initializer/initializer_2p5d.py
View file @
2238758c
...
@@ -105,8 +105,6 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
...
@@ -105,8 +105,6 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
self
.
num_group
=
self
.
world_size
//
self
.
tensor_parallel_size
self
.
num_group
=
self
.
world_size
//
self
.
tensor_parallel_size
self
.
tesseract_dep
=
tesseract_dep
self
.
tesseract_dep
=
tesseract_dep
self
.
tesseract_dim
=
tesseract_dim
self
.
tesseract_dim
=
tesseract_dim
assert
self
.
tensor_parallel_size
==
self
.
tesseract_dim
**
2
*
self
.
tesseract_dep
,
\
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def
init_dist_group
(
self
):
def
init_dist_group
(
self
):
"""Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
"""Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
...
@@ -161,8 +159,6 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
...
@@ -161,8 +159,6 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
self
.
num_group
=
self
.
world_size
//
self
.
tensor_parallel_size
self
.
num_group
=
self
.
world_size
//
self
.
tensor_parallel_size
self
.
tesseract_dep
=
tesseract_dep
self
.
tesseract_dep
=
tesseract_dep
self
.
tesseract_dim
=
tesseract_dim
self
.
tesseract_dim
=
tesseract_dim
assert
self
.
tensor_parallel_size
==
self
.
tesseract_dim
**
2
*
self
.
tesseract_dep
,
\
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def
init_dist_group
(
self
):
def
init_dist_group
(
self
):
"""Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
"""Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
...
@@ -218,8 +214,6 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
...
@@ -218,8 +214,6 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
self
.
num_group
=
self
.
world_size
//
self
.
tensor_parallel_size
self
.
num_group
=
self
.
world_size
//
self
.
tensor_parallel_size
self
.
tesseract_dep
=
tesseract_dep
self
.
tesseract_dep
=
tesseract_dep
self
.
tesseract_dim
=
tesseract_dim
self
.
tesseract_dim
=
tesseract_dim
assert
self
.
tensor_parallel_size
==
self
.
tesseract_dim
**
2
*
self
.
tesseract_dep
,
\
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def
init_dist_group
(
self
):
def
init_dist_group
(
self
):
"""Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
"""Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment