Commit 650cb815 authored by dongchl's avatar dongchl
Browse files

bug fix

parent 6bdc5d69
...@@ -975,7 +975,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -975,7 +975,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
dgrad = dgrad.reshape(inputmat.size()) dgrad = dgrad.reshape(inputmat.size())
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
if enable_lightop and (rsigma is torch.bfloat16 or rsigma is torch.float16): if enable_lightop and (rsigma.dtype is torch.bfloat16 or rsigma.dtype is torch.float16):
dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight) dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment