"docs/source/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "72ae755ab8951f08949ea53983418ecef02d85bb"
Commit 17cec5c0 authored by Dan Kondratyuk's avatar Dan Kondratyuk Committed by A. Unique TensorFlower
Browse files

Support arbitrary dimensions in stochastic depth.

PiperOrigin-RevId: 363188927
parent 8aae5ece
...@@ -232,7 +232,7 @@ class StochasticDepth(tf.keras.layers.Layer): ...@@ -232,7 +232,7 @@ class StochasticDepth(tf.keras.layers.Layer):
batch_size = tf.shape(inputs)[0] batch_size = tf.shape(inputs)[0]
random_tensor = keep_prob random_tensor = keep_prob
random_tensor += tf.random.uniform( random_tensor += tf.random.uniform(
[batch_size, 1, 1, 1], dtype=inputs.dtype) [batch_size] + [1] * (inputs.shape.rank - 1), dtype=inputs.dtype)
binary_tensor = tf.floor(random_tensor) binary_tensor = tf.floor(random_tensor)
output = tf.math.divide(inputs, keep_prob) * binary_tensor output = tf.math.divide(inputs, keep_prob) * binary_tensor
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment