Unverified Commit 95150c38 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

fix distributed sampler for ddp and add dap (#141)

* fix sampler of ddp and add dap

* add end to end unit test for training

* modify unit test datapath

* lower the bar for cuda kernel

* skip unit test of end-to-end train on 3080
parent 19ce8406
...@@ -21,7 +21,7 @@ from typing import Optional, Sequence, List, Any ...@@ -21,7 +21,7 @@ from typing import Optional, Sequence, List, Any
import ml_collections as mlc import ml_collections as mlc
import torch import torch
from colossalai.utils import is_using_ddp
from fastfold.data import ( from fastfold.data import (
data_pipeline, data_pipeline,
feature_pipeline, feature_pipeline,
...@@ -384,8 +384,8 @@ class OpenFoldBatchCollator: ...@@ -384,8 +384,8 @@ class OpenFoldBatchCollator:
class OpenFoldDataLoader(torch.utils.data.DataLoader): class OpenFoldDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, config, stage="train", generator=None, **kwargs): def __init__(self, dataset, config, stage="train", generator=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(dataset, **kwargs)
self.config = config self.config = config
self.stage = stage self.stage = stage
...@@ -604,28 +604,36 @@ def TrainDataLoader( ...@@ -604,28 +604,36 @@ def TrainDataLoader(
generator = generator.manual_seed(batch_seed) generator = generator.manual_seed(batch_seed)
train_batch_collator = OpenFoldBatchCollator(config, "train") train_batch_collator = OpenFoldBatchCollator(config, "train")
train_sampler = None
if is_using_ddp():
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataset.reroll() train_dataset.reroll()
train_dataloader = OpenFoldDataLoader( train_dataloader = OpenFoldDataLoader(
train_dataset, dataset=train_dataset,
config=config, config=config,
stage="train", stage="train",
generator=generator, generator=generator,
batch_size=config.data_module.data_loaders.batch_size, batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers, num_workers=config.data_module.data_loaders.num_workers,
collate_fn=train_batch_collator, collate_fn=train_batch_collator,
sampler=train_sampler,
) )
test_dataloader = None test_dataloader = None
if test_dataset is not None: if test_dataset is not None:
test_sampler = None
if is_using_ddp():
test_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
test_batch_collator = OpenFoldBatchCollator(config, "test") test_batch_collator = OpenFoldBatchCollator(config, "test")
test_dataloader = OpenFoldDataLoader( test_dataloader = OpenFoldDataLoader(
train_dataset, dataset=test_dataset,
config=config, config=config,
stage="test", stage="test",
generator=generator, generator=generator,
batch_size=config.data_module.data_loaders.batch_size, batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers, num_workers=config.data_module.data_loaders.num_workers,
collate_fn=test_batch_collator, collate_fn=test_batch_collator,
sampler=test_sampler,
) )
return train_dataloader, test_dataloader return train_dataloader, test_dataloader
import os import os
import random
import torch
import numpy as np
def get_param_path(): def get_param_path():
# develop # develop
...@@ -15,3 +18,16 @@ def get_data_path(): ...@@ -15,3 +18,16 @@ def get_data_path():
return '/home/lczxl/data2/fastfold/example_input/mono_batch.pkl' return '/home/lczxl/data2/fastfold/example_input/mono_batch.pkl'
# test # test
return '/data/scratch/fastfold/mono_batch.pkl' return '/data/scratch/fastfold/mono_batch.pkl'
def get_train_data_path():
return '/data/scratch/fastfold/std_train_batch.pkl'
def set_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
\ No newline at end of file
import os
import pytest
import torch
import pickle
import torch.multiprocessing as mp
from functools import partial
import colossalai
from fastfold.model.hub import AlphaFold
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.test_utils import get_train_data_path
from fastfold.model.hub.loss import AlphaFoldLoss
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.test_utils import set_seed
def get_param_and_grad(model):
params = dict()
grads = dict()
for name, param in model.named_parameters():
params[name] = param.detach().clone()
grads[name] = param.grad.detach().clone()
return params, grads
@pytest.fixture(scope="module")
def get_openfold_state():
config = model_config('initial_training', train=True)
config.globals.inplace = False
set_seed(42)
model = AlphaFold(config)
model.train().cuda()
criterion = AlphaFoldLoss(config.loss)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8)
batch = pickle.load(open(get_train_data_path(), 'rb'))
set_seed(42)
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
out = model(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss, _ = criterion(out, batch, True)
optimizer.zero_grad()
set_seed(42)
loss.backward()
optimizer.step()
of_params, of_grads = get_param_and_grad(model)
return of_params, of_grads
@pytest.mark.skipif(torch.cuda.mem_get_info(0)[1] < 4e10, reason="Not enough cuda memory")
@pytest.mark.parametrize('world_size', [1])
def test_state_dict(world_size, get_openfold_state):
run_func = partial(run_dist, world_size=world_size, model=get_openfold_state)
mp.spawn(run_func, nprocs=world_size)
def run_dist(rank, world_size, model):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
colossalai.launch(config=dict(parallel=dict(tensor=dict(size=world_size))), rank=rank, world_size=world_size,
host='localhost', port=10101, backend='nccl')
train(world_size, model)
def train(world_size, get_openfold_state):
of_params, of_grads = get_openfold_state
config = model_config('initial_training', train=True)
config.globals.inplace = False
set_seed(42)
model = AlphaFold(config)
model = inject_fastnn(model)
model.train().cuda()
criterion = AlphaFoldLoss(config.loss)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-8)
set_chunk_size(None)
batch = pickle.load(open(get_train_data_path(), 'rb'))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
set_seed(42)
out = model(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
loss, _ = criterion(out, batch, True)
optimizer.zero_grad()
set_seed(42)
loss.backward()
optimizer.step()
ff_params, ff_grads = get_param_and_grad(model)
params_dif = 0
grads_dif = 0
for name in ff_params.keys():
# the modules' names in fastfold and openfold are not equal
# it leads some differences on the order of the parameters
# it's not a hard problem to solve
# but check the params and grads of the same part may be just enough
if name not in of_params.keys():
continue
dif = torch.max(torch.abs(ff_params[name] - of_params[name]))
if dif > params_dif:
params_dif = dif
dif = torch.max(torch.abs(ff_grads[name] - of_grads[name]))
if dif > grads_dif:
grads_dif = dif
assert params_dif < 1e-3 and grads_dif < 5e-3, f"Test failed at world size: {world_size}, \
the param dif is {params_dif}, the grad diff is {grads_dif}"
if __name__ == '__main__':
test_state_dict(1, None, None)
\ No newline at end of file
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
import numpy as np import numpy as np
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from fastfold.config import model_config from fastfold.config import model_config
...@@ -13,7 +12,6 @@ from fastfold.utils.inject_fastnn import inject_fastnn ...@@ -13,7 +12,6 @@ from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader from fastfold.data.data_modules import SetupTrainDataset, TrainDataLoader
from fastfold.utils.tensor_utils import tensor_tree_map from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.validation_utils import compute_validation_metrics from fastfold.utils.validation_utils import compute_validation_metrics
#import logging #import logging
#logging.disable(logging.WARNING) #logging.disable(logging.WARNING)
import torch.multiprocessing import torch.multiprocessing
...@@ -153,6 +151,10 @@ def main(): ...@@ -153,6 +151,10 @@ def main():
"--save_ckpt_interval", type=int, default=1, "--save_ckpt_interval", type=int, default=1,
help="The interval epochs of save checkpoint" help="The interval epochs of save checkpoint"
) )
parser.add_argument(
"--dap_size", type=int, default=1,
help="DAP size, recommended as 1 - nproc_per_node"
)
args = parser.parse_args() args = parser.parse_args()
random.seed(args.seed) random.seed(args.seed)
...@@ -160,7 +162,8 @@ def main(): ...@@ -160,7 +162,8 @@ def main():
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
if args.from_torch: if args.from_torch:
colossalai.launch_from_torch(config=dict(torch_ddp=dict(static_graph=True))) colossalai.launch_from_torch(config=dict(parallel=dict(tensor=dict(size=args.dap_size)),
torch_ddp=dict(static_graph=True)))
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
logger.log_to_file(args.log_path) logger.log_to_file(args.log_path)
...@@ -227,7 +230,7 @@ def main(): ...@@ -227,7 +230,7 @@ def main():
loss, loss_breakdown = engine.criterion( loss, loss_breakdown = engine.criterion(
output, batch, _return_breakdown=True) output, batch, _return_breakdown=True)
if (i+1) % args.log_interval == 0: if (i+1) % args.log_interval == 0:
logger.info(f'Training, Epoch: {epoch}, Step: {i+1}, Global_Step: {epoch*args.train_epoch_len+i+1},' + logger.info(f'Training, Epoch: {epoch}, Step: {i+1}, Global_Step: {epoch*len(train_dataloader)+i+1},' +
f' Loss:{log_loss(loss_breakdown, batch, output)}', ranks=[0]) f' Loss:{log_loss(loss_breakdown, batch, output)}', ranks=[0])
engine.zero_grad() engine.zero_grad()
engine.backward(loss) engine.backward(loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment