"tests/vscode:/vscode.git/clone" did not exist on "7015417fd4910a47263ea34c79c2cdb2ff314fdf"
Unverified Commit 5f1ac1e1 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Revert "[v1] Add fp32 support to v1 engine through flex attn" (#19404)

parent 9368cc90
...@@ -183,34 +183,6 @@ def test_env( ...@@ -183,34 +183,6 @@ def test_env(
assert backend.get_name() == expected assert backend.get_name() == expected
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("use_v1", [True, False])
def test_fp32_fallback(
device: str,
use_v1: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Test attention backend selection with fp32."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
if use_v1 else "TORCH_SDPA")
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
assert (backend.get_name() == "FLEX_ATTENTION"
if use_v1 else "XFORMERS")
def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to # TODO: When testing for v1, pipe in `use_v1` as an argument to
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import random import random
import pytest import pytest
import torch
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
...@@ -400,7 +399,6 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): ...@@ -400,7 +399,6 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} must come before the current layer" error_msg = f"{layer_1} must come before the current layer"
...@@ -429,7 +427,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): ...@@ -429,7 +427,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
invalid_layer = "model.layers.0.cross_attn.attn" invalid_layer = "model.layers.0.cross_attn.attn"
...@@ -458,7 +455,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): ...@@ -458,7 +455,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
def test_init_kv_cache_with_kv_sharing_target_same_as_current(): def test_init_kv_cache_with_kv_sharing_target_same_as_current():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
error_msg = f"{layer_1} cannot be the same as the current layer" error_msg = f"{layer_1} cannot be the same as the current layer"
...@@ -487,7 +483,6 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): ...@@ -487,7 +483,6 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
def test_init_kv_cache_without_kv_sharing(): def test_init_kv_cache_without_kv_sharing():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config() vllm_config = get_vllm_config()
...@@ -555,7 +550,6 @@ def test_init_kv_cache_without_kv_sharing(): ...@@ -555,7 +550,6 @@ def test_init_kv_cache_without_kv_sharing():
def test_init_kv_cache_with_kv_sharing_valid(): def test_init_kv_cache_with_kv_sharing_valid():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn" layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn"
vllm_config = get_vllm_config() vllm_config = get_vllm_config()
......
...@@ -1337,6 +1337,13 @@ class EngineArgs: ...@@ -1337,6 +1337,13 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# Only Fp16 and Bf16 dtypes since we only support FA.
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
if model_config.dtype not in V1_SUPPORTED_DTYPES:
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
recommend_to_remove=False)
return False
# No Embedding Models so far. # No Embedding Models so far.
if model_config.task not in ["generate"]: if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}", _raise_or_fallback(feature_name=f"--task {model_config.task}",
......
...@@ -233,10 +233,6 @@ class CudaPlatformBase(Platform): ...@@ -233,10 +233,6 @@ class CudaPlatformBase(Platform):
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends." return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend") "triton_attn.TritonAttentionBackend")
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
f"Using FlexAttenion backend for {dtype} on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if cls.is_device_capability(100): if cls.is_device_capability(100):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed # Prefer FlashInfer for V1 on Blackwell GPUs if installed
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment