Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
614e2763
Commit
614e2763
authored
Feb 08, 2023
by
zhuww
Browse files
support running in multimer mode
parent
7e01f6d6
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
104 additions
and
77 deletions
+104
-77
fastfold/config.py
fastfold/config.py
+1
-1
fastfold/data/templates.py
fastfold/data/templates.py
+14
-12
fastfold/model/fastnn/evoformer.py
fastfold/model/fastnn/evoformer.py
+18
-14
fastfold/model/hub/alphafold.py
fastfold/model/hub/alphafold.py
+1
-1
fastfold/model/nn/structure_module.py
fastfold/model/nn/structure_module.py
+5
-3
fastfold/utils/import_weights.py
fastfold/utils/import_weights.py
+20
-17
fastfold/workflow/template/fastfold_data_workflow.py
fastfold/workflow/template/fastfold_data_workflow.py
+2
-1
fastfold/workflow/template/fastfold_multimer_data_workflow.py
...fold/workflow/template/fastfold_multimer_data_workflow.py
+2
-1
inference.py
inference.py
+41
-27
No files found.
fastfold/config.py
View file @
614e2763
...
@@ -575,7 +575,7 @@ multimer_model_config_update = {
...
@@ -575,7 +575,7 @@ multimer_model_config_update = {
"tm"
:
{
"tm"
:
{
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
tm_enabled
,
"enabled"
:
True
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_m"
:
c_m
,
...
...
fastfold/data/templates.py
View file @
614e2763
...
@@ -881,18 +881,20 @@ def _process_single_hit(
...
@@ -881,18 +881,20 @@ def _process_single_hit(
)
as
e
:
)
as
e
:
# These 3 errors indicate missing mmCIF experimental data rather than a
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
# problem with the template search, so turn them into warnings.
warning
=
(
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
# warning = (
"%s, mmCIF parsing errors: %s"
# "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
%
(
# "%s, mmCIF parsing errors: %s"
hit_pdb_code
,
# % (
hit_chain_id
,
# hit_pdb_code,
hit
.
sum_probs
,
# hit_chain_id,
hit
.
index
,
# hit.sum_probs,
str
(
e
),
# hit.index,
parsing_result
.
errors
,
# str(e),
)
# parsing_result.errors,
)
# )
# )
warning
=
None
if
strict_error_check
:
if
strict_error_check
:
return
SingleHitResult
(
features
=
None
,
error
=
warning
,
warning
=
None
)
return
SingleHitResult
(
features
=
None
,
error
=
warning
,
warning
=
None
)
else
:
else
:
...
...
fastfold/model/fastnn/evoformer.py
View file @
614e2763
...
@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
...
@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from
fastfold.model.fastnn
import
MSACore
,
OutProductMean
,
PairCore
from
fastfold.model.fastnn
import
MSACore
,
OutProductMean
,
PairCore
from
fastfold.model.fastnn.ops
import
Linear
from
fastfold.model.fastnn.ops
import
Linear
from
fastfold.distributed.comm
import
gather
,
scatter
from
fastfold.distributed.comm
import
gather
,
scatter
,
col_to_row
from
fastfold.distributed.comm_async
import
All_to_All_Async
,
All_to_All_Async_Opp
from
fastfold.distributed.comm_async
import
All_to_All_Async
,
All_to_All_Async_Opp
from
fastfold.utils.checkpointing
import
checkpoint_blocks
from
fastfold.utils.checkpointing
import
checkpoint_blocks
...
@@ -49,7 +49,10 @@ class Evoformer(nn.Module):
...
@@ -49,7 +49,10 @@ class Evoformer(nn.Module):
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
padding_size
))
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
padding_size
))
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
m
=
scatter
(
m
,
dim
=
1
)
if
self
.
is_multimer
:
m
=
scatter
(
m
,
dim
=
2
)
else
:
m
=
scatter
(
m
,
dim
=
1
)
z
=
scatter
(
z
,
dim
=
1
)
z
=
scatter
(
z
,
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
...
@@ -76,7 +79,10 @@ class Evoformer(nn.Module):
...
@@ -76,7 +79,10 @@ class Evoformer(nn.Module):
m
=
m
.
squeeze
(
0
)
m
=
m
.
squeeze
(
0
)
z
=
z
.
squeeze
(
0
)
z
=
z
.
squeeze
(
0
)
m
=
gather
(
m
,
dim
=
0
)
if
self
.
is_multimer
:
m
=
gather
(
m
,
dim
=
1
)
else
:
m
=
gather
(
m
,
dim
=
0
)
z
=
gather
(
z
,
dim
=
0
)
z
=
gather
(
z
,
dim
=
0
)
m
=
m
[:,
:
-
padding_size
,
:]
m
=
m
[:,
:
-
padding_size
,
:]
...
@@ -107,8 +113,10 @@ class Evoformer(nn.Module):
...
@@ -107,8 +113,10 @@ class Evoformer(nn.Module):
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
is_multimer
:
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
2
)
else
:
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
,
drop_unused
=
True
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -126,15 +134,8 @@ class Evoformer(nn.Module):
...
@@ -126,15 +134,8 @@ class Evoformer(nn.Module):
z
=
self
.
pair
.
inplace
(
z
,
pair_mask
)
z
=
self
.
pair
.
inplace
(
z
,
pair_mask
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
else
:
else
:
# z = self.communication.inplace(m[0], msa_mask, z)
# z_ori = z[0].clone()
# m[0], work = All_to_All_Async.apply(m[0], 1, 2)
# z = self.pair_stack.inplace(z, pair_mask)
# m[0] = All_to_All_Async_Opp.apply(m[0], work, 1, 2)
# m[0] = self.msa_stack(m[0], z_ori, msa_mask)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
],
work
=
All_to_All_Async
.
apply
(
m
[
0
],
1
,
2
)
m
[
0
]
=
col_to_row
(
m
[
0
])
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
m
[
0
]
=
self
.
msa
(
m
[
0
],
z
[
0
],
msa_mask
)
m
[
0
]
=
self
.
msa
(
m
[
0
],
z
[
0
],
msa_mask
)
z
=
self
.
pair
.
inplace
(
z
,
pair_mask
)
z
=
self
.
pair
.
inplace
(
z
,
pair_mask
)
...
@@ -143,7 +144,10 @@ class Evoformer(nn.Module):
...
@@ -143,7 +144,10 @@ class Evoformer(nn.Module):
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunks
=
4
)
if
self
.
is_multimer
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
else
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
,
chunks
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunks
=
4
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
,
chunks
=
4
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
...
...
fastfold/model/hub/alphafold.py
View file @
614e2763
...
@@ -360,7 +360,7 @@ class AlphaFold(nn.Module):
...
@@ -360,7 +360,7 @@ class AlphaFold(nn.Module):
pair_mask
=
pair_mask
.
to
(
dtype
=
z
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
[
0
].
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)[
0
]
)[
0
]
del
extra_msa_feat
,
extra_msa_fn
del
extra_msa_feat
,
extra_msa_fn
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
fastfold/model/nn/structure_module.py
View file @
614e2763
...
@@ -530,7 +530,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -530,7 +530,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, C_s]
# [*, N_res, C_s]
if
self
.
is_multimer
:
if
self
.
is_multimer
:
s
=
self
.
linear_out
(
s
=
self
.
linear_out
(
torch
.
cat
((
o
,
*
o_pt
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
.
dtype
)
torch
.
cat
((
o
,
*
o_pt
,
o_pt_norm
,
o_pair
),
dim
=-
1
).
to
(
dtype
=
z
[
0
]
.
dtype
)
)
)
else
:
else
:
s
=
self
.
linear_out
(
s
=
self
.
linear_out
(
...
@@ -874,7 +874,8 @@ class StructureModule(nn.Module):
...
@@ -874,7 +874,8 @@ class StructureModule(nn.Module):
def
_forward_multimer
(
def
_forward_multimer
(
self
,
self
,
evoformer_output_dict
,
s
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
...
@@ -898,6 +899,7 @@ class StructureModule(nn.Module):
...
@@ -898,6 +899,7 @@ class StructureModule(nn.Module):
s
.
device
,
s
.
device
,
)
)
outputs
=
[]
outputs
=
[]
z
=
[
z
]
for
i
in
range
(
self
.
no_blocks
):
for
i
in
range
(
self
.
no_blocks
):
# [*, N, C_s]
# [*, N, C_s]
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
s
=
s
+
self
.
ipa
(
s
,
z
,
rigids
,
mask
)
...
@@ -960,7 +962,7 @@ class StructureModule(nn.Module):
...
@@ -960,7 +962,7 @@ class StructureModule(nn.Module):
A dictionary of outputs
A dictionary of outputs
"""
"""
if
self
.
is_multimer
:
if
self
.
is_multimer
:
outputs
=
self
.
_forward_multimer
(
evoformer_output_dict
,
aatype
,
mask
)
outputs
=
self
.
_forward_multimer
(
evoformer_output_dict
[
"single"
],
evoformer_output_dict
[
"pair"
]
,
aatype
,
mask
)
else
:
else
:
outputs
=
self
.
_forward_monomer
(
evoformer_output_dict
,
aatype
,
mask
)
outputs
=
self
.
_forward_monomer
(
evoformer_output_dict
,
aatype
,
mask
)
...
...
fastfold/utils/import_weights.py
View file @
614e2763
...
@@ -126,11 +126,9 @@ def assign(translation_dict, orig_weights):
...
@@ -126,11 +126,9 @@ def assign(translation_dict, orig_weights):
print
(
ref
[
0
].
shape
)
print
(
ref
[
0
].
shape
)
print
(
weights
[
0
].
shape
)
print
(
weights
[
0
].
shape
)
raise
raise
def
get_translation_dict
(
model
,
version
):
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
,
is_multimer
:
bool
=
False
):
is_multimer
=
"multimer"
in
version
data
=
np
.
load
(
npz_path
)
# translations = get_translation_dict(model, is_multimer=("multimer" in version))
#######################
#######################
# Some templates
# Some templates
#######################
#######################
...
@@ -540,16 +538,14 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
...
@@ -540,16 +538,14 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
},
},
}
}
# return translations
no_templ
=
[
no_templ
=
[
"model_3"
,
"model_3"
,
"model_4"
,
"model_4"
,
"model_5"
,
"model_5"
,
"model_3_ptm"
,
"model_3_ptm"
,
"model_4_ptm"
,
"model_4_ptm"
,
"model_5_ptm"
,
"model_5_ptm"
,
]
]
if
version
in
no_templ
:
if
version
in
no_templ
:
evo_dict
=
translations
[
"evoformer"
]
evo_dict
=
translations
[
"evoformer"
]
keys
=
list
(
evo_dict
.
keys
())
keys
=
list
(
evo_dict
.
keys
())
...
@@ -557,10 +553,19 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
...
@@ -557,10 +553,19 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
if
"template_"
in
k
:
if
"template_"
in
k
:
evo_dict
.
pop
(
k
)
evo_dict
.
pop
(
k
)
if
"_ptm"
in
version
:
if
"_ptm"
in
version
or
is_multimer
:
translations
[
"predicted_aligned_error_head"
]
=
{
translations
[
"predicted_aligned_error_head"
]
=
{
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
"logits"
:
LinearParams
(
model
.
aux_heads
.
tm
.
linear
)
}
}
return
translations
def
import_jax_weights_
(
model
,
npz_path
,
version
=
"model_1"
):
data
=
np
.
load
(
npz_path
)
translations
=
get_translation_dict
(
model
,
version
)
# Flatten keys and insert missing key prefixes
# Flatten keys and insert missing key prefixes
flat
=
_process_translations_dict
(
translations
)
flat
=
_process_translations_dict
(
translations
)
...
@@ -578,5 +583,3 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
...
@@ -578,5 +583,3 @@ def import_jax_weights_(model, npz_path, version="model_1", is_multimer: bool =
# Set weights
# Set weights
assign
(
flat
,
data
)
assign
(
flat
,
data
)
fastfold/workflow/template/fastfold_data_workflow.py
View file @
614e2763
...
@@ -119,7 +119,8 @@ class FastFoldDataWorkFlow:
...
@@ -119,7 +119,8 @@ class FastFoldDataWorkFlow:
def
run
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
=
None
,
storage_dir
:
str
=
None
)
->
None
:
def
run
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
=
None
,
storage_dir
:
str
=
None
)
->
None
:
storage_dir
=
"file:///tmp/ray/"
+
os
.
getlogin
()
+
"/workflow_data"
timestamp
=
time
.
strftime
(
"%Y-%m-%d-%H-%M-%S"
,
time
.
localtime
())
storage_dir
=
"file:///tmp/ray/"
+
str
(
timestamp
)
+
"/workflow_data"
if
storage_dir
is
not
None
:
if
storage_dir
is
not
None
:
if
not
os
.
path
.
exists
(
storage_dir
):
if
not
os
.
path
.
exists
(
storage_dir
):
os
.
makedirs
(
storage_dir
[
7
:],
exist_ok
=
True
)
os
.
makedirs
(
storage_dir
[
7
:],
exist_ok
=
True
)
...
...
fastfold/workflow/template/fastfold_multimer_data_workflow.py
View file @
614e2763
...
@@ -137,7 +137,8 @@ class FastFoldMultimerDataWorkFlow:
...
@@ -137,7 +137,8 @@ class FastFoldMultimerDataWorkFlow:
def
run
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
=
None
,
storage_dir
:
str
=
None
)
->
None
:
def
run
(
self
,
fasta_path
:
str
,
alignment_dir
:
str
=
None
,
storage_dir
:
str
=
None
)
->
None
:
storage_dir
=
"file:///tmp/ray/"
+
os
.
getlogin
()
+
"/workflow_data"
timestamp
=
time
.
strftime
(
"%Y-%m-%d-%H-%M-%S"
,
time
.
localtime
())
storage_dir
=
"file:///tmp/ray/"
+
str
(
timestamp
)
+
"/workflow_data"
if
storage_dir
is
not
None
:
if
storage_dir
is
not
None
:
if
not
os
.
path
.
exists
(
storage_dir
):
if
not
os
.
path
.
exists
(
storage_dir
):
os
.
makedirs
(
storage_dir
[
7
:],
exist_ok
=
True
)
os
.
makedirs
(
storage_dir
[
7
:],
exist_ok
=
True
)
...
...
inference.py
View file @
614e2763
...
@@ -312,21 +312,31 @@ def inference_multimer_model(args):
...
@@ -312,21 +312,31 @@ def inference_multimer_model(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
amber_relaxer
=
relax
.
AmberRelaxation
(
if
(
args
.
relaxation
):
use_gpu
=
True
,
amber_relaxer
=
relax
.
AmberRelaxation
(
**
config
.
relax
,
use_gpu
=
True
,
)
**
config
.
relax
,
)
# Relax the prediction.
# Relax the prediction.
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
if
(
args
.
save_outputs
):
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_output_dict.pkl'
)
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
# Save the relaxed PDB.
logger
.
info
(
f
"Model output written to
{
output_dict_path
}
..."
)
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
def
inference_monomer_model
(
args
):
def
inference_monomer_model
(
args
):
...
@@ -454,21 +464,22 @@ def inference_monomer_model(args):
...
@@ -454,21 +464,22 @@ def inference_monomer_model(args):
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
with
open
(
unrelaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
f
.
write
(
protein
.
to_pdb
(
unrelaxed_protein
))
amber_relaxer
=
relax
.
AmberRelaxation
(
if
(
args
.
relaxation
):
use_gpu
=
True
,
amber_relaxer
=
relax
.
AmberRelaxation
(
**
config
.
relax
,
use_gpu
=
True
,
)
**
config
.
relax
,
)
# Relax the prediction.
t
=
time
.
perf_counter
()
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
# Save the relaxed PDB.
# Relax the prediction.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
t
=
time
.
perf_counter
()
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
relaxed_pdb_str
,
_
,
_
=
amber_relaxer
.
process
(
prot
=
unrelaxed_protein
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
print
(
f
"Relaxation time:
{
time
.
perf_counter
()
-
t
}
"
)
f
.
write
(
relaxed_pdb_str
)
# Save the relaxed PDB.
relaxed_output_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_relaxed.pdb'
)
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
if
(
args
.
save_outputs
):
if
(
args
.
save_outputs
):
output_dict_path
=
os
.
path
.
join
(
output_dict_path
=
os
.
path
.
join
(
...
@@ -512,6 +523,9 @@ if __name__ == "__main__":
...
@@ -512,6 +523,9 @@ if __name__ == "__main__":
help
=
"""Path to model parameters. If None, parameters are selected
help
=
"""Path to model parameters. If None, parameters are selected
automatically according to the model name from
automatically according to the model name from
./data/params"""
)
./data/params"""
)
parser
.
add_argument
(
"--relaxation"
,
action
=
"store_false"
,
default
=
False
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--save_outputs"
,
action
=
"store_true"
,
default
=
False
,
"--save_outputs"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to save all model outputs, including embeddings, etc."
help
=
"Whether to save all model outputs, including embeddings, etc."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment