Unverified Commit 191e3fda authored by bai's avatar bai Committed by GitHub
Browse files

Update flashinfer to 0.6.8 (#39959)


Signed-off-by: default avatarbai <v@gor.io>
parent b9cf629b
......@@ -580,33 +580,12 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
# https://docs.flashinfer.ai/installation.html
# From versions.json: .flashinfer.version
# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798)
# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on
# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738
ARG FLASHINFER_VERSION=0.6.7
ARG FLASHINFER_VERSION=0.6.8.post1
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
&& flashinfer show-config
# Pre-download FlashInfer TRTLLM BMM headers for air-gapped environments.
# At runtime, MoE JIT compilation downloads these from edge.urm.nvidia.com
# which fails without internet. This step caches them at build time.
RUN python3 <<'PYEOF'
from flashinfer.jit import env as jit_env
from flashinfer.jit.cubin_loader import download_trtllm_headers, get_cubin
from flashinfer.artifacts import ArtifactPath, CheckSumHash
download_trtllm_headers(
'bmm',
jit_env.FLASHINFER_CUBIN_DIR / 'flashinfer' / 'trtllm' / 'batched_gemm' / 'trtllmGen_bmm_export',
f'{ArtifactPath.TRTLLM_GEN_BMM}/include/trtllmGen_bmm_export',
ArtifactPath.TRTLLM_GEN_BMM,
get_cubin(f'{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt', CheckSumHash.TRTLLM_GEN_BMM),
)
print('FlashInfer TRTLLM BMM headers downloaded successfully')
PYEOF
&& flashinfer show-config \
&& flashinfer download-cubin
# ============================================================
# OPENAI API SERVER DEPENDENCIES
......
......@@ -217,16 +217,13 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
# build flashinfer for torch nightly from source around 10 mins
# release version: v0.6.7
# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798)
# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on
# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738
# release version: v0.6.8.post1
# todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo "git clone flashinfer..." \
&& git clone --depth 1 --branch v0.6.7 --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& git clone --depth 1 --branch v0.6.8.post1 --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \
&& git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \
......
......@@ -65,7 +65,7 @@
"default": "true"
},
"FLASHINFER_VERSION": {
"default": "0.6.7"
"default": "0.6.8.post1"
},
"GDRCOPY_CUDA_VERSION": {
"default": "12.8"
......
......@@ -169,7 +169,7 @@ Priority is **1 = highest** (tried first).
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.0 |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
......
......@@ -9,8 +9,8 @@ torchaudio==2.11.0
# These must be updated alongside torch
torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.7
flashinfer-cubin==0.6.7
flashinfer-python==0.6.8.post1
flashinfer-cubin==0.6.8.post1
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
# breaking changes in 1.19.0
nvidia-cudnn-frontend>=1.13.0,<1.19.0
......
......@@ -62,7 +62,7 @@ def test_supports_batch_invariant_disables():
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch(
"vllm.utils.flashinfer.current_platform.is_device_capability",
"vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True,
)
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True)
......
......@@ -548,7 +548,9 @@ def test_trtllm_gen_mxfp4_fused_moe(
hidden_states, hidden_states_scale = mxfp8_quantize(
hidden_states, is_sf_swizzled_layout=False
)
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1)
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
else:
hidden_states_scale = None
......@@ -595,20 +597,20 @@ def test_trtllm_gen_mxfp4_fused_moe(
if beta is not None:
beta = torch.full((num_experts,), beta, device=hidden_states.device)
tg_result = tg_mxfp4_moe(
router_logits,
topk,
num_experts,
intermediate_size,
hidden_size,
hidden_states,
hidden_states_scale,
w13,
w13_scale,
bias13,
w2,
w2_scale,
bias2,
act_type,
router_logits=router_logits,
topk=topk,
num_experts=num_experts,
intermediate_size=intermediate_size,
hidden_size=hidden_size,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
w13_weight=w13,
w13_weight_scale=w13_scale,
w13_bias=bias13,
w2_weight=w2,
w2_weight_scale=w2_scale,
w2_bias=bias2,
act_type=act_type,
alpha=alpha,
beta=beta,
limit=limit,
......
......@@ -130,14 +130,7 @@ class FlashInferExperts(mk.FusedMoEExpertsModular):
p.is_device_capability(90)
or p.is_device_capability_family(100)
or p.is_device_capability_family(110)
or p.is_device_capability(120)
# NOTE: SM121 (DGX Spark) is excluded because the bf16
# unquantized CUTLASS MoE GEMM in flashinfer <= 0.6.7 has no
# Relu2 template instantiation and throws "Invalid activation
# type" on Nemotron-H. Fixed upstream by
# https://github.com/flashinfer-ai/flashinfer/pull/2926
# (merged 2026-04-01, not yet in a stable release); lift this
# restriction once flashinfer >= 0.6.8 is the minimum.
or p.is_device_capability_family(120)
)
and has_flashinfer_cutlass_fused_moe()
)
......
......@@ -305,10 +305,9 @@ def supports_trtllm_attention() -> bool:
if envs.VLLM_BATCH_INVARIANT:
return False
# TRTLLM attention is currently only validated on SM100 (CC 10.0).
# SM103 (GB300) hangs with FlashInfer >= 0.6.7.
# See: https://github.com/flashinfer-ai/flashinfer/issues/2939
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
return (
current_platform.is_device_capability_family(100) and has_nvidia_artifactory()
)
def force_use_trtllm_attention() -> bool | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment