Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4e48efdf
Commit
4e48efdf
authored
Dec 13, 2021
by
zihanl
Browse files
change directory name
parent
f24c972c
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
730 deletions
+0
-730
tasks/knwl_dialo/evaluate.py
tasks/knwl_dialo/evaluate.py
+0
-58
tasks/knwl_dialo/metrics.py
tasks/knwl_dialo/metrics.py
+0
-77
tasks/knwl_dialo/preprocessing.py
tasks/knwl_dialo/preprocessing.py
+0
-595
No files found.
tasks/knwl_dialo/evaluate.py
deleted
100644 → 0
View file @
f24c972c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model evaluation"""
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
tasks.knwl_dialo.metrics
import
F1Metric
from
tqdm
import
tqdm
def
evaluate_f1
(
guess_file
,
answer_file
):
"""Evaluating F1 Score"""
guess_list
=
[]
print_rank_0
(
'reading %s'
%
guess_file
)
with
open
(
guess_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
tqdm
(
f
)):
line
=
line
.
strip
()
if
"<|endoftext|>"
in
line
:
line
=
line
.
replace
(
"<|endoftext|>"
,
""
)
guess_list
.
append
(
line
)
answer_list
=
[]
print_rank_0
(
'reading %s'
%
answer_file
)
with
open
(
answer_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
tqdm
(
f
)):
line
=
line
.
strip
()
if
line
==
"no_passages_used"
:
line
=
""
answer_list
.
append
(
line
)
assert
len
(
guess_list
)
==
len
(
answer_list
),
\
"lengths of guess and answer are different!"
precision
,
recall
,
f1
=
F1Metric
.
compute_all_pairs
(
guess_list
,
answer_list
)
print_rank_0
(
'Precision: %.4f; recall: %.4f; f1: %.4f'
%
(
precision
,
recall
,
f1
))
print_rank_0
(
'done :-)'
)
def
main
():
args
=
get_args
()
evaluate_f1
(
args
.
guess_file
,
args
.
answer_file
)
tasks/knwl_dialo/metrics.py
deleted
100644 → 0
View file @
f24c972c
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from
collections
import
Counter
from
typing
import
List
import
numpy
as
np
import
re
re_art
=
re
.
compile
(
r
'\b(a|an|the)\b'
)
re_punc
=
re
.
compile
(
r
'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']'
)
def
normalize_answer
(
s
):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s
=
s
.
lower
()
s
=
re_punc
.
sub
(
' '
,
s
)
s
=
re_art
.
sub
(
' '
,
s
)
s
=
' '
.
join
(
s
.
split
())
return
s
class
F1Metric
:
"""
Helper class which computes token-level F1.
"""
@
staticmethod
def
_prec_recall_f1_score
(
pred_items
,
gold_items
):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common
=
Counter
(
gold_items
)
&
Counter
(
pred_items
)
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
,
0
,
0
precision
=
1.0
*
num_same
/
len
(
pred_items
)
recall
=
1.0
*
num_same
/
len
(
gold_items
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
precision
,
recall
,
f1
@
staticmethod
def
compute_each_pair
(
guess
:
str
,
answer
:
str
):
if
answer
==
""
:
return
None
,
None
,
None
if
guess
==
""
:
return
0
,
0
,
0
g_tokens
=
normalize_answer
(
guess
).
split
()
a_tokens
=
normalize_answer
(
answer
).
split
()
precision
,
recall
,
f1
=
F1Metric
.
_prec_recall_f1_score
(
g_tokens
,
a_tokens
)
return
precision
,
recall
,
f1
@
staticmethod
def
compute_all_pairs
(
guesses
:
List
[
str
],
answers
:
List
[
str
]):
# additional augment:
assert
len
(
guesses
)
==
len
(
answers
)
precision_list
,
recall_list
,
f1_list
=
[],
[],
[]
for
guess
,
answer
in
zip
(
guesses
,
answers
):
precision
,
recall
,
f1
=
F1Metric
.
compute_each_pair
(
guess
,
answer
)
if
precision
is
None
or
recall
is
None
or
f1
is
None
:
continue
precision_list
.
append
(
precision
)
recall_list
.
append
(
recall
)
f1_list
.
append
(
f1
)
return
np
.
mean
(
precision_list
),
np
.
mean
(
recall_list
),
np
.
mean
(
f1_list
)
tasks/knwl_dialo/preprocessing.py
deleted
100644 → 0
View file @
f24c972c
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment