Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bc49758a
Commit
bc49758a
authored
Sep 13, 2023
by
Geoffrey Yu
Browse files
update data_transforms_multimer to the latest version of multimer branch
parent
da92663d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
26 deletions
+27
-26
openfold/data/data_transforms_multimer.py
openfold/data/data_transforms_multimer.py
+27
-26
No files found.
openfold/data/data_transforms_multimer.py
View file @
bc49758a
...
@@ -347,7 +347,7 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
...
@@ -347,7 +347,7 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
return
get_contiguous_crop_idx
(
protein
,
crop_size
,
generator
)
return
get_contiguous_crop_idx
(
protein
,
crop_size
,
generator
)
target_res_idx
=
randint
(
lower
=
0
,
target_res_idx
=
randint
(
lower
=
0
,
upper
=
interface_residues
.
shape
[
-
1
],
upper
=
interface_residues
.
shape
[
-
1
]
-
1
,
generator
=
generator
,
generator
=
generator
,
device
=
positions
.
device
)
device
=
positions
.
device
)
...
@@ -374,43 +374,45 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
...
@@ -374,43 +374,45 @@ def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
def
get_contiguous_crop_idx
(
protein
,
crop_size
,
generator
):
def
get_contiguous_crop_idx
(
protein
,
crop_size
,
generator
):
unique_asym_ids
,
chain_lens
=
protein
[
"asym_id"
].
unique
(
return_counts
=
True
)
unique_asym_ids
,
chain_idxs
,
chain_lens
=
protein
[
"asym_id"
].
unique
(
dim
=-
1
,
return_inverse
=
True
,
return_counts
=
True
)
shuffle_idx
=
torch
.
randperm
(
chain_lens
.
shape
[
-
1
],
device
=
chain_lens
.
device
,
generator
=
generator
)
shuffle_idx
=
torch
.
randperm
(
chain_lens
.
shape
[
-
1
],
device
=
chain_lens
.
device
,
generator
=
generator
)
num_remaining
=
int
(
chain_lens
.
sum
())
_
,
idx_sorted
=
torch
.
sort
(
chain_idxs
,
stable
=
True
)
cum_sum
=
chain_lens
.
cumsum
(
dim
=
0
)
cum_sum
=
torch
.
cat
((
torch
.
tensor
([
0
]),
cum_sum
[:
-
1
]),
dim
=
0
)
asym_offsets
=
idx_sorted
[
cum_sum
]
num_budget
=
crop_size
num_budget
=
crop_size
num_remaining
=
int
(
protein
[
"seq_length"
])
crop_idxs
=
[]
crop_idxs
=
[]
for
idx
in
shuffle_idx
:
chain_len
=
int
(
chain_lens
[
idx
])
num_remaining
-=
chain_len
per_asym_residue_index
=
{}
crop_size_max
=
min
(
num_budget
,
chain_len
)
for
cur_asym_id
in
unique_asym_ids
:
crop_size_min
=
min
(
chain_len
,
max
(
0
,
num_budget
-
num_remaining
))
asym_mask
=
(
protein
[
"asym_id"
]
==
cur_asym_id
).
bool
()
per_asym_residue_index
[
int
(
cur_asym_id
)]
=
torch
.
masked_select
(
protein
[
"asym_id"
],
asym_mask
)[
0
]
for
j
,
idx
in
enumerate
(
shuffle_idx
):
this_len
=
int
(
chain_lens
[
idx
])
num_remaining
-=
this_len
# num res at most we can keep in this ent
crop_size_max
=
min
(
num_budget
,
this_len
)
# num res at least we shall keep in this ent
crop_size_min
=
min
(
this_len
,
max
(
0
,
num_budget
-
num_remaining
))
chain_crop_size
=
randint
(
lower
=
crop_size_min
,
chain_crop_size
=
randint
(
lower
=
crop_size_min
,
upper
=
crop_size_max
+
1
,
upper
=
crop_size_max
,
generator
=
generator
,
generator
=
generator
,
device
=
chain_lens
.
device
)
device
=
chain_lens
.
device
)
num_budget
-=
chain_crop_size
chain_start
=
randint
(
lower
=
0
,
chain_start
=
randint
(
lower
=
0
,
upper
=
this
_len
-
chain_crop_size
+
1
,
upper
=
chain
_len
-
chain_crop_size
,
generator
=
generator
,
generator
=
generator
,
device
=
chain_lens
.
device
)
device
=
chain_lens
.
device
)
cur_asym_id
=
unique_asym_ids
[
int
(
idx
)].
item
()
asym_offset
=
per_
asym_
residue_index
[
int
(
cur_asym_
id
)
]
asym_offset
=
asym_
offsets
[
id
x
]
crop_idxs
.
append
(
crop_idxs
.
append
(
torch
.
arange
(
asym_offset
+
chain_start
,
asym_offset
+
chain_start
+
chain_crop_size
)
torch
.
arange
(
asym_offset
+
chain_start
,
asym_offset
+
chain_start
+
chain_crop_size
)
)
)
asym_offset
+=
this_len
num_budget
-=
chain_crop_size
return
torch
.
concat
(
crop_idxs
)
return
torch
.
concat
(
crop_idxs
)
.
sort
().
values
@
curry1
@
curry1
...
@@ -453,7 +455,7 @@ def random_crop_to_size(
...
@@ -453,7 +455,7 @@ def random_crop_to_size(
if
subsample_templates
:
if
subsample_templates
:
templates_crop_start
=
randint
(
lower
=
0
,
templates_crop_start
=
randint
(
lower
=
0
,
upper
=
num_templates
+
1
,
upper
=
num_templates
,
generator
=
g
,
generator
=
g
,
device
=
protein
[
"seq_length"
].
device
)
device
=
protein
[
"seq_length"
].
device
)
templates_select_indices
=
torch
.
randperm
(
templates_select_indices
=
torch
.
randperm
(
...
@@ -480,8 +482,7 @@ def random_crop_to_size(
...
@@ -480,8 +482,7 @@ def random_crop_to_size(
for
i
,
(
dim_size
,
dim
)
in
enumerate
(
zip
(
shape_schema
[
k
],
v
.
shape
)):
for
i
,
(
dim_size
,
dim
)
in
enumerate
(
zip
(
shape_schema
[
k
],
v
.
shape
)):
is_num_res
=
dim_size
==
NUM_RES
is_num_res
=
dim_size
==
NUM_RES
if
i
==
0
and
k
.
startswith
(
"template"
):
if
i
==
0
and
k
.
startswith
(
"template"
):
crop_start
=
templates_crop_start
v
=
v
[
slice
(
templates_crop_start
,
templates_crop_start
+
num_templates_crop_size
)]
v
=
v
[
slice
(
crop_start
,
crop_start
+
num_templates_crop_size
)]
elif
is_num_res
:
elif
is_num_res
:
v
=
torch
.
index_select
(
v
,
i
,
crop_idxs
)
v
=
torch
.
index_select
(
v
,
i
,
crop_idxs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment