Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a87777bf
Commit
a87777bf
authored
Dec 06, 2021
by
zihanl
Browse files
delete finetune part
parent
5f4e63fc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
453 deletions
+0
-453
tasks/knwl_dialo/data.py
tasks/knwl_dialo/data.py
+0
-243
tasks/knwl_dialo/finetune.py
tasks/knwl_dialo/finetune.py
+0
-210
No files found.
tasks/knwl_dialo/data.py
deleted
100644 → 0
View file @
5f4e63fc
"""Build Dataset for Controllable Coversational Model"""
import
os
import
torch
import
numpy
as
np
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
def
read_data_for_finetuning
(
tokenizer
,
data_path
,
module
):
"""
Data Format: topic
\t
dialog context
\t
knowledge
\t
response.
"""
data_list
=
[]
with
open
(
data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
rstrip
()
splits
=
line
.
split
(
"
\t
"
)
assert
len
(
splits
)
==
4
topic
=
splits
[
0
].
split
(
" [CTRL] "
)[
0
]
dialog_context
=
splits
[
1
]
knowledge
=
splits
[
2
]
response
=
splits
[
3
]
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
if
module
==
"response"
:
# input_ids
input_ids
=
tokenizer
.
tokenize
(
"( "
+
topic
+
" )"
)
if
knowledge
!=
"no_passages_used"
:
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
knowledge
+
" )"
)[:
256
])
for
turn
in
turns
:
turn
=
"<< "
+
turn
+
" >>"
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
input_ids
.
extend
(
tokenizer
.
tokenize
(
":"
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
module
==
"knowledge"
:
# skip example without knowledge sentences
if
knowledge
==
"no_passages_used"
:
continue
input_ids
=
[]
input_ids
.
extend
(
tokenizer
.
tokenize
(
"( "
+
topic
+
" )"
))
for
turn
in
turns
:
turn
=
"<< "
+
turn
+
" >>"
input_ids
.
extend
(
tokenizer
.
tokenize
(
turn
))
input_ids
.
extend
(
tokenizer
.
tokenize
(
":"
))
output_ids
=
tokenizer
.
tokenize
(
knowledge
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct module name! "
\
"(either dialog or cnotrol))"
)
return
data_list
def
read_data_for_prompting
(
tokenizer
,
test_data_path
,
prompt_file
,
module
,
num_prompt_examples
,
dynamic_prompt
):
# get prompts
if
dynamic_prompt
:
import
json
prompt_examples_dict
=
{}
with
open
(
prompt_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt_examples
=
prompt_examples
[:
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
topic
]
=
prompt
else
:
with
open
(
prompt_file
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
data_list
=
[]
with
open
(
test_data_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
topic
=
splits
[
0
].
split
(
" [CTRL] "
)[
0
]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
last_turn
=
turns
[
-
1
]
ctrl_sent
=
splits
[
2
]
response
=
splits
[
3
]
if
dynamic_prompt
:
prompt
=
prompt_examples_dict
[
topic
]
if
module
==
"response"
:
# input seq
input_seq
=
prompt
input_seq
+=
"Topic: "
+
topic
+
". "
input_seq
+=
"User says: "
+
last_turn
+
" "
input_seq
+=
"We know that: "
+
ctrl_sent
+
" "
input_seq
+=
"System replies:"
# output seq
output_seq
=
response
input_ids
=
tokenizer
.
tokenize
(
input_seq
)
output_ids
=
tokenizer
.
tokenize
(
output_seq
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
module
==
"knowledge"
:
# input seq
input_seq
=
prompt
input_seq
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
# output seq
output_seq
=
ctrl_sent
input_ids
=
tokenizer
.
tokenize
(
input_seq
)
output_ids
=
tokenizer
.
tokenize
(
output_seq
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
raise
ValueError
(
"Please input a correct module name! "
\
"(either dialog or cnotrol))"
)
return
data_list
def
data_shuffle
(
data
,
seed
):
# set random seed to make the shuffling reproducible
np
.
random
.
seed
(
seed
)
np
.
random
.
shuffle
(
data
)
return
data
class
KnwlDialoDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data
,
max_seq_len
,
pad_id
,
eod_id
):
# need to deal with padding, label masking
self
.
data
=
data
self
.
max_seq_len
=
max_seq_len
self
.
pad_id
=
pad_id
self
.
eod_id
=
eod_id
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
data_dict
=
self
.
data
[
idx
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
text
=
input_ids
+
output_ids
+
[
self
.
eod_id
]
loss_mask
=
[
0
]
*
(
len
(
input_ids
)
-
1
)
+
[
1
]
*
(
len
(
output_ids
)
+
1
)
text_len
=
len
(
text
)
if
text_len
>
self
.
max_seq_len
+
1
:
text
=
text
[:
self
.
max_seq_len
+
1
]
loss_mask
=
loss_mask
[:
self
.
max_seq_len
]
else
:
text
+=
[
self
.
pad_id
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
loss_mask
+=
[
0
]
*
(
self
.
max_seq_len
+
1
-
text_len
)
return
{
"text"
:
np
.
array
(
text
,
dtype
=
np
.
int64
),
\
"loss_mask"
:
np
.
array
(
loss_mask
,
dtype
=
np
.
int64
)}
def
build_train_valid_datasets
(
train_data_path
,
valid_data_path
,
module
,
max_seq_len
,
seed
):
"""Build train, valid, and test datasets."""
tokenizer
=
get_tokenizer
()
train_data_list
=
read_data_for_finetuning
(
tokenizer
,
train_data_path
,
module
)
valid_data_list
=
read_data_for_finetuning
(
tokenizer
,
valid_data_path
,
module
)
# shuffle the training data
train_data_list
=
data_shuffle
(
train_data_list
,
seed
)
# build train, valid datasets
train_dataset
=
KnwlDialoDataset
(
train_data_list
,
max_seq_len
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
valid_dataset
=
KnwlDialoDataset
(
valid_data_list
,
max_seq_len
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
train_dataset
,
valid_dataset
def
build_test_dataset
(
test_data_path
,
module
,
max_seq_len
):
tokenizer
=
get_tokenizer
()
test_data_list
=
read_data_for_finetuning
(
tokenizer
,
test_data_path
,
module
)
test_dataset
=
KnwlDialoDataset
(
test_data_list
,
max_seq_len
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
test_dataset
def
build_test_dataset_for_prompting
(
test_data_path
,
prompt_file
,
module
,
max_seq_len
,
num_prompt_examples
,
dynamic_prompt
):
tokenizer
=
get_tokenizer
()
test_data_list
=
read_data_for_prompting
(
tokenizer
,
test_data_path
,
prompt_file
,
module
,
\
num_prompt_examples
,
dynamic_prompt
)
test_dataset
=
KnwlDialoDataset
(
test_data_list
,
max_seq_len
,
pad_id
=
tokenizer
.
pad_id
,
eod_id
=
tokenizer
.
eod_id
)
return
test_dataset
tasks/knwl_dialo/finetune.py
deleted
100644 → 0
View file @
5f4e63fc
"""Finetuning a pretrained language model for knowledge/response generation"""
import
torch
from
functools
import
partial
from
megatron
import
mpu
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron.model
import
GPTModel
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
get_model
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.initialize
import
initialize_megatron
from
tasks.finetune_utils
import
finetune
from
tasks.knwl_dialo.data
import
build_train_valid_datasets
from
tasks.knwl_dialo.utils
import
get_ltor_attention_masks_and_position_ids
from
tasks.knwl_dialo.utils
import
get_token_stream
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
train_valid_datasets_provider
():
"""Build train, valid, and test datasets for dialog/control module"""
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
module
)
train_ds
,
valid_ds
=
build_train_valid_datasets
(
train_data_path
=
args
.
train_data_path
,
valid_data_path
=
args
.
test_data_path
,
module
=
args
.
module
,
max_seq_len
=
args
.
seq_length
,
seed
=
args
.
seed
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
module
)
print_rank_0
(
'> Train size: %d'
%
len
(
train_ds
))
print_rank_0
(
'> Validation size: %d'
%
len
(
valid_ds
))
args
.
eval_interval
=
len
(
train_ds
)
//
args
.
global_batch_size
print_rank_0
(
'> evaluation interval: %d'
%
args
.
eval_interval
)
args
.
eval_iters
=
len
(
valid_ds
)
//
args
.
global_batch_size
print_rank_0
(
'> evaluation iteration: %d'
%
args
.
eval_iters
)
return
train_ds
,
valid_ds
def
process_batch
(
batch
):
"""Generate a batch"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Items and their type.
keys
=
[
'text'
,
'loss_mask'
]
datatype
=
torch
.
int64
data_b
=
mpu
.
broadcast_data
(
keys
,
batch
,
datatype
)
tokens_
=
data_b
[
'text'
].
long
()
labels
=
tokens_
[:,
1
:].
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
loss_mask
=
data_b
[
'loss_mask'
].
float
()
# Get the attention_mask and postition ids.
attention_mask
,
position_ids
=
\
get_ltor_attention_masks_and_position_ids
(
tokens
,
tokenizer
.
eod_id
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
loss_func
(
loss_mask
,
output_tensor
):
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
forward_step
(
batch
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
process_batch
(
batch_
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
def
generate_samples_input_from_file
(
model
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_pos
=
0
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w"
)
context_count
=
0
model
.
eval
()
# start the generation process
with
torch
.
no_grad
():
while
True
:
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
# get the generation outputs
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
# write the generation to the output file
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
if
"
\r
"
in
trim_decode_tokens
:
trim_decode_tokens
=
trim_decode_tokens
.
replace
(
"
\r
"
,
""
)
if
"
\n
"
in
trim_decode_tokens
:
trim_decode_tokens
=
trim_decode_tokens
.
replace
(
"
\n
"
,
""
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
context_count
+=
1
if
input_pos
==
input_count
:
return
def
run_generation
(
model_provider
):
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# run generation
generate_samples_input_from_file
(
model
)
def
main
():
args
=
get_args
()
if
"FINETUNE"
in
args
.
task
:
# finetune
finetune
(
train_valid_datasets_provider
,
model_provider
,
\
forward_step
=
forward_step
)
else
:
# generate
run_generation
(
model_provider
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment