Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
fa142de0
Commit
fa142de0
authored
May 19, 2025
by
dongcl
Browse files
patch for megatron core0.12.0
parent
cf5d3189
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
155 additions
and
185 deletions
+155
-185
Megatron-LM
Megatron-LM
+1
-1
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+14
-0
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+1
-13
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+31
-31
dcu_megatron/training/initialize.py
dcu_megatron/training/initialize.py
+33
-1
dcu_megatron/training/tokenizer/tokenizer.py
dcu_megatron/training/tokenizer/tokenizer.py
+6
-0
dcu_megatron/training/training.py
dcu_megatron/training/training.py
+69
-139
No files found.
Megatron-LM
@
d580efc6
Compare
408eb718
...
d580efc6
Subproject commit
408eb7186a68ba4d30ee6cc8b05b4de6ba702148
Subproject commit
d580efc68a9f0dbf1945f834f6f6200cd01d3343
dcu_megatron/adaptor/megatron_adaptor.py
View file @
fa142de0
...
...
@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod
,
apply_wrapper
=
True
)
# reduce_scatter_to_sequence_parallel_region
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region'
,
torch
.
_dynamo
.
disable
,
apply_wrapper
=
True
)
# reduce_from_tensor_model_parallel_region
MegatronAdaptation
.
register
(
'megatron.core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region'
,
torch
.
_dynamo
.
disable
,
apply_wrapper
=
True
)
# flux
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
from
..core.tensor_parallel.layers
import
(
...
...
@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..training.initialize
import
_initialize_distributed
from
..training.initialize
import
_compile_dependencies
from
..training.training
import
train
from
..training.initialize
import
_set_random_seed
MegatronAdaptation
.
register
(
'megatron.training.tokenizer.tokenizer.build_tokenizer'
,
build_tokenizer
)
...
...
@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.training.initialize._compile_dependencies'
,
_compile_dependencies
)
# 添加固定seed
MegatronAdaptation
.
register
(
'megatron.training.initialize._set_random_seed'
,
_set_random_seed
)
# add trace_handler
MegatronAdaptation
.
register
(
'megatron.training.training.train'
,
train
)
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
fa142de0
...
...
@@ -7,6 +7,7 @@ from megatron.training import get_args
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel.schedules
import
set_current_microbatch
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.utils
import
(
get_attr_wrapped_model
,
...
...
@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import (
from
.combined_1f1b
import
VppContextManager
,
forward_backward_step
,
set_streams
,
wrap_forward_func
def
set_current_microbatch
(
model
,
microbatch_id
):
"""Set the current microbatch."""
decoder_exists
=
True
decoder
=
None
try
:
decoder
=
get_attr_wrapped_model
(
model
,
"decoder"
)
except
RuntimeError
:
decoder_exists
=
False
if
decoder_exists
and
decoder
is
not
None
:
for
layer
in
decoder
.
layers
:
layer
.
current_microbatch
=
microbatch_id
def
get_pp_rank_microbatches
(
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
,
forward_only
=
False
):
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
fa142de0
...
...
@@ -16,35 +16,6 @@ from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerL
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
def
_callable_wrapper
(
self
,
is_forward
,
func
,
stream
,
event
,
*
args
,
skip_detach
=
False
,
**
kwargs
):
"""
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
torch
.
cuda
.
nvtx
.
range_push
(
func
.
__name__
)
event
.
wait
(
stream
)
with
torch
.
cuda
.
stream
(
stream
):
outputs
=
func
(
*
args
,
**
kwargs
)
event
.
record
(
stream
)
if
skip_detach
:
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
detached_output_tensors
=
[]
if
not
is_forward
:
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
for
tensor
in
outputs
:
if
tensor
is
None
:
detached_output_tensors
.
append
(
None
)
elif
tensor
.
dtype
.
is_floating_point
:
detached_output_tensors
.
append
(
tensor
.
detach
().
requires_grad_
(
True
))
else
:
detached_output_tensors
.
append
(
tensor
.
detach
())
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
def
forward
(
self
,
hidden_states
:
Tensor
,
...
...
@@ -123,7 +94,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual
=
hidden_states
# Optional Input Layer norm
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
self
.
recompute_input_layernorm
:
self
.
input_layernorm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
input_layernorm_output
=
self
.
input_layernorm_checkpoint
.
checkpoint
(
self
.
input_layernorm
,
hidden_states
)
else
:
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output_with_bias
=
self
.
self_attention
(
...
...
@@ -138,6 +115,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset
=
sequence_len_offset
,
)
if
self
.
recompute_input_layernorm
:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self
.
input_layernorm_checkpoint
.
discard_output_and_register_recompute
(
attention_output_with_bias
[
0
]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
...
...
@@ -178,7 +162,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
)
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
if
self
.
recompute_pre_mlp_layernorm
:
self
.
pre_mlp_norm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
self
.
pre_mlp_norm_checkpoint
.
checkpoint
(
self
.
pre_mlp_layernorm
,
hidden_states
)
else
:
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
...
...
@@ -249,6 +239,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
if
shared_expert_output
is
not
None
:
output
+=
shared_expert_output
mlp_output_with_bias
=
(
output
,
mlp_bias
)
if
self
.
recompute_pre_mlp_layernorm
:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self
.
pre_mlp_norm_checkpoint
.
discard_output_and_register_recompute
(
mlp_output_with_bias
[
0
]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
...
...
dcu_megatron/training/initialize.py
View file @
fa142de0
...
...
@@ -105,7 +105,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process
init_process_group_kwargs
=
{
'backend'
:
args
.
distributed_backend
,
'backend'
:
args
.
distributed_backend
,
'world_size'
:
args
.
world_size
,
'rank'
:
args
.
rank
,
'init_method'
:
args
.
dist_url
,
...
...
@@ -149,3 +149,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f
"> initialized pipeline model parallel with size "
f
"
{
mpu
.
get_pipeline_model_parallel_world_size
()
}
"
)
def
_set_random_seed
(
seed_
:
int
,
data_parallel_random_init
:
bool
=
False
,
te_rng_tracker
:
bool
=
False
,
inference_rng_tracker
:
bool
=
False
,
use_cudagraphable_rng
:
bool
=
False
,
):
"""Set random seed for reproducability."""
args
=
get_args
()
if
seed_
is
not
None
and
seed_
>
0
:
# Ensure that different pipeline MP stages get different seeds.
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
# Ensure different data parallel ranks get different seeds
if
data_parallel_random_init
:
seed
=
seed
+
(
10
*
mpu
.
get_data_parallel_rank
())
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
device_count
()
>
0
:
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
,
te_rng_tracker
,
inference_rng_tracker
,
use_cudagraphable_rng
)
if
args
.
reproduce
:
assert
(
args
.
attention_dropout
>
0
)
is
False
,
f
"To utilize the reproduction function, args.attention_dropout =
{
args
.
attention_dropout
}
must be set to 0."
assert
(
args
.
hidden_dropout
>
0
)
is
False
,
f
"To utilize the reproduction function, args.hidden_dropout =
{
args
.
hidden_dropout
}
must be set to 0."
torch
.
backends
.
cudnn
.
deterministic
=
True
# 设置cudnn后端为确定性算法
torch
.
backends
.
cudnn
.
benchmark
=
False
# 固定卷积算法
torch
.
use_deterministic_algorithms
(
True
)
# 使用torch的deterministic算子 避免不确定性
else
:
raise
ValueError
(
"Seed ({}) should be a positive integer."
.
format
(
seed_
))
dcu_megatron/training/tokenizer/tokenizer.py
View file @
fa142de0
...
...
@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import (
_Llama2Tokenizer
,
CustomTikTokenizer
,
_NullTokenizer
,
_NullMultimodalTokenizer
,
_vocab_size_with_padding
)
from
megatron.training.tokenizer.multimodal_tokenizer
import
MultimodalTokenizer
def
build_tokenizer
(
args
,
**
kwargs
):
...
...
@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs):
args
.
tokenizer_prompt_format
,
args
.
special_tokens
,
args
.
image_tag_type
,
args
.
force_system_message
,
)
elif
args
.
tokenizer_type
==
'NullMultimodalTokenizer'
:
assert
args
.
vocab_size
is
not
None
tokenizer
=
_NullMultimodalTokenizer
(
args
.
vocab_size
)
elif
args
.
tokenizer_type
==
"DeepSeekV2Tokenizer"
:
tokenizer
=
_DeepSeekV2Tokenizer
(
args
.
tokenizer_model
,
args
.
extra_vocab_size
)
args
.
padded_vocab_size
=
tokenizer
.
vocab_size
...
...
dcu_megatron/training/training.py
View file @
fa142de0
...
...
@@ -53,18 +53,9 @@ from megatron.training.training import (
stimer
=
StragglerDetector
()
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
,
):
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
,
checkpointing_context
,
non_loss_data_func
):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -74,10 +65,7 @@ def train(
try
:
from
workload_inspector.utils.webserver
import
run_server
import
threading
threading
.
Thread
(
target
=
run_server
,
daemon
=
True
,
args
=
(
torch
.
distributed
.
get_rank
(),)
).
start
()
threading
.
Thread
(
target
=
run_server
,
daemon
=
True
,
args
=
(
torch
.
distributed
.
get_rank
(),
)).
start
()
except
ModuleNotFoundError
:
print_rank_0
(
"workload inspector module not found."
)
...
...
@@ -100,17 +88,11 @@ def train(
rerun_state_machine
.
current_iteration
=
iteration
# Track E2E metrics at the start of training.
one_logger_utils
.
on_train_start
(
iteration
=
iteration
,
consumed_train_samples
=
args
.
consumed_train_samples
,
train_samples
=
args
.
train_samples
,
seq_length
=
args
.
seq_length
,
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
log_throughput
=
args
.
log_throughput
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
,
)
one_logger_utils
.
on_train_start
(
iteration
=
iteration
,
consumed_train_samples
=
args
.
consumed_train_samples
,
train_samples
=
args
.
train_samples
,
seq_length
=
args
.
seq_length
,
train_iters
=
args
.
train_iters
,
save
=
args
.
save
,
async_save
=
args
.
async_save
,
log_throughput
=
args
.
log_throughput
,
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
)
num_floating_point_operations_so_far
=
args
.
num_floating_point_operations_so_far
...
...
@@ -118,10 +100,9 @@ def train(
config
.
grad_scale_func
=
optimizer
.
scale_loss
config
.
timers
=
timers
if
isinstance
(
model
[
0
],
(
custom_FSDP
,
DDP
))
and
args
.
overlap_grad_reduce
:
assert
config
.
no_sync_func
is
None
,
(
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
assert
config
.
no_sync_func
is
None
,
\
(
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
config
.
no_sync_func
=
[
model_chunk
.
no_sync
for
model_chunk
in
model
]
if
len
(
model
)
==
1
:
config
.
no_sync_func
=
config
.
no_sync_func
[
0
]
...
...
@@ -145,9 +126,8 @@ def train(
if
args
.
manual_gc
:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert
(
args
.
manual_gc_interval
>=
0
),
'Manual garbage collection interval should be larger than or equal to 0'
assert
args
.
manual_gc_interval
>=
0
,
\
'Manual garbage collection interval should be larger than or equal to 0'
gc
.
disable
()
gc
.
collect
()
...
...
@@ -157,13 +137,10 @@ def train(
world
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
mmcnt
=
args
.
straggler_minmax_count
stimer
.
configure
(
world
,
rank
,
mmcnt
=
mmcnt
,
enabled
=
not
args
.
disable_straggler_on_startup
,
port
=
args
.
straggler_ctrlr_port
,
)
stimer
.
configure
(
world
,
rank
,
mmcnt
=
mmcnt
,
enabled
=
not
args
.
disable_straggler_on_startup
,
port
=
args
.
straggler_ctrlr_port
)
num_floating_point_operations_since_last_log_event
=
0.0
num_microbatches
=
get_num_microbatches
()
...
...
@@ -171,10 +148,10 @@ def train(
eval_iterations
=
0
def
get_e2e_base_metrics
():
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
num_floating_point_operations_since_current_train_start
=
(
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start
=
\
num_floating_point_operations_so_far
-
args
.
num_floating_point_operations_so_far
)
return
{
'iteration'
:
iteration
,
'train_duration'
:
timers
(
'interval-time'
).
active_time
(),
...
...
@@ -184,7 +161,7 @@ def train(
'num_floating_point_operations_so_far'
:
num_floating_point_operations_so_far
,
'consumed_train_samples'
:
args
.
consumed_train_samples
,
'world_size'
:
args
.
world_size
,
'seq_length'
:
args
.
seq_length
,
'seq_length'
:
args
.
seq_length
}
# Cache into one-logger for callback.
if
one_logger
:
...
...
@@ -192,11 +169,7 @@ def train(
one_logger
.
store_set
(
'get_e2e_base_metrics'
,
get_e2e_base_metrics
)
prof
=
None
if
(
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
):
if
args
.
profile
and
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
and
args
.
use_pytorch_profiler
:
def
trace_handler
(
p
):
from
pathlib
import
Path
Path
(
f
"
{
args
.
profile_dir
}
"
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
...
@@ -242,9 +215,8 @@ def train(
pre_hook_enabled
=
False
# Also, check weight hash across DP replicas to be very pedantic.
if
args
.
check_weight_hash_across_dp_replicas_interval
is
not
None
:
assert
check_param_hashes_across_dp_replicas
(
model
,
cross_check
=
True
),
"Parameter hashes not matching across DP replicas"
assert
check_param_hashes_across_dp_replicas
(
model
,
cross_check
=
True
),
\
"Parameter hashes not matching across DP replicas"
torch
.
distributed
.
barrier
()
print_rank_0
(
f
">>> Weight hashes match after
{
iteration
}
iterations..."
)
...
...
@@ -270,20 +242,14 @@ def train(
# to make sure training configuration is still valid.
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
False
,
verbose
=
True
)
if
get_num_microbatches
()
!=
num_microbatches
and
iteration
!=
0
:
assert
get_num_microbatches
()
>
num_microbatches
,
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
assert
get_num_microbatches
()
>
num_microbatches
,
\
(
f
"Number of microbatches should be increasing due to batch size rampup; "
f
"instead going from
{
num_microbatches
}
to
{
get_num_microbatches
()
}
"
)
if
args
.
save
is
not
None
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
,
)
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
num_microbatches
=
get_num_microbatches
()
update_num_microbatches
(
args
.
consumed_train_samples
,
consistency_check
=
True
,
verbose
=
True
)
...
...
@@ -292,9 +258,9 @@ def train(
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step
(
train_data_iterator
)
iteration
+=
1
batch_size
=
(
mpu
.
get_data_parallel_world_size
()
*
args
.
micro_batch_size
*
get_num_microbatches
()
)
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
get_num_microbatches
(
)
args
.
consumed_train_samples
+=
batch_size
args
.
skipped_train_samples
+=
batch_size
continue
...
...
@@ -302,28 +268,19 @@ def train(
# Run training step.
args
.
curr_iteration
=
iteration
ft_integration
.
on_training_step_start
()
(
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
,
)
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
)
loss_dict
,
skipped_iter
,
should_checkpoint
,
should_exit
,
exit_code
,
grad_norm
,
num_zeros_in_grad
=
\
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
config
)
ft_integration
.
on_training_step_end
()
if
should_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
,
)
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
=
train_data_iterator
)
if
should_exit
:
break
...
...
@@ -346,13 +303,12 @@ def train(
pre_hook_enabled
=
True
iteration
+=
1
batch_size
=
(
mpu
.
get_data_parallel_world_size
()
*
args
.
micro_batch_size
*
get_num_microbatches
()
)
batch_size
=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
get_num_microbatches
(
)
args
.
consumed_train_samples
+=
batch_size
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_running_global_batch_size
()
)
num_skipped_samples_in_batch
=
(
get_current_global_batch_size
()
-
get_current_running_global_batch_size
())
if
args
.
decrease_batch_size_if_needed
:
assert
num_skipped_samples_in_batch
>=
0
else
:
...
...
@@ -378,22 +334,16 @@ def train(
decoupled_learning_rate
=
param_group
[
'lr'
]
else
:
learning_rate
=
param_group
[
'lr'
]
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
decoupled_learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
,
)
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
decoupled_learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Evaluation.
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
args
.
do_valid
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
args
.
do_valid
:
timers
(
'interval-time'
).
stop
()
if
should_disable_forward_pre_hook
(
args
):
disable_forward_pre_hook
(
model
)
...
...
@@ -403,18 +353,11 @@ def train(
gc
.
collect
()
prefix
=
f
'iteration
{
iteration
}
'
timers
(
'eval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
,
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
,
non_loss_data_func
=
non_loss_data_func
)
eval_duration
+=
timers
(
'eval-time'
).
elapsed
()
eval_iterations
+=
args
.
eval_iters
timers
(
'eval-time'
).
stop
()
...
...
@@ -430,25 +373,13 @@ def train(
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
,
)
post_training_step_callbacks
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
prof
,
num_floating_point_operations_since_last_log_event
)
# Checkpoint and decide whether to exit.
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
,
)
should_exit
=
checkpoint_and_decide_exit
(
model
,
optimizer
,
opt_param_scheduler
,
iteration
,
num_floating_point_operations_so_far
,
checkpointing_context
,
train_data_iterator
)
if
should_exit
:
break
...
...
@@ -477,7 +408,6 @@ def train(
if
wandb_writer
:
wandb_writer
.
finish
()
ft_integration
.
shutdown
()
one_logger_utils
.
finish
()
sys
.
exit
(
exit_code
)
return
iteration
,
num_floating_point_operations_so_far
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment