Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3354081f
Commit
3354081f
authored
Jun 23, 2020
by
Neel Kant
Browse files
More ict_merge changes and interactive testing
parent
6495927e
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
134 additions
and
105 deletions
+134
-105
megatron/__init__.py
megatron/__init__.py
+2
-2
megatron/checkpointing.py
megatron/checkpointing.py
+1
-2
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+1
-52
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+55
-2
megatron/data/gpt2_dataset.py
megatron/data/gpt2_dataset.py
+1
-2
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+43
-15
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+1
-0
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+6
-3
megatron/model/classification.py
megatron/model/classification.py
+1
-2
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+1
-2
megatron/model/realm_model.py
megatron/model/realm_model.py
+3
-4
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+9
-3
megatron/training.py
megatron/training.py
+1
-2
megatron/utils.py
megatron/utils.py
+3
-2
pretrain_bert.py
pretrain_bert.py
+1
-2
pretrain_bert_ict.py
pretrain_bert_ict.py
+1
-2
pretrain_gpt2.py
pretrain_gpt2.py
+1
-2
tasks/eval_utils.py
tasks/eval_utils.py
+1
-2
tasks/finetune_utils.py
tasks/finetune_utils.py
+1
-2
tasks/glue/finetune.py
tasks/glue/finetune.py
+1
-2
No files found.
megatron/__init__.py
View file @
3354081f
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
from
.package_info
import
(
from
.package_info
import
(
__description__
,
__description__
,
...
@@ -30,7 +31,6 @@ from .global_vars import get_tensorboard_writer
...
@@ -30,7 +31,6 @@ from .global_vars import get_tensorboard_writer
from
.global_vars
import
get_adlr_autoresume
from
.global_vars
import
get_adlr_autoresume
from
.global_vars
import
get_timers
from
.global_vars
import
get_timers
import
torch
def
print_rank_0
(
message
):
def
print_rank_0
(
message
):
"""If distributed is initialized print only on rank 0."""
"""If distributed is initialized print only on rank 0."""
...
@@ -38,4 +38,4 @@ def print_rank_0(message):
...
@@ -38,4 +38,4 @@ def print_rank_0(message):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
print
(
message
,
flush
=
True
)
else
:
else
:
print
(
message
,
flush
=
True
)
print
(
message
,
flush
=
True
)
\ No newline at end of file
megatron/checkpointing.py
View file @
3354081f
...
@@ -23,9 +23,8 @@ import numpy as np
...
@@ -23,9 +23,8 @@ import numpy as np
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
mpu
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
def
check_checkpoint_args
(
checkpoint_args
):
def
check_checkpoint_args
(
checkpoint_args
):
...
...
megatron/data/bert_dataset.py
View file @
3354081f
...
@@ -22,15 +22,13 @@ import numpy as np
...
@@ -22,15 +22,13 @@ import numpy as np
import
torch
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
,
get_args
from
megatron
import
get_tokenizer
,
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.dataset_utils
import
get_a_and_b_segments
from
megatron.data.dataset_utils
import
get_a_and_b_segments
from
megatron.data.dataset_utils
import
truncate_segments
from
megatron.data.dataset_utils
import
truncate_segments
from
megatron.data.dataset_utils
import
create_tokens_and_tokentypes
from
megatron.data.dataset_utils
import
create_tokens_and_tokentypes
from
megatron.data.dataset_utils
import
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
from
megatron
import
print_rank_0
class
BertDataset
(
Dataset
):
class
BertDataset
(
Dataset
):
...
@@ -85,55 +83,6 @@ class BertDataset(Dataset):
...
@@ -85,55 +83,6 @@ class BertDataset(Dataset):
self
.
masked_lm_prob
,
np_rng
)
self
.
masked_lm_prob
,
np_rng
)
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
print_rank_0
(
' > building dataset index ...'
)
start_time
=
time
.
time
()
indexed_dataset
=
make_indexed_dataset
(
data_prefix
,
data_impl
,
skip_warmup
)
assert
indexed_dataset
.
sizes
.
shape
[
0
]
==
indexed_dataset
.
doc_idx
[
-
1
]
print_rank_0
(
' > finished creating indexed dataset in {:4f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' > indexed dataset stats:'
)
print_rank_0
(
' number of documents: {}'
.
format
(
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
))
print_rank_0
(
' number of sentences: {}'
.
format
(
indexed_dataset
.
sizes
.
shape
[
0
]))
return
indexed_dataset
def
get_train_valid_test_split_
(
splits_string
,
size
):
""" Get dataset splits from comma or '/' separated string list."""
splits
=
[]
if
splits_string
.
find
(
','
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
','
)]
elif
splits_string
.
find
(
'/'
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
'/'
)]
else
:
splits
=
[
float
(
splits_string
)]
while
len
(
splits
)
<
3
:
splits
.
append
(
0.
)
splits
=
splits
[:
3
]
splits_sum
=
sum
(
splits
)
assert
splits_sum
>
0.0
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits_index
=
[
0
]
for
index
,
split
in
enumerate
(
splits
):
splits_index
.
append
(
splits_index
[
index
]
+
int
(
round
(
split
*
float
(
size
))))
diff
=
splits_index
[
-
1
]
-
size
for
index
in
range
(
1
,
len
(
splits_index
)):
splits_index
[
index
]
-=
diff
assert
len
(
splits_index
)
==
4
assert
splits_index
[
-
1
]
==
size
return
splits_index
def
get_samples_mapping_
(
indexed_dataset
,
def
get_samples_mapping_
(
indexed_dataset
,
data_prefix
,
data_prefix
,
num_epochs
,
num_epochs
,
...
...
megatron/data/dataset_utils.py
View file @
3354081f
...
@@ -18,17 +18,19 @@
...
@@ -18,17 +18,19 @@
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
# with some modifications.
import
time
import
collections
import
collections
import
itertools
import
numpy
as
np
import
numpy
as
np
from
megatron
import
print_rank_0
,
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
DSET_TYPE_STD
=
'standard_bert'
DSET_TYPE_STD
=
'standard_bert'
DSET_TYPE_ICT
=
'ict'
DSET_TYPE_ICT
=
'ict'
DSET_TYPES
=
[
DSET_TYPE_ICT
,
DSET_TYPE_STD
]
DSET_TYPES
=
[
DSET_TYPE_ICT
,
DSET_TYPE_STD
]
def
compile_helper
():
def
compile_helper
():
"""Compile helper function ar runtime. Make sure this
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
is invoked on a single process."""
...
@@ -447,3 +449,54 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -447,3 +449,54 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
test_dataset
=
build_dataset
(
2
,
'test'
)
test_dataset
=
build_dataset
(
2
,
'test'
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
def
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
):
print_rank_0
(
' > building dataset index ...'
)
start_time
=
time
.
time
()
indexed_dataset
=
make_indexed_dataset
(
data_prefix
,
data_impl
,
skip_warmup
)
assert
indexed_dataset
.
sizes
.
shape
[
0
]
==
indexed_dataset
.
doc_idx
[
-
1
]
print_rank_0
(
' > finished creating indexed dataset in {:4f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' > indexed dataset stats:'
)
print_rank_0
(
' number of documents: {}'
.
format
(
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
))
print_rank_0
(
' number of sentences: {}'
.
format
(
indexed_dataset
.
sizes
.
shape
[
0
]))
return
indexed_dataset
def
get_train_valid_test_split_
(
splits_string
,
size
):
""" Get dataset splits from comma or '/' separated string list."""
splits
=
[]
if
splits_string
.
find
(
','
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
','
)]
elif
splits_string
.
find
(
'/'
)
!=
-
1
:
splits
=
[
float
(
s
)
for
s
in
splits_string
.
split
(
'/'
)]
else
:
splits
=
[
float
(
splits_string
)]
while
len
(
splits
)
<
3
:
splits
.
append
(
0.
)
splits
=
splits
[:
3
]
splits_sum
=
sum
(
splits
)
assert
splits_sum
>
0.0
splits
=
[
split
/
splits_sum
for
split
in
splits
]
splits_index
=
[
0
]
for
index
,
split
in
enumerate
(
splits
):
splits_index
.
append
(
splits_index
[
index
]
+
int
(
round
(
split
*
float
(
size
))))
diff
=
splits_index
[
-
1
]
-
size
for
index
in
range
(
1
,
len
(
splits_index
)):
splits_index
[
index
]
-=
diff
assert
len
(
splits_index
)
==
4
assert
splits_index
[
-
1
]
==
size
return
splits_index
megatron/data/gpt2_dataset.py
View file @
3354081f
...
@@ -21,8 +21,7 @@ import time
...
@@ -21,8 +21,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
print_rank_0
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
mpu
from
megatron.data.bert_dataset
import
get_train_valid_test_split_
from
megatron.data.bert_dataset
import
get_train_valid_test_split_
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
...
...
megatron/data/helpers.cpp
View file @
3354081f
...
@@ -401,7 +401,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -401,7 +401,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const
uint64_t
max_num_samples
,
const
uint64_t
max_num_samples
,
const
int32_t
max_seq_length
,
const
int32_t
max_seq_length
,
const
int32_t
seed
,
const
int32_t
seed
,
const
bool
verbose
)
{
const
bool
verbose
,
const
bool
use_one_sent_blocks
)
{
/* Build a mapping of (start-index, end-index, sequence-length) where
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
and sequence-length is the target sequence length.
...
@@ -442,6 +443,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -442,6 +443,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
int64_t
num_samples
=
-
1
;
int64_t
num_samples
=
-
1
;
DocIdx
*
maps
=
NULL
;
DocIdx
*
maps
=
NULL
;
// Acceptable number of sentences per block.
int
min_num_sent
=
2
;
if
(
use_one_sent_blocks
)
{
min_num_sent
=
1
;
}
// Perform two iterations, in the first iteration get the size
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
// and allocate memory and in the second iteration populate the map.
bool
second
=
false
;
bool
second
=
false
;
...
@@ -453,6 +460,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -453,6 +460,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
// Current map index.
uint64_t
map_index
=
0
;
uint64_t
map_index
=
0
;
uint64_t
empty_docs
=
0
;
uint64_t
one_sent_docs
=
0
;
uint64_t
long_sent_docs
=
0
;
// For each epoch:
// For each epoch:
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
for
(
int32_t
epoch
=
0
;
epoch
<
num_epochs
;
++
epoch
)
{
// assign every block a unique id
// assign every block a unique id
...
@@ -480,19 +490,31 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -480,19 +490,31 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Remaining documents.
// Remaining documents.
auto
num_remain_sent
=
sent_index_last
-
sent_index_first
;
auto
num_remain_sent
=
sent_index_last
-
sent_index_first
;
// Some bookkeeping
if
((
epoch
==
0
)
&&
(
!
second
))
{
if
(
num_remain_sent
==
0
)
{
++
empty_docs
;
}
if
(
num_remain_sent
==
1
)
{
++
one_sent_docs
;
}
}
// Detect documents with long sentences.
// Detect documents with long sentences.
bool
contains_long_sentence
=
false
;
bool
contains_long_sentence
=
false
;
if
(
num_remain_sent
>
1
)
{
if
(
num_remain_sent
>
=
min_num_sent
)
{
for
(
auto
sent_index
=
sent_index_first
;
for
(
auto
sent_index
=
sent_index_first
;
sent_index
<
sent_index_last
;
++
sent_index
)
{
sent_index
<
sent_index_last
;
++
sent_index
)
{
if
(
sizes
[
sent_index
]
>
LONG_SENTENCE_LEN
){
if
(
sizes
[
sent_index
]
>
LONG_SENTENCE_LEN
){
if
((
epoch
==
0
)
&&
(
!
second
))
{
++
long_sent_docs
;
}
contains_long_sentence
=
true
;
contains_long_sentence
=
true
;
break
;
break
;
}
}
}
}
}
}
// If we have
more than two
sentences.
// If we have
enough sentences and no long
sentences.
if
((
num_remain_sent
>
1
)
&&
(
!
contains_long_sentence
))
{
if
((
num_remain_sent
>
=
min_num_sent
)
&&
(
!
contains_long_sentence
))
{
// Set values.
// Set values.
auto
seq_len
=
int32_t
{
0
};
auto
seq_len
=
int32_t
{
0
};
...
@@ -508,12 +530,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -508,12 +530,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
--
num_remain_sent
;
--
num_remain_sent
;
// If we have reached the target length.
// If we have reached the target length.
// and
if not only one
sentence
i
s left
in the document.
// and
there are an acceptable number of
sentences left
// and if we have at least t
wo
sent
n
eces.
// and if we have at least t
he minimum number of
sente
n
ces.
// or if we have reached end of the document.
// or if we have reached end of the document.
if
(((
seq_len
>=
target_seq_len
)
&&
if
(((
seq_len
>=
target_seq_len
)
&&
(
num_remain_sent
>
1
)
&&
(
num_remain_sent
>
=
min_num_sent
)
&&
(
num_sent
>
1
)
)
||
(
num_remain_sent
==
0
))
{
(
num_sent
>
=
min_num_sent
)
)
||
(
num_remain_sent
==
0
))
{
// Populate the map.
// Populate the map.
if
(
second
)
{
if
(
second
)
{
...
@@ -538,11 +560,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -538,11 +560,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
// for (auto sent_index=sent_index_first; ...
}
// for (auto sent_index=sent_index_first; ...
}
// if (num_remain_sent > 1) {
}
// if (num_remain_sent > 1) {
}
// for (int doc=0; doc < num_docs; ++doc) {
}
// for (int doc=0; doc < num_docs; ++doc) {
block_id
=
0
;
}
// for (int epoch=0; epoch < num_epochs; ++epoch) {
}
// for (int epoch=0; epoch < num_epochs; ++epoch) {
if
(
!
second
)
{
if
(
!
second
)
{
if
(
verbose
)
{
if
(
verbose
)
{
cout
<<
" number of empty documents: "
<<
empty_docs
<<
endl
<<
std
::
flush
;
cout
<<
" number of documents with one sentence: "
<<
one_sent_docs
<<
endl
<<
std
::
flush
;
cout
<<
" number of documents with long sentences: "
<<
long_sent_docs
<<
endl
<<
std
::
flush
;
cout
<<
" will create mapping for "
<<
map_index
<<
cout
<<
" will create mapping for "
<<
map_index
<<
" samples"
<<
endl
<<
std
::
flush
;
" samples"
<<
endl
<<
std
::
flush
;
}
}
...
@@ -554,9 +581,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -554,9 +581,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
// for (int iteration=0; iteration < 2; ++iteration) {
}
// for (int iteration=0; iteration < 2; ++iteration) {
//
Shuffle.
Shuffle
.
//
We need a 64 bit random number generator as we might have more
We
need
a
64
bit
random
number
generator
as
we
might
have
more
//
than 2 billion samples.
than
2
billion
samples
.
std
::
mt19937_64
rand64_gen
(
seed
+
1
);
std
::
mt19937_64
rand64_gen
(
seed
+
1
);
for
(
auto
i
=
(
num_samples
-
1
);
i
>
0
;
--
i
)
{
for
(
auto
i
=
(
num_samples
-
1
);
i
>
0
;
--
i
)
{
const
auto
j
=
static_cast
<
int64_t
>
(
rand64_gen
()
%
(
i
+
1
));
const
auto
j
=
static_cast
<
int64_t
>
(
rand64_gen
()
%
(
i
+
1
));
...
@@ -591,20 +618,21 @@ py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
...
@@ -591,20 +618,21 @@ py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
const
uint64_t
max_num_samples
,
const
uint64_t
max_num_samples
,
const
int
max_seq_length
,
const
int
max_seq_length
,
const
int
seed
,
const
int
seed
,
const
bool
verbose
)
{
const
bool
verbose
,
const
bool
use_one_sent_blocks
)
{
if
(
sizes_
.
size
()
>
std
::
numeric_limits
<
uint32_t
>::
max
())
{
if
(
sizes_
.
size
()
>
std
::
numeric_limits
<
uint32_t
>::
max
())
{
if
(
verbose
)
{
if
(
verbose
)
{
cout
<<
" using uint64 for data mapping..."
<<
endl
<<
std
::
flush
;
cout
<<
" using uint64 for data mapping..."
<<
endl
<<
std
::
flush
;
}
}
return
build_blocks_mapping_impl
<
uint64_t
>
(
docs_
,
sizes_
,
titles_sizes_
,
return
build_blocks_mapping_impl
<
uint64_t
>
(
docs_
,
sizes_
,
titles_sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
verbose
);
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
verbose
,
use_one_sent_blocks
);
}
else
{
}
else
{
if
(
verbose
)
{
if
(
verbose
)
{
cout
<<
" using uint32 for data mapping..."
<<
endl
<<
std
::
flush
;
cout
<<
" using uint32 for data mapping..."
<<
endl
<<
std
::
flush
;
}
}
return
build_blocks_mapping_impl
<
uint32_t
>
(
docs_
,
sizes_
,
titles_sizes_
,
return
build_blocks_mapping_impl
<
uint32_t
>
(
docs_
,
sizes_
,
titles_sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
verbose
);
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
verbose
,
use_one_sent_blocks
);
}
}
}
}
...
...
megatron/data/realm_dataset.py
View file @
3354081f
...
@@ -65,6 +65,7 @@ class ICTDataset(Dataset):
...
@@ -65,6 +65,7 @@ class ICTDataset(Dataset):
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
print
(
self
.
tokenizer
.
decode_token_ids
(
block_tokens
),
'
\n
'
)
block_data
=
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
block_data
=
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
sample
=
{
sample
=
{
...
...
megatron/data/realm_dataset_utils.py
View file @
3354081f
...
@@ -4,7 +4,7 @@ import time
...
@@ -4,7 +4,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
print_rank_0
,
mpu
from
megatron
import
mpu
,
print_rank_0
def
join_str_list
(
str_list
):
def
join_str_list
(
str_list
):
...
@@ -19,7 +19,7 @@ def join_str_list(str_list):
...
@@ -19,7 +19,7 @@ def join_str_list(str_list):
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
):
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
=
False
):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account."""
a dataset of the titles for the source documents since their lengths must be taken into account."""
if
not
num_epochs
:
if
not
num_epochs
:
...
@@ -39,6 +39,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
...
@@ -39,6 +39,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
if
use_one_sent_docs
:
indexmap_filename
+=
'_1sentok'
indexmap_filename
+=
'.npy'
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
# Build the indexed mapping if not exist.
...
@@ -67,7 +69,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
...
@@ -67,7 +69,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
max_num_samples
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
max_seq_length
-
3
,
# account for added tokens
seed
,
seed
,
verbose
)
verbose
,
use_one_sent_docs
)
print_rank_0
(
' > done building samples index mapping'
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
...
...
megatron/model/classification.py
View file @
3354081f
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.bert_model
import
bert_position_ids
...
@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer
...
@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
megatron
import
print_rank_0
class
Classification
(
MegatronModule
):
class
Classification
(
MegatronModule
):
...
...
megatron/model/multiple_choice.py
View file @
3354081f
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_attention_mask_func
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_extended_attention_mask
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.bert_model
import
bert_position_ids
...
@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer
...
@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
from
megatron
import
print_rank_0
class
MultipleChoice
(
MegatronModule
):
class
MultipleChoice
(
MegatronModule
):
...
...
megatron/model/realm_model.py
View file @
3354081f
...
@@ -125,7 +125,7 @@ class ICTBertModel(MegatronModule):
...
@@ -125,7 +125,7 @@ class ICTBertModel(MegatronModule):
class
IREncoderBertModel
(
MegatronModule
):
class
IREncoderBertModel
(
MegatronModule
):
"""B
ert Language mode
l."""
"""B
ERT-based encoder for queries or blocks used for learned information retrieva
l."""
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
):
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
2
,
parallel_output
=
True
):
super
(
IREncoderBertModel
,
self
).
__init__
()
super
(
IREncoderBertModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -158,9 +158,8 @@ class IREncoderBertModel(MegatronModule):
...
@@ -158,9 +158,8 @@ class IREncoderBertModel(MegatronModule):
tokentype_ids
=
tokentype_ids
)
tokentype_ids
=
tokentype_ids
)
# Output.
# Output.
if
self
.
add_ict_head
:
ict_logits
=
self
.
ict_head
(
pooled_output
)
ict_logits
=
self
.
ict_head
(
pooled_output
)
return
ict_logits
,
None
return
ict_logits
,
None
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
...
megatron/tokenizer/tokenizer.py
View file @
3354081f
...
@@ -20,7 +20,6 @@ from abc import abstractmethod
...
@@ -20,7 +20,6 @@ from abc import abstractmethod
from
.bert_tokenization
import
FullTokenizer
as
FullBertTokenizer
from
.bert_tokenization
import
FullTokenizer
as
FullBertTokenizer
from
.gpt2_tokenization
import
GPT2Tokenizer
from
.gpt2_tokenization
import
GPT2Tokenizer
from
megatron.data.realm_dataset_utils
import
join_str_list
def
build_tokenizer
(
args
):
def
build_tokenizer
(
args
):
...
@@ -160,8 +159,15 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
...
@@ -160,8 +159,15 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
tokens
=
self
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
tokens
=
self
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
exclude_list
=
[
'[PAD]'
,
'[CLS]'
]
exclude_list
=
[
'[PAD]'
,
'[CLS]'
]
non_pads
=
[
t
for
t
in
tokens
if
t
not
in
exclude_list
]
non_pads
=
[
t
for
t
in
tokens
if
t
not
in
exclude_list
]
joined_strs
=
join_str_list
(
non_pads
)
return
joined_strs
result
=
""
for
s
in
non_pads
:
if
s
.
startswith
(
"##"
):
result
+=
s
[
2
:]
else
:
result
+=
" "
+
s
return
result
@
property
@
property
def
cls
(
self
):
def
cls
(
self
):
...
...
megatron/training.py
View file @
3354081f
...
@@ -22,11 +22,10 @@ import torch
...
@@ -22,11 +22,10 @@ import torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Module
...
...
megatron/utils.py
View file @
3354081f
...
@@ -19,10 +19,9 @@ import sys
...
@@ -19,10 +19,9 @@ import sys
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.fp16
import
FP16_Optimizer
from
megatron.fp16
import
FP16_Optimizer
...
@@ -173,3 +172,5 @@ def get_ltor_masks_and_position_ids(data,
...
@@ -173,3 +172,5 @@ def get_ltor_masks_and_position_ids(data,
attention_mask
=
(
attention_mask
<
0.5
)
attention_mask
=
(
attention_mask
<
0.5
)
return
attention_mask
,
loss_mask
,
position_ids
return
attention_mask
,
loss_mask
,
position_ids
pretrain_bert.py
View file @
3354081f
...
@@ -18,10 +18,9 @@
...
@@ -18,10 +18,9 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
...
...
pretrain_bert_ict.py
View file @
3354081f
...
@@ -19,10 +19,9 @@ import torch
...
@@ -19,10 +19,9 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
ICTBertModel
from
megatron.model
import
ICTBertModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
...
...
pretrain_gpt2.py
View file @
3354081f
...
@@ -17,11 +17,10 @@
...
@@ -17,11 +17,10 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.data.gpt2_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt2_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPT2Model
from
megatron.model
import
GPT2Model
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
...
...
tasks/eval_utils.py
View file @
3354081f
...
@@ -20,9 +20,8 @@ import time
...
@@ -20,9 +20,8 @@ import time
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
process_batch
from
tasks.finetune_utils
import
process_batch
...
...
tasks/finetune_utils.py
View file @
3354081f
...
@@ -17,10 +17,9 @@
...
@@ -17,10 +17,9 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
evaluate_and_print_results
...
...
tasks/glue/finetune.py
View file @
3354081f
...
@@ -15,9 +15,8 @@
...
@@ -15,9 +15,8 @@
"""GLUE finetuning/evaluation."""
"""GLUE finetuning/evaluation."""
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
from
megatron.model.classification
import
Classification
from
megatron.model.classification
import
Classification
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment