"vscode:/vscode.git/clone" did not exist on "6674a5157f10f6f3a7ef41f2397ec90f8d20d0ef"
Unverified Commit 76a2c86b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix flashinfer version in sgl-kernel (#10135)

parent e719bb0e
...@@ -1187,7 +1187,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1187,7 +1187,7 @@ class FlashInferMultiStepDraftBackend:
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len), (self.speculative_num_steps, max_bs * self.topk * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
...@@ -1349,6 +1349,10 @@ def fast_decode_plan( ...@@ -1349,6 +1349,10 @@ def fast_decode_plan(
self.device, non_blocking=non_blocking self.device, non_blocking=non_blocking
) )
# TODO:
# We want to cache `empty_q_data`, `empty_kv_cache`, `last_page_len_host` (if it is ones) in the wrapper
# so that we do not need to create them every time.
# Create empty tensors for dtype info if needed # Create empty tensors for dtype info if needed
empty_q_data = torch.empty( empty_q_data = torch.empty(
0, 0,
......
...@@ -81,7 +81,7 @@ FetchContent_Populate(repo-triton) ...@@ -81,7 +81,7 @@ FetchContent_Populate(repo-triton)
FetchContent_Declare( FetchContent_Declare(
repo-flashinfer repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3 GIT_TAG 1a85c439a064c1609568675aa580a409a53fb183
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-flashinfer) FetchContent_Populate(repo-flashinfer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment