Unverified Commit 9606c719 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Revert #7509 (#7887)

parent 64cc6444
......@@ -113,8 +113,7 @@ class FlashInferState(AttentionState):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
......@@ -172,8 +171,7 @@ class FlashInferState(AttentionState):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment