Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
62f16817
Commit
62f16817
authored
Jun 10, 2025
by
dongcl
Browse files
evaluate support dualpipev
parent
2385a133
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
323 additions
and
218 deletions
+323
-218
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
...or/features_manager/pipeline_parallel/pipeline_feature.py
+5
-0
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
...n/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
+180
-217
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+138
-1
No files found.
dcu_megatron/adaptor/features_manager/pipeline_parallel/pipeline_feature.py
View file @
62f16817
...
...
@@ -48,6 +48,7 @@ class PipelineFeature(AbstractFeature):
train_step
,
_allreduce_embedding_grads_wrapper
)
from
dcu_megatron.training.training
import
evaluate
patch_manager
.
register_patch
(
'megatron.training.training.get_model'
,
get_model
)
...
...
@@ -64,6 +65,10 @@ class PipelineFeature(AbstractFeature):
patch_manager
.
register_patch
(
'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads'
,
_allreduce_embedding_grads_wrapper
)
# use first rank
patch_manager
.
register_patch
(
'megatron.training.training.evaluate'
,
evaluate
)
if
(
args
.
schedule_method
==
"interleaved_1f1b"
and
args
.
combined_1f1b
...
...
dcu_megatron/core/pipeline_parallel/dualpipev/dualpipev_schedules.py
View file @
62f16817
This diff is collapsed.
Click to expand it.
dcu_megatron/training/training.py
View file @
62f16817
import
gc
import
sys
import
time
from
functools
import
wraps
import
torch.distributed
...
...
@@ -14,7 +15,10 @@ from megatron.core.distributed import DistributedDataParallel as DDP
from
megatron.core.distributed.custom_fsdp
import
FullyShardedDataParallel
as
custom_FSDP
from
megatron.core.distributed
import
finalize_model_grads
from
megatron.core.rerun_state_machine
import
get_rerun_state_machine
from
megatron.core.rerun_state_machine
import
(
get_rerun_state_machine
,
RerunMode
,
)
from
megatron.training.initialize
import
write_args_to_tensorboard
from
megatron.core.num_microbatches_calculator
import
(
get_current_global_batch_size
,
...
...
@@ -30,6 +34,8 @@ from megatron.training.utils import (
logical_and_across_model_parallel_group
,
reduce_max_stat_across_model_parallel_group
,
unwrap_model
,
is_rank0
,
is_last_rank
,
)
from
megatron.training.global_vars
import
(
get_args
,
...
...
@@ -55,6 +61,7 @@ from megatron.training.training import (
cuda_graph_capture
,
cuda_graph_set_manual_hooks
,
dummy_train_step
,
_TRAIN_START_TIME
,
)
from
megatron.core.pipeline_parallel
import
get_forward_backward_func
...
...
@@ -560,3 +567,133 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
sys
.
exit
(
exit_code
)
return
iteration
,
num_floating_point_operations_so_far
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
non_loss_data_func
=
None
):
"""Evaluation."""
args
=
get_args
()
timers
=
get_timers
()
timers
(
'evaluate'
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
from
megatron.legacy.model.vision.knn_monitor
import
compute_feature_bank
compute_feature_bank
(
model
)
# Turn on evaluation mode which disables dropout.
for
model_module
in
model
:
model_module
.
eval
()
# Disable result validation during evaluation
rerun_state_machine
=
get_rerun_state_machine
()
rerun_mode
=
rerun_state_machine
.
get_mode
()
rerun_state_machine
.
set_mode
(
RerunMode
.
DISABLED
)
total_loss_dict
=
{}
# make validation batch size independent from training batch size
eval_batch_size
=
args
.
global_batch_size
eval_num_microbatches
=
eval_batch_size
//
\
(
args
.
micro_batch_size
*
args
.
data_parallel_size
)
with
torch
.
no_grad
():
iteration
=
0
if
verbose
:
print_rank_0
(
f
'Evaluating on
{
args
.
eval_iters
*
eval_batch_size
}
samples'
)
while
iteration
<
args
.
eval_iters
:
iteration
+=
1
if
verbose
:
print_rank_0
(
f
'Evaluating iter
{
iteration
}
/
{
args
.
eval_iters
}
'
)
forward_backward_func
=
get_forward_backward_func
()
# Don't care about timing during evaluation
config
.
timers
=
None
ft_integration
.
on_eval_step_start
()
loss_dicts
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
eval_num_microbatches
,
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
True
)
ft_integration
.
on_eval_step_end
()
config
.
timers
=
get_timers
()
# Empty unused memory
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
if
args
.
schedule_method
==
'dualpipev'
:
is_last_stage
=
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
else
:
is_last_stage
=
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
if
is_last_stage
:
# Reduce across processes.
for
loss_dict
in
loss_dicts
:
for
key
in
loss_dict
:
if
key
not
in
total_loss_dict
:
total_loss_dict
[
key
]
=
torch
.
tensor
([
0.0
,
0.0
],
dtype
=
torch
.
float
).
cuda
()
val
=
loss_dict
[
key
]
if
isinstance
(
val
,
tuple
)
or
isinstance
(
val
,
list
):
total_loss_dict
[
key
][
0
]
+=
val
[
0
]
total_loss_dict
[
key
][
1
]
+=
val
[
1
]
else
:
total_loss_dict
[
key
][
0
]
+=
val
total_loss_dict
[
key
][
1
]
+=
1
args
.
consumed_valid_samples
+=
eval_batch_size
if
args
.
exit_duration_in_mins
:
train_time
=
(
time
.
time
()
-
_TRAIN_START_TIME
)
/
60.0
done_cuda
=
torch
.
tensor
(
[
train_time
>
args
.
exit_duration_in_mins
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
done_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
done
=
done_cuda
.
item
()
if
done
:
rerun_state_machine
.
set_mode
(
rerun_mode
)
print_rank_0
(
'Exiting during evaluation, timelimit reached'
)
return
None
,
None
,
True
is_last_rank_func
=
is_rank0
if
args
.
schedule_method
==
'dualpipev'
else
is_last_rank
collected_non_loss_data
=
None
if
non_loss_data_func
is
not
None
:
collected_non_loss_data
=
non_loss_data_func
(
model
)
elif
process_non_loss_data_func
is
not
None
and
is_last_rank_func
():
collected_non_loss_data
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
True
,
collect_non_loss_data
=
True
)
# Move model back to the train mode.
for
model_module
in
model
:
model_module
.
train
()
for
key
in
total_loss_dict
:
numerator
,
denominator
=
total_loss_dict
[
key
]
total_loss_dict
[
key
]
=
numerator
/
denominator
timers
(
'evaluate'
).
stop
()
timers
.
log
([
'evaluate'
])
rerun_state_machine
.
set_mode
(
rerun_mode
)
rerun_state_machine
.
set_mode
(
rerun_mode
)
return
total_loss_dict
,
collected_non_loss_data
,
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment