Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1113f97f
Commit
1113f97f
authored
Jul 05, 2019
by
thomwolf
Browse files
clean up glue example
parent
162ba383
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
423 additions
and
17 deletions
+423
-17
examples/run_bert_classifier.py
examples/run_bert_classifier.py
+3
-17
examples/run_glue.py
examples/run_glue.py
+401
-0
examples/utils_glue.py
examples/utils_glue.py
+1
-0
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+18
-0
No files found.
examples/run_bert_classifier.py
View file @
1113f97f
...
@@ -309,14 +309,7 @@ def main():
...
@@ -309,14 +309,7 @@ def main():
# define a new function to compute loss values for both output_modes
# define a new function to compute loss values for both output_modes
ouputs
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
labels
=
label_ids
)
ouputs
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
labels
=
label_ids
)
loss
=
loss
=
ouputs
[
0
]
if
output_mode
==
"classification"
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
num_labels
),
label_ids
.
view
(
-
1
))
elif
output_mode
==
"regression"
:
loss_fct
=
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
label_ids
.
view
(
-
1
))
if
n_gpu
>
1
:
if
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
...
@@ -423,15 +416,8 @@ def main():
...
@@ -423,15 +416,8 @@ def main():
label_ids
=
label_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
)
outputs
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
labels
=
label_ids
)
tmp_eval_loss
,
logits
=
outputs
[:
2
]
# create eval loss and other metric required by the task
if
output_mode
==
"classification"
:
loss_fct
=
CrossEntropyLoss
()
tmp_eval_loss
=
loss_fct
(
logits
.
view
(
-
1
,
num_labels
),
label_ids
.
view
(
-
1
))
elif
output_mode
==
"regression"
:
loss_fct
=
MSELoss
()
tmp_eval_loss
=
loss_fct
(
logits
.
view
(
-
1
),
label_ids
.
view
(
-
1
))
eval_loss
+=
tmp_eval_loss
.
mean
().
item
()
eval_loss
+=
tmp_eval_loss
.
mean
().
item
()
nb_eval_steps
+=
1
nb_eval_steps
+=
1
...
...
examples/run_glue.py
0 → 100644
View file @
1113f97f
This diff is collapsed.
Click to expand it.
examples/utils_glue.py
View file @
1113f97f
...
@@ -583,6 +583,7 @@ processors = {
...
@@ -583,6 +583,7 @@ processors = {
output_modes
=
{
output_modes
=
{
"cola"
:
"classification"
,
"cola"
:
"classification"
,
"mnli"
:
"classification"
,
"mnli"
:
"classification"
,
"mnli-mm"
:
"classification"
,
"mrpc"
:
"classification"
,
"mrpc"
:
"classification"
,
"sst-2"
:
"classification"
,
"sst-2"
:
"classification"
,
"sts-b"
:
"regression"
,
"sts-b"
:
"regression"
,
...
...
pytorch_transformers/tokenization_utils.py
View file @
1113f97f
...
@@ -110,6 +110,24 @@ class PreTrainedTokenizer(object):
...
@@ -110,6 +110,24 @@ class PreTrainedTokenizer(object):
return
tokenizer
return
tokenizer
def
tokenize
(
self
,
text
):
raise
NotImplementedError
def
convert_tokens_to_ids
(
self
,
tokens
):
raise
NotImplementedError
def
convert_ids_to_tokens
(
self
,
ids
):
raise
NotImplementedError
def
encode
(
self
,
text
):
raise
NotImplementedError
def
decode
(
self
,
token_ids
,
*
input
,
**
kwargs
):
raise
NotImplementedError
def
save_vocabulary
(
self
,
vocab_path
):
raise
NotImplementedError
def
clean_up_tokenization
(
out_string
):
def
clean_up_tokenization
(
out_string
):
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment