import os from functools import wraps from typing import Callable import numpy as np import pytest import torch import torch.nn.functional as F from e3nn import o3 from torch.testing import assert_close from mace import data, modules, tools from mace.tools import compile as mace_compile from mace.tools import torch_geometric table = tools.AtomicNumberTable([6]) atomic_energies = np.array([1.0], dtype=float) cutoff = 5.0 def create_mace(device: str, seed: int = 1702): torch_geometric.seed_everything(seed) model_config = { "r_max": cutoff, "num_bessel": 8, "num_polynomial_cutoff": 6, "max_ell": 3, "interaction_cls": modules.interaction_classes[ "RealAgnosticResidualInteractionBlock" ], "interaction_cls_first": modules.interaction_classes[ "RealAgnosticResidualInteractionBlock" ], "num_interactions": 2, "num_elements": 1, "hidden_irreps": o3.Irreps("128x0e + 128x1o"), "MLP_irreps": o3.Irreps("16x0e"), "gate": F.silu, "atomic_energies": atomic_energies, "avg_num_neighbors": 8, "atomic_numbers": table.zs, "correlation": 3, "radial_type": "bessel", "atomic_inter_scale": 1.0, "atomic_inter_shift": 0.0, } model = modules.ScaleShiftMACE(**model_config) return model.to(device) def create_batch(device: str): from ase import build size = 2 atoms = build.bulk("C", "diamond", a=3.567, cubic=True) atoms_list = [atoms.repeat((size, size, size))] print("Number of atoms", len(atoms_list[0])) configs = [data.config_from_atoms(atoms) for atoms in atoms_list] data_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) for config in configs ], batch_size=1, shuffle=False, drop_last=False, ) batch = next(iter(data_loader)) batch = batch.to(device) batch = batch.to_dict() return batch def time_func(func: Callable): @wraps(func) def wrapper(*args, **kwargs): torch._inductor.cudagraph_mark_step_begin() # pylint: disable=W0212 outputs = func(*args, **kwargs) torch.cuda.synchronize() return outputs return wrapper @pytest.fixture(params=[torch.float32, torch.float64], ids=["fp32", "fp64"]) def default_dtype(request): with tools.torch_tools.default_dtype(request.param): yield torch.get_default_dtype() # skip if on windows @pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_mace(device, default_dtype): # pylint: disable=W0621 print(f"using default dtype = {default_dtype}") if device == "cuda" and not torch.cuda.is_available(): pytest.skip(reason="cuda is not available") model_defaults = create_mace(device) tmp_model = mace_compile.prepare(create_mace)(device) model_compiled = torch.compile(tmp_model, mode="default") batch = create_batch(device) output1 = model_defaults(batch, training=True) output2 = model_compiled(batch, training=True) assert_close(output1["energy"], output2["energy"]) assert_close(output1["forces"], output2["forces"]) @pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 print(f"using default dtype = {default_dtype}") batch = create_batch("cuda") model = create_mace("cuda") model = time_func(model) benchmark(model, batch, training=True) @pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") @pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) @pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) def test_compile_benchmark(benchmark, compile_mode, enable_amp): if enable_amp: pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") with tools.torch_tools.default_dtype(torch.float32): batch = create_batch("cuda") torch.compiler.reset() model = mace_compile.prepare(create_mace)("cuda") model = torch.compile(model, mode=compile_mode) model = time_func(model) with torch.autocast("cuda", enabled=enable_amp): benchmark(model, batch, training=True) @pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") def test_graph_breaks(): import torch._dynamo as dynamo batch = create_batch("cuda") model = mace_compile.prepare(create_mace)("cuda") explanation = dynamo.explain(model)(batch, training=False) # these clutter the output but might be useful for investigating graph breaks explanation.ops_per_graph = None explanation.out_guards = None print(explanation) assert explanation.graph_break_count == 0