Commit 573915c9 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix rn_add_forward_autograd import

parent 9ddd0f97
......@@ -18,7 +18,7 @@ try:
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
import lightop
from lightop import op
except Exception:
print("INFO: Please install lightop if you want to infer awq of marlin.\n")
......@@ -351,20 +351,20 @@ def fused_add_rms_norm(
# layer norm ops (opt)
def rms_norm_opt(input: torch.Tensor, weight: torch.Tensor, out: torch.Tensor,
epsilon: float, training: Optional[bool]=False) -> None:
lightop.rmsnorm_forward(input, weight, out, epsilon, training)
op.rmsnorm_forward(input, weight, out, epsilon, training)
def rms_norm_opt_fake(
input: torch.Tensor,
weight: torch.Tensor,
out: torch.Tensor,
epsilon: float,
training: Optional[bool] = False
training: Optional[bool] = False,
) -> torch.Tensor:
return torch.empty_like(input)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=False) -> None:
lightop.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace)
weight: torch.Tensor, epsilon: float, training: Optional[bool]=False, inplace: Optional[bool]=True) -> None:
op.rn_add_forward_autograd(input, residual, weight, epsilon, training, inplace)
def fused_add_rms_norm_opt_fake(
input: torch.Tensor,
......@@ -372,7 +372,7 @@ def fused_add_rms_norm_opt_fake(
weight: torch.Tensor,
epsilon: float,
training: Optional[bool] = False,
inplace: Optional[bool] = False
inplace: Optional[bool] = False,
) -> torch.Tensor:
return torch.empty_like(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment