Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
40dbda68
"docs/source/vscode:/vscode.git/clone" did not exist on "40ed717232bf87f42ce2d3c16a14f0015e9c5fa9"
Commit
40dbda68
authored
Jun 18, 2019
by
thomwolf
Browse files
updating classification example
parent
7388c83b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
4 deletions
+20
-4
examples/run_classifier.py
examples/run_classifier.py
+20
-4
No files found.
examples/run_classifier.py
View file @
40dbda68
...
@@ -228,10 +228,10 @@ def main():
...
@@ -228,10 +228,10 @@ def main():
# Prepare data loader
# Prepare data loader
train_examples
=
processor
.
get_train_examples
(
args
.
data_dir
)
train_examples
=
processor
.
get_train_examples
(
args
.
data_dir
)
cached_train_features_file
=
args
.
data_dir
+
'
_{0}_{1}_{2}'
.
format
(
cached_train_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
'train
_{0}_{1}_{2}'
.
format
(
list
(
filter
(
None
,
args
.
bert_model
.
split
(
'/'
))).
pop
(),
list
(
filter
(
None
,
args
.
bert_model
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
),
str
(
args
.
max_seq_length
),
str
(
task_name
))
str
(
task_name
))
)
try
:
try
:
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
train_features
=
pickle
.
load
(
reader
)
train_features
=
pickle
.
load
(
reader
)
...
@@ -311,7 +311,7 @@ def main():
...
@@ -311,7 +311,7 @@ def main():
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
# define a new function to compute loss values for both output_modes
# define a new function to compute loss values for both output_modes
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
labels
=
None
)
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
)
if
output_mode
==
"classification"
:
if
output_mode
==
"classification"
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
...
@@ -380,6 +380,22 @@ def main():
...
@@ -380,6 +380,22 @@ def main():
### Evaluation
### Evaluation
if
args
.
do_eval
:
if
args
.
do_eval
:
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
cached_train_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
'dev_{0}_{1}_{2}'
.
format
(
list
(
filter
(
None
,
args
.
bert_model
.
split
(
'/'
))).
pop
(),
str
(
args
.
max_seq_length
),
str
(
task_name
)))
try
:
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
train_features
=
pickle
.
load
(
reader
)
except
:
train_features
=
convert_examples_to_features
(
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
)
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
" Saving train features into cached file %s"
,
cached_train_features_file
)
with
open
(
cached_train_features_file
,
"wb"
)
as
writer
:
pickle
.
dump
(
train_features
,
writer
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
)
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
)
logger
.
info
(
"***** Running evaluation *****"
)
logger
.
info
(
"***** Running evaluation *****"
)
...
@@ -414,7 +430,7 @@ def main():
...
@@ -414,7 +430,7 @@ def main():
label_ids
=
label_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
labels
=
None
)
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
)
# create eval loss and other metric required by the task
# create eval loss and other metric required by the task
if
output_mode
==
"classification"
:
if
output_mode
==
"classification"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment