Commit 153002ad authored by wanghl6's avatar wanghl6
Browse files

[Perf]融合算子优化

parent aef3c487
......@@ -325,7 +325,8 @@ if TYPE_CHECKING:
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
VLLM_DISABLE_DSA: bool = False
VLLM_LIGHTLY_CP_THRESHOULD: int = 2048
USE_LIGHTOP_CP_CONVERT_FP8_KV_CACHE : bool = False
USE_LIGHTOP_FUSE_LN_ROPE_QUANT : bool = False
def get_default_cache_root():
return os.getenv(
......@@ -2012,6 +2013,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
# MLA_CP open threshold
"VLLM_LIGHTLY_CP_THRESHOULD":
lambda: int(os.getenv("VLLM_LIGHTLY_CP_THRESHOULD", "2048")),
"USE_LIGHTOP_CP_CONVERT_FP8_KV_CACHE":
lambda: (os.environ.get("USE_LIGHTOP_CP_CONVERT_FP8_KV_CACHE", "False").lower() in
("true", "1")),
"USE_LIGHTOP_FUSE_LN_ROPE_QUANT":
lambda: (os.environ.get("USE_LIGHTOP_FUSE_LN_ROPE_QUANT", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -1067,7 +1067,51 @@ def per_token_group_quant_fp8(
return x_q, x_s
def _lightop_fuse_norm_rope_quant_fp8_impl(
positions: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
head_dim: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
is_rmsnorm: bool,
weight_k: torch.Tensor | None,
bias_k: torch.Tensor | None,
eps: float
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from lightop import op
return op.fuse_norm_rope_quant_fp8(
positions, q, k, head_dim, cos_sin_cache,
is_neox, is_rmsnorm, weight_k, bias_k, eps
)
def _lightop_fuse_norm_rope_quant_fp8_fake(
positions: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
head_dim: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
is_rmsnorm: bool,
weight_k: torch.Tensor | None,
bias_k: torch.Tensor | None,
eps: float
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
k_out = torch.empty_like(k)
fp8_dtype = current_platform.fp8_dtype() if hasattr(current_platform, "fp8_dtype") else torch.float8_e4m3fn
q_fp8_out = torch.empty_like(q, dtype=fp8_dtype)
q_scale_out = torch.empty(
(q.shape[0], q.shape[1], 1),
dtype=torch.float32,
device=q.device
)
return k_out, q_fp8_out, q_scale_out
direct_register_custom_op(
op_name="lightop_fuse_norm_rope_quant_fp8",
op_func=_lightop_fuse_norm_rope_quant_fp8_impl,
mutates_args=[],
fake_impl=_lightop_fuse_norm_rope_quant_fp8_fake,
)
def per_token_group_quant_fp8_packed_for_deepgemm(
x: torch.Tensor,
group_size: int,
......
......@@ -860,53 +860,68 @@ class Indexer(nn.Module):
) -> torch.Tensor:
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
if envs.USE_FUSED_RMS_QUANT and self.wk.weight.dtype == torch.int8 and iqis is not None:
k, _ = self.wk(hidden_states, iqis=iqis)
else:
k, _ = self.wk(hidden_states)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
# Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
# so we need to reshape back to token-flattened shapes
q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
k_pe = k_pe.reshape(-1, 1, self.rope_dim)
# `rotary_emb` is shape-preserving; `q_pe` is already
# [num_tokens, n_head, rope_dim].
q = torch.cat([q_pe, q_nope], dim=-1)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
if enable_lightly_cplb and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor)
# we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(
if envs.USE_LIGHTOP_FUSE_LN_ROPE_QUANT:
is_rmsnorm = not hasattr(self.k_norm, 'bias') or self.k_norm.bias is None
weight_k = getattr(self.k_norm, 'weight', None)
bias_k = getattr(self.k_norm, 'bias', None)
eps = getattr(self.k_norm, 'eps', 1e-5)
cos_sin_cache = getattr(rotary_emb, 'cos_sin_cache', None)
is_neox = getattr(rotary_emb, 'is_neox', True)
k, q_fp8, q_scale = torch.ops.vllm.lightop_fuse_norm_rope_quant_fp8(
positions,
q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt is not None,
k,
self.head_dim,
cos_sin_cache,
is_neox,
is_rmsnorm,
weight_k,
bias_k,
eps
)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
if current_platform.is_rocm() and torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] != "gfx938":
q_fp8 = q
q_scale = None
else:
q_fp8 = q
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
# Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
# so we need to reshape back to token-flattened shapes
q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
k_pe = k_pe.reshape(-1, 1, self.rope_dim)
# `rotary_emb` is shape-preserving; `q_pe` is already
# [num_tokens, n_head, rope_dim].
q = torch.cat([q_pe, q_nope], dim=-1)
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(
q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt is not None,
)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
else:
q_fp8 = q
q_scale = None
if envs.USE_FUSED_RMS_QUANT and self.weights_proj.weight.dtype == torch.int8 and iqis is not None:
weights, _ = self.weights_proj(hidden_states, iqis=iqis)
......
......@@ -868,14 +868,25 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
assert fp8_metadata.prefill is not None
for chunk in fp8_metadata.prefill.chunks:
chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
ops.cp_gather_and_upconvert_fp8_kv_cache(
kv_c_and_k_pe_cache,
chunk_workspace,
chunk.block_table,
chunk.seq_lens,
chunk.workspace_starts,
len(chunk.block_table),
)
if not envs.USE_LIGHTOP_CP_CONVERT_FP8_KV_CACHE:
ops.cp_gather_and_upconvert_fp8_kv_cache(
kv_c_and_k_pe_cache,
chunk_workspace,
chunk.block_table,
chunk.seq_lens,
chunk.workspace_starts,
len(chunk.block_table),
)
else:
from lightop import op
op.cp_gather_and_upconvert_fp8_kv_cache(
kv_c_and_k_pe_cache,
chunk_workspace,
chunk.block_table,
chunk.seq_lens,
chunk.workspace_starts,
len(chunk.block_table),
)
chunk_q = q[chunk.tokens_slice]
chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment