Commit 3912d41c authored by zhuwenwen's avatar zhuwenwen
Browse files

replace triton_ of rms and act_and_mul

parent 613edd7d
...@@ -77,6 +77,8 @@ class SiluAndMul(CustomOp): ...@@ -77,6 +77,8 @@ class SiluAndMul(CustomOp):
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO: if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x) return self.forward_cuda(x)
elif not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x)
else: else:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:] return F.silu(x[..., :d]) * x[..., d:]
......
...@@ -168,6 +168,8 @@ class RMSNorm(CustomOp): ...@@ -168,6 +168,8 @@ class RMSNorm(CustomOp):
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO: if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x, residual) return self.forward_cuda(x, residual)
elif not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x, residual)
else: else:
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.to(torch.float32) x = x.to(torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment