Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
77753d0a
Commit
77753d0a
authored
Sep 27, 2022
by
Jared Casper
Browse files
Small fixes.
parent
5f4ddd9b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
9 deletions
+10
-9
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+5
-5
megatron/core/tensor_parallel/utils.py
megatron/core/tensor_parallel/utils.py
+5
-4
No files found.
megatron/core/parallel_state.py
View file @
77753d0a
...
@@ -174,14 +174,14 @@ def initialize_model_parallel(
...
@@ -174,14 +174,14 @@ def initialize_model_parallel(
if
len
(
ranks
)
>
1
:
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
position_embedding_ranks
=
[
ranks
[
0
]]
position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank
_
is
not
None
:
if
pipeline_model_parallel_split_rank
is
not
None
:
if
ranks
[
pipeline_model_parallel_split_rank
_
]
not
in
embedding_ranks
:
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
_
],
ranks
[
pipeline_model_parallel_split_rank
],
ranks
[
-
1
]]
ranks
[
-
1
]]
if
ranks
[
pipeline_model_parallel_split_rank
_
]
not
in
position_embedding_ranks
:
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
position_embedding_ranks
:
position_embedding_ranks
=
[
ranks
[
0
],
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
_
]]
ranks
[
pipeline_model_parallel_split_rank
]]
else
:
else
:
embedding_ranks
=
ranks
embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
...
...
megatron/core/tensor_parallel/utils.py
View file @
77753d0a
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
typing
import
List
,
Sequence
from
typing
import
List
,
Sequence
from
megatron.core.utils
import
divide
from
megatron.core.utils
import
divide
from
megatron.core
import
parallel_state
def
split_tensor_along_last_dim
(
def
split_tensor_along_last_dim
(
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
...
@@ -47,8 +48,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
...
@@ -47,8 +48,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""
"""
partition_size
=
torch
.
numel
(
tensor
)
//
\
partition_size
=
torch
.
numel
(
tensor
)
//
\
get_tensor_model_parallel_world_size
()
parallel_state
.
get_tensor_model_parallel_world_size
()
start_index
=
partition_size
*
get_tensor_model_parallel_rank
()
start_index
=
partition_size
*
parallel_state
.
get_tensor_model_parallel_rank
()
end_index
=
start_index
+
partition_size
end_index
=
start_index
+
partition_size
if
new_buffer
:
if
new_buffer
:
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
data
=
torch
.
empty
(
partition_size
,
dtype
=
tensor
.
dtype
,
...
@@ -70,7 +71,7 @@ def gather_split_1d_tensor(tensor):
...
@@ -70,7 +71,7 @@ def gather_split_1d_tensor(tensor):
tensor: A Tensor or view of this rank's portion of the data.
tensor: A Tensor or view of this rank's portion of the data.
"""
"""
numel_gathered
=
torch
.
numel
(
tensor
)
*
\
numel_gathered
=
torch
.
numel
(
tensor
)
*
\
get_tensor_model_parallel_world_size
()
parallel_state
.
get_tensor_model_parallel_world_size
()
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -80,7 +81,7 @@ def gather_split_1d_tensor(tensor):
...
@@ -80,7 +81,7 @@ def gather_split_1d_tensor(tensor):
# This API calls directly NCCL all-gather versus the former does
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
# internal copies and can potentially cause slow down.
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
group
=
get_tensor_model_parallel_group
())
group
=
parallel_state
.
get_tensor_model_parallel_group
())
return
gathered
return
gathered
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment