Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0c151638
Commit
0c151638
authored
Dec 09, 2020
by
Jared Casper
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Add implementation for pipelined zeroshot GPT-2 evaluation
parent
3afcba6e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
92 additions
and
49 deletions
+92
-49
tasks/zeroshot_gpt2/evaluate.py
tasks/zeroshot_gpt2/evaluate.py
+92
-49
No files found.
tasks/zeroshot_gpt2/evaluate.py
View file @
0c151638
...
@@ -20,12 +20,12 @@ import math
...
@@ -20,12 +20,12 @@ import math
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
,
is_last_rank
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.model
import
GPT2Model
from
megatron.model
import
GPT2Model
,
GPT2ModelFirstStage
,
GPT2ModelLastStage
,
GPT2ModelIntermediateStage
from
megatron.training
import
get_model
from
megatron.training
import
get_model
,
communicate
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
...
@@ -48,7 +48,17 @@ def get_model_provider(eval_metric):
...
@@ -48,7 +48,17 @@ def get_model_provider(eval_metric):
'is not supported.'
.
format
(
eval_metric
))
'is not supported.'
.
format
(
eval_metric
))
print_rank_0
(
'building GPT2 model ...'
)
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
parallel_output
)
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
GPT2ModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPT2ModelLastStage
(
parallel_output
=
parallel_output
,
num_tokentypes
=
0
)
else
:
model
=
GPT2ModelIntermediateStage
(
num_tokentypes
=
0
)
else
:
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
parallel_output
)
return
model
return
model
...
@@ -83,27 +93,58 @@ def forward_step(batch, model, eval_metric):
...
@@ -83,27 +93,58 @@ def forward_step(batch, model, eval_metric):
tokens
,
labels
,
attention_mask
,
position_ids
,
loss_mask
=
process_batch
(
tokens
,
labels
,
attention_mask
,
position_ids
,
loss_mask
=
process_batch
(
batch
)
batch
)
# Tell the model what our actual batch size will be
args
=
get_args
()
args
.
micro_batch_size
=
len
(
labels
)
# Forward model.
# Forward model.
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# For loss, return the unreduced loss.
# Forward pass through the model.
if
eval_metric
==
'loss'
:
if
mpu
.
is_pipeline_first_stage
():
losses
=
mpu
.
vocab_parallel_cross_entropy
(
assert
input_tensor
is
None
output
.
contiguous
().
float
(),
labels
.
contiguous
())
if
mpu
.
is_pipeline_last_stage
():
loss
=
torch
.
sum
(
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
losses
.
view
(
-
1
)
*
loss_mask
.
contiguous
().
view
(
-
1
).
float
())
else
:
return
loss
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
else
:
assert
input_tensor
is
not
None
output
=
model
(
input_tensor
,
attention_mask
)
if
not
mpu
.
is_pipeline_last_stage
():
communicate
(
tensor_send_next
=
output
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
return
None
if
mpu
.
is_pipeline_last_stage
():
# For loss, return the unreduced loss.
if
eval_metric
==
'loss'
:
losses
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
contiguous
().
float
(),
labels
.
contiguous
())
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
.
contiguous
().
view
(
-
1
).
float
())
return
loss
# For accuracy, return the number of correctly predicted samples.
# For accuracy, return the number of correctly predicted samples.
if
eval_metric
==
'accuracy'
:
if
eval_metric
==
'accuracy'
:
outputs
=
torch
.
argmax
(
output
,
-
1
)
outputs
=
torch
.
argmax
(
output
,
-
1
)
correct
=
(
outputs
==
labels
).
float
()
correct
=
(
outputs
==
labels
).
float
()
correct
[(
1
-
loss_mask
).
bool
()]
=
1
correct
[(
1
-
loss_mask
).
bool
()]
=
1
correct
=
correct
.
prod
(
-
1
)
correct
=
correct
.
prod
(
-
1
)
return
correct
.
sum
()
return
correct
.
sum
()
raise
NotImplementedError
(
'forward method for evaluation metric {} '
raise
NotImplementedError
(
'forward method for evaluation metric {} '
'is not implemented.'
.
format
(
eval_metric
))
'is not implemented.'
.
format
(
eval_metric
))
return
None
def
evaluate
(
data_loader
,
model
,
eval_metric
):
def
evaluate
(
data_loader
,
model
,
eval_metric
):
...
@@ -123,10 +164,11 @@ def evaluate(data_loader, model, eval_metric):
...
@@ -123,10 +164,11 @@ def evaluate(data_loader, model, eval_metric):
output
=
forward_step
(
batch
,
model
,
eval_metric
)
output
=
forward_step
(
batch
,
model
,
eval_metric
)
# Reduce across processes.
# Reduce across processes.
torch
.
distributed
.
all_reduce
(
output
,
if
mpu
.
is_pipeline_last_stage
():
group
=
mpu
.
get_data_parallel_group
())
torch
.
distributed
.
all_reduce
(
output
,
group
=
mpu
.
get_data_parallel_group
())
total_output
+=
output
total_output
+=
output
return
total_output
return
total_output
...
@@ -138,33 +180,34 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric):
...
@@ -138,33 +180,34 @@ def evaluate_and_print_results(task, data_loader, model, eval_metric):
output
=
evaluate
(
data_loader
,
model
,
eval_metric
)
output
=
evaluate
(
data_loader
,
model
,
eval_metric
)
string
=
' validation results on {} | '
.
format
(
task
)
string
=
' validation results on {} | '
.
format
(
task
)
if
eval_metric
==
'loss'
:
if
is_last_rank
():
num_tokenized_tokens
=
data_loader
.
dataset
.
num_tokenized_tokens
if
eval_metric
==
'loss'
:
num_original_tokens
=
data_loader
.
dataset
.
num_original_tokens
num_tokenized_tokens
=
data_loader
.
dataset
.
num_tokenized_tokens
val_loss
=
output
/
(
num_tokenized_tokens
-
1
)
num_original_tokens
=
data_loader
.
dataset
.
num_original_tokens
ppl
=
math
.
exp
(
min
(
20
,
val_loss
))
val_loss
=
output
/
(
num_tokenized_tokens
-
1
)
token_ratio
=
(
num_tokenized_tokens
-
1
)
/
(
num_original_tokens
-
1
)
ppl
=
math
.
exp
(
min
(
20
,
val_loss
))
adjusted_ppl
=
math
.
exp
(
min
(
20
,
val_loss
*
token_ratio
))
token_ratio
=
(
num_tokenized_tokens
-
1
)
/
(
num_original_tokens
-
1
)
string
+=
'avg loss: {:.4E} | '
.
format
(
val_loss
)
adjusted_ppl
=
math
.
exp
(
min
(
20
,
val_loss
*
token_ratio
))
string
+=
'ppl: {:.4E} | '
.
format
(
ppl
)
string
+=
'avg loss: {:.4E} | '
.
format
(
val_loss
)
string
+=
'adjusted ppl: {:.4E} | '
.
format
(
adjusted_ppl
)
string
+=
'ppl: {:.4E} | '
.
format
(
ppl
)
string
+=
'token ratio: {} |'
.
format
(
token_ratio
)
string
+=
'adjusted ppl: {:.4E} | '
.
format
(
adjusted_ppl
)
string
+=
'token ratio: {} |'
.
format
(
token_ratio
)
elif
eval_metric
==
'accuracy'
:
num_examples
=
len
(
data_loader
.
dataset
)
acc
=
output
/
num_examples
string
+=
'number correct: {:.4E} | '
.
format
(
output
)
string
+=
'total examples: {:.4E} | '
.
format
(
num_examples
)
string
+=
'avg accuracy: {:.4E}'
.
format
(
acc
)
else
:
elif
eval_metric
==
'accuracy'
:
raise
NotImplementedError
(
'evaluation method for {} metric is not '
num_examples
=
len
(
data_loader
.
dataset
)
'implemented yet.'
.
format
(
eval_metric
))
acc
=
output
/
num_examples
string
+=
'number correct: {:.4E} | '
.
format
(
output
)
string
+=
'total examples: {:.4E} | '
.
format
(
num_examples
)
string
+=
'avg accuracy: {:.4E}'
.
format
(
acc
)
else
:
raise
NotImplementedError
(
'evaluation method for {} metric is not '
'implemented yet.'
.
format
(
eval_metric
))
length
=
len
(
string
)
+
1
length
=
len
(
string
)
+
1
print
_rank_0
(
'-'
*
length
)
print
(
'-'
*
length
)
print
_rank_0
(
string
)
print
(
string
)
print
_rank_0
(
'-'
*
length
)
print
(
'-'
*
length
)
def
main
():
def
main
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment