Commit 0eaf8026 authored by zhuwenwen's avatar zhuwenwen
Browse files

replace triton_ of rms and act_and_mul

parent a6088d09
...@@ -75,7 +75,9 @@ class SiluAndMul(CustomOp): ...@@ -75,7 +75,9 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO: if not torch.compiler.is_compiling() and envs.envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x)
elif envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x) return self.forward_cuda(x)
else: else:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
......
...@@ -165,8 +165,9 @@ class RMSNorm(CustomOp): ...@@ -165,8 +165,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO: return self.forward_cuda(x, residual)
elif envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x, residual) return self.forward_cuda(x, residual)
else: else:
orig_dtype = x.dtype orig_dtype = x.dtype
...@@ -203,7 +204,7 @@ class RMSNorm(CustomOp): ...@@ -203,7 +204,7 @@ class RMSNorm(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None: if self.variance_size_override is not None:
return self.forward_native(x, residual) return self.forward_native(x, residual)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment