Unverified Commit 2101d93b authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix CI TestChunkedSGMV (#10737)

parent 70e4b218
...@@ -621,6 +621,12 @@ class CachedKernel: ...@@ -621,6 +621,12 @@ class CachedKernel:
return complete_args return complete_args
def _clear_cache(self):
"""
Clear the kernel cache for testing purposes.
"""
self.kernel_cache.clear()
def cached_triton_kernel(key_fn=None): def cached_triton_kernel(key_fn=None):
""" """
......
...@@ -10,11 +10,18 @@ from sglang.srt.lora.triton_ops import ( ...@@ -10,11 +10,18 @@ from sglang.srt.lora.triton_ops import (
chunked_sgmv_lora_expand_forward, chunked_sgmv_lora_expand_forward,
chunked_sgmv_lora_shrink_forward, chunked_sgmv_lora_shrink_forward,
) )
from sglang.srt.lora.triton_ops.chunked_sgmv_expand import _chunked_lora_expand_kernel
from sglang.srt.lora.triton_ops.chunked_sgmv_shrink import _chunked_lora_shrink_kernel
from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.lora.utils import LoRABatchInfo
CHUNK_SIZE = 16 CHUNK_SIZE = 16
def reset_kernel_cache():
_chunked_lora_shrink_kernel._clear_cache()
_chunked_lora_expand_kernel._clear_cache()
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Matrix multiplication with mixed precision handling for float16""" """Matrix multiplication with mixed precision handling for float16"""
result = torch.matmul(a.float(), b.float()) result = torch.matmul(a.float(), b.float())
...@@ -436,6 +443,10 @@ class TestChunkedSGMV(unittest.TestCase): ...@@ -436,6 +443,10 @@ class TestChunkedSGMV(unittest.TestCase):
List[str], List[str],
]: ]:
"""Create test batch with specified composition and mode""" """Create test batch with specified composition and mode"""
# Reset kernel cache to avoid cross-test contamination
reset_kernel_cache()
seq_lengths = self.generate_sequence_lengths( seq_lengths = self.generate_sequence_lengths(
batch_size, batch_mode, 1, self.max_seq_len batch_size, batch_mode, 1, self.max_seq_len
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment