Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a8ad8304
Commit
a8ad8304
authored
Aug 28, 2019
by
VictorSanh
Browse files
fix bugs
parent
60c984da
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
12 deletions
+14
-12
pytorch_transformers/modeling_dilbert.py
pytorch_transformers/modeling_dilbert.py
+14
-12
No files found.
pytorch_transformers/modeling_dilbert.py
View file @
a8ad8304
...
@@ -60,7 +60,7 @@ class DilBertConfig(PretrainedConfig):
...
@@ -60,7 +60,7 @@ class DilBertConfig(PretrainedConfig):
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
activation
=
'gelu'
,
activation
=
'gelu'
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
tie_weights
=
True
,
tie_weights
_
=
True
,
**
kwargs
):
**
kwargs
):
super
(
DilBertConfig
,
self
).
__init__
(
**
kwargs
)
super
(
DilBertConfig
,
self
).
__init__
(
**
kwargs
)
...
@@ -82,7 +82,7 @@ class DilBertConfig(PretrainedConfig):
...
@@ -82,7 +82,7 @@ class DilBertConfig(PretrainedConfig):
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation
=
activation
self
.
activation
=
activation
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
tie_weights
=
tie_weights
self
.
tie_weights
_
=
tie_weights
_
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -274,13 +274,15 @@ class TransformerBlock(nn.Module):
...
@@ -274,13 +274,15 @@ class TransformerBlock(nn.Module):
sa_output
=
self
.
attention
(
query
=
x
,
key
=
x
,
value
=
x
,
mask
=
attn_mask
)
sa_output
=
self
.
attention
(
query
=
x
,
key
=
x
,
value
=
x
,
mask
=
attn_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
sa_output
,
sa_weights
=
sa_output
# (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
sa_output
,
sa_weights
=
sa_output
# (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else
:
sa_output
=
sa_output
[
0
]
sa_output
=
self
.
sa_layer_norm
(
sa_output
+
x
)
# (bs, seq_length, dim)
sa_output
=
self
.
sa_layer_norm
(
sa_output
+
x
)
# (bs, seq_length, dim)
# Feed Forward Network
# Feed Forward Network
ffn_output
=
self
.
ffn
(
sa_output
)
# (bs, seq_length, dim)
ffn_output
=
self
.
ffn
(
sa_output
)
# (bs, seq_length, dim)
ffn_output
=
self
.
output_layer_norm
(
ffn_output
+
sa_output
)
# (bs, seq_length, dim)
ffn_output
=
self
.
output_layer_norm
(
ffn_output
+
sa_output
)
# (bs, seq_length, dim)
output
=
(
ffn_output
)
output
=
(
ffn_output
,
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
output
=
(
sa_weights
,)
+
output
output
=
(
sa_weights
,)
+
output
return
output
return
output
...
@@ -468,36 +470,36 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
...
@@ -468,36 +470,36 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
encod
er
=
DilBertModel
(
config
)
self
.
dilb
er
t
=
DilBertModel
(
config
)
self
.
vocab_transform
=
nn
.
Linear
(
config
.
dim
,
config
.
dim
)
self
.
vocab_transform
=
nn
.
Linear
(
config
.
dim
,
config
.
dim
)
self
.
vocab_layer_norm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
vocab_layer_norm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
vocab_projector
=
nn
.
Linear
(
config
.
dim
,
config
.
vocab_size
)
self
.
vocab_projector
=
nn
.
Linear
(
config
.
dim
,
config
.
vocab_size
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
self
.
tie_weights
_
()
self
.
tie_weights
()
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
def
tie_weights
_
(
self
):
def
tie_weights
(
self
):
"""
"""
Tying the weights of the vocabulary projection to the base token embeddings.
Tying the weights of the vocabulary projection to the base token embeddings.
"""
"""
if
self
.
config
.
tie_weights
:
if
self
.
config
.
tie_weights
_
:
self
.
vocab_projector
.
weight
=
self
.
encod
er
.
embeddings
.
word_embeddings
.
weight
self
.
vocab_projector
.
weight
=
self
.
dilb
er
t
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
tensor
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
=
None
,
attention_mask
:
torch
.
tensor
=
None
,
masked_lm_labels
:
torch
.
tensor
=
None
):
masked_lm_labels
:
torch
.
tensor
=
None
):
tfmr
_output
=
self
.
encod
er
(
input_ids
=
input_ids
,
dlbrt
_output
=
self
.
dilb
er
t
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
attention_mask
=
attention_mask
)
hidden_states
=
tfmr
_output
[
0
]
# (bs, seq_length, dim)
hidden_states
=
dlbrt
_output
[
0
]
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_transform
(
hidden_states
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_transform
(
hidden_states
)
# (bs, seq_length, dim)
prediction_logits
=
gelu
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
gelu
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_layer_norm
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_layer_norm
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_projector
(
prediction_logits
)
# (bs, seq_length, vocab_size)
prediction_logits
=
self
.
vocab_projector
(
prediction_logits
)
# (bs, seq_length, vocab_size)
outputs
=
(
prediction_logits
,
)
+
tfmr
_output
[
2
:]
outputs
=
(
prediction_logits
,
)
+
dlbrt
_output
[
2
:]
if
masked_lm_labels
is
not
None
:
if
masked_lm_labels
is
not
None
:
mlm_loss
=
self
.
mlm_loss_fct
(
prediction_logits
.
view
(
-
1
,
prediction_logits
.
size
(
-
1
)),
mlm_loss
=
self
.
mlm_loss_fct
(
prediction_logits
.
view
(
-
1
,
prediction_logits
.
size
(
-
1
)),
masked_lm_labels
.
view
(
-
1
))
masked_lm_labels
.
view
(
-
1
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment