Commit cb68935c authored by wanghl6's avatar wanghl6
Browse files

topk opt

parent 0bd5fcd2
......@@ -320,8 +320,9 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8: bool = False
USE_LIGHTOP_TOPK: bool = False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
def get_default_cache_root():
return os.getenv(
"XDG_CACHE_HOME",
......@@ -1990,6 +1991,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D":
lambda: (os.environ.get("VLLM_V1_USE_FA_UNIFIED_ATTN_2D", "False").lower() in
("true", "1")),
"USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8":
lambda: (os.environ.get("USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8", "False").lower() in
("true", "1")),
"USE_LIGHTOP_TOPK":
lambda: (os.environ.get("USE_LIGHTOP_TOPK", "False").lower() in
("true", "1")),
"USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX":
lambda: (os.environ.get("USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -3,7 +3,7 @@
"""Custom Sparse Attention Indexer layers."""
import torch
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
......@@ -170,6 +170,7 @@ def sparse_attn_indexer(
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
......@@ -180,6 +181,17 @@ def sparse_attn_indexer(
logits.stride(1),
topk_tokens,
)
else:
op.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode:
decode_metadata = attn_metadata.decode
......@@ -230,6 +242,9 @@ def sparse_attn_indexer(
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
# if torch.distributed.get_rank() == 0:
# print(f"====[DEBUG] logits shape: {logits.shape}, next_n: {next_n}, topk_tokens size: {topk_tokens}")
if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
......@@ -240,7 +255,17 @@ def sparse_attn_indexer(
logits.stride(1),
topk_tokens,
)
else:
op.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment