Unverified Commit 20181372 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Feature] Batch invariant nvfp4 linear support (#39322)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent a776a48b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import random
import pytest
import torch
from utils import (
_extract_step_logprobs,
_random_prompt,
skip_unsupported,
)
from vllm import LLM, SamplingParams
pytestmark = pytest.mark.skipif(
not hasattr(torch, "float8_e4m3fn"),
reason="NVFP4 tests require torch.float8_e4m3fn support.",
)
NVFP4_TEST_MODEL = os.getenv(
"VLLM_TEST_NVFP4_MODEL", "nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4"
)
def _make_llm(max_num_seqs: int, backend: str) -> LLM:
return LLM(
model=NVFP4_TEST_MODEL,
max_num_seqs=max_num_seqs,
gpu_memory_utilization=float(
os.getenv("VLLM_NVFP4_TEST_GPU_MEMORY_UTILIZATION", "0.05")
),
max_model_len=int(os.getenv("VLLM_NVFP4_TEST_MAX_MODEL_LEN", "2048")),
dtype="auto",
tensor_parallel_size=int(os.getenv("VLLM_NVFP4_TEST_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=True,
attention_config={"backend": backend},
)
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
def test_dense_nvfp4_generation_is_deterministic_across_batch_sizes_e2e(backend):
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
num_trials = int(os.getenv("VLLM_NVFP4_NEEDLE_TRIALS", "2"))
max_batch_size = int(os.getenv("VLLM_NVFP4_NEEDLE_BATCH_SIZE", "8"))
min_random_prompt = int(os.getenv("VLLM_NVFP4_MIN_PROMPT", "32"))
max_random_prompt = int(os.getenv("VLLM_NVFP4_MAX_PROMPT", "96"))
assert max_batch_size >= 2, "Batch size should be >= 2 to test invariance."
sampling = SamplingParams(
temperature=float(os.getenv("VLLM_NVFP4_NEEDLE_TEMPERATURE", "0.6")),
top_p=float(os.getenv("VLLM_NVFP4_NEEDLE_TOP_P", "0.95")),
max_tokens=int(os.getenv("VLLM_NVFP4_NEEDLE_MAX_TOKENS", "16")),
seed=20240919,
logprobs=5,
)
needle_prompt = "Write one factual sentence about the moon."
llm = None
baseline_completion = None
baseline_logprobs = None
try:
llm = _make_llm(max_num_seqs=max_batch_size, backend=backend)
baseline_output = llm.generate([needle_prompt], sampling, use_tqdm=False)[0]
baseline_completion = baseline_output.outputs[0]
baseline_logprobs, baseline_token_ids = _extract_step_logprobs(baseline_output)
assert baseline_logprobs is not None
assert baseline_token_ids is not None
for _ in range(num_trials):
batch_size = random.randint(max_batch_size // 2, max_batch_size)
needle_pos = random.randint(0, batch_size - 1)
prompts: list[str] = []
for idx in range(batch_size):
if idx == needle_pos:
prompts.append(needle_prompt)
else:
prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
outputs = llm.generate(prompts, sampling, use_tqdm=False)
needle_output = outputs[needle_pos]
needle_completion = needle_output.outputs[0]
needle_logprobs, needle_token_ids = _extract_step_logprobs(needle_output)
assert needle_logprobs is not None
assert needle_token_ids is not None
assert needle_output.prompt == needle_prompt
assert baseline_completion is not None
assert baseline_logprobs is not None
assert needle_completion.token_ids == baseline_completion.token_ids
assert needle_completion.text == baseline_completion.text
torch.testing.assert_close(needle_logprobs, baseline_logprobs)
finally:
if llm is not None:
with contextlib.suppress(Exception):
llm.shutdown()
...@@ -109,6 +109,13 @@ def select_nvfp4_linear_backend() -> NvFp4LinearBackend: ...@@ -109,6 +109,13 @@ def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
Select the best available NVFP4 GEMM backend based on environment Select the best available NVFP4 GEMM backend based on environment
configuration and platform capabilities. configuration and platform capabilities.
""" """
if envs.VLLM_BATCH_INVARIANT:
logger.info_once(
"VLLM_BATCH_INVARIANT forces NVFP4 linear to use the emulation "
"backend for deterministic execution."
)
return NvFp4LinearBackend.EMULATION
selected_backend: NvFp4LinearBackend | None = None selected_backend: NvFp4LinearBackend | None = None
if envs.VLLM_USE_FBGEMM: if envs.VLLM_USE_FBGEMM:
...@@ -259,6 +266,8 @@ def apply_nvfp4_linear( ...@@ -259,6 +266,8 @@ def apply_nvfp4_linear(
alpha = layer.alpha alpha = layer.alpha
output_size = layer.output_size_per_partition output_size = layer.output_size_per_partition
input_size = layer.input_size_per_partition input_size = layer.input_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
if backend == NvFp4LinearBackend.MARLIN: if backend == NvFp4LinearBackend.MARLIN:
return apply_fp4_marlin_linear( return apply_fp4_marlin_linear(
...@@ -272,20 +281,19 @@ def apply_nvfp4_linear( ...@@ -272,20 +281,19 @@ def apply_nvfp4_linear(
bias=bias, bias=bias,
) )
elif backend == NvFp4LinearBackend.EMULATION: elif backend == NvFp4LinearBackend.EMULATION:
x_2d = x.reshape(-1, x.shape[-1])
out = run_nvfp4_emulations( out = run_nvfp4_emulations(
x=x, x=x_2d,
input_global_scale=input_global_scale_inv, input_global_scale=input_global_scale_inv,
weight=weight, weight=weight,
weight_scale_swizzled=weight_scale, weight_scale_swizzled=weight_scale,
weight_global_scale=weight_global_scale, weight_global_scale=weight_global_scale,
swizzle=swizzle, swizzle=swizzle,
) )
out = out[:, :output_size]
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out return out.view(*output_shape)
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
# Quantize BF16 or FP16 to (FP4 and interleaved block scale) # Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment