Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6f72a285
Commit
6f72a285
authored
Jun 28, 2021
by
zihanl
Browse files
add dialog dataset and special tokens in tokenizer
parent
1f7cfdf6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
297 additions
and
3 deletions
+297
-3
dialog_ctrl/dialog_dataset.py
dialog_ctrl/dialog_dataset.py
+121
-0
dialog_ctrl/utils.py
dialog_ctrl/utils.py
+38
-0
megatron/arguments.py
megatron/arguments.py
+17
-0
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+15
-3
train_dialog_gpt.py
train_dialog_gpt.py
+106
-0
No files found.
dialog_ctrl/dialog_dataset.py
0 → 100644
View file @
6f72a285
"""Build Dataset for Controllable Coversational Model"""
import
os
import
torch
import
numpy
as
np
from
megatron
import
get_tokenizer
def
read_data
(
tokenizer
,
data_path
,
train_module
):
"""read and tokenize dialog data"""
data_list
=
[]
with
open
(
data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
length_split
=
len
(
splits
)
assert
length_split
==
2
or
length_split
==
3
or
length_split
==
4
if
train_module
==
"dialog"
:
dialog_context
=
splits
[
0
]
response
=
splits
[
-
1
]
# only take the last three turns in the dialog context
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
context
=
" [SEP] "
.
join
(
turns
)
input_ids
=
tokenizer
.
tokenize
(
context
)
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
train_module
==
"control"
:
if
length_split
==
2
:
continue
dialog_context
=
splits
[
0
]
ctrl_sent
=
splits
[
-
2
]
ctrl_code
=
splits
[
1
]
if
length_split
==
4
else
None
turns
=
dialog_context
.
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
if
ctrl_code
:
inputs
=
last_turn
+
" [CTRL] "
+
ctrl_code
else
:
inputs
=
last_turn
outputs
=
ctrl_sent
input_ids
=
tokenizer
.
tokenize
(
inputs
)
output_ids
=
tokenizer
.
tokenize
(
outputs
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct train-module name! (either dialog or cnotrol))"
)
return
data_list
def
data_shuffle
(
data
,
seed
):
# set random seed to make the shuffling reproducible
np
.
random
.
seed
(
seed
)
np
.
random
.
shuffle
(
data
)
return
data
class
ControlDialogDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data
,
max_seq_len
,
pad_id
,
eod_id
):
# need to deal with padding, label masking
self
.
data
=
data
self
.
max_seq_len
self
.
pad_id
=
pad_id
self
.
eod_id
=
eod_id
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
data_dict
=
self
.
data
[
idx
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max_seq_len!"
# length_of_loss_mask == length_of_text - 1
text
=
input_ids
+
[
self
.
pad_id
]
+
output_ids
+
[
self
.
eod_id
]
loss_mask
=
[
0
]
*
len
(
input_ids
)
+
[
1
]
*
(
len
(
output_ids
)
+
1
)
text_len
=
len
(
text
)
if
text_len
>
self
.
max_seq_len
:
text
=
text
[:
self
.
max_seq_len
]
loss_mask
=
loss_mask
[:
self
.
max_seq_len
-
1
]
else
:
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
-
text_len
)
return
{
"text"
:
np
.
array
(
text
,
dtype
=
np
.
int64
),
"loss_mask"
:
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)}
def
build_train_valid_test_datasets
(
data_folder
,
dataset_name
,
train_module
,
max_seq_len
,
seed
):
"""Build train, valid, and test datasets."""
dataname_dict
=
{
"wizard_of_wikipedia"
:
{
"train"
:
"train_entity_based_control.txt"
,
"valid"
:
"valid_random_split_entity_based_control.txt"
,
"test"
:
"test_random_split_entity_based_control.txt"
}}
train_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"train"
])
valid_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"valid"
])
test_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"test"
])
tokenizer
=
get_tokenizer
()
train_data_list
=
read_data
(
tokenizer
,
train_data_path
,
train_module
)
valid_data_list
=
read_data
(
tokenizer
,
valid_data_path
,
train_module
)
test_data_list
=
read_data
(
tokenizer
,
test_data_path
,
train_module
)
# shuffle the training data
train_data_list
=
data_shuffle
(
train_data_list
,
seed
)
# build train, valid, and test datasets
train_dataset
=
ControlDialogDataset
(
train_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
dialog_ctrl/utils.py
0 → 100644
View file @
6f72a285
import
torch
def
get_ltor_attention_masks_and_position_ids
(
data
,
eod_token_id
):
"""Build attention masks and position id for left to right model."""
micro_batch_size
,
seq_length
=
data
.
size
()
# Attention mask
attention_mask
=
torch
.
tril
(
torch
.
ones
((
micro_batch_size
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# reset attentino mask and position ids
# Loop through the batches:
for
b
in
range
(
micro_batch_size
):
# Find indecies where EOD token is.
eod_index
=
position_ids
[
b
,
data
[
b
]
==
eod_token_id
]
eod_index
=
eod_index
.
clone
()
# Loop through EOD indecies:
prev_index
=
0
for
j
in
range
(
eod_index
.
size
()[
0
]):
i
=
eod_index
[
j
]
# Mask attention loss.
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
# Reset positions.
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
prev_index
=
i
+
1
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
return
attention_mask
,
position_ids
\ No newline at end of file
megatron/arguments.py
View file @
6f72a285
...
@@ -752,3 +752,20 @@ def _add_vit_args(parser):
...
@@ -752,3 +752,20 @@ def _add_vit_args(parser):
help
=
'patch dimension used in vit'
)
help
=
'patch dimension used in vit'
)
return
parser
return
parser
def
_add_dialog_ctrl_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"dialog control"
)
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
help
=
'either control module or dialogue model (control or dialog)'
)
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
help
=
'data folder (path of the data folder)'
)
group
.
add_argument
(
'--dataset-name'
,
type
=
str
,
default
=
""
,
help
=
'dataset name (e.g., wizard_of_wikipedia)'
)
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
help
=
'maximum sequence length'
)
group
.
add_argument
(
'--spec_toks'
,
type
=
str
,
default
=
"[SEP],[CTRL],[PAD]"
,
help
=
'additional special tokens'
)
return
parser
megatron/tokenizer/tokenizer.py
View file @
6f72a285
...
@@ -40,7 +40,7 @@ def build_tokenizer(args):
...
@@ -40,7 +40,7 @@ def build_tokenizer(args):
vocab_extra_ids
=
args
.
vocab_extra_ids
)
vocab_extra_ids
=
args
.
vocab_extra_ids
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
,
special_tokens
=
args
.
spec_toks
)
else
:
else
:
raise
NotImplementedError
(
'{} tokenizer is not '
raise
NotImplementedError
(
'{} tokenizer is not '
'implemented.'
.
format
(
args
.
tokenizer_type
))
'implemented.'
.
format
(
args
.
tokenizer_type
))
...
@@ -260,13 +260,25 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
...
@@ -260,13 +260,25 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
class
_GPT2BPETokenizer
(
AbstractTokenizer
):
class
_GPT2BPETokenizer
(
AbstractTokenizer
):
"""Original GPT2 BPE tokenizer."""
"""Original GPT2 BPE tokenizer."""
def
__init__
(
self
,
vocab_file
,
merge_file
):
def
__init__
(
self
,
vocab_file
,
merge_file
,
special_tokens
=
None
):
name
=
'GPT2 BPE'
name
=
'GPT2 BPE'
super
().
__init__
(
name
)
super
().
__init__
(
name
)
if
special_tokens
is
not
None
:
# for controllable dialog, special_tokens: "[SEP],[CTRL],[PAD]"
special_tokens
=
special_tokens
.
split
(
","
)
else
:
special_tokens
=
[]
self
.
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merge_file
,
errors
=
'replace'
,
self
.
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merge_file
,
errors
=
'replace'
,
special_tokens
=
[]
,
max_len
=
None
)
special_tokens
=
special_tokens
,
max_len
=
None
)
self
.
eod_id
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
self
.
eod_id
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
if
len
(
special_tokens
)
>
0
:
if
"[PAD]"
in
special_tokens
:
self
.
pad_id
=
self
.
tokenizer
.
encoder
[
'[PAD]'
]
if
"[SEP]"
in
special_tokens
:
self
.
sep_id
=
self
.
tokenizer
.
encoder
[
'[SEP]'
]
if
"[CTRL]"
in
special_tokens
:
self
.
ctrl_id
=
self
.
tokenizer
.
encoder
[
'[CTRL]'
]
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
...
...
train_dialog_gpt.py
0 → 100644
View file @
6f72a285
"""Train dialogue model based on GPT"""
import
torch
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
mpu
# from megatron.data.gpt_dataset import build_train_valid_test_datasets
from
dialog_ctrl.dialog_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
# from megatron.utils import get_ltor_masks_and_position_ids
from
dialog_ctrl.utils
import
get_ltor_attention_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
get_batch
(
data_iterator
):
"""Generate a batch"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Items and their type.
keys
=
[
'text'
,
'loss_mask'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
tokens_
=
data_b
[
'text'
].
long
()
labels
=
tokens_
[:,
1
:].
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
loss_mask
=
data_b
[
'loss_mask'
].
float
()
# Get the attention_mask and postition ids.
attention_masks
,
position_ids
=
get_ltor_attention_masks_and_position_ids
(
tokens
,
tokenizer
.
eod_id
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
loss_func
(
loss_mask
,
output_tensor
):
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
def
train_valid_test_datasets_provider
():
"""Build train, valid, and test datasets for control module"""
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
train_module
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_folder
=
args
.
data_folder
,
dataset_name
=
args
.
dataset_name
,
train_module
=
args
.
train_module
,
max_seq_len
=
args
.
max_seq_len
,
seed
=
args
.
seed
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
train_module
)
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment