"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "3b1d5cd4d48ce733ce4f4d32c950da9a998d42b0"
Unverified Commit cfbb9829 authored by Ji Xin's avatar Ji Xin Committed by GitHub
Browse files

Add DeeBERT (entropy-based early exiting for *BERT) (#5477)

* Add deebert code

* Add readme of deebert

* Add test for deebert

Update test for Deebert

* Update DeeBert (README, class names, function refactoring); remove requirements.txt

* Format update

* Update test

* Update readme and model init methods
parent b4b33fdf
# DeeBERT: Early Exiting for *BERT
This is the code base for the paper [DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference](https://www.aclweb.org/anthology/2020.acl-main.204/), modified from its [original code base](https://github.com/castorini/deebert).
The original code base also has information for downloading sample models that we have trained in advance.
## Usage
There are three scripts in the folder which can be run directly.
In each script, there are several things to modify before running:
* `PATH_TO_DATA`: path to the GLUE dataset.
* `--output_dir`: path for saving fine-tuned models. Default: `./saved_models`.
* `--plot_data_dir`: path for saving evaluation results. Default: `./results`. Results are printed to stdout and also saved to `npy` files in this directory to facilitate plotting figures and further analyses.
* `MODEL_TYPE`: bert or roberta
* `MODEL_SIZE`: base or large
* `DATASET`: SST-2, MRPC, RTE, QNLI, QQP, or MNLI
#### train_deebert.sh
This is for fine-tuning DeeBERT models.
#### eval_deebert.sh
This is for evaluating each exit layer for fine-tuned DeeBERT models.
#### entropy_eval.sh
This is for evaluating fine-tuned DeeBERT models, given a number of different early exit entropy thresholds.
## Citation
Please cite our paper if you find the resource useful:
```
@inproceedings{xin-etal-2020-deebert,
title = "{D}ee{BERT}: Dynamic Early Exiting for Accelerating {BERT} Inference",
author = "Xin, Ji and
Tang, Raphael and
Lee, Jaejun and
Yu, Yaoliang and
Lin, Jimmy",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.acl-main.204",
pages = "2246--2251",
}
```
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
PATH_TO_DATA=/h/xinji/projects/GLUE
MODEL_TYPE=bert # bert or roberta
MODEL_SIZE=base # base or large
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
if [ $MODEL_TYPE = 'bert' ]
then
MODEL_NAME=${MODEL_NAME}-uncased
fi
ENTROPIES="0 0.1 0.2 0.3 0.4 0.5 0.6 0.7"
for ENTROPY in $ENTROPIES; do
python -u run_glue_deebert.py \
--model_type $MODEL_TYPE \
--model_name_or_path ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
--task_name $DATASET \
--do_eval \
--do_lower_case \
--data_dir $PATH_TO_DATA/$DATASET \
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
--plot_data_dir ./results/ \
--max_seq_length 128 \
--early_exit_entropy $ENTROPY \
--eval_highway \
--overwrite_cache \
--per_gpu_eval_batch_size=1
done
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
PATH_TO_DATA=/h/xinji/projects/GLUE
MODEL_TYPE=bert # bert or roberta
MODEL_SIZE=base # base or large
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
if [ $MODEL_TYPE = 'bert' ]
then
MODEL_NAME=${MODEL_NAME}-uncased
fi
python -u run_glue_deebert.py \
--model_type $MODEL_TYPE \
--model_name_or_path ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
--task_name $DATASET \
--do_eval \
--do_lower_case \
--data_dir $PATH_TO_DATA/$DATASET \
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
--plot_data_dir ./results/ \
--max_seq_length 128 \
--eval_each_highway \
--eval_highway \
--overwrite_cache \
--per_gpu_eval_batch_size=1
This diff is collapsed.
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_bert import (
BERT_INPUTS_DOCSTRING,
BERT_START_DOCSTRING,
BertEmbeddings,
BertLayer,
BertPooler,
BertPreTrainedModel,
)
def entropy(x):
""" Calculate entropy of a pre-softmax logit Tensor
"""
exp_x = torch.exp(x)
A = torch.sum(exp_x, dim=1) # sum of exp(x_i)
B = torch.sum(x * exp_x, dim=1) # sum of x_i * exp(x_i)
return torch.log(A) - B / A
class DeeBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.highway = nn.ModuleList([BertHighway(config) for _ in range(config.num_hidden_layers)])
self.early_exit_entropy = [-1 for _ in range(config.num_hidden_layers)]
def set_early_exit_entropy(self, x):
if (type(x) is float) or (type(x) is int):
for i in range(len(self.early_exit_entropy)):
self.early_exit_entropy[i] = x
else:
self.early_exit_entropy = x
def init_highway_pooler(self, pooler):
loaded_model = pooler.state_dict()
for highway in self.highway:
for name, param in highway.pooler.state_dict().items():
param.copy_(loaded_model[name])
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
all_hidden_states = ()
all_attentions = ()
all_highway_exits = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
)
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
current_outputs = (hidden_states,)
if self.output_hidden_states:
current_outputs = current_outputs + (all_hidden_states,)
if self.output_attentions:
current_outputs = current_outputs + (all_attentions,)
highway_exit = self.highway[i](current_outputs)
# logits, pooled_output
if not self.training:
highway_logits = highway_exit[0]
highway_entropy = entropy(highway_logits)
highway_exit = highway_exit + (highway_entropy,) # logits, hidden_states(?), entropy
all_highway_exits = all_highway_exits + (highway_exit,)
if highway_entropy < self.early_exit_entropy[i]:
new_output = (highway_logits,) + current_outputs[1:] + (all_highway_exits,)
raise HighwayException(new_output, i + 1)
else:
all_highway_exits = all_highway_exits + (highway_exit,)
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
outputs = outputs + (all_highway_exits,)
return outputs # last-layer hidden state, (all hidden states), (all attentions), all highway exits
@add_start_docstrings(
"The Bert Model transformer with early exiting (DeeBERT). ", BERT_START_DOCSTRING,
)
class DeeBertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = DeeBertEncoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def init_highway_pooler(self):
self.encoder.init_highway_pooler(self.pooler)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
Tuple of each early exit's results (total length: number of layers)
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
"""
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions), highway exits
class HighwayException(Exception):
def __init__(self, message, exit_layer):
self.message = message
self.exit_layer = exit_layer # start from 1!
class BertHighway(nn.Module):
"""A module to provide a shortcut
from (the output of one non-final BertLayer in BertEncoder) to (cross-entropy computation in BertForSequenceClassification)
"""
def __init__(self, config):
super().__init__()
self.pooler = BertPooler(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, encoder_outputs):
# Pooler
pooler_input = encoder_outputs[0]
pooler_output = self.pooler(pooler_input)
# "return" pooler_output
# BertModel
bmodel_output = (pooler_input, pooler_output) + encoder_outputs[1:]
# "return" bodel_output
# Dropout and classification
pooled_output = bmodel_output[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits, pooled_output
@add_start_docstrings(
"""Bert Model (with early exiting - DeeBERT) with a classifier on top,
also takes care of multi-layer training. """,
BERT_START_DOCSTRING,
)
class DeeBertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.num_layers = config.num_hidden_layers
self.bert = DeeBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_layer=-1,
train_highway=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
Tuple of each early exit's results (total length: number of layers)
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
"""
exit_layer = self.num_layers
try:
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
# sequence_output, pooled_output, (hidden_states), (attentions), highway exits
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
except HighwayException as e:
outputs = e.message
exit_layer = e.exit_layer
logits = outputs[0]
if not self.training:
original_entropy = entropy(logits)
highway_entropy = []
highway_logits_all = []
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# work with highway exits
highway_losses = []
for highway_exit in outputs[-1]:
highway_logits = highway_exit[0]
if not self.training:
highway_logits_all.append(highway_logits)
highway_entropy.append(highway_exit[2])
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
highway_losses.append(highway_loss)
if train_highway:
outputs = (sum(highway_losses[:-1]),) + outputs
# exclude the final highway, of course
else:
outputs = (loss,) + outputs
if not self.training:
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
if output_layer >= 0:
outputs = (
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
) # use the highway of the last layer
return outputs # (loss), logits, (hidden_states), (attentions), (highway_exits)
from __future__ import absolute_import, division, print_function, unicode_literals
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.configuration_roberta import RobertaConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_roberta import ROBERTA_INPUTS_DOCSTRING, ROBERTA_START_DOCSTRING, RobertaEmbeddings
from .modeling_highway_bert import BertPreTrainedModel, DeeBertModel, HighwayException, entropy
@add_start_docstrings(
"The RoBERTa Model transformer with early exiting (DeeRoBERTa). ", ROBERTA_START_DOCSTRING,
)
class DeeRobertaModel(DeeBertModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
def __init__(self, config):
super().__init__(config)
self.embeddings = RobertaEmbeddings(config)
self.init_weights()
@add_start_docstrings(
"""RoBERTa Model (with early exiting - DeeRoBERTa) with a classifier on top,
also takes care of multi-layer training. """,
ROBERTA_START_DOCSTRING,
)
class DeeRobertaForSequenceClassification(BertPreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.num_layers = config.num_hidden_layers
self.roberta = DeeRobertaModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_layer=-1,
train_highway=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
highway_exits (:obj:`tuple(tuple(torch.Tensor))`:
Tuple of each early exit's results (total length: number of layers)
Each tuple is again, a tuple of length 2 - the first entry is logits and the second entry is hidden states.
"""
exit_layer = self.num_layers
try:
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
except HighwayException as e:
outputs = e.message
exit_layer = e.exit_layer
logits = outputs[0]
if not self.training:
original_entropy = entropy(logits)
highway_entropy = []
highway_logits_all = []
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# work with highway exits
highway_losses = []
for highway_exit in outputs[-1]:
highway_logits = highway_exit[0]
if not self.training:
highway_logits_all.append(highway_logits)
highway_entropy.append(highway_exit[2])
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
highway_loss = loss_fct(highway_logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
highway_loss = loss_fct(highway_logits.view(-1, self.num_labels), labels.view(-1))
highway_losses.append(highway_loss)
if train_highway:
outputs = (sum(highway_losses[:-1]),) + outputs
# exclude the final highway, of course
else:
outputs = (loss,) + outputs
if not self.training:
outputs = outputs + ((original_entropy, highway_entropy), exit_layer)
if output_layer >= 0:
outputs = (
(outputs[0],) + (highway_logits_all[output_layer],) + outputs[2:]
) # use the highway of the last layer
return outputs # (loss), logits, (hidden_states), (attentions), entropy
import argparse
import logging
import sys
import unittest
from unittest.mock import patch
import run_glue_deebert
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f
class DeeBertTests(unittest.TestCase):
def test_glue_deebert(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
train_args = """
run_glue_deebert.py
--model_type roberta
--model_name_or_path roberta-base
--task_name MRPC
--do_train
--do_eval
--do_lower_case
--data_dir ./tests/fixtures/tests_samples/MRPC/
--max_seq_length 128
--per_gpu_eval_batch_size=1
--per_gpu_train_batch_size=8
--learning_rate 2e-4
--num_train_epochs 3
--overwrite_output_dir
--seed 42
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
--plot_data_dir ./examples/deebert/results/
--save_steps 0
--overwrite_cache
--eval_after_first_stage
""".split()
eval_args = """
run_glue_deebert.py
--model_type roberta
--model_name_or_path ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
--task_name MRPC
--do_eval
--do_lower_case
--data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
--plot_data_dir ./examples/deebert/results/
--max_seq_length 128
--eval_each_highway
--eval_highway
--overwrite_cache
--per_gpu_eval_batch_size=1
""".split()
entropy_eval_args = """
run_glue_deebert.py
--model_type roberta
--model_name_or_path ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
--task_name MRPC
--do_eval
--do_lower_case
--data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir ./examples/deebert/saved_models/roberta-base/MRPC/two_stage
--plot_data_dir ./examples/deebert/results/
--max_seq_length 128
--early_exit_entropy 0.1
--eval_highway
--overwrite_cache
--per_gpu_eval_batch_size=1
""".split()
with patch.object(sys, "argv", train_args):
result = run_glue_deebert.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)
with patch.object(sys, "argv", eval_args):
result = run_glue_deebert.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)
with patch.object(sys, "argv", entropy_eval_args):
result = run_glue_deebert.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
PATH_TO_DATA=/h/xinji/projects/GLUE
MODEL_TYPE=bert # bert or roberta
MODEL_SIZE=base # base or large
DATASET=MRPC # SST-2, MRPC, RTE, QNLI, QQP, or MNLI
MODEL_NAME=${MODEL_TYPE}-${MODEL_SIZE}
EPOCHS=10
if [ $MODEL_TYPE = 'bert' ]
then
EPOCHS=3
MODEL_NAME=${MODEL_NAME}-uncased
fi
python -u run_glue_deebert.py \
--model_type $MODEL_TYPE \
--model_name_or_path $MODEL_NAME \
--task_name $DATASET \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $PATH_TO_DATA/$DATASET \
--max_seq_length 128 \
--per_gpu_eval_batch_size=1 \
--per_gpu_train_batch_size=8 \
--learning_rate 2e-5 \
--num_train_epochs $EPOCHS \
--overwrite_output_dir \
--seed 42 \
--output_dir ./saved_models/${MODEL_TYPE}-${MODEL_SIZE}/$DATASET/two_stage \
--plot_data_dir ./results/ \
--save_steps 0 \
--overwrite_cache \
--eval_after_first_stage
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment