Unverified Commit ceee81d7 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

support training, fix some inplace func of nn (#118)



* support training, fix some inplace func of nn

* fix some merge issue
Co-authored-by: default avatarshenggan <csg19971016@gmail.com>
parent 164f6777
...@@ -391,7 +391,7 @@ class TemplatePairStack(nn.Module): ...@@ -391,7 +391,7 @@ class TemplatePairStack(nn.Module):
for i in range(0, t.shape[0]): for i in range(0, t.shape[0]):
t[i] = self.layer_norm(t[i]) t[i] = self.layer_norm(t[i])
else: else:
t = self.layer_norm(t[i]) t = self.layer_norm(t)
return t return t
def inplace( def inplace(
......
from .alphafold import AlphaFold from .alphafold import AlphaFold
from .lr_scheduler import AlphaFoldLRScheduler
from .loss import AlphaFoldLoss
__all__ = ["AlphaFold"] __all__ = ["AlphaFold", "AlphaFoldLRScheduler", "AlphaFoldLoss"]
\ No newline at end of file \ No newline at end of file
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
""" Implements the learning rate schedule defined in the AlphaFold 2
supplement. A linear warmup is followed by a plateau at the maximum
learning rate and then exponential decay.
Note that the initial learning rate of the optimizer in question is
ignored; use this class' base_lr parameter to specify the starting
point of the warmup.
"""
def __init__(self,
optimizer,
last_epoch: int = -1,
verbose: bool = False,
base_lr: float = 0.,
max_lr: float = 0.001,
warmup_no_steps: int = 1000,
start_decay_after_n_steps: int = 50000,
decay_every_n_steps: int = 50000,
decay_factor: float = 0.95,
):
step_counts = {
"warmup_no_steps": warmup_no_steps,
"start_decay_after_n_steps": start_decay_after_n_steps,
}
for k,v in step_counts.items():
if(v < 0):
raise ValueError(f"{k} must be nonnegative")
if(warmup_no_steps > start_decay_after_n_steps):
raise ValueError(
"warmup_no_steps must not exceed start_decay_after_n_steps"
)
self.optimizer = optimizer
self.last_epoch = last_epoch
self.verbose = verbose
self.base_lr = base_lr
self.max_lr = max_lr
self.warmup_no_steps = warmup_no_steps
self.start_decay_after_n_steps = start_decay_after_n_steps
self.decay_every_n_steps = decay_every_n_steps
self.decay_factor = decay_factor
super(AlphaFoldLRScheduler, self).__init__(
optimizer,
last_epoch=last_epoch,
verbose=verbose,
)
def state_dict(self):
state_dict = {
k:v for k,v in self.__dict__.items() if k not in ["optimizer"]
}
return state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
if(not self._get_lr_called_within_step):
raise RuntimeError(
"To get the last learning rate computed by the scheduler, use "
"get_last_lr()"
)
step_no = self.last_epoch
if(step_no <= self.warmup_no_steps):
lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
elif(step_no > self.start_decay_after_n_steps):
steps_since_decay = step_no - self.start_decay_after_n_steps
exp = (steps_since_decay // self.decay_every_n_steps) + 1
lr = self.max_lr * (self.decay_factor ** exp)
else: # plateau
lr = self.max_lr
return [lr for group in self.optimizer.param_groups]
...@@ -56,7 +56,7 @@ class Dropout(nn.Module): ...@@ -56,7 +56,7 @@ class Dropout(nn.Module):
shape[bd] = 1 shape[bd] = 1
mask = x.new_ones(shape) mask = x.new_ones(shape)
mask = self.dropout(mask) mask = self.dropout(mask)
x *= mask x = x * mask
return x return x
......
...@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module): ...@@ -264,11 +264,6 @@ class EvoformerBlock(nn.Module):
eps=eps, eps=eps,
) )
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.is_multimer = is_multimer self.is_multimer = is_multimer
def forward(self, def forward(self,
......
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from fastfold.model.nn.primitives import Linear, LayerNorm from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.model.loss import ( from fastfold.model.hub.loss import (
compute_plddt, compute_plddt,
compute_tm, compute_tm,
compute_predicted_aligned_error, compute_predicted_aligned_error,
......
import random
import torch
import numpy as np
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import HybridAdam
from tqdm import tqdm
from fastfold.config import model_config
from fastfold.model.hub import AlphaFold, AlphaFoldLRScheduler, AlphaFoldLoss
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
from fastfold.utils.tensor_utils import tensor_tree_map
import logging
logging.disable(logging.WARNING)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
def main():
parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=False, action='store_true')
parser.add_argument(
"--template_mmcif_dir", type=str,
help="Directory containing mmCIF files to search for templates"
)
parser.add_argument(
"--max_template_date", type=str,
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser.add_argument(
"--train_data_dir", type=str,
help="Directory containing training mmCIF files"
)
parser.add_argument(
"--train_alignment_dir", type=str,
help="Directory containing precomputed training alignments"
)
parser.add_argument(
"--train_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files"
)
parser.add_argument(
"--distillation_alignment_dir", type=str, default=None,
help="Directory containing precomputed distillation alignments"
)
parser.add_argument(
"--distillation_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--val_data_dir", type=str, default=None,
help="Directory containing validation mmCIF files"
)
parser.add_argument(
"--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments"
)
parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary"
)
parser.add_argument(
"--train_filter_path", type=str, default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set'''
)
parser.add_argument(
"--distillation_filter_path", type=str, default=None,
help="""See --train_filter_path"""
)
parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None,
help="""Path to obsolete.dat file containing list of obsolete PDBs and
their replacements."""
)
parser.add_argument(
"--template_release_dates_cache_path", type=str, default=None,
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
help=(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
help="Training alignment index. See the README for instructions."
)
parser.add_argument(
"--config_preset", type=str, default="initial_training",
help=(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None,
)
parser.add_argument(
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
)
parser.add_argument(
"--seed", type=int, default=42,
help="Random seed"
)
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
if args.from_torch:
colossalai.launch_from_torch(config=dict(torch_ddp=dict(static_graph=True)))
disable_existing_loggers()
logger = get_dist_logger()
config = model_config(args.config_preset, train=True)
config.globals.inplace = False
model = AlphaFold(config)
model = inject_fastnn(model)
train_dataset, test_dataset = SetupTrainDataset(
config=config.data,
template_mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
train_data_dir=args.train_data_dir,
train_alignment_dir=args.train_alignment_dir,
train_chain_data_cache_path=args.train_chain_data_cache_path,
distillation_data_dir=args.distillation_data_dir,
distillation_alignment_dir=args.distillation_alignment_dir,
distillation_chain_data_cache_path=args.distillation_chain_data_cache_path,
val_data_dir=args.val_data_dir,
val_alignment_dir=args.val_alignment_dir,
kalign_binary_path=args.kalign_binary_path,
# train_mapping_path=args.train_mapping_path,
# distillation_mapping_path=args.distillation_mapping_path,
obsolete_pdbs_file_path=args.obsolete_pdbs_file_path,
template_release_dates_cache_path=args.template_release_dates_cache_path,
train_epoch_len=args.train_epoch_len,
_alignment_index_path=args._alignment_index_path,
)
train_dataloader, test_dataloader = TrainDataLoader(
config=config.data,
train_dataset=train_dataset,
test_dataset=test_dataset,
batch_seed=args.seed,
)
criterion = AlphaFoldLoss(config.loss)
optimizer = HybridAdam(model.parameters(), lr=1e-3, eps=1e-8)
lr_scheduler = AlphaFoldLRScheduler(optimizer)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
lr_scheduler=lr_scheduler,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
)
for epoch in range(200):
engine.train()
if gpc.get_global_rank() == 0:
train_dataloader = tqdm(train_dataloader)
for batch in train_dataloader:
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
engine.zero_grad()
output = engine(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss, loss_breakdown = engine.criterion(
output, batch, _return_breakdown=True)
if gpc.get_global_rank() == 0:
train_dataloader.set_postfix(loss=float(loss))
engine.backward(loss)
engine.step()
lr_scheduler.step()
if test_dataloader is not None:
engine.eval()
if gpc.get_global_rank() == 0:
train_dataloader = tqdm(train_dataloader)
for batch in test_dataloader:
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
with torch.no_grad():
output = engine(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
_, loss_breakdown = engine.criterion(
output, batch, _return_breakdown=True)
if gpc.get_global_rank() == 0:
train_dataloader.set_postfix(loss=float(loss))
if __name__ == "__main__":
main()
DATA_DIR=/path/to/data
PROJECT_DIR=/path/to/project
gpus_per_node=2
nnodes=1
max_template_date=2021-10-10
train_data_dir=${DATA_DIR}/mmcif_dir # specify the dir contains *.cif or *.pdb
train_alignment_dir=${DATA_DIR}/alignment_dir # a dir to save template and features.pkl of training sequence
mkdir -p ${train_alignment_dir}
# val_data_dir=${PROJECT_DIR}/dataset/val_pdb
# val_alignment_dir=${PROJECT_DIR}/dataset/alignment_val_pdb # a dir to save template and features.pkl of vld sequence
template_mmcif_dir=${DATA_DIR}/data/pdb_mmcif/mmcif_files
template_release_dates_cache_path=${DATA_DIR}/mmcif_cache.json # a cache used to pre-filter templates
train_chain_data_cache_path=${DATA_DIR}/chain_data_cache.json # a separate chain-level cache with data used for training-time data filtering
train_epoch_len=10000 # virtual length of each training epoch, which affects frequency of validation & checkpointing
torchrun --standalone --nproc_per_node ${gpus_per_node} --nnodes ${nnodes} train.py \
--from_torch \
--template_mmcif_dir=${template_mmcif_dir} \
--max_template_date=${max_template_date} \
--train_data_dir=${train_data_dir} \
--train_alignment_dir=${train_alignment_dir} \
--train_chain_data_cache_path=${train_chain_data_cache_path} \
--template_release_dates_cache_path=${template_release_dates_cache_path} \
--train_epoch_len=${train_epoch_len} \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment