"vscode:/vscode.git/clone" did not exist on "04fd783cc50bcc6744634e7300b3828b38a4dc79"
Unverified Commit 8088ca41 authored by Matt's avatar Matt Committed by GitHub
Browse files

Make TF ESM inv_freq non-trainable like PyTorch (#23940)

Make TF inv_freq non-trainable like PyTorch
parent 5929f86e
...@@ -110,7 +110,7 @@ class TFRotaryEmbedding(Layer): ...@@ -110,7 +110,7 @@ class TFRotaryEmbedding(Layer):
def build(self, input_shape): def build(self, input_shape):
super().build(input_shape) super().build(input_shape)
self.inv_freq = self.add_weight( self.inv_freq = self.add_weight(
"inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0) "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False
) )
self.inv_freq.assign( self.inv_freq.assign(
1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) 1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment