Unverified Commit c77092a5 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[FlaxGPTJ] Fix bug in rotary embeddings (#16298)

parent 4b277483
......@@ -122,7 +122,7 @@ def create_sinusoidal_positions(num_pos, dim):
def rotate_every_two(tensor):
rotate_half_tensor = jnp.stack((tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,))
return rotate_half_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment