Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
458d7785
Commit
458d7785
authored
Feb 23, 2022
by
rprenger
Browse files
Fixing merge conflict
parents
d7bf1ab5
9c5a830f
Changes
65
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
333 additions
and
8 deletions
+333
-8
tasks/msdp/prompt.py
tasks/msdp/prompt.py
+322
-0
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+0
-0
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+0
-0
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+7
-7
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+4
-1
No files found.
tasks/msdp/prompt.py
0 → 100644
View file @
458d7785
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompting the pretrained language model to generate knowledge/response"""
import
json
import
torch
import
requests
from
nltk
import
word_tokenize
from
megatron
import
mpu
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.text_generation
import
generate_and_post_process
def
call_model_api
(
inputs
,
tokens_to_generate
):
"""Calling the model api to get the output generations"""
args
=
get_args
()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers
=
{
'Content-Type'
:
'application/json; charset=UTF-8'
}
data
=
{
"prompts"
:
[
inputs
],
"tokens_to_generate"
:
tokens_to_generate
,
"top_k"
:
1
}
data_json
=
json
.
dumps
(
data
)
outputs
=
requests
.
put
(
args
.
megatron_api_url
,
headers
=
headers
,
data
=
data_json
).
json
()[
"text"
][
0
]
input_len
=
len
(
inputs
)
outputs
=
outputs
[
input_len
:]
outputs
=
outputs
.
split
(
"
\n
"
)[
0
].
strip
()
return
outputs
def
read_prompts
(
prompt_path
,
prompt_type
,
n_example
):
"""Read prompt data"""
if
prompt_type
==
"knowledge"
:
# prompts for the knowledge generation
prompt_examples_dict
=
{}
# read prompt_path
with
open
(
prompt_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
return
prompt_examples_dict
else
:
# prompts for the response generation
# read prompt_path
prompt
=
""
with
open
(
prompt_path
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
n_example
]
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
return
prompt
def
generate_samples_by_calling_api
():
""" Generate outputs by calling"""
args
=
get_args
()
assert
args
.
prompt_type
in
[
"knowledge"
,
"response"
],
\
"Please input a correct prompt type!"
if
args
.
prompt_type
==
"knowledge"
:
# read knowledge generation prompts
knwl_gen_prompt_dict
=
read_prompts
(
args
.
prompt_file
,
args
.
prompt_type
,
args
.
num_prompt_examples
)
else
:
resp_gen_prompt
=
read_prompts
(
args
.
prompt_file
,
args
.
prompt_type
,
args
.
num_prompt_examples
)
# read the test data
fname
=
open
(
args
.
sample_input_file
,
"r"
)
test_sample_list
=
fname
.
readlines
()
# create output file
fname_out
=
open
(
args
.
sample_output_file
,
"w"
)
# call the api to get the output generations
for
test_sample
in
test_sample_list
:
test_sample
=
test_sample
.
strip
()
splits
=
test_sample
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
# prepare the inputs for the api
if
args
.
prompt_type
==
"knowledge"
:
## inputs = prompt + current test
# get the prompt
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
key
=
topic
+
" "
+
last_turn
inputs
=
knwl_gen_prompt_dict
[
key
]
# add current test
inputs
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
else
:
# inputs = prompt + current test
# get the prompt
inputs
=
resp_gen_prompt
# add current test
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
inputs
+=
"Topic: "
+
topic
+
". "
inputs
+=
"User says: "
+
last_turn
+
" "
inputs
+=
"We know that: "
+
knowledge
+
" "
inputs
+=
"System replies:"
# get the output generations from the api,
# and write to the output file
generations
=
call_model_api
(
inputs
,
args
.
out_seq_length
)
fname_out
.
write
(
generations
)
fname_out
.
write
(
"
\n
"
)
fname
.
close
()
fname_out
.
close
()
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
generate_samples_by_prompting_input_from_file
(
model
):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w"
)
# only two prompt types (i.e., knowledge and response) are allowed
assert
args
.
prompt_type
in
[
"knowledge"
,
"response"
],
\
"Please input a correct prompt type!"
# Read the prompt file
if
args
.
prompt_type
==
"knowledge"
:
# read the prompts for the knowledge generation
prompt_examples_dict
=
{}
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
# get the prompt examples based on the key
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
else
:
# read the prompts for the response generation
# prompts are fixed for all test samples
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
args
.
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
input_pos
=
0
model
.
eval
()
# perform prompting
with
torch
.
no_grad
():
while
True
:
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
input_str
=
all_raw_text
[
input_pos
]
input_str
=
input_str
.
strip
()
splits
=
input_str
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
if
args
.
prompt_type
==
"knowledge"
:
# first add the prompt into the raw_text
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
key
=
topic
+
" "
+
last_turn
raw_text
=
prompt_examples_dict
[
key
]
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
raw_text
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
else
:
# first add the prompt into the raw_text
raw_text
=
prompt
# construct inputs for response generation
# then add the constructed inputs into the raw_text
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
raw_text
+=
"Topic: "
+
topic
+
". "
raw_text
+=
"User says: "
+
last_turn
+
" "
raw_text
+=
"We know that: "
+
knowledge
+
" "
raw_text
+=
"System replies:"
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
else
:
raw_text
=
"EMPTY TEXT"
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
outputs
=
generate_and_post_process
(
model
=
model
,
prompts
=
[
raw_text
],
tokens_to_generate
=
args
.
out_seq_length
,
top_k_sampling
=
1
)
prompts_plus_generations
=
outputs
[
0
]
prompts_plus_generations
=
prompts_plus_generations
[
0
]
# write the generated output to the output file
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
generations
=
prompts_plus_generations
[
raw_text_len
:]
generations
=
generations
.
split
(
"
\n
"
)[
0
]
generations
=
generations
.
strip
()
fname_out
.
write
(
generations
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
if
input_pos
==
input_count
:
return
def
main
():
args
=
get_args
()
if
args
.
api_prompt
:
# obtain the generations by calling the api
generate_samples_by_calling_api
()
return
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
,
wrap_with_ddp
=
False
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# perform the prompting
generate_samples_by_prompting_input_from_file
(
model
)
tasks/vision/classification.py
→
tasks/vision/classification
/classification
.py
View file @
458d7785
File moved
tasks/vision/eval_utils.py
→
tasks/vision/
classification/
eval_utils.py
View file @
458d7785
File moved
tasks/vision/finetune_utils.py
View file @
458d7785
...
...
@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def
_train
(
model
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
...
...
@@ -179,7 +179,7 @@ def _train(
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr
_scheduler
forward_step
,
batch
,
model
,
optimizer
,
opt_param
_scheduler
)
iteration
+=
1
...
...
@@ -206,7 +206,7 @@ def _train(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Checkpointing
...
...
@@ -215,7 +215,7 @@ def _train(
and
args
.
save_interval
and
iteration
%
args
.
save_interval
==
0
):
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
...
...
@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
...
...
@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
...
...
@@ -300,7 +300,7 @@ def finetune(
_train
(
model
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
...
...
tools/run_text_generation_server.py
View file @
458d7785
...
...
@@ -78,4 +78,7 @@ if __name__ == "__main__":
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
0
)
if
choice
[
0
].
item
()
==
0
:
generate_and_post_process
(
model
)
try
:
generate_and_post_process
(
model
)
except
ValueError
as
ve
:
pass
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment