"tests/fixtures/custom_pipeline/what_ever.py" did not exist on "46dae846dfd083a1c29c4c88e813a470c045c846"
Commit b14e47f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/hpcaitech/FastFold

parents 490cb6f5 05681304
Pipeline #234 failed with stages
in 0 seconds
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.distributed.comm import gather, scatter, row_to_col
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.extra_msa_stack.blocks[0].msa_stack.MSAColumnAttention.eval().cuda()
target_module = target_module.extra_msa_stack.blocks[0].msa_att_col.eval().cuda()
msa_len = 512
seq_len = 128
m = torch.randn((msa_len, seq_len, 64)).cuda()
m_mask = torch.ones((msa_len, seq_len)).cuda().to(dtype=m.dtype)
m_mask[128:, :] = 0
m_out = m + target_module(m, mask=m_mask, chunk_size=None)
return m_out, m, m_mask, fast_module
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
def test_state_dict(world_size, chunk_size, get_openfold_module_and_data):
run_func = partial(_test_msa_global_att_col, world_size=world_size, chunk_size=chunk_size, get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_msa_global_att_col(rank, world_size, chunk_size, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
m_out, m, m_mask, fast_module = get_openfold_module_and_data
fast_module = copy.deepcopy(fast_module).cuda()
fast_m = copy.deepcopy(m.cuda()).unsqueeze(0)
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = m_mask.cuda().size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
fast_m = torch.nn.functional.pad(fast_m, (0, 0, 0, padding_size))
fast_m = scatter(fast_m, dim=1)
fast_m_mask = copy.deepcopy(m_mask.cuda()).unsqueeze(0)
fast_m_mask = torch.nn.functional.pad(fast_m_mask, (0, padding_size))
with torch.no_grad():
set_chunk_size(chunk_size)
fast_m = row_to_col(fast_m)
fast_m_mask = scatter(fast_m_mask, dim=2)
m_fast = fast_module(fast_m, fast_m_mask)
m_fast = m_fast.squeeze(0)
m_fast = gather(m_fast, dim=1)
m_fast = m_fast[:, :-padding_size, :]
error = torch.max(torch.abs(m_out.cuda() - m_fast))
assert error < 5e-5, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}"
import torch
import pytest
import copy
import fastfold
import os
import torch.multiprocessing as mp
from functools import partial
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.utils.test_utils import get_param_path
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.config import model_config
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.utils.test_utils import get_param_path
from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm import gather, scatter
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.evoformer.blocks[0].communication.eval().cuda()
target_module = target_module.evoformer.blocks[0].core.outer_product_mean.eval().cuda()
msa_len = 20
seq_len = 30
m = torch.randn((msa_len, seq_len, 256)).cuda()
m_mask = torch.ones((msa_len, seq_len)).cuda()
m_mask[:, -5:] = 0
z = torch.zeros((seq_len, seq_len, 128)).cuda()
out = target_module(m, m_mask)
return m, m_mask, z, fast_module, out
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(_test_out_product_mean, world_size=world_size, chunk_size=chunk_size,
inplace=inplace, get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_out_product_mean(rank, world_size, chunk_size, inplace, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
m, m_mask, z, fast_module, out = get_openfold_module_and_data
fast_module = copy.deepcopy(fast_module).cuda()
fast_m = copy.deepcopy(m.cuda()).unsqueeze(0)
fast_z = copy.deepcopy(z.cuda()).unsqueeze(0)
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = m_mask.cuda().size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
fast_m = torch.nn.functional.pad(fast_m, (0, 0, 0, padding_size))
fast_z = torch.nn.functional.pad(fast_z, (0, 0, 0, padding_size, 0, padding_size))
fast_m = scatter(fast_m, dim=1)
fast_z = scatter(fast_z, dim=1)
fast_m_mask = copy.deepcopy(m_mask.cuda()).unsqueeze(0)
fast_m_mask = torch.nn.functional.pad(fast_m_mask, (0, padding_size))
with torch.no_grad():
set_chunk_size(chunk_size)
fast_m = row_to_col(fast_m)
if inplace:
out_fast = fast_module.inplace(fast_m, fast_m_mask, [fast_z])[0]
else:
out_fast = fast_module(fast_m, fast_m_mask, fast_z)
out_fast = out_fast.squeeze(0)
out_fast = gather(out_fast, dim=0)
out_fast = out_fast[:-padding_size, :-padding_size, :]
error = torch.mean(torch.abs(out.cuda() - out_fast))
assert error < 1e-5, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
import torch
from fastfold.model.fastnn.kernel import fused_softmax
from fastfold.model.fastnn.kernel import softmax
def _test_softmax_core():
batch_, chunk_, head_ = 1, 8, 4
test_seq_ = [31, 32, 128, 129, 256, 259, 512, 700, 1024]
test_dtype = [torch.float32, torch.float16, torch.bfloat16]
test_device = torch.device("cuda")
tolerance_eps = {torch.float32: 1e-6, torch.float16: 2e-4, torch.bfloat16: 1e-3}
for seq_ in test_seq_:
for dtype in test_dtype:
sample_input = torch.rand(batch_, chunk_, head_, seq_,
seq_).to(device=test_device, dtype=dtype).requires_grad_(True)
sample_mask = torch.cuda.FloatTensor(batch_, chunk_, seq_).uniform_() > 0
sample_mask = sample_mask.to(device=test_device, dtype=dtype).requires_grad_(False)
sample_bias = torch.rand(batch_, 1, head_, seq_,
seq_).to(device=test_device, dtype=dtype).requires_grad_(True)
sample_input_fastnn = torch.clone(sample_input.detach()).requires_grad_(True)
sample_mask_fastnn = torch.clone(sample_mask.detach()).requires_grad_(False)
sample_bias_fastnn = torch.clone(sample_bias.detach()).requires_grad_(True)
# Forward
sample_mask_torch = 1e9 * (sample_mask - 1)[:, :, None, None, :]
torch_out = torch.nn.functional.softmax(sample_input + sample_mask_torch + sample_bias,
dim=-1)
fastnn_out = fused_softmax(sample_input_fastnn, sample_mask_fastnn, sample_bias_fastnn)
fwd_fastnn_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item()
assert fwd_fastnn_error < tolerance_eps[
dtype], f"fastnn fwd kernel error when {seq_} {dtype}"
# Backward
out_grad = torch.rand_like(torch_out).requires_grad_(False)
torch_out.backward(out_grad)
fastnn_out.backward(out_grad)
grad_input_error = torch.max(torch.abs(sample_input.grad -
sample_input_fastnn.grad)).cpu().item()
assert grad_input_error < tolerance_eps[
dtype], f"fastnn bwd kernel error when {seq_} {dtype}"
grad_bias_error = torch.max(torch.abs(sample_bias.grad -
sample_bias_fastnn.grad)).cpu().item()
assert grad_bias_error < tolerance_eps[
dtype], f"fastnn bwd kernel error when {seq_} {dtype}"
def test_softmax():
_test_softmax_core()
if softmax._triton_available:
softmax._triton_available = False
_test_softmax_core()
if __name__ == "__main__":
test_softmax()
import torch
import pytest
import pickle
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.test_utils import get_param_path, get_data_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.template_embedder
fast_module = fast_module.eval().cuda()
target_module = target_module.template_embedder
target_module = target_module.eval().cuda()
batch = pickle.load(open(get_data_path(), 'rb'))
fetch_cur_batch = lambda t: t[..., 0]
feats = tensor_tree_map(fetch_cur_batch, batch)
feats = {k: v.cuda() for k, v in feats.items() if k.startswith("template_")}
seq_len = 33
z = torch.randn((seq_len, seq_len, 128)).cuda()
z_mask = torch.ones((seq_len, seq_len)).cuda().to(dtype=z.dtype)
template_embeds = target_module(copy.deepcopy(feats), z, z_mask.to(dtype=z.dtype), 0, None)
z_out = z + template_embeds["template_pair_embedding"]
return fast_module, z_out, feats, z, z_mask
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 4]) # should set 4 to test offload
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(_test_template_embedder, world_size=world_size, chunk_size=chunk_size,
inplace=inplace, get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_template_embedder(rank, world_size, chunk_size, inplace, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
fast_module, z_out, feats, z, z_mask = get_openfold_module_and_data
fast_module = copy.deepcopy(fast_module).cuda()
template_feats = copy.deepcopy(feats)
for k, v in template_feats.items():
template_feats[k] = v.cuda()
with torch.no_grad():
set_chunk_size(chunk_size)
if inplace:
template_embeds = fast_module(copy.deepcopy(template_feats), copy.deepcopy(z).cuda(), z_mask.to(dtype=z.dtype).cuda(), 0, chunk_size, inplace=inplace)
z_fast = template_embeds["template_pair_embedding"]
else:
template_embeds = fast_module(copy.deepcopy(template_feats), copy.deepcopy(z).cuda(), z_mask.to(dtype=z.dtype).cuda(), 0, chunk_size)
z_fast = z.cuda() + template_embeds["template_pair_embedding"]
error = torch.mean(torch.abs(z_out.cuda() - z_fast))
assert error < 5e-4, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
import os
import copy
import pytest
import torch
import pickle
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.model.hub import AlphaFold
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.test_utils import get_data_path, get_param_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
config = model_config('model_1')
config.globals.inplace = False
model = AlphaFold(config)
import_jax_weights_(model, get_param_path())
model.eval().cuda()
batch = pickle.load(open(get_data_path(), 'rb'))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
with torch.no_grad():
out = model(batch)
fastmodel = copy.deepcopy(model)
fastmodel = inject_fastnn(fastmodel)
fastmodel.eval().cuda()
return model, out, fastmodel
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(run_dist, world_size=world_size, chunk_size=chunk_size, inplace=inplace, model=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def run_dist(rank, world_size, chunk_size, inplace, model):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
inference(chunk_size, inplace, model)
def inference(chunk_size, inplace, get_openfold_module_and_data):
model, out, fastmodel = get_openfold_module_and_data
model.globals.chunk_size = chunk_size
model.globals.inplace = inplace
fastmodel = copy.deepcopy(fastmodel).cuda()
fastmodel.structure_module.default_frames = fastmodel.structure_module.default_frames.cuda()
fastmodel.structure_module.group_idx = fastmodel.structure_module.group_idx.cuda()
fastmodel.structure_module.atom_mask = fastmodel.structure_module.atom_mask.cuda()
fastmodel.structure_module.lit_positions = fastmodel.structure_module.lit_positions.cuda()
set_chunk_size(model.globals.chunk_size)
batch = pickle.load(open(get_data_path(), 'rb'))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
with torch.no_grad():
fastout = fastmodel(batch)
pos_dif = torch.max(torch.abs(fastout["final_atom_positions"] - out["final_atom_positions"].cuda()))
assert pos_dif < 5e-4, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {pos_dif}"
import os
import pytest
import torch
import pickle
import torch.multiprocessing as mp
from functools import partial
import colossalai
from fastfold.model.hub import AlphaFold
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.test_utils import get_train_data_path
from fastfold.model.hub.loss import AlphaFoldLoss
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.test_utils import set_seed
def get_param_and_grad(model):
params = dict()
grads = dict()
for name, param in model.named_parameters():
params[name] = param.detach().clone()
grads[name] = param.grad.detach().clone()
return params, grads
@pytest.fixture(scope="module")
def get_openfold_state():
config = model_config('initial_training', train=True)
config.globals.inplace = False
set_seed(42)
model = AlphaFold(config)
model.train().cuda()
criterion = AlphaFoldLoss(config.loss)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8)
batch = pickle.load(open(get_train_data_path(), 'rb'))
set_seed(42)
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
out = model(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss, _ = criterion(out, batch, True)
optimizer.zero_grad()
set_seed(42)
loss.backward()
optimizer.step()
of_params, of_grads = get_param_and_grad(model)
return of_params, of_grads
@pytest.mark.skipif(torch.cuda.mem_get_info(0)[1] < 4e10, reason="Not enough cuda memory")
@pytest.mark.parametrize('world_size', [1])
def test_state_dict(world_size, get_openfold_state):
run_func = partial(run_dist, world_size=world_size, model=get_openfold_state)
mp.spawn(run_func, nprocs=world_size)
def run_dist(rank, world_size, model):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
colossalai.launch(config=dict(parallel=dict(tensor=dict(size=world_size))), rank=rank, world_size=world_size,
host='localhost', port=10101, backend='nccl')
train(world_size, model)
def train(world_size, get_openfold_state):
of_params, of_grads = get_openfold_state
config = model_config('initial_training', train=True)
config.globals.inplace = False
set_seed(42)
model = AlphaFold(config)
model = inject_fastnn(model)
model.train().cuda()
criterion = AlphaFoldLoss(config.loss)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8)
set_chunk_size(None)
batch = pickle.load(open(get_train_data_path(), 'rb'))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
set_seed(42)
out = model(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss, _ = criterion(out, batch, True)
optimizer.zero_grad()
set_seed(42)
loss.backward()
optimizer.step()
ff_params, ff_grads = get_param_and_grad(model)
params_dif = 0
grads_dif = 0
for name in ff_params.keys():
# the modules' names in fastfold and openfold are not equal
# it leads some differences on the order of the parameters
# it's not a hard problem to solve
# but check the params and grads of the same part may be just enough
if name not in of_params.keys():
continue
dif = torch.max(torch.abs(ff_params[name] - of_params[name]))
if dif > params_dif:
params_dif = dif
dif = torch.max(torch.abs(ff_grads[name] - of_grads[name]))
if dif > grads_dif:
grads_dif = dif
assert params_dif < 1e-3 and grads_dif < 5e-3, f"Test failed at world size: {world_size}, \
the param dif is {params_dif}, the grad diff is {grads_dif}"
if __name__ == '__main__':
test_state_dict(1, None, None)
\ No newline at end of file
import os
import random
import torch
import numpy as np
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from fastfold.config import model_config
from fastfold.model.hub import AlphaFold, AlphaFoldLRScheduler, AlphaFoldLoss
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.validation_utils import compute_validation_metrics
#import logging
#logging.disable(logging.WARNING)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
def log_loss(loss_breakdown, batch, outputs, train=True):
loss_info = ''
for loss_name, loss_value in loss_breakdown.items():
loss_info += (f' {loss_name}=' + "{:.3f}".format(loss_value))
with torch.no_grad():
other_metrics = compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for loss_name, loss_value in other_metrics.items():
loss_info += (f' {loss_name}=' + "{:.3f}".format(loss_value))
return loss_info
def main():
parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=False, action='store_true')
parser.add_argument(
"--template_mmcif_dir", type=str,
help="Directory containing mmCIF files to search for templates"
)
parser.add_argument(
"--max_template_date", type=str,
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser.add_argument(
"--train_data_dir", type=str,
help="Directory containing training mmCIF files"
)
parser.add_argument(
"--train_alignment_dir", type=str,
help="Directory containing precomputed training alignments"
)
parser.add_argument(
"--train_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files"
)
parser.add_argument(
"--distillation_alignment_dir", type=str, default=None,
help="Directory containing precomputed distillation alignments"
)
parser.add_argument(
"--distillation_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--val_data_dir", type=str, default=None,
help="Directory containing validation mmCIF files"
)
parser.add_argument(
"--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments"
)
parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary"
)
parser.add_argument(
"--train_filter_path", type=str, default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set'''
)
parser.add_argument(
"--distillation_filter_path", type=str, default=None,
help="""See --train_filter_path"""
)
parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None,
help="""Path to obsolete.dat file containing list of obsolete PDBs and
their replacements."""
)
parser.add_argument(
"--template_release_dates_cache_path", type=str, default=None,
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
help=(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
help="Training alignment index. See the README for instructions."
)
parser.add_argument(
"--config_preset", type=str, default="initial_training",
help=(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None,
)
parser.add_argument(
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
)
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed"
)
parser.add_argument(
"--max_epochs", type=int, default=10000,
help="The Max epochs of train"
)
parser.add_argument(
"--log_interval", type=int, default=1,
help="The interval steps of logging during training"
)
parser.add_argument(
"--log_path", type=str, default='train_log',
help="The path of log folder"
)
parser.add_argument(
"--save_ckpt_path", type=str, default=None,
help="The path where to save checkpoint, None means not save"
)
parser.add_argument(
"--save_ckpt_interval", type=int, default=1,
help="The interval epochs of save checkpoint"
)
parser.add_argument(
"--dap_size", type=int, default=1,
help="DAP size, recommended as 1 - nproc_per_node"
)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
if args.from_torch:
colossalai.launch_from_torch(config=dict(parallel=dict(tensor=dict(size=args.dap_size)),
torch_ddp=dict(static_graph=True)))
disable_existing_loggers()
logger = get_dist_logger()
logger.log_to_file(args.log_path)
config = model_config(args.config_preset, train=True)
config.globals.inplace = False
model = AlphaFold(config)
model = inject_fastnn(model)
train_dataset, test_dataset = SetupTrainDataset(
config=config.data,
template_mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
train_data_dir=args.train_data_dir,
train_alignment_dir=args.train_alignment_dir,
train_chain_data_cache_path=args.train_chain_data_cache_path,
distillation_data_dir=args.distillation_data_dir,
distillation_alignment_dir=args.distillation_alignment_dir,
distillation_chain_data_cache_path=args.distillation_chain_data_cache_path,
val_data_dir=args.val_data_dir,
val_alignment_dir=args.val_alignment_dir,
kalign_binary_path=args.kalign_binary_path,
# train_mapping_path=args.train_mapping_path,
# distillation_mapping_path=args.distillation_mapping_path,
obsolete_pdbs_file_path=args.obsolete_pdbs_file_path,
template_release_dates_cache_path=args.template_release_dates_cache_path,
train_epoch_len=args.train_epoch_len,
_alignment_index_path=args._alignment_index_path,
)
train_dataloader, test_dataloader = TrainDataLoader(
config=config.data,
train_dataset=train_dataset,
test_dataset=test_dataset,
batch_seed=args.seed,
)
criterion = AlphaFoldLoss(config.loss)
optimizer = HybridAdam(model.parameters(), lr=1e-3, eps=1e-8)
lr_scheduler = AlphaFoldLRScheduler(optimizer)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
lr_scheduler=lr_scheduler,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
)
logger.info('Start training.', ranks=[0])
for epoch in range(args.max_epochs):
engine.train()
for i, batch in enumerate(train_dataloader):
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
output = engine(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss, loss_breakdown = engine.criterion(
output, batch, _return_breakdown=True)
if (i+1) % args.log_interval == 0:
logger.info(f'Training, Epoch: {epoch}, Step: {i+1}, Global_Step: {epoch*len(train_dataloader)+i+1},' +
f' Loss:{log_loss(loss_breakdown, batch, output)}', ranks=[0])
engine.zero_grad()
engine.backward(loss)
engine.step()
lr_scheduler.step()
if test_dataloader is not None:
engine.eval()
for i, batch in enumerate(test_dataloader):
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
with torch.no_grad():
output = engine(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
batch["use_clamped_fape"] = 0.
_, loss_breakdown = engine.criterion(
output, batch, _return_breakdown=True)
logger.info(f'Validation, Step: {i+1}, \
Loss:{log_loss(loss_breakdown, batch, output, False)}', ranks=[0])
if (args.save_ckpt_path is not None) and ( (epoch+1) % args.save_ckpt_interval == 0):
torch.save(engine.model, os.path.join(args.save_ckpt_path, 'model.pth'))
if __name__ == "__main__":
main()
DATA_DIR=/path/to/data
PROJECT_DIR=/path/to/project
gpus_per_node=2
nnodes=1
max_template_date=2021-10-10
train_data_dir=${DATA_DIR}/mmcif_dir # specify the dir contains *.cif or *.pdb
train_alignment_dir=${DATA_DIR}/alignment_dir # a dir to save template and features.pkl of training sequence
mkdir -p ${train_alignment_dir}
# val_data_dir=${PROJECT_DIR}/dataset/val_pdb
# val_alignment_dir=${PROJECT_DIR}/dataset/alignment_val_pdb # a dir to save template and features.pkl of vld sequence
template_mmcif_dir=${DATA_DIR}/data/pdb_mmcif/mmcif_files
template_release_dates_cache_path=${DATA_DIR}/mmcif_cache.json # a cache used to pre-filter templates
train_chain_data_cache_path=${DATA_DIR}/chain_data_cache.json # a separate chain-level cache with data used for training-time data filtering
train_epoch_len=10000 # virtual length of each training epoch, which affects frequency of validation & checkpointing
torchrun --standalone --nproc_per_node ${gpus_per_node} --nnodes ${nnodes} train.py \
--from_torch \
--template_mmcif_dir=${template_mmcif_dir} \
--max_template_date=${max_template_date} \
--train_data_dir=${train_data_dir} \
--train_alignment_dir=${train_alignment_dir} \
--train_chain_data_cache_path=${train_chain_data_cache_path} \
--template_release_dates_cache_path=${template_release_dates_cache_path} \
--train_epoch_len=${train_epoch_len} \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment