Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ab754c8c
Commit
ab754c8c
authored
Apr 03, 2020
by
Raul Puri
Browse files
functionalized code
parent
5dd2a9ad
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
101 additions
and
65 deletions
+101
-65
tasks/ensemble_classifier.py
tasks/ensemble_classifier.py
+101
-65
No files found.
tasks/ensemble_classifier.py
View file @
ab754c8c
import
torch
import
os
import
os
import
numpy
as
np
import
argparse
import
argparse
import
collections
import
collections
parser
=
argparse
.
ArgumentParser
()
import
numpy
as
np
parser
.
add_argument
(
'--paths'
,
required
=
True
,
nargs
=
'+'
)
import
torch
parser
.
add_argument
(
'--eval'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--outdir'
)
def
process_files
(
args
):
parser
.
add_argument
(
'--prediction-name'
,
default
=
'test_predictions.pt'
)
all_predictions
=
collections
.
OrderedDict
()
parser
.
add_argument
(
'--calc-threshold'
,
action
=
'store_true'
)
all_labels
=
collections
.
OrderedDict
()
parser
.
add_argument
(
'--one-threshold'
,
action
=
'store_true'
)
all_uid
=
collections
.
OrderedDict
()
parser
.
add_argument
(
'--threshold'
,
nargs
=
'+'
,
default
=
None
,
type
=
float
)
for
path
in
args
.
paths
:
parser
.
add_argument
(
'--labels'
,
nargs
=
'+'
,
default
=
None
)
args
=
parser
.
parse_args
()
all_predictions
=
collections
.
OrderedDict
()
all_labels
=
collections
.
OrderedDict
()
all_uid
=
collections
.
OrderedDict
()
for
path
in
args
.
paths
:
path
=
os
.
path
.
join
(
path
,
args
.
prediction_name
)
path
=
os
.
path
.
join
(
path
,
args
.
prediction_name
)
try
:
try
:
data
=
torch
.
load
(
path
)
data
=
torch
.
load
(
path
)
...
@@ -38,10 +29,11 @@ for path in args.paths:
...
@@ -38,10 +29,11 @@ for path in args.paths:
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
continue
continue
all_correct
=
0
return
all_predictions
,
all_labels
,
all_uid
count
=
0
def
get_threshold
(
all_predictions
,
all_labels
):
if
args
.
one_threshold
:
def
get_threshold
(
all_predictions
,
all_labels
,
one_threshold
=
False
):
if
one_threshold
:
all_predictons
=
{
'combined'
:
np
.
concatenate
(
list
(
all_predictions
.
values
()))}
all_predictons
=
{
'combined'
:
np
.
concatenate
(
list
(
all_predictions
.
values
()))}
all_labels
=
{
'combined'
:
np
.
concatenate
(
list
(
all_predictions
.
labels
()))}
all_labels
=
{
'combined'
:
np
.
concatenate
(
list
(
all_predictions
.
labels
()))}
out_thresh
=
[]
out_thresh
=
[]
...
@@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels):
...
@@ -50,6 +42,8 @@ def get_threshold(all_predictions, all_labels):
labels
=
all_labels
[
dataset
]
labels
=
all_labels
[
dataset
]
out_thresh
.
append
(
calc_threshold
(
preds
,
labels
))
out_thresh
.
append
(
calc_threshold
(
preds
,
labels
))
return
out_thresh
return
out_thresh
def
calc_threshold
(
p
,
l
):
def
calc_threshold
(
p
,
l
):
trials
=
[(
i
)
*
(
1.
/
100.
)
for
i
in
range
(
100
)]
trials
=
[(
i
)
*
(
1.
/
100.
)
for
i
in
range
(
100
)]
best_acc
=
float
(
'-inf'
)
best_acc
=
float
(
'-inf'
)
...
@@ -61,6 +55,7 @@ def calc_threshold(p, l):
...
@@ -61,6 +55,7 @@ def calc_threshold(p, l):
best_thresh
=
t
best_thresh
=
t
return
best_thresh
return
best_thresh
def
apply_threshold
(
preds
,
t
):
def
apply_threshold
(
preds
,
t
):
assert
(
np
.
allclose
(
preds
.
sum
(
-
1
),
np
.
ones
(
preds
.
shape
[
0
])))
assert
(
np
.
allclose
(
preds
.
sum
(
-
1
),
np
.
ones
(
preds
.
shape
[
0
])))
prob
=
preds
[:,
-
1
]
prob
=
preds
[:,
-
1
]
...
@@ -69,6 +64,7 @@ def apply_threshold(preds, t):
...
@@ -69,6 +64,7 @@ def apply_threshold(preds, t):
preds
[
np
.
arange
(
len
(
thresholded
)),
thresholded
.
reshape
(
-
1
)]
=
1
preds
[
np
.
arange
(
len
(
thresholded
)),
thresholded
.
reshape
(
-
1
)]
=
1
return
preds
return
preds
def
threshold_predictions
(
all_predictions
,
threshold
):
def
threshold_predictions
(
all_predictions
,
threshold
):
if
len
(
threshold
)
!=
len
(
all_predictions
):
if
len
(
threshold
)
!=
len
(
all_predictions
):
threshold
=
[
threshold
[
-
1
]]
*
(
len
(
all_predictions
)
-
len
(
threshold
))
threshold
=
[
threshold
[
-
1
]]
*
(
len
(
all_predictions
)
-
len
(
threshold
))
...
@@ -78,17 +74,25 @@ def threshold_predictions(all_predictions, threshold):
...
@@ -78,17 +74,25 @@ def threshold_predictions(all_predictions, threshold):
all_predictions
[
dataset
]
=
apply_threshold
(
preds
,
thresh
)
all_predictions
[
dataset
]
=
apply_threshold
(
preds
,
thresh
)
return
all_predictions
return
all_predictions
for
d
in
all_predictions
:
def
postprocess_predictions
(
all_predictions
,
all_labels
,
args
):
for
d
in
all_predictions
:
all_predictions
[
d
]
=
all_predictions
[
d
]
/
len
(
args
.
paths
)
all_predictions
[
d
]
=
all_predictions
[
d
]
/
len
(
args
.
paths
)
if
args
.
calc_threshold
:
if
args
.
calc_threshold
:
args
.
threshold
=
get_threshold
(
all_predictions
,
all_labels
)
args
.
threshold
=
get_threshold
(
all_predictions
,
all_labels
,
args
.
one_threshold
)
print
(
'threshold'
,
args
.
threshold
)
print
(
'threshold'
,
args
.
threshold
)
if
args
.
threshold
is
not
None
:
if
args
.
threshold
is
not
None
:
all_predictions
=
threshold_predictions
(
all_predictions
,
args
.
threshold
)
all_predictions
=
threshold_predictions
(
all_predictions
,
args
.
threshold
)
for
dataset
in
all_predictions
:
return
all_predictions
,
all_labels
def
write_predictions
(
all_predictions
,
all_labels
,
all_uid
,
args
):
all_correct
=
0
count
=
0
for
dataset
in
all_predictions
:
preds
=
all_predictions
[
dataset
]
preds
=
all_predictions
[
dataset
]
preds
=
np
.
argmax
(
preds
,
-
1
)
preds
=
np
.
argmax
(
preds
,
-
1
)
if
args
.
eval
:
if
args
.
eval
:
...
@@ -105,5 +109,37 @@ for dataset in all_predictions:
...
@@ -105,5 +109,37 @@ for dataset in all_predictions:
with
open
(
outpath
,
'w'
)
as
f
:
with
open
(
outpath
,
'w'
)
as
f
:
f
.
write
(
'id
\t
label
\n
'
)
f
.
write
(
'id
\t
label
\n
'
)
f
.
write
(
'
\n
'
.
join
(
str
(
uid
)
+
'
\t
'
+
str
(
args
.
labels
[
p
])
for
uid
,
p
in
zip
(
all_uid
[
dataset
],
preds
.
tolist
())))
f
.
write
(
'
\n
'
.
join
(
str
(
uid
)
+
'
\t
'
+
str
(
args
.
labels
[
p
])
for
uid
,
p
in
zip
(
all_uid
[
dataset
],
preds
.
tolist
())))
if
args
.
eval
:
if
args
.
eval
:
print
(
all_correct
/
count
)
print
(
all_correct
/
count
)
def
ensemble_predictions
(
args
):
all_predictions
,
all_labels
,
all_uid
=
process_files
(
args
)
all_predictions
,
all_labels
=
postprocess_predictions
(
all_predictions
,
all_labels
,
args
)
write_predictions
(
all_predictions
,
all_labels
,
all_uid
,
args
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--paths'
,
required
=
True
,
nargs
=
'+'
,
help
=
'paths to checkpoint directories used in ensemble'
)
parser
.
add_argument
(
'--eval'
,
action
=
'store_true'
,
help
=
'compute accuracy metrics against labels (dev set)'
)
parser
.
add_argument
(
'--outdir'
,
help
=
'directory to place ensembled predictions in'
)
parser
.
add_argument
(
'--prediction-name'
,
default
=
'test_predictions.pt'
,
help
=
'name of predictions in checkpoint directories'
)
parser
.
add_argument
(
'--calc-threshold'
,
action
=
'store_true'
,
help
=
'calculate threshold classification'
)
parser
.
add_argument
(
'--one-threshold'
,
action
=
'store_true'
,
help
=
'use on threshold for all subdatasets'
)
parser
.
add_argument
(
'--threshold'
,
nargs
=
'+'
,
default
=
None
,
type
=
float
,
help
=
'user supplied threshold for classification'
)
parser
.
add_argument
(
'--labels'
,
nargs
=
'+'
,
default
=
None
,
help
=
'whitespace separated list of label names'
)
args
=
parser
.
parse_args
()
ensemble_predictions
(
args
)
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment