Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
33e72b08
Commit
33e72b08
authored
Dec 13, 2019
by
thomwolf
Browse files
fix inner dimensions for 3B/11B models
parent
f19dad61
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
28 deletions
+19
-28
transformers/modeling_t5.py
transformers/modeling_t5.py
+11
-16
transformers/modeling_tf_t5.py
transformers/modeling_tf_t5.py
+8
-12
No files found.
transformers/modeling_t5.py
View file @
33e72b08
...
...
@@ -30,7 +30,7 @@ from torch import nn
import
torch.nn.functional
as
F
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
from
.configuration_t5
import
T5Config
from
.file_utils
import
add_start_docstrings
,
DUMMY_INPUTS
,
DUMMY_MASK
...
...
@@ -191,28 +191,26 @@ class T5Attention(nn.Module):
self
.
output_attentions
=
config
.
output_attentions
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
d
im
=
config
.
d_model
self
.
d
_model
=
config
.
d_model
self
.
d_kv
=
config
.
d_kv
self
.
n_heads
=
config
.
num_heads
self
.
dropout
=
config
.
dropout_rate
assert
self
.
dim
%
self
.
n_heads
==
0
assert
self
.
dim
//
self
.
n_heads
==
self
.
d_kv
self
.
inner_dim
=
self
.
n_heads
*
self
.
d_kv
# Mesh TensorFlow initialization to avoid scaling before softmax
self
.
q
=
nn
.
Linear
(
self
.
d
im
,
self
.
dim
,
bias
=
False
)
self
.
k
=
nn
.
Linear
(
self
.
d
im
,
self
.
dim
,
bias
=
False
)
self
.
v
=
nn
.
Linear
(
self
.
d
im
,
self
.
dim
,
bias
=
False
)
self
.
o
=
nn
.
Linear
(
self
.
dim
,
self
.
d
im
,
bias
=
False
)
self
.
q
=
nn
.
Linear
(
self
.
d
_model
,
self
.
inner_
dim
,
bias
=
False
)
self
.
k
=
nn
.
Linear
(
self
.
d
_model
,
self
.
inner_
dim
,
bias
=
False
)
self
.
v
=
nn
.
Linear
(
self
.
d
_model
,
self
.
inner_
dim
,
bias
=
False
)
self
.
o
=
nn
.
Linear
(
self
.
inner_
dim
,
self
.
d
_model
,
bias
=
False
)
if
self
.
has_relative_attention_bias
:
self
.
relative_attention_bias
=
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
n_heads
)
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
mask
=
torch
.
ones
(
self
.
n_heads
,
self
.
d_kv
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
...
...
@@ -226,7 +224,7 @@ class T5Attention(nn.Module):
self
.
o
=
prune_linear_layer
(
self
.
o
,
index
,
dim
=
1
)
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
inner_dim
=
self
.
d_kv
*
self
.
n_heads
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
@
staticmethod
...
...
@@ -303,17 +301,14 @@ class T5Attention(nn.Module):
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
klen
=
kv
.
size
(
1
)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
dim_per_head
=
self
.
dim
//
n_heads
def
shape
(
x
):
""" projection """
return
x
.
view
(
bs
,
-
1
,
self
.
n_heads
,
dim_per_head
).
transpose
(
1
,
2
)
return
x
.
view
(
bs
,
-
1
,
self
.
n_heads
,
self
.
d_kv
).
transpose
(
1
,
2
)
def
unshape
(
x
):
""" compute context """
return
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
bs
,
-
1
,
self
.
n_heads
*
dim_per_head
)
return
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
bs
,
-
1
,
self
.
inner_dim
)
q
=
shape
(
self
.
q
(
input
))
# (bs, n_heads, qlen, dim_per_head)
if
kv
is
None
:
...
...
transformers/modeling_tf_t5.py
View file @
33e72b08
...
...
@@ -108,17 +108,16 @@ class TFT5Attention(tf.keras.layers.Layer):
self
.
output_attentions
=
config
.
output_attentions
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
d
im
=
config
.
d_model
self
.
d
_model
=
config
.
d_model
self
.
d_kv
=
config
.
d_kv
self
.
n_heads
=
config
.
num_heads
assert
self
.
dim
%
self
.
n_heads
==
0
assert
self
.
dim
//
self
.
n_heads
==
self
.
d_kv
self
.
inner_dim
=
self
.
n_heads
*
self
.
d_kv
# Mesh TensorFlow initialization to avoid scaling before softmax
self
.
q
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'q'
)
self
.
k
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'k'
)
self
.
v
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'v'
)
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d
im
,
use_bias
=
False
,
name
=
'o'
)
self
.
q
=
tf
.
keras
.
layers
.
Dense
(
self
.
inner_
dim
,
use_bias
=
False
,
name
=
'q'
)
self
.
k
=
tf
.
keras
.
layers
.
Dense
(
self
.
inner_
dim
,
use_bias
=
False
,
name
=
'k'
)
self
.
v
=
tf
.
keras
.
layers
.
Dense
(
self
.
inner_
dim
,
use_bias
=
False
,
name
=
'v'
)
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
d
_model
,
use_bias
=
False
,
name
=
'o'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
if
self
.
has_relative_attention_bias
:
...
...
@@ -199,17 +198,14 @@ class TFT5Attention(tf.keras.layers.Layer):
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
klen
=
shape_list
(
kv
)[
1
]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
dim_per_head
=
self
.
dim
//
n_heads
def
shape
(
x
):
""" projection """
return
tf
.
transpose
(
tf
.
reshape
(
x
,
(
bs
,
-
1
,
self
.
n_heads
,
dim_per_head
)),
perm
=
(
0
,
2
,
1
,
3
))
return
tf
.
transpose
(
tf
.
reshape
(
x
,
(
bs
,
-
1
,
self
.
n_heads
,
self
.
d_kv
)),
perm
=
(
0
,
2
,
1
,
3
))
def
unshape
(
x
):
""" compute context """
return
tf
.
reshape
(
tf
.
transpose
(
x
,
perm
=
(
0
,
2
,
1
,
3
)),
(
bs
,
-
1
,
self
.
n_heads
*
dim_per_head
))
return
tf
.
reshape
(
tf
.
transpose
(
x
,
perm
=
(
0
,
2
,
1
,
3
)),
(
bs
,
-
1
,
self
.
inner_dim
))
q
=
shape
(
self
.
q
(
input
))
# (bs, n_heads, qlen, dim_per_head)
if
kv
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment