Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
5aa3d5e3
Commit
5aa3d5e3
authored
Jul 05, 2021
by
zihanl
Browse files
add dialctrl into tasks
parent
3c363d57
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
425 additions
and
51 deletions
+425
-51
dialogctrl/dialog_dataset.py
dialogctrl/dialog_dataset.py
+36
-12
megatron/arguments.py
megatron/arguments.py
+20
-20
megatron/checkpointing.py
megatron/checkpointing.py
+15
-15
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+9
-0
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+2
-1
megatron/training.py
megatron/training.py
+7
-3
tasks/dialctrl/data.py
tasks/dialctrl/data.py
+168
-0
tasks/dialctrl/finetune.py
tasks/dialctrl/finetune.py
+109
-0
tasks/dialctrl/utils.py
tasks/dialctrl/utils.py
+45
-0
tasks/main.py
tasks/main.py
+14
-0
No files found.
dialogctrl/dialog_dataset.py
View file @
5aa3d5e3
...
...
@@ -29,18 +29,22 @@ def read_data(tokenizer, data_path, train_module):
response
=
splits
[
-
1
]
# only take the last three turns in the dialog context
turns
=
dialog_context
.
split
(
" [SEP] "
)
#
turns = turns[-3:]
turns
=
turns
[
-
3
:]
# input_ids
for
idx
,
turn
in
enumerate
(
turns
):
if
not
(
turn
.
endswith
(
"?"
)
or
turn
.
endswith
(
"."
)
or
turn
.
endswith
(
"!"
)):
turn
=
turn
+
" ."
if
idx
==
0
:
input_ids
=
tokenizer
.
tokenize
(
turn
)
else
:
input_ids
.
extend
([
tokenizer
.
sep_id
]
+
tokenizer
.
tokenize
(
turn
))
# input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
if
length_split
>
2
:
# when there is control sentence, add it into the input_ids
input_ids
.
extend
([
tokenizer
.
ctrl_id
]
+
tokenizer
.
tokenize
(
ctrl_sent
))
# input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
ctrl_sent
+
" ) ."
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
...
...
@@ -55,16 +59,35 @@ def read_data(tokenizer, data_path, train_module):
ctrl_code
=
splits
[
1
]
if
length_split
==
4
else
None
turns
=
dialog_context
.
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
# input_ids
# last_turn = turns[-1]
# turns = turns[-3:]
# for idx, turn in enumerate(turns):
# if idx == 0:
# input_ids = tokenizer.tokenize(turn)
# else:
# # input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# input_ids.extend(tokenizer.tokenize(turn))
# # input_ids
# if ctrl_code:
# ctrl_code_list = ctrl_code.split(" [CTRL] ")
# for code in ctrl_code_list:
# # input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
# input_ids.extend(tokenizer.tokenize(code + " ."))
# put control code at the begginning
input_ids
=
[]
if
ctrl_code
:
input_ids
=
tokenizer
.
tokenize
(
last_turn
)
ctrl_code_list
=
ctrl_code
.
split
(
" [CTRL] "
)
for
code
in
ctrl_code_list
:
input_ids
.
extend
([
tokenizer
.
ctrl_id
]
+
tokenizer
.
tokenize
(
code
))
else
:
input_ids
=
tokenizer
.
tokenize
(
last_turn
)
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
code
+
" )"
))
turns
=
turns
[
-
3
:]
for
turn
in
turns
:
if
not
(
turn
.
endswith
(
"?"
)
or
turn
.
endswith
(
"."
)
or
turn
.
endswith
(
"!"
)):
turn
=
turn
+
" ."
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
# output_ids
outputs
=
ctrl_sent
...
...
@@ -105,8 +128,9 @@ class ControlDialogDataset(torch.utils.data.Dataset):
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
text
=
input_ids
+
[
self
.
sep_id
]
+
output_ids
+
[
self
.
eod_id
]
loss_mask
=
[
0
]
*
len
(
input_ids
)
+
[
1
]
*
(
len
(
output_ids
)
+
1
)
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
text
=
input_ids
+
output_ids
+
[
self
.
eod_id
]
loss_mask
=
[
0
]
*
(
len
(
input_ids
)
-
1
)
+
[
1
]
*
(
len
(
output_ids
)
+
1
)
text_len
=
len
(
text
)
if
text_len
>
self
.
max_seq_len
+
1
:
...
...
megatron/arguments.py
View file @
5aa3d5e3
...
...
@@ -41,7 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_dialog_ctrl_args
(
parser
)
#
parser = _add_dialog_ctrl_args(parser)
# Custom arguments.
if
extra_args_provider
is
not
None
:
...
...
@@ -755,22 +755,22 @@ def _add_vit_args(parser):
return
parser
def
_add_dialog_ctrl_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"dialog control"
)
group
.
add_argument
(
'--run-dialog'
,
action
=
'store_true'
,
help
=
'run dialog modeling'
)
group
.
add_argument
(
'--num-epoch'
,
type
=
int
,
default
=
30
,
help
=
'number of epoches to train the model'
)
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
help
=
'either control module or dialogue model (control or dialog)'
)
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
help
=
'data folder (path of the data folder)'
)
group
.
add_argument
(
'--dataset-name'
,
type
=
str
,
default
=
""
,
help
=
'dataset name (e.g., wizard_of_wikipedia)'
)
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
help
=
'maximum sequence length'
)
group
.
add_argument
(
'--spec-toks'
,
type
=
str
,
default
=
"[SEP],[CTRL],[PAD]"
,
help
=
'additional special tokens'
)
return
parser
#
def _add_dialog_ctrl_args(parser):
#
group = parser.add_argument_group(title="dialog control")
#
group.add_argument('--run-dialog', action='store_true',
#
help='run dialog modeling')
#
group.add_argument('--num-epoch', type=int, default=30,
#
help='number of epoches to train the model')
#
group.add_argument('--train-module', type=str, default="",
#
help='either control module or dialogue model (control or dialog)')
#
group.add_argument('--data-folder', type=str, default="",
#
help='data folder (path of the data folder)')
#
group.add_argument('--dataset-name', type=str, default="",
#
help='dataset name (e.g., wizard_of_wikipedia)')
#
group.add_argument('--max-seq-len', type=int, default=1024,
#
help='maximum sequence length')
#
group.add_argument('--spec-toks', type=str, default="[SEP],[CTRL],[PAD]",
#
help='additional special tokens')
#
return parser
megatron/checkpointing.py
View file @
5aa3d5e3
...
...
@@ -344,21 +344,21 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0
(
f
' checkpoint version
{
checkpoint_version
}
'
)
fix_query_key_value_ordering
(
model
,
checkpoint_version
)
if
not
args
.
run_dialog
:
# Original pre-train GPT setting
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
#
if not args.run_dialog:
# Original pre-train GPT setting
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
# rng states.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_rng
:
...
...
megatron/text_generation_utils.py
View file @
5aa3d5e3
...
...
@@ -247,6 +247,7 @@ def generate_samples_interactive(model, print_frequency=24):
terminate_runs
=
1
else
:
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
# context_tokens = context_tokens + [tokenizer.sep_id]
context_length
=
len
(
context_tokens
)
if
context_length
>=
(
args
.
seq_length
//
2
):
...
...
@@ -299,9 +300,14 @@ def generate_samples_interactive(model, print_frequency=24):
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
decode_tokens
,
_
=
decode_tokens
# print("tokenzied inputs:", tokenizer.tokenize(raw_text))
# print("decode_tokens:", decode_tokens)
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
# trim_decode_tokens = tokenizer.detokenize(
# decode_tokens[context_length:])
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
is_pipeline_first_stage
()
\
...
...
@@ -314,6 +320,9 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
# print("decode_tokens:", decode_tokens)
# trim_decode_tokens = tokenizer.detokenize(
# decode_tokens[context_length:])
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
input
(
"
\n
Press Enter to continue >>>"
)
...
...
megatron/tokenizer/tokenizer.py
View file @
5aa3d5e3
...
...
@@ -41,6 +41,7 @@ def build_tokenizer(args):
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
,
special_tokens
=
args
.
spec_toks
)
# tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
else
:
raise
NotImplementedError
(
'{} tokenizer is not '
'implemented.'
.
format
(
args
.
tokenizer_type
))
...
...
@@ -273,7 +274,7 @@ class _GPT2BPETokenizer(AbstractTokenizer):
special_tokens
=
special_tokens
,
max_len
=
None
)
self
.
eod_id
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
if
len
(
special_tokens
)
>
0
:
if
special_tokens
is
not
None
and
len
(
special_tokens
)
>
0
:
if
"[SEP]"
in
special_tokens
:
self
.
sep_id
=
self
.
tokenizer
.
special_tokens
[
'[SEP]'
]
if
"[CTRL]"
in
special_tokens
:
...
...
megatron/training.py
View file @
5aa3d5e3
...
...
@@ -180,8 +180,12 @@ def pretrain(train_valid_test_dataset_provider,
valid_data_iterator
,
model
,
iteration
,
False
)
if
e
>=
8
and
e
<=
13
and
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
# if args.train_module == "dialog":
# if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
if
args
.
train_module
==
"control"
:
if
(
e
+
1
)
>=
5
and
(
e
+
1
)
<=
9
and
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
if
args
.
do_test
:
# Run on test data.
...
...
@@ -845,7 +849,7 @@ def build_train_valid_test_data_iterators(
print_rank_0
(
' validation: {}'
.
format
(
valid_size
))
print_rank_0
(
' test: {}'
.
format
(
test_size
))
batch_size
=
args
.
micro
_batch_size
*
args
.
data_parallel_size
batch_size
=
args
.
global
_batch_size
args
.
train_iters
=
train_size
//
batch_size
+
1
args
.
eval_iters
=
valid_size
//
batch_size
+
1
args
.
test_iters
=
test_size
//
batch_size
+
1
...
...
tasks/dialctrl/data.py
0 → 100644
View file @
5aa3d5e3
"""Build Dataset for Controllable Coversational Model"""
import
os
import
torch
import
numpy
as
np
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
def
read_data
(
tokenizer
,
data_path
,
train_module
):
"""read and tokenize dialog data"""
data_list
=
[]
with
open
(
data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
length_split
=
len
(
splits
)
assert
length_split
==
2
or
length_split
==
3
or
length_split
==
4
if
train_module
==
"dialog"
:
# if length_split == 2:
# continue
dialog_context
=
splits
[
0
]
if
length_split
>
2
:
ctrl_sent
=
splits
[
-
2
]
response
=
splits
[
-
1
]
# only take the last three turns in the dialog context
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
# input_ids
for
idx
,
turn
in
enumerate
(
turns
):
if
not
(
turn
.
endswith
(
"?"
)
or
turn
.
endswith
(
"."
)
or
turn
.
endswith
(
"!"
)):
turn
=
turn
+
" ."
if
idx
==
0
:
input_ids
=
tokenizer
.
tokenize
(
turn
)
else
:
# input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
if
length_split
>
2
:
# when there is control sentence, add it into the input_ids
# input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
ctrl_sent
+
" ) ."
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
train_module
==
"control"
:
if
length_split
==
2
:
continue
dialog_context
=
splits
[
0
]
ctrl_sent
=
splits
[
-
2
]
ctrl_code
=
splits
[
1
]
if
length_split
==
4
else
None
turns
=
dialog_context
.
split
(
" [SEP] "
)
# last_turn = turns[-1]
# turns = turns[-3:]
# for idx, turn in enumerate(turns):
# if idx == 0:
# input_ids = tokenizer.tokenize(turn)
# else:
# # input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# input_ids.extend(tokenizer.tokenize(turn))
# # input_ids
# if ctrl_code:
# ctrl_code_list = ctrl_code.split(" [CTRL] ")
# for code in ctrl_code_list:
# # input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
# input_ids.extend(tokenizer.tokenize(code + " ."))
# put control code at the begginning
input_ids
=
[]
if
ctrl_code
:
ctrl_code_list
=
ctrl_code
.
split
(
" [CTRL] "
)
for
code
in
ctrl_code_list
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
code
+
" )"
))
turns
=
turns
[
-
3
:]
for
turn
in
turns
:
if
not
(
turn
.
endswith
(
"?"
)
or
turn
.
endswith
(
"."
)
or
turn
.
endswith
(
"!"
)):
turn
=
turn
+
" ."
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
# output_ids
outputs
=
ctrl_sent
output_ids
=
tokenizer
.
tokenize
(
outputs
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct train-module name! (either dialog or cnotrol))"
)
return
data_list
def
data_shuffle
(
data
,
seed
):
# set random seed to make the shuffling reproducible
np
.
random
.
seed
(
seed
)
np
.
random
.
shuffle
(
data
)
return
data
class
ControlDialogDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data
,
max_seq_len
,
sep_id
,
pad_id
,
eod_id
):
# need to deal with padding, label masking
self
.
data
=
data
self
.
max_seq_len
=
max_seq_len
self
.
sep_id
=
sep_id
self
.
pad_id
=
pad_id
self
.
eod_id
=
eod_id
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
data_dict
=
self
.
data
[
idx
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
text
=
input_ids
+
output_ids
+
[
self
.
eod_id
]
loss_mask
=
[
0
]
*
(
len
(
input_ids
)
-
1
)
+
[
1
]
*
(
len
(
output_ids
)
+
1
)
text_len
=
len
(
text
)
if
text_len
>
self
.
max_seq_len
+
1
:
text
=
text
[:
self
.
max_seq_len
+
1
]
loss_mask
=
loss_mask
[:
self
.
max_seq_len
]
else
:
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
return
{
"text"
:
np
.
array
(
text
,
dtype
=
np
.
int64
),
"loss_mask"
:
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)}
def
build_train_valid_test_datasets
(
data_folder
,
dataset_name
,
train_module
,
max_seq_len
,
seed
):
"""Build train, valid, and test datasets."""
dataname_dict
=
{
"wizard_of_wikipedia"
:
{
"train"
:
"train_entity_based_control.txt"
,
"valid"
:
"valid_random_split_entity_based_control.txt"
,
"test"
:
"test_random_split_entity_based_control.txt"
}}
train_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"train"
])
valid_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"valid"
])
test_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"test"
])
tokenizer
=
get_tokenizer
()
train_data_list
=
read_data
(
tokenizer
,
train_data_path
,
train_module
)
valid_data_list
=
read_data
(
tokenizer
,
valid_data_path
,
train_module
)
test_data_list
=
read_data
(
tokenizer
,
test_data_path
,
train_module
)
# shuffle the training data
train_data_list
=
data_shuffle
(
train_data_list
,
seed
)
# build train, valid, and test datasets
train_dataset
=
ControlDialogDataset
(
train_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
train_dataset
,
valid_dataset
,
test_dataset
tasks/dialctrl/finetune.py
0 → 100644
View file @
5aa3d5e3
"""Controllable Dialogue Finetuning"""
import
torch
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.model
import
GPTModel
from
megatron.training
import
evaluate_and_print_results
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.finetune_utils
import
finetune
from
tasks.dialctrl.data
import
build_train_valid_test_datasets
from
tasks.dialctrl.utils
import
get_ltor_attention_masks_and_position_ids
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
train_valid_datasets_provider
():
"""Build train, valid, and test datasets for dialog/control module"""
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
train_module
)
train_ds
,
valid_ds
,
_
=
build_train_valid_test_datasets
(
data_folder
=
args
.
data_folder
,
dataset_name
=
args
.
dataset_name
,
train_module
=
args
.
train_module
,
max_seq_len
=
args
.
max_seq_len
,
seed
=
args
.
seed
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
train_module
)
args
.
eval_interval
=
len
(
train_ds
)
//
args
.
global_batch_size
print_rank_0
(
' > evaluation interval: %d'
%
args
.
eval_interval
)
return
train_ds
,
valid_ds
def
process_batch
(
batch
):
"""Generate a batch"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Items and their type.
keys
=
[
'text'
,
'loss_mask'
]
datatype
=
torch
.
int64
data_b
=
mpu
.
broadcast_data
(
keys
,
batch
,
datatype
)
tokens_
=
data_b
[
'text'
].
long
()
labels
=
tokens_
[:,
1
:].
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
loss_mask
=
data_b
[
'loss_mask'
].
float
()
# Get the attention_mask and postition ids.
attention_mask
,
position_ids
=
\
get_ltor_attention_masks_and_position_ids
(
tokens
,
tokenizer
.
eod_id
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
loss_func
(
loss_mask
,
output_tensor
):
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
forward_step
(
batch
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
process_batch
(
batch_
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
def
main
():
finetune
(
train_valid_datasets_provider
,
model_provider
,
\
forward_step
=
forward_step
)
tasks/dialctrl/utils.py
0 → 100644
View file @
5aa3d5e3
import
torch
from
megatron
import
print_rank_0
def
get_ltor_attention_masks_and_position_ids
(
data
,
eod_token_id
):
"""Build attention masks and position id for left to right model."""
micro_batch_size
,
seq_length
=
data
.
size
()
# Attention mask
attention_mask
=
torch
.
tril
(
torch
.
ones
((
micro_batch_size
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
# mask padded tokens
for
b
in
range
(
micro_batch_size
):
for
idx
in
range
(
seq_length
-
1
):
if
data
[
b
,
idx
]
==
eod_token_id
:
# pad tokens that come after the eod token
attention_mask
[
b
,
0
,
idx
+
1
:,
:]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# # reset attentino mask and position ids
# # Loop through the batches:
# for b in range(micro_batch_size):
# # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id]
# eod_index = eod_index.clone()
# # Loop through EOD indecies:
# prev_index = 0
# for j in range(eod_index.size()[0]):
# i = eod_index[j]
# # Mask attention loss.
# attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# # Reset positions.
# position_ids[b, (i + 1):] -= (i + 1 - prev_index)
# prev_index = i + 1
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
return
attention_mask
,
position_ids
\ No newline at end of file
tasks/main.py
View file @
5aa3d5e3
...
...
@@ -84,6 +84,18 @@ def get_tasks_args(parser):
help
=
'Av.rank validation: how many other negatives to'
' take from each question pool'
)
# finetune for controllable dialogue
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
help
=
'either control module or dialogue model (control or dialog)'
)
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
help
=
'data folder (path of the data folder)'
)
group
.
add_argument
(
'--dataset-name'
,
type
=
str
,
default
=
""
,
help
=
'dataset name (e.g., wizard_of_wikipedia)'
)
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
help
=
'maximum sequence length'
)
group
.
add_argument
(
'--spec-toks'
,
type
=
str
,
default
=
"[SEP],[CTRL],[PAD]"
,
help
=
'additional special tokens'
)
return
parser
...
...
@@ -108,6 +120,8 @@ if __name__ == '__main__':
from
orqa.evaluate_orqa
import
main
elif
args
.
task
in
[
'RET-FINETUNE-NQ'
]:
from
orqa.supervised.finetune
import
main
elif
args
.
task
==
'dialctrl'
:
from
dialctrl.finetune
import
main
else
:
raise
NotImplementedError
(
'Task {} is not implemented.'
.
format
(
args
.
task
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment