"vscode:/vscode.git/clone" did not exist on "0669c1fcd15051ec6fe2d950079886faccf2fb33"
Unverified Commit 2fd28d43 authored by Kevin Canwen Xu's avatar Kevin Canwen Xu Committed by GitHub
Browse files

Add BERT Loses Patience (Patience-based Early Exit) (#5078)

* Add BERT Loses Patience (Patience-based Early Exit)

* update model archive

* update format

* sort import

* flake8

* Add results

* full results

* align the table

* refactor to inherit

* default per gpu eval = 1

* Formatting

* Formatting

* isort

* modify readme

* Add check

* Fix format

* Fix format

* Doc strings

* ALBERT & BERT for sequence classification don't inherit from the original anymore

* Remove incorrect comments

* Remove incorrect comments

* Remove incorrect comments

* Sync up with new code

* Sync up with new code

* Add a test

* Add a test

* Add a test

* Add a test

* Add a test

* Add a test

* Finishing up!
parent f1679d7c
# Patience-based Early Exit
Patience-based Early Exit (PABEE) is a plug-and-play inference method for pretrained language models.
We have already implemented it on BERT and ALBERT. Basically, you can make your LM faster and more robust with PABEE. It can even improve the performance of ALBERT on GLUE. The only sacrifice is that the batch size can only be 1.
Learn more in the paper ["BERT Loses Patience: Fast and Robust Inference with Early Exit"](https://arxiv.org/abs/2006.04152) and the official [GitHub repo](https://github.com/JetRunner/PABEE).
![PABEE](https://github.com/JetRunner/PABEE/raw/master/bert-loses-patience.png)
## Training
You can fine-tune a pretrained language model (you can choose from BERT and ALBERT) and train the internal classifiers by:
```bash
export GLUE_DIR=/path/to/glue_data
export TASK_NAME=MRPC
python ./run_glue_with_pabee.py \
--model_type albert \
--model_name_or_path bert-base-uncased/albert-base-v2 \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--do_lower_case \
--data_dir "$GLUE_DIR/$TASK_NAME" \
--max_seq_length 128 \
--per_gpu_train_batch_size 32 \
--per_gpu_eval_batch_size 32 \
--learning_rate 2e-5 \
--save_steps 50 \
--logging_steps 50 \
--num_train_epochs 5 \
--output_dir /path/to/save/ \
--evaluate_during_training
```
## Inference
You can inference with different patience settings by:
```bash
export GLUE_DIR=/path/to/glue_data
export TASK_NAME=MRPC
python ./run_glue_with_pabee.py \
--model_type albert \
--model_name_or_path /path/to/save/ \
--task_name $TASK_NAME \
--do_eval \
--do_lower_case \
--data_dir "$GLUE_DIR/$TASK_NAME" \
--max_seq_length 128 \
--per_gpu_eval_batch_size 1 \
--learning_rate 2e-5 \
--logging_steps 50 \
--num_train_epochs 15 \
--output_dir /path/to/save/ \
--eval_all_checkpoints \
--patience 3,4,5,6,7,8
```
where `patience` can be a list of patience settings, separated by a comma. It will help determine which patience works best.
When evaluating on a regression task (STS-B), you may add `--regression_threshold 0.1` to define the regression threshold.
## Results
On the GLUE dev set:
| Model | \#Param | Speed | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST\-2 | STS\-B |
|--------------|---------|--------|-------|-------|-------|-------|-------|-------|--------|--------|
| ALBERT\-base | 12M | | 58\.9 | 84\.6 | 89\.5 | 91\.7 | 89\.6 | 78\.6 | 92\.8 | 89\.5 |
| \+PABEE | 12M | 1\.57x | 61\.2 | 85\.1 | 90\.0 | 91\.8 | 89\.6 | 80\.1 | 93\.0 | 90\.1 |
| Model | \#Param | Speed\-up | MNLI | SST\-2 | STS\-B |
|---------------|---------|-----------|-------|--------|--------|
| BERT\-base | 108M | | 84\.5 | 92\.1 | 88\.9 |
| \+PABEE | 108M | 1\.62x | 83\.6 | 92\.0 | 88\.7 |
| ALBERT\-large | 18M | | | | |
| \+PABEE | 18M | 2\.42x | 86\.8 | 95\.2 | 90\.6 |
## Citation
If you find this resource useful, please consider citing the following paper:
```bibtex
@misc{zhou2020bert,
title={BERT Loses Patience: Fast and Robust Inference with Early Exit},
author={Wangchunshu Zhou and Canwen Xu and Tao Ge and Julian McAuley and Ke Xu and Furu Wei},
year={2020},
eprint={2006.04152},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
# coding=utf-8
# Copyright 2020 Google AI, Google Brain, the HuggingFace Inc. team and Microsoft Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ALBERT model with Patience-based Early Exit. """
import logging
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_albert import (
ALBERT_INPUTS_DOCSTRING,
ALBERT_START_DOCSTRING,
AlbertModel,
AlbertPreTrainedModel,
AlbertTransformer,
)
logger = logging.getLogger(__name__)
class AlbertTransformerWithPabee(AlbertTransformer):
def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None):
if current_layer == 0:
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
else:
hidden_states = hidden_states[0]
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
# Index of the hidden group
group_idx = int(current_layer / (self.config.num_hidden_layers / self.config.num_hidden_groups))
layer_group_output = self.albert_layer_groups[group_idx](
hidden_states,
attention_mask,
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
)
hidden_states = layer_group_output[0]
return (hidden_states,)
@add_start_docstrings(
"The bare ALBERT Model transformer with PABEE outputting raw hidden-states without any specific head on top.",
ALBERT_START_DOCSTRING,
)
class AlbertModelWithPabee(AlbertModel):
def __init__(self, config):
super().__init__(config)
self.encoder = AlbertTransformerWithPabee(config)
self.init_weights()
self.patience = 0
self.inference_instances_num = 0
self.inference_layers_num = 0
self.regression_threshold = 0
def set_regression_threshold(self, threshold):
self.regression_threshold = threshold
def set_patience(self, patience):
self.patience = patience
def reset_stats(self):
self.inference_instances_num = 0
self.inference_layers_num = 0
def log_stats(self):
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
print(message)
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_dropout=None,
output_layers=None,
regression=False,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = embedding_output
if self.training:
res = []
for i in range(self.config.num_hidden_layers):
encoder_outputs = self.encoder.adaptive_forward(
encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask,
)
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
logits = output_layers[i](output_dropout(pooled_output))
res.append(logits)
elif self.patience == 0: # Use all layers for inference
encoder_outputs = self.encoder(encoder_outputs, extended_attention_mask, head_mask=head_mask)
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)]
else:
patient_counter = 0
patient_result = None
calculated_layer_num = 0
for i in range(self.config.num_hidden_layers):
calculated_layer_num += 1
encoder_outputs = self.encoder.adaptive_forward(
encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask,
)
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
logits = output_layers[i](pooled_output)
if regression:
labels = logits.detach()
if patient_result is not None:
patient_labels = patient_result.detach()
if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold:
patient_counter += 1
else:
patient_counter = 0
else:
labels = logits.detach().argmax(dim=1)
if patient_result is not None:
patient_labels = patient_result.detach().argmax(dim=1)
if (patient_result is not None) and torch.all(labels.eq(patient_labels)):
patient_counter += 1
else:
patient_counter = 0
patient_result = logits
if patient_counter == self.patience:
break
res = [patient_result]
self.inference_layers_num += calculated_layer_num
self.inference_instances_num += 1
return res
@add_start_docstrings(
"""Albert Model transformer with PABEE and a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
ALBERT_START_DOCSTRING,
)
class AlbertForSequenceClassificationWithPabee(AlbertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.albert = AlbertModelWithPabee(config)
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifiers = nn.ModuleList(
[nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)]
)
self.init_weights()
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
logits ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import AlbertTokenizer
from pabee import AlbertForSequenceClassificationWithPabee
import torch
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForSequenceClassificationWithPabee.from_pretrained('albert-base-v2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
logits = self.albert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_dropout=self.dropout,
output_layers=self.classifiers,
regression=self.num_labels == 1,
)
outputs = (logits[-1],)
if labels is not None:
total_loss = None
total_weights = 0
for ix, logits_item in enumerate(logits):
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits_item.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1))
if total_loss is None:
total_loss = loss
else:
total_loss += loss * (ix + 1)
total_weights += ix + 1
outputs = (total_loss / total_weights,) + outputs
return outputs
# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and Microsoft Corporation.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model with Patience-based Early Exit. """
import logging
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_bert import (
BERT_INPUTS_DOCSTRING,
BERT_START_DOCSTRING,
BertEncoder,
BertModel,
BertPreTrainedModel,
)
logger = logging.getLogger(__name__)
class BertEncoderWithPabee(BertEncoder):
def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None):
layer_outputs = self.layer[current_layer](hidden_states, attention_mask, head_mask[current_layer])
hidden_states = layer_outputs[0]
return hidden_states
@add_start_docstrings(
"The bare Bert Model transformer with PABEE outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class BertModelWithPabee(BertModel):
"""
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
.. _`Attention is all you need`:
https://arxiv.org/abs/1706.03762
"""
def __init__(self, config):
super().__init__(config)
self.encoder = BertEncoderWithPabee(config)
self.init_weights()
self.patience = 0
self.inference_instances_num = 0
self.inference_layers_num = 0
self.regression_threshold = 0
def set_regression_threshold(self, threshold):
self.regression_threshold = threshold
def set_patience(self, patience):
self.patience = patience
def reset_stats(self):
self.inference_instances_num = 0
self.inference_layers_num = 0
def log_stats(self):
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
message = f"*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***"
print(message)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_dropout=None,
output_layers=None,
regression=False,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = embedding_output
if self.training:
res = []
for i in range(self.config.num_hidden_layers):
encoder_outputs = self.encoder.adaptive_forward(
encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask
)
pooled_output = self.pooler(encoder_outputs)
logits = output_layers[i](output_dropout(pooled_output))
res.append(logits)
elif self.patience == 0: # Use all layers for inference
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
)
pooled_output = self.pooler(encoder_outputs[0])
res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)]
else:
patient_counter = 0
patient_result = None
calculated_layer_num = 0
for i in range(self.config.num_hidden_layers):
calculated_layer_num += 1
encoder_outputs = self.encoder.adaptive_forward(
encoder_outputs, current_layer=i, attention_mask=extended_attention_mask, head_mask=head_mask
)
pooled_output = self.pooler(encoder_outputs)
logits = output_layers[i](pooled_output)
if regression:
labels = logits.detach()
if patient_result is not None:
patient_labels = patient_result.detach()
if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold:
patient_counter += 1
else:
patient_counter = 0
else:
labels = logits.detach().argmax(dim=1)
if patient_result is not None:
patient_labels = patient_result.detach().argmax(dim=1)
if (patient_result is not None) and torch.all(labels.eq(patient_labels)):
patient_counter += 1
else:
patient_counter = 0
patient_result = logits
if patient_counter == self.patience:
break
res = [patient_result]
self.inference_layers_num += calculated_layer_num
self.inference_instances_num += 1
return res
@add_start_docstrings(
"""Bert Model transformer with PABEE and a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
BERT_START_DOCSTRING,
)
class BertForSequenceClassificationWithPabee(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModelWithPabee(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifiers = nn.ModuleList(
[nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)]
)
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import BertTokenizer, BertForSequenceClassification
from pabee import BertForSequenceClassificationWithPabee
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassificationWithPabee.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
logits = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_dropout=self.dropout,
output_layers=self.classifiers,
regression=self.num_labels == 1,
)
outputs = (logits[-1],)
if labels is not None:
total_loss = None
total_weights = 0
for ix, logits_item in enumerate(logits):
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits_item.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1))
if total_loss is None:
total_loss = loss
else:
total_loss += loss * (ix + 1)
total_weights += ix + 1
outputs = (total_loss / total_weights,) + outputs
return outputs
This diff is collapsed.
import argparse
import logging
import sys
import unittest
from unittest.mock import patch
import run_glue_with_pabee
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f
class PabeeTests(unittest.TestCase):
def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
testargs = """
run_glue_with_pabee.py
--model_type albert
--model_name_or_path albert-base-v2
--data_dir ./tests/fixtures/tests_samples/MRPC/
--task_name mrpc
--do_train
--do_eval
--output_dir ./tests/fixtures/tests_samples/temp_dir
--per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=1
--learning_rate=2e-5
--max_steps=50
--warmup_steps=2
--overwrite_output_dir
--seed=42
--max_seq_length=128
""".split()
with patch.object(sys, "argv", testargs):
result = run_glue_with_pabee.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment