Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
092dacfd
Commit
092dacfd
authored
Jun 26, 2019
by
thomwolf
Browse files
changing is_regression to unified API
parent
e55d4c4e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
31 deletions
+49
-31
examples/utils_glue.py
examples/utils_glue.py
+12
-0
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
...ch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
+12
-12
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+8
-3
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+17
-16
No files found.
examples/utils_glue.py
View file @
092dacfd
...
...
@@ -591,3 +591,15 @@ output_modes = {
"rte"
:
"classification"
,
"wnli"
:
"classification"
,
}
GLUE_TASKS_NUM_LABELS
=
{
"cola"
:
2
,
"mnli"
:
3
,
"mrpc"
:
2
,
"sst-2"
:
2
,
"sts-b"
:
1
,
"qqp"
:
2
,
"qnli"
:
2
,
"rte"
:
2
,
"wnli"
:
2
,
}
pytorch_pretrained_bert/convert_xlnet_checkpoint_to_pytorch.py
View file @
092dacfd
...
...
@@ -28,16 +28,16 @@ from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
XLNetForSequenceClassification
,
load_tf_weights_in_xlnet
)
GLUE_TASKS
=
{
"cola"
:
"classification"
,
"mnli"
:
"classification"
,
"mrpc"
:
"classification"
,
"sst-2"
:
"classification"
,
"sts-b"
:
"regression"
,
"qqp"
:
"classification"
,
"qnli"
:
"classification"
,
"rte"
:
"classification"
,
"wnli"
:
"classification"
,
GLUE_TASKS
_NUM_LABELS
=
{
"cola"
:
2
,
"mnli"
:
3
,
"mrpc"
:
2
,
"sst-2"
:
2
,
"sts-b"
:
1
,
"qqp"
:
2
,
"qnli"
:
2
,
"rte"
:
2
,
"wnli"
:
2
,
}
...
...
@@ -46,9 +46,9 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
config
=
XLNetConfig
.
from_json_file
(
bert_config_file
)
finetuning_task
=
finetuning_task
.
lower
()
if
finetuning_task
is
not
None
else
""
if
finetuning_task
in
GLUE_TASKS
:
if
finetuning_task
in
GLUE_TASKS
_NUM_LABELS
:
print
(
"Building PyTorch XLNetForSequenceClassification model from configuration: {}"
.
format
(
str
(
config
)))
model
=
XLNetForSequenceClassification
(
config
,
is_regression
=
bool
(
GLUE_TASKS
[
finetuning_task
]
==
"regression"
)
)
model
=
XLNetForSequenceClassification
(
config
,
num_labels
=
GLUE_TASKS_NUM_LABELS
[
finetuning_task
]
)
elif
'squad'
in
finetuning_task
:
model
=
XLNetForQuestionAnswering
(
config
)
else
:
...
...
pytorch_pretrained_bert/modeling.py
View file @
092dacfd
...
...
@@ -27,7 +27,7 @@ from io import open
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.file_utils
import
cached_path
,
WEIGHTS_NAME
,
CONFIG_NAME
...
...
@@ -1196,6 +1196,11 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits
=
self
.
classifier
(
pooled_output
)
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
# We are doing regression
loss_fct
=
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
labels
.
view
(
-
1
))
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
...
...
pytorch_pretrained_bert/modeling_xlnet.py
View file @
092dacfd
...
...
@@ -1175,7 +1175,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
target
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
labels
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
...
@@ -1212,11 +1212,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
logits
=
self
.
lm_loss
(
output
)
if
target
is
not
None
:
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
target
.
view
(
-
1
))
labels
.
view
(
-
1
))
return
loss
,
new_mems
# if self.output_attentions:
...
...
@@ -1305,13 +1305,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Outputs: Tuple of (logits or loss, mems)
`logits or loss`:
if
target
is None:
if
labels
is None:
Token logits with shape [batch_size, sequence_length]
else:
CrossEntropy loss with the targets
`new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `
target
`
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `
labels
`
Example usage:
```python
...
...
@@ -1328,13 +1328,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
```
"""
def
__init__
(
self
,
config
,
summary_type
=
"last"
,
use_proj
=
True
,
num_labels
=
2
,
is_regression
=
False
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
output_attentions
=
False
,
keep_multihead_output
=
False
):
super
(
XLNetForSequenceClassification
,
self
).
__init__
(
config
)
self
.
output_attentions
=
output_attentions
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
summary_type
=
summary_type
self
.
is_regression
=
is_regression
self
.
num_labels
=
num_labels
self
.
transformer
=
XLNetModel
(
config
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
...
...
@@ -1342,12 +1342,12 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self
.
sequence_summary
=
XLNetSequenceSummary
(
config
,
summary_type
=
summary_type
,
use_proj
=
use_proj
,
output_attentions
=
output_attentions
,
keep_multihead_output
=
keep_multihead_output
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
if
not
is_regression
else
1
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
num_labels
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
inp_k
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
target
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
labels
=
None
,
output_all_encoded_layers
=
True
,
head_mask
=
None
):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...
...
@@ -1382,13 +1382,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
if
target
is
not
None
:
if
self
.
is_regression
:
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
# We are doing regression
loss_fct
=
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
target
.
view
(
-
1
))
loss
=
loss_fct
(
logits
.
view
(
-
1
),
labels
.
view
(
-
1
))
else
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
target
.
view
(
-
1
))
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
return
loss
,
new_mems
# if self.output_attentions:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment