Unverified Commit 96a2093e authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] Compatibility of window attention and cuda graph (#1090)

parent a34dd86a
......@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
scaling: float,
num_kv_heads: int,
layer_id: int,
reuse: bool = False,
sliding_window_size: int = -1,
logit_cap: int = -1,
v_head_dim: int = -1,
......@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
self.reuse = reuse
self.sliding_window_size = sliding_window_size
if (
......@@ -127,8 +129,9 @@ class RadixAttention(nn.Module):
if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1]
if not input_metadata.flashinfer_use_ragged:
self.store_kv_cache(k, v, input_metadata)
if not input_metadata.flashinfer_use_ragged or self.reuse:
if not self.reuse:
self.store_kv_cache(k, v, input_metadata)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
......@@ -179,7 +182,8 @@ class RadixAttention(nn.Module):
if isinstance(decode_wrapper, list):
decode_wrapper = decode_wrapper[1]
self.store_kv_cache(k, v, input_metadata)
if not self.reuse:
self.store_kv_cache(k, v, input_metadata)
o = decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
......@@ -191,8 +195,10 @@ class RadixAttention(nn.Module):
return o.view(-1, self.tp_q_head_num * self.head_dim)
def forward(self, q, k, v, input_metadata: InputMetadata):
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if k is not None:
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if input_metadata.forward_mode == ForwardMode.EXTEND:
return self.extend_forward(q, k, v, input_metadata)
......
......@@ -107,9 +107,6 @@ class CudaGraphRunner:
)
# FlashInfer inputs
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffers[0]
)
self.flashinfer_kv_indptr = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
......@@ -121,6 +118,23 @@ class CudaGraphRunner:
self.flashinfer_kv_last_page_len = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
if model_runner.sliding_window_size is None:
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffers[0]
)
else:
self.flashinfer_workspace_buffers = [
self.model_runner.flashinfer_workspace_buffers[0],
self.model_runner.flashinfer_workspace_buffers[2],
]
self.flashinfer_kv_indptr = [
self.flashinfer_kv_indptr,
self.flashinfer_kv_indptr.clone(),
]
self.flashinfer_kv_indices = [
self.flashinfer_kv_indices,
self.flashinfer_kv_indices.clone(),
]
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
......@@ -171,15 +185,32 @@ class CudaGraphRunner:
use_tensor_cores = True
else:
use_tensor_cores = False
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
)
if self.model_runner.sliding_window_size is None:
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
)
else:
flashinfer_decode_wrapper = []
for i in range(2):
flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[i],
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
:bs
],
)
)
update_flashinfer_indices(
ForwardMode.DECODE,
self.model_runner,
......
......@@ -154,7 +154,6 @@ class InputMetadata:
model_runner: "ModelRunner",
batch: ScheduleBatch,
forward_mode: ForwardMode,
sliding_window_size: Optional[int] = None,
):
ret = cls(
forward_mode=forward_mode,
......@@ -198,7 +197,7 @@ class InputMetadata:
):
flashinfer_use_ragged = True
ret.init_flashinfer_handlers(
model_runner, prefix_lens, flashinfer_use_ragged, sliding_window_size
model_runner, prefix_lens, flashinfer_use_ragged
)
return ret
......@@ -221,7 +220,6 @@ class InputMetadata:
model_runner,
prefix_lens,
flashinfer_use_ragged,
sliding_window_size=None,
):
update_flashinfer_indices(
self.forward_mode,
......@@ -230,7 +228,6 @@ class InputMetadata:
self.seq_lens,
prefix_lens,
flashinfer_use_ragged=flashinfer_use_ragged,
sliding_window_size=sliding_window_size,
)
(
......@@ -254,7 +251,6 @@ def update_flashinfer_indices(
prefix_lens,
flashinfer_decode_wrapper=None,
flashinfer_use_ragged=False,
sliding_window_size=None,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
......@@ -262,7 +258,7 @@ def update_flashinfer_indices(
head_dim = model_runner.model_config.head_dim
batch_size = len(req_pool_indices)
if sliding_window_size is None:
if model_runner.sliding_window_size is None:
if flashinfer_use_ragged:
paged_kernel_lens = prefix_lens
else:
......@@ -335,7 +331,7 @@ def update_flashinfer_indices(
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
paged_kernel_lens = torch.minimum(
paged_kernel_lens, torch.tensor(sliding_window_size)
paged_kernel_lens, torch.tensor(model_runner.sliding_window_size)
)
kv_start_idx = seq_lens - paged_kernel_lens
else:
......
......@@ -187,6 +187,11 @@ class ModelRunner:
scheduler_config=None,
cache_config=None,
)
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures
)
......@@ -295,12 +300,6 @@ class ModelRunner:
return c
def init_flashinfer(self):
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
else None
)
if self.server_args.disable_flashinfer:
assert (
self.sliding_window_size is None
......@@ -339,7 +338,7 @@ class ModelRunner:
use_tensor_cores=use_tensor_cores,
)
else:
workspace_buffers = torch.empty(
self.flashinfer_workspace_buffers = torch.empty(
4,
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
......@@ -351,17 +350,17 @@ class ModelRunner:
for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffers[2 * i + 0], "NHD"
self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
)
)
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[2 * i + 1], "NHD"
self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
)
)
self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
workspace_buffers[2 * i + 0],
self.flashinfer_workspace_buffers[2 * i + 0],
"NHD",
use_tensor_cores=use_tensor_cores,
)
......@@ -404,7 +403,6 @@ class ModelRunner:
self,
batch,
ForwardMode.DECODE,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
......@@ -417,7 +415,6 @@ class ModelRunner:
self,
batch,
forward_mode=ForwardMode.EXTEND,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
......@@ -429,7 +426,6 @@ class ModelRunner:
self,
batch,
forward_mode=ForwardMode.EXTEND,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
batch.input_ids,
......
......@@ -453,10 +453,12 @@ class ServerArgs:
logger.info(
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
)
# FIXME: compatibility with radix attention
self.disable_radix_cache = True
# FIXME: compatibility with jump forward
self.disable_regex_jump_forward = True
self.disable_flashinfer = False
self.disable_cuda_graph = True
# FIXME: compatibility with chunked prefill
self.chunked_prefill_size = None
......
......@@ -36,7 +36,7 @@ DEFAULT_PROMPTS = [
]
dirpath = os.path.dirname(__file__)
with open(os.path.join(dirpath, "long_prompt"), "r") as f:
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment