Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
def329b8
Commit
def329b8
authored
Oct 16, 2023
by
Christina Floristean
Browse files
Removed references to jax
parent
8470b803
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
4 additions
and
234 deletions
+4
-234
environment.yml
environment.yml
+0
-1
openfold/utils/geometry/__init__.py
openfold/utils/geometry/__init__.py
+0
-3
openfold/utils/geometry/rigid_matrix_vector.py
openfold/utils/geometry/rigid_matrix_vector.py
+0
-1
openfold/utils/geometry/rotation_matrix.py
openfold/utils/geometry/rotation_matrix.py
+0
-2
openfold/utils/geometry/struct_of_array.py
openfold/utils/geometry/struct_of_array.py
+0
-220
openfold/utils/geometry/test_utils.py
openfold/utils/geometry/test_utils.py
+4
-5
openfold/utils/geometry/vector.py
openfold/utils/geometry/vector.py
+0
-2
No files found.
environment.yml
View file @
def329b8
...
...
@@ -19,7 +19,6 @@ dependencies:
-
deepspeed==0.5.10
-
dm-tree==0.1.6
-
ml-collections==0.1.0
-
jax==0.3.25
-
pandas==2.0.2
-
numpy==1.21.2
-
PyYAML==5.4.1
...
...
openfold/utils/geometry/__init__.py
View file @
def329b8
...
...
@@ -15,14 +15,11 @@
from
openfold.utils.geometry
import
rigid_matrix_vector
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
vector
Rot3Array
=
rotation_matrix
.
Rot3Array
Rigid3Array
=
rigid_matrix_vector
.
Rigid3Array
StructOfArray
=
struct_of_array
.
StructOfArray
Vec3Array
=
vector
.
Vec3Array
square_euclidean_distance
=
vector
.
square_euclidean_distance
euclidean_distance
=
vector
.
euclidean_distance
...
...
openfold/utils/geometry/rigid_matrix_vector.py
View file @
def329b8
...
...
@@ -20,7 +20,6 @@ from typing import Union, List
import
torch
from
openfold.utils.geometry
import
rotation_matrix
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
vector
...
...
openfold/utils/geometry/rotation_matrix.py
View file @
def329b8
...
...
@@ -18,9 +18,7 @@ import dataclasses
from
typing
import
List
import
torch
import
numpy
as
np
from
openfold.utils.geometry
import
struct_of_array
from
openfold.utils.geometry
import
utils
from
openfold.utils.geometry
import
vector
from
openfold.utils.tensor_utils
import
tensor_tree_map
...
...
openfold/utils/geometry/struct_of_array.py
deleted
100644 → 0
View file @
8470b803
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class decorator to represent (nested) struct of arrays."""
import
dataclasses
import
jax
def
get_item
(
instance
,
key
):
sliced
=
{}
for
field
in
get_array_fields
(
instance
):
num_trailing_dims
=
field
.
metadata
.
get
(
'num_trailing_dims'
,
0
)
this_key
=
key
if
isinstance
(
key
,
tuple
)
and
Ellipsis
in
this_key
:
this_key
+=
(
slice
(
None
),)
*
num_trailing_dims
sliced
[
field
.
name
]
=
getattr
(
instance
,
field
.
name
)[
this_key
]
return
dataclasses
.
replace
(
instance
,
**
sliced
)
@
property
def
get_shape
(
instance
):
"""Returns Shape for given instance of dataclass."""
first_field
=
dataclasses
.
fields
(
instance
)[
0
]
num_trailing_dims
=
first_field
.
metadata
.
get
(
'num_trailing_dims'
,
None
)
value
=
getattr
(
instance
,
first_field
.
name
)
if
num_trailing_dims
:
return
value
.
shape
[:
-
num_trailing_dims
]
else
:
return
value
.
shape
def
get_len
(
instance
):
"""Returns length for given instance of dataclass."""
shape
=
instance
.
shape
if
shape
:
return
shape
[
0
]
else
:
raise
TypeError
(
'len() of unsized object'
)
# Match jax.numpy behavior.
@
property
def
get_dtype
(
instance
):
"""Returns Dtype for given instance of dataclass."""
fields
=
dataclasses
.
fields
(
instance
)
sets_dtype
=
[
field
.
name
for
field
in
fields
if
field
.
metadata
.
get
(
'sets_dtype'
,
False
)
]
if
sets_dtype
:
assert
len
(
sets_dtype
)
==
1
,
'at most field can set dtype'
field_value
=
getattr
(
instance
,
sets_dtype
[
0
])
elif
instance
.
same_dtype
:
field_value
=
getattr
(
instance
,
fields
[
0
].
name
)
else
:
# Should this be Value Error?
raise
AttributeError
(
'Trying to access Dtype on Struct of Array without'
'either "same_dtype" or field setting dtype'
)
if
hasattr
(
field_value
,
'dtype'
):
return
field_value
.
dtype
else
:
# Should this be Value Error?
raise
AttributeError
(
f
'field_value
{
field_value
}
does not have dtype'
)
def
replace
(
instance
,
**
kwargs
):
return
dataclasses
.
replace
(
instance
,
**
kwargs
)
def
post_init
(
instance
):
"""Validate instance has same shapes & dtypes."""
array_fields
=
get_array_fields
(
instance
)
arrays
=
list
(
get_array_fields
(
instance
,
return_values
=
True
).
values
())
first_field
=
array_fields
[
0
]
# These slightly weird constructions about checking whether the leaves are
# actual arrays is since e.g. vmap internally relies on being able to
# construct pytree's with object() as leaves, this would break the checking
# as such we are only validating the object when the entries in the dataclass
# Are arrays or other dataclasses of arrays.
try
:
dtype
=
instance
.
dtype
except
AttributeError
:
dtype
=
None
if
dtype
is
not
None
:
first_shape
=
instance
.
shape
for
array
,
field
in
zip
(
arrays
,
array_fields
):
field_shape
=
array
.
shape
num_trailing_dims
=
field
.
metadata
.
get
(
'num_trailing_dims'
,
None
)
if
num_trailing_dims
:
array_shape
=
array
.
shape
field_shape
=
array_shape
[:
-
num_trailing_dims
]
msg
=
(
f
'field
{
field
}
should have number of trailing dims'
' {num_trailing_dims}'
)
assert
len
(
array_shape
)
==
len
(
first_shape
)
+
num_trailing_dims
,
msg
else
:
field_shape
=
array
.
shape
shape_msg
=
(
f
"Stripped Shape
{
field_shape
}
of field
{
field
}
doesn't "
f
"match shape
{
first_shape
}
of field
{
first_field
}
"
)
assert
field_shape
==
first_shape
,
shape_msg
field_dtype
=
array
.
dtype
allowed_metadata_dtypes
=
field
.
metadata
.
get
(
'allowed_dtypes'
,
[])
if
allowed_metadata_dtypes
:
msg
=
f
'Dtype is
{
field_dtype
}
but must be in
{
allowed_metadata_dtypes
}
'
assert
field_dtype
in
allowed_metadata_dtypes
,
msg
if
'dtype'
in
field
.
metadata
:
target_dtype
=
field
.
metadata
[
'dtype'
]
else
:
target_dtype
=
dtype
msg
=
f
'Dtype is
{
field_dtype
}
but must be
{
target_dtype
}
'
assert
field_dtype
==
target_dtype
,
msg
def
flatten
(
instance
):
"""Flatten Struct of Array instance."""
array_likes
=
list
(
get_array_fields
(
instance
,
return_values
=
True
).
values
())
flat_array_likes
=
[]
inner_treedefs
=
[]
num_arrays
=
[]
for
array_like
in
array_likes
:
flat_array_like
,
inner_treedef
=
jax
.
tree_flatten
(
array_like
)
inner_treedefs
.
append
(
inner_treedef
)
flat_array_likes
+=
flat_array_like
num_arrays
.
append
(
len
(
flat_array_like
))
metadata
=
get_metadata_fields
(
instance
,
return_values
=
True
)
metadata
=
type
(
instance
).
metadata_cls
(
**
metadata
)
return
flat_array_likes
,
(
inner_treedefs
,
metadata
,
num_arrays
)
def
make_metadata_class
(
cls
):
metadata_fields
=
get_fields
(
cls
,
lambda
x
:
x
.
metadata
.
get
(
'is_metadata'
,
False
))
metadata_cls
=
dataclasses
.
make_dataclass
(
cls_name
=
'Meta'
+
cls
.
__name__
,
fields
=
[(
field
.
name
,
field
.
type
,
field
)
for
field
in
metadata_fields
],
frozen
=
True
,
eq
=
True
)
return
metadata_cls
def
get_fields
(
cls_or_instance
,
filterfn
,
return_values
=
False
):
fields
=
dataclasses
.
fields
(
cls_or_instance
)
fields
=
[
field
for
field
in
fields
if
filterfn
(
field
)]
if
return_values
:
return
{
field
.
name
:
getattr
(
cls_or_instance
,
field
.
name
)
for
field
in
fields
}
else
:
return
fields
def
get_array_fields
(
cls
,
return_values
=
False
):
return
get_fields
(
cls
,
lambda
x
:
not
x
.
metadata
.
get
(
'is_metadata'
,
False
),
return_values
=
return_values
)
def
get_metadata_fields
(
cls
,
return_values
=
False
):
return
get_fields
(
cls
,
lambda
x
:
x
.
metadata
.
get
(
'is_metadata'
,
False
),
return_values
=
return_values
)
class
StructOfArray
:
"""Class Decorator for Struct Of Arrays."""
def
__init__
(
self
,
same_dtype
=
True
):
self
.
same_dtype
=
same_dtype
def
__call__
(
self
,
cls
):
cls
.
__array_ufunc__
=
None
cls
.
replace
=
replace
cls
.
same_dtype
=
self
.
same_dtype
cls
.
dtype
=
get_dtype
cls
.
shape
=
get_shape
cls
.
__len__
=
get_len
cls
.
__getitem__
=
get_item
cls
.
__post_init__
=
post_init
new_cls
=
dataclasses
.
dataclass
(
cls
,
frozen
=
True
,
eq
=
False
)
# pytype: disable=wrong-keyword-args
# pytree claims to require metadata to be hashable, not sure why,
# But making derived dataclass that can just hold metadata
new_cls
.
metadata_cls
=
make_metadata_class
(
new_cls
)
def
unflatten
(
aux
,
data
):
inner_treedefs
,
metadata
,
num_arrays
=
aux
array_fields
=
[
field
.
name
for
field
in
get_array_fields
(
new_cls
)]
value_dict
=
{}
array_start
=
0
for
num_array
,
inner_treedef
,
array_field
in
zip
(
num_arrays
,
inner_treedefs
,
array_fields
):
value_dict
[
array_field
]
=
jax
.
tree_unflatten
(
inner_treedef
,
data
[
array_start
:
array_start
+
num_array
])
array_start
+=
num_array
metadata_fields
=
get_metadata_fields
(
new_cls
)
for
field
in
metadata_fields
:
value_dict
[
field
.
name
]
=
getattr
(
metadata
,
field
.
name
)
return
new_cls
(
**
value_dict
)
jax
.
tree_util
.
register_pytree_node
(
nodetype
=
new_cls
,
flatten_func
=
flatten
,
unflatten_func
=
unflatten
)
return
new_cls
openfold/utils/geometry/test_utils.py
View file @
def329b8
...
...
@@ -18,7 +18,6 @@ import dataclasses
from
alphafold.model.geometry
import
rigid_matrix_vector
from
alphafold.model.geometry
import
rotation_matrix
from
alphafold.model.geometry
import
vector
import
jax.numpy
as
jnp
import
numpy
as
np
...
...
@@ -35,7 +34,7 @@ def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
np
.
testing
.
assert_array_almost_equal
(
mat1
.
to_array
(),
mat2
.
to_array
(),
6
)
def
assert_array_equal_to_rotation_matrix
(
array
:
j
np
.
ndarray
,
def
assert_array_equal_to_rotation_matrix
(
array
:
np
.
ndarray
,
matrix
:
rotation_matrix
.
Rot3Array
):
"""Check that array and Matrix match."""
np
.
testing
.
assert_array_equal
(
matrix
.
xx
,
array
[...,
0
,
0
])
...
...
@@ -49,7 +48,7 @@ def assert_array_equal_to_rotation_matrix(array: jnp.ndarray,
np
.
testing
.
assert_array_equal
(
matrix
.
zz
,
array
[...,
2
,
2
])
def
assert_array_close_to_rotation_matrix
(
array
:
j
np
.
ndarray
,
def
assert_array_close_to_rotation_matrix
(
array
:
np
.
ndarray
,
matrix
:
rotation_matrix
.
Rot3Array
):
np
.
testing
.
assert_array_almost_equal
(
matrix
.
to_array
(),
array
,
6
)
...
...
@@ -66,11 +65,11 @@ def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
np
.
testing
.
assert_allclose
(
vec1
.
z
,
vec2
.
z
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_close_to_vector
(
array
:
j
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
def
assert_array_close_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_allclose
(
vec
.
to_array
(),
array
,
atol
=
1e-6
,
rtol
=
0.
)
def
assert_array_equal_to_vector
(
array
:
j
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
def
assert_array_equal_to_vector
(
array
:
np
.
ndarray
,
vec
:
vector
.
Vec3Array
):
np
.
testing
.
assert_array_equal
(
vec
.
to_array
(),
array
)
...
...
openfold/utils/geometry/vector.py
View file @
def329b8
...
...
@@ -19,8 +19,6 @@ from typing import Union, List
import
torch
from
openfold.utils.geometry
import
utils
Float
=
Union
[
float
,
torch
.
Tensor
]
@
dataclasses
.
dataclass
(
frozen
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment