Unverified Commit 988d0a4b authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[kernel] Use sgl_kernel rope (#3169)


Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent 81262c7b
......@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.utils import is_cuda_available
_is_cuda_available = is_cuda_available()
if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
......@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if not _is_cuda_available:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
......@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
if _is_cuda_available:
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
else:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
def forward_xpu(
......
......@@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase):
chunks_ids[i] = chunks_ids[i][1:]
# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
......@@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase):
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
async def async_generate(self, payload):
url = self.base_url + "/generate"
......@@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase):
chunks_ids[i] = chunks_ids[i][1:]
# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
......@@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase):
assert response["meta_info"]["finish_reason"]["type"] == "abort"
else:
# 2. not using session control
requests.post(self.base_url + "/flush_cache")
output_ids = tokenizer.encode(gen_so_far)
if output_ids[0] == tokenizer.bos_token_id:
output_ids = output_ids[1:]
......@@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase):
output_no_session = response["text"]
print("second request output without session:")
print(output_no_session)
assert second_output == output_no_session
assert (
second_output == output_no_session
), f"second_output: {second_output}, output_no_session: {output_no_session}"
def test_session_control_backtrack_with_abort(self):
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
......@@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase):
assert len(x) == len(chunks_per_step[0])
# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
......@@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase):
print(outputs_from_session)
print("====== outputs from normal queries: =======")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
def test_session_control_with_branching(self):
root_prompt = "First, let me explain in one sentence about AI"
......@@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase):
gen_len = 32
# 1. using session control
requests.post(self.base_url + "/flush_cache")
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
......@@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase):
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment