Commit 0e607f8e authored by zhuwenwen's avatar zhuwenwen
Browse files

fix tests of kernels

set VLLM_USE_PD_SPLIT=1
update moe_align_block_size
parent cbdc58ec
......@@ -387,127 +387,127 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
use_alibi: bool = False,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
alibi_bias = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens)
output = torch.empty_like(query)
start = 0
# Dynamic sequence length not supported with custom attn_bias.
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias = [
b.materialize((1, num_query_heads, i, i), device=device).squeeze()
for b, i in zip(attn_bias, seq_lens)
]
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
)
output = output.squeeze(0)
cu_seq_lens = [0]
for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
ref_output = ref_multi_query_kv_attention(
cu_seq_lens,
query,
key,
value,
scale,
alibi_bias,
dtype,
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
num_seqs: int,
num_heads: tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
return test_multi_query_kv_attention(
num_seqs,
num_heads,
head_size,
dtype,
seed,
device,
use_alibi=True,
)
# @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.skipif(current_platform.is_rocm(),
# reason="Xformers backend is not supported on ROCm.")
# @torch.inference_mode()
# def test_multi_query_kv_attention(
# num_seqs: int,
# num_heads: tuple[int, int],
# head_size: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# use_alibi: bool = False,
# ) -> None:
# current_platform.seed_everything(seed)
# torch.set_default_device(device)
# # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# # As the xformers library is already tested with its own tests, we can use
# # a smaller MAX_SEQ_LEN here.
# max_len = min(MAX_SEQ_LEN, 4096)
# seq_lens = random.sample(range(1, max_len), num_seqs)
# num_tokens = sum(seq_lens)
# scale = float(1.0 / (head_size**0.5))
# num_query_heads, num_kv_heads = num_heads
# qkv = torch.empty(num_tokens,
# num_query_heads + 2 * num_kv_heads,
# head_size,
# dtype=dtype)
# qkv.uniform_(-scale, scale)
# query, key, value = qkv.split(
# [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
# num_queries_per_kv = num_query_heads // num_kv_heads
# if num_queries_per_kv > 1:
# # Handle MQA and GQA
# key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
# value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
# alibi_bias = None
# if use_alibi:
# alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
# attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
# seq_lens)
# output = torch.empty_like(query)
# start = 0
# # Dynamic sequence length not supported with custom attn_bias.
# for i, seq_len in enumerate(seq_lens):
# end = start + seq_len
# out = xops.memory_efficient_attention_forward(
# query[None, start:end],
# key[None, start:end],
# value[None, start:end],
# attn_bias=attn_bias[i],
# p=0.0,
# scale=scale)
# output[start:end].copy_(out.view_as(query[start:end]))
# start += seq_len
# # xformers.AttentionBias to Tensor for use in reference impl.
# alibi_bias = [
# b.materialize((1, num_query_heads, i, i), device=device).squeeze()
# for b, i in zip(attn_bias, seq_lens)
# ]
# else:
# attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
# output = xops.memory_efficient_attention_forward(
# query.unsqueeze(0),
# key.unsqueeze(0),
# value.unsqueeze(0),
# attn_bias=attn_bias,
# p=0.0,
# scale=scale,
# )
# output = output.squeeze(0)
# cu_seq_lens = [0]
# for seq_len in seq_lens:
# cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
# ref_output = ref_multi_query_kv_attention(
# cu_seq_lens,
# query,
# key,
# value,
# scale,
# alibi_bias,
# dtype,
# )
# atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
# rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
# torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
# @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", [64])
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @pytest.mark.skipif(current_platform.is_rocm(),
# reason="Xformers backend is not supported on ROCm.")
# @torch.inference_mode()
# def test_multi_query_kv_attention_with_alibi(
# num_seqs: int,
# num_heads: tuple[int, int],
# head_size: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# return test_multi_query_kv_attention(
# num_seqs,
# num_heads,
# head_size,
# dtype,
# seed,
# device,
# use_alibi=True,
# )
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
......
......@@ -447,40 +447,40 @@ def test_swap_blocks(
dist_value_caches[0][dst].cpu())
@pytest.mark.skipif(current_platform.is_rocm(),
reason="FP8 is not supported on ROCm.")
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_fp8_e4m3_conversion(
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
low = -224.0
high = 224.0
shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache_fp8, cache)
converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8)
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
# @pytest.mark.skipif(current_platform.is_rocm(),
# reason="FP8 is not supported on ROCm.")
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
# @pytest.mark.parametrize("head_size", HEAD_SIZES)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_fp8_e4m3_conversion(
# num_heads: int,
# head_size: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# current_platform.seed_everything(seed)
# low = -224.0
# high = 224.0
# shape = (num_blocks, num_heads, head_size, block_size)
# cache = torch.empty(shape, dtype=dtype, device=device)
# cache.uniform_(low, high)
# cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
# ops.convert_fp8(cache_fp8, cache)
# converted_cache = torch.empty_like(cache)
# ops.convert_fp8(converted_cache, cache_fp8)
# torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
def _create_mla_cache(
......@@ -596,117 +596,117 @@ def test_concat_and_cache_mla(
torch.testing.assert_close(kv_cache, ref_kv_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_cache_slice = ref_cache[block_idx, block_offset]
ref_cache_16bit = ref_cache_slice.view(dtype)
ref_cache_32bit = ref_cache_slice.view(torch.float32)
kv_c_data = kv_c[i]
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = (tile_idx + 1) * 128
tile_data[:] = kv_c_data[tile_start:tile_end]
# tile_scale = tile_data.amax().to(torch.float32) / 448.
# NOTE: Using torch's amax() gives different results,
# so this must be manually computed.
tile_data_float = tile_data.to(torch.float32)
manual_max = abs(tile_data_float[0])
for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448.
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8")
for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
kv_cache_slice = kv_cache[block_idx, block_offset]
ref_cache_slice = ref_cache[block_idx, block_offset]
kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
# @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
# @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.parametrize("device", CUDA_DEVICES)
# @torch.inference_mode()
# def test_concat_and_cache_ds_mla(
# kv_lora_rank: int,
# qk_rope_head_dim: int,
# num_tokens: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# device: str,
# ) -> None:
# if dtype.itemsize != 2:
# pytest.skip("ds_mla only supports 16-bit input")
# kv_cache_dtype = "fp8_ds_mla"
# current_platform.seed_everything(seed)
# torch.set_default_device(device)
# total_slots = num_blocks * block_size
# slot_mapping_lst = random.sample(range(total_slots), num_tokens)
# slot_mapping = torch.tensor(slot_mapping_lst,
# dtype=torch.long,
# device=device)
# kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
# k_pe = torch.randn(num_tokens,
# qk_rope_head_dim,
# dtype=dtype,
# device=device)
# entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
# scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# kv_cache = _create_mla_cache(num_blocks,
# block_size,
# entry_size,
# dtype=torch.uint8,
# kv_cache_dtype=kv_cache_dtype,
# device=device)
# ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
# tile_data = torch.zeros(128, dtype=dtype, device=device)
# for i in range(num_tokens):
# slot = slot_mapping[i].item()
# block_idx = slot // block_size
# block_offset = slot % block_size
# ref_cache_slice = ref_cache[block_idx, block_offset]
# ref_cache_16bit = ref_cache_slice.view(dtype)
# ref_cache_32bit = ref_cache_slice.view(torch.float32)
# kv_c_data = kv_c[i]
# for tile_idx in range(4):
# tile_start = tile_idx * 128
# tile_end = (tile_idx + 1) * 128
# tile_data[:] = kv_c_data[tile_start:tile_end]
# # tile_scale = tile_data.amax().to(torch.float32) / 448.
# # NOTE: Using torch's amax() gives different results,
# # so this must be manually computed.
# tile_data_float = tile_data.to(torch.float32)
# manual_max = abs(tile_data_float[0])
# for j in range(1, 128):
# manual_max = max(manual_max, abs(tile_data_float[j]))
# tile_scale = manual_max / 448.
# ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
# ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
# tile_data,
# tile_scale.item(),
# kv_dtype="fp8")
# for j in range(qk_rope_head_dim):
# ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
# opcheck(
# torch.ops._C_cache_ops.concat_and_cache_mla,
# (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
# test_utils=DEFAULT_OPCHECK_TEST_UTILS,
# )
# ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
# kv_cache_dtype, scale)
# for i in range(num_tokens):
# slot = slot_mapping[i].item()
# block_idx = slot // block_size
# block_offset = slot % block_size
# kv_cache_slice = kv_cache[block_idx, block_offset]
# ref_cache_slice = ref_cache[block_idx, block_offset]
# kv_nope = kv_cache_slice[:kv_lora_rank]
# ref_nope = ref_cache_slice[:kv_lora_rank]
# kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
# 4:kv_lora_rank // 4 + 4]
# ref_scales = ref_cache_slice.view(
# torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
# kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
# ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
# torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
# torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
# torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
......@@ -993,70 +993,70 @@ def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
) -> None:
device = "cpu"
kv_cache_dtype = "auto"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache)
# @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
# @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
# @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
# @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
# @pytest.mark.parametrize("dtype", DTYPES)
# @pytest.mark.parametrize("seed", SEEDS)
# @pytest.mark.cpu_model
# @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
# @torch.inference_mode()
# def test_concat_and_cache_mla_cpu(
# kv_lora_rank: int,
# qk_rope_head_dim: int,
# num_tokens: int,
# block_size: int,
# num_blocks: int,
# dtype: torch.dtype,
# seed: int,
# ) -> None:
# device = "cpu"
# kv_cache_dtype = "auto"
# current_platform.seed_everything(seed)
# torch.set_default_device(device)
# total_slots = num_blocks * block_size
# slot_mapping_lst = random.sample(range(total_slots), num_tokens)
# slot_mapping = torch.tensor(slot_mapping_lst,
# dtype=torch.long,
# device=device)
# kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
# k_pe = torch.randn(num_tokens,
# qk_rope_head_dim,
# dtype=dtype,
# device=device)
# entry_size = kv_lora_rank + qk_rope_head_dim
# scale = torch.tensor(0.1, dtype=torch.float32, device=device)
# kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
# kv_cache_dtype, device)
# ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
# for i in range(num_tokens):
# slot = slot_mapping[i].item()
# block_idx = slot // block_size
# block_offset = slot % block_size
# ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
# ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
# if kv_cache_dtype == "fp8":
# ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
# ops.convert_fp8(ref_kv_cache,
# ref_temp,
# scale.item(),
# kv_dtype=kv_cache_dtype)
# else:
# ref_kv_cache = ref_temp
# opcheck(
# torch.ops._C_cache_ops.concat_and_cache_mla,
# (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
# test_utils=DEFAULT_OPCHECK_TEST_UTILS,
# )
# ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
# kv_cache_dtype, scale)
# torch.testing.assert_close(kv_cache, ref_kv_cache)
......@@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.triton_utils import triton
from vllm.platforms import current_platform
def cal_diff(x: torch.Tensor,
......@@ -42,7 +43,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
[torch.bfloat16, torch.float16, torch.float8_e4m3fn] if not current_platform.is_rocm() else [torch.bfloat16])
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, torch_dtype):
......
......@@ -87,7 +87,7 @@ def generate_markdown_table():
@torch.inference_mode()
def test_merge_attn_states(num_tokens: int, num_query_heads: int,
head_size: int, output_dtype: torch.dtype):
if not current_platform.is_cuda():
if not current_platform.is_cuda() or not current_platform.is_rocm():
pytest.skip('Currently only support compare triton merge_attn_states '
'with custom cuda merge_attn_states kernel')
......
......@@ -46,7 +46,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform",
RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
assert attn.attn_backend == _Backend.FLASH_ATTN # _Backend.TORCH_SDPA
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
......
......@@ -21,7 +21,7 @@ NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
QDTYPES = [None, torch.float8_e4m3fn] if current_platform.is_rocm() else [None]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
......@@ -199,8 +199,8 @@ def test_flash_attn_with_paged_kv(
f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.skipif(current_platform.is_rocm(),
reason="varlen_with_paged_kv is not supported on ROCm.")
# @pytest.mark.skipif(current_platform.is_rocm(),
# reason="varlen_with_paged_kv is not supported on ROCm.")
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
......@@ -302,10 +302,10 @@ def test_varlen_with_paged_kv(
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
# fa_version=fa_version,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
)
output = output if not use_out else out
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
import os
import pytest
import torch
from packaging.version import Version
......@@ -10,6 +11,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from ...utils import models_path_prefix
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......@@ -48,12 +50,12 @@ class MRoPETestInfo(NamedTuple):
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking")),
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-7B-Instruct")),
MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-72B-Instruct")),
# MRoPETestInfo(model_name=os.path.join("Qwen/Qwen2.5-VL-72B-Instruct")),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct",
model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-4B-Instruct"),
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
......@@ -61,7 +63,7 @@ MODELS_TO_TEST = [
)
]),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-30B-A3B-Instruct"),
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
......
......@@ -172,7 +172,7 @@ def torch_moe_align_block_size(
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size(m: int, topk: int, num_experts: int,
block_size: int, pad_sorted_ids: bool):
"""Test moe_align_block_size without expert mapping"""
......@@ -235,7 +235,7 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
@pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
# @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map(m: int, topk: int,
num_experts: int,
block_size: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment