Commit 092dacfd authored by thomwolf's avatar thomwolf
Browse files

changing is_regression to unified API

parent e55d4c4e
...@@ -591,3 +591,15 @@ output_modes = { ...@@ -591,3 +591,15 @@ output_modes = {
"rte": "classification", "rte": "classification",
"wnli": "classification", "wnli": "classification",
} }
GLUE_TASKS_NUM_LABELS = {
"cola": 2,
"mnli": 3,
"mrpc": 2,
"sst-2": 2,
"sts-b": 1,
"qqp": 2,
"qnli": 2,
"rte": 2,
"wnli": 2,
}
...@@ -28,16 +28,16 @@ from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME, ...@@ -28,16 +28,16 @@ from pytorch_pretrained_bert.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
XLNetForSequenceClassification, XLNetForSequenceClassification,
load_tf_weights_in_xlnet) load_tf_weights_in_xlnet)
GLUE_TASKS = { GLUE_TASKS_NUM_LABELS = {
"cola": "classification", "cola": 2,
"mnli": "classification", "mnli": 3,
"mrpc": "classification", "mrpc": 2,
"sst-2": "classification", "sst-2": 2,
"sts-b": "regression", "sts-b": 1,
"qqp": "classification", "qqp": 2,
"qnli": "classification", "qnli": 2,
"rte": "classification", "rte": 2,
"wnli": "classification", "wnli": 2,
} }
...@@ -46,9 +46,9 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py ...@@ -46,9 +46,9 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
config = XLNetConfig.from_json_file(bert_config_file) config = XLNetConfig.from_json_file(bert_config_file)
finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
if finetuning_task in GLUE_TASKS: if finetuning_task in GLUE_TASKS_NUM_LABELS:
print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
model = XLNetForSequenceClassification(config, is_regression=bool(GLUE_TASKS[finetuning_task] == "regression")) model = XLNetForSequenceClassification(config, num_labels=GLUE_TASKS_NUM_LABELS[finetuning_task])
elif 'squad' in finetuning_task: elif 'squad' in finetuning_task:
model = XLNetForQuestionAnswering(config) model = XLNetForQuestionAnswering(config)
else: else:
......
...@@ -27,7 +27,7 @@ from io import open ...@@ -27,7 +27,7 @@ from io import open
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
...@@ -1196,6 +1196,11 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1196,6 +1196,11 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
if labels is not None: if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss return loss
......
...@@ -1175,7 +1175,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1175,7 +1175,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
target=None, output_all_encoded_layers=True, head_mask=None): labels=None, output_all_encoded_layers=True, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...@@ -1212,11 +1212,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1212,11 +1212,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
logits = self.lm_loss(output) logits = self.lm_loss(output)
if target is not None: if labels is not None:
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)), loss = loss_fct(logits.view(-1, logits.size(-1)),
target.view(-1)) labels.view(-1))
return loss, new_mems return loss, new_mems
# if self.output_attentions: # if self.output_attentions:
...@@ -1305,13 +1305,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1305,13 +1305,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Outputs: Tuple of (logits or loss, mems) Outputs: Tuple of (logits or loss, mems)
`logits or loss`: `logits or loss`:
if target is None: if labels is None:
Token logits with shape [batch_size, sequence_length] Token logits with shape [batch_size, sequence_length]
else: else:
CrossEntropy loss with the targets CrossEntropy loss with the targets
`new_mems`: list (num layers) of updated mem states at the entry of each layer `new_mems`: list (num layers) of updated mem states at the entry of each layer
each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model] each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model]
Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `target` Note that the first two dimensions are transposed in `mems` with regards to `input_ids` and `labels`
Example usage: Example usage:
```python ```python
...@@ -1328,13 +1328,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1328,13 +1328,13 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
``` ```
""" """
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2, def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
is_regression=False, output_attentions=False, keep_multihead_output=False): output_attentions=False, keep_multihead_output=False):
super(XLNetForSequenceClassification, self).__init__(config) super(XLNetForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.attn_type = config.attn_type self.attn_type = config.attn_type
self.same_length = config.same_length self.same_length = config.same_length
self.summary_type = summary_type self.summary_type = summary_type
self.is_regression = is_regression self.num_labels = num_labels
self.transformer = XLNetModel(config, output_attentions=output_attentions, self.transformer = XLNetModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
...@@ -1342,12 +1342,12 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1342,12 +1342,12 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type, self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type,
use_proj=use_proj, output_attentions=output_attentions, use_proj=use_proj, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) keep_multihead_output=keep_multihead_output)
self.logits_proj = nn.Linear(config.d_model, num_labels if not is_regression else 1) self.logits_proj = nn.Linear(config.d_model, num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
target=None, output_all_encoded_layers=True, head_mask=None): labels=None, output_all_encoded_layers=True, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
...@@ -1382,13 +1382,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1382,13 +1382,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output = self.sequence_summary(output) output = self.sequence_summary(output)
logits = self.logits_proj(output) logits = self.logits_proj(output)
if target is not None: if labels is not None:
if self.is_regression: if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), target.view(-1)) loss = loss_fct(logits.view(-1), labels.view(-1))
else: else:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), target.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, new_mems return loss, new_mems
# if self.output_attentions: # if self.output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment