Unverified Commit 6b6e9877 authored by Jason Li's avatar Jason Li Committed by GitHub
Browse files

[NVIDIA] flashinfer TRTLLM attention prefill token limit (#25998)


Signed-off-by: default avatarjasonlizhengjian <jason.li@centml.ai>
Signed-off-by: default avatarjasonlizhengjian <jasonlizhengjian@gmail.com>
parent 9c3c21c5
......@@ -283,11 +283,18 @@ def use_trtllm_attention(
if force_use_trtllm is None:
# Environment variable not set - use auto-detection
use_trtllm = (
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
)
if use_trtllm:
logger.warning_once("Using TRTLLM attention (auto-detected).")
if is_prefill:
# Prefill auto-detection
use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
else:
# Decode auto-detection
use_trtllm = (
num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
)
if use_trtllm:
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm
# Environment variable is set to 1 - respect it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment