Unverified Commit 4e1c6a02 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Bugfix] fix rotary embedding test for _get_padded_tensor_shape (#18229)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent c7852a6d
......@@ -152,6 +152,10 @@ def test_batched_rotary_embedding(
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) if use_key else None
# slice tensor if required, noop otherwise
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment