Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
da8a5778
Commit
da8a5778
authored
Apr 08, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 367463455
parent
9e3550e5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
560 additions
and
2 deletions
+560
-2
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+3
-1
official/nlp/projects/bigbird/encoder.py
official/nlp/projects/bigbird/encoder.py
+34
-1
official/nlp/projects/bigbird/recompute_grad.py
official/nlp/projects/bigbird/recompute_grad.py
+240
-0
official/nlp/projects/bigbird/recomputing_dropout.py
official/nlp/projects/bigbird/recomputing_dropout.py
+159
-0
official/nlp/projects/bigbird/stateless_dropout.py
official/nlp/projects/bigbird/stateless_dropout.py
+124
-0
No files found.
official/nlp/configs/encoders.py
View file @
da8a5778
...
@@ -138,6 +138,7 @@ class BigBirdEncoderConfig(hyperparams.Config):
...
@@ -138,6 +138,7 @@ class BigBirdEncoderConfig(hyperparams.Config):
type_vocab_size
:
int
=
16
type_vocab_size
:
int
=
16
initializer_range
:
float
=
0.02
initializer_range
:
float
=
0.02
embedding_width
:
Optional
[
int
]
=
None
embedding_width
:
Optional
[
int
]
=
None
use_gradient_checkpointing
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -296,7 +297,8 @@ def build_encoder(config: EncoderConfig,
...
@@ -296,7 +297,8 @@ def build_encoder(config: EncoderConfig,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
stddev
=
encoder_cfg
.
initializer_range
),
embedding_width
=
encoder_cfg
.
embedding_width
)
embedding_width
=
encoder_cfg
.
embedding_width
,
use_gradient_checkpointing
=
encoder_cfg
.
use_gradient_checkpointing
)
if
encoder_type
==
"xlnet"
:
if
encoder_type
==
"xlnet"
:
return
encoder_cls
(
return
encoder_cls
(
...
...
official/nlp/projects/bigbird/encoder.py
View file @
da8a5778
...
@@ -21,6 +21,30 @@ from official.modeling import activations
...
@@ -21,6 +21,30 @@ from official.modeling import activations
from
official.nlp
import
keras_nlp
from
official.nlp
import
keras_nlp
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.projects.bigbird
import
attention
from
official.nlp.projects.bigbird
import
attention
from
official.nlp.projects.bigbird
import
recompute_grad
from
official.nlp.projects.bigbird
import
recomputing_dropout
class
RecomputeTransformerLayer
(
layers
.
TransformerScaffold
):
"""Transformer layer that recomputes the forward pass during backpropagation."""
def
call
(
self
,
inputs
):
emb
,
mask
=
inputs
def
f
(
*
args
):
# recompute_grad can only handle tensor inputs. so we enumerate the
# nested input [emb, mask] as follows:
# args[0]: emb
# args[1]: mask[0] = band_mask
# args[2]: mask[1] = encoder_from_mask
# args[3]: mask[2] = encoder_to_mask
# args[4]: mask[3] = blocked_encoder_mask
x
=
super
(
RecomputeTransformerLayer
,
self
).
call
([
args
[
0
],
[
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
]]])
return
x
f
=
recompute_grad
.
recompute_grad
(
f
)
return
f
(
emb
,
*
mask
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
...
@@ -52,6 +76,8 @@ class BigBirdEncoder(tf.keras.Model):
...
@@ -52,6 +76,8 @@ class BigBirdEncoder(tf.keras.Model):
matrices in the shape of ['vocab_size', 'embedding_width'] and
matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
smaller than 'hidden_size').
use_gradient_checkpointing: Use gradient checkpointing to trade-off compute
for memory.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -69,10 +95,17 @@ class BigBirdEncoder(tf.keras.Model):
...
@@ -69,10 +95,17 @@ class BigBirdEncoder(tf.keras.Model):
attention_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
0.02
),
embedding_width
=
None
,
embedding_width
=
None
,
use_gradient_checkpointing
=
False
,
**
kwargs
):
**
kwargs
):
activation
=
tf
.
keras
.
activations
.
get
(
activation
)
activation
=
tf
.
keras
.
activations
.
get
(
activation
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
if
use_gradient_checkpointing
:
tf
.
keras
.
layers
.
Dropout
=
recomputing_dropout
.
RecomputingDropout
layer_cls
=
RecomputeTransformerLayer
else
:
layer_cls
=
layers
.
TransformerScaffold
self
.
_self_setattr_tracking
=
False
self
.
_self_setattr_tracking
=
False
self
.
_config_dict
=
{
self
.
_config_dict
=
{
'vocab_size'
:
vocab_size
,
'vocab_size'
:
vocab_size
,
...
@@ -148,7 +181,7 @@ class BigBirdEncoder(tf.keras.Model):
...
@@ -148,7 +181,7 @@ class BigBirdEncoder(tf.keras.Model):
encoder_outputs
=
[]
encoder_outputs
=
[]
attn_head_dim
=
hidden_size
//
num_attention_heads
attn_head_dim
=
hidden_size
//
num_attention_heads
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
layer
=
layer
s
.
TransformerScaffold
(
layer
=
layer
_cls
(
num_attention_heads
,
num_attention_heads
,
intermediate_size
,
intermediate_size
,
activation
,
activation
,
...
...
official/nlp/projects/bigbird/recompute_grad.py
0 → 100644
View file @
da8a5778
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for rematerialization.
Incubates a version of tf.recompute_grad that is XLA compatible.
"""
import
collections
import
os
import
threading
from
typing
import
Deque
,
List
,
NamedTuple
,
Optional
,
Sequence
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
class
RecomputeContext
(
NamedTuple
(
'RecomputeContext'
,
[
(
'is_recomputing'
,
bool
),
(
'seed'
,
tf
.
Tensor
),
(
'children'
,
Deque
[
'RecomputeContext'
]),
])):
"""Context for recomputation.
Attributes:
is_recomputing: Whether we are in a recomputation phase.
seed: Scalar integer tensor that should be used with stateless random ops
for deterministic behavior and correct computation of the gradient.
children: Nested `RecomputeContext` instances. Used internally by
`recompute_grad` to track nested instances of `RecomputeContext`.
"""
def
__enter__
(
self
):
return
_context_stack
.
push
(
self
)
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
_context_stack
.
pop
(
self
)
# Simplified version of `_DefaultStack` in
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/ops.py.
class
_ContextStack
(
threading
.
local
):
"""A thread-local stack for providing implicit recompute contexts."""
def
__init__
(
self
):
super
(
_ContextStack
,
self
).
__init__
()
self
.
_stack
=
[]
def
top
(
self
)
->
Optional
[
RecomputeContext
]:
return
self
.
_stack
[
-
1
]
if
self
.
_stack
else
None
def
push
(
self
,
context
:
RecomputeContext
):
self
.
_stack
.
append
(
context
)
return
context
def
pop
(
self
,
context
:
RecomputeContext
):
if
self
.
_stack
[
-
1
]
is
not
context
:
raise
AssertionError
(
'Nesting violated for RecomputeContext.'
)
self
.
_stack
.
pop
()
_context_stack
=
_ContextStack
()
def
get_recompute_context
()
->
Optional
[
RecomputeContext
]:
"""Returns the current recomputing context if it exists."""
return
_context_stack
.
top
()
# Adapted from
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_util.py.
def
_get_containing_xla_context
(
graph
:
tf
.
Graph
)
->
Optional
[
object
]:
"""Returns the first ancestor `XLAControlFlowContext` in the `graph`."""
ctxt
=
graph
.
_get_control_flow_context
()
# pylint: disable=protected-access
while
ctxt
:
if
ctxt
.
IsXLAContext
():
return
ctxt
ctxt
=
ctxt
.
outer_context
return
None
def
_in_xla_context
(
graph
:
Optional
[
tf
.
Graph
]
=
None
)
->
bool
:
"""Detects whether we are in an XLA context."""
if
'--tf_xla_auto_jit=2'
in
os
.
environ
.
get
(
'TF_XLA_FLAGS'
,
''
):
return
True
graph
=
tf
.
compat
.
v1
.
get_default_graph
()
if
graph
is
None
else
graph
while
True
:
if
_get_containing_xla_context
(
graph
)
is
not
None
:
return
True
try
:
graph
=
graph
.
outer_graph
except
AttributeError
:
return
False
def
_force_data_dependency
(
first_compute
:
Sequence
[
tf
.
Tensor
],
then_compute
:
Sequence
[
tf
.
Tensor
])
->
List
[
tf
.
Tensor
]:
"""Force all of `then_compute` to depend on all of `first_compute`.
Uses a dummy data dependency, which is useful when running on TPUs because
XLA ignores control dependencies. Only supports float arguments.
Args:
first_compute: Sequence of `Tensor`s to be executed before `then_compute`.
then_compute: Sequence of `Tensor`s to executed after `first_compute`.
Returns:
Sequence of `Tensor`s with same length of `then_compute`.
Raises:
ValueError: if ranks are unknown or types are not floating.
"""
def
_first_element
(
x
):
if
x
.
shape
.
ndims
is
None
:
raise
ValueError
(
'Rank of Tensor %s must be known'
%
x
)
ndims
=
x
.
shape
.
ndims
begin
=
tf
.
zeros
(
ndims
,
dtype
=
tf
.
int32
)
size
=
tf
.
ones
(
ndims
,
dtype
=
tf
.
int32
)
return
tf
.
reshape
(
tf
.
slice
(
x
,
begin
,
size
),
[])
first_compute_sum
=
tf
.
add_n
(
[
_first_element
(
x
)
for
x
in
first_compute
if
x
is
not
None
])
dtype
=
first_compute_sum
.
dtype
if
not
dtype
.
is_floating
:
raise
ValueError
(
'_force_data_dependency only supports floating dtypes.'
)
zero
=
np
.
finfo
(
dtype
.
as_numpy_dtype
).
tiny
*
first_compute_sum
return
[
x
+
tf
.
cast
(
zero
,
x
.
dtype
)
if
x
is
not
None
else
None
for
x
in
then_compute
]
def
_make_seed_if_none
(
seed
:
Optional
[
tf
.
Tensor
])
->
tf
.
Tensor
:
"""Uses the global generator to make a seed if necessary."""
if
seed
is
not
None
:
return
seed
generator
=
tf
.
random
.
experimental
.
get_global_generator
()
# The two seeds for stateless random ops don't have individual semantics and
# are scrambled together, so providing one seed is fine. This makes it easier
# for users to provide a local seed without worrying about integer overflow.
# See `make_seeds` in
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/stateful_random_ops.py.
try
:
return
generator
.
uniform_full_int
([],
tf
.
int32
,
name
=
'recompute_grad_seed'
)
except
(
RuntimeError
,
TypeError
,
ValueError
,
tf
.
errors
.
NotFoundError
)
as
e
:
# For a number of reasons, the above operation can fail like using multiple
# graphs or toggling between eager and graph modes. Reset the generator.
logging
.
warn
(
'Resetting the generator. %s: %s'
,
type
(
e
),
e
)
tf
.
random
.
experimental
.
set_global_generator
(
None
)
generator
=
tf
.
random
.
experimental
.
get_global_generator
()
return
generator
.
uniform_full_int
([],
tf
.
int32
,
name
=
'recompute_grad_seed'
)
def
recompute_grad
(
f
,
seed
=
None
):
"""An eager-compatible version of recompute_grad.
For f(*args, **kwargs), this supports gradients with respect to args, or to
gradients with respect to any variables residing in the kwarg 'variables'.
Note that for keras layer and model objects, this is handled automatically.
Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
be able to access the member variables of that object, because `g` returns
through the wrapper function `inner`. When recomputing gradients through
objects that inherit from keras, we suggest keeping a reference to the
underlying object around for the purpose of accessing these variables.
Args:
f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
seed: Optional seed for random ops. `seed` should an integer scalar
`Tensor`. When compiling to XLA, `seed` must have dtype `tf.int32`. If
`seed` is not provided one will be generated.
Returns:
A function `g` that wraps `f`, but which recomputes `f` on the backwards
pass of a gradient call.
"""
@
tf
.
custom_gradient
def
inner
(
*
args
,
**
kwargs
):
"""Inner function closure for calculating gradients."""
# Detect when we're nested and in the backwards pass, so we don't generate
# an additional seed.
parent_context
=
get_recompute_context
()
if
parent_context
is
not
None
and
parent_context
.
is_recomputing
:
# Use the cached context in the recomputation phase.
with
parent_context
.
children
.
popleft
().
_replace
(
is_recomputing
=
True
)
as
context
:
result
=
f
(
*
args
,
**
kwargs
)
else
:
with
RecomputeContext
(
is_recomputing
=
False
,
seed
=
_make_seed_if_none
(
seed
),
children
=
collections
.
deque
())
as
context
:
result
=
f
(
*
args
,
**
kwargs
)
# In the forward pass, build up a tree of recomputation contexts.
if
parent_context
is
not
None
and
not
parent_context
.
is_recomputing
:
parent_context
.
children
.
append
(
context
)
def
grad
(
*
dresult
,
**
grad_kwargs
):
"""Gradient function calculation for inner function."""
variables
=
grad_kwargs
.
pop
(
'variables'
,
None
)
if
grad_kwargs
:
raise
ValueError
(
'Found unexpected kwargs for `grad`: '
,
list
(
grad_kwargs
.
keys
()))
inputs
,
seed
=
list
(
args
),
context
.
seed
if
_in_xla_context
():
inputs
=
_force_data_dependency
(
tf
.
nest
.
flatten
(
dresult
),
inputs
+
[
seed
])
seed
=
inputs
.
pop
()
with
tf
.
GradientTape
()
as
tape
:
tape
.
watch
(
inputs
)
if
variables
is
not
None
:
tape
.
watch
(
variables
)
with
tf
.
control_dependencies
(
dresult
):
with
context
.
_replace
(
is_recomputing
=
True
,
seed
=
seed
):
result
=
f
(
*
inputs
,
**
kwargs
)
kw_vars
=
[]
if
variables
is
not
None
:
kw_vars
=
list
(
variables
)
grads
=
tape
.
gradient
(
result
,
list
(
inputs
)
+
kw_vars
,
output_gradients
=
dresult
)
return
grads
[:
len
(
inputs
)],
grads
[
len
(
inputs
):]
return
result
,
grad
return
inner
official/nlp/projects/bigbird/recomputing_dropout.py
0 → 100644
View file @
da8a5778
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras dropout layer that is aware of `RecomputeContext`."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.projects.bigbird
import
recompute_grad
as
recompute_grad_lib
from
official.nlp.projects.bigbird
import
stateless_dropout
as
stateless_dropout_lib
# Reimplements internal function
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/smart_cond.py.
def
smart_cond
(
pred
,
true_fn
=
None
,
false_fn
=
None
,
name
=
None
):
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
If `pred` is a bool or has a constant value, we return either `true_fn()`
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
Arguments:
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
name: Optional name prefix when using `tf.cond`.
Returns:
Tensors returned by the call to either `true_fn` or `false_fn`.
Raises:
TypeError: If `true_fn` or `false_fn` is not callable.
"""
if
not
callable
(
true_fn
):
raise
TypeError
(
'`true_fn` must be callable.'
)
if
not
callable
(
false_fn
):
raise
TypeError
(
'`false_fn` must be callable.'
)
pred_value
=
tf
.
get_static_value
(
pred
)
if
isinstance
(
pred
,
tf
.
Variable
)
or
pred_value
is
None
:
return
tf
.
cond
(
pred
,
true_fn
=
true_fn
,
false_fn
=
false_fn
,
name
=
name
)
if
pred_value
:
return
true_fn
()
else
:
return
false_fn
()
# See https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dropout.
class
RecomputingDropout
(
tf
.
keras
.
layers
.
Layer
):
"""`tf.keras.layers.Dropout` that supports `recompute_grad`."""
def
__init__
(
self
,
rate
,
noise_shape
=
None
,
seed
=
None
,
force_recomputation
=
False
,
**
kwargs
):
"""Initializes `RecomputingDropout`.
Args:
rate: Float between 0 and 1. Fraction of the input units to drop.
noise_shape: 1D integer tensor representing the shape of the binary
dropout mask that will be multiplied with the input. For instance, if
inputs have shape `(batch_size, timesteps, features)` and you want the
dropout mask to be the same for all timesteps, you can use
`noise_shape=(batch_size, 1, features)`.
seed: A Python integer to use as random seed.
force_recomputation: If `True`, then raises an error if called outside a
recompute context.
**kwargs: Keyword arguments for `tf.keras.layers.Layer`.
"""
super
(
RecomputingDropout
,
self
).
__init__
(
**
kwargs
)
self
.
rate
=
rate
self
.
noise_shape
=
noise_shape
self
.
seed
=
seed
self
.
force_recomputation
=
force_recomputation
self
.
supports_masking
=
True
# Create a layer-specific seed to combine with the global recompute seed.
self
.
_recompute_seed
=
(
np
.
random
.
randint
(
-
2
**
31
,
2
**
31
,
dtype
=
np
.
int32
)
if
seed
is
None
else
seed
)
def
_get_noise_shape
(
self
,
inputs
):
# Subclasses of `Dropout` may implement `_get_noise_shape(self, inputs)`,
# which will override `self.noise_shape`, and allows for custom noise
# shapes with dynamically sized inputs.
if
self
.
noise_shape
is
None
:
return
None
concrete_inputs_shape
=
tf
.
shape
(
inputs
)
noise_shape
=
[]
for
i
,
value
in
enumerate
(
self
.
noise_shape
):
noise_shape
.
append
(
concrete_inputs_shape
[
i
]
if
value
is
None
else
value
)
return
tf
.
convert_to_tensor
(
noise_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
"""Builds computation graph.
Args:
inputs: Input tensor (of any rank).
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
`inputs` masked according to layer configuration.
Raises:
ValueError: If `force_recomputation` is `True` and called outside a
a recompute context.
"""
if
training
is
None
:
training
=
tf
.
keras
.
backend
.
learning_phase
()
def
dropped_inputs
():
"""Randomly drops elements of `inputs` when `training=True`."""
recompute_context
=
recompute_grad_lib
.
get_recompute_context
()
if
recompute_context
is
None
:
if
self
.
force_recomputation
:
raise
ValueError
(
'RecomputeContext is required when force_recomputation=True.'
)
return
tf
.
nn
.
dropout
(
inputs
,
noise_shape
=
self
.
_get_noise_shape
(
inputs
),
seed
=
self
.
seed
,
rate
=
self
.
rate
)
seed
=
tf
.
stack
([
recompute_context
.
seed
,
self
.
_recompute_seed
])
return
stateless_dropout_lib
.
stateless_dropout
(
inputs
,
rate
=
self
.
rate
,
seed
=
seed
,
noise_shape
=
self
.
_get_noise_shape
(
inputs
))
output
=
smart_cond
(
training
,
dropped_inputs
,
lambda
:
tf
.
identity
(
inputs
))
return
output
def
compute_output_shape
(
self
,
input_shape
):
return
input_shape
def
get_config
(
self
):
config
=
{
'rate'
:
self
.
rate
,
'noise_shape'
:
self
.
noise_shape
,
'seed'
:
self
.
seed
,
'force_recomputation'
:
self
.
force_recomputation
,
}
base_config
=
super
(
RecomputingDropout
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
official/nlp/projects/bigbird/stateless_dropout.py
0 → 100644
View file @
da8a5778
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A replacement for tf.nn.dropout that uses stateless random ops."""
import
numbers
from
typing
import
Optional
,
Sequence
,
Text
,
Union
from
absl
import
logging
import
tensorflow
as
tf
def
_as_shape
(
shape
:
Union
[
Sequence
[
int
],
tf
.
TensorShape
])
->
tf
.
TensorShape
:
"""Converts the given object to a TensorShape."""
return
shape
if
isinstance
(
shape
,
tf
.
TensorShape
)
else
tf
.
TensorShape
(
shape
)
def
_get_noise_shape
(
x
:
tf
.
Tensor
,
noise_shape
:
Union
[
Sequence
[
int
],
tf
.
TensorShape
]
)
->
Union
[
tf
.
Tensor
,
tf
.
TensorShape
,
Sequence
[
int
]]:
"""Computes the shape of the binary mask for dropout."""
# If noise_shape is none return immediately.
if
noise_shape
is
None
:
return
tf
.
shape
(
x
)
try
:
# Best effort to figure out the intended shape.
# If not possible, let the op to handle it.
# In eager mode exception will show up.
noise_shape_
=
_as_shape
(
noise_shape
)
except
(
TypeError
,
ValueError
):
return
noise_shape
if
x
.
shape
.
dims
is
not
None
and
len
(
x
.
shape
.
dims
)
==
len
(
noise_shape_
.
dims
):
new_dims
=
[]
for
i
,
dim
in
enumerate
(
x
.
shape
.
dims
):
if
noise_shape_
.
dims
[
i
].
value
is
None
and
dim
.
value
is
not
None
:
new_dims
.
append
(
dim
.
value
)
else
:
new_dims
.
append
(
noise_shape_
.
dims
[
i
].
value
)
return
tf
.
TensorShape
(
new_dims
)
return
noise_shape
def
stateless_dropout
(
x
:
tf
.
Tensor
,
rate
:
float
,
seed
:
tf
.
Tensor
,
noise_shape
:
Optional
[
Union
[
Sequence
[
int
],
tf
.
TensorShape
]]
=
None
,
name
:
Optional
[
Text
]
=
None
)
->
tf
.
Tensor
:
"""Computes dropout: randomly sets elements to zero to prevent overfitting.
See https://www.tensorflow.org/api_docs/python/tf/nn/dropout.
This version differs in that the seed is required if the rate is nonzero.
Args:
x: A floating point tensor.
rate: A scalar `Tensor` with the same type as x. The probability that each
element is dropped. For example, setting rate=0.1 would drop 10% of input
elements.
seed: A shape [2] integer Tensor of seeds to the random number generator.
Must have dtype `tf.int32` when compiling to XLA.
noise_shape: A 1-D `Tensor` of type `int32`, representing the shape for
randomly generated keep/drop flags.
name: A name for this operation (optional).
Returns:
A `Tensor` of the same shape of `x`.
Raises:
ValueError: If `rate` is not in `[0, 1)` or if `x` is not a floating point
tensor. `rate=1` is disallowed, because the output would be all zeros,
which is likely not what was intended.
"""
with
tf
.
name_scope
(
name
or
'stateless_dropout'
)
as
name
:
x
=
tf
.
convert_to_tensor
(
x
,
name
=
'x'
)
if
not
x
.
dtype
.
is_floating
:
raise
ValueError
(
'x has to be a floating point tensor since it
\'
s going '
' to be scaled. Got a %s tensor instead.'
%
x
.
dtype
)
if
isinstance
(
rate
,
numbers
.
Real
):
if
not
(
rate
>=
0
and
rate
<
1
):
raise
ValueError
(
'rate must be a scalar tensor or a float in the '
'range [0, 1), got %g'
%
rate
)
if
rate
>
0.5
:
logging
.
log_first_n
(
logging
.
WARN
,
'Large dropout rate: %g (>0.5). In TensorFlow '
'.x, dropout() uses dropout rate instead of keep_prob. '
'Please ensure that this is intended.'
,
5
,
rate
)
# Early return if nothing needs to be dropped.
if
tf
.
get_static_value
(
rate
)
==
0
:
return
x
rate
=
tf
.
convert_to_tensor
(
rate
,
dtype
=
x
.
dtype
,
name
=
'rate'
)
rate
.
shape
.
assert_has_rank
(
0
)
noise_shape
=
_get_noise_shape
(
x
,
noise_shape
)
# Sample a uniform distribution on [0.0, 1.0) and select values larger than
# rate.
#
# NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0)
# and subtract 1.0.
random_tensor
=
tf
.
random
.
stateless_uniform
(
noise_shape
,
seed
=
seed
,
dtype
=
x
.
dtype
)
keep_prob
=
1
-
rate
scale
=
1
/
keep_prob
# NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that
# float to be selected, hence we use a >= comparison.
keep_mask
=
random_tensor
>=
rate
ret
=
x
*
scale
*
tf
.
cast
(
keep_mask
,
x
.
dtype
)
if
not
tf
.
executing_eagerly
():
ret
.
set_shape
(
x
.
get_shape
())
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment