Unverified Commit 53328d75 authored by LI MOU's avatar LI MOU Committed by GitHub
Browse files

[BUG] fix crash on flashinfer backend with cudagraph disabled, when attention...

[BUG] fix crash on flashinfer backend with cudagraph disabled, when attention group_size not in [1,2,4,8] (#7509)
parent c75363fb
...@@ -4,7 +4,7 @@ import flashinfer ...@@ -4,7 +4,7 @@ import flashinfer
import pytest import pytest
import torch import torch
NUM_HEADS = [(16, 16), (32, 8), (64, 8)] NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16] DTYPES = [torch.float16, torch.bfloat16]
...@@ -123,7 +123,10 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], ...@@ -123,7 +123,10 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\ wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=(
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
)
wrapper.begin_forward(kv_indptr, wrapper.begin_forward(kv_indptr,
kv_indices, kv_indices,
kv_last_page_lens, kv_last_page_lens,
......
...@@ -113,7 +113,8 @@ class FlashInferState(AttentionState): ...@@ -113,7 +113,8 @@ class FlashInferState(AttentionState):
self.runner.parallel_config)) self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads( num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config) self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4 use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(), self._get_workspace_buffer(),
"NHD", "NHD",
...@@ -171,7 +172,8 @@ class FlashInferState(AttentionState): ...@@ -171,7 +172,8 @@ class FlashInferState(AttentionState):
self.runner.parallel_config)) self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads( num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config) self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4 use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
self._graph_decode_wrapper = \ self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper( CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_decode_workspace_buffer, _indptr_buffer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment