Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
cfd0fc6e
Unverified
Commit
cfd0fc6e
authored
Feb 03, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Feb 03, 2022
Browse files
Merge pull request #76 from aqlaboratory/chunking_experiment_rebased
parents
c9e0f894
2726892a
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
112 additions
and
45 deletions
+112
-45
scripts/utils.py
scripts/utils.py
+5
-5
tests/test_evoformer.py
tests/test_evoformer.py
+9
-6
tests/test_model.py
tests/test_model.py
+1
-0
tests/test_msa.py
tests/test_msa.py
+7
-10
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+3
-3
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+2
-2
train_openfold.py
train_openfold.py
+85
-19
No files found.
scripts/utils.py
View file @
cfd0fc6e
...
@@ -4,19 +4,19 @@ from datetime import date
...
@@ -4,19 +4,19 @@ from datetime import date
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_data_args
(
parser
:
argparse
.
ArgumentParser
):
parser
.
add_argument
(
parser
.
add_argument
(
'uniref90_database_path'
,
type
=
str
,
'
--
uniref90_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'mgnify_database_path'
,
type
=
str
,
'
--
mgnify_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'pdb70_database_path'
,
type
=
str
,
'
--
pdb70_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'template_mmcif_dir'
,
type
=
str
,
'
--
template_mmcif_dir'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'uniclust30_database_path'
,
type
=
str
,
'
--
uniclust30_database_path'
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
'--bfd_database_path'
,
type
=
str
,
default
=
None
,
...
...
tests/test_evoformer.py
View file @
cfd0fc6e
...
@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase):
...
@@ -135,8 +135,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
assert
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
<
consts
.
eps
)
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_msa
-
out_gt_msa
)
)
<
consts
.
eps
)
assert
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
<
consts
.
eps
)
assert
(
torch
.
max
(
torch
.
abs
(
out_repro_pair
-
out_gt_pair
)
)
<
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
class
TestExtraMSAStack
(
unittest
.
TestCase
):
...
@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase):
...
@@ -172,7 +172,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n
,
transition_n
,
msa_dropout
,
msa_dropout
,
pair_stack_dropout
,
pair_stack_dropout
,
blocks_per_
ckpt
=
Non
e
,
ckpt
=
Fals
e
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
).
eval
()
).
eval
()
...
@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase):
...
@@ -257,16 +257,19 @@ class TestMSATransition(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
model
.
evoformer
.
blocks
[
0
].
core
.
msa_transition
(
.
msa_transition
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
)
)
.
cpu
()
.
cpu
()
)
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
print
(
out_gt
)
print
(
out_repro
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/test_model.py
View file @
cfd0fc6e
...
@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase):
...
@@ -130,5 +130,6 @@ class TestModel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
]
out_repro
=
out_repro
.
squeeze
(
0
)
out_repro
=
out_repro
.
squeeze
(
0
)
print
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
print
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
1e-3
)
tests/test_msa.py
View file @
cfd0fc6e
...
@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
...
@@ -88,15 +88,13 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
model
.
evoformer
.
blocks
[
0
].
msa_att_row
(
.
msa_att_row
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_act
).
cuda
(),
z
=
torch
.
as_tensor
(
pair_act
).
cuda
(),
z
=
torch
.
as_tensor
(
pair_act
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
)
.
cpu
()
).
cpu
()
)
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase):
...
@@ -153,14 +151,14 @@ class TestMSAColumnAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
model
.
evoformer
.
blocks
[
0
].
msa_att_col
(
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
).
cuda
(),
)
)
.
cpu
()
).
cpu
()
)
print
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
)))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
self
.
assertTrue
(
torch
.
all
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
))
...
@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
...
@@ -218,8 +216,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
extra_msa_stack
.
stack
.
blocks
[
0
]
model
.
extra_msa_stack
.
blocks
[
0
].
msa_att_col
(
.
msa_att_col
(
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
msa_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
msa_mask
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
tests/test_triangular_attention.py
View file @
cfd0fc6e
...
@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_att_start
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
if
starting
if
starting
else
model
.
evoformer
.
blocks
[
0
].
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -95,7 +95,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size
=
None
,
chunk_size
=
None
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
)
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
def
test_tri_att_end_compare
(
self
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
cfd0fc6e
...
@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
if
incoming
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
train_openfold.py
View file @
cfd0fc6e
...
@@ -13,6 +13,7 @@ import time
...
@@ -13,6 +13,7 @@ import time
import
numpy
as
np
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.environments
import
SLURMEnvironment
from
pytorch_lightning.plugins.environments
import
SLURMEnvironment
import
torch
import
torch
...
@@ -29,7 +30,7 @@ from openfold.utils.callbacks import (
...
@@ -29,7 +30,7 @@ from openfold.utils.callbacks import (
)
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.loss
import
AlphaFoldLoss
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
scripts.zero_to_fp32
import
(
from
scripts.zero_to_fp32
import
(
...
@@ -66,21 +67,28 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -66,21 +67,28 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss
# Compute loss
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
self
.
log
(
"loss"
,
loss
)
return
{
"loss"
:
loss
}
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
self
.
cached_weights
=
self
.
model
.
state_dict
()
self
.
cached_weights
=
self
.
model
.
state_dict
()
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Calculate validation loss
# Calculate validation loss
outputs
=
self
(
batch
)
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
loss
=
self
.
loss
(
outputs
,
batch
)
loss
=
lddt_ca
(
self
.
log
(
"val_loss"
,
loss
,
prog_bar
=
True
)
outputs
[
"final_atom_positions"
],
return
{
"val_loss"
:
loss
}
batch
[
"all_atom_positions"
],
batch
[
"all_atom_mask"
],
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
self
.
log
(
"val/loss"
,
loss
,
logger
=
True
)
def
validation_epoch_end
(
self
,
_
):
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
# Restore the model weights to normal
...
@@ -101,6 +109,9 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -101,6 +109,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
self
.
ema
.
update
(
self
.
model
)
def
on_load_checkpoint
(
self
,
checkpoint
):
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
def
on_save_checkpoint
(
self
,
checkpoint
):
def
on_save_checkpoint
(
self
,
checkpoint
):
checkpoint
[
"ema"
]
=
self
.
ema
.
state_dict
()
checkpoint
[
"ema"
]
=
self
.
ema
.
state_dict
()
...
@@ -140,15 +151,15 @@ def main(args):
...
@@ -140,15 +151,15 @@ def main(args):
if
(
args
.
checkpoint_best_val
):
if
(
args
.
checkpoint_best_val
):
checkpoint_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoints"
)
checkpoint_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoints"
)
mc
=
ModelCheckpoint
(
mc
=
ModelCheckpoint
(
dirpath
=
checkpoint_dir
,
filename
=
"openfold_{epoch}_{step}_{val_loss:.2f}"
,
filename
=
"openfold_{epoch}_{step}_{val_loss:.2f}"
,
monitor
=
"val_loss"
,
monitor
=
"val/loss"
,
mode
=
"max"
,
)
)
callbacks
.
append
(
mc
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
es
=
EarlyStoppingVerbose
(
monitor
=
"val
_
loss"
,
monitor
=
"val
/
loss"
,
min_delta
=
args
.
min_delta
,
min_delta
=
args
.
min_delta
,
patience
=
args
.
patience
,
patience
=
args
.
patience
,
verbose
=
False
,
verbose
=
False
,
...
@@ -157,7 +168,7 @@ def main(args):
...
@@ -157,7 +168,7 @@ def main(args):
strict
=
True
,
strict
=
True
,
)
)
callbacks
.
append
(
es
)
callbacks
.
append
(
es
)
if
(
args
.
log_performance
):
if
(
args
.
log_performance
):
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
perf
=
PerformanceLoggingCallback
(
perf
=
PerformanceLoggingCallback
(
...
@@ -166,24 +177,41 @@ def main(args):
...
@@ -166,24 +177,41 @@ def main(args):
)
)
callbacks
.
append
(
perf
)
callbacks
.
append
(
perf
)
loggers
=
[]
if
(
args
.
wandb
):
wdb_logger
=
WandbLogger
(
name
=
args
.
experiment_name
,
save_dir
=
args
.
output_dir
,
id
=
args
.
wandb_id
,
project
=
args
.
wandb_project
,
**
{
"entity"
:
args
.
wandb_entity
}
)
loggers
.
append
(
wdb_logger
)
if
(
args
.
deepspeed_config_path
is
not
None
):
if
(
args
.
deepspeed_config_path
is
not
None
):
if
"SLURM_JOB_ID"
in
os
.
environ
:
#if "SLURM_JOB_ID" in os.environ:
cluster_environment
=
SLURMEnvironment
()
# cluster_environment = SLURMEnvironment()
else
:
#else:
cluster_environment
=
None
# cluster_environment = None
strategy
=
DeepSpeedPlugin
(
strategy
=
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
,
config
=
args
.
deepspeed_config_path
,
cluster_environment
=
cluster_environment
,
#
cluster_environment=cluster_environment,
)
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
)
>
1
or
args
.
num_nodes
>
1
:
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
else
:
else
:
strategy
=
None
strategy
=
None
trainer
=
pl
.
Trainer
.
from_argparse_args
(
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
args
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
strategy
=
strategy
,
callbacks
=
callbacks
,
callbacks
=
callbacks
,
logger
=
loggers
,
)
)
if
(
args
.
resume_model_weights_only
):
if
(
args
.
resume_model_weights_only
):
...
@@ -198,7 +226,7 @@ def main(args):
...
@@ -198,7 +226,7 @@ def main(args):
)
)
trainer
.
save_checkpoint
(
trainer
.
save_checkpoint
(
os
.
path
.
join
(
trainer
.
logger
.
log
_dir
,
"checkpoints"
,
"final.ckpt"
)
os
.
path
.
join
(
args
.
output
_dir
,
"checkpoints"
,
"final.ckpt"
)
)
)
...
@@ -318,10 +346,37 @@ if __name__ == "__main__":
...
@@ -318,10 +346,37 @@ if __name__ == "__main__":
"--log_performance"
,
type
=
bool_type
,
default
=
False
,
"--log_performance"
,
type
=
bool_type
,
default
=
False
,
help
=
"Measure performance"
help
=
"Measure performance"
)
)
parser
.
add_argument
(
"--wandb"
,
action
=
"store_true"
,
default
=
False
,
)
parser
.
add_argument
(
"--experiment_name"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--wandb_id"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--wandb_project"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--wandb_entity"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
help
=
"Whether to TorchScript eligible components of them model"
help
=
"Whether to TorchScript eligible components of them model"
)
)
parser
.
add_argument
(
"--train_prot_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--distillation_prot_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
# Disable the initial validation pass
...
@@ -330,7 +385,15 @@ if __name__ == "__main__":
...
@@ -330,7 +385,15 @@ if __name__ == "__main__":
)
)
# Remove some buggy/redundant arguments introduced by the Trainer
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
])
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
,
"--reload_dataloaders_every_epoch"
,
"--reload_dataloaders_every_n_epochs"
,
]
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -339,4 +402,7 @@ if __name__ == "__main__":
...
@@ -339,4 +402,7 @@ if __name__ == "__main__":
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
raise
ValueError
(
"For distributed training, --seed must be specified"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
main
(
args
)
main
(
args
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment