Commit a2217966 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.6.2-dev' into v0.6.2-dev

parents 93089fb2 1a493a24
...@@ -41,7 +41,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12") ...@@ -41,7 +41,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
# Supported AMD GPU architectures. # Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx926;gfx928") set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx926;gfx928;gfx936")
# #
# Supported/expected torch versions for CUDA/ROCm. # Supported/expected torch versions for CUDA/ROCm.
......
...@@ -567,7 +567,7 @@ __global__ void paged_attention_v1_kernel_TC( ...@@ -567,7 +567,7 @@ __global__ void paged_attention_v1_kernel_TC(
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#ifdef __gfx928__ #if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>( KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
...@@ -607,7 +607,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( ...@@ -607,7 +607,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#ifdef __gfx928__ #if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac,
PARTITION_SIZE>( PARTITION_SIZE>(
...@@ -952,7 +952,7 @@ void paged_attention_v1_opt_tc( ...@@ -952,7 +952,7 @@ void paged_attention_v1_opt_tc(
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads, paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
...@@ -1182,7 +1182,7 @@ void paged_attention_v2_opt_tc( ...@@ -1182,7 +1182,7 @@ void paged_attention_v2_opt_tc(
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
......
...@@ -11,7 +11,7 @@ pytest-asyncio ...@@ -11,7 +11,7 @@ pytest-asyncio
tensorizer>=2.9.0 tensorizer>=2.9.0
setuptools_scm>=8 setuptools_scm>=8
torch == 2.4.0 torch == 2.4.1
triton == 3.0.0 triton == 3.0.0
flash_attn == 2.6.1 flash_attn == 2.6.1
lmslim == 0.2.0 lmslim == 0.2.0
\ No newline at end of file
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import VLLM_PATH, RemoteOpenAIServer
import vllm.envs as envs
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists() assert chatml_jinja_path.exists()
...@@ -6,6 +7,7 @@ assert chatml_jinja_path.exists() ...@@ -6,6 +7,7 @@ assert chatml_jinja_path.exists()
def run_and_test_dummy_opt_api_server(model, tp=1): def run_and_test_dummy_opt_api_server(model, tp=1):
# the model is registered through the plugin # the model is registered through the plugin
if envs.VLLM_USE_TRITON_FLASH_ATTN:
server_args = [ server_args = [
"--gpu-memory-utilization", "--gpu-memory-utilization",
"0.10", "0.10",
...@@ -18,6 +20,19 @@ def run_and_test_dummy_opt_api_server(model, tp=1): ...@@ -18,6 +20,19 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
"-tp", "-tp",
f"{tp}", f"{tp}",
] ]
else:
server_args = [
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float16",
"--chat-template",
str(chatml_jinja_path),
"--load-format",
"dummy",
"-tp",
f"{tp}",
]
with RemoteOpenAIServer(model, server_args) as server: with RemoteOpenAIServer(model, server_args) as server:
client = server.get_client() client = server.get_client()
completion = client.chat.completions.create( completion = client.chat.completions.create(
...@@ -39,4 +54,5 @@ def run_and_test_dummy_opt_api_server(model, tp=1): ...@@ -39,4 +54,5 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
def test_oot_registration_for_api_server(dummy_opt_path: str): def test_oot_registration_for_api_server(dummy_opt_path: str):
dummy_opt_path="facebook/opt-125m"
run_and_test_dummy_opt_api_server(dummy_opt_path) run_and_test_dummy_opt_api_server(dummy_opt_path)
...@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple ...@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple
import pytest import pytest
import torch import torch
import vllm.attention.backends.flash_attn # noqa: F401 from vllm.utils import is_hip
if is_hip():
import flash_attn
else:
import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.utils import seed_everything from vllm.utils import seed_everything
...@@ -70,16 +74,16 @@ def ref_paged_attn( ...@@ -70,16 +74,16 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
if not is_hip():
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
kv_lens: List[int], kv_lens: List[int],
num_heads: Tuple[int, int], num_heads: Tuple[int, int],
head_size: int, head_size: int,
...@@ -87,7 +91,7 @@ def test_flash_attn_with_paged_kv( ...@@ -87,7 +91,7 @@ def test_flash_attn_with_paged_kv(
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
seed_everything(0) seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
...@@ -212,7 +216,22 @@ def test_varlen_with_paged_kv( ...@@ -212,7 +216,22 @@ def test_varlen_with_paged_kv(
num_blocks, num_blocks,
(num_seqs, max_num_blocks_per_seq), (num_seqs, max_num_blocks_per_seq),
dtype=torch.int32) dtype=torch.int32)
if is_hip():
output = flash_attn.flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)
else:
output = torch.ops.vllm.flash_attn_varlen_func( output = torch.ops.vllm.flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
...@@ -233,6 +252,7 @@ def test_varlen_with_paged_kv( ...@@ -233,6 +252,7 @@ def test_varlen_with_paged_kv(
else: else:
test_utils = ["test_faketensor"] test_utils = ["test_faketensor"]
if not is_hip():
opcheck(torch.ops.vllm.flash_attn_varlen_func, opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(), args=tuple(),
kwargs=dict( kwargs=dict(
......
...@@ -4,14 +4,16 @@ import time ...@@ -4,14 +4,16 @@ import time
import pytest import pytest
import torch import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.utils import is_hip from vllm.utils import is_hip
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
if not is_hip():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64] NUM_QUERIES_PER_KV = [1, 8, 64]
HEAD_SIZES = [128, 96, 24] HEAD_SIZES = [128, 96, 24]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment