Unverified Commit 8088ca41 authored by Matt's avatar Matt Committed by GitHub
Browse files

Make TF ESM inv_freq non-trainable like PyTorch (#23940)

Make TF inv_freq non-trainable like PyTorch
parent 5929f86e
......@@ -110,7 +110,7 @@ class TFRotaryEmbedding(Layer):
def build(self, input_shape):
super().build(input_shape)
self.inv_freq = self.add_weight(
"inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0)
"inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False
)
self.inv_freq.assign(
1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment