Unverified Commit 5e49c3e7 authored by elvischenv's avatar elvischenv Committed by GitHub
Browse files

Bump Flashinfer to v0.4.0 (#26326)


Signed-off-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
parent 0d7c3cb5
...@@ -371,7 +371,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist ...@@ -371,7 +371,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
# Install FlashInfer from source # Install FlashInfer from source
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
# Keep this in sync with "flashinfer" extra in setup.py # Keep this in sync with "flashinfer" extra in setup.py
ARG FLASHINFER_GIT_REF="v0.3.1" ARG FLASHINFER_GIT_REF="v0.4.0"
# Flag to control whether to compile FlashInfer AOT kernels # Flag to control whether to compile FlashInfer AOT kernels
# Set to "true" to enable AOT compilation: # Set to "true" to enable AOT compilation:
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ... # docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
...@@ -392,7 +392,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' ...@@ -392,7 +392,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
fi fi
pushd flashinfer pushd flashinfer
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ] && [ "${FLASHINFER_GIT_REF}" = "v0.3.1" ]; then
# NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh # NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh
echo "🏗️ Installing FlashInfer from pre-compiled wheel" echo "🏗️ Installing FlashInfer from pre-compiled wheel"
uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \ uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \
......
...@@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2. ...@@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
# build flashinfer for torch nightly from source around 10 mins # build flashinfer for torch nightly from source around 10 mins
# release version: v0.3.1 # release version: v0.4.0
# todo(elainewy): cache flashinfer build result for faster build # todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \ RUN --mount=type=cache,target=/root/.cache/ccache \
...@@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ ...@@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
echo "git clone flashinfer..." \ echo "git clone flashinfer..." \
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \ && git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \ && cd flashinfer \
&& git checkout v0.3.1 \ && git checkout v0.4.0 \
&& git submodule update --init --recursive \ && git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \ && echo "finish git clone flashinfer..." \
&& rm -rf build \ && rm -rf build \
......
...@@ -715,7 +715,7 @@ setup( ...@@ -715,7 +715,7 @@ setup(
], # Required for audio processing ], # Required for audio processing
"video": [], # Kept for backwards compatibility "video": [], # Kept for backwards compatibility
# FlashInfer should be updated together with the Dockerfile # FlashInfer should be updated together with the Dockerfile
"flashinfer": ["flashinfer-python==0.3.1"], "flashinfer": ["flashinfer-python==0.4.0"],
# Optional deps for AMD FP4 quantization support # Optional deps for AMD FP4 quantization support
"petit-kernel": ["petit-kernel"], "petit-kernel": ["petit-kernel"],
}, },
......
...@@ -7,9 +7,8 @@ import pytest ...@@ -7,9 +7,8 @@ import pytest
import torch import torch
from tests.kernels.quantization.nvfp4_utils import ( from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype, dequantize_nvfp4_to_dtype,
get_nvfp4_global_scale,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import round_up from vllm.utils import round_up
...@@ -171,13 +170,12 @@ def test_flashinfer_trtllm_decode_with_baseline( ...@@ -171,13 +170,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype) output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output) wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0 o_scale = 1.0
o_sf_scale = None o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE: if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output) _, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE: elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ( o_sf_scale = get_nvfp4_global_scale(output)
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1) o_sf_scale_float = o_sf_scale.item()
).to(torch.float32)
# TRTLLM Decode # TRTLLM Decode
if o_quant_dtype == FP4_DTYPE: if o_quant_dtype == FP4_DTYPE:
...@@ -204,7 +202,7 @@ def test_flashinfer_trtllm_decode_with_baseline( ...@@ -204,7 +202,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
bmm1_scale=q_scale * k_scale * sm_scale, bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale, bmm2_scale=v_scale / o_scale,
window_left=window_left, window_left=window_left,
o_sf_scale=o_sf_scale, o_sf_scale=o_sf_scale_float,
out=output_trtllm, out=output_trtllm,
) )
if o_quant_dtype == FP8_DTYPE: if o_quant_dtype == FP8_DTYPE:
...@@ -361,13 +359,12 @@ def test_flashinfer_trtllm_prefill_with_baseline( ...@@ -361,13 +359,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype) output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output) wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0 o_scale = 1.0
o_sf_scale = None o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE: if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output) _, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE: elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ( o_sf_scale = get_nvfp4_global_scale(output)
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1) o_sf_scale_float = o_sf_scale.item()
).to(torch.float32)
# TRTLLM Prefill # TRTLLM Prefill
if o_quant_dtype == FP4_DTYPE: if o_quant_dtype == FP4_DTYPE:
...@@ -398,7 +395,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( ...@@ -398,7 +395,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
cum_seq_lens_q=q_indptr, cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr, cum_seq_lens_kv=kv_indptr,
window_left=window_left, window_left=window_left,
o_sf_scale=o_sf_scale, o_sf_scale=o_sf_scale_float,
out=output_trtllm, out=output_trtllm,
) )
if o_quant_dtype == FP8_DTYPE: if o_quant_dtype == FP8_DTYPE:
......
...@@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype): ...@@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype):
return values.reshape(m, n * 2).to(dtype=dtype) return values.reshape(m, n * 2).to(dtype=dtype)
def get_nvfp4_global_scale(a: torch.Tensor):
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
def quant_nvfp4_tensor(a: torch.Tensor): def quant_nvfp4_tensor(a: torch.Tensor):
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to( a_global_scale = get_nvfp4_global_scale(a)
torch.float32
)
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale) a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
return a_quant, a_block_scale, a_global_scale return a_quant, a_block_scale, a_global_scale
...@@ -50,7 +50,7 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None): ...@@ -50,7 +50,7 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None):
with RemoteOpenAIServer( with RemoteOpenAIServer(
model, model,
server_args, server_args,
max_wait_seconds=1000, # Due to FlashInfer compile max_wait_seconds=1500, # Due to FlashInfer compile
override_hf_configs=dummy_hf_overrides, override_hf_configs=dummy_hf_overrides,
) as server: ) as server:
client = server.get_client() client = server.get_client()
......
...@@ -1199,7 +1199,7 @@ def fast_plan_decode( ...@@ -1199,7 +1199,7 @@ def fast_plan_decode(
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try: try:
# Make sure we pass exactly 15 arguments for tensor core version # Make sure we pass exactly 18 arguments for tensor core version
self._plan_info = self._cached_module.plan( self._plan_info = self._cached_module.plan(
self._float_workspace_buffer, self._float_workspace_buffer,
self._int_workspace_buffer, self._int_workspace_buffer,
...@@ -1216,6 +1216,9 @@ def fast_plan_decode( ...@@ -1216,6 +1216,9 @@ def fast_plan_decode(
head_dim, head_dim,
head_dim, head_dim,
False, # causal False, # causal
window_left,
-1, # fixed_split_size
False, # disable_split_kv
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e raise RuntimeError(f"Error in tensor core plan: {e}") from e
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment