Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
aaa7aa32
Commit
aaa7aa32
authored
Dec 06, 2021
by
zihanl
Browse files
remove finetune part
parent
a87777bf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
142 deletions
+1
-142
tasks/knwl_dialo/evaluate.py
tasks/knwl_dialo/evaluate.py
+1
-111
tasks/knwl_dialo/utils.py
tasks/knwl_dialo/utils.py
+0
-31
No files found.
tasks/knwl_dialo/evaluate.py
View file @
aaa7aa32
...
@@ -2,116 +2,10 @@
...
@@ -2,116 +2,10 @@
"""Model evaluation"""
"""Model evaluation"""
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
setup_model_and_optimizer
from
megatron.checkpointing
import
load_checkpoint
from
tasks.finetune_utils
import
build_data_loader
from
tasks.knwl_dialo.data
import
build_test_dataset
from
tasks.knwl_dialo.data
import
build_test_dataset_for_prompting
from
tasks.knwl_dialo.finetune
import
model_provider
from
tasks.knwl_dialo.finetune
import
process_batch
from
tasks.knwl_dialo.finetune
import
loss_func
from
tasks.knwl_dialo.finetune
import
forward_step
from
tasks.knwl_dialo.metrics
import
F1Metric
from
tasks.knwl_dialo.metrics
import
F1Metric
from
tqdm
import
tqdm
from
tqdm
import
tqdm
def
test_dataset_provider
():
"""Build the test dataset"""
args
=
get_args
()
print_rank_0
(
'> building the test dataset for %s module ...'
\
%
args
.
module
)
if
args
.
prompt_type
!=
""
:
print_rank_0
(
'> evaluating ppl for prompting'
)
test_ds
=
build_test_dataset_for_prompting
(
test_data_path
=
args
.
test_data_path
,
prompt_file
=
args
.
prompt_file
,
module
=
args
.
module
,
max_seq_len
=
args
.
seq_length
,
num_prompt_examples
=
args
.
num_prompt_examples
,
three_turns
=
args
.
three_turns
,
dynamic_prompt
=
args
.
dynamic_prompt
)
else
:
print_rank_0
(
'> evaluating ppl for finetuning'
)
test_ds
=
build_test_dataset
(
test_data_path
=
args
.
test_data_path
,
module
=
args
.
module
,
max_seq_len
=
args
.
seq_length
,
last_turn
=
args
.
last_turn
,
no_control_code
=
args
.
no_control_code
,
add_separator
=
args
.
add_separator
,
add_ctrl_code_to_dialog
=
args
.
add_ctrl_code_to_dialog
,
remove_ctrl_sent
=
args
.
remove_ctrl_sent
)
print_rank_0
(
"> finished creating the test dataset for %s module ..."
\
%
args
.
module
)
print_rank_0
(
'> test set size: %d'
%
len
(
test_ds
))
args
.
eval_iters
=
len
(
test_ds
)
//
args
.
global_batch_size
print_rank_0
(
'> evaluation iteration: %d'
%
args
.
eval_iters
)
return
test_ds
def
_build_test_iterator
(
test_dataset
,
task_collate_fn
=
None
):
"""Test dataloader."""
args
=
get_args
()
print_rank_0
(
'building test dataloader ...'
)
# Test loader
test_dataloader
=
build_data_loader
(
test_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
,
task_collate_fn
)
test_iterator
=
test_dataloader
.
__iter__
()
return
test_iterator
def
evaluate_ppl
(
test_dataset_provider
,
model_provider
,
forward_step
):
"""Evaluating perplexity"""
args
=
get_args
()
timers
=
get_timers
()
# test dataloader.
timers
(
'test dataset/dataloder'
).
start
()
test_dataset
=
test_dataset_provider
()
test_iterator
=
_build_test_iterator
(
test_dataset
)
timers
(
'test dataset/dataloder'
).
stop
()
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model and optimizer'
).
stop
()
timers
(
'pretrained checkpoint'
).
start
()
if
args
.
pretrained_checkpoint
is
not
None
:
original_load
=
args
.
load
args
.
load
=
args
.
pretrained_checkpoint
original_rng
=
args
.
no_load_rng
args
.
no_load_rng
=
True
iteration
=
load_checkpoint
(
model
,
None
,
None
)
args
.
load
=
original_load
args
.
no_load_rng
=
original_rng
# This is critical when only model is loaded. We should make sure
# main parameters are also updated.
optimizer
.
reload_model_params
()
timers
(
'pretrained checkpoint'
).
stop
()
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
timers
.
log
([
'test dataset/dataloder'
,
'model and optimizer'
,
'pretrained checkpoint'
])
print_rank_0
(
'evaluating ...'
)
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step
,
test_iterator
,
model
,
iteration
,
False
)
print_rank_0
(
'done :-)'
)
def
evaluate_f1
(
guess_file
,
answer_file
):
def
evaluate_f1
(
guess_file
,
answer_file
):
"""Evaluating F1 Score"""
"""Evaluating F1 Score"""
...
@@ -146,9 +40,5 @@ def evaluate_f1(guess_file, answer_file):
...
@@ -146,9 +40,5 @@ def evaluate_f1(guess_file, answer_file):
def
main
():
def
main
():
args
=
get_args
()
args
=
get_args
()
if
'PPL'
in
args
.
task
:
evaluate_ppl
(
test_dataset_provider
,
model_provider
,
forward_step
)
elif
'F1'
in
args
.
task
:
evaluate_f1
(
args
.
guess_file
,
args
.
answer_file
)
evaluate_f1
(
args
.
guess_file
,
args
.
answer_file
)
tasks/knwl_dialo/utils.py
View file @
aaa7aa32
...
@@ -12,37 +12,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
...
@@ -12,37 +12,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
def
get_ltor_attention_masks_and_position_ids
(
data
,
eod_token_id
):
"""
Build attention masks and position id for left to right model.
Different from the existing get_ltor_masks_and_position_ids function,
we add padding to the input sequences to make sure their lengths are the same.
"""
micro_batch_size
,
seq_length
=
data
.
size
()
# Attention mask
attention_mask
=
torch
.
tril
(
torch
.
ones
(
(
micro_batch_size
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
# mask padded tokens
for
b
in
range
(
micro_batch_size
):
for
idx
in
range
(
seq_length
-
1
):
if
data
[
b
,
idx
]
==
eod_token_id
:
# pad tokens that come after the eod token
attention_mask
[
b
,
0
,
idx
+
1
:,
:]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
return
attention_mask
,
position_ids
def
switch
(
val1
,
val2
,
boolean
):
def
switch
(
val1
,
val2
,
boolean
):
"""Return either val1 or val2 depending on boolean"""
"""Return either val1 or val2 depending on boolean"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment