Unverified Commit af473f0a authored by Po-Han Huang (NVIDIA)'s avatar Po-Han Huang (NVIDIA) Committed by GitHub
Browse files

[bugfix] Fix Llama3/4 issues caused by FlashInfer 0.2.10 (#22426)


Signed-off-by: default avatarPo-Han Huang <pohanh@nvidia.com>
parent 157f9c13
...@@ -6,14 +6,22 @@ import torch ...@@ -6,14 +6,22 @@ import torch
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# Guess tokens per expert assuming perfect expert distribution first. # TODO: Revert this to dynamic calculation once a new version of FlashInfer
num_tokens_per_expert = (num_tokens * top_k) // num_experts # with the necessary kernels is released.
# And pad the number to the next power of 2. tile_tokens_dim = 8
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. # from flashinfer import next_positive_power_of_2
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
# # Guess tokens per expert assuming perfect expert distribution first.
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
# # And pad the number to the next power of 2.
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
# # kernel.
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim return tile_tokens_dim
......
...@@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
head_dim = self.kv_cache_spec.head_size head_dim = self.kv_cache_spec.head_size
# currently prefill trtllm attention does not support fp8 kv cache # currently prefill trtllm attention does not support fp8 kv cache
prefill_use_trtllm = use_trtllm_attention( prefill_use_trtllm = not cache_dtype.startswith("fp8") \
and use_trtllm_attention(
num_prefill_tokens, max_seq_len, cache_dtype, num_prefill_tokens, max_seq_len, cache_dtype,
num_qo_heads, num_kv_heads, head_dim) num_qo_heads, num_kv_heads, head_dim)
decode_use_trtllm = use_trtllm_attention( decode_use_trtllm = use_trtllm_attention(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment