"docs/source/en/using-diffusers/custom_pipeline_examples.mdx" did not exist on "8aac1f99d7af5873db7d23c07fba370d0f5061a6"
Unverified Commit 8a599895 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

Re-organize file and align results (#99)

* Move template modification from nn to fastnn (#97)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* Align results and update unit tests (#98)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* update evoformer stack test

* update test

* update msa_att_row

* update msa_att_col

* update evoformer and evo-stack

* update evoformer

* update extramsa

* move model loading out of the loop

* finish template test

* update test

* Move template modification from nn to fastnn (#84)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code

* move template modification to fastnn, restore template in nn

* re-organize fastnn

* update evoformer unit test

* move model out of function

* only test inference

* remove cache in build

* update test inference

* restore changes

* restore build changes

* update inference and evoformer stack

* fix some bug

* update test

* update evoformer stack test

* update test

* update test

* fix test

* update test

* update test

* update input embedder

* update embedder

* reset core

* update test

* support template multimer in inject_nn
parent 7a69a181
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.distributed.comm import gather, scatter, row_to_col
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.evoformer.blocks[0].msa.MSAColumnAttention.eval().cuda()
target_module = target_module.evoformer.blocks[0].msa_att_col.eval().cuda()
msa_len = 300
seq_len = 300
m = torch.randn((msa_len, seq_len, 256)).cuda()
m_mask = torch.ones((msa_len, seq_len)).cuda().to(dtype=m.dtype)
m_out = m + target_module(m, mask=m_mask, chunk_size=None)
return m_out, m, m_mask, fast_module
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
def test_state_dict(world_size, chunk_size, get_openfold_module_and_data):
run_func = partial(_test_msa_att_col, world_size=world_size, chunk_size=chunk_size, get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_msa_att_col(rank, world_size, chunk_size, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
m_out, m, m_mask, fast_module = get_openfold_module_and_data
fast_module = copy.deepcopy(fast_module).cuda()
fast_m = copy.deepcopy(m.cuda()).unsqueeze(0)
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = m_mask.cuda().size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
fast_m = torch.nn.functional.pad(fast_m, (0, 0, 0, padding_size))
fast_m = scatter(fast_m, dim=1)
fast_m_mask = copy.deepcopy(m_mask.cuda()).unsqueeze(0)
fast_m_mask = torch.nn.functional.pad(fast_m_mask, (0, padding_size))
with torch.no_grad():
set_chunk_size(chunk_size)
fast_m = row_to_col(fast_m)
fast_m_mask = scatter(fast_m_mask, dim=2)
m_fast = fast_module(fast_m, fast_m_mask)
m_fast = m_fast.squeeze(0)
m_fast = gather(m_fast, dim=1)
m_fast = m_fast[:, :-padding_size, :]
error = torch.max(torch.abs(m_out.cuda() - m_fast))
assert error < 5e-5, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}"
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.utils.test_utils import get_param_path
from fastfold.distributed.comm import gather, scatter
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.evoformer.blocks[0].msa.MSARowAttentionWithPairBias.eval().cuda()
target_module1 = target_module.evoformer.blocks[0].msa_att_row.eval().cuda()
target_module2 = target_module.evoformer.blocks[0].msa_dropout_layer.eval().cuda()
msa_len = 300
seq_len = 300
m = torch.randn((msa_len, seq_len, 256)).cuda()
m_mask = torch.ones((msa_len, seq_len)).cuda().to(dtype=m.dtype)
z = torch.randn((seq_len, seq_len, 128)).cuda()
z_mask = torch.ones((seq_len, seq_len)).cuda().to(dtype=z.dtype)
m_out = m + target_module2(target_module1(m, z=z, mask=m_mask, chunk_size=None))
return m_out, m, z, m_mask, z_mask, fast_module
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
def test_state_dict(world_size, chunk_size, get_openfold_module_and_data):
run_func = partial(_test_msa_att_row, world_size=world_size, chunk_size=chunk_size, get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_msa_att_row(rank, world_size, chunk_size, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
m_out, m, z, m_mask, z_mask, fast_module = get_openfold_module_and_data
fast_module = copy.deepcopy(fast_module).cuda()
fast_m = copy.deepcopy(m.cuda()).unsqueeze(0)
fast_z = copy.deepcopy(z.cuda()).unsqueeze(0)
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = z_mask.cuda().size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
fast_m = torch.nn.functional.pad(fast_m, (0, 0, 0, padding_size))
fast_z = torch.nn.functional.pad(fast_z, (0, 0, 0, padding_size, 0, padding_size))
fast_m = scatter(fast_m, dim=1)
fast_z = scatter(fast_z, dim=1)
fast_m_mask = copy.deepcopy(m_mask.cuda()).unsqueeze(0)
fast_m_mask = torch.nn.functional.pad(fast_m_mask, (0, padding_size))
with torch.no_grad():
set_chunk_size(chunk_size)
fast_m_mask = scatter(fast_m_mask.cuda(), dim=1)
m_fast = fast_module(fast_m.cuda(), fast_z.cuda(), fast_m_mask)
m_fast = m_fast.squeeze(0)
m_fast = gather(m_fast, dim=0)
m_fast = m_fast[:, :-padding_size, :]
error = torch.max(torch.abs(m_out.cuda() - m_fast))
assert error < 5e-5, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}"
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.distributed.comm import gather, scatter, row_to_col
from fastfold.utils.test_utils import get_param_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.extra_msa_stack.blocks[0].msa_stack.MSAColumnAttention.eval().cuda()
target_module = target_module.extra_msa_stack.blocks[0].msa_att_col.eval().cuda()
msa_len = 512
seq_len = 128
m = torch.randn((msa_len, seq_len, 64)).cuda()
m_mask = torch.ones((msa_len, seq_len)).cuda().to(dtype=m.dtype)
m_mask[128:, :] = 0
m_out = m + target_module(m, mask=m_mask, chunk_size=None)
return m_out, m, m_mask, fast_module
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
def test_state_dict(world_size, chunk_size, get_openfold_module_and_data):
run_func = partial(_test_msa_global_att_col, world_size=world_size, chunk_size=chunk_size, get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_msa_global_att_col(rank, world_size, chunk_size, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
m_out, m, m_mask, fast_module = get_openfold_module_and_data
fast_module = copy.deepcopy(fast_module).cuda()
fast_m = copy.deepcopy(m.cuda()).unsqueeze(0)
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = m_mask.cuda().size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
fast_m = torch.nn.functional.pad(fast_m, (0, 0, 0, padding_size))
fast_m = scatter(fast_m, dim=1)
fast_m_mask = copy.deepcopy(m_mask.cuda()).unsqueeze(0)
fast_m_mask = torch.nn.functional.pad(fast_m_mask, (0, padding_size))
with torch.no_grad():
set_chunk_size(chunk_size)
fast_m = row_to_col(fast_m)
fast_m_mask = scatter(fast_m_mask, dim=2)
m_fast = fast_module(fast_m, fast_m_mask)
m_fast = m_fast.squeeze(0)
m_fast = gather(m_fast, dim=1)
m_fast = m_fast[:, :-padding_size, :]
error = torch.max(torch.abs(m_out.cuda() - m_fast))
assert error < 5e-5, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}"
import torch
import pytest
import copy
import fastfold
from fastfold.model.fastnn.ops import OutProductMean as FastOutProductMean, set_chunk_size
from fastfold.model.nn.outer_product_mean import OuterProductMean
import os
import torch.multiprocessing as mp
from functools import partial
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.utils.test_utils import get_param_path
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.config import model_config
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.utils.test_utils import get_param_path
from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm import gather, scatter
def test_out_product_mean():
fastfold.distributed.init_dap()
msa_len = 20
seq_len = 30
dim_m = 32
dim_z = 64
hidden = 16
fast_opm = FastOutProductMean(n_feat=dim_m, n_feat_out=dim_z, n_feat_proj=hidden).cuda()
opm = OuterProductMean(c_m=dim_m, c_z=dim_z, c_hidden=hidden).cuda()
fast_opm.linear_a.weight = opm.linear_1.weight
fast_opm.linear_a.bias = opm.linear_1.bias
fast_opm.linear_b.weight = opm.linear_2.weight
fast_opm.linear_b.bias = opm.linear_2.bias
fast_opm.o_linear.weight = opm.linear_out.weight
fast_opm.o_linear.bias = opm.linear_out.bias
m = torch.randn((1, msa_len, seq_len, dim_m)).cuda()
m_mask = torch.ones((1, msa_len, seq_len)).cuda()
m_mask[:, :, -5:] = 0
z = torch.zeros((1, seq_len, seq_len, dim_z)).cuda()
out = fast_opm(m, m_mask, z)
out_fast = opm(m, m_mask)
assert torch.allclose(out, out_fast, atol=1e-6)
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.evoformer.blocks[0].communication.eval().cuda()
target_module = target_module.evoformer.blocks[0].core.outer_product_mean.eval().cuda()
msa_len = 20
seq_len = 30
m = torch.randn((msa_len, seq_len, 256)).cuda()
m_mask = torch.ones((msa_len, seq_len)).cuda()
m_mask[:, -5:] = 0
z = torch.zeros((seq_len, seq_len, 128)).cuda()
set_chunk_size(1)
out_fast = opm(m, m_mask)
assert torch.allclose(out, out_fast, atol=1e-6)
out = target_module(m, m_mask)
return m, m_mask, z, fast_module, out
out_fast = fast_opm.inplace(m, m_mask, [z])[0]
assert torch.allclose(out, out_fast, atol=1e-6)
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(_test_out_product_mean, world_size=world_size, chunk_size=chunk_size,
inplace=inplace, get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
if __name__ == "__main__":
test_out_product_mean()
def _test_out_product_mean(rank, world_size, chunk_size, inplace, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
m, m_mask, z, fast_module, out = get_openfold_module_and_data
fast_module = copy.deepcopy(fast_module).cuda()
fast_m = copy.deepcopy(m.cuda()).unsqueeze(0)
fast_z = copy.deepcopy(z.cuda()).unsqueeze(0)
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = m_mask.cuda().size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
fast_m = torch.nn.functional.pad(fast_m, (0, 0, 0, padding_size))
fast_z = torch.nn.functional.pad(fast_z, (0, 0, 0, padding_size, 0, padding_size))
fast_m = scatter(fast_m, dim=1)
fast_z = scatter(fast_z, dim=1)
fast_m_mask = copy.deepcopy(m_mask.cuda()).unsqueeze(0)
fast_m_mask = torch.nn.functional.pad(fast_m_mask, (0, padding_size))
with torch.no_grad():
set_chunk_size(chunk_size)
fast_m = row_to_col(fast_m)
if inplace:
out_fast = fast_module.inplace(fast_m, fast_m_mask, [fast_z])[0]
else:
out_fast = fast_module(fast_m, fast_m_mask, fast_z)
out_fast = out_fast.squeeze(0)
out_fast = gather(out_fast, dim=0)
out_fast = out_fast[:-padding_size, :-padding_size, :]
error = torch.mean(torch.abs(out.cuda() - out_fast))
assert error < 1e-5, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
import torch
import pytest
import pickle
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.test_utils import get_param_path, get_data_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
with torch.no_grad():
config = model_config('model_1')
config.globals.inplace = False
target_module = AlphaFold(config)
import_jax_weights_(target_module, get_param_path())
fast_module = copy.deepcopy(target_module)
fast_module = inject_fastnn(fast_module)
fast_module = fast_module.template_embedder
fast_module = fast_module.eval().cuda()
target_module = target_module.template_embedder
target_module = target_module.eval().cuda()
batch = pickle.load(open(get_data_path(), 'rb'))
fetch_cur_batch = lambda t: t[..., 0]
feats = tensor_tree_map(fetch_cur_batch, batch)
feats = {k: v.cuda() for k, v in feats.items() if k.startswith("template_")}
seq_len = 33
z = torch.randn((seq_len, seq_len, 128)).cuda()
z_mask = torch.ones((seq_len, seq_len)).cuda().to(dtype=z.dtype)
template_embeds = target_module(copy.deepcopy(feats), z, z_mask.to(dtype=z.dtype), 0, None)
z_out = z + template_embeds["template_pair_embedding"]
return fast_module, z_out, feats, z, z_mask
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(_test_template_embedder, world_size=world_size, chunk_size=chunk_size,
inplace=inplace, get_openfold_module_and_data=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def _test_template_embedder(rank, world_size, chunk_size, inplace, get_openfold_module_and_data):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
fast_module, z_out, feats, z, z_mask = get_openfold_module_and_data
fast_module = copy.deepcopy(fast_module).cuda()
template_feats = copy.deepcopy(feats)
for k, v in template_feats.items():
template_feats[k] = v.cuda()
with torch.no_grad():
set_chunk_size(chunk_size)
if inplace:
template_embeds = fast_module(copy.deepcopy(template_feats), copy.deepcopy(z).cuda(), z_mask.to(dtype=z.dtype).cuda(), 0, chunk_size, inplace=inplace)
z_fast = template_embeds["template_pair_embedding"]
else:
template_embeds = fast_module(copy.deepcopy(template_feats), copy.deepcopy(z).cuda(), z_mask.to(dtype=z.dtype).cuda(), 0, chunk_size)
z_fast = z.cuda() + template_embeds["template_pair_embedding"]
error = torch.mean(torch.abs(z_out.cuda() - z_fast))
assert error < 5e-4, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"
......@@ -10,52 +10,66 @@ import fastfold
from fastfold.model.hub import AlphaFold
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.utils import inject_fastnn
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.test_utils import get_data_path, get_param_path
@pytest.fixture(scope="module")
def get_openfold_module_and_data():
config = model_config('model_1')
config.globals.inplace = False
model = AlphaFold(config)
import_jax_weights_(model, get_param_path())
model.eval().cuda()
batch = pickle.load(open(get_data_path(), 'rb'))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
with torch.no_grad():
out = model(batch)
fastmodel = copy.deepcopy(model)
fastmodel = inject_fastnn(fastmodel)
fastmodel.eval().cuda()
return model, out, fastmodel
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace):
run_func = partial(run_dist, world_size=world_size, chunk_size=chunk_size, inplace=inplace)
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
run_func = partial(run_dist, world_size=world_size, chunk_size=chunk_size, inplace=inplace, model=get_openfold_module_and_data)
mp.spawn(run_func, nprocs=world_size)
def run_dist(rank, world_size, chunk_size, inplace):
def run_dist(rank, world_size, chunk_size, inplace, model):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
# init distributed for Dynamic Axial Parallelism
fastfold.distributed.init_dap()
inference(chunk_size, inplace)
inference(chunk_size, inplace, model)
def inference(chunk_size, inplace, get_openfold_module_and_data):
def inference(chunk_size, inplace):
model, out, fastmodel = get_openfold_module_and_data
config = model_config('model_1')
config.globals.chunk_size = chunk_size
config.globals.inplace = False
model = AlphaFold(config)
import_jax_weights_(model, '/data/scratch/fastfold/weight.npz')
model.eval()
model.cuda()
model.globals.chunk_size = chunk_size
model.globals.inplace = inplace
fastmodel = copy.deepcopy(model)
fastmodel = inject_fastnn(fastmodel)
fastmodel.eval()
fastmodel.cuda()
fastmodel = copy.deepcopy(fastmodel).cuda()
fastmodel.structure_module.default_frames = fastmodel.structure_module.default_frames.cuda()
fastmodel.structure_module.group_idx = fastmodel.structure_module.group_idx.cuda()
fastmodel.structure_module.atom_mask = fastmodel.structure_module.atom_mask.cuda()
fastmodel.structure_module.lit_positions = fastmodel.structure_module.lit_positions.cuda()
set_chunk_size(model.globals.chunk_size)
batch = pickle.load(open('/data/scratch/fastfold/mono_batch.pkl', 'rb'))
batch = pickle.load(open(get_data_path(), 'rb'))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
fastbatch = copy.deepcopy(batch)
with torch.no_grad():
out = model(batch)
config.globals.inplace = inplace
fastout = fastmodel(fastbatch)
fastout = fastmodel(batch)
pos_dif = torch.max(torch.abs(fastout["final_atom_positions"] - out["final_atom_positions"]))
assert pos_dif < 1.5, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {pos_dif}"
pos_dif = torch.max(torch.abs(fastout["final_atom_positions"] - out["final_atom_positions"].cuda()))
assert pos_dif < 5e-4, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {pos_dif}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment