Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
3aaf0ca8
"demo/java/src/ImageRestorer.java" did not exist on "a64900f3382e8e32155c47b8c5597022480b20ac"
Commit
3aaf0ca8
authored
Sep 25, 2023
by
Geoffrey Yu
Browse files
update trainining code with new input from new multimer pipeline
parent
74670a88
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
10 deletions
+12
-10
train_openfold.py
train_openfold.py
+12
-10
No files found.
train_openfold.py
View file @
3aaf0ca8
...
@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
...
@@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
return
self
.
model
(
batch
)
return
self
.
model
(
batch
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
features
,
gt_features
=
batch
# Log it
# Log it
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
if
(
self
.
ema
.
device
!=
features
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
self
.
ema
.
to
(
features
[
"aatype"
].
device
)
# Run the model
# Run the model
outputs
=
self
(
batch
)
outputs
=
self
(
features
)
# Remove the recycling dimension
# Remove the recycling dimension
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
features
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
features
)
# Compute loss
# Compute loss
loss
,
loss_breakdown
=
self
.
loss
(
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
outputs
,
(
features
,
gt_features
)
,
_return_breakdown
=
True
)
)
# Log it
# Log it
self
.
_log
(
loss_breakdown
,
batch
,
outputs
)
self
.
_log
(
loss_breakdown
,
features
,
outputs
)
return
loss
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
features
,
gt_features
=
batch
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# model.state_dict() contains references to model weights rather
...
@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
...
@@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
# Run the model
outputs
=
self
(
batch
)
outputs
=
self
(
features
)
# Compute loss and other metrics
# Compute loss and other metrics
batch
[
"use_clamped_fape"
]
=
0.
features
[
"use_clamped_fape"
]
=
0.
_
,
loss_breakdown
=
self
.
loss
(
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
outputs
,
(
features
,
gt_features
)
,
_return_breakdown
=
True
)
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
self
.
_log
(
loss_breakdown
,
features
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
# Restore the model weights to normal
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment