Unverified Commit 780fbf2f authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[Fix] Fix accuracy bug in CSGMV kernel caching key. (#11579)

parent 825432fc
...@@ -9,7 +9,7 @@ from sglang.srt.utils import cached_triton_kernel ...@@ -9,7 +9,7 @@ from sglang.srt.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
@triton.jit @triton.jit(do_not_specialize=["num_segs"])
def _chunked_lora_expand_kernel( def _chunked_lora_expand_kernel(
# Pointers to matrices # Pointers to matrices
x, x,
......
...@@ -6,8 +6,10 @@ from sglang.srt.lora.utils import LoRABatchInfo ...@@ -6,8 +6,10 @@ from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import cached_triton_kernel from sglang.srt.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) @cached_triton_kernel(
@triton.jit lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"])
)
@triton.jit(do_not_specialize=["num_segs"])
def _chunked_lora_shrink_kernel( def _chunked_lora_shrink_kernel(
# Pointers to matrices # Pointers to matrices
x, x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment