Commit a5b976df authored by zhuwenwen's avatar zhuwenwen
Browse files

解决PA部分size计算错误的问题

优化bf16精度
解决bf16精度问题,解决cudagraph精度问题
调整pa tc和非tc调用关系
parent 10ce38cc
...@@ -14,7 +14,9 @@ if HAS_TRITON: ...@@ -14,7 +14,9 @@ if HAS_TRITON:
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
support_tc = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and support_tc
@dataclass @dataclass
class PagedAttentionMetadata: class PagedAttentionMetadata:
...@@ -128,12 +130,62 @@ class PagedAttention: ...@@ -128,12 +130,62 @@ class PagedAttention:
# to parallelize. # to parallelize.
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
if envs.VLLM_USE_TC_PAGED_ATTN:
use_v1 = (max_seq_len < 8192
and (max_seq_len<(1024 if num_kv_heads == num_heads else 600) or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512))) if use_tc and head_size==128:
else: if envs.VLLM_USE_PA_PRINT_PARAM:
use_v1 = (max_seq_len <= 8192 print("PA V1 SIZE:")
and (max_num_partitions == 1 or num_seqs * num_heads > 512)) print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if attn_masks is None:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v1_opt_tc_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
return output
use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
...@@ -143,100 +195,52 @@ class PagedAttention: ...@@ -143,100 +195,52 @@ class PagedAttention:
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN: if attn_masks is None:
if attn_masks is None: ops.paged_attention_v1_opt(
ops.paged_attention_v1_opt_tc( output,
output, query,
query, key_cache,
key_cache, value_cache,
value_cache, num_kv_heads,
num_kv_heads, scale,
scale, block_tables,
block_tables, seq_lens,
seq_lens, block_size,
block_size, max_seq_len,
max_seq_len, alibi_slopes,
alibi_slopes, kv_cache_dtype,
kv_cache_dtype, k_scale,
k_scale, v_scale,
v_scale, tp_rank,
tp_rank, blocksparse_local_blocks,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_block_size, blocksparse_head_sliding_step
blocksparse_head_sliding_step )
)
else:
ops.paged_attention_v1_opt_tc_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else: else:
if attn_masks is None: ops.paged_attention_v1_opt_with_mask(
ops.paged_attention_v1_opt( output,
output, query,
query, key_cache,
key_cache, value_cache,
value_cache, num_kv_heads,
num_kv_heads, scale,
scale, block_tables,
block_tables, seq_lens,
seq_lens, block_size,
block_size, max_seq_len,
max_seq_len, alibi_slopes,
alibi_slopes, kv_cache_dtype,
kv_cache_dtype, k_scale,
k_scale, v_scale,
v_scale, tp_rank,
tp_rank, blocksparse_local_blocks,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_block_size, blocksparse_head_sliding_step,
blocksparse_head_sliding_step attn_masks,
) attn_masks_stride
else: )
ops.paged_attention_v1_opt_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else: else:
if attn_masks is None: if attn_masks is None:
ops.paged_attention_v1( ops.paged_attention_v1(
...@@ -306,112 +310,58 @@ class PagedAttention: ...@@ -306,112 +310,58 @@ class PagedAttention:
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN: if attn_masks is None:
if attn_masks is None: ops.paged_attention_v2_opt(
ops.paged_attention_v2_opt_tc( output,
output, exp_sums,
exp_sums, max_logits,
max_logits, tmp_output,
tmp_output, query,
query, key_cache,
key_cache, value_cache,
value_cache, num_kv_heads,
num_kv_heads, scale,
scale, block_tables,
block_tables, seq_lens,
seq_lens, block_size,
block_size, max_seq_len,
max_seq_len, alibi_slopes,
alibi_slopes, kv_cache_dtype,
kv_cache_dtype, k_scale,
k_scale, v_scale,
v_scale, tp_rank,
tp_rank, blocksparse_local_blocks,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_block_size, blocksparse_head_sliding_step
blocksparse_head_sliding_step )
)
else:
ops.paged_attention_v2_opt_tc_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else: else:
if attn_masks is None: ops.paged_attention_v2_opt_with_mask(
ops.paged_attention_v2_opt( output,
output, exp_sums,
exp_sums, max_logits,
max_logits, tmp_output,
tmp_output, query,
query, key_cache,
key_cache, value_cache,
value_cache, num_kv_heads,
num_kv_heads, scale,
scale, block_tables,
block_tables, seq_lens,
seq_lens, block_size,
block_size, max_seq_len,
max_seq_len, alibi_slopes,
alibi_slopes, kv_cache_dtype,
kv_cache_dtype, k_scale,
k_scale, v_scale,
v_scale, tp_rank,
tp_rank, blocksparse_local_blocks,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_block_size, blocksparse_head_sliding_step,
blocksparse_head_sliding_step attn_masks,
) attn_masks_stride
else: )
ops.paged_attention_v2_opt_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else: else:
if attn_masks is None: if attn_masks is None:
ops.paged_attention_v2( ops.paged_attention_v2(
...@@ -528,4 +478,4 @@ class PagedAttention: ...@@ -528,4 +478,4 @@ class PagedAttention:
) -> None: ) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches] key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists) ops.copy_blocks(key_caches, value_caches, src_to_dists)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment