"tests/vscode:/vscode.git/clone" did not exist on "b377e1b9c137f3cb6086d755c40e138163755c76"
Unverified Commit 0b2aa8a7 authored by Zhang Junda's avatar Zhang Junda Committed by GitHub
Browse files

Intoduce cpu tensor as metadata to avoid blocking gpu kernel launch (#10720)


Co-authored-by: default avatarhnyls2002 <lsyincs@gmail.com>
parent 609f65ba
......@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices, dtype=torch.int64, device=self.device
)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
self.orig_seq_lens = torch.tensor(
seq_lens, dtype=torch.int32, device=self.device
)
......
......@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids: torch.Tensor = None # shape: [b], int64
req_pool_indices: torch.Tensor = None # shape: [b], int64
seq_lens: torch.Tensor = None # shape: [b], int64
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache
out_cache_loc: torch.Tensor = None # shape: [b], int64
output_ids: torch.Tensor = None # shape: [b], int64
......@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def alloc_paged_token_slots_extend(
self,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
backup_state: bool = False,
......@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = (
extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
)
self._evict_tree_cache_if_needed(num_tokens)
......@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens, seq_lens, last_loc, extend_num_tokens
prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
if out_cache_loc is None:
error_msg = (
......@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def alloc_paged_token_slots_decode(
self,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
backup_state: bool = False,
):
......@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
seq_lens, seq_lens_cpu, last_loc
)
if out_cache_loc is None:
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
......@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
if not decoder_out_cache_loc:
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
......@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device
)
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
token_type_ids_tensor = None
if len(token_type_ids) > 0:
......@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor,
)
out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu_tensor,
last_loc,
extend_num_tokens,
)
# Set fields
self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu_tensor
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc
self.input_embeds = (
......@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while first_iter or (
not self.check_decode_mem(selected_indices=sorted_indices)
......@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
seq_lens_cpu = self.seq_lens.cpu().numpy()
seq_lens_cpu = self.seq_lens_cpu.numpy()
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
......@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
......@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.enable_overlap:
# Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1
self.seq_lens_cpu = self.seq_lens_cpu + 1
self.orig_seq_lens = self.orig_seq_lens + 1
else:
# A faster in-place version
self.seq_lens.add_(1)
self.seq_lens_cpu.add_(1)
self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs
......@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_pool_indices, self.seq_lens - 2
]
self.out_cache_loc = self.alloc_paged_token_slots_decode(
self.seq_lens, last_loc
self.seq_lens, self.seq_lens_cpu, last_loc
)
self.req_to_token_pool.write(
......@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device]
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
......@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum
......@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.grammars = None
seq_lens_cpu = (
seq_lens_cpu_cache
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
)
global bid
......
......@@ -27,7 +27,7 @@ import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, next_power_of_2
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
......@@ -294,7 +294,6 @@ def alloc_extend_kernel(
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
max_num_extend_tokens: tl.constexpr,
......@@ -323,13 +322,6 @@ def alloc_extend_kernel(
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
tl.int64
)
tl.store(ret_values, merged_value)
# Part 1: fill the old partial page
last_loc = tl.load(last_loc_ptr + pid)
num_part1 = (
......@@ -381,7 +373,6 @@ def alloc_decode_kernel(
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
):
......@@ -404,10 +395,6 @@ def alloc_decode_kernel(
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
tl.store(ret_values, sum_num_new_pages)
if num_page_start_loc_self == 0:
last_loc = tl.load(last_loc_ptr + pid)
tl.store(out_indices + pid, last_loc + 1)
......@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear()
......@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc_extend(
self,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
......@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
......@@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
merged_value = self.ret_values.item()
num_new_pages = merged_value >> 32
num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size)
if num_new_pages > len(self.free_pages):
return None
......@@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc_decode(
self,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
......@@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
)
......@@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.item()
num_new_pages = get_num_new_pages(
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True
)
if num_new_pages > len(self.free_pages):
return None
......
......@@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def alloc_extend(
self,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
......@@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
(seq_lens_cpu + self.page_size - 1) // self.page_size
- (prefix_lens_cpu + self.page_size - 1) // self.page_size
)
.sum()
.item()
......@@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def alloc_decode(
self,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
......@@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
)
need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item()
need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int()
num_new_pages = need_new_pages_cpu.sum().item()
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
......
......@@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput):
end_offset = batch.seq_lens + self.draft_token_num
else:
prefix_lens = batch.seq_lens
prefix_lens_cpu = batch.seq_lens_cpu
end_offset = prefix_lens + self.draft_token_num
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
prefix_lens, end_offset, last_loc, len(batch.input_ids)
prefix_lens,
prefix_lens_cpu,
end_offset,
end_offset_cpu,
last_loc,
len(batch.input_ids),
)
self.last_loc = last_loc
......@@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput):
verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
accept_length_cpu = accept_length.cpu()
accept_length_list = accept_length_cpu.tolist()
if page_size == 1:
# TODO: boolean array index leads to a device sync. Remove it.
......@@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput):
else:
batch.out_cache_loc = tgt_cache_loc
batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
draft_input = EagleDraftInput(
hidden_states=batch.spec_info.hidden_states[accept_index],
verified_id=verified_id,
accept_length=accept_length,
accept_length_cpu=accept_length.tolist(),
accept_length_cpu=accept_length_list,
seq_lens_for_draft_extend=batch.seq_lens,
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu,
req_pool_indices_for_draft_extend=batch.req_pool_indices,
)
......@@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput):
next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
accept_length_cpu = accept_length.tolist()
if len(unfinished_accept_index) > 0:
unfinished_accept_index = torch.cat(unfinished_accept_index)
unfinished_index_device = torch.tensor(
unfinished_index, dtype=torch.int64, device=predict.device
)
draft_input_accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
accept_length_list[i] for i in unfinished_index
]
if page_size == 1 or self.topk == 1:
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
......@@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput):
unfinished_index_device,
batch.seq_lens,
)
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
filter_finished_cache_loc_kernel[(bs,)](
batch.out_cache_loc,
tgt_cache_loc,
......@@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput):
accept_length_cpu=draft_input_accept_length_cpu,
accept_length=accept_length[unfinished_index_device],
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index],
req_pool_indices_for_draft_extend=batch.req_pool_indices[
unfinished_index_device
],
......@@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput):
draft_input=draft_input,
logits_output=logits_output,
verified_id=verified_id,
accept_length_per_req_cpu=accept_length_cpu,
accept_length_per_req_cpu=accept_length_list,
accepted_indices=accept_index,
)
......@@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput):
# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
seq_lens_for_draft_extend_cpu: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None
def __post_init__(self):
......@@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput):
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
batch.return_logprob = False
batch.return_hidden_states = False
......
......@@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens,
self.speculative_num_steps,
)
prefix_lens_cpu = batch.seq_lens_cpu
seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
extend_num_tokens = num_seqs * self.speculative_num_steps
else:
# In this case, the last partial page needs to be duplicated.
......@@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.page_size,
)
# TODO(lmzheng): remove this device sync
extend_num_tokens = torch.sum(self.extend_lens).item()
prefix_lens_cpu = batch.seq_lens_cpu
last_page_lens = prefix_lens_cpu % self.page_size
num_new_pages_per_topk = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens_cpu = (
prefix_lens_cpu // self.page_size * self.page_size
+ num_new_pages_per_topk * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend(
prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
backup_state=True,
......@@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker):
assert isinstance(batch.spec_info, EagleDraftInput)
# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
......@@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker):
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
)
batch.seq_lens = seq_lens_backup
batch.seq_lens_cpu = seq_lens_cpu_backup
batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup
batch.return_logprob = return_logprob_backup
......
......@@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput):
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
end_offset = batch.seq_lens + self.draft_token_num
else:
# TODO(lsyin): add prefix lens cpu here to support page size > 1
prefix_lens = batch.seq_lens
end_offset = prefix_lens + self.draft_token_num
last_loc = get_last_loc(
......@@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput):
self._fill_requests(batch, logits_output)
self._free_cache(batch, page_size)
accept_length_cpu = self.accept_length.cpu()
num_accepted_tokens = accept_length_cpu.sum().item()
batch.seq_lens.add_(self.accept_length + 1)
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
return logits_output, self.verified_id, self.accept_length.sum().item()
return logits_output, self.verified_id, num_accepted_tokens
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
pass
......
......@@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit(
return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
def get_num_new_pages(
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
page_size: int,
decode: bool = False,
) -> torch.Tensor:
"""
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
"""
cpu_device = torch.device("cpu")
assert prefix_lens.device == cpu_device
assert seq_lens.device == cpu_device
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (prefix_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
extend_lens = seq_lens - prefix_lens
sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
if decode:
return sum_num_new_pages.item()
merged_value = (sum_num_new_pages) << 32 | torch.sum(extend_lens).to(torch.int64)
return merged_value.item() >> 32
class CachedKernel:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment