Commit fb785094 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove Python 3.9-only math function

parent 3bb47f89
......@@ -27,10 +27,16 @@ from openfold.utils.tensor_utils import (
)
def _prod(nums):
out = 1
for n in nums:
out = out * n
return nums
def _calculate_fan(shape, fan="fan_in"):
i = shape[0]
o = shape[1]
prod = math.prod(shape[:2])
prod = _prod(shape[:2])
fan_in = prod * i
fan_out = prod * o
......@@ -53,7 +59,7 @@ def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
a = -2
b = 2
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
size = math.prod(shape)
size = _prod(shape)
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
samples = np.reshape(samples, shape)
with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment