Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b29fe6b7
Commit
b29fe6b7
authored
May 24, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 24, 2020
Browse files
[Clean up]: remove is_v2() check inside transformer.
PiperOrigin-RevId: 312988874
parent
09d3c74a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
38 deletions
+7
-38
official/nlp/transformer/beam_search.py
official/nlp/transformer/beam_search.py
+5
-11
official/nlp/transformer/data_pipeline.py
official/nlp/transformer/data_pipeline.py
+2
-7
official/nlp/transformer/misc.py
official/nlp/transformer/misc.py
+0
-20
No files found.
official/nlp/transformer/beam_search.py
View file @
b29fe6b7
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.transformer
import
beam_search_v1
as
v1
from
official.nlp.transformer
import
beam_search_v1
as
v1
from
official.nlp.transformer
import
misc
_StateKeys
=
v1
.
_StateKeys
# pylint: disable=protected-access
_StateKeys
=
v1
.
_StateKeys
# pylint: disable=protected-access
...
@@ -52,8 +51,8 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
...
@@ -52,8 +51,8 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
# Account for corner case where there are no finished sequences for a
# Account for corner case where there are no finished sequences for a
# particular batch item. In that case, return alive sequences for that batch
# particular batch item. In that case, return alive sequences for that batch
# item.
# item.
finished_seq
=
tf
.
compat
.
v2
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
compat
.
v2
.
where
(
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
score_cond
,
finished_scores
,
alive_log_probs
)
return
finished_seq
,
finished_scores
return
finished_seq
,
finished_scores
...
@@ -102,14 +101,9 @@ def sequence_beam_search(symbols_to_logits_fn,
...
@@ -102,14 +101,9 @@ def sequence_beam_search(symbols_to_logits_fn,
batch_size
=
(
batch_size
=
(
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
initial_ids
.
shape
.
as_list
()[
0
]
if
padded_decode
else
tf
.
shape
(
initial_ids
)[
0
])
tf
.
shape
(
initial_ids
)[
0
])
if
misc
.
is_v2
():
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
sbs
=
SequenceBeamSearchV2
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
)
padded_decode
,
dtype
)
else
:
sbs
=
v1
.
SequenceBeamSearch
(
symbols_to_logits_fn
,
vocab_size
,
batch_size
,
beam_size
,
alpha
,
max_decode_length
,
eos_id
,
padded_decode
,
dtype
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
return
sbs
.
search
(
initial_ids
,
initial_cache
)
...
...
official/nlp/transformer/data_pipeline.py
View file @
b29fe6b7
...
@@ -56,7 +56,6 @@ import os
...
@@ -56,7 +56,6 @@ import os
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.transformer
import
misc
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
# Buffer size for reading records from a TFRecord file. Each training file is
# Buffer size for reading records from a TFRecord file. Each training file is
...
@@ -313,9 +312,5 @@ def eval_input_fn(params, ctx=None):
...
@@ -313,9 +312,5 @@ def eval_input_fn(params, ctx=None):
def
map_data_for_transformer_fn
(
x
,
y
):
def
map_data_for_transformer_fn
(
x
,
y
):
"""Maps data for training, and handles weried behaviors for different vers."""
"""Maps data for training, and handles weried behaviors for different vers."""
# Will transform input x and targets y into tuple(x, y) as new model inputs.
# Will transform input x and targets y into tuple(x, y) as new model inputs.
if
misc
.
is_v2
():
# For TF v2, the 2nd parameter is omitted to make Keras training work.
# For TF v2, the 2nd parameter is omitted to make Keras training work.
return
((
x
,
y
),)
return
((
x
,
y
),)
else
:
# For TF v1, Keras requires a dummy placeholder as the 2nd parameter.
return
((
x
,
y
),
tf
.
constant
(
0.0
))
official/nlp/transformer/misc.py
View file @
b29fe6b7
...
@@ -22,10 +22,6 @@ from __future__ import print_function
...
@@ -22,10 +22,6 @@ from __future__ import print_function
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
# TODO(tianlin) Import internal library. Remove this when some functions for
# different TF versions are fixed.
from
tensorflow.python
import
tf2
as
tf2_internal
from
official.nlp.transformer
import
model_params
from
official.nlp.transformer
import
model_params
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
...
@@ -39,11 +35,6 @@ PARAMS_MAP = {
...
@@ -39,11 +35,6 @@ PARAMS_MAP = {
}
}
def
is_v2
():
"""Returns whether it is v2."""
return
tf2_internal
.
enabled
()
def
get_model_params
(
param_set
,
num_gpus
):
def
get_model_params
(
param_set
,
num_gpus
):
"""Gets predefined model params."""
"""Gets predefined model params."""
if
num_gpus
>
1
:
if
num_gpus
>
1
:
...
@@ -78,17 +69,6 @@ def define_transformer_flags():
...
@@ -78,17 +69,6 @@ def define_transformer_flags():
fp16_implementation
=
True
fp16_implementation
=
True
)
)
# Additional performance flags
# TODO(b/76028325): Remove when generic layout optimizer is ready.
flags
.
DEFINE_boolean
(
name
=
'enable_grappler_layout_optimizer'
,
default
=
True
,
help
=
'Enable Grappler layout optimizer. Currently Grappler can '
'de-optimize fp16 graphs by forcing NCHW layout for all '
'convolutions and batch normalizations, and this flag allows to '
'disable it.'
)
flags_core
.
define_benchmark
()
flags_core
.
define_benchmark
()
flags_core
.
define_device
(
tpu
=
True
)
flags_core
.
define_device
(
tpu
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment