Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3c363d57
Commit
3c363d57
authored
Jul 01, 2021
by
zihanl
Browse files
update dialog dataset and training
parent
d7509658
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
108 additions
and
54 deletions
+108
-54
dialogctrl/dialog_dataset.py
dialogctrl/dialog_dataset.py
+22
-12
dialogctrl/utils.py
dialogctrl/utils.py
+10
-3
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/checkpointing.py
megatron/checkpointing.py
+15
-13
megatron/training.py
megatron/training.py
+58
-25
train_gpt_conv.py
train_gpt_conv.py
+1
-1
No files found.
dialogctrl/dialog_dataset.py
View file @
3c363d57
...
...
@@ -20,11 +20,16 @@ def read_data(tokenizer, data_path, train_module):
assert
length_split
==
2
or
length_split
==
3
or
length_split
==
4
if
train_module
==
"dialog"
:
# if length_split == 2:
# continue
dialog_context
=
splits
[
0
]
if
length_split
>
2
:
ctrl_sent
=
splits
[
-
2
]
response
=
splits
[
-
1
]
# only take the last three turns in the dialog context
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
#
turns = turns[-3:]
# input_ids
for
idx
,
turn
in
enumerate
(
turns
):
...
...
@@ -33,6 +38,10 @@ def read_data(tokenizer, data_path, train_module):
else
:
input_ids
.
extend
([
tokenizer
.
sep_id
]
+
tokenizer
.
tokenize
(
turn
))
if
length_split
>
2
:
# when there is control sentence, add it into the input_ids
input_ids
.
extend
([
tokenizer
.
ctrl_id
]
+
tokenizer
.
tokenize
(
ctrl_sent
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
...
...
@@ -65,7 +74,7 @@ def read_data(tokenizer, data_path, train_module):
else
:
raise
ValueError
(
"Please input a correct train-module name! (either dialog or cnotrol))"
)
return
data_list
...
...
@@ -78,10 +87,11 @@ def data_shuffle(data, seed):
class
ControlDialogDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data
,
max_seq_len
,
pad_id
,
eod_id
):
def
__init__
(
self
,
data
,
max_seq_len
,
sep_id
,
pad_id
,
eod_id
):
# need to deal with padding, label masking
self
.
data
=
data
self
.
max_seq_len
=
max_seq_len
self
.
sep_id
=
sep_id
self
.
pad_id
=
pad_id
self
.
eod_id
=
eod_id
...
...
@@ -95,16 +105,16 @@ class ControlDialogDataset(torch.utils.data.Dataset):
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
text
=
input_ids
+
[
self
.
pad
_id
]
+
output_ids
+
[
self
.
eod_id
]
text
=
input_ids
+
[
self
.
sep
_id
]
+
output_ids
+
[
self
.
eod_id
]
loss_mask
=
[
0
]
*
len
(
input_ids
)
+
[
1
]
*
(
len
(
output_ids
)
+
1
)
text_len
=
len
(
text
)
if
text_len
>
self
.
max_seq_len
:
text
=
text
[:
self
.
max_seq_len
]
loss_mask
=
loss_mask
[:
self
.
max_seq_len
-
1
]
if
text_len
>
self
.
max_seq_len
+
1
:
text
=
text
[:
self
.
max_seq_len
+
1
]
loss_mask
=
loss_mask
[:
self
.
max_seq_len
]
else
:
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
-
text_len
)
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
return
{
"text"
:
np
.
array
(
text
,
dtype
=
np
.
int64
),
"loss_mask"
:
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)}
...
...
@@ -127,8 +137,8 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
train_data_list
=
data_shuffle
(
train_data_list
,
seed
)
# build train, valid, and test datasets
train_dataset
=
ControlDialogDataset
(
train_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
train_dataset
=
ControlDialogDataset
(
train_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
train_dataset
,
valid_dataset
,
test_dataset
dialogctrl/utils.py
View file @
3c363d57
import
torch
from
megatron
import
print_rank_0
def
get_ltor_attention_masks_and_position_ids
(
data
,
eod_token_id
):
"""Build attention masks and position id for left to right model."""
...
...
@@ -10,12 +10,19 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
# Attention mask
attention_mask
=
torch
.
tril
(
torch
.
ones
((
micro_batch_size
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
# mask padded tokens
for
b
in
range
(
micro_batch_size
):
for
idx
in
range
(
seq_length
-
1
):
if
data
[
b
,
idx
]
==
eod_token_id
:
# pad tokens that come after the eod token
attention_mask
[
b
,
0
,
idx
+
1
:,
:]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# reset attentino mask and position ids
# Loop through the batches:
#
#
reset attentino mask and position ids
#
#
Loop through the batches:
# for b in range(micro_batch_size):
# # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id]
...
...
megatron/arguments.py
View file @
3c363d57
...
...
@@ -760,6 +760,8 @@ def _add_dialog_ctrl_args(parser):
group
.
add_argument
(
'--run-dialog'
,
action
=
'store_true'
,
help
=
'run dialog modeling'
)
group
.
add_argument
(
'--num-epoch'
,
type
=
int
,
default
=
30
,
help
=
'number of epoches to train the model'
)
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
help
=
'either control module or dialogue model (control or dialog)'
)
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
...
...
megatron/checkpointing.py
View file @
3c363d57
...
...
@@ -344,19 +344,21 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0
(
f
' checkpoint version
{
checkpoint_version
}
'
)
fix_query_key_value_ordering
(
model
,
checkpoint_version
)
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
if
not
args
.
run_dialog
:
# Original pre-train GPT setting
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
# rng states.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_rng
:
...
...
megatron/training.py
View file @
3c363d57
...
...
@@ -138,27 +138,57 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0
(
'training ...'
)
iteration
=
0
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
do_test
:
# Run on test data.
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
0
,
True
)
if
not
args
.
run_dialog
:
# original pre-training for GPT
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
do_test
:
# Run on test data.
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
0
,
True
)
else
:
# training for dialog/control model
timers
(
'interval-time'
).
start
()
# start timers('interval-time') here to avoid it from starting multiple times
for
e
in
range
(
args
.
num_epoch
):
print_rank_0
(
'> training on epoch %d'
%
(
e
+
1
))
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
+=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
False
)
if
e
>=
8
and
e
<=
13
and
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
do_test
:
# Run on test data.
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
0
,
True
)
def
update_train_iters
(
args
):
...
...
@@ -611,7 +641,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration
=
args
.
iteration
timers
(
'interval-time'
).
start
()
if
not
args
.
run_dialog
:
timers
(
'interval-time'
).
start
()
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
...
...
@@ -813,9 +845,10 @@ def build_train_valid_test_data_iterators(
print_rank_0
(
' validation: {}'
.
format
(
valid_size
))
print_rank_0
(
' test: {}'
.
format
(
test_size
))
args
.
train_iters
=
train_size
//
args
.
global_batch_size
args
.
eval_iters
=
valid_size
//
args
.
global_batch_size
args
.
test_iters
=
test_size
//
args
.
global_batch_size
batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
args
.
train_iters
=
train_size
//
batch_size
+
1
args
.
eval_iters
=
valid_size
//
batch_size
+
1
args
.
test_iters
=
test_size
//
batch_size
+
1
else
:
# Number of train/valid/test samples.
...
...
train_gpt_conv.py
View file @
3c363d57
...
...
@@ -53,7 +53,7 @@ def get_batch(data_iterator):
# Get the attention_mask and postition ids.
attention_mask
,
position_ids
=
get_ltor_attention_masks_and_position_ids
(
tokens
,
tokenizer
.
eod_id
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment