"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "5f3c5ffd5343c88ecd99c9ae64fbec009275f73b"
Unverified Commit 4cdb7ee5 authored by yujun's avatar yujun Committed by GitHub
Browse files

fix #11724 (#11897)

parent 83f02512
...@@ -151,11 +151,12 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer): ...@@ -151,11 +151,12 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
position_enc = np.array( position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
) )
table = np.zeros_like(position_enc)
# index 0 is all zero # index 0 is all zero
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor # convert to tensor
table = tf.convert_to_tensor(position_enc) table = tf.convert_to_tensor(table)
tf.stop_gradient(table) tf.stop_gradient(table)
return table return table
......
...@@ -152,11 +152,12 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Layer): ...@@ -152,11 +152,12 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
position_enc = np.array( position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
) )
table = np.zeros_like(position_enc)
# index 0 is all zero # index 0 is all zero
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor # convert to tensor
table = tf.convert_to_tensor(position_enc) table = tf.convert_to_tensor(table)
tf.stop_gradient(table) tf.stop_gradient(table)
return table return table
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment