Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
33e72b08
Commit
33e72b08
authored
Dec 13, 2019
by
thomwolf
Browse files
fix inner dimensions for 3B/11B models
parent
f19dad61
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
28 deletions
+19
-28
transformers/modeling_t5.py
transformers/modeling_t5.py
+11
-16
transformers/modeling_tf_t5.py
transformers/modeling_tf_t5.py
+8
-12
No files found.
transformers/modeling_t5.py
View file @
33e72b08
...
@@ -30,7 +30,7 @@ from torch import nn
...
@@ -30,7 +30,7 @@ from torch import nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
from
.configuration_t5
import
T5Config
from
.configuration_t5
import
T5Config
from
.file_utils
import
add_start_docstrings
,
DUMMY_INPUTS
,
DUMMY_MASK
from
.file_utils
import
add_start_docstrings
,
DUMMY_INPUTS
,
DUMMY_MASK
...
@@ -191,28 +191,26 @@ class T5Attention(nn.Module):
...
@@ -191,28 +191,26 @@ class T5Attention(nn.Module):
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
d
im
=
config
.
d_model
self
.
d
_model
=
config
.
d_model
self
.
d_kv
=
config
.
d_kv
self
.
d_kv
=
config
.
d_kv
self
.
n_heads
=
config
.
num_heads
self
.
n_heads
=
config
.
num_heads
self
.
dropout
=
config
.
dropout_rate
self
.
dropout
=
config
.
dropout_rate
assert
self
.
dim
%
self
.
n_heads
==
0
self
.
inner_dim
=
self
.
n_heads
*
self
.
d_kv
assert
self
.
dim
//
self
.
n_heads
==
self
.
d_kv
# Mesh TensorFlow initialization to avoid scaling before softmax
# Mesh TensorFlow initialization to avoid scaling before softmax
self
.
q
=
nn
.
Linear
(
self
.
d
im
,
self
.
dim
,
bias
=
False
)
self
.
q
=
nn
.
Linear
(
self
.
d
_model
,
self
.
inner_
dim
,
bias
=
False
)
self
.
k
=
nn
.
Linear
(
self
.
d
im
,
self
.
dim
,
bias
=
False
)
self
.
k
=
nn
.
Linear
(
self
.
d
_model
,
self
.
inner_
dim
,
bias
=
False
)
self
.
v
=
nn
.
Linear
(
self
.
d
im
,
self
.
dim
,
bias
=
False
)
self
.
v
=
nn
.
Linear
(
self
.
d
_model
,
self
.
inner_
dim
,
bias
=
False
)
self
.
o
=
nn
.
Linear
(
self
.
dim
,
self
.
d
im
,
bias
=
False
)
self
.
o
=
nn
.
Linear
(
self
.
inner_
dim
,
self
.
d
_model
,
bias
=
False
)
if
self
.
has_relative_attention_bias
:
if
self
.
has_relative_attention_bias
:
self
.
relative_attention_bias
=
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
n_heads
)
self
.
relative_attention_bias
=
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
n_heads
)
self
.
pruned_heads
=
set
()
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
mask
=
torch
.
ones
(
self
.
n_heads
,
self
.
d_kv
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
for
head
in
heads
:
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
...
@@ -226,7 +224,7 @@ class T5Attention(nn.Module):
...
@@ -226,7 +224,7 @@ class T5Attention(nn.Module):
self
.
o
=
prune_linear_layer
(
self
.
o
,
index
,
dim
=
1
)
self
.
o
=
prune_linear_layer
(
self
.
o
,
index
,
dim
=
1
)
# Update hyper params
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
inner_dim
=
self
.
d_kv
*
self
.
n_heads
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
@
staticmethod
@
staticmethod
...
@@ -303,17 +301,14 @@ class T5Attention(nn.Module):
...
@@ -303,17 +301,14 @@ class T5Attention(nn.Module):
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
else
:
klen
=
kv
.
size
(
1
)
klen
=
kv
.
size
(
1
)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
dim_per_head
=
self
.
dim
//
n_heads
def
shape
(
x
):
def
shape
(
x
):
""" projection """
""" projection """
return
x
.
view
(
bs
,
-
1
,
self
.
n_heads
,
dim_per_head
).
transpose
(
1
,
2
)
return
x
.
view
(
bs
,
-
1
,
self
.
n_heads
,
self
.
d_kv
).
transpose
(
1
,
2
)
def
unshape
(
x
):
def
unshape
(
x
):
""" compute context """
""" compute context """
return
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
bs
,
-
1
,
self
.
n_heads
*
dim_per_head
)
return
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
bs
,
-
1
,
self
.
inner_dim
)
q
=
shape
(
self
.
q
(
input
))
# (bs, n_heads, qlen, dim_per_head)
q
=
shape
(
self
.
q
(
input
))
# (bs, n_heads, qlen, dim_per_head)
if
kv
is
None
:
if
kv
is
None
:
...
...
transformers/modeling_tf_t5.py
View file @
33e72b08
...
@@ -108,17 +108,16 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -108,17 +108,16 @@ class TFT5Attention(tf.keras.layers.Layer):
self
.
output_attentions
=
config
.
output_attentions
self
.
output_attentions
=
config
.
output_attentions
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
d
im
=
config
.
d_model
self
.
d
_model
=
config
.
d_model
self
.
d_kv
=
config
.
d_kv
self
.
d_kv
=
config
.
d_kv
self
.
n_heads
=
config
.
num_heads
self
.
n_heads
=
config
.
num_heads
assert
self
.
dim
%
self
.
n_heads
==
0
self
.
inner_dim
=
self
.
n_heads
*
self
.
d_kv
assert
self
.
dim
//
self
.
n_heads
==
self
.
d_kv
# Mesh TensorFlow initialization to avoid scaling before softmax
# Mesh TensorFlow initialization to avoid scaling before softmax
self
.
q
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'q'
)
self
.
q
=
tf
.
keras
.
layers
.
Dense
(
self
.
inner_
dim
,
use_bias
=
False
,
name
=
'q'
)
self
.
k
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'k'
)
self
.
k
=
tf
.
keras
.
layers
.
Dense
(
self
.
inner_
dim
,
use_bias
=
False
,
name
=
'k'
)
self
.
v
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'v'
)
self
.
v
=
tf
.
keras
.
layers
.
Dense
(
self
.
inner_
dim
,
use_bias
=
False
,
name
=
'v'
)
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d
im
,
use_bias
=
False
,
name
=
'o'
)
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d
_model
,
use_bias
=
False
,
name
=
'o'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
if
self
.
has_relative_attention_bias
:
if
self
.
has_relative_attention_bias
:
...
@@ -199,17 +198,14 @@ class TFT5Attention(tf.keras.layers.Layer):
...
@@ -199,17 +198,14 @@ class TFT5Attention(tf.keras.layers.Layer):
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
else
:
klen
=
shape_list
(
kv
)[
1
]
klen
=
shape_list
(
kv
)[
1
]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
dim_per_head
=
self
.
dim
//
n_heads
def
shape
(
x
):
def
shape
(
x
):
""" projection """
""" projection """
return
tf
.
transpose
(
tf
.
reshape
(
x
,
(
bs
,
-
1
,
self
.
n_heads
,
dim_per_head
)),
perm
=
(
0
,
2
,
1
,
3
))
return
tf
.
transpose
(
tf
.
reshape
(
x
,
(
bs
,
-
1
,
self
.
n_heads
,
self
.
d_kv
)),
perm
=
(
0
,
2
,
1
,
3
))
def
unshape
(
x
):
def
unshape
(
x
):
""" compute context """
""" compute context """
return
tf
.
reshape
(
tf
.
transpose
(
x
,
perm
=
(
0
,
2
,
1
,
3
)),
(
bs
,
-
1
,
self
.
n_heads
*
dim_per_head
))
return
tf
.
reshape
(
tf
.
transpose
(
x
,
perm
=
(
0
,
2
,
1
,
3
)),
(
bs
,
-
1
,
self
.
inner_dim
))
q
=
shape
(
self
.
q
(
input
))
# (bs, n_heads, qlen, dim_per_head)
q
=
shape
(
self
.
q
(
input
))
# (bs, n_heads, qlen, dim_per_head)
if
kv
is
None
:
if
kv
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment