Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
42968138
Commit
42968138
authored
Aug 27, 2019
by
VictorSanh
Browse files
wip wouf
parent
1d232400
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
343 additions
and
65 deletions
+343
-65
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-0
pytorch_transformers/modeling_dilbert.py
pytorch_transformers/modeling_dilbert.py
+341
-65
No files found.
pytorch_transformers/__init__.py
View file @
42968138
...
...
@@ -40,6 +40,8 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_roberta
import
(
RobertaConfig
,
RobertaForMaskedLM
,
RobertaModel
,
RobertaForSequenceClassification
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_dilbert
import
(
DilBertconfig
,
DilBertForMaskedLM
,
DilBertModel
,
DilBertForSequenceClassification
,
DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_utils
import
(
WEIGHTS_NAME
,
CONFIG_NAME
,
TF_WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_layer
,
Conv1D
)
...
...
pytorch_transformers/modeling_dilbert.py
View file @
42968138
...
...
@@ -20,6 +20,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
json
import
logging
import
math
import
copy
import
sys
from
io
import
open
...
...
@@ -54,6 +55,7 @@ class DilBertconfig(PretrainedConfig):
n_layers
=
6
,
n_heads
=
12
,
dim
=
768
,
hidden_dim
=
4
*
768
,
dropout
=
0.1
,
attention_dropout
=
0.1
,
activation
=
'gelu'
,
...
...
@@ -62,7 +64,7 @@ class DilBertconfig(PretrainedConfig):
**
kwargs
):
super
(
DilBertconfig
,
self
).
__init__
(
**
kwargs
)
if
isintance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
if
isin
s
tance
(
vocab_size_or_config_json_file
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
vocab_size_or_config_json_file
,
unicode
)):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
'utf-8'
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
...
...
@@ -85,6 +87,7 @@ class DilBertconfig(PretrainedConfig):
"or the path to a pretrained model config file (str)"
)
### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ###
def
gelu
(
x
):
return
0.5
*
x
*
(
1.0
+
torch
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
...
...
@@ -102,9 +105,9 @@ class Embeddings(nn.Module):
def
__init__
(
self
,
config
):
super
(
Embeddings
,
self
).
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
dim
,
padding_idx
=
0
)
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
dim
,
padding_idx
=
0
)
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
dim
)
if
sinusoidal_pos_embds
:
if
config
.
sinusoidal_pos_embds
:
create_sinusoidal_embeddings
(
n_pos
=
config
.
max_position_embeddings
,
dim
=
config
.
dim
,
out
=
self
.
position_embeddings
.
weight
)
...
...
@@ -116,7 +119,13 @@ class Embeddings(nn.Module):
"""
Parameters
----------
input_ids: torch.tensor(bs, max_seq_length) - The token ids to embed.
input_ids: torch.tensor(bs, max_seq_length)
The token ids to embed.
Outputs
-------
embeddings: torch.tensor(bs, max_seq_length, dim)
The embedded tokens (plus position embeddings, no token_type embeddings)
"""
seq_length
=
input_ids
.
size
(
1
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
# (max_seq_length)
...
...
@@ -125,9 +134,9 @@ class Embeddings(nn.Module):
word_embeddings
=
self
.
word_embeddings
(
input_ids
)
# (bs, max_seq_length, dim)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
# (bs, max_seq_length, dim)
embeddings
=
word_embeddings
+
position_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
embeddings
=
self
.
dropout
(
embeddings
)
embeddings
=
word_embeddings
+
position_embeddings
# (bs, max_seq_length, dim)
embeddings
=
self
.
LayerNorm
(
embeddings
)
# (bs, max_seq_length, dim)
embeddings
=
self
.
dropout
(
embeddings
)
# (bs, max_seq_length, dim)
return
embeddings
class
MultiHeadSelfAttention
(
nn
.
Module
):
...
...
@@ -142,10 +151,10 @@ class MultiHeadSelfAttention(nn.Module):
assert
self
.
dim
%
self
.
n_heads
==
0
self
.
q_lin
=
nn
.
Linear
(
in_features
=
dim
,
out_features
=
dim
)
self
.
k_lin
=
nn
.
Linear
(
in_features
=
dim
,
out_features
=
dim
)
self
.
v_lin
=
nn
.
Linear
(
in_features
=
dim
,
out_features
=
dim
)
self
.
out_lin
=
nn
.
Linear
(
in_features
=
dim
,
out_features
=
dim
)
self
.
q_lin
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
dim
)
self
.
k_lin
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
dim
)
self
.
v_lin
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
dim
)
self
.
out_lin
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
dim
)
def
forward
(
self
,
query
:
torch
.
tensor
,
...
...
@@ -153,8 +162,6 @@ class MultiHeadSelfAttention(nn.Module):
value
:
torch
.
tensor
,
mask
:
torch
.
tensor
):
"""
Classic Self Attention. I don't understand the one of PyTorch...
Parameters
----------
query: torch.tensor(bs, seq_length, dim)
...
...
@@ -162,12 +169,12 @@ class MultiHeadSelfAttention(nn.Module):
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Return
------
Outputs
------
-
weights: torch.tensor(bs, n_heads, seq_length, seq_length)
Attention weights
context: torch.tensor(bs, seq_length, dim)
Contextualized layer
Contextualized layer
. Optional: only if `output_attentions=True`
"""
bs
,
q_length
,
dim
=
query
.
size
()
k_length
=
key
.
size
(
1
)
...
...
@@ -204,9 +211,9 @@ class MultiHeadSelfAttention(nn.Module):
context
=
self
.
out_lin
(
context
)
# (bs, q_length, dim)
if
self
.
output_attentions
:
return
context
,
weights
return
(
context
,
weights
)
else
:
return
context
return
(
context
,)
class
FFN
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -215,8 +222,8 @@ class FFN(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
p
=
config
.
dropout
)
self
.
lin1
=
nn
.
Linear
(
in_features
=
config
.
dim
,
out_features
=
config
.
hidden_dim
)
self
.
lin2
=
nn
.
Linear
(
in_features
=
config
.
hidden_dim
,
out_features
=
config
.
dim
)
assert
activation
in
[
'relu'
,
'gelu'
],
ValueError
(
f
"activation (
{
config
.
activation
}
) must be in ['relu', 'gelu']"
)
self
.
activation
=
gelu
if
activation
==
'gelu'
else
nn
.
ReLU
()
assert
config
.
activation
in
[
'relu'
,
'gelu'
],
ValueError
(
f
"activation (
{
config
.
activation
}
) must be in ['relu', 'gelu']"
)
self
.
activation
=
gelu
if
config
.
activation
==
'gelu'
else
nn
.
ReLU
()
def
forward
(
self
,
input
:
torch
.
tensor
):
...
...
@@ -238,19 +245,12 @@ class TransformerBlock(nn.Module):
self
.
activation
=
config
.
activation
self
.
output_attentions
=
config
.
output_attentions
assert
dim
%
n_heads
==
0
assert
config
.
dim
%
config
.
n_heads
==
0
self
.
attention
=
MultiHeadSelfAttention
(
dim
=
config
.
dim
,
n_heads
=
config
.
n_heads
,
dropout
=
config
.
attention_dropout
,
output_attentions
=
config
.
output_attentions
)
self
.
attention
=
MultiHeadSelfAttention
(
config
)
self
.
sa_layer_norm
=
nn
.
LayerNorm
(
normalized_shape
=
config
.
dim
,
eps
=
1e-12
)
self
.
ffn
=
FFN
(
in_dim
=
config
.
dim
,
hidden_dim
=
config
.
hidden_dim
,
out_dim
=
config
.
dim
,
dropout
=
config
.
dropout
,
activation
=
config
.
activation
)
self
.
ffn
=
FFN
(
config
)
self
.
output_layer_norm
=
nn
.
LayerNorm
(
normalized_shape
=
config
.
dim
,
eps
=
1e-12
)
def
forward
(
self
,
...
...
@@ -261,21 +261,28 @@ class TransformerBlock(nn.Module):
----------
x: torch.tensor(bs, seq_length, dim)
attn_mask: torch.tensor(bs, seq_length)
Outputs
-------
sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length)
The attention weights
ffn_output: torch.tensor(bs, seq_length, dim)
The output of the transformer block contextualization.
"""
# Self-Attention
sa_output
=
self
.
attention
(
query
=
x
,
key
=
x
,
value
=
x
,
mask
=
attn_mask
)
if
self
.
output_attentions
:
sa_output
,
sa_weights
=
sa_output
# (bs, seq_length, dim)
sa_output
,
sa_weights
=
sa_output
# (bs, seq_length, dim)
, (bs, n_heads, seq_length, seq_length)
sa_output
=
self
.
sa_layer_norm
(
sa_output
+
x
)
# (bs, seq_length, dim)
# Feed Forward Network
ffn_output
=
self
.
ffn
(
sa_output
)
# (bs, seq_length, dim)
ffn_output
=
self
.
output_layer_norm
(
ffn_output
+
sa_output
)
# (bs, seq_length, dim)
output
=
(
ffn_output
)
if
self
.
output_attentions
:
return
sa_weights
,
ffn_output
else
:
return
ffn_output
output
=
(
sa_weights
,)
+
output
return
output
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -283,52 +290,286 @@ class Transformer(nn.Module):
super
(
Transformer
,
self
).
__init__
()
self
.
n_layers
=
config
.
n_layers
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
layer
=
TransformerBlock
(
n_heads
=
config
.
n_heads
,
dim
=
config
.
dim
,
hidden_dim
=
config
.
hidden_dim
,
dropout
=
config
.
dropout
,
attention_dropout
=
config
.
attention_dropout
,
activation
=
config
.
activation
,
output_attentions
=
config
.
output_attentions
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
n_layers
)])
layer
=
TransformerBlock
(
config
)
self
.
layer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
config
.
n_layers
)])
def
forward
(
self
,
x
:
torch
.
tensor
,
attn_mask
:
torch
.
tensor
=
None
,
output_all_encoded_layers
:
bool
=
True
):
attn_mask
:
torch
.
tensor
=
None
):
"""
Parameters
----------
x: torch.tensor(bs, seq_length, dim)
Input sequence embedded.
attn_mask: torch.tensor(bs, seq_length)
output_all_encoded_layers: bool
Attention mask on the sequence.
Outputs
-------
hidden_state: torch.tensor(bs, seq_length, dim)
Sequence of hiddens states in the last (top) layer
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if output_hidden_states=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
"""
all_
encoder_layer
s
=
[]
all_attentions
=
[]
all_
hidden_state
s
=
()
all_attentions
=
()
hidden_state
=
x
for
_
,
layer_module
in
enumerate
(
self
.
layer
):
x
=
layer_module
(
x
=
x
,
attn_mask
=
attn_mask
)
hidden_state
=
layer_module
(
x
=
hidden_state
,
attn_mask
=
attn_mask
)
if
self
.
output_attentions
:
attentions
,
x
=
x
all_attentions
.
append
(
attentions
)
all_encoder_layers
.
append
(
x
)
if
not
output_all_encoded_layers
:
all_encoder_layers
=
all_encoder_layers
[
-
1
]
attentions
,
hidden_state
=
hidden_state
all_attentions
=
all_attentions
+
(
attentions
,)
all_hidden_states
=
all_hidden_states
+
(
hidden_state
,)
outputs
=
(
hidden_state
,)
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
return
all_attentions
,
all_encoder_layers
else
:
return
all_encoder_layers
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
### INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL ###
class
DilBertPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class
=
DilBertconfig
pretrained_model_archive_map
=
DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
None
base_model_prefix
=
"dilbert"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
DilBertPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
nn
.
Embedding
):
if
module
.
weight
.
requires_grad
:
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
DILBERT_START_DOCSTRING
=
r
"""
Smaller, faster, cheaper, lighter: DilBERT
For more information on DilBERT, you should check TODO(Victor): Link to Medium
Parameters:
config (:class:`~pytorch_transformers.DilBertconfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
DILBERT_INPUTS_DOCSTRING
=
r
"""
Inputs:
**input_ids**L ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices oof input sequence tokens in the vocabulary.
The input sequences should start with `[CLS]` and `[SEP]` tokens.
For now, ONLY BertTokenizer(`bert-base-uncased`) is supported and you should use this tokenizer when using DilBERT.
**attention_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
"""
@
add_start_docstrings
(
"The bare DilBERT encoder/transformer outputing raw hidden-states without any specific head on top."
,
DILBERT_START_DOCSTRING
,
DILBERT_INPUTS_DOCSTRING
)
class
DilBertModel
(
DilBertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
(
DilBertModel
,
self
).
__init__
(
config
)
self
.
embeddings
=
Embeddings
(
config
)
# Embeddings
self
.
transformer
=
Transformer
(
config
)
# Encoder
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
=
None
):
"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
Sequences of token ids.
attention_mask: torch.tensor(bs, seq_length)
Attention mask on the sequences. Optional: If None, it's like there was no padding.
Outputs
-------
hidden_state: torch.tensor(bs, seq_length, dim)
Sequence of hiddens states in the last (top) layer
pooled_output: torch.tensor(bs, dim)
Pooled output: for DilBert, the pooled output is simply the hidden state of the [CLS] token.
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if output_hidden_states=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
"""
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
# (bs, seq_length)
embedding_output
=
self
.
embeddings
(
input_ids
)
# (bs, seq_length, dim)
tfmr_output
=
self
.
transformer
(
x
=
embedding_output
,
attn_mask
=
attention_mask
)
hidden_state
=
tfmr_output
[
0
]
pooled_output
=
hidden_state
[:,
0
]
output
=
(
hidden_state
,
pooled_output
)
+
tfmr_output
[
1
:]
# TODO(Victor)
# class DilBertWithLMHeadModel(DilBertPreTrainedModel):
# class DilBertForSequenceClassification(DilBertPretrainedModel):
return
output
# hidden_state, pooled_output, (hidden_states), (attentions)
@
add_start_docstrings
(
"""DilBert Model with a `masked language modeling` head on top. """
,
DILBERT_START_DOCSTRING
,
DILBERT_INPUTS_DOCSTRING
)
class
DilBertForMaskedLM
(
DilBertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
(
DilBertForMaskedLM
,
self
).
__init__
(
config
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
encoder
=
DilBertModel
(
config
)
self
.
vocab_transform
=
nn
.
Linear
(
config
.
dim
,
config
.
dim
)
self
.
vocab_layer_norm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
vocab_projector
=
nn
.
Linear
(
config
.
dim
,
config
.
vocab_size
)
self
.
apply
(
self
.
init_weights
)
self
.
tie_weights
()
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
def
tie_weights_
(
self
):
"""
Tying the weights of the vocabulary projection to the base token embeddings.
"""
if
self
.
config
.
tie_weights
:
self
.
vocab_projector
.
weight
=
self
.
encoder
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
=
None
,
masked_lm_labels
:
torch
.
tensor
=
None
):
"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
Token ids.
attention_mask: torch.tensor(bs, seq_length)
Attention mask. Optional: If None, it's like there was no padding.
masked_lm_labels: torch.tensor(bs, seq_length)
The masked language modeling labels. Optional: If None, no loss is computed.
Outputs
-------
mlm_loss: torch.tensor(1,)
Masked Language Modeling loss to optimize.
Optional: only if `masked_lm_labels` is not None
prediction_logits: torch.tensor(bs, seq_length, voc_size)
Token prediction logits
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if `output_hidden_states`=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
"""
tfmr_output
=
self
.
encoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
hidden_states
=
tfmr_output
[
0
]
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_transform
(
hidden_states
)
# (bs, seq_length, dim)
prediction_logits
=
gelu
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_layer_norm
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_projector
(
prediction_logits
)
# (bs, seq_length, vocab_size)
outputs
=
(
prediction_logits
,
)
+
tfmr_output
[
2
:]
if
masked_lm_labels
is
not
None
:
mlm_loss
=
self
.
mlm_loss_fct
(
prediction_logits
.
view
(
-
1
,
prediction_logits
.
size
(
-
1
)),
masked_lm_labels
.
view
(
-
1
))
outputs
=
(
mlm_loss
,)
+
outputs
return
outputs
# (mlm_loss), prediction_logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""DilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
DILBERT_START_DOCSTRING
,
DILBERT_INPUTS_DOCSTRING
)
class
DilBertForSequenceClassification
(
DilBertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
(
DilBertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
dilbert
=
DilBertModel
(
config
)
self
.
pre_classifier
=
nn
.
Linear
(
config
.
dim
,
config
.
dim
)
self
.
classifier
=
nn
.
Linear
(
config
.
dim
,
config
.
num_labels
)
self
.
dropout
=
nn
.
Dropout
(
config
.
seq_classif_dropout
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
=
None
,
labels
:
torch
.
tensor
=
None
):
"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
Token ids.
attention_mask: torch.tensor(bs, seq_length)
Attention mask. Optional: If None, it's like there was no padding.
labels: torch.tensor(bs,)
Classification Labels: Optional: If None, no loss will be computed.
Outputs
-------
loss: torch.tensor(1)
Sequence classification loss.
Optional: Is computed only if `labels` is not None.
logits: torch.tensor(bs, seq_length)
Classification (or regression if config.num_labels==1) scores
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if `output_hidden_states`=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
"""
dilbert_output
=
self
.
dilbert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
pooled_output
=
dilbert_output
[
1
]
# (bs, dim)
pooled_output
=
self
.
pre_classifier
(
pooled_output
)
# (bs, dim)
pooled_output
=
nn
.
ReLU
()(
pooled_output
)
# (bs, dim)
pooled_output
=
self
.
dropout
(
pooled_output
)
# (bs, dim)
logits
=
self
.
classifier
(
pooled_output
)
# (bs, dim)
outputs
=
(
logits
,)
+
dilbert_output
[
2
:]
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
loss_fct
=
nn
.
MSELoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
),
labels
.
view
(
-
1
))
else
:
loss_fct
=
nn
.
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""DilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """
,
DILBERT_START_DOCSTRING
,
DILBERT_INPUTS_DOCSTRING
)
class
DilBertForQuestionAnswering
(
DilBertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
(
DilBertForQuestionAnswering
,
self
).
__init__
(
config
)
...
...
@@ -345,16 +586,51 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
attention_mask
:
torch
.
tensor
=
None
,
start_positions
:
torch
.
tensor
=
None
,
end_positions
:
torch
.
tensor
=
None
):
_
,
_
,
hidden_states
=
self
.
dilbert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# _, _, (bs, max_query_len, dim)
"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
Token ids.
attention_mask: torch.tensor(bs, seq_length)
Attention mask. Optional: If None, it's like there was no padding.
start_positions: torch,tensor(bs)
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Optional: if None, no loss is computed.
end_positions: torch,tensor(bs)
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Optional: if None, no loss is computed.
Outputs
-------
loss: torch.tensor(1)
Question answering loss.
Optional: Is computed only if `start_positions` and `end_positions` are not None.
start_logits: torch.tensor(bs, seq_length)
Span-start scores.
end_logits: torch.tensor(bs, seq_length)
Spand-end scores.
all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
Tuple of length n_layers with the hidden states from each layer.
Optional: only if `output_hidden_states`=True
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
"""
dilbert_output
=
self
.
dilbert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
hidden_states
=
dilbert_output
[
0
]
# (bs, max_query_len, dim)
hidden_states
=
self
.
dropout
(
hidden_states
)
# (bs, max_query_len, dim)
logits
=
self
.
qa_outputs
(
hidden_states
)
# (bs, max_query_len, 2)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
=
start_logits
.
squeeze
(
-
1
)
# (bs, max_query_len)
end_logits
=
end_logits
.
squeeze
(
-
1
)
# (bs, max_query_len)
outputs
=
(
start_logits
,
end_logits
,)
+
(
hidden_states
,)
outputs
=
(
start_logits
,
end_logits
,)
+
dilbert_output
[
2
:]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
...
...
@@ -372,4 +648,4 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, hidden_states
\ No newline at end of file
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment