Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
dcff1acd
Commit
dcff1acd
authored
Jan 25, 2021
by
Mohammad Shoeybi
Committed by
Jared Casper
Jan 25, 2021
Browse files
Adding option to remove the binary head for BERT
parent
1aa2e08a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
89 additions
and
43 deletions
+89
-43
megatron/arguments.py
megatron/arguments.py
+3
-0
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+23
-9
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+11
-7
megatron/data/helpers.cpp
megatron/data/helpers.cpp
+18
-8
pretrain_bert.py
pretrain_bert.py
+34
-19
No files found.
megatron/arguments.py
View file @
dcff1acd
...
@@ -262,6 +262,9 @@ def _add_network_size_args(parser):
...
@@ -262,6 +262,9 @@ def _add_network_size_args(parser):
'reasons.'
)
'reasons.'
)
group
.
add_argument
(
'--onnx-safe'
,
type
=
bool
,
required
=
False
,
group
.
add_argument
(
'--onnx-safe'
,
type
=
bool
,
required
=
False
,
help
=
'Use workarounds for known problems with Torch ONNX exporter'
)
help
=
'Use workarounds for known problems with Torch ONNX exporter'
)
group
.
add_argument
(
'--bert-no-binary-head'
,
action
=
'store_false'
,
help
=
'Disable BERT binary head.'
,
dest
=
'bert_binary_head'
)
return
parser
return
parser
...
...
megatron/data/bert_dataset.py
View file @
dcff1acd
...
@@ -36,13 +36,14 @@ class BertDataset(Dataset):
...
@@ -36,13 +36,14 @@ class BertDataset(Dataset):
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
max_seq_length
,
short_seq_prob
,
seed
,
binary_head
):
# Params to store.
# Params to store.
self
.
name
=
name
self
.
name
=
name
self
.
seed
=
seed
self
.
seed
=
seed
self
.
masked_lm_prob
=
masked_lm_prob
self
.
masked_lm_prob
=
masked_lm_prob
self
.
max_seq_length
=
max_seq_length
self
.
max_seq_length
=
max_seq_length
self
.
binary_head
=
binary_head
# Dataset.
# Dataset.
self
.
indexed_dataset
=
indexed_dataset
self
.
indexed_dataset
=
indexed_dataset
...
@@ -55,7 +56,8 @@ class BertDataset(Dataset):
...
@@ -55,7 +56,8 @@ class BertDataset(Dataset):
self
.
max_seq_length
,
self
.
max_seq_length
,
short_seq_prob
,
short_seq_prob
,
self
.
seed
,
self
.
seed
,
self
.
name
)
self
.
name
,
self
.
binary_head
)
# Vocab stuff.
# Vocab stuff.
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -81,7 +83,8 @@ class BertDataset(Dataset):
...
@@ -81,7 +83,8 @@ class BertDataset(Dataset):
self
.
vocab_id_to_token_dict
,
self
.
vocab_id_to_token_dict
,
self
.
cls_id
,
self
.
sep_id
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
self
.
masked_lm_prob
,
np_rng
,
self
.
binary_head
)
def
get_samples_mapping_
(
indexed_dataset
,
def
get_samples_mapping_
(
indexed_dataset
,
...
@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
...
@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length
,
max_seq_length
,
short_seq_prob
,
short_seq_prob
,
seed
,
seed
,
name
):
name
,
binary_head
):
if
not
num_epochs
:
if
not
num_epochs
:
if
not
max_num_samples
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
raise
ValueError
(
"Need to specify either max_num_samples "
...
@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
...
@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length
-
3
,
# account for added tokens
max_seq_length
-
3
,
# account for added tokens
short_seq_prob
,
short_seq_prob
,
seed
,
seed
,
verbose
)
verbose
,
2
if
binary_head
else
1
)
print_rank_0
(
' > done building sapmles index maping'
)
print_rank_0
(
' > done building sapmles index maping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
...
@@ -173,7 +178,7 @@ def build_training_sample(sample,
...
@@ -173,7 +178,7 @@ def build_training_sample(sample,
target_seq_length
,
max_seq_length
,
target_seq_length
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
masked_lm_prob
,
np_rng
,
binary_head
):
"""Biuld training sample.
"""Biuld training sample.
Arguments:
Arguments:
...
@@ -193,12 +198,21 @@ def build_training_sample(sample,
...
@@ -193,12 +198,21 @@ def build_training_sample(sample,
the opper bound whereas the numpy one is exclusive.
the opper bound whereas the numpy one is exclusive.
"""
"""
# We assume that we have at least two sentences in the sample
if
binary_head
:
assert
len
(
sample
)
>
1
# We assume that we have at least two sentences in the sample
assert
len
(
sample
)
>
1
assert
target_seq_length
<=
max_seq_length
assert
target_seq_length
<=
max_seq_length
# Divide sample into two segments (A and B).
# Divide sample into two segments (A and B).
tokens_a
,
tokens_b
,
is_next_random
=
get_a_and_b_segments
(
sample
,
np_rng
)
if
binary_head
:
tokens_a
,
tokens_b
,
is_next_random
=
get_a_and_b_segments
(
sample
,
np_rng
)
else
:
tokens_a
=
[]
for
j
in
range
(
len
(
sample
)):
tokens_a
.
extend
(
sample
[
j
])
tokens_b
=
[]
is_next_random
=
False
# Truncate to `target_sequence_length`.
# Truncate to `target_sequence_length`.
max_num_tokens
=
target_seq_length
max_num_tokens
=
target_seq_length
...
...
megatron/data/dataset_utils.py
View file @
dcff1acd
...
@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
...
@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
"""Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
#print(len_a, len_b, max_num_tokens)
assert
len_a
>
0
assert
len_a
>
0
assert
len_b
>
0
if
len_a
+
len_b
<=
max_num_tokens
:
if
len_a
+
len_b
<=
max_num_tokens
:
return
False
return
False
while
len_a
+
len_b
>
max_num_tokens
:
while
len_a
+
len_b
>
max_num_tokens
:
...
@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
...
@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
for
token
in
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
tokentypes
.
append
(
1
)
tokentypes
.
append
(
1
)
# [SEP].
if
tokens_b
:
tokens
.
append
(
sep_id
)
# [SEP].
tokentypes
.
append
(
1
)
tokens
.
append
(
sep_id
)
tokentypes
.
append
(
1
)
return
tokens
,
tokentypes
return
tokens
,
tokentypes
...
@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
short_seq_prob
,
seed
,
skip_warmup
,
binary_head
,
dataset_type
=
'standard_bert'
):
dataset_type
=
'standard_bert'
):
if
len
(
data_prefix
)
==
1
:
if
len
(
data_prefix
)
==
1
:
...
@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length
,
masked_lm_prob
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
short_seq_prob
,
seed
,
skip_warmup
,
skip_warmup
,
binary_head
,
dataset_type
=
dataset_type
)
dataset_type
=
dataset_type
)
# Blending dataset.
# Blending dataset.
# Parse the values.
# Parse the values.
...
@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes
[
i
],
data_impl
,
splits_string
,
prefixes
[
i
],
data_impl
,
splits_string
,
datasets_train_valid_test_num_samples
[
i
],
datasets_train_valid_test_num_samples
[
i
],
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
dataset_type
)
seed
,
skip_warmup
,
binary_head
,
dataset_type
=
dataset_type
)
if
train_ds
:
if
train_ds
:
train_datasets
.
append
(
train_ds
)
train_datasets
.
append
(
train_ds
)
if
valid_ds
:
if
valid_ds
:
...
@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
short_seq_prob
,
seed
,
skip_warmup
,
binary_head
,
dataset_type
=
'standard_bert'
):
dataset_type
=
'standard_bert'
):
if
dataset_type
not
in
DSET_TYPES
:
if
dataset_type
not
in
DSET_TYPES
:
...
@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs
=
None
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_seq_length
=
max_seq_length
,
max_seq_length
=
max_seq_length
,
seed
=
seed
seed
=
seed
,
binary_head
=
binary_head
)
)
if
dataset_type
==
DSET_TYPE_ICT
:
if
dataset_type
==
DSET_TYPE_ICT
:
...
...
megatron/data/helpers.cpp
View file @
dcff1acd
...
@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
...
@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const
int32_t
max_length
,
const
int32_t
max_length
,
std
::
mt19937
&
rand32_gen
)
{
std
::
mt19937
&
rand32_gen
)
{
/* Training sample length. */
/* Training sample length. */
if
(
short_seq_ratio
==
0
)
{
return
max_length
;
}
const
auto
random_number
=
rand32_gen
();
const
auto
random_number
=
rand32_gen
();
if
((
random_number
%
short_seq_ratio
)
==
0
)
{
if
((
random_number
%
short_seq_ratio
)
==
0
)
{
return
2
+
random_number
%
(
max_length
-
1
);
return
2
+
random_number
%
(
max_length
-
1
);
...
@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const
int32_t
max_seq_length
,
const
int32_t
max_seq_length
,
const
double
short_seq_prob
,
const
double
short_seq_prob
,
const
int32_t
seed
,
const
int32_t
seed
,
const
bool
verbose
)
{
const
bool
verbose
,
const
int32_t
min_num_sent
)
{
/* Build a mapping of (start-index, end-index, sequence-length) where
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
and sequence-length is the target sequence length.
...
@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// Consistency checks.
// Consistency checks.
assert
(
num_epochs
>
0
);
assert
(
num_epochs
>
0
);
assert
(
max_seq_length
>
1
);
assert
(
max_seq_length
>
1
);
assert
(
short_seq_prob
>
0.0
);
assert
(
short_seq_prob
>
=
0.0
);
assert
(
short_seq_prob
<=
1.0
);
assert
(
short_seq_prob
<=
1.0
);
assert
(
seed
>
0
);
assert
(
seed
>
0
);
...
@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
auto
sizes
=
sizes_
.
unchecked
<
1
>
();
auto
sizes
=
sizes_
.
unchecked
<
1
>
();
// For efficiency, convert probability to ratio. Note: rand() generates int.
// For efficiency, convert probability to ratio. Note: rand() generates int.
const
auto
short_seq_ratio
=
static_cast
<
int32_t
>
(
round
(
1.0
/
short_seq_prob
));
int32_t
short_seq_ratio
=
0
;
if
(
short_seq_prob
>
0
)
{
short_seq_ratio
=
static_cast
<
int32_t
>
(
round
(
1.0
/
short_seq_prob
));
}
if
(
verbose
)
{
if
(
verbose
)
{
const
auto
sent_start_index
=
docs
[
0
];
const
auto
sent_start_index
=
docs
[
0
];
...
@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}
}
// If we have more than two sentences.
// If we have more than two sentences.
if
((
num_remain_sent
>
1
)
&&
(
!
contains_long_sentence
))
{
if
((
num_remain_sent
>
=
min_num_sent
)
&&
(
!
contains_long_sentence
))
{
// Set values.
// Set values.
auto
seq_len
=
int32_t
{
0
};
auto
seq_len
=
int32_t
{
0
};
...
@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
...
@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// and if we have reached end of the document.
// and if we have reached end of the document.
if
(((
seq_len
>=
target_seq_len
)
&&
if
(((
seq_len
>=
target_seq_len
)
&&
(
num_remain_sent
>
1
)
&&
(
num_remain_sent
>
1
)
&&
(
num_sent
>
1
)
)
||
(
num_remain_sent
==
0
))
{
(
num_sent
>
=
min_num_sent
)
)
||
(
num_remain_sent
==
0
))
{
// Check for overflow.
// Check for overflow.
if
((
3
*
map_index
+
2
)
>
if
((
3
*
map_index
+
2
)
>
...
@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
...
@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const
int
max_seq_length
,
const
int
max_seq_length
,
const
double
short_seq_prob
,
const
double
short_seq_prob
,
const
int
seed
,
const
int
seed
,
const
bool
verbose
)
{
const
bool
verbose
,
const
int32_t
min_num_sent
)
{
if
(
sizes_
.
size
()
>
std
::
numeric_limits
<
uint32_t
>::
max
())
{
if
(
sizes_
.
size
()
>
std
::
numeric_limits
<
uint32_t
>::
max
())
{
if
(
verbose
)
{
if
(
verbose
)
{
...
@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
...
@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
}
}
return
build_mapping_impl
<
uint64_t
>
(
docs_
,
sizes_
,
num_epochs
,
return
build_mapping_impl
<
uint64_t
>
(
docs_
,
sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
verbose
);
short_seq_prob
,
seed
,
verbose
,
min_num_sent
);
}
else
{
}
else
{
if
(
verbose
)
{
if
(
verbose
)
{
cout
<<
" using uint32 for data mapping..."
<<
endl
<<
std
::
flush
;
cout
<<
" using uint32 for data mapping..."
<<
endl
<<
std
::
flush
;
}
}
return
build_mapping_impl
<
uint32_t
>
(
docs_
,
sizes_
,
num_epochs
,
return
build_mapping_impl
<
uint32_t
>
(
docs_
,
sizes_
,
num_epochs
,
max_num_samples
,
max_seq_length
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
,
verbose
);
short_seq_prob
,
seed
,
verbose
,
min_num_sent
);
}
}
}
}
...
...
pretrain_bert.py
View file @
dcff1acd
...
@@ -23,7 +23,10 @@ from megatron import print_rank_0
...
@@ -23,7 +23,10 @@ from megatron import print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
BertModel
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
from
megatron.model
import
(
BertModel
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
)
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
@@ -34,23 +37,24 @@ def model_provider():
...
@@ -34,23 +37,24 @@ def model_provider():
print_rank_0
(
'building BERT model ...'
)
print_rank_0
(
'building BERT model ...'
)
args
=
get_args
()
args
=
get_args
()
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# Determine model based on position of stage in pipeline.
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
model
=
BertModelFirstStage
(
model
=
BertModelFirstStage
(
num_tokentypes
=
2
)
num_tokentypes
=
num_tokentypes
)
elif
mpu
.
is_pipeline_last_stage
():
elif
mpu
.
is_pipeline_last_stage
():
model
=
BertModelLastStage
(
model
=
BertModelLastStage
(
num_tokentypes
=
2
,
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
True
,
add_binary_head
=
args
.
bert_binary_head
,
parallel_output
=
True
)
parallel_output
=
True
)
else
:
else
:
model
=
BertModelIntermediateStage
(
model
=
BertModelIntermediateStage
(
num_tokentypes
=
2
)
num_tokentypes
=
num_tokentypes
)
else
:
else
:
model
=
BertModel
(
model
=
BertModel
(
num_tokentypes
=
2
,
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
True
,
add_binary_head
=
args
.
bert_binary_head
,
parallel_output
=
True
)
parallel_output
=
True
)
return
model
return
model
...
@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor):
=
get_batch
(
data_iterator
)
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
if
not
args
.
bert_binary_head
:
types
=
None
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
assert
input_tensor
is
None
assert
input_tensor
is
None
...
@@ -109,22 +116,29 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -109,22 +116,29 @@ def forward_step(data_iterator, model, input_tensor):
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
lm_loss_
,
sop_logits
=
output_tensor
lm_loss_
,
sop_logits
=
output_tensor
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sentence_order
.
view
(
-
1
),
ignore_index
=-
1
)
sop_loss
=
sop_loss
.
float
()
lm_loss_
=
lm_loss_
.
float
()
lm_loss_
=
lm_loss_
.
float
()
loss_mask
=
loss_mask
.
float
()
loss_mask
=
loss_mask
.
float
()
lm_loss
=
torch
.
sum
(
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
if
sop_logits
is
not
None
:
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sentence_order
.
view
(
-
1
),
ignore_index
=-
1
)
sop_loss
=
sop_loss
.
float
()
loss
=
lm_loss
+
sop_loss
averaged_losses
=
average_losses_across_data_parallel_group
(
[
lm_loss
,
sop_loss
])
return
loss
,
{
'lm loss'
:
averaged_losses
[
0
],
'sop loss'
:
averaged_losses
[
1
]}
else
:
loss
=
lm_loss
averaged_losses
=
average_losses_across_data_parallel_group
(
[
lm_loss
])
return
loss
,
{
'lm loss'
:
averaged_losses
[
0
]}
loss
=
lm_loss
+
sop_loss
averaged_losses
=
average_losses_across_data_parallel_group
([
lm_loss
,
sop_loss
])
return
loss
,
{
'lm loss'
:
averaged_losses
[
0
],
'sop loss'
:
averaged_losses
[
1
]}
return
output_tensor
return
output_tensor
...
@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob
=
args
.
mask_prob
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
))
skip_warmup
=
(
not
args
.
mmap_warmup
),
binary_head
=
args
.
bert_binary_head
)
print_rank_0
(
"> finished creating BERT datasets ..."
)
print_rank_0
(
"> finished creating BERT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
return
train_ds
,
valid_ds
,
test_ds
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment