"tests/python/common/sampling/test_sampling.py" did not exist on "7e6a6b4ac01e4a18433e937124f3a4c94301a34c"
Unverified Commit 988d0a4b authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[kernel] Use sgl_kernel rope (#3169)


Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent 81262c7b
...@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available()
if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp): ...@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
self.dtype = dtype self.dtype = dtype
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
cache = cache.to(dtype) # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if not _is_cuda_available:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
...@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp): ...@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops if _is_cuda_available:
apply_rope_with_cos_sin_cache_inplace(
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) positions=positions,
ops.rotary_embedding( query=query,
positions, key=key,
query, head_size=self.head_size,
key, cos_sin_cache=self.cos_sin_cache,
self.head_size, is_neox=self.is_neox_style,
self.cos_sin_cache, )
self.is_neox_style, else:
) self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key return query, key
def forward_xpu( def forward_xpu(
......
...@@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase):
chunks_ids[i] = chunks_ids[i][1:] chunks_ids[i] = chunks_ids[i][1:]
# 1. using session control # 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post( session_id = requests.post(
self.base_url + "/open_session", self.base_url + "/open_session",
json={"capacity_of_str_len": 1000}, json={"capacity_of_str_len": 1000},
...@@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase): ...@@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase):
print(outputs_from_session) print(outputs_from_session)
print("outputs from normal queries:") print("outputs from normal queries:")
print(outputs_normal) print(outputs_normal)
assert outputs_from_session == outputs_normal assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
async def async_generate(self, payload): async def async_generate(self, payload):
url = self.base_url + "/generate" url = self.base_url + "/generate"
...@@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase):
chunks_ids[i] = chunks_ids[i][1:] chunks_ids[i] = chunks_ids[i][1:]
# 1. using session control # 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post( session_id = requests.post(
self.base_url + "/open_session", self.base_url + "/open_session",
json={"capacity_of_str_len": 1000}, json={"capacity_of_str_len": 1000},
...@@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase):
assert response["meta_info"]["finish_reason"]["type"] == "abort" assert response["meta_info"]["finish_reason"]["type"] == "abort"
else: else:
# 2. not using session control # 2. not using session control
requests.post(self.base_url + "/flush_cache")
output_ids = tokenizer.encode(gen_so_far) output_ids = tokenizer.encode(gen_so_far)
if output_ids[0] == tokenizer.bos_token_id: if output_ids[0] == tokenizer.bos_token_id:
output_ids = output_ids[1:] output_ids = output_ids[1:]
...@@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase): ...@@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase):
output_no_session = response["text"] output_no_session = response["text"]
print("second request output without session:") print("second request output without session:")
print(output_no_session) print(output_no_session)
assert second_output == output_no_session assert (
second_output == output_no_session
), f"second_output: {second_output}, output_no_session: {output_no_session}"
def test_session_control_backtrack_with_abort(self): def test_session_control_backtrack_with_abort(self):
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True)) asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
...@@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase): ...@@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase):
assert len(x) == len(chunks_per_step[0]) assert len(x) == len(chunks_per_step[0])
# 1. using session control # 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post( session_id = requests.post(
self.base_url + "/open_session", self.base_url + "/open_session",
json={"capacity_of_str_len": 1000}, json={"capacity_of_str_len": 1000},
...@@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase): ...@@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase):
print(outputs_from_session) print(outputs_from_session)
print("====== outputs from normal queries: =======") print("====== outputs from normal queries: =======")
print(outputs_normal) print(outputs_normal)
assert outputs_from_session == outputs_normal assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
def test_session_control_with_branching(self): def test_session_control_with_branching(self):
root_prompt = "First, let me explain in one sentence about AI" root_prompt = "First, let me explain in one sentence about AI"
...@@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase):
gen_len = 32 gen_len = 32
# 1. using session control # 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post( session_id = requests.post(
self.base_url + "/open_session", self.base_url + "/open_session",
json={"capacity_of_str_len": 1000}, json={"capacity_of_str_len": 1000},
...@@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase): ...@@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase):
print(outputs_from_session) print(outputs_from_session)
print("outputs from normal queries:") print("outputs from normal queries:")
print(outputs_normal) print(outputs_normal)
assert outputs_from_session == outputs_normal assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment