Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4abd7ce2
Commit
4abd7ce2
authored
Apr 17, 2020
by
Neel Kant
Browse files
Refactor before merge request
parent
f1ad8c94
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
115 additions
and
269 deletions
+115
-269
ict_qualitative_test.py
ict_qualitative_test.py
+0
-121
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+14
-6
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+11
-6
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+30
-31
megatron/deprecated_data_utils/datasets.py
megatron/deprecated_data_utils/datasets.py
+2
-2
megatron/model/bert_model.py
megatron/model/bert_model.py
+23
-25
megatron/training.py
megatron/training.py
+0
-3
pretrain_bert_ict.py
pretrain_bert_ict.py
+35
-75
No files found.
ict_qualitative_test.py
deleted
100644 → 0
View file @
f1ad8c94
import
numpy
as
np
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.training
import
get_model
from
pretrain_bert_ict
import
get_batch
,
model_provider
def
main
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
model
=
load_checkpoint
()
model
.
eval
()
dataset
=
get_dataset
()
data_iter
=
iter
(
get_dataloader
(
dataset
))
all_input_tokens
=
[]
all_input_logits
=
[]
all_block_tokens
=
[]
all_block_logits
=
[]
for
i
in
range
(
100
):
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_token_types
,
block_pad_mask
=
get_batch
(
data_iter
)
input_logits
,
doc_logits
,
_
=
model
.
module
.
module
.
forward
(
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
,
return_logits
=
True
)
all_input_tokens
.
append
(
input_tokens
.
detach
().
cpu
().
numpy
())
all_input_logits
.
append
(
input_logits
.
detach
().
cpu
().
numpy
())
all_block_tokens
.
append
(
block_tokens
.
detach
().
cpu
().
numpy
())
all_block_logits
.
append
(
doc_logits
.
detach
().
cpu
().
numpy
())
all_input_tokens
=
np
.
array
(
all_input_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_input_logits
=
np
.
array
(
all_input_logits
).
reshape
(
-
1
,
128
)
all_block_tokens
=
np
.
array
(
all_block_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_block_logits
=
np
.
array
(
all_block_logits
).
reshape
(
-
1
,
128
)
np
.
save
(
'input_tokens.npy'
,
all_input_tokens
)
np
.
save
(
'input_logits.npy'
,
all_input_logits
)
np
.
save
(
'block_tokens.npy'
,
all_block_tokens
)
np
.
save
(
'doc_logits.npy'
,
all_block_logits
)
def
load_checkpoint
():
args
=
get_args
()
model
=
get_model
(
model_provider
)
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
model
.
load_state_dict
(
state_dict
[
'model'
])
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
def
get_dataset
():
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
data_path
+
'-titles'
,
'mmap'
,
True
)
doc_idx_ptr
=
block_dataset
.
get_doc_idx
()
total_num_documents
=
block_dataset
.
doc_idx
.
shape
[
0
]
-
1
block_dataset
.
set_doc_idx
(
doc_idx_ptr
[
0
:
total_num_documents
])
kwargs
=
dict
(
name
=
'full'
,
context_dataset
=
block_dataset
,
titles_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
None
,
max_num_samples
=
total_num_documents
*
3
,
max_seq_length
=
288
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
)
dataset
=
InverseClozeDataset
(
**
kwargs
)
return
dataset
def
get_dataloader
(
dataset
):
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
if
__name__
==
"__main__"
:
main
()
megatron/data/bert_dataset.py
View file @
4abd7ce2
...
@@ -42,9 +42,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -42,9 +42,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
skip_warmup
)
skip_warmup
)
if
ict_dataset
:
if
ict_dataset
:
title
s
_dataset
=
get_indexed_dataset_
(
data_prefix
+
'-titles'
,
title_dataset
=
get_indexed_dataset_
(
data_prefix
+
'-titles'
,
data_impl
,
data_impl
,
skip_warmup
)
skip_warmup
)
# Get start and end indices of train/valid/train into doc-idx
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# Note that doc-idx is desinged to be num-docs + 1 so we can
...
@@ -54,6 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -54,6 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits.
# Print stats about the splits.
print_rank_0
(
' > dataset split:'
)
print_rank_0
(
' > dataset split:'
)
def
print_split_stats
(
name
,
index
):
def
print_split_stats
(
name
,
index
):
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' {}:'
.
format
(
name
))
print_rank_0
(
' document indices in [{}, {}) total of {} '
print_rank_0
(
' document indices in [{}, {}) total of {} '
...
@@ -82,7 +83,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -82,7 +83,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Build the dataset accordingly.
# Build the dataset accordingly.
kwargs
=
dict
(
kwargs
=
dict
(
name
=
name
,
name
=
name
,
context_dataset
=
indexed_dataset
,
data_prefix
=
data_prefix
,
data_prefix
=
data_prefix
,
num_epochs
=
None
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_num_samples
=
train_valid_test_num_samples
[
index
],
...
@@ -92,9 +92,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -92,9 +92,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
)
)
if
ict_dataset
:
if
ict_dataset
:
dataset
=
InverseClozeDataset
(
titles_dataset
=
titles_dataset
,
**
kwargs
)
dataset
=
InverseClozeDataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
**
kwargs
)
else
:
else
:
dataset
=
BertDataset
(
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
)
dataset
=
BertDataset
(
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
)
# Set the original pointer so dataset remains the main dataset.
# Set the original pointer so dataset remains the main dataset.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
)
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
)
# Checks.
# Checks.
...
...
megatron/data/helpers.cpp
View file @
4abd7ce2
...
@@ -452,6 +452,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -452,6 +452,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
// Current map index.
uint64_t
map_index
=
0
;
uint64_t
map_index
=
0
;
int32_t
block_id
=
0
;
// For each epoch:
// For each epoch:
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
...
@@ -514,14 +515,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -514,14 +515,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map.
// Populate the map.
if
(
second
)
{
if
(
second
)
{
const
auto
map_index_0
=
3
*
map_index
;
const
auto
map_index_0
=
4
*
map_index
;
maps
[
map_index_0
]
=
static_cast
<
DocIdx
>
(
prev_start_index
);
maps
[
map_index_0
]
=
static_cast
<
DocIdx
>
(
prev_start_index
);
maps
[
map_index_0
+
1
]
=
static_cast
<
DocIdx
>
(
sent_index
+
1
);
maps
[
map_index_0
+
1
]
=
static_cast
<
DocIdx
>
(
sent_index
+
1
);
maps
[
map_index_0
+
2
]
=
static_cast
<
DocIdx
>
(
doc
);
maps
[
map_index_0
+
2
]
=
static_cast
<
DocIdx
>
(
doc
);
maps
[
map_index_0
+
3
]
=
static_cast
<
DocIdx
>
(
block_id
);
}
}
// Update indices / counters.
// Update indices / counters.
++
map_index
;
++
map_index
;
++
block_id
;
prev_start_index
=
sent_index
+
1
;
prev_start_index
=
sent_index
+
1
;
seq_len
=
0
;
seq_len
=
0
;
num_sent
=
0
;
num_sent
=
0
;
...
@@ -529,6 +532,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -529,6 +532,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
// for (auto sent_index=sent_index_first; ...
}
// for (auto sent_index=sent_index_first; ...
}
// if (num_remain_sent > 1) {
}
// if (num_remain_sent > 1) {
}
// for (int doc=0; doc < num_docs; ++doc) {
}
// for (int doc=0; doc < num_docs; ++doc) {
block_id
=
0
;
}
// for (int epoch=0; epoch < num_epochs; ++epoch) {
}
// for (int epoch=0; epoch < num_epochs; ++epoch) {
if
(
!
second
)
{
if
(
!
second
)
{
...
@@ -538,7 +542,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -538,7 +542,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
}
assert
(
maps
==
NULL
);
assert
(
maps
==
NULL
);
assert
(
num_samples
<
0
);
assert
(
num_samples
<
0
);
maps
=
new
DocIdx
[
3
*
map_index
];
maps
=
new
DocIdx
[
4
*
map_index
];
num_samples
=
static_cast
<
int64_t
>
(
map_index
);
num_samples
=
static_cast
<
int64_t
>
(
map_index
);
}
}
...
@@ -550,12 +554,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -550,12 +554,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
std
::
mt19937_64
rand64_gen
(
seed
+
1
);
std
::
mt19937_64
rand64_gen
(
seed
+
1
);
for
(
auto
i
=
(
num_samples
-
1
);
i
>
0
;
--
i
)
{
for
(
auto
i
=
(
num_samples
-
1
);
i
>
0
;
--
i
)
{
const
auto
j
=
static_cast
<
int64_t
>
(
rand64_gen
()
%
(
i
+
1
));
const
auto
j
=
static_cast
<
int64_t
>
(
rand64_gen
()
%
(
i
+
1
));
const
auto
i0
=
3
*
i
;
const
auto
i0
=
4
*
i
;
const
auto
j0
=
3
*
j
;
const
auto
j0
=
4
*
j
;
// Swap values.
// Swap values.
swap
(
maps
[
i0
],
maps
[
j0
]);
swap
(
maps
[
i0
],
maps
[
j0
]);
swap
(
maps
[
i0
+
1
],
maps
[
j0
+
1
]);
swap
(
maps
[
i0
+
1
],
maps
[
j0
+
1
]);
swap
(
maps
[
i0
+
2
],
maps
[
j0
+
2
]);
swap
(
maps
[
i0
+
2
],
maps
[
j0
+
2
]);
swap
(
maps
[
i0
+
3
],
maps
[
j0
+
3
]);
}
}
// Method to deallocate memory.
// Method to deallocate memory.
...
@@ -566,8 +571,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -566,8 +571,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Return the numpy array.
// Return the numpy array.
const
auto
byte_size
=
sizeof
(
DocIdx
);
const
auto
byte_size
=
sizeof
(
DocIdx
);
return
py
::
array
(
std
::
vector
<
int64_t
>
{
num_samples
,
3
},
// shape
return
py
::
array
(
std
::
vector
<
int64_t
>
{
num_samples
,
4
},
// shape
{
3
*
byte_size
,
byte_size
},
// C-style contiguous strides
{
4
*
byte_size
,
byte_size
},
// C-style contiguous strides
maps
,
// the data pointer
maps
,
// the data pointer
free_when_done
);
// numpy array references
free_when_done
);
// numpy array references
...
...
megatron/data/ict_dataset.py
View file @
4abd7ce2
import
itertools
import
itertools
import
random
import
random
import
os
import
os
import
sys
import
time
import
time
import
numpy
as
np
import
numpy
as
np
...
@@ -16,19 +15,19 @@ from megatron.data import helpers
...
@@ -16,19 +15,19 @@ from megatron.data import helpers
class
InverseClozeDataset
(
Dataset
):
class
InverseClozeDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
context
_dataset
,
title
s
_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
block
_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
):
short_seq_prob
,
seed
):
self
.
name
=
name
self
.
name
=
name
self
.
seed
=
seed
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
max_seq_length
=
max_seq_length
self
.
context
_dataset
=
context
_dataset
self
.
block
_dataset
=
block
_dataset
self
.
title
s
_dataset
=
title
s
_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
get_samples_mapping
(
self
.
context
_dataset
,
self
.
samples_mapping
=
get_samples_mapping
(
self
.
block
_dataset
,
self
.
title
s
_dataset
,
self
.
title_dataset
,
data_prefix
,
data_prefix
,
num_epochs
,
num_epochs
,
max_num_samples
,
max_num_samples
,
...
@@ -47,38 +46,38 @@ class InverseClozeDataset(Dataset):
...
@@ -47,38 +46,38 @@ class InverseClozeDataset(Dataset):
return
self
.
samples_mapping
.
shape
[
0
]
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
title
s
_dataset
[
int
(
doc_idx
)])
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
context
=
[
list
(
self
.
context
_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
block
=
[
list
(
self
.
block
_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
context
)
>
1
assert
len
(
block
)
>
1
# avoid selecting the first or last sentence to be the query.
# avoid selecting the first or last sentence to be the query.
if
len
(
context
)
==
2
:
if
len
(
block
)
==
2
:
rand_sent_idx
=
int
(
self
.
rng
.
random
()
>
0.5
)
rand_sent_idx
=
int
(
self
.
rng
.
random
()
>
0.5
)
else
:
else
:
rand_sent_idx
=
self
.
rng
.
randint
(
1
,
len
(
context
)
-
2
)
rand_sent_idx
=
self
.
rng
.
randint
(
1
,
len
(
block
)
-
2
)
# keep the query in the
context
10% of the time.
# keep the query in the
block
10% of the time.
if
self
.
rng
.
random
()
<
0.1
:
if
self
.
rng
.
random
()
<
0.1
:
input
=
context
[
rand_sent_idx
].
copy
()
query
=
block
[
rand_sent_idx
].
copy
()
else
:
else
:
input
=
context
.
pop
(
rand_sent_idx
)
query
=
block
.
pop
(
rand_sent_idx
)
#
may
still need to truncate because blocks are concluded when
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
# the sentence lengths have exceeded max_seq_length.
input
=
input
[:
self
.
max_seq_length
-
2
]
query
=
query
[:
self
.
max_seq_length
-
2
]
context
=
list
(
itertools
.
chain
(
*
context
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
input
_tokens
,
input
_token_types
,
input
_pad_mask
=
self
.
concat_and_pad_tokens
(
input
)
query
_tokens
,
query
_token_types
,
query
_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
context
_tokens
,
context
_token_types
,
context
_pad_mask
=
self
.
concat_and_pad_tokens
(
context
,
title
)
block
_tokens
,
block
_token_types
,
block
_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
sample
=
{
sample
=
{
'
input_text
'
:
np
.
array
(
input
_tokens
),
'
query_tokens
'
:
np
.
array
(
query
_tokens
),
'
input
_types'
:
np
.
array
(
input
_token_types
),
'
query
_types'
:
np
.
array
(
query
_token_types
),
'
input
_pad_mask'
:
np
.
array
(
input
_pad_mask
),
'
query
_pad_mask'
:
np
.
array
(
query
_pad_mask
),
'
context_text
'
:
np
.
array
(
context
_tokens
),
'
block_tokens
'
:
np
.
array
(
block
_tokens
),
'
context
_types'
:
np
.
array
(
context
_token_types
),
'
block
_types'
:
np
.
array
(
block
_token_types
),
'
context
_pad_mask'
:
np
.
array
(
context
_pad_mask
)
'
block
_pad_mask'
:
np
.
array
(
block
_pad_mask
)
}
}
return
sample
return
sample
...
@@ -97,7 +96,7 @@ class InverseClozeDataset(Dataset):
...
@@ -97,7 +96,7 @@ class InverseClozeDataset(Dataset):
return
tokens
,
token_types
,
pad_mask
return
tokens
,
token_types
,
pad_mask
def
get_samples_mapping
(
context
_dataset
,
def
get_samples_mapping
(
block
_dataset
,
titles_dataset
,
titles_dataset
,
data_prefix
,
data_prefix
,
num_epochs
,
num_epochs
,
...
@@ -131,8 +130,8 @@ def get_samples_mapping(context_dataset,
...
@@ -131,8 +130,8 @@ def get_samples_mapping(context_dataset,
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
# Make sure the types match the helpers input types.
assert
context
_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
block
_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
context
_dataset
.
sizes
.
dtype
==
np
.
int32
assert
block
_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
verbose
=
torch
.
distributed
.
get_rank
()
==
0
...
@@ -140,8 +139,8 @@ def get_samples_mapping(context_dataset,
...
@@ -140,8 +139,8 @@ def get_samples_mapping(context_dataset,
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
name
))
name
))
samples_mapping
=
helpers
.
build_blocks_mapping
(
samples_mapping
=
helpers
.
build_blocks_mapping
(
context
_dataset
.
doc_idx
,
block
_dataset
.
doc_idx
,
context
_dataset
.
sizes
,
block
_dataset
.
sizes
,
titles_dataset
.
sizes
,
titles_dataset
.
sizes
,
num_epochs
,
num_epochs
,
max_num_samples
,
max_num_samples
,
...
...
megatron/deprecated_data_utils/datasets.py
View file @
4abd7ce2
...
@@ -918,10 +918,10 @@ class InverseClozeDataset(data.Dataset):
...
@@ -918,10 +918,10 @@ class InverseClozeDataset(data.Dataset):
sample
=
{
sample
=
{
'input_text'
:
np
.
array
(
input_tokens
),
'input_text'
:
np
.
array
(
input_tokens
),
'
input
_types'
:
np
.
array
(
input_token_types
),
'
query
_types'
:
np
.
array
(
input_token_types
),
'input_pad_mask'
:
np
.
array
(
input_pad_mask
),
'input_pad_mask'
:
np
.
array
(
input_pad_mask
),
'context_text'
:
np
.
array
(
context_tokens
),
'context_text'
:
np
.
array
(
context_tokens
),
'
context
_types'
:
np
.
array
(
context_token_types
),
'
block
_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
}
}
...
...
megatron/model/bert_model.py
View file @
4abd7ce2
...
@@ -215,6 +215,7 @@ class BertModel(MegatronModule):
...
@@ -215,6 +215,7 @@ class BertModel(MegatronModule):
class
ICTBertModel
(
MegatronModule
):
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
def
__init__
(
self
,
ict_head_size
,
ict_head_size
,
num_tokentypes
=
0
,
num_tokentypes
=
0
,
...
@@ -227,41 +228,38 @@ class ICTBertModel(MegatronModule):
...
@@ -227,41 +228,38 @@ class ICTBertModel(MegatronModule):
parallel_output
=
parallel_output
parallel_output
=
parallel_output
)
)
self
.
question_model
=
BertModel
(
**
bert_args
)
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
_question_key
=
'question_model'
self
.
query_model
=
BertModel
(
**
bert_args
)
self
.
context_model
=
BertModel
(
**
bert_args
)
self
.
_query_key
=
'question_model'
self
.
_context_key
=
'context_model'
def
forward
(
self
,
input_tokens
,
input_attention_mask
,
input_types
,
# this model embeds evidence blocks - Embed_doc in the paper
context_tokens
,
context_attention_mask
,
context_types
,
return_logits
=
False
):
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
_block_key
=
'context_model'
question_ict_logits
,
_
=
self
.
question_model
.
forward
(
input_tokens
,
1
-
input_attention_mask
,
input_types
)
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
context_ict_logits
,
_
=
self
.
context_model
.
forward
(
context_tokens
,
1
-
context_attention_mask
,
context_types
)
block_tokens
,
block_attention_mask
,
block_types
):
"""Run a forward pass for each of the models and compute the similarity scores."""
# [batch x h] * [h x batch]
query_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
1
-
query_attention_mask
,
query_types
)
retrieval_scores
=
question_ict_logits
.
matmul
(
torch
.
transpose
(
context_ict_logits
,
0
,
1
))
block_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
1
-
block_attention_mask
,
block_types
)
if
return_logits
:
return
question_ict_logits
,
context_ict_logits
,
retrieval_scores
return
retrieval_scores
return
query_logits
,
block_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
state_dict_
=
{}
state_dict_
[
self
.
_que
stion
_key
]
\
state_dict_
[
self
.
_que
ry
_key
]
\
=
self
.
que
stion
_model
.
state_dict_for_save_checkpoint
(
=
self
.
que
ry
_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_
context
_key
]
\
state_dict_
[
self
.
_
block
_key
]
\
=
self
.
context
_model
.
state_dict_for_save_checkpoint
(
=
self
.
block
_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
"""Load the state dicts of each of the models"""
self
.
query_model
.
load_state_dict
(
self
.
question_model
.
load_state_dict
(
state_dict
[
self
.
_query_key
],
strict
=
strict
)
state_dict
[
self
.
_question_key
],
strict
=
strict
)
self
.
block_model
.
load_state_dict
(
self
.
context_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
state_dict
[
self
.
_context_key
],
strict
=
strict
)
megatron/training.py
View file @
4abd7ce2
...
@@ -262,19 +262,16 @@ def train_step(forward_step_func, data_iterator,
...
@@ -262,19 +262,16 @@ def train_step(forward_step_func, data_iterator,
timers
(
'forward'
).
start
()
timers
(
'forward'
).
start
()
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
timers
(
'forward'
).
stop
()
timers
(
'forward'
).
stop
()
torch
.
cuda
.
synchronize
()
# Calculate gradients, reduce across processes, and clip.
# Calculate gradients, reduce across processes, and clip.
timers
(
'backward'
).
start
()
timers
(
'backward'
).
start
()
backward_step
(
optimizer
,
model
,
loss
)
backward_step
(
optimizer
,
model
,
loss
)
timers
(
'backward'
).
stop
()
timers
(
'backward'
).
stop
()
torch
.
cuda
.
synchronize
()
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
torch
.
cuda
.
synchronize
()
# Update learning rate.
# Update learning rate.
skipped_iter
=
0
skipped_iter
=
0
...
...
pretrain_bert_ict.py
View file @
4abd7ce2
...
@@ -25,7 +25,6 @@ from megatron import print_rank_0
...
@@ -25,7 +25,6 @@ from megatron import print_rank_0
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
ICTBertModel
from
megatron.model
import
ICTBertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
num_batches
=
0
num_batches
=
0
...
@@ -46,8 +45,8 @@ def model_provider():
...
@@ -46,8 +45,8 @@ def model_provider():
def
get_batch
(
data_iterator
):
def
get_batch
(
data_iterator
):
# Items and their type.
# Items and their type.
keys
=
[
'
input_text'
,
'input
_types'
,
'
input
_pad_mask'
,
keys
=
[
'
query_tokens'
,
'query
_types'
,
'
query
_pad_mask'
,
'
context_text'
,
'context
_types'
,
'
context
_pad_mask'
]
'
block_tokens'
,
'block
_types'
,
'
block
_pad_mask'
]
datatype
=
torch
.
int64
datatype
=
torch
.
int64
# Broadcast data.
# Broadcast data.
...
@@ -58,15 +57,15 @@ def get_batch(data_iterator):
...
@@ -58,15 +57,15 @@ def get_batch(data_iterator):
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
input
_tokens
=
data_b
[
'
input_text
'
].
long
()
query
_tokens
=
data_b
[
'
query_tokens
'
].
long
()
input
_types
=
data_b
[
'
input
_types'
].
long
()
query
_types
=
data_b
[
'
query
_types'
].
long
()
input
_pad_mask
=
data_b
[
'
input
_pad_mask'
].
long
()
query
_pad_mask
=
data_b
[
'
query
_pad_mask'
].
long
()
context
_tokens
=
data_b
[
'
context_text
'
].
long
()
block
_tokens
=
data_b
[
'
block_tokens
'
].
long
()
context
_types
=
data_b
[
'
context
_types'
].
long
()
block
_types
=
data_b
[
'
block
_types'
].
long
()
context
_pad_mask
=
data_b
[
'
context
_pad_mask'
].
long
()
block
_pad_mask
=
data_b
[
'
block
_pad_mask'
].
long
()
return
input
_tokens
,
input
_types
,
input
_pad_mask
,
\
return
query
_tokens
,
query
_types
,
query
_pad_mask
,
\
context
_tokens
,
context
_types
,
context
_pad_mask
block
_tokens
,
block
_types
,
block
_pad_mask
def
forward_step
(
data_iterator
,
model
):
def
forward_step
(
data_iterator
,
model
):
...
@@ -75,15 +74,18 @@ def forward_step(data_iterator, model):
...
@@ -75,15 +74,18 @@ def forward_step(data_iterator, model):
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
input
_tokens
,
input
_types
,
input
_pad_mask
,
\
query
_tokens
,
query
_types
,
query
_pad_mask
,
\
context
_tokens
,
context
_types
,
context
_pad_mask
=
get_batch
(
data_iterator
)
block
_tokens
,
block
_types
,
block
_pad_mask
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
retrieval_score
s
=
model
(
input
_tokens
,
input
_pad_mask
,
input
_types
,
query_logits
,
block_logit
s
=
model
(
query
_tokens
,
query
_pad_mask
,
query
_types
,
context
_tokens
,
context
_pad_mask
,
context
_types
).
float
()
block
_tokens
,
block
_pad_mask
,
block
_types
).
float
()
# [batch x h] * [h x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
top5_vals
,
top5_indices
=
torch
.
topk
(
softmaxed
,
k
=
5
,
sorted
=
True
)
top5_vals
,
top5_indices
=
torch
.
topk
(
softmaxed
,
k
=
5
,
sorted
=
True
)
batch_size
=
softmaxed
.
shape
[
0
]
batch_size
=
softmaxed
.
shape
[
0
]
...
@@ -98,71 +100,29 @@ def forward_step(data_iterator, model):
...
@@ -98,71 +100,29 @@ def forward_step(data_iterator, model):
'top5_acc'
:
reduced_losses
[
2
]}
'top5_acc'
:
reduced_losses
[
2
]}
def
get_
train_val_test_data
(
):
def
train_val
id
_test_data
sets_provider
(
train_val_test_num_samples
):
"""
Load the data on rank zero and boradcast number of tokens to all GPUS
."""
"""
Build train, valid and test datasets
."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets '
'for BERT ...'
)
(
train_data
,
valid_data
,
test_data
)
=
(
None
,
None
,
None
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
# Data loader only on rank 0 of each model parallel group.
data_impl
=
args
.
data_impl
,
if
mpu
.
get_model_parallel_rank
()
==
0
:
splits_string
=
args
.
split
,
print_rank_0
(
'> building train, validation, and test datasets '
train_valid_test_num_samples
=
train_val_test_num_samples
,
'for BERT ...'
)
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
short_seq_prob
=
args
.
short_seq_prob
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
seed
=
args
.
seed
,
global_batch_size
=
args
.
batch_size
*
data_parallel_size
skip_warmup
=
(
not
args
.
mmap_warmup
),
ict_dataset
=
True
)
# Number of train/valid/test samples.
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
train_iters
=
args
.
train_iters
eval_iters
=
(
train_iters
//
args
.
eval_iters
+
1
)
*
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_iters
*
global_batch_size
,
eval_iters
*
global_batch_size
,
test_iters
*
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
ict_dataset
=
True
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
train_data
=
make_data_loader
(
train_ds
)
valid_data
=
make_data_loader
(
valid_ds
)
test_data
=
make_data_loader
(
test_ds
)
do_train
=
train_data
is
not
None
and
args
.
train_iters
>
0
do_valid
=
valid_data
is
not
None
and
args
.
eval_iters
>
0
do_test
=
test_data
is
not
None
and
args
.
eval_iters
>
0
# Need to broadcast num_tokens and num_type_tokens.
flags
=
torch
.
cuda
.
LongTensor
(
[
int
(
do_train
),
int
(
do_valid
),
int
(
do_test
)])
else
:
flags
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
])
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
flags
,
mpu
.
get_model_parallel_src_rank
(),
group
=
mpu
.
get_model_parallel_group
())
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
return
train_d
ata
,
valid_d
ata
,
test_d
ata
return
train_d
s
,
valid_d
s
,
test_d
s
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
get_
train_val_test_data
,
model_provider
,
forward_step
,
pretrain
(
train_val
id
_test_data
sets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment