Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0534bae0
Commit
0534bae0
authored
Aug 08, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 466203688
parent
8e8a0713
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
15 deletions
+53
-15
official/nlp/modeling/ops/decoding_module.py
official/nlp/modeling/ops/decoding_module.py
+24
-6
official/nlp/modeling/ops/decoding_module_test.py
official/nlp/modeling/ops/decoding_module_test.py
+1
-0
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+28
-9
No files found.
official/nlp/modeling/ops/decoding_module.py
View file @
0534bae0
...
...
@@ -22,7 +22,7 @@ import tensorflow as tf
from
tensorflow.python.framework
import
dtypes
from
official.modeling
import
tf_utils
Output
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]
Output
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
Optional
[
tf
.
Tensor
]
]
InternalState
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
,
Dict
]
InitialState
=
Tuple
[
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]
...
...
@@ -46,6 +46,10 @@ class StateKeys:
# the previous iteration.
ALIVE_CACHE
=
"ALIVE_CACHE"
# The initial model state/cache after model processing the initial token.
# The cache will be filled if extra_cache_output is true.
INITIAL_OUTPUT_CACHE
=
"INITIAL_OUTPUT_CACHE"
# Top finished sequences for each batch item.
# Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
# shorter than CUR_INDEX + 1 are padded with 0s.
...
...
@@ -109,7 +113,8 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
def
__init__
(
self
,
length_normalization_fn
:
Callable
[[
int
,
tf
.
DType
],
float
],
dtype
:
tf
.
DType
=
tf
.
float32
,
decoding_name
:
Optional
[
str
]
=
None
):
decoding_name
:
Optional
[
str
]
=
None
,
extra_cache_output
:
bool
=
False
):
"""Initialize the Decoding Module.
Args:
...
...
@@ -118,24 +123,26 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
decoding_name: an optional name for the decoding loop tensors.
extra_cache_output: If true, the first cache will be in the states.
"""
self
.
length_normalization_fn
=
length_normalization_fn
self
.
dtype
=
tf
.
as_dtype
(
dtype
)
self
.
decoding_name
=
decoding_name
def
generate
(
self
,
initial_ids
:
tf
.
Tensor
,
def
generate
(
self
,
initial_ids
:
tf
.
Tensor
,
initial_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Output
:
"""Implements the decoding strategy (beam_search or sampling).
Args:
initial_ids: initial ids to pass into the symbols_to_logits_fn.
int tensor
with shape [batch_size, 1]
initial_ids: initial ids to pass into the symbols_to_logits_fn.
int tensor
with shape [batch_size, 1]
initial_cache: dictionary for caching model outputs from previous step.
Returns:
Tuple of tensors representing
finished_sequence: shape [batch, max_seq_length]
finished_scores: [batch]
first_cache: The cache after init token
"""
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
...
...
@@ -163,6 +170,17 @@ class DecodingModule(tf.Module, metaclass=abc.ABCMeta):
}
new_state
.
update
(
alive_state
)
new_state
.
update
(
finished_state
)
if
self
.
extra_cache_output
:
i
=
state
[
StateKeys
.
CUR_INDEX
]
old_cache
=
state
[
StateKeys
.
INITIAL_OUTPUT_CACHE
]
def
update_with_cache
(
new_state
,
cache
):
"""Updates new_state with cache."""
new_state
.
update
({
StateKeys
.
INITIAL_OUTPUT_CACHE
:
cache
})
tf
.
cond
(
tf
.
equal
(
i
,
0
),
lambda
:
update_with_cache
(
new_state
,
new_cache
),
lambda
:
update_with_cache
(
new_state
,
old_cache
))
return
[
new_state
]
finished_state
=
tf
.
nest
.
map_structure
(
...
...
official/nlp/modeling/ops/decoding_module_test.py
View file @
0534bae0
...
...
@@ -29,6 +29,7 @@ class TestSubclass(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def
__init__
(
self
,
length_normalization_fn
=
length_normalization
,
extra_cache_output
=
True
,
dtype
=
tf
.
float32
):
super
(
TestSubclass
,
self
).
__init__
(
length_normalization_fn
=
length_normalization
,
dtype
=
dtype
)
...
...
official/nlp/modeling/ops/sampling_module.py
View file @
0534bae0
...
...
@@ -163,7 +163,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
sample_temperature
=
0.0
,
enable_greedy
:
bool
=
True
,
dtype
:
tf
.
DType
=
tf
.
float32
,
decoding_name
:
Optional
[
str
]
=
None
):
decoding_name
:
Optional
[
str
]
=
None
,
extra_cache_output
:
bool
=
False
):
"""Initialize sampling module."""
self
.
symbols_to_logits_fn
=
symbols_to_logits_fn
self
.
length_normalization_fn
=
length_normalization_fn
...
...
@@ -178,10 +179,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
sample_temperature
,
dtype
=
tf
.
float32
)
self
.
enable_greedy
=
enable_greedy
self
.
decoding_name
=
decoding_name
self
.
extra_cache_output
=
extra_cache_output
super
(
SamplingModule
,
self
).
__init__
(
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
,
decoding_name
=
decoding_name
)
decoding_name
=
decoding_name
,
extra_cache_output
=
extra_cache_output
)
def
_grow_alive_seq
(
self
,
state
:
Dict
[
str
,
Any
],
...
...
@@ -300,16 +303,14 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module
.
StateKeys
.
CUR_INDEX
:
tf
.
TensorShape
([]),
decoding_module
.
StateKeys
.
ALIVE_SEQ
:
tf
.
TensorShape
(
[
batch_size
,
self
.
max_decode_length
+
1
]),
tf
.
TensorShape
([
batch_size
,
self
.
max_decode_length
+
1
]),
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
batch_size
,
1
]),
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
tf
.
nest
.
map_structure
(
lambda
state
:
state
.
get_shape
(),
alive_cache
),
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
(
[
batch_size
,
self
.
max_decode_length
+
1
]),
tf
.
TensorShape
([
batch_size
,
self
.
max_decode_length
+
1
]),
decoding_module
.
StateKeys
.
FINISHED_SCORES
:
tf
.
TensorShape
([
batch_size
,
1
]),
decoding_module
.
StateKeys
.
FINISHED_FLAGS
:
...
...
@@ -324,9 +325,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module
.
StateKeys
.
ALIVE_LOG_PROBS
:
tf
.
TensorShape
([
None
,
1
]),
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
tf
.
nest
.
map_structure
(
decoding_module
.
get_shape_keep_last_dim
,
alive_cache
),
tf
.
nest
.
map_structure
(
decoding_module
.
get_shape_keep_last_dim
,
alive_cache
),
decoding_module
.
StateKeys
.
FINISHED_SEQ
:
tf
.
TensorShape
([
None
,
None
]),
decoding_module
.
StateKeys
.
FINISHED_SCORES
:
...
...
@@ -335,6 +335,22 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
tf
.
TensorShape
([
None
,
1
])
}
if
self
.
extra_cache_output
:
state
.
update
(
{
decoding_module
.
StateKeys
.
INITIAL_OUTPUT_CACHE
:
alive_cache
})
if
self
.
padded_decode
:
state_shape_invariants
.
update
({
decoding_module
.
StateKeys
.
INITIAL_OUTPUT_CACHE
:
tf
.
nest
.
map_structure
(
lambda
state
:
state
.
get_shape
(),
alive_cache
)
})
else
:
state_shape_invariants
.
update
({
decoding_module
.
StateKeys
.
INITIAL_OUTPUT_CACHE
:
tf
.
nest
.
map_structure
(
decoding_module
.
get_shape_keep_last_dim
,
alive_cache
),
})
return
state
,
state_shape_invariants
def
_get_new_alive_state
(
self
,
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
...
...
@@ -428,6 +444,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
finished_scores
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
if
self
.
extra_cache_output
:
return
finished_seq
,
finished_scores
,
finished_state
[
decoding_module
.
StateKeys
.
INITIAL_OUTPUT_CACHE
]
return
finished_seq
,
finished_scores
def
_continue_search
(
self
,
state
)
->
tf
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment