Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
96e7ee72
"docs/source/vscode:/vscode.git/clone" did not exist on "a09fe140c1c059baf05c4f97e5b4e83c719608db"
Unverified
Commit
96e7ee72
authored
Nov 27, 2019
by
Thomas Wolf
Committed by
GitHub
Nov 27, 2019
Browse files
Merge pull request #1740 from huggingface/fix-ctrl-past
Fix CTRL past
parents
3c28a2da
8da47b07
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+3
-2
No files found.
transformers/modeling_ctrl.py
View file @
96e7ee72
...
@@ -63,7 +63,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
...
@@ -63,7 +63,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
scaled_attention_logits
=
matmul_qk
/
np
.
sqrt
(
dk
)
scaled_attention_logits
=
matmul_qk
/
np
.
sqrt
(
dk
)
if
mask
is
not
None
:
if
mask
is
not
None
:
scaled_attention_logits
+=
(
mask
*
-
1e4
)
nd
,
ns
=
scaled_attention_logits
.
size
(
-
2
),
scaled_attention_logits
.
size
(
-
1
)
scaled_attention_logits
+=
(
mask
[
ns
-
nd
:
ns
,
:
ns
]
*
-
1e4
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# Apply the attention mask
# Apply the attention mask
...
@@ -373,7 +374,7 @@ class CTRLModel(CTRLPreTrainedModel):
...
@@ -373,7 +374,7 @@ class CTRLModel(CTRLPreTrainedModel):
inputs_embeds
=
self
.
w
(
input_ids
)
inputs_embeds
=
self
.
w
(
input_ids
)
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_shape
[
-
1
]
seq_len
=
input_shape
[
-
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
inputs_embeds
.
device
)
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
+
past_length
,
seq_len
+
past_length
),
1
).
to
(
inputs_embeds
.
device
)
inputs_embeds
*=
np
.
sqrt
(
self
.
d_model_size
)
inputs_embeds
*=
np
.
sqrt
(
self
.
d_model_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment