Unverified Commit 191e3fda authored by bai's avatar bai Committed by GitHub
Browse files

Update flashinfer to 0.6.8 (#39959)


Signed-off-by: default avatarbai <v@gor.io>
parent b9cf629b
...@@ -580,33 +580,12 @@ RUN --mount=type=cache,target=/root/.cache/uv \ ...@@ -580,33 +580,12 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Install FlashInfer JIT cache (requires CUDA-version-specific index URL) # Install FlashInfer JIT cache (requires CUDA-version-specific index URL)
# https://docs.flashinfer.ai/installation.html # https://docs.flashinfer.ai/installation.html
# From versions.json: .flashinfer.version # From versions.json: .flashinfer.version
# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798) ARG FLASHINFER_VERSION=0.6.8.post1
# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on
# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738
ARG FLASHINFER_VERSION=0.6.7
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \ uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
&& flashinfer show-config && flashinfer show-config \
&& flashinfer download-cubin
# Pre-download FlashInfer TRTLLM BMM headers for air-gapped environments.
# At runtime, MoE JIT compilation downloads these from edge.urm.nvidia.com
# which fails without internet. This step caches them at build time.
RUN python3 <<'PYEOF'
from flashinfer.jit import env as jit_env
from flashinfer.jit.cubin_loader import download_trtllm_headers, get_cubin
from flashinfer.artifacts import ArtifactPath, CheckSumHash
download_trtllm_headers(
'bmm',
jit_env.FLASHINFER_CUBIN_DIR / 'flashinfer' / 'trtllm' / 'batched_gemm' / 'trtllmGen_bmm_export',
f'{ArtifactPath.TRTLLM_GEN_BMM}/include/trtllmGen_bmm_export',
ArtifactPath.TRTLLM_GEN_BMM,
get_cubin(f'{ArtifactPath.TRTLLM_GEN_BMM}/checksums.txt', CheckSumHash.TRTLLM_GEN_BMM),
)
print('FlashInfer TRTLLM BMM headers downloaded successfully')
PYEOF
# ============================================================ # ============================================================
# OPENAI API SERVER DEPENDENCIES # OPENAI API SERVER DEPENDENCIES
......
...@@ -217,16 +217,13 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2. ...@@ -217,16 +217,13 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
# build flashinfer for torch nightly from source around 10 mins # build flashinfer for torch nightly from source around 10 mins
# release version: v0.6.7 # release version: v0.6.8.post1
# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798)
# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on
# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738
# todo(elainewy): cache flashinfer build result for faster build # todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \ RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \ --mount=type=cache,target=/root/.cache/uv \
echo "git clone flashinfer..." \ echo "git clone flashinfer..." \
&& git clone --depth 1 --branch v0.6.7 --recursive https://github.com/flashinfer-ai/flashinfer.git \ && git clone --depth 1 --branch v0.6.8.post1 --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \ && cd flashinfer \
&& git submodule update --init --recursive \ && git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \ && echo "finish git clone flashinfer..." \
......
...@@ -65,7 +65,7 @@ ...@@ -65,7 +65,7 @@
"default": "true" "default": "true"
}, },
"FLASHINFER_VERSION": { "FLASHINFER_VERSION": {
"default": "0.6.7" "default": "0.6.8.post1"
}, },
"GDRCOPY_CUDA_VERSION": { "GDRCOPY_CUDA_VERSION": {
"default": "12.8" "default": "12.8"
......
...@@ -169,7 +169,7 @@ Priority is **1 = highest** (tried first). ...@@ -169,7 +169,7 @@ Priority is **1 = highest** (tried first).
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ | | ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A | | `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512 | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x | | `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.0 | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
......
...@@ -9,8 +9,8 @@ torchaudio==2.11.0 ...@@ -9,8 +9,8 @@ torchaudio==2.11.0
# These must be updated alongside torch # These must be updated alongside torch
torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision==0.26.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.6.7 flashinfer-python==0.6.8.post1
flashinfer-cubin==0.6.7 flashinfer-cubin==0.6.8.post1
# Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to # Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to
# breaking changes in 1.19.0 # breaking changes in 1.19.0
nvidia-cudnn-frontend>=1.13.0,<1.19.0 nvidia-cudnn-frontend>=1.13.0,<1.19.0
......
...@@ -62,7 +62,7 @@ def test_supports_batch_invariant_disables(): ...@@ -62,7 +62,7 @@ def test_supports_batch_invariant_disables():
@patch("vllm.envs.VLLM_BATCH_INVARIANT", False) @patch("vllm.envs.VLLM_BATCH_INVARIANT", False)
@patch( @patch(
"vllm.utils.flashinfer.current_platform.is_device_capability", "vllm.utils.flashinfer.current_platform.is_device_capability_family",
return_value=True, return_value=True,
) )
@patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True) @patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True)
......
...@@ -548,7 +548,9 @@ def test_trtllm_gen_mxfp4_fused_moe( ...@@ -548,7 +548,9 @@ def test_trtllm_gen_mxfp4_fused_moe(
hidden_states, hidden_states_scale = mxfp8_quantize( hidden_states, hidden_states_scale = mxfp8_quantize(
hidden_states, is_sf_swizzled_layout=False hidden_states, is_sf_swizzled_layout=False
) )
hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1) hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
else: else:
hidden_states_scale = None hidden_states_scale = None
...@@ -595,20 +597,20 @@ def test_trtllm_gen_mxfp4_fused_moe( ...@@ -595,20 +597,20 @@ def test_trtllm_gen_mxfp4_fused_moe(
if beta is not None: if beta is not None:
beta = torch.full((num_experts,), beta, device=hidden_states.device) beta = torch.full((num_experts,), beta, device=hidden_states.device)
tg_result = tg_mxfp4_moe( tg_result = tg_mxfp4_moe(
router_logits, router_logits=router_logits,
topk, topk=topk,
num_experts, num_experts=num_experts,
intermediate_size, intermediate_size=intermediate_size,
hidden_size, hidden_size=hidden_size,
hidden_states, hidden_states=hidden_states,
hidden_states_scale, hidden_states_scale=hidden_states_scale,
w13, w13_weight=w13,
w13_scale, w13_weight_scale=w13_scale,
bias13, w13_bias=bias13,
w2, w2_weight=w2,
w2_scale, w2_weight_scale=w2_scale,
bias2, w2_bias=bias2,
act_type, act_type=act_type,
alpha=alpha, alpha=alpha,
beta=beta, beta=beta,
limit=limit, limit=limit,
......
...@@ -130,14 +130,7 @@ class FlashInferExperts(mk.FusedMoEExpertsModular): ...@@ -130,14 +130,7 @@ class FlashInferExperts(mk.FusedMoEExpertsModular):
p.is_device_capability(90) p.is_device_capability(90)
or p.is_device_capability_family(100) or p.is_device_capability_family(100)
or p.is_device_capability_family(110) or p.is_device_capability_family(110)
or p.is_device_capability(120) or p.is_device_capability_family(120)
# NOTE: SM121 (DGX Spark) is excluded because the bf16
# unquantized CUTLASS MoE GEMM in flashinfer <= 0.6.7 has no
# Relu2 template instantiation and throws "Invalid activation
# type" on Nemotron-H. Fixed upstream by
# https://github.com/flashinfer-ai/flashinfer/pull/2926
# (merged 2026-04-01, not yet in a stable release); lift this
# restriction once flashinfer >= 0.6.8 is the minimum.
) )
and has_flashinfer_cutlass_fused_moe() and has_flashinfer_cutlass_fused_moe()
) )
......
...@@ -305,10 +305,9 @@ def supports_trtllm_attention() -> bool: ...@@ -305,10 +305,9 @@ def supports_trtllm_attention() -> bool:
if envs.VLLM_BATCH_INVARIANT: if envs.VLLM_BATCH_INVARIANT:
return False return False
# TRTLLM attention is currently only validated on SM100 (CC 10.0). return (
# SM103 (GB300) hangs with FlashInfer >= 0.6.7. current_platform.is_device_capability_family(100) and has_nvidia_artifactory()
# See: https://github.com/flashinfer-ai/flashinfer/issues/2939 )
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
def force_use_trtllm_attention() -> bool | None: def force_use_trtllm_attention() -> bool | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment