Unverified Commit 622f8632 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

[hotfix] Jit type hint #2161 (#2164)

parent 27327a4c
......@@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d(
@torch.jit.script
def norm_forward(x, mean, sqr_mean, weight, bias, eps):
def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float):
mu = x - mean
var = sqr_mean - mean**2
sigma = torch.sqrt(var + eps)
......@@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps):
@torch.jit.script
def norm_backward(grad, mu, sigma, weight):
def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor):
# dbias, dweight = grad, grad * mu / sigma
dz = grad * weight
dmu = dz / sigma
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment