Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b1a6d73b
Commit
b1a6d73b
authored
Nov 25, 2021
by
zihanl
Browse files
fix training.py
parent
6fd0b406
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
7 deletions
+2
-7
megatron/training.py
megatron/training.py
+2
-7
No files found.
megatron/training.py
View file @
b1a6d73b
...
@@ -141,7 +141,6 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -141,7 +141,6 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0
(
'training ...'
)
print_rank_0
(
'training ...'
)
iteration
=
0
iteration
=
0
# if not args.run_dialog:
if
args
.
do_train
and
args
.
train_iters
>
0
:
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
model
,
optimizer
,
lr_scheduler
,
...
@@ -163,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -163,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
test_data_iterator
,
model
,
0
,
True
)
0
,
True
)
def
update_train_iters
(
args
):
def
update_train_iters
(
args
):
# For iteration-based training, we don't need to do anything
# For iteration-based training, we don't need to do anything
...
@@ -355,8 +354,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
...
@@ -355,8 +354,6 @@ def setup_model_and_optimizer(model_provider_func, model_type):
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
start
()
timers
(
'load-checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
# need to set train_samples to None
args
.
train_samples
=
None
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
stop
()
timers
(
'load-checkpoint'
).
stop
()
timers
.
log
([
'load-checkpoint'
])
timers
.
log
([
'load-checkpoint'
])
...
@@ -662,9 +659,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -662,9 +659,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
# Iterations.
iteration
=
args
.
iteration
iteration
=
args
.
iteration
# if not args.run_dialog:
timers
(
'interval-time'
).
start
()
timers
(
'interval-time'
).
start
()
print_datetime
(
'before the start of training step'
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
...
@@ -860,7 +855,7 @@ def build_train_valid_test_data_iterators(
...
@@ -860,7 +855,7 @@ def build_train_valid_test_data_iterators(
else
:
else
:
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
args
.
eval_iters
args
.
eval_iters
test_iters
=
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_samples
,
train_val_test_num_samples
=
[
train_samples
,
eval_iters
*
args
.
global_batch_size
,
eval_iters
*
args
.
global_batch_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment