Unverified Commit 96a2093e authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] Compatibility of window attention and cuda graph (#1090)

parent a34dd86a
...@@ -34,6 +34,7 @@ class RadixAttention(nn.Module): ...@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
scaling: float, scaling: float,
num_kv_heads: int, num_kv_heads: int,
layer_id: int, layer_id: int,
reuse: bool = False,
sliding_window_size: int = -1, sliding_window_size: int = -1,
logit_cap: int = -1, logit_cap: int = -1,
v_head_dim: int = -1, v_head_dim: int = -1,
...@@ -47,6 +48,7 @@ class RadixAttention(nn.Module): ...@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling self.scaling = scaling
self.layer_id = layer_id self.layer_id = layer_id
self.reuse = reuse
self.sliding_window_size = sliding_window_size self.sliding_window_size = sliding_window_size
if ( if (
...@@ -127,7 +129,8 @@ class RadixAttention(nn.Module): ...@@ -127,7 +129,8 @@ class RadixAttention(nn.Module):
if isinstance(prefill_wrapper_paged, list): if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1] prefill_wrapper_paged = prefill_wrapper_paged[1]
if not input_metadata.flashinfer_use_ragged: if not input_metadata.flashinfer_use_ragged or self.reuse:
if not self.reuse:
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
...@@ -179,6 +182,7 @@ class RadixAttention(nn.Module): ...@@ -179,6 +182,7 @@ class RadixAttention(nn.Module):
if isinstance(decode_wrapper, list): if isinstance(decode_wrapper, list):
decode_wrapper = decode_wrapper[1] decode_wrapper = decode_wrapper[1]
if not self.reuse:
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
o = decode_wrapper.forward( o = decode_wrapper.forward(
...@@ -191,6 +195,8 @@ class RadixAttention(nn.Module): ...@@ -191,6 +195,8 @@ class RadixAttention(nn.Module):
return o.view(-1, self.tp_q_head_num * self.head_dim) return o.view(-1, self.tp_q_head_num * self.head_dim)
def forward(self, q, k, v, input_metadata: InputMetadata): def forward(self, q, k, v, input_metadata: InputMetadata):
if k is not None:
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
......
...@@ -107,9 +107,6 @@ class CudaGraphRunner: ...@@ -107,9 +107,6 @@ class CudaGraphRunner:
) )
# FlashInfer inputs # FlashInfer inputs
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffers[0]
)
self.flashinfer_kv_indptr = torch.zeros( self.flashinfer_kv_indptr = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda" (self.max_bs + 1,), dtype=torch.int32, device="cuda"
) )
...@@ -121,6 +118,23 @@ class CudaGraphRunner: ...@@ -121,6 +118,23 @@ class CudaGraphRunner:
self.flashinfer_kv_last_page_len = torch.ones( self.flashinfer_kv_last_page_len = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda" (self.max_bs,), dtype=torch.int32, device="cuda"
) )
if model_runner.sliding_window_size is None:
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffers[0]
)
else:
self.flashinfer_workspace_buffers = [
self.model_runner.flashinfer_workspace_buffers[0],
self.model_runner.flashinfer_workspace_buffers[2],
]
self.flashinfer_kv_indptr = [
self.flashinfer_kv_indptr,
self.flashinfer_kv_indptr.clone(),
]
self.flashinfer_kv_indices = [
self.flashinfer_kv_indices,
self.flashinfer_kv_indices.clone(),
]
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
...@@ -171,6 +185,7 @@ class CudaGraphRunner: ...@@ -171,6 +185,7 @@ class CudaGraphRunner:
use_tensor_cores = True use_tensor_cores = True
else: else:
use_tensor_cores = False use_tensor_cores = False
if self.model_runner.sliding_window_size is None:
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, self.flashinfer_workspace_buffer,
"NHD", "NHD",
...@@ -180,6 +195,22 @@ class CudaGraphRunner: ...@@ -180,6 +195,22 @@ class CudaGraphRunner:
paged_kv_indices_buffer=self.flashinfer_kv_indices, paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
) )
else:
flashinfer_decode_wrapper = []
for i in range(2):
flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[i],
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
:bs
],
)
)
update_flashinfer_indices( update_flashinfer_indices(
ForwardMode.DECODE, ForwardMode.DECODE,
self.model_runner, self.model_runner,
......
...@@ -154,7 +154,6 @@ class InputMetadata: ...@@ -154,7 +154,6 @@ class InputMetadata:
model_runner: "ModelRunner", model_runner: "ModelRunner",
batch: ScheduleBatch, batch: ScheduleBatch,
forward_mode: ForwardMode, forward_mode: ForwardMode,
sliding_window_size: Optional[int] = None,
): ):
ret = cls( ret = cls(
forward_mode=forward_mode, forward_mode=forward_mode,
...@@ -198,7 +197,7 @@ class InputMetadata: ...@@ -198,7 +197,7 @@ class InputMetadata:
): ):
flashinfer_use_ragged = True flashinfer_use_ragged = True
ret.init_flashinfer_handlers( ret.init_flashinfer_handlers(
model_runner, prefix_lens, flashinfer_use_ragged, sliding_window_size model_runner, prefix_lens, flashinfer_use_ragged
) )
return ret return ret
...@@ -221,7 +220,6 @@ class InputMetadata: ...@@ -221,7 +220,6 @@ class InputMetadata:
model_runner, model_runner,
prefix_lens, prefix_lens,
flashinfer_use_ragged, flashinfer_use_ragged,
sliding_window_size=None,
): ):
update_flashinfer_indices( update_flashinfer_indices(
self.forward_mode, self.forward_mode,
...@@ -230,7 +228,6 @@ class InputMetadata: ...@@ -230,7 +228,6 @@ class InputMetadata:
self.seq_lens, self.seq_lens,
prefix_lens, prefix_lens,
flashinfer_use_ragged=flashinfer_use_ragged, flashinfer_use_ragged=flashinfer_use_ragged,
sliding_window_size=sliding_window_size,
) )
( (
...@@ -254,7 +251,6 @@ def update_flashinfer_indices( ...@@ -254,7 +251,6 @@ def update_flashinfer_indices(
prefix_lens, prefix_lens,
flashinfer_decode_wrapper=None, flashinfer_decode_wrapper=None,
flashinfer_use_ragged=False, flashinfer_use_ragged=False,
sliding_window_size=None,
): ):
"""Init auxiliary variables for FlashInfer attention backend.""" """Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
...@@ -262,7 +258,7 @@ def update_flashinfer_indices( ...@@ -262,7 +258,7 @@ def update_flashinfer_indices(
head_dim = model_runner.model_config.head_dim head_dim = model_runner.model_config.head_dim
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
if sliding_window_size is None: if model_runner.sliding_window_size is None:
if flashinfer_use_ragged: if flashinfer_use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
else: else:
...@@ -335,7 +331,7 @@ def update_flashinfer_indices( ...@@ -335,7 +331,7 @@ def update_flashinfer_indices(
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE: if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
paged_kernel_lens = torch.minimum( paged_kernel_lens = torch.minimum(
paged_kernel_lens, torch.tensor(sliding_window_size) paged_kernel_lens, torch.tensor(model_runner.sliding_window_size)
) )
kv_start_idx = seq_lens - paged_kernel_lens kv_start_idx = seq_lens - paged_kernel_lens
else: else:
......
...@@ -187,6 +187,11 @@ class ModelRunner: ...@@ -187,6 +187,11 @@ class ModelRunner:
scheduler_config=None, scheduler_config=None,
cache_config=None, cache_config=None,
) )
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
else None
)
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.model_config.hf_config.architectures self.model_config.hf_config.architectures
) )
...@@ -295,12 +300,6 @@ class ModelRunner: ...@@ -295,12 +300,6 @@ class ModelRunner:
return c return c
def init_flashinfer(self): def init_flashinfer(self):
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
else None
)
if self.server_args.disable_flashinfer: if self.server_args.disable_flashinfer:
assert ( assert (
self.sliding_window_size is None self.sliding_window_size is None
...@@ -339,7 +338,7 @@ class ModelRunner: ...@@ -339,7 +338,7 @@ class ModelRunner:
use_tensor_cores=use_tensor_cores, use_tensor_cores=use_tensor_cores,
) )
else: else:
workspace_buffers = torch.empty( self.flashinfer_workspace_buffers = torch.empty(
4, 4,
global_config.flashinfer_workspace_size, global_config.flashinfer_workspace_size,
dtype=torch.uint8, dtype=torch.uint8,
...@@ -351,17 +350,17 @@ class ModelRunner: ...@@ -351,17 +350,17 @@ class ModelRunner:
for i in range(2): for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append( self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper( BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffers[2 * i + 0], "NHD" self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
) )
) )
self.flashinfer_prefill_wrapper_paged.append( self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper( BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[2 * i + 1], "NHD" self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
) )
) )
self.flashinfer_decode_wrapper.append( self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper( BatchDecodeWithPagedKVCacheWrapper(
workspace_buffers[2 * i + 0], self.flashinfer_workspace_buffers[2 * i + 0],
"NHD", "NHD",
use_tensor_cores=use_tensor_cores, use_tensor_cores=use_tensor_cores,
) )
...@@ -404,7 +403,6 @@ class ModelRunner: ...@@ -404,7 +403,6 @@ class ModelRunner:
self, self,
batch, batch,
ForwardMode.DECODE, ForwardMode.DECODE,
sliding_window_size=self.sliding_window_size,
) )
return self.model.forward( return self.model.forward(
...@@ -417,7 +415,6 @@ class ModelRunner: ...@@ -417,7 +415,6 @@ class ModelRunner:
self, self,
batch, batch,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
sliding_window_size=self.sliding_window_size,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -429,7 +426,6 @@ class ModelRunner: ...@@ -429,7 +426,6 @@ class ModelRunner:
self, self,
batch, batch,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
sliding_window_size=self.sliding_window_size,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, batch.input_ids,
......
...@@ -453,10 +453,12 @@ class ServerArgs: ...@@ -453,10 +453,12 @@ class ServerArgs:
logger.info( logger.info(
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer." f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
) )
# FIXME: compatibility with radix attention
self.disable_radix_cache = True self.disable_radix_cache = True
# FIXME: compatibility with jump forward
self.disable_regex_jump_forward = True self.disable_regex_jump_forward = True
self.disable_flashinfer = False self.disable_flashinfer = False
self.disable_cuda_graph = True # FIXME: compatibility with chunked prefill
self.chunked_prefill_size = None self.chunked_prefill_size = None
......
...@@ -36,7 +36,7 @@ DEFAULT_PROMPTS = [ ...@@ -36,7 +36,7 @@ DEFAULT_PROMPTS = [
] ]
dirpath = os.path.dirname(__file__) dirpath = os.path.dirname(__file__)
with open(os.path.join(dirpath, "long_prompt"), "r") as f: with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
long_prompt = f.read() long_prompt = f.read()
DEFAULT_PROMPTS.append(long_prompt) DEFAULT_PROMPTS.append(long_prompt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment