Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a8ad8304
Commit
a8ad8304
authored
Aug 28, 2019
by
VictorSanh
Browse files
fix bugs
parent
60c984da
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
12 deletions
+14
-12
pytorch_transformers/modeling_dilbert.py
pytorch_transformers/modeling_dilbert.py
+14
-12
No files found.
pytorch_transformers/modeling_dilbert.py
View file @
a8ad8304
...
@@ -60,7 +60,7 @@ class DilBertConfig(PretrainedConfig):
...
@@ -60,7 +60,7 @@ class DilBertConfig(PretrainedConfig):
attention_dropout
=
0.1
,
attention_dropout
=
0.1
,
activation
=
'gelu'
,
activation
=
'gelu'
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
tie_weights
=
True
,
tie_weights
_
=
True
,
**
kwargs
):
**
kwargs
):
super
(
DilBertConfig
,
self
).
__init__
(
**
kwargs
)
super
(
DilBertConfig
,
self
).
__init__
(
**
kwargs
)
...
@@ -82,7 +82,7 @@ class DilBertConfig(PretrainedConfig):
...
@@ -82,7 +82,7 @@ class DilBertConfig(PretrainedConfig):
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation
=
activation
self
.
activation
=
activation
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
tie_weights
=
tie_weights
self
.
tie_weights
_
=
tie_weights
_
else
:
else
:
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
"or the path to a pretrained model config file (str)"
)
...
@@ -274,13 +274,15 @@ class TransformerBlock(nn.Module):
...
@@ -274,13 +274,15 @@ class TransformerBlock(nn.Module):
sa_output
=
self
.
attention
(
query
=
x
,
key
=
x
,
value
=
x
,
mask
=
attn_mask
)
sa_output
=
self
.
attention
(
query
=
x
,
key
=
x
,
value
=
x
,
mask
=
attn_mask
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
sa_output
,
sa_weights
=
sa_output
# (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
sa_output
,
sa_weights
=
sa_output
# (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else
:
sa_output
=
sa_output
[
0
]
sa_output
=
self
.
sa_layer_norm
(
sa_output
+
x
)
# (bs, seq_length, dim)
sa_output
=
self
.
sa_layer_norm
(
sa_output
+
x
)
# (bs, seq_length, dim)
# Feed Forward Network
# Feed Forward Network
ffn_output
=
self
.
ffn
(
sa_output
)
# (bs, seq_length, dim)
ffn_output
=
self
.
ffn
(
sa_output
)
# (bs, seq_length, dim)
ffn_output
=
self
.
output_layer_norm
(
ffn_output
+
sa_output
)
# (bs, seq_length, dim)
ffn_output
=
self
.
output_layer_norm
(
ffn_output
+
sa_output
)
# (bs, seq_length, dim)
output
=
(
ffn_output
)
output
=
(
ffn_output
,
)
if
self
.
output_attentions
:
if
self
.
output_attentions
:
output
=
(
sa_weights
,)
+
output
output
=
(
sa_weights
,)
+
output
return
output
return
output
...
@@ -468,36 +470,36 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
...
@@ -468,36 +470,36 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
encod
er
=
DilBertModel
(
config
)
self
.
dilb
er
t
=
DilBertModel
(
config
)
self
.
vocab_transform
=
nn
.
Linear
(
config
.
dim
,
config
.
dim
)
self
.
vocab_transform
=
nn
.
Linear
(
config
.
dim
,
config
.
dim
)
self
.
vocab_layer_norm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
vocab_layer_norm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
vocab_projector
=
nn
.
Linear
(
config
.
dim
,
config
.
vocab_size
)
self
.
vocab_projector
=
nn
.
Linear
(
config
.
dim
,
config
.
vocab_size
)
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
self
.
tie_weights
_
()
self
.
tie_weights
()
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
mlm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
def
tie_weights
_
(
self
):
def
tie_weights
(
self
):
"""
"""
Tying the weights of the vocabulary projection to the base token embeddings.
Tying the weights of the vocabulary projection to the base token embeddings.
"""
"""
if
self
.
config
.
tie_weights
:
if
self
.
config
.
tie_weights
_
:
self
.
vocab_projector
.
weight
=
self
.
encod
er
.
embeddings
.
word_embeddings
.
weight
self
.
vocab_projector
.
weight
=
self
.
dilb
er
t
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
def
forward
(
self
,
input_ids
:
torch
.
tensor
,
input_ids
:
torch
.
tensor
,
attention_mask
:
torch
.
tensor
=
None
,
attention_mask
:
torch
.
tensor
=
None
,
masked_lm_labels
:
torch
.
tensor
=
None
):
masked_lm_labels
:
torch
.
tensor
=
None
):
tfmr
_output
=
self
.
encod
er
(
input_ids
=
input_ids
,
dlbrt
_output
=
self
.
dilb
er
t
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
attention_mask
=
attention_mask
)
hidden_states
=
tfmr
_output
[
0
]
# (bs, seq_length, dim)
hidden_states
=
dlbrt
_output
[
0
]
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_transform
(
hidden_states
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_transform
(
hidden_states
)
# (bs, seq_length, dim)
prediction_logits
=
gelu
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
gelu
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_layer_norm
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_layer_norm
(
prediction_logits
)
# (bs, seq_length, dim)
prediction_logits
=
self
.
vocab_projector
(
prediction_logits
)
# (bs, seq_length, vocab_size)
prediction_logits
=
self
.
vocab_projector
(
prediction_logits
)
# (bs, seq_length, vocab_size)
outputs
=
(
prediction_logits
,
)
+
tfmr
_output
[
2
:]
outputs
=
(
prediction_logits
,
)
+
dlbrt
_output
[
2
:]
if
masked_lm_labels
is
not
None
:
if
masked_lm_labels
is
not
None
:
mlm_loss
=
self
.
mlm_loss_fct
(
prediction_logits
.
view
(
-
1
,
prediction_logits
.
size
(
-
1
)),
mlm_loss
=
self
.
mlm_loss_fct
(
prediction_logits
.
view
(
-
1
,
prediction_logits
.
size
(
-
1
)),
masked_lm_labels
.
view
(
-
1
))
masked_lm_labels
.
view
(
-
1
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment