Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fb3328fe
Commit
fb3328fe
authored
Nov 21, 2021
by
zihanl
Browse files
upload knwl_dialo
parent
ec202742
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1018 additions
and
405 deletions
+1018
-405
.gitignore
.gitignore
+3
-3
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+93
-21
tasks/dialctrl/data.py
tasks/dialctrl/data.py
+0
-267
tasks/dialctrl/utils.py
tasks/dialctrl/utils.py
+0
-47
tasks/knwl_dialo/data.py
tasks/knwl_dialo/data.py
+243
-0
tasks/knwl_dialo/evaluate.py
tasks/knwl_dialo/evaluate.py
+32
-16
tasks/knwl_dialo/finetune.py
tasks/knwl_dialo/finetune.py
+213
-0
tasks/knwl_dialo/metrics.py
tasks/knwl_dialo/metrics.py
+3
-12
tasks/knwl_dialo/prompt.py
tasks/knwl_dialo/prompt.py
+174
-0
tasks/knwl_dialo/utils.py
tasks/knwl_dialo/utils.py
+225
-0
tasks/main.py
tasks/main.py
+26
-32
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+6
-7
No files found.
.gitignore
View file @
fb3328fe
...
@@ -7,10 +7,10 @@ dist/
...
@@ -7,10 +7,10 @@ dist/
tensorboard
tensorboard
commands/
commands/
commands_new/
commands_new/
commands_others/
commands_final/
*.log
*.log
logs
logs
*.so
*.so
*.out
*.out
train_gpt_conv.py
dialogctrl/
dialogctrl/
\ No newline at end of file
control_gen/
\ No newline at end of file
megatron/text_generation_utils.py
View file @
fb3328fe
...
@@ -263,6 +263,7 @@ def generate_samples_prompt_input_from_file(model):
...
@@ -263,6 +263,7 @@ def generate_samples_prompt_input_from_file(model):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
from
nltk
import
word_tokenize
# Read the sample file and open the output file.
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
assert
args
.
sample_input_file
is
not
None
,
\
...
@@ -282,16 +283,35 @@ def generate_samples_prompt_input_from_file(model):
...
@@ -282,16 +283,35 @@ def generate_samples_prompt_input_from_file(model):
fname_out
=
open
(
sample_output_file
,
"w"
)
fname_out
=
open
(
sample_output_file
,
"w"
)
# Read the prompt file
# Read the prompt file
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
if
args
.
dynamic_prompt
:
prompt_examples
=
f
.
readlines
()
prompt_examples_dict
=
{}
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
else
:
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
args
.
num_prompt_examples
]
prompt_examples
=
prompt_examples
[:
args
.
num_prompt_examples
]
prompt
=
""
prompt
=
""
for
instance
in
prompt_examples
:
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt
+=
instance
+
"
\n
"
assert
args
.
prompt_type
in
[
"
context"
,
"keyphrase
"
]
assert
args
.
prompt_type
in
[
"
knowledge"
,
"knowledge_notopic"
,
"dialogue"
,
"dialogue_notopic
"
]
context_count
=
0
context_count
=
0
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -306,25 +326,77 @@ def generate_samples_prompt_input_from_file(model):
...
@@ -306,25 +326,77 @@ def generate_samples_prompt_input_from_file(model):
control_codes
=
splits
[
0
].
split
(
" [CTRL] "
)
control_codes
=
splits
[
0
].
split
(
" [CTRL] "
)
topic
=
control_codes
[
0
]
topic
=
control_codes
[
0
]
raw_text
=
prompt
if
args
.
dynamic_prompt
:
if
args
.
prompt_type
==
"context"
:
turns
=
splits
[
1
].
split
(
" [SEP] "
)
turns
=
splits
[
1
].
split
(
" [SEP] "
)
context
=
turns
[
-
1
]
last_turn
=
turns
[
-
1
]
raw_text
+=
"( "
+
context
+
" ) "
+
topic
+
" :"
key
=
topic
+
" "
+
last_turn
raw_text
=
prompt_examples_dict
[
key
]
else
:
else
:
keyphrase_list
=
control_codes
[
1
:]
raw_text
=
prompt
for
i
,
keyphrase
in
enumerate
(
keyphrase_list
):
if
args
.
prompt_type
==
"knowledge"
:
if
i
==
0
:
turns
=
splits
[
1
].
split
(
" [SEP] "
)
raw_text
+=
"( "
context
=
turns
[
-
1
]
raw_text
+=
"( "
+
context
+
" ) "
+
topic
+
" =>"
# raw_text += "( " + context + " ) " + topic + ":"
# raw_text += "( " + context + " ) " + topic + " ->"
elif
args
.
prompt_type
==
"knowledge_notopic"
:
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
for
j
,
turn
in
enumerate
(
turns
):
if
j
!=
0
:
raw_text
+=
" "
else
:
else
:
raw_text
+=
"; "
raw_text
+=
"( "
+
turn
+
" )"
raw_text
+=
keyphrase
raw_text
+=
" =>"
elif
args
.
prompt_type
==
"dialogue"
:
turns
=
splits
[
1
].
split
(
" [SEP] "
)
# context = turns[-1]
ctrl_sent
=
splits
[
2
]
ctrl_sent
=
" "
.
join
(
word_tokenize
(
ctrl_sent
))
# ## version one
# turns = turns[-3:]
# raw_text += "Topic: " + topic + ". "
# if len(turns) == 2:
# for idx, turn in enumerate(turns):
# if idx % 2 == 0:
# raw_text += "System: " + turn + " "
# else:
# raw_text += "User: " + turn + " "
# else:
# for idx, turn in enumerate(turns):
# if idx % 2 == 0:
# raw_text += "User: " + turn + " "
# else:
# raw_text += "System: " + turn + " "
# raw_text += "We know that: " + ctrl_sent + " "
# raw_text += "Therefore, the System will say:"
## version two
last_turn
=
turns
[
-
1
]
ctrl_sent
=
ctrl_sent
.
strip
()
last_turn
=
last_turn
.
strip
()
raw_text
+=
"Topic: "
+
topic
+
". "
raw_text
+=
"User says: "
+
last_turn
+
" "
raw_text
+=
"We know that: "
+
ctrl_sent
+
" "
raw_text
+=
"System replies:"
if
len
(
keyphrase_list
)
>
0
:
else
:
raw_text
+=
" ) "
turns
=
splits
[
1
].
split
(
" [SEP] "
)
raw_text
+=
topic
+
" :"
# context = turns[-1]
ctrl_sent
=
splits
[
2
]
ctrl_sent
=
" "
.
join
(
word_tokenize
(
ctrl_sent
))
## version two
last_turn
=
turns
[
-
1
]
ctrl_sent
=
ctrl_sent
.
strip
()
last_turn
=
last_turn
.
strip
()
raw_text
+=
"User says: "
+
last_turn
+
" "
raw_text
+=
"We know that: "
+
ctrl_sent
+
" "
raw_text
+=
"System replies:"
input_pos
+=
1
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
raw_text_len
=
len
(
raw_text
)
...
...
tasks/dialctrl/data.py
deleted
100644 → 0
View file @
ec202742
"""Build Dataset for Controllable Coversational Model"""
import
os
import
torch
import
numpy
as
np
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
def
read_data
(
tokenizer
,
data_path
,
train_module
):
"""read and tokenize dialog data"""
data_list
=
[]
with
open
(
data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
length_split
=
len
(
splits
)
assert
length_split
==
2
or
length_split
==
3
or
length_split
==
4
if
train_module
==
"dialog"
:
# if length_split == 2:
# continue
dialog_context
=
splits
[
0
]
if
length_split
>
2
:
ctrl_sent
=
splits
[
-
2
]
response
=
splits
[
-
1
]
# only take the last three turns in the dialog context
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
# input_ids
input_ids
=
[]
if
length_split
>
2
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
ctrl_sent
+
" )"
))
for
idx
,
turn
in
enumerate
(
turns
):
if
not
(
turn
.
endswith
(
"?"
)
or
turn
.
endswith
(
"."
)
or
turn
.
endswith
(
"!"
)):
turn
=
turn
+
" ."
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
train_module
==
"control"
:
if
length_split
==
2
:
continue
dialog_context
=
splits
[
0
]
ctrl_sent
=
splits
[
-
2
]
ctrl_code
=
splits
[
1
]
if
length_split
==
4
else
None
turns
=
dialog_context
.
split
(
" [SEP] "
)
# put control code at the begginning
input_ids
=
[]
if
ctrl_code
:
ctrl_code_list
=
ctrl_code
.
split
(
" [CTRL] "
)
for
code
in
ctrl_code_list
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
code
+
" )"
))
turns
=
turns
[
-
3
:]
for
turn
in
turns
:
if
not
(
turn
.
endswith
(
"?"
)
or
turn
.
endswith
(
"."
)
or
turn
.
endswith
(
"!"
)):
turn
=
turn
+
" ."
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
# output_ids
outputs
=
ctrl_sent
output_ids
=
tokenizer
.
tokenize
(
outputs
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct train-module name! "
\
"(either dialog or cnotrol))"
)
return
data_list
def
read_data_v2
(
tokenizer
,
data_path
,
train_module
,
last_turn
=
False
,
no_control_code
=
False
,
add_separator
=
False
,
add_ctrl_code_to_dialog
=
False
,
remove_ctrl_sent
=
False
):
"""
Read and tokenize data for version 2 (v2) data files.
Format: control code
\t
dialog context
\t
control sentence
\t
response.
Response only comes from the wizard.
Currently, this function is used to build test dataset for calculating PPL.
"""
data_list
=
[]
with
open
(
data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
rstrip
()
splits
=
line
.
split
(
"
\t
"
)
assert
len
(
splits
)
==
4
control_code
=
splits
[
0
]
dialog_context
=
splits
[
1
]
control_sent
=
splits
[
2
]
response
=
splits
[
3
]
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
if
train_module
==
"dialog"
:
# input_ids
if
add_ctrl_code_to_dialog
:
ctrl_code
=
control_code
.
split
(
" [CTRL] "
)[
0
]
input_ids
=
tokenizer
.
tokenize
(
"( "
+
ctrl_code
+
" )"
)
if
not
remove_ctrl_sent
and
control_sent
!=
"no_passages_used"
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
control_sent
+
" )"
)[:
256
])
else
:
if
remove_ctrl_sent
or
control_sent
==
"no_passages_used"
:
input_ids
=
[]
else
:
input_ids
=
tokenizer
.
tokenize
(
"( "
+
control_sent
+
" )"
)[:
256
]
for
turn
in
turns
:
if
add_separator
:
turn
=
"<< "
+
turn
+
" >>"
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
if
add_separator
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
":"
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
train_module
==
"control"
:
# skip example without control sentences
if
control_sent
==
"no_passages_used"
:
continue
input_ids
=
[]
if
not
no_control_code
:
ctrl_code_list
=
control_code
.
split
(
" [CTRL] "
)[:
3
]
# only choose maximum three control codes
for
code
in
ctrl_code_list
:
if
len
(
code
)
>
0
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
code
+
" )"
))
if
last_turn
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
turns
[
-
1
]))
else
:
for
turn
in
turns
:
if
add_separator
:
turn
=
"<< "
+
turn
+
" >>"
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
if
add_separator
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
":"
))
output_ids
=
tokenizer
.
tokenize
(
control_sent
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct train-module name! "
\
"(either dialog or cnotrol))"
)
return
data_list
def
data_shuffle
(
data
,
seed
):
# set random seed to make the shuffling reproducible
np
.
random
.
seed
(
seed
)
np
.
random
.
shuffle
(
data
)
return
data
class
ControlDialogDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data
,
max_seq_len
,
sep_id
,
pad_id
,
eod_id
):
# need to deal with padding, label masking
self
.
data
=
data
self
.
max_seq_len
=
max_seq_len
self
.
sep_id
=
sep_id
self
.
pad_id
=
pad_id
self
.
eod_id
=
eod_id
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
data_dict
=
self
.
data
[
idx
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
# assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
text
=
input_ids
+
output_ids
+
[
self
.
eod_id
]
loss_mask
=
[
0
]
*
(
len
(
input_ids
)
-
1
)
+
[
1
]
*
(
len
(
output_ids
)
+
1
)
text_len
=
len
(
text
)
if
text_len
>
self
.
max_seq_len
+
1
:
text
=
text
[:
self
.
max_seq_len
+
1
]
loss_mask
=
loss_mask
[:
self
.
max_seq_len
]
else
:
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
return
{
"text"
:
np
.
array
(
text
,
dtype
=
np
.
int64
),
\
"loss_mask"
:
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)}
def
build_train_valid_datasets
(
train_data_path
,
valid_data_path
,
train_module
,
max_seq_len
,
seed
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
):
"""Build train, valid, and test datasets."""
# dataname_dict = {"wizard_of_wikipedia": {"train": "train_entity_based_control.txt", "valid": "valid_random_split_entity_based_control.txt", "test": "test_random_split_entity_based_control.txt"}}
# train_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["train"])
# valid_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["valid"])
# test_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["test"])
tokenizer
=
get_tokenizer
()
# train_data_list = read_data(tokenizer, train_data_path, train_module)
train_data_list
=
read_data_v2
(
tokenizer
,
train_data_path
,
train_module
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
)
valid_data_list
=
read_data_v2
(
tokenizer
,
valid_data_path
,
train_module
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
)
# shuffle the training data
train_data_list
=
data_shuffle
(
train_data_list
,
seed
)
# build train, valid datasets
train_dataset
=
ControlDialogDataset
(
train_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
train_dataset
,
valid_dataset
def
build_test_dataset
(
test_data_path
,
train_module
,
max_seq_len
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
):
tokenizer
=
get_tokenizer
()
test_data_list
=
read_data_v2
(
tokenizer
,
test_data_path
,
train_module
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
test_dataset
tasks/dialctrl/utils.py
deleted
100644 → 0
View file @
ec202742
import
torch
from
megatron
import
print_rank_0
def
get_ltor_attention_masks_and_position_ids
(
data
,
eod_token_id
):
"""Build attention masks and position id for left to right model."""
micro_batch_size
,
seq_length
=
data
.
size
()
# Attention mask
attention_mask
=
torch
.
tril
(
torch
.
ones
(
(
micro_batch_size
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
# mask padded tokens
for
b
in
range
(
micro_batch_size
):
for
idx
in
range
(
seq_length
-
1
):
if
data
[
b
,
idx
]
==
eod_token_id
:
# pad tokens that come after the eod token
attention_mask
[
b
,
0
,
idx
+
1
:,
:]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# # reset attentino mask and position ids
# # Loop through the batches:
# for b in range(micro_batch_size):
# # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id]
# eod_index = eod_index.clone()
# # Loop through EOD indecies:
# prev_index = 0
# for j in range(eod_index.size()[0]):
# i = eod_index[j]
# # Mask attention loss.
# attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# # Reset positions.
# position_ids[b, (i + 1):] -= (i + 1 - prev_index)
# prev_index = i + 1
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
return
attention_mask
,
position_ids
\ No newline at end of file
tasks/knwl_dialo/data.py
0 → 100644
View file @
fb3328fe
"""Build Dataset for Controllable Coversational Model"""
import
os
import
torch
import
numpy
as
np
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
def
read_data_for_finetuning
(
tokenizer
,
data_path
,
module
):
"""
Data Format: topic
\t
dialog context
\t
knowledge
\t
response.
"""
data_list
=
[]
with
open
(
data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
rstrip
()
splits
=
line
.
split
(
"
\t
"
)
assert
len
(
splits
)
==
4
topic
=
splits
[
0
].
split
(
" [CTRL] "
)[
0
]
dialog_context
=
splits
[
1
]
knowledge
=
splits
[
2
]
response
=
splits
[
3
]
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
if
module
==
"response"
:
# input_ids
input_ids
=
tokenizer
.
tokenize
(
"( "
+
topic
+
" )"
)
if
knowledge
!=
"no_passages_used"
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
knowledge
+
" )"
)[:
256
])
for
turn
in
turns
:
turn
=
"<< "
+
turn
+
" >>"
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
input_ids
.
extend
(
tokenizer
.
tokenize
(
":"
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
module
==
"knowledge"
:
# skip example without knowledge sentences
if
knowledge
==
"no_passages_used"
:
continue
input_ids
=
[]
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
topic
+
" )"
))
for
turn
in
turns
:
turn
=
"<< "
+
turn
+
" >>"
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
input_ids
.
extend
(
tokenizer
.
tokenize
(
":"
))
output_ids
=
tokenizer
.
tokenize
(
knowledge
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct module name! "
\
"(either dialog or cnotrol))"
)
return
data_list
def
read_data_for_prompting
(
tokenizer
,
test_data_path
,
prompt_file
,
module
,
num_prompt_examples
,
dynamic_prompt
):
# get prompts
if
dynamic_prompt
:
import
json
prompt_examples_dict
=
{}
with
open
(
prompt_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt_examples
=
prompt_examples
[:
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
topic
]
=
prompt
else
:
with
open
(
prompt_file
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
data_list
=
[]
with
open
(
test_data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
topic
=
splits
[
0
].
split
(
" [CTRL] "
)[
0
]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
last_turn
=
turns
[
-
1
]
ctrl_sent
=
splits
[
2
]
response
=
splits
[
3
]
if
dynamic_prompt
:
prompt
=
prompt_examples_dict
[
topic
]
if
module
==
"response"
:
# input seq
input_seq
=
prompt
input_seq
+=
"Topic: "
+
topic
+
". "
input_seq
+=
"User says: "
+
last_turn
+
" "
input_seq
+=
"We know that: "
+
ctrl_sent
+
" "
input_seq
+=
"System replies:"
# output seq
output_seq
=
response
input_ids
=
tokenizer
.
tokenize
(
input_seq
)
output_ids
=
tokenizer
.
tokenize
(
output_seq
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
module
==
"knowledge"
:
# input seq
input_seq
=
prompt
input_seq
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
# output seq
output_seq
=
ctrl_sent
input_ids
=
tokenizer
.
tokenize
(
input_seq
)
output_ids
=
tokenizer
.
tokenize
(
output_seq
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct module name! "
\
"(either dialog or cnotrol))"
)
return
data_list
def
data_shuffle
(
data
,
seed
):
# set random seed to make the shuffling reproducible
np
.
random
.
seed
(
seed
)
np
.
random
.
shuffle
(
data
)
return
data
class
KnwlDialoDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data
,
max_seq_len
,
pad_id
,
eod_id
):
# need to deal with padding, label masking
self
.
data
=
data
self
.
max_seq_len
=
max_seq_len
self
.
pad_id
=
pad_id
self
.
eod_id
=
eod_id
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
data_dict
=
self
.
data
[
idx
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
text
=
input_ids
+
output_ids
+
[
self
.
eod_id
]
loss_mask
=
[
0
]
*
(
len
(
input_ids
)
-
1
)
+
[
1
]
*
(
len
(
output_ids
)
+
1
)
text_len
=
len
(
text
)
if
text_len
>
self
.
max_seq_len
+
1
:
text
=
text
[:
self
.
max_seq_len
+
1
]
loss_mask
=
loss_mask
[:
self
.
max_seq_len
]
else
:
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
return
{
"text"
:
np
.
array
(
text
,
dtype
=
np
.
int64
),
\
"loss_mask"
:
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)}
def
build_train_valid_datasets
(
train_data_path
,
valid_data_path
,
module
,
max_seq_len
,
seed
):
"""Build train, valid, and test datasets."""
tokenizer
=
get_tokenizer
()
train_data_list
=
read_data_for_finetuning
(
tokenizer
,
train_data_path
,
module
)
valid_data_list
=
read_data_for_finetuning
(
tokenizer
,
valid_data_path
,
module
)
# shuffle the training data
train_data_list
=
data_shuffle
(
train_data_list
,
seed
)
# build train, valid datasets
train_dataset
=
KnwlDialoDataset
(
train_data_list
,
max_seq_len
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
valid_dataset
=
KnwlDialoDataset
(
valid_data_list
,
max_seq_len
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
train_dataset
,
valid_dataset
def
build_test_dataset
(
test_data_path
,
module
,
max_seq_len
):
tokenizer
=
get_tokenizer
()
test_data_list
=
read_data_for_finetuning
(
tokenizer
,
test_data_path
,
module
)
test_dataset
=
KnwlDialoDataset
(
test_data_list
,
max_seq_len
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
test_dataset
def
build_test_dataset_for_prompting
(
test_data_path
,
prompt_file
,
module
,
max_seq_len
,
num_prompt_examples
,
dynamic_prompt
):
tokenizer
=
get_tokenizer
()
test_data_list
=
read_data_for_prompting
(
tokenizer
,
test_data_path
,
prompt_file
,
module
,
\
num_prompt_examples
,
dynamic_prompt
)
test_dataset
=
KnwlDialoDataset
(
test_data_list
,
max_seq_len
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
test_dataset
tasks/
dialctrl
/evaluate.py
→
tasks/
knwl_dialo
/evaluate.py
View file @
fb3328fe
...
@@ -7,9 +7,13 @@ from megatron.training import evaluate_and_print_results
...
@@ -7,9 +7,13 @@ from megatron.training import evaluate_and_print_results
from
megatron.training
import
setup_model_and_optimizer
from
megatron.training
import
setup_model_and_optimizer
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
from
tasks.dialctrl.data
import
build_test_dataset
from
tasks.knwl_dialo.data
import
build_test_dataset
from
tasks.dialctrl.finetune
import
model_provider
,
process_batch
,
loss_func
,
forward_step
from
tasks.knwl_dialo.data
import
build_test_dataset_for_prompting
from
tasks.dialctrl.metrics
import
F1Metric
from
tasks.knwl_dialo.finetune
import
model_provider
from
tasks.knwl_dialo.finetune
import
process_batch
from
tasks.knwl_dialo.finetune
import
loss_func
from
tasks.knwl_dialo.finetune
import
forward_step
from
tasks.knwl_dialo.metrics
import
F1Metric
from
tqdm
import
tqdm
from
tqdm
import
tqdm
def
test_dataset_provider
():
def
test_dataset_provider
():
...
@@ -18,15 +22,27 @@ def test_dataset_provider():
...
@@ -18,15 +22,27 @@ def test_dataset_provider():
print_rank_0
(
'> building the test dataset for %s module ...'
\
print_rank_0
(
'> building the test dataset for %s module ...'
\
%
args
.
train_module
)
%
args
.
train_module
)
test_ds
=
build_test_dataset
(
if
args
.
eval_prompting
:
test_data_path
=
args
.
test_data_path
,
print_rank_0
(
'> evaluating ppl for prompting'
)
train_module
=
args
.
train_module
,
test_ds
=
build_test_dataset_for_prompting
(
max_seq_len
=
args
.
max_seq_len
,
test_data_path
=
args
.
test_data_path
,
last_turn
=
args
.
last_turn
,
prompt_file
=
args
.
prompt_file
,
no_control_code
=
args
.
no_control_code
,
train_module
=
args
.
train_module
,
add_separator
=
args
.
add_separator
,
max_seq_len
=
args
.
max_seq_len
,
add_ctrl_code_to_dialog
=
args
.
add_ctrl_code_to_dialog
,
num_prompt_examples
=
args
.
num_prompt_examples
,
remove_ctrl_sent
=
args
.
remove_ctrl_sent
)
three_turns
=
args
.
three_turns
,
dynamic_prompt
=
args
.
dynamic_prompt
)
else
:
test_ds
=
build_test_dataset
(
test_data_path
=
args
.
test_data_path
,
train_module
=
args
.
train_module
,
max_seq_len
=
args
.
max_seq_len
,
last_turn
=
args
.
last_turn
,
no_control_code
=
args
.
no_control_code
,
add_separator
=
args
.
add_separator
,
add_ctrl_code_to_dialog
=
args
.
add_ctrl_code_to_dialog
,
remove_ctrl_sent
=
args
.
remove_ctrl_sent
)
print_rank_0
(
"> finished creating the test dataset for %s module ..."
\
print_rank_0
(
"> finished creating the test dataset for %s module ..."
\
%
args
.
train_module
)
%
args
.
train_module
)
...
@@ -93,7 +109,7 @@ def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
...
@@ -93,7 +109,7 @@ def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
print_rank_0
(
'done :-)'
)
print_rank_0
(
'done :-)'
)
def
evaluate_f1
(
guess_file
,
answer_file
,
remove_stopwords
):
def
evaluate_f1
(
guess_file
,
answer_file
):
guess_list
=
[]
guess_list
=
[]
print_rank_0
(
'reading %s'
%
guess_file
)
print_rank_0
(
'reading %s'
%
guess_file
)
...
@@ -116,7 +132,7 @@ def evaluate_f1(guess_file, answer_file, remove_stopwords):
...
@@ -116,7 +132,7 @@ def evaluate_f1(guess_file, answer_file, remove_stopwords):
assert
len
(
guess_list
)
==
len
(
answer_list
),
\
assert
len
(
guess_list
)
==
len
(
answer_list
),
\
"lengths of guess and answer are different!"
"lengths of guess and answer are different!"
precision
,
recall
,
f1
=
F1Metric
.
compute_all_pairs
(
guess_list
,
answer_list
,
remove_stopwords
)
precision
,
recall
,
f1
=
F1Metric
.
compute_all_pairs
(
guess_list
,
answer_list
)
print_rank_0
(
'Precision: %.4f; recall: %.4f; f1: %.4f'
%
(
precision
,
recall
,
f1
))
print_rank_0
(
'Precision: %.4f; recall: %.4f; f1: %.4f'
%
(
precision
,
recall
,
f1
))
print_rank_0
(
'done :-)'
)
print_rank_0
(
'done :-)'
)
...
@@ -124,10 +140,10 @@ def evaluate_f1(guess_file, answer_file, remove_stopwords):
...
@@ -124,10 +140,10 @@ def evaluate_f1(guess_file, answer_file, remove_stopwords):
def
main
():
def
main
():
args
=
get_args
()
args
=
get_args
()
if
'ppl'
in
args
.
task
:
if
'ppl'
in
args
.
task
:
evaluate_ppl
(
test_dataset_provider
,
model_provider
,
forward_step
)
evaluate_ppl
(
test_dataset_provider
,
model_provider
,
forward_step
)
elif
'f1'
in
args
.
task
:
elif
'f1'
in
args
.
task
:
evaluate_f1
(
args
.
guess_file
,
args
.
answer_file
,
args
.
remove_stopwords
)
evaluate_f1
(
args
.
guess_file
,
args
.
answer_file
)
tasks/
dialctrl
/finetune.py
→
tasks/
knwl_dialo
/finetune.py
View file @
fb3328fe
"""
Controllable
Dialogue Finetuning"""
"""Dialogue Finetuning"""
import
torch
import
torch
from
functools
import
partial
from
functools
import
partial
from
megatron
import
mpu
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
get_model
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.initialize
import
initialize_megatron
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
from
tasks.dialctrl.data
import
build_train_valid_datasets
from
tasks.knwl_dialo.data
import
build_train_valid_datasets
from
tasks.dialctrl.utils
import
get_ltor_attention_masks_and_position_ids
from
tasks.knwl_dialo.utils
import
get_ltor_attention_masks_and_position_ids
from
tasks.knwl_dialo.utils
import
get_token_stream
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
...
@@ -113,8 +116,98 @@ def forward_step(batch, model):
...
@@ -113,8 +116,98 @@ def forward_step(batch, model):
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
def
generate_samples_input_from_file
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w"
)
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
if
"
\r
"
in
trim_decode_tokens
:
trim_decode_tokens
=
trim_decode_tokens
.
replace
(
"
\r
"
,
""
)
if
"
\n
"
in
trim_decode_tokens
:
trim_decode_tokens
=
trim_decode_tokens
.
replace
(
"
\n
"
,
""
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
context_count
+=
1
if
input_pos
==
input_count
:
return
def
run_generation
(
model_provider
):
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
generate_samples_input_from_file
(
model
)
def
main
():
def
main
():
args
=
get_args
()
finetune
(
train_valid_datasets_provider
,
model_provider
,
\
forward_step
=
forward_step
)
if
"finetune"
in
args
.
task
:
finetune
(
train_valid_datasets_provider
,
model_provider
,
\
forward_step
=
forward_step
)
else
:
# generate
run_generation
(
model_provider
)
tasks/
dialctrl
/metrics.py
→
tasks/
knwl_dialo
/metrics.py
View file @
fb3328fe
...
@@ -61,7 +61,7 @@ class F1Metric:
...
@@ -61,7 +61,7 @@ class F1Metric:
return
precision
,
recall
,
f1
return
precision
,
recall
,
f1
@
staticmethod
@
staticmethod
def
compute_each_pair
(
guess
:
str
,
answer
:
str
,
rm_sw
:
bool
):
def
compute_each_pair
(
guess
:
str
,
answer
:
str
):
if
answer
==
""
:
if
answer
==
""
:
return
None
,
None
,
None
return
None
,
None
,
None
if
guess
==
""
:
if
guess
==
""
:
...
@@ -69,26 +69,17 @@ class F1Metric:
...
@@ -69,26 +69,17 @@ class F1Metric:
g_tokens
=
normalize_answer
(
guess
).
split
()
g_tokens
=
normalize_answer
(
guess
).
split
()
a_tokens
=
normalize_answer
(
answer
).
split
()
a_tokens
=
normalize_answer
(
answer
).
split
()
if
rm_sw
:
g_tokens
=
remove_stopwords
(
g_tokens
)
a_tokens
=
remove_stopwords
(
a_tokens
)
if
len
(
a_tokens
)
==
0
:
return
None
,
None
,
None
if
len
(
g_tokens
)
==
0
:
return
0
,
0
,
0
precision
,
recall
,
f1
=
F1Metric
.
_prec_recall_f1_score
(
g_tokens
,
a_tokens
)
precision
,
recall
,
f1
=
F1Metric
.
_prec_recall_f1_score
(
g_tokens
,
a_tokens
)
return
precision
,
recall
,
f1
return
precision
,
recall
,
f1
@
staticmethod
@
staticmethod
def
compute_all_pairs
(
guesses
:
List
[
str
],
answers
:
List
[
str
]
,
rm_sw
=
False
):
def
compute_all_pairs
(
guesses
:
List
[
str
],
answers
:
List
[
str
]):
# additional augment:
# additional augment:
# rm_sw: whether to remove stopwords
assert
len
(
guesses
)
==
len
(
answers
)
assert
len
(
guesses
)
==
len
(
answers
)
precision_list
,
recall_list
,
f1_list
=
[],
[],
[]
precision_list
,
recall_list
,
f1_list
=
[],
[],
[]
for
guess
,
answer
in
zip
(
guesses
,
answers
):
for
guess
,
answer
in
zip
(
guesses
,
answers
):
precision
,
recall
,
f1
=
F1Metric
.
compute_each_pair
(
guess
,
answer
,
rm_sw
)
precision
,
recall
,
f1
=
F1Metric
.
compute_each_pair
(
guess
,
answer
)
if
precision
is
None
or
recall
is
None
or
f1
is
None
:
if
precision
is
None
or
recall
is
None
or
f1
is
None
:
continue
continue
precision_list
.
append
(
precision
)
precision_list
.
append
(
precision
)
...
...
tasks/knwl_dialo/prompt.py
0 → 100644
View file @
fb3328fe
import
json
import
torch
from
nltk
import
word_tokenize
from
megatron
import
mpu
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
tasks.knwl_dialo.utils
import
get_token_stream
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
generate_samples_by_prompting_input_from_file
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w"
)
# Read the prompt file
if
args
.
dynamic_prompt
:
prompt_examples_dict
=
{}
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
else
:
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
args
.
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
assert
args
.
prompt_type
in
[
"knowledge"
,
"response"
]
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
input_str
=
all_raw_text
[
input_pos
]
input_str
=
input_str
.
strip
()
splits
=
input_str
.
split
(
"
\t
"
)
control_codes
=
splits
[
0
].
split
(
" [CTRL] "
)
topic
=
control_codes
[
0
]
if
args
.
dynamic_prompt
:
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
key
=
topic
+
" "
+
last_turn
raw_text
=
prompt_examples_dict
[
key
]
else
:
raw_text
=
prompt
if
args
.
prompt_type
==
"knowledge"
:
turns
=
splits
[
1
].
split
(
" [SEP] "
)
context
=
turns
[
-
1
]
raw_text
+=
"( "
+
context
+
" ) "
+
topic
+
" =>"
else
:
# args.prompt_type == "response":
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
last_turn
=
turns
[
-
1
]
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
raw_text
+=
"Topic: "
+
topic
+
". "
raw_text
+=
"User says: "
+
last_turn
+
" "
raw_text
+=
"We know that: "
+
knowledge
+
" "
raw_text
+=
"System replies:"
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
generated_output
=
trim_decode_tokens
.
split
(
"
\n
"
)[
0
]
generated_output
=
generated_output
.
strip
()
fname_out
.
write
(
generated_output
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
context_count
+=
1
if
input_pos
==
input_count
:
return
def
main
():
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
generate_samples_by_prompting_input_from_file
(
model
)
tasks/knwl_dialo/utils.py
0 → 100644
View file @
fb3328fe
import
torch
from
megatron
import
mpu
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron.utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
from
megatron.p2p_communication
import
recv_forward
,
send_forward
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
get_ltor_attention_masks_and_position_ids
(
data
,
eod_token_id
):
"""Build attention masks and position id for left to right model."""
micro_batch_size
,
seq_length
=
data
.
size
()
# Attention mask
attention_mask
=
torch
.
tril
(
torch
.
ones
(
(
micro_batch_size
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
# mask padded tokens
for
b
in
range
(
micro_batch_size
):
for
idx
in
range
(
seq_length
-
1
):
if
data
[
b
,
idx
]
==
eod_token_id
:
# pad tokens that come after the eod token
attention_mask
[
b
,
0
,
idx
+
1
:,
:]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
return
attention_mask
,
position_ids
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
,
layer_past
=
None
,
get_key_value
=
None
,
forward_method_parallel_output
=
None
):
# functions the correct size
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
input_tensor
=
recv_forward
()
# Forward pass through the model.
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
if
get_key_value
:
output_tensor
,
layer_past
=
output_tensor
send_forward
(
output_tensor
)
args
.
seq_length
=
orig_seq_length
if
get_key_value
:
return
output_tensor
,
layer_past
return
output_tensor
def
pad_batch
(
batch
,
pad_id
,
args
):
context_lengths
=
[]
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
if
context_length
<
args
.
seq_length
:
tokens
.
extend
([
pad_id
]
*
(
args
.
seq_length
-
context_length
))
context_lengths
.
append
(
context_length
)
return
batch
,
context_lengths
def
get_batch
(
context_tokens
):
"""Generate batch from context tokens."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Move to GPU.
tokens
=
context_tokens
.
view
(
args
.
micro_batch_size
,
-
1
).
contiguous
().
cuda
()
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
)
return
tokens
,
attention_mask
,
position_ids
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
model
.
eval
()
with
torch
.
no_grad
():
context_length
=
context_lengths
.
min
().
item
()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if
hasattr
(
args
,
'eos_id'
):
eos_id
=
args
.
eos_id
else
:
eos_id
=
tokenizer
.
eod
counter
=
0
org_context_length
=
context_length
layer_past
=
None
batch_size
=
context_tokens
.
size
(
0
)
is_done
=
torch
.
zeros
([
batch_size
]).
byte
().
cuda
()
tokens
=
context_tokens
if
maxlen
is
None
:
maxlen
=
args
.
seq_length
-
1
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
output
=
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
type_ids
,
forward_method_parallel_output
=
False
)
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
output
[:,
context_length
-
1
,
:]
if
mpu
.
is_pipeline_last_stage
():
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
started
=
context_lengths
<=
context_length
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
new_tokens
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
yield
tokens
,
lengths
else
:
if
mpu
.
is_pipeline_first_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
tokens
[:,
context_length
]
=
new_tokens
yield
tokens
,
None
else
:
yield
None
,
None
done
=
torch
.
cuda
.
ByteTensor
([
0
])
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
context_length
+=
1
counter
+=
1
if
done
:
break
def
get_token_stream
(
model
,
context_tokens
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_tensor_model_parallel_src_rank
(),
group
=
mpu
.
get_tensor_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
)
for
tokens
,
lengths
in
batch_token_iterator
:
context_length
+=
1
if
tokens
is
not
None
:
yield
tokens
[:,
:
context_length
],
lengths
else
:
yield
None
,
None
tasks/main.py
View file @
fb3328fe
...
@@ -84,9 +84,24 @@ def get_tasks_args(parser):
...
@@ -84,9 +84,24 @@ def get_tasks_args(parser):
help
=
'Av.rank validation: how many other negatives to'
help
=
'Av.rank validation: how many other negatives to'
' take from each question pool'
)
' take from each question pool'
)
# finetune for controllable dialogue
# parameters for the knowledgeable dialogue generation
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'either control module or dialogue model (control or dialog)'
)
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
help
=
'Get input from file instead of interactive mode, '
'each line is an input.'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
help
=
'Output file got from --sample-input-file'
)
group
.
add_argument
(
'--prompt-file'
,
type
=
str
,
default
=
""
,
help
=
'prompting file'
)
group
.
add_argument
(
'--prompt-type'
,
type
=
str
,
default
=
""
,
help
=
'prompt type (knowledge or response)'
)
group
.
add_argument
(
'--num-prompt-examples'
,
type
=
int
,
default
=
10
,
help
=
'number of prompt examples'
)
group
.
add_argument
(
'--dynamic-prompt'
,
action
=
'store_true'
,
default
=
False
,
help
=
'using different prompts for different test samples'
)
group
.
add_argument
(
'--module'
,
type
=
str
,
default
=
""
,
help
=
'either knowledge generation (knowledge) or response generation (response)'
)
group
.
add_argument
(
'--train-data-path'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--train-data-path'
,
type
=
str
,
default
=
""
,
help
=
'datapath for training set'
)
help
=
'datapath for training set'
)
group
.
add_argument
(
'--test-data-path'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--test-data-path'
,
type
=
str
,
default
=
""
,
...
@@ -99,29 +114,8 @@ def get_tasks_args(parser):
...
@@ -99,29 +114,8 @@ def get_tasks_args(parser):
help
=
'maximum sequence length'
)
help
=
'maximum sequence length'
)
group
.
add_argument
(
'--spec-toks'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--spec-toks'
,
type
=
str
,
default
=
None
,
help
=
'additional special tokens'
)
help
=
'additional special tokens'
)
group
.
add_argument
(
'--last-turn'
,
action
=
'store_true'
,
group
.
add_argument
(
'--eval-prompting'
,
action
=
'store_true'
,
help
=
'only use last turn for control model'
)
help
=
'Whether to evaluate prompting'
)
group
.
add_argument
(
'--no-control-code'
,
action
=
'store_true'
,
help
=
'removing control code in the training for control model'
)
group
.
add_argument
(
'--remove-stopwords'
,
action
=
'store_true'
,
help
=
'removing stopwords when evaluating F1-score'
)
group
.
add_argument
(
'--add-separator'
,
action
=
'store_true'
,
help
=
'add separator between turns and add colon before generation'
)
group
.
add_argument
(
'--add-ctrl-code-to-dialog'
,
action
=
'store_true'
,
help
=
'add control code in the dialog modeling'
)
group
.
add_argument
(
'--remove-ctrl-sent'
,
action
=
'store_true'
,
help
=
'dont use control sentence in dialog modeling'
)
# finetune for controllable generation
group
.
add_argument
(
'--wiki-path'
,
type
=
str
,
default
=
""
,
help
=
'data path for the wikipedia corpus'
)
group
.
add_argument
(
'--tokenized-path'
,
type
=
str
,
default
=
""
,
help
=
'data path for the tokenized file'
)
group
.
add_argument
(
'--prop'
,
type
=
float
,
default
=
1.0
,
help
=
'Proportion of data used for training'
)
group
.
add_argument
(
'--max-instance'
,
type
=
int
,
default
=
10000000
,
help
=
'Proportion of data used for training'
)
return
parser
return
parser
...
@@ -146,12 +140,12 @@ if __name__ == '__main__':
...
@@ -146,12 +140,12 @@ if __name__ == '__main__':
from
orqa.evaluate_orqa
import
main
from
orqa.evaluate_orqa
import
main
elif
args
.
task
in
[
'RET-FINETUNE-NQ'
]:
elif
args
.
task
in
[
'RET-FINETUNE-NQ'
]:
from
orqa.supervised.finetune
import
main
from
orqa.supervised.finetune
import
main
elif
args
.
task
==
'
control-gen
'
:
elif
args
.
task
==
'
knwl-dialo-prompt
'
:
from
control_gen.finetune
import
main
from
knwl_dialo.prompt
import
main
elif
args
.
task
==
'dialctrl'
:
elif
args
.
task
==
[
'knwl-dialo-finetune'
,
'knwl-dialo-gen'
]
:
from
dialctrl
.finetune
import
main
from
knwl_dialo
.finetune
import
main
elif
args
.
task
in
[
'
dialctrl
-eval-ppl'
,
'
dialctrl
-eval-f1'
]:
elif
args
.
task
in
[
'
knwl-dialo
-eval-ppl'
,
'
knwl-dialo
-eval-f1'
]:
from
dialctrl
.evaluate
import
main
from
knwl_dialo
.evaluate
import
main
else
:
else
:
raise
NotImplementedError
(
'Task {} is not implemented.'
.
format
(
raise
NotImplementedError
(
'Task {} is not implemented.'
.
format
(
args
.
task
))
args
.
task
))
...
...
tools/generate_samples_gpt.py
View file @
fb3328fe
...
@@ -76,6 +76,7 @@ def add_text_generate_args(parser):
...
@@ -76,6 +76,7 @@ def add_text_generate_args(parser):
help
=
'additional special tokens'
)
help
=
'additional special tokens'
)
group
.
add_argument
(
'--line-by-line'
,
action
=
"store_true"
,
group
.
add_argument
(
'--line-by-line'
,
action
=
"store_true"
,
help
=
'generate samples line by line'
)
help
=
'generate samples line by line'
)
group
.
add_argument
(
'--prompt'
,
action
=
"store_true"
,
group
.
add_argument
(
'--prompt'
,
action
=
"store_true"
,
help
=
'generate samples based on prompting'
)
help
=
'generate samples based on prompting'
)
group
.
add_argument
(
'--prompt-file'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--prompt-file'
,
type
=
str
,
default
=
""
,
...
@@ -84,6 +85,10 @@ def add_text_generate_args(parser):
...
@@ -84,6 +85,10 @@ def add_text_generate_args(parser):
help
=
'prompt type (context or keyphrase)'
)
help
=
'prompt type (context or keyphrase)'
)
group
.
add_argument
(
'--num-prompt-examples'
,
type
=
int
,
default
=
10
,
group
.
add_argument
(
'--num-prompt-examples'
,
type
=
int
,
default
=
10
,
help
=
'number of prompt examples'
)
help
=
'number of prompt examples'
)
group
.
add_argument
(
"--noknowledge"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Do not use knowledge in prompting'
)
group
.
add_argument
(
'--dynamic-prompt'
,
action
=
'store_true'
,
default
=
False
,
help
=
'using different prompts for different test samples'
)
return
parser
return
parser
...
@@ -114,13 +119,7 @@ def main():
...
@@ -114,13 +119,7 @@ def main():
if
args
.
num_samples
==
0
:
if
args
.
num_samples
==
0
:
if
args
.
sample_input_file
!=
None
:
if
args
.
sample_input_file
!=
None
:
args
.
micro_batch_size
=
1
args
.
micro_batch_size
=
1
if
args
.
line_by_line
:
generate_samples_input_from_file
(
model
)
if
args
.
prompt
:
generate_samples_prompt_input_from_file
(
model
)
else
:
generate_samples_line_by_line_input_from_file
(
model
)
else
:
generate_samples_input_from_file
(
model
)
else
:
else
:
generate_samples_interactive
(
model
)
generate_samples_interactive
(
model
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment