Commit 714c12da authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'wanghl_glm5_kernel_opt' into 'v0.15.1-dev'

glm5 融合算子优化

See merge request dcutoolkit/deeplearing/vllm!534
parents 0bd5fcd2 71276043
......@@ -320,8 +320,9 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_FUSED_TOPP_TOPK: bool = False
VLLM_V1_USE_FA_UNIFIED_ATTN_2D: bool = False
VLLM_ENABLE_RAY_ASYNC_SCHEDULING: bool = False
USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8: bool = False
USE_LIGHTOP_TOPK: bool = False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
def get_default_cache_root():
return os.getenv(
"XDG_CACHE_HOME",
......@@ -1990,6 +1991,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_USE_FA_UNIFIED_ATTN_2D":
lambda: (os.environ.get("VLLM_V1_USE_FA_UNIFIED_ATTN_2D", "False").lower() in
("true", "1")),
"USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8":
lambda: (os.environ.get("USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8", "False").lower() in
("true", "1")),
"USE_LIGHTOP_TOPK":
lambda: (os.environ.get("USE_LIGHTOP_TOPK", "False").lower() in
("true", "1")),
"USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX":
lambda: (os.environ.get("USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -915,6 +915,37 @@ def _per_token_group_quant_fp8_colmajor(
tl.store(y_s_ptr, y_s)
def _lightop_per_token_group_quant_fp8_impl(
x_q: torch.Tensor,
x: torch.Tensor,
x_s: torch.Tensor,
group_size: int,
eps: float,
use_ue8m0: bool,
) -> None:
from lightop import op
op.per_token_group_quant_fp8(x_q, x, x_s, group_size, eps, use_ue8m0)
def _lightop_per_token_group_quant_fp8_fake(
x_q: torch.Tensor,
x: torch.Tensor,
x_s: torch.Tensor,
group_size: int,
eps: float,
use_ue8m0: bool,
) -> None:
pass
direct_register_custom_op(
"lightop_per_token_group_quant_fp8",
_lightop_per_token_group_quant_fp8_impl,
mutates_args=["x_q", "x_s"],
fake_impl=_lightop_per_token_group_quant_fp8_fake,
)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
......@@ -980,7 +1011,11 @@ def per_token_group_quant_fp8(
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
if envs.USE_LIGHTOP_PER_TOKEN_GROUP_QUANT_FP8 and not column_major_scales:
torch.ops.vllm.lightop_per_token_group_quant_fp8(x_q, x, x_s, group_size, eps, use_ue8m0)
return x_q, x_s
# prefer CUDA kernel if available
# TODO(bnell): this causes some fp8 moe test to fail.
if current_platform.is_cuda() and x.is_contiguous():
......@@ -1743,4 +1778,4 @@ def process_fp8_input_tensor_strategy_moe(
"for each layer."
)
return w13_input_scale.max(), w2_input_scale.max()
return w13_input_scale.max(), w2_input_scale.max()
\ No newline at end of file
......@@ -3,7 +3,7 @@
"""Custom Sparse Attention Indexer layers."""
import torch
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
......@@ -170,16 +170,28 @@ def sparse_attn_indexer(
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
else:
op.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode:
decode_metadata = attn_metadata.decode
......@@ -230,17 +242,30 @@ def sparse_attn_indexer(
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
# if torch.distributed.get_rank() == 0:
# print(f"====[DEBUG] logits shape: {logits.shape}, next_n: {next_n}, topk_tokens size: {topk_tokens}")
if not envs.USE_LIGHTOP_TOPK:
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
else:
op.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
......
......@@ -302,6 +302,18 @@ def triton_convert_req_index_to_global_index(
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
starts for each prefill request
"""
if (envs.USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX):
from lightop import op
return op.convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE,
NUM_TOPK_TOKENS,
HAS_PREFILL_WORKSPACE,
prefill_workspace_request_ids,
prefill_workspace_starts
)
assert req_id.dtype == torch.int32
assert block_table.dtype == torch.int32
assert token_indices.dtype == torch.int32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment