Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7e4488ae
Commit
7e4488ae
authored
Jun 22, 2022
by
Jiayu Ye
Committed by
A. Unique TensorFlower
Jun 22, 2022
Browse files
Internal change
PiperOrigin-RevId: 456624049
parent
a6d78dd4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
7 deletions
+15
-7
official/nlp/modeling/models/t5.py
official/nlp/modeling/models/t5.py
+9
-5
official/nlp/modeling/models/t5_test.py
official/nlp/modeling/models/t5_test.py
+6
-2
No files found.
official/nlp/modeling/models/t5.py
View file @
7e4488ae
...
...
@@ -1250,7 +1250,10 @@ class Decoder(Module):
training: Whether it is training pass, affecting dropouts.
Returns:
output of a transformer encoder.
output of a transformer encoder including
1. logits: Logits for each word in the vocab.
2. raw_logits: Logits along the moded dimension.
3. cache: Used for decoding in inference mode.
"""
cfg
=
self
.
config
# Casts inputs to the dtype.
...
...
@@ -1298,7 +1301,7 @@ class Decoder(Module):
logits
=
logits
/
math
.
sqrt
(
cfg
.
d_model
)
else
:
logits
=
self
.
logits_dense
(
output
)
return
logits
,
cache
return
dict
(
logits
=
logits
,
cache
=
cache
,
raw_logits
=
output
)
class
T5Transformer
(
Module
):
...
...
@@ -1392,7 +1395,7 @@ class T5Transformer(Module):
cache
=
None
,
max_decode_len
=
None
,
decode
=
False
,
training
=
False
):
training
=
False
)
->
Dict
[
str
,
tf
.
Tensor
]
:
eligible_inputs_array
=
[]
if
encoder_input_tokens
is
not
None
:
eligible_inputs
=
tf
.
cast
(
...
...
@@ -1449,7 +1452,7 @@ class T5Transformer(Module):
decoder_mask
=
(
1.0
-
tf
.
cast
(
decoder_mask
,
self
.
compute_dtype
))
*
-
1e9
encoder_decoder_mask
=
(
1.0
-
tf
.
cast
(
encoder_decoder_mask
,
self
.
compute_dtype
))
*
-
1e9
logits
,
cache
=
self
.
decoder
(
outputs
=
self
.
decoder
(
decoder_input_tokens
,
encoded
,
decode_position
=
decode_position
,
...
...
@@ -1459,7 +1462,8 @@ class T5Transformer(Module):
max_decode_len
=
max_decode_len
,
decode
=
decode
,
training
=
training
)
return
dict
(
logits
=
logits
,
encoded
=
encoded
,
cache
=
cache
)
outputs
[
"encoded"
]
=
encoded
return
outputs
@
tf
.
Module
.
with_name_scope
def
__call__
(
self
,
...
...
official/nlp/modeling/models/t5_test.py
View file @
7e4488ae
...
...
@@ -403,7 +403,9 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
batch_size
=
4
targets
=
tf
.
zeros
((
4
,
8
),
dtype
=
tf
.
int32
)
encoded
=
tf
.
zeros
((
4
,
8
,
config
.
d_model
),
dtype
=
tf
.
float32
)
logits
,
cache
=
decoder
(
targets
,
encoded
)
outputs
=
decoder
(
targets
,
encoded
)
logits
=
outputs
[
"logits"
]
cache
=
outputs
[
"cache"
]
self
.
assertEqual
(
logits
.
shape
,
(
4
,
8
,
config
.
vocab_size
))
cache
=
{}
...
...
@@ -412,13 +414,15 @@ class T5Test(tf.test.TestCase, parameterized.TestCase):
cache
[
1
]
=
_create_cache
(
batch_size
,
max_decode_len
,
config
.
num_heads
,
config
.
d_kv
)
targets
=
tf
.
zeros
((
4
,
1
),
dtype
=
tf
.
int32
)
logits
,
cache
=
decoder
(
outputs
=
decoder
(
targets
,
encoded
,
decode_position
=
2
,
cache
=
cache
,
decode
=
True
,
max_decode_len
=
max_decode_len
)
logits
=
outputs
[
"logits"
]
cache
=
outputs
[
"cache"
]
self
.
assertEqual
(
logits
.
shape
,
(
batch_size
,
1
,
config
.
vocab_size
))
for
entry
in
cache
.
values
():
for
tensor
in
entry
.
values
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment