Unverified Commit 514f37c3 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[kernel] Fix position ids in rope (#3173)

parent 52c03f16
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "sgl-kernel" name = "sgl-kernel"
version = "0.0.2.post19" version = "0.0.2.post20"
description = "Kernel Library for SGLang" description = "Kernel Library for SGLang"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"
......
...@@ -51,7 +51,7 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -51,7 +51,7 @@ def apply_rope_with_cos_sin_cache_inplace(
raise ValueError("cos_sin_cache should be float32") raise ValueError("cos_sin_cache should be float32")
with query.device as device: with query.device as device:
pos_ids = pos_ids.int() positions = positions.int()
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache( torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size), q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size),
......
...@@ -196,3 +196,7 @@ def test_correctness( ...@@ -196,3 +196,7 @@ def test_correctness(
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
) )
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
pytest.main([__file__])
__version__ = "0.0.2.post19" __version__ = "0.0.2.post20"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment