Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dee09a40
You need to sign in or sign up before continuing.
Commit
dee09a40
authored
Nov 02, 2018
by
thomwolf
Browse files
various fixes
parent
2c731fd1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
8 deletions
+11
-8
modeling_pytorch.py
modeling_pytorch.py
+2
-1
run_classifier_pytorch.py
run_classifier_pytorch.py
+9
-7
No files found.
modeling_pytorch.py
View file @
dee09a40
...
@@ -412,7 +412,8 @@ class BertForSequenceClassification(nn.Module):
...
@@ -412,7 +412,8 @@ class BertForSequenceClassification(nn.Module):
model = modeling.BertModel(config, num_labels)
model = modeling.BertModel(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
logits = model(input_ids, token_type_ids, input_mask)
```
```
"""
def
__init__
(
self
,
config
,
num_labels
):
"""
def
__init__
(
self
,
config
,
num_labels
):
super
(
BertForSequenceClassification
,
self
).
__init__
()
super
(
BertForSequenceClassification
,
self
).
__init__
()
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
...
run_classifier_pytorch.py
View file @
dee09a40
...
@@ -73,8 +73,8 @@ parser.add_argument("--init_checkpoint",
...
@@ -73,8 +73,8 @@ parser.add_argument("--init_checkpoint",
type
=
str
,
type
=
str
,
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
help
=
"Initial checkpoint (usually from a pre-trained BERT model)."
)
parser
.
add_argument
(
"--do_lower_case"
,
parser
.
add_argument
(
"--do_lower_case"
,
default
=
Tru
e
,
default
=
Fals
e
,
type
=
bool
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. Should be True for uncased models and False for cased models."
)
help
=
"Whether to lower case the input text. Should be True for uncased models and False for cased models."
)
parser
.
add_argument
(
"--max_seq_length"
,
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
default
=
128
,
...
@@ -84,11 +84,11 @@ parser.add_argument("--max_seq_length",
...
@@ -84,11 +84,11 @@ parser.add_argument("--max_seq_length",
"than this will be padded."
)
"than this will be padded."
)
parser
.
add_argument
(
"--do_train"
,
parser
.
add_argument
(
"--do_train"
,
default
=
False
,
default
=
False
,
type
=
bool
,
action
=
'store_true'
,
help
=
"Whether to run training."
)
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_eval"
,
parser
.
add_argument
(
"--do_eval"
,
default
=
False
,
default
=
False
,
type
=
bool
,
action
=
'store_true'
,
help
=
"Whether to run eval on the dev set."
)
help
=
"Whether to run eval on the dev set."
)
parser
.
add_argument
(
"--train_batch_size"
,
parser
.
add_argument
(
"--train_batch_size"
,
default
=
32
,
default
=
32
,
...
@@ -117,7 +117,7 @@ parser.add_argument("--save_checkpoints_steps",
...
@@ -117,7 +117,7 @@ parser.add_argument("--save_checkpoints_steps",
help
=
"How often to save the model checkpoint."
)
help
=
"How often to save the model checkpoint."
)
parser
.
add_argument
(
"--no_cuda"
,
parser
.
add_argument
(
"--no_cuda"
,
default
=
False
,
default
=
False
,
type
=
bool
,
action
=
'store_true'
,
help
=
"Whether not to use CUDA when available"
)
help
=
"Whether not to use CUDA when available"
)
parser
.
add_argument
(
"--local_rank"
,
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
type
=
int
,
...
@@ -490,6 +490,7 @@ def main():
...
@@ -490,6 +490,7 @@ def main():
warmup
=
args
.
warmup_proportion
,
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_steps
)
t_total
=
num_train_steps
)
global_step
=
0
if
args
.
do_train
:
if
args
.
do_train
:
train_features
=
convert_examples_to_features
(
train_features
=
convert_examples_to_features
(
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
)
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
)
...
@@ -511,7 +512,6 @@ def main():
...
@@ -511,7 +512,6 @@ def main():
train_dataloader
=
DataLoader
(
train_data
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
train_dataloader
=
DataLoader
(
train_data
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
model
.
train
()
model
.
train
()
global_step
=
0
for
epoch
in
args
.
num_train_epochs
:
for
epoch
in
args
.
num_train_epochs
:
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
input_ids
=
input_ids
.
to
(
device
)
input_ids
=
input_ids
.
to
(
device
)
...
@@ -552,9 +552,11 @@ def main():
...
@@ -552,9 +552,11 @@ def main():
input_ids
=
input_ids
.
to
(
device
)
input_ids
=
input_ids
.
to
(
device
)
input_mask
=
input_mask
.
float
().
to
(
device
)
input_mask
=
input_mask
.
float
().
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
tmp_eval_loss
,
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
tmp_eval_loss
,
logits
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
logits
=
logits
.
detach
().
cpu
().
numpy
()
label_ids
=
label_ids
.
to
(
'cpu'
).
numpy
()
tmp_eval_accuracy
=
accuracy
(
logits
,
label_ids
)
tmp_eval_accuracy
=
accuracy
(
logits
,
label_ids
)
eval_loss
+=
tmp_eval_loss
.
item
()
eval_loss
+=
tmp_eval_loss
.
item
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment