Commit 424fa81f authored by zhuwenwen's avatar zhuwenwen
Browse files

back to forward_static

parent 57e945fd
...@@ -148,7 +148,7 @@ class RMSNorm(CustomOp): ...@@ -148,7 +148,7 @@ class RMSNorm(CustomOp):
@staticmethod @staticmethod
def forward_static( def forward_static(
self, # self,
x: torch.Tensor, x: torch.Tensor,
variance_epsilon: float, variance_epsilon: float,
hidden_size: int, hidden_size: int,
...@@ -158,45 +158,45 @@ class RMSNorm(CustomOp): ...@@ -158,45 +158,45 @@ class RMSNorm(CustomOp):
variance_size_override: int | None = None, variance_size_override: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP: # if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x, residual) # return self.forward_cuda(x, residual)
# else:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
# residual promoted f16->f32 automatically,
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x = x + residual
residual = x.to(orig_dtype)
if x.shape[-1] != hidden_size:
raise ValueError(
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
)
if variance_size_override is None:
x_var = x
else: else:
orig_dtype = x.dtype if hidden_size < variance_size_override:
x = x.to(torch.float32)
if residual is not None:
# residual promoted f16->f32 automatically,
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x = x + residual
residual = x.to(orig_dtype)
if x.shape[-1] != hidden_size:
raise ValueError( raise ValueError(
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}" "Expected hidden_size to be at least "
f"{variance_size_override}, but found: {hidden_size}"
) )
if variance_size_override is None: x_var = x[:, :, :variance_size_override]
x_var = x
else:
if hidden_size < variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{variance_size_override}, but found: {hidden_size}"
)
x_var = x[:, :, :variance_size_override] variance = x_var.pow(2).mean(dim=-1, keepdim=True)
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x * torch.rsqrt(variance + variance_epsilon) x = x.to(orig_dtype)
x = x.to(orig_dtype) if weight is not None:
if weight is not None: x = x * weight
x = x * weight if residual is None:
if residual is None: return x
return x else:
else: return x, residual
return x, residual
def forward_native( def forward_native(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment