Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8cb5ac1e
Commit
8cb5ac1e
authored
Mar 22, 2021
by
Poorva Potdar
Committed by
A. Unique TensorFlower
Mar 22, 2021
Browse files
Internal change
PiperOrigin-RevId: 364378436
parent
0e6f8848
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
13 deletions
+7
-13
official/nlp/modeling/ops/decoding_module.py
official/nlp/modeling/ops/decoding_module.py
+2
-9
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+5
-4
No files found.
official/nlp/modeling/ops/decoding_module.py
View file @
8cb5ac1e
...
@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple
...
@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
dtypes
from
official.modeling
import
tf_utils
Output
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]
Output
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]
InternalState
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
,
Dict
]
InternalState
=
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
,
Dict
]
...
@@ -64,15 +65,7 @@ def log_prob_from_logits(logits):
...
@@ -64,15 +65,7 @@ def log_prob_from_logits(logits):
def
shape_list
(
tensor
):
def
shape_list
(
tensor
):
"""Return a list of the tensor's shape, and ensure no None values in list."""
"""Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions)
return
tf_utils
.
get_shape_list
(
tensor
)
shape
=
tensor
.
get_shape
().
as_list
()
# Ensure that the shape values are not None
dynamic_shape
=
tf
.
shape
(
tensor
)
for
i
in
range
(
len
(
shape
)):
# pylint: disable=consider-using-enumerate
if
shape
[
i
]
is
None
:
shape
[
i
]
=
dynamic_shape
[
i
]
return
shape
def
get_shape_keep_last_dim
(
tensor
):
def
get_shape_keep_last_dim
(
tensor
):
...
...
official/nlp/modeling/ops/sampling_module.py
View file @
8cb5ac1e
...
@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p):
...
@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p):
"""
"""
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
sorted_indices
=
tf
.
argsort
(
logits
,
direction
=
"DESCENDING"
)
# Flatten logits as tf.gather on TPU needs axis to be compile time constant.
# Flatten logits as tf.gather on TPU needs axis to be compile time constant.
range_for_gather
=
tf
.
expand_dims
(
tf
.
range
(
0
,
logits
.
shape
[
0
]),
axis
=
1
)
logits_shape
=
decoding_module
.
shape_list
(
logits
)
range_for_gather
=
tf
.
tile
(
range_for_gather
*
logits
.
shape
[
1
],
range_for_gather
=
tf
.
expand_dims
(
tf
.
range
(
0
,
logits_shape
[
0
]),
axis
=
1
)
[
1
,
logits
.
shape
[
1
]])
+
sorted_indices
range_for_gather
=
tf
.
tile
(
range_for_gather
*
logits_shape
[
1
],
[
1
,
logits_shape
[
1
]])
+
sorted_indices
flattened_logits
=
tf
.
reshape
(
logits
,
[
-
1
])
flattened_logits
=
tf
.
reshape
(
logits
,
[
-
1
])
flattened_sorted_indices
=
tf
.
reshape
(
range_for_gather
,
[
-
1
])
flattened_sorted_indices
=
tf
.
reshape
(
range_for_gather
,
[
-
1
])
sorted_logits
=
tf
.
reshape
(
sorted_logits
=
tf
.
reshape
(
tf
.
gather
(
flattened_logits
,
flattened_sorted_indices
),
tf
.
gather
(
flattened_logits
,
flattened_sorted_indices
),
[
logits
.
shape
[
0
],
logits
.
shape
[
1
]])
[
logits
_
shape
[
0
],
logits
_
shape
[
1
]])
cumulative_probs
=
tf
.
cumsum
(
tf
.
nn
.
softmax
(
sorted_logits
,
axis
=-
1
),
axis
=-
1
)
cumulative_probs
=
tf
.
cumsum
(
tf
.
nn
.
softmax
(
sorted_logits
,
axis
=-
1
),
axis
=-
1
)
# Remove tokens with cumulative probability above the threshold.
# Remove tokens with cumulative probability above the threshold.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment