Commit 424fa81f authored by zhuwenwen's avatar zhuwenwen
Browse files

back to forward_static

parent 57e945fd
......@@ -148,7 +148,7 @@ class RMSNorm(CustomOp):
@staticmethod
def forward_static(
self,
# self,
x: torch.Tensor,
variance_epsilon: float,
hidden_size: int,
......@@ -158,9 +158,9 @@ class RMSNorm(CustomOp):
variance_size_override: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x, residual)
else:
# if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
# return self.forward_cuda(x, residual)
# else:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment