Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e3e5ea89
Commit
e3e5ea89
authored
Jan 20, 2021
by
Deepak Narayanan
Browse files
Compute tensor chunk size more cleanly, and add assertion for global batch size
parent
27fc4689
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
4 deletions
+9
-4
megatron/arguments.py
megatron/arguments.py
+5
-1
megatron/p2p_communication.py
megatron/p2p_communication.py
+4
-3
No files found.
megatron/arguments.py
View file @
e3e5ea89
...
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
' divisible by tensor parallel size ({}) times pipeline paralle '
\
' divisible by tensor parallel size ({}) times pipeline paralle
l
'
\
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
pipeline_model_parallel_size
)
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
...
@@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -116,6 +116,10 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'setting global batch size to {}'
.
format
(
print
(
'setting global batch size to {}'
.
format
(
args
.
global_batch_size
),
flush
=
True
)
args
.
global_batch_size
),
flush
=
True
)
assert
args
.
global_batch_size
>
0
assert
args
.
global_batch_size
>
0
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
global_batch_size
%
args
.
pipeline_model_parallel_size
==
0
,
\
'global batch size is not divisible by pipeline parallel size when '
\
'using interleaved schedule'
# Parameters dtype.
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
args
.
params_dtype
=
torch
.
float
...
...
megatron/p2p_communication.py
View file @
e3e5ea89
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
reduce
import
operator
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
...
@@ -30,9 +32,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -30,9 +32,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_recv_next
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
scatter_gather_tensors_in_pipeline
:
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
(
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
args
.
seq_length
*
args
.
micro_batch_size
*
args
.
hidden_size
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
mpu
.
get_tensor_model_parallel_world_size
()
else
:
else
:
tensor_chunk_shape
=
tensor_shape
tensor_chunk_shape
=
tensor_shape
dtype
=
args
.
params_dtype
dtype
=
args
.
params_dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment