Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b1efc33d
Commit
b1efc33d
authored
Mar 26, 2020
by
Neel Kant
Browse files
Modify pretrain_bert_ict.py to work with ICTBertModel
parent
371d2ea9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
32 deletions
+28
-32
megatron/mpu/data.py
megatron/mpu/data.py
+1
-1
pretrain_bert_ict.py
pretrain_bert_ict.py
+27
-31
No files found.
megatron/mpu/data.py
View file @
b1efc33d
...
@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype):
...
@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype):
members of the same model parallel group.
members of the same model parallel group.
Arguments:
Arguments:
keys: list of keys in the data di
s
ctionary to be broadcasted
keys: list of keys in the data dictionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
datatype: torch data type of all tensors in data associated
with keys.
with keys.
...
...
pretrain_bert_ict.py
View file @
b1efc33d
...
@@ -20,7 +20,7 @@ import torch.nn.functional as F
...
@@ -20,7 +20,7 @@ import torch.nn.functional as F
from
configure_data
import
configure_data
from
configure_data
import
configure_data
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model
import
BertModel
from
megatron.model
import
ICT
BertModel
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.utils
import
vocab_size_with_padding
...
@@ -30,9 +30,9 @@ from megatron.training import run
...
@@ -30,9 +30,9 @@ from megatron.training import run
def
model_provider
(
args
):
def
model_provider
(
args
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building BERT model ...'
)
print_rank_0
(
'building BERT model
s
...'
)
model
=
BertModel
(
model
=
ICT
BertModel
(
num_layers
=
args
.
num_layers
,
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
vocab_size
=
args
.
vocab_size
,
hidden_size
=
args
.
hidden_size
,
hidden_size
=
args
.
hidden_size
,
...
@@ -42,8 +42,8 @@ def model_provider(args):
...
@@ -42,8 +42,8 @@ def model_provider(args):
output_dropout_prob
=
args
.
hidden_dropout
,
output_dropout_prob
=
args
.
hidden_dropout
,
max_sequence_length
=
args
.
max_position_embeddings
,
max_sequence_length
=
args
.
max_position_embeddings
,
checkpoint_activations
=
args
.
checkpoint_activations
,
checkpoint_activations
=
args
.
checkpoint_activations
,
ict_head_size
=
128
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
add_binary_head
=
True
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
num_tokentypes
=
args
.
tokentype_size
,
num_tokentypes
=
args
.
tokentype_size
,
parallel_output
=
True
,
parallel_output
=
True
,
...
@@ -56,27 +56,30 @@ def model_provider(args):
...
@@ -56,27 +56,30 @@ def model_provider(args):
def
get_batch
(
data_iterator
,
timers
):
def
get_batch
(
data_iterator
,
timers
):
# Items and their type.
# Items and their type.
keys
=
[
'text'
,
'types'
,
'is_random'
,
'mask'
,
'mask_labels'
,
'pad_mask'
]
keys
=
[
'input_text'
,
'input_types'
,
'input_pad_mask'
,
'context_text'
,
'context_types'
,
'context_pad_mask'
]
datatype
=
torch
.
int64
datatype
=
torch
.
int64
# Broadcast data.
# Broadcast data.
timers
(
'data loader'
).
start
()
timers
(
'data loader'
).
start
()
if
data_iterator
is
not
None
:
if
data_iterator
is
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
data
=
None
else
:
data
=
next
(
data_iterator
)
timers
(
'data loader'
).
stop
()
timers
(
'data loader'
).
stop
()
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
tokens
=
data_b
[
'text'
].
long
()
input_
tokens
=
data_b
[
'
input_
text'
].
long
()
types
=
data_b
[
'types'
].
long
()
input_
types
=
data_b
[
'
input_
types'
].
long
()
next_sentence
=
data_b
[
'is_random
'
].
long
()
input_pad_mask
=
data_b
[
'input_pad_mask
'
].
long
()
loss_mask
=
data_b
[
'mask
'
].
f
lo
at
()
context_tokens
=
data_b
[
'context_text
'
].
lo
ng
()
lm_labels
=
data_b
[
'mask_label
s'
].
long
()
context_types
=
data_b
[
'context_type
s'
].
long
()
padding
_mask
=
data_b
[
'pad_mask'
].
long
()
context_pad
_mask
=
data_b
[
'
context_
pad_mask'
].
long
()
return
tokens
,
types
,
next_sentence
,
loss_mask
,
lm_labels
,
padding_mask
return
input_tokens
,
input_types
,
input_pad_mask
,
\
context_tokens
,
context_types
,
context_pad_mask
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
...
@@ -84,27 +87,20 @@ def forward_step(data_iterator, model, args, timers):
...
@@ -84,27 +87,20 @@ def forward_step(data_iterator, model, args, timers):
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
tokens
,
types
,
next_sentence
,
loss_mask
,
lm_labels
,
padding
_mask
\
input_
tokens
,
input_
types
,
input_pad
_mask
,
\
=
get_batch
(
data_iterator
,
timers
)
context_tokens
,
context_types
,
context_pad_mask
=
get_batch
(
data_iterator
,
timers
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
lm_logits
,
nsp_logits
=
model
(
tokens
,
1
-
padding_mask
,
tokentype_ids
=
types
)
retrieval_scores
=
model
(
input_tokens
,
1
-
input_pad_mask
,
input_types
,
context_tokens
,
1
-
context_pad_mask
,
context_types
)
nsp_loss
=
F
.
cross_entropy
(
nsp_logits
.
view
(
-
1
,
2
).
contiguous
().
float
(),
next_sentence
.
view
(
-
1
).
contiguous
(),
ignore_index
=-
1
)
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
lm_labels
.
contiguous
())
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
loss
=
lm_loss
+
nsp_loss
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
0
).
float
()
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
softmaxed
.
size
()[
0
]))
reduced_losses
=
reduce_losses
([
lm_loss
,
nsp
_loss
])
reduced_losses
=
reduce_losses
([
retrieval
_loss
])
return
loss
,
{
'lm loss'
:
reduced_losses
[
0
],
'nsp
loss'
:
reduced_losses
[
1
]}
return
retrieval_loss
,
{
'retrieval
loss'
:
reduced_losses
[
0
]}
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
(
args
):
...
@@ -152,5 +148,5 @@ def get_train_val_test_data(args):
...
@@ -152,5 +148,5 @@ def get_train_val_test_data(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
run
(
'Pretrain BERT model'
,
get_train_val_test_data
,
run
(
'Pretrain
ICT
BERT model'
,
get_train_val_test_data
,
model_provider
,
forward_step
)
model_provider
,
forward_step
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment