Commit 0e607f8e authored by zhuwenwen's avatar zhuwenwen
Browse files

fix tests of kernels

set VLLM_USE_PD_SPLIT=1
update moe_align_block_size
parent cbdc58ec
......@@ -387,127 +387,127 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
use_alibi: bool = False,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
alibi_bias = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens)
output = torch.empty_like(query)
start = 0
# Dynamic sequence length not supported with custom attn_bias.
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias = [
b.materialize((1, num_query_heads, i, i), device=device).squeeze()
for b, i in zip(attn_bias, seq_lens)
]
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
)
output = output.squeeze(0)
cu_seq_lens = [0]
for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
query,
key,
value,
scale,
alibi_bias,
dtype,
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
return test_multi_query_kv_attention(
num_seqs,
num_heads,
head_size,
dtype,
seed,
device,
use_alibi=True,
)
# @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.skipif(current_platform.is_rocm(),
# reason="Xformers backend is not supported on ROCm.")
# @torch.inference_mode()
# def test_multi_query_kv_attention(
# num_seqs: int,
# num_heads: tuple[int, int],
# head_size: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# use_alibi: bool = False,
# ) -> None:
# current_platform.seed_everything(seed)
# torch.set_default_device(device)
# # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# # As the xformers library is already tested with its own tests, we can use
# # a smaller MAX_SEQ_LEN here.
# max_len = min(MAX_SEQ_LEN, 4096)
# seq_lens = random.sample(range(1, max_len), num_seqs)
# num_tokens = sum(seq_lens)
# scale = float(1.0 / (head_size**0.5))
# num_query_heads, num_kv_heads = num_heads
# qkv = torch.empty(num_tokens,
# num_query_heads + 2 * num_kv_heads,
# head_size,
# dtype=dtype)
# qkv.uniform_(-scale, scale)
# query, key, value = qkv.split(
# [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
# num_queries_per_kv = num_query_heads // num_kv_heads
# if num_queries_per_kv > 1:
# # Handle MQA and GQA
# key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
# value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
# alibi_bias = None
# if use_alibi:
# alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
# attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
# seq_lens)
# output = torch.empty_like(query)
# start = 0
# # Dynamic sequence length not supported with custom attn_bias.
# for i, seq_len in enumerate(seq_lens):
# end = start + seq_len
# out = xops.memory_efficient_attention_forward(
# query[None, start:end],
# key[None, start:end],
# value[None, start:end],
# attn_bias=attn_bias[i],
# p=0.0,
# scale=scale)
# output[start:end].copy_(out.view_as(query[start:end]))
# start += seq_len
# # xformers.AttentionBias to Tensor for use in reference impl.
# alibi_bias = [
# b.materialize((1, num_query_heads, i, i), device=device).squeeze()
# for b, i in zip(attn_bias, seq_lens)
# ]
# else:
# attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
# output = xops.memory_efficient_attention_forward(
# query.unsqueeze(0),
# key.unsqueeze(0),
# value.unsqueeze(0),
# attn_bias=attn_bias,
# p=0.0,
# scale=scale,
# )
# output = output.squeeze(0)
# cu_seq_lens = [0]
# for seq_len in seq_lens:
# cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
# ref_output = ref_multi_query_kv_attention(
# cu_seq_lens,
# query,
# key,
# value,
# scale,
# alibi_bias,
# dtype,
# )
# atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
# rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
# @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", [64])
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.skipif(current_platform.is_rocm(),
# reason="Xformers backend is not supported on ROCm.")
# @torch.inference_mode()
# def test_multi_query_kv_attention_with_alibi(
# num_seqs: int,
# num_heads: tuple[int, int],
# head_size: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# return test_multi_query_kv_attention(
# num_seqs,
# num_heads,
# head_size,
# dtype,
# seed,
# device,
# use_alibi=True,
# )
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
......
......@@ -447,40 +447,40 @@ def test_swap_blocks(
dist_value_caches[0][dst].cpu())
@pytest.mark.skipif(current_platform.is_rocm(),
reason="FP8 is not supported on ROCm.")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_e4m3_conversion(
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
low = -224.0
high = 224.0
shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache_fp8, cache)
converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8)
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
# @pytest.mark.skipif(current_platform.is_rocm(),
# reason="FP8 is not supported on ROCm.")
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
# num_heads: int,
# head_size: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# current_platform.seed_everything(seed)
# low = -224.0
# high = 224.0
# shape = (num_blocks, num_heads, head_size, block_size)
# cache = torch.empty(shape, dtype=dtype, device=device)
# cache.uniform_(low, high)
# cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
# ops.convert_fp8(cache_fp8, cache)
# converted_cache = torch.empty_like(cache)
# ops.convert_fp8(converted_cache, cache_fp8)
# torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
def _create_mla_cache(
......@@ -596,117 +596,117 @@ def test_concat_and_cache_mla(
torch.testing.assert_close(kv_cache, ref_kv_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_cache_slice = ref_cache[block_idx, block_offset]
ref_cache_16bit = ref_cache_slice.view(dtype)
ref_cache_32bit = ref_cache_slice.view(torch.float32)
kv_c_data = kv_c[i]
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = (tile_idx + 1) * 128
tile_data[:] = kv_c_data[tile_start:tile_end]
# tile_scale = tile_data.amax().to(torch.float32) / 448.
# NOTE: Using torch's amax() gives different results,
# so this must be manually computed.
tile_data_float = tile_data.to(torch.float32)
manual_max = abs(tile_data_float[0])
for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448.
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8")
for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
kv_cache_slice = kv_cache[block_idx, block_offset]
ref_cache_slice = ref_cache[block_idx, block_offset]
kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
# @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
# @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_concat_and_cache_ds_mla(
# kv_lora_rank: int,
# qk_rope_head_dim: int,
# num_tokens: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# if dtype.itemsize != 2:
# pytest.skip("ds_mla only supports 16-bit input")
# kv_cache_dtype = "fp8_ds_mla"
# current_platform.seed_everything(seed)
# torch.set_default_device(device)
# total_slots = num_blocks * block_size
# slot_mapping_lst = random.sample(range(total_slots), num_tokens)
# slot_mapping = torch.tensor(slot_mapping_lst,
# dtype=torch.long,
# device=device)
# kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
# k_pe = torch.randn(num_tokens,
# qk_rope_head_dim,
# dtype=dtype,
# device=device)
# entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
# scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# kv_cache = _create_mla_cache(num_blocks,
# block_size,
# entry_size,
# dtype=torch.uint8,
# kv_cache_dtype=kv_cache_dtype,
# device=device)
# ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
# tile_data = torch.zeros(128, dtype=dtype, device=device)
# for i in range(num_tokens):
# slot = slot_mapping[i].item()
# block_idx = slot // block_size
# block_offset = slot % block_size
# ref_cache_slice = ref_cache[block_idx, block_offset]
# ref_cache_16bit = ref_cache_slice.view(dtype)
# ref_cache_32bit = ref_cache_slice.view(torch.float32)
# kv_c_data = kv_c[i]
# for tile_idx in range(4):
# tile_start = tile_idx * 128
# tile_end = (tile_idx + 1) * 128
# tile_data[:] = kv_c_data[tile_start:tile_end]
# # tile_scale = tile_data.amax().to(torch.float32) / 448.
# # NOTE: Using torch's amax() gives different results,
# # so this must be manually computed.
# tile_data_float = tile_data.to(torch.float32)
# manual_max = abs(tile_data_float[0])
# for j in range(1, 128):
# manual_max = max(manual_max, abs(tile_data_float[j]))
# tile_scale = manual_max / 448.
# ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
# ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
# tile_data,
# tile_scale.item(),
# kv_dtype="fp8")
# for j in range(qk_rope_head_dim):
# ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
# opcheck(
# torch.ops._C_cache_ops.concat_and_cache_mla,
# (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
# test_utils=DEFAULT_OPCHECK_TEST_UTILS,
# )
# ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
# kv_cache_dtype, scale)
# for i in range(num_tokens):
# slot = slot_mapping[i].item()
# block_idx = slot // block_size
# block_offset = slot % block_size
# kv_cache_slice = kv_cache[block_idx, block_offset]
# ref_cache_slice = ref_cache[block_idx, block_offset]
# kv_nope = kv_cache_slice[:kv_lora_rank]
# ref_nope = ref_cache_slice[:kv_lora_rank]
# kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
# 4:kv_lora_rank // 4 + 4]
# ref_scales = ref_cache_slice.view(
# torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
# kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
# ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
# torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
# torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
# torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
......@@ -993,70 +993,70 @@ def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
) -> None:
device = "cpu"
kv_cache_dtype = "auto"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache)
# @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
# @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.cpu_model
# @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
# @torch.inference_mode()
# def test_concat_and_cache_mla_cpu(
# kv_lora_rank: int,
# qk_rope_head_dim: int,
# num_tokens: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# ) -> None:
# device = "cpu"
# kv_cache_dtype = "auto"
# current_platform.seed_everything(seed)
# torch.set_default_device(device)
# total_slots = num_blocks * block_size
# slot_mapping_lst = random.sample(range(total_slots), num_tokens)
# slot_mapping = torch.tensor(slot_mapping_lst,
# dtype=torch.long,
# device=device)
# kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
# k_pe = torch.randn(num_tokens,
# qk_rope_head_dim,
# dtype=dtype,
# device=device)
# entry_size = kv_lora_rank + qk_rope_head_dim
# scale = torch.tensor(0.1, dtype=torch.float32, device=device)
# kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
# kv_cache_dtype, device)
# ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
# for i in range(num_tokens):
# slot = slot_mapping[i].item()
# block_idx = slot // block_size
# block_offset = slot % block_size
# ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
# ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
# if kv_cache_dtype == "fp8":
# ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
# ops.convert_fp8(ref_kv_cache,
# ref_temp,
# scale.item(),
# kv_dtype=kv_cache_dtype)
# else:
# ref_kv_cache = ref_temp
# opcheck(
# torch.ops._C_cache_ops.concat_and_cache_mla,
# (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
# test_utils=DEFAULT_OPCHECK_TEST_UTILS,
# )
# ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
# kv_cache_dtype, scale)
# torch.testing.assert_close(kv_cache, ref_kv_cache)
......@@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.triton_utils import triton
from vllm.platforms import current_platform
def cal_diff(x: torch.Tensor,
......@@ -42,7 +43,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
[torch.bfloat16, torch.float16, torch.float8_e4m3fn] if not current_platform.is_rocm() else [torch.bfloat16])
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, torch_dtype):
......
......@@ -87,7 +87,7 @@ def generate_markdown_table():
@torch.inference_mode()
def test_merge_attn_states(num_tokens: int, num_query_heads: int,
head_size: int, output_dtype: torch.dtype):
if not current_platform.is_cuda():
if not current_platform.is_cuda() or not current_platform.is_rocm():
pytest.skip('Currently only support compare triton merge_attn_states '
'with custom cuda merge_attn_states kernel')
......
......@@ -46,7 +46,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform",
RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
assert attn.attn_backend == _Backend.FLASH_ATTN # _Backend.TORCH_SDPA
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
......
......@@ -73,7 +73,7 @@ def test_contexted_kv_attention(
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
block_size = 32 if not current_platform.is_rocm() else 16
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode
......@@ -249,318 +249,318 @@ def test_contexted_kv_attention(
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
89):
pytest.skip(
'Triton limitation: fp8e4nv data type is not supported on CUDA'
' arch < 89')
current_platform.seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)
if kv_cache_dtype == "auto":
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=cache_dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_kv_heads,
head_size,
dtype=cache_dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.long),
dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long),
dim=0)
for i in range(BS):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
start_time = time.time()
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.parametrize("op", OPS)
# @torch.inference_mode()
# def test_contexted_kv_attention_alibi(
# num_heads: int,
# num_queries_per_kv: int,
# head_size: int,
# dtype: torch.dtype,
# kv_cache_dtype: str,
# device: str,
# op: Callable,
# ) -> None:
# if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
# 89):
# pytest.skip(
# 'Triton limitation: fp8e4nv data type is not supported on CUDA'
# ' arch < 89')
# current_platform.seed_everything(0)
# torch.set_default_device(device)
# # Need this, otherwise when we capture the graph the process
# # for GPU 1 would run on both GPU0 and GPU1 and things would hang
# #
# # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
# torch.cuda.set_device(device)
# def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
# # Fork from: vllm/vllm/model_executor/models/bloom.py#L44
# closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
# base = torch.tensor(
# 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
# dtype=torch.float32,
# )
# powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
# slopes = torch.pow(base, powers)
# if closest_power_of_2 != total_num_heads:
# extra_base = torch.tensor(
# 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
# dtype=torch.float32,
# )
# num_remaining_heads = min(closest_power_of_2,
# total_num_heads - closest_power_of_2)
# extra_powers = torch.arange(start=1,
# end=1 + 2 * num_remaining_heads,
# step=2,
# dtype=torch.int32)
# slopes = torch.cat(
# [slopes, torch.pow(extra_base, extra_powers)], dim=0)
# return slopes
# alibi_slopes = _get_alibi_slopes(num_heads).to(device)
# MAX_SEQ_LEN = 1024
# MAX_CTX_LEN = 1024
# BS = 10
# cache_size = 640
# block_size = 32
# max_block_per_request = 64
# query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
# seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
# num_kv_heads = num_heads // num_queries_per_kv
# num_tokens = sum(query_lens)
# query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
# query.uniform_(-1e-3, 1e-3)
# output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
# kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
# kv.uniform_(-1e-3, 1e-3)
# key, value = kv.unbind(dim=1)
# if kv_cache_dtype == "auto":
# cache_dtype = dtype
# else:
# cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
# k_cache = torch.zeros(cache_size,
# block_size,
# num_kv_heads,
# head_size,
# dtype=cache_dtype)
# v_cache = torch.zeros(cache_size,
# block_size,
# num_kv_heads,
# head_size,
# dtype=cache_dtype)
# k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
# v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
# values = torch.arange(0, cache_size, dtype=torch.long)
# values = values[torch.randperm(cache_size)]
# block_table = values[:BS * max_block_per_request].view(
# BS, max_block_per_request)
# b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
# b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
# b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
# dtype=torch.long),
# dim=0)
# max_input_len = MAX_SEQ_LEN
# # copy kv to cache
# b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
# dtype=torch.long),
# dim=0)
# for i in range(BS):
# for j in range(query_lens[i]):
# k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
# j])
# v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
# b_ctx_len[i] + j])
# cur_ctx = 0
# block_id = 0
# while cur_ctx < b_ctx_len[i]:
# start_loc = b_seq_start_loc[i] + cur_ctx
# if cur_ctx + block_size > b_ctx_len[i]:
# end_loc = b_seq_start_loc[i] + b_ctx_len[i]
# else:
# end_loc = start_loc + block_size
# start_slot = block_table[i, block_id] * block_size
# end_slot = start_slot + end_loc - start_loc
# k_cache.view(-1, num_kv_heads,
# head_size)[start_slot:end_slot].copy_(
# key[start_loc:end_loc])
# v_cache.view(-1, num_kv_heads,
# head_size)[start_slot:end_slot].copy_(
# value[start_loc:end_loc])
# cur_ctx += block_size
# block_id += 1
# # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
# k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
# 8).permute(0, 2, 3, 1, 4).contiguous()
# # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
# v_cache = v_cache.view(-1, block_size, num_kv_heads,
# head_size).permute(0, 2, 3, 1).contiguous()
# k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# # Warm up the Triton kernel by calling it once before actually measuring
# # generation time
# op(query,
# k,
# v,
# output,
# kv_cache_dtype,
# k_cache,
# v_cache,
# block_table,
# b_start_loc,
# b_seq_len,
# MAX_CTX_LEN,
# max_input_len,
# k_scale,
# v_scale,
# alibi_slopes=alibi_slopes)
# torch.cuda.synchronize()
# start_time = time.time()
# op(query,
# k,
# v,
# output,
# kv_cache_dtype,
# k_cache,
# v_cache,
# block_table,
# b_start_loc,
# b_seq_len,
# MAX_CTX_LEN,
# max_input_len,
# k_scale,
# v_scale,
# alibi_slopes=alibi_slopes)
# torch.cuda.synchronize()
# end_time = time.time()
# print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
if not current_platform.is_rocm():
scale = float(1.0 / (head_size**0.5))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens),
num_heads,
head_size,
dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0
query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...]
],
dim=0)
seq_start += seq_len
query_start += query_len
query = query_pad
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime.
key = key.reshape(key.shape[0], -1, key.shape[-1])
value = value.reshape(value.shape[0], -1, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
query = query_pad
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime.
key = key.reshape(key.shape[0], -1, key.shape[-1])
value = value.reshape(value.shape[0], -1, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
if not current_platform.is_rocm():
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/v1/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
# if not current_platform.is_rocm():
# scale = float(1.0 / (head_size**0.5))
# # NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# # we have to pad query tensor before MQA/GQA expanding.
# if query.shape[0] != key.shape[0]:
# query_pad = torch.empty(sum(seq_lens),
# num_heads,
# head_size,
# dtype=dtype)
# query_pad.uniform_(-1e-3, 1e-3)
# seq_start = 0
# query_start = 0
# for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
# seq_end = seq_start + seq_len
# query_end = query_start + query_len
# query_pad[seq_start:seq_end, ...] = torch.cat([
# torch.zeros(
# seq_len - query_len, num_heads, head_size, dtype=dtype),
# query[query_start:query_end, ...]
# ],
# dim=0)
# seq_start += seq_len
# query_start += query_len
# query = query_pad
# if num_kv_heads != num_heads:
# # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# # project the key and value tensors to the desired number of
# # heads.
# #
# # see also: vllm/model_executor/layers/attention.py
# key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
# num_queries_per_kv, key.shape[-1])
# value = value[:, :,
# None, :].expand(value.shape[0], num_kv_heads,
# num_queries_per_kv, value.shape[-1])
# # [seq, num_kv_heads, num_queries_per_kv, dk]=>
# # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# # codebase. We save some time reshaping alibi matrix at runtime.
# key = key.reshape(key.shape[0], -1, key.shape[-1])
# value = value.reshape(value.shape[0], -1, value.shape[-1])
# query = query.unsqueeze(0)
# key = key.unsqueeze(0)
# value = value.unsqueeze(0)
# attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
# output_ref = torch.empty_like(output)
# seq_start = 0
# query_start = 0
# start_time = time.time()
# # Attention with alibi slopes.
# # FIXME(DefTruth): Because xformers does not support dynamic sequence
# # lengths with custom attention bias, we process each prompt one by
# # one. This is inefficient, especially when we have many short prompts.
# # modified from: vllm/attention/backends/xformers.py#L343
# for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
# seq_end = seq_start + seq_len
# query_end = query_start + query_len
# out = xops.memory_efficient_attention_forward(query[:,
# seq_start:seq_end],
# key[:,
# seq_start:seq_end],
# value[:,
# seq_start:seq_end],
# attn_bias=attn_bias[i],
# p=0.0,
# scale=scale)
# out = out.view_as(query[:, seq_start:seq_end]).view(
# seq_len, num_heads, head_size)
# output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
# ...])
# seq_start += seq_len
# query_start += query_len
# query = query_pad
# if num_kv_heads != num_heads:
# # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# # project the key and value tensors to the desired number of
# # heads.
# #
# # see also: vllm/model_executor/layers/attention.py
# key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
# num_queries_per_kv, key.shape[-1])
# value = value[:, :,
# None, :].expand(value.shape[0], num_kv_heads,
# num_queries_per_kv, value.shape[-1])
# # [seq, num_kv_heads, num_queries_per_kv, dk]=>
# # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# # codebase. We save some time reshaping alibi matrix at runtime.
# key = key.reshape(key.shape[0], -1, key.shape[-1])
# value = value.reshape(value.shape[0], -1, value.shape[-1])
# query = query.unsqueeze(0)
# key = key.unsqueeze(0)
# value = value.unsqueeze(0)
# attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
# output_ref = torch.empty_like(output)
# seq_start = 0
# query_start = 0
# if not current_platform.is_rocm():
# start_time = time.time()
# # Attention with alibi slopes.
# # FIXME(DefTruth): Because xformers does not support dynamic sequence
# # lengths with custom attention bias, we process each prompt one by
# # one. This is inefficient, especially when we have many short prompts.
# # modified from: vllm/v1/attention/backends/xformers.py#L343
# for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
# seq_end = seq_start + seq_len
# query_end = query_start + query_len
# out = xops.memory_efficient_attention_forward(query[:,
# seq_start:seq_end],
# key[:,
# seq_start:seq_end],
# value[:,
# seq_start:seq_end],
# attn_bias=attn_bias[i],
# p=0.0,
# scale=scale)
# out = out.view_as(query[:, seq_start:seq_end]).view(
# seq_len, num_heads, head_size)
# output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
# ...])
# seq_start += seq_len
# query_start += query_len
# torch.cuda.synchronize()
# end_time = time.time()
# print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
# atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
# torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
# These tests are optional to only run when explicitly invoked
......@@ -595,23 +595,23 @@ def test_contexted_kv_attention_f32(
op)
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
dtype, kv_cache_dtype, device, op)
# @pytest.mark.optional
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("dtype", [torch.float32])
# @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.parametrize("op", OPS)
# @torch.inference_mode()
# def test_contexted_kv_attention_alibi_f32(
# num_heads: int,
# num_queries_per_kv: int,
# head_size: int,
# dtype: torch.dtype,
# kv_cache_dtype: str,
# device: str,
# op: Callable,
# ) -> None:
# test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
# dtype, kv_cache_dtype, device, op)
......@@ -21,7 +21,7 @@ NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
QDTYPES = [None, torch.float8_e4m3fn] if current_platform.is_rocm() else [None]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
......@@ -199,8 +199,8 @@ def test_flash_attn_with_paged_kv(
f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.skipif(current_platform.is_rocm(),
reason="varlen_with_paged_kv is not supported on ROCm.")
# @pytest.mark.skipif(current_platform.is_rocm(),
# reason="varlen_with_paged_kv is not supported on ROCm.")
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
......@@ -302,10 +302,10 @@ def test_varlen_with_paged_kv(
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
# fa_version=fa_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
)
output = output if not use_out else out
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
import os
import pytest
import torch
from packaging.version import Version
......@@ -10,6 +11,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from ...utils import models_path_prefix
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......@@ -48,12 +50,12 @@ class MRoPETestInfo(NamedTuple):
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking")),
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-7B-Instruct")),
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-72B-Instruct")),
# MRoPETestInfo(model_name=os.path.join("Qwen/Qwen2.5-VL-72B-Instruct")),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct",
model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-4B-Instruct"),
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
......@@ -61,7 +63,7 @@ MODELS_TO_TEST = [
)
]),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-30B-A3B-Instruct"),
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
......
......@@ -172,7 +172,7 @@ def torch_moe_align_block_size(
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size(m: int, topk: int, num_experts: int,
block_size: int, pad_sorted_ids: bool):
"""Test moe_align_block_size without expert mapping"""
......@@ -235,7 +235,7 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
@pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map(m: int, topk: int,
num_experts: int,
block_size: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment