Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
62f16817
Commit
62f16817
authored
Jun 10, 2025
by
dongcl
Browse files
evaluate support dualpipev
parent
2385a133
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
323 additions
and
218 deletions
+323
-218
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
...or/features_manager/pipeline_parallel/pipeline_feature.py
+5
-0
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+180
-217
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+138
-1
No files found.
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
View file @
62f16817
...
@@ -48,6 +48,7 @@ class PipelineFeature(AbstractFeature):
...
@@ -48,6 +48,7 @@ class PipelineFeature(AbstractFeature):
train_step
,
train_step
,
_allreduce_embedding_grads_wrapper
_allreduce_embedding_grads_wrapper
)
)
from
dcu_megatron.training.training
import
evaluate
patch_manager
.
register_patch
(
patch_manager
.
register_patch
(
'megatron.training.training.get_model'
,
get_model
)
'megatron.training.training.get_model'
,
get_model
)
...
@@ -64,6 +65,10 @@ class PipelineFeature(AbstractFeature):
...
@@ -64,6 +65,10 @@ class PipelineFeature(AbstractFeature):
patch_manager
.
register_patch
(
patch_manager
.
register_patch
(
'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads'
,
_allreduce_embedding_grads_wrapper
)
'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads'
,
_allreduce_embedding_grads_wrapper
)
# use first rank
patch_manager
.
register_patch
(
'megatron.training.training.evaluate'
,
evaluate
)
if
(
if
(
args
.
schedule_method
==
"interleaved_1f1b"
args
.
schedule_method
==
"interleaved_1f1b"
and
args
.
combined_1f1b
and
args
.
combined_1f1b
...
...
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
62f16817
...
@@ -25,6 +25,8 @@ from megatron.core.pipeline_parallel.schedules import (
...
@@ -25,6 +25,8 @@ from megatron.core.pipeline_parallel.schedules import (
check_first_val_step
,
check_first_val_step
,
finish_embedding_wgrad_compute
finish_embedding_wgrad_compute
)
)
from
dcu_megatron.training.utils
import
print_rank_message
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
...
@@ -584,7 +586,11 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -584,7 +586,11 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_model_chunk_id
,
fwd_model_chunk_id
,
bwd_model_chunk_id
,
bwd_model_chunk_id
,
fwd_input_tensor
=
None
,
fwd_input_tensor
=
None
,
bwd_output_tensor_grad
=
None
bwd_output_tensor_grad
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
):
):
"""Helper method to run combined forward and backward step"""
"""Helper method to run combined forward and backward step"""
# forward prepare
# forward prepare
...
@@ -670,9 +676,9 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -670,9 +676,9 @@ def forward_backward_pipelining_with_cutinhalf(
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
forward_data_store
=
[]
forward_data_store
=
[]
if
not
forward_only
:
input_tensors
=
[[],
[]]
input_tensors
=
[[],
[]]
output_tensors
=
[[],
[]]
output_tensors
=
[[],
[]]
output_tensor_grads
=
[[],
[]]
master_chunk_id
=
0
master_chunk_id
=
0
slave_chunk_id
=
1
slave_chunk_id
=
1
...
@@ -681,10 +687,23 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -681,10 +687,23 @@ def forward_backward_pipelining_with_cutinhalf(
num_chunk_max_microbatch
=
[
num_microbatches
,
num_microbatches
*
2
]
num_chunk_max_microbatch
=
[
num_microbatches
,
num_microbatches
*
2
]
checkpoint_activations_microbatch
=
None
checkpoint_activations_microbatch
=
None
fwd_wait_handles_warmup
=
None
def
forward_step_helper
(
input_tensor
,
model_chunk_id
,
cur_microbatch
,
is_first_microbatch
=
False
):
def
wait_comm_handles
(
comm_handles
):
if
comm_handles
is
None
:
return
for
_
,
req_handle
in
comm_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
comm_handles
=
None
def
forward_step_helper
(
model_chunk_id
,
cur_microbatch
,
is_first_microbatch
=
False
):
set_dualpipe_chunk
(
model_chunk_id
)
set_dualpipe_chunk
(
model_chunk_id
)
if
not
forward_only
:
offset
=
cur_bwd_chunk_microbatch
[
model_chunk_id
]
input_tensor
=
input_tensors
[
model_chunk_id
][
cur_microbatch
-
offset
]
else
:
input_tensor
=
input_tensors
[
model_chunk_id
][
0
]
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
output_tensor
,
num_tokens
=
forward_step_no_model_graph
(
forward_step_func
,
forward_step_func
,
model_chunk_id
,
model_chunk_id
,
...
@@ -699,17 +718,17 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -699,17 +718,17 @@ def forward_backward_pipelining_with_cutinhalf(
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
current_microbatch
=
cur_microbatch
current_microbatch
=
cur_microbatch
)
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
nonlocal
total_num_tokens
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
if
forward_only
:
input_tensors
[
model_chunk_id
].
append
(
input_tensors
[
model_chunk_id
].
pop
(
0
)
(
cur_microbatch
,
input_tensor
))
output_tensors
[
model_chunk_id
].
pop
()
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
return
output_tensor
return
output_tensor
def
backward_step_helper
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
bwd_
model_chunk_id
=
None
,
bwd_cur_microbatch
=
None
):
def
backward_step_helper
(
model_chunk_id
,
bwd_cur_microbatch
=
None
):
nonlocal
master_chunk_id
nonlocal
master_chunk_id
nonlocal
slave_chunk_id
nonlocal
slave_chunk_id
nonlocal
num_chunk_max_microbatch
nonlocal
num_chunk_max_microbatch
...
@@ -721,207 +740,187 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -721,207 +740,187 @@ def forward_backward_pipelining_with_cutinhalf(
# bubble.
# bubble.
if
(
if
(
bwd_cur_microbatch
is
not
None
bwd_cur_microbatch
is
not
None
and
bwd_cur_microbatch
==
num_chunk_max_microbatch
[
bwd_
model_chunk_id
]
-
1
and
bwd_cur_microbatch
==
num_chunk_max_microbatch
[
model_chunk_id
]
-
1
):
):
if
(
if
(
config
.
grad_sync_func
is
None
config
.
grad_sync_func
is
None
or
(
bwd_
model_chunk_id
==
slave_chunk_id
and
parallel_state
.
is_pipeline_last_stage
())
or
(
model_chunk_id
==
slave_chunk_id
and
parallel_state
.
is_pipeline_last_stage
())
or
(
bwd_
model_chunk_id
==
master_chunk_id
and
parallel_state
.
is_pipeline_first_stage
())
or
(
model_chunk_id
==
master_chunk_id
and
parallel_state
.
is_pipeline_first_stage
())
):
):
enable_grad_sync
()
enable_grad_sync
()
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step
(
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
)
return
input_tensor_grad
return
input_tensor_grad
output_tensor_master_send
=
None
output_tensor_slave_send
=
None
fwd_wait_recv_handles
=
[
None
,
None
]
fwd_wait_send_handles
=
[
None
,
None
]
bwd_wait_recv_handles
=
[
None
,
None
]
bwd_wait_send_handles
=
[
None
,
None
]
# Run warmup forward passes
# Run warmup forward passes
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)
input_tensor
,
_
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
)
input_tensors
[
master_chunk_id
].
append
(
input_tensor
)
for
i
in
range
(
schedule
[
'warmup'
][
rank
]):
for
i
in
range
(
schedule
[
'warmup'
][
rank
]):
wait_comm_handles
(
fwd_wait_recv_handles
[
master_chunk_id
])
# recv for next iteration
input_tensor
,
fwd_wait_recv_handles
[
master_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
input_tensors
[
master_chunk_id
].
append
(
input_tensor
)
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
output_tensor_warmup
=
forward_step_helper
(
output_tensor
=
forward_step_helper
(
input_tensor
,
master_chunk_id
,
master_chunk_id
,
cur_fwd_chunk_microbatch
[
master_chunk_id
],
cur_fwd_chunk_microbatch
[
master_chunk_id
],
is_first_microbatch
=
is_first_microbatch
is_first_microbatch
=
is_first_microbatch
)
)
cur_fwd_chunk_microbatch
[
master_chunk_id
]
+=
1
cur_fwd_chunk_microbatch
[
master_chunk_id
]
+=
1
if
i
!=
schedule
[
'warmup'
][
rank
]
-
1
:
if
fwd_wait_send_handles
[
master_chunk_id
]
is
not
None
:
input_tensor
,
_
=
send_forward_recv_forward
(
for
req
,
req_handle
in
fwd_wait_send_handles
[
master_chunk_id
].
items
():
output_tensor_warmup
,
tensor_shape
,
config
,
master_chunk_id
)
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_send_handles
[
master_chunk_id
]
=
None
if
not
forward_only
:
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_master_send
,
config
.
deallocate_pipeline_outputs
)
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
else
:
output_tensor_master_send
=
output_tensor
input_tensor
,
_
=
recv_forward
(
fwd_wait_send_handles
[
master_chunk_id
]
=
send_forward
(
output_tensor_master_send
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
tensor_shape
,
config
,
master_chunk_id
)
fwd_wait_handles_warmup
=
send_forward
(
output_tensor_warmup
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
# Run interleaved forward passes for two model chunk
# Run interleaved forward passes for two model chunk
fwd_wait_handles
=
None
fwd_wait_handles_slave_chunk
=
None
fwd_wait_handles_send
=
None
for
i
in
range
(
schedule
[
'interleaved_forward'
][
rank
]):
for
i
in
range
(
schedule
[
'interleaved_forward'
][
rank
]):
if
fwd_wait_handles
is
not
None
:
wait_comm_handles
(
fwd_wait_
recv_
handles
[
master_chunk_id
])
for
req
,
req_handle
in
fwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
if
not
parallel_state
.
is_pipeline_last_stage
()
:
req_handle
.
wait
(
)
input_tensor_slave
,
fwd_wait_recv_handles
[
slave_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
fwd_wait_handles
=
None
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
is_first_microbatch
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
(
i
==
0
)
is_first_microbatch
=
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
(
i
==
0
)
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch
)
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch
)
output_tensor
=
forward_step_helper
(
output_tensor_master
=
forward_step_helper
(
input_tensor
,
master_chunk_id
,
master_chunk_id
,
cur_fwd_chunk_microbatch
[
master_chunk_id
],
cur_fwd_chunk_microbatch
[
master_chunk_id
],
is_first_microbatch
=
is_first_microbatch
is_first_microbatch
=
is_first_microbatch
)
)
cur_fwd_chunk_microbatch
[
master_chunk_id
]
+=
1
cur_fwd_chunk_microbatch
[
master_chunk_id
]
+=
1
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
fwd_wait_handles_send
is
not
None
:
if
not
parallel_state
.
is_pipeline_last_stage
():
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
wait_comm_handles
(
fwd_wait_send_handles
[
master_chunk_id
])
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles_send
=
None
if
not
forward_only
:
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_master_send
,
config
.
deallocate_pipeline_outputs
)
output_tensor_send
,
config
.
deallocate_pipeline_outputs
)
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
output_tensor_master_send
=
output_tensor_master
fwd_wait_send_handles
[
master_chunk_id
]
=
send_forward
(
output_tensor_master_send
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
# prepare input for slave chunk
if
parallel_state
.
is_pipeline_last_stage
():
if
not
forward_only
:
if
not
forward_only
:
input_tensor_slave
=
output_tensor
.
detach
()
input_tensor_slave
=
output_tensor
_master
.
detach
()
input_tensor_slave
.
requires_grad
=
True
input_tensor_slave
.
requires_grad
=
True
else
:
else
:
input_tensor_slave
=
output_tensor
input_tensor_slave
=
output_tensor_master
else
:
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
input_tensor_slave
,
_
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
input_tensor
,
fwd_wait_handles
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
if
fwd_wait_handles_warmup
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_warmup
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles_warmup
=
None
if
not
forward_only
:
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_master
,
config
.
deallocate_pipeline_outputs
)
output_tensor_warmup
,
config
.
deallocate_pipeline_outputs
)
else
:
wait_comm_handles
(
fwd_wait_recv_handles
[
slave_chunk_id
])
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
# recv input tensor for master clunk
deallocate_output_tensor
(
input_tensor
,
fwd_wait_recv_handles
[
master_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
input_tensors
[
master_chunk_id
].
append
(
input_tensor
)
# slave forward
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
)
output_tensor_slave_chunk
=
forward_step_helper
(
output_tensor_slave
=
forward_step_helper
(
input_tensor_slave
,
slave_chunk_id
,
slave_chunk_id
,
cur_fwd_chunk_microbatch
[
slave_chunk_id
],
cur_fwd_chunk_microbatch
[
slave_chunk_id
],
is_first_microbatch
=
is_first_microbatch
is_first_microbatch
=
is_first_microbatch
)
)
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
wait_comm_handles
(
fwd_wait_send_handles
[
slave_chunk_id
])
if
not
forward_only
:
if
not
forward_only
:
if
i
==
schedule
[
'interleaved_forward'
][
rank
]
-
1
:
deallocate_output_tensor
(
output_tensor_slave_send
,
config
.
deallocate_pipeline_outputs
)
firstFB_no_overlp_handle
=
None
# last rank not overlap first F&B
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
output_tensor_grad_bwd
,
firstFB_no_overlp_handle
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
else
:
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave
_chunk
,
output_tensor_slave_send
=
output_tensor_slave
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
fwd_wait_send_handles
[
slave_chunk_id
]
=
send_forward
(
output_tensor_slave_send
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# check whether data transmission is completed.
output_tensor_send
=
output_tensor
wait_comm_handles
(
fwd_wait_send_handles
[
master_chunk_id
])
fwd_wait_handles_send
=
send_forward
(
if
not
forward_only
:
output_tensor_send
,
tensor_shape
,
config
,
master_chunk_id
,
async_op
=
True
)
deallocate_output_tensor
(
output_tensor_master_send
,
config
.
deallocate_pipeline_outputs
)
else
:
# custom_backward requires output_tensor.numel() == 1
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
fwd_wait_handles
is
not
None
:
wait_comm_handles
(
fwd_wait_send_handles
[
slave_chunk_id
])
for
req
,
req_handle
in
fwd_wait_handles
.
items
():
if
not
forward_only
:
if
req_handle
is
not
None
:
deallocate_output_tensor
(
output_tensor_slave_send
,
config
.
deallocate_pipeline_outputs
)
req_handle
.
wait
()
fwd_wait_handles
=
None
# Run 1b1w1f stages for slave chunk
# Run 1b1w1f stages for slave chunk
bwd_wait_handles
=
None
if
not
forward_only
:
if
parallel_state
.
is_pipeline_last_stage
():
output_tensor_grad
,
bwd_wait_recv_handles
[
slave_chunk_id
]
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
else
:
output_tensor_grad
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
output_tensor_grads
[
slave_chunk_id
].
append
(
output_tensor_grad
)
if
not
forward_only
and
parallel_state
.
is_pipeline_first_stage
():
deallocate_output_tensor
(
output_tensor_slave_send
,
config
.
deallocate_pipeline_outputs
)
for
_
in
range
(
schedule
[
'1b1w1f'
][
rank
]):
for
_
in
range
(
schedule
[
'1b1w1f'
][
rank
]):
# If asynchronous, the memory will rise. TODO dongcl
input_tensor_slave
,
fwd_wait_recv_handles
[
slave_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
input_tensors
[
slave_chunk_id
].
append
(
input_tensor_slave
)
if
not
forward_only
:
if
not
forward_only
:
input_tensor_bwd
=
input_tensors
[
slave_chunk_id
].
pop
(
0
)[
1
]
input_tensor_grad
=
backward_step_helper
(
slave_chunk_id
)
output_tensor_bwd
=
output_tensors
[
slave_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
cur_bwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
cur_bwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
if
fwd_wait_handles_slave_chunk
is
not
None
:
# If asynchronous, the memory will rise.
for
req
in
fwd_wait_handles_slave_chunk
:
bwd_wait_send_handles
[
slave_chunk_id
]
=
send_backward
(
input_tensor_grad
,
tensor_shape
,
config
,
slave_chunk_id
)
req
.
wait
()
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
if
fwd_wait_handles_send
is
not
None
:
wait_comm_handles
(
fwd_wait_send_handles
[
slave_chunk_id
])
for
req
,
req_handle
in
fwd_wait_handles_send
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles_send
=
None
if
not
forward_only
:
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor_slave_send
,
config
.
deallocate_pipeline_outputs
)
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
:
if
not
forward_only
:
# If asynchronous, the memory will rise.
output_tensor_grad
,
_
=
recv_backward
(
tensor_shape
,
config
,
slave_chunk_id
)
bwd_wait_handles
=
send_backward
(
input_tensor_grad
,
output_tensor_grads
[
slave_chunk_id
].
append
(
output_tensor_grad
)
tensor_shape
,
config
,
slave_chunk_id
)
# If asynchronous, the memory will rise.
input_tensor_slave
,
recv_forward_handle
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
)
if
recv_forward_handle
is
not
None
:
# 1F: Forward pass
for
req
,
handle
in
recv_forward_handle
.
items
():
if
fwd_wait_recv_handles
[
slave_chunk_id
]
is
not
None
:
for
req
,
handle
in
fwd_wait_recv_handles
[
slave_chunk_id
].
items
():
if
handle
is
not
None
:
if
handle
is
not
None
:
handle
.
wait
()
handle
.
wait
()
recv_forward_handle
=
None
fwd_wait_recv_handles
[
slave_chunk_id
]
=
None
# 1F: Forward pass
output_tensor_slave
=
forward_step_helper
(
output_tensor
=
forward_step_helper
(
input_tensor_slave
,
slave_chunk_id
,
slave_chunk_id
,
cur_fwd_chunk_microbatch
[
slave_chunk_id
],
cur_fwd_chunk_microbatch
[
slave_chunk_id
],
is_first_microbatch
=
False
is_first_microbatch
=
False
)
)
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
+=
1
if
not
forward_only
:
# check whether backward data transmission is completed.
output_tensor_grad_bwd
,
_
=
recv_backward
(
wait_comm_handles
(
bwd_wait_send_handles
[
slave_chunk_id
])
tensor_shape
,
config
,
slave_chunk_id
)
fwd_wait_handles_slave_chunk
=
send_forward
(
output_tensor_slave
_chunk
,
output_tensor_slave_send
=
output_tensor_slave
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
fwd_wait_send_handles
[
slave_chunk_id
]
=
send_forward
(
output_tensor_slave_send
,
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
# Run overlaping f&bw stages
# Run overlaping f&bw stages
fwd_wait_handles
=
None
fwd_wait_send_recv_handles
=
None
bwd_wait_handles
=
None
bwd_wait_send_recv_handles
=
None
fwd_wait_handles_recv
=
None
fwd_model_chunk_id
=
master_chunk_id
fwd_model_chunk_id
=
master_chunk_id
bwd_model_chunk_id
=
slave_chunk_id
bwd_model_chunk_id
=
slave_chunk_id
num_overlap_steps
=
schedule
[
'overlap'
][
rank
]
+
schedule
[
'1b1overlap'
][
rank
]
num_overlap_steps
=
schedule
[
'overlap'
][
rank
]
+
schedule
[
'1b1overlap'
][
rank
]
...
@@ -934,161 +933,130 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -934,161 +933,130 @@ def forward_backward_pipelining_with_cutinhalf(
if
not
only_bwd
:
if
not
only_bwd
:
def
pp_pre_forward
():
def
pp_pre_forward
():
nonlocal
fwd_wait_handles
_recv
nonlocal
fwd_wait_
recv_
handles
if
fwd_wait_handles_recv
is
not
None
:
# wait input for current step
for
req
,
req_handle
in
fwd_wait_handles_recv
.
items
():
wait_comm_handles
(
fwd_wait_recv_handles
[
fwd_model_chunk_id
])
req_handle
.
wait
()
fwd_wait_handles_recv
=
None
def
pp_post_forward
(
output_tensor
):
def
pp_post_forward
(
output_tensor
):
nonlocal
cur_fwd_chunk_microbatch
nonlocal
cur_fwd_chunk_microbatch
nonlocal
num_chunk_max_microbatch
nonlocal
num_chunk_max_microbatch
nonlocal
fwd_wait_handles
nonlocal
fwd_wait_send_handles
nonlocal
fwd_wait_handles_slave_chunk
nonlocal
fwd_wait_send_recv_handles
nonlocal
firstFB_no_overlp_handle
if
fwd_model_chunk_id
==
master_chunk_id
:
if
fwd_model_chunk_id
==
master_chunk_id
:
fwd_send_only
=
False
fwd_send_only
=
False
else
:
else
:
fwd_send_only
=
(
cur_fwd_chunk_microbatch
[
master_chunk_id
]
==
num_chunk_max_microbatch
[
master_chunk_id
])
fwd_send_only
=
(
cur_fwd_chunk_microbatch
[
master_chunk_id
]
==
num_chunk_max_microbatch
[
master_chunk_id
])
# 同步上个阶段最后一个slave前向send
if
fwd_wait_handles_slave_chunk
is
not
None
:
for
req
,
req_handle
in
fwd_wait_handles_slave_chunk
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles_slave_chunk
=
None
if
not
forward_only
:
deallocate_output_tensor
(
output_tensor_slave_chunk
,
config
.
deallocate_pipeline_outputs
)
if
fwd_send_only
:
if
fwd_send_only
:
input_tensor
=
None
fwd_wait_send_handles
[
fwd_model_chunk_id
]
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
fwd_wait_handles
=
send_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
else
:
else
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
not
forward_only
:
if
not
forward_only
:
input_tensor
=
output_tensor
.
detach
()
input_tensor
=
output_tensor
.
detach
()
input_tensor
.
requires_grad
=
True
input_tensor
.
requires_grad
=
True
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
else
:
else
:
input_tensor
=
output_tensor
input_tensor
=
output_tensor
else
:
else
:
input_tensor
,
fwd_wait_handles
=
send_forward_recv_slave_forward
(
input_tensor
,
fwd_wait_
send_recv_
handles
=
send_forward_recv_slave_forward
(
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
output_tensor
,
tensor_shape
,
config
,
fwd_model_chunk_id
,
async_op
=
True
)
if
not
forward_only
and
firstFB_no_overlp_handle
is
not
None
:
input_tensors
[
1
-
fwd_model_chunk_id
].
append
(
input_tensor
)
for
req
,
req_handle
in
firstFB_no_overlp_handle
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
firstFB_no_overlp_handle
=
None
return
in
put_tensor
return
out
put_tensor
def
pp_pre_backward
():
def
pp_pre_backward
():
nonlocal
bwd_wait_handles
nonlocal
bwd_wait_
send_recv_
handles
if
not
forward_only
:
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
wait_comm_handles
(
bwd_wait_send_recv_handles
)
for
_
,
req_handle
in
bwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
def
pp_post_backward
(
input_tensor_grad
):
def
pp_post_backward
(
input_tensor_grad
):
nonlocal
fwd_wait_handles
nonlocal
fwd_wait_send_handles
nonlocal
bwd_wait_handles
nonlocal
fwd_wait_send_recv_handles
nonlocal
bwd_wait_send_recv_handles
# Check whether the forward data transmission is completed.
wait_comm_handles
(
fwd_wait_send_handles
[
fwd_model_chunk_id
])
wait_comm_handles
(
fwd_wait_send_recv_handles
)
if
fwd_wait_handles
is
not
None
:
for
_
,
req_handle
in
fwd_wait_handles
.
items
():
if
req_handle
is
not
None
:
req_handle
.
wait
()
fwd_wait_handles
=
None
if
not
forward_only
:
if
not
forward_only
:
deallocate_output_tensor
(
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
output_tensor
,
config
.
deallocate_pipeline_outputs
)
if
not
forward_only
:
if
not
forward_only
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad
=
input_tensor_grad
output_tensor_grad
=
input_tensor_grad
else
:
else
:
output_tensor_grad
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
output_tensor_grad
,
bwd_wait_
send_recv_
handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
input_tensor_grad
,
tensor_shape
,
tensor_shape
,
config
,
config
,
fwd_model_chunk_id
,
fwd_model_chunk_id
,
async_op
=
True
async_op
=
True
)
)
else
:
output_tensor_grads
[
fwd_model_chunk_id
].
append
(
output_tensor_grad
)
output_tensor_grad
=
None
return
out
put_tensor_grad
return
in
put_tensor_grad
# forward
# forward
pp_pre_forward
()
pp_pre_forward
()
output_tensor
=
forward_step_helper
(
output_tensor
=
forward_step_helper
(
input_tensor
,
fwd_model_chunk_id
,
fwd_model_chunk_id
,
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
],
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
],
is_first_microbatch
=
False
is_first_microbatch
=
False
)
)
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
+=
1
cur_fwd_chunk_microbatch
[
fwd_model_chunk_id
]
+=
1
in
put_tensor
=
pp_post_forward
(
output_tensor
)
out
put_tensor
=
pp_post_forward
(
output_tensor
)
# backward
# backward
pp_pre_backward
()
pp_pre_backward
()
if
not
forward_only
:
if
not
forward_only
:
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
try
:
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
bwd_model_chunk_id
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
output_tensor_bwd
,
output_tensor_grad_bwd
)
except
Exception
as
e
:
print
(
f
"step_id:
{
step_id
}
, rank:
{
torch
.
distributed
.
get_rank
()
}
, bwd_model_chunk_id:
{
bwd_model_chunk_id
}
"
,
flush
=
True
)
raise
Exception
(
f
"
{
e
}
"
)
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
else
:
else
:
input_tensor_grad
=
None
input_tensor_grad
=
None
output_tensor_grad_bwd
=
pp_post_backward
(
input_tensor_grad
)
_
=
pp_post_backward
(
input_tensor_grad
)
# only run backward
# only run backward
else
:
else
:
if
bwd_model_chunk_id
==
slave_chunk_id
and
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
<
num_chunk_max_microbatch
[
slave_chunk_id
]:
if
bwd_model_chunk_id
==
slave_chunk_id
and
cur_fwd_chunk_microbatch
[
slave_chunk_id
]
<
num_chunk_max_microbatch
[
slave_chunk_id
]:
input_tensor
,
fwd_wait_handles
_recv
=
recv_forward
(
input_tensor
,
fwd_wait_
recv_
handles
[
slave_chunk_id
]
=
recv_forward
(
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
tensor_shape
,
config
,
slave_chunk_id
,
async_op
=
True
)
input_tensors
[
slave_chunk_id
].
append
(
input_tensor
)
if
not
forward_only
:
if
not
forward_only
:
if
bwd_wait_handles
is
not
None
:
wait_comm_handles
(
bwd_wait_send_handles
[
1
-
bwd_model_chunk_id
])
for
req
,
req_handle
in
bwd_wait_handles
.
items
():
wait_comm_handles
(
bwd_wait_send_recv_handles
)
if
req_handle
is
not
None
:
req_handle
.
wait
()
bwd_wait_handles
=
None
input_tensor_bwd
=
input_tensors
[
bwd_model_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
bwd_model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
bwd_model_chunk_id
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
bwd_cur_microbatch
=
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
bwd_cur_microbatch
=
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
)
)
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
cur_bwd_chunk_microbatch
[
bwd_model_chunk_id
]
+=
1
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
if
parallel_state
.
is_pipeline_last_stage
()
and
fwd_model_chunk_id
==
master_chunk_id
:
output_tensor_grad
_bwd
=
input_tensor_grad
output_tensor_grad
=
input_tensor_grad
else
:
else
:
if
step_id
==
num_overlap_steps
-
1
:
if
step_id
==
num_overlap_steps
-
1
:
bwd_wait_handles
=
send_backward
(
bwd_wait_
send_
handles
[
bwd_model_chunk_id
]
=
send_backward
(
input_tensor_grad
,
input_tensor_grad
,
tensor_shape
,
tensor_shape
,
config
,
config
,
bwd_model_chunk_id
,
bwd_model_chunk_id
,
)
)
output_tensor_grad
=
None
else
:
else
:
# send_backward_recv_slave_backward
# send_backward_recv_slave_backward
output_tensor_grad
_bwd
,
bwd_wait_handles
=
send_backward_recv_slave_backward
(
output_tensor_grad
,
bwd_wait_
send_recv_
handles
=
send_backward_recv_slave_backward
(
input_tensor_grad
,
input_tensor_grad
,
tensor_shape
,
tensor_shape
,
config
,
config
,
fwd_model_chunk_id
fwd_model_chunk_id
)
)
output_tensor_grads
[
1
-
bwd_model_chunk_id
].
append
(
output_tensor_grad
)
# swap fwd & bwd chunks
# swap fwd & bwd chunks
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
fwd_model_chunk_id
,
bwd_model_chunk_id
=
bwd_model_chunk_id
,
fwd_model_chunk_id
...
@@ -1102,16 +1070,11 @@ def forward_backward_pipelining_with_cutinhalf(
...
@@ -1102,16 +1070,11 @@ def forward_backward_pipelining_with_cutinhalf(
# Run cooldown phases
# Run cooldown phases
if
not
forward_only
:
if
not
forward_only
:
for
i
in
range
(
schedule
[
'cooldown'
][
rank
][
0
]):
for
i
in
range
(
schedule
[
'cooldown'
][
rank
][
0
]):
output_tensor_grad_bwd
,
_
=
recv_backward
(
tensor_shape
,
config
,
master_chunk_id
)
output_tensor_grad
,
_
=
recv_backward
(
tensor_shape
,
config
,
master_chunk_id
)
output_tensor_grads
[
master_chunk_id
].
append
(
output_tensor_grad
)
input_tensor_bwd
=
input_tensors
[
master_chunk_id
].
pop
(
0
)[
1
]
output_tensor_bwd
=
output_tensors
[
master_chunk_id
].
pop
(
0
)
input_tensor_grad
=
backward_step_helper
(
input_tensor_grad
=
backward_step_helper
(
input_tensor_bwd
,
master_chunk_id
,
output_tensor_bwd
,
output_tensor_grad_bwd
,
bwd_model_chunk_id
=
master_chunk_id
,
bwd_cur_microbatch
=
cur_bwd_chunk_microbatch
[
master_chunk_id
]
bwd_cur_microbatch
=
cur_bwd_chunk_microbatch
[
master_chunk_id
]
)
)
cur_bwd_chunk_microbatch
[
master_chunk_id
]
+=
1
cur_bwd_chunk_microbatch
[
master_chunk_id
]
+=
1
...
...
dcu_megatron/training/training.py
View file @
62f16817
import
gc
import
gc
import
sys
import
sys
import
time
from
functools
import
wraps
from
functools
import
wraps
import
torch.distributed
import
torch.distributed
...
@@ -14,7 +15,10 @@ from megatron.core.distributed import DistributedDataParallel as DDP
...
@@ -14,7 +15,10 @@ from megatron.core.distributed import DistributedDataParallel as DDP
from
megatron.core.distributed.custom_fsdp
import
FullyShardedDataParallel
as
custom_FSDP
from
megatron.core.distributed.custom_fsdp
import
FullyShardedDataParallel
as
custom_FSDP
from
megatron.core.distributed
import
finalize_model_grads
from
megatron.core.distributed
import
finalize_model_grads
from
megatron.core.rerun_state_machine
import
get_rerun_state_machine
from
megatron.core.rerun_state_machine
import
(
get_rerun_state_machine
,
RerunMode
,
)
from
megatron.training.initialize
import
write_args_to_tensorboard
from
megatron.training.initialize
import
write_args_to_tensorboard
from
megatron.core.num_microbatches_calculator
import
(
from
megatron.core.num_microbatches_calculator
import
(
get_current_global_batch_size
,
get_current_global_batch_size
,
...
@@ -30,6 +34,8 @@ from megatron.training.utils import (
...
@@ -30,6 +34,8 @@ from megatron.training.utils import (
logical_and_across_model_parallel_group
,
logical_and_across_model_parallel_group
,
reduce_max_stat_across_model_parallel_group
,
reduce_max_stat_across_model_parallel_group
,
unwrap_model
,
unwrap_model
,
is_rank0
,
is_last_rank
,
)
)
from
megatron.training.global_vars
import
(
from
megatron.training.global_vars
import
(
get_args
,
get_args
,
...
@@ -55,6 +61,7 @@ from megatron.training.training import (
...
@@ -55,6 +61,7 @@ from megatron.training.training import (
cuda_graph_capture
,
cuda_graph_capture
,
cuda_graph_set_manual_hooks
,
cuda_graph_set_manual_hooks
,
dummy_train_step
,
dummy_train_step
,
_TRAIN_START_TIME
,
)
)
from
megatron.core.pipeline_parallel
import
get_forward_backward_func
from
megatron.core.pipeline_parallel
import
get_forward_backward_func
...
@@ -560,3 +567,133 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -560,3 +567,133 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
sys
.
exit
(
exit_code
)
sys
.
exit
(
exit_code
)
return
iteration
,
num_floating_point_operations_so_far
return
iteration
,
num_floating_point_operations_so_far
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
non_loss_data_func
=
None
):
"""Evaluation."""
args
=
get_args
()
timers
=
get_timers
()
timers
(
'evaluate'
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
from
megatron.legacy.model.vision.knn_monitor
import
compute_feature_bank
compute_feature_bank
(
model
)
# Turn on evaluation mode which disables dropout.
for
model_module
in
model
:
model_module
.
eval
()
# Disable result validation during evaluation
rerun_state_machine
=
get_rerun_state_machine
()
rerun_mode
=
rerun_state_machine
.
get_mode
()
rerun_state_machine
.
set_mode
(
RerunMode
.
DISABLED
)
total_loss_dict
=
{}
# make validation batch size independent from training batch size
eval_batch_size
=
args
.
global_batch_size
eval_num_microbatches
=
eval_batch_size
//
\
(
args
.
micro_batch_size
*
args
.
data_parallel_size
)
with
torch
.
no_grad
():
iteration
=
0
if
verbose
:
print_rank_0
(
f
'Evaluating on
{
args
.
eval_iters
*
eval_batch_size
}
samples'
)
while
iteration
<
args
.
eval_iters
:
iteration
+=
1
if
verbose
:
print_rank_0
(
f
'Evaluating iter
{
iteration
}
/
{
args
.
eval_iters
}
'
)
forward_backward_func
=
get_forward_backward_func
()
# Don't care about timing during evaluation
config
.
timers
=
None
ft_integration
.
on_eval_step_start
()
loss_dicts
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
eval_num_microbatches
,
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
True
)
ft_integration
.
on_eval_step_end
()
config
.
timers
=
get_timers
()
# Empty unused memory
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
if
args
.
schedule_method
==
'dualpipev'
:
is_last_stage
=
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
else
:
is_last_stage
=
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
if
is_last_stage
:
# Reduce across processes.
for
loss_dict
in
loss_dicts
:
for
key
in
loss_dict
:
if
key
not
in
total_loss_dict
:
total_loss_dict
[
key
]
=
torch
.
tensor
([
0.0
,
0.0
],
dtype
=
torch
.
float
).
cuda
()
val
=
loss_dict
[
key
]
if
isinstance
(
val
,
tuple
)
or
isinstance
(
val
,
list
):
total_loss_dict
[
key
][
0
]
+=
val
[
0
]
total_loss_dict
[
key
][
1
]
+=
val
[
1
]
else
:
total_loss_dict
[
key
][
0
]
+=
val
total_loss_dict
[
key
][
1
]
+=
1
args
.
consumed_valid_samples
+=
eval_batch_size
if
args
.
exit_duration_in_mins
:
train_time
=
(
time
.
time
()
-
_TRAIN_START_TIME
)
/
60.0
done_cuda
=
torch
.
tensor
(
[
train_time
>
args
.
exit_duration_in_mins
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
done_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
done
=
done_cuda
.
item
()
if
done
:
rerun_state_machine
.
set_mode
(
rerun_mode
)
print_rank_0
(
'Exiting during evaluation, timelimit reached'
)
return
None
,
None
,
True
is_last_rank_func
=
is_rank0
if
args
.
schedule_method
==
'dualpipev'
else
is_last_rank
collected_non_loss_data
=
None
if
non_loss_data_func
is
not
None
:
collected_non_loss_data
=
non_loss_data_func
(
model
)
elif
process_non_loss_data_func
is
not
None
and
is_last_rank_func
():
collected_non_loss_data
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
True
,
collect_non_loss_data
=
True
)
# Move model back to the train mode.
for
model_module
in
model
:
model_module
.
train
()
for
key
in
total_loss_dict
:
numerator
,
denominator
=
total_loss_dict
[
key
]
total_loss_dict
[
key
]
=
numerator
/
denominator
timers
(
'evaluate'
).
stop
()
timers
.
log
([
'evaluate'
])
rerun_state_machine
.
set_mode
(
rerun_mode
)
rerun_state_machine
.
set_mode
(
rerun_mode
)
return
total_loss_dict
,
collected_non_loss_data
,
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment