Commit 3842b316 authored by laibao's avatar laibao Committed by zhangzbb
Browse files

[FEATURE] 接入 LightOP 的 silu_and_mul 自定义算子并统一 OPT 路径

parent d4bca618
......@@ -383,6 +383,24 @@ def fused_add_rms_norm_opt_fake(
) -> None:
return None
def silu_and_mul_opt_lightop(input: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.silu_and_mul_opt_lightop(input)
def silu_and_mul_opt_lightop_impl(input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=input.dtype, device=input.device)
op.silu_and_mul_opt(out, input)
return out
def silu_and_mul_opt_lightop_fake(input: torch.Tensor) -> torch.Tensor:
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
return input.new_empty(output_shape)
def fused_qk_norm_rope(
qkv: torch.Tensor,
num_heads_q: int,
......@@ -3631,6 +3649,13 @@ direct_register_custom_op(
fake_impl=fused_add_rms_norm_opt_fake,
)
direct_register_custom_op(
op_name="silu_and_mul_opt_lightop",
op_func=silu_and_mul_opt_lightop_impl,
mutates_args=[],
fake_impl=silu_and_mul_opt_lightop_fake,
)
"""
qwen3-vl-8b中LLM的修改 rms+mrope dim==1 2026/03/18
"""
......
......@@ -150,12 +150,14 @@ class SiluAndMul(CustomOp):
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
if envs.VLLM_USE_OPT_OP:
from vllm import _custom_ops as ops
return ops.silu_and_mul_opt_lightop(x)
else:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if envs.VLLM_USE_OPT_OP:
self.op_opt(out, x)
else:
self.op(out, x)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment