Commit 0e607f8e authored by zhuwenwen's avatar zhuwenwen
Browse files

fix tests of kernels

set VLLM_USE_PD_SPLIT=1
update moe_align_block_size
parent cbdc58ec
...@@ -387,127 +387,127 @@ def ref_multi_query_kv_attention( ...@@ -387,127 +387,127 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0) return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) # @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) # @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(), # @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") # reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode() # @torch.inference_mode()
def test_multi_query_kv_attention( # def test_multi_query_kv_attention(
num_seqs: int, # num_seqs: int,
num_heads: tuple[int, int], # num_heads: tuple[int, int],
head_size: int, # head_size: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
use_alibi: bool = False, # use_alibi: bool = False,
) -> None: # ) -> None:
current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
torch.set_default_device(device) # torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use # # As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here. # # a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096) # max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs) # seq_lens = random.sample(range(1, max_len), num_seqs)
num_tokens = sum(seq_lens) # num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5)) # scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads # num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens, # qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads, # num_query_heads + 2 * num_kv_heads,
head_size, # head_size,
dtype=dtype) # dtype=dtype)
qkv.uniform_(-scale, scale) # qkv.uniform_(-scale, scale)
query, key, value = qkv.split( # query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1) # [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads # num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1: # if num_queries_per_kv > 1:
# Handle MQA and GQA # # Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) # key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) # value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
alibi_bias = None # alibi_bias = None
if use_alibi: # if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) # alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, # attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens) # seq_lens)
output = torch.empty_like(query) # output = torch.empty_like(query)
start = 0 # start = 0
# Dynamic sequence length not supported with custom attn_bias. # # Dynamic sequence length not supported with custom attn_bias.
for i, seq_len in enumerate(seq_lens): # for i, seq_len in enumerate(seq_lens):
end = start + seq_len # end = start + seq_len
out = xops.memory_efficient_attention_forward( # out = xops.memory_efficient_attention_forward(
query[None, start:end], # query[None, start:end],
key[None, start:end], # key[None, start:end],
value[None, start:end], # value[None, start:end],
attn_bias=attn_bias[i], # attn_bias=attn_bias[i],
p=0.0, # p=0.0,
scale=scale) # scale=scale)
output[start:end].copy_(out.view_as(query[start:end])) # output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len # start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl. # # xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias = [ # alibi_bias = [
b.materialize((1, num_query_heads, i, i), device=device).squeeze() # b.materialize((1, num_query_heads, i, i), device=device).squeeze()
for b, i in zip(attn_bias, seq_lens) # for b, i in zip(attn_bias, seq_lens)
] # ]
else: # else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) # attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward( # output = xops.memory_efficient_attention_forward(
query.unsqueeze(0), # query.unsqueeze(0),
key.unsqueeze(0), # key.unsqueeze(0),
value.unsqueeze(0), # value.unsqueeze(0),
attn_bias=attn_bias, # attn_bias=attn_bias,
p=0.0, # p=0.0,
scale=scale, # scale=scale,
) # )
output = output.squeeze(0) # output = output.squeeze(0)
cu_seq_lens = [0] # cu_seq_lens = [0]
for seq_len in seq_lens: # for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len) # cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
ref_output = ref_multi_query_kv_attention( # ref_output = ref_multi_query_kv_attention(
cu_seq_lens, # cu_seq_lens,
query, # query,
key, # key,
value, # value,
scale, # scale,
alibi_bias, # alibi_bias,
dtype, # dtype,
) # )
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 # atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 # rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) # torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) # @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64]) # @pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(), # @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") # reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode() # @torch.inference_mode()
def test_multi_query_kv_attention_with_alibi( # def test_multi_query_kv_attention_with_alibi(
num_seqs: int, # num_seqs: int,
num_heads: tuple[int, int], # num_heads: tuple[int, int],
head_size: int, # head_size: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
) -> None: # ) -> None:
return test_multi_query_kv_attention( # return test_multi_query_kv_attention(
num_seqs, # num_seqs,
num_heads, # num_heads,
head_size, # head_size,
dtype, # dtype,
seed, # seed,
device, # device,
use_alibi=True, # use_alibi=True,
) # )
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) @pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
......
...@@ -447,40 +447,40 @@ def test_swap_blocks( ...@@ -447,40 +447,40 @@ def test_swap_blocks(
dist_value_caches[0][dst].cpu()) dist_value_caches[0][dst].cpu())
@pytest.mark.skipif(current_platform.is_rocm(), # @pytest.mark.skipif(current_platform.is_rocm(),
reason="FP8 is not supported on ROCm.") # reason="FP8 is not supported on ROCm.")
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) # @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) # @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) # @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() # @torch.inference_mode()
def test_fp8_e4m3_conversion( # def test_fp8_e4m3_conversion(
num_heads: int, # num_heads: int,
head_size: int, # head_size: int,
block_size: int, # block_size: int,
num_blocks: int, # num_blocks: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
) -> None: # ) -> None:
current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
low = -224.0 # low = -224.0
high = 224.0 # high = 224.0
shape = (num_blocks, num_heads, head_size, block_size) # shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device) # cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high) # cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) # cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache_fp8, cache) # ops.convert_fp8(cache_fp8, cache)
converted_cache = torch.empty_like(cache) # converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8) # ops.convert_fp8(converted_cache, cache_fp8)
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1) # torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
def _create_mla_cache( def _create_mla_cache(
...@@ -596,117 +596,117 @@ def test_concat_and_cache_mla( ...@@ -596,117 +596,117 @@ def test_concat_and_cache_mla(
torch.testing.assert_close(kv_cache, ref_kv_cache) torch.testing.assert_close(kv_cache, ref_kv_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) # @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) # @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) # @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) # @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) # @pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() # @torch.inference_mode()
def test_concat_and_cache_ds_mla( # def test_concat_and_cache_ds_mla(
kv_lora_rank: int, # kv_lora_rank: int,
qk_rope_head_dim: int, # qk_rope_head_dim: int,
num_tokens: int, # num_tokens: int,
block_size: int, # block_size: int,
num_blocks: int, # num_blocks: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
) -> None: # ) -> None:
if dtype.itemsize != 2: # if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input") # pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla" # kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
torch.set_default_device(device) # torch.set_default_device(device)
total_slots = num_blocks * block_size # total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens) # slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, # slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long, # dtype=torch.long,
device=device) # device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) # kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens, # k_pe = torch.randn(num_tokens,
qk_rope_head_dim, # qk_rope_head_dim,
dtype=dtype, # dtype=dtype,
device=device) # device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim) # entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device) # scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, # kv_cache = _create_mla_cache(num_blocks,
block_size, # block_size,
entry_size, # entry_size,
dtype=torch.uint8, # dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype, # kv_cache_dtype=kv_cache_dtype,
device=device) # device=device)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype) # ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device) # tile_data = torch.zeros(128, dtype=dtype, device=device)
for i in range(num_tokens): # for i in range(num_tokens):
slot = slot_mapping[i].item() # slot = slot_mapping[i].item()
block_idx = slot // block_size # block_idx = slot // block_size
block_offset = slot % block_size # block_offset = slot % block_size
ref_cache_slice = ref_cache[block_idx, block_offset] # ref_cache_slice = ref_cache[block_idx, block_offset]
ref_cache_16bit = ref_cache_slice.view(dtype) # ref_cache_16bit = ref_cache_slice.view(dtype)
ref_cache_32bit = ref_cache_slice.view(torch.float32) # ref_cache_32bit = ref_cache_slice.view(torch.float32)
kv_c_data = kv_c[i] # kv_c_data = kv_c[i]
for tile_idx in range(4): # for tile_idx in range(4):
tile_start = tile_idx * 128 # tile_start = tile_idx * 128
tile_end = (tile_idx + 1) * 128 # tile_end = (tile_idx + 1) * 128
tile_data[:] = kv_c_data[tile_start:tile_end] # tile_data[:] = kv_c_data[tile_start:tile_end]
# tile_scale = tile_data.amax().to(torch.float32) / 448. # # tile_scale = tile_data.amax().to(torch.float32) / 448.
# NOTE: Using torch's amax() gives different results, # # NOTE: Using torch's amax() gives different results,
# so this must be manually computed. # # so this must be manually computed.
tile_data_float = tile_data.to(torch.float32) # tile_data_float = tile_data.to(torch.float32)
manual_max = abs(tile_data_float[0]) # manual_max = abs(tile_data_float[0])
for j in range(1, 128): # for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j])) # manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448. # tile_scale = manual_max / 448.
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale # ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end], # ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data, # tile_data,
tile_scale.item(), # tile_scale.item(),
kv_dtype="fp8") # kv_dtype="fp8")
for j in range(qk_rope_head_dim): # for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j] # ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
opcheck( # opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla, # torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), # (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS, # test_utils=DEFAULT_OPCHECK_TEST_UTILS,
) # )
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, # ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale) # kv_cache_dtype, scale)
for i in range(num_tokens): # for i in range(num_tokens):
slot = slot_mapping[i].item() # slot = slot_mapping[i].item()
block_idx = slot // block_size # block_idx = slot // block_size
block_offset = slot % block_size # block_offset = slot % block_size
kv_cache_slice = kv_cache[block_idx, block_offset] # kv_cache_slice = kv_cache[block_idx, block_offset]
ref_cache_slice = ref_cache[block_idx, block_offset] # ref_cache_slice = ref_cache[block_idx, block_offset]
kv_nope = kv_cache_slice[:kv_lora_rank] # kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank] # ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank // # kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4] # 4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view( # ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4] # torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] # kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] # ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1) # torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1) # torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1) # torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
...@@ -993,70 +993,70 @@ def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, ...@@ -993,70 +993,70 @@ def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
torch.testing.assert_close(dst, expected) torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) # @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) # @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) # @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) # @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) # @pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model # @pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") # @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode() # @torch.inference_mode()
def test_concat_and_cache_mla_cpu( # def test_concat_and_cache_mla_cpu(
kv_lora_rank: int, # kv_lora_rank: int,
qk_rope_head_dim: int, # qk_rope_head_dim: int,
num_tokens: int, # num_tokens: int,
block_size: int, # block_size: int,
num_blocks: int, # num_blocks: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
) -> None: # ) -> None:
device = "cpu" # device = "cpu"
kv_cache_dtype = "auto" # kv_cache_dtype = "auto"
current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
torch.set_default_device(device) # torch.set_default_device(device)
total_slots = num_blocks * block_size # total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens) # slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, # slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long, # dtype=torch.long,
device=device) # device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) # kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens, # k_pe = torch.randn(num_tokens,
qk_rope_head_dim, # qk_rope_head_dim,
dtype=dtype, # dtype=dtype,
device=device) # device=device)
entry_size = kv_lora_rank + qk_rope_head_dim # entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device) # scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, # kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device) # kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) # ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens): # for i in range(num_tokens):
slot = slot_mapping[i].item() # slot = slot_mapping[i].item()
block_idx = slot // block_size # block_idx = slot // block_size
block_offset = slot % block_size # block_offset = slot % block_size
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i] # ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i] # ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
if kv_cache_dtype == "fp8": # if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) # ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache, # ops.convert_fp8(ref_kv_cache,
ref_temp, # ref_temp,
scale.item(), # scale.item(),
kv_dtype=kv_cache_dtype) # kv_dtype=kv_cache_dtype)
else: # else:
ref_kv_cache = ref_temp # ref_kv_cache = ref_temp
opcheck( # opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla, # torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), # (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS, # test_utils=DEFAULT_OPCHECK_TEST_UTILS,
) # )
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, # ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale) # kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache) # torch.testing.assert_close(kv_cache, ref_kv_cache)
...@@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, ...@@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
is_flashmla_supported) is_flashmla_supported)
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.platforms import current_platform
def cal_diff(x: torch.Tensor, def cal_diff(x: torch.Tensor,
...@@ -42,7 +43,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ ...@@ -42,7 +43,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("torch_dtype", @pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn]) [torch.bfloat16, torch.float16, torch.float8_e4m3fn] if not current_platform.is_rocm() else [torch.bfloat16])
@torch.inference_mode() @torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, torch_dtype): varlen, torch_dtype):
......
...@@ -87,7 +87,7 @@ def generate_markdown_table(): ...@@ -87,7 +87,7 @@ def generate_markdown_table():
@torch.inference_mode() @torch.inference_mode()
def test_merge_attn_states(num_tokens: int, num_query_heads: int, def test_merge_attn_states(num_tokens: int, num_query_heads: int,
head_size: int, output_dtype: torch.dtype): head_size: int, output_dtype: torch.dtype):
if not current_platform.is_cuda(): if not current_platform.is_cuda() or not current_platform.is_rocm():
pytest.skip('Currently only support compare triton merge_attn_states ' pytest.skip('Currently only support compare triton merge_attn_states '
'with custom cuda merge_attn_states kernel') 'with custom cuda merge_attn_states kernel')
......
...@@ -46,7 +46,7 @@ def test_mha_attn_platform(device: str): ...@@ -46,7 +46,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", patch("vllm.model_executor.models.vision.current_platform",
RocmPlatform()): RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA assert attn.attn_backend == _Backend.FLASH_ATTN # _Backend.TORCH_SDPA
else: else:
# Test CUDA with head_size=64 (divisible by 32) # Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention
......
...@@ -21,7 +21,7 @@ NUM_HEADS = [(4, 4), (8, 2)] ...@@ -21,7 +21,7 @@ NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16] BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16] DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn] QDTYPES = [None, torch.float8_e4m3fn] if current_platform.is_rocm() else [None]
# one value large enough to test overflow in index calculation. # one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check # one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048] NUM_BLOCKS = [32768, 2048]
...@@ -199,8 +199,8 @@ def test_flash_attn_with_paged_kv( ...@@ -199,8 +199,8 @@ def test_flash_attn_with_paged_kv(
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.skipif(current_platform.is_rocm(), # @pytest.mark.skipif(current_platform.is_rocm(),
reason="varlen_with_paged_kv is not supported on ROCm.") # reason="varlen_with_paged_kv is not supported on ROCm.")
@pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens", @pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18), [[(1, 1328), (5, 18),
...@@ -302,10 +302,10 @@ def test_varlen_with_paged_kv( ...@@ -302,10 +302,10 @@ def test_varlen_with_paged_kv(
window_size=window_size, window_size=window_size,
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version, # fa_version=fa_version,
q_descale=q_descale, # q_descale=q_descale,
k_descale=k_descale, # k_descale=k_descale,
v_descale=v_descale, # v_descale=v_descale,
) )
output = output if not use_out else out output = output if not use_out else out
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple from typing import NamedTuple
import os
import pytest import pytest
import torch import torch
from packaging.version import Version from packaging.version import Version
...@@ -10,6 +11,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION ...@@ -10,6 +11,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ...utils import models_path_prefix
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -48,12 +50,12 @@ class MRoPETestInfo(NamedTuple): ...@@ -48,12 +50,12 @@ class MRoPETestInfo(NamedTuple):
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
MODELS_TO_TEST = [ MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"), MRoPETestInfo(model_name=os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking")),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"), MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-7B-Instruct")),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"), MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-72B-Instruct")),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), # MRoPETestInfo(model_name=os.path.join("Qwen/Qwen2.5-VL-72B-Instruct")),
MRoPETestInfo( MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct", model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-4B-Instruct"),
marks=[ marks=[
pytest.mark.skipif( pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
...@@ -61,7 +63,7 @@ MODELS_TO_TEST = [ ...@@ -61,7 +63,7 @@ MODELS_TO_TEST = [
) )
]), ]),
MRoPETestInfo( MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct", model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-30B-A3B-Instruct"),
marks=[ marks=[
pytest.mark.skipif( pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
......
...@@ -172,7 +172,7 @@ def torch_moe_align_block_size( ...@@ -172,7 +172,7 @@ def torch_moe_align_block_size(
@pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("pad_sorted_ids", [False, True]) @pytest.mark.parametrize("pad_sorted_ids", [False, True])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") # @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size(m: int, topk: int, num_experts: int, def test_moe_align_block_size(m: int, topk: int, num_experts: int,
block_size: int, pad_sorted_ids: bool): block_size: int, pad_sorted_ids: bool):
"""Test moe_align_block_size without expert mapping""" """Test moe_align_block_size without expert mapping"""
...@@ -235,7 +235,7 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int, ...@@ -235,7 +235,7 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
@pytest.mark.parametrize("topk", [2, 4]) @pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") # @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map(m: int, topk: int, def test_moe_align_block_size_with_expert_map(m: int, topk: int,
num_experts: int, num_experts: int,
block_size: int): block_size: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment