Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2c731fd1
Commit
2c731fd1
authored
Nov 02, 2018
by
thomwolf
Browse files
small tweaks
parent
9343a231
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
26 deletions
+33
-26
modeling_pytorch.py
modeling_pytorch.py
+22
-10
run_classifier_pytorch.py
run_classifier_pytorch.py
+11
-16
No files found.
modeling_pytorch.py
View file @
2c731fd1
...
...
@@ -349,7 +349,6 @@ class BertModel(nn.Module):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
...
...
@@ -359,16 +358,10 @@ class BertModel(nn.Module):
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = modeling.BertModel(config=config, is_training=True,
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
label_embeddings = tf.get_variable(...)
pooled_output = model.get_pooled_output()
logits = tf.matmul(pooled_output, label_embeddings)
...
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
:
BertConfig
):
"""Constructor for BertModel.
...
...
@@ -400,7 +393,26 @@ class BertModel(nn.Module):
return
all_encoder_layers
,
pooled_output
class
BertForSequenceClassification
(
nn
.
Module
):
def
__init__
(
self
,
config
,
num_labels
):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
num_labels = 2
model = modeling.BertModel(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def
__init__
(
self
,
config
,
num_labels
):
super
(
BertForSequenceClassification
,
self
).
__init__
()
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
...
...
run_classifier_pytorch.py
View file @
2c731fd1
...
...
@@ -115,16 +115,10 @@ parser.add_argument("--save_checkpoints_steps",
default
=
1000
,
type
=
int
,
help
=
"How often to save the model checkpoint."
)
parser
.
add_argument
(
"--iterations_per_loop"
,
default
=
1000
,
type
=
int
,
help
=
"How many steps to make in each estimator call."
)
parser
.
add_argument
(
"--no_cuda"
,
default
=
False
,
type
=
bool
,
help
=
"Whether not to use CUDA when available"
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
,
...
...
@@ -518,16 +512,17 @@ def main():
model
.
train
()
global_step
=
0
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
input_ids
=
input_ids
.
to
(
device
)
input_mask
=
input_mask
.
float
().
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
loss
,
_
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
loss
.
backward
()
optimizer
.
step
()
global_step
+=
1
for
epoch
in
args
.
num_train_epochs
:
for
input_ids
,
input_mask
,
segment_ids
,
label_ids
in
train_dataloader
:
input_ids
=
input_ids
.
to
(
device
)
input_mask
=
input_mask
.
float
().
to
(
device
)
segment_ids
=
segment_ids
.
to
(
device
)
label_ids
=
label_ids
.
to
(
device
)
loss
,
_
=
model
(
input_ids
,
segment_ids
,
input_mask
,
label_ids
)
loss
.
backward
()
optimizer
.
step
()
global_step
+=
1
if
args
.
do_eval
:
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment