Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6ae4a03c
Commit
6ae4a03c
authored
Apr 14, 2020
by
Mohammad
Browse files
Merge branch 'staging' into staging_gpt2_dataloader
parents
e06159f2
7890681a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
163 additions
and
13 deletions
+163
-13
megatron/initialize.py
megatron/initialize.py
+18
-13
tasks/ensemble_classifier.py
tasks/ensemble_classifier.py
+145
-0
No files found.
megatron/initialize.py
View file @
6ae4a03c
...
...
@@ -59,6 +59,7 @@ def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
if
torch
.
distributed
.
is_initialized
():
if
args
.
rank
==
0
:
...
...
@@ -66,23 +67,25 @@ def _initialize_distributed():
'skipping initialization ...'
,
flush
=
True
)
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
device
=
torch
.
cuda
.
current_device
()
local_rank
=
args
.
rank
%
torch
.
cuda
.
device_count
()
assert
local_rank
==
device
,
\
'expected local-rank to be the same as rank % device-count.'
if
device_count
>
0
:
device
=
torch
.
cuda
.
current_device
()
local_rank
=
args
.
rank
%
device_count
assert
local_rank
==
device
,
\
'expected local-rank to be the same as rank % device-count.'
else
:
if
args
.
rank
==
0
:
print
(
'> initializing torch distributed ...'
,
flush
=
True
)
# Manually set the device ids.
device
=
args
.
rank
%
torch
.
cuda
.
device_count
()
if
args
.
local_rank
is
not
None
:
assert
args
.
local_rank
==
device
,
\
'expected local-rank to be the same as rank % device-count.'
else
:
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
if
device_count
>
0
:
device
=
args
.
rank
%
device_count
if
args
.
local_rank
is
not
None
:
assert
args
.
local_rank
==
device
,
\
'expected local-rank to be the same as rank % device-count.'
else
:
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
...
...
@@ -94,7 +97,8 @@ def _initialize_distributed():
init_method
=
init_method
)
# Set the model-parallel / data-parallel communicators.
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
if
device_count
>
0
:
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
def
_init_autoresume
():
...
...
@@ -112,7 +116,8 @@ def _set_random_seed(seed):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
if
torch
.
cuda
.
device_count
()
>
0
:
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
else
:
raise
ValueError
(
'Seed ({}) should be a positive integer.'
.
format
(
seed
))
...
...
tasks/ensemble_classifier.py
0 → 100644
View file @
6ae4a03c
import
os
import
argparse
import
collections
import
numpy
as
np
import
torch
def
process_files
(
args
):
all_predictions
=
collections
.
OrderedDict
()
all_labels
=
collections
.
OrderedDict
()
all_uid
=
collections
.
OrderedDict
()
for
path
in
args
.
paths
:
path
=
os
.
path
.
join
(
path
,
args
.
prediction_name
)
try
:
data
=
torch
.
load
(
path
)
for
dataset
in
data
:
name
,
d
=
dataset
predictions
,
labels
,
uid
=
d
if
name
not
in
all_predictions
:
all_predictions
[
name
]
=
np
.
array
(
predictions
)
if
args
.
labels
is
None
:
args
.
labels
=
[
i
for
i
in
range
(
all_predictions
[
name
].
shape
[
1
])]
if
args
.
eval
:
all_labels
[
name
]
=
np
.
array
(
labels
)
all_uid
[
name
]
=
np
.
array
(
uid
)
else
:
all_predictions
[
name
]
+=
np
.
array
(
predictions
)
assert
np
.
allclose
(
all_uid
[
name
],
np
.
array
(
uid
))
except
Exception
as
e
:
print
(
e
)
continue
return
all_predictions
,
all_labels
,
all_uid
def
get_threshold
(
all_predictions
,
all_labels
,
one_threshold
=
False
):
if
one_threshold
:
all_predictons
=
{
'combined'
:
np
.
concatenate
(
list
(
all_predictions
.
values
()))}
all_labels
=
{
'combined'
:
np
.
concatenate
(
list
(
all_predictions
.
labels
()))}
out_thresh
=
[]
for
dataset
in
all_predictions
:
preds
=
all_predictions
[
dataset
]
labels
=
all_labels
[
dataset
]
out_thresh
.
append
(
calc_threshold
(
preds
,
labels
))
return
out_thresh
def
calc_threshold
(
p
,
l
):
trials
=
[(
i
)
*
(
1.
/
100.
)
for
i
in
range
(
100
)]
best_acc
=
float
(
'-inf'
)
best_thresh
=
0
for
t
in
trials
:
acc
=
((
apply_threshold
(
p
,
t
).
argmax
(
-
1
)
==
l
).
astype
(
float
)).
mean
()
if
acc
>
best_acc
:
best_acc
=
acc
best_thresh
=
t
return
best_thresh
def
apply_threshold
(
preds
,
t
):
assert
(
np
.
allclose
(
preds
.
sum
(
-
1
),
np
.
ones
(
preds
.
shape
[
0
])))
prob
=
preds
[:,
-
1
]
thresholded
=
(
prob
>=
t
).
astype
(
int
)
preds
=
np
.
zeros_like
(
preds
)
preds
[
np
.
arange
(
len
(
thresholded
)),
thresholded
.
reshape
(
-
1
)]
=
1
return
preds
def
threshold_predictions
(
all_predictions
,
threshold
):
if
len
(
threshold
)
!=
len
(
all_predictions
):
threshold
=
[
threshold
[
-
1
]]
*
(
len
(
all_predictions
)
-
len
(
threshold
))
for
i
,
dataset
in
enumerate
(
all_predictions
):
thresh
=
threshold
[
i
]
preds
=
all_predictions
[
dataset
]
all_predictions
[
dataset
]
=
apply_threshold
(
preds
,
thresh
)
return
all_predictions
def
postprocess_predictions
(
all_predictions
,
all_labels
,
args
):
for
d
in
all_predictions
:
all_predictions
[
d
]
=
all_predictions
[
d
]
/
len
(
args
.
paths
)
if
args
.
calc_threshold
:
args
.
threshold
=
get_threshold
(
all_predictions
,
all_labels
,
args
.
one_threshold
)
print
(
'threshold'
,
args
.
threshold
)
if
args
.
threshold
is
not
None
:
all_predictions
=
threshold_predictions
(
all_predictions
,
args
.
threshold
)
return
all_predictions
,
all_labels
def
write_predictions
(
all_predictions
,
all_labels
,
all_uid
,
args
):
all_correct
=
0
count
=
0
for
dataset
in
all_predictions
:
preds
=
all_predictions
[
dataset
]
preds
=
np
.
argmax
(
preds
,
-
1
)
if
args
.
eval
:
correct
=
(
preds
==
all_labels
[
dataset
]).
sum
()
num
=
len
(
all_labels
[
dataset
])
accuracy
=
correct
/
num
count
+=
num
all_correct
+=
correct
accuracy
=
(
preds
==
all_labels
[
dataset
]).
mean
()
print
(
accuracy
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
args
.
outdir
,
dataset
)):
os
.
makedirs
(
os
.
path
.
join
(
args
.
outdir
,
dataset
))
outpath
=
os
.
path
.
join
(
args
.
outdir
,
dataset
,
os
.
path
.
splitext
(
args
.
prediction_name
)[
0
]
+
'.tsv'
)
with
open
(
outpath
,
'w'
)
as
f
:
f
.
write
(
'id
\t
label
\n
'
)
f
.
write
(
'
\n
'
.
join
(
str
(
uid
)
+
'
\t
'
+
str
(
args
.
labels
[
p
])
for
uid
,
p
in
zip
(
all_uid
[
dataset
],
preds
.
tolist
())))
if
args
.
eval
:
print
(
all_correct
/
count
)
def
ensemble_predictions
(
args
):
all_predictions
,
all_labels
,
all_uid
=
process_files
(
args
)
all_predictions
,
all_labels
=
postprocess_predictions
(
all_predictions
,
all_labels
,
args
)
write_predictions
(
all_predictions
,
all_labels
,
all_uid
,
args
)
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--paths'
,
required
=
True
,
nargs
=
'+'
,
help
=
'paths to checkpoint directories used in ensemble'
)
parser
.
add_argument
(
'--eval'
,
action
=
'store_true'
,
help
=
'compute accuracy metrics against labels (dev set)'
)
parser
.
add_argument
(
'--outdir'
,
help
=
'directory to place ensembled predictions in'
)
parser
.
add_argument
(
'--prediction-name'
,
default
=
'test_predictions.pt'
,
help
=
'name of predictions in checkpoint directories'
)
parser
.
add_argument
(
'--calc-threshold'
,
action
=
'store_true'
,
help
=
'calculate threshold classification'
)
parser
.
add_argument
(
'--one-threshold'
,
action
=
'store_true'
,
help
=
'use on threshold for all subdatasets'
)
parser
.
add_argument
(
'--threshold'
,
nargs
=
'+'
,
default
=
None
,
type
=
float
,
help
=
'user supplied threshold for classification'
)
parser
.
add_argument
(
'--labels'
,
nargs
=
'+'
,
default
=
None
,
help
=
'whitespace separated list of label names'
)
args
=
parser
.
parse_args
()
ensemble_predictions
(
args
)
if
__name__
==
'__main__'
:
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment