Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dee09a40
Commit
dee09a40
authored
Nov 02, 2018
by
thomwolf
Browse files
various fixes
parent
2c731fd1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
8 deletions
+11
-8
modeling_pytorch.py
modeling_pytorch.py
+2
-1
run_classifier_pytorch.py
run_classifier_pytorch.py
+9
-7
No files found.
modeling_pytorch.py
View file @
dee09a40
...
@@ -412,7 +412,8 @@ class BertForSequenceClassification(nn.Module):
...
@@ -412,7 +412,8 @@ class BertForSequenceClassification(nn.Module):
model = modeling.BertModel(config, num_labels)
model = modeling.BertModel(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
def
__init__
(
self
,
config
,
num_labels
):
"""
def
__init__
(
self
,
config
,
num_labels
):
super
(
BertForSequenceClassification
,
self
).
__init__
()
super
(
BertForSequenceClassification
,
self
).
__init__
()
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
...
run_classifier_pytorch.py
View file @
dee09a40
...
@@ -73,8 +73,8 @@ parser.add_argument("--init_checkpoint",
...
@@ -73,8 +73,8 @@ parser.add_argument("--init_checkpoint",
type
=
str
,
type
=
str
,
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
parser
.
add_argument
(
"--do_lower_case"
,
parser
.
add_argument
(
"--do_lower_case"
,
default
=
Tru
e
,
default
=
Fals
e
,
type
=
bool
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. Should be True for uncased models and False for cased models."
)
help
=
"Whether to lower case the input text. Should be True for uncased models and False for cased models."
)
parser
.
add_argument
(
"--max_seq_length"
,
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
default
=
128
,
...
@@ -84,11 +84,11 @@ parser.add_argument("--max_seq_length",
...
@@ -84,11 +84,11 @@ parser.add_argument("--max_seq_length",
"than this will be padded."
)
"than this will be padded."
)
parser
.
add_argument
(
"--do_train"
,
parser
.
add_argument
(
"--do_train"
,
default
=
False
,
default
=
False
,
type
=
bool
,
action
=
'store_true'
,
help
=
"Whether to run training."
)
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_eval"
,
parser
.
add_argument
(
"--do_eval"
,
default
=
False
,
default
=
False
,
type
=
bool
,
action
=
'store_true'
,
help
=
"Whether to run eval on the dev set."
)
help
=
"Whether to run eval on the dev set."
)
parser
.
add_argument
(
"--train_batch_size"
,
parser
.
add_argument
(
"--train_batch_size"
,
default
=
32
,
default
=
32
,
...
@@ -117,7 +117,7 @@ parser.add_argument("--save_checkpoints_steps",
...
@@ -117,7 +117,7 @@ parser.add_argument("--save_checkpoints_steps",
help
=
"How often to save the model checkpoint."
)
help
=
"How often to save the model checkpoint."
)
parser
.
add_argument
(
"--no_cuda"
,
parser
.
add_argument
(
"--no_cuda"
,
default
=
False
,
default
=
False
,
type
=
bool
,
action
=
'store_true'
,
help
=
"Whether not to use CUDA when available"
)
help
=
"Whether not to use CUDA when available"
)
parser
.
add_argument
(
"--local_rank"
,
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
type
=
int
,
...
@@ -490,6 +490,7 @@ def main():
...
@@ -490,6 +490,7 @@ def main():
warmup
=
args
.
warmup_proportion
,
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_steps
)
t_total
=
num_train_steps
)
global_step
=
0
if
args
.
do_train
:
if
args
.
do_train
:
train_features
=
convert_examples_to_features
(
train_features
=
convert_examples_to_features
(
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
)
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
)
...
@@ -511,7 +512,6 @@ def main():
...
@@ -511,7 +512,6 @@ def main():
train_dataloader
=
DataLoader
(
train_data
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
train_dataloader
=
DataLoader
(
train_data
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
model
.
train
()
model
.
train
()
global_step
=
0
for
epoch
in
args
.
num_train_epochs
:
for
epoch
in
args
.
num_train_epochs
:
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
input_ids
=
input_ids
.
to
(
device
)
input_ids
=
input_ids
.
to
(
device
)
...
@@ -552,9 +552,11 @@ def main():
...
@@ -552,9 +552,11 @@ def main():
input_ids
=
input_ids
.
to
(
device
)
input_ids
=
input_ids
.
to
(
device
)
input_mask
=
input_mask
.
float
().
to
(
device
)
input_mask
=
input_mask
.
float
().
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
tmp_eval_loss
,
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
tmp_eval_loss
,
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
logits
=
logits
.
detach
().
cpu
().
numpy
()
label_ids
=
label_ids
.
to
(
'cpu'
).
numpy
()
tmp_eval_accuracy
=
accuracy
(
logits
,
label_ids
)
tmp_eval_accuracy
=
accuracy
(
logits
,
label_ids
)
eval_loss
+=
tmp_eval_loss
.
item
()
eval_loss
+=
tmp_eval_loss
.
item
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment