Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0c8e823b
"...resnet50_tensorflow.git" did not exist on "092a5461e4c6d272ffdeb26b940bec45f3019427"
Commit
0c8e823b
authored
Aug 29, 2019
by
LysandreJik
Browse files
Added patch to remaining models
parent
0cd28352
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
0 deletions
+9
-0
pytorch_transformers/modeling_bert.py
pytorch_transformers/modeling_bert.py
+3
-0
pytorch_transformers/modeling_openai.py
pytorch_transformers/modeling_openai.py
+3
-0
pytorch_transformers/modeling_xlm.py
pytorch_transformers/modeling_xlm.py
+3
-0
No files found.
pytorch_transformers/modeling_bert.py
View file @
0c8e823b
...
@@ -337,12 +337,14 @@ class BertAttention(nn.Module):
...
@@ -337,12 +337,14 @@ class BertAttention(nn.Module):
super
(
BertAttention
,
self
).
__init__
()
super
(
BertAttention
,
self
).
__init__
()
self
.
self
=
BertSelfAttention
(
config
)
self
.
self
=
BertSelfAttention
(
config
)
self
.
output
=
BertSelfOutput
(
config
)
self
.
output
=
BertSelfOutput
(
config
)
self
.
pruned_heads
=
[]
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
self
.
num_attention_heads
,
self
.
self
.
attention_head_size
)
mask
=
torch
.
ones
(
self
.
self
.
num_attention_heads
,
self
.
self
.
attention_head_size
)
for
head
in
heads
:
for
head
in
heads
:
head
-=
len
(
list
(
filter
(
lambda
h
:
h
<
head
,
self
.
pruned_heads
)))
mask
[
head
]
=
0
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
@@ -354,6 +356,7 @@ class BertAttention(nn.Module):
...
@@ -354,6 +356,7 @@ class BertAttention(nn.Module):
# Update hyper params
# Update hyper params
self
.
self
.
num_attention_heads
=
self
.
self
.
num_attention_heads
-
len
(
heads
)
self
.
self
.
num_attention_heads
=
self
.
self
.
num_attention_heads
-
len
(
heads
)
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
self
.
pruned_heads
.
extend
(
heads
)
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
self_outputs
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
self_outputs
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
...
...
pytorch_transformers/modeling_openai.py
View file @
0c8e823b
...
@@ -249,12 +249,14 @@ class Attention(nn.Module):
...
@@ -249,12 +249,14 @@ class Attention(nn.Module):
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
pruned_heads
=
[]
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
for
head
in
heads
:
for
head
in
heads
:
head
-=
len
(
list
(
filter
(
lambda
h
:
h
<
head
,
self
.
pruned_heads
)))
mask
[
head
]
=
0
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
@@ -265,6 +267,7 @@ class Attention(nn.Module):
...
@@ -265,6 +267,7 @@ class Attention(nn.Module):
# Update hyper params
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
self
.
pruned_heads
.
extend
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
w
=
torch
.
matmul
(
q
,
k
)
...
...
pytorch_transformers/modeling_xlm.py
View file @
0c8e823b
...
@@ -271,6 +271,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -271,6 +271,7 @@ class MultiHeadAttention(nn.Module):
self
.
k_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
k_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
v_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
v_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
pruned_heads
=
[]
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
attention_head_size
=
self
.
dim
//
self
.
n_heads
...
@@ -278,6 +279,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -278,6 +279,7 @@ class MultiHeadAttention(nn.Module):
return
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
for
head
in
heads
:
for
head
in
heads
:
head
-=
len
(
list
(
filter
(
lambda
h
:
h
<
head
,
self
.
pruned_heads
)))
mask
[
head
]
=
0
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
@@ -289,6 +291,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -289,6 +291,7 @@ class MultiHeadAttention(nn.Module):
# Update hyper params
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
pruned_heads
.
extend
(
heads
)
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment