Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ec202742
Commit
ec202742
authored
Aug 28, 2021
by
zihanl
Browse files
add recent code and prompting
parent
7b799b7c
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
955 additions
and
59 deletions
+955
-59
.gitignore
.gitignore
+11
-1
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+357
-0
tasks/dialctrl/data.py
tasks/dialctrl/data.py
+142
-43
tasks/dialctrl/evaluate.py
tasks/dialctrl/evaluate.py
+133
-0
tasks/dialctrl/finetune.py
tasks/dialctrl/finetune.py
+17
-6
tasks/dialctrl/metrics.py
tasks/dialctrl/metrics.py
+98
-0
tasks/dialctrl/utils.py
tasks/dialctrl/utils.py
+3
-1
tasks/main.py
tasks/main.py
+36
-6
tools/control_dialog_interactive.py
tools/control_dialog_interactive.py
+136
-0
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+22
-2
No files found.
.gitignore
View file @
ec202742
...
@@ -4,3 +4,13 @@ __pycache__
...
@@ -4,3 +4,13 @@ __pycache__
build/
build/
dist/
dist/
*.egg-info/
*.egg-info/
tensorboard
commands/
commands_new/
*.log
logs
*.so
*.out
train_gpt_conv.py
dialogctrl/
control_gen/
\ No newline at end of file
megatron/text_generation_utils.py
View file @
ec202742
...
@@ -24,6 +24,7 @@ import torch
...
@@ -24,6 +24,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
from
megatron.utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
...
@@ -190,6 +191,362 @@ def generate_samples_input_from_file(model):
...
@@ -190,6 +191,362 @@ def generate_samples_input_from_file(model):
raw_text
=
None
raw_text
=
None
context_count
+=
1
context_count
+=
1
def
generate_samples_line_by_line_input_from_file
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w"
)
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
if
"
\r
"
in
trim_decode_tokens
:
trim_decode_tokens
=
trim_decode_tokens
.
replace
(
"
\r
"
,
""
)
if
"
\n
"
in
trim_decode_tokens
:
trim_decode_tokens
=
trim_decode_tokens
.
replace
(
"
\n
"
,
""
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
context_count
+=
1
if
input_pos
==
input_count
:
return
def
generate_samples_prompt_input_from_file
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w"
)
# Read the prompt file
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
args
.
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
assert
args
.
prompt_type
in
[
"context"
,
"keyphrase"
]
context_count
=
0
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
input_str
=
all_raw_text
[
input_pos
]
input_str
=
input_str
.
strip
()
splits
=
input_str
.
split
(
"
\t
"
)
control_codes
=
splits
[
0
].
split
(
" [CTRL] "
)
topic
=
control_codes
[
0
]
raw_text
=
prompt
if
args
.
prompt_type
==
"context"
:
turns
=
splits
[
1
].
split
(
" [SEP] "
)
context
=
turns
[
-
1
]
raw_text
+=
"( "
+
context
+
" ) "
+
topic
+
" :"
else
:
keyphrase_list
=
control_codes
[
1
:]
for
i
,
keyphrase
in
enumerate
(
keyphrase_list
):
if
i
==
0
:
raw_text
+=
"( "
else
:
raw_text
+=
"; "
raw_text
+=
keyphrase
if
len
(
keyphrase_list
)
>
0
:
raw_text
+=
" ) "
raw_text
+=
topic
+
" :"
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
generated_output
=
trim_decode_tokens
.
split
(
"
\n
"
)[
0
]
generated_output
=
generated_output
.
strip
()
fname_out
.
write
(
generated_output
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
context_count
+=
1
if
input_pos
==
input_count
:
return
def
dialog_with_gpt_control_interactive
(
conv_model
,
ctrl_model
,
add_separtor
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
conv_model
.
eval
()
ctrl_model
.
eval
()
dialog_history
=
[]
with
torch
.
no_grad
():
while
True
:
ctrl_model_input_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
# input @@ to separate the control code and current turn
input_text
=
input
(
">>> "
)
while
not
input_text
:
print
(
"Input should not be empty!"
)
input_text
=
input
(
">>> "
)
assert
" @@ "
in
input_text
,
"Please input with a correct template"
splits
=
input_text
.
split
(
" @@ "
)
ctrl_code
=
splits
[
0
]
curr_turn
=
splits
[
1
]
prev_two_turns
=
""
if
add_separtor
:
for
i
,
turn
in
enumerate
(
dialog_history
[
-
2
:]):
if
i
==
0
:
prev_two_turns
=
"<< "
+
turn
+
" >>"
else
:
prev_two_turns
+=
" "
prev_two_turns
+=
"<< "
+
turn
+
" >>"
else
:
prev_two_turns
=
" "
.
join
(
dialog_history
[
-
2
:])
dialog_history
.
append
(
curr_turn
)
print
(
"
\n
History:"
,
prev_two_turns
)
print
(
"User:"
,
curr_turn
)
if
add_separtor
:
curr_turn
=
"<< "
+
curr_turn
+
" >>"
if
prev_two_turns
!=
""
:
dialog_context
=
prev_two_turns
+
" "
+
curr_turn
else
:
dialog_context
=
curr_turn
ctrl_input
=
ctrl_code
+
" "
+
dialog_context
if
add_separtor
:
ctrl_input
+=
" :"
ctrl_input_text_len
=
len
(
ctrl_input
)
ctrl_context_tokens
=
tokenizer
.
tokenize
(
ctrl_input
)
else
:
ctrl_context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
token_stream
=
get_token_stream
(
ctrl_model
,
[
ctrl_context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
control_sent
=
tokenizer
.
detokenize
(
decode_tokens
)[
ctrl_input_text_len
:]
control_sent
=
control_sent
.
replace
(
"<|endoftext|>"
,
""
)
print
(
"
\n
Control Sentence:"
,
control_sent
)
if
control_sent
!=
""
:
control_sent
=
"( "
+
control_sent
+
" )"
conv_input
=
control_sent
+
" "
+
dialog_context
else
:
conv_input
=
dialog_context
conv_input_text_len
=
len
(
conv_input
)
conv_context_tokens
=
tokenizer
.
tokenize
(
conv_input
)
token_stream
=
get_token_stream
(
conv_model
,
[
conv_context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
response
=
tokenizer
.
detokenize
(
decode_tokens
)[
conv_input_text_len
:]
response
=
response
.
replace
(
"<|endoftext|>"
,
""
)
print
(
"
\n
Chatbot:"
,
response
)
dialog_history
.
append
(
response
)
def
dialog_with_dpr_control_interactive
(
conv_model
,
ctrl_model
,
ctrl_tokenizer
,
knowledge_corpus
,
knowledge_corpus_emb
,
add_separtor
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
conv_model
.
eval
()
ctrl_model
.
eval
()
dialog_history
=
[]
with
torch
.
no_grad
():
while
True
:
input_text
=
input
(
">>> "
)
while
not
input_text
:
print
(
"Input should not be empty!"
)
input_text
=
input
(
">>> "
)
assert
" @@ "
in
input_text
,
"Please input with a correct template"
splits
=
input_text
.
split
(
" @@ "
)
ctrl_code
=
splits
[
0
]
curr_turn
=
splits
[
1
]
prev_two_turns
=
" "
.
join
(
dialog_history
[
-
2
:])
prev_two_turns_v2
=
""
if
add_separtor
:
for
i
,
turn
in
enumerate
(
dialog_history
[
-
2
:]):
if
i
==
0
:
prev_two_turns_v2
=
"<< "
+
turn
+
" >>"
else
:
prev_two_turns_v2
+=
" "
prev_two_turns_v2
+=
"<< "
+
turn
+
" >>"
else
:
prev_two_turns_v2
=
prev_two_turns
dialog_history
.
append
(
curr_turn
)
print
(
"
\n
History:"
,
prev_two_turns_v2
)
print
(
"
\n
User:"
,
curr_turn
)
if
prev_two_turns
!=
""
:
dialog_context
=
prev_two_turns
+
" "
+
curr_turn
else
:
dialog_context
=
curr_turn
if
add_separtor
:
curr_turn
=
"<< "
+
curr_turn
+
" >>"
dialog_context_v2
=
prev_two_turns_v2
+
curr_turn
else
:
dialog_context_v2
=
dialog_context
ctrl_input
=
ctrl_code
+
" "
+
dialog_context
ctrl_input_ids
=
ctrl_tokenizer
.
encode
(
ctrl_input
)
ctrl_input_ids
=
torch
.
LongTensor
([
ctrl_input_ids
]).
cuda
()
attn_masks
=
torch
.
ones
(
1
,
ctrl_input_ids
.
size
()[
-
1
]).
cuda
()
query_emb
=
ctrl_model
(
input_ids
=
ctrl_input_ids
,
attention_mask
=
attn_masks
).
pooler_output
# (1,768)
logits
=
knowledge_corpus_emb
.
matmul
(
query_emb
[
0
])
retrieved_idx
=
torch
.
argmax
(
logits
).
item
()
control_sent
=
knowledge_corpus
[
retrieved_idx
].
strip
()
print
(
"
\n
Control Sentence:"
,
control_sent
)
if
control_sent
!=
""
:
control_sent
=
"( "
+
control_sent
+
" )"
conv_input
=
control_sent
+
" "
+
dialog_context_v2
else
:
conv_input
=
dialog_context_v2
conv_input_text_len
=
len
(
conv_input
)
conv_context_tokens
=
tokenizer
.
tokenize
(
conv_input
)
token_stream
=
get_token_stream
(
conv_model
,
[
conv_context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
response
=
tokenizer
.
detokenize
(
decode_tokens
)[
conv_input_text_len
:]
response
=
response
.
replace
(
"<|endoftext|>"
,
""
)
print
(
"
\n
Chatbot:"
,
response
)
dialog_history
.
append
(
response
)
# We added this function to support the tasks evaluation such as squad
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
# codebase. The lm-evaluation-harness code can now call this function
# codebase. The lm-evaluation-harness code can now call this function
...
...
tasks/dialctrl/data.py
View file @
ec202742
...
@@ -32,20 +32,15 @@ def read_data(tokenizer, data_path, train_module):
...
@@ -32,20 +32,15 @@ def read_data(tokenizer, data_path, train_module):
turns
=
turns
[
-
3
:]
turns
=
turns
[
-
3
:]
# input_ids
# input_ids
input_ids
=
[]
if
length_split
>
2
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
ctrl_sent
+
" )"
))
for
idx
,
turn
in
enumerate
(
turns
):
for
idx
,
turn
in
enumerate
(
turns
):
if
not
(
turn
.
endswith
(
"?"
)
or
turn
.
endswith
(
"."
)
or
turn
.
endswith
(
"!"
)):
if
not
(
turn
.
endswith
(
"?"
)
or
turn
.
endswith
(
"."
)
or
turn
.
endswith
(
"!"
)):
turn
=
turn
+
" ."
turn
=
turn
+
" ."
if
idx
==
0
:
input_ids
=
tokenizer
.
tokenize
(
turn
)
else
:
# input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
if
length_split
>
2
:
# when there is control sentence, add it into the input_ids
# input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
ctrl_sent
+
" ) ."
))
# output_ids
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
output_ids
=
tokenizer
.
tokenize
(
response
)
...
@@ -59,23 +54,6 @@ def read_data(tokenizer, data_path, train_module):
...
@@ -59,23 +54,6 @@ def read_data(tokenizer, data_path, train_module):
ctrl_code
=
splits
[
1
]
if
length_split
==
4
else
None
ctrl_code
=
splits
[
1
]
if
length_split
==
4
else
None
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
dialog_context
.
split
(
" [SEP] "
)
# last_turn = turns[-1]
# turns = turns[-3:]
# for idx, turn in enumerate(turns):
# if idx == 0:
# input_ids = tokenizer.tokenize(turn)
# else:
# # input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# input_ids.extend(tokenizer.tokenize(turn))
# # input_ids
# if ctrl_code:
# ctrl_code_list = ctrl_code.split(" [CTRL] ")
# for code in ctrl_code_list:
# # input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
# input_ids.extend(tokenizer.tokenize(code + " ."))
# put control code at the begginning
# put control code at the begginning
input_ids
=
[]
input_ids
=
[]
if
ctrl_code
:
if
ctrl_code
:
...
@@ -96,7 +74,95 @@ def read_data(tokenizer, data_path, train_module):
...
@@ -96,7 +74,95 @@ def read_data(tokenizer, data_path, train_module):
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
else
:
raise
ValueError
(
"Please input a correct train-module name! (either dialog or cnotrol))"
)
raise
ValueError
(
"Please input a correct train-module name! "
\
"(either dialog or cnotrol))"
)
return
data_list
def
read_data_v2
(
tokenizer
,
data_path
,
train_module
,
last_turn
=
False
,
no_control_code
=
False
,
add_separator
=
False
,
add_ctrl_code_to_dialog
=
False
,
remove_ctrl_sent
=
False
):
"""
Read and tokenize data for version 2 (v2) data files.
Format: control code
\t
dialog context
\t
control sentence
\t
response.
Response only comes from the wizard.
Currently, this function is used to build test dataset for calculating PPL.
"""
data_list
=
[]
with
open
(
data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
rstrip
()
splits
=
line
.
split
(
"
\t
"
)
assert
len
(
splits
)
==
4
control_code
=
splits
[
0
]
dialog_context
=
splits
[
1
]
control_sent
=
splits
[
2
]
response
=
splits
[
3
]
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
if
train_module
==
"dialog"
:
# input_ids
if
add_ctrl_code_to_dialog
:
ctrl_code
=
control_code
.
split
(
" [CTRL] "
)[
0
]
input_ids
=
tokenizer
.
tokenize
(
"( "
+
ctrl_code
+
" )"
)
if
not
remove_ctrl_sent
and
control_sent
!=
"no_passages_used"
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
control_sent
+
" )"
)[:
256
])
else
:
if
remove_ctrl_sent
or
control_sent
==
"no_passages_used"
:
input_ids
=
[]
else
:
input_ids
=
tokenizer
.
tokenize
(
"( "
+
control_sent
+
" )"
)[:
256
]
for
turn
in
turns
:
if
add_separator
:
turn
=
"<< "
+
turn
+
" >>"
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
if
add_separator
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
":"
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
train_module
==
"control"
:
# skip example without control sentences
if
control_sent
==
"no_passages_used"
:
continue
input_ids
=
[]
if
not
no_control_code
:
ctrl_code_list
=
control_code
.
split
(
" [CTRL] "
)[:
3
]
# only choose maximum three control codes
for
code
in
ctrl_code_list
:
if
len
(
code
)
>
0
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
code
+
" )"
))
if
last_turn
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
turns
[
-
1
]))
else
:
for
turn
in
turns
:
if
add_separator
:
turn
=
"<< "
+
turn
+
" >>"
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
if
add_separator
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
":"
))
output_ids
=
tokenizer
.
tokenize
(
control_sent
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct train-module name! "
\
"(either dialog or cnotrol))"
)
return
data_list
return
data_list
...
@@ -125,7 +191,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
...
@@ -125,7 +191,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
data_dict
=
self
.
data
[
idx
]
data_dict
=
self
.
data
[
idx
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max-seq-len!"
#
assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
# length_of_loss_mask == length_of_text - 1
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
...
@@ -140,29 +206,62 @@ class ControlDialogDataset(torch.utils.data.Dataset):
...
@@ -140,29 +206,62 @@ class ControlDialogDataset(torch.utils.data.Dataset):
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
return
{
"text"
:
np
.
array
(
text
,
dtype
=
np
.
int64
),
"loss_mask"
:
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)}
return
{
"text"
:
np
.
array
(
text
,
dtype
=
np
.
int64
),
\
"loss_mask"
:
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)}
def
build_train_valid_test_datasets
(
data_folder
,
dataset_name
,
train_module
,
max_seq_len
,
seed
):
def
build_train_valid_datasets
(
train_data_path
,
valid_data_path
,
train_module
,
max_seq_len
,
seed
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
):
"""Build train, valid, and test datasets."""
"""Build train, valid, and test datasets."""
dataname_dict
=
{
"wizard_of_wikipedia"
:
{
"train"
:
"train_entity_based_control.txt"
,
"valid"
:
"valid_random_split_entity_based_control.txt"
,
"test"
:
"test_random_split_entity_based_control.txt"
}}
#
dataname_dict = {"wizard_of_wikipedia": {"train": "train_entity_based_control.txt", "valid": "valid_random_split_entity_based_control.txt", "test": "test_random_split_entity_based_control.txt"}}
train_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"train"
])
#
train_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["train"])
valid_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"valid"
])
#
valid_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["valid"])
test_data_path
=
os
.
path
.
join
(
data_folder
,
dataset_name
+
"/processed/"
+
dataname_dict
[
dataset_name
][
"test"
])
#
test_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["test"])
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
train_data_list
=
read_data
(
tokenizer
,
train_data_path
,
train_module
)
# train_data_list = read_data(tokenizer, train_data_path, train_module)
valid_data_list
=
read_data
(
tokenizer
,
valid_data_path
,
train_module
)
train_data_list
=
read_data_v2
(
tokenizer
,
train_data_path
,
train_module
,
test_data_list
=
read_data
(
tokenizer
,
test_data_path
,
train_module
)
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
)
valid_data_list
=
read_data_v2
(
tokenizer
,
valid_data_path
,
train_module
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
)
# shuffle the training data
# shuffle the training data
train_data_list
=
data_shuffle
(
train_data_list
,
seed
)
train_data_list
=
data_shuffle
(
train_data_list
,
seed
)
# build train, valid, and test datasets
# build train, valid datasets
train_dataset
=
ControlDialogDataset
(
train_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
train_dataset
=
ControlDialogDataset
(
train_data_list
,
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
max_seq_len
,
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
train_dataset
,
valid_dataset
def
build_test_dataset
(
test_data_path
,
train_module
,
max_seq_len
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
):
tokenizer
=
get_tokenizer
()
test_data_list
=
read_data_v2
(
tokenizer
,
test_data_path
,
train_module
,
last_turn
,
no_control_code
,
add_separator
,
add_ctrl_code_to_dialog
,
remove_ctrl_sent
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
sep_id
=
tokenizer
.
sep_id
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
train_dataset
,
valid_dataset
,
test_dataset
return
test_dataset
tasks/dialctrl/evaluate.py
0 → 100644
View file @
ec202742
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
setup_model_and_optimizer
from
megatron.checkpointing
import
load_checkpoint
from
tasks.finetune_utils
import
build_data_loader
from
tasks.dialctrl.data
import
build_test_dataset
from
tasks.dialctrl.finetune
import
model_provider
,
process_batch
,
loss_func
,
forward_step
from
tasks.dialctrl.metrics
import
F1Metric
from
tqdm
import
tqdm
def
test_dataset_provider
():
"""Build the test dataset for dialog/control module"""
args
=
get_args
()
print_rank_0
(
'> building the test dataset for %s module ...'
\
%
args
.
train_module
)
test_ds
=
build_test_dataset
(
test_data_path
=
args
.
test_data_path
,
train_module
=
args
.
train_module
,
max_seq_len
=
args
.
max_seq_len
,
last_turn
=
args
.
last_turn
,
no_control_code
=
args
.
no_control_code
,
add_separator
=
args
.
add_separator
,
add_ctrl_code_to_dialog
=
args
.
add_ctrl_code_to_dialog
,
remove_ctrl_sent
=
args
.
remove_ctrl_sent
)
print_rank_0
(
"> finished creating the test dataset for %s module ..."
\
%
args
.
train_module
)
print_rank_0
(
'> test set size: %d'
%
len
(
test_ds
))
args
.
eval_iters
=
len
(
test_ds
)
//
args
.
global_batch_size
print_rank_0
(
'> evaluation iteration: %d'
%
args
.
eval_iters
)
return
test_ds
def
_build_test_iterator
(
test_dataset
,
task_collate_fn
=
None
):
"""Test dataloader."""
args
=
get_args
()
print_rank_0
(
'building test dataloader ...'
)
# Test loader
test_dataloader
=
build_data_loader
(
test_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
,
task_collate_fn
)
test_iterator
=
test_dataloader
.
__iter__
()
return
test_iterator
def
evaluate_ppl
(
test_dataset_provider
,
model_provider
,
forward_step
):
args
=
get_args
()
timers
=
get_timers
()
# test dataloader.
timers
(
'test dataset/dataloder'
).
start
()
test_dataset
=
test_dataset_provider
()
test_iterator
=
_build_test_iterator
(
test_dataset
)
timers
(
'test dataset/dataloder'
).
stop
()
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model and optimizer'
).
stop
()
timers
(
'pretrained checkpoint'
).
start
()
if
args
.
pretrained_checkpoint
is
not
None
:
original_load
=
args
.
load
args
.
load
=
args
.
pretrained_checkpoint
original_rng
=
args
.
no_load_rng
args
.
no_load_rng
=
True
iteration
=
load_checkpoint
(
model
,
None
,
None
)
args
.
load
=
original_load
args
.
no_load_rng
=
original_rng
# This is critical when only model is loaded. We should make sure
# main parameters are also updated.
optimizer
.
reload_model_params
()
timers
(
'pretrained checkpoint'
).
stop
()
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
timers
.
log
([
'test dataset/dataloder'
,
'model and optimizer'
,
'pretrained checkpoint'
])
print_rank_0
(
'evaluating ...'
)
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step
,
test_iterator
,
model
,
iteration
,
False
)
print_rank_0
(
'done :-)'
)
def
evaluate_f1
(
guess_file
,
answer_file
,
remove_stopwords
):
guess_list
=
[]
print_rank_0
(
'reading %s'
%
guess_file
)
with
open
(
guess_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
tqdm
(
f
)):
line
=
line
.
strip
()
if
"<|endoftext|>"
in
line
:
line
=
line
.
replace
(
"<|endoftext|>"
,
""
)
guess_list
.
append
(
line
)
answer_list
=
[]
print_rank_0
(
'reading %s'
%
answer_file
)
with
open
(
answer_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
tqdm
(
f
)):
line
=
line
.
strip
()
if
line
==
"no_passages_used"
:
line
=
""
answer_list
.
append
(
line
)
assert
len
(
guess_list
)
==
len
(
answer_list
),
\
"lengths of guess and answer are different!"
precision
,
recall
,
f1
=
F1Metric
.
compute_all_pairs
(
guess_list
,
answer_list
,
remove_stopwords
)
print_rank_0
(
'Precision: %.4f; recall: %.4f; f1: %.4f'
%
(
precision
,
recall
,
f1
))
print_rank_0
(
'done :-)'
)
def
main
():
args
=
get_args
()
if
'ppl'
in
args
.
task
:
evaluate_ppl
(
test_dataset_provider
,
model_provider
,
forward_step
)
elif
'f1'
in
args
.
task
:
evaluate_f1
(
args
.
guess_file
,
args
.
answer_file
,
args
.
remove_stopwords
)
tasks/dialctrl/finetune.py
View file @
ec202742
...
@@ -12,7 +12,7 @@ from megatron.model import GPTModel
...
@@ -12,7 +12,7 @@ from megatron.model import GPTModel
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
evaluate_and_print_results
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
from
tasks.dialctrl.data
import
build_train_valid_
test_
datasets
from
tasks.dialctrl.data
import
build_train_valid_datasets
from
tasks.dialctrl.utils
import
get_ltor_attention_masks_and_position_ids
from
tasks.dialctrl.utils
import
get_ltor_attention_masks_and_position_ids
...
@@ -35,16 +35,27 @@ def train_valid_datasets_provider():
...
@@ -35,16 +35,27 @@ def train_valid_datasets_provider():
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
train_module
)
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
train_module
)
train_ds
,
valid_ds
,
_
=
build_train_valid_
test_
datasets
(
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_folder
=
args
.
data_folder
,
train_data_path
=
args
.
train_data_path
,
dataset_name
=
args
.
dataset_name
,
valid_data_path
=
args
.
test_data_path
,
train_module
=
args
.
train_module
,
train_module
=
args
.
train_module
,
max_seq_len
=
args
.
max_seq_len
,
max_seq_len
=
args
.
max_seq_len
,
seed
=
args
.
seed
)
seed
=
args
.
seed
,
last_turn
=
args
.
last_turn
,
no_control_code
=
args
.
no_control_code
,
add_separator
=
args
.
add_separator
,
add_ctrl_code_to_dialog
=
args
.
add_ctrl_code_to_dialog
,
remove_ctrl_sent
=
args
.
remove_ctrl_sent
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
train_module
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
train_module
)
print_rank_0
(
'> Train size: %d'
%
len
(
train_ds
))
print_rank_0
(
'> Validation size: %d'
%
len
(
valid_ds
))
args
.
eval_interval
=
len
(
train_ds
)
//
args
.
global_batch_size
args
.
eval_interval
=
len
(
train_ds
)
//
args
.
global_batch_size
print_rank_0
(
' > evaluation interval: %d'
%
args
.
eval_interval
)
print_rank_0
(
'> evaluation interval: %d'
%
args
.
eval_interval
)
args
.
eval_iters
=
len
(
valid_ds
)
//
args
.
global_batch_size
print_rank_0
(
'> evaluation iteration: %d'
%
args
.
eval_iters
)
return
train_ds
,
valid_ds
return
train_ds
,
valid_ds
...
...
tasks/dialctrl/metrics.py
0 → 100644
View file @
ec202742
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from
collections
import
Counter
from
typing
import
List
import
numpy
as
np
import
re
from
nltk.corpus
import
stopwords
re_art
=
re
.
compile
(
r
'\b(a|an|the)\b'
)
re_punc
=
re
.
compile
(
r
'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']'
)
stopword_list
=
stopwords
.
words
(
'english'
)
stopword_list
=
stopword_list
+
[
"n's"
,
"'s"
]
stopword_dict
=
{
token
:
True
for
token
in
stopword_list
}
def
normalize_answer
(
s
):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s
=
s
.
lower
()
s
=
re_punc
.
sub
(
' '
,
s
)
s
=
re_art
.
sub
(
' '
,
s
)
# TODO: this could almost certainly be faster with a regex \s+ -> ' '
s
=
' '
.
join
(
s
.
split
())
return
s
def
remove_stopwords
(
token_list
):
new_list
=
[]
for
token
in
token_list
:
if
token
in
stopword_dict
:
continue
new_list
.
append
(
token
)
return
new_list
class
F1Metric
:
"""
Helper class which computes token-level F1.
"""
@
staticmethod
def
_prec_recall_f1_score
(
pred_items
,
gold_items
):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common
=
Counter
(
gold_items
)
&
Counter
(
pred_items
)
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
,
0
,
0
precision
=
1.0
*
num_same
/
len
(
pred_items
)
recall
=
1.0
*
num_same
/
len
(
gold_items
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
precision
,
recall
,
f1
@
staticmethod
def
compute_each_pair
(
guess
:
str
,
answer
:
str
,
rm_sw
:
bool
):
if
answer
==
""
:
return
None
,
None
,
None
if
guess
==
""
:
return
0
,
0
,
0
g_tokens
=
normalize_answer
(
guess
).
split
()
a_tokens
=
normalize_answer
(
answer
).
split
()
if
rm_sw
:
g_tokens
=
remove_stopwords
(
g_tokens
)
a_tokens
=
remove_stopwords
(
a_tokens
)
if
len
(
a_tokens
)
==
0
:
return
None
,
None
,
None
if
len
(
g_tokens
)
==
0
:
return
0
,
0
,
0
precision
,
recall
,
f1
=
F1Metric
.
_prec_recall_f1_score
(
g_tokens
,
a_tokens
)
return
precision
,
recall
,
f1
@
staticmethod
def
compute_all_pairs
(
guesses
:
List
[
str
],
answers
:
List
[
str
],
rm_sw
=
False
):
# additional augment:
# rm_sw: whether to remove stopwords
assert
len
(
guesses
)
==
len
(
answers
)
precision_list
,
recall_list
,
f1_list
=
[],
[],
[]
for
guess
,
answer
in
zip
(
guesses
,
answers
):
precision
,
recall
,
f1
=
F1Metric
.
compute_each_pair
(
guess
,
answer
,
rm_sw
)
if
precision
is
None
or
recall
is
None
or
f1
is
None
:
continue
precision_list
.
append
(
precision
)
recall_list
.
append
(
recall
)
f1_list
.
append
(
f1
)
return
np
.
mean
(
precision_list
),
np
.
mean
(
recall_list
),
np
.
mean
(
f1_list
)
tasks/dialctrl/utils.py
View file @
ec202742
...
@@ -8,7 +8,9 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
...
@@ -8,7 +8,9 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
micro_batch_size
,
seq_length
=
data
.
size
()
micro_batch_size
,
seq_length
=
data
.
size
()
# Attention mask
# Attention mask
attention_mask
=
torch
.
tril
(
torch
.
ones
((
micro_batch_size
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
attention_mask
=
torch
.
tril
(
torch
.
ones
(
(
micro_batch_size
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
# mask padded tokens
# mask padded tokens
for
b
in
range
(
micro_batch_size
):
for
b
in
range
(
micro_batch_size
):
...
...
tasks/main.py
View file @
ec202742
...
@@ -87,15 +87,41 @@ def get_tasks_args(parser):
...
@@ -87,15 +87,41 @@ def get_tasks_args(parser):
# finetune for controllable dialogue
# finetune for controllable dialogue
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
help
=
'either control module or dialogue model (control or dialog)'
)
help
=
'either control module or dialogue model (control or dialog)'
)
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--train-data-path'
,
type
=
str
,
default
=
""
,
help
=
'data folder (path of the data folder)'
)
help
=
'datapath for training set'
)
group
.
add_argument
(
'--dataset-name'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--test-data-path'
,
type
=
str
,
default
=
""
,
help
=
'dataset name (e.g., wizard_of_wikipedia)'
)
help
=
'datapath for test set'
)
group
.
add_argument
(
'--guess-file'
,
type
=
str
,
default
=
""
,
help
=
'datapath for generated sentences'
)
group
.
add_argument
(
'--answer-file'
,
type
=
str
,
default
=
""
,
help
=
'datapath for golden sentences'
)
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
help
=
'maximum sequence length'
)
help
=
'maximum sequence length'
)
group
.
add_argument
(
'--spec-toks'
,
type
=
str
,
default
=
"[SEP],[CTRL],[PAD]"
,
group
.
add_argument
(
'--spec-toks'
,
type
=
str
,
default
=
None
,
help
=
'additional special tokens'
)
help
=
'additional special tokens'
)
group
.
add_argument
(
'--last-turn'
,
action
=
'store_true'
,
help
=
'only use last turn for control model'
)
group
.
add_argument
(
'--no-control-code'
,
action
=
'store_true'
,
help
=
'removing control code in the training for control model'
)
group
.
add_argument
(
'--remove-stopwords'
,
action
=
'store_true'
,
help
=
'removing stopwords when evaluating F1-score'
)
group
.
add_argument
(
'--add-separator'
,
action
=
'store_true'
,
help
=
'add separator between turns and add colon before generation'
)
group
.
add_argument
(
'--add-ctrl-code-to-dialog'
,
action
=
'store_true'
,
help
=
'add control code in the dialog modeling'
)
group
.
add_argument
(
'--remove-ctrl-sent'
,
action
=
'store_true'
,
help
=
'dont use control sentence in dialog modeling'
)
# finetune for controllable generation
group
.
add_argument
(
'--wiki-path'
,
type
=
str
,
default
=
""
,
help
=
'data path for the wikipedia corpus'
)
group
.
add_argument
(
'--tokenized-path'
,
type
=
str
,
default
=
""
,
help
=
'data path for the tokenized file'
)
group
.
add_argument
(
'--prop'
,
type
=
float
,
default
=
1.0
,
help
=
'Proportion of data used for training'
)
group
.
add_argument
(
'--max-instance'
,
type
=
int
,
default
=
10000000
,
help
=
'Proportion of data used for training'
)
return
parser
return
parser
...
@@ -120,8 +146,12 @@ if __name__ == '__main__':
...
@@ -120,8 +146,12 @@ if __name__ == '__main__':
from
orqa.evaluate_orqa
import
main
from
orqa.evaluate_orqa
import
main
elif
args
.
task
in
[
'RET-FINETUNE-NQ'
]:
elif
args
.
task
in
[
'RET-FINETUNE-NQ'
]:
from
orqa.supervised.finetune
import
main
from
orqa.supervised.finetune
import
main
elif
args
.
task
==
'control-gen'
:
from
control_gen.finetune
import
main
elif
args
.
task
==
'dialctrl'
:
elif
args
.
task
==
'dialctrl'
:
from
dialctrl.finetune
import
main
from
dialctrl.finetune
import
main
elif
args
.
task
in
[
'dialctrl-eval-ppl'
,
'dialctrl-eval-f1'
]:
from
dialctrl.evaluate
import
main
else
:
else
:
raise
NotImplementedError
(
'Task {} is not implemented.'
.
format
(
raise
NotImplementedError
(
'Task {} is not implemented.'
.
format
(
args
.
task
))
args
.
task
))
...
...
tools/control_dialog_interactive.py
0 → 100644
View file @
ec202742
"""Sample Generate Controllable Dialog Model"""
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
import
argparse
import
torch
from
transformers
import
DPRQuestionEncoderTokenizer
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
dialog_with_gpt_control_interactive
,
dialog_with_dpr_control_interactive
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
add_control_dialog_generate_args
(
parser
):
"""Text generation arguments."""
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'During generation recompute all attention '
'instead of using previously computed keys/values.'
)
group
.
add_argument
(
"--ctrl-type"
,
type
=
str
,
default
=
""
,
help
=
"Either dpr or gpt"
)
group
.
add_argument
(
"--ctrl-hidden-size"
,
type
=
int
,
default
=
1024
,
help
=
"hidden-size of gpt control model"
)
group
.
add_argument
(
"--ctrl-num-layers"
,
type
=
int
,
default
=
24
,
help
=
"num-layers of gpt control model"
)
group
.
add_argument
(
"--ctrl-num-attention-heads"
,
type
=
int
,
default
=
16
,
help
=
"num-attention-heads of gpt control model"
)
group
.
add_argument
(
"--ctrl-gpt-load"
,
type
=
str
,
default
=
""
,
help
=
"checkpoint path of the gpt control model"
)
group
.
add_argument
(
"--ctrl-dpr-load"
,
type
=
str
,
default
=
""
,
help
=
"checkpoint path of the dpr control model"
)
group
.
add_argument
(
"--knowledge-corpus-path"
,
type
=
str
,
default
=
""
,
help
=
"The path for the knowledge corpus"
)
group
.
add_argument
(
"--knowledge-corpus-emb"
,
type
=
str
,
default
=
""
,
help
=
"The path for the knowledge embedding"
)
group
.
add_argument
(
'--spec-toks'
,
type
=
str
,
default
=
None
,
help
=
'additional special tokens'
)
group
.
add_argument
(
'--add-separator'
,
action
=
"store_true"
,
help
=
'Add separator for the inputs'
)
return
parser
def
main
():
"""Main program."""
initialize_megatron
(
extra_args_provider
=
add_control_dialog_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up conversational model
conv_model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
conv_model
,
None
,
None
)
assert
len
(
conv_model
)
==
1
,
"Above condition should have caught this"
conv_model
=
conv_model
[
0
]
# Set up control model
assert
args
.
ctrl_type
in
[
"gpt"
,
"dpr"
],
\
"please input a correct control model type"
if
args
.
ctrl_type
==
"gpt"
:
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
args
.
hidden_size
=
args
.
ctrl_hidden_size
args
.
ffn_hidden_size
=
4
*
args
.
hidden_size
args
.
num_layers
=
args
.
ctrl_num_layers
args
.
num_attention_heads
=
args
.
ctrl_num_attention_heads
args
.
load
=
args
.
ctrl_gpt_load
ctrl_model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
ctrl_model
,
None
,
None
)
ctrl_model
=
ctrl_model
[
0
]
dialog_with_gpt_control_interactive
(
conv_model
,
ctrl_model
,
args
.
add_separator
)
else
:
print_rank_0
(
"> Loading model from %s"
%
args
.
ctrl_dpr_load
)
ctrl_model
=
torch
.
load
(
args
.
ctrl_dpr_load
)
ctrl_model
.
cuda
()
ctrl_tokenizer
=
DPRQuestionEncoderTokenizer
.
from_pretrained
(
'facebook/dpr-question_encoder-single-nq-base'
)
print_rank_0
(
"> Loading knowledge corpus and embeddings"
)
with
open
(
args
.
knowledge_corpus_path
,
"r"
)
as
f
:
knowledge_corpus
=
f
.
readlines
()
knowledge_corpus_emb
=
torch
.
load
(
args
.
knowledge_corpus_emb
)
knowledge_corpus_emb
=
knowledge_corpus_emb
.
cuda
()
assert
knowledge_corpus_emb
.
size
()[
0
]
==
len
(
knowledge_corpus
),
\
"The size of knowledge corpus and embeddings should be the same"
dialog_with_dpr_control_interactive
(
conv_model
,
ctrl_model
,
ctrl_tokenizer
,
knowledge_corpus
,
knowledge_corpus_emb
,
args
.
add_separator
)
if
__name__
==
"__main__"
:
main
()
tools/generate_samples_gpt.py
View file @
ec202742
...
@@ -30,6 +30,8 @@ from megatron.model import GPTModel
...
@@ -30,6 +30,8 @@ from megatron.model import GPTModel
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_samples_input_from_file
from
megatron.text_generation_utils
import
generate_samples_input_from_file
from
megatron.text_generation_utils
import
generate_samples_prompt_input_from_file
from
megatron.text_generation_utils
import
generate_samples_line_by_line_input_from_file
from
megatron.text_generation_utils
import
generate_samples_interactive
from
megatron.text_generation_utils
import
generate_samples_interactive
...
@@ -70,6 +72,18 @@ def add_text_generate_args(parser):
...
@@ -70,6 +72,18 @@ def add_text_generate_args(parser):
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'During generation recompute all attention '
help
=
'During generation recompute all attention '
'instead of using previously computed keys/values.'
)
'instead of using previously computed keys/values.'
)
group
.
add_argument
(
'--spec-toks'
,
type
=
str
,
default
=
None
,
help
=
'additional special tokens'
)
group
.
add_argument
(
'--line-by-line'
,
action
=
"store_true"
,
help
=
'generate samples line by line'
)
group
.
add_argument
(
'--prompt'
,
action
=
"store_true"
,
help
=
'generate samples based on prompting'
)
group
.
add_argument
(
'--prompt-file'
,
type
=
str
,
default
=
""
,
help
=
'prompting file'
)
group
.
add_argument
(
'--prompt-type'
,
type
=
str
,
default
=
""
,
help
=
'prompt type (context or keyphrase)'
)
group
.
add_argument
(
'--num-prompt-examples'
,
type
=
int
,
default
=
10
,
help
=
'number of prompt examples'
)
return
parser
return
parser
...
@@ -98,8 +112,14 @@ def main():
...
@@ -98,8 +112,14 @@ def main():
# Generate samples.
# Generate samples.
if
args
.
num_samples
==
0
:
if
args
.
num_samples
==
0
:
args
.
micro_batch_size
=
1
if
args
.
sample_input_file
!=
None
:
if
args
.
sample_input_file
!=
None
:
args
.
micro_batch_size
=
1
if
args
.
line_by_line
:
if
args
.
prompt
:
generate_samples_prompt_input_from_file
(
model
)
else
:
generate_samples_line_by_line_input_from_file
(
model
)
else
:
generate_samples_input_from_file
(
model
)
generate_samples_input_from_file
(
model
)
else
:
else
:
generate_samples_interactive
(
model
)
generate_samples_interactive
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment