Unverified Commit 9f0247cf authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

`VLLM_USE_TRITON_FLASH_ATTN` V0 variable deprecation (#27611)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
Signed-off-by: default avatarAndreas Karatzas <Andreas.Karatzas@amd.com>
parent 7f829be7
...@@ -78,17 +78,13 @@ HF_MOUNT="/root/.cache/huggingface" ...@@ -78,17 +78,13 @@ HF_MOUNT="/root/.cache/huggingface"
commands=$@ commands=$@
echo "Commands:$commands" echo "Commands:$commands"
if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"pytest -v -s basic_correctness/test_basic_correctness.py"}
commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"}
fi
if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then
commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"}
fi fi
if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"pytest -v -s compile/test_basic_correctness.py"}
commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"}
fi
if [[ $commands == *"pytest -v -s lora"* ]]; then if [[ $commands == *"pytest -v -s lora"* ]]; then
commands=${commands//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"} commands=${commands//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the triton_flash_attention kernel
Run `pytest tests/kernels/test_triton_flash_attention.py`.
"""
import pytest
import torch
from vllm.attention.ops.triton_flash_attention import (
SUPPORTED_LAYOUTS,
MetaData,
compute_alibi_tensor,
scale_fp8,
triton_attention_rocm,
)
from vllm.platforms import current_platform
class ReferenceAttention:
def __init__(
self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata
):
self.Z = Z
self.HQ = HQ
self.HK = HK
self.N_CTX_Q = N_CTX_Q
self.N_CTX_K = N_CTX_K
self.D_HEAD = D_HEAD
self.use_alibi = use_alibi
self.dtype = dtype
self.input_metadata = input_metadata
def fwd(self, q, k, v):
scores = (
torch.einsum("bhqd,bhkd->bhqk", q, k).float() * self.input_metadata.sm_scale
)
if self.input_metadata.causal:
mask = torch.tril(
torch.ones(self.N_CTX_Q, self.N_CTX_K, device="cuda"),
diagonal=self.N_CTX_K - self.N_CTX_Q,
)
scores[:, :, mask == 0] = float("-inf")
if self.input_metadata.bias is not None:
scores += self.input_metadata.bias
if self.use_alibi:
scores += compute_alibi_tensor(
self.input_metadata.alibi_slopes, self.N_CTX_Q, self.N_CTX_K
)
p = torch.softmax(scores, dim=-1)
if self.input_metadata.causal:
# If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going
# into softmax. This creates a row of NaNs as -inf - -inf == NaN.
# So we fix this by converting the NaNs to 0s, which is what they
# should be out of the softmax.
nan_mask = torch.isnan(p)
p[nan_mask == 1] = 0
ref_out = torch.einsum("bhqk,bhkd->bhqd", p.to(self.dtype), v)
# compare
if self.input_metadata.layout == "bshd":
ref_out = ref_out.transpose(1, 2).clone()
return ref_out
def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
self.dtype
)
k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
self.dtype
)
v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
self.dtype
)
result = self.fwd(q, k, v)
if self.input_metadata.o_scale is not None:
result, _ = scale_fp8(result, self.input_metadata.o_scale)
return result
def fwd_fp8_kv(self, q, k_quantized, v_quantized):
k_descale, v_descale = (
self.input_metadata.k_descale,
self.input_metadata.v_descale,
)
k_dequantized = (
k_quantized.to(torch.float32) * k_descale.to(torch.float32)
).to(self.dtype)
v_dequantized = (
v_quantized.to(torch.float32) * v_descale.to(torch.float32)
).to(self.dtype)
return self.fwd(q, k_dequantized, v_dequantized)
def varlen_fwd(self, q, k, v, is_mqa=False):
ref_out = torch.empty_like(q)
if is_mqa:
# Make KV look like HQ/HK "groups" of HK. Later, we will reshape so
# the size aligns with Q.
k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(
-1, -1, self.HQ // self.HK, -1
)
v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(
-1, -1, self.HQ // self.HK, -1
)
else:
k_ref = k
v_ref = v
for i in range(0, self.input_metadata.num_contexts):
start_q, start_k = (
self.input_metadata.cu_seqlens_q[i],
self.input_metadata.cu_seqlens_k[i],
)
end_q, end_k = (
self.input_metadata.cu_seqlens_q[i + 1],
self.input_metadata.cu_seqlens_k[i + 1],
)
k_curr = k_ref[start_k:end_k]
v_curr = v_ref[start_k:end_k]
if is_mqa:
k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3])
v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3])
scores = torch.einsum("qhd,khd->qhk", q[start_q:end_q], k_curr).float()
p = torch.softmax(scores * self.input_metadata.sm_scale, dim=-1).half()
ref_out[start_q:end_q] = torch.einsum("qhk,khd->qhd", p, v_curr)
return ref_out
def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False):
q_descale = None
if not fp8_kv:
q, q_descale = scale_fp8(q)
k, k_descale = scale_fp8(k)
v, v_descale = scale_fp8(v)
# In real world use case, the p scale would be a parameter trained by the
# model.
p_scale = None
o_scale = torch.rand(1, device="cuda", requires_grad=False) if use_o_scale else None
return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale
def input_helper(
Z,
HQ,
HK,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
layout=None,
use_alibi=None,
causal=None,
is_fp8=False,
fp8_kv=False,
use_o_scale=False,
use_bias=False,
):
assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
current_platform.seed_everything(0)
# Initialize q, k, v
if layout == "bhsd":
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
elif layout == "bshd":
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
if use_alibi:
# for n heads the set of slopes is the geometric sequence that starts
# 2^(-8/n)
alibi_slopes = torch.tensor(
[2 ** (-8 / HQ * i) for i in range(1, HQ + 1)],
dtype=torch.float32,
device="cuda",
).repeat(Z, 1)
else:
alibi_slopes = None
if use_bias:
bias = torch.randn(
(1, HQ, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda", requires_grad=False
)
else:
bias = None
q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False)
if is_fp8:
(q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale) = quantize_input(
q, k, v, use_o_scale=use_o_scale, fp8_kv=fp8_kv
)
else:
q_descale = k_descale = v_descale = p_scale = o_scale = None
input_metadata = MetaData(
sm_scale=D_HEAD**-0.5,
max_seqlens_q=N_CTX_Q,
max_seqlens_k=N_CTX_K,
layout=layout,
alibi_slopes=alibi_slopes,
alibi_batch=Z,
alibi_nheads=HQ,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=o_scale,
bias=bias,
seqlen_q=N_CTX_Q,
seqlen_k=N_CTX_K,
)
return q, k, v, input_metadata
def varlen_input_helper(
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False
):
current_platform.seed_everything(0)
# Random sequence lengths. Using N_CTX as kind of max of sum of individual
# seqs
if not equal_seqlens:
max_seqlens_q = N_CTX_Q // Z
max_seqlens_k = N_CTX_K // Z
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32)
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32)
else:
seqlens_q = torch.full((Z,), N_CTX_Q // Z)
seqlens_k = torch.full((Z,), N_CTX_K // Z)
# Calculate cumulative sequence lengths
cu_seqlens_q = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
seqlens_q.cumsum(dim=0, dtype=torch.int32),
]
)
cu_seqlens_k = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
seqlens_k.cumsum(dim=0, dtype=torch.int32),
]
)
cu_seqlens_q = cu_seqlens_q.to(device="cuda")
cu_seqlens_k = cu_seqlens_k.to(device="cuda")
# Initialize q, k, v with variable lengths
total_q = cu_seqlens_q[-1].item()
total_k = cu_seqlens_k[-1].item()
q = (
torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
k = (
torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
v = (
torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda")
.normal_(mean=0.0, std=0.5)
.requires_grad_()
)
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
return q, k, v, input_metadata
@pytest.mark.parametrize(
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD",
[
(1, 48, 12, 1, 1, 64),
(4, 4, 4, 128, 128, 65),
(16, 48, 48, 1, 1, 128),
(64, 48, 24, 3, 3, 128),
(4, 4, 4, 113, 123, 1),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("use_alibi", [True, False])
@pytest.mark.parametrize("layout", ["bshd"])
def test_op_fwd(
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16
):
current_platform.seed_everything(0)
q, k, v, input_metadata = input_helper(
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, use_alibi, causal
)
o = torch.empty_like(q)
# triton implementation
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
# Transpose here if layout is bshd so we have same reference code for all
# layouts
if layout == "bshd":
q = q.transpose(1, 2).clone()
k = k.transpose(1, 2).clone()
v = v.transpose(1, 2).clone()
# Replicate K and V if using MQA/GQA
if HQ != HK:
k = (
k.view(k.shape[0], k.shape[1], -1, k.shape[2], k.shape[3])
.expand(-1, -1, HQ // HK, -1, -1)
.reshape(k.shape[0], -1, k.shape[2], k.shape[3])
)
v = (
v.view(v.shape[0], v.shape[1], -1, v.shape[2], v.shape[3])
.expand(-1, -1, HQ // HK, -1, -1)
.reshape(v.shape[0], -1, v.shape[2], v.shape[3])
)
ref_impl = ReferenceAttention(
Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata
)
ref_out = ref_impl.fwd(q, k, v)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
@pytest.mark.parametrize(
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
[
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("layout", ["bhsd"])
@pytest.mark.parametrize("use_o_scale", [True, False])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="Triton FP8 requires CUDA 9.0 or higher",
)
def test_op_fwd_fp8(
Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, use_o_scale, dtype=torch.float32
):
current_platform.seed_everything(0)
# Disable grad to save memory it won't run into OOM on CI machine.
# q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD,
# dtype, layout)
q_quantized, k_quantized, v_quantized, input_metadata = input_helper(
Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
causal=causal,
layout=layout,
is_fp8=True,
use_o_scale=use_o_scale,
)
o = torch.empty_like(q_quantized) if use_o_scale else None
tri_out, _ = triton_attention_rocm(
q_quantized, k_quantized, v_quantized, o, input_metadata
)
ref_impl = ReferenceAttention(
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
)
ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized)
# compare
torch.testing.assert_close(
ref_out.to(torch.float32), tri_out.to(torch.float32), atol=7e-2, rtol=2e-1
)
@pytest.mark.parametrize(
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
[
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
(4, 4, 113, 123, 1),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("layout", ["bhsd"])
def test_op_fwd_fp8_kv(
Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float32
):
current_platform.seed_everything(0)
q, k_quantized, v_quantized, input_metadata = input_helper(
Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
causal=causal,
layout=layout,
is_fp8=True,
fp8_kv=True,
)
o = torch.empty_like(q)
tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, input_metadata)
ref_impl = ReferenceAttention(
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
)
ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized)
torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1)
@pytest.mark.parametrize(
"Z, H, N_CTX_Q, N_CTX_K, D_HEAD",
[
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("use_bias", [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
current_platform.seed_everything(0)
q, k, v, input_metadata = input_helper(
Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
layout="bhsd",
causal=causal,
use_bias=use_bias,
)
o = torch.empty_like(q)
# triton implementation
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
ref_impl = ReferenceAttention(
Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata
)
ref_out = ref_impl.fwd(q, k, v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
# NOTE: Uses thd layout, so also tests thd.
@pytest.mark.parametrize(
"Z, H, N_CTX, D_HEAD",
[(1, 48, 256, 64), (4, 48, 512, 64), (16, 48, 512, 64), (64, 48, 128, 128)],
)
@pytest.mark.parametrize("causal", [True, False])
def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype)
tri_out = torch.empty_like(q)
triton_attention_rocm(q, k, v, tri_out, input_metadata)
ref_impl = ReferenceAttention(
Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata
)
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
# NOTE: Uses thd layout, so also tests thd.
@pytest.mark.parametrize(
"Z, HQ, HK, N_CTX, D_HEAD",
[
(2, 48, 24, 128, 64),
(4, 48, 12, 256, 64),
(4, 48, 4, 512, 64),
(4, 64, 16, 128, 128),
],
)
@pytest.mark.parametrize("causal", [False])
def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16):
q, k, v, input_metadata = varlen_input_helper(
Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype
)
tri_out = torch.empty_like(q)
triton_attention_rocm(q, k, v, tri_out, input_metadata)
ref_impl = ReferenceAttention(
Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata
)
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
...@@ -27,13 +27,7 @@ def test_models( ...@@ -27,13 +27,7 @@ def test_models(
example_prompts, example_prompts,
model: str, model: str,
dtype: str, dtype: str,
monkeypatch,
) -> None: ) -> None:
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts) vllm_outputs = vllm_model.classify(example_prompts)
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import pytest import pytest
from vllm.config import PoolerConfig from vllm.config import PoolerConfig
from vllm.platforms import current_platform
from ...utils import check_embeddings_close from ...utils import check_embeddings_close
...@@ -51,13 +50,7 @@ def test_models( ...@@ -51,13 +50,7 @@ def test_models(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model, model,
monkeypatch,
) -> None: ) -> None:
if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
vllm_extra_kwargs = {} vllm_extra_kwargs = {}
if model == "ssmits/Qwen2-7B-Instruct-embed-base": if model == "ssmits/Qwen2-7B-Instruct-embed-base":
vllm_extra_kwargs["pooler_config"] = PoolerConfig( vllm_extra_kwargs["pooler_config"] = PoolerConfig(
......
...@@ -2,18 +2,11 @@ ...@@ -2,18 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.platforms import current_platform
def test_idefics_multimodal( def test_idefics_multimodal(
vllm_runner, vllm_runner,
monkeypatch,
) -> None: ) -> None:
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
...@@ -59,13 +52,7 @@ def update_config(config): ...@@ -59,13 +52,7 @@ def update_config(config):
def test_gemma_multimodal( def test_gemma_multimodal(
vllm_runner, vllm_runner,
monkeypatch,
) -> None: ) -> None:
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
messages = [ messages = [
{ {
"role": "system", "role": "system",
......
...@@ -76,7 +76,6 @@ def test_prm_models( ...@@ -76,7 +76,6 @@ def test_prm_models(
math_step_prompts, math_step_prompts,
model: str, model: str,
dtype: str, dtype: str,
monkeypatch,
) -> None: ) -> None:
check_transformers_version( check_transformers_version(
"Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53.2" "Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53.2"
...@@ -85,11 +84,6 @@ def test_prm_models( ...@@ -85,11 +84,6 @@ def test_prm_models(
if current_platform.is_cpu(): if current_platform.is_cpu():
pytest.skip("CPU only supports V1") pytest.skip("CPU only supports V1")
if current_platform.is_rocm():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model: with vllm_runner(model, max_model_len=1024, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.reward(math_step_prompts) vllm_outputs = vllm_model.reward(math_step_prompts)
......
...@@ -5,7 +5,6 @@ image, embedding, and video support for different VLMs in vLLM. ...@@ -5,7 +5,6 @@ image, embedding, and video support for different VLMs in vLLM.
""" """
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import PosixPath from pathlib import PosixPath
...@@ -38,13 +37,6 @@ from .vlm_utils.types import ( ...@@ -38,13 +37,6 @@ from .vlm_utils.types import (
VLMTestType, VLMTestType,
) )
# This hack is needed for phi3v & paligemma models
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
COMMON_BROADCAST_SETTINGS = { COMMON_BROADCAST_SETTINGS = {
"test_type": VLMTestType.IMAGE, "test_type": VLMTestType.IMAGE,
"dtype": "half", "dtype": "half",
......
...@@ -11,7 +11,6 @@ from huggingface_hub import snapshot_download ...@@ -11,7 +11,6 @@ from huggingface_hub import snapshot_download
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.image import rescale_image_size from vllm.multimodal.image import rescale_image_size
from vllm.platforms import current_platform
from ....conftest import ( from ....conftest import (
IMAGE_ASSETS, IMAGE_ASSETS,
...@@ -46,12 +45,6 @@ models = [model_path] ...@@ -46,12 +45,6 @@ models = [model_path]
target_dtype = "half" target_dtype = "half"
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
def run_test( def run_test(
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
......
...@@ -14,7 +14,6 @@ from vllm.assets.image import ImageAsset ...@@ -14,7 +14,6 @@ from vllm.assets.image import ImageAsset
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.multimodal.image import convert_image_mode, rescale_image_size
from vllm.platforms import current_platform
from ....conftest import ( from ....conftest import (
IMAGE_ASSETS, IMAGE_ASSETS,
...@@ -68,12 +67,6 @@ def vllm_to_hf_output( ...@@ -68,12 +67,6 @@ def vllm_to_hf_output(
target_dtype = "half" target_dtype = "half"
# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if current_platform.is_rocm():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
def run_test( def run_test(
hf_runner: type[HfRunner], hf_runner: type[HfRunner],
......
...@@ -8,7 +8,6 @@ See also `tests/kernels/moe/test_ocp_mx_moe.py`. ...@@ -8,7 +8,6 @@ See also `tests/kernels/moe/test_ocp_mx_moe.py`.
""" """
import importlib.metadata import importlib.metadata
import os
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
...@@ -246,8 +245,6 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig): ...@@ -246,8 +245,6 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
task = "gsm8k" task = "gsm8k"
rtol = 0.03 rtol = 0.03
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
results = lm_eval.simple_evaluate( results = lm_eval.simple_evaluate(
model="vllm", model="vllm",
model_args=config.get_model_args(tp_size=8, model_max_len=38768), model_args=config.get_model_args(tp_size=8, model_max_len=38768),
...@@ -263,8 +260,6 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig): ...@@ -263,8 +260,6 @@ def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
and measured_value + rtol > EXPECTED_VALUE and measured_value + rtol > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
del os.environ["VLLM_USE_TRITON_FLASH_ATTN"]
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
......
This diff is collapsed.
...@@ -18,7 +18,6 @@ if TYPE_CHECKING: ...@@ -18,7 +18,6 @@ if TYPE_CHECKING:
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_NCCL_SO_PATH: str | None = None VLLM_NCCL_SO_PATH: str | None = None
LD_LIBRARY_PATH: str | None = None LD_LIBRARY_PATH: str | None = None
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: int | None = None VLLM_FLASH_ATTN_VERSION: int | None = None
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
...@@ -521,10 +520,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -521,10 +520,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl
# library file in the locations specified by `LD_LIBRARY_PATH` # library file in the locations specified by `LD_LIBRARY_PATH`
"LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None), "LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None),
# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN": lambda: (
os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")
),
# Use separate prefill and decode kernels for V1 attention instead of # Use separate prefill and decode kernels for V1 attention instead of
# the unified triton kernel. # the unified triton kernel.
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: ( "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: (
...@@ -1554,7 +1549,6 @@ def compute_hash() -> str: ...@@ -1554,7 +1549,6 @@ def compute_hash() -> str:
"VLLM_PP_LAYER_PARTITION", "VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE", "VLLM_MLA_DISABLE",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ", "VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK", "VLLM_DP_RANK",
"VLLM_DP_SIZE", "VLLM_DP_SIZE",
......
...@@ -49,25 +49,8 @@ _ROCM_UNSUPPORTED_MODELS: list[str] = [] ...@@ -49,25 +49,8 @@ _ROCM_UNSUPPORTED_MODELS: list[str] = []
# Models partially supported by ROCm. # Models partially supported by ROCm.
# Architecture -> Reason. # Architecture -> Reason.
_ROCM_SWA_REASON = ( _ROCM_SWA_REASON = ()
"Sliding window attention (SWA) is not yet supported in " _ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {}
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
"Qwen2ForCausalLM": _ROCM_SWA_REASON,
"MistralForCausalLM": _ROCM_SWA_REASON,
"MixtralForCausalLM": _ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration": (
"ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
),
"Phi3VForCausalLM": (
"ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`"
),
}
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = { _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
"0x74a0": "AMD_Instinct_MI300A", "0x74a0": "AMD_Instinct_MI300A",
"0x74a1": "AMD_Instinct_MI300X", "0x74a1": "AMD_Instinct_MI300X",
......
...@@ -37,7 +37,6 @@ _GLOBAL_RUNTIME_DATA = dict[str, str | int | bool]() ...@@ -37,7 +37,6 @@ _GLOBAL_RUNTIME_DATA = dict[str, str | int | bool]()
_USAGE_ENV_VARS_TO_COLLECT = [ _USAGE_ENV_VARS_TO_COLLECT = [
"VLLM_USE_MODELSCOPE", "VLLM_USE_MODELSCOPE",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_ATTENTION_BACKEND", "VLLM_ATTENTION_BACKEND",
"VLLM_USE_FLASHINFER_SAMPLER", "VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_PP_LAYER_PARTITION", "VLLM_PP_LAYER_PARTITION",
......
...@@ -5,22 +5,18 @@ from typing import ClassVar ...@@ -5,22 +5,18 @@ from typing import ClassVar
import torch import torch
from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
...@@ -99,46 +95,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -99,46 +95,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"TritonMLA V1 with FP8 KV cache not yet supported" "TritonMLA V1 with FP8 KV cache not yet supported"
) )
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
self.triton_fa_func = triton_attention if HAS_TRITON else None
def _flash_attn_varlen_diff_headdims_rocm(
self, q, k, v, softmax_scale=None, **kwargs
):
assert self.triton_fa_func is not None
# Triton Attention requires a padded V
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0)
# The output of triton_attention is a tuple of
# [output_tensor, encoded_softmax] where encoded_softmax is always None
output_tensor, _ = self.triton_fa_func(
q,
k,
padded_v,
None, # output
kwargs["cu_seqlens_q"],
kwargs["cu_seqlens_k"],
kwargs["max_seqlen_q"],
kwargs["max_seqlen_k"],
kwargs["causal"],
softmax_scale,
None, # bias
)
return output_tensor
def _flash_attn_varlen_diff_headdims( def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
): ):
if (
current_platform.is_rocm()
and self.use_triton_flash_attn
and not return_softmax_lse
):
return self._flash_attn_varlen_diff_headdims_rocm(
q, k, v, softmax_scale=softmax_scale, **kwargs
)
else:
return super()._flash_attn_varlen_diff_headdims( return super()._flash_attn_varlen_diff_headdims(
q, q,
k, k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment