Unverified Commit 2101d93b authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix CI TestChunkedSGMV (#10737)

parent 70e4b218
......@@ -621,6 +621,12 @@ class CachedKernel:
return complete_args
def _clear_cache(self):
"""
Clear the kernel cache for testing purposes.
"""
self.kernel_cache.clear()
def cached_triton_kernel(key_fn=None):
"""
......
......@@ -10,11 +10,18 @@ from sglang.srt.lora.triton_ops import (
chunked_sgmv_lora_expand_forward,
chunked_sgmv_lora_shrink_forward,
)
from sglang.srt.lora.triton_ops.chunked_sgmv_expand import _chunked_lora_expand_kernel
from sglang.srt.lora.triton_ops.chunked_sgmv_shrink import _chunked_lora_shrink_kernel
from sglang.srt.lora.utils import LoRABatchInfo
CHUNK_SIZE = 16
def reset_kernel_cache():
_chunked_lora_shrink_kernel._clear_cache()
_chunked_lora_expand_kernel._clear_cache()
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Matrix multiplication with mixed precision handling for float16"""
result = torch.matmul(a.float(), b.float())
......@@ -436,6 +443,10 @@ class TestChunkedSGMV(unittest.TestCase):
List[str],
]:
"""Create test batch with specified composition and mode"""
# Reset kernel cache to avoid cross-test contamination
reset_kernel_cache()
seq_lengths = self.generate_sequence_lengths(
batch_size, batch_mode, 1, self.max_seq_len
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment