Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
064bdc46
Commit
064bdc46
authored
Apr 07, 2020
by
Neel Kant
Browse files
Fix issue with validation dataloader
parent
72fb0d5c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
12 deletions
+15
-12
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+1
-1
megatron/training.py
megatron/training.py
+5
-6
megatron/utils.py
megatron/utils.py
+7
-3
pretrain_bert_ict.py
pretrain_bert_ict.py
+2
-2
No files found.
megatron/data/ict_dataset.py
View file @
064bdc46
...
...
@@ -35,7 +35,7 @@ class InverseClozeDataset(Dataset):
def
__getitem__
(
self
,
idx
):
# get rng state corresponding to index (allows deterministic random pair)
rng
=
random
.
Random
(
idx
+
self
.
seed
)
rng
=
random
.
Random
(
idx
+
20000
+
self
.
seed
)
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
# get seq length. Save 2 tokens for beginning and end
...
...
megatron/training.py
View file @
064bdc46
...
...
@@ -98,7 +98,6 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
iteration
=
0
if
args
.
do_train
and
args
.
train_iters
>
0
:
if
args
.
do_train
:
iteration
,
_
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
)
...
...
@@ -485,8 +484,8 @@ def get_train_val_test_data_iterators(train_data, val_data, test_data):
if
val_data
is
not
None
:
start_iter_val
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
val_data
.
batch_sampler
.
start_iter
=
start_iter_val
%
\
len
(
val_data
)
val_data
.
batch_sampler
.
start_iter
=
0
print_rank_0
(
'setting validation data start iteration to {}'
.
format
(
val_data
.
batch_sampler
.
start_iter
))
...
...
megatron/utils.py
View file @
064bdc46
...
...
@@ -24,7 +24,7 @@ from megatron import get_adlr_autoresume
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
save_checkpoint
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
,
RandomSampler
from
megatron.fp16
import
FP16_Optimizer
...
...
@@ -102,12 +102,16 @@ def make_data_loader(dataset):
num_workers
=
args
.
num_workers
# Use a simple sampler with distributed batch sampler.
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
#sampler = torch.utils.data.SequentialSampler(dataset)
sampler
=
RandomSampler
(
dataset
,
replacement
=
True
,
num_samples
=
global_batch_size
*
args
.
train_iters
)
batch_sampler
=
DistributedBatchSampler
(
sampler
=
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
world_size
=
world_size
,
wrap_last
=
True
)
# Torch dataloader.
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
...
...
pretrain_bert_ict.py
View file @
064bdc46
...
...
@@ -115,7 +115,7 @@ def get_train_val_test_data():
# Number of train/valid/test samples.
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_interval
+
1
)
*
args
.
eval_iters
eval_iters
=
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
...
...
@@ -159,7 +159,7 @@ def get_train_val_test_data():
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
return
train_data
,
val_data
,
test_data
return
train_data
,
val
id
_data
,
test_data
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment