Unverified Commit 622f8632 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

[hotfix] Jit type hint #2161 (#2164)

parent 27327a4c
...@@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d( ...@@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d(
@torch.jit.script @torch.jit.script
def norm_forward(x, mean, sqr_mean, weight, bias, eps): def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float):
mu = x - mean mu = x - mean
var = sqr_mean - mean**2 var = sqr_mean - mean**2
sigma = torch.sqrt(var + eps) sigma = torch.sqrt(var + eps)
...@@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps): ...@@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps):
@torch.jit.script @torch.jit.script
def norm_backward(grad, mu, sigma, weight): def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
# dbias, dweight = grad, grad * mu / sigma # dbias, dweight = grad, grad * mu / sigma
dz = grad * weight dz = grad * weight
dmu = dz / sigma dmu = dz / sigma
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment