Commit 0e607f8e authored by zhuwenwen's avatar zhuwenwen
Browse files

fix tests of kernels

set VLLM_USE_PD_SPLIT=1
update moe_align_block_size
parent cbdc58ec
...@@ -387,127 +387,127 @@ def ref_multi_query_kv_attention( ...@@ -387,127 +387,127 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0) return torch.cat(ref_outputs, dim=0)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) # @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) # @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(), # @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") # reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode() # @torch.inference_mode()
def test_multi_query_kv_attention( # def test_multi_query_kv_attention(
num_seqs: int, # num_seqs: int,
num_heads: tuple[int, int], # num_heads: tuple[int, int],
head_size: int, # head_size: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
use_alibi: bool = False, # use_alibi: bool = False,
) -> None: # ) -> None:
current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
torch.set_default_device(device) # torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use # # As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here. # # a smaller MAX_SEQ_LEN here.
max_len = min(MAX_SEQ_LEN, 4096) # max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs) # seq_lens = random.sample(range(1, max_len), num_seqs)
num_tokens = sum(seq_lens) # num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5)) # scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads # num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens, # qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads, # num_query_heads + 2 * num_kv_heads,
head_size, # head_size,
dtype=dtype) # dtype=dtype)
qkv.uniform_(-scale, scale) # qkv.uniform_(-scale, scale)
query, key, value = qkv.split( # query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1) # [num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads # num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1: # if num_queries_per_kv > 1:
# Handle MQA and GQA # # Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) # key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) # value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
alibi_bias = None # alibi_bias = None
if use_alibi: # if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) # alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, # attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens) # seq_lens)
output = torch.empty_like(query) # output = torch.empty_like(query)
start = 0 # start = 0
# Dynamic sequence length not supported with custom attn_bias. # # Dynamic sequence length not supported with custom attn_bias.
for i, seq_len in enumerate(seq_lens): # for i, seq_len in enumerate(seq_lens):
end = start + seq_len # end = start + seq_len
out = xops.memory_efficient_attention_forward( # out = xops.memory_efficient_attention_forward(
query[None, start:end], # query[None, start:end],
key[None, start:end], # key[None, start:end],
value[None, start:end], # value[None, start:end],
attn_bias=attn_bias[i], # attn_bias=attn_bias[i],
p=0.0, # p=0.0,
scale=scale) # scale=scale)
output[start:end].copy_(out.view_as(query[start:end])) # output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len # start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl. # # xformers.AttentionBias to Tensor for use in reference impl.
alibi_bias = [ # alibi_bias = [
b.materialize((1, num_query_heads, i, i), device=device).squeeze() # b.materialize((1, num_query_heads, i, i), device=device).squeeze()
for b, i in zip(attn_bias, seq_lens) # for b, i in zip(attn_bias, seq_lens)
] # ]
else: # else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) # attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward( # output = xops.memory_efficient_attention_forward(
query.unsqueeze(0), # query.unsqueeze(0),
key.unsqueeze(0), # key.unsqueeze(0),
value.unsqueeze(0), # value.unsqueeze(0),
attn_bias=attn_bias, # attn_bias=attn_bias,
p=0.0, # p=0.0,
scale=scale, # scale=scale,
) # )
output = output.squeeze(0) # output = output.squeeze(0)
cu_seq_lens = [0] # cu_seq_lens = [0]
for seq_len in seq_lens: # for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len) # cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
ref_output = ref_multi_query_kv_attention( # ref_output = ref_multi_query_kv_attention(
cu_seq_lens, # cu_seq_lens,
query, # query,
key, # key,
value, # value,
scale, # scale,
alibi_bias, # alibi_bias,
dtype, # dtype,
) # )
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 # atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 # rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) # torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) # @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64]) # @pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(), # @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") # reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode() # @torch.inference_mode()
def test_multi_query_kv_attention_with_alibi( # def test_multi_query_kv_attention_with_alibi(
num_seqs: int, # num_seqs: int,
num_heads: tuple[int, int], # num_heads: tuple[int, int],
head_size: int, # head_size: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
) -> None: # ) -> None:
return test_multi_query_kv_attention( # return test_multi_query_kv_attention(
num_seqs, # num_seqs,
num_heads, # num_heads,
head_size, # head_size,
dtype, # dtype,
seed, # seed,
device, # device,
use_alibi=True, # use_alibi=True,
) # )
@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) @pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
......
...@@ -447,40 +447,40 @@ def test_swap_blocks( ...@@ -447,40 +447,40 @@ def test_swap_blocks(
dist_value_caches[0][dst].cpu()) dist_value_caches[0][dst].cpu())
@pytest.mark.skipif(current_platform.is_rocm(), # @pytest.mark.skipif(current_platform.is_rocm(),
reason="FP8 is not supported on ROCm.") # reason="FP8 is not supported on ROCm.")
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) # @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) # @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) # @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() # @torch.inference_mode()
def test_fp8_e4m3_conversion( # def test_fp8_e4m3_conversion(
num_heads: int, # num_heads: int,
head_size: int, # head_size: int,
block_size: int, # block_size: int,
num_blocks: int, # num_blocks: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
) -> None: # ) -> None:
current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
low = -224.0 # low = -224.0
high = 224.0 # high = 224.0
shape = (num_blocks, num_heads, head_size, block_size) # shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device) # cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high) # cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) # cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache_fp8, cache) # ops.convert_fp8(cache_fp8, cache)
converted_cache = torch.empty_like(cache) # converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8) # ops.convert_fp8(converted_cache, cache_fp8)
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1) # torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
def _create_mla_cache( def _create_mla_cache(
...@@ -596,117 +596,117 @@ def test_concat_and_cache_mla( ...@@ -596,117 +596,117 @@ def test_concat_and_cache_mla(
torch.testing.assert_close(kv_cache, ref_kv_cache) torch.testing.assert_close(kv_cache, ref_kv_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) # @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) # @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) # @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) # @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) # @pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() # @torch.inference_mode()
def test_concat_and_cache_ds_mla( # def test_concat_and_cache_ds_mla(
kv_lora_rank: int, # kv_lora_rank: int,
qk_rope_head_dim: int, # qk_rope_head_dim: int,
num_tokens: int, # num_tokens: int,
block_size: int, # block_size: int,
num_blocks: int, # num_blocks: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
) -> None: # ) -> None:
if dtype.itemsize != 2: # if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input") # pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla" # kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
torch.set_default_device(device) # torch.set_default_device(device)
total_slots = num_blocks * block_size # total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens) # slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, # slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long, # dtype=torch.long,
device=device) # device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) # kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens, # k_pe = torch.randn(num_tokens,
qk_rope_head_dim, # qk_rope_head_dim,
dtype=dtype, # dtype=dtype,
device=device) # device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim) # entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device) # scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, # kv_cache = _create_mla_cache(num_blocks,
block_size, # block_size,
entry_size, # entry_size,
dtype=torch.uint8, # dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype, # kv_cache_dtype=kv_cache_dtype,
device=device) # device=device)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype) # ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device) # tile_data = torch.zeros(128, dtype=dtype, device=device)
for i in range(num_tokens): # for i in range(num_tokens):
slot = slot_mapping[i].item() # slot = slot_mapping[i].item()
block_idx = slot // block_size # block_idx = slot // block_size
block_offset = slot % block_size # block_offset = slot % block_size
ref_cache_slice = ref_cache[block_idx, block_offset] # ref_cache_slice = ref_cache[block_idx, block_offset]
ref_cache_16bit = ref_cache_slice.view(dtype) # ref_cache_16bit = ref_cache_slice.view(dtype)
ref_cache_32bit = ref_cache_slice.view(torch.float32) # ref_cache_32bit = ref_cache_slice.view(torch.float32)
kv_c_data = kv_c[i] # kv_c_data = kv_c[i]
for tile_idx in range(4): # for tile_idx in range(4):
tile_start = tile_idx * 128 # tile_start = tile_idx * 128
tile_end = (tile_idx + 1) * 128 # tile_end = (tile_idx + 1) * 128
tile_data[:] = kv_c_data[tile_start:tile_end] # tile_data[:] = kv_c_data[tile_start:tile_end]
# tile_scale = tile_data.amax().to(torch.float32) / 448. # # tile_scale = tile_data.amax().to(torch.float32) / 448.
# NOTE: Using torch's amax() gives different results, # # NOTE: Using torch's amax() gives different results,
# so this must be manually computed. # # so this must be manually computed.
tile_data_float = tile_data.to(torch.float32) # tile_data_float = tile_data.to(torch.float32)
manual_max = abs(tile_data_float[0]) # manual_max = abs(tile_data_float[0])
for j in range(1, 128): # for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j])) # manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448. # tile_scale = manual_max / 448.
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale # ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end], # ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data, # tile_data,
tile_scale.item(), # tile_scale.item(),
kv_dtype="fp8") # kv_dtype="fp8")
for j in range(qk_rope_head_dim): # for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j] # ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
opcheck( # opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla, # torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), # (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS, # test_utils=DEFAULT_OPCHECK_TEST_UTILS,
) # )
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, # ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale) # kv_cache_dtype, scale)
for i in range(num_tokens): # for i in range(num_tokens):
slot = slot_mapping[i].item() # slot = slot_mapping[i].item()
block_idx = slot // block_size # block_idx = slot // block_size
block_offset = slot % block_size # block_offset = slot % block_size
kv_cache_slice = kv_cache[block_idx, block_offset] # kv_cache_slice = kv_cache[block_idx, block_offset]
ref_cache_slice = ref_cache[block_idx, block_offset] # ref_cache_slice = ref_cache[block_idx, block_offset]
kv_nope = kv_cache_slice[:kv_lora_rank] # kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank] # ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank // # kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4] # 4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view( # ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4] # torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] # kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] # ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1) # torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1) # torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1) # torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
...@@ -993,70 +993,70 @@ def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, ...@@ -993,70 +993,70 @@ def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
torch.testing.assert_close(dst, expected) torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) # @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) # @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) # @pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) # @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) # @pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model # @pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") # @pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode() # @torch.inference_mode()
def test_concat_and_cache_mla_cpu( # def test_concat_and_cache_mla_cpu(
kv_lora_rank: int, # kv_lora_rank: int,
qk_rope_head_dim: int, # qk_rope_head_dim: int,
num_tokens: int, # num_tokens: int,
block_size: int, # block_size: int,
num_blocks: int, # num_blocks: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
) -> None: # ) -> None:
device = "cpu" # device = "cpu"
kv_cache_dtype = "auto" # kv_cache_dtype = "auto"
current_platform.seed_everything(seed) # current_platform.seed_everything(seed)
torch.set_default_device(device) # torch.set_default_device(device)
total_slots = num_blocks * block_size # total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens) # slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, # slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long, # dtype=torch.long,
device=device) # device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) # kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens, # k_pe = torch.randn(num_tokens,
qk_rope_head_dim, # qk_rope_head_dim,
dtype=dtype, # dtype=dtype,
device=device) # device=device)
entry_size = kv_lora_rank + qk_rope_head_dim # entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device) # scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, # kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device) # kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) # ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens): # for i in range(num_tokens):
slot = slot_mapping[i].item() # slot = slot_mapping[i].item()
block_idx = slot // block_size # block_idx = slot // block_size
block_offset = slot % block_size # block_offset = slot % block_size
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i] # ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i] # ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
if kv_cache_dtype == "fp8": # if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) # ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache, # ops.convert_fp8(ref_kv_cache,
ref_temp, # ref_temp,
scale.item(), # scale.item(),
kv_dtype=kv_cache_dtype) # kv_dtype=kv_cache_dtype)
else: # else:
ref_kv_cache = ref_temp # ref_kv_cache = ref_temp
opcheck( # opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla, # torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), # (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS, # test_utils=DEFAULT_OPCHECK_TEST_UTILS,
) # )
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, # ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale) # kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache) # torch.testing.assert_close(kv_cache, ref_kv_cache)
...@@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, ...@@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
is_flashmla_supported) is_flashmla_supported)
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.platforms import current_platform
def cal_diff(x: torch.Tensor, def cal_diff(x: torch.Tensor,
...@@ -42,7 +43,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ ...@@ -42,7 +43,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("torch_dtype", @pytest.mark.parametrize("torch_dtype",
[torch.bfloat16, torch.float16, torch.float8_e4m3fn]) [torch.bfloat16, torch.float16, torch.float8_e4m3fn] if not current_platform.is_rocm() else [torch.bfloat16])
@torch.inference_mode() @torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen, torch_dtype): varlen, torch_dtype):
......
...@@ -87,7 +87,7 @@ def generate_markdown_table(): ...@@ -87,7 +87,7 @@ def generate_markdown_table():
@torch.inference_mode() @torch.inference_mode()
def test_merge_attn_states(num_tokens: int, num_query_heads: int, def test_merge_attn_states(num_tokens: int, num_query_heads: int,
head_size: int, output_dtype: torch.dtype): head_size: int, output_dtype: torch.dtype):
if not current_platform.is_cuda(): if not current_platform.is_cuda() or not current_platform.is_rocm():
pytest.skip('Currently only support compare triton merge_attn_states ' pytest.skip('Currently only support compare triton merge_attn_states '
'with custom cuda merge_attn_states kernel') 'with custom cuda merge_attn_states kernel')
......
...@@ -46,7 +46,7 @@ def test_mha_attn_platform(device: str): ...@@ -46,7 +46,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", patch("vllm.model_executor.models.vision.current_platform",
RocmPlatform()): RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1) attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA assert attn.attn_backend == _Backend.FLASH_ATTN # _Backend.TORCH_SDPA
else: else:
# Test CUDA with head_size=64 (divisible by 32) # Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention # - should use vLLM's FlashAttention
......
...@@ -73,7 +73,7 @@ def test_contexted_kv_attention( ...@@ -73,7 +73,7 @@ def test_contexted_kv_attention(
MAX_CTX_LEN = 1024 MAX_CTX_LEN = 1024
BS = 10 BS = 10
cache_size = 640 cache_size = 640
block_size = 32 block_size = 32 if not current_platform.is_rocm() else 16
max_block_per_request = 64 max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode # ensure one sequence in batch is a decode
...@@ -249,318 +249,318 @@ def test_contexted_kv_attention( ...@@ -249,318 +249,318 @@ def test_contexted_kv_attention(
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) # @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES) # @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) # @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS) # @pytest.mark.parametrize("op", OPS)
@torch.inference_mode() # @torch.inference_mode()
def test_contexted_kv_attention_alibi( # def test_contexted_kv_attention_alibi(
num_heads: int, # num_heads: int,
num_queries_per_kv: int, # num_queries_per_kv: int,
head_size: int, # head_size: int,
dtype: torch.dtype, # dtype: torch.dtype,
kv_cache_dtype: str, # kv_cache_dtype: str,
device: str, # device: str,
op: Callable, # op: Callable,
) -> None: # ) -> None:
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( # if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
89): # 89):
pytest.skip( # pytest.skip(
'Triton limitation: fp8e4nv data type is not supported on CUDA' # 'Triton limitation: fp8e4nv data type is not supported on CUDA'
' arch < 89') # ' arch < 89')
current_platform.seed_everything(0) # current_platform.seed_everything(0)
torch.set_default_device(device) # torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process # # Need this, otherwise when we capture the graph the process
# for GPU 1 would run on both GPU0 and GPU1 and things would hang # # for GPU 1 would run on both GPU0 and GPU1 and things would hang
# # #
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 # # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device) # torch.cuda.set_device(device)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44 # # Fork from: vllm/vllm/model_executor/models/bloom.py#L44
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) # closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor( # base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))), # 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32, # dtype=torch.float32,
) # )
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) # powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers) # slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads: # if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor( # extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), # 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32, # dtype=torch.float32,
) # )
num_remaining_heads = min(closest_power_of_2, # num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2) # total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1, # extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads, # end=1 + 2 * num_remaining_heads,
step=2, # step=2,
dtype=torch.int32) # dtype=torch.int32)
slopes = torch.cat( # slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0) # [slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes # return slopes
alibi_slopes = _get_alibi_slopes(num_heads).to(device) # alibi_slopes = _get_alibi_slopes(num_heads).to(device)
MAX_SEQ_LEN = 1024 # MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024 # MAX_CTX_LEN = 1024
BS = 10 # BS = 10
cache_size = 640 # cache_size = 640
block_size = 32 # block_size = 32
max_block_per_request = 64 # max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] # query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] # ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] # seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv # num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(query_lens) # num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) # query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3) # query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) # output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) # kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3) # kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1) # key, value = kv.unbind(dim=1)
if kv_cache_dtype == "auto": # if kv_cache_dtype == "auto":
cache_dtype = dtype # cache_dtype = dtype
else: # else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] # cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(cache_size, # k_cache = torch.zeros(cache_size,
block_size, # block_size,
num_kv_heads, # num_kv_heads,
head_size, # head_size,
dtype=cache_dtype) # dtype=cache_dtype)
v_cache = torch.zeros(cache_size, # v_cache = torch.zeros(cache_size,
block_size, # block_size,
num_kv_heads, # num_kv_heads,
head_size, # head_size,
dtype=cache_dtype) # dtype=cache_dtype)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) # k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) # v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long) # values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)] # values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view( # block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request) # BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long) # b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) # b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, # b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.long), # dtype=torch.long),
dim=0) # dim=0)
max_input_len = MAX_SEQ_LEN # max_input_len = MAX_SEQ_LEN
# copy kv to cache # # copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], # b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long), # dtype=torch.long),
dim=0) # dim=0)
for i in range(BS): # for i in range(BS):
for j in range(query_lens[i]): # for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + # k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
j]) # j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + # v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
b_ctx_len[i] + j]) # b_ctx_len[i] + j])
cur_ctx = 0 # cur_ctx = 0
block_id = 0 # block_id = 0
while cur_ctx < b_ctx_len[i]: # while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx # start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]: # if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i] # end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else: # else:
end_loc = start_loc + block_size # end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size # start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc # end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads, # k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_( # head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc]) # key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads, # v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_( # head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc]) # value[start_loc:end_loc])
cur_ctx += block_size # cur_ctx += block_size
block_id += 1 # block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] # # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, # k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous() # 8).permute(0, 2, 3, 1, 4).contiguous()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size] # # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads, # v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous() # head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring # # Warm up the Triton kernel by calling it once before actually measuring
# generation time # # generation time
op(query, # op(query,
k, # k,
v, # v,
output, # output,
kv_cache_dtype, # kv_cache_dtype,
k_cache, # k_cache,
v_cache, # v_cache,
block_table, # block_table,
b_start_loc, # b_start_loc,
b_seq_len, # b_seq_len,
MAX_CTX_LEN, # MAX_CTX_LEN,
max_input_len, # max_input_len,
k_scale, # k_scale,
v_scale, # v_scale,
alibi_slopes=alibi_slopes) # alibi_slopes=alibi_slopes)
torch.cuda.synchronize() # torch.cuda.synchronize()
start_time = time.time() # start_time = time.time()
op(query, # op(query,
k, # k,
v, # v,
output, # output,
kv_cache_dtype, # kv_cache_dtype,
k_cache, # k_cache,
v_cache, # v_cache,
block_table, # block_table,
b_start_loc, # b_start_loc,
b_seq_len, # b_seq_len,
MAX_CTX_LEN, # MAX_CTX_LEN,
max_input_len, # max_input_len,
k_scale, # k_scale,
v_scale, # v_scale,
alibi_slopes=alibi_slopes) # alibi_slopes=alibi_slopes)
torch.cuda.synchronize() # torch.cuda.synchronize()
end_time = time.time() # end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") # print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
if not current_platform.is_rocm(): # if not current_platform.is_rocm():
scale = float(1.0 / (head_size**0.5)) # scale = float(1.0 / (head_size**0.5))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function, # # NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding. # # we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]: # if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens), # query_pad = torch.empty(sum(seq_lens),
num_heads, # num_heads,
head_size, # head_size,
dtype=dtype) # dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3) # query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0 # seq_start = 0
query_start = 0 # query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): # for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len # seq_end = seq_start + seq_len
query_end = query_start + query_len # query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat([ # query_pad[seq_start:seq_end, ...] = torch.cat([
torch.zeros( # torch.zeros(
seq_len - query_len, num_heads, head_size, dtype=dtype), # seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...] # query[query_start:query_end, ...]
], # ],
dim=0) # dim=0)
seq_start += seq_len # seq_start += seq_len
query_start += query_len # query_start += query_len
query = query_pad # query = query_pad
if num_kv_heads != num_heads: # if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, # # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of # # project the key and value tensors to the desired number of
# heads. # # heads.
# # #
# see also: vllm/model_executor/layers/attention.py # # see also: vllm/model_executor/layers/attention.py
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, # key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1]) # num_queries_per_kv, key.shape[-1])
value = value[:, :, # value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads, # None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1]) # num_queries_per_kv, value.shape[-1])
# [seq, num_kv_heads, num_queries_per_kv, dk]=> # # [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the # # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime. # # codebase. We save some time reshaping alibi matrix at runtime.
key = key.reshape(key.shape[0], -1, key.shape[-1]) # key = key.reshape(key.shape[0], -1, key.shape[-1])
value = value.reshape(value.shape[0], -1, value.shape[-1]) # value = value.reshape(value.shape[0], -1, value.shape[-1])
query = query.unsqueeze(0) # query = query.unsqueeze(0)
key = key.unsqueeze(0) # key = key.unsqueeze(0)
value = value.unsqueeze(0) # value = value.unsqueeze(0)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) # attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output) # output_ref = torch.empty_like(output)
seq_start = 0 # seq_start = 0
query_start = 0 # query_start = 0
start_time = time.time() # start_time = time.time()
# Attention with alibi slopes. # # Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence # # FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by # # lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts. # # one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343 # # modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): # for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len # seq_end = seq_start + seq_len
query_end = query_start + query_len # query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:, # out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end], # seq_start:seq_end],
key[:, # key[:,
seq_start:seq_end], # seq_start:seq_end],
value[:, # value[:,
seq_start:seq_end], # seq_start:seq_end],
attn_bias=attn_bias[i], # attn_bias=attn_bias[i],
p=0.0, # p=0.0,
scale=scale) # scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view( # out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size) # seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, # output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...]) # ...])
seq_start += seq_len # seq_start += seq_len
query_start += query_len # query_start += query_len
query = query_pad # query = query_pad
if num_kv_heads != num_heads: # if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, # # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of # # project the key and value tensors to the desired number of
# heads. # # heads.
# # #
# see also: vllm/model_executor/layers/attention.py # # see also: vllm/model_executor/layers/attention.py
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, # key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1]) # num_queries_per_kv, key.shape[-1])
value = value[:, :, # value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads, # None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1]) # num_queries_per_kv, value.shape[-1])
# [seq, num_kv_heads, num_queries_per_kv, dk]=> # # [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the # # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime. # # codebase. We save some time reshaping alibi matrix at runtime.
key = key.reshape(key.shape[0], -1, key.shape[-1]) # key = key.reshape(key.shape[0], -1, key.shape[-1])
value = value.reshape(value.shape[0], -1, value.shape[-1]) # value = value.reshape(value.shape[0], -1, value.shape[-1])
query = query.unsqueeze(0) # query = query.unsqueeze(0)
key = key.unsqueeze(0) # key = key.unsqueeze(0)
value = value.unsqueeze(0) # value = value.unsqueeze(0)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) # attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output) # output_ref = torch.empty_like(output)
seq_start = 0 # seq_start = 0
query_start = 0 # query_start = 0
if not current_platform.is_rocm(): # if not current_platform.is_rocm():
start_time = time.time() # start_time = time.time()
# Attention with alibi slopes. # # Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence # # FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by # # lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts. # # one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/v1/attention/backends/xformers.py#L343 # # modified from: vllm/v1/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): # for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len # seq_end = seq_start + seq_len
query_end = query_start + query_len # query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:, # out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end], # seq_start:seq_end],
key[:, # key[:,
seq_start:seq_end], # seq_start:seq_end],
value[:, # value[:,
seq_start:seq_end], # seq_start:seq_end],
attn_bias=attn_bias[i], # attn_bias=attn_bias[i],
p=0.0, # p=0.0,
scale=scale) # scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view( # out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size) # seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, # output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...]) # ...])
seq_start += seq_len # seq_start += seq_len
query_start += query_len # query_start += query_len
torch.cuda.synchronize() # torch.cuda.synchronize()
end_time = time.time() # end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") # print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 # atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) # torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
# These tests are optional to only run when explicitly invoked # These tests are optional to only run when explicitly invoked
...@@ -595,23 +595,23 @@ def test_contexted_kv_attention_f32( ...@@ -595,23 +595,23 @@ def test_contexted_kv_attention_f32(
op) op)
@pytest.mark.optional # @pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) # @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES) # @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32]) # @pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) # @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS) # @pytest.mark.parametrize("op", OPS)
@torch.inference_mode() # @torch.inference_mode()
def test_contexted_kv_attention_alibi_f32( # def test_contexted_kv_attention_alibi_f32(
num_heads: int, # num_heads: int,
num_queries_per_kv: int, # num_queries_per_kv: int,
head_size: int, # head_size: int,
dtype: torch.dtype, # dtype: torch.dtype,
kv_cache_dtype: str, # kv_cache_dtype: str,
device: str, # device: str,
op: Callable, # op: Callable,
) -> None: # ) -> None:
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, # test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
dtype, kv_cache_dtype, device, op) # dtype, kv_cache_dtype, device, op)
...@@ -21,7 +21,7 @@ NUM_HEADS = [(4, 4), (8, 2)] ...@@ -21,7 +21,7 @@ NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16] BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16] DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn] QDTYPES = [None, torch.float8_e4m3fn] if current_platform.is_rocm() else [None]
# one value large enough to test overflow in index calculation. # one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check # one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048] NUM_BLOCKS = [32768, 2048]
...@@ -199,8 +199,8 @@ def test_flash_attn_with_paged_kv( ...@@ -199,8 +199,8 @@ def test_flash_attn_with_paged_kv(
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.skipif(current_platform.is_rocm(), # @pytest.mark.skipif(current_platform.is_rocm(),
reason="varlen_with_paged_kv is not supported on ROCm.") # reason="varlen_with_paged_kv is not supported on ROCm.")
@pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens", @pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18), [[(1, 1328), (5, 18),
...@@ -302,10 +302,10 @@ def test_varlen_with_paged_kv( ...@@ -302,10 +302,10 @@ def test_varlen_with_paged_kv(
window_size=window_size, window_size=window_size,
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version, # fa_version=fa_version,
q_descale=q_descale, # q_descale=q_descale,
k_descale=k_descale, # k_descale=k_descale,
v_descale=v_descale, # v_descale=v_descale,
) )
output = output if not use_out else out output = output if not use_out else out
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple from typing import NamedTuple
import os
import pytest import pytest
import torch import torch
from packaging.version import Version from packaging.version import Version
...@@ -10,6 +11,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION ...@@ -10,6 +11,7 @@ from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ...utils import models_path_prefix
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -48,12 +50,12 @@ class MRoPETestInfo(NamedTuple): ...@@ -48,12 +50,12 @@ class MRoPETestInfo(NamedTuple):
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
MODELS_TO_TEST = [ MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"), MRoPETestInfo(model_name=os.path.join(models_path_prefix, "zai-org/GLM-4.1V-9B-Thinking")),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"), MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-7B-Instruct")),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"), MRoPETestInfo(model_name=os.path.join(models_path_prefix, "Qwen/Qwen2-VL-72B-Instruct")),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), # MRoPETestInfo(model_name=os.path.join("Qwen/Qwen2.5-VL-72B-Instruct")),
MRoPETestInfo( MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct", model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-4B-Instruct"),
marks=[ marks=[
pytest.mark.skipif( pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
...@@ -61,7 +63,7 @@ MODELS_TO_TEST = [ ...@@ -61,7 +63,7 @@ MODELS_TO_TEST = [
) )
]), ]),
MRoPETestInfo( MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct", model_name=os.path.join(models_path_prefix, "Qwen/Qwen3-VL-30B-A3B-Instruct"),
marks=[ marks=[
pytest.mark.skipif( pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
......
...@@ -172,7 +172,7 @@ def torch_moe_align_block_size( ...@@ -172,7 +172,7 @@ def torch_moe_align_block_size(
@pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("pad_sorted_ids", [False, True]) @pytest.mark.parametrize("pad_sorted_ids", [False, True])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") # @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size(m: int, topk: int, num_experts: int, def test_moe_align_block_size(m: int, topk: int, num_experts: int,
block_size: int, pad_sorted_ids: bool): block_size: int, pad_sorted_ids: bool):
"""Test moe_align_block_size without expert mapping""" """Test moe_align_block_size without expert mapping"""
...@@ -235,7 +235,7 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int, ...@@ -235,7 +235,7 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int,
@pytest.mark.parametrize("topk", [2, 4]) @pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") # @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map(m: int, topk: int, def test_moe_align_block_size_with_expert_map(m: int, topk: int,
num_experts: int, num_experts: int,
block_size: int): block_size: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment