Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5489bda9
Commit
5489bda9
authored
Feb 13, 2021
by
Deepak Narayanan
Browse files
More comments and some cleanup (e.g., better variable names)
parent
626645c0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
36 deletions
+85
-36
megatron/model/transformer.py
megatron/model/transformer.py
+13
-1
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+13
-9
megatron/p2p_communication.py
megatron/p2p_communication.py
+42
-10
megatron/schedules.py
megatron/schedules.py
+12
-11
megatron/training.py
megatron/training.py
+5
-5
No files found.
megatron/model/transformer.py
View file @
5489bda9
...
@@ -554,12 +554,24 @@ class ParallelTransformer(MegatronModule):
...
@@ -554,12 +554,24 @@ class ParallelTransformer(MegatronModule):
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self_attn_mask_type
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by virtual_pipeline_model_parallel_size'
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
else
:
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
...
megatron/mpu/initialize.py
View file @
5489bda9
...
@@ -271,10 +271,8 @@ def get_pipeline_model_parallel_rank():
...
@@ -271,10 +271,8 @@ def get_pipeline_model_parallel_rank():
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
if
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
\
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
get_virtual_pipeline_model_parallel_rank
()
!=
0
:
if
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
and
\
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
!=
0
:
return
False
return
False
return
get_pipeline_model_parallel_rank
()
==
0
return
get_pipeline_model_parallel_rank
()
==
0
...
@@ -282,11 +280,11 @@ def is_pipeline_first_stage(ignore_virtual=False):
...
@@ -282,11 +280,11 @@ def is_pipeline_first_stage(ignore_virtual=False):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
not
ignore_virtual
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
virtual_pipeline_model_parallel_world_size
=
\
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
get_virtual_pipeline_model_parallel_world_size
()
if
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
and
\
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
\
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
!=
(
get_virtual_pipeline_model_parallel_rank
()
!=
(
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
-
1
):
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
False
return
get_pipeline_model_parallel_rank
()
==
(
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
get_pipeline_model_parallel_world_size
()
-
1
)
...
@@ -304,6 +302,12 @@ def set_virtual_pipeline_model_parallel_rank(rank):
...
@@ -304,6 +302,12 @@ def set_virtual_pipeline_model_parallel_rank(rank):
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
get_virtual_pipeline_model_parallel_world_size
():
"""Return the virtual pipeline-parallel world size."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
get_tensor_model_parallel_src_rank
():
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
in the tensor model parallel group."""
...
...
megatron/p2p_communication.py
View file @
5489bda9
...
@@ -23,7 +23,24 @@ from megatron import mpu
...
@@ -23,7 +23,24 @@ from megatron import mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
use_ring_exchange
=
False
):
"""Communicate tensors between stages."""
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args
=
get_args
()
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# Create placeholder tensors for receive in forward and backward directions
...
@@ -50,6 +67,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -50,6 +67,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
@@ -67,27 +85,32 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -67,27 +85,32 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
else
:
else
:
ops
=
[]
ops
=
[]
if
tensor_send_prev
is
not
None
:
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
send_prev_op
=
torch
.
distributed
.
P2POp
(
mpu
.
get_pipeline_model_parallel_prev_rank
())
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
recv_prev_op
=
torch
.
distributed
.
P2POp
(
mpu
.
get_pipeline_model_parallel_prev_rank
())
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
send_next_op
=
torch
.
distributed
.
P2POp
(
mpu
.
get_pipeline_model_parallel_next_rank
())
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
recv_next_op
=
torch
.
distributed
.
P2POp
(
mpu
.
get_pipeline_model_parallel_next_rank
())
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
ops
.
append
(
recv_next_op
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
for
req
in
reqs
:
req
.
wait
()
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
tensor_recv_prev_before
=
tensor_recv_prev
# If using scatter-gather optimization, gather smaller chunks.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
...
@@ -101,6 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...
@@ -101,6 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
def
recv_forward
(
timers
=
None
,
use_ring_exchange
=
False
):
def
recv_forward
(
timers
=
None
,
use_ring_exchange
=
False
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
...
@@ -118,6 +142,7 @@ def recv_forward(timers=None, use_ring_exchange=False):
...
@@ -118,6 +142,7 @@ def recv_forward(timers=None, use_ring_exchange=False):
def
recv_backward
(
timers
=
None
,
use_ring_exchange
=
False
):
def
recv_backward
(
timers
=
None
,
use_ring_exchange
=
False
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
...
@@ -135,6 +160,7 @@ def recv_backward(timers=None, use_ring_exchange=False):
...
@@ -135,6 +160,7 @@ def recv_backward(timers=None, use_ring_exchange=False):
def
send_forward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
def
send_forward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
timers
(
'forward-send'
).
start
()
...
@@ -149,6 +175,7 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
...
@@ -149,6 +175,7 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
def
send_backward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
def
send_backward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send'
).
start
()
timers
(
'backward-send'
).
start
()
...
@@ -163,6 +190,7 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
...
@@ -163,6 +190,7 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
output_tensor_grad
=
None
else
:
else
:
...
@@ -180,6 +208,7 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
...
@@ -180,6 +208,7 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
...
@@ -197,6 +226,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
...
@@ -197,6 +226,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
timers
(
'forward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
input_tensor
,
_
=
_communicate
(
...
@@ -211,6 +241,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
...
@@ -211,6 +241,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
timers
(
'backward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
_
,
output_tensor_grad
=
_communicate
(
...
@@ -227,6 +258,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
...
@@ -227,6 +258,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def
send_forward_backward_recv_forward_backward
(
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
timers
=
None
):
recv_next
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
input_tensor
,
output_tensor_grad
=
_communicate
(
...
...
megatron/schedules.py
View file @
5489bda9
...
@@ -136,19 +136,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -136,19 +136,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
num_microbatches_remaining
=
\
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
k
,
forward
):
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
"""Helper method to get the model chunk ID given the iteration number."""
k_in_group
=
k
%
(
pipeline_parallel_size
*
num_model_chunks
)
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
i
=
k
_in_group
//
pipeline_parallel_size
model_chunk_id
=
microbatch_id
_in_group
//
pipeline_parallel_size
if
not
forward
:
if
not
forward
:
i
=
(
num_model_chunks
-
i
-
1
)
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
return
i
return
model_chunk_id
def
forward_step_helper
(
k
):
def
forward_step_helper
(
microbatch_id
):
"""Helper method to run forward step with model split into chunks
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
True
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
@@ -164,11 +164,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -164,11 +164,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return
output_tensor
return
output_tensor
def
backward_step_helper
(
k
):
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
k
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
...
@@ -317,8 +317,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -317,8 +317,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return
losses_reduced
return
losses_reduced
def
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
optimizer
,
timers
,
forward_only
):
model
,
optimizer
,
timers
,
forward_only
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
stages.
...
...
megatron/training.py
View file @
5489bda9
...
@@ -50,7 +50,7 @@ from megatron.utils import unwrap_model
...
@@ -50,7 +50,7 @@ from megatron.utils import unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
forward_backward_no_pipelining
from
megatron.schedules
import
forward_backward_no_pipelining
from
megatron.schedules
import
forward_backward_pipelining
from
megatron.schedules
import
forward_backward_pipelining
_without_interleaving
from
megatron.schedules
import
forward_backward_pipelining_with_interleaving
from
megatron.schedules
import
forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
...
@@ -340,7 +340,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -340,7 +340,7 @@ def train_step(forward_step_func, data_iterator,
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
else
:
forward_backward_func
=
forward_backward_pipelining
forward_backward_func
=
forward_backward_pipelining
_without_interleaving
else
:
else
:
forward_backward_func
=
forward_backward_no_pipelining
forward_backward_func
=
forward_backward_no_pipelining
losses_reduced
=
forward_backward_func
(
losses_reduced
=
forward_backward_func
(
...
@@ -681,7 +681,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -681,7 +681,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
else
:
forward_backward_func
=
forward_backward_pipelining
forward_backward_func
=
forward_backward_pipelining
_without_interleaving
else
:
else
:
forward_backward_func
=
forward_backward_no_pipelining
forward_backward_func
=
forward_backward_no_pipelining
loss_dicts
=
forward_backward_func
(
loss_dicts
=
forward_backward_func
(
...
@@ -692,8 +692,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -692,8 +692,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Reduce across processes.
# Reduce across processes.
for
loss_dict
in
loss_dicts
:
for
loss_dict
in
loss_dicts
:
for
key
in
loss_dict
:
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
\
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
loss_dict
[
key
]
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
loss_dict
[
key
]
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
args
.
micro_batch_size
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment