Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0cd28352
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "610cb106a216cfb99d840648b576f9502189e4d1"
Commit
0cd28352
authored
Aug 27, 2019
by
LysandreJik
Browse files
Attempt to fix head index
parent
c85b5db6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
0 deletions
+3
-0
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+3
-0
No files found.
pytorch_transformers/modeling_gpt2.py
View file @
0cd28352
...
@@ -233,12 +233,14 @@ class Attention(nn.Module):
...
@@ -233,12 +233,14 @@ class Attention(nn.Module):
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
pruned_heads
=
[]
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
for
head
in
heads
:
for
head
in
heads
:
head
-=
len
(
list
(
filter
(
lambda
h
:
h
<
head
,
self
.
pruned_heads
)))
mask
[
head
]
=
0
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
@@ -249,6 +251,7 @@ class Attention(nn.Module):
...
@@ -249,6 +251,7 @@ class Attention(nn.Module):
# Update hyper params
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
self
.
pruned_heads
.
extend
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
w
=
torch
.
matmul
(
q
,
k
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment