Commit fb785094 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Remove Python 3.9-only math function

parent 3bb47f89
...@@ -27,10 +27,16 @@ from openfold.utils.tensor_utils import ( ...@@ -27,10 +27,16 @@ from openfold.utils.tensor_utils import (
) )
def _prod(nums):
out = 1
for n in nums:
out = out * n
return nums
def _calculate_fan(shape, fan="fan_in"): def _calculate_fan(shape, fan="fan_in"):
i = shape[0] i = shape[0]
o = shape[1] o = shape[1]
prod = math.prod(shape[:2]) prod = _prod(shape[:2])
fan_in = prod * i fan_in = prod * i
fan_out = prod * o fan_out = prod * o
...@@ -53,7 +59,7 @@ def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): ...@@ -53,7 +59,7 @@ def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
a = -2 a = -2
b = 2 b = 2
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
size = math.prod(shape) size = _prod(shape)
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
samples = np.reshape(samples, shape) samples = np.reshape(samples, shape)
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment