From 2409a22f188e5aeb85439acc5bc948275d4e7815 Mon Sep 17 00:00:00 2001 From: fanding2000 Date: Mon, 1 Dec 2025 14:32:46 +0800 Subject: [PATCH] Format fix. More options in readme --- README.md | 224 +- basic_function/CSP_generator_normal.py | 1228 +++---- .../__pycache__/CSP_function.cpython-310.pyc | Bin 1078 -> 0 bytes .../__pycache__/CSP_function.cpython-311.pyc | Bin 1934 -> 0 bytes .../__pycache__/CSP_function.cpython-313.pyc | Bin 1729 -> 0 bytes .../__pycache__/CSP_function.cpython-38.pyc | Bin 1076 -> 0 bytes .../__pycache__/CSP_function.cpython-39.pyc | Bin 1092 -> 0 bytes .../CSP_generator_normal.cpython-310.pyc | Bin 19052 -> 0 bytes .../CSP_generator_normal.cpython-311.pyc | Bin 30681 -> 0 bytes .../CSP_generator_normal.cpython-313.pyc | Bin 23867 -> 0 bytes .../CSP_generator_normal.cpython-38.pyc | Bin 19038 -> 0 bytes .../CSP_generator_normal.cpython-39.pyc | Bin 19036 -> 0 bytes .../chemical_knowledge.cpython-310.pyc | Bin 10019 -> 0 bytes .../chemical_knowledge.cpython-311.pyc | Bin 14370 -> 0 bytes .../chemical_knowledge.cpython-313.pyc | Bin 13212 -> 0 bytes .../chemical_knowledge.cpython-38.pyc | Bin 8187 -> 0 bytes .../chemical_knowledge.cpython-39.pyc | Bin 7537 -> 0 bytes .../conformer_search.cpython-310.pyc | Bin 1992 -> 0 bytes .../conformer_search.cpython-311.pyc | Bin 3474 -> 0 bytes .../conformer_search.cpython-313.pyc | Bin 3058 -> 0 bytes .../conformer_search.cpython-38.pyc | Bin 1935 -> 0 bytes .../conformer_search.cpython-39.pyc | Bin 1992 -> 0 bytes .../__pycache__/data_classes.cpython-310.pyc | Bin 24657 -> 0 bytes .../__pycache__/data_classes.cpython-311.pyc | Bin 38043 -> 0 bytes .../__pycache__/data_classes.cpython-313.pyc | Bin 34167 -> 0 bytes .../__pycache__/data_classes.cpython-38.pyc | Bin 25263 -> 0 bytes .../__pycache__/data_classes.cpython-39.pyc | Bin 25184 -> 0 bytes .../__pycache__/descriptor.cpython-39.pyc | Bin 3807 -> 0 bytes .../__pycache__/format_parser.cpython-310.pyc | Bin 10361 -> 0 bytes .../__pycache__/format_parser.cpython-311.pyc | Bin 26332 -> 0 bytes .../__pycache__/format_parser.cpython-313.pyc | Bin 22219 -> 0 bytes .../__pycache__/format_parser.cpython-38.pyc | Bin 10320 -> 0 bytes .../__pycache__/format_parser.cpython-39.pyc | Bin 10303 -> 0 bytes .../__pycache__/operation.cpython-310.pyc | Bin 14168 -> 0 bytes .../__pycache__/operation.cpython-311.pyc | Bin 16658 -> 0 bytes .../__pycache__/operation.cpython-313.pyc | Bin 2860 -> 0 bytes .../__pycache__/operation.cpython-38.pyc | Bin 14170 -> 0 bytes .../__pycache__/operation.cpython-39.pyc | Bin 14138 -> 0 bytes .../__pycache__/operation_new.cpython-310.pyc | Bin 5811 -> 0 bytes .../__pycache__/operation_new.cpython-313.pyc | Bin 12016 -> 0 bytes .../__pycache__/operation_new.cpython-38.pyc | Bin 5759 -> 0 bytes .../__pycache__/operation_new.cpython-39.pyc | Bin 5839 -> 0 bytes .../__pycache__/others.cpython-310.pyc | Bin 1562 -> 0 bytes .../__pycache__/others.cpython-311.pyc | Bin 3270 -> 0 bytes .../__pycache__/others.cpython-313.pyc | Bin 2865 -> 0 bytes .../__pycache__/others.cpython-38.pyc | Bin 1590 -> 0 bytes .../__pycache__/others.cpython-39.pyc | Bin 1586 -> 0 bytes .../packaged_function.cpython-310.pyc | Bin 3446 -> 0 bytes .../packaged_function.cpython-311.pyc | Bin 5959 -> 0 bytes .../packaged_function.cpython-313.pyc | Bin 5214 -> 0 bytes .../packaged_function.cpython-38.pyc | Bin 3354 -> 0 bytes .../packaged_function.cpython-39.pyc | Bin 3351 -> 0 bytes .../unit_cell_parser.cpython-310.pyc | Bin 6600 -> 0 bytes .../unit_cell_parser.cpython-311.pyc | Bin 8999 -> 0 bytes .../unit_cell_parser.cpython-313.pyc | Bin 8352 -> 0 bytes .../unit_cell_parser.cpython-38.pyc | Bin 6702 -> 0 bytes .../unit_cell_parser.cpython-39.pyc | Bin 6659 -> 0 bytes basic_function/chemical_knowledge.py | 280 +- basic_function/conformer_search.py | 118 +- basic_function/data_classes.py | 1226 +++---- basic_function/format_parser.py | 664 ++-- basic_function/operation.py | 936 +++--- basic_function/packaged_function.py | 204 +- basic_function/unit_cell_parser.py | 456 +-- csp.sh | 73 +- .../3rdparty/SevenNet/sevenn/__init__.py | 26 +- .../__pycache__/__init__.cpython-310.pyc | Bin 482 -> 0 bytes .../sevenn/__pycache__/_const.cpython-310.pyc | Bin 7317 -> 0 bytes .../sevenn/__pycache__/_keys.cpython-310.pyc | Bin 6386 -> 0 bytes .../atom_graph_data.cpython-310.pyc | Bin 3003 -> 0 bytes .../__pycache__/calculator.cpython-310.pyc | Bin 20783 -> 0 bytes .../__pycache__/checkpoint.cpython-310.pyc | Bin 14780 -> 0 bytes .../__pycache__/model_build.cpython-310.pyc | Bin 12071 -> 0 bytes .../sevenn/__pycache__/util.cpython-310.pyc | Bin 9907 -> 0 bytes mace-bench/3rdparty/SevenNet/sevenn/_const.py | 620 ++-- mace-bench/3rdparty/SevenNet/sevenn/_keys.py | 452 +-- .../SevenNet/sevenn/atom_graph_data.py | 150 +- .../3rdparty/SevenNet/sevenn/calculator.py | 1692 +++++----- .../3rdparty/SevenNet/sevenn/checkpoint.py | 1104 +++---- .../SevenNet/sevenn/error_recorder.py | 860 ++--- mace-bench/3rdparty/SevenNet/sevenn/logger.py | 672 ++-- .../3rdparty/SevenNet/sevenn/main/sevenn.py | 496 +-- .../SevenNet/sevenn/main/sevenn_cp.py | 184 +- .../SevenNet/sevenn/main/sevenn_get_model.py | 140 +- .../sevenn/main/sevenn_graph_build.py | 260 +- .../SevenNet/sevenn/main/sevenn_inference.py | 258 +- .../sevenn/main/sevenn_patch_lammps.py | 110 +- .../SevenNet/sevenn/main/sevenn_preset.py | 90 +- .../3rdparty/SevenNet/sevenn/model_build.py | 1112 +++---- .../nn/__pycache__/__init__.cpython-310.pyc | Bin 179 -> 0 bytes .../nn/__pycache__/activation.cpython-310.pyc | Bin 449 -> 0 bytes .../__pycache__/convolution.cpython-310.pyc | Bin 4258 -> 0 bytes .../nn/__pycache__/cue_helper.cpython-310.pyc | Bin 5559 -> 0 bytes .../edge_embedding.cpython-310.pyc | Bin 6629 -> 0 bytes .../equivariant_gate.cpython-310.pyc | Bin 2580 -> 0 bytes .../__pycache__/force_output.cpython-310.pyc | Bin 5367 -> 0 bytes .../interaction_blocks.cpython-310.pyc | Bin 1851 -> 0 bytes .../nn/__pycache__/linear.cpython-310.pyc | Bin 4878 -> 0 bytes .../node_embedding.cpython-310.pyc | Bin 3017 -> 0 bytes .../nn/__pycache__/scale.cpython-310.pyc | Bin 11977 -> 0 bytes .../self_connection.cpython-310.pyc | Bin 3838 -> 0 bytes .../nn/__pycache__/sequential.cpython-310.pyc | Bin 6182 -> 0 bytes .../nn/__pycache__/util.cpython-310.pyc | Bin 535 -> 0 bytes .../3rdparty/SevenNet/sevenn/nn/activation.py | 16 +- .../SevenNet/sevenn/nn/convolution.py | 282 +- .../3rdparty/SevenNet/sevenn/nn/cue_helper.py | 378 +-- .../SevenNet/sevenn/nn/edge_embedding.py | 434 +-- .../SevenNet/sevenn/nn/equivariant_gate.py | 122 +- .../SevenNet/sevenn/nn/force_output.py | 448 +-- .../SevenNet/sevenn/nn/interaction_blocks.py | 152 +- .../3rdparty/SevenNet/sevenn/nn/linear.py | 360 +- .../SevenNet/sevenn/nn/node_embedding.py | 182 +- .../3rdparty/SevenNet/sevenn/nn/scale.py | 774 ++--- .../SevenNet/sevenn/nn/self_connection.py | 256 +- .../3rdparty/SevenNet/sevenn/nn/sequential.py | 366 +-- .../3rdparty/SevenNet/sevenn/nn/util.py | 34 +- .../sevenn/pair_e3gnn/patch_lammps.sh | 308 +- .../3rdparty/SevenNet/sevenn/parse_input.py | 492 +-- .../__pycache__/__init__.cpython-310.pyc | Bin 184 -> 0 bytes .../backward_compatibility.cpython-310.pyc | Bin 5093 -> 0 bytes .../sevenn/scripts/backward_compatibility.py | 368 +-- .../sevenn/scripts/convert_model_modality.py | 602 ++-- .../SevenNet/sevenn/scripts/deploy.py | 296 +- .../SevenNet/sevenn/scripts/graph_build.py | 238 +- .../SevenNet/sevenn/scripts/inference.py | 454 +-- .../sevenn/scripts/processing_continue.py | 546 +-- .../sevenn/scripts/processing_dataset.py | 962 +++--- .../sevenn/scripts/processing_epoch.py | 364 +- .../3rdparty/SevenNet/sevenn/scripts/train.py | 278 +- .../3rdparty/SevenNet/sevenn/sevenn_logger.py | 12 +- .../SevenNet/sevenn/sevennet_calculator.py | 12 +- .../__pycache__/__init__.cpython-310.pyc | Bin 182 -> 0 bytes .../__pycache__/dataload.cpython-310.pyc | Bin 14649 -> 0 bytes .../train/__pycache__/dataset.cpython-310.pyc | Bin 17074 -> 0 bytes .../SevenNet/sevenn/train/atoms_dataset.py | 628 ++-- .../3rdparty/SevenNet/sevenn/train/collate.py | 82 +- .../SevenNet/sevenn/train/dataload.py | 1218 +++---- .../3rdparty/SevenNet/sevenn/train/dataset.py | 992 +++--- .../SevenNet/sevenn/train/graph_dataset.py | 1414 ++++---- .../3rdparty/SevenNet/sevenn/train/loss.py | 446 +-- .../SevenNet/sevenn/train/modal_dataset.py | 730 ++--- .../3rdparty/SevenNet/sevenn/train/optim.py | 46 +- .../3rdparty/SevenNet/sevenn/train/trainer.py | 460 +-- mace-bench/3rdparty/SevenNet/sevenn/util.py | 660 ++-- .../data/inferences/snet0_on_hfo2/errors.txt | 12 +- .../SevenNet/tests/lammps_tests/conftest.py | 48 +- .../tests/lammps_tests/test_lammps.py | 934 +++--- .../tests/unit_tests/test_calculator.py | 434 +-- .../SevenNet/tests/unit_tests/test_cli.py | 466 +-- .../SevenNet/tests/unit_tests/test_cueq.py | 564 ++-- .../SevenNet/tests/unit_tests/test_data.py | 1042 +++--- .../SevenNet/tests/unit_tests/test_errors.py | 570 ++-- .../SevenNet/tests/unit_tests/test_modal.py | 272 +- .../SevenNet/tests/unit_tests/test_model.py | 426 +-- .../tests/unit_tests/test_pretrained.py | 688 ++-- .../tests/unit_tests/test_shift_scale.py | 988 +++--- .../SevenNet/tests/unit_tests/test_train.py | 804 ++--- mace-bench/3rdparty/mace/mace/__init__.py | 10 +- .../mace/__pycache__/__init__.cpython-310.pyc | Bin 285 -> 0 bytes .../mace/__pycache__/__init__.cpython-313.pyc | Bin 315 -> 0 bytes .../__pycache__/__version__.cpython-310.pyc | Bin 221 -> 0 bytes .../__pycache__/__version__.cpython-313.pyc | Bin 224 -> 0 bytes mace-bench/3rdparty/mace/mace/__version__.py | 6 +- .../mace/mace/calculators/__init__.py | 22 +- .../__pycache__/__init__.cpython-310.pyc | Bin 412 -> 0 bytes .../__pycache__/__init__.cpython-313.pyc | Bin 418 -> 0 bytes .../foundations_models.cpython-310.pyc | Bin 12018 -> 0 bytes .../foundations_models.cpython-313.pyc | Bin 15702 -> 0 bytes .../__pycache__/lammps_mace.cpython-310.pyc | Bin 2326 -> 0 bytes .../__pycache__/lammps_mace.cpython-313.pyc | Bin 4061 -> 0 bytes .../__pycache__/mace.cpython-310.pyc | Bin 16640 -> 0 bytes .../__pycache__/mace.cpython-313.pyc | Bin 29046 -> 0 bytes .../mace/calculators/foundations_models.py | 678 ++-- .../mace/mace/calculators/lammps_mace.py | 210 +- .../mace/calculators/lammps_mliap_mace.py | 428 +-- .../3rdparty/mace/mace/calculators/mace.py | 1408 ++++---- .../cli/__pycache__/__init__.cpython-310.pyc | Bin 178 -> 0 bytes .../cli/__pycache__/__init__.cpython-313.pyc | Bin 171 -> 0 bytes .../convert_e3nn_cueq.cpython-310.pyc | Bin 5946 -> 0 bytes .../convert_e3nn_cueq.cpython-313.pyc | Bin 8523 -> 0 bytes .../visualise_train.cpython-310.pyc | Bin 11810 -> 0 bytes .../visualise_train.cpython-313.pyc | Bin 22408 -> 0 bytes .../mace/mace/cli/active_learning_md.py | 386 +-- .../mace/mace/cli/convert_cueq_e3nn.py | 416 +-- .../3rdparty/mace/mace/cli/convert_device.py | 62 +- .../mace/mace/cli/convert_e3nn_cueq.py | 408 +-- .../mace/mace/cli/create_lammps_model.py | 228 +- .../3rdparty/mace/mace/cli/eval_configs.py | 330 +- .../mace/mace/cli/fine_tuning_select.py | 988 +++--- .../3rdparty/mace/mace/cli/plot_train.py | 684 ++-- .../3rdparty/mace/mace/cli/preprocess_data.py | 600 ++-- .../3rdparty/mace/mace/cli/run_train.py | 2014 ++++++------ .../3rdparty/mace/mace/cli/select_head.py | 120 +- .../3rdparty/mace/mace/cli/visualise_train.py | 1280 ++++---- .../3rdparty/mace/mace/data/__init__.py | 80 +- .../data/__pycache__/__init__.cpython-310.pyc | Bin 934 -> 0 bytes .../data/__pycache__/__init__.cpython-313.pyc | Bin 956 -> 0 bytes .../__pycache__/atomic_data.cpython-310.pyc | Bin 5568 -> 0 bytes .../__pycache__/atomic_data.cpython-313.pyc | Bin 14075 -> 0 bytes .../__pycache__/hdf5_dataset.cpython-310.pyc | Bin 3229 -> 0 bytes .../__pycache__/hdf5_dataset.cpython-313.pyc | Bin 5190 -> 0 bytes .../__pycache__/lmdb_dataset.cpython-310.pyc | Bin 2258 -> 0 bytes .../__pycache__/lmdb_dataset.cpython-313.pyc | Bin 4216 -> 0 bytes .../__pycache__/neighborhood.cpython-310.pyc | Bin 1590 -> 0 bytes .../__pycache__/neighborhood.cpython-313.pyc | Bin 2772 -> 0 bytes .../data/__pycache__/utils.cpython-310.pyc | Bin 10534 -> 0 bytes .../data/__pycache__/utils.cpython-313.pyc | Bin 17743 -> 0 bytes .../3rdparty/mace/mace/data/atomic_data.py | 600 ++-- .../3rdparty/mace/mace/data/hdf5_dataset.py | 194 +- .../3rdparty/mace/mace/data/lmdb_dataset.py | 138 +- .../3rdparty/mace/mace/data/neighborhood.py | 132 +- mace-bench/3rdparty/mace/mace/data/utils.py | 736 ++--- .../3rdparty/mace/mace/modules/__init__.py | 200 +- .../__pycache__/__init__.cpython-310.pyc | Bin 2297 -> 0 bytes .../__pycache__/__init__.cpython-313.pyc | Bin 2584 -> 0 bytes .../__pycache__/blocks.cpython-310.pyc | Bin 20740 -> 0 bytes .../__pycache__/blocks.cpython-313.pyc | Bin 39273 -> 0 bytes .../__pycache__/irreps_tools.cpython-310.pyc | Bin 3331 -> 0 bytes .../__pycache__/irreps_tools.cpython-313.pyc | Bin 5857 -> 0 bytes .../modules/__pycache__/loss.cpython-310.pyc | Bin 13402 -> 0 bytes .../modules/__pycache__/loss.cpython-313.pyc | Bin 27560 -> 0 bytes .../__pycache__/models.cpython-310.pyc | Bin 15897 -> 0 bytes .../__pycache__/models.cpython-313.pyc | Bin 30422 -> 0 bytes .../__pycache__/radial.cpython-310.pyc | Bin 10527 -> 0 bytes .../__pycache__/radial.cpython-313.pyc | Bin 19652 -> 0 bytes .../symmetric_contraction.cpython-310.pyc | Bin 6201 -> 0 bytes .../symmetric_contraction.cpython-313.pyc | Bin 11475 -> 0 bytes .../modules/__pycache__/utils.cpython-310.pyc | Bin 12677 -> 0 bytes .../modules/__pycache__/utils.cpython-313.pyc | Bin 25100 -> 0 bytes .../__pycache__/wrapper_ops.cpython-310.pyc | Bin 4673 -> 0 bytes .../__pycache__/wrapper_ops.cpython-313.pyc | Bin 7737 -> 0 bytes .../3rdparty/mace/mace/modules/blocks.py | 1844 +++++------ .../mace/mace/modules/irreps_tools.py | 232 +- mace-bench/3rdparty/mace/mace/modules/loss.py | 1132 +++---- .../3rdparty/mace/mace/modules/models.py | 1894 +++++------ .../3rdparty/mace/mace/modules/radial.py | 716 ++-- .../mace/modules/symmetric_contraction.py | 466 +-- .../3rdparty/mace/mace/modules/utils.py | 1164 +++---- .../3rdparty/mace/mace/modules/wrapper_ops.py | 384 +-- .../3rdparty/mace/mace/tools/__init__.py | 146 +- .../__pycache__/__init__.cpython-310.pyc | Bin 1507 -> 0 bytes .../__pycache__/__init__.cpython-313.pyc | Bin 1558 -> 0 bytes .../__pycache__/arg_parser.cpython-310.pyc | Bin 15563 -> 0 bytes .../__pycache__/arg_parser.cpython-313.pyc | Bin 24109 -> 0 bytes .../arg_parser_tools.cpython-310.pyc | Bin 2874 -> 0 bytes .../arg_parser_tools.cpython-313.pyc | Bin 6300 -> 0 bytes .../mace/tools/__pycache__/cg.cpython-310.pyc | Bin 5777 -> 0 bytes .../mace/tools/__pycache__/cg.cpython-313.pyc | Bin 10042 -> 0 bytes .../__pycache__/checkpoint.cpython-310.pyc | Bin 7529 -> 0 bytes .../__pycache__/checkpoint.cpython-313.pyc | Bin 11960 -> 0 bytes .../tools/__pycache__/compile.cpython-310.pyc | Bin 3061 -> 0 bytes .../tools/__pycache__/compile.cpython-313.pyc | Bin 3881 -> 0 bytes .../__pycache__/default_keys.cpython-310.pyc | Bin 874 -> 0 bytes .../__pycache__/default_keys.cpython-313.pyc | Bin 1106 -> 0 bytes .../finetuning_utils.cpython-310.pyc | Bin 3891 -> 0 bytes .../finetuning_utils.cpython-313.pyc | Bin 10867 -> 0 bytes .../tools/__pycache__/scatter.cpython-310.pyc | Bin 2874 -> 0 bytes .../tools/__pycache__/scatter.cpython-313.pyc | Bin 5246 -> 0 bytes .../__pycache__/scripts_utils.cpython-310.pyc | Bin 24603 -> 0 bytes .../__pycache__/scripts_utils.cpython-313.pyc | Bin 45514 -> 0 bytes .../__pycache__/torch_tools.cpython-310.pyc | Bin 4686 -> 0 bytes .../__pycache__/torch_tools.cpython-313.pyc | Bin 7647 -> 0 bytes .../tools/__pycache__/train.cpython-310.pyc | Bin 14147 -> 0 bytes .../tools/__pycache__/train.cpython-313.pyc | Bin 28206 -> 0 bytes .../tools/__pycache__/utils.cpython-310.pyc | Bin 6327 -> 0 bytes .../tools/__pycache__/utils.cpython-313.pyc | Bin 11054 -> 0 bytes .../3rdparty/mace/mace/tools/arg_parser.py | 1942 +++++------ .../mace/mace/tools/arg_parser_tools.py | 244 +- mace-bench/3rdparty/mace/mace/tools/cg.py | 422 +-- .../3rdparty/mace/mace/tools/checkpoint.py | 454 +-- .../3rdparty/mace/mace/tools/compile.py | 190 +- .../3rdparty/mace/mace/tools/default_keys.py | 42 +- .../mace/tools/fairchem_dataset/__init__.py | 6 +- .../__pycache__/__init__.cpython-310.pyc | Bin 274 -> 0 bytes .../__pycache__/__init__.cpython-313.pyc | Bin 276 -> 0 bytes .../lmdb_dataset_tools.cpython-310.pyc | Bin 26544 -> 0 bytes .../lmdb_dataset_tools.cpython-313.pyc | Bin 42257 -> 0 bytes .../fairchem_dataset/lmdb_dataset_tools.py | 1908 +++++------ .../mace/mace/tools/finetuning_utils.py | 408 +-- .../mace/mace/tools/model_script_utils.py | 530 +-- .../mace/mace/tools/multihead_tools.py | 400 +-- .../mace/mace/tools/run_train_utils.py | 434 +-- .../3rdparty/mace/mace/tools/scatter.py | 224 +- .../3rdparty/mace/mace/tools/scripts_utils.py | 1776 +++++----- .../mace/mace/tools/slurm_distributed.py | 80 +- .../3rdparty/mace/mace/tools/tables_utils.py | 492 +-- .../mace/tools/torch_geometric/__init__.py | 14 +- .../__pycache__/__init__.cpython-310.pyc | Bin 438 -> 0 bytes .../__pycache__/__init__.cpython-313.pyc | Bin 441 -> 0 bytes .../__pycache__/batch.cpython-310.pyc | Bin 6866 -> 0 bytes .../__pycache__/batch.cpython-313.pyc | Bin 12410 -> 0 bytes .../__pycache__/data.cpython-310.pyc | Bin 15844 -> 0 bytes .../__pycache__/data.cpython-313.pyc | Bin 22430 -> 0 bytes .../__pycache__/dataloader.cpython-310.pyc | Bin 3767 -> 0 bytes .../__pycache__/dataloader.cpython-313.pyc | Bin 4816 -> 0 bytes .../__pycache__/dataset.cpython-310.pyc | Bin 10260 -> 0 bytes .../__pycache__/dataset.cpython-313.pyc | Bin 15667 -> 0 bytes .../__pycache__/seed.cpython-310.pyc | Bin 631 -> 0 bytes .../__pycache__/seed.cpython-313.pyc | Bin 844 -> 0 bytes .../__pycache__/utils.cpython-310.pyc | Bin 1724 -> 0 bytes .../__pycache__/utils.cpython-313.pyc | Bin 2374 -> 0 bytes .../mace/mace/tools/torch_geometric/batch.py | 514 +-- .../mace/mace/tools/torch_geometric/data.py | 882 ++--- .../mace/tools/torch_geometric/dataloader.py | 174 +- .../mace/tools/torch_geometric/dataset.py | 560 ++-- .../mace/mace/tools/torch_geometric/seed.py | 34 +- .../mace/mace/tools/torch_geometric/utils.py | 108 +- .../3rdparty/mace/mace/tools/torch_tools.py | 306 +- mace-bench/3rdparty/mace/mace/tools/train.py | 1338 ++++---- mace-bench/3rdparty/mace/mace/tools/utils.py | 332 +- .../3rdparty/mace/scripts/eval_configs.py | 12 +- .../3rdparty/mace/scripts/preprocess_data.py | 12 +- .../3rdparty/mace/scripts/run_checks.sh | 18 +- mace-bench/3rdparty/mace/scripts/run_train.py | 12 +- mace-bench/3rdparty/mace/tests/__init__.py | 6 +- .../mace/tests/modules/test_radial.py | 190 +- .../3rdparty/mace/tests/test_benchmark.py | 242 +- .../3rdparty/mace/tests/test_calculator.py | 1378 ++++---- mace-bench/3rdparty/mace/tests/test_cg.py | 24 +- .../3rdparty/mace/tests/test_compile.py | 308 +- mace-bench/3rdparty/mace/tests/test_cueq.py | 362 +- mace-bench/3rdparty/mace/tests/test_data.py | 426 +-- .../mace/tests/test_finetuning_select.py | 328 +- .../3rdparty/mace/tests/test_foundations.py | 1024 +++--- .../3rdparty/mace/tests/test_hessian.py | 108 +- .../3rdparty/mace/tests/test_lmdb_database.py | 268 +- mace-bench/3rdparty/mace/tests/test_models.py | 748 ++--- .../3rdparty/mace/tests/test_modules.py | 536 +-- .../3rdparty/mace/tests/test_multifiles.py | 2058 ++++++------ .../3rdparty/mace/tests/test_preprocess.py | 412 +-- .../3rdparty/mace/tests/test_run_train.py | 2916 ++++++++--------- .../mace/tests/test_run_train_allkeys.py | 936 +++--- .../3rdparty/mace/tests/test_schedulefree.py | 254 +- mace-bench/3rdparty/mace/tests/test_tools.py | 96 +- mace-bench/reproduce/init_7net.sh | 20 +- mace-bench/reproduce/init_mace.sh | 26 +- mace-bench/reproduce/mace_opt_new.py | 598 ++-- mace-bench/reproduce/mace_opt_origin.py | 592 ++-- mace-bench/reproduce/perf_v2_base/run_mace.sh | 7 +- mace-bench/reproduce/perf_v2_batch/opt.sh | 9 +- mace-bench/reproduce/subtest.sh | 48 +- mace-bench/reproduce/subtest_baseline.sh | 46 +- mace-bench/requirements.txt | 272 +- mace-bench/scripts/mace_opt_batch.py | 224 +- mace-bench/setup.py | 44 +- mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO | 46 +- .../src/BOMLIP_CSP.egg-info/SOURCES.txt | 38 +- .../BOMLIP_CSP.egg-info/dependency_links.txt | 2 +- .../src/BOMLIP_CSP.egg-info/top_level.txt | 2 +- mace-bench/src/batchopt/__init__.py | 58 +- .../__pycache__/__init__.cpython-310.pyc | Bin 894 -> 0 bytes .../atoms_to_graphs.cpython-310.pyc | Bin 10226 -> 0 bytes .../__pycache__/baseline.cpython-310.pyc | Bin 4440 -> 0 bytes .../__pycache__/pbc_graph.cpython-310.pyc | Bin 3374 -> 0 bytes .../pbc_graph_legacy.cpython-310.pyc | Bin 9723 -> 0 bytes .../__pycache__/relaxengine.cpython-310.pyc | Bin 27428 -> 0 bytes .../__pycache__/utils.cpython-310.pyc | Bin 3439 -> 0 bytes mace-bench/src/batchopt/atoms_to_graphs.py | 618 ++-- mace-bench/src/batchopt/baseline.py | 340 +- .../src/batchopt/extensions/__init__.py | 24 +- .../__pycache__/__init__.cpython-310.pyc | Bin 558 -> 0 bytes .../batchopt/extensions/cuda_ops/__init__.py | 182 +- .../__pycache__/__init__.cpython-310.pyc | Bin 2301 -> 0 bytes mace-bench/src/batchopt/pbc_graph.py | 314 +- mace-bench/src/batchopt/pbc_graph_legacy.py | 1126 +++---- .../src/batchopt/relaxation/__init__.py | 22 +- .../__pycache__/__init__.cpython-310.pyc | Bin 542 -> 0 bytes .../__pycache__/ase_utils.cpython-310.pyc | Bin 2889 -> 0 bytes .../__pycache__/optimizable.cpython-310.pyc | Bin 22681 -> 0 bytes .../src/batchopt/relaxation/ase_utils.py | 190 +- .../src/batchopt/relaxation/optimizable.py | 1582 ++++----- .../relaxation/optimizers/__init__.py | 24 +- .../__pycache__/__init__.cpython-310.pyc | Bin 540 -> 0 bytes .../__pycache__/bfgs_torch.cpython-310.pyc | Bin 7588 -> 0 bytes .../__pycache__/bfgsfusedls.cpython-310.pyc | Bin 19795 -> 0 bytes .../bfgslinesearch_torch.cpython-310.pyc | Bin 7251 -> 0 bytes .../__pycache__/lbfgs_torch.cpython-310.pyc | Bin 5951 -> 0 bytes .../linesearch_torch.cpython-310.pyc | Bin 10077 -> 0 bytes .../relaxation/optimizers/bfgs_torch.py | 570 ++-- .../relaxation/optimizers/bfgsfusedls.py | 1986 +++++------ mace-bench/src/batchopt/relaxengine.py | 2866 ++++++++-------- mace-bench/src/batchopt/utils.py | 240 +- mace-bench/util/env.sh | 11 +- mace-bench/util/mps_clean.sh | 19 +- mace-bench/util/mps_start.sh | 18 +- main.py | 199 +- post_process/check_match.py | 306 +- post_process/clean_table.py | 96 +- post_process/duplicate_remove.py | 492 +-- post_process/run_remove.sh | 81 +- 390 files changed, 48240 insertions(+), 48217 deletions(-) delete mode 100644 basic_function/__pycache__/CSP_function.cpython-310.pyc delete mode 100644 basic_function/__pycache__/CSP_function.cpython-311.pyc delete mode 100644 basic_function/__pycache__/CSP_function.cpython-313.pyc delete mode 100644 basic_function/__pycache__/CSP_function.cpython-38.pyc delete mode 100644 basic_function/__pycache__/CSP_function.cpython-39.pyc delete mode 100644 basic_function/__pycache__/CSP_generator_normal.cpython-310.pyc delete mode 100644 basic_function/__pycache__/CSP_generator_normal.cpython-311.pyc delete mode 100644 basic_function/__pycache__/CSP_generator_normal.cpython-313.pyc delete mode 100644 basic_function/__pycache__/CSP_generator_normal.cpython-38.pyc delete mode 100644 basic_function/__pycache__/CSP_generator_normal.cpython-39.pyc delete mode 100644 basic_function/__pycache__/chemical_knowledge.cpython-310.pyc delete mode 100644 basic_function/__pycache__/chemical_knowledge.cpython-311.pyc delete mode 100644 basic_function/__pycache__/chemical_knowledge.cpython-313.pyc delete mode 100644 basic_function/__pycache__/chemical_knowledge.cpython-38.pyc delete mode 100644 basic_function/__pycache__/chemical_knowledge.cpython-39.pyc delete mode 100644 basic_function/__pycache__/conformer_search.cpython-310.pyc delete mode 100644 basic_function/__pycache__/conformer_search.cpython-311.pyc delete mode 100644 basic_function/__pycache__/conformer_search.cpython-313.pyc delete mode 100644 basic_function/__pycache__/conformer_search.cpython-38.pyc delete mode 100644 basic_function/__pycache__/conformer_search.cpython-39.pyc delete mode 100644 basic_function/__pycache__/data_classes.cpython-310.pyc delete mode 100644 basic_function/__pycache__/data_classes.cpython-311.pyc delete mode 100644 basic_function/__pycache__/data_classes.cpython-313.pyc delete mode 100644 basic_function/__pycache__/data_classes.cpython-38.pyc delete mode 100644 basic_function/__pycache__/data_classes.cpython-39.pyc delete mode 100644 basic_function/__pycache__/descriptor.cpython-39.pyc delete mode 100644 basic_function/__pycache__/format_parser.cpython-310.pyc delete mode 100644 basic_function/__pycache__/format_parser.cpython-311.pyc delete mode 100644 basic_function/__pycache__/format_parser.cpython-313.pyc delete mode 100644 basic_function/__pycache__/format_parser.cpython-38.pyc delete mode 100644 basic_function/__pycache__/format_parser.cpython-39.pyc delete mode 100644 basic_function/__pycache__/operation.cpython-310.pyc delete mode 100644 basic_function/__pycache__/operation.cpython-311.pyc delete mode 100644 basic_function/__pycache__/operation.cpython-313.pyc delete mode 100644 basic_function/__pycache__/operation.cpython-38.pyc delete mode 100644 basic_function/__pycache__/operation.cpython-39.pyc delete mode 100644 basic_function/__pycache__/operation_new.cpython-310.pyc delete mode 100644 basic_function/__pycache__/operation_new.cpython-313.pyc delete mode 100644 basic_function/__pycache__/operation_new.cpython-38.pyc delete mode 100644 basic_function/__pycache__/operation_new.cpython-39.pyc delete mode 100644 basic_function/__pycache__/others.cpython-310.pyc delete mode 100644 basic_function/__pycache__/others.cpython-311.pyc delete mode 100644 basic_function/__pycache__/others.cpython-313.pyc delete mode 100644 basic_function/__pycache__/others.cpython-38.pyc delete mode 100644 basic_function/__pycache__/others.cpython-39.pyc delete mode 100644 basic_function/__pycache__/packaged_function.cpython-310.pyc delete mode 100644 basic_function/__pycache__/packaged_function.cpython-311.pyc delete mode 100644 basic_function/__pycache__/packaged_function.cpython-313.pyc delete mode 100644 basic_function/__pycache__/packaged_function.cpython-38.pyc delete mode 100644 basic_function/__pycache__/packaged_function.cpython-39.pyc delete mode 100644 basic_function/__pycache__/unit_cell_parser.cpython-310.pyc delete mode 100644 basic_function/__pycache__/unit_cell_parser.cpython-311.pyc delete mode 100644 basic_function/__pycache__/unit_cell_parser.cpython-313.pyc delete mode 100644 basic_function/__pycache__/unit_cell_parser.cpython-38.pyc delete mode 100644 basic_function/__pycache__/unit_cell_parser.cpython-39.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_const.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_keys.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/__pycache__/atom_graph_data.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/__pycache__/calculator.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/__pycache__/checkpoint.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/__pycache__/model_build.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/__pycache__/util.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/activation.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/convolution.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/cue_helper.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/edge_embedding.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/equivariant_gate.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/force_output.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/interaction_blocks.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/linear.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/node_embedding.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/scale.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/self_connection.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/sequential.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/util.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/backward_compatibility.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataload.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/lmdb_dataset.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/lmdb_dataset.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/utils.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/utils.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/cg.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/cg.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/compile.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/compile.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/scatter.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/scatter.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/train.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/train.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/batch.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/batch.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-313.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/utils.cpython-310.pyc delete mode 100644 mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/utils.cpython-313.pyc delete mode 100644 mace-bench/src/batchopt/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/__pycache__/atoms_to_graphs.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/__pycache__/baseline.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/__pycache__/pbc_graph.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/__pycache__/pbc_graph_legacy.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/__pycache__/utils.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/extensions/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/__pycache__/optimizable.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/optimizers/__pycache__/__init__.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgs_torch.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgsfusedls.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/optimizers/__pycache__/bfgslinesearch_torch.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/optimizers/__pycache__/lbfgs_torch.cpython-310.pyc delete mode 100644 mace-bench/src/batchopt/relaxation/optimizers/__pycache__/linesearch_torch.cpython-310.pyc diff --git a/README.md b/README.md index de49402..3882afe 100644 --- a/README.md +++ b/README.md @@ -1,103 +1,123 @@ -# BOMLIP-CSP - -An open-source Python framework that integrates machine learning interatomic -potentials (MLIPs) with a tailored batched optimization strategy, enabling rapid, -unbiased structure prediction across the full density range - -## Perform the complete CSP process - -```sh -git clone https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP.git --recursive && cd BOMLIP-CSP - -conda create -n BOMLIP_CSP python=3.10 -y && conda activate BOMLIP_CSP -cd BOMLIP-CSP -top_dir=$(pwd) -cd $top_dir/mace-bench -./reproduce/init_mace.sh && source util/env.sh -sudo ./util/mps_start.sh - -cd $top_dir -./csp.sh - -sudo ./util/mps_clean.sh -``` -## Reproduce mace batch opt speedup. - -```sh -#!/bin/bash - -git clone https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP.git --recursive && cd BOMLIP-CSP -conda create -n BOMLIP_CSP python=3.10 -y && conda activate BOMLIP_CSP -cd BOMLIP-CSP/mace-bench - -# initialize mace env. -./reproduce/init_mace.sh && source util/env.sh -sudo ./util/mps_start.sh -cd reproduce - -# run baseline sub-test -./subtest_baseline.sh - -# run baseline mixed test -cd perf_v2_base -./run_mace.sh - -# run BOMLIP_CSP sub-test -cd ../ -./subtest.sh - -# run BOMLIP_CSP mixed test -cd perf_v2_batch -./opt.sh - -# clean mps -./util/mps_clean.sh - -``` - -## If you want to configure the 7net environment. - -```sh -#!/bin/bash -conda create -n 7net-cueq python=3.10 -y && conda activate 7net-cueq -./reproduce/init_7net.sh && source util/env.sh - -# Use a fixed batch size for structural optimization -python ../../scripts/mace_opt_batch.py --target_folder "../../data/perf_v2" \ - --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers 4 \ - --batch_size 2 --max_steps 3000 --filter1 UnitCellFilter \ - --filter2 UnitCellFilter --optimizer1 BFGSFusedLS --optimizer2 BFGS \ - --num_threads 2 --cueq true --use_ordered_files true --model sevennet -``` - -## License - -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. - -### Third-party Dependencies - -This project includes dependencies with various licenses: -- **MACE**: MIT License (compatible) -- **FairChem**: MIT License (compatible) -- **SevenNet**: GPL v3 License (Note: GPL is a copyleft license) - -### License Compatibility Notice - -**Important**: This project can run completely without relying on SevenNet. -This project includes SevenNet as an optional dependency, which is licensed under GPL v3. -If you use SevenNet functionality, you should be aware of the GPL licensing requirements. -For commercial use or to avoid GPL restrictions, consider using only the MACE calculator -functionality. - -## Citation - -If you use this code in your research, please cite: - -```bibtex -@software{BOMLIP_CSP, - author = {Chengxi Zhao, Zhaojia Ma, Dingrui Fan}, - title = {BOMLIP_CSP: Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction}, - year = {2025}, - url = {https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP} -} +# BOMLIP-CSP + +An open-source Python framework that integrates machine learning interatomic +potentials (MLIPs) with a tailored batched optimization strategy, enabling rapid, +unbiased structure prediction across the full density range + +## Perform a complete CSP process + +```sh +git clone https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP.git --recursive && cd BOMLIP-CSP + +conda create -n BOMLIP_CSP python=3.10 -y && conda activate BOMLIP_CSP +cd BOMLIP-CSP/mace-bench +./reproduce/init_mace.sh && source util/env.sh +sudo ./util/mps_start.sh + +cd .. +./csp.sh + +sudo ./util/mps_clean.sh +``` + +## Perform conformer search / structure generation / structure optimization separately + +In csp.sh, the argument --mode controls the jobs to do. +Use conformer_only to perform conformer search task only. +```sh +python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ + --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ + --num_generation 100 --generate_conformers 20 --use_conformers 4 --mode conformer_only > generate.log 2>&1 +``` +Or use structure_only to perform structure generation only. +```sh +python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ + --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ + --num_generation 100 --generate_conformers 20 --use_conformers 4 --mode structure_only > generate.log 2>&1 +``` +Structure optimization is done by a seperate command +```sh +python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" ... +``` +Change this command into a comment if you don't want to do that. + +## Reproduce mace batch opt speedup. + +```sh +#!/bin/bash + +git clone https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP.git --recursive && cd BOMLIP-CSP +conda create -n BOMLIP_CSP python=3.10 -y && conda activate BOMLIP_CSP +cd BOMLIP-CSP/mace-bench + +# initialize mace env. +./reproduce/init_mace.sh && source util/env.sh +sudo ./util/mps_start.sh +cd reproduce + +# run baseline sub-test +./subtest_baseline.sh + +# run baseline mixed test +cd perf_v2_base +./run_mace.sh + +# run BOMLIP_CSP sub-test +cd ../ +./subtest.sh + +# run BOMLIP_CSP mixed test +cd perf_v2_batch +./opt.sh + +# clean mps +./util/mps_clean.sh + +``` + +## If you want to configure the 7net environment. + +```sh +#!/bin/bash +conda create -n 7net-cueq python=3.10 -y && conda activate 7net-cueq +./reproduce/init_7net.sh && source util/env.sh + +# Use a fixed batch size for structural optimization +python ../../scripts/mace_opt_batch.py --target_folder "../../data/perf_v2" \ + --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers 4 \ + --batch_size 2 --max_steps 3000 --filter1 UnitCellFilter \ + --filter2 UnitCellFilter --optimizer1 BFGSFusedLS --optimizer2 BFGS \ + --num_threads 2 --cueq true --use_ordered_files true --model sevennet +``` + +## License + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +### Third-party Dependencies + +This project includes dependencies with various licenses: +- **MACE**: MIT License (compatible) +- **FairChem**: MIT License (compatible) +- **SevenNet**: GPL v3 License (Note: GPL is a copyleft license) + +### License Compatibility Notice + +**Important**: This project can run completely without relying on SevenNet. +This project includes SevenNet as an optional dependency, which is licensed under GPL v3. +If you use SevenNet functionality, you should be aware of the GPL licensing requirements. +For commercial use or to avoid GPL restrictions, consider using only the MACE calculator +functionality. + +## Citation + +If you use this code in your research, please cite: + +```bibtex +@software{BOMLIP_CSP, + author = {Chengxi Zhao, Zhaojia Ma, Dingrui Fan}, + title = {BOMLIP_CSP: Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction}, + year = {2025}, + url = {https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP} +} ``` \ No newline at end of file diff --git a/basic_function/CSP_generator_normal.py b/basic_function/CSP_generator_normal.py index 146cc70..3c70e18 100644 --- a/basic_function/CSP_generator_normal.py +++ b/basic_function/CSP_generator_normal.py @@ -1,615 +1,615 @@ -""" -This module provides the CrystalGenerator class for crystal structure prediction (CSP). - -It uses a Sobol sequence-based random search to generate candidate crystal -structures for a given set of molecules and space group, followed by a crude -packing minimization. -""" - -# Standard library imports -import itertools -from typing import List, Tuple, Optional, Any - -# Third-party imports -import numpy as np -from scipy.spatial import cKDTree -from scipy.stats import qmc - -# Local application/library specific imports -from basic_function import chemical_knowledge -from basic_function import operation -from basic_function import data_classes - -# Module-level constants for better readability and maintenance -_VDW_CLASH_FACTOR = 0.9 # Scaling factor for van der Waals radii in collision checks -_SUPERCELL_RANGE = np.arange(-2, 3) # Range for generating supercell translations - - -class CrystalGenerator: - """ - Generates candidate crystal structures for Crystal Structure Prediction (CSP). - - The generator takes a list of unique molecules and a space group, then searches - the conformational space of cell parameters and molecular orientations to - produce tightly packed, sterically plausible crystal structures. - """ - - def __init__(self, - molecules: list[data_classes.Molecule], - space_group: int = 1, - angles: tuple[float, float] = (45.0, 135.0)): - """ - Initializes the CrystalGenerator. - - Args: - molecules: A list of molecule objects (from data_classes) that will form - the asymmetric unit. - space_group: The international space group number (e.g., 1 for P1). - angles: A tuple (min, max) defining the range for sampling cell angles in degrees. - """ - if not (0 < space_group <= 230): - raise ValueError("Space group must be an integer between 1 and 230.") - - self.molecules = molecules - self.space_group_number = space_group - self.angle_sampling_range = angles - - # Derived properties from the space group - self.symmetry_ops = chemical_knowledge.space_group[self.space_group_number][0] - self.point_group = chemical_knowledge.space_group[self.space_group_number][2] - - # Calculate counts and dimensions - self.num_asym_molecules = len(self.molecules) - self.num_total_molecules = len(self.symmetry_ops) * self.num_asym_molecules - self.atomic_counts_per_molecule = self._calculate_atomic_counts() - - # Determine search space dimensionality - self.search_dimensions, self.search_dimension_shape = self._determine_search_dimensions() - - # Pre-calculate molecular and crystal properties - self.max_vdw_radius = self._find_max_vdw_radius() - self.estimated_packed_volume = self._calculate_estimated_packed_volume() - self._orient_molecules() - - # Pre-generate supercell translation vectors, sorted by distance from origin - self.supercell_frac_translations = np.array( - sorted(list(itertools.product(_SUPERCELL_RANGE, repeat=3)), - key=lambda p: p[0]**2 + p[1]**2 + p[2]**2) - ) - - def _calculate_atomic_counts(self) -> list[int]: - """Calculates the number of atoms for each molecule in the asymmetric unit.""" - return [len(mol.atoms) for mol in self.molecules] - - def _orient_molecules(self) -> None: - """ - Orients each molecule to a standardized principal axis frame. - This reduces the rotational search space. For details, see: http://sobereva.com/426 - """ - for i, molecule in enumerate(self.molecules): - if len(molecule.atoms) > 1: - self.molecules[i] = operation.orient_molecule(molecule) - - def _find_max_vdw_radius(self) -> float: - """Finds the maximum van der Waals radius among all atoms in all molecules.""" - vdw_max = 0.0 - for molecule in self.molecules: - elements, _ = molecule.get_ele_and_cart() - for ele in set(elements): - vdw_max = max(vdw_max, chemical_knowledge.element_vdw_radii[ele]) - return vdw_max - - def _determine_search_dimensions(self) -> tuple[int, list[int]]: - """ - Determines the dimensionality of the search space. - - The search space consists of: - - 3 dimensions for cell angles (alpha, beta, gamma) - - 3 dimensions for cell lengths (a, b, c) - - 3 * N dimensions for molecular translations (x, y, z for each of N molecules) - - 3 * N dimensions for molecular rotations (Euler angles for each of N molecules) - - Returns: - A tuple containing the total dimension count and a list detailing the - breakdown of dimensions. - """ - dim_cell_lengths = 3 - dim_cell_angles = 3 - dim_translations = 3 * self.num_asym_molecules - dim_rotations = 3 * self.num_asym_molecules - total_dimension = dim_cell_lengths + dim_cell_angles + dim_translations + dim_rotations - shape = [dim_cell_lengths, dim_cell_angles, dim_translations, dim_rotations] - return total_dimension, shape - - def _calculate_estimated_packed_volume(self) -> float: - """ - Estimates the total volume of all molecules in the unit cell based on their - van der Waals radii. This is used for heuristics during generation. - """ - total_volume = 0.0 - for molecule in self.molecules: - elements, _ = molecule.get_ele_and_cart() - vdws = np.array([chemical_knowledge.element_vdw_radii[x] for x in elements]) - volumes = (4 / 3) * np.pi * vdws**3 - total_volume += np.sum(volumes) - return total_volume * len(self.symmetry_ops) # Multiply by Z - - def _map_random_to_angle(self, value: float) -> float: - """ - Maps a random number from [0, 1] to an angle in the specified range. - - This uses an arcsin distribution to more densely sample angles near the - midpoint of the range, which can be more efficient if orthogonal angles - are more likely. - """ - min_angle, max_angle = self.angle_sampling_range - angle_range = max_angle - min_angle - # A non-linear mapping to bias sampling - a = np.arcsin(2 * value - 1.0) / np.pi - return (0.5 + a) * angle_range + min_angle - - def _get_cell_angles_from_vector(self, vector: np.ndarray) -> tuple[float, float, float]: - """ - Determines the three cell angles based on a 3D random vector, respecting - the constraints of the crystal's point group. - """ - angle_candidates = [self._map_random_to_angle(v) for v in vector] - - if self.point_group == "Triclinic": - return angle_candidates[0], angle_candidates[1], angle_candidates[2] - if self.point_group == "Monoclinic": - return 90.0, angle_candidates[1], 90.0 - if self.point_group in ["Orthorhombic", "Tetragonal", "Cubic"]: - return 90.0, 90.0, 90.0 - if self.point_group == "Hexagonal": - return 90.0, 90.0, 120.0 - if self.point_group == "Trigonal": - # For rhombohedral lattices described in hexagonal axes, angles are fixed. - # This assumes a rhombohedral setting where angles are variable and equal. - return angle_candidates[0], angle_candidates[0], angle_candidates[0] - # Fallback for safety, though should be covered by above cases - return 90.0, 90.0, 90.0 - - - def _get_cell_lengths_from_vector(self, - vector: np.ndarray, - cell_angles: list[float], - rotated_molecules_cart: list[np.ndarray] - ) -> tuple[float, float, float]: - """ - Determines the three cell lengths based on a 3D random vector and molecule size. - - The method first calculates the minimum bounding box for the rotated molecules, - then scales the lengths based on the random vector to explore larger volumes. - """ - # Estimate minimum cell lengths to avoid self-collision within a molecule - min_lengths = np.zeros(3) - conversion_matrix = operation.c2f_matrix([[1, 1, 1], cell_angles]) - for cart_coords in rotated_molecules_cart: - frac_coords = cart_coords @ conversion_matrix - max_vals = np.max(frac_coords, axis=0) - min_vals = np.min(frac_coords, axis=0) - min_lengths = np.maximum(min_lengths, max_vals - min_vals) - - # Add a buffer based on the largest VdW radius - min_lengths += self.max_vdw_radius * 2 - - # Scale the lengths using the random vector to explore the search space - a = min_lengths[0] + vector[0] * (self.num_total_molecules * min_lengths[0]) - b = min_lengths[1] + vector[1] * (self.num_total_molecules * min_lengths[1]) - c = min_lengths[2] + vector[2] * (self.num_total_molecules * min_lengths[2]) - - # Apply constraints based on the point group - if self.point_group in ["Tetragonal", "Hexagonal"]: - return a, a, c - if self.point_group in ["Trigonal", "Cubic"]: - return a, a, a - return a, b, c - - def _check_for_collisions(self, - atom_elements: np.ndarray, - atom_cart_coords: np.ndarray - ) -> bool: - """ - Performs a steric clash test for the generated structure. - - It checks for intermolecular distances that are smaller than the sum of - the van der Waals radii (with a tolerance factor). - - Args: - atom_elements: A numpy array of element symbols for all atoms in the supercell. - atom_cart_coords: A numpy array of Cartesian coordinates for all atoms. - - Returns: - True if a collision is detected, False otherwise. - """ - vdw_radii = np.array([chemical_knowledge.element_vdw_radii[el.item()] for el in atom_elements]) - - start_index = 0 - for i in range(self.num_asym_molecules): - # Define the asymmetric unit molecule to check against its environment - num_atoms_in_mol = self.atomic_counts_per_molecule[i] - end_index = start_index + num_atoms_in_mol - - asym_mol_coords = atom_cart_coords[start_index:end_index] - asym_mol_vdws = vdw_radii[start_index:end_index] - - # The rest of the atoms form the environment - neighbor_coords = atom_cart_coords[end_index:] - neighbor_vdws = vdw_radii[end_index:] - - # A coarse filter using a bounding box around the asymmetric molecule - mol_min = np.min(asym_mol_coords, axis=0) - self.max_vdw_radius * 2 - mol_max = np.max(asym_mol_coords, axis=0) + self.max_vdw_radius * 2 - box_indices = np.all((neighbor_coords > mol_min) & (neighbor_coords < mol_max), axis=1) - - if not np.any(box_indices): - # Move to the next molecule in the asymmetric unit - num_atoms_in_supercell_mol = num_atoms_in_mol * len(self.supercell_frac_translations) * len(self.symmetry_ops) - start_index += num_atoms_in_supercell_mol - continue - - nearby_neighbor_coords = neighbor_coords[box_indices] - nearby_neighbor_vdws = neighbor_vdws[box_indices] - - # Use KD-Trees for efficient nearest-neighbor search - tree_asym = cKDTree(asym_mol_coords, compact_nodes=False, balanced_tree=False) - tree_neighbors = cKDTree(nearby_neighbor_coords, compact_nodes=False, balanced_tree=False) - - # Find all pairs of atoms within the maximum possible interaction distance - possible_contacts = tree_asym.query_ball_tree(tree_neighbors, self.max_vdw_radius * 2) - - for j, neighbor_indices in enumerate(possible_contacts): - if not neighbor_indices: - continue - - # Check precise distances for potential contacts - diff = asym_mol_coords[j] - nearby_neighbor_coords[neighbor_indices] - # einsum is a fast way to compute squared norms row-wise - distances = np.sqrt(np.einsum('ij,ij->i', diff, diff)) - - sum_radii = (asym_mol_vdws[j] + nearby_neighbor_vdws[neighbor_indices]) * _VDW_CLASH_FACTOR - - if np.any(distances < sum_radii): - return True # Collision detected - - # Update start index for the next asymmetric molecule - num_atoms_in_supercell_mol = num_atoms_in_mol * len(self.supercell_frac_translations) * len(self.symmetry_ops) - start_index += num_atoms_in_supercell_mol - - return False # No collisions found - - - def _shrink_cell_dimensions(self, a: float, b: float, c: float, locked_dims: list[bool] - ) -> tuple[float, float, float, list[int]]: - """ - Shrinks the crystal cell along the longest unlocked dimension by 1 Angstrom. - This is a crude optimization step to pack the molecules more tightly. - - Args: - a, b, c: Current cell lengths. - locked_dims: A boolean list [a, b, c] where True means the dimension - cannot be shrunk further. - - Returns: - A tuple of (new_a, new_b, new_c, last_change_indices). - """ - lengths = [val for val, is_locked in zip([a, b, c], locked_dims) if not is_locked] - if not lengths: - return a, b, c, [] # All dimensions are locked - - max_length = max(lengths) - last_change = [] - - # Logic to shrink the largest dimension(s) while respecting point group constraints - if self.point_group in ["Triclinic", "Monoclinic", "Orthorhombic"]: - if a == max_length and not locked_dims[0]: - a -= 1.0 - last_change = [0] - elif b == max_length and not locked_dims[1]: - b -= 1.0 - last_change = [1] - elif c == max_length and not locked_dims[2]: - c -= 1.0 - last_change = [2] - elif self.point_group in ["Tetragonal", "Hexagonal"]: - if (a == max_length or b == max_length) and not locked_dims[0]: - a -= 1.0 - b -= 1.0 - last_change = [0, 1] - elif c == max_length and not locked_dims[2]: - c -= 1.0 - last_change = [2] - elif self.point_group in ["Trigonal", "Cubic"]: - if (a == max_length or b == max_length or c == max_length) and not locked_dims[0]: - a -= 1.0 - b -= 1.0 - c -= 1.0 - last_change = [0, 1, 2] - - return a, b, c, last_change - - def _setup_crystal_from_vector(self, vector: np.ndarray - ) -> tuple[Optional[list], Optional[list[np.ndarray]], Optional[list[Any]]]: - """ - Performs the initial setup of a crystal structure from a random vector. - This includes setting angles, rotating molecules, and setting initial lengths. - This helper is used by both `generate` and `_generate_from_vector`. - """ - # Unpack the Sobol vector into its components for cell parameters and molecules - # Slicing indices based on the defined search space shape - s = self.search_dimension_shape - cell_angle_seed = vector[0:s[1]] - cell_length_seed = vector[s[1]:s[1]+s[0]] - move_part_seed = vector[s[1]+s[0] : s[1]+s[0]+s[2]] - rotate_part_seed = vector[s[1]+s[0]+s[2]:] - - # 1. Set cell angles - alpha, beta, gamma = self._get_cell_angles_from_vector(cell_angle_seed) - cell_angles = [alpha, beta, gamma] - - # Check for valid cell matrix from angles - ca, cb, cg = np.cos(np.deg2rad([alpha, beta, gamma])) - volume_sqrt_term = 1 - ca**2 - cb**2 - cg**2 + 2 * ca * cb * cg - if volume_sqrt_term <= 0: - print("Failed: Invalid angles cannot form a valid parallelepiped.") - return None, None, None - - # 2. Rotate molecules - rotated_molecules_cart = [] - rotated_molecules_ele = [] - rotate_vectors = rotate_part_seed.reshape(-1, 3) - for r_vec, molecule in zip(rotate_vectors, self.molecules): - elements, cart_coords = molecule.get_ele_and_cart() - rotation_matrix = operation.get_rotate_matrix(r_vec) - rotated_cart = cart_coords @ rotation_matrix - rotated_molecules_cart.append(rotated_cart) - rotated_molecules_ele.append(elements) - - # 3. Set initial cell lengths - a, b, c = self._get_cell_lengths_from_vector(cell_length_seed, cell_angles, rotated_molecules_cart) - cell_lengths = [a, b, c] - - crystal_params = [cell_lengths, cell_angles, move_part_seed, rotated_molecules_cart, rotated_molecules_ele] - - return crystal_params, volume_sqrt_term, rotate_part_seed - - def _build_supercell_for_clash_test(self, - cell_params: list, - rotated_molecules_cart: list[np.ndarray], - rotated_molecules_ele: list[list[str]], - move_part_seed: np.ndarray - ) -> tuple[np.ndarray, np.ndarray, list, list]: - """ - Builds a supercell and returns all atomic elements and coordinates for clash testing. - This version correctly handles asymmetric units with multiple, different-sized molecules. - """ - f2c_matrix = operation.f2c_matrix(cell_params) - c2f_matrix = operation.c2f_matrix(cell_params) - supercell_cart_translations = self.supercell_frac_translations @ f2c_matrix - - all_asym_frac_coords = [] - all_asym_elements = [] - - # Use lists to collect 2D blocks of coordinates and elements. This is efficient. - sc_cart_blocks = [] - sc_ele_blocks = [] - - for i, cart_coords in enumerate(rotated_molecules_cart): - # Apply translation vector to this molecule's fractional coordinates - trans_vector = move_part_seed[i * 3:(i + 1) * 3] - frac_coords = cart_coords @ c2f_matrix + trans_vector - - all_asym_frac_coords.append(frac_coords) - all_asym_elements.append(rotated_molecules_ele[i]) - - # Apply symmetry operations - symm_cart_coords = operation.apply_SYMM(frac_coords, self.symmetry_ops) @ f2c_matrix - symm_elements_list = [rotated_molecules_ele[i]] * len(self.symmetry_ops) - - # Center molecules that were moved across periodic boundaries - centroid_frac = np.mean(frac_coords, axis=0) - centroids_all_symm = operation.apply_SYMM(centroid_frac, self.symmetry_ops) - for j, cent in enumerate(centroids_all_symm): - move_to_center = (np.mod(cent, 1) - cent) @ f2c_matrix - symm_cart_coords[j] += move_to_center - - # --- Core Correction Logic --- - # 1. Create the full block of atoms for the current molecule type by applying all - # supercell translations. - mol_block_cart_temp = [] - for translation_vec in supercell_cart_translations: - # Adding the translation vector to all symmetry-equivalent molecules - translated_coords = symm_cart_coords + translation_vec - # Reshape to a flat (N_atoms * N_symm, 3) 2D array and append - mol_block_cart_temp.append(translated_coords.reshape(-1, 3)) - - # 2. Stack all translated blocks for this molecule type into a single 2D array - sc_cart_blocks.append(np.vstack(mol_block_cart_temp)) - - # 3. Handle the corresponding elements, ensuring they are flattened correctly - num_translations = len(self.supercell_frac_translations) - ele_block = np.array(symm_elements_list * num_translations).reshape(-1, 1) - sc_ele_blocks.append(ele_block) - - # After iterating through all molecule types, stack their respective complete blocks - final_sc_cart = np.vstack(sc_cart_blocks) - final_sc_ele = np.vstack(sc_ele_blocks) - - return final_sc_cart, final_sc_ele, all_asym_frac_coords, all_asym_elements - - def _create_final_crystal_object(self, - cell_params: list, - asym_frac_coords: list, - asym_elements: list, - seed: Any - ) -> data_classes.Crystal: - """Creates the final Crystal object from the successful structure.""" - - flat_elements = np.concatenate(asym_elements, axis=0).reshape(-1, 1) - flat_frac_coords = np.concatenate(asym_frac_coords, axis=0).reshape(-1, 3) - - atoms = [] - for ele, frac in zip(flat_elements, flat_frac_coords): - atoms.append(data_classes.Atom(element=ele.item(), frac_xyz=frac)) - - return data_classes.Crystal( - cell_para=cell_params, - atoms=atoms, - comment=str(seed), - system_name=str(seed), - space_group=self.space_group_number, - SYMM=self.symmetry_ops - ) - - def generate(self, - seed: Any = "unknown", - test: bool = False, - densely_pack_method: bool = False, - frame_tolerance: float = 1.5 - ) -> Optional[data_classes.Crystal]: - """ - The main generation method. - - Uses a Sobol sequence to get a random vector, then attempts to build and - pack a crystal structure through an iterative shrinking process. - - Args: - seed: A seed for the Sobol sequence generator. If "unknown", an error is raised. - test: A flag for enabling verbose test-mode output (prints cycle number). - densely_pack_method: If True, applies a heuristic to shrink very large - initial volumes. - frame_tolerance: Tolerance for checking if the final structure is a 2D slab. - - Returns: - A `data_classes.Crystal` object if a valid structure is found, otherwise `None`. - """ - if seed == "unknown": - raise ValueError("A seed must be provided for the Sobol generator.") - - sobol_gen = qmc.Sobol(d=self.search_dimensions, seed=seed) - initial_vector = sobol_gen.random(n=1).flatten() - - setup_result, volume_sqrt_term, _ = self._setup_crystal_from_vector(initial_vector) - if setup_result is None: - return None # Invalid initial angles - - cell_lengths, cell_angles, move_part_seed, rot_carts, rot_eles = setup_result - a, b, c = cell_lengths - alpha, beta, gamma = cell_angles - - # Heuristic to shrink extremely sparse initial structures - if densely_pack_method: - crystal_volume = a * b * c * np.sqrt(volume_sqrt_term) - if crystal_volume > self.estimated_packed_volume * 20: - c = self.estimated_packed_volume * 20 / (a * b * np.sqrt(volume_sqrt_term)) - - locked_dims = [False, False, False] - old_a, old_b, old_c = a, b, c - - for cycle_no in range(1001): - if cycle_no == 1001: - print(f"Stopping: Max optimization cycles reached. Seed: {seed}") - return None - - if a < 0 or b < 0 or c < 0: - print(f"BUG: Negative cell dimension. sg={self.space_group_number}, seed={seed}") - return None - - if test: - print(f"Cycle: {cycle_no}") - - cell_params = [[a, b, c], [alpha, beta, gamma]] - - sc_cart, sc_ele, asym_fracs, asym_eles = self._build_supercell_for_clash_test( - cell_params, rot_carts, rot_eles, move_part_seed - ) - - has_collision = self._check_for_collisions(sc_ele, sc_cart) - - if has_collision: - if cycle_no == 0: - print(f"Failed: Initial structure has collisions. Seed: {seed}") - return None - - # Collision occurred, so revert to last good state and lock the changed dimension - a, b, c = old_a, old_b, old_c - for dim_idx in last_change: - locked_dims[dim_idx] = True - else: - # No collision, this is a valid (though maybe not dense) structure. - # Check if optimization is finished (all dimensions are locked). - if cycle_no > 0 and all(locked_dims): - final_crystal = self._create_final_crystal_object(cell_params, asym_fracs, asym_eles, seed) - - # Final check to filter out 2D slab-like structures - if not operation.detect_is_frame_vdw_new(final_crystal, tolerance=frame_tolerance): - print(f"Failed: Generated structure is a 2D slab. Seed: {seed}") - return None - - print(f"Success: Generated a valid crystal structure. Seed: {seed}") - return final_crystal - - # If no collision and not finished, save current state and shrink further - old_a, old_b, old_c = a, b, c - a, b, c, last_change = self._shrink_cell_dimensions(a, b, c, locked_dims) - - # ============================================================================== - # Test-related functions, kept for compatibility, marked as internal. - # ============================================================================== - - def _generate_from_vector(self, - seed_vector: np.ndarray, - frame_tolerance: float = 1.5 - ) -> Optional[data_classes.Crystal]: - """ - Generates a single crystal structure directly from a vector, without optimization. - This is an internal method intended for testing and analysis. - Original name: generate_by_vector_2. - - Args: - seed_vector: A numpy array of shape (self.search_dimensions,) defining the structure. - frame_tolerance: Tolerance for checking if the final structure is a 2D slab. - - Returns: - A `data_classes.Crystal` object if valid, otherwise `None`. - """ - if not isinstance(seed_vector, np.ndarray): - raise TypeError("seed_vector must be a numpy array.") - - expected_len = self.search_dimensions - if len(seed_vector) != expected_len: - raise ValueError(f"Length of seed_vector must be {expected_len}, got {len(seed_vector)}.") - - setup_result, _, _ = self._setup_crystal_from_vector(seed_vector) - if setup_result is None: - return None # Invalid initial angles - - cell_lengths, cell_angles, move_part_seed, rot_carts, rot_eles = setup_result - cell_params = [cell_lengths, cell_angles] - - sc_cart, sc_ele, asym_fracs, asym_eles = self._build_supercell_for_clash_test( - cell_params, rot_carts, rot_eles, move_part_seed - ) - - if self._check_for_collisions(sc_ele, sc_cart): - print("Failed: Structure from vector has collisions.") - return None - - generated_crystal = self._create_final_crystal_object( - cell_params, asym_fracs, asym_eles, seed="from_vector" - ) - - # Optional: Keep the slab check for consistency - # if not operation.detect_is_frame_vdw_new(generated_crystal, tolerance=frame_tolerance): - # print("Failed: Generated structure is a 2D slab.") - # return None - - return generated_crystal - - def _is_valid_vector(self, seed_vector: np.ndarray) -> bool: - """ - Checks if a given vector produces a valid, collision-free structure. - Internal method for testing. - """ +""" +This module provides the CrystalGenerator class for crystal structure prediction (CSP). + +It uses a Sobol sequence-based random search to generate candidate crystal +structures for a given set of molecules and space group, followed by a crude +packing minimization. +""" + +# Standard library imports +import itertools +from typing import List, Tuple, Optional, Any + +# Third-party imports +import numpy as np +from scipy.spatial import cKDTree +from scipy.stats import qmc + +# Local application/library specific imports +from basic_function import chemical_knowledge +from basic_function import operation +from basic_function import data_classes + +# Module-level constants for better readability and maintenance +_VDW_CLASH_FACTOR = 0.9 # Scaling factor for van der Waals radii in collision checks +_SUPERCELL_RANGE = np.arange(-2, 3) # Range for generating supercell translations + + +class CrystalGenerator: + """ + Generates candidate crystal structures for Crystal Structure Prediction (CSP). + + The generator takes a list of unique molecules and a space group, then searches + the conformational space of cell parameters and molecular orientations to + produce tightly packed, sterically plausible crystal structures. + """ + + def __init__(self, + molecules: list[data_classes.Molecule], + space_group: int = 1, + angles: tuple[float, float] = (45.0, 135.0)): + """ + Initializes the CrystalGenerator. + + Args: + molecules: A list of molecule objects (from data_classes) that will form + the asymmetric unit. + space_group: The international space group number (e.g., 1 for P1). + angles: A tuple (min, max) defining the range for sampling cell angles in degrees. + """ + if not (0 < space_group <= 230): + raise ValueError("Space group must be an integer between 1 and 230.") + + self.molecules = molecules + self.space_group_number = space_group + self.angle_sampling_range = angles + + # Derived properties from the space group + self.symmetry_ops = chemical_knowledge.space_group[self.space_group_number][0] + self.point_group = chemical_knowledge.space_group[self.space_group_number][2] + + # Calculate counts and dimensions + self.num_asym_molecules = len(self.molecules) + self.num_total_molecules = len(self.symmetry_ops) * self.num_asym_molecules + self.atomic_counts_per_molecule = self._calculate_atomic_counts() + + # Determine search space dimensionality + self.search_dimensions, self.search_dimension_shape = self._determine_search_dimensions() + + # Pre-calculate molecular and crystal properties + self.max_vdw_radius = self._find_max_vdw_radius() + self.estimated_packed_volume = self._calculate_estimated_packed_volume() + self._orient_molecules() + + # Pre-generate supercell translation vectors, sorted by distance from origin + self.supercell_frac_translations = np.array( + sorted(list(itertools.product(_SUPERCELL_RANGE, repeat=3)), + key=lambda p: p[0]**2 + p[1]**2 + p[2]**2) + ) + + def _calculate_atomic_counts(self) -> list[int]: + """Calculates the number of atoms for each molecule in the asymmetric unit.""" + return [len(mol.atoms) for mol in self.molecules] + + def _orient_molecules(self) -> None: + """ + Orients each molecule to a standardized principal axis frame. + This reduces the rotational search space. For details, see: http://sobereva.com/426 + """ + for i, molecule in enumerate(self.molecules): + if len(molecule.atoms) > 1: + self.molecules[i] = operation.orient_molecule(molecule) + + def _find_max_vdw_radius(self) -> float: + """Finds the maximum van der Waals radius among all atoms in all molecules.""" + vdw_max = 0.0 + for molecule in self.molecules: + elements, _ = molecule.get_ele_and_cart() + for ele in set(elements): + vdw_max = max(vdw_max, chemical_knowledge.element_vdw_radii[ele]) + return vdw_max + + def _determine_search_dimensions(self) -> tuple[int, list[int]]: + """ + Determines the dimensionality of the search space. + + The search space consists of: + - 3 dimensions for cell angles (alpha, beta, gamma) + - 3 dimensions for cell lengths (a, b, c) + - 3 * N dimensions for molecular translations (x, y, z for each of N molecules) + - 3 * N dimensions for molecular rotations (Euler angles for each of N molecules) + + Returns: + A tuple containing the total dimension count and a list detailing the + breakdown of dimensions. + """ + dim_cell_lengths = 3 + dim_cell_angles = 3 + dim_translations = 3 * self.num_asym_molecules + dim_rotations = 3 * self.num_asym_molecules + total_dimension = dim_cell_lengths + dim_cell_angles + dim_translations + dim_rotations + shape = [dim_cell_lengths, dim_cell_angles, dim_translations, dim_rotations] + return total_dimension, shape + + def _calculate_estimated_packed_volume(self) -> float: + """ + Estimates the total volume of all molecules in the unit cell based on their + van der Waals radii. This is used for heuristics during generation. + """ + total_volume = 0.0 + for molecule in self.molecules: + elements, _ = molecule.get_ele_and_cart() + vdws = np.array([chemical_knowledge.element_vdw_radii[x] for x in elements]) + volumes = (4 / 3) * np.pi * vdws**3 + total_volume += np.sum(volumes) + return total_volume * len(self.symmetry_ops) # Multiply by Z + + def _map_random_to_angle(self, value: float) -> float: + """ + Maps a random number from [0, 1] to an angle in the specified range. + + This uses an arcsin distribution to more densely sample angles near the + midpoint of the range, which can be more efficient if orthogonal angles + are more likely. + """ + min_angle, max_angle = self.angle_sampling_range + angle_range = max_angle - min_angle + # A non-linear mapping to bias sampling + a = np.arcsin(2 * value - 1.0) / np.pi + return (0.5 + a) * angle_range + min_angle + + def _get_cell_angles_from_vector(self, vector: np.ndarray) -> tuple[float, float, float]: + """ + Determines the three cell angles based on a 3D random vector, respecting + the constraints of the crystal's point group. + """ + angle_candidates = [self._map_random_to_angle(v) for v in vector] + + if self.point_group == "Triclinic": + return angle_candidates[0], angle_candidates[1], angle_candidates[2] + if self.point_group == "Monoclinic": + return 90.0, angle_candidates[1], 90.0 + if self.point_group in ["Orthorhombic", "Tetragonal", "Cubic"]: + return 90.0, 90.0, 90.0 + if self.point_group == "Hexagonal": + return 90.0, 90.0, 120.0 + if self.point_group == "Trigonal": + # For rhombohedral lattices described in hexagonal axes, angles are fixed. + # This assumes a rhombohedral setting where angles are variable and equal. + return angle_candidates[0], angle_candidates[0], angle_candidates[0] + # Fallback for safety, though should be covered by above cases + return 90.0, 90.0, 90.0 + + + def _get_cell_lengths_from_vector(self, + vector: np.ndarray, + cell_angles: list[float], + rotated_molecules_cart: list[np.ndarray] + ) -> tuple[float, float, float]: + """ + Determines the three cell lengths based on a 3D random vector and molecule size. + + The method first calculates the minimum bounding box for the rotated molecules, + then scales the lengths based on the random vector to explore larger volumes. + """ + # Estimate minimum cell lengths to avoid self-collision within a molecule + min_lengths = np.zeros(3) + conversion_matrix = operation.c2f_matrix([[1, 1, 1], cell_angles]) + for cart_coords in rotated_molecules_cart: + frac_coords = cart_coords @ conversion_matrix + max_vals = np.max(frac_coords, axis=0) + min_vals = np.min(frac_coords, axis=0) + min_lengths = np.maximum(min_lengths, max_vals - min_vals) + + # Add a buffer based on the largest VdW radius + min_lengths += self.max_vdw_radius * 2 + + # Scale the lengths using the random vector to explore the search space + a = min_lengths[0] + vector[0] * (self.num_total_molecules * min_lengths[0]) + b = min_lengths[1] + vector[1] * (self.num_total_molecules * min_lengths[1]) + c = min_lengths[2] + vector[2] * (self.num_total_molecules * min_lengths[2]) + + # Apply constraints based on the point group + if self.point_group in ["Tetragonal", "Hexagonal"]: + return a, a, c + if self.point_group in ["Trigonal", "Cubic"]: + return a, a, a + return a, b, c + + def _check_for_collisions(self, + atom_elements: np.ndarray, + atom_cart_coords: np.ndarray + ) -> bool: + """ + Performs a steric clash test for the generated structure. + + It checks for intermolecular distances that are smaller than the sum of + the van der Waals radii (with a tolerance factor). + + Args: + atom_elements: A numpy array of element symbols for all atoms in the supercell. + atom_cart_coords: A numpy array of Cartesian coordinates for all atoms. + + Returns: + True if a collision is detected, False otherwise. + """ + vdw_radii = np.array([chemical_knowledge.element_vdw_radii[el.item()] for el in atom_elements]) + + start_index = 0 + for i in range(self.num_asym_molecules): + # Define the asymmetric unit molecule to check against its environment + num_atoms_in_mol = self.atomic_counts_per_molecule[i] + end_index = start_index + num_atoms_in_mol + + asym_mol_coords = atom_cart_coords[start_index:end_index] + asym_mol_vdws = vdw_radii[start_index:end_index] + + # The rest of the atoms form the environment + neighbor_coords = atom_cart_coords[end_index:] + neighbor_vdws = vdw_radii[end_index:] + + # A coarse filter using a bounding box around the asymmetric molecule + mol_min = np.min(asym_mol_coords, axis=0) - self.max_vdw_radius * 2 + mol_max = np.max(asym_mol_coords, axis=0) + self.max_vdw_radius * 2 + box_indices = np.all((neighbor_coords > mol_min) & (neighbor_coords < mol_max), axis=1) + + if not np.any(box_indices): + # Move to the next molecule in the asymmetric unit + num_atoms_in_supercell_mol = num_atoms_in_mol * len(self.supercell_frac_translations) * len(self.symmetry_ops) + start_index += num_atoms_in_supercell_mol + continue + + nearby_neighbor_coords = neighbor_coords[box_indices] + nearby_neighbor_vdws = neighbor_vdws[box_indices] + + # Use KD-Trees for efficient nearest-neighbor search + tree_asym = cKDTree(asym_mol_coords, compact_nodes=False, balanced_tree=False) + tree_neighbors = cKDTree(nearby_neighbor_coords, compact_nodes=False, balanced_tree=False) + + # Find all pairs of atoms within the maximum possible interaction distance + possible_contacts = tree_asym.query_ball_tree(tree_neighbors, self.max_vdw_radius * 2) + + for j, neighbor_indices in enumerate(possible_contacts): + if not neighbor_indices: + continue + + # Check precise distances for potential contacts + diff = asym_mol_coords[j] - nearby_neighbor_coords[neighbor_indices] + # einsum is a fast way to compute squared norms row-wise + distances = np.sqrt(np.einsum('ij,ij->i', diff, diff)) + + sum_radii = (asym_mol_vdws[j] + nearby_neighbor_vdws[neighbor_indices]) * _VDW_CLASH_FACTOR + + if np.any(distances < sum_radii): + return True # Collision detected + + # Update start index for the next asymmetric molecule + num_atoms_in_supercell_mol = num_atoms_in_mol * len(self.supercell_frac_translations) * len(self.symmetry_ops) + start_index += num_atoms_in_supercell_mol + + return False # No collisions found + + + def _shrink_cell_dimensions(self, a: float, b: float, c: float, locked_dims: list[bool] + ) -> tuple[float, float, float, list[int]]: + """ + Shrinks the crystal cell along the longest unlocked dimension by 1 Angstrom. + This is a crude optimization step to pack the molecules more tightly. + + Args: + a, b, c: Current cell lengths. + locked_dims: A boolean list [a, b, c] where True means the dimension + cannot be shrunk further. + + Returns: + A tuple of (new_a, new_b, new_c, last_change_indices). + """ + lengths = [val for val, is_locked in zip([a, b, c], locked_dims) if not is_locked] + if not lengths: + return a, b, c, [] # All dimensions are locked + + max_length = max(lengths) + last_change = [] + + # Logic to shrink the largest dimension(s) while respecting point group constraints + if self.point_group in ["Triclinic", "Monoclinic", "Orthorhombic"]: + if a == max_length and not locked_dims[0]: + a -= 1.0 + last_change = [0] + elif b == max_length and not locked_dims[1]: + b -= 1.0 + last_change = [1] + elif c == max_length and not locked_dims[2]: + c -= 1.0 + last_change = [2] + elif self.point_group in ["Tetragonal", "Hexagonal"]: + if (a == max_length or b == max_length) and not locked_dims[0]: + a -= 1.0 + b -= 1.0 + last_change = [0, 1] + elif c == max_length and not locked_dims[2]: + c -= 1.0 + last_change = [2] + elif self.point_group in ["Trigonal", "Cubic"]: + if (a == max_length or b == max_length or c == max_length) and not locked_dims[0]: + a -= 1.0 + b -= 1.0 + c -= 1.0 + last_change = [0, 1, 2] + + return a, b, c, last_change + + def _setup_crystal_from_vector(self, vector: np.ndarray + ) -> tuple[Optional[list], Optional[list[np.ndarray]], Optional[list[Any]]]: + """ + Performs the initial setup of a crystal structure from a random vector. + This includes setting angles, rotating molecules, and setting initial lengths. + This helper is used by both `generate` and `_generate_from_vector`. + """ + # Unpack the Sobol vector into its components for cell parameters and molecules + # Slicing indices based on the defined search space shape + s = self.search_dimension_shape + cell_angle_seed = vector[0:s[1]] + cell_length_seed = vector[s[1]:s[1]+s[0]] + move_part_seed = vector[s[1]+s[0] : s[1]+s[0]+s[2]] + rotate_part_seed = vector[s[1]+s[0]+s[2]:] + + # 1. Set cell angles + alpha, beta, gamma = self._get_cell_angles_from_vector(cell_angle_seed) + cell_angles = [alpha, beta, gamma] + + # Check for valid cell matrix from angles + ca, cb, cg = np.cos(np.deg2rad([alpha, beta, gamma])) + volume_sqrt_term = 1 - ca**2 - cb**2 - cg**2 + 2 * ca * cb * cg + if volume_sqrt_term <= 0: + print("Failed: Invalid angles cannot form a valid parallelepiped.") + return None, None, None + + # 2. Rotate molecules + rotated_molecules_cart = [] + rotated_molecules_ele = [] + rotate_vectors = rotate_part_seed.reshape(-1, 3) + for r_vec, molecule in zip(rotate_vectors, self.molecules): + elements, cart_coords = molecule.get_ele_and_cart() + rotation_matrix = operation.get_rotate_matrix(r_vec) + rotated_cart = cart_coords @ rotation_matrix + rotated_molecules_cart.append(rotated_cart) + rotated_molecules_ele.append(elements) + + # 3. Set initial cell lengths + a, b, c = self._get_cell_lengths_from_vector(cell_length_seed, cell_angles, rotated_molecules_cart) + cell_lengths = [a, b, c] + + crystal_params = [cell_lengths, cell_angles, move_part_seed, rotated_molecules_cart, rotated_molecules_ele] + + return crystal_params, volume_sqrt_term, rotate_part_seed + + def _build_supercell_for_clash_test(self, + cell_params: list, + rotated_molecules_cart: list[np.ndarray], + rotated_molecules_ele: list[list[str]], + move_part_seed: np.ndarray + ) -> tuple[np.ndarray, np.ndarray, list, list]: + """ + Builds a supercell and returns all atomic elements and coordinates for clash testing. + This version correctly handles asymmetric units with multiple, different-sized molecules. + """ + f2c_matrix = operation.f2c_matrix(cell_params) + c2f_matrix = operation.c2f_matrix(cell_params) + supercell_cart_translations = self.supercell_frac_translations @ f2c_matrix + + all_asym_frac_coords = [] + all_asym_elements = [] + + # Use lists to collect 2D blocks of coordinates and elements. This is efficient. + sc_cart_blocks = [] + sc_ele_blocks = [] + + for i, cart_coords in enumerate(rotated_molecules_cart): + # Apply translation vector to this molecule's fractional coordinates + trans_vector = move_part_seed[i * 3:(i + 1) * 3] + frac_coords = cart_coords @ c2f_matrix + trans_vector + + all_asym_frac_coords.append(frac_coords) + all_asym_elements.append(rotated_molecules_ele[i]) + + # Apply symmetry operations + symm_cart_coords = operation.apply_SYMM(frac_coords, self.symmetry_ops) @ f2c_matrix + symm_elements_list = [rotated_molecules_ele[i]] * len(self.symmetry_ops) + + # Center molecules that were moved across periodic boundaries + centroid_frac = np.mean(frac_coords, axis=0) + centroids_all_symm = operation.apply_SYMM(centroid_frac, self.symmetry_ops) + for j, cent in enumerate(centroids_all_symm): + move_to_center = (np.mod(cent, 1) - cent) @ f2c_matrix + symm_cart_coords[j] += move_to_center + + # --- Core Correction Logic --- + # 1. Create the full block of atoms for the current molecule type by applying all + # supercell translations. + mol_block_cart_temp = [] + for translation_vec in supercell_cart_translations: + # Adding the translation vector to all symmetry-equivalent molecules + translated_coords = symm_cart_coords + translation_vec + # Reshape to a flat (N_atoms * N_symm, 3) 2D array and append + mol_block_cart_temp.append(translated_coords.reshape(-1, 3)) + + # 2. Stack all translated blocks for this molecule type into a single 2D array + sc_cart_blocks.append(np.vstack(mol_block_cart_temp)) + + # 3. Handle the corresponding elements, ensuring they are flattened correctly + num_translations = len(self.supercell_frac_translations) + ele_block = np.array(symm_elements_list * num_translations).reshape(-1, 1) + sc_ele_blocks.append(ele_block) + + # After iterating through all molecule types, stack their respective complete blocks + final_sc_cart = np.vstack(sc_cart_blocks) + final_sc_ele = np.vstack(sc_ele_blocks) + + return final_sc_cart, final_sc_ele, all_asym_frac_coords, all_asym_elements + + def _create_final_crystal_object(self, + cell_params: list, + asym_frac_coords: list, + asym_elements: list, + seed: Any + ) -> data_classes.Crystal: + """Creates the final Crystal object from the successful structure.""" + + flat_elements = np.concatenate(asym_elements, axis=0).reshape(-1, 1) + flat_frac_coords = np.concatenate(asym_frac_coords, axis=0).reshape(-1, 3) + + atoms = [] + for ele, frac in zip(flat_elements, flat_frac_coords): + atoms.append(data_classes.Atom(element=ele.item(), frac_xyz=frac)) + + return data_classes.Crystal( + cell_para=cell_params, + atoms=atoms, + comment=str(seed), + system_name=str(seed), + space_group=self.space_group_number, + SYMM=self.symmetry_ops + ) + + def generate(self, + seed: Any = "unknown", + test: bool = False, + densely_pack_method: bool = False, + frame_tolerance: float = 1.5 + ) -> Optional[data_classes.Crystal]: + """ + The main generation method. + + Uses a Sobol sequence to get a random vector, then attempts to build and + pack a crystal structure through an iterative shrinking process. + + Args: + seed: A seed for the Sobol sequence generator. If "unknown", an error is raised. + test: A flag for enabling verbose test-mode output (prints cycle number). + densely_pack_method: If True, applies a heuristic to shrink very large + initial volumes. + frame_tolerance: Tolerance for checking if the final structure is a 2D slab. + + Returns: + A `data_classes.Crystal` object if a valid structure is found, otherwise `None`. + """ + if seed == "unknown": + raise ValueError("A seed must be provided for the Sobol generator.") + + sobol_gen = qmc.Sobol(d=self.search_dimensions, seed=seed) + initial_vector = sobol_gen.random(n=1).flatten() + + setup_result, volume_sqrt_term, _ = self._setup_crystal_from_vector(initial_vector) + if setup_result is None: + return None # Invalid initial angles + + cell_lengths, cell_angles, move_part_seed, rot_carts, rot_eles = setup_result + a, b, c = cell_lengths + alpha, beta, gamma = cell_angles + + # Heuristic to shrink extremely sparse initial structures + if densely_pack_method: + crystal_volume = a * b * c * np.sqrt(volume_sqrt_term) + if crystal_volume > self.estimated_packed_volume * 20: + c = self.estimated_packed_volume * 20 / (a * b * np.sqrt(volume_sqrt_term)) + + locked_dims = [False, False, False] + old_a, old_b, old_c = a, b, c + + for cycle_no in range(1001): + if cycle_no == 1001: + print(f"Stopping: Max optimization cycles reached. Seed: {seed}") + return None + + if a < 0 or b < 0 or c < 0: + print(f"BUG: Negative cell dimension. sg={self.space_group_number}, seed={seed}") + return None + + if test: + print(f"Cycle: {cycle_no}") + + cell_params = [[a, b, c], [alpha, beta, gamma]] + + sc_cart, sc_ele, asym_fracs, asym_eles = self._build_supercell_for_clash_test( + cell_params, rot_carts, rot_eles, move_part_seed + ) + + has_collision = self._check_for_collisions(sc_ele, sc_cart) + + if has_collision: + if cycle_no == 0: + print(f"Failed: Initial structure has collisions. Seed: {seed}") + return None + + # Collision occurred, so revert to last good state and lock the changed dimension + a, b, c = old_a, old_b, old_c + for dim_idx in last_change: + locked_dims[dim_idx] = True + else: + # No collision, this is a valid (though maybe not dense) structure. + # Check if optimization is finished (all dimensions are locked). + if cycle_no > 0 and all(locked_dims): + final_crystal = self._create_final_crystal_object(cell_params, asym_fracs, asym_eles, seed) + + # Final check to filter out 2D slab-like structures + if not operation.detect_is_frame_vdw_new(final_crystal, tolerance=frame_tolerance): + print(f"Failed: Generated structure is a 2D slab. Seed: {seed}") + return None + + print(f"Success: Generated a valid crystal structure. Seed: {seed}") + return final_crystal + + # If no collision and not finished, save current state and shrink further + old_a, old_b, old_c = a, b, c + a, b, c, last_change = self._shrink_cell_dimensions(a, b, c, locked_dims) + + # ============================================================================== + # Test-related functions, kept for compatibility, marked as internal. + # ============================================================================== + + def _generate_from_vector(self, + seed_vector: np.ndarray, + frame_tolerance: float = 1.5 + ) -> Optional[data_classes.Crystal]: + """ + Generates a single crystal structure directly from a vector, without optimization. + This is an internal method intended for testing and analysis. + Original name: generate_by_vector_2. + + Args: + seed_vector: A numpy array of shape (self.search_dimensions,) defining the structure. + frame_tolerance: Tolerance for checking if the final structure is a 2D slab. + + Returns: + A `data_classes.Crystal` object if valid, otherwise `None`. + """ + if not isinstance(seed_vector, np.ndarray): + raise TypeError("seed_vector must be a numpy array.") + + expected_len = self.search_dimensions + if len(seed_vector) != expected_len: + raise ValueError(f"Length of seed_vector must be {expected_len}, got {len(seed_vector)}.") + + setup_result, _, _ = self._setup_crystal_from_vector(seed_vector) + if setup_result is None: + return None # Invalid initial angles + + cell_lengths, cell_angles, move_part_seed, rot_carts, rot_eles = setup_result + cell_params = [cell_lengths, cell_angles] + + sc_cart, sc_ele, asym_fracs, asym_eles = self._build_supercell_for_clash_test( + cell_params, rot_carts, rot_eles, move_part_seed + ) + + if self._check_for_collisions(sc_ele, sc_cart): + print("Failed: Structure from vector has collisions.") + return None + + generated_crystal = self._create_final_crystal_object( + cell_params, asym_fracs, asym_eles, seed="from_vector" + ) + + # Optional: Keep the slab check for consistency + # if not operation.detect_is_frame_vdw_new(generated_crystal, tolerance=frame_tolerance): + # print("Failed: Generated structure is a 2D slab.") + # return None + + return generated_crystal + + def _is_valid_vector(self, seed_vector: np.ndarray) -> bool: + """ + Checks if a given vector produces a valid, collision-free structure. + Internal method for testing. + """ return self._generate_from_vector(seed_vector) is not None \ No newline at end of file diff --git a/basic_function/__pycache__/CSP_function.cpython-310.pyc b/basic_function/__pycache__/CSP_function.cpython-310.pyc deleted file mode 100644 index 6b80d351bd5c8f8361a22ef4112c988465283110..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1078 zcmY*YKaUhQ6t_M9XJM1aW7_L$7B&l&ja)O=if$rtrp!>Do1bH~c>zecm<0CCW(8McY& zxGZ!fS}>;Nk~dXZ)O6K~i<+|q2X9oqeD+3i{`r@I{&`#y_xCrNfDB=sz*3)pX|%%z z?Qu^G*^xcM2&Ctv_Rj)C7-nSg9bb6N`#^fnc)M_q-jd$$1LpUkj*P#>&(LpaAL-bH zIyp;C^b^{WK72Pbp~1(91;?n5P0%NLWRflJQmwFI&apl+<1JnRSbDK4cOzdeG@hPRrYvkQR2s>+5}gR!G`Yc4mmD=MihX)p8Oxh%S6J}DTZ zZi%+n`lwP)o4p@vu{4CdDG>i z|E!i@*rn$)BO@e_za(qT%zk;&Oh<qpV>p_NcY1c}rgU@4B_yMr_A!~@- zMYB|9AnfG(q{C#dCN@@O)ot#p4&GN$OCL_lTfp+T<@!R%)up`cd@krv>@{{5d^UdA W2)3^IBX{SN%NC;yKxF`OLcRh8FEqvg diff --git a/basic_function/__pycache__/CSP_function.cpython-311.pyc b/basic_function/__pycache__/CSP_function.cpython-311.pyc deleted file mode 100644 index ba4e2e82f4acb503b75c35068c00cbe94f590075..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1934 zcmZ`4TSy#N^v=%R_xc(&W?cu{2A2@4Al1-BZR=yK(BbuN^=0bQ*d@ep=dv7=PFNv;LiPhJf{)TrcuoWd(ifXVF zBW+DP7lyN&0AF7!#Yx90ix`ob+(PZ3Iq$}^4e^iA@4!f)F~C|x{MWb z4q@I1ha54vxQNycc`xNVDO4?cjV3gN>>3p}ltm=lBec`7mXWLw%aDsEv1ING^ERa^ zXNldXnt-;?AtEE48Je1`Dwt5wK@1tnD&|qq4yH7aN#td1g@vLOi!^p})oA+LC4|>x zTPb3PM$Azs*JXk1~1VUf`!CwSZrjq?au+*ITWNh>mc%#BRu zF>>ZDU86DEVx@2>By3X9=ty_S#SJ-Cq49~hk^94JXOsO>}PN6WpDV|n?d%b56_h+ z*T1T^wv{Ij+N79V{CJ*JnzVh@7zk4|79{z3G>$v52+zL8w zl|Ox%Ze4%0F}^we)5tH=mA>Gu>wf=e&_C+uM)%UA9vnSpb!4{#^C6i}VSbRbHGD9& z{y*vK7Dp}GYeSCB{#O(o^J$Fv*rEv&krfLsu3|p+YA5sz;9*W*W$1O{J;TzLbTo!X y8Ov{~V={Ck3W87t>2mE?1p{9Ase-Fs_^E<-yzo;6-Cq4Uf;r*F5%?cei|1c89ocOF diff --git a/basic_function/__pycache__/CSP_function.cpython-313.pyc b/basic_function/__pycache__/CSP_function.cpython-313.pyc deleted file mode 100644 index 3e1cc991a9a0fa43b4921f8ef6b936e1f4cf0f68..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1729 zcmZ8h%}*Og6rc63zdwKl0|r7JY9YA=90XRBstF2BN`%{}^>(XDKCHI8gY9C^ZfDnk zs~mjjAqrAUt48gi2af5ze*o%0vKxXfjPr!3&(EJwnI39)Kd4MYV|tH63yQ)|j9GS+KQFGH#gYJ~@%VF)!Z z3$`~%Kc0SPMXG%cN=)YxldsbRkU^g7Xpv9$6={9fDSxg=18Qf{uTj@t3VEiY+c8h- zaC(+c;v`@0%9`UP{*DjvGGB~e^=Fb7yFqg|A4&oWlzKcbS@1IN2anPKpkS%*!WJlr z9xe`esLSObD>Q%J5fu1%0)Qw;gD=!+YLNVFMIqUb!lfZ^mmK!aLlHSrOnLe)mqQmG z(b9;gFOGSr%O@@>Loqp4e5-5mkR0`5r9ro@OHtFFXP46s$Glx7T2Y6j0PSzQ^tLC< zanDBzUefxC8PB83Jx|t;_r3)n1G1sVr~gqpWm7|EFrvX)O~EQEL(JNF$`-2US7Gxs z%y2&uG{bC?kPe~gqF`NJg^D>#`s>CTR8-rjD^+A03X6r=@REqtpyTf61&KX;%#7>b_L1 zs-~u9Cf=8X6TO^(xFXTp;TFK)H-zAN-M(C^j!7;SP}a3Ti~dF=;ymX-rXMB z@6WbAIuv@_!muL@9|$8Z&NxEmK$tudN894$OL21h$&R>le0YaagbCEo7?f|R1 zAQ)n}LYdQ~!ZpbVcD4t$mwCPISAHF2VQ=Fq%3>C?kVUU?mN5JbO_M$3L64`{CZgps z*QIE{nG_3NmqlLDWg|{2&gLAVQSs>hbIJMVe|!3JyCCk@KWGd(1ilBXK7i0@iw)Z0 zju^5fJA@Hv&uQ%+2Zk`r$ow}s_n7yFbe{3H;SRkZo!Y`MKzj@g!3ffVR~40cGw0HU_HF_Y_m+IU zPn>s7q5fy1Y!W7xJ_E1g$1(V~K@(G=Vcf5fQfb(;-- zdJR9dQ4WIvWHxbz!>w%O5CU%N!HaCdO2`Lks9s2}7NTM{Yz2%$Szp7wvcWlU{2)uK zTmo86i{c%Ue?r?|AI=t{=ChUJa(2YztX)Y??>&2R!C32sX!_Xyxhe~Kv}y`>qi1mJ zmvaM_bB%Eg?bFiDBd>WsO6VejLh3pQgi{>Da~%({4?J-?j7ds>y$O+j!Sc&ue_2zoR(KXWO2jwsgTPvdCkRK(!SXX;xCBV_+~BG Ss^YiYom1{FA_hS2iOF}miZrj7!AK=1(Vc1EtP(UzCp0uak<(i~?CxBZi zZHoK@sm$L9{3}?vNfD@a=f)ZC*~WPm^IU4?(c0GuCz z_0J$U;<)hM?CCflh|;F&(LqDePU7@ zo9rUD$rorx`uOE@8(Vyecyx;T)JA<~MmF2wKDYUPY*VN|GL!xEfcmi=?{C@B0Rw^* zVz=%1fXxYyq2|P#*vSs_1T?vr?I7biga)NgkMDtw7%l`Y$V*YzRIzp`ln)(!2Ou7- z#rBwa`;H4s*==3Qh3Ez=Y1mrO?i+Wfk*@@0MmDr+jif%8@(En7H_-hd;BGI4p`wO^ zEvF?@#syk{XxNqV<98?UvJtJJulOZE;Z^10>%q9u%Z*T5+ObM$N7}0*daPKtDyEEc z>Zhpbxsr`LEoIw^(g>&^n@+ZnPP?EMt&2Q9J)gp~)m9r;yUfTMR--M28-IvVBbu%l zee_y!=@JHw0c0-ohR3a4;t>LF7r={R#w*ANX{cE#p;xlzF770>Lc7qwUb*NRIDTB@ zHB*4r(6am+sW0K$?~3xf#jjc@z4(bYRSUUbU8P!_3B5M5TXY*G=%W|UZWvp$lI;-t zo$0Ei=bN_l8@_;JpPUCUtb1G-Xa}u)e^?`i*21k4D6H>;Ksd)KJooV!hrlzhY>1=V}bfO+PO>O!|6)6J1u__JQu;dtfh$WJhIS=~Y|D}&W_Dl9kZ<|HjRQ&2qt)zF=lHn^hw+-t_z+%G~O8n{JeU zyGzw}Gt#fO52|XEtN+Oxw}k5=tKV(a@iSN7bz5HDZd&{8&i$tAG^n!B>C&p8I{K|@g_Hw>e=L>jd)%l`pjLdQkPZ;iL zYUf;x=ai>kzEBsBxb9rS(`WeU%jn?>uIJqqyk%zZEN1kA^Ci@uD}1@|75w*A=LXk) z%{}G(6t1hhrpwfp=l|2ri`c!V(C$^|>pbdsi>&jK^DOL+u|qaI%<_=MotRtc8Gq;a8K*vPfmg(57px($_xa2>af;}0 z_Sznf{lpoy$4@D^4IO4;1g_8XB>q$Hw83jye28qY4?frEvFU`g4Ke=nvIjP*o+#g=(l+Xopk_b)}5h_4=3JHR}_?FM@m5ND9HtR zv``1-@!tHxtTA3eC>;o?8!m^TLjRKpEw=RSabZ=6w$oL$oy3DRmtG(S&UO zpaqHoK@iOba?_IrBinkF3EpdicG}~#$-8d0d#!C(m>1mj#`>Chh2`ZtS0E6RzUWs{ z2E9!{a>l#>VYFtp><3lTad#k?AfISb5Rwg-yYTH+w@JlJ{_=U#YomDs;wYw3>Y7$v z`REUk^sj6UWue*X;XrQVJ==IJa2uH8wj11carCaR;C%jt%j+LO-YBAk1Y#>I+^%Z} zb?PwrqdbYMEnNPyNJ6GSN1LY{_$lRdC-Ef$+Hp*ctDhqd`SJPU9m_7l>G7L zrQJ@;z0|V%yLM;Kvyp?UfRjsmd#QU6>~`7~G>c2n9=y7>(`&P~aS26M%6+UhF^uMV z_aMr@*tA>Qj{VYG=-sE+dRFOwagsBxTNV~Ou&nhLn;mch|0QJPbC{sM?;dPc&ql>> z*v+2%st_F^D(ESi8=(hFZg{l7dtxQ~WbxJsI_qMVo+-ujbzpV6el*qXU>9OWp*P)j zG>3jHB0?)6Y(JW(vY-PkY@~46!bG9(KtZygmqGWi>YZLY@T1kSTMHUU`dD6wI*SZq1X01mTLqm?(~okJ z&j(S(+IsJuSKr-u_02b}cdu{0_G)z6?_o9Q=vX_#u3G_i$HxlN%ahLMyUiWmTHY%$ zN4u8NrdX!^%O|Ar9Pa!tAeo!atJ>qdn$@yuR-M!5)Do_uTG5MoQO&Cc?y{=!xcIog z@ zIT5!wiP4RXOrw&F;`2=_(7>p-a|1O{LY1V+uWEN;Jn-UFHDRI!_JszlYLT{n8MaKP zIk|Mj7ip{NomTgyck#e~qd^vx{@0H$$`dGi^mHaYRaM@w4^{DHbT=`icm`$f;Np{C z(fZY&c&yE_^&nnH>%lwm6=bjuELCqPx98v22I@cyv{2iF#CFtuLo5fzVJ2ibxTl8d zi^|;$8(Vt+cTxfHHmfMUIleovfNW?aFkx)rKukKYsl0aG>q7UjAHYTf;kMkN7-72& z8W|KDNdN+TLK-WHQN>(p-Mk6442R$Lnm%+M*S%`)20{1grAvMXlFGejuVZU2ed+nH zq%&k5y;;ph1s7yRhKwsuhlv#sU!$aeB$}Bxr&XOdBT~I6pOD?;&aC0VTey6(n&&cQ zwW4a@F+S~|KfXf~z58Q3lf~n7@y@)2jH3=nzrhK)y)?j{9IB+%4D^A49m4*gyg0}N zdZ@Cllvz<8V*l>W-pzz`qO|_&H^JECT7sK-tzOH#2Wja*Qhw8hY=AT-UT4~^4&*S@ z#`vV+T+vPP!v6Yy!EE@inw9dR(r^RIZGvBebJlGUL^^EONXIZNA);ASf*A)XOZD(5 z(RGv~CQI$E(Oo`n=SQj)>A0C(-3w@S8<$T~*3gP-3IEN3O&MDMx#KH*0xKU|GCGPKCJV2sjIFDL^&%^6TrHHdy4m$9S_u&hv7_0VV-gKU@yvpa-3X!Zh*e@baL z5}_p5JrzmN6oUh@ABbZ~NhB?R5k2wNL={;KKG=hg=Z5;@Me_wS?N+kIQ5EWf-R$n# zYorSyZP=}rT}@~Aq-Pf$H zJHQ)n4k_vXqTl57V^XgI3lXu3`Q)C{mwcDy&xFRED5h9aFpg9oWnEKF>VwH`y^7)waEYHr(tqvvMh;buC#a(*DIbmZ`0szbEQvNRrQ<6!&P7VrkwjY8 z1E24;SW6P;P?V5$Ns^@83+C+NHG1rIBsBvZ->>Clr9b&Tq@Fd~V1fkx#yTt8?ccr(Yxy<-!H~IpwiBB4yKddv@g!#4 zfC`inBh#G(nemLM`(zMXYSA4+41TGG@0HGASW274MK_Ij??~Z49)GE-advOp?tE0 zV>RKwzTA)t{|7Hu4RIEErXj#b{6gYQl*HPA*hIcofDs2x0tMyfqQ zH!#~;c9&|g?l39kL%)pfd^$wMd|qAFD(Xo(!U@ks*?X|cJ7SD#bo4`1L%;0n__N9m z`8c!>@*$M3htL{+33TM>`IDKk`ED9_8)=`iXQcEjgHh%F57JGUP&tBK;3s3KO3L|c^MxA;smd2#Glfg! z7=Wdv9g>Pdst1G+YW}Lqifc#`j1{%+uP9nq9x79K!HA8Z*08 zK~_}C^6rTjaVz(mELA}#_O~d*Ebsbpra;;#3srdzU=o8S%fgS*zx}pqnfwE_WIY)PlN7^8dyG(jlKtI{g2=ie9 zJX#wVcPqGquHa-A6lf1dXb@xMBp*qzRLX~HWDqnkf~#ucTc{A< zrdmzBPx%^>YEjCROy3n9UmEYl`tv&w3J{hL#0|88IqJ2eTx|M^Dm{Ci5+(0jNGvR#PQV~iAN}e7~vrQ7zdgT zFpdt6v=L@;@NzJlaiqTs1bvVv13JzZP@Z!PE-#`y$xrd2&w~YI9>I&VmvyoW3Y4@G z%5&jFIr^VQd7cjCzJYXlU%95hgM()WWz4t`mXB1w6J`VWi1wz)tl(UUa)TMnW`~4T;LU0hhlkZuoy1- zzj621z^dTe4J+Z|3cbM+=2O8n>lDMKaOp_L3ND8QZZk`5mXu%)vs&SKFNZ7pP({U$ zRd}kLsVxQl{U6`L$wNsk*?&07dfssbd5cIiYDa-YX<;`&kr?GICxB6&dEven>|&k) zVj!U0)!{6%Ni(AW46?fE@3=p0jH<~C1x37z2noX1Nqvb4-h?*-ksIOhExM?8lNP_oMF za1N}QH^ER`vxA9>`<_37o1Un55nJ=={4(5KAh?_C%|q=7?*CDnQT%}7G-Bb+!6W0>@J~{~Snt{tchJ$0+1~`@K24}Pb{w8B z#O*}W+jf(vz=4Oub^Cd5Z_V4g_>y<4IzM8VX4tM_phY}OT@vmg8NL(**y~8Q+Xup= ziZ4+@=BpqjC7OXhUmRE%smY^=uT!bv-vz9hb-gwyH=4D+apRlT#+%o-Ubk*u-?;Vm zyVX@GKw8vK8%eE@xQ{%z%>$AXfM4AQ;$^C=fFjwu<)f5zry8JWN4BWVByO@;vec&2 z0=AuJ+Aci(+q|kId1}*XfwZmBp9QH?%W=kjP@95?r_Mbb%G&DK*eRLwEEdUT+&-|< z4r6h)Fj>w66=)aOsaGp=HhIji&2~FJ2OQuX0I|V;9jSYG?PTWhWFo`yc6Opd$}(yN ztS@`ryh$N>$~uXhejm%}zk+08*3h8ZYI(J!0TtFu>axD9uA*caWty5tTGk;r^@?WT zzM?N^3tE}#9xoeEhqeCaeliMXK`!<%rr1XkO~JQ{fD3Xs`gJmMl0Py%>7FI#yI{f6 z&qcl_f{q)IeGfG@9W!CZKNXiE)@Mrz)Gz~+GTR_|iAV@Xhhq5-T4Z4gUWRfC^UuFb z>gm@2lQ~919pu_RdZqenQnmH3k&%gd7@%BVhg7Bdw^5InBLD3WG8Of@ljHg=9*I=g zpO+)~x8ok$@rdnY#I$E9(G3k_e>v`Z5cfTp==(YDn~dBa#eIKY_U(T^e)3207^s(T zm5h<}1Vf?tL%gS4r!YtBggA`f8{jeI-C>u&?nf#=I)Dbhlvf{c!kF)2a^f45{3%M%vm)N6Bt;O!Hr>{dpgGs4RlpJS zzSm{jMA{&EDMH9h7;+c!r*RwQVp@obWLL-=@SWkJPP}Tggm8wF-sCjB{xx*D3Ksyt zf3U)$USX?Zm}cR!EUL>~N=^dA0C0)PC1in=wEl&kbn(VeG{XOKxR@pWA|?a200xC4 z50*$lHLU)_i7D^W6xjHZtYbeMwL~44j<>-bto^tx`{*Xxa&2s8z_W0uhETiT*9W;l zeozQuP=tksJ}440feXfoEI9&RfmTz4(g1;M;gqA#Lxm4Z@$a*2PkhYg6fpGgQh<$D_AI={Z*vY#v zI29BRr@}cH2#evVa1qGJQaE#1g3*9J!wPZ>l)L*qywNm_m|{3{q+`CzgO$N*2-6{4 z#ao>YPopLi&V-dlW^g7v6PDq(1a?s^^gm9yLTVonx1{(X1h&CRiBK+5o{0%^^#}4%wDupY3`R;mJOAMDSvaB?EC}zdj05zPZNmM$#^sz%g#f-R`TVD z*{w?0=YR1dk#IW{Dg*Md6Ghh!m_O zkE?*Ih-)e^LPV{CjeH+$1ndgzpyUAEKnkZ7<(SEq!|5d6c{MI!6Fcmg$zv5aPh0~uOO;u8}G3c=I{jbXY3&^z??-!x1l@^mqB=K0;h5PBmzKRNl2sxJaEW= zn3u)iB20cLfpv=HhQ8$Rhacm?yFNmG!Q@)KX5b;bb`2Wr&JGm5c5sof;2|_T8aG=a zQe*7{1f@t@`xlX`&ct}>&hzyIFLkK}+r?!3ia&=WDuUrO53H?kz4ex~mQAW6cW`8o zjtP9V49*Zyf_?kJ z0dYZcEES6cDH~7^iSJclLxbD z>8knu%NvLRBB+n_Cf-Xw_V*H&R6szO|ly#|u+`+2|I+*gX(y;{) zmWx*+%JUg`aL}jlMV5UeN;qf{UvG5~;}2zG%7=#LwyZWnDrI=KLFcwwlIL{_m`oir zC18?ItW-MamYjfbbZa20Vv? zltHa(_K#0;XnsbIQ4WSKzHMvYls;bNf4&?UbQq%fIIe;NyDjNci)Jw8mWvpJ7?B<$ zN`f0x^X4TD<_!!RATDrRI*`e~K!GIW!+?R~=1sEcorat~QQ3%^G0FdycR$ggn< zMgzMn7ZgPCek;k#w!?M^ahV0mk_XDNNOj??u50>WiOh%q-`uer|6Ldqt4;xLbr4L! zbD{pG^4oYf_Daw!XPl{h#s4n+Bl35ae#M_hTXLJ0eiI*Q@-sL{=`))6E!3C8)ub;u zIuNos_*E7fIxwcwY}%ZL)=~~jN5=cc;LPCc;N0LTl$Rm3&){I2p(!HFj>0GYIeMHq zHoGM&BRe?nz$X))hvxC`?t1*Z@dFjRHWMyBR4IosDlneTg-=2IG6v5C z_);!>28h+{BhCK=Bh5{YgtGZzS@?_wuMMrv_^P7cd07@N@;mXY22bDhXveDy_J05o zrAo*xu#Og@Fh;`7Vps17K=a;7U!@X41tS9bi;R&mr4GJEyb7^J0f{2ScAyn=YSLZO zQ5o4f$?QXHNvGFU$dhpf-8g zj#xL}*fGz?r}aFUiKa_mt&sVL;1opX4K*qge{_hcb1IGP{(&mGJtCYsG?p^UEN<|GVm~B3P!-U?kJ+ePf z!1nHpesq@jnC~P)4MQZB@Yn!56ko75^ii32Hal&1)ad@3;FkT%@#ZHVaDi&*Bkf6Q zNI!yFhcr|jmNYCO=}W8j=0_B9R@HCbyy?+5$M`lZ=ybbSfve_Q_JeW6miHDELlMa! z8FPzikfm?`=~v!+?W(!yHY72VfF0t)>!#oMdOycs892yn5RqIp`mnmT-;E-~pc8P$Jzd;E_ z2SiKKNofIvl|)K@FrdKVyQr-`E7gV{QS-k|$?sC~caXsIL*LO8Qd29(v>$8e__&PC zc0eRBdBm^NBd=3Ju}v~ugKX%i0HA<1Yd^|U7CMuRx)^~|zfX_+fRbNDG6W3dDPy#? zmgzu6w&l;Z{HfRSOg8ih9_c($F4nj4y%rl8u#XZI){hE_@(LRS6abwz_I`>FXwV#x zWy^COB&vJ7W0QJNJ^`_QpLCu(crw43hb~^hh0m`DdmtShns*t9MIPlP@^L{AT1KE3 zv?1v03z~tl(&Lgrzf@Pz3BAxBFJur@3x8M^{h(JJdepx^;X|%u4$vuaOyUh0iwl{> z#K?2Hfy@89NCId*q~|=;{9*uwob(<@eq*2q0A%r%uD-7$6d(&2RfjeZmm+8UHXu@J zdk7y7>p;dJ104za5pmmP;7-jP+K3Bw#jKEUWw3K1f-q-eL7xj ztH&5mhfEpcJVWV}dX(hXs`%6vAK1v4EefW9qd4Y&kWa+MaXxX`%>#S@GVBFk zS~W;THc8LFn#f7my9e<;SAsjz(LO27%lI8{ahH!oIqe*{l;)~u?0t^TK z|4&>oZ2MV~segWi*oU7MkH(R#_s_k_pUDs=ep2gx7T?&veN0E+j23PkbmI?riJ{B52Ilq5$q%WF zc{THI;?>nz$%y|cRT2A*N{H0s;1Y5#iN88$~{xOmo z&}#lSRS3V3u7XL zeZo0-Lw-}A9&`!_BK;HRdw5tjvEBa?lS14L+t<856frHAa2Yv#{|nj5!^Vepn--4V ztDNaSy&;!8W?oR5CtIUyh5={fFki(F{$UtF%Gk}D;%`x@Du2$?SHZ+vEGsHn0Q2cz z(LugsS$BJOGp@;5meZ+Qmbi>ZC4P{93)(3Cn>CS6zKsn3D*#bur`bU~=2xiMGnA}R zLK`lgqvQ*eyh{nI9{-GT>@{g32S)^4-I9t)lMn_a8A_PO{~F~EDfz3Eus-rnC`YR) z{uL#|W#k;KrI*8l)zEtJPv9Hr$N4Mgk!lwKOgzqWKiVgdz#k(9`+sTIsFZrw z&;#W2#|D+I5Qu@XKc1}3Q|-3#DUCu;ki{|%h7p3VRO diff --git a/basic_function/__pycache__/CSP_generator_normal.cpython-311.pyc b/basic_function/__pycache__/CSP_generator_normal.cpython-311.pyc deleted file mode 100644 index 5f70bc4a231a585fccb4ebe1a15bb1a3306868e7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 30681 zcmchA32+-%nqK1~KmsH{@V=Y3C{mOt>aZwFGIi6sPuuhXLuinKL=sd3lqDLNbtbbL zc+07x6K{r2JPF3JC-hoV;Uu1f+00awWHzfLo1~lRiLjk8VJhWhrz)wU%QIDzsY*8A z|GLrW21&0cSF82#`t|E~zyJOp|3-Ov8HelsAI;tV@g&FnCwjlB0t601g@#^tv#LLEO1n+pQ;2W#Gj`o{i(m&O~RTVmKL1@DcH0 zA{n0I6G?F{lAIHxd?Y+0#02(oA};bL^bFUG441{I5Q`*Z@fm){iEEep2FuE>-;O2t z>9{aA73GEKL~I5nCvQjdAM4ejw8&IAkr?3LxE+h!=I0W~7Uu8Gg%hz|BC3r~mtBkB zila5rdvnp5NR&^;`9yen7R6F=`9`9>vtcnj9Zg0>J~10c#vPI9)D$norlT_nv@9`D z7M_`$ilQ!*i>6IQBXd(>k&lb9=u9$9txxQpjVEF()<+dh>Vv}1-hP;fMZ!~45Bb@t z@LVExYbsi{X`m8={8*BY%|xc=1oQ)qxg8f0JdIhtozZ-msBHl8I5LT`NLhR~8JmvH zvsUo;V_}{(K5<)&&D@R6Olob8#HXfW)EA6sGRnq#uwgh&t{~NAMGj+>TCD4?}ln;%`hv zg-Lo@8lR;_p?&1T3J8bTbVU<*>BCHiCgx^XJJiza=*CoRCKgf4F2rZzdfazO#2|>b z6`Fz4|!H3^+{o8WHD};(@DSLJbemwhyKs0ctGT#hVGFpJT z9Hz3JLM8GBg`g#G6$(2BTD) zZ&jhB36Ic((Wn!e%{8IbPSjwk+Z;EKrl};S(1P0Qg;soP=|>@t(Dv(>EAZyJ*?4}9 zcA;PGXidZaR*hfPgW}fE>TBF`Tsu&IlhA2teV65_8#CmIR->&w$lEORVy1TBe;@wu z#7I?}M$VXVq?=c%vyf+^Oz1~BE!Y=B!a$UIVOmY*n$0t)ua9c%i(Oa=ZGsmFVy_y| zA-(~$!~HIx3?A37fi}YYWbA%)29P|-$0zatnBYCYu?Dgy#rWJTkO!5_5D0@d{B7g_ z;AQ|HpgIPWG%Lm<(F9)F z8cpq^OGA464u%(Ao|uYnCYQ}H)|G-(#*W&!W{Tk*M>JlY~mOM}MHkyccp zXuZmZAp=dIA2YLqGeTGt!w*sXHLXz(c~cHThty(yH55jtQ1EOVLzo{|piaG^ESc9v z5=28h5(7;l@NdMDx3!7isw@n7$Qb5OY$n9IlAoqI%es^yFh4 zcc+lYL=XwK0@}fszzo{;HNG?a5kkFn{I?3<)^kot;ySAz{|0Yk#Ni8utm;XIC9l(HYgnpgx}LgL_m5rYAx5@%*pOMn+z_k%)?k zS0>QW*efTaiMz@8>?^Ym!B)`#p&?W;OxU+l|4%n z9f?oR9-ZH{X|;qv8N`yI(BQLEad5$jqlo0Q-$YXqJc744spqnGuXKK#SQ>tG=)2$e z=o?1jm&n8Y?0^2rpZv+s2)=+vE)k=%(Qwi-aP*1|+*w4vjbSqbP{BSAP=z!Thp45^ zBRA5`fS2W^bZigjMK-Bo8y-d4VPZQ1G{H8Tph@7x9)!R0JUbPhz9oc@&i8+MLQfrG-G)hTrFqLDaaMX&T=lW2dU@4+k)cX93VoFA}&&*;25XG#P zB;q1y9o0#TPc4msk4VO`8&o$K9s%Mw5C$MhCZlR_GxefYhBR50L6j!?oT@Vsotn^w znhiP)e}c#gf--DTsGj`lpKsW9Q5>VsXdS=BZPas>!G&{Dpa=IV3$w%8ua>qJY?fM1 z$-&b~@U-MRy zG#y&$SDH>gzNa*eu}{{zcPv#aeWP;sai#nC?_YaT{fDjcsTY+~FUs97QjU%s+nA55 zf6|KGDEC}edM;CX=SB&4^tkQmk=Is_ye1!cT{-f)>=~gqEsIH|Z6w#)@wD~X)z)X_ z)}u=6QF_so+cEHT$EnpFr{o=HlpSX%t@{@quCx229p68&<_r8zgHqY?@!pRPEnQu% zlsku&&f%3&sq&~?c~n8@JEr)K9cgFMAg=&liU>uN*O7K5saE|;*(_w8@omagz~ol+pSHcz`Of*24KT_+fA(~Y zSW9AuejRpzkQng>2Hqb0FZl-bm9Vv~1}3A)P;@F90-C^f6(RD3SOB;r{6GUSm3RqJ zaMIvUV)*nFic#Ifu^^x7$(OHs^vns>7E&E}NR%Ka?HGN}h|l18e&^=xQh+#C$S{fg ziMt5i;?{#T8F#L}`GblL&QW^RriGcaxyFvCjl-*r!vLX5AfsBUQrsaEVvik8xG2`?yt>MuytfvKeX@J z`k6&UTR&Jsb_>$xwnN%!2UrdV5!7j?5p#iEuru7`WTA_~bP1)VoGI6&Gwn{5q}&sv zYOsfIU|||0&Pfv1gN)HBID_N-3YOORmiSdoutEeHfMtT1R>p>bo|sJDhM1o{X*853 zKPLvw!f_86g^rHdgIaX(VFBD4w{q!`;K zXi+9!LjW98%QZD;{JDn4OvPG#)8fewF3R=&N_~H(JlEK?I1Q2mZ|mw8kH~d9GQRc3 zzNL$Dr(D4w1qO^kyNC+ruFR+)^F9h3N+JT(doHUqBFm1-SDZ4elx$czhS$z4< z2vgcQ=07yn$0eq}N?2DmX|KZ2I$<4$MS&Ni6Oet7`V5?ci4kBzRw(c0AWUUYIX}@P zx7<>gbTwbn2!A`7oE_P{I}wLSGkQNf2-* zrD@*=XD>YqKuRHC>}ck}nm@Sk%zH;=e~04l$X;3X_elPpwXR)C*O8Uma@QH9>&$}l zy^6(qa-jW}8?M6Vn*l6MB?GYQh-`_VBmXA|y0EYWJLDLSV(Eo-`J`QfBgqyV8iVVO z5ob~nm*6C3g8_5^Z1WN}Ey)sVne&im0Hq5Ip}S1WFfWGF#qJSE|r_PTgwg7LeLg&8KDFjNDRxB1Drm=M)8s=Cmh_%;XvG-z+&5K~Fw7`nb` zw=)iGFW%^^Yc*I`8YBHG?D>>Qyal$k=9RKq_U+rd4pZQIHsWkjWiOU z$M+%u5Ag^B)jk^oZ8SIC=hn8FPXp=N46q$1UZ%`fD6G9V z^L<~tj|<42coP92d>L2Yl=10wkh64PgF~<#6;xI8-ZOglZ{t^PenCN4^`cVsV&-J7 zuK8p8N3M6{OU|Y6CvLfRcjipaSO1;kADqq}eE-r1mtb6IgcSvhJLusjk94 zo(KO+`!vM}r<9ZJOh!-^JUfu%<(yIjV*`~tAi%rC^ z^%#DKfF0~&flz=jg{1d9RAdR#aS+u5x(Fh6vMWFtu$Ca34=6q;S+lhunG-gu24h)H z^Zp`{gHDfc93wgx*ZQ$Or^e~i!XvJgsuOSFK_|OK$_e>5G%^|?Qc47~7WJ|pXc36h zwS|j`Z?y{K`fMl}*RpH*=DRkp!p*t9gGv+U5wHcwcYW}t>}k%FE*#3$wq^U|+Ww69 zb6>?fV+-NM?ssF^&UfRouSfCqd_VZ8M)nP6Y`H+yLik_-}!68ai0_ILP=45h7_FR^!n_AyC{z# z_crou?x9KKW$it+m1^zt%s*pdCQNicZ|H)|xcH;cl#jz8x`?+-Sk{o|jfNLA<-&;LH#F zIzzdHdJ>NjVA$64!-hdxc@u$Lb55x_mnmEK`W0`}cOI^KcSznHxvG}M@OwwHJD09} zv}?I>W%RQasp_cYJDSt7&98dyLfkzw_g6sqHz* zOLwlSQ>yCJzcec7lB&A&vmXBJ==z1to z!Q(OMNW!97q&?0sUn8K5mHjCYLPF76w~CGm7;O^R??bTvlSRx-+LZ#m`sx^j10}$A zfJSMzS@2MxPE&5d@yKZvBUy457^W?9GW_hZ3w)LKAy>&Zxyp=O<;Z0cm|4p58o4S` zr70g1r}$H4sd5$zq`avL7OTWp1{0aA&bI~HKEt$&<0kBY%&wFl|0)SuTE3yJuV2Z= ze9ZEVpMB#`1+=gG-1A2YWWI&))`czwc&m%jpnEsbGgPkamTHS6f`}6RY`qvZfWpE6ZTf(k0RL9Pl&{GC z_fb#cIZzyo0S;EbSHIZ)Zqs{BnUb}dR;8wI=@9)}zANoKF4vq;YEA&uHGv3sln${V zb1v7|y;LSQ?#-NAtF3=;V)5mrds6L?Tsx%H4*l@>XQx*VKe;09yCmyWmg863T9jk(~0=|Ub%$OGg;cwc^(s>gIHA|Ij zkL=bm(*SH)11M^>Y9Rx>STvvFzjkE}z?e^=-|yl7r$z!my9pe=LZH1Q?M{2rr70JH zJ>YA0)jaeG{vh5alxIc zB(R?fvapJw0JI@povIdGciCKsHA&Ln>6ZZJ;jeH)>353S0b>ae#25h+VV(a5<(FIA znXZ8?GncB_#D>&vXNE%PN9%v8_Ip^(pVEqE7aPsGbbY!ZRhOz~{b-=jGKztG?o_q4 zC6qJWn6it*sYa}&g3UBThq~0Xd-I!oU0=Tz#&9=)Nb!AWOE=N?sfGzhx;fdMZyVV6 zqH#LOE!bZD>*GN-$e zB(l@5JH)23E7g^9JgT+A^u$&pT({i_TUKv9TUia8^QO8Yb{niqQtgS3{F+TSV>Efa zA(kI4N2=+2n6*#Y%+YIx4H#TVcP!Z6`Is^6jp6@>%4;j9CZn8gQYDjuIT+6S?i=8D5g zG;D$)BumB&tu#nors|787a4{G4PxH#E!c6~jtZuxZ8k7g$nx4XZGAx2H#mCHGmsP!~Jw=p~W-n(DPUHy2MnQ9Q z;$xb9Gc(*T_)>j!TQL7vO;h1pkXu4f3c~D-xOmqz729ZSU$+?H>+*!xb%T3qNf;)~ zQ!)4~82)dlYJ!aB2lzOaocKm85#?XM7@vv0UYLaW*6Ub}y)?bOhM^r+9i&jvN%#c7 z_{}DWe+fH=u}5UNHN!+S7a2=5?;X2%`VuVaX?e__;3uFdrgbwmqcz4X^wO3ANr3s8 z_{=WNr|o_;Zy!AoCSzx- zRvR35ll!*K#=YYeV9D}{U7wkDyud>je$*13WdmU2bXnD}N}r`JKbRGdZXuc4Vl?u( z8!;G(9u)M)_=f=r9>AULSvasb`WHsDCv(V#W#0S3=+%q)6*6Ls$HnL*ME00|wnVf> z0%%0AU}1lORRZ{E27QC`JB@G0bmdxMso#b&&MY5BsPlr`|g#O>c|4B8VDI-G~7z-ut zsWqn2PB0@dW&maudM23ElgVMkzd#-Mi`0RNh_*6xpAvY%lE{4up18@74ky$=dw`|Xrg`y3_Qw0KF5OtZ@yV+zFUY%2N-Zaqnp2rF$~l*vdq2JO zh774r+cUV8xP#-&K>H!Vx(S8w^GmL!Uh?v5zDg+=%Q_Kgx9q#4`0hx)JJ1F!>|5u% zp7MuR`NJ#cWqwTI$0ScnPD6A8e81u2s~^1}?K&cNA62@K$_+;|6*+G!P2)?p z?9t_u%DyYo^_%j(o2#8SrOunVK(!LswQQFIdzZs2?LUbA_|}gmrPDX%qc12&Uyu*H zAO~K|jM8WhOD%ip_B7_2TQEKCi-*wT&l_47lkYcu(6qq~l}^~cU}0t~=dWMfx9V?$ zHexMMBh`*)Pa@E6IqEl_XjoXwRC0Kwmh^nh9^0AKtbp~p!g5K zHmAIql_F3nf_BS+s1k@ufv8@}dcINFxg}xwl+-gKH6G5KL(BVjKizq3b?33iUg^Ru zNtls$#+98=AJ~y(z{g$MDIO4_mS=xyGIEx)*ED$WbZK zA$1Nfk1Cx9AB?T>o!OJ$E&Zr8$9F0G zo?JtV(lE49?rsl$!6EoL1qK+&7WaKlvwAk0M4;VrU`z>&NrACkU6&NlZ;V8BI}P51E!WkZ zy|>h!JtPJA+`yiv17}wU&OWJpGWz76bXAZCqRK#24)ke*BsKK?_l<~+R+P4`G!^9& zfGLYE9bzI)i25r;h)#tyl`V_-=8Mkbu(HT$huLa_eBcU}8`hlAQ=MnjuM?t8SC+`b zO_x;!%5p>Kxs73;(<+ikds1$QHUWPjBDHTViiHRihZ0!70|9y=MW(cl`5jluI2Ckl z-e@^dP`)r8Quk>b{2BaF{zTKfxK%eiAj$hvb9!~wQtYx3$W5zK!cUqF-F+zKd~tLuDP*S}iVzjS=rCfDs% z>h@-Qv`fw}cYQV}^P>tsN)|@|HV9;2bLRBI*<3yUw0>x{en_q#R_ce}9?P6qs9tOB zRay_N^vJEFO6zFm688LJT&~*%Qzc)^Q(xDruPYl_dL|o^eft&Peu>>VKbiAui}>H5 zS)xS2A0V*k1W2|{i|fn3gYeJQTeSiO)oZbA39480H-8lt@uqwvmp2wKjEiW&m0;7j z@!ykjWbwTD+G69BCJm`LBwa;|l1WFevkg$jLNC;ykkV6aP+^u8)S)EHPnTQTZpgY) zklPoOu-5IEu7FgNwf?v3xBfb5NQ`0TXh{E4<(n1BUr~$S(o4T}>j5O}VZ95aD=oFr z+GQ6c{Vr>}A=SrQecXd@*QWvx;ILa556JsLtw$UYJMk0m3}xzJ7*qnopps(AaJtG; zr!gC;Dr+mh_Ka313(WwpYZ`g1tWS;Am9DYWOS}cUpa-b2_BCC*P0re)nvC^mZFx~? zbt(5YI*cOCfsIRfw#iv%t-%H<{yR4}-#xx>g*de*e4hi;!NGq1dz;dPw z$*?ktM$`cKLyh;)PdnJc1cR{TtNjeS@D?& zP#sRxh-w_L6Jpn>_@9ukuR%jhRv4=~P$$^)y%TDk$xn{)V>((?y);TXlZG~E4`AQa zssd*UCXiuEK_st)0xA=Ko0`bR84PqJf2sudN$`GRJSH&4I-=U{kZ=S(N0bY1;r$4S z24O&&r|LY5NK8H%<}jx}e^ILHBL@xgkHj$>@LJ4sn1xF50TP+00Huik6>Uq<<}M&q zYwS4z?J{@)f!DpwfB}rnsNC_DVaBXufwFTvO@b7ZkjG zF*Ex1n+uMG>+6lJ*{b(pQ5c-IGl^w3&I*GDk#TppD>4k#T5RvP4vMrTl)uZWD`s0<+ejg+adIBt)=aQ7o_g}KTNFb zQx1a_y`~(zCbwKyTCQgq|@%Mdy@1vtD{c`_lrT?_#KP`Dqvt4W*KRV&nl))djr9$PE}3Y>gCC zUMnyMA}QG>XK8WHcSg;mLgC1VJ)l(Ii6 zvTFef8$=gu(0jObn5KFdG=1=J9{7xjF`d*I)d@f7={$=u$A5}Ll9AC8m!ogLxD1r8+j=$OayS<9w4LvcNIc0-_ycm)0XTd3nxFbQSACFo$i6^|+d)Ry z{KHJ3cpn!*0;gby%QZn-{o3r54P%tQ0cdJw?MybOaJY4823KJ5?1+0TMpmWdmRpoo zhI}5uVaewuTiK#CU&23Ff`1yy zPZ~x;SebZES6Em;qn1>uV1HE1fK@PSV87*y0>ifT7pNO1VOZGtt6&law#Tm-r63m2 z6od1~DU=k~3?|WBLcx5j(%S!o?VUeHS(OE>l@5Y=Ha}OHjWtq5xF}t1c^kBro~}t% zZ^Inbq$(a2o$dn`&M+~-oD~`P0PEgir~t+ssd}{A1LnT2(C$<)RbO~wZ((maO9;T_1!caWu7kM(8DN|{xrRHm z&XzU?WOJbz5)?Q@m=I`MxhSj%PU8fcXu+1CU_@X+*0lE+vLNXb`Bp*>A{C;HuV(2p%~s<@ds2p%z+YW^$eWCZ#{&v*j!`Hrc_pODGdYDPpcIy zYJ8p>vNXETh zTbJ>!Rn{(~<;rfQvU`K;FXg|WATtUVR|u0q!PFnfjIGyqW@B>wZfMNvy0dqe`=p-3 zO5L*%CfB!Ry-vYEQJN2AF63IevM*)Fv*Sw39>(Sk6fuDP*M55Q zC29OMWn7eRCX|~Axj!NKcI13@Qqw`%caUkuTH2N7-OKHA^YBj(oRgc+WiDu@{Vy(! zO0|P>?VwUSxZEPw9)&GGv-#iR#pjL9AM|8*E%&a1(Vu=Yxq9@5+Uu_QIi$24%3NIM`=0WLR{29KgHJ9<{2`eiS9oyyEuRk_l%5%r2cK64 zpWnbaIy?*Hz|IAa)VMQO*NaLkTXNOS>ve5e546hCz;n{*6=mS6Tz7TBxrRRUWuISm zE(>zwVWsi#f;-o{XSrVK9g&V)m99mUBZAy3NRzkap4;Hv~}WxTE@Tigc%vXcLzk{VUZ1!^*L8jKgeS;5vIVo*`YLM#}g zBx|wq=mY`qg;NdtR>$}<_PGk?-9c!4=}<6&x3O5# z5~5pKx_t)TWqRD9jZw-yX``N4YBy*Q)JUE<`o;XDI#6;6{H210+6v`sWeQHgHECz1 z7S$$fRT~+q>DMN&thz00DJm6a>Yfj^UeJL$`rP!D&L6vKI1t)#&L9HOWCGx`8 zW)nn=W;{jvNQDdYa!TcM3(b|-%2id_mq?=Ho=3*fnw&sEnf+{5XGS8;M(dE4hT z^^0}yzVhBHuqbNa|60pmZpqHe4SSV_y_t&j+J=l*cOEPK{>TqgQrk(%OLxxSlC50z zcStxUW^}1Z?%Ib#5Jz*BZP}sK$}Xw03+Aea<(>oCa5kI^RxP|K2Rjzpb+@v`*QDkF z%^P?t&omTWEBSkI=K_@rFUtXb0f!A3F&uT&zIY{Di3&>=ego(1t#9$pv5*3?(o45^ zw)x5znn=k#n%gQ?#k}Gc$LJpnQy@Lj{ofYtnKxno!xp4kqNz?ICcMHK3hWX95v$yz zSZWJ6jM>z#yQ~axhZ)NVMQ4Xvu@jt~a7uBZ1Ux&wQ-qhcb%rZ)&A>SxCdmbJP=K8W zxvT+)z%@=gi^C!ffwEk{8%tU#%dMZ=p`YJTLVL=RtIU*(y%k)R_g=VD5YFFp!BZLI zRZzt;7Z_i<9P$rCLIP#C5x1n0Jc?bAC6uQszJ|LR+~CrcLP^SBa7#-COmUMN8{F0W zB#B7{Sjb%0sNaxW1c9^JEaN-~Sd61@hI~eEkk+aq**hHNaB^^}s+dB71Kc)eFq3bz z)-$-0VKo&wqSd5p({)Ku4EalQbW7Ey>L&h_QaL=OcYDLWQ;?7 zstV49)>Z&GryJnXw#^C9kbpE==(-}cwzV{4G)+{JseZ^xitC4*#3akPW6(Xw*@V?^ zm`fmT&_R#9R@WUk#q}0yxnoc>sb+}teOB#6g3Y&34pdqZ{?}Lnh6qF#rNqc_LO=6g z`};i_IHA8TM1RhK&LBe$X0g5LZ!;K1;RR7h-p3~lk0JUa1$5*M{4ZvryArqK$-EUf z9cn`ooA_LkiEPY-ex+rip~#WevRa~VDW)>Chrhx1myB&=V7?S>ITX3FXfh)7s?cuw)M)!x z%9}`LrFTFdaMBxTHuA;@e}(jEzHV%076Yl5VRV+C@8pFz24e;T&-x56i+gkUI~_v7 z`=9@nZIDqc7Yj7ZX$z@lkfB8H=w#?u_@Id2LAEW(%|n|g3{WYolbSQV*7~hzO&)tz z{A(z6GZ~4!{S^)(4C0v`M8prNly(aKI)crV6iHT=<)X`AjJ_=+KE4=7^{W})CgpW{{TT>i$-8RgW+Nh1mZuUSO1uTpHT4o z6#NqkW-0gv1%E&RnH`G%l!AXo!9S;Q|t8&i$jayg@@VwOI^z@spEiDdl1$xJv*QF9A51?ymC$&drf-% zj@)xs>A5R=Iy2sd*Fc&-^>(j%N%D8}_lKo3FUlugQck=id&d>;IOegmAy-wuupb;x zZR69Ly{k2QrQx$r`lT1ba?LHJ=2oU`4Hn$>1G(0|<;Ew~%GgV~skh6~KnF}W&{5M3 zWtnoESRH6q{KLx^o`fX-tFr%9#s6xdf^~mYCaD1jy9SR*^baQU41O;5*-wkVL(ewt zho>XL$%`u#Qp-ZHRiZ!Zhqx~k!k_A!(%-m4At4?Kg+v0Dn*QE%LH`>@;`b?yAiU~g zf4D_FO$pCXu%CjH6pT+ARLD-g7)MzFtbK629p>hauaA_ zkch#g$Mlv+VetsPXs6f?3VJ9Y>W~F}h$Mc3R=;KZ+pt%;JR2MZgIsq{Zuher?jBeB z28STm%x9aG_5(`u!3_tV*IU~+T=cSpYixsaN^~)$Gz=8-vyZgkvd!(PTiU87yS>uO$t;2A%~L1U?Fc&P>=QrGY@ zD#uT8P+RHZm&MQK;f>(pA4i{@lCHfdjfbT3A?5VzEU`G~X?4}*hWCFqvf)5%eeZ$g zxzGAHT!>>F4joy!`lE5>nX?;iO7L*~dzM|xb1PA4&k1Gc$&FG<<*dO;8#g$d+qrAQ zeb?q9HvsmNt83hFu-JOf&JCP}L9a@<&K>z|9r~%PxTV=Kln3&uJ+}< zXfOt#{c+dhxhGfua9n!nW#y%abRnXg5tQTbg}}SwpmxA@#)gf(v~St<=&;g*X{VI+ zJ^Pk#tnB^K$Yb#*^U~of%Kod$o@*N}dWR)-VV^M=^^;2wPGLJXx0 z4sEz8hQscAH%ckSS%Z07hpUMe5`JhQ;b-Lr`+;ShbmJwol|ALU`j)zvYgSw<$;T&t z^ibM)R_Pksa3DK-Z3=GMxPv1bu7j@X4GuxBTOZ^V`;W?`ew_AyV#9&-^}%5+Qwi76 zv*FHXXWtYB+6Je+-PbCP(zbsNakT{6E>PkcY$7Z1$lBeaaJ& z*zChWoH9=5%ZkLuiq$BO{QTGpxL+j1G9;@N*1tibsY{7f)CeL*0nO24kpS)?r4i?> zN#UxfwG1j*Ar<~=hX!(wisz68H<#Fp5LAK9w&C>HoB#(1YN3XZxPW%&xDn~sxiSB? zDv1kfcaH0qtapxUlB_py@nc<`k4vUjJ5UuDs!`GRWB8@@X-T n|1B5vS9syAdCP-k+0ZO;Sg7YbHY}H|!!KC+&$oHY7RUbre7%DB diff --git a/basic_function/__pycache__/CSP_generator_normal.cpython-313.pyc b/basic_function/__pycache__/CSP_generator_normal.cpython-313.pyc deleted file mode 100644 index 90c3d72f1a63b723e914900de2a9ad63daa8df46..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 23867 zcmd^ndvIIlb>{`~Bmt5j0D=I)2SD%ziVu1;;yO0yQZ5=LQlLKrrT|mY&yji7y-GVchy$Av)b+K*pj<# z`bV?pd>0o0K}fPQ{iCz{Nj&$v-}`yKbI$j}eN3OsW zp4TZS<&+FlCnT)IivB$C=dpRjF9pApl@h;fSjNgBUN)>?74Vb~D^IA{d=jS^E;ylP z)x@uaJQ`MGrA$-}fFR^}*pfG-XSuk*}#7knP~yz52A|I)m7X4VU# zh5kh^>kj(-jEnKU1Ob{EchK$f%)0{tZ{QRJC_TeR$5}5XmR?x!V3H%wGp_SXj0dxM zgt`QpteARq1>Rq$o}*ZalVbCn5?1QeJM+q^7bL9ADJ`d z7uW)4K2Wu@0I0^P23qJeI5i}Hk+Trqi=9P4wGdiNLQ9-lc-J{gSpBTdQOawNu$Kct z_xy1$gZ=Zf4{(A!;#^>JHSi~ezeit$X1`8N?w-d`;wG*0645`E$}=W*PmF>mohJwC zlpv+4c`T2(PnC=0g0+ z{j8y;?Bd&0UCLK4`oT`nKv3%Xpb`{(_@sR9SVBP7Np6=(+X6>5ES z{lfd>8>8=^yIpd-{U?U#=-5xp+<~zueKtBV$esB-W%Z6fkn5+>zZZ#ON%UyA}@2E8yVr%i+&#y6vFxZ1^>L)vxMoT^InE8U|eDd zuVh@nQWgTdHsE!$o;lZyZ^6q1fE)70>=0LA&b{d6l?(1mu8T7-@tX5KX2vDH8Od9A zo%Pn@0)95=o#AEkzCe)A_XWLd(C?oQ@XAHjKeOZsLN8h9rnhm4m6 zyz}STY8WQA0tuN3Ua`ElW$>(#Ki))^~B1fsG5CWhf+rYW3Iry&`z-v`g@6gkPj zCpZ|oRENbdWR!!zg&#MtW2sX&ontC>Qqq#K)H#`s7f%TBX{APB$Z*Jkqf1UX*U-DDZV^3OPSmNmOx4_D)ii8}f`;DQYS;sADNnBUv>(YCyCS>&o2{TG9aNMvf&1 zJNN#Pwo?5>YKU5lBvwxA}k*KeBF+@%*BII zE(r_(`3m)kimGae1mAx7&6n?1cE2(bI*CtiQ# zPW_RM!C3uBSaz*wxo9O2)7l^AQ&l}X=NXqX*)+mDKatBEMJ@w#iRrVP$YsdPcV^7I zRGckX8CJ+I&)h&yo@-i;3rG~efOo;p!HqNm(g#$)D9TwXG|}f)Dg*2N~jDeBcD8zK5!;=lbvlckd{*m33wwML0U-Z zkP&t$Ok>Fo5s0nDG|5>i1_4WuWIz;~4bJgJsQ@C=@rC%l;J@f~E#ev==~;geB^nZt zpOVma%q21rzG256%`KSihs1$%K)ynKR76=SLd6Mld8jC1GQVL5p@g~4#L8PZQ%k5Y zQEqwTNJRE_(VIoFiZ-si4U!t{oS}2=OPt~8#+Nz6*-(DM*0iarN;GueE{xXChGaii z*@d~q13`4Dx&1zbpq!seG!SHe*tEH0iR8~nCp%hz)iScFRQY3AdX@uyNJ>j^ML1=c zUfkFjxl-FPowWrtL`LqE)AHnMfc$hG`a#c7B=rnrvqagL6*7KwDCxO(2-X9NmaGE^ z3bCGmkKywcec~FBUv$qec|k7s!6}G>Ja|E-7QAvn=9Rd!rK_qE9g%qweFbe ztQ5o00CA30RH^ye@bYj(vV01r%V5fQ(V97}JtW`K6n%Z@YeQl8O5L(AVqf;hH1+T3 z-YvaV3UW+Hk|;5RFD#d>9OJaM$PlM(iK<$6&M}TB*}nsD{GvnyGbsT^YR6p(j5u{D zL%O7cnNIC0F$_jL#fHUM12zf_57{W>$T8c=L%9fNOVAMENpwee%3u&B?yAE z-U}(ygwG%JUJ{aI#yz+MsdyDsO>#rjLCp4fBxr|#g}w(KAZp`+4)Dcf3Bzf_%=p0g z=4GcKHh?pgI&k;`qQ(hp-Lhi=rz2_1`e#7k)F}t<^vhC9=vYEyiEC;&O--aTZtvmj zJ?rwA{oq~AA&{|+75CJ|akYh0TUNdhQ#U2_HIa5sUl;MOxw-n@_3GOtT;Fh1KOEHz zCkjjAg_T@kWuzij*pkqfhQA!Sz?C^xJEHoQsHPt=nCsp>5lqzc8dC;NENH)rTSRzl+mA(v{=hI0sscn%Ts;AD%19@#Asl%oO zofo8=CrYn^L~|x@#3XFJKprC(HzcEQ=4F@%Q8CVP#8jtnO3{!*4rV1S$;bgqKifeY zm16lx{qMI@vQw(1v(SDk{4w(LDbvp^P2Ard5)S2jLF_#a6%bC$qFhQODfR^VaIb)f zwY0z&XW0xKDq&e`f^hE6CsjfsR29gJU8wjdHrjX_dqL#k&TJP1~pUlQ8&i-`vcFAUCTfDNLtL)#fZB`!r z$t+iKHm08l9Zi%~yp{L1{7w0?e^s_7xs}fu+e60_8q*txZw}lTi1d8pG^g41wUZy` zqfk~rS(|^Zs`%AIB~>{g2_VZ{uRLrn*k)~3LMU0dNI}hlSHb1dR^bGeh||f)aQO|o zanaH8jAfAH46cEb!5g>-@JXHC$=x2v=7uWaa4bk_DzOpRq|{uNa_Gm;YcQ8At5N3-jeV# z)Yi#KeM`#1!l78)~xB6ROh(E0|>AevFt+iCU-#xS*5(;cC3Z z$IE7X=g+e)%q5xTC7u9C!xUj@0ucMrO>Z2CTlUC<@8cXtxe5?$kA?dGxza#u45i@<;jYlh zl&lc7?TZ!e`!qG&`xGtxNC7phl!SXir!s2jjum!)ni_!Fi@854q>7Su7UU?R%Lfl( z^v@L%Mh|l+I6B4Bt;o6rTNo^pDNIeAGsU#Mndf~@Mh{Ff2P4OA*$3x4Eg!=LnC0-A zt53oxXvG+`E!&Bf9gKG>oJtwRD8Z+6s=%iLpUSD4Di(9)X42xJ(2Zrm7?o&rYNPG@CZ$(%NRp6 zxn?;S!)E9inpLu#xE_){Y@w4Sw@J30Ke=)y`%m^rYsMHqlT`)0V*PMJmcmKf_SBZP zJ-MaK?u=R3Esg;2rq<*fv@;zUtvQ^tH62b_s^`SSo9zD{x`l4Xk)`*b`h#x`@En1? zL=1uss0Y>n(@F2jH7^0U)wWa2P>3VA48tJZ#HZ>lXjQl7$FL{m0!N@v6jkHai zXV^>B7^*W_j&7tI#57Dd-JRt1z<+@+DhS`-ek!NVD?F7+nwNm<|0ATg0v|Vx3Ucbh zZITn3g0dX!iTdQ8=*j4L?v!=Ol)ZG%Q^(Ov_dX^~`qXa*F4dQD@|)>qaSULE0|#8f z^LsC|kM7I4&QFS+89OEV(og8w-m%g2yXX^T1A3p+oCb)&+6Ku7ch^+9)G;wn`U$0G zoicRaWC_e^u5&IYuXC37DVY6JAj1f545Wl?2es@6mEF>p=d8#v_m$#mCUd_Z%>3ed z0a+-u767i7S_`-~z=$_-fT@@`$W%@oVyY&dfiVVMgz0w{(pI{ka2SvpaSTq^L5V`m zAi2>M%mKQRIY?JAhoGl;kCx?xF`=zpaN3|-f~O$MFwa1HWW+#ZJ1!e(235W46pNOm ze=PH|te>@68NW5)_0Cv*0qd;qqL*>x@e)8vVHhzQAhL}SE&|X7^8jQJ>cRJb5JSU1 z?mO@H&0D9v9yj3Ktp4*>K@$h;zzqDsg_G6eUz(qRYJ%43B_H_Q)+KEn9?K|t5eWESbR`dRBV5|4Gib$7EhFl#+v{lb@ohNI8-KR|rD z(z-J%*ek4_%N}S3>vaQ`cLp+JZ!r^?Lr@SRCS1`^%qf>3^omNz#`v#vZ|fqWBs}a= zV9pu@(<ZUJI~SS{d| z)lF~`fEV^J;DtND>;Qr5U2rjO092ri&m!&c%NQ&7E-VHwv$%~}JgxZx0RPNib|J)&i~uiN@VXhWLiuNibsa8( z0IBSGkuUQ2=KTiKpCe24IngRH-8E??* z3A%g%#Q6b&4{;4(vEns;7`X*s01y(>9>-z9y7hacFj9$r29S$hXdhrAUvw>a;ReABjChE2%B)xHV-2f5for znSg@A^O2X4mcm9M*m^GjZ-hbQ<^FjnMtDpM4-c=NcL#zl&m8bjXa%+?`4YhGkzphK z4ZVVCRVB722L@G)!et-$DzJSps;UY@Gk`)>7l->|>WVF$C0fxH({)ER-CKsTYeQFu zBC@!(o3nPWk0U5G|DMhg2lJ?|CNdh;)kJj-0O`@0ueB|=#dLP?mxIYQ(!bISFb@k{ zp+-hm_D2!0YK@dc+TU`5*|h3bRjjUey@adV%h~rv&HDg1f=P2l+V@2*m>-OlRW&zV z*IlcwSmgmuRUVRu=}#0?k@o9nzjii!Y{kAj^v3i}?{)9;=~!WHC~vEB7ub|*Zk=CaZ@svF zm}@=AX$}I?L9Gc(UhUq}ltgu>H#KJx`Z6?Kmt8GeDS53tq=3Gd%Hej-XbPzl#`?Ii zl{2=kvTN-zV^>JMWwl3!*JM8^davlN^$@422n~b>6WLg)@MNN_`mL6DP5+&m{*C%r z%@9{Mgy>O|fOw3kHZ_ifrYx?h;xtt&FWuENz#X0K83~|u^ku>XM(umn$D>_CxA(_% zldmX4l2Cidea~FJGJ3rQVz0!5#dsUp_nXu%ciZbpB6%a#qNR=qk0njzv{_4_-WTpQba##(c z-zweGmPUHQiI1!zEE7JI~V{ov{;xD4ojfA7Y{=$X+UUjY<8H*hxEJi%Ee!m4{(b6jhM zToKczwh8Vnt`#pAN7SnWoOREpw)4Jycig`3j(y+y;Kt0K`+n?;+0R6sXF2=XsMeaW zx2|5~?A_6to)7HnGe7dZ@B3jZS2GwczF%2=bK=HCWO5DeZo9(;TejA?t?!PlZ@np5 zjO|Rg=zewMYSrE9-sR)r{%~;1W`EoGP2;L;)%~sNwNB2~v%ZJ39t;m9n)a@@b4>>} z^0}s=sG&A88Z$I}ET^n}k5p7;d-!O=V2v9ZI77obdGD&et%|nyZye>?2V;#x@y1cE zaWpzM9z8n|Yn;4mI0uM|%J%!^6*nC>9Fd84rr(|W_T1VRW35NyttYwGlefoUjAE^0 z(OPG$d_1g#frYLN-WZIyzUPegjl{Z6adoF+6{q7B6I=x>jc23JeJ)nSfhoko1n7QvY#m8+_-Twdcc?qR8z$yv9APg}sO>FF{Voez{^oEYzD>kgiUj0(U zv^x5&O0t|zZE8>B8rx0jV9~!;8QFWM{YcDyk{Ry@$zo2ynBsZZ;zE92rCno-Em7VXX#znt&hemhr+6asXQ#- zveiZ=R_E^8dX|U6M^+5?4CZSmmruS`5w~^Rv30EP`BCrty+7O=Ju}YP=(~oqFyb}6 zxX{|;wY_(0dpAaJOK!Js3`Nh-(a%jsPkW+AW}@>K68*=bUSHh1#Cewz2M^<-_*kA- zHzbK_ZGfAFsnYk{(f34qhobZhr}svEOia&g8O+i0u1!OC!cv_mYm90eVU_F3R%)U1 zy{i|#bztq;J4fhP7o9Me!;&=lrM)#goM`O`p8#--Vb?<`Roiy2#JE!TS`|S4!txB# zuqzICYz@HQ*ZbEmY&69hMmTM4SQ2grWEk$}!qOE9V0i~0x@1(bUO={hsGoA_dm|7- z0p=$`;FoaupMgE81${_9mCBS0JOFn;C@f|B{sBnf!*pWJOyQGif}cD=h1+@o){kr_mW64FR;kQr1L|>E4m9ZkC%IjNQOR)| z1iC4dl!P(BsKM<~x=&z3&vu`wJjl!xXOvf-T^iaXGvagJs7(ItUQiMz{}QP(^_%7M9b6+lrr*9FuNLBi}TB9=)$D0FvAj40QZM*gV-l!N-{#>#wHo6 zqf2t#yU@C)kFwYYMxXv6Af2+4gH7T~NxP4MDa}|7rD9JqevJUXm7xo01NH;@Q;JUl za4}|#3+$L|w{6VDfyqb&cKdv|g))jaGh)q58Ldv%R+y{}atUjPF=f=5doIdo6Pd%L z#agVG%*Yo(qK1F71!#Q%i{f7j`+QcJf&5YnVB7~VJ`{@<0%HG|{fnQW``bgrRx$v_ zlS0tsK`RV`1~dcuUg5NR9+mbPR}g*yjTS-dFrkwW;s_$-a;>s++i{ zEdeNV2>?Cb(65Z&eE$0LF>`agqy>;2y0V+i*PA03uD5YIN9ahR%)V-mmbHS^ZK{rV zqNXMgDowSKapFTc#dO0IX;`g|w8YH2LnrQOO0GS({9MGg?20(K(#Ci6tK;vUyLIk6 zRa|Kgr|AhDgLB(de(mDbiz{=hqg+|rTI~;--)r6|;W`d)IJu5Nu51v0C1L)QNw&1Q zYeiR!R$k(&JJyD{&f_ud;Oz?^OY=)EC z&vhNUE#A+UcPa8wR~-Sy)#z*OsuRwEK5|@!@xhcdSWfeHSgyt_pf`nivDo^ z{Upw(He%+=o8tOiG5xN;wp4#oN>#LeypPhgegu$lZRxKbo|91eNdc4i%C13&%!9m5 zbf5_fP?djzCG+@S7T_(RdmfI|l(;gfqor`=oVx5lcg~*)T z@IC?uA}t5u1CC>4PmGe01K32_MiJ9Av<%C2=5@+ysf-^MrGTv?AC3r~za{j14d5~qSVX);dgl#{ss1xJlIA{oyl-f-?; zPQPbxgoHi<3@1GT8NCI}BV1SKpZ={U^f&OB{tDQ9&%dDl!aNK8B|R5rP@lW!=~;zK zD;0nb@^LLqw7m_2R4n)JFNq67DQw} zo4FR`S_eYMIioXCY+R{%^-F-BUTI$$4ZjpF+A>$HOnjpvbONjwuYNh9EeVe zwyt=tsuLx;*Q&P%F6RmitN-Di_@jZSZGWtA{}#4py=2X?{_JM!;HHtj(>jQ)**r6j z%>hQw8R@4rCt~|nTe|j)wna)-yq}^W2=@DDYslK@cU?IeBesgKstc>(bOjs4i=d-M zRkmL}8j(>6d{{~Jc@b5+dwr%_AZVJ|7&EVRZKOzu?=M;Y~A^R!~N>oDG#BBp5VvWS5!HmF7 zaLr~pENHIgKr*fgn z#qyGB8MG1dsxa?o&q);3d}ud}RJNlPVod~&_#H+wl~Os!9;Zwy2^=wH4k++~8Q_8H48&L}gt;``al)H+p@9}-m^&x9PnL{PK} zF+LNLNK6hTG(Z`1fDsGCaV22H^oZqDt?3dmHKPNl0Rbe^`D3WaWI3lg{UsStJ5l{1 z4uD9xS18Cu=s;Z^PpYd%P!G$-Y(jV@6?X;7R*Z=;)8;8;8pMr(VJO>yiNPqgn*hGb zY5iD3b}AUNaxs;a0=QzX6sB~F3^1Z&XIYb1CO-hlP%o;y1dM3GD}365m-_&!`wp!k zpt`Dz);H}e30hwz*2+|8q=l>hY_}~fz*ZO={MBTv!`AE^1XQOZQGqkROHq%xDs#5`QhWz>ko+@0k!dntl*GSuP&Pb!! z9+&e7*?GQM;bn(p^t(Dbt)W3n};^>@ipeWcRc%tr@M^o7oP*53Ol+ zs!}~~MRQQH|82B`#tu<*8_pKK;g(5(>gu2#P??zBbVIIr$*6a~P*)~}dL8s`ELWW0 zr*RoE9RQCG2zpgU4F@yJ7W_~{=4#K~h6b_Jr`4VTV0(J0O&RN@POJfkQuIHjw$wZ@ z?N9AnE#02crrHcp9W~a!b2R8$+AhvA(?NGQi;12HX6#Hzve>??v*Pg1@?dmplB;-6 zMq6{I1V9s6@169Xr}o@LcRnUf>dR$JBk+~3j5EhXH;R3OH3@r;to?4LhwjR`GLMS= z85C!|Wl$v!e&|Mk2e@c!`hdEO}3;-T*QYV_$_lS9XK$)p!w*Gk(@#4a^aM*R79xP3zKy<@k$eXUo`HfL_8s&QtM0!=-@ikGCp!Bsk~JjnAwj0e z{xK3%h1u^RsQ}{WBT6$sQ~*Zg5@v(_0miN)c^}CSk$iyU-y``CNPdK51Idq({0Wl( zh~!U^{3j%ThNKk9pCkDylG{jrg5;-2ko#x+MEO7Cs|!h5eZPUPU`nUhC=w3I9V9U% zn@E0+EW`G}(Ohx4L+FUI>=u8)ldUiy2A3QA7o zeJng4QPkeyJfN7WxF?~j{Qn1C^ZzEc2B0utSSnfR51)&au6m=^&Zwa)l)qJEi5J!1 zDXNcLjuo|sq}lKH3`Y3mioP_Y+R_=WHC}C8>5l1YLQ3)Xu=TML8$lIWYva|uTy-x1 zI;sywRTcNGj@9P46)qP95Rs)4eoP8hcl+|ot3_)UHze;}+33FA|Kq*zXWwvT=iH~Q<11{TJ47GhG^BE?|Ig1f8==I@m+?iIvP5CAMQiSZj?n#YZA`f5jwF| z*%Ytrx>MP;_Cm6h%9Ei}_u;Z$W`uH3vd;C!}*E& z?zIbCec$>JSAPr;GLaJa+!p{RD!KrwVs3%!0&PWHYlqL$ymc7AVHj;Wup#4Gj>c+_ z#cM~n+L7DS(J^PNcKoiE#_#L3+y@+x@2W3S_D;>ajo)ru+aGIsCf+p2H4Wa@-R_Sy zor+pd$Bbt}#{rlD9Z}s-MXJB22kca={RC$_5ilZj{e-u9IsH>0bS~;Bn9s{J4GNFgtoZXy$ z_nK~PG^XzX+?7@b_gQGfjLh9Nx4kkHIubU(RTO?4?~RJMx#^C%X>HFBdf)5)?%vxn z&OCBga|%YXd^bXJ9Px_XcPe(Tk8Vgd+Si9}%cJzO(c@Fmea}T_d|Q3{*Qa8ALz@%N zMyGtSi5H@mz6c!bvSbK37(No+dw<}24{<4#{gSB44A?U?{8@K?_D2c~9seHTX-O_% zfW>CJw5bih8aXSOg)iNm8W?v-SUtRp-x1UybIyaj;zv1)0SXLXm;sciaG>%#KYWyx zq874~tH`7cxt3GcMy6uwc23=w(3=uT)QY1>IO{D?839BK*SOXAT7$%fqzZ{0NiC95 zBzV>fVgsIJr=aP2cslZCfx@XW@QmQO_xWew8#zS~G%z59Q2%hqVZhIfB9`3c;tO1W zp`C$`vw~maa-p!6j45`x2yPud`r~E%E*IMg#jv1`Q7oD|$rt&^=iAr@eEmEU21y5! z5R$(|LWGv@qYnp}H3H#FT`mx`g7A@BfSR+u=_L@#pauh`G9xh}DG&*I0uw-fhx$ZT zB5!>lz%1?qcmz#h{9jiq9(Vu% diff --git a/basic_function/__pycache__/CSP_generator_normal.cpython-38.pyc b/basic_function/__pycache__/CSP_generator_normal.cpython-38.pyc deleted file mode 100644 index 4eac67ff271c8e16a847713e1674a10f17243250..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 19038 zcmd6PTW}m#dS2h=iosw23;}{SwJ5EYvkQVEwbsgvRw9=aC2b@~5h$%aD{iwf-54}5 zH=gbxiE&S&n6k4@nU4J6+$_5Q`$^?oH-1P`Nt|0b%& ztBO**x$XI8yX*8@uGtgayPo6vX0Yv=*TkM5*sWLHjw|e-E6hgA_I-1Ua?U!tOZTHcD;L z*fxW%*^~phW&<^zV+Y8|af<0Ua(LTpdUxFp-VMy|7AD_v8<;+Z>Nuv~vm3726y1Jr z6_ahXy7w^J<{swJ5PiojqJG!wG|jfx@!Hop@X-thb&%G~VtT5gno zt4Gy#E7Gra_G)UBYy9c!H-+mWtKVri@Uzg^cH3UVZdtpX?!A`lG^w)C?a{8`T|AVq z7~A6QbN%MO{WtXa{>z{K$)iV)KGiX7gT@~J5I|kQ*Z;3LAy)yYD0R(I>$>y2Yq*)6 z?26*%+`LIRW7Ww!g%6GT3_yUMW}Fh9N^XIlW}Pyg zX3zqE>byKoL^Z{R+MyZXPu|;#tF1Z=0v%Z)F$qQ=0RVSZ_zkUW1KHI&)}_7 z{FdRaVCPReUqtyir-t$~&R0=(41WcTjyk+`9<^)E1)Rha&I{b%sP0*;@uG7Hbtcyw zbDqPleZ~0+^!21u1WNv7RGGv&`wAhciBF7(0P`aVGD1X&AaN(=Mv9+rAH`4V^Cs{- z#yg;OVDB=%wtxYISpAL%l%2pY`!I%qQV=2%61zUnlW0ey+X0DU zdz+wNN9dBj#+`hMJybr?nn)GpgBw?5=6koV)byy3l8wRh_ZI437XO%@|}Dr zpU3wazW%>Q;;5l!D7vz%2oTj>)j!kGR3%Ux{ek+kYN&>ag%(J)P)Di<#!eX7sr#n*QbViAvI~y=-1J6R1iG&FNx(;glvDW4QK)|5WEKRz>@}J2YQ)--Rl5uI)~ws&)e+u+ncU1&$(;OwN>*X z^SQS#f{`VC(XWIFT1`Ow!8`{hvTC;N`!&;Xx4=-qS7=gTZcUfF@a=Z5Ma2yI@_p0m zpm`H)B!*Awns#0J^pB7XE^dr?nAz?F88@-c4qgk~Cg!;52KQW`-9=`2m!5le?Nbf! zi4qd9o2+nqt{pU}!{kGGVl^B1`cESX83G+`k#a=Akz1xb!6xTHkBoJmm8jwcQAY9! z^{@8C5*1MLgKu2e?zY_v;Pm3cmLo3Q@W3r^x_)q>w-;=8I~JsY3*h)(!`kY1SVFjf zA}ggjR)+{hYpu5zh%Br2S}-%7QNBrm@1~7A6V_1tN_FNea?})#&y+fghbX^k_k5 zNMDOiSq>3NK$vSu0&6*5+wJ%eD*R~vXwlMSk>zjOJvXXafK0~%lv;P4dswF9_5G+k z@qBh{k~lsO#M^Z&Np7vX-B!QtM$3<@n71VDPG|2&TBjFfY$5EuDC>6xI*SY<0a3xj zs)BB}?1s<6{SD@g$}B zZflE=md{Eo(5a8z(iDH8i6}BGL1@HiSe4&-^8f5%0o3!LY274FKc&<4Y@TnO+1MfI29VOrbXJ? zv(P=dt;vlmzDNhv=(c+=y@MBg8X&7^%HV59H|8;rJ&v78Up18v>=RWyj}9jW6-1QY zrqO6giZ-bI*ki39+7Y61G+{mq8VEbV{PenV>zVhpp*qw8E!1|vp&fPC5XHdQ&xFhY zcfhe;6XUNofy_Ss%E&z*E4_9Ibow)c(N(xA#xd}Y2>3kTG>P7hk@Fq{@JYKwsub)Mt zWERx2IyigubS8TC51k6VN?$&emymJPAxSmBk6UX)oX5UOvdmB)8aN%C4a%#-OrVD< zOGud&FYP|b!df=Gv^8R-~?xkEIMN)X~8C8r)-y zF}_e7X``#kP)~CPL{kmyUdPv1R)!gRhKOo3jbS#-gxM{!5NLK4$e6;|^@JVCT~C=2 zq{Cp3^aY|;k_SnNe@r&KIZ;I#f)B#rp_>SnMFDU)7YH5-TD_>#FkIbsYMHuRw}o@1W7ziRHSngf>Vm=Iva>mxY& zzvwr?drayT*eygXRzAMxbdm44#Cs;B>_iyFmV#uYGAK(Sqs3vzA#$+9XiHgKQkYm| zO@|;^bW^zYuG78Op;H|l$cU81qNX?tFstPl_Ujciw3uPXTV9__OQhT@n~d*jeU_S} z`%#~hqEk9#lwlEO5^so?Z{h2cOsbR(l`yBMR)9N$^G7k~bL6Na7?eSWOr*Yz3}H}A zq^?gAseRqi934zvCl-I84l@o&*Ny=-E#qVs3E#3HUK*H7_CWor0^$YI6=cz0?txkb z$F5}upQPLVN-T0o+>&P}r7k*yLli3!3`pIPYK63@!S)6W^hDa{QF`dDK??$B!aqpt z4ny3w+ZVt@ui={xav+wmYz|Rt?8L8PKMw8I=ZGG}NLQ4@q6f>>ub>l%A(!zJsrM(3 z^-rPrLwrX9$E!yVb0ly)Mi`wYdNkkV&;NW`0&ZUN#;;>6u|~-SB$3whKR0nlb>|Vgg<8RB%<0cI~(V-m2~sI{^4Ocs@`Abw@u5^6%(jhFAfB z7iPx>%GPx+x8$%!&pY@Gxz6iZFvo{F`3>LtWD5Qn`GRLJ5KwvF*LVtdAk^X3<*dPj^%~_ z_Ht8h{P$k28R87`3`(#Y@q)w~D2c@av5tJb08I`e1(tv;bTOHo;UB4cMyfplKG5CT zc8_YY@6apdOTU2be8MLs54}Y@K2=)vKm(2`U2+ zxwJzPT1XfHTY`R1@(9Z@4e5ILqHjvjVk6NA5{^(p;v0*!cjHRfca(joVF!nlFjbIF zl@h_b;zc~lV<)Xu5E1+s-w_eKdX!9%gvzW|q6AcAfQMY#x*4eRo!U@&^@lK%?2@TnP;6RoH& zK@=$COSBeWeMy@p6nsv9z-O9OG&8ov(Gp83Fo~p^1M3D7lL&K?&nz;js`JNq2$X$b z5+nr32P`#cj@AT0ZbIpV?lvHlj`^Gzg5sR-RSkH0eh7WA@yM#UeS=SaREtbGn8oYFDq z*euz)NSP;n-H$5#IJPQ@zQpvB&k1D6RB{bp|GR)-iN0*lV@3$1)g5W4P{@jOUHM48 z^&KEGAudn`sBDB;AY2YAHBkBIfMmlwDc5nnfbyJUaCs5sNq&ZrJrC5+JcL(gC+lRZ z3Z%9Y%5&jFIr^VPd7hAQ*FZYEt6YKKm);$gG3P>9K2ZJdh1mc;qn#O2IXG9M+;9%F znIn<`5>gDym|+EoOTY8MEXJb0!b5$yfTv3Ov=}bzmqU1!UQ6~Ljgy|YT|r(Xl3=CVkV6>|bN~$^B?Ph3 z193HtHW;HZ2xfqTv#~3U66{49*$GH?wGr7s@&2HM^Loxaav}R@JGJ7_o1kFqzHa3uF!_opfJ2fWv=WFV&^L zhxF3OfMcSa2xgi)X%=V6>_1V*#O>K0e)%XpVkofyZXlcSLi}kW8593?aT{HY+-o-= zq0f@t9cK=|7$R|^*-g7eaNxiL;<|&px3lW)oPWu?Sz8=aT{EomusKG2fx0B?h~xp1 zH{c*5-R|s3-S#J_mK0n;f=V<8FTdEcFj9+05nrQH!@mR5XV&#PfZb@`dhhx-t!r;w z-FVHqarN5Gx8A9pkUXVL{dADj3yFKlgDX8C)&Uc(dtbavl@-7wd%t{?k`PsgsoIe( z>T`)3EoL_L*|dPQ>$#2#|NkcMDlwz_Y+4|-Z}ew@5!G{?vG3Psz~rfO4`^9GacJxm z(0LY1q)u+`S!stchpSGO^FReU1$O$?%bZPK^Xv1yuFqizSOW}k@N7rw4wjwFJf2Ks zINsJ)R7id z^+p`2C(YLJhlf5;LOWd!x$w6|zq4y@^`Q&9Vr=$PSy*C* z2fsbnaqn4}2>rn4KtGLDXaqilX$S;OH=#^p;U$Oom=5eAF2j6`G+X9l73sbVH6Q+) zF)qer1Nt=C{30E)c)j5HmfWQYlEO@^254btUKVW3PzEm^B?DMGU`W~+(mk8X$T04^ z7+1VU$xl&&-WBl{C8?P~Y|>)`31o78;PqH#k*Y~vvJuiAMifT;IXp(W`1qqDsT=YE zv}d%B6RV4s5%Mr?;wNdvWKB@YCtwpOfiN0SD=MsPjHgAAPf|dVx{`X1pHZR>&i%Oi zbjU1w8l&d$Woq~xOav4Gv2L*HTP#xxj zT!>J$ojmb^TQ35(7lw=B;(mc0!PVihptwI1EBr=#g>&?>C%-yTg-B5I1G z>x0WT5OLr-ml5;_BH}smbjgHH6p?f&ELc`i5GOR8mfQ1suCrFt*VorSML_c->0v30 za{zfk&TVPB&>LOE`ylG=68!Q}h9V_`CZ3!OC2B=(4` z-RnVU6fHEWogSGqBGs?Yjf`~&mvWtYg$v{katT77x_4cKG(j+xE#ltO5`@N(hdDCH z!yIMU!_4NANNd>mZsOaFD$=H7k$h}Xgcci9qGPcsY0N-~?WvdJL8OJquV(}qYR8;N z+gMC}j$4gPFZEK=8yPDm#jn!X()@ZkhD}rH5Rw*K(1d>Q(v8-1?3d`EG$h^mnQ_IA*)Odzud=h_| zl3$|a1tg7e0g}8aBt$aFypKu3CIIz~%%#YIbRtKNRCuk(WD3HP$G3oQ5#O1>2oc@N zc9}CEDp@5kU@T(8o!B_j!J=h}}HJxf$e6gFv1p2Ed|a zcJO;CiTrxsYa!ys9$5;Q&IlQkkcQYG5I3>6as8xCfbKy^+Y2b-h#E0vi|vmv0zd|A zP`o$9C5Nwe91ipK5gZJX*Y39h4}rL=5MZ~qAn$d8^JEjwY$?iSWO$44Pvgmpa$7R+Rl1Ja7 z+0YB(XV6<@P(@vmTsSW2$rbfQ($9Hrav)e-dWiaf`63SPq#4!clLDxi@l@D+kQ9SX z`Y<0P>#ENpHVP59@!IR9^bt)a$>9^$Cc&{avI4NNhJX^p+zxU3A#44X|+mPH7F(=`EcpywY*m4Fw!Mx*7VNc*`uQ3=MM`Vo8SW$YHT5IZM<@r zu&E?D1d4ohoued^t_-Cl^UvaeEnbNDa82~Y3#)s_Ywyh4rF=c4CK{#A1N!*=*y{3VglG&26SBfKq z@`AQCGOogjAi9*z#pjAZkw6F%?PZ`yaSX)|sPs1|`AsAMPJ}I|=U1=r;}{yt=*-vA zG%~1o5@1h(bJIpr6gPl8VdL`l#6KWtB}1o@Us9IC)@ zx%V@}qEuCe<*f*bk0zd|shbJVrpAYi~xvGL_ zU@`@9UbY>zLzK)s`$6(hDVC@tOz}>r0{Z-}H zus-%;&`jr?nO!I?eieR|BsVW};#s^wuGZ48;-V(M3G_dmdi=(peDGgp9+O5O;jGh-|S-8Y&VH1+}XE z9{^aXz;Y9$rH!bLv4FPNQ9NeLd3V|!RWfEl{eY|^LvKtey{{52L&#DXvI{W zC%lDQXCQNg809Y^kw2h*iD^-KsC8@P{d8Xy9BXpvkUhN zwKxwM2j;(i(CY{0Io9KFBXO^R+fXttY|N%by-uMPP%yXHFo!uoqVm8aY%>i*q+OM} zf?XGT(vhC-E_FHelc>x`9PZS5f;!p$Cm#&+TQ8vfU z=P_o$+v8W+5|{Pugy~_3M29{!z!pU(td2Zf=I!-v#~oMK{|+eV;Mw@_lWSq|mF-Dr zNUv9|fm^B$({JXWbbqU}{wW2r-SFu4F>ad$-Chr?zihr~-#={F<#PqJP`oq1#oS={ zW6nHy^6T%udf8len-Y+T`HsxUYo_1)+91cb7aU}+5rAAa2hWa746%tN;rEdA0Y@XB zSv=NFOt_G;mC^gJ#<#`C+_i^~DSm(Ol?^FeTuv|0!mbe;F=AdiCUnGa)vDqYc1RGz zs%IF%#21mr-8zo|I? zNc;-D@){)+BPAn0NacW1?>=y-Pz}*~ELU1Y~=)hKd?baSHn~ZwF zeXb@6(G$(^)ZpwG*pIF*kH@2|4bHs5S8~X%KCSg2i@QBAH%IE)`rv%RFE&y>!SEuX zWs+yq5u=s{!XnO;XlJby6%kK@0AC8DkQ!f(uI$mx)WXfZUVKfM$h(Y=V9L*QeV@vh zax*n2%3Yh6WcgoG6%*qUBK|l?g&b7kA5r-fiYxvxRs0i5*bekNl>4WY{4+}aIg&ao z-u!>9kkv!F8z$-8v8guwi$FgAE9gO~ltA>dFxH+($zJ|2QDb5_B6Hcov>->F8~jX)Lf2#r70W>K|8R*9GX8HKh+@P* zmNJ6#TNL6w1wUSwcllES06Pv#sKEOX60HLaeh-tv^qE2IO&&!|)F4(*wno>qY%xc}7=pyK8#lz?r6*PXoTa|M47Et7lEH&rDmj73%a2xYLq-f z3Ees2D03;8{409AMM(<@oEz|$OQI(#pS}9POg`*OZL5 zk8{VVjG~?ias2y8z;KJu-+rV4+BCKJqtcfbjB=q|II&cot&~q4H(;8`tEJCMPm_B= z`%&o|xQYF!bkR`HJ}RxzTNJsy@WawoGGx%7@k3bSNfTC;&yq5<`(bHQ`4iNlMgB*n zi#&_kG4bzcM&#FsvcVpSwt@~*5R{IV8)Z03PErKeNxk0Q8t&hc<4=Y~&BSH6r0cUJ z7?lyCJoX$`h?biWDW;(umM-ckN!%#hR{~ zo|@?!PgRqgPIV$IWqa*_h#Vkx;*}ROo4g2y4O>ZoBtYU^oFG>O-iHJz_+gO%He~oA zKsG?It2Oie|Ea3(=@~McJf(+y`qVk6&h@|l|DioST~zS-^S}B2)^Fcdl>bH#;~xVL z7xDFf6^WzN6-RZnmReWwtGBeaUf0`3-Dqd(nRd3GRb_pyo=3gWDzuCBB9~=aQ|(f{ z#QAJ%x?Qf9IiEv*rar^@Jo1%#1^GgK*3IwDIYnn`U#ri*p*SUH`UAz8zOU7*j<&AU z%648=l;ZU*&o|p$r{8kTp6K539M?C4E!Vs(cKyI^z2SCTVFz7dHd?mto12uARi+<^ zek14$dee2hM&NZj=Gn{Z*J>-p;#+~)_tBSau6H-OEz@^z_uWpzJ-1=|u44+j<8<37 zwMApg47z4h4&<5*)Od~^AScHursK%rZL{g!aXWZ7FuR+We9LWM`WUL?n10W0xMovy z`@LmMw$r5FQPt&lH5nClXnUq81*TD06k4PB|Mef0zXYV zWjsxzkDPpi`=NF-P6ajPxZSKXho>1D$(jF1bE?iUJXhS=?KwFkPMmIU1D5FV8e?x%iRcEIB99vdS%~r(@13tl$afG)fm}Ec_YnvC*6H91V>09BsqN`bl^3&g4QlxZtS?uGFI#g!a6Ey z*?r&J0R1{dm;4p(doIF@U;eobKqn%DI2Am&^pJl0vh)vF$bj-#oi1zzD`^WDP^vfqB;h zv=AH}=#)B+KQtlR-)#e$01O1Lfjsb}!PtSWGO&9cz)j~MT=IFFoql`673NuYrMa?f zo@YLH?K~J+(ii1j1?JXtxeMQJ_gYlUpfBGy zy$+f;!A4^Eq^@b#l~4Z!$>99@h=-Z&K9F$(>+ImQz-?lV8*Xsd1=^iwhIiqG=T|<} z@SZ3k0lUcxx98eHgE~w;lqXiRj<5d=l8_f6dbwZlqcBaJm`_J#-s)N`82)v3ldHK5+TQC5JdFSTzX?H?Fw;PSEY$=WRAx8+AUN#ld4Vy%Dm+UxIKe*NvYt@kdiz43Z<((hw82z9JYVK*!QSI5T= z(&9;0^WD}aA1$Ain4?omVN;x-^UEio@;N;D&mvK#7u39#Rq?0KsdIWsol}c?MK7vF zHLn|Z%04t64i+B?bP~rP_jV597~vsv2ZA>e^gjSjZ7Wxl`zr8BAG~=vfdNTQW2ytz zNH8F|w`)TnPJoGM#ZhGLz{oTz@g~M@T7MIx-YgH*KnYdiCBLTKHrD0V)HLxFTHsV@ zz?v3mE6+m%>9!^}uJ|$?RHNJOz49Jj5KBk~Upc%fkAUiN%uM>JseD);tKvoK@!&}b zB1_ls^=UbZHmLo?W33(74We_j8+;Bl5O#q%>J{bYvp2M%I@AI!)V9H+9d*YL#lYCh zgvomL`WFGq-^eO;t z+a0kG*0mssL7VUEgKbF02L`8uEY#FZSm8T6p%zS6F6Ga`7k2Yi}DH3O`gm-ym%d7 zKZ``k%&BE{aOUvoO!Vv@I2C%8zI-aLAmgY*l52n)H&=!@k3E&7nV~*3a5^{}lvjtD zKo3=xkur(sTpMZGldMVm53ML^^cMNXIbD8KPNKf(QpGDfRFw zfm)O!;!5o<(o;Te=SQj)>3Epj*&15i#MdX@t!M^FPzirkb?~{v`+EcdKd{9)bddu7 zMLf_J;|tZ1HoB?|^)zQdIMu-Lb$pHE$}mIE5K@h%G0cXUFuO@M0?n=h8B-X$lCUDV z>nS6Gd>HJK#z2%x@*gSlk4cBuC#pzC@Ie@Syf=~>&zUcnX}6Lbjtfp_?N)EgUM2|$ zY13}E?OHm!N4*O!-3+!MXVQzyX5-)+Uoh7uM~orEhDJ2TZ_Kmzmd)K|bHH*P69TMw zZ3IXE7yTx9k4e1_%Y}#q%SZQ|F7iE>_|Al!ort5@Qjm;P3S|jov^eZI#0{1iZ7GdQ z3KI*h=@2B1ZV1=jak_UqbgH8R8IiJB*c4{~X0;r{e!YT*7BlO3%j+{~iIjY0lkr`x zPg9e0Kk745Y)XfWGAzDK;tlcfO?-WlN|myq66O@u3UFs|?l9&&Mus|sK^b(&MCuoi zAq5<8<3!kHsyCTk`Cr#6@RtfMO*=0VzCEv5+=3Sl@two=E#VOb@*k=t1C2 z_y_Bp!w|RS_60D}Yxt&v9EfEsmqW}NJMl}{j|02)7_oyG>7sH_>|mMtb#wv|WEDS= zdT;Vr{}PHnz*hsye2vCgq2xIvk=FA-<@;@x zj>MBFiorT6NsyFh!E{}`L9e}uq;7!T`}Lfx^e5NzMKlTVm0+5Y*f^FqM(=%>B8Wn_hz~Os2O9TNCusCYg5Qj??~Z49)FZ-VTNzLjY2P zV_D(9zuJ@=|NU2MhB%EpgA%Mpydd#5N@9^fyo-Fj01Xb}1eSm-bTFBn;vcDVMyfpl zKG50Pc8_YY@6adZOTU8de8MFq4}C>Daipkl#Cw39>>U_fpbQ@Z9v!_A@X#;&DSlbm zB>#l=K|X{ab{{gtw*XF#zN3i?q2`dON^>BRA%!t|Q(srOK8O0Baq_ax$*1LoJBs)@ zGB0tfDYV*jO7YV)o_;CqQ}&FM-erSSdGKL+HWLC!umu~+0n18C2W|6(D+$EPRWF;u zC8!KQ;L;9BU?D*SED3r&$s#PnG^FF<^S&uTi%mo$NH|0ViD%5y-jPa^C|}=F_Mm?4 zA5gwjK{{GW0Pl#G@hFd-v{gX_@FU8wwsq++i6F_8S*%0{sKo#cxxk}HUsSjb%NPt$ zUPGF4h+ncjCn~tg))g^N)hQ{Wz;{+W0+);I)d6!nV4fYhWNlTh9k+7TE)9<$`%KntIm zK`GIS>HyT8l(Y?p=4k3MfrfKK? z@;FU^O9P$7O8TLB^81n*O3Wjc<=*SHh-;8 z2T&$b$`I>5kX#`Z$nb_SqJyXgcsd(9(l){Fq>;IR#8?|44n%+8zamM}hgX9Xd=yAS zE>sLi8RNv}2&f)sT69Co-J0CI9uL0Io0^nncU zB-_OBZ^B)HkPNwo1BNjoZK=(DM7(IUNsKw7`^%sbuJ2)`OdUKp+OQYc@{#pzZ2Cu< zGOi1FoFShhS_82g9uXnZRY`$wz1arH={jHE4}Fqwt2IkJe8 zYPzTG!}C9`p6b%&L+WW{z&FuOWHim4w2ZTm_FvIZ3Tdl{X%Rz+g>{2u#*6TelFnEp zTN5_*Mal*9g{u(Mr^)P&^M+dt;W*LshTS3A&O)IO#~YJG)ZF{VA#?Jy(#-63xKXFLo`A)Z$UZSE$tRZ^QbTb-fPY zHk!4*cIE5V<+m@bziC~)bou(b@70zhKj~0ET_p8F;$8CKPY;M~z)I`h6R%Qb1@Oqu zFCV2OPt{?qc4UkCOyWh08BKjUEnp>krsEN(EX_v%yN z_|&-vl&miu7&`@Wp2Y%bl^eTO+F{Jys*~kBP=QW?eSY;aXOq|b`fRW3b1VYZ09zc~ z+mX7BWhXO_CleWtx49V=Qi4%0V1L;u=S}j=W7bjN^helE{~{7)whFUx1=4R`gXtLF ztdf30J)tgXMd*fkRnx#nb?{MrPOHG@W#|>Hia(gj(5|Qrp8d(-x0YePE_N`aXd#KF z;Ce+Y2Kgg>$v#O9NY;y&MAnyJRCA1`I?Q!m0W+og_ocY&zfVdi>Y<== z{b?{*s{an^5h3M&Cj@&%z3$|={`Yt!5^Fz`Bl-V0?(vV~5&t+DG3^eImdJ3T3#K>&`z!|87=7(YDDq7Sh3u(Fh@I9o3j64 zoYAds!HfqUMdZ8@2kS{YdW(D^R7@FSd;N|)#*vkmg5}PempV=GzixZP1QNR;NHYb} znO(%jB>{BAPkST=ldG8fkGUe~KO0lx_?5!?B1!UhCinaM1rjcU0cscAXoqtxMn546!vm!mHHEz$4nn45hl z#jY5e{gKQovBQJYp6$4IElh-d;B=s$#xj%xA7V5_fu@_#rLhQ;0~|~T?vQfH9E?<4 z=3o`+zYH}W4xBOm#Y6*&^(q=hI;8M=!SgM-OA`cz*;ftFBH7o=htVI42J}cXL%Lf> zGBAwzHH;{}PRY+xf_@c2URQKrq(_UXS$_sh{K}KVgq^N@~RW zcsa_&kPsC~=a3H|LZdC6SZH({p$|u`gvr+i9NSksToI`-UA zOVn}c!8Uk;6(DZQUcHI7TpMd0@I~yYAtdk{`Y<=l4+|mmi?Gnthefhj;0s+vmh8ih zfmTz)(s0_*!YN0ehcq9S;@@fc3sdqG%*Z8tr`h^Y4$E+|tKtU%d?NQ%NbuQpfbwt# z-%6O>%Tf>;^oQ9X7b0+NJ5NmE=F0%{x#4^`zgJ+VaCLYrDDF*#bI=tQ!eikAEGEap znY|M92lN?MkgHM-xLiB7htlJ<<^1>4$0DK4@_qkt|6%JB=3T zrcWJ_5i&MB!!F5D04PXl<}w>Tl6J`i4pBnxcFS!cMkaOWfOu~}IyP@5ddw~E>sAtL z<&JCIw?QzJ^xXvm<06|YYC!Q{u)QLEfB83IZDhla@n$vQ7Ue9$_YWmvR z+NTI_ekiRTWpNfDFUX-L?H78Zi?|=ez+HfkKFUzkL=fd*pQTVG@l9%1q2x76IBa2- ztaj39lNcqkcCQENQMA#h_Qc4p5vhKCW@Nxa*p%zkD_kIVkV_Ey)V<>(vn8-sUL7-e$I$L|VhfcLU#MRFRe*iv(qhBDL6}5*>>TOk+kvY)`!$4rH5St&Z91d>Q(v8-1iLYB3oBJ>_PPAwDL zSK}Gpq8XAUSNwHKUP96smn6x%LeeCY%zKz53N$cTy@NGEdSRD~Cd45=V2d3+1_ z7V(`5j1aM|Y@9g_vI6o2;!tv6{6GqS7v-4fmc!{VTT#L?Oe!;B9&XziCN2f!^suxD zGX%xq?&<-8ZLYpJ{d$L{cEh`c$foeJPl(>B|L#SE{4Z3v&FrVlYx;L zMrBpG{aIMt%b>?;B8)s=iYA6}YPf{bGDz%F73<)4Fe1p|6#6*H@4m?I6|tMAI5&md z6CjXJ5CdQ-GduXVDT#co@3jzpV~@-QOlO3QPe>JhgykkiH?E&F4A89zX@CJm98n{t zY_Sm%mH>!@4T}DTn zz$S~X=Wu*xeJ*+QEt(CzAVTyO8B|f1Bo~fMdU8d5o)mSSn;ZycmmVTOV8@7qJZVPt z*`xrvW;_+PBP7M3lRoSS$-3&(h?7F(ZM^n+DSbqfNpkq4wLx%fjm!aTwIQGcG53RA zqKPaj7Rm5jgsRO{P@98tU4qgLwOXx`dJT#R$39&Ac`dJ1Ih=F>DmJ|{c>b`c`1oO= zY!iIIOpPt%t&LX>J2sUhhd>dqu5*}V(gmWFWd0RAu-OaIUyk-bLl6MGOhA5+PZ+}> zMW&@0l(O25uscMpuqSaU}ku{JcDn60qf?0eIKuS-QV5*Z7-#h+L@!R*EM!mLdpACE^ka z!&;V$wYou9YP2PQh@~ zz)Z<@$=EB!5kh%E+Zq{{;7JfY%I4y8Mc_yv1c~-CaHKef;@_pxzeCC2MFQYN_;PxF z^$I_Zp|OnAd<#t@gNi2s_7pfbZ6rl=1IQCLF7HzOBZ5}4b}IP=bpfzC2Q!eNYWSmY zXTa_ck8+4UrpG7;^%wWrI#>Ucx>c0_^(tVKRuUP6AJKdqrNJ@cmUPEOGnjeXMZ`jE zyFLt{&!C4?4Z5a54T1s$>kJ$N4!HF>%A5erG9gYmxh_+!Ffm1`>%j}bDh4?VS zsz!>7lkF?wZBj47+>OkzC{>kVIV_^h%&;lO!+Oj@lt=W)^msJnS$FR2R3=K zWgzz*^g45BrpHNv;x)>hd;vsX%`aiJxyjK`Ha{v0pVr{`q1`&T&**txmW2zvHlEq= zsoUQp7^qe4{{pZ|#g*$IEN#SijD@zv&f_to4lc5gM04=IDp|Fle4tzg;+Rr)UnWY1 z0H%PL79vN`ifKAYIq9fu4LeF*MC3}h--Lq)-VM+h1cBcH;bUtBX@HgGdaMv09Xx2V@6Xa&^EO}5ZsPLQWO zK7ceJk#<$?3U*!WO6PjIyVT{8pG09kLdsZ1@4pe>9v}1C9z3e}{lS;krFgNLUZ{m-BQ|5i+;mLr zh~KJJ#VPEPAeL3nFp`PSBagdw90wBRh>t^P5PzGRkrE*8QL;=)K*_IC!Xm;x<^C!q zeM)|h61H(1mtIZ_mRkv=7rEa30YOd$RD<%Q5W7 zvOKP@vAPe_3X~)9>-5T-l>9A9H~@r{?x+9*1Iyrkl&35tDjCNy^7n`j>6Krj%)5U?644F3sVtDuwgN{P8F&qi!o-TOXDmQ+KLrDD z1tk=2lGRS5EQ{-5WjM(o%|9#|B|K4JSr*n1v@xn`1vzSqcaS^cs#r2U4BS&hpcB=K zj!Q~IK(+KUiyS#&FXQX~LnHy@A(DXZYkoC=*v_|!5mqwP1DJGiA6Vbf#l;{CL#z(@ zATCAD_$?S&ar0Q&gMWynBV(9>^aRNX($Nhk13qd%I?76OFM`9eVGfeQ6yC@q@S_0Z zXr8$#>?Ff7>>w0CG7afL6W@hogS*Hu;ZiRMC80nO$S^SAJ}_X+t&kDb*2lSSf|YD1 z07*IOjy8xytlh^p)P%p1P7mI4fC7!URO*eA_hNCW7gu^@G#Eu?z~>y(LGURe=fG2Z z?ba?Xos4?Hy{;w+(k5vGtBG&`4S6>{G3&ypamsj(X2$UwN5n<2S<^gAw-r|=uV?zp z6L*JGPB0qy|39|HX5YtgsKJ>rupeDw9uG)a8JvEbujY_V{ixQ1EbjQg)EucRYlCwM zr&v!p1jCDjmPu|=N03??7>l@5qMEf*R75}tf_*8LLTY2c=R1xy!;(yOa{Vd_vBc01|1-7LM}% zCu5I03J?v$>ty3B21!4ZCsM|wi~M>Z7$9EzQ0Mycn%-Yd(dM!YoMNd*{tqtk;l=;S z15t?>uu`&rZj(a1kHCqyJxBGT0uE|-J(`{ImI8PI0OGEx0NT&2Z+eA9~M>70>4n&#FR@V;1m#G=4(c&B> zHAxeY#Iq6i+>LZ*lQ6Q+mAKCnWh$hT>9diQ7)7VOAF=cO8L|Y1Ez?)TKcT? z47n4uAD4axH?SX;&Kv5Phou#Ii=wyZepI?dh70;Legx|~>AtG+SyF~}KPqh~e}P)G z$p5%>o@Y@zCjJx6i2N8)HrORWR?uMzg3{4)qYQ_}Nh$!lsMp(F!QESO^l>~?eJU=) z1zn%zzNm~q<+1<3Kn9GF;KrvV20S9{O>pjuFHte=%2C^e|KdOb!7GqX zUqXXDdnTCh8N2b}`cDR%Oy6-O_Hza*4wR5R(PZ+Q$VBEl<1E$kiU|%{L>__K-{RIcmz*T0D+DVAU4qjEtC(xeR82KBgJ!O_tTnr} zot%ZaB$^G%GzYYi8?@8KU;@nr6KNipM3;ccR12oid@z+fU>end>Es3b5`q252ll50 z-~dweovcS~F!{kDv=AIBvXnuKU=J7B%M!bkO%3qmh`i-d0Cqk#f+Hyij-n=TjNn*t z3dYf;@JwJ^-$_0jg`|Q`3V}t`3>NcV(-h7ul~4;D_Ab#sGv63 zl@tbD!arU3tC-ly%V{)O^qu4kny4KuC0YW?bOmUmKY@0-5=@}0z(l$lOroV=GF=0v z&@wQU{tTwkwO~5^1?)@Lf&J)uus7> z=>c#GJqVW2LtrVb1*g)(;52#!ETc!ka#{yg&|_dFJr26)32-_+30Bcl`bC`$O%}0Z@)$Ae84n7s|5_g7Vyhp*-_>P@Z)N zlxI92$}(BaS=XcjaV znhni^=0NkIc~E{E=R-$AM?yzIM?ptJ$3VwG$3n+K$3e$ICqT#R`7D4v0lq@$M5q&5 z04;(#pvBNa=oIKAXbIE_Erm{oPK6dhr$LLMWzZ?ma%c&(0$K{KgieLJpwpn!p=Ho2 zXt^HC)37VxTMw;-J_~h0pNCF|z6h;?Zh%%pH$rQmuR>=)Ux&_wZie#ve+zUrbSsp< z1Kx(Zq3=TZd*FTOTXr_lM(e?UFZ&!BbC&!JxE7f^zJ3H3q0 zhAx2afGW`Mp!Lx2p?>I((1p+_bP@CyXajU7Gywe%v=O=+8q{Oi3%d!veb7sx2cRnS zAT$I$1Z{>Mf$}%NG3a7(OZax+8vwMRv9&gnH zJk~8S{&Qh-k6T5L+hU#!7CmkkJ?@C{pD+A(3jbX({vpDDxA3oy@oVkx5&nB){6j_i z`-K1gll*H=@(&a39}xZrW7=m3|3kvRHpZVR{0|HNBk(_}??jgHuM_^qV*I0o|8e1e zBE~;X_@5O1rwqUFJuQ6CXg=XtFFb#X@jNR$&&7D27oHblJTD5*OEI1e!t*jb8x@Nf z`4!>edu@1L6CS?dhG&!TY=-9zz1J4uc{9edRe0Wl=WV^sJHo?n0Q3UiTb+pShd43o z`wy6fegKt@^L%)c=c7)~$9k_W&nLQP+evMVE{4Nj~NB)TgHG2>($KF;7xn>d#M(!glK2 z55cZbKKy_`+@?~g5{=j{R_?l!Kl?CHg->~?X+pwh5B|XwkCHKW#T)yENe;JEt|&Y~ z>QjgCq_9D`lv_3N<8A-Isf4Pr^gPV?xKRHZ~j#80NF`krufRdp4JSB;qG zU{Dn^)gczlSmVVrOjK1Lkz9na6+}W2sk9Nzxzy?FH97Q6JyxL4VRmF#B2tYnVlItH zvwezK$!DJy=-aTZK=1i`fu7~x3iR#zxj@h9PsXDk3iKoLeSyB^8c?&`!=SV+t7ixye2lS{kqt?W`Ow(D(76F%Asv4E#}`F9%a- zXLnzLo;Y~4q`}Z4++&PmcIamqd!)^Un6-2tFk*(Goq4Q6yD#u)1H#Y_9Ugr*vhSvg z2O{RLzv;1Ros{fQzC6!%)va63b||k{Q&J0TgB;2e{oFL}z99~!eo#Ye*g6cezqx)y zqb1X!JU(&O#p{n{JCqgG&F9^IEZ3pT>GR??(};YB^3&%}bj&?6Qcq{aXooTc13M3m zbts>IbfEgX-s2rg(Z6>6KK0;4hZd^c;ZP9^hh zLp+L~>`<;B;Qh@VE_Nv2e=_KcgTGC2C^z2Tn6k9B)S>)yYasv1V=(q?NB`Do4lRSf z!YFYaP7hsH?of6ud$(ruY#1Vpt(AJ3X1g59(oOHJDQcWM?w!RK zcu*&$l^7ovIH29f_t{(2 zqDaH9P5b;PSEi75*dO&`*qS=!p|4*{TK%OLPbRl%Ph5mbNl-BbuadXk!<8acDY)p8 zmrH-*N=-mne#Cd+-{<&HnNz2TH_sX$Mk{Mk0zV#FxBDx8QnEmct8xMN@hJV;J-IXP zU4S?6ENxdi_(7FV!Fpf2uUJ8)GNfItx4BZPYJ=J8QE@dZ+inUayqAzT}+I<^aZXd6c!UZtipbjs`4fhJtJN`mdV;7`vh= zx>GcFwM2i15s?Z5(Y?as3`7qIW6~nzeL?h)s1${whlNob@SE|CiWfBbC1F$q{T3L0 z8^g*-VAvVG8Hqqf95x=W_NoQQAYSj)c$H78$Hs&(Qgak7Oi_tpVaNRCK4U0j7$bvm0b@8Llaa+3!N_J@$jD*jGV&PtjEfi}8KW4Z z8Dkh@8RHn^850;283hanqmVI);bcr^6fue!Qy3+TQpQxqG)5VtoKeB3WVjg98C8sG zMh#;IVV+B!_ByuF_$rqaS5ZAF`waKcp1d-F%~fRE=TJbe#Sz^B1Quv0C@Rq z7#}_ud@gvEA&yxwEit#_KRxqOmq_hWvuUv@)N3mK@O!_xU0P#aYQEasZoVJ>cIhnB z8XVk9W&VTkHKt}uv)qoJOQoe|{x8!KOS`4olp$Ay`&_C&Z^qF}0iK+2%@1%pAu`=~AiMw9cg7fk*s9WWZ(MD2Ta(Y-)7zq|*v8tRN;GGhq|^x zy&J(sdUqQtpTQ7cc-d?>r5?p|55m!HX2c2M+wSMg#2J}sH?I=8>6$acqTY#S z5z8W9Tg1B9)6nAUdZ-1IFdkc>Z7H8?v88Dy^;-thUUoaM|A)<7FTXgB%wUm7}OeP$m%0JQXDrzgN~l9qN&y) z96lo}ww@Kn*KBS_f*2no4(GYO`WWidJWx+|w`MXCT1T(WDk2g+FoH9~E@oAqKo9j9 z?v{jS2tn(X6;{`=(uH0;6>0z4J?(*d^AOFNTLp-M|~s)7WfnTlJoK zXxx^`?CgmOdTK6C(Q(-@u0y6i3OD9+6*CI=6fR?w4&&0mA-btgazMSbzTCVscC9Zz zFkD@IwKL(Ci-~fLqx&jn1uNry){_wKN;1aPkEr3)8o{|yY>Ad0t#jBdyS&~;ePnT` z#LC3I)s8qHxz4-Hz4Y_XXG7(a*}*NeUPd?=wXSnibu+iz1cdeXf(ZJ!aX2#d6#zXa zISci8!e?sl9X%QP%@UXUb?zW;c*=1vCoMzb-f3GGPSW39`o+WqjRB8-R<%QtfUw1G z@@d8H={8=I`MHsV8eg;OQG$)02L1eLos#smiln1sPi;{~RlhW%Q9fS6TDr8HqZ%is zVMELfx)F_8$%cqxR#*3$f( zY>#og_{JUIN}KA0vGmJhjKY&}09}VcyJB$lz;ycg2#0OF*Z2;yx$^Vy-obaD`0BCq z8d1U>@yMrW&tEHfk@yl{tG@j~5tQ$4gvbFxRG zL82baMO?!H8kv)ksPZpX)yv@DnkxU+RQ-%r#!|*|Ai9FFig62rPd9$KnLQD6UBv8# zY4tQT)q5gxov+ywu_~Sb4zfjEhIZ;F47Lo1@am2a_@Tf@T8l0RTj=sJx9CLd7*P%~ zBTjvkTRhHqnX!fOHsdY8ZFRQr*G7D#m+|{+F4}k_>;q z6AJl4>ODN#gFwXUZ}tU3?vxXL1BRL@(j(M_pQg2n8f$b?6ANjGJt{_LYZi}tUj}vJ-T|o|QCl%ZAKIF-Ex^jlJ z+jd;BolwVf?zvqPs6^YVPDd4YtyjyG^{NR@uwj+dx{;Sq(~`d V?Vlb0@r|t~>_{18JY!Al_%8$sBYgk> diff --git a/basic_function/__pycache__/chemical_knowledge.cpython-311.pyc b/basic_function/__pycache__/chemical_knowledge.cpython-311.pyc deleted file mode 100644 index 963bc37b65997d89f3d17d8ec9db98bf02ef9527..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14370 zcmaJ{34D}Swg0}!GLt}9LJ~p#B9#g!beSjzxbp>lqsuH!?O7K4OX{ z(JVHJ%@0RCVs4X}+=5s((E^MXTY<4+8_*%P1LH+2aJ<+7oFH}rCy6#-qSysY5^i9M zXa}Z>4q%!Pz;xjO&J??W8NyPT?1Y*lyuf*44={^YDO>CXU&O1I%UQ}3UGU`dx)q9U z@FLLzEEc`MWugzbg0X~m!Ah|Yp4G%vCTVY!2_MiU`hk^V09Zx27S*(}qDJfouI2r( zP8X$1{&1T}5qC@@;w4U82Z z0y@MoV7&M+aJ={kaDuo8I7!?KOceJ4lf=IPQ^ZGssp5WMnm7(j7as%86dwm>hzC?A zKLIsI{5x=-co3K+P5`sTN#G*!Nnoxx1QwSBi&$tHtLcnfyF-nfLVVTz((;673rJe0`x`kE#SAs{NVH@tuZm+QW9- z2{zivP;?{OVAElvVd-v+g`ElOfX#rV5oE$r{8_LReKst`o&!sf=fYCdd9W05J}gDd zf=z>60Gke*4LcKdA#4WhBDvSWa^RT*n+rP+b}?)gY#wYj>=M{Tu=%jLum!Mru!XSs zutl(iuyh+2!4|_7!!Cne2D=<~1?&pg64(;hm9Q&eSHrGS>sboE8on~vHLxz&QrJpZ zCu|jL8EiFdIcyEA3wAAR1?)Q5O4#+VRj?ait6^(lYhdeO*TUAru7hoWT@Skvb^~l9 zY^@s0H^J-Ry9iqk`)$|;*zdq@g#8|DBkT}t6YTe4n_-`X-30q1*v+uRuyp@F1=|Aq zG%P&>ehRw{_6jUL2cCs(h5ZHW4%lD8?o?w5g15o~oen$!fKm7!Vip62X8_QSub){Q*j>!x z>9&KS-G77`?f;|9=mgx&jLyJ^Xmqg0@cl5pAHnw?eDB5gK79WTUwQf{_DAnEw$X{Z z|6aorbv8!71dX+iLOd36R2X;aZLwZH!RX_-MR7ld@5iku9uJ+ho#=nyNkFo!2+RwhPv+oU$&z+I~OJK%NJFMduRs#9+@awXzb)@RSq=S_4}=dr%rw>tj0FQLFz=eDw2 zAN=u~3vAAit=6(AK+-qRm94DH+DEneD0q`sb!<38W`pAIqT{BJ&DCYe(++Jb^I1H| zBpTUzRS=85U2Zje3e?nLQJCtN(4DQ>YrT#+-uD^;nVZ$vs4th#;sLBiU$)vFi&Jsm zEr=?QMI&1sPr}uUX?RuGQF81L*|znlwv@+eH@+SWu3x2rr=&AYy_fr}8XggqtoNy$ z(C8Z5>o8tl$lQufeTdYKu^?h|2g(qru`^_A5a`=NVA!=is7~K@HDAN2l6tFD#W0=O zQRu)EGS`M|Egp+!^4zaWRT_e&s^#BH)hfSHs?zm8rD{$8tc(6ssySTY z0iyrp;^LmD9H;fznl0Ne-pX@Y4>k?Vedtz!)7m=qdqHDyk<$-)-H|;A}^t4r)5=md&4_6Aq1U?S3y46`Y+e$ip@@I|EbjI&A$X0S2b_H ztz?TD&%0ZkBkkA$avdtRI<05VRety85C~sdpWWuPo;m&f&zyJ*pO*-#LHd483rZ#RX{fN6@)8`?~ zEi$X4RM6vDgogXlPeUK%(4h-ORY!+m^A1N_t0voTajXk*iD^dq)1iq{=t)g7mmY z{wV8ojLRO9*a;j2aRl@`;m9}w z{2FMCBoIs>pq~jxCJ{^~NFoz1TzUT2r>y~5zHpQ zPm{*TT!MK7^9iyD77%0;EF@S&kVBA5u$Ul^U+c!5V^60w+NkK{XBb@|Zf;|L#3AzZn z0UfkA^oMo=?U0*q03I?d4d`012aNt1>+mD-*|Du7`|0F7AgIRdBi+RoACpmu@R2gX!FR=oxPdH18artUSohpt3Cs?bJgbv#{4Mye=HK? z&i3#B>)g?(BT-F8c2u2z>Rm@OJzbt|PtQQ6w?DI|cOcVq$lE_~M~>+5(5}BpJDhTK z#MohPMXaMHW$C2t0Y{T@-tfJKVtPqhyd)G;-&j%E7_wJxsI0B5Ywo8MXJ%%OR&|uL zZtC~=`dfFqdxW>O!qdNZptrBJ@9;opZ%>=6siC#q-S6#a+r7W1W5CGpQG zyV~~l^d9W;2+Om$@9=2+sxBI3Z+G7se|CSbZ=kLHaGQIe7ggJO_II~?e2Z6g^`cq- z8W3HT_EkR}$zw(!rCt6xx9059!K@PpBaWLWH?m_wQ~EqUZ?8ay1Mc=NPaE|TirwAo z>vqGs`}#aRA{0xpJt7q4_x5E+hbFnZy4qAe+XfEzc|r*t`+YuSx2@MFJid@!6^16t zhPrZpKP4F%%JlseB$WJr&dEntCmo&sPDu5Dec!%+K<)0I4&&^o#wlN(R^UX@TsDk zkp6iEJH~N)#~d+F*&c{%k~>Ci8{g}w8j9NMIUI^P;O^S*(R+pyf(6G)$OZS^LQT33 z`sq%^QYP2SKj*fcZ5o_+!gIPQ;^Lu=#vIByoa4vSD?b$1?DKYXd3(GaqmJA|Ik|^( za{d0OhTQzoagc`><>wW^6N+jm$PdNW_V)D3%GuFA1BZLmPL)t$G~p0el@0vNT?KiydJxgnCAId)ad%uJ2D8rV)=?)>6%I}HIKD5D z*hGTK0PeC|=O&(WLj#oK;)(jB8w&FD)cu@eq>pkZqsttzL-j+sj4Z_MiW6v(oO7=1 z=dngIG)0flPnnEl=zW9XLHiQ)$X$BmlvTMHoH(8!FEwQ)6pO6jztH%4A1*0hXK#1A zSI@|ogw94M+!;_W1X@$a{*Ybvkgwy>;_GBB^84ducnZe5Fw!oj!BtkETYX?=i~Mpq z(lLZkOhX-2ox~f4suuB>im+M4s$gc6XSE3a+l!2T)fUxGSqaWmPKjI|+1MY7Zs_RfQ9bjS;+gTem0O&)T;wcM zE4BY=Z}@0uj#3TTN)HFUwnuvP&7qpJbZqp+gTA958kuhz04>X1WokVBDe~S?i=j>y zAMOllz$ZNMj>}1|5T84lI{yT9cd3Jk06pDqwO8dPiN~-xo3vYzJ;iZN78Qizn>_+O+Y@i-6*iUg;_@f$g z3quY(K;>Gi#^cqv^79I5k@eULNcuQa<7CkMykb-suLiS}6gE(Ce)f3ZSMVVBJx*|m zV2I!ef-3~q2>uH&@*u$}f=>~g2jGR%bcf9Lkl6tmI~bcdqbJp=AgjOFgeMvn6R z7WI59C?CYMgT2bJlnNP^fuY822QfI7Q;i6W?PKH+b!zv~UtWFmcUIq2f^333g2jMP zjB7tVHsV5)_50xgaj?zj7G7^CMY}qB@#;fwi_$+q`@7x!{hof`-)J864j78@4tToz zLy7NuN9`XFp>h3v?ha3zqU1g_(h9fh~z7kV^w#r*x%(@kM8mO0Ox*o*!*&+&QjAk#a`Fig)?!BH*`$bAuxPL% zA_|Ky3`E45GV}7>h+1oIFrR6T$Y!&}3~b$@^aL4gB1fBGa5ODu%sM+atVMS2pooa2 z%P%EGgwu4Hmn$N2y;*Mt8k&_hRL5J)?Pj2Lm+~eWF%Har3Ku3Clafv}M?^Z!QI{;p z8N6LmwCqxJq@Uuh;);gTl9j0XlI=fWsBsuA14xc0^tp!h73!bzVOIKW)8`h%C ztTZo+h^#d?nt{g6O3O)0Oft5bJIuh&_DEfi=`{m;dX!F;W0}TdNtI*Kf|pwhPK6dl zv;bN-%?k5!b3|@18_dAQO%dH7yFQApUj0x z@)&5r$3P2SdoB1FXd#b*RywHGI-J(}7)YHWCq0$Jr^rcb!IRd4C#?l9y%uWewbaKz zYB_0%DUoBK>*`}5wcdt27QGFX;N_mK1gAm^J_cH-W1yuz23qT5ptU{*Qm4yY%-|^L zG8bBKF0|lWXu-MALghkAIqUSX$6#y$7YB@#@a)vYCG;pr3C9q_Hb1~D%!grKI1XO$ z#1!}y&zy6%O7mR&x}f>WHJ1)*&YOLUR$I(%W?*}}mNN7Eq~vp!6;7Zy^1{E1JZ|+=D2@qcrq1DGNw#D z9jp1Qd1o!nmlR#ttT`RyEG^`!wrHutsqSYnvlQkHtIcDb3 zd`-D|xmXJsX|tB{7LQczG`q||w@;SxJQ?`OGC!P5F_UO^(US8mpB0`){!|)EvgWhq zp53B(Ug3o*%~$iRRRcX1CFNLl>8kQn_ejYHX+d~8MM^W0r=6a!IcDw9yl~kiyXNI) zm3jG~RvXPtW?*xxmThK-84!D7SL3kGV zGifkcn$KM@n5sEW@l?%SX0`c@Lo0rWwCvQ@fqqs+;+8K3NoMKFudJtTHX&S;Jar}RgIDrC;HzrOwS+4oC8E4uxpObaorg`p?i;0>SuNaE7 zm7Q`&hU{&vS!-TVeb6O*JTW|(N+!`n3N%kmKhvoB?D^+>n&;$Qv}wMS;;6PVjuI+_ zQgnEOE!$*Gx{NWiDHJhL&P;Khnc_S%#dGs7M(QtHIaIFO(qneGT`Ha$L!v?NKEZ}Y zT}^JeRw?OC2hpce^b{Fg@o6+)#d*Gp=TKf1FD@A>(41Z}h9fQ3nRVtBpDv?U%HY=R zQqql&kkcq+stl<(PgrrDu;M&n#d*TaZ%;ARP*YcR*2~~Z>EJBsZ3e+-Q1EmaTyYMr zI0sjpgDcL#6;~mZ(m7UA=a}V888d@o&Xh3~=a`CfOvO1S^V^f0X>fE^Lq%6gM^{ou zXGu>Q%y&om~|j#iv^wBo$Nit`F9&MU0AT4ANs3M-|P%#xls2swvB&XyBaoF}X} zPnh}bNzOKS>g=U^9=&8yFY{$DigPcDb1#Z>Tg7pY6mZ)c^9?m{DuGf1S4t0@CEfA} znN1ZJ$dHOpn{m2b`#GfYBP5;*vi|Z_Lrt2~)A~vjt1wLu-8Ivx-|3EAuNOOREJ6tv zGX#t&H0v{|jU*!le`Uo#H5FDpSGeJ7;f7#g-L=9xc!J6G*Wfo2!{ak2+MbV& zIc`68_+-+_y@A--;V8IXBG<|IfPMOz%5#eXv5Uho@ZP{LDzWe<1ne`<>_68Uh%E@) z$&a_{*noXi(7p;d)I6LMh^-IDQ3-x_p%&u<_8DiyxmAJKCE;=8$CZp0WkGuxaCq+U zkw9#7IDty=(}_Lc*cWu{BlI1@zts87DqGkv9M!h)M5=;&7FCLaj$+`&gF{CG_6_05 zRD%2P|7kphs^DF6OydJXT><<0a59zPIzveb{!I}$RE$~J>%&v21h)=K%7cz_;PA35 zm}X0O8kOL^5G5Z9Iz9x%P|Ix9HkcY)Gmm(aEj*p7;pRZK)j`K<;85rA-2waN@C+)! z`9?`a&`|*#-g*TkJHr`Ng0qK`RYAup;Lw)godJ7ucovo5h@iw3bhv=S)mQ2QcD!Pt zBs|xQ+llyuVa6LaW0!3=4XM+1z=j{n90zSLju~&ct%eHM*;;tKTW#0J6h0kK+=SKDgIRcCAEE_R@cr9sEifMe-J&(MZ|eQnrg#0dR%reE;|sHeSYrwy8@1t*W=a);&aW9<6G$()o{@9r=;@<_+sg+@7<0 X0#oyEB*acE3>$!BY`P4ZHp$p*P1}U7gJh|g<^RrIMw0x~zVCeBxy!lBxpy98 zCSzivG<>SlEIn_|*R)qjs85U}lbJI#?Ott}#x$mP=*+^xfZ;3xXk|8_okaqpSTt}3 zivh;6nZP(U3mDI40~1&xFo`7tQ`j6}Dw_*TW9h(o95YyEhmJV&*#h_$vMk^tmJQ5d zxxhTO7`TMx0}I$vU?E!uEMmpL5GBE?^w%0LHW3zy#(3Cb3Rn3hM%34 z27s$LuI5>=hTRCyI^uFADK{#Z33ReSU=i4 z3z)<12IjG2z$NVSzyfv;u#kNLSj6rHmas1ZOWA$EGIl?31v?I0#l8ew&Atp=!yW*x zV-JQ_@+;64?5jX0I{~a>CxO-MYrq=zbzm*~25=+$Ch!LKE#M~hZD1Wc1*~VMfeq|C zz()2Eu!(&axS4$q*v!5!BmID#fj-NA2>cQIG4Lnsr@)`FbHEYyF!1N>5#XcjG2r9u zJaCl#0{8^`CGbi1E8wr$Q@{)CH^8UaZ-KvKzX$$-`DLUrb`km#dj|MN_9x)8>^b1` zEC9UB{tSG9y$F1Xy$pPX{RQ|c`z!Eo?C-$W*z3S?_7C7c*}s7QX8!^Hmj!_n?23#u z$=-l|lf4CeoBa=Xl}!P!v3G#)viE@R*XT^^(1G+=I12*{=PUxu%9#z!&RJxLMSHdELwhi`4So#k5HSBiS3$XM( z@HA{Y?C)TA!u|nvmmJF&cn5r!V0Xj*5!MC!ENmz2^RQj8mtnhMUx0PPz684m_7zwI z_Ep#(*uTMgU|)mX3p)- z(FenpX@i#8TCGjfhAo;lptr?!Sn4A+En3Ue2E)9%S2L~4_-RJJM+#{SNAY~7FpQ6l z&fO6_cNEe|zjIm%okS}6EHC*S*j=(sMj`cg_h#zr*fjrqaBA^+-r}BVB@20rFYp%k zPV;AT{};LczG?nN+R?tn(A@`RO#z&$#E@G|ve4JPgmzWvfTHhi(Std5nAL zjyNytjB?K}L>=H0OQY~zVdEqF~jEvj=x%u3@cOfd> z#!U&|uX)Ayzw&@yX`Gu!CGUYdW%Fy`4aV<&OCPG!Xf*Rm#H)=j9jDKm=%~SMd}Cl; z)U#hlfu+G^pyr-{rKuZz8wd=$(TnQzbjtZ^ z4!NlBm&+KsBQFdcxC8pefThiC@R>Y&sa!6@Sh;NZM!DSO|CGz+`bW9k)7Mqe-^*o+ z{HrZ-AYz19R6n{y<}lVs#KM@xcE7F03?gzKvk>P6 z9?>C)NIDd~T9$f#M<&Mn(#z5JoVLU{j2E)3x8D22T!-;UM0`TIb%De9Zrb+lHD6ui zFnSjB9`r@zBKj}SF7<}xJB)|cZ{2bB{X&Ov7aPjD@BLzju`T&$W7^VEhw=LJ-<@iI zw@hwl!%By-2pv1Gt#%mC|Nd?EYT{alQT1H#jT^46cZi|d9S&o|9lftz8>w&@m+Cv? z+O9esM$z%bDfiw9!iyp|IE>F^bp2<$ui9b!&A;IC)&JBujL+WZjlb>SMu+jIC;Lhd zzYlWdBHAC^*lU3f$YU7$6 z#(nSKeap3nLHIg9${faS9u-7fLM;x%*x>08?rcQ}Iuoyfe7^lX=hUh#4rBP9l<+OT z2Z^p;G;CSDRgUMKt&Z_dWPsR*4ci>X(K}vIJ+Dsy*`MW@b4WCvy2p<-3ZhQQgthiR_yldrkiC@tlI-5 zDqTj~svkeKf3XXq&Lv9PNjTlcC>G#t^d~WCm$qH zWg6wncRjrEbrLOo#vSjv-~Ma58)93hINT>m)D0VbuU$Jm@gj-ZJz`vqd#H`eNW0lp z-1-#=eXCg2DJp1o8<_87U#&JE8VAME`ZbA-rs&Mt1JN{Oj3K#h=_FAxY@9KBj{W;9 zBx()OTf7INrPmlchRwc{L{q2S0uq%TV>07ne@Dh%D(N2|=%8!3;p24GWK}y?GCI5Z61yRbGi}y><{i4UMCQHjZ~zZ9I%1oFIb0N?;?f6X0c5 z8>e3e$LV!=9KZi*pI`yOLV_%UMFiOdIRv=`c?63ImJs9<6c8*WC?r@$P()BnP(n~jK*w>MPT@Em zzVVd=s|Z#TtRYxSu#R9oK{wep1X~H(2(}SyC)h#IPOy_;7eNQXZUPrU7eP0Hn_v$Co#pW!0uRAn zf_((N1bu)m$_@3Qbf65Gd>wGFR;YRPRtWJy{oS6fjv-fPue+lc?tpbqzuD)4bqx%-z1;!pV81!!?hb_c zJOg>wK%A?$w?p>VF?49a9f;`~Hcj;2(QkIU&45i71|k}pH&itTY*lqt^;HcmftjKs z)p&4_wj`5(0!@WI{eAAj;X$`qxTo7J92hp;9ZuFXjT{>4>GyJ9r)v;H-ZSj&8uIje z3%h#UeV#5?Z^u4w|AAh2x8Ytoa43*S(K*OBeZ+Y_=F-~T&n$O6lhk=}x$ETm6U*;8Ft*%vxN$70 z^X2GNe|YNE$p|gEOXu?*of})R)H)S$F#k}#57%u`AhN~uboF|?o~|i-!NL53L-_?h zUszK?(Nr|#p(RCy#qb2enu?18QT6@aeo;9u+@wo)%0`u_Nz6fBRW$JVW)@7>=)mGE zo{HlAqsJu_yG!KEB3SK3Y2wU91sz$bM)oLmP3uVk}LPuQ{*m z<6{l2P$UhBR>(&y8Csza4TcZeq%T80a;F+Ot*Y1zzHoeke5+|y0#>XF{uh|hXyR~~ zI6|EsH6xS0N~dB!7*K5lT2sfqfKBvJq-IE)uNG^G&le@aQ!t)|p>`2X&Whqts}Ij? ziBD|DEDRwK(bPayMZ)l%4~@wuHsX4UA>(KOa+hcZCaoV&LeVNTg7Bu!F0V?61XM!e z^Tu*&KAFcK?J<_Roe2npH&r?-WFNlS;+)8>l$k)u7&=0<4A-|k{sRX~Y(38zkqW+8 zdG2IVA(*$%CD+-ErU6HxLvdoT(m1=42O&jbh~rP~}o= zQH+?Nl5)d&rBZ$y@SUbgByBtl5gr@l!;6o-qgOA!e7sHx55p6YH@>bhVAO!7I3iP& z(AO$T7ol^7bdPvJIxV~Zz>Y|$?MaVSQ3b=bsv<=Zxzd*hVhx{1mbq0%GAQ=vr`Y$ zJKfhMd7{2A!Ux|WUj#bOKPcMR3w$)9f<)n~6Wkq>W<~h0?M+hD!xIpj>!Jh%+cM;q}@CAa?0Q~Tw zy8?P=K<@&LaPk-C>yh!vlfnfqoCV z<08)z2;1jA6o^2KVfP@tuq|6=zDeEL*Y@>y5BIv)n=3F@{9hZqioYjqCu+`5qTslXyYPEh-fKhCEnc~R3bg|L_$d9E;+R!BudK8428t{3jKUm zNNv>X^ru=vvPEyx{o8g*9V?>EJEwJE6*i_ghO}g=Qo7p zCcRPjH?>TthFkR=y1#w5ERWYB?3nj#9xq;tOSrowB+_Vhsw6l6)b*0m73ac3CFMMv zswcyRN>w;j-lZG5zo%E0Cy41L^6&{_x=QfrD#53##PU_=vVuxf;9p{=HtQQ$+kZJbsFZuLR#*CHVA}D5d3A!l7@_&$onRo!+GT zH@Aj#o8GSbce+BlTkp~Ro<6BlMf|xueyWJC1fRbWeEv%C`70sjKcOVkw^BTWxWZk2 zZB{TdHdSObZMBUHy9gY~#S9^0q~J-X1YZRuc#12*Q(Os=;!4RBS5l?8&~YN#EFLXR z%vlLOXC?TYmEb9^giLWIRf;REa!zRE7|lD8$4d|?t^`kUCHNjI!Bbobk>X0J6jxe> zQ(C3C&`Dyt$vk|Ln646hx=QfrD#5o{37O(bsuWjRrMS{6#f4TW#^b9LlY(#VTq*eU zmEb9^giLWIRf;REQe0`3;zFm1`KR+JX=46L@cApj=dT2xzY;RVm7JF1+~wDjf(ugO zm(dlT6pSE-6yDESkSF@0U?jZY@k#JYp1JTuwc-W%>rwGF>(3ofoG0>DrMBwZb^ne| zCB-`5C?r2aIl&l;Geb*Axj#wq1=%MKC|_ zeY)Q}Bua6B>_?J=vnW)YHaq!#tK#z)oiG$%R(fWO;&eK5ln~q3rlgFdq>99Q$G1b| zk(}UciWE+Sf&MB+nn3P-b^_BYhG9^T$ElP?!4OQ*Ydv$-Gd4>1RzX-B}i4-Y; z<`gbCzo2u1bI6}UBS=(ye%6VtiWiogsaAX)pQ~)3Mj)ja!ERMm98|B6{G8+l=Tf9p zEpg8M*@|PXcEw9poUWW}BMhx(`NO7Vlywo6Hu?$-T#JWBTJ z1G@jl141qkJG+oVEfDD{InwoDQ1arkGsBA09ayVGBmR>nf6I<5vfeJeQ}=i66QxusBo?fP!r?=qC!t9x~Szw9Pk%p!+EWz#HD zf|6swk`%`*4l2&~Y*;BVjdmr)o*7DtQ05ihWiEwU7|f$k`GOY+j&~*&lzeW+i7>^p z7oS?Gc-iW6AzrO-)Gy@nl8Jn+Ro|wMZQrTnZoOL{bN48@Pw&^q1_p&(im*o)1`82( zVky4IjxojOW}FODJbUr!62*&_pY2k7twSH3twfE!Nx$HnkX^Omm2j-nJ=Qnyd{oRa zb~5aA$z&0t9F7f^CMK`Jy?S(3FoM{OxMTA<d*(Mr(~U8qMF1Y;<`3~lD@W0i`}Nk87H_`;mi7R3veosCz# zY}H7pE$?(fcC=Bi*DuKCbdes756+?xaWscw#Z%IbH!Hp%`=qIOe&Jb*;>#(7Y%9hs zg$yAT9m=t`?V=`KsF>7jiWn~@B{`pzfyylnMIvEuX|FdAxcgWjNDFjX17v4%2~;I&PvXAQ*yqWlJnh^T<)e+ayO+?OUosFS0Lm<3b{Z`Q*u5{$!VHb zBvTgmwnNzU6!j;o=Vx1Gq=%N4pJ&vs2NK^wm|dvcjwpNrQP z{P(BNmeh@v)IC$u@JwRk#gc}T+fT5&-Jjn(R?=`WvC*%^2WOX0Tnd-#BI- zoWSpUdbK5}Y4#dRa3&SvT0~*l*#jdt`)zf>S>(gT_5U`QO@(-Q;|=zY^!ja^f{EnA ztpnfckur?X))-7CAFc)XDo0mbSnId71?P|tuU+tc9K)!v)L39@EiHUBt(M?is(^!y z3hPFCMsM-kwgl73hl2;-hS6;o;M*0W%O4zta!*>c$ZBd(W2q<8U}@(KcA~-Zv+j{Pzis2BMUQN^ zOjxv?mTs!!w%lmJZ%cO5G8wK#%4s!G$!5z=-g=kia%A4=tkXC7?Q1SaZn{u4*1Xei z-*q|CIa+mLvESZ$IkJ6hXXjX#*KhBiT!n!jj-Fhvr7ym;xY9q@c{zRYL+kx%t1oBd zKeXLHZzbu~{Eub~$Cy*Z2C9i>}02XO`R=eRXn^t|c{z Nf9|?yjctpv{vV^^R{Q_} diff --git a/basic_function/__pycache__/chemical_knowledge.cpython-38.pyc b/basic_function/__pycache__/chemical_knowledge.cpython-38.pyc deleted file mode 100644 index eee87999297cf343e10f70862b9b3776c278dba9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8187 zcmcgxd3;n=mack9Dj|V@>>`^8N+2O7L<9u#vJ!-Zn1rnihDx2{r6iS#uPU09!B&xW zZ$#W17uuzp(RSJy$98FV9kpFrai(owk9)g3w^31P0j0n1zE`OvGChCH{JQG*eRb|V z%RTp;`109Hcoayn1N4%m*nxspH_#{)ZWQb%- zjtrGyl8f(f86ii=NEs!gWsHoKaWWpZN5X3YW=6?TGEpYUWSR1Jzf5!;&${AuxOybfdXUbX9Ej^Nu zv!z!y%Q^BBIakh;^JR-%AY0|9a-m!#7t1Absaz(P%gTy!Tz1Nha+BOFx5zH}h1@E)$?bB7{8D~}?{4`gxl`_vU(4O{&%kfw9{CsfSGjj* z#IMP}$$j!$xnCZT2jwC8ojfd$$fNR@JT6bjlk${2E&ndh$g}dC{9c}y7vx2GN&X-& z%PX=+UX}ll|CB$bmgwDLg^|8_%SBet#WOO@XLLasP(HMrPg1H z5p%0iah5?kxv#d=EeBR8UF`r}9Y!mGdZjyMfUW@}?%B~b4$w7Wvaqn=Dpoy47XVv| zJ16-wF}ex38MsAx?t=6S;8r6|UBd&lxvn<%f7H^Vzn1<^WTiQb_NKmdx%RGpk$>!a z+aRsuJ2>*~J2U-RGNsWo^A4}aW6V{by`{!R_4&5cn6mYgQ|_PqeW`Kw=)0G`llW_? zasOma{ZQ%}A1F0me1RXR8k7#g4cGJYO{sBP=i=oBZ^YZIb2z_p>x z(&`0nNTUM3%MhxOfL#-^US?@)Lx!r;Yp<5NRedGd_pg(E|Gd<#{Kv_5f0R^xkUTB# zC(HY8sll!{Z|f6Kd1nTdS0|zwhwVw4Jn?1H9fETtzDO?=d|F|IorV!&N0S;bdCgOoFf6j$>*zN9Ix&kY&6+f3pqPp`;s|BaIOOE!ntHH z;{=4@oH3~0n>XRHQ>P;5x8Ba#v0EFV8*fbaop<$2x zM%%RZ?nK5+tp3zJb2`%Ub>q$@YfioAP@!&YYm867_E52ItQ+^+S3Gk{bmODf@7mw| z{XF+P*3H+A83BjTBs_xtS00eapyTPv;&tCV?HCOC|aB%mR zduk^%qY_!@#JvzyF27!{yZ5KlbmzwKL5sVm2Oz9C@#3-#;SFw)c0yRY;NEA?oE30q znjt)n-QKqa!o_>ud^_hx2xpwtHr}kGt-MH0J{4t~6>M-8^Z&H=-N2Km7Q(W8Z1l4V>slKYc8s z+n>?sYtmlQKGFfwK~UI>hf!D~rICE33{o~Ji!_Eb6qFY+!5&j1rIFG}86+Pmlaxgo zLgJslW)5j6X&5P&G@LYobOdQ6X%uNRX$)yBX&h-h=}6K9(ov*|q)DX7q$#ASq@zjG zNcfA}V@@Z{ARS9Ojx>{$N18>NP0A+~kmisINyn3lNX4YNq!Q8zq>l0##jpZ7CiBO|5>dvd6F0L@@`OCg|4Taj*jVnvEC-Ebkh> z)_@Qz5po$!>f?T`jA{ewc8py<1ltI_f;n%8R*MQ~Ag0N|W4q;XygHDKWEyF4m zK=~&8VwEa{uEw|)8P?%}wSw-LuS9l@Sf(5yYG6|vx7lZ#eYV+Wn|-#~XPbSt*=L)5 zw%KQ!eYV+Wn|-#~XPbSt*=L)5w%KQ!eYV+Wn|-#~XPbSt*=L)5w%KQ!eYV+Wn|-#~ zXPbSt*=KtyiG8-&XPbSt*=L)5w%KQ!eYV+Wn|-#~XPbSt*=L)5w%KQ!eYV+Wn|-#~ zXPbSt*=L)5w%KRibkZhLJ1GJR@@ynOJQzF|++~b88=lSHUi>Y;e&%MaSBrZ#`C`-J_}7P9Fxuta z;yvHn>%Ao{?&;Nz^z6dT=Zq}=tJejdczQge7oJNW;?1h1H4ibQzvr&-$Uw+Ga2F*sh(C-ssuwG4F>e zFU`{C9PiIq*-%;4U}aXVsH&~1Yf>sj`)dl@qLENxS1e=}ZWObyI@}RxZwkfYg`GX| zwrEF7MPq$oYakX5wruR`2*$(Fj>2GDC=w0^+FLetM9*pui4mI9*<)odYG%I`8Gz)NFZ+c0-c?q4zYY(AtY8>BHWpm?jI3oZ*OtS*b?vQ4Ecuz zyG#>BZHbx^GObKk=^ySmCXHjM!k7;seQV5NQ|F0{@$&b0)4f^VvEC8dc+GqYf}%tYhCQ>33X3u0PpdC3 zva)NV9Z}aflUmvi>sB_7Wg;=Wp#Qiafu}%mB3lJUsM%aJm>5Q}JD-d^5=Zn6vAnW^ zS4AQhB35y%u#!(jN%Wg{vQ0(9IFU%NC@!o`3{j41A_KV>M$~*Hk*)G(WLVZ07Y3bu zVkg{&o2esXD}4n}%3LO(cIz0Zz!fU}zC7K?$idlwvrQ*a){GhXeaHx3Yv@W7SC5Hn zs}=Rn9TddfKRj51gg?{qDpKWRZq9rnn`*Ztq0Yg;VkOT>z8i(c>`}ic^i)&* zqo{ggVt^Yh`!aUi7%w@MJcP+^`O~XKin$TCCs{Yv*J%%8;W)TsN6(W@`7;AGr8d&HuT-PZJ7?h>DBa%Se#89HC|`ZaTwJZVKVkw%B!-|r z%8<9Sxz>=HP`6v!A)KZzk;tMg1Mqe_H<6LRpeK<5|NL&mN^qfq9u-v$5@`(ubFD1g zkDV?Jb(bqDoQwR2C6_E#!yG4gQQJpKznCfz)Eio#RFD`5r{+rb~h~YGRuzzfFDj3Bhkq>b-Z{;TEBLTb}#LR~ojSpcfBODJ! zV*ZhXUL|62v2tRafncb`Fr!_a{+!Ne7~#~Q--Gsk22zjztVNNibhU?;@Rchz60CVO zO~cc+-}iqT)2F{Z<-tCOv)St#^$%R_E3Ee8yawkvcszJ#9RJ^K4i~k51%KRw!__Y} z_mOYnVe%TB=iqV5KehJ%myNr_biBvAUvHnJ^+^MCT5`nSl=UaR@k6u6`bKz1<2z9y rAHM{@$=`|IsVN^P(jkpAhcqo8R{Z{boAKW1{JxVlZL+uDcmLl2sti2y diff --git a/basic_function/__pycache__/chemical_knowledge.cpython-39.pyc b/basic_function/__pycache__/chemical_knowledge.cpython-39.pyc deleted file mode 100644 index 513e6fcb056eab5499db4fd6ec3c3525870b8358..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7537 zcmaJ`33yahmag}*l0XO{EFzl-N+cmBL{tRwvJ!-Zn1rnihDx2{r6iS#uM$m4aN0(v z+oeZs+qQLZ!L}Wlww-pyvBiC9r`=m{+U40UxboU%P*mE0(*J+od!azTNqzUPbI(2J z+;h)=-ck@6H7X~7zcU`ZqwA&P1A+HAN&V}=#K1`0nGpy`Ks3=MP0}SpG9^p0WrXBN zu8fpXl84V|$(J!wAY-La#>se@AQNGG47?^mGggk3$udQz%C!IYGhL388FKvA`N4q9 zlv#3uoG7!UNKTSDQYm?+e5*Cq&Y!FkrBr501M(LKA^hjKK=gS2$AcK;W z3uTjRmW$+{hj-mpZr1YmmTteJSY#z!}5qcDv!zI@`OAof0U=>f8i zpXDXlDZAul`Cs{qydtm4Yx26-@`n6X-juiGZFxs_%e(TPye}Wfhw_o^5l8+eAIm55 zseC5?Cwt{{`9k)|m-3Z-E&Ju~@{J710XZn&%6D?;YE1$d8itOMhLMhufsu)kg^`Uh z0wV__7h@#GD2zOe(HQv{V=xLZ#$ps=jFSX#JjMi!i5SOVOu{%8V=~4RjHwvYFs7>t zaqr_WW?T{oGKb?aW?mAAq}`gaEl?ZSq6L}JER(sOfN`Qpk69SAP3B&7gzhAzn*&|3 znP&2gCAZOYuB!pkvs77~?CO+NnX)QR(acktiWJQ$N;5x2bE?vumZCXbX%?ht&iI~Y zq0%f$v01D%l_?rSX{w;9_RdkGG_@(3I;E*k(JWD#h7`?Gr8zT2vrK80r)V0LrYS|U zLTQ>)G-oMIONypdY1*J!>1DV|X;!Cb)+o)|6irZR&W5Jl^IE4g=cH&llx96NA_-#Mw(GH-_{yzy&+=j%cHhl9kI1)TdT3PCR}48gq$a0>U5fyNC8!R;|Dy*Gc-V`ooC&Tqey zd+l~T-!R{tnR)SbPabQS4`mb-R%XsH%)2HB<&0Zq8D`gv?tx^+Z0!EWd*=3}6&vPn z7q341os5G8h;nB3F*)X>}_UyL0p0j-7S`730LswsP-~kA#&JVX5=6ba& z1ouAFW|(GeG`@G;N^HTa@*@a8Z~v}l==4>F*?;ZC^i?lH$gQ8%uP<2b<#TYg;dG({ zTpnuI8s>A`>h9aO1A@AT-54~?r+@X}FR$JY;o$bqcQsCGMJG5AgFy-Y%t8pn?mI)Z-$`GJp=(sK5|RFiI(DJ`6md< zE3?ZmcY^MVIwAbt>bmCBn;|H#f+&`*LCLR#pl*>LLfAKLSG;4|xrVb7DLT6dZxP-m zyi4G(0<+=K%~MX(VutyfhaaDC!tQRvL`OgR$zw6Y`IJUq67~`H6Alm#0-_ooMp2!R zM#v;&5ON6FgmHwCfTEZM3|KlLjgU^jZ;ybLNys8(6GjmDKHIbm+Yz6fu!8hLJTQTbT zgZdIY0@^x*dhK9PZ-_$&n>MVi#p7T()*Cip9AoI~g!tS5vBVS*4Mgbf6f&_#$6&LwOlbQ5BLFnh!QV`H!{c*+EG zHUpcrP533YZuVw}7@e;#PZYB(bqC-rL?g*w+>B>8NRKZtn~wqT!AW{XOBn zXuPLA+!cvM!=dhujXm)Tx+7vn=JpQSISad)Z#>q!nAGgU2qcRVajUPRbFc%U;nUI6 zAM1=*a~F2U;h$Ix!S^Z3u*ddBtY}1VrMd9MC`O= zw6`cdm>=ry?(hoN(Kpx|362W)TNcXL5w|2_*;(FVaJ1{_8z&Bz!g>g)+T(sORrX+l z7kofV*Rr+oTE0F}x1PfM2bobF!mB+>X7YSyHfh%?ZVvBRlx95uGdpc#WYEqyKh)hH zIlSm5l(P_}a0|-5wjLoo4H(K8C>|_M;`Uc&k7%=^;cn0u9?C8mC@vW+E=eZSnoG)t zaxop8Q&w6IO)#ywyv)vNjQ7M{KS-l%wJ! z;vDWB_F`B&8E6#M1I0w7HZ*20TV3&3XVgEDWh2;7-jN&n zM8KPYlff)ELYcoG$@hyjCz;G~x94W63y1wOnro`ceXsAGY);ap;~k_B%xG?+X})S0 znpUWsD$re!u?bb`j1?+8RMs0T#3;YmZdpCgc`zO(ijH8zzZ{CaC^TkY zr9`2pQu$Ga{EG$(&pdSb=cSKiC}uYNV(s-ZzB*cQI=nMfP~~> zk`dQyQD5mALHVXCiA#^`oD8Nnhr>NyWOYrcld0?0ZE?2T@U=`Yum9N_7CUo@7Ot0{ z4x)ArN44Ig<-7jgh^0E`h&as5dJ6#0>Y6GqpX6BgzT=VMU0Lel-c1K}g%=!gb8;C{ z*PUCsWS)2L@-9ql&=U)Jz3Mhe4$@ZDWVfPvo{ZaMMY)~R7Qv&?jQ51Pz5a8Z^1QPu z(&0GDJ(NH7E*fQ2ft#>Pm&-Y(S!x>!YHtWfml+-WzF#tB!!Gf7R?dg#NEjCEAME3@ z8A~h5N<)5oy~Ax4r^*q0xmz6-%e#2|D70k)%F(T$pF}RU z^ZFy1%w(k0pL~!r$ zG*(TY%w_~e=QD3!k!zpKNJ2N5%s}ws0oNzHB*~0Q29w-ZO_FIXCG+fTJV0HpJ?D{L zUS*~8NV1<>Ij5GYRIa;eS?PQjjP#N@g)=nJ{Ia4c7GG{Hz9?FJueJDIYw@Mc;@hmn zcUX&Wv#!LqSxW)mn=HQ4TIzw%7g>uhs21NytqTa35PnLyjBq*O2Ey%xy8zBrglh@E zAZ!QVWl+PfrcPT6L(B+u_jZNs^v+0M$j&fBF{O7l|c;1xLMBCi=wAP4tGskq*;}_xA>Kd*e}r zQ@) diff --git a/basic_function/__pycache__/conformer_search.cpython-310.pyc b/basic_function/__pycache__/conformer_search.cpython-310.pyc deleted file mode 100644 index 2c51dcd1cd052eaa87d4c9de27bff35e72993366..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1992 zcmZuyPj4JG6t`z)XJ=>sWZQ-wP=$P>D@r1SR8>(0L`fT$x^SMZesK7#h{@4t?A zJ%s)=ljUZ^L_$lD^LM)X1%h@8fdx&n&Q$eYK1$rAP$>eV1#@EZHFgZP{m1{MtqP* zRFFZM3{w$vA=ms$XP*iha|MEpYHs$;)H%6(u>B-dWH*gSdUV1Es<%O&jyMUEELYcR zYu6i-Jnk<%kh@`GcI@Tx5f>yKE_o*^CC9Vy7W-p*)~8DGI8$=IdzYStaUK(}J+m^` zSnLzA?5mOBazyS9qm&x22fK&c51uVqx(Y}@`ccmcJcxJ;(k-={nIy|zU6@h}YLA{i+1?3!SbqvZ zQznBZoP2Y3@Su~4JmCkiG^=tH+7#Ej~{=PsW1-59InUQF7MT~ zH@`aF0qaY5p5*I%xW(wMKubUdQzNl1Yk z@DDSQnxa_G&ebSQ`df$l>qk_EgZ?m21}aRG^@UZ*sThpbvU7cFp1P&*TyP?dC$H^d z^L4S0Jxg>!blH6$p@9L&ycwy%R0WtWqluW(=DZS;KrbFXZG1*hibJe_Bt!#Zbb|vT zs^?I!OnsP%n#4wqcjRyz+Rsk=5$6Qrty&@}SfA?(w-T|q+2gI^g0cJk- zEg!GiUAv1NkaTeu<{pf9?3-p4TmLx!iXLo$i#Il%HS#c#Jr)>)L09(ARgcD-2M+lc{PIo diff --git a/basic_function/__pycache__/conformer_search.cpython-311.pyc b/basic_function/__pycache__/conformer_search.cpython-311.pyc deleted file mode 100644 index 8673545a0a78603330b152008bfcab2535679d92..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3474 zcmbUkTTC3+_0B$LXCJ^m@Pe_kaf8`)B9M(^W85@F5NMTvBa_&{1+RylfgQ6i&CJ$3 zc9oUoHfxIrlqP~Yb%LV=yTw)LS6kI;t$w*uKX%YatfsOQDOwEu36^Wg@>kEDWgp{| zpWc~0_uPBWJ+FJuIkUIyb_zlJ_r;Urev8NCcw;TUhYcrXHb( z*ME)MKVA|hS(~&eWlWjIbYQnW`D2)-e9#_#{P{+9 zzL5Q8;rH_ouV2~t?8@dJZho7&OsT7dxyu`$tbTX*x;l!+BWx%cVSWD7J#TViaZ};z zKf{hdr*HGGzbgFVtLjP22Fr*gqrp%-HWufjEHC){TucP|qU5qnkhAr{Txm;`o($5W z$VL-ldBetUGKDX16;^L-e*MYDotv?zy2`L~Y?pMF-dHNpwv_jD{|7dCI+E=$En8PF*t z*w+o_35qE}^K?`Qn0b7Xit&wi-t6g9C|G|7w4)h8v4UY;J9a6K1fPttr=tR16~c=1 zjp!)L^d}=Cmx!=Mmacbudru`qF3L@_{X6~-m=&^g7K%gbUBy=II)$pZhR0W0v6l}b z$bmj)5DfAJ!3>M|te`#X33U#?E5K0=kI^xP8}4C+2{E1+PE3j6cr4gG`1bH9EpVaW zSTYt8xp-{2!c<^sJ``?EOewB$HpcR_$X1;_?*%dVh6=*~GU!Poa(HG(t_Ej=nclp& zA=4`xJo$!Z+}G!u4&uI{;NB~{TW?R^Yt0P>bM6b0`-1Ggkm-M7LN)cWBXHX}$W2rHgUE~lcavAREN;&=*N;^*1q-vJjB|g=Cih5omP{^A&>gT;t_`BeOwF(fha+X;a7aZ1efqdCm4P7ME)nN%z=9-NAf#v_Ui(6X&c%?{z)ae>3-Lm?&=VoD*V z6+;}jE7T>P6WItC!+L}wae)ogDkh<*7#UBBLG0S+oT=!cB=Y#Q`Fd;*!6^oOhzbc# z2#$eK3UN^}(4ZjjDypSVfLlx=&5yI9fG@XN|Ed6wvw>1uz6`-)dE;M$=@;N95O7U1 zXxpGS4iOJXYSub`YTbQk&3)*0hvYsgyN^obt1|g&PJN$l>j;(3o^za)$oFLOy$I{^MhC0X4^8|Q1130UUl6&pBs&BL((RCb+q*e^e+L8Z-U#$AQ+pZ*fF&y)3(4-bS^?y^k=sYR(ebqYp@s|H9&70L51#eKP6G zk-l$So-COsYv*dN)hwDmw_mp}g>J>ah)V~L%Lk83-VWK@A(1C!@h`q#Z} zYu>h1Lh>Gwy+o<_{%lX))39* z>^zz6-Zmk-XYSOsQ@N(s*PD9RntJd3{J!wdWbVw6eCB7lkqh$3sMHjan?jO}k!?)Q z#ys#ebApjF9-dLV6L&Pu> z+)m(8D&C6!cBr3?{C-S`Ka|vK_FKd^U>x7XXXFo~aR$D?f5dme5P#YW41jGtK@fRV ipVPj1v_GeP^T?7>pKa3~;=nf2f+Icx4B7>#TKpF-BYBtr diff --git a/basic_function/__pycache__/conformer_search.cpython-313.pyc b/basic_function/__pycache__/conformer_search.cpython-313.pyc deleted file mode 100644 index 41e00fec48d2e4ff92b25041e65a438b875182f1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3058 zcma)8eM}t36`#E?j=c}?ePF=YUK5jW!~)rd6x(T2f(>p6IJwPXEJQBL-NLPXyLX!1 zYXHm2$x55UX$nz;Oq?%UN*h;d8~3las?}QgqiX)@oe%ZU)TtsRhoS$(t*KP?ufEy4 z<<2fqJJQa)dGq%7K4#vVx$knZ2-@Gy9-IEej?fo$Vl-2kxZexJ65;Gkf`%De9E4g?PkE+TcqU|tZ^?}9k5d%AjdjPjG*sc~4jVG1u@#L0 z3U4}MR;*Q$&IhIFtQCb&4^!eRwz1J}(|!~%r&@r;-hJzXwaj8M^Znw_7w=wqb?yCE z*ME8Ki}Xd7EiPPKTe|Vp%`2>&h>MX_T;$q=$A=`XBwYOXHy{}}%B}zIN5${`xSC0S zS5k5~l2E1+SQfF$wM&WyvL$U>dH{dpwFN3J3o~Is(?mI`Z3(Zvn=XF%QSrum>mM(z z{p=bDQ%NkVVJ(J5HO941#S;RBUH{d4cYpHRjpa-+ee|Q24Gf%0YLYC? ziNobM2Wp7jh&gdT7)j!+4i%{i2d<$uxwe8HmcX7GPz?H1l`e9nMn<|~30dq)sUq&0 zisG(h3X9=>{<$h*R*NMRoi{0{QY1W;QX-m^P`WDcswm({tTQ=Fywjp0VnGwDb{zXb zAarD3(@(G^S6m%d#EU*~OK=k=NF znQ+ecQhIpPhUyv$?!aw#;Cg%BeK2~Ukw}k9 I%W4M`4lA%zXGL9ZcG8F|~bK3juyw;pOn zwK8yxkwagDU9&B=-T2QKA)Y%8RiXb&vlpGbg^=O^=_cdP${2DO$ci)M4DPC|7?k0O z;sS)-2DdVXY7KCs_@-kH)v?tX)=j6LO>z_1P? z_n5KD*KVt`u?J2NZ9HpGN?oWf7$|Y8`w&~f1&PP2*GjYzHx3Pdf{+7oa0Z7@l&ER6 z59U?p&{RR&3vwgxeAB!aO)&w-$eio^g-YE}=LSy;Sdo=K)9ZEG-c*kT2u8Kha=3< zIrI4o>|9krq*Lzv(I+lIckqeM-l=T}c5crD)>aG;pTE#KGdo9^mjeyNlu)5Y3+F_z z1MUT?RM3buiQ%5Zv19MWZKGy(*Gg5o)?#Fo{h_CPgJm9D|}Z zl9KzhgbcD+P^LxV1|DUcmbBP+qBEox=KQQYnTQi7kXyczy0YMN300y+nlLkjnI+5| zu_OSUurFgt6XTLX?TEw^su-vzHnn7!J(1GFG_(%|lbGQO!X6r1T1W^Z5eqH;#126S zPk~c}IZG@87zkJ;v&=~di%k>ow5X|cTS^yBrN=%BRsZlMf{y^3rk*Opn@d zue|w6erNlicLXwhtKNo!w=L&wyDH?p?V01N{-(<>zWHL_-~<(^GKn&N+84$5xz=!EE8#`DcGP zvgt*hj(^l4-+_PKYhv`!y0Y%m1^d{)RR@IqV%xwW^oK*PL6bRP-k^8f#yTA0yI0<( z?bD|q(~bdCs)zD}p9BhCC{`RjCza1StJ9tIfvmrba5Kp0XImAgpptOP@Vo~tTe_>zCC3n8`y-$8;FS%T+v&?tD^PQRRulasv z&gJtag7(kSTZ?}g2tAR5(Wb)S7<7IZ2tfpU$j3Iul2<&%S8X-nRZp`uk|SwiT*Y>Z zsc3TXPD{f-1$_oO zKMF*kHtwL7MsTE+l}G`qMk-KEprurVB9IehQ*EVSr5>RUZs}1f(yI#ip*D_TgiH(C z_RVj$Hg0q`KJET=WBc~it$SB{KX;z2UpAB3?)A%Cw;nxxd^;KWp+_qXk6IIFW>30e zu&Mj-H`q}+V)g#`y!*+wqe7`h+6Wqcxe^A|kolBxYr+i#tRGk&+5|iOYu9(T_|8h% z5rX=4G2F2A^?LXIcil&Kdk=4I{n&vGyQ`Vc%c908uUQjSFLb2(-miDJAN=0`a-+Nc z`03-%`qwtQ_csUONt=nTGp812hq3nlyxF~TZ#P!Y<>Eg+b~nE0b-o{3g&d-iVjB?O z41~vHz3wo_=cP20^7BUjLY*}NddBB+me=h4$NdYG zoN9Q&t$Q?4X&*Q@H}`&BxW3z@r*`sRO2@|TS;U3svm56|Qx}_isbR8);{5O+$}Zt? z8VrLy{sUqKn0_`tRSSJO)!>v(RSBD_HyAC?ESw%;Rz)og64?cZyOnaa5mba522*>Q zIdxd2Hd$Yd4=mDvGDpyn0?@$_m)20;z`FdDqG~Fv2$FvB4TLJP(s!GbN^m5VKMpBx zj?3C7xUEFE10P1`Q2-P26Dm=zfaF6IC?6o9wE+oG`6=L_`V%el%77LvGVUOfx~ocUi>PHrW;;W4u$gL`9ndN5CbBfhjKNrB2eXN; zYD4KdWMfQ z)wPz{94RX~wAPLsY5{rk(B$Fj|1f6rwdcIubG*FPnp|0J#`sd{rRcCO6Fa>rrl}rU`kMCzTXuC?F%Ly;yhZbs7*mdxDDjhCeMr zpT{ZY1dG%zfJ_R`yQ20US+O%PURd=ngr1!Pb{MNYwzKeb$X!|QFl>yQ>*IBb!M1hoS)CifR#1mWu-kduP(~r95z(*m45*fcwZO* diff --git a/basic_function/__pycache__/conformer_search.cpython-39.pyc b/basic_function/__pycache__/conformer_search.cpython-39.pyc deleted file mode 100644 index 05b66e4192783036227ca675516b91f854fd434f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1992 zcmZuyOK%)S5bmDW&dxr(iA@eDLVd$4MFuWWWJQ2wCk6$Z$W9b7TC^H(_j<>kSGs!? zdp&b9*ZcrLNV^A+I3n>U_zQjI0R9U(Q9b+cg6PrIR99D5Rae#56SP_$g7N-uKaUPv zg#Of%)nUTqb9lugAQVv?p&0uZYu<>A*!0bcHzUipXp`2c`yKXe+M*t90OwHi7 zup4!;vUW!-hHq^riZpw8qut;94Q~GW;d?#)_-tTl<@u<=*M@f&Uhx=+L=!wkrA2XJ z^^L*+Y8EC?OQJKoKn2LD@zyMD&~yqk#idi&g)=lD4l}sI2zdm?K1&!6BqMPeu|XaM zoD9-rnDUr$vE@}d&jTLBOoCv$n(KXYbw)Oi_MU~39HbFRPtMptcDKok5hG!e<#MUE zwcMEGasSo>*$flC<1mj;7$@m)#XFHHIh}`hyFU&t`hk=z&ZJoEJ_s(tIFAX~o?BUH z-0tJC=*tmjVnjBFQ5tBkM+e7yWF|}_O5oJmy$C7eVRE_}N+1Z$vnU)usupQEVlwB+ zN=~9s$ZfJi^j9Bc5jLU5I!W(k+ye=_HG8jq6tP?I$mu?d|&>tiOPu z11kImoP1|~aKD}LJYh$%(5qtP-+clnrw4f?!z^NzNq_z2)2Cl&GK|A9gYxL=vTjYe zi_)n!SYN4yYOdy1dAAwky8TpZS+;&ig{pyL@F#vnyba6mz8Y-zUJ7WE-VTk!1abg< zmGa&m6X!C`dfBBMrAdGH_<8Rn5aFOd%#(o()1-H6SFnH&MqAmXT3;k@B|rz<2<;2C zTKd~5ZL(kW}wEu3q-hrazC=5ARBD^}tDh^YOOsjZ!& z(kr}4ojTyto_JHRsr^>68pzz)TsCH_m91f^XU(DkUYkX8(gK*bK+-PiQ#@I#^M`>X@+9eD3fEp1({)4Dz*SU0G5ZIo-p+GL~fD$i?^_oi^(*T1~>DT1q@aX4;W zl^-ojAjbQz0-iwoKLBJf0YDZch;$2Hk{}^a)qDVb&ji3F0-ylUK@pU)K#m>|aUNzQ zKBpn?n!FC_=eh>G0c7myzN~H*@(yp0&8yNIuQZjOlB?~9C6G@ZZao_Q595tL_*b;I zf%mV90M?Ql5R2aeL_Fnbd>n@n6TZDe=@+8v zBvogCi|SHTPLO3Rp?+hZ$-_L}k!dWH&4c8W`E8JC!K+Y?zN8m^9p=ZE@ktu_O<fE+~C>;8hE~4*368b|pl)`7q0tCj&5mi{q%pfg}4EQON6=r8-P4Dk4JD@)U;eC*b zO(4iH@CJabZFtx-H_Q(17#2u627nK{xMSYOcY)(t8H%odTU-N?YrKM>;_oa4HQ(m+ yJe=3TtN1Pb6Rd7NzM&O0b}M@;VQS#>DpvP#oKm>VKGnCVfZsK0;wE-Y@7_Nqjy8z^ diff --git a/basic_function/__pycache__/data_classes.cpython-310.pyc b/basic_function/__pycache__/data_classes.cpython-310.pyc deleted file mode 100644 index f4abbefd3107b3109fdce332739469a127cb391b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 24657 zcmch94R9RSec$cg-UkkcPl6yo(bmeA1snw)Mah;GhNc8jqNB)y6p@zfqv-SDb^#oC zAHeQ{1hxlTHvKtcb>hx68IP+LbURJiq^Z-SY3od!G;y1@?li62y3?e6UXM@gf!1!mv9_x64M-~au;-v4c9WTcS7-&g1~gmQgWUnM$UWtz=udO0Jc!(aNZl$u!4WW0f(9XPe`#iOPh;bBIq?CMBLn{7~hPmU@568nOzXN?C>L zdSwc!VXKH#5vjvSjaZ{djUqK|=@(L^vD$epl`33ZYq(~sV{J8U)3R3@ZQC`yHQTIr z9AwnIn(2DZR^8ij@M5Lon2x=HsNMD&?Nzhpby|&jqR8pOBC;0DbIy+I)tU=tt! zX?QZ9Rl9@coJL(*izbUgHIzi7F2-kKbhgu|H8JFR!){|3Zlkq@A)!6DTzCb`pIRzq z1LFe?*9$Tiw>Fw~kojO6uY>erdnd?0y`lOJjK>;vucQT;x`ncc7~ZDU!teN2yWv&q zcC%T<99`Q%-ei5v27qhL>XmlqTGO^xZR8GdiqdzT9Vh={jJzMYb_$r7s_2$hNn5&Y ztmszSGCpNgGQf?Tl}9X#7_r8vbw2o8TL96GGZM%(H8yFm8k9lc(!Q5Fe{mUHo zVtW-M=(MnK92WmpoLYTAW9N|_qD)_FTMOpZ1@qcv^IF4O!;Ww_?79c60z^VI?KRE0 zR2!B#-)MUn-=c|S-rNG#Tiin{4InSK75AJ>n@tLAn>A+@FC&b^?ongKXgdxtuz%j? zCGRT6fyOSaV7@Du@Cs(kDPRj=b47jhT9U!l*VvOeztZf~IK*cFLf|?UNAeSEvah1u zww=|TUd0Oy5Bsy@abl<+z8+Y=6_SJqEKeMSMqC8&0P+#%mhEN-iw8*eMzY*#Teek| z_w!9|DFT@hV$&0kx%}Z<*HlHhPSRwjoq(xjyLG3rLFmob+Qg(qbE8?S+iRVs6|T{W zT_YVd6_nAi6H=dE|(q=v#&UQ&hCl=Ltpsp0TlK~lwrdN*v1fS`=lj*`tpsQUzH9B`xIgZg8SE^<{n z%j-bfkZ`czA|~A<8_YR=T!G@W`J}yb4g1;zV73TeB7^DLif3S*BeK&&5Z344v-ajz z!?CT?=3GdL=7c)|4TvIh5kZ;by3R!uVeVi<5h+JQ#4BuW5+~V|q#;a5zS=)=Qf&t$ zZLi*IG{}X6336L)68UzJk46&YqfrF8u&2sU+)pr^=p{_3K7xF-cY@)>_6Q0+JPmRo zdM<(Fr&qxM$AC-D1cONihY$oA>CBm8#$g824303EVK9r}rXCos-CS`-nMgTr$Iq?v zC)dDLPqu3QTCKC*s3BIjPkZ&1t-9=ebZAuD|39huBa2vxa#T9VIyQD{d#P{hfTg7}c+Rn{;|RK`_0*nr zBlTG7^2iJNuC}Xt*fL+ncIxV1PQ9GIu4$=FY#5+%+CTLSHY?Dr)`Vyfqgfxv62d~! zsqKQim3{rRxl~GX65RRKYLEeeY`H;ZgH4sv&f74EAbrK&2{Kn}%`IDir-D2A%Aja; zmcz{=nAE2<<28Rsrph-5W=Fyxvs*%hv$M1-X?N(yQR`=sV+$kV;_y5VJ3^y`?h zsfx1L9x*7GH!!2)@&4Kkhgk=d}66N-Qvd{I@$^D4tpi+@FmVvu%S@f--Zr z2*;0`m+ihUXrmH$(*Sl}14&n+qHXmyrZWy0`gao8W@`)FWUY1;EV>+Oh5q10hj)1F>GvZ71* zJ1rCMW(K~y%X$ai%~~bvZu@9BhId)_U<`MNkbOMJh5C{ItAezBLN;XkMh67LtScQV zDFI69bDg$V0}e&9W}gBzPX!{OByqEZ@=8hA9YoYnHmU)&U1`x?Jt(kT6onbPnwl~I?kG{Cw?b4@3GU=lmCYwS zz?SS}3$LKhLCOY(|p%1KSUT_~}5zgGxsR8|qXM9`|T; zG_G>C*)7>jyyp_D}-(1k_BZVBy`Z;S(+F!;Y7zl_G*K6jKE-jer9`NXTe8_-dJ9t zvawVeb&lfKc@KjJ8N8RlLkuYNDe><|7(312M;V-9@IC~UaWTqNd$yQhHeAChc7o_? zRX_$t5|~wvv09I(2s<&6_hCFPM=?6KU*H>u2$t)ftrPG zx@*vk_fpz@Pb56GkD;AMAx=HoYXc-F(;kPIwS@MUkUy+iqW0D{yrN}L>_Q&+)fL^X zFy=0MqgIyU=qgFpNPCgF5L@53o55@{>tl(rXC-Gmk(29Yo!xHE{Yp3Id>%bcpvE{_ zAW_Uz$2M)e|MSS7WJ_`W_@<4~)Lf~H{T`frmf;s>JAto@8I}nNNTOoC{`NxzYO$(KwsbS?45L&q8tSywme_ek z@P(eJm+uqMwf1fIVCc*$=-Bp#bM^*%q8g*Wcr@j|JKEm2r5@!o%??)3JzI{OSpAtl z``5no8_&ObdUf^5KmOLIzw_@O`!3p(zopWwGmn^}Yt$(n>VeL^3`picR>^c6O0kuR zcnPXrhc?+-kXzwMZ0P7<)%9k_wZl?;%TlD(LE$+nbE3y6x_LL-k9F(d?Oc_r==P{0 z+Wzv6viwSY^69|HW*MU)AR(0PJIuGOhHXfJi>ST6O%j?tnIP?%* z%f6OJEU)3w$J2Sl3OY5Eyk=<7Xz?7?^;@GuS;S`ag6`k-hRc2cOC!Dh)pwwK5FG0C z!@yH*52SoOjT{5SPOD}%)0PG~Pxqhr;0CM>G((9(Ptp)6l8L>6GY?pbDf>2Vr$0iq z%c^B2);Jd!hzUiQCs!TH7d3qa&3;Z79NLm$#8|-9{uva0l^TCRY5aFY@EB0`wvceh zKL^f}T?|lz6B8W>iGB0D$QcPv%RFVq73SS$f-xEc%0I_)FCg&uPc2!hZ)zk%1c`9* zPa@*MxIvBx@f{`{{;um-^kZ1`^d796>lx%^Z|E@UNdNG5i+zYYhX5ivD|IOGr%Q!H)V3A@u`l zoNtEJ$a4->ESJm9BbcGH$bb{}A4ni4b&1<7q16#X33Q42TDcE=NlGq4l~D5hE~z+9 zF0Z6RIxc%d7|ha5O^krf6C6eVqB>pYTH=3(VDKO2K z{U>8Cwgd}Lv=tL2%n#zcu)GTWM~RIMt13oGYrZ50F0R{H72UK|_D9bt3)BX=js<)z zIzo9#n{K~Dg$?`!iwHK8t_Eq0Bs9RtuI*7tZjJrWpjG4-@Q8sjujlnbdK|Q@2+B8Q z6hQTgFi=K&SdgP{ZHU@@xYDnZ*hZ+W>~+^QpmjmnC%dprR?`py^{TN6lRUU?+Fyj} z3_y_c2p-MSNJ)0jvk%$S_YqP{Cr&Bpq-o=HQO-h$ezYT&{fFXZc$N*Qv)za;vK=BF zP03M}@Vxe}rJEUnaWG^<-QcODtsqS>ilrDbftWe9#$ft4U}?WWK!_m;d?HIe@xvaW z7*M_sn#A;o8@OclAvY-WxPe8%SF5Hs!58)n(fyT;_am{a_*Ca=f;I3AI5h6Ublc)l zO>7AdBQ(Q z5+Ts*mbROQ>ayL%G1GgY41ug_v#C{W*HE!F?eiebyP0kV^`OjLNja^q(ap@JcC%er z(@`e_WeFU)o1KL+lx!cEetZ||N;h{!cg}5|@v`f=sHFVnX1Sgq=RC=NDT7^+-yQ1a zot19xbE&JEv+}vr-cUDp10=d@>=jgO^Gt6n?%ZcnD;kY*pUQWI(C_L{&?Y#dDo({g z*7LRBftDDfLVPqyKMCc?;3jQi5q*a(9D+Bj-_btd`91*+TS45Ia<{N?{sD7NWJRx; zt4?iW4aZnDr@m&+KYIQ`DQQNBD+9Ph8*&@BA7qhM++N(N3rUcgNCyWAJY;ujR-+*% zrj5>qn30r5^-+?Bax{^Ta%0DheU#@o@w?D|fFgrT_y^DEV#MN6Cx^qqu0YhF85~dl z&`6Rjs{bvxoJ1h&Z56ZsZ&6B^P?mV^fO*3HPw>t!E|`y&%roXZ=ey<3FP6-EW$uVS zS}K{x(NBo)y~Z6ZZ^}7~;CxAUuHd&)3`HwK=l%oqgj|jU&E=^VRq)HgIR)cDNxrCW zf`0V40vi2s6b#Z0Yx|~FIdl~!e-CHWjW+D<3WHQ(+B$;585(414mq2~z+UCH1Vb%F z3W6UUc{!iE2GK^#M&si zT^S0yw^6IA0MO#F1qvg5)BIP+cWF<`7bl@17&>)^0)C6&26=GSN$3%}W{5UnAf{`D z|5eNvw1TF;So|-OGsbrccln10g0OFOi3u^cIEsV^anb`tvvd&!(@GS)LAfaHKm4p5 zD}o)GZU?7MR%}ZTojM{H7gyX_#C3>f^AOkvMKZ3ni|17?l89tkC6e{jYUey&K^Bwq z+1trtLqr@Ui%q5cd!l*cT(=qyDST=&2nYtDb5URV zWx!xK(RvR3yPV8X-9I#-_X98|pAlYj4Vxxc7FlSt5jWNQ)d+HliPU*?*>Z1ZM+zXw(Dfd zFAtb?wbtSo6QGaxY#%@fI3giQj9`L{Cbz4Bnc)ECMv}n5`jn}&RVbAtGRn51Rf>>U zqqTN>)WzXA;Y_H=$~09`vE+V?Rd*T4s(yws*@u&fRSm~K%*7>gjqB5zfAsbf?VA)Q z!_&xfc-&D;QArd^j$MV^fXBc?HH{MP%a9JTd-{#k9yD~G-|QJy`i8cf^D=u`%RnqI zF&G?ohrC=jCvx#Mv{%?2_VRl}R<=6~b+6bhqP6wik?x2EDZ+gOzj;U#@3$Z)SOv%r zg)^zm2fadfXkbH`EVD&9`u?XJ|B zOw>7q+|Nm!IM*BAD|Sb2Xx7wL%K5T4;*IW&F@0U@!Xq!L>y2B7yCX-iFQB9zhfFik z&5LxCL2QVz%`tD%DtLz|Yh9-bJi9vzNKN56+|_aNr9t`{ouyMwhQCKryJOunYLD$r z-$+Rdztq)GAM0>kn@a6!Se0?Buv_qsAfK&Z9rk9dqBVU3=5q8f38+u(PIkxlCcN41 zM0c`Vu#UWhQjod|)(|AT!sdUnW*|{%fJCZ#v zp&&t@AvAo-oPP$=JM|kLiSUi+P-XZhHq_}TMUe&%E9h^e4ACsrS%z0Ya|cokwJ-1j zv?kVRb$$wdMjyJ^2~izq#q&iAwdS?jj%(h3!fK#pm!4s8n1y%kRrJHzS?~tJnW9~5 znKqpw5{FLCW7t+f>Xx=p(qADZbH0E8f(NA1z-!FWhanNLwAxd781D3s+3rhq&Z7;Dv#eU#kI==*=J}S&CRvEh`;TYHTi~vlUu^ z#WovEG3MASknFtbMh7AimQoR9SYn*UcJ-;g)VWFyP^se2al z5V-JP(56M`8i%+={23cU{BXCFlA z3+#i)BnKI=9L^$3Nm}6%#v!1AxM5Vi+H%vbl5>Z=cAy<2f2LV$EnBs-PoWQy;fo0w zK9aBw3S(g{$XM^faF#;;43|1gF~Oi^(S$*d?Q84FsK065KN7)rUzt8Q6UOA`0_R5& zA)H-!bkbN3(}=bLWrmhg5r9P`6VnM4?qnX=FjSwe&eEmKtuv8mhRqNOVQ>ex zWkMfpS;nlnGD1=M#>%zgc)4EAiwp!1uPc>ZKvT{qP}X0Fz>y?F9KoBC$PV_4MG_O? zASs-3*2Spjn}#ABilW)A4R1&2O=t^q{xS*&LnIpvc%J9SPWuWojY(E zd&{4>-Q@ZJ$Vo(t?fe}i;K%pAH@8#~P3_}iZrS;JXr-@@x3Wl=B2W!pB=-Fl7pbHv zTiWdw>4m6g=T`_0Wu252smO}3>-{U@{5ceTo)v!`k-?y6z4LM|2t6OZnS@I8$Ab%U z$a8qyui{N?;l)Oy+y!q_cxcM;f_R-Gk8;)#9A3cIYZ~R>sEFN9de7j^GYG?#*X1CJicdz(5E-4*7=&feXTlOjn`J>M% zpTGIVk37KiVY6#KrhNa*`PGy4mdZ8@oXDcsbIfzG4ubJs9!%zZIGL#2I*saY@h{Myc+A(&yr*KYcf=m_v*$rR%@6tL z^u93h;qbTOr!h%C%dx2v_lF6Tkj%<0?SyiQ$(*{c{NRe8IdwmOW(5jQ10x;!)j*=- zy+B#e{Nr+W31At~D$r+r(mkcwAS(bli2+$>-@pQP36; zp2hz$NCAd^kK~=fwSaL0GW%=sqZ>bj6fTecT0S$5nnxjxPvX~jEq+j<5I%6y$c{s{ zn9}@WBw27%Iw#RN1Rsu}q@6DvbM8aTImO_91`jZx^~y0BJj_7Ou`V;lqfuvtL4(1U z8C+vP_NajMF~ndhZ+p(?nRk=H-)2DBMU+D2r1NV`k`V{SGVX6Ghjf8fFQA0%yc@w! z;Bnu9Kzz}2g*yu;3mN<$D&(eyCLfqI3faQxW5%)kOkT>@QMB)mrmIKrxJy`p*YQy= z$3E&AJ7!fAclc+Oe z7oxFZykIWyV9&sVJ$x@-^zXFpM&8jV&(2vXDI0mV&)LUh)CPR-OX%@!)_vB!c0tBx zSPRw()PB48zn=>7Ple|oK1{mdePUYG+nx4_Nb!nK+i)Qdc9bU7Ybe%q2o{yLRDJ{Z za}Ue+|HS>ATB#0$-a{TP(UPUoUQ1IK1PVN*g8zeiWo3Y#Z64UM1okEG^^|1m^mFQbA5sl*Y+=xky_?%-u zA7xNuu*|?>u#XCn==~LR>+sayKNffKM*a(^(mRpjq#%o>+{`L;U@GXISjT&+`E@Al zK)JO4NUYUGM<6s=aJRhv*;Ir>2MJl`A2~T6 zN43622^BVa0neK?N&F%q^9fY&PYxRNTRA6kKFR(CKGzxhG=tB?c>*)qBz}oqf0p^5 zW6%Swl$By+yI(~y*JQ zJAv~xj`9fx`*tqD_A61xVkK~XbwIISL*6g5*smkv&)#-?Vq0*&jzV<;EBd3z#Pm5l z?g9d6l5zqA&2|*$D{1)CXx%iM2#8_m`4X#I;m3k?y&YepdY*PJ71x#8zk-%EXi7(fplE%mb zgZRI1MdZZyx1u%3!iTu{TJ|;np?D1hFtI}kJVJJO`XK8!Xe$v#MBAyOTjvG?q1gkr zL~zIJwEc}}NHaJMDC)%@`ggyXA@va(G|}5LL_|fbyW4 z*zq|RSN276>y(`(G}zZ&MCyeJKF*m76MR2og9rzi;BRA$&QpjarbqKLyG(#L_Q43x zo2zjKn#}N0t96;YWO;Q+PsTiQe~(oVzIANj&%xK;- zaMZ5s{j!f>*-s-6PI`zLxYcAp*1e}!*Q0MWuolvO=F_OXo~0%;1!JL=eM#>ccL3c> zIe0|$gT4+W8;UptlEL8r0*2W)34)&!=H}^hi%_>uW>X1A4V-+65bE8DY&Tro=$4&OVn08{fGj}_{SMa?x4P|X zQgA;x4YJ|Im=CeSK8h+H82w}U7hL-y%8u(rtbPF}ZFleI(fu|00Eyc1wsHPC5}@xt zi=P992yu~Xehv|Z?0=YD92O>lc0rFFpi?2u0jI;rXER z5q5Dqfcytk6F;bI%752^k)xm9b)`24ji3<0vtj*y43t9yn(l)K7qWmyIO%^y1jb)* zPBBx52M!ipaQT^w$b4-?ZoZM>7j@vAvzNJ%va&EEM+X#i$jv6o@P1z|k%sj{9WU^H zK1|a^+|F0Oy!5d>HiA@dWN{Sa6pS^L~*xIIO9@ z)M1rUhXb3>TGPx)q8{Wq%xI%=jB+v<^DI0f_m-usSLeFie9JOLCbu zHa;M0QcS7lQ{*M{wbmBQvCcP8(fMZ#ej7oc6`X&-#6M*47=zzp@Q)b$V+Q|(!8aKY z-<;oMaF)Rr5iG$A^KK@-%|Mh~VZ=go0_Hu$V3t8oV5l&;Z?DcF?>`bXk0O%SXJF-< zf~{{{n}_`{3wiWa>I#Cs!z^|c^A-cGV)dY)Hn2nF!U7~Qen@3+ zpQ0*Ug}HR>lRA{yA`W92;3WDw?&DJ6JpwF%BT~krqj(xObm1D-&0I;r8H_h?b5`bt z4sU?WDx^P@%QsTHaQ}kCQjU@$FzLbFVca_%o=v$QheJ}14oQ%_hUog$9fqqqoWz`u zcZ+*@+~yrY$q|&?#T{3~Py*#<=|&D`EbJJaZsaThS1mZK?~bD0aCcOkor>-Yn@@CU z^BJmr6QJmseWLGKxv}3fT!qxfcz91>ReifC{Mc8Y9J0?@9F9U;t_|NOaiWf}meLHq z*}w~?@UX^*JJ32=6mBHqfMW&TO;_O%)r`jjXX(&|9Pg!I8|)Z}t3I;k)|wtpNATGi z`W~6?8ZMu@_}E%p|Dbs)_NzW^UU-Yn-f`=p>0;jpVUpj=YQZ^~E>ZZ9UkE&0Z{e5Z zAOMS>IdAE~j|)Cn@dqyA=4sg?p~vYFzN>)~0QDU)))Ut<05WQ|XUIB0Pt(AAVU8~% zqITHV7#k!BQiV)|2rH$(#~j7UIqoaL3?d2RK(CPi)aQEMyMCJ#fkMxf=}@P^4SRJL zpelC(6p6UqU=0qogHbP?SEPtk=pr1;h1dw$5i#d!&iMBjP{Rx0wT;6&t5I8}^P0F= z1!-J35Ra=Ma}6giHoj3HXLzMCrBsrjREFTdr96`=!&G7*a?w+Y-s`xxP&v$2u~EooFe?5)f6>VM z(=k2mYqXD?X3=;|PD#*7PN5KLFVVq9-_zv#RAZq_<^eK|bb zl;C#q`xs>36%%%n+Gh7-HdpdQ|-a04akaivmU;S1@@A? z0mo~KG`J*cK(I+8mG7mj%u8w9QA$Vmj5gsP3kHz&KNb_53-o7H3TJpLqPHG$lPsoD z_=?rFp6sE%goh8Cl49|X-=^c;8CkwQ8Xm1)k6<}}2rvw!ABPAeE@98FoG8|>K;gP8 zN<_NTUqw72or|52MD5+A6QVnnW8C)WK#0$_#+xQ7J4%De%e zDkXAK)Pw=Uz+R@Pn&}!hv>6JVI!qh;^5HL^6|oZ{=nRBF3Zd)#BuoZEDMZ2*{YH8> zZyC6iz#9fQ610Y1!o8u(M;%W4i+qSx<1H1KA>v_27Q(#C8?x5J2KnC>*VPxPdN9+U_-z*4u zQD68hs31)lu;3T|WKQu&6@5*FI}N*OnPi+IQ-f~v(_0}_At4gt^^Q7G!x1&_A@Xu( zVz2|H*o&xUyB;nL(@$#Pk^#*1tz~@hE%GHw`uQ&buLlx!5k-FoZ-t13_OD7R4)Dh~mqTNPfrI^cNu@d*78~9pwtF|2&_zDRwB&r4Fhfx~@w=={BoRsnQBfC;SleqE-oOKbeZAbA#5mqB(JDg@^MD+*%d18O1DA%mu zndahC%@y6*z$daPe!83cMVsCyw-mZON`l9u=)x$?bu;N1FpH_Qk)G71^+}vM^??>}*nD=iPh)Yhu*zYqC z&ztWsMx%j48B4L{KVs~EFu0e&|74H_bxHuEhgoFbFHv;92N7|>(F;?Pxyj72?6Lf@ z>}Y;GQy4B71;ZSjIysY@JUo+eSf_ME`4jyy;z1T)*o2OzT!my3RECtV5$V1RF35lY zKsFFG`U}kcY1W;_ciK00%DD5x??;unaPq}X%8i0Y!6*db#HVNFlRwIaOuixqF5(#^ u;~F%{Ab&=EJ@YIknsQA92Hdo8@c}DL6#@Sc8E|PLJ$gFz@zlpB&;B1j4zle4 diff --git a/basic_function/__pycache__/data_classes.cpython-311.pyc b/basic_function/__pycache__/data_classes.cpython-311.pyc deleted file mode 100644 index a8512e072ae1ca290b0d22f1478940cd4d26b6d6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 38043 zcmeHwdvqIDdLN!7Kmy~@>_|Hs|!vtM#dm zn?8Qu9n4?`56RkO(@m4h;fFhS?rZMjyWjWS?|$=7oz5Zy+*4QXkL-NGVEC_eqc}^- z^WZ~x-ZF3o&NyZmG~(AZW*RpSnvE1@hTk%1VSWqz)CXfq~Drze78-*8}T%s1udLjfLP<-->O3EWtRxnL*^Pr+pPLVyp!X9>T=jSr{l#(s>cn+MJCyk!Wmwi+~X z#z8Y@dety!;mm+m&H`A#Sy8Woq@^!BIUf1E$P3rUk4}#VCc>c|MeKfjIu!O^@Lvw> zV7Gy>06q0~cFK>#S0=r~7bYh|0q^8QV25|u&xd_iuSL9*ymy5650jrIR39KEm7_dV zb9#1)UeiA&moqZS4+rQO#r30lzTpdgepDY6mW6tqlT>h zVu*@T=52XLFgq2EDdso>CY*C1S(utf$-?Ogn(q@yn=(O?Hf3rg3)#RAa>=6fI8NHx z)S!8lEY?hn)QvnPk~U?SCyR$BC%6FTqo~Qk;mL8DzHN@Ac{C8_9jFxVB2Yr0lt39k z(#i_qD=5HCpprlpfog#FO-V~AFgB7bxO4@xG8Cd=?Dc-RcX-FylfY-8vm^cqE_ild zAap4_IdyjGTKK}`gs=a(qi2Wwq2RD@WO`zlru$jWANKo($NZsCAk;N=Eot-lf`lA> z5oaH2+@;)lQ9x(~;JP7E-SAfA%}A_m@w8aoD^>T-TIL*ynpNK_dZ#G1WAU0;vt6p$ zjzDMbz?{kJyx(enr#*IdX}egrORC$2$2&V> z<4dQ-`aM$po>_}ztH~3Zvr=jX$Hn?>QvEhL1GQGi#p;bx^+tNBS^Pb~I~#OJiY6o> zfCuG(ZyCZW;xS-tjA`yDrR<`Hiz=dd4>3MaW9jdV`rJ%;b5jwJKE3`fBR6y2+*HJ* zPp`kr$jy>BHx*&&)9ddta+dphE6SUjia_=0^>;+c0lgLvM2b`Bb%*!mnMe&` z)<7uepYRS(PV!uE!XE~uf+y94X$cKwK9CV>kYll2Y2?f9%VgRFsR*aCpc3+d4k8Lg z6CfjF7T{12viw+slqmrC*gC}rC&I}BWIpcCkD2*e%;;Qr`F$f_k4GVz*Ytetyi%1vb-CYUF+hmb&KxXb)svt=M0E8?t{&00NwRGcY?~gH z8U^6uSEC(u7@2S@!^mUsfNV5e#INi)<0XbGLq&(cvAl=C50poSuNgTb;;Oj#vN2S_ zLR4Ih5C_tkv(jOlv$1X&;GM`X(ng)^z2XmfY4r^H$AXan=e-idvO692Dnxtu+EgIY zn>VU=B7g(|K25|=XsGx$b0)d41|@69n-wg~@TEu*8*t3Oono8>z^YwRF&AC97B5*R z$XEMGv1FaZ{GXK6#4PVRg_2HOOSN~KKk5{^pOvqKWB+GnL(fs8gpd+k6i0r<)7e?g_<6m`3hd-D? z(4(0)ux^E10ica51hf;QI#|THxuRE%U}KAM=itDI4mvp}po?<>mT)D2rCcds8CS`b zBW^iY0fbqVEbQm6g~I+Z|9RAev0RyKppM5Q#`SVwI802Q!o~VoGIqxvb#{8kz^!^G zN4#gwjQPXi;BbH?^7(z;>%AEvJy{`}&aG!T691R&^f?Y?CQ8AO3cDVgK0F1*9Z56bQ?{QUCb3A7!#a z{h09~ZGnu6_R{2A?+tW~c6rZ`A8k!PXb$?>)WsZm2U7Xy`R@HXmbzbN7ZlzZ4nqLo9|KEGt11iI7nleI!zy3?{NxxU3Q8V2#A(5BXZ-d&#@*lZ+u$k#wZ`BOjXb13ixNlhac_CVHUF#n<46-#}m^fo=jl1c(7=CGnfcx0%2e zfMl(LRKa^Qe%}Xne`+ic4kSyN#LzdzlRPmT421Ztl&qIP9YBhKmzP0WAtR-!B|+uD zPf@fIMm;8KSAVPSox0dFOSNL{6H@IHU{+nZ15+$?;bN^=xk;+r#8~Jm&s!JYycnxl z+%8sak*c;JusAnYpOYcmSuA++gjtcR?)Uj zvaJ(r>k{53%A_ITu6ygyn}-%IFBZnGiSBO6-HpoG(|I2g>wBeoNCLq03+&1!2Z_mJ znI$o{BBLZ#AfT<4 zzE4n+(%goLk}yg5QjW9eV}6a86BSdd?^YTn1f-;~>eCnv6R&AgVf4H*V+K=WzNlj4 z%SOJ4g{asTA;!7Nj93P~Mc0ZXreqFzsVTKBm4VI8KDETT{*3Z6^Cs!1lDC@bp?Rw{ zn0rarh6Vce_?l7o4y9^28n#cL0{uf0SmwCG8B0cMdS=W~OSoBy0UV(AAlijQwalgG zR0K7hvCK6@DW`PKwg@my1h|B=6GW*b;bo4keMK^3`DcIm{%7$cLr3>~xmG3gy;xpk zvY!$9LE<@wCqeRIk(&rdTK4n&B)`Mk4=w>uWO)Y3Sp0mzJ24sdjs`CWCL-Qbetsf2 zF}lM$Fv$cdl#Lhcr*~N9MItpsrFy3(K}%9{l#`N`^E4qw7k_>6?cYE3V8`glGoQzQ z3S* zqxVbN_g@_ju)QgH4Q@mpX_L2S+$t2eE^fZP`_Acj#{r?^z5owCSYq{GjwaCvFNzn*w6{h}1r^+&&R+pAg%pr1mM%`;z2+X~kf2UIAlF z@D0bjb$~W%hT3#=B`o;2nQu#i?v2Ufo;3%Eju}~W?ZW4-{?mNna^39@Si{#iM zu zH%raW2`65VPWXl4QEB)kv6+{e`T2b}kIWx=jHsX1tr42h>U*TRJ%W4B=LLr9wNh2@ z9m|I*FNXP=an3jQJmS6K}c z5x{fxQ43!drm3BBoK0Vfyzi%rVb2w#Ek)GwJCPAj_%*VZ(U|ECLLzMmwPb7xGP)Jp za>Ey7A{De~ib_4fh|vR)GC!zYBD_LOS;`yaB7H>cW(BYJhWTq=2u-J(x zaz3`N$k8(jMiy|ky`Yfw< z5a|khOkKx6L*Ot0TCb8N{xLq_=dSrk)yoH!6Jsz$(gj95rl`kO+((?sy2Unl-b7I7 z;NNjd(PpTuB_hXaKZ8Y@AgCGLmR8%{%^&R&dXD0{-`sh>rtX&Q65u;tVGhFja%{$+iSRNu3- zMXKL^XTMnAclV@FAI24D@Q;RmctJRIMmoib0|99uAdFm+2F8TRusD#4A;$(x)C^*m z%`Vf zq>|=Ed%UDuDCtf#t(`CVB2iq2TH8+<86DdqI$9-1tDsz;5^;M{bUZIPo)?rWdj+RK z$X4**h9^_oMN8uWuV5U?PCxszun?%jClh8m`y`uT6 zP?>}9>&d1ZxFk(B<#_l3a(F;xT{nE1uai5dlRMDK%`G4>56!#52-^U0Avd?Z&ScH7J}eyD(e=y7yNIWxp@X#=+2!C%o`I`bvMUj zeNt8HVp+VZU8rhLtnGMrX6YC-K<+e(Yo8ziUehpZt|lfAQ(V_$_C; z_H4ZW*$8t#*)yyTQoO6mhQtv{13!kqZ1*05XjaM&$b)3=Dmv1!v4%7ot4BFA=26J| zqJT0ZGjHS@qR>1}Pn1z3h_K9%f~YyBA~mH3^sA4lFAM5h#p$@w4B01VO0iN)(E?1} zcV-GvGh5UQKR06!LoTVfi|VdKv@mMd5r{_4eB1IK2Kfi-Kxd?M(TqJ>#P5k(V=CdS zR2Lza;*rXjfOARa>IYK_H_B2V7(~BfSE!F{^5DHu6PPCbl z_As>UL46Jfl8%&ogJ|QPHU~e9cuC9fB~8S>T_V3QTaaRx<5^wmD@iHA1x@EG1`I>>mjSct&7{l1_DBmK0M^$H4fM zk9pZvZgogYXvBFtL3oq+}G{gy7NrSERb zy-MNe3F+v0;UXzt_nQuy@YDK~=`#j@%HZ2Y3+9EBiTYI__G-5$JZ;OKjd9P$rApD$ zD|vbaPjAB0vg}zO_pDz$DS9?Zo=t*hQ^M1}?CFksx|do+&o;@kP4H~{BC&Q0q=c0{ z3@j8UY8t-fc*hYtBi3|FHQfTc67Gg&cXQm`96KqxJ0y39z^>1$4HZ@T_gtwl)Z(p7 z&Nc?-EeUVqyyd1dQCh!T+7vH^>Z4e?PAXj|unX-zUv}fbY~M^GdtH+ZZc* zJH)n_KdJOWAXzkvzii@0B~*vdxq3o1EndJ{4XEaNQv6^vl;h?#_u1&#Jeq(cZVnI~$duNgNfz zUI<|Z90^D%OPE0j%VfzK-IXL=qV{}{1RDCP%e5lm!6L_&VEmG-L*)hb8Rm`WZp{<` z{cA}B6eBzR;vZ9oh4P$!vuWboO9?baQEr+KnJLsRBx$vQ%q`y;Tb_6s+q_u5iC#>( zXvV^Y#po|ni*#|gOo;{ZMO~`7f@sF+IV@fdr2El~GwS3D^T2WyXl_ZO3!*<_cZ3WRyG$VP4JYvsP=`gD*L{;#Mdrs?r?ifES14y0S|m|3LTym9Uyhg;HlaYf;^>APr# z+Jc@B!kBt-#-$t8>L$4^R2@Z86V%Ccvj+r^E8T1Q{G(RwGisBW5)gyA->f7?C#bTV ze&*}THX7#qQKRmCH8@2dm$vk0jKDjkT+vKXm{_=!Lk=qWhIyEN*9<%nBc-z}$2qtv z{afqr(8e3sd@q|RkCx_s_b^duDHkn^mPdja?j1_n4x~ij{64DD7y(r%l{9mcVJ3r47KUD$4g?|)F-L+^ zRF%*LKlBTU1Cdn=yUBMPpsj+@=Qz;_)o~CJNgH{n#H5>id|(6wbJ(Y7dn6qcN(W(r zSA9oxcl{{#+PDwRM~aoqBj>2bwxxXiNpq09nzUTOp`T=#Ca#aQRnnmZp&p}2t~xA$&^ShruQ+duovr%bP+K*X+m_1f8|WHpQQZ9{A9WwBwC)UavA zQ0{yNl4XK#*yqjjC+;`5zI#a6@btaSV)Id{`KVCZm}mw~URp-+2+ki)RIiDJ;??a! zb^GGj9WLJWq|o(bqGSDnEmn4`IMy%LtP^V1-LLg5oE2-^rP}sIE?(Os)b=DgI`wf9 zwJnPk@!Iu5?fNChT}yo9K4Ig&tehy8J6_u<)OP+Ha=2f+TB>bZTrbx4O0~UrHVd_T zaNY0N_^)cE?*6-n#qMKL_c5{KxYTi6U;9s2trgboxN}xq_q4d`khJQMP;+Rdk+oE_ z!CAZPSRHq)j@d*k!MqX6?jwn(sF@$4-4`==KGn z_ZhM4u+(*Ux$8u{>x6Lfl<>l7v1?H38WbDPNR1?(vsW-6IvQpV&cBdwR4zL_afj!d zCl=QxIyNQRJ6EixRU1At0Q{5ye9n#VIX4nNwhS3Wsuq)tI&+0QCa%Lh@E2N6bX{j5vIT8g$x7b*L6 zk*<$2<}L3!L`SFO=oHumZJ*`pwejk;Vs)ET4F=t4-@^bfouh8q(HM6$#?~)4ZH_l> zUOFW4tt}VHK5|sTpkLcc+90~##Hw~mQM9cR$!3~PI6#+!~P+hRbw5)_;DlutB9cC zdcU;o9P^J4asJ(rHl;HZ#yJNwE=OIB_h1Xvb;HM7pOUto70#WP&W%W0N3Y8|GVdFc z=1YNVVE+p!vjCTy<1Wc4DN>!M3Jp}C7d6LLQf2+ZzFUWc^`mrEUJxrUNR=1zi%rxH zW*3>#D4?9A$kwcQ8E#wBX}uNq$XcOHc>^ML}Hw1 zc30heN^UABA{N0XHuOt??tPMbAC@;(%f)@S521Eqd9PI7J6o72samiLB_O>Tvm5ts z^~h{9%8DS5(c46g=Ng&oXmY1Ursqk4U5nv`6SvOEe!%R;`#F!ye06O@UHL!2fM7M^ zKY~x$z`RAFS)+~rH3a7#U{uChdiT@-tEjvQAvHC?QuzB3{D__>CBTY)siJ?@mUn>J zgc7e{^RgNp&}|Enlzeg_U!}I1(zUhH)kGUdJC~2EF3A;+W=635BH9l-8YwpSJ6m)3t^+ zT-yTLaP3<^w4QSVwAC0jVnh2Ix*9ri*HC$T4hKEC60Vdh<8a)abH4{d>jNUTn0rKP z{ANaLROT-0=g}I-lb9OjGFmkqu5Z9My`{y}YA)YuC1s8E=+{ZRyKW&et&%zpSPh zay%WB4<;28xyDTG_+ChAZcb}&2ES~4dB?UH(0f^>sUeYrdCFlibR|DLBw0kuIH~iZ zMrr?PjeijGQL;9O2r18L#WR?;v=i|?veIhpb9#2m?Uwf^Ql6r6y&AJYYM6W`j2RD7 zwm!Y7Q(zSMWpT<>1B+4gJUe5~EDM$Jjp#SNi2&^g$ebRH==34BOrt_Q^zyp*rfFSyqWS6oQ>1ttT347uFbNc1Uoo7HwY1MkEna8*E>8b;n)Z zOU>W!x?3jo^xy51dJf#P{Lm$wdO>>Xv@m!^dg`ojen>h$CX7!==O=}!m!$K&=n6@$ zkZ21_wyUX(vs^wDFCR*j)-LQ? z?7Y)(&mlaIJI_LJv3BX{J2UsjKeL)D&LLBRB+WZden6@)V1c2wfsUsYIQtPqaNdH` zVP#c9)s-JrH@tZn?Ccm=up@UGrS&`Sa37A}>yvg5Fcp(?!g-%`^hIH4L>gjUJj5qf zHQgFt?2}e?krAppere-gmiN#)^454cj8(mhr<*0saXxywe!08_N;xoug_3)u z@*dH#Npfrwlxt-*I%kD?h_Md?(A(f>A{q~1M`bO|%XI|W2(%Mm){VB4kG5U-lLT4` z_y`OWm?m(Yzy}1rN8kqp{*VCiL$Z0K4)W2l3&xi7o#Z317yltICiQjoZ^h&$<46S7 zJFVMR3@TW)#(EA-u7QV)CMz9L)4&3aUkR%KzD<+amAnq zyG>S(6{SL3w{^{mK?N6#HY@a&RNyvQgDimxi_F$03>gLB7Nw}5sMd<^$phrZqb8$v zSDh*4dk}#~YwQVznR43bMR+cAPj-YC!{iegdm@uhPG<7S1v5}Cn06|`ZwY=&@mq@D zGOi49%3=1&Fj&FNK2;`dM^b8PI!z)AszvnEgRh~Fq5Wcv;U?Q<>!N7)X!EA}yJ-Yh zvu6*rST+K1rinHbQ!e+JCPuN@2jn%|iKrPCV3HQHZpKW&L>kq@1(^rx6=lRWn`~5u z?aHz>o~`U8h2Z}afz1RMxj^C`oj{ZDN~BB~FRD?PL$tON8o|Ei&%@M(UMIa7D)R)EKn0CFAcI@~r{T z?8>m^Gvdyb<5^1m%S!6pP5J-z9wq-Bdin2CgOOASY@T9$f+&qheXd4J?Ki$iZTBxJ zdDbwbMTt#O9-e%|@E3KPh`&qklv@qR+Pv>dsD!_#WHLudodT{@ z0AvjNWZBd-Q=CrOfAGUZYMt5;Sw#vMj)a_8YWPyJ*I##*}WW z=@2Z~lAd{5Gf%$`=<>+OKTCt~G*mM*-5Jx?8`YLwSd}_L0T~~4%qhk6_)*gh) zc%m63qvj|bJ3<}u>nF|)GU^ORUuZm~MPv|?sH21_^I@pGKrz?=Vrq0_bW5GlW|5Dr z(6*F3iU#2e5wcU!X0m>0Iyja-Qi3Bf*GS=w4W~%6I+7j|i#j4icok!+quxC|9DQk~Zq)Bfk|3r{4D5@`z=vQW>41u=g^6v!F0hb7b||>ZtcBph$s9Xtrm_muKf6U!S|wc9Cl`y&}Nu%^uxTcHXWjAqv0@!S)swhib(neeTWO+$LD=R@SQ!i{qY%+#i z`dFrMN-AcHvm1iM@vI@r*>wJDrwXaDVu+vZ-a`hv6a^J2sI3g zE-keqb5U3@i>Vt%N+|tm&je;_Md4^QI`jjzQ%Tu5YGfA7XsYU-;ot>**v6{Y-=PYy z1(H}FQW3)m88^-xQ^_fpUQVAG#o^m9ASea*Dm|3G&&lQ68ax?$K>m7-^cZxQvJ?5$MBsAHIxm6wBj*zc`W{` z!4OXK{smQsNpbd&Z!Z8;ey7Lzcj%UokYfFT4PO=OBnxEgVo8_2o|zUPWF0YmoD70a z|L~XTrDP>BmVA*l0?D%ZK{o49lRWbWw>vy~US=SFL?mYiH^&i&XJtnpvl-du9-h=lnqW7?H z{5hfagjjn*syzXnw7+>JF`AETkdfpZODv$`(t|7a0U+&0te%=lQzA6@lI0odW1%o7 zoAY4n&c8+!Evs?|_mOsWG-jxDX@{b$JoY;EnFooWJhu6OwE6V%W?y`>Pk7NUy%-cX zUz9dqyq;o>eh0Dm-zBgHKwBuu6C*YG3Z11Qyoi!k7wfnv;HtbNR$h`SFXdMkQ%RV^ zm~na%=5OZXC!CqHa8|D1F*ntXT4IRH_el?ej$*;Kcsmnt0SQ*o#6@R2rLb4=>t z^_&~f)ho|tzYZ^J)CKgT^kH`y+nIT`>@fohF)EmQjNF{L*E8fsE96(mjS+ILfryQy zX|!v`{}EEjWZWNCc_vf$?6d&egR4*mzh*m+Pat8NT^6x5 zQ6^06rsphVVv*J;dE@OEd^&MXGXf$`CHpSgVw|tt6^HvF1p$! zSG#EIkZc`-t%GfoDWC4RcBk`RgYZH?bd5-^5z#g(*+vE1Xv*ZA$^INk!U18Hfo!#G zSuWohFW*TgI`3`%;V$T$OV6AU&iSQtZ1d&Zcw+CqkDT}X(%z%kh^jb_{T6~C^qg!{ z3Xljq1%65&8CkRJYLC0xzq1LuCz1E}+}TN+$~k^*5PBe1Eterqe;se^nW9x*NjAyr!;W!Kz5m=A}q31el7e)&~bf;{r=>(rDQAA`#it^5XHl6pc8)0L z(Z4~1x(-=^p!W31*CGeC)wAktHfngxn%@0usVOaAN=;Yq7udDb%)W@D+_IX&sPXjh zSda_&>1&_`<2WcE9M&~74S9xs_QCUYOhm@yc=9tjxehDbFcEJBrl&z+dQ!(0s+5By z!;`QyjAr-oel8eP_+rKk5lhUE0Hnl!*X7M)#u7b7=lvEnNhh-T0=?JrO6nS$`nYwQ z)cPz_7&t4nLT6x9Y#mi2WN=2hrXiwWOB^yC9IPf4$&F;f;bcDHkd;8E#l^ z6hOWP?3QM2arPBBM>fYY=Lup4YK;JKmdyCbyYck`yhXu?iL<_H)5gxOhaP7}%{K~m zqG{%uqGrtAf&m?3R~zkni1KyITLx6tkbVK)M3fU}q)^HvYPHX&U`RAae_cT;1fpM@M~K%D5pOJ+3STnnS=Oarh+lRml%7I_erOIq27 zHuy#Io4FBu#DfLkS3WtKEW8|q)$srqA}kN#&UH9>AQxctM4nKW={#77Wk*6f_9fjk zc0yV7#-Xdo79%@R0c%$`@$o2QSv|fyF@FGj@qFQe@r|N{r*Xbu{-rnU_iI;uGZfqU z_RF_k7Hd1CS`q;>aa>vJd>_7^ zdCR%r`~pYuQpwmPIgqAdb+bf8AnQ@XubWZ`q;$r31C5++_#0+f2K-%D)CfdTmO>P= zh{Pg`MW}Uoe0e!#HH>nmhScDLCx6yvOdB2fU209Dq|ur%q%B{u?xh2Fy1)PQ-Lmfu zEDbCiTLN4fVDF95Kw=O!B(NQBL-M3?BIz8S=6OPBtmDWMQD}NB%u{ElUxBw%VMKSm zh#|2Tj=f9LrSI)X*WK-kLz#R2z#GLkiy@P(_RNpQ zjBi}R@wuw%`R5k8-#9gUAmM44J@P4;RD0uvn=e3Z#krON7*wlXweZr-2o0vS33v5; z|AKL0m*{Sj+--upE#YpA9TVNH^GCk;q_TSUfNYg*T&!F#RpMx>(YcRN^828RolqKG zZu;+77ya+sZrjMHxu;+9^vk1o-u_8N?d%b`3+%F7^?y=7_z0-NDxQcReA7Az<5AR2 zN#hB0Qa#0?B}G;uO`QsdMJBBD%1xQFG?BR|*{Boy@NMcj?^1e_^`DnI5Kitt*KC3_tiN+csCrtgdRnS_8eupD z{FJ)Xktcjr3jY#3V#_O1%PXH53=MszpE4-eY7g4ldoYCbJio&NajdFkx1z)gRA1w#So`v@6!x@sJtZ0HAp&zg?X&lTe_dUnou zoSvyaCn&J`gbAjpsv2W`Z%)n@ub67=u=}Zk9z(fzepRekKG-~4kSM8~b+VD5TWYB5 z84n@rVB`~KM~$~LR#yGG%bB!7=ggfed|(E}=obrGJ?U7mb}JB@h8Bdo$xt3>d8!w4 zT6@Jl!f4FRRn9fB<;G0uX5xIr!OhtY{}fO%`8vkT_5 ziWas=rHxW)yHMJWogF;2DoL|@*@>F4UjBX5>oxOhASvWy+A%!HcGlyv&>^a+ggP6yP*q0@S-$yNjOT^l3V@-WFtIc_Tdx0W_ej- z5eABTlKMHZ#Rf#Yo#)p|Mp{M8PZlT z%v^Zm4)3>X)TWa?q^J#_xN%I{IQDVqvT)^PtQ%L2eY9YNjZe}lLO-9ARlx(a!W=Xm zHvJF!dCv3``3Hd~OeZP9LDTb;fPQe-Z#qSH^usbcMY#k_7busYX`G$~P1opIz!af7 z`azVJO}|7@UNP-6qo7wz&zMmV{UGqL`G^?_UNH@rk$`@1_p<3(x}zVK*|QY0=_uU* z;)(2c9wR>>o~kfW;(axyq9dN|3$rSx0yg%c&3+`7m|b+jh4*8L@Th`T2tu51qg z{CWYnd?ROh)lP;%Vrn9-=RMqfpav)^MBu7n2sab=&sZS`Gr>BTRZmfnjXxl-LjCqL zRf`EAh_#}Mh1FsptW;lu?Q7#x(Ge^X|-T#d^W9UhXey2*puBrqJg&YRU zhkj*>ral}pDUwb(T51QMZ=qMDH&EFJGE&Q;`hpWm+iiYy+hwQvQ_9+ltg$PKv*z31 z+WqG4Sj%Fk*wialY?CUs-8m4icv5i4r|o|n`&FYsrzl5M1#8O$GWxplHM@RmBR#~T z8ErNL>3DAWI%kBD6y%nsL$ng+OxaROTAasJ&SU0rz3-~TsNhVIw|_@OZ(1I9FDv?#qW+BdGr)xEt5i8;EC{Zq%$`*nMPo0 zcK;0phI;WzU;Me@&oKdossOIvYq(o^uRy47!?oD77+88lIC%kjy8|`uM*0g;P6?#!s&@1eG-wU*_t%_C$8~tlh1-r^Kxy)$pVH-LiFtx zh}M(Eu+isZeGv#Hi;x@aVNXwRAllZR;I3y4j2du3T&vPK(O+m4YPKYJhCr{b&tM26ywAOs>WER2YiEmCC* z1XuD0h8N2g>SU?scZQc*-oJSJqFB3a_RzzhBD03ox1L`1bjCfMix=;--+fs)aZ)(& zyy!V4c~0RQX$1zz#SY)Es$1A4Y}hYU9gxK?5QdT94Pf4MzXE3#X$Yv-$HyT(!8z1n z@i-?P=MXMd1th?VfK(9x*;3Y;aC>8YqPu1O@E54WnwEFhF19a4#P&Y1=}D;xn0qyf zZfIM4LiB7Eb{rAwkIWuURIEmLZbF73^B#0bN(}omDyt<`;n-#Fio9i`sIL7RWUkXK zVEZ)dAKk5K3GC<8k&Va)lg52qzV2@NxN@yjx&6)>>>aD2>?Og{bs`0R{t#&t>0T>> z_$LW`1EEY4Hpy%R$+H0FO~$-hS_ppt{C!xF2ky(9C~aC)(-d9;>|_(amjrg)`#5d1^4>J z(|7Em`$@_DBtlB)`_ZgT2KXekg*)>==PGIoCLz5}zV8vBR^$JSz>f+11%U?yNaIRA zx5)&nO!8DmExL+8BY}eiIw;<10u2NhJ4)awJPZ*MxUT-KSnO8l7Kqvy6?Wbzq|a~{qJ9knVIPv-2RL24{ht?xc@>giZet5i*pFP!r3?* ze~=S+Jaq?koq9oUr8qsp2Eo9>284}*k%f&1Q#wOpg-Suk6er6VpwjkB;gGvYGUxU6Bm!SWFm>Tckr z>~XIr;23m|j5x-eg5ND5E_d*pd(<=N9C2Lm`Yw*RUBhm~ruoL)f-~UpdGTT#bb0&% z1X6r~b8f+pkRkAlYjjZVjO_E#m&HZ|Ug1u$aS?PjUeMcgg2ARQW)hAeo1m_@y)740nP$F6B!(;SS&XKnCA)helrZ^|s zcMP6$3d6Bix7RHUPpGLLml|}r{eyyMEP&QfcGNfG9vmNG1qYQeprN3(l7aul|BcvJ zI6L1HpGrC4Y`jgU%$1!RQFAPl8zUK)JGvIQl%hjEdklXl# zx>FN!r#vCIDIs_26LO~|NlYV3zd5?tqYqE(uxWWs{dfUM@Vz$QlvyDWZV9Lh_2pD~5MNCmHPL)hczaJ?9&BuN?<0Sv%ws){cz}ZbyghNL*kdaL(staRW}jXV5V; z?j58pxYp$iI30r{PQTynZyuYFOb&;K;FCkjY)9Xll~*g8^8b5yFLIHhimQ`XCTAPw zPK1g!PZ_>!jufxBo_;NTwry@ARJ;{2nHn)mOO;x$H(qO;J#}+ysB}9@Wkm|guAaPd zayA1s7f{XVk+RC`ZP(gnM{k}8mF=7|%$tg}qL#L=mg@fWVNva@qN>K;|%IBsLKl=8I3*P~me9(6x{&U*CA z)T8db&smRQnR?WH_Brb@E>jO-lq9E&i}FkupOj_lu_RT3SEao2OsFSynR;l)B{`+M z@=U15v`js;Uz40tUU??eleSDfwAYiIQeJr`)RVqUJp?I|oKjwSS~)<$jPA*dC=_h7 zK6iPtm;j*L?{RvqgFc_&@_3yApmU@Y5kMpuk(?mIKaY#$vczB{d(V*(I!Y1fV^LWY z0_`Jajl!gZU)y3HZWmU+|G$EHkq!OaE2AVPJMW^3Nc*^4(@l3j1 z){7n>xbc8hAyESpWA4e#%VxED-6-G|Vq~i~j&5zxC!^?Ev>d;9jG)&CFG%TZ^|0_( z%5eZ5@MBiqSD%0B`I(7OR(;T94>sCE@U!Ya$SR&SyqWn%W+U!@;a3HYE)iam$$z3XB?H!~>B8jBb#7%ho<4C43R z_c}-2z~$)<$EeQ*3K?Ou!|}|xb0nIQbs)_yb@@20_yV1kb=YpoYV7gJa!fTS_3D!8`)`e(LD&a1V7kM$@BJ8OD}JrRm1D zUltpTdwIDGWgo=9>XZeuIEkpw$Gw1uG1*e#rx9-^q}$TqXV~&>;L6VPg4vdV_e>l3 zF~MTXgr8-zz|Xd2!OyW}!_T!9*m79zJX@}iKb+T4D5Z7?6ZqPZk9MIa&K2lb0#B5{ z_z)xata~s(oR7lQIn!B)MK?8BN5HvReM8p%{t*nGXVA?GIGm1_HP)n<^-E$loL-Y| z)pUuq4J@f`z>!?rfaCMlHn^m=!6mf~qPElWmx7aZf)k7XQnzNnwg_u zH|mHnpLzS9Pij7}_Jzmy`APK+m%8OJPzfO2@pM+e-3l;H>ba*s0j zFQq)|5r8NqgY86D*QD{%nu#^wV*O)I0HR^RH$L_e(M1gwp#%?M6M37-+d>{8wpdF- z8@)b7-Zpqri2}ZyUYXf-fY%)xaR=N|Hj`61#srdL1w3xQu$_wSAg>HwlzWv?2O%3N zI|_AGrt>&u`%`#8swFG0mtHHKJ$SPuRPq#9LyJaClzmK{D+v{BVC-XI`PK7R&d(Ol zZ4DJ}L`;ST^SYD`U0Hj5^4esu@u^!ULM!(X_gJpQbuO*kWDc8FiKbQaruvAriV9Xl z@=LFFUg?~9b}n^xB9z~PPNc;uI~pnjpL^d_%ur$i7$KPpBf$O$TKb z$po77KF9SFZ_8Z2x($0YxWy{2U*-LJ)P<~+^yzCP1l6)vg`c1?YcIsV&^W5$rDE6Z=g-I%v!@Yc3om* zavWi6C9*m0J%h7yhp%dnZ*(THJGFa-^getQ2)4i8{KZQ^4iXHlz*yS4lD9)j`Qll zpy+#wm6EY1F(^*({w$p+?j*<3X|8Ldd9><48RzkZz;+QxOrOZaVyoD)HGf{$Lb)Ch&fxT z*Rbr2R@02y4gEDr?jG!tWQWnNPmbd_f*eiw8n1z7l=|GatEXydKAWitS(`v8*wVKfBuu*4TUjEKuU9iuAs9hx1{rqfesM_ z=Y&AmNy<^l6-U5FCt6M^b%>|}w>O$bag;{iC#CJVH0WlhKQeiNxrvj7ODl-NaWGLl zN-*L4z=V8Q#)`}7EFS4BUqX7eN~t@ZWz;!{ANCZ8Dwix#`G8}T3A`EOETy_+lD@&n zghZ0j1JV5@$@r{uWZW$os7I243ixR?FvdhCT6R&Dl?u_aEG0CAZ2s*SVYFTH#T;snbZmQc~!aM5P5XmhA&>#bU`XxHtHV$uG) zdw$9B#fSKX((>!|uhv75wd0K)q0;r?(l)WQEmXSgmR&6Exa}585B^e586k(rxyVPl zo`oC&frqSgQY2oH^O2sb6!?!*I7{wV&%AU7f{x7kx!$+V+&FVA?hcxFqaAbhSKD4{n<)upR4tfug1N2t z%^N=`F2A1sdiu2CgS_&nj4Q9@rssCxoy&JFhc@;8BEL9Nuwu66&89b+LIrCVYFfiJ zwtF?UU~gZz*D3ZogM-82!Dqz5XF@f?^zLtT{bsp8EUgaKU_5sQ^LIYX<4USODdvjm z!iAgf6>h#|_=)9*mQW!m$1i7~``I~DhDd(#)y}VXPVJ7Ev*+_yh4VLx`5SM#elYU= z5i!3*Gy;NlnIEk_^% zp=|*8;`?crGpBWA0Z_~LsKk(xlaU%4B%+FJ_z~&@b6ueA2s{pRH~8*`6nS$~?}RS$ zm|Ti+7DN`_(6OB+*b$Cr1Ejq1F&s#_)s0UZTq;v`m`384P{ZO{c;FC}5nd=?6D~h^ zul(TM-9PVow=3B12={x$eov_U{M7ERbYb-tto#Qvn!*WuT*7#EBSK&TKsKIiWKSYa zb0Ad+2`T@Qfs9Z53e|d=yh?Zpj0hdjvUT<{VCZt|EJ~h!_!#n(NWyuL1L~T268RlR z9;ieO`Vw`FJWyWZo(hSfJRipzD0%*g5n0NJfSl4G*)dWJDuE0_IN31?K4483$R9=| zYY^vZp1Qx{YA8%@l@W`2Ek@w+q^dIi$`fDb`P>YtzDMKxm711eI*II8;jA|M4MOzu z)Qcfd8Ou|Qu`~b#Gw0?RV;T2Okh$xRWe9CEdJX*~>0!>6(oez|=Co&A`631CV~lFB z0*uF8#_~;q&=$YVISNN<~JnIcr`a;EHztZ!y z0)OSu^qvpPDz0yNbxV{8En9zcqgb}})*i8}{kHvnS>Qhm{QTU{&IOP6hmX6&>dgF0%CUn89-DQ^z)J8lOVhi5SPRR%L0o=YVDyBq|3+IT-nO0tc9Gy>1STA zo^f42f9?EhYaiq^PPNm?vUxAb^V-$4l zeeMy@@b7Ql>hq(BV zq+1%bM)RpDU29~}(XWc;Z91i8Muh6S*%qFxRvF{==vKxKWX!YR_CrkVs;sQMks=J> zGlW6%*p`c(v{oZdAy?f9CsJ7&8$Cv4X(WCZH7pXGdXf83YyN6#zuNgy=kzl#bp#CS6W^E9*cv3rj8}KE=D~>1C+4PuJ zpV%YTjx*Y$u2-T9khThQ%V2Zu@8^0Tj81mSma+3AYOP%nxgA*p0*!MY(Yy&e;Webn ztCcNYh9FoHW4Kga3(5=(GY~!y*b1ghKmf8DOx~zZ#Z_Jsxj@boTf!QNZ=K*cho>{pRI2F5f%~wUt|yV%<~IyRUX# z>6)z+^J_lVb46uw+3y!NevnWe5x+_KaAD)aRIae{lXMO;h{y4OnJN!RL#jBM6B?bx zw-L+s5B5>AQ>ZK}=v%MuiF`Gu#aAn!%Bx=s(V`tb>Cx=tiT<%e!I*a0S{^i3ESv(m zkT|gH)j$_2eZrIZmN{;PEFy?k+k1GC+b=@9E!P!%un8Wm>99^y`C27L<Vm+w5CClB&vN0dNj6s(lu z92;|cU9y^T8?rWJ%D~2tFatJnl^M>VfLmY`v0utm&#Fl+27PpojyYJ69dEZtmWxg( zAOjY;AgzVOnMJRc^kW{clp)8ls$}s|CYqAT*{CnWwag(%aiR1iUks266#6zH`o}2i z?~>_UI=Z!rChNSZDq_x@;uo^>uV%iSDbrV>?1m|0q_B7@BVx&$7QTGpgNmxzmTP@8 zol{+rk`*)0Tq~YQpE?jR=X~|hONV9(C}Q4RwUApd?fF{sYwe%VNpOmeb*$^pj z2$!!D%h%m3m@nTPDX$HeuMx}F%-QG5k=7V4ZxPE|Zq`01-|{entEyWxbM>2MG9twl z*Ui_=v;7Z>TR;HUZ~7p=BAj0XxI1fqkiYt20k@*=6Frw#DEIMxZtY?*SF&o_5V2N% zJu{MB7S65`v!Md{AiEy4N69LPGBYjHIWO&NlPQo4j>hzzb0V*zp!zU+FiZ;~AjplQ`H2biN`LM4vl^0;HXU$+*t7wtGs6aaf+HO&W3O$HL8Qao= z6P5_vGN9P+h@LJ;dCF>wi+m5o772HN^vkLYXKlWhwfV<-Cg(qU@7USUHYfbMuAiLx z@hS1xS#g^)v}GWaH5fDves*u0bCHJRTLtX~v$kmN4AP&3uo#HW1=wdu2Tx+*H6|br z0B8U60AeWs&FWO3{istAU{9(N6MM#!l3HLDr`=JM>P>x;crSL!>oq+wuX?r%R7MOl zb9NPoW1;-_s9@*2{X{A;rzL2?Ew2>2dY{W{wwqD*XBdr4d!{Xw&e&a8k_emFc{>2I zEzNGSrPDy!(}{@ef^en3KyHC@Dqn2NP)nGTN~~u!X2#WSQS(ymX>qkfPJq_4)O4yV zJ{>(odUjl0Ici?2D>tq#FUy;!rc+(<>0iK?sKPd_^Kb}ZW8as5h;w29VgjXpqCA=F z(LP;P%PM1JPfHldKSFJlN*%TWdnOU^N?!R4x$7^YsMF9>t>lb7?QtR`kRn<+;z$u7 z+83gzzE-Kjmfla3%A8%*J5WeNu zbF#DZao?-`J?cL1X;DhWo;I_sQ29RPiIK#Zj{_cgbG*6sY^^Wt-=NfC&#~v)(``jt z)po4deK>muboQoqVa5pX+lu40S+GyJ5fEpxvuEQiUi3+dZb?oVLy{;olypnHRpKOm zO@mJGAt&i>ZaUrAkBegKnzo)=d#d^L>IR9wG^tyCDMp-4>Y6Tn z^nc(aegY}@Z<>(!NdzZ9B8ZIl$==EAi-ONPZ1qo!T76@rKrFWy)lMfU#sq+qsr^i) z?{q_jlpV!54${XXktb8SlZ+$2i_i*872IPZ(E5{%{;?5HfXN*tr>R-`l#fKfbX92)UCLA5F3RKH}}839$_$d3V(DUt=aQo<;jIJBPK-J*e54Y|0&%ORV!7ruJD-#vV>i9H0eqwzRh*SARCQ(yIofq7Wu@)&yAs3<% ze?mp4unusjS#Fi7yHNQC(Nr>(8d+Hxgf?PP$&@+jm0g{9WnyYyq_pyS)3v78)`m-4 z#L||V)i;GuX~jw)He5-Y)-Tko zdb9JJox!yS?zV<%js&wS->ca;eJD~?J?jq@HO`IPa)p|oj;vlYV*=%uG20O;u3soA zpE(sOX`FL~O4dhKH^n7IN^0lwLM3Z%nr|CI>vkuWqs07BNz-4u$U@1=a7n|xl7_i8 zp_0wFTJM$YTv)yC_e;Vp9rs!~ZXXJ@91X4RiR;#fE9!#vZMRN|^#?*LI)lZXig6RTwP@KhDc-6L!)lR+D|xL z<~pL`vx$cPV#mWYuId#3v6(AaAuFYp6Ae#V6L?LJC=AhkQ`J(6o~Uv?{ud+gUg3Zu z2uOI(?BFnOs~i}O3{*8(K8~xmhq_!C?qKm zNx5E#WQFRqPLQ=DT@pM6d_z}ks97KzhQo0ZUWQC6iUKHbz)K7 zgQ5mN33F-KTq&9>XV-+QTE(i?o5w>{+aH*BAa8Nw2U!(CQ^nHRCJKbjHjPC5aey`h zVdG)`K(`gEA$FM_2vpylz`GEhBlx(uO|BHa?6i0iG(_B4lB0e_rwL+C${^(drFeWU zFm3_o0PZ`n37~1i6~Jjc{4~>1*{5mSh;wwn<=i1PD1)vrf6@7C-ZtWcMwx#H0&>=s zD9^4oy~xdP+8^F@O5AiRc=}BE^pJRZ=-#GbHh4F9Nq@mTf&1DiRH54?bBfG0Da1^y+bNw7yjr&#n7)R>slC`1XA$FCU+>ol5K<+io`uvNXz$2q}$d z1UeG-3*W}OpL$2wSjNQ(bQ9&0ij}#Dr>H2Uigw{ql*UET{KBjIU)~RWuk;_KhblUL zlz!82ef0IwAEt{H9ijZ)0NrcWz18`h&QR{=?{v=YzPjz}+irADr3Q01N3sfoSwQ8L zpFIM`+0B0@aX!tWtMV>X3X z5U)LfL`2rnZi-G|UcuEZFK@XS_`&4clcDmb-=3TcTtEH#>F-a9qzWBVoNcwE2a9Lw~}&MUmX z3u1)0(bF_)U`w?@aAQl~7@gbLxMQv~AT1~yELxm6#bza=Cpu@!Jkb+WmIh2kEOA;% zE;<_8Y4SK|r^zF*q~~2WnR*(=*{(7cIMo8cjq9~ptJlhC*l;(_mSxMf<=ApmBvOKb!wc>sdZ|lDRpY4DRmNb zO>$@{Hz^9QOJ}3&59Jw;o8Lmsao4=2qs*%EOn`+&%haO^BxC(k$}7)=dWx5+r#Mz# z2`S~3XF@$C%haQal$2Uxt5vg}(q-yVg{6z>R~t{c#`2wV`EEPbx6WmEH( zq081K0d^iJ3!k93!Tq|3a>)4-eEB*=hN?H`=#%1JPQ<^Q{B`AXZCfsbep=F+8Z)_X znRZx?gq9y4Bc&7CCgnZ!C>A~u7d{EG6Z3W9-r{Gk}sMRAb(R8uFN82cvTt zGs=ivMjg>IjD4znZ}$kr5lb%L;3ZBgst+Q|)Yye0B9HxqNEB+AS$&Yq3IR-+))39~ z6DE?9jfl*ftSrzNwzP}9nM`Z=B~fF_15U!)4@}o zQ11D8^M#0V^<*KtAXvCHT)0at+;uzuo$@>7cLP7a{O)BigyLaG(CH34y`s|_^o@zm zXCm!;-l@1#arYUq{U|j3ifsI^aO0+zZqt-he~QaAg{;Z`aPEMZI}ph(nb|(qbgSa7 zId}|jDM=1H1!hA!chRLJ*Wpc$4^7%aT_Cq)%kp2e~C$hDEV z-EVc>=z`YGnr*jSKNjifJ}n+O9X#U@kDLt-421`L;sBeH0U@%Y>iX!l z(YbbUMYC+>=9W`jw+rS<8t*jTJ$k2!DMa>*NBe`PPAh7W=f$(^y(t`B3h*dGmJUpKc20t`c)$80XFB-#~Qhym=!8GtnM59@DOv4 zb~f;P8w!Ngl(dGtCi0reqf-;vn942+k$_upkk>%o1@gwo`z`XO$otpi0k3kx&&Ugs zM=XtOjb$B$TF9#;50W)b*r0g35b#qIWO*eYmRHtf8n--JQEfcUKkDRl#tZzT6y}?X zjK}#$>y5^Zk1Qrcf0AEetbDXXXLRw8OfAOhN9TBxamAy2ozcTTO4l2!$gBJ`y~HT% z9kM0|(U9sCwrBBo5S47b1$I_6txqCe%YF;2w`kgKDIn`Du5f=~`Zu;PMH zE?IFAV8aD|zAaBE7|w?!7gJYM1;n!-AS{8SChCHmg@e#eUP00{s8D8TRblJ9(a+D>Y+Lh?Ug!^AYW*o5mP1%f4%sglqc~DIk z()G`Ti;NrZn>RrT(p=5DAQIUuCG`1|k@B?td7RIr6}=(yk^?>xMOc_FALm8MW8WRq89GJZ`$a1B zC3p$ntFD9yQvOk&wFtLp0Dc57>7;~VL#a$_;G*XyXKDO?LdBP?7j^t3HhzR1Wv2z( z*sa2-Ttmr6PYubn3MFItVv_54owL6~X}Ito(HY@C!AlfyGa^DNNghC^CP4h&c^bZZ zRF;V_X?rP-!fd;_5Qqmv+Q>m+c_FSUifUZ7_T@1Vq+vuwCVkb>7_}wRnxccZVfok% za^AgYFHYr-I2{OviA!cOU*YEz`d8#J4u&cO2ct;zbH>J*LSFPPu@mO{O4ZL_lIFwr z`-vcfP``w%KDNEl_MoKsr2|tPk+Q05+opDZ`4IFjuXVE<#73@kt-Q|E?xQHl*VsV) z1p)~JrOJ9^QWCp*U4sT-sm42=&^ zh@Gg?peZu_Z&NIrRX0M33s@=}8260CGQ43c+3e$cEO3dVb|5FnO7q!QK{(k(Pq%eYyi&K2)`2Lbb{bHLb_P?7;^R-Z8+7uVVf5LYP-g^4 znB}uG<=V(GidZ5ckW6$cMm90xs4FGvIT~kN+QB$Sva){z(+633QyEJKNk#XKsP{=a zu519~Hzm43ZlnB=ZB148)3-9GHC&{P2XJptRg58U#Fbg>c-HKzQ=mH?!JUh=U&gg9Vtvb@yRgXyxL(E|6 zPLyW*5(m--=OvH7Ax-!-#W9vLcEqMmJ#@atM}==u7TT=1V8p;JTLiJ)IgKPKMK+-# zSq5RI2Fm@uA+{d;q%g+7umXsgl0#yqBut)|_|IthJcn9n`Q&}|@=KS$^a2)9MP;~R zy;!k6RM9G$N~cm6a`LVgzg+xoQK;Bu%aJ< zAHzU7DK|+!!XI$Smxh@%@!ueeALAd@kP!YY6=_s~Ym&agQtgN@(%e;`5VRv&_l8?f zh^;5Wtq!r(5j^V*pY@1mJ@c*S*&Y}E9kK~`$*YAI-|ElsE)_4=>J+t6fGqLK5UbCN z1sAlHAyl-y9y=#~(*AaBy3@5No6(l?8y@e)A$ps(4Rc=ec#qdS-ZRm1fnY-eVnubv zwbPh3`_i5($9pki#HM)!*-wj!xyP?fWsNG&gyR;?V?xwN0)-@}lvkb!^<*w{+>)JC z30{@*%9H8+TDm7qB(ESzybwk%S-&$vpAMNif;_(ZWD92 z-Rk|xsUM!Y+xqkE?`{vCI4K_N51w|0PoEc0GXe1F(a5gd?_}P|yz3Nq9f6oTuZP6k znKm8@W(6ywx0kEKmPXOi_`MB~=1$(2{O-<><*9koHrAJD8mwbIb>pd!W%Img3*wj9 zW4~`|SxPGWt-xmjhYUWKBEaYmJd>RQWg(D6lAsAkYt^Yy;Hby76ok&ulo74*B6qX> z2M69h@C0ZnrN?TM)Wy{ti~NQryxcAp?1;h3C{4k>iV!&Aq{VMBcnKpe%aEkUU`sRZ z5>}NKzDi(-bx)xyP7Dxg)WBm3jRxzC(9PFkTVgpLPmM7qo_$QB2dXZ)&l>btVxJXQ z;MK>3oqZc@M04q%L=}Fneqhr>sY8WETF34n>+=)8gx+M>PIMr8T&vQZDr1TJUGY>a zsMutO{uyJYVB)AuL#z55M|9;hZnXc3wP;zgq?)0y3oO}&mF7*crIg}YvON{OQZFv! zin2+iYIQ>Zcvk#-6G9ECvUNpRDI6ah@wnVh`n{}_QC!&b494zUBv5ihFtJVDj3s%7 zJCkxI^^6IdJH?)A=j^6d1JQ2hO`|C>RM)Gb2QP2KbsZ4j82V9@O|OPiZ3avwZ64X;wz6SXoa2L*Zh{s+>EQngO%>em9 zbkm}4oH${5s;Tkq9y6+&6idu&>`j#|^w{-^+@T8y)o#QR!l@UvgyA*zb=0AUUCnZg z-Jni^5sZbgJ3+;bzK0Fv_Yl}&cm*G0XKPdaNDT2Xm|KdcYVCY4Y;{s7Hn_3Q_o}>( zvh(nPKy6-KFYYT+W3mwnrFM4F99>84Ok|`L`r?u11p+tm%{732q&Mz72;E+r&Ku4E*sN* z5&HF~7^r8}DVIh4v3Rugd#J$Qjh6os+)o~?DqL-Y0ckKl)2TE3mGnq?<#fvQGgs0U zN>;q?pWXD@a}P>ZPq)LU^^Eh%#B9#h=dV0Jx9?{ATb(yLe^FdE)BZ0q!O)bhVK`)O zj6T6<526wF?Vvk~soT6)${ZXQ1awul3NDWZu7mr>M*?g(1o}Rqh`bW&C`pmdGG=er zu3fSsttIi}q~_mjFuqOIejC;Pj~BU5b*Y)FKIY(klFQ}fPw)L&281z1<KFM%GyHshLnOa)_NbV@>g6sn z53vGvfN+(2?*g-0G5Vlj%{0GI-!$j^o+(t`@g37_%gw#FTE2hacFy;^uOGc(63aWL zQ@)l40;Z7eXylhpbupdK-MpW$$DMVVPv@rmluv<=lVwJGF7FMI1xGgWLNwPZHmYSz@fXpUT=Pj;ZEEZ$05&JT4k2 zFF8uBbZK9{TF-uuESfH?ppMZjvI&eb>k7LrUz}GKtyHKzf3H&Mo$#l@1GaT+d~;lbawcq!}rY z?F;aOBV+TcwuDz57FQi+3Js^kBd5ew&}kT+Uo|WSz981Vu*h*0?Yf|;BvM%Ut*77I z_sxC5HQRse4|ep1cEZBy@!*cWP|XRk@WlM7GvQN%_f8E4UE}i?CqL%6=lL!jieKWp zb$C0hJ3^0({85TL&G+a&rAIH_(~mxidUY@kP*^$Je#JMH@km#kR{C%~musEP2;pu} zN+hdbD)X~P2Y9aHY5ucE>6~>h|Jk3`zTl^oGF`AopK_X&>-o|ewb0jJfu<2srp zPz7^}v`e8r*d-5Kv|x*rgijFqlD2gW#>(4+hzP}E8OzH+3Io0t;*{8$b&z2Q7~^$a zayXNXmf@ZwnOGZ*p6!!mv$(U3q@+8vRtctJt>zSjbE?Ih>N!s^r#hIk1@fV&9j2kz zR?qR*qIR?@s;`}x>VB`JZrT`WYQCC&C4FY2m|giGyKzy^mDU5M=B}36k9f+NST#0N ze}op!!F)?*>pPempnfgdI406Zn;QH+Xx|tqPzIw3c1`tE zW+lI`mGCUi6lAkXIGJGt0t7n>fe@pG(HaPgWT$*VET2uUX#_=;k}4%=ea1?mb`8o` z$s5b}o7A3khyk*NzRDiEWFiwe0?C^h`8cU};8)1>^4evbISmbdnki{*vjdGkM^YRT zc%5EZnHmbQf>B$2wDRR!q{=$>%ea*`f-G~Vy)VOO8N22L=*lR5=!$N`GO8l(P-{ur zU3`F_2wW`u;?k*1b@ks&W2?VSS^ar4$6hO&J$kL;K~5c=ci`@$RWw__RZLo6KWz-H z8~JJDt(q_Z>=C-`M5amx&WQsT#3L7ekbkF9TsJcBe>QmWIY7lrd^LawId0W{hGUwR(I{5WDDEh=+^I`}p8$?&XRuBFECS1k0{})Fh-7{IX$=@1&GQ=2 z?@_Q~*Tn(b)e7M%+XObstBm>*HJeTcxnwtov}NgdOtmo|Ga;!Q>ug3TjbojaG4tuH z65L`>s7R_;$$=;2{5h>Xi~+4Z<%VLiVdj&Ap>o&fG08lKdWbzEri_@cQewLdVAEiH zirBRCzO6l9q0@lzyYJTOPa^5A<@z*XzOQS4Qwh&a|{_6Xmu zI+~w6{4ch`qfh8ple7dT>{y3>ti*&fJFK-DDC|zI?@)e!ikbA6^@y2R(dr+=&)B); zUjeXs0)8tIkmLU5VLNlb+^JY8AR1KCjd_SQmhI5! z7{8KS=043ikOk)agZ!e`*I(awZR3^gvzz8x-qxg9Si$8+jQo|&*bi)A zQt_%Gt1!=!{EfqYtFs118G=nM`A`toYSjW#n3)JaS7* z{er5NPfk9>rL?VAwqD(FWyfr-ShYEnx8>GeG4JW1`DvNvh#x;J6M6h*QNa%WOxlyi zkYrNIXnB{B#E})Bt$cr~G8r*}0+ZR0qJf&GuEuP8aZ7<~HM2G6Mxs#b~_cpLBa#tBxc}=S{J5)IJrk=BX#b%2?D# zTV?VD04{rqEtQF2Qo7d3dq1xKkT)bdz=RGO9i={RYTp3+C6H9w#BLq_-W9$?aTO}$Tkn51`lCxm}VAp?HW z%+-(~rLc|fCxgu3*`*9vNO7={cl)Ju)CQY*<6akqotKnwrd+oSGyIa7SrGv2AgCne zu>xdeZ@_m+`TY`e{5S}t#LULmPum4!E4y8?{7EW~&m2E!QrlmBz~86QeHYDr_9BM^ z95C;FFZaz1g$imRc#(gWZ7ye~G+46wdxJM?zk6OR*)r9+YyxViSb6=xwFBYuCb7I} z?)>fNg1z?O-eaNi<5OLbuRxS@XrZulW_xh$p83MP)4GMc!uy5p`5v(Br{;THp+a{! z&mGEhNAj&OGn!xf@*$YdDB8!v>SuNBn|0r;n`;ys+e1|Vm@83VMT1zrDcIH}mUT@X zisY?CECZtryl3`zc)AZ0#J9=v7*%retm~pIic#b)baWw{B=6gln`tS%gOGf$TwVb% zw}{ZecxB!Q0%y0oT8CS6(w|EZ}U&H8aO zhieuI>JY!^JQ~XH`B~lV4f6$c;exGV!PZ;V^94KJt(!^<=J)&}e+_<9J^yLkNY5$` znulXeDK7JGkROpf`4l60{0(n#9 z{ULd5P(h$`0_bv1Ry)zr`2CDtA<)GyveGN#p_(YvLLPxn!9gCN?vi^Be@=1a(LRXY zb9iJ(Gj4uV$Zj&o*B98uh3J)q)sNDRNVXV^`H##-+-R`qj3c<>U^1?!%MCj}G1VFi zKk4R;#<gE6Jk^ z9+GLB{QJE-1Q)&>fAWtFf2`+u{yi@DeQx#pT+{np&HG&adtAnQT-y6w<$GNA`&_}N jh8BJ=e$9vcPtWmt41DWDj{HxLn0z|EBZy+=_xpNgcJ^*TvJ}aO+Mem2et!M+_v)`3&yS4c6Zre;&wXL#KerQ!?=lkq zOCoX>Kj#+_n2Acl)J(mpRWv-4O}&+@BwMLUs+F#!TbW9xm91o3xk^q``GzWa$v<2f zmS>?-z%$hxX%#C)$&+r5wni(X@}6mqwZ<#s@}9-}L}fzWb9kSuOlparNSH%r{*{E8 zzo}QI5F0iNh!qe!fY^vxM68I|w5eZ8ltydkwL~I+d8OePt&X|Tv<%Z)YP2oKa91p& z-m#HTb8Cj<+8cFu!^VfDj&0c18eXlo+h{KvHMi4h)MHsr=NFJPZ=AEY9Jkh-H)?Iu zxX@`@bu>`UqdAl{>dl(tpgzZCD~&elZPnaHr)|_0I~#6TZN1a6%|^TCTF$)DthsKZ zZW(JeyVkN?%Xa4Tth|HX+Q!<-meZ)$n#P*lS+i`n!6JCS-arNDC%V4T#&FOEGP#{j z^IF4|@hn>%RA)EpQd>A!WU3)2Ds|951HH5CMy-i1*Be$F-EbPM4Ri_hIpzGDnEu4Y zQpQg`-Edq#eR*T8Y5D2r+xYA!7us8X?wK{!wx4>uQFlw4pRSw8iBwm;zro z`-D!-2)o(C1k5#yEw}*f;s$2$w8WWaHLVu#$Cv}KOQ#K>M6CQ)Yq8TbI!l6t(!iWm zUM|lY$A0n{f^!U>IaVr1wbg63Tix99jJft&8B=WAwJp?o4#}3&sI`H~5f&L&Hs_74 zdBeNPZZEW#(SuG46Uc7yZ^^FJ2UK<*$pPy0)wVHjT%R{?Ts3Yq+!d@0XU(d+z%D=} zK-XT?980xf8gq@di~cPbnC|rrV86+=wA26+b9r&?$*@_az_wYlm+>*gOspc+SM;`H z0~7nleO}V8qaUd3$`Z!AgaI#M#2f;a0v1`=Mz1F6Tz!Qt8FNd`PK{lB79a%9V{#-t zHYa-uT5Zc--s%;+)Nrv*J1z%?^1bgw6i zowjM2Rrx;GJ^URRK3$u=x%c1vp&9%5ZPNY<)nf_r0m6pYcmxDsYcqX zMt4i|GuLil4xBnDcJ!y?Ers7x_&I+EL08)m#Z zWrA2*V%Z>;kytK>WhFKg#BvhLFDInVA&CtK-|`YGG}O0Ya|9%%SUW_96C&@EpmxBH zh6g@?Rl2}=?JTYWaRbu9j0+fak8m*M_;Ce_)5cTQ)(xy|1Ay5efQb%f6rO#8x7ksPaDSqT69eK6cB+(ax5e%$2hacLYi=Fe?=itM@7UeEN_x0S(c;p`gW^j?a+v!f7mjUVskZwT4p>?ujo&%^ zoZ|?(iPgl8wwri7arMy0^lfchcd=l;h{e>^znFLv@rylHRFju~>m@wI)P_ z2+8_5mJk+-P;Cd~t?cQijf#Y$d==$*H~34Y2SrT_{nS5mY=>}Yi?Kq zIu#iFl%k0A2;Lkn_=G;8c|$T(rtcdY^T^o#8eTXyQ@fUQCS1s-H?@~Fj4iR6bW^Ko zH?ySQ#CRno8%x3Ho%3DYIgc_J{mYO_v4GRwe~(u0vz85vAT8z88BPuJvqd&uzXp2P zV=>1_v3n~qD*s%q0RgaoX)iz~1jp&LpRnzYecE_(DH0Vx_}d*f5Kzr|uFJ-f(YC-N zL6f;kgx5#a%QD{+v{8twXaFOxftV{{(K34r(jEu=yt@f#qqPA(vQoPa=G=~YWNG=u{(<`W2IMJ7%MdGN+4*{V&Cdp+jGrP4la&}E zN^Yg*R5@l}-Z-Yc2$PnX5G2T95Q*G`#($nj=Vkd=$`K*kkfni;`D#_^0C+D{tLq!J zX7D8or0vwJRhuiuzJ~!P$fI0n2x-w!Ft0b2YKJzGga##fLlAjH9DDPhjTK5UykpCMU5EG-W^sKmaD7nvd+HMV4 z70RJ~O4b~ejF=+DN*edq{0u!V(PH zBVSR+CBi8ID^z)F10<{yf=QCu2&MXmHXjP`(}wT|b&t$9(6>T7?osiuYe7vkA4D}> zxl)@q7Uzxn)p_Get+}?sC}i(>W4YF9)vjJ0RE-SCn0p#4&?C ziI~6`34zA~`HSMA1FMPE5~?iGH7Fj{WxKPn#^|MwUbwJNY8(>q5t_oUd==*~{j_w| z�zEMiXhx9}T6#YI~!#XxaX7tYJTMy+K=tpJIDndUJkj-b0RFU*4n|aZJ+2DKz>osd%X?U? z$88Ko$%{UUpF`O*F}hzI9D|6MNqS=e9qq}5_IpvFk0cEtA&Z}I%)f=}FTDo36f-akD!)me|F=ouS>tc9N+q5yVp{neQgMDVq9TO*&t4hm~m6 zei-$<0%7ZwUL7Dg>9zqO>~+-tI?_RaSQGWScFr9!QxwXORy5Psbmwb)bH3${ni-0! z%OqLF_NSN%k@geYX^bYlIu`4DMpDLODcNquzTM3_-|l8#)9fj3p`S)@%tWX|Dv0MDT}((EI?Wp}2LSM~lN`-EElhC5^C&EZ!F zQ}pqsc2nEVbqT?PW??FE80N~u9K~!qV;1>TgxCAK5#>7k7ST;O-4ALTR64(88!q#RYT{N$iXA4}bsh@1Z{VyI4A8-;Xy%+2;9nf@ewrOW$>0G5e%|fC{x4IzXCCbZeBZ>+AzMgfpqS`6eOeoXf}+2b z(Q|miPtPU!%;|YiU2%{K-caeE2C8a1py#Ve zq@+Oav}$@iV`>ogbnnUM*Iz(uPpqj4cjqj=+>m0l0A|H4_?OMolsi$G*fs zEG)u7Ip?=@0ltD|KfCiMH>4X84sfo&j52{qAg_6cLud@h`6Gb?yD&hTjsapIr1rq! zGT|T?FK`fukgxASn+O0Al4BOW^0oko=j;OjX;lV4VfQcN#f6Q8tP^rR>_B|3>OhQ# zff&gh7)V#sNXhK#unc8aVfb7f>guL$Cc((3-Mq;h!(e)8+KRyRrrm;>F|)$nP!HLg z-9{TZwgCnX_Q*Fp^L(1aK(vDQfyk&5aRTtHew#{tpT6fAK`}C>{RPYAvi%rFXmhOg z69~MAVhBv7;||kkcEsiadS7388*)z@AjtjLHc34;L9ZnPVlS&lcn}P6VnwuRklHh? z_64MU6Tcok4;huRQw*Tp-sD2C6dh%gfpx#^Jryy!i!c?1i#3+Rcph7i}rKyyB;rMevzKUztyn=uTk1 zk;UC3rCcFtONpAu6DL_K@HMrGFV<*J|?xD!qet_V*c;b|j zVwx6qD`iiG7)TkR>^&UK#dEAc?G=aQll2hqXpjzz1Y*TK7jLBn(EgAG)r7m8mI60< znAn>kJ&3tfYYZln1LpTzgr(TQ0J#Y;kshfWkh>3-lx!lFF^32sNR~0xV;Lp|WUZQ9 zhq%9!5-S1s*OT7IBZ=}t=X#8Pa7Q^X5yJr7;)YLTGcm3_FmIeHA;%q-WlzvtMsM)< zKOVhUN-G#~uKgqod5W!65VZ$V(iZv?tO4{y#ED-=W`{H}k;>D)95IstdB3+Le}DEU zq_rVPxR@r|F5id3bpkB>Bby`GiHy#lA^?CA1dt2O$`xKq(g@ifTh!Z0=sJJg#g5g7 zpgn=SYO{%DZ9ApjRi*=xG}sr5}avziTa$}>C5X$4v5_V}x5EREduP&a39cC(*JT+7;hUOAs|^O~Ssn&*|&fXz8&7RLb35KiM!hZ)ugu z>)0u9vB}+N!(6X0NM$ydD}RPoTiQ$>Nup!2nln(4atW^#atk>4m{HjnYmG``7tpp) zn`k?J8Y_F*@{_;>MPg{#w+jd=6Ah;tAZRtfTv3?(WX&}FOl@t=YMWyIDdm+^EAS~2 zFplUy986_690)c~eQn)Otynb^%F2dq1N%Zoy;&(La}i4;=VoOnXx>7pssh3)`(_~U zNBV~OJ4kXqh#*lQ`A_L7(JJzI7N}uph2#W6J&APuOKG~M|DSxWpyjn@_=zyYc2LN2K;EX_JM?;BN9&HC!(YDr?{R_y9KDKojNZD-x|PR2~(EhlfVNp26h*=|-u;=e$R`R!phw=-mBy2DWR z3f%%^2Y-8{3;QPQtS{r4gADPS3F*MhLu$yMNvyZse0SugW)9=a&zpt#7vvv7xo_ZE z#MfVwazkM`q>du>J5nZ|I)*ZTim&7N`h6)g5i2u!O{axcN=51JFl@vlyP7$5Er+jX z+!439Gs={kS{IIeVR?7VJkT9Eh!p}gwY~lVOoQZAoTYD1n$LrY?a^)$rAK$BcN47PHEn&ZtDy$W$W3i3v8`c# z#?1V7-aUwP)`R)jnK29I^e$Ww(8f_fe|&qQJGL|K&UVMU6WzRd@Kxl3%#}BXAm8Ot z>+hH|kghbqBvC!og$=rT7}#(`gzj0a?nAYQfn$*z^2FvE#1bX1)Idy?qo8>rc#n;E=t;j()j?QnD^9E-3CY1LYWML&z!E|qKUThaf63Q2J^lVol@n8t@bnX# zZ@z^;U)#bL{7!87nuk~KO%f`^z02O%4I7SpggAgrI;>huDn`YXFZ#2Wh2ti6p{`ym zy~iI52#ahJ(RLx5N`AW8xq+RmEX?|J*YnfPS`+Scq7$V(%kDV#OQ_G+Z3_lkh(Vxq zFnzkVh%-LriGXLN#R1YE3hg3}pF1B|y8NLhHtUwS^4cBtuF0UqdXr|ywV&Zz)`4HA z1~tymbsEE(UxW zo_QExj8V&k8OY&}?y2i3&KXCP(7l5pbob=xgEApZu~DG>7+wfvXKER0{tiUoU2O^W z5L!z`6c+JJtS3;m<7r^Vu%mQ!<}Rggp9#e@EQC-nBcXtMxH2CpQ~O{GGe*sM5o*%c zSB{JQmD&1%WwuKM5r-@?+k6P5IALPH-j&4mZhC{s6ocX*$(@qf<#>N?r4)ftl+S9d zxm!Y)0>hd81r+dyNM6{-K*;N%Lmzo%j%QdoO*sjyIj+oU&6~N?ko$oBBS8%B!|x6; zLh9YeFP*ae3m8aWFMl{FaV0(l-awRKkA7$#TLyVTujk$CPfwF;}qH*O_ z+|^#woX?`w33HNTG^fl1@|+Hy2l;#hWu%4~K3NCga~3WpiEeT=r%K{&C1m~xY=dGpw{SMC^@G_6H;WZ-Ddh0Fw6uiX$=1n@7ADi=RMzG>n^NaHNyocOyAF zbO9G5BDeZY1AFLV$%iXi%v7__FVn!^ev2oGF;|cHJYTjqEK2Ok%N*Xun{Xw-(H29b z2J8&fn;Wuk(xY%u$5?&uucs%#@GDzgcv2||&dfze<$4xPZ~DZE6X6D&aY@|K3?$vo zT!g3aX~R4GG{(zwcrQROh4x^~g~zB$r-kRy2GA#EjJPFkK zz0=-E;C_cn;6!&jS$_YLr{j3mq5Zyi#++#_Ba>HrL3uRJEqweT#t#@><8kHLXv{62 zthZFMk>@}b#296qixejG_c9$5P%vHM!2IkJN(Z~0Jag)S+r#I;@zlAokK9f^a_WQe zl#j+EmD{IL{0IE2=ubZG>1RGrNwIajebCFC2Q4)|;-S$8g2+dMXT?imkY0v;Q#l?C zBFG_$mD}10i6rKS_+VmKL0Lqh*GNHLAGgRf;*lC_4p$2usU=2h|{ww=f+&yc(}4>jNs{5t|wBBLEeEU!J5 zJo}cOW{$je2ui}zTe%ElFfB7 zi=Xo}0QqiDgj!jP)RSjK)v3@Sq`-(2bG9G03M!X1vRaHB-j(b9XzvTQ$3|aInq%fT zoZZLvrGxJwecYOmn!)`Ln`Dg=pEeI7KE*9I^kl}IMeINjJ7gY4Y&wV?G2esOLGy$) zVa-^1)m!P+-LLDWVcrAB`dPD#cG4zOz|W9!$UKSsSu=r=G4)>aK04rAN093%;>XN6 z)cPLt0n|8Q9!JXk<~(xTB`*4>{M?0L7si88IXD(gtA4xFJ`pN=(Ow_U6T))Sq(To> zo<7T>M3>6%;32<;$@`Ib$Ws&6VPb5$I0sC=NP97jYv6Wy?|(;3rQg9l9P)k-_x*hB zT^!a6w}1LJ2YW0`y~Pgf3Cxc5xnK9VS6J?y+^l)FWv{z+uNFns#sT&{&dh#_1;v}$ zn~yvkB7of4uaZZ9#8+BO<~_1F%Qk+Mfsomwe7lpY^J}Doq_TZy-(VtvGst+@hqv=_ zCJo~69DWX!Lx@AL&T=;%!p=Lmh1efQ0lUhe#$XSfsw3@7>>v&GiO|d9NYt%&a`T9lvKjxCx^IkTmI?jda0q%Z5f3uC2GF zZ>8BTvJBkPxquJ9NJ!8;AnvKhdQf<$oiss@VOG7t;5vie-ipB7U>f<%FR&33cc|rw zl;%whYNk)oC6GfTa2`OU2jq9Vd1h~+SYM@t7c0Gl-w(bC_Ddne#GrC=&`^F9yI{75 zUfW#;vT@aeFs&+Ii+9TY1atK!BiiT^R{vFIRk+W|fY5RHZN|_CKmzZFZauSqp1u7M znL54^hl_V&??A;aPs7m<&DLRmg`>HXrQ?_5-4RQN{g0V{4~e48^>t?5W6GMnWB*|= z`2s8YDueeSsB^BuKgx-WpTjR8r&v!QsL;)iU^6QT*Cnl+#Kv2ys(1C^Es1l)Bwf1q z`165VG1bLZAV`-h5BHx(6N8$z6+hZLOm+^C9;oo58F=pNq=)xR*BjYlQmYV1X z6qm-r{gz7@=}vMt0VL&J1{&NS6mUVCC^`XgwH>%i;w^;(X({LcPxs2Y^mYeYLkCL# z3`(!!mVnvB6zrj9=2g9$I*j#D%ED8ne`@NmAi&B1(QEMZyo^5g4T96d6(*}|8ZV){ ziaItln~1s4U=Lcvq9>6J@3N3?;(gHS`Q31$2FMbslhw}B-YFwZ5zjfXVe#gORBf~2 zl(K5IP+Yb@%YbY|wo`1*Dvs$}Hzea;itA^BQ%xUXfjvZ5d}I3i@*ml25)O)UdA$H{ zZ}@p9c&sB@XT*j6QhfkL4ZY<>fO2^a+U>M=@Nr*J&WFrJ0WbSDG3{NX+JBEh|6uGd zA?-V8+y47_@s3Bm-LqvFTOYv|rTq_3;G+Fy);VBxh~?z>2p0AXn~^hjs(Rm0$clsw zy&V+1ci8gYH#l>{k>BY5oLUT(*RCu-j{Y4DmYx>W?Bo-M4QvXr*^f?cKr1 z&xcTMn12uh)#QN6`=CUSQV?!@J>3|?hEE<@Eqgw4Y5U_?r|!oD4T*?I6}x(k6_;z zLIw6b)E*Ab;h1oNk{yr&Pvw|XeX#>7ruI+Pzhq8>avk*Ue`oW3E0D*^-WV*r(FJxi zoZ7zxXnDgh$SUhNtfu1^WLyGL!*FB7Blfz9m+(jG%9)Zxjg;SJF`@ zk*_<9;|6f~vp?T0?BsBSdjvU0kn?Le)QmUeK)zWX1%O)}%qBXI0x$>M?r`W}yNGhb z-J-bAeQo{oUE0xxYQGLZ^h{~t^=3}w#SUjLbu$wWd@QSrtb)7W`sABEuK#fL!ktR+ zUK5x25ML=J;hGMfCx*W~ZgD~FWNSErid~;2_@`ZmOItJQ4_x1a0|EFh8QWmvKv{L? zo>ObO*fPQ`M)V#toE4lkc5wGzRDQq7FY=N%d$ zDw+h=3m>3uPDGdi9>8D0w}F!7HGD{d827wMdJI&Y*rOiASVJhQ?CVP;PR)dU!LhvF zDeSmF75F4g2Pzbyg%z^pSsGPZ>aNh)4j1)>Ux#_NBZyS;HFGumkPVJf_!;rtgo~d%!RbXogZH2d zRTzE^iv4fW>#ec8vWOkeJ@hj;Y!J@VpOKPsUnx<9B?#u(X>B@rC`mJI0bvUN)S&}X zR(0S2esAz{D{9_!L}mM?zK_UK+K7nk!-$Z`LIc%arTdh=gX60pb9V&>2r|vF;A9lM zb#XE(xo)ga1y`9p5+3xcMFt2iAG(~<8(!8%&1otCo%RH1OY1)0-! z#!b2KCGTc-;apGQ21kHX5OI=-<$5tQ{b~{?&ywNkvGwEKB$z_R`$R;9F40R>DYwCy zm)=~+F~Nv#;kvUMJxN4em54iVNyYfb%L{orN~W%lu7@+%BVzXN0}KPH$0UJ7K_FDd z2&^uS;>Br)CbJysVXq^1Onx^QC0ToaoJx4*2)A9jhT@iA`nd{D#HHSlv{lmb8O9Vr zs7T0x)FTOXw$j78Zy^eMN{5#PnwGW3x0qzq!Dnp2(BvQ}YySb>D&p4~VwfZA-u^>= zD26=?2o9AzqN2i4QB@4&^T+oMWsYM$f^Ug5S3lM~jvL9v20&QcQ1gE?=|bG1wyy#J zm5X}68x^rPTK5JbT=&4xDN*pT@^Ns8eP3>wk6iRg%PWGhZY(z#xQ;4ZE5ELn>uJW#+ zGz3?OjZ6A&ayw_Ha1MoMO0W@Y4!w%gVmNb=M^4@tl%K&h0h4f2V|zHCqC?0WhS>&b z*I>-Ce-`IEh9Srft>dk$;oh9=6lhziQ2Tyd^o-#G<;v^0#}IB8#)~!HY8~7Xrp`aq zKxJUd;T6iw4um7u!l_c(vj#%oylG7Si=RRbK#hxk#dAsSF2-%%>WuedXA@E>W8xka zt~#1$BMKIs;#A~N+N97LR49_|5|K5$8o8~GXTfNG-Fy*8F(^F5I=5dqjbVhkcNz(} z{)&y@prZ*~pd&&d7vCa;ZbuXhBu1dujak`idoHBs9zOjg#07Eyd_<4?jCecVpjFb60+uZ!| zP(-|L-cZtEK>2;$3P_^xZUc)*2>+#U-74?sXx)aB9^5}Oi0ms-^4qNGWdz2H$E2>pZF6d6*ns+gTfPh^hd zW->FmVtQ=2n8Lr}T_-29(+9@VHcONaDkrb+;@!{S{#Gb^%706Sz@?!az{nb;YknFu z0>Xsg)90D`D$7pe{{dLrD&zPU?;R_12IUR}3Y&sZei4#$>{4U7EJ_)%$$Dfl$7Yus#K*^BICP6i3Qu`D6yg>a#>Z0m9iq`;kcYsWgN>dB?_w& zS7If91X0BNeqYbb&fYCZmLhq#rf0gRyQll>ufJD+-L;XCd;-5O|I+7I|FE4%{23$h zUlNf=a5=w@z)VyUre^9*t)k(cZ0fCKCD}?nRCXm!lZre&Dca-(fIhP!GR z^^T2%np-m**WRqVn>Jo7cWlG9*70by-9~%GsJWe1qaMq0I=_IVdE=bD?YOn(yiseL z#?zgqRmTU)d3+9Kje4`@IH=EY`ISZ+^|oqmqtiBOOPx(OthV0i*k+?$b1i4yXx3b} zQMZisnq6yIu4Oy(c~;&*Yi(nFb=zswYfWR_?yOt3+h7qqUu&QOv=dF=Y@<8)1~R#w zPV;KRmHw<)9aLvG>QY_)ALCf6HQ8_jSUtxYru^*QDI>lpsT zVkzUNo@zL*pT4-c-n9JmN85PqCl}h=e(odd>f3(mu}0l3X@0tHA}=0N?uOaIeQdMc zaI1Bz*{q_Ej%6clqP}VYz_n)eYP)m2X_+e)Qis?@`F5Qjj{hQzJdIQ<0ZdF(bW^J& zt(2+13am6!FQqDJGi_!N%bGdFGC_>kld5FRG_YgXPjMh?KZgqQ27b?m!|ni>4qys= z;ph`OH6v_h4-+ufEVkeRv`d>9!P63Fn$@&gz#n4{z%HFOfD*Ct+pVQe)95S<63PeW ztnx~E-Z=iD;|R_%_{i~6IjXH*v)$^}wr9+>*UK1U+pcY+)^kX*yl&0fLF5Hm50R9n&7 zjtxxg@Ar90yM}h4vP;Y8?=m{Pj2^QKmx}MgJMTN9ZxA-PvUa^CW5ZED~x4N?JMJN|I&D?vnk!_S4Y z#~+SO5I^6;R6iRa=whkhCs)7<$AC-r1cONiQwaRDe9S(?glPtc8O$)4WpD(+P2Ep9 zR&&`NWh7zWiQBCY)lXmkD43aZdBJQo+L(^uTpj!JW0rH(?W|v3-*#6!?drJ;&s<)r zfrVF>H`_8hE^`7@14*I04i;LiqP$zJmQb(5MkI!c+QGj!vVgHDhq)IyeQXO8w%s>< zz}OOLT<35(ClGWKYl&TLFY#F7%8|?Zj<%z_m^NR)l}O1ASv2>D!0#z`mK<)uc1=;e@H*&|cJB z4QaZYTuZs>W&H;FD=C>+3VQFH@9NHZltJ%bglvjwoc8`(G>4zDY~TecDu>Q+Y8anw zGV}UXP{kgjIZnFWn~zcXXKM|Jf&CMF0g@s3PN)62ZFlU`#uLktzyR{!?zn-NYR+?Z zHkOUH1x^Xd%()^QKdN3P`@W!!LYzwj*m(^kU5SjA*_)R3IN;~qNkAK|O>mRd+BLA~ zcGN0&dmW>DTr#vR%XExbgG87JKFR0uVks*V%+EIgoauXW=p2*~7d0EO40o#zEfq?mHRZ0GME>x=< zo3&=}CJQ9()T>pSGsM1&0cV|kKLgI|SYR+Zf`BB!;fHftYP@);n3mK!68e5b%tvrJ z&~bib5mm%jDmbZ1E)ZX-;G`-;L2T3gFMH`lWPb#41{& zvk9r~usMS~V@MefQfAE~NSO%gIBFh4Y%++w!@L8rDeI7`!J1w>ysQW1?lcX&n;H1- zF7usuH*3yWM^MwXZ*)LBjJndU;xeJ+KG$ix zHDFaJefFtVb5t#2>Jz6=Ai#D*K7nAJizu{*Mt zfR;T9x;Iw)G`bMcmOk+%RCf;wFc(FY#yqDE4S?H9^yrNcRak;MK2&An2^X*>YuUsr zXnN29o7>U15raAnl@z4VZd|j>nAVFiyb8&*(}YYUqY8oA0agD06{!M`jynSs@=-M<#j3_i>0D{ z6nFc54BpS+AqF2{Kp{|xgS0-_KgHk-gAX#G$`V05g_8MhvYVoT@8zWP z#ZZ)LKZJT-f^hXxuMUu$blZT?^(yLr73ssOCTer-oLewc6vB`;Vy3U^&R6*4e8Vl8 z8H%PWBv~WvPcaoD?9c9`(VO(zXsqoSNg0c!WV;#rRyXT>tDAjAv#0RQag-TD9VCnC z>gdL|x*5DF<~ z^XBkNgels1L%X5vqg%n$u zv9t|SgR?2RG^K7#Rha6wIk&G{77Rh;f06l8+3x0}syS=owxW5Vo*}8FetwYH_6WuI zesJ@D{NB?`&pZrmBZaD(-UH8w^%2B5qI9apkPL!~(e@d&)Gd4G0;G-TTEE0wHi9~# zqK35`Af`%Qz4!T=-A1Ef2n%F|2v}&5n5)__Q=!jJ`Br*?eZQp{-kp$cDP+e0(Z`tj zrVlxK5|_h^)3KYlG66jT$kWx4aELtAal z&n~ki7BqG+>Uy)|SV1nnWiC=`KmQ!nI?-zs)x3xGN1C>O2gjo-nm&q%?!UCHEW%P= zy_a;nmwfT}N^z&UUKtDV)~YH|0OVan0zc_&w)|`$6qJTU>8yxG!B40sHX3UYO16*= zRz^a_NTN1vPIj5WeF*%#+kvrPMs(ls*^l7;7A}X3Ad!JSqUZE!Z47#e{zgX6;R%-I#K3<96C46upcZWa5l{fq|G)gmH4r zZ|VYk1fYKZT#6U>xgTqC_K`>t6 zAP^xt--k9403syEDE!)+0wA8V4*(=l8C=5UU&Mn8QwTXG#Cw=~cwf_j7!Lt4lDn{m zuBDNZ+0$VF%C5o6xi-|*P2Eg_e@}ydGe;i0E=`jVc-^!+Y-Y@?a5dCJu4cFKjU2xL zz74L(C*1SAo5Mh~g7@CYb`miFu&RF3N&SGD=NUmU@}q+V%jL4YfF9b9GI$Ju_h1Zx zscYP37|o8DR-kLtSId3i>!b2C_<_og?U2G_#q(-1pz$($g!RDQCMHAs3AUjhukELh z_I0K|jlc_elQK;VAlcsJLNE;-Wp;rXzwA93@wP=+gTjd#%VEIaXp~o=87YCXVOGU_ zY0i~oamH~PE26cw%3krDvR18=8=1h?f-O|}5w=o9+d=bJSj5jTi{P@DNRwzuV62fj z-J_tK82ceXOUmEJr3{&%iMeDRv~Wr*ypc&wX?ajaZzP=Kf);(FLZs#;NKIZ9BDJ#X z-Ozy0Q_70jh5fRcgm|b|QyUYsU*1f53ox?*2(os;m^uy^+Hat&7@rb=+yr<>k3tT}-3Lozc*H5@5D7TNRF6}b6o<8Has$Hs zZc1zdTtiQKpNQngr#shTY=f)DfdLri-4+*hA`6Le>Hc}+R0%n5GaGwiTS_o>sNL0?NFT?xs-=y3W;v{kz>%H$9u!$#k(=g)(XAQefKM%q(=Jc>O^4 zi#yO}y4kC`{i%&DH?x)vbILtC%V7mI=MwnKX-tjW&QLdJZ*{YuNnF$HtS2y z;r9T>(1tL_c?S`t71bAg)PZ@p+_+RAbDH;uH&q(uNLNLC24-YhZte+!bs zfbPV_3G@?-1O)Hg!o2Zl$v9)ovA>(n+(OB?SNe|UM@uE+1ilj>darUD!<(=lMsU8Q z+pD-&3W3Z;=sa+Ml+eCSxyB*0q*kujrC=FkhhNk;u#D1U2dMPJ$mk~<=GINEGIW}1Gcw!Ut)P5U&mmhwuB7046`5=RsscB3*J zb^%MHzP90~R;`)|?PSxofp;Ny-l`OpS%{^PbE`5GeBMH-s_On4n{6QQNBX+?`$%#q zOeP8>{3$&pYD6CQ0#ypFkeooMCy|bSO4Bv{|KxK8Ew2@mQ@VF(An^J|l^75)2r9ZR zBu5w2%hW~cODd^vk3vw=d-z#dYy|%^oDPSweNYlu`=&A@-sQMU$h0 z1=_%Pm;}!ktigxLta})(Mrk-!dluhzglp*Dtw|**wynqu9U*h^<)j%eLZqA280$Y*lj-k495bFH=1zZj#JekIM1^4&_ zcJXeYpnDZBMk0up2lTmGYq5nfz(=#R4;%y=k#7>02#eB|;*JJ}g*BADI6;B&DP3o+ z&@D++lyOC?6d;<0BklC4iOp7W6;8+q%80zX#e&xvi15C{Cj)^$5gXB<^}`%k61;>y zrg=wh-_5=*u`84r&f#*7po2?=!>?Vl~d-`t5 zOzvqrSvS3#F;jTT$rJ32J40@^n-u~0f1<|x&aj)?9WpcBVQ6%PZUK^lzcbPuF(D~9 zFXEnq#PEsWW4q(-Y8fZ1 z8L?ypF-jSJ`^2n>_`*>^Ry-hMVfBR~1RaV{@hM~O8HnsuX}CVZH=;I`;l9{VOQjS- z8eF-cGm^4MGgoH`P6N$t$S+j5zztBFn5orbDtH-P>>?jUwYC-Y7d6zH*K6C3@xV#5 zftnpUj=`@MPPo_b9rn(Is}L3#ty;^l=ra*pfwCXR6!R0ewE2?$Iw_j{A_9mUlxO7G zJ6t~T@Z6<)PhL533H}N9pZvh(`!1KS+*|UsEl4Y= zdMAew3V4|Bk+?Lssl z^C5uZM2T&BTawz_X$b~WtcHUmbxLFx!?xT^DZ-*Cp4D1+w}l=B<}v&8$m|c1tgwB7 zVAn%~KC;RjA0ZylV3WXHNT7vhyhEW#2SV0PF=ucriNO?`%AW+MQeF}1nh47NPAUdCaMQh(Eqg+hX_oW}w0zAQ01udWB_aw%C^kVD>V|U@5>TxN^se0UtF(+cL zjj}!H@qY!JPXv&>7f>9DVctCA#aR3#;-g{QEQ15x=lxkESBLJ~Vj|?yy=h=!TmB1nHRf>Wn$lT0thS_?+@ioLl(BgNz?Cy2fM5OVOBHxv$<*$wraV}Dj z(B6ymHbB92i5>H^k1L()R`SfL`)>`O1E*6vzdm*=dG^%%;wc}GM=G~Yqxg^btLRTW z=ILkNTS@WjczU0gIS*QDe9Xf~?++p$5AGE&iB5VMwoT=DAc!D`Bvx)|Czb0?`qcg9 z_bq$rQxEVqD^U0dFw&;?4+Ku`ZIl7cJudsM0QMpHMN#VnOmfeO8z1)?jH#7Oal@0< zc5&DHD8|x$o&l8{C18l*P-K!Rru{VrqN;s>PqcW7A0Ip##L~xgOFs#J5fO6z42`;$ zT}jp&<~}Q+c#T)di)bEBYjm+|eAcD{7XK7}Xy0R|5;cn*gJjtWXNmcex2cI{6x?UxvQi2)@hQB{>E(U%z| zWA;-^*nFzI=|xoQ1qSoH7s1cqau|cTI$ON6cwaG%|70;+9GZA=A~lgIo-U@2<&Nbf zHyhaZBfowGm-7?=|91C)T3L(Kmb*gLKF}egz{nJHv>&z#Dwj2~R*an8mD~Mj?+q5q zMqf>uW9B#<-Nz23gHIxT+?tS@!7UM+WQ`J^HV-2{#RWICWX7CD>`)LpVje|oI*1)J z-+|a+^Q1Ll%~*NWT4~jtuj-~@-UX-nS+k7qq)n!P{~_gwc^~p;%>;VJ)Vs}h)A`;y zhFtd`e%zcxt?w}JM~xHa38dU>&LhVi;+}uX&pjQi$#^g}2OFVj>2G)1Cqrd0TKU5c zLRfa1RQRFl(_dMX>{9tHyyMp~cs~*EcxuWzOpr|%JAcVXX-lT@4csyB{qKkg^*gtR zT|VgDevl8ojqQ2il2G4LVUOjhr`Q%ff!Xms*Xtg)3(LKgi#7MP?02>9)uO22IKaNo zk=akNpm;KS^O3(p1d!YNR4$fs%1>HM>OC?#%Wr&xfskb`Z`{_+`4!SlQsusVZZM(1 zHDo+&!<+dxlLqlRhs&8q0PzXdTdvJRID*IyA@sIRBK9Xxz^*c=G1y0^>PY({8<@jG zYSvphU7%=hP3>+(&>|?OE70RSwMa>Y>HSwY;GLePl_m8eU*c=&$Kmr z+qT1y@_APD_ZYkjL7ih1{wNPJehybaHZglZI6*FtVNojyhb66>!~$HZs(1C^DT%$o zBvHEe*hd2&Vycm?Kwgk7urx!9cO7Y#97qsYdzJk&rtQHIF=fAi2J92i;RYrmc(3_O z%=9FKehTf6uW!ooe&9r0t6%+k4}&2@W4$&4(1EOd03k346afe|378B34us$#SlOI` zS1jV#AS@(Tj#R;-+JY=!`cQ^w4ct+2^;XM?=Uw{&M?!fhC zn@c@1a)ag>kwdh%7x>-3!9bYFfZ5?6eUrMs!R9CiKBHyy!VkT>-%gYIs2i&2%@*RO z;xwEKZh#O>Q7$q-f|PgST)2NxEzij20UWX8$c%_jnzFr!3j016@*v^Z)J$)`=|q-l z?4L%Sfz*Y3``@7Tif{KiPP;!F69atY5)-&cJjXhFd?I2eel9HhFT;Tre%+JS!uNfB z;@pbnmk9rqg~r0gmJ8_PZgMXH%;cd28hjtr(hCR@f|%N)6v-}x*9@(o!}u{oaDdv9HrjQdHfp9!`qeT)V6QCV?{>2J%w<=7XH z7fvyH0WRIj**2P4#DV@2eSk#i-h?6`Iei8#aoW3hd7voALB@bBIIwbQ?;+Lx+YI_U zVt)~7--%#-B5Livb;8)J2-YR-zk`yC_TOcl1LlO7L;fQ{!j@qLa^`kbALt7Cj_{l} zgM#-K3%#Hx*rY=x5}t+q(mx^1OTb0Rj*t$Pc@MihEEYq8 zz`}=G$iWU96Mj#!3{v2J8*{2Jc1Xq4BFe@W&1q1q!`{7bZNqN`!dKZFgPAuvhmO`$ z`yT*W-Y~4P%0v#+>G+~-i;nCxNSY8+v+*=}(_F2!33Ijm6%@38i^10s_*&lnhm3rc z!Q%}65rcos;GZz~8iU_tKwPtbgTccLXjNZ?`>YrNzQGi+3G~de-(=bc7?9D$gpZFi zx^MoLkoH~T4krs0@)=N`DY1vn!D5i1z&8@%ldSYDvziqIean5LNQk(Rjv6s)kgXe- z8Dcj8&*9NH$`h}rp$i{kDm^o$Ou(IlCz&bm9(@hFswp-f#^}R`EbY>tKZ(O<;N90v zUroTDkNc-_1l=A!Y^GNrMIv9Wk=TI;AAHTSlsJJwYdgc(H#|I>a6SiLvMhbcz;1@< zz1JPajsp1k*`Mncc5~R1J%XGg$oUoQQpOWhSPV2P|<<$BH_SDmcQePoml9-4Fk+O~-=E zns~j3_)575=WcL=FkIqsEDLHUOT)fWEd4CQDeW5k*qTv$;N2eh{Nuf3Y=eCRW!2Gn zPOa%;-3Z4J(PhkVRy@$3~TLK{0>mj zB(Pri01a*;x(smleN#;}31X~d$sY3*Pxa{QF}4uOC=_Z)IX`_Jt5gJIrkxL8OwenQP&TY_Lnh&xoTY{Q2Y#4li05Tm@a|yKrhyY=4U` zZjF`IB`kODqnp8Qf^dxff)wlze#M)G_r-1ZS?I(bX8ywn- znl~NM*1kdSBdwG%BGO7ole9v~)Ly2ilpdTuc1K`iAgde;wmiWx7h9f^8^*>|a01!c za0g#4vN3Sd(7Til?=mAIZ?~{}P@2`Dre!!GabT!P{vq!6yBWyfAI4L-(#I^~HmN9@ z{T~=FD5$F0X8!>T_VXe*baL-yNX;-RcLCi3MY8~w1Os-=II}NrAH;}tRMa^=ou5PS z8DT~=6T|;qw(lYZ<^}~rI*$X6P?i9@X8_04PFl7LP+`I$6=X>}85d`sfz4$0;6qQ5 z23vYl5Nwi&<$5tQ{c;jp%97!>u#FSlBv?Vl``L)}T%bd$QeuNGExoaj{eKa)!YO6f zd-8`m=MV?jl6LXOlLxu&NrtYEnujCTBVG1)0EU5dW0F83`wvww0;`KXc5$kq$tZ_f z*Q>}KliCeNuOfI)NFTg%gx4-TKyf@T-B|@6;&N_C&?;$p4O@vIQzX_v&XI&VztF?D z-$oSXk`B-JGc9YKPj*yD$O=|c#NwEqY}Mch|I>~dtr+kea##js@osi8_oG*tL0 z%87w|{`i5e%n?B;WhK(wdJET9Y-qp{1wd6CG4p>j=|bFGcAz-G`F@l==sQKcjb^-o z2xmNSb4nC>%z2zbz`dp1n;V`%a6YDrqeBCC-2+$XMah}y7@Tdv@g_XFLHTCDf*5QW zrGikoLnf_|-u-7su#7Tl5rIM{GwIKhNZ3EQ?}R*{j-OITP{3=P_NwTVlMfP`TW6h>VD$5KCtu7*a4@<6%E(2&xbhm-W5m zPR>kWvk5njU^UeodKufmu(2YKoILU;KZ6bblkiOA_pk{?hfp~T^9<6i!f0dv0yc3B zLvS71z*ASl!8h3{P_n8N&t2k=JnWAbc&1=W9HSIymM_ZE~oAzQAU~laHGn z2t%%gjiR#f41~ZI(wMv#-&h)e8t3(jyOJC^j6=B9w(X_P7Nk(d#3w48bF{riv@6<} zsR*HblR|1xp%TD|tl^2qZMF9b7V&H5^Vn}ep&{0|gTiPGBh+!xNWl43{0Me1nh^cM zZ{Yw~!JIN+!P6Y2Pf<%1t&N4-4XbGyWS}BxgNE}^wE(IRzzOksM=dX6b)VZwxkoqF z*nzU`MHI6f7u&w+S2l3S0OtDU5{~i^rv*^TevdSk!hcwlKoGa_DKe0OZ&Z({yEcrhk-aceUnfBiour{{A&iZG1wHx z6bJtkK7E(LNe2Ik!S@;T#K9hl{We~_OzHY=JixI@FHRM+6X}V}vD{2%CRa?44Hr}R z7w@=lB0GI(EN!zy>9F$A`cpjn8633=wN78m@M_+pGKha|^d(BDb-U<}X@6B>W8g`l`W_ISy`j-~7AA|F^CvKhR|H(J=WGExv`0E1V6KA+s5yu{uy~O}=$o=M}DV<6CAMyvobm z#Hhs82TE({9FodQmR0|JoCuHO~H~ZYuYcy$_~(s&nmtiFa^0+&)9P% zsXfD9o-5Ck{_=BJ=+|i<(6^Er{RU}2qZn0R=u_i6V#-X{HVUmd zi#7C<{fM`%ct1KCaO<&exz-?v6YFV^>{|mry!v(P?%iA4*4Nfu>!I6z-1gdhuPyB! zg5+A)t;Yd6uaAx&v_ViA`pVzcL;ZpBl4_3Wxgv3n$CH<8rf%9JTZ# z6?jhfDD)DPsYl$8JrRtONNkVBS@jO3lNSv~_fRU8|DY+8ts-Vlp0wL{29f6u;(Hhr z+=5~!$p040boV%jZIc2oqJV$ZQpGA(Z1cePM?4tD=d`&uc58r0R^t!Q$=x=m>={ZV z6~?CW3@zVxzCqFF@IG|;s!Wx5w(hbD&)a-`u%Gc_D# zhAV_S#*6aL)`z|u+ND7dx`UprQ(D=Y7sYm&`@NgO<+g@Q7IZXr85VID{Hzp&u`d$a znD`=!Gb7%2M}B6=!_C-n#>N?&2*ROOB{Pe7+N*C=0eg8y%Ck0G4*jQ2!5?QV$k@m( zA>9&p$5sz_Y-8krCEqp$Pte}8l|p28WE9%CjSdxyHCcnb$1X7wZ&NezUc+pM-Fj70 zO;%!;SyOGYPgU_t@Sl^u*D$O}TF^d2i&GqVen6aY6(glD4{JjKj$HvhX@pOSdaC6p z2Xq035~86FGj8yb91{YSa-^GE%Q}HMq$u64%%a*G!_O4@HG)%@L-`d)emuwDD7LWlbAJ!jar%(nYnMN`p0u;(%(V)qVpEu6}%t(6`v{wk##B)`E8^@7C5=U*7*x zhMNH1wGw|gvI6eAZ2+3*9>u;DM?;@5qc9pAC495w8JOoghLfLmVDNK_X2d@7OcqPpLq zMXK(Tn?>ZExj#-46`=Q4#3I)O0c%UmDi?_@-Xrp_NU3q$9=9ix`iYhDEEh1Q1mwByG+&fo zBSSY4!Li+#FU}|u&Mv58i&U?_262wt#h<3EJj*;=fAE*jKetVJp`1VM+B%KwCT^%B z{4PGY6Q{VBc8xS$%$8;<=J$JEXgC_@c_7n*uvSVhB>LQ6 X^Ug5h_-k>GzOZ7V*I?CDs~`Oc-kbXwvERAGYE86nPx;VCOPnLfqZ+;N19ZNa;##VTtrC&H&5tJTQaL@B$wNmgTd2 zf=>c7`5d3($AINHyQviCAERGorO2YJ*KmasRJ=MK^4tlWQn~8*zU$+8qP*eODrKiy zx>NUdsxDu5%ZyqkKT?veYqK^)X2k(kc3*4=Z__OY7aP05hF33L zyZP?LRmZQCOKaQpGEL;gHBZ!>pww`L?~297Ze(1kI<-~qTpgvJNd!Y_o)WIZOZRr0 zrL{`cUA$8DFwy>1ATjNM_L2X>V;#hgsl{Vjvqde+xK+32)`Q3>J0hU=cE$<3TB*XD zZ@uq`dZoU8+PbyjT2;5c9&A|Nnnm@##l8AMU_r9-h83W`zFk{&MX!Fna@Vb+hm<6b z^tw}Xqny9%2X3uI55W%F{j$V|K(%7Rh%aXTNl1hkK@Dww|I8~{0_a2E>cOz4ZQzG0pY6lG7&$aSv|jw@0ZeAFJey)V8)NhM=b8G1u7YZ@G&8)T8ZC{u6!m!}w43)eqy}({rk;)3Gi#-Nk0Q*eqdz!AD?l^i6XQ zh7?BeTq`BVat`JWc7H~px)2g==DXMkY0&qC+P)`9x1S>3TYnqE+F%an@2C%%_)p3A zcsmyqHZ_`2j09Tlv+iuuIG;ejs2eElX*V#7FsjPtm^9}Tu{oE$JOeH;>X07Byu<$? z+j>FP(p)P)W3D^(!{&nKx>G-FE@-Z)9&MjG*PZ%d{AsQ`^~3o0M&lQmYo$Ba3rFS} zb;(?Juz0R}PmmX~#Z)7_JJ%T7AJJ^{lh~S0IiF$kF)jSwESKFzV3h^zs#CS(W`DsV z+bfN7xw2-FiL<=zK+-Q^8*yqK`lSvMH#|h)d4gB!u)qG2wI-ahV@b!6@; zQK?(B6)wUk3&2j7%5Jq9eUI)rY3q&$Q~Oww>9T{`u3Z0}CBv2mo=!z8P2aD>7-OtS9OYDZ`e#-(Qv z4@uQob*tvVOkYv3+i+=+Ry|Zq_Epdu1$S#$;qI_P4CA@`ORuyn>;BTqEr3g}EM8t4 zf|$3O=X%rf)U|NGweVGz!u?h=bsxR7R*K_d4kIg`Bk&A>o$F&B>1D66i(tid8xqUZ zoI7x=E=5^6_hfaVl)!wr;#uM#ov3vDB8y`vi3}PNS~`N9RB@8Z^qb$gaieHNnZ&-q zHQcR9FU&UReInyw=ab439Lq+;DsUqCJ%Kna_Pb+0OuAqN`8t+5ElCSM1TYI}_+Qhi z!A$sEro5ZDT)5I_(Zcb)~3^H_=885^&-krQLf0 zT;gIHqZV%zevin`Vmy2~l;RCg+a{S&yf!Ev^kGN40~j~(K<`t?sKbjbz<1S|SjMBc z1W--^bbcp*kbU_MmLCjVz z%(daJB2GLXX4-lv-Oa6e!h;l1O2;eB(|F1tzRb)jL<*Ed2GReJTs%rDRE>gTN`ojK?Y78h3QI=ZJQj@B=PLHlq(Op8e1Qr?{V~< zh9)~(AB5vPP0k9Q{;AuC2}fM^!3$K*Gpz|eI#?S{kQ8^B_($3ASYN*#Sra}nWg^RD zIKeZoDy@ki*UpEN)cZd3|0W#8cqLCvsb`Ys4AmkX@zmb^@AvBfXY=wQImb!H_6VPV zTM1pLeuy@p>SzOsSEW9aPtm7KJE#vy%%|D~Xp!M^cJ}3b56Iqaf0sREJ{yq!Q-17_ z{BjoVDv;{aA#$U(PwpQMwv6lewA5R4#vKmZiqvrCDH?XU4rgH5@{Ql5~MuY+=?GfM-jrYkS41^+)lIg9KSz3}GMW}L%TK+7>{&(c?vF3o94^Z5~)Odrh|&}bri zgl4>t=2=N|{16TLqggJgLwkR@P zGJ$04=)RiZ3w`)BpQ3VTK7vj8Y-|{8%ITaH9d^nHNnbz^Jkj&}a=CSZNjTay^ z0UV!Gdb9XY+wySn(2(@g;aFI}Ca>}HpXhRy5x`3c|0I~<#Xa`BwEZJ(1L^-9yVUd3 z7@PPhnB_0D=R)j2l-1n&W@x}AxUeEZ^WrLiT{s;24Lq+7JeTe9n7%J=v&Z_2C{8;}B*U`Dw89VAM@3zHi~6SM z!p@XEBpvF3j9Z$gRy-sQD)oxL;c^S{NN~Zr_0IKoEF=}m9(Nyev}v9~u!HRLb{Sa+ z*S9ttwtuAsh@bp5SCkMoA>xcO=b=R zMIYh?03_}Mw^3?T%6F>B;&&o0D^0lG$yBJ5d?OQ zLbW7eZ;$q}=Q4z|=}dwW4R-2YeA|t0O?fjmyq!|jtFPNSvA3rnL`k@7$m!#d#6xld zO56FmyVHFduv1AYzGInoAd&xL07}K{_!|i9bahOnU{8HeV6z&PQiA@RI*-6lr6fH& zi_#NpQ3gI^te~DaP#-L2bcB9XHp@=_pZeguuB#_eLVT%o4F8k=13GgNB`4Wg^whhk z=aZl7=4Eve!PFG>^u%~#XeZIzWwa_}r%`_NqOdZb{M82(h{Z<=`6v35S&g(ZQC5I1 zI7Ik>CP=4PI~tFVluF2OmCAB`ABscC;FG{0cQ(sY(e|d?Uv86Gooy& zbY6hXZ$d_^`)sH75$jSA$KlY( z4UqaIid;gGK0?8G9wAgqN;4wcBz$gLtu*C&A=P{fM~|CENFmv-q$K$=n^`F#-AP8e z_oOUs*-4I_LbOk1tI@B(|4|=*GvRM0{LMc8-`o0q;veMr7-WIYw7nO{SDM=Wmf6%V zExu$nkAF$Br?->K%}&sK9Wv!)Z**KfuFuGH{SJosw$Wl-Y2)Z%pM6__r=ZGXWVM;T zzjXB#zqRs+NrUl7jnu`L#0Z)VzCzh6>u%k>*AQ3#n>bS!>VsLOc{a`v;5|gHcu8u+XSTVApvD{LlFc0VA?um ztDELEK-aH0b+bADzJT3;x$ma7ElSWWA-`43$JU$dWRkPG3obBD9*#v>Oz;@XWsse zw{N_?a?94=y!FnF0qvvlBZK|#XzITKAng~di^0ywiy&DEvK3@`^e5^A^$WdV{B%fZ zrtpPW1(PyOOXoKQq-zddK~4AT0DKcb862*5yk{7S_)7pa)Vg^adhDu z-&fmN`Y;22ekL?wl16v|Mk~`cJ0=NH*@*bFE)T+D9)um`kpqv>Av|cL#?bTMp%pYh zo@pEzFO^RqekY9zVFBM%Oi*^JRp677LjhDcup;Bx8kG{@fKvQhh%YesG{!Tg9t+1n zhhxgtI6f#L>BnV@3FIr{u^&y2DxX1In#J7Ya5Om@PGX#M1f@BPF*2Aj200MX%*Ao( z)_)vj`s^D|nY>tx;1ErzuTQ~HPMmU=U`8L(*NSq(EHk$HtqmEia^u6_gCocy9< zN@+eP59}QuzQ`&*$3UvNKZfbF7;0@v9^n#@rfTgeJ&SczyoxTx29Rd^(sxN9 zM|#whlb=h>eKO@z*{u9HN5K-ohdBF*%0+g-7K%))fC~QjZ{rr}WWfe3h!o&82;kG3 zJ`WfALNi4P`=w@1-mGhQIN#K7U$8zrqA5uzg`s#v3NuUl3TIeT4C3Hm81D<#OL9-i zyZnrRR$M$g#mo`Dw3!SU!Hy!SyH0i69j3{(?oht;mF5of*_lRBMyGOt2@dRl6`Se> z3thfTL*TpSj_FDwHUb*--J-!cM=Mg~Qh!EfIh{`>C zwBs~)3bK_tg&24X?L5#mm>$u<38a>DVvzh_7oN$$yvRq}e>J;4VS&g8n&SI0B0?imUx-?VO zG`^sjNL|M14xHaDbWl{6#Y=S8mzR5a4@?kw55lGzwhVXRG_a^}_*<+veA9@K7O7Yw zaFM`u0?P!*!IrzH?5$7puGfMOxN0892S`_;wBpnHFqGJVH<+QB$ISOk!wd@6{{u2~ BSiArL diff --git a/basic_function/__pycache__/format_parser.cpython-311.pyc b/basic_function/__pycache__/format_parser.cpython-311.pyc deleted file mode 100644 index d2cb79e5634ebefc5ae6c888d0c397951ae20b42..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 26332 zcmeHwdr%u$x?i^*sPzCzfRMz)$O4S9F&GB$7-MXV!4ELl2HT7cmIW;^Hb~GC*p^$f zWV1KROt@KQ${Vv&c2>KYnPqlnv+UMPEq80SE=kSJ^&gw6ZslyDQsRnlC0lhTf81g+ zyH$65lS;nRtyc@g$etNzle+EZ(>bTV^ZL&FJHK=Mlk98_0nc0iksbS=Z3OXOu_gA% zqK40}gW*0w6ErzNT*kk$335_)Sw>>B9L(~|a>1+sv*NNsFq0F?%Su}L9pZ8pZJ?>| zke5~PtA<~SHqu)7Rnxh24*Y89JX#OGTAJ)5s`5X@(e_l4kt~nzx{L9GNk=<G7yQ2;4xzHdY_6h zBB1oClYL9WK29u>L#jbYeIP5K@@W&CzU+h<%iX6>*l$oVGo8h%r1?*=NpnK^@?t&- zf~BNsp;b2pDKfwDnG#$#;%Q>3YIaYAWEflnKE+l4 z)Bg>Rq#5E*xcYpKrF);iO;1crAMu{7CqleOyhldJ+g|M3_OxxJdGLZ48vNj>(@l>L z9(Q?f`aDyEQ+IsVJ#I&P-`T-or+0kBF*@TO@r`@jgQFg1(&=+dIT`R)Gj-=F)iU9n z9HyPE{!+%}q#d{K_#LC;6Rw(;2@muZ-c~S(47hE0ao0CTEZ42betKYz_#AA+Q~98G z@23hdeJW%O%dScuQMo2ulPi2eAeZL7?c8gMB#Q$x*}Td9k0(d z>A;q%+=#+6<#I>VI7<`bZkIPAABUvKy*?(Qc0qtH#_4lK6cgiKUnC3i>vJ&?8RLp5 zy;Bq8zzAt#a7H5Paj!#6KqSjK1%c8G&T2%_25F5bM<+Z^Uqm^@;NYn0 zO-4f(3c(hC7z6f}Eh#;w5!`SG=3RuEGDkc%nij`{dAzX#87t;ahjj((I$KC*<8^kV zv(I%rF0d|D1jz@x-w{!f8w1*@f)&E6uCU4-fP9He_!Dx5tJ8WUvJbm)qVh;SIIOP`e=~ zbUPk5G^2*&YnQpR=RO(Y8~RW~pWyyPXZ&9G>)j8Fk#&z4(&MH>tEbndKR&}ZokmTk z|3e3Nevliwiq1Ru?rW&~+V6(Pxa$n(^`YxC{O~Lqp5+>Egd+7aJ5onIFoRI)0#epb4q$oXY-~%I0#;^Hd+A`Z%iZ_xd8FFJEq4 z@jc4t^gX=32kCp}PCn7*&Gj<5P%@nnxqsyLNW7NF!sq=`PO4Goih4L`!7EPOv^Eqx|0M;(h6~uT9h;UA-;GTX=QZulCwlW;ohA* z>fVTt^lVCq0q75dS<#G0ex$Pnh&f$OvNokL9j7W|&XkdPo7mMEhpEY!bJJ*&Ht#yg z97=H`K2M^v6C-%CS(=7z$~m13b25rinTY)>$EeRJM#I+_qcLraN%{I(rJ9>jj7hP7 zWvS+66yw(^)y&&ezBG5+v_4bDoCO(knlt9Kq;mS=Gd<`(vn8`g#SLL)h`TtcFX?*W zI5G531Io1R-5QLq29i5V7FLG1Ewf3nut~M$MtsJemLtB}Nv_u!a3{r*0e5nJf-&L! z__!dI2}vw`hn4r?(WWpy+W6B-rD$uZ_=*!UN{my!DTlh+5L-E?Et`@w^2mzT*p$`t*3w>ROdGy*1SDSg z2$(L$wDsvZ#$4$Lm@dZTwgnUWg#P3?Vw8@6$uTD7E4inT){N(fQ91&qi!mwouN*D2M^+wT95BOviYM!=+4(u{!VIJO=EznnX%PC_pqALS*nC&EsyPZ^AW^bXJr zl{)i4d!oUs`nHh~*ObpT!hlA@IbjoYxYahS^$^q{BjckstX!~pW_$vFGiY<1lQI70 zm_hU?;qZ7EdfW|~4{x(=lyQ#OJd6!%XOfw`cP5){ouK(~H49NXCnlnL8Hn$Yt)H23 zHQyQseGM>9x_rzXTYvBI-sBLl3I@W*NyNr4{lvI@GmIP3JwED)6k6 z%I3Z9nVFz%!!DbkaLL5(DcWz;h<;kkuvF=bb?5OZ8lvW>y#WX4>+!&qmu+kU)x z5EkY&kOF&k<*mx9XAGt07BlX132L6*(HuQ<5F-;FuWPsHEuxxnW0k!-q7+p6Pf1ND zgH=2JN=K{+9NsDCh|4j?cxI*q4VdHPzB4$#m1pZJHU9Dh<0-e#(eE6daK*G*QK{5v z{8J{%eRTcCS@6OPEKy( zd1@4NcA`iooWrgO%~Nxdlkd)y3kzx310V8|KHwGyiqNeb?`-7=V%&ANxhY_?-ED3N z0I6%Ld3kgTgQnlVD^V%!eYWhPbOZosd+3$X-xs*DVtUNgHfu>k=sYoek>6eFIg zI}wU@xuyh@YSMWVRGxJawNUoBSBfYZC=VC23x`mML&y>vGGoRT3X23+9R@EFn4Q>7 z(f7)kGgVYX7482(X?tf<(2s&98F~oZuS9gwj}B1zPDRLZQ18;WnN#53in(-|k&J z@vuFp_`#_Mr}+GeAkF9R;T3z4VlSuI8&+xNJHGAzC;qTD7io6{$&hv@r`;JYvd?$V zcmKYq^ugg}GhbAPit6UO!#Yc{;+50c!<6BB`q%ZG+^b{4uHsw&uQ~MFMpQHB2 zrPat&O^9mZsHV769Z%IGs-C0jGjn=uELfa=n71^DN@{uQeq`Ox8|#pYYaxtnpwAzNMT2WQ!8jDT{g?wTVop5oZW9Y>7 za77)eIIv2hie^;NzSf2+wq*NEPL|#v|3F~+nOv5A6U>-mO?aGAguybef1=jS)2~}! zv&J)bY~6AyWI6S_(}U>rb#8nLjkD-KdZ1an!YKN@?5H^4Mw*9(cP& zl=B8bszKWigqpwiV9+z;`?SB*`)_V zymc?K?&Xcu$XJbAms4nCPBL8^Gk;LBGL5PZqN=vF%1`#erW87M9v!>F4Gy7WSHl(6 zZx5|ND{Dj*t*g_hVk;)xzOD)F>zdHMoI?A;5zDgURBT@x2MA5xyj9pglI$Ge-4$Wy zXc!Fl31L@=0NqVhKp{bLC|Ht}eL$w{O{djGvMKxGptIY9b#_&dj-aS<3i z@aH`NWR3_Ma_3Hk^NK)5$_jzZs~(%mf+tqwtL2>O2yZ%qOh@K5Vam8pSwfU$sgS42 z5mn9!Z&;hpDe?s|GA~1bBXJ!iSnxi#Aj{vvHr#e<3wl5}zr1aaTHvHwlpMH#> zsUMS3Dhn8-GGbM%E}`4SHV7ATn{Z2y83II*YUn{srpc%uiFRKNh#(Bch(3vRKdoW4 z()6d;B)YkPF>R@(>;f|6!WoI=j8ZAnS<#w9E;v~uot4Ndt&4v7?<}B|B$Ckxnz|I( z(AiP`5UmOTrffT3Y_3$Uvn2CTTAAs$k=ZsWON}E%o}e0eAA?JTM=Fy*9-9j#54A6l z>%*&x=##V`l9wzgsc&yFt54HTw(&(0?xJn{sS@QPjSWzK&nW4Yj{vQ7DNR}b=QO2e zkd{|4sb$8b18p?5ZpCWEFSe}_LfEOL1vTPc+8UAj%r#=mIG(RYq$?#lFBr$KC1e7- zYk9PaCersD8JFjld6*?v?k&m|j-)Bir*wXr5>08nvjp?swk;)re~VK3>+Ql%<)=;2 zJ|Zj8!WNp-mUF=~m-CjXe7C>Wz0uu;(NZnBpM%t`x>3~%;O8e{7stFY*}M6lz}B$o8@e7Am9DwHxPq0 zDF&0<`3=Mn-6>~q9tzz1zkwKvqIr-(4CX|;f5vf(Sos_9Z%LlZr0q9~Uv$4AZMST+ zUM!MKUD`C3yy!H-tR{Qwc}eoVePY+1&13Tuy@IbSF{Z&6Q{sy>?4=@`=yY($7-GaVWX&q^>=pb(C}uBy*)Ly+4^6wo}Iueg>)V&;CwM~c+&%CgL0&; zZX4cR&%kS<^HRfIR4x1~HXAe&bigW^bH&%mto25IEQzc>#eX4N2!1;PMQqM9^={T1 z@zXMFQF84ETT&eWThiY9m-j{soeM6O${-~L@f;c095~-0|Df_+(5}26*RV+YRp{r) zcve?;k{5I(0VrXk-FX%9;{Gbbquk-D`x+K4Qf{Bhd;Y_ePX_!(QoOhuX ze$G;$wcFdnVDPY@^6Kwo4q`K2RSVn6RdT^SK{7p;!f0_2c#X;6C6#zIMNs207%CY( zR}$NO+kCF%4eSX+D&Snn^vbcH-kKvs%^a>$PsuVkGl4gznL}WWkd7CXtOpR?=VG!} z+tzY_eIk}DToa#?wilI-cOW`u2ov1+8I1CTP3ok5SiJKK@&WJyoN0kGKEhT!(+qqs zDv3YG2}>jq_wUSMOjBN3y>pM+_M7?PbZhC#|CKo&m5`s}CVV3mUDz;PGixK$$_0PMJZ z!k)grw8sOWnsN8I_qvO=!Ct+uS^)Cc0B|$npWX3%=Q;(U+NwEYd;jEYAY@L8q z5R2v|Av6qQh#fQ2C(e9{@CJO64ufza%G=`49r4F6{7~0Dw;U55_gF-MgOB7xA{>ls z6i%YSIX@2ocOcUdo!fOQZu1sHq%n-dR=>$d%3>Wi^UuK3E%@_lpu?Xd{+LDN7HoT{ zES#G^*Ci5F?joQoN;p+XIJaP4HLrS-Yo1sAKKUS-TnVSifQSdlo@n*oJM#LGr31XS z3~9^oKJ2cr&b+SM5z_5gn&EX7NLRt>D*isK(61{BLW+X#dzTvD@Zar?8ccbjx8-$Zowv(TUKl3E)C(SFvd8T;^P#*yOFg0-E z+t|;EZ~ldO_3w+ygBO+$txSK=y1IAG`H1|tBdGZ_xA!!v=;DjIQBn5>(Mzh2Lb$@u zYhCjliw%#hB@airiY9n}Bm0Alk6&XBQ;GnRoHoIlGgUTV8n608gvt2krTvTLPl zjr_1@?E;*)dvuO(=s`8TT;Fu4W}2&+4(}*k)P@VJK~wPVO5NM9uDJM0fCliDhgWa% zm8a0o(_8_5AMe`pvnH;-9p1k>|H;6kQSKt_WWp)AaWr^?yE(-VPI1!=KM05BUgZa0 zMHj#I860GTZG70L$3I}J7IyQ2g1NFCe*n7J8_q8c8sF$y?0Q_Z6BX@U?gd~%bt_+V z1Qi{DVmG}UwiN$R`<6D?yV8crn|R9sWI4bI?~|>+Y-AJW!Y|9v7aKY01l7nloP1=& zrxRw6<;z=8c?)kjj4X#a%i#@uMzSm*@=BJ@-P;#D5X#-didF4e+s8SN zOwdJe=g#vlpGWn5>-DdM>H)no#MfU%^;fy!*--r~UmvfYOBaHNmZ#rtU8&&90MEmh z9bDbZmvx|$PObpIi4t+)!f^5>2Y{iZ1DQM5&D|k$_oF`Ed=8n8~H?WjQKV;H})~6%XGxh5Dva zIX2{EV?Vhem)S4CaD^$Lxv7R&5KekscYgXz$ z$mZ=$$lkQ5UfOw2w{-3e{TuqQu?P>9-xOx=0kO?wx?}nnX?_0MwhU|?iBdcY+y%pJ8lO5^RcHVvz*^f#c zkMZ^+$bMw8>x*!{b%Rg{AVW;&k3TM}LS;RVe7_A~bZFTKDjVT+wy>=M1C?Sxz(rly zV7ae%9o%5><}uVI)9xi3{iy-TToH$YSp8cxuJ1rd$-AU3=%9B+8$7tb%Pf*ztPnd}~;5Sj>BEW{G^=Kkt8R zFmVM3)(cug1pwdT3))aY8*ey<497Ubu`ixz3phoAuq=-#Fp@c96z|n=!1)2kh_FJE zptP_y0WWwy|64^Ir6mu7!{&!hbH5HNlWKRQHsKWQ=D@&(PVonv#Em$j0-w+ayQmDTU9dN^D686)EN{HQlf$PHc^AdKc1@z)@|J!!&p&xJL={8pJE& zO^-C}4rS=Q%%63)xwXj~=y{J66hiOGBC?wO3XtE=+jV=2 zY4w+jz*Nwyz`429i;`koPjLWPSM;~=Z;s%K0htf$+NFUXI^J?DpW{nv0nIPECR0(t zyGe2Tb!VEwnXcJC8t7sEIi!(Uz^P5jF7q$IA+wDB1NOQfNGziRT=Bq)bLsMWNlmDv zW_cPGfT*Mn98+E~tKY>Dsg92`psLcqaUxJ0`Zeu9UEXo0TjMXh#DHE4)RFOHi#B|? zsu|8ERp~_SBG%(Y&l}A`C=tEa1D8z=-*H55pJF_iK?|q*VlI)~B#Xm2>6rnQC{~ar z)U=ERTfIQ4C=siyO-h4!7D+wp^iGLr<~MK@3L$z`M|a1Wj-GzF9;(0h%ob;njGO8Y z^A>oA$qV8AH&6(3#1jl{exk@;SLBBj`HL6t54|}QY~W41k!d%ts6vV=PEi$(f4{^l zN|2(2Q-qba-n=B%-Bt7Y!S2w(ZvJ2oI@kkl zd~q)VYU)L%Uhe!w-gJ>yTtbRVoZ?a<+0g)Sy;wc4e&|Bz&;|a`C3FY^9pH-x5Kz+q zG7aF99VQ2_xP}zhIK{OvC=y3Qs$x!69M4q}v^#9xv2NZIGVkHd)yND0eua7l#DMAi ziLj;gK{n``vau7Ui;8fb`CbdG+w=Btd3(b7CF}X+q5Sf&rD@%AIAl4zN`E~0s13FD zaA*6`SvVBsKxeOU!=q@}!%a=2VTQMOk;MxDe$xz|wE_j11O59Ar4Wr=h+ji!OiNiI z^$t$GBa!No#F4B&#Z>=MN9Zk_rD5H2C}cSVtA5K#WI4&_oI*LLIO^1s1cH7fK8=J* z460#qiaQL`gf9ip%UIdg1B-MjM9WgQ!T}yGUi+9dtz+fV;RmKCB5WQmeFT6zo_KG1 z=oCKyn^$bT9h?e5fl*|CI^JYZN@}X4X-t95Djzxopi+_wuzT>B&O@>L9aqAl>%H|lx2Jhm@H6fnQbHUl;Z%5 zlBUEIFEperr<99M1V7n0rwcu-j`_2+{!$J*FjK&ZPl#@MQm=Ppt^j9zN}9}0YhVfu zb^iT$y-&lZP2~fbXtu6hbT%vtRKfUUEjd1^T(@5qz+^uA{~>8;(?}!dYm-L0cCVMV zEoK~~pFyqOx`g18UZu1~(8ff&$|T%C4x2+OX=8F54CK%Nod=-ImbC()ITC1HYP|^P zAxB&Hc_|l2e3j(5khI@RDI2L| zg+*r`tKPciBrjT$4uI&{yhM4h+GLw#872LaypWQ9NxJS9Kt>j63Z>-&B`MeqkK19e8L&g9pIY5}Q;Cn;1klQDMe_#!@2mEbXPCV7L;0rup zAAPFTUja%n+TrxVe)cfjN()E%lkq??wgXVJy%ua*Pmhg+qe34&{e32#YCc;yLBcvsH-Ou6iQkNSWT?LdWu`_&!d1-ZCDni)LQ zY@L@u_!2FRD&gX7uy1~(qeO5s1GvyB>MJO{Jy3_pr(|Un$+&?7z$$+ZKM@69K1mjZ zL2%Ac3hh`5c=-#(S9C>uR~1Xge|Y+>(?95b(7j$<6DqD*_VUFqq2ia~EP_;?lGT1C zzCNPauN8i5?O@sMR}59#u1MsAduyS#egcN5e1@$b8?`x>bvC>q|Iwxjan~QUQMy(NAXsluOB-QVH zh7;bP>-!e~c`!DUue}+ny~#~X@wL;ab{YcV^BI)SY{(_C2w8=U2oNu< zHUKl$2M8L4s+<&b6f=oG?Z#vmCY_kzb{C=G+_VUqii;-+1jPVLRla~$VrH-(+`0rH z5`**3+{UC16O1=v{FvMY643!P3N8kMyR@0{+Yx;X4is}xkGnB=3Py;C4vewS-iSV` zK7osWz$ajfsF*_}TMR7rL4gZT%yJLJ=*L9gLT3B zU_H|9n(IiQivZ|jwk2+)e!)PZ1ImlAhlMG?hy(Tqefu;&Jw zV}^VIK*aL;YNW@v7XV9X8GJ3?348|26TJm*)pe3f{emL0Pt?wdZY%V!CVx4(*3LJb zL`^3}x3IbBKKmvcY+ue56)NlIhLE{orEk^Fn-3%NVbGu$iq{QgAwyYol=l6~|4aWb z*biB-^X118P(ueYbZ~|aoWf!p@F_B=_=*0<`V~2EYe3@u;VE)8|1V2^U9tvS>|N+! zm&lh<0!r5nyF!Lt;^?|s{+D}xz2`5gKduHlUwHxnHJn6-lbqpXOon%P{a&QsyFnLEwPgCq=KYx?!07eYrba2H?Uk6uPcFM}IjbOjY%S)^jI-s1I@NME@@XtQ&l z*vjAOTQ1=%4&Vz|WE@<;Vmq{$4dGcTLw%WQ0gabwBKRVj91>;$b%`-nVjnb^61bj~w_O^fM0b zniHKF=4dbHn?dv}ck4Ey?|?3nCA;vBX}9bQ-aci?9&GECofU28uPJAS3y8YFQ5V3Fw4+)j>tvdsYp7Cmb|x_C(7S#Y z9GQ9ih3o_^a57cS4BwN*%+dMoi9Pg71U;qOD zE5f9QhKu}LnO}hw59;3k1O#@hNis~R=Y)S@Vkakk!$c(~eZ#~)PWpxk?VRv$BP);0 v-yq^-PcKPA1KmVvnUXX?Wt9-aEJ>O-h)tA{mta0AIZB=oxVNLiK=ywC!Ct52 diff --git a/basic_function/__pycache__/format_parser.cpython-313.pyc b/basic_function/__pycache__/format_parser.cpython-313.pyc deleted file mode 100644 index 483d4d73676e3d746e0cd79f14b4ec2c499d209a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 22219 zcmd^ne^47)o?o|qARz>jKtke|X#^O{#$Y`B7-Jh7V}96}A66q98+$YYnjdSBpe2lr z?HS)yQj6~{J2*45gKu0JT$yC><#NmHZEBg#R1K5LW!dCXr*^Y$O)pzlO1&!@DWvypE4@BmUhC5T!M4&TZ<_udhIzE^tXO#^Ur)B=bLUFF z*j=cV(270z($mCxu8Ub$gLN_M%2uxSB$4{PFrG-m0ktGC0&Dv_3`@ZX%%hHL53Lkm zJBeXFX#Ji7sVroQ6h~ljHE~MuvAYxN?(%3z;{mll*^^v^xl-GCoTgE~{t8FV|AA}W zfIv~GQ#?T<<6R65gW5q1E5h!p+Ij3gb|2^QD?yZ&;9JXw0;5#; zHjtfho~xmufM0ml>kAI|c?Tx^eWBrizkDD-kNQI1F&_=7%f_zq`a>hW(cXSv9iLBA zzJBkOt5e>A;Ss9r&`1ErZ?FzLB7q4q59-0WiOp9msQ#=51__3JtG*E2^OXjCUkTMi zWt=>h+c0!SE0r4u+Va3AXogxN0);%Bzc(k82)lRoREW z##E_?A17n>k~K4?x8GIIHok9Qba|{ zdIGZ+PHNDK(X8@)A>@>lSPn>|MHD_Lu?O>OI(GJ;P9a=wVRs>g#5>HtGJn6uqjAeC zUW`s!dm(@2@+ZZWP?!A+@n6>qUT`&e z7XKHWE#p@cW&T3RU&?CoEdDPlpZL|}S^QsAK3k8M|94gsOlhu{M(9gfO%!u@#%l5` z`ps99=S?N_EtowwU+I+c+Qw=!xWkq2`&&5e8S4A`sId^yN5hHOH$n(Hxl#h1nuX(Z z-|zr|4yQz5B4nrqJ%DdiqOFxYVj?*<5TN^q{cvIq))E7>ua5}O1Ul7A1i`DLwL}x> z3aDBk8{f!C^ehi~93;Bv395E-7|!@WI7)@+t3+4F@eV09IqhbTb>uq#@N=Sto;5(rYeK)axu z@FTUVpVta%R~|Ps(dcA1RpOOI>r(HwbkBZI0pF9Hm%a5b1mjQ~C7@D+IbsF4w`sP6Il6yz=F#}epz73dO3`V>fLq%@W2eEvaD#`s3YhBoB7o}`xw zZJ;35lMMPsM>o)XYZ3G~qDV%3z0`=|TYKE2(A6;t38^;#5gBn2Xf%OxOzGiOrH3HK zo!4rsZxF<_+R7W?R8*I}I-uY&bWQDyHOr=b;o6Oo?R14}H>R}LAeS5ME(@IplBHe9 zDF6pFZ?TTXYx)96rPxnVV}egN>bnGLunIm|XnQo<@LC$$gQ80i2SMePBo15}?eTd% z5((-$G}^J!NVm#s$a8INEQ z1f&;qmQo)P)n#eINc2jH=}sU+^$NZZx!%OqaxuM$(`7Td?Aa5nt~g@KZIeI31wwpP6E_}X+IeLLTbdhmf48PXP zIC@v~egA;S6=hkzqTeUWplU^5Ee9%A^!qmttXeW>#_wj#Uf}Y|nY?l~Yu`gl#dLkd zXnm{WR!2Cu@xSaSx-&3$dT#t4IiJh94l=HTi}j4Fj&n6JuBN5yT=O}mS!k2y3tThB zG*jV$L8f^qQe43m*D%F3i#StU%M~{;#SKgKOz|_4JW=EM-^0IFt5PrFkJVWE{xt(; zc1)WdB^#&v|0GMU#4*lsf^nQ!IeCFQImDbC3J;HQ!`B)3oxBl#e@yuIU`pbdGhL5BGQ(hj&H)x}uCX^w-ra z=UmN4hwdGEn6>ZLiRp%@)MlGMXus1w|H}OM{i+2kn_ImQVsZ~YG#-LVWmiO5haIS6 z>$sz*nWLw}Zjw3L74|?UKg%BKW*pt&bEtc-=wFjd`~N?=qR@ZWj(zrni+3)vS$iH@ zN;m7)_8;Zl&s!L0Tn9K8bnW`3lArJWd@rc*nPYD5SPyfoCw$=|bL_Q9ap^}F?_FG| zV2Z1_;yR|dZgHF`-jp15=hubq>=U~4-e`ATtHBHz8x`^hv zYA8(%R_?9`fEWijZWGm%M~3|X79^x&JOFp9c9)871>VAX1Aeqr@JZB_5MV*BBv^pN zeReQL1dHA9Ku}nf=#z-rfOep;13{0GYDQd(#FfY-uSC%lXMi98|KKa&+{7a0^qVJs zm60>8TeTI=H7^!1wj2+7 z(TCvM6l>Q4VvDE}E=k|-PXf{E;IH?nNtJ{ukbrM?s7Q@JIZl%*BLL70pb?n>Sa9rL zc~ab{sS++y7W=MfjDTOmQtZ*WWn?~<$P=g4Z`&7bCPm5^Zts!HBb_e+y-k&Pl1Xut zGBQEx5Lz5Nhc@e4AUD4e>Yob#H29m!F;agTkGd+qIVOSbL5~qNK=*av19#+)*&SqH z%}Fi$CAsvV))B5ay(?H<^rTDAmClM_UtsX0Nad672p)#ZZ`nlJGU6dg+Xrpa279D9 z8CWDKkx4=;BZ~9T9u_hQ)q^w&PnV%~(OP;8f(}WpDVZuz{1j+){&X@;iO!QQ$>kuF z+4s+&1WIx+ZAT8z<}b?z>VYn)cK!@cIxCT+o&}GHx_JL+%MB5qc_;J7BvY@h=7U zzi*H`4|20A^-0vq2jTDg#U_#!Mo@xe7dEtF0u^Z!MjOnnvHsjvzDavzzCe2<o_`JPr)vAT;Qq%Atq3LAtPQsqJ$BR{{Ts$3p!6_O6L@ZS zL<&5C1GF-kw2e6x^|Bvt$St8Ix2T6h=vT;cAl+~67*8nq^YO}x<7Mk;{hfFTn3-VF zknYdyDE^&z_?2*7Oy;IB_V}1`HM!$}N|Fi@ zHz#&qM|AOf6>+AGwOwqVDD~E6O&bR#Op!{cDBJI@by1Il6M3_?%_+FWP#66NNNF#qV=}2myrNJ3Cp<|Zi_wllu=pyL z`cmvpTF5kMRo#?k=Qe0;WJZinb5a`|-305_R5*tX!ag@EZe89P%U|-rn9OoZTU+TN z+&w0lYqlpFcpU~_Ii9o^8ueN3-Epj=UQdqLGvO|@Fx&;iD{(Zp9T98PHXN;)ev2{y z&YyTeEzR9u!Wp*YfF^F9z{OS*?0sSPBwpk$iGspGeIQud3GJ$DgVUkMNV@jIab7$p zMW4e-qNhSmzuS!t@d=m5C~=kU1lFu;14^N_^q_Ru(@yIk75yqWT}|`>)-K(V}CX&C>m%4 zWb$|+Shgalh(?2f$GtD5Nc1^ib{GCn#B3Mqm(qXMEM|*NY&=f9ynJ9X!RXhJlYtzt z-9+$F9JYbb6sRG9fM7cemec~gMb`q|%PYi>P}Z?R96~`DZ5UmL9O(TReH5IRmmy4n zQll>53+^B5MD0;jj{uPvhWZ^6$p*vpEjpCYdjA0u#wrBbGlcQbE#UBJ>*2tJaBo7m zkGF{QaabO2j(>!(yY)l@RQHt@WrkmHY3N3jH#+(PkQ4wQ+8+|Itf~C=0D#BC{^8&d z)lUFM6)F_~Y65`ezCb_qRX?Pe$_D@p?8hhi0EMQ4#E=j0^j?bc6Cp6Y4pIHCY#MEO zX!PcS2wNEdf2Tu1gnWy58Cv(~Cgh+^4_$3!)Z*HC=M-62{TlkY7vA3KyTt2)LtwZ*F(!x!05(x*0-8*tNsb^{V{9c{>aJ-Vw8qogNQP}%4|wIYhPPEdsOkz^ zJ=4iw<`m6&7RK+_E$&&yKks8|Plk(IxtvqCTBjRVv-0i^FvZoMsy@^IOuv-5Y-L|< zXLoe4S)J1-Bc_bE-n{kZY>+b*GN!^W%+^(lb>{rr`Ev!V1qGpe>AAD@xAW#oSc_|( zTEahWU-B@OCznq%mF-MfN0=OE%f=%+@@G;anOSqTg^GJ`EKp1dLdHrCFJ59wPA~;0 z+02uxJ9qziHM76r-@AX_{qMS$2f}B1!b8K%h3|zgjWHL-!s9e^0c`!>U@pADoO$yJ zhBb-6Rp~JVUzIedW;!C){5i|*_LE=y~_@XMj>mW-UTCpBpgFdxY1=qD`xXc}Eh*6;l zCN473^V)nPlUK1&z~sHM(8}Z;VeChjidcJ7#8Egm&Nz0> zrhMTbzPlc9?4#y;&GQ{>(V-Q`;Wg9NGRed;@@7xp-aA*rrthAgSh9Tl+7iW7wJ!UZ zst%^SlPf>Nlmq5|o-OYQdxLCwFk;J}!*5@l8(gvNT_|3v|0HeM#=P1dKJ8{+bu;@( zZvQ!EKVZQZ+5NADdoQo-m%GWVXMX%%-9j-_i16UT1B-i@LXch)o7p4@lR|RzF?lr* zm)FSHn>hO^#(rvJv7lx0tAJJhfyGuPznRHu;j-G9 ztoG#~n?;7tjzex9IL=W@pW( z>7O}%=2+UfT#T^A&ey`;ex|cOObsxd1K~lM>7<#HLH0n1&7WYiFSD7GYZ!JNKaQpa z`*GDv=~Y?pow|MM!(6Ul|AT`4i$|7gn1keMRW10cn1iR;D)$NrFf^d+e&$?&CCApq zt*P;Cw6dM0_c;I?u0^gE`A zB?nFFFN>s#wN!19Q6+1s+=Ap;TpY8)gvZs~Bh1 zLf>K`>#U1^(!bciI*+b@a*TBzdAs$oSB2%a;ZI&yVYck)-ZR}ceQ4x}iU&l+V$o-%pOi+bk1bVv99X&@?(%&8X1H7E ztlfj*s-eefya4PrFx+wpf1<_;ssP5dWdqusY~}RXj6VDBAugxlg43;t7{8;DL3On|{!|bpB^BS# z_Q)Ez*a;Ub?UmBYlmvQL4suJp(0gbijx0d;N<0+bTY$GpG-&IO-Y$vW1^luqTF+R? zWE>e&LZ~JD8_K6#14Z;1y+IH4ku`_0T1oseFv%yCgiJ>GoD_P_{NU|;$Pbih+*8#?FPKbh_N5Gv0I{@3J zB~fm3ojZLazi2;DjuZQaa@=P9<@Kl*Dj3z+sD*cVMAUS1*`>v`?B7q>AyrPoCCg4J zPLiEc+;hmT2ej7WMVgc>P&~LKIq6Q}E_3^o`q~Y{l%}F=pJ6KJ8HTVi#hG9eENFJt zDNxX)^9uZ*90#WmBfvz)gDi+Z#pK6HOljJR73nheu9m>sK&^q4cY}D1dH{E+rjoAJ z)>Q{@wBN^tl0tVm#Dr) z43L)LE9@o~&h2<`pnai!w)+F`9q;^UFhvaKmBXWNo}eoI;SK`dsGdSGFT1jORN1n9 z12@{~KLYmjAA=KLP5Pgpr#}Gaf1wxx9I^b}4-TGP@Xem*^2(UJviWf?uOggR0S~{S z-vT0DgKUAPbjOEjIKUggxERzaQ-%g`so=2BZt_4eFmr>NY zAniuf`uH+%;Z4B+e2$>^syF(50(uM)q=0?D^n_233wV8_feBC)A+=FVc|~WTUnVfvt7aS9fz}`k#XoT!_ji(yh8r!?z0|H$gZGt{k(mIG;!m^pvDuhQ4zXdxR{)9%)iJjg`?@XU@EL5zZEstZmnd#ubqtpIy=9Ma*{2T)>zMz=F)2 z1-8qQ)mS*IhRdpBvg#H?a2%Oj3N6=vKDiu%Ey^z@*{rh<_1$1|X~}-1N#Qg(j3x*1 zk~zEIJA3;qliRUq=MJ1=4xCy!(9Y&|u(pn{`wU|{v!Xe>fwp_GhCAqC4tiD&o@H~p zSz9+U8MS#=G_OZMD>A_9a^lFU#oH31pY*eW$lEjI`*#Z5UuCdY?#3|o5UvXh!rRg*;Rk%1u}~T zdaA`tjTjkD*g^ax*? z>249<07W8MXV0fmi)>XX)57izaxA7F!7S)nSq^23&Phlf-2~fu+#LX#dZ@OdB@eMYJ+_yTA+g}zxj{WiaK0CQzUtppGjKBXlAg(G5zou(;WQ-JP+ zS9ZIm44qNzD_G0pJv@F*4JQ5gCP*CJKNed@~O2 z3_{T55O7KQGKwlk4r~rE`U-NcA_rDR5m}_KA>Vaycq4$I@aijkL5v>0!kZ-EPkN9% z>_;#uh=wkGW|ATHznL7fUBVfc zgcBkHkNST)`tj&e16SR`RJTOzIqzM+eSNNBKAp9fa`sBbUb#RnI$8T+I5C)WIddUn zE?fxxbn4@&pIm2)j{o@jLcxcg`HFwp{qc3C==h4c@oUIGcTV$>>7Hpp%@LIhQMqXS zEbo)Nr66~pl{o;K`Df;l&zW~J=AEGN-Tle#PfOX7=AV@=oc?p|yzejcpOi8s%`4^> zss7hE(;mjO2Q3MMU{Gw*|_y;HNoP7Tjms`fE;ocvlWchxSAKyhhtIIV5v!jqr?~P!2X{mKkmShcKvBLj-}g$7b?5o z&wFSokzYBx3su_tAb0P=(Bk;VmzKTgdv9&t@asONtvB2s42LF|{>$OXD@^~@HH7rE zqBj>$soIdD{5sx_LLI73F?1S*+Hf}t<&mndF)X!91&6=%T_6J`oja4$#pykd12|Ha z)3~TC|K_*#I9A+FXOoVDXq@KZGU$ZZ-#(yNd3QpXY-r@ diff --git a/basic_function/__pycache__/format_parser.cpython-38.pyc b/basic_function/__pycache__/format_parser.cpython-38.pyc deleted file mode 100644 index f13219069742d4945b050cacebdcf4bba3c3f806..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10320 zcmb_iO>7)TcJAu#`Rn=N@JIZUOtocM;@DbBava;?%8D(?gtJ-Eip5=L+tN;kQ%!Qn znVw;FQzAz_m!&tS7*T>;a#;)mB!GiG4@|InzluOm( zj#GbFwfUx9W;8PT(IV-(iOc<86rSQ^F^6i4`K-;3l$PqLzRJ}(<&d>BPXnyY(c<)< zhHd($(u#SyuO6~n%Kg8%r?%p)gctMT;R3xxJIND%;z&INMl0#1{Ny3y$sJW(_l$Ph zPZGT&)~%%mYDQS|W9^KW-Od3UvwCCUs6!?pjEEp$AF@r&kMq>5!qcB>M~o@E*L@AM z#tzwzCjQ;a^UNVyL2cZR`MR(3?5A2Q<)^Tzxz&IPLHqMW7EkdT*e+^Saoxq`CQ;C4 z?PzzE&(zgG`AqqY1#Dl?hRWjj=H;z=&AzL|vUf!sSTFEOmN&=&n8V5o0TD4SL=cVgop!;iFO7pC+OT4&$&@66L zs`k>gY8@hWuLB4%kF|U5t54_|eM0O$(NGk$AYoVSn(cT&vRo1#jkgjduU;!wc=N_b zCE-+@%@@tJE!(Wx&Zf6z);CP*cTHY*E_x>TDsPz{`klSnx-EMBo0SK)gBfBHJb`HD7BpE??V2HUlj*gV&2GKI8h1`k09?fvNErUcg(v z($bv?PjAP3;I!~EwBtt#aEU@(jv^@0g_2z;)rE{Mlhr@Cu8)rF?J&`cMaCCE$iEL4rYZ5%>{ zLLIKQVsbC%py;6SXB6u5!O?c63uS2qeKqv<)xh0RjA%b*yFc@xX&}dW$WasjDk-0C zr@h>^Ml!`-VC)g=N}P833}!~(P-|~}Lvn;dRkrg|m7fn)xt!$%P=P`R|4`%|`j6FJ zOZJkqm9LR@r+<_*NZOtLQPLo3Q$NOjrL;T!qv(^gJN={R_jcp|BW>jw(nen-?G6-5 zyH^9h;4Q?;@}9J{-9I5|^97itv!!3d@E9h3@0QDU!!yeQnzd9lrJ=uMlJ=Eqxm?*W zNy(Y@Jx}5?9N-le&Uai(2Qmtb}mnC`Hv9J4iC9!FHZCd;=_xDZ%Ae7p{RJaG$)JEioP+4a<%T-i{Y>!vFDKA<^(b6*9FX25pn4M=EY)X zi;HffRJMzoqQ2LVmbUo*cW%?7FW$L&(P*CUA^x7@71v7ZRl8$7qpi7WG~?BJy-_rp zXL{K7?q20Vu~B!6F4hDO!8krWs<-5lA7M0;SIzxP2bTc7SS;JsYB6*nilq?(kS^8` z-uei59FI%iAl#B_Y2B_G z$1{CR-a*5rOvDEUneTxR}Mxit|*QL1Cr)hzEMP-Z(&@V%rS~C2FM|xK>w# zl$1SboggM44_llg3et^=hcB|2LQ9aOEulw8kee!Apf>&Pk8a;CB!fhx-{2e`)TAF~ zCFy-4##Vypt6%^$dbwodU zZZf#!qW%TT>qeI`M06Bg#<9!gByN2ifYFtZ*YIbuXaDj{NtEE zj@5j*`?){P<8WIt()IqE?kolzb2$sohCMvdn&7$Net&{E*-1Q(i`V4rV|}xCIF4wA zxDiK_{sd3FskA1%bUWiu5)Vi0F?0XiAH&{7b*4DEsGgvH>0~Yp@T`jO4e*pFe%e_+ z4u_IBNB;=UfpvoOiLcmw@U%=&dVS4=4|<=lx-xx_dqYt~Ai- zJykI=KY-mYA;LyKk3Zwtia*weWZd1KNY1k~r}%|G(VBAnfw_h;=OpHT0`rvf;WIHQ zgwe$`FsGfxXJVd~nAd?h)0Ks`+eCd9StfcT&pe+*4w+9Y>AXoVJm09+q@Qi;-q?2D zpY*4Ydm6|66osS`n(jg~1VO7Te5%r$)ZsPKDt+&0sIBCxRv*WAUwf?Yu6E^r7VRT@oJ1eaixTILPvGSHaL!4bza7BAoV%7GI@&t~Yi^u%*PF#HMr?>`<__?3vFNbflOjJzpah^L;Bltoee2Py}J64^A5q!Tj z2TeBV&uvfnV`!iE=h3>8o>CA1P@=w!Vp~NasBADvQ13OK-{g$7U_|tyg&%vl_ z{Nks&lo!HgtRpUd?#=MUL-rFgw@8=>0eeuG#;y?T&GJ{;bAINCQ4Voe_(_=GS62lY zJ!$^dR>W0;DN!gr_2)|;3QFvRc9+sSLl#)%(DtX)P8q-``PcDkPA30E9$3fhB`YSq zdhuJR2CVq4aq^cPPNeUB%*=0%@?Ui+rvSkp?{hiH$wiFPO?JzuL%asOZ*BUYcc*`y zpx>tA4JzKG;w=>4+Eo8Q6iJ%j8pU69D8m0E2V1<2!paTCVng-yp?cXG5Api~DJ$RK zM1j*Fl8lp{&tp0En_N+jtMC8w_^L@?I650J*?(eSs~i1TUaZk1(dSC1yO;*_o%h z>^qMLuy^U8sH<-QH-{+Cc&a{4YDtVYke8oE+sP{l+8O!J$rwmx^`GX|S@0M3A&dV* ztSQhb&<)1J^PD2mOU1H$mIuN=(*G1v$w;XI>w!px68T8(w_%Bnq~|N^p?9HdtCX&X z4LmOcLZtbD7vnlC9}>;7B%eP#0bBHywiYG!5qH3^mQf5^V@C)vDR$!UP^2BCISC_8 z{t+>=-O|!i z5Iw%y)E>2rrhawliqSm%jEqaqxJv^l$W{%gandoA)NtvdEQ+t(&$k#?+Q{A?u^%dM z9aMRWt2X10mao6zwpKr5Qbl~G2I|rkK|%BIOOv{`X*>3QLtOtaqDvF#IE-wb55sf3 zc}Uj|p9gk^oHUf*gvlf^xpZZtwJLr@oLMK?6Zg>e-?+MriH}fgCT>*(6oX;uluT|K zH&NPdrQ{gR`Huv&2$X*}^K4T7Y#HgNLMGJZq&1_I(gRR|s;cAMa{Ztf9b}5zK(NwC z;C5Q9iM~p)RIBg70FgQB8AWl5s&`Nnl40^A-N!?iZc05Uxs5PAa-XKq<@Aa7zW?6s z_g2>|{oS<>ZVwA+{DdI>14G?6P$;C$pc&%O4d>)VC22RxD?wL14YN;FhbBTjr{z@f zV~i}5pgKP?>|HhSJ@j-FDmbL4gbGp`GE{D9P$1#ayy|P+GzyO9kv!2j)Pi^gr+vCF zPx*EwjKBn#^uporUPnzw8Y32^QP};Ul?2TcE(2E@L1hM47O4{rj&u^QGfFxm>Pwxl z`!}Lfq(nP`bWb=t&J)Ey@?i2ln6daDGf?vBPbD{dC9@wVI0an|UH_V=Y|Lpp|z%QCo=Nx@-h4P znBD(s4R{yNt4j~=oKX8Z|*I$E4l$s?@ zD(_`^x?XbOrL5!231o%)qoB^1kvjS7ITEF>;76>ZlOiG@y zoRnNvxeXKY+ZqLr55A*0iZ-Lbj@iWm6W<3GJot}s3v|+QLs~{M@peBG$LvxwMp^ge zW?J6No2V={_4}91k58~eJW6&boRREYA$};%zUGzZk~W@zj^ITR%*cmQ{1KpIR%oi1Oia5jSP8(jue^(l~LKt*N)FxdlXQ3|=CSVa6jb)Bc;92GJMA%O--B;vymrOQHJ z$ngPC3R5yR5Pv`eyKcYukgBw=qU+s@HLCsqg_S`dg0nw-mLn?r)>y~+?lhz-4uvXs zYpo2xPPc1~MdlM{dl>0#g1_iFwz46ysKe5+L2kI@cWGKOK4Bq0a~VG&nl$a)9|DV< z`8-USt_{2K)OJBl!np?oM?Si%3vj(>HCd_|Hp?>V6mX>3RMp=kV-svyomRyTXcyEK zak>AludK-Yi!2}{@;wBxQD7@@CQLv7)TcJAu#`Rn=Na7gh-lBJd`OB`FvNE^p?xUynd60vKnXvfmp**4|LaH>fT zHPbV!Zc5~+=d$$P;zR**$z?G#kRSqb&>^|xU?4$GIqe~bpic`VhZO?sVHXJkB3R{n z)jhu)*>d0+bj_=(SFc{x>w14RFWPit`zs;>KspO7K~p<{3aqKF4!>0#J(2 z^8%j&WbosBnjZs{=IpjoJnUb$4Mm0j0%Q7=?B?RvFR zu9faNjR!TGZ`c(^Bcl&3lAh~$-2X-4DLxjnueO=bI_yActDfqsT%A|;SzGfoz}h@5 zPXB4xrf(?in5X;dKD(jZ{fkw#9d9SRm=_Nh=p{Nyp70X~>OL^qNiXFm_Zd&#Q^i%! z=%oE5(K}$hT570fgf&0b$#~hV9I!E~HxZ87XA;7Q2mUxk=L#*Dg-Fsf6x!Tk@u zQkA#f{kSZgs*RWSX9P9FI z5valrVs5im1tNG7w3;AQbxR=wLA=}qQ7&kKfqo6V2C?;8qwEE-rT}>>d8<*gD?2s2 zXap(RENrZ|85mWkx#P**3u2NRYa%3BYOI&!BE&3ciWwptt2*4?CDt4}FeE!Af~{md zqTZLh$kK@HkbvlYaTic1%wS2DVbe@!d3GG%%rli9HOEXA|E-~CSW?ZaViCB@f!3<* zR%mJ64-rb^k%#s_(Bmsx@;wxJg4!yMq~>cK#^vGNXS2{nvp8f-@fp|esgIcWH83?E z!!f+!D{b8=czP%11E-B6(TN`@z$FTCIf|e}4@&l+R1Y$GP`U?YpgQ!e1kXZY;2GkB zDCP$GV`KU!2Kw`3`X{4)j2%nAFwj4a{?tJKIQso{PWRSytOw2XpxGWYN067~StuNR z%h-n^g;HE;$K+nlL)}5|&nndCgQKlX56aRC25RUXsDZnK7}0*jwtwbB-$0JXAxBO8 zLQ+1_Nqf01jbw_wz}N%UlQ`}4Nz9DCk=FkDM&t;!s%+(@LO&NObUDjQpaQiH{-Msh z^dG6)mh2^ID_j?q>a8v z+FdA=cE1LG!CQ!x<$Y;u+kZmR=BHqsPM3cJ^JAF!T&q;b)bRWC39VrD`rEOWF-cX+`amexeQBV zFUcv(wOV8!FyHsgJ7ULPdQgSk0Ak(t#NOawvH~DY+z}Xd3ESHd4*06Bm%2Rn7EyIf zG6rU75lwfqu~Xybnr%weZtl1+S900wjaseofNEVUx+KZdj(y$7D~k==dz)<2!r2>3 zt5CU%;Na}V3lA<7zaf>y2cqiP(wrXufh{jG+TH4b0zjKQgec^-47mU`qKH?7?uXLxpRiYQ7VrS zfOND*@YcrQji6MJ*G*?@sB*P5G0W*tUZv%Mo2<}k`exn4hlb2uq|gK$e~ zCBbfP|F@l`7A#aXOGCk6pdli86 z-W2!^C)+y0G2Lefx)_2K=|%X`57-qV6OkZp?|N{Di*f8%Xd>J` zl9j?vxN;K(0$as0NP*(Opm}ubCVE>J4rl=1rR`PVxh7c-KCHNjnz)3b5~GgjL+?rk zk6hHxvAkYf8AB{bab+CKOyI#nxC%FPfg=iYbgs8kFw}ABe4K>yVYJhJx&v1gG2-if zqNDrL)!de_7b8sRbj5ibwFF|z#GFEI*Q5B|kJ=f`kj0b3GqIuYlfQkF!vPf)*fDvuBGtcv9!o+gQ(PL>zoKoaNZAHzAY zj&MHp9L{^Vqns0Ge7chZw+Sw#aUkUj$M;`{c+|uUKQ_dr+}ZmIoZ^`=+@W`nyU~$h zAD@w{`-kp6!nqYv0<+IqM~~|0I))^mx}v|XXg#1jvYs2GJoWZ`4qR!V(SNIAVrB?? zP(p-_ejfjdcT0JskI1;Ul1R?;G^hB|FSMr}e`u~z%rFr7B`}XUAATj~3DDF<0hlw+ z!dGIRl$cAvobAa%+rE$b9CAx^B42tjiR?3<_M>}}PQBQu_LQIP=-$Lu-k{SM}bb`1ax>efzzh{7tmS#1d#va-2!< zoR|6|ensiwoRK)69>GZr;GC5>KRJRkIe_!B#QEzX9L&iF=Q<~G{^bbH;9RdroL?Qm zk#q5wpT4h%Cw_`+(1=g`G(R7nW|^2MAgMWVphob80eFT_Q#)3jg#rA2dmiU#%Aenw z_6h&Ee;log?Gt`3lFt0rG5-YZ6-ulE<>b85mj`y5cKdZdDe-6gyq|+X)%e2ax|A2f zWUM1D{>Gc-#eMb@GPOvM2myQW>I`;;U~i7U)|vOA=_zx#E&L=*@5L2C22Yy4JImrS z!IU19o_SG|?fv4#vWS8*JE0M#{LY9u7CE&28MRXiFv|XQ-I}A>KamI4A$!@1Nw;47 zE~)`5eQTWjWtS7_d>=CNTci9%k8%nS9P$C5le}C6DZQMx3_HYkf%mOV|F^yAUnl7I zsCa{lH>r3F#kV%qKN3Zf=C?-ipSl#`43dW}-bP{NhJ&$@`r1gnVoiql1Cf-KA8ew? zX&6ZcN>AxT9J1A-E|zF+^7pMorGYenIHQA$S+fJMZ6k-$a;S zX>j`~$C%b>gdfNV?^KW%uw8SrjLgQGZ9ArijH+kz;&Cxg6C9^P1`P%>A;NPh?+`{S zGT`_g^`D~RG!-TlXQ)uAco_xq%ATk;1AVJebz}lryh45FsV^byW({$*b#CwtL=N9^ z*KJWkAcGh%j5d*TmRlW$HHqt?-yNiG?oz4-8GUhq$X-EFJSiv(Y-N!7$J<=O6;0({ zkZ?DVir#6;Z9;exZkmkcMUjS5Xn9GmrVyECSyB)XUW6_9N=J(_`v^qfXd|%s44KDH>;Um51yUR?i!_WhCt;)` zjQ9}@#xpRE|0B7{$}th0yYIB)lxYQ5ueWjax0UqCjD9IGiNC~FN^(fQ#3TLox%_hb z?;H>L?Jz|8RF0Y$qQBgu|5l$q?levf%wR+_7|{$ye+GmSr|?WWIWNug;0NoLrk^6} z$(5G&ux+&T%Zrzc)`_o3x%92QG=PF^)sPw|9Ya|SmmW%^_{!aUn{lOs^!)+*NP#b* z%FA4}6@R#N^$oYZ@|Z~#@mLMi#Y=)B=+Q%yy0T$A_HI*L{V(d3uE4o4PRFT3u6Fb? zu#0}wNPZIrlf>ZSrS-jm3wy_B;_iL)i-or;-Ihm*F9a!L0=1uCk63(K{=QgnkUR)Aook-6=*ScQQ~ zsa$XD!0?b6>Kj9GjH-W#qL>UbAn5@f!aP&ze%Wn?Ig(Fl3SCa0xcS4Iw{EW7vGi+q z-oG^}n8_nT_)iRV$#YOhi$NpAp%c!^hf30Flu&|>dKRXgs16N;dQQu!qJ)tZ5>od= z2EC^q-a}6>tAaav%Bpa+>t#c6gg5i5uk~^%xS2-=MdMHh;#FMp>At+)+mlKKB#tpkx*QtFp|HE}M%PD)P69cf zaCTfIiht(Ga6GQFh$CD;iGM>|mgu4_gX4Tw_8{96!AXv*#Rx_|hG!1cPKy3Afyf}? z8&FnRo`X6|bd0XDLPVCuh}r-pg1eLmK1wMKN)ux!5ubUnjVtC*A{ z3fn)6q%jd)CSE~onnE;_#^vFJKZSiyQz_M52$4z+A@G4%Xg-Whq4Iu>+#8L|`UOO0 z$HVsRO+N<3cOr^sWF|d!8*YZ#*Ko_N#BxJaa4U4DWqc4BcX#QKeR9a|K40_Q#{25h zOS?h3FVkjtU!iLMT;0}hhl&mQjy$b#>Lf6((7g{Fq2|@^!W4$8?iG2ZUUuP`tl{qx zNCgi@K3)8=did+P5o^!kMO4u7W{-)~&F~)-v;#{bfxaUXk|!)DBo|dKzaaloqu{l{ z-)9b@O(?QMcA?0`EnvZ8U&T+LlXe@?E|P(_y9^w%i>(-?*q2&q`D0#3Wuc|ty=Z=N zgcagXvO=MYWaScBArO=dIv*k6%$A3l5KmzJ?$O~V%zEACEP(;&PnrJ*2H~G7m+@*&>B(m>eDAHqLAhh@ZD8*?R1Bef4 zV9&i5w4LG(6+JIr{D`U_p|COtIB>~_|KEt}t~JqhvbznbszV_OPN|gv*zI<$iO5{y z5)UJt4e%Gezg9LR7Ij!UHpmT^e2u0hV-ptg|16V7M3bhSBTG+Cdmg4s*GAoSYCEqc z;lu-i!yf&rr{H4GX|hx^Y>s8rY2ZjRsj9zC#tLjnol!*%w2SJp_?^LDeR;W`TY->> z1My?RV9RhGOh7?y_#aha$fpq^Nm~3q74!!id1n${W00dRt*)GPVD)kz diff --git a/basic_function/__pycache__/operation.cpython-310.pyc b/basic_function/__pycache__/operation.cpython-310.pyc deleted file mode 100644 index 89df7b9d5215982543a19f24f09b8b7813b65328..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14168 zcmd^G+ix7_b)VbLUbtLR6fKFe?2HvV;@Z53l;a?dV_Tsj$4xA|l_EEzsO#a*aJkg% z&hndC-a2e9av~Y13lv2Q6lnl5ZwC6<=N9NgANvQ)LkhGHf!~Ver7#k!zjMBs*(9NeDE~n(!$11^I<8RB6vb6+#Z_I+S8bJ_ znyv9uw{<*q-{=(VLdUetPSGxQN_MGJw#%K0T~TGbF}up}<93ao6CJ~z@^cfl=sPp=tn$t}Nc z*k5of-czV417c=Ik@q!iMghc_TLpyWj>~@B8bjv2u9@zHI|+zq%9^|Eo^q$$(_RU% zMeibwI-|HncM4F`?i249>@T`k+;i>>;GcERdovhw24g(wUO>%r?o&Ji%vO#?GXO-1 z<^-sx-7f&@OP>Q}xz7OV66QPOUgWvXVDvA#&!TqLea@?R7tlNQfqVZ_4psJMFc(^L z$-U&x0=n+jy(cmLd5rzM`vPj7PiuHbQ9I|(qxJ>&MPB!!cOKABW3kUDxUa zXxjH3VFdtnsFQH6684+xmJ?dxZl~i#V%KtdJ-^-L&vRDT2WZpteXHel(?)aLTgwXw zp8IzE-LM@lR6Zv3y-_z4{c3kNF|M_nQDS_j9Y#s%?H-Ni_(|bbzvp{N;k#YD*VUxh z{Pwk5!t?4{GSyu7IvCn--0KG0zUQ{Qr0P16(`fom7s4u!rmBvbcR4(CK z!4-Z7x4mMd#L9hDRo(^m-&71GQg?yv^!Gibt8OcMl}L%S13gv_)gy4vWpK3b+e%&E z`{$K)uX!)D+H00$g;+FgvD39yJ*(Y~Jki6PJePLSvAX@vswb@3H9v473&*l196qG_ zLZxEicU82)m$^i5;mg*ob?TPAx31TXq`;F(Dz_cK?|oBe#@f(>BeVsAHE54w$&o9`^HI$^un$dGa|$FhaqZc=WC4Y@TJ z=#b9grj<*osa91()gEbjO{21jvN(;8Nf8Ly2zuejLDRH|Ey{z&vA2|c6_*w(*Oa?I z+t*|LKsi(nm_;6H2b!xNs{2Ny$HsltHSX!+naFSpvHml~HGi((SAVQVg#*(q0;<5J z5=v%NBzW}L`o6Yb+%H9?gR)zWONZKiIWD8#-mk=!gUX>o&+yvj*IYce6mc81WA50Y zD&BUhUDX|r3vMkgu4%6*fSF*J^&G~@VJ35!Qw+10!-P4^R0b1OviyUr}_0#=}Za#}%w}TCQ|+F;)*IwO!h&z#AR~MkRND@ejtc$6j&XH*nkK6mzcKR#s_rJmBay}x7Ix$d?e>r7Fs!5 zLE8mCb3MxiM@TsinpnO5DqeHotluMaBeCOU>+0_^&slGwwJ>MhxwJE9?ao?7m}@QyZ!Kfe+1^;QZ(%ed zc=`>uTtdr8INi|a#?}^yDiEPHyOK}O0TmC8JD}c|tv8t?Wvu(?0nssZ9G|3u=gQ$f zQLRMzM4kIs2N~^6-#p&*(%$$bYS`{>Ieyz+_(&5(t$eF%9}}v+e-?2o(RIE?19^m!15AJ zcu8@s-E|=l(FVxcODYhLA_}+L(K@pmF^=|Pf^L)4lL;y4b=~@yn8vGpDm&XocA&}W zp+jq=hh|qK(`sN4yPB4JG!WD=qLQ6PCoeSI0}Y5N^ARqd!sNW~-Hgm)@jiGSn`#?xc+3E-^sJ z{J{VNDn7DVBdmz_@v*bc*yrN}{eYzbW3yG0UgYmu-e$iIx+G#bdFpb0^)jEWcn;^3 z6dWO(-9+C5;_Ge@CF(7R*Ui4?fmY@S0tiUl9<7$}Dwe6P3Y;IH`#7U}hUm%kU>=QL z5Vl#Fv`cgX^q%NwoD{f?UF>?>RCyGypB>&n23_++1I(%+x4XgTiK&D{+|{Z9yq*G% zPpXC{UZp7(Q$2TYdRR+bSdrRqQX$2)0p8m3n!}h(S_G+!!~(+`K(-kALLcPsLSJZU z!}pfH552)HybqXbP$)mp9;j=oYoewUDGxQ$67H+}8t`9_RP<9MQ)7mDq>@h8S%oQK zfhYv^R3{<$>kz~R=pY19$|0yU+(%{KgodNVOiL~3HlUhYn=6~k7{B;X-OxdCU-&?Y z>3RMG{ekj8d!VoBZkgw+y3iYVz9rD078x6bhsuT-m)x-rHSTL;a}B)|-6}MtpFh-S zp15NM)AYn_rj`9NLBwU0DqYR3?pM%H3G`lxi|+V`Iw-j2`~aJ}0`npg{U+DK6Ggoc@7L52-gTYN-ACFEY=5d zc}y}n``x)Szcgnp-G#Prq;MBzl?Sdux)J0(^_`1zDpTVCAYmYbRt9WTu*&%#Dv9;l~1ZM1>|oAU1xJkl(_s1xx8}b)OIq ziX`spjB-wK!Jz+A+a$G6<4#zFGYP{|(-tl#syIr}l~;&~G=om>)xVd%pw}#7t+P^xs#rSsv@hKy&Vg zL^!X~+*co>-E_LL!ThntZvZIK!CF63;ZbRKy$gNqk5t?bo5dqG2MgMC|O z1SJIliWChZ7I3Sp;x#H?r&~VAU*jz_aWj<*s`*H6xrO&*Yoin7wS5T%UfTq%P43D= zjrQiO@(^b_qdYmh+K*Jp;lG7ubv-eN#U=)p3ymq~@JE*Z3&5tcI#07AZ4fw6^n8z$ zHtsMf_Pb=Xck8-%1K$&q+y$*&$R3OkLadM{CZyG);W9J@-zOv2k^5ma)YF)y?$^hwT&7g#sM zAD|mK#J8zLo&!D%sE0&M-0(-J3}3;mG;v0)X%(2M<|D(nsM6mQnXY)NK*6fPc*PeT z&qvyCDy4H|+G+;Fi+9mMQc68A-X~9t$g8)cW5ydfWl}?nxK?n5e@{Y6A=8#9VPAuo zfar16d#b3#AZ7Aokl4z^6qGPh!IT5z5VQ=zqaDELjSYx0<0tSNvd^k;TFGk)SeK6+Fu%%vw4L*!{6c+c2S~!l=jISNF%`@r}}c4eiPh!?g`^3QS<9D=?!c z4mFt8IFEYK`3a8fZP8ABGug>ojvv69Y(p0G;6`e;d$5k39mo{unQf3v)OJ!aKyD#; z*wzEr>XSiEZlAZ{h=Jw@*f%At6~S(RlL>+aLIxc`Ho-F$bV*%=_ld@5yHW}T?lZ-< zOW}yd>j(@Nxb3wzJX!7TJeybx)^&8=3GmJJU<&&oT!Nk_-B;^T)O&exF%01E^0u4> zF#g3CmtQ)`%atu(3iJ%$Q&~=vhbK)=mh4oP%g2mzntSMUXcO3vi59F1SY*WUjxq~? zU>x*BQut&MNW$SJk=7a9n1L4e)kxW3mMKMwHQR{Z3-~uzUQLZFue0x*)?)e3=@Yf_IU7q~021RlZ2?hF#0Czd_h#XxPe`d;4+NBFG=@t)%#F}2^px)O>3bM$$D^E zpotDXrMLzw#QIQ8foqIdLu(@j|A3lY;5YDJeuKM>zZv{i z$(tq_GCMCrtJUtTg3W@S?=4tYxijBei{{rsxaf1>h=36vMqY>Aq9F)1=W?S?%`Gxv zuFR*b#=6n(yn}cQ8Sk`Co@5$^@pz7dbv=%0qPro_WVRNbsnjl?0x6_!I^}h#=v%0& z7p3dK2sZ^hY#nFE)_QF)nRYkPo3Q1H4?(MQ@UV+pmm36Umr+=fMH5!H?=9K#xom5j z=wWjg*B-9zqZEFMI2v!cmvWDPL(d_(3#TW}!R(&`J=dUP)u2;-OQk>ax1xp5$rvlx zaFIEk(O>~-P^2fb2KC=4#HscatKZl6^?d{8HfTHswc{O5DrgxBjB7K7Q;K9iYRmf- zXaEXy1PTlwf*O}y?KJ2suB6W~`E;Q*fO;yp#)xi6Pe8yy5np!;H(`*+2EKrz(3cU9 zGws(QLx0fY3AgwPbRTMsatY-E_cq~{xxUKv6|N`!1aRlTA>!(j1V4uQS~?0K zxvq-3i47ApH||^m&8;{-h5yKaC0%0RKAMk$`4w<}FKokUFFi4E8B2e(^joI*HHxl7 z8&8OP%x8Y*jx%Si&RI<|rX9bxPOnfS=B$>}={R@q&M_jAkRJJMBXZ?g@Nu=Gbr{lw}WvXH_Lp_(%dU|lbDdpKkENWReu2`cl z9^IOBBZH46J0$oz%62i`sFc(YBy|#XHBpB~}$n4@pTWRpEP9&%jgnyrDfR=?0bI ziLUbBIdxh!HPHp&h&8cx~t*eLXyU^Vzrn{kS z7`sp9@mfG3UaL7jJ&x@;eiO0}Q2%6{7?w0rQUC+DfDq6*W&Bm{h?Q|;*IM1R-W%*k zCaeD6)60!|fsb1%uYZE0O>}IfJ@%b;_p0CG7&It>5n0Fn9mBEaGU>O3ezSyQ#M4VP zj%#Q&U>e7Zd(>@-!Xw9m7R3Ka)z@)_u3Sg~8L$A>k+lqdeF$u- z>&_yE&4UrgqGe#Eq+q1+fzV^yYzVyuDnQ1aMo}VjRS^Fi7WY82lY1FNCI&voY~{la z&@5Oq$o2UAsKb#7{NBNb5!$AmH*uY&6Cma#y;Z3n1M-(VavK`LLfCZ*3k@4zhQu4P zVWc(l6_ns&`UX`HqYz)G+naP7wZO=&9}-V76;qpjWIoz^8Y-f{+qn6?x86F*1cN*1 zGIjGJ-Evb4=*tn2O{)DW-Cm~K*XYIr^3deKODco@U#G#zz^GQi{512^N_nt=C!dLiu5Ml$4B4I&dCot5YP}@P_ko+a^aKH-zy*}1Ckgqxa z`q@}J2P>JfXh>y-fpDvVKO5Ge_6j7W^nxyea0r|StZj^R zno0Y74}NF#|G!82qgu-#HHV_~M@V9Hu|x3%+QiI!9nA5B9)544l1hWgcr@?_*hbX$ zJ=Vqr(}pQRV2z0A#z;%3iV$3%d82)rV&{}=;kjh>lQ-#DT<+`uksIV&$pCL>Bvm<8 zQqKdgMW04_1Gi%oPds-87yIy`=c*hSCI<_A)0BVsA;QKHS<(hPlhDBLOKwY{-y?de zRBtlDOYSodc_}h3?6#d2WtGv_<1r=j7-jG|B&@tHoDfn6hF_rt-E#Kw9RBJVt_D3m zN{=-@U^858*6h+8{_5E@J|}!daYu6bOAiXX74Q z&ah@I5l?EMs%bEQpORl>F3RZ&z8Fr#UyeJ$oYCt~B6#R4hLO z@1nFN8tIH0c?)@70U--SZ6p7MzM}_`p5y=b%f<2q~&yx-XwD@MZ4o5Niq#k4iYR0I=4XBRX}mD zer9Zq_Cn`AyL9<-eGaI%L=Vc(oV|P*2)5>Q+MM|Slpg7HRB_0U>$6ZSx4$y<`7C4` zWF1I+9ZdO-+knmHHekxZ9|>bl8fd6Z{V4e+<@Ov*j&09FE*EbLn^Z_#^)m`voo~ud za^+!Q1zm}~BRE2K0ZGG~xQ;k#!BJob4C@No0MnsSvWDQG-4WE~VBUu~Z*d17_f`go zoUNBUs$>Gvd>cIf2CSp>z)#d1bPy4Gly;po;&f@;ajk_8O=K$}U&2`p z*;OluG{qL(C^Ms8W&UcHv)P64wn!gpQbbA<<=t^WF;O?fchSHu_5$Ehy*rA(=sxSK>ASlOWbm~LAs%Ks35k&I-;sX@UW5;L9)is9X?MTa)pvWp%+2J zBoB|dQT#xOO04TsB1b7g+(xEBK$|k*?SYThg)b9HN|m^Px>8))7=zcf_@N4}Qau=# zX`TnQXd*)T9p&OaGLaI6vC;ISwQ0^Z55agUBu>~g}l%|fi9f*6@ zR@<>|yYE`?613a7&6-DzU}pvXR+c%qv|xR6rw2ZhteWQkG}5do+lDzD(; zwD2E6{@+C~j_kr2iK}7-ZLcGVL%-+keoRrL`kBON z2FPbGxt`bKQnBOQLx^n2KHUirDE^L{^1{VOIM%%= zG39?k*kee|6T1*XZ4n};b;mBeK?Gvg9M^3yRbT=Nj_NrBe@-y_ezlG0S8nfiHtKnvhAJNuZk`30TeMSxYND4q3b#Pm$&^f4A*ZE_%gwnog1&x$Z)_#5-;DQ*-#Kg`sm1yChbc$ z>_}y2$t`!5-5rA)A6w}aqzv?Z?^PU=68;!Bq#>JhQWj1}x_u}0DxDS1NIR<`UuH_H ze5$=fjyrlovht_edE-+m*EH=@O|Sg6a%JiT?IPvYnDq5hUOGomd{Ox3Pq-EKbsD~xXZ#jO*Z_^JYacpN5If?Dq%TVMDi7 znQFIgS6XPh)Ive1h11dtCrcwfINL)X5i~eY?$fG;=o_p?d&i(GWhre{W>=cBzemg!K=%lFM<3ll+3X!LO1CdXt zFcqew)ChfvrZH}SxZ#q4#ElR)UNVxnDQdc8hI~`ha>+veTQ6C|=CEbRpr%8{K{{-` zMPFjV%q{AYEo=j758Ht{!VaL$uoI{&>;hU9t^!&et_E5YHuY0oHD6=nx@g%M4km(u zP&CN%EDz7Mp<#9;5(-8G*JAPOQ8qlpLaHU67-l*C+}HmN;cFU-h72g65FGH|2;Zl} zW{Ua|{RuThr3~;jCaJ4M9x7#;qM^<7Ee29i1PNio+G3PCNqr9SFG|mZ->Bm^J%it@ zLjarsB5!2=h>1QQ%us3#d%y2l9r@3$qX4khk5&=G zxp>l~EukN$G=0U9vLy{CDTtX;_M|CkE6>-}(%hsRx*SWf4t0-{cJ&jEUgdey<)~(4 z0^Vq4RwrOxSFBexo=~sqsZg&fMV+gg?-#FEHK$Oo>Zwq#Dn(bXI;Ud2@J=|B#?@Cr zqPjTh6m^v;#*>buQ$6nRr0P37sp=gv-nVMkfOdV*unuM{@O6QXD6OL~bB(M@?wUGP zH-oA=bvdi|71ow+F0>L>)Fw(pebpllDZ(P)4R#)eoPTj}#netS4{ z;PTr%Y#NsbgRyYr@=2DzmWYpD9=(|uj>iHg`oDL1Ajn5Tfx+=uC=rRrF2_e%E{L)1 zqc^{`ABsl!L?}KodSs%BABU^}RZ`BF-Zw^?<=1N!L&cq znDyV@kR6lOcI8eljNL!G*dlE`B(6QQI3}(=Cf55)+)tr2Mf&!VL#i4$-H@sVn+z$n zVn}f`m2t&+?=o_@oyzn8=Ys#m@#L_V9KL|h8V=;)YLK{*p!ApQl&a~vPU)Akew`S< zF0GSQdKP-go{>(Kd**?jbs1zU8)Zk>SVCq7xnL-8!do_0kJ?dm`FrItV|Ef^fj4| z%5+3FVMAqii9Jy4R&6<{)ImE+DRhTOEHs11$@jhQQwOqgWPAA?;&uQNs&D=j^bB5_ z$MxRHvw1?Cbu(utPdzTaMeu@rwNHBEvT)`5(i;JI66@ZV>fX=)iTTTBL)`#9`NqF9 zR8?L2KqfeMIO~?`+On4x=(|@Ix}}YK7HM(ge$jP6avfNvjOLw$Owo_r-tf9 z%DM0MNNu|p`lPmfi|vbJzv&d){12`QeQ*7CL^%Jp)ccNbF(iap>EfU;G%Q_=h`kXZ zIwJOth^?{w|8aSk!PZe)a5gP7l&kTf)3@aGeYRd|*(Dr2lizv#LASK?l+b%YxOhqG zeNVU?6XK)NZpfH6V(2SOq48pk+LoCO>)LaTAymg^~|`qoF8wM*7k!P=Vl zte^XSc4DFBZVCpjbKm0T2YtWU_7yE1I4yRb{;E&xJSTeI%zJ%V_uNRh}DxcB|X> z$(De5p=$C^FHgOi%B@T>n${r3Q%akXOqomj4#3fv$~udj>irtf{9?v@r}fzY-l{n9`Fq(5~Bxp?BL@AEXRHj z+zz07$L?Kw33BW*{#&tHt{Y;6XcQd|6ghV6f(TlE;LuT@q7%hp@V{*BNvb9rjuf7L ziX)iUz@Zqqs(^U{BZ@i7P=LZWcvN{lp9%!GDP0K^yRq8AU^f)vTY$j1zYYt*9LUL;r{M-ZkT!L$cT-gvC1uV&N z!JBd|a)e-L;g}L1B+Zx26D$|!W#dF-RHo0%rf?)UBwM19STH&yn-D0<#&A3#yGO?r z=_olc5DZ=Gvdc`dU$PZJBvdXl7(pJ{q70Ah=2$?Ikq_9wXq=DW$w_8jhsqONh%&g&*R{iv7&+$t92vGodu?C@a2!gpa5EI@I7qkUTWwT8(xtb zUJ*^6$ycW?Jh9rRdVY|)l}e8-SsMgv!=q~V%#BPl{s|3TV)bUJdb7Z6e&VX0xlovj zcj3S2Ju3lqos(SWCVTS@zHHCKwj)bzN5r;cQrj`H;h13ZD6%YKb+c66Jo(y_#toT4 zp<|!e*dsOeEK?Tqk#CSpS529whVx8K+CAe)`=_d=s%EM(?b%Ja?z>wS++yplg$s+p zFW-CM7x$bJTTh9Ny&}_#Y|~Z$m!(c-dvgA}2gUW<7aA5@zFa4+KO{6B7Ma7VH|@{M zK3(;=dQE22?ag9!`)AzU>%ZdeUthQ(?l>mxI3{-brB1(C?T0aUxMueM^vI8oWHyP8 zHp$Tj;L6blM~*&7Hshj#$52s^1N#36t}oscFg7l)0WwiKac0X^&vq zlXuijzPS=B;VL9psR<|;%IPFvnDTB$wW$k97s{}NjgeY<--~p0(k|F)g%B1X$KhQa4iw3n6_Fw z6v@SUlv+#eP}Ni9P`5;BX=MqO+o2nwA*0^lj{AFXC%go9RNCi@fe@r#zTpfRhFBsG{JX05{>#|@q~{b z9~@*k-yjzs@x2-u83;ywVC=g_3Pjn#gl`!3U*C8PHaw1Jl@3Ge;9`m;JO@UDTyVtq z!2#uoV-w?COnDem*5UF|vac7O9v>}4sAB}{DBHc6>j7rj2p~!}j7DT*D9&^HA*stO z(;sk%DP;PlOi#%4bspBKvVSUDryF4a8Dh4u4RdIWeGR_+moROU)N(ClcTJ7K29=m@ z%os&x-KPV$2ebXZ+I4@=o%a^H7y7>3EpFT=HuXqNJtEUH>CfAoFUHTAI=DX;0dn?w+kn_sx2yH-JDtNOI4I=X>Yv*}m);$VuFJM{L;g5~5n> zjT!&EIZJ0-?ijQFJLY*82!w%Hpzoc(xA)G(f`1_hO$z_ISL{3hty;F4JeF@LlhgVK zB+Hb^Y+bos;k^W^%%=iuHF&gF1AcWJOz+w})uxuhEkK2a!^ZN4hxeu$m964SnRIz- znPy#{TBb##noApL(gb&1a2Tn)g=)7HxP2^@)gXKG3(8?C$g#Yj9GgZC&I<|27nK8d zhj7PfEp7`*CdriIwxq2X4;vL8CAW|cHNp+0157idZVs+n`oXMH*r3QJp=o7lV@x`< zql>kqoElzKF(jZor9Lp0khZ}_glTpx1!-_YTW%*wv*XMn7w|bt?IHY1&MF&>`J8h0 zu|lOpBYyoQ#JVg*EpmOB`aL9X0+CJp7@S9JWV3)X2z$>}Fu0>C=J4do;n|S$BC|Od z9UTtJrU5n)l+8oIk&&Qm2;p}CzC*HsACl<{+*`oQqiG2DxKb!su?KgBikg)LmiLf1 zPBd$tai?2moaw%qYSe5MADo_aaOO4D6y z*&&hHkTqoAx?>fY&Rk1w;9keVu6tV+4dUhlB69$GWo4$^KjbI*e0AMa0*m{n&1rht z_VMw#S2OxFh4;hm-2ddqYf`iBEP!=g&;3_CoQ2DEK zDJWl2z@+@DnBAgRsoVwVCNZV0u2;E>n2(^7BO0!fHVo}TTN46#4x#KAeEIBm+Hj>d zRG=$hO=G%WVAkgC&Z)XIojyKmnQqFs0dcm@56-X8jtNZHW2XyL>Z$Zwnbz5NGu-TD z(b<~q6P>6gom=wHm3w^pbjCe>4m79DzHDU>Q!F}Asl3<;PmM~ZyTp$&YghE(L*TC~@u19=JU>Yw0xN{8pvIK2qQCXy6t^XTP7j5U zC|5^aMGZuPzMOW{lFTaO{JX@4Vrfv+o}?2mY_#QmNYDHic%7m5!itczYI)(J!&HHj z(R0FO(cdOo$|A$|L|geOsaxyXbp>nceelatum0L%Sxstbe3O*^+7(c_t2~Ta`QTZo z%EK`Mzb-?Co>gV*i*>92Yu|i*?`MJ)+Nrxz?x*f~&wWe#Wa?XbE4K$K07ox^|M4nh zV$Jr=M>oB}l~NSQc4QFEP(!z~qmbHo=P#cQg29@!j?UuQYl%CV!-V2G8?{Ae_iP!1jk zL4<5$W8)*}B*w}Hj+L#0kysdRMZj2-;5a-e$+id|03Nu>lFftBcn}UJ=rn_OBXAkg zTZqcGQ#V5FD5i2vP>jO`N-lLcKgRZ>C@ViU&ax9Mc=-?+=vlY(=*<8&nwtjBA0zh!Ow}ZnU$;RpHRS6&lV=KK z%`(+!?)V1Dlr`_&@X5%}MrPx4@rT}ROWtj{(84a!yI1n=eds;3mT`A@7jeerye9k z-x`x(=bPBe!+2eQEO%Ve^rE=ce>)nZCIEGw&y#$v^;rFOSXT>vq|u5%AHv}xwQ3|u=UsqNyv46 z$Qw#6kVGn*Cue6G9M-ce5@K8Z(Ba)Zr{kQ%{j#8-BR=Jh3JFsFXLiEuhh`{ z)q3HbcmJ_jY`8Gx$}{U9GHpvtTlVb3_Weuk`^ENyQu{%XIV3TM1oFX)qaL9iZjHJm zN0(sgQVeP(Lk%p*68NcscbI@5BLW}zn&As}{sI63^Gp`lR|c&b{4)a}Mx#V410Y)k zK0N@kYXFdL%PIjN2bh6Mn=TmIO4|?sAg6XCNx70P0)W6RSOtKpGyuq^*_a)x0H7-E z_MrhltCijlwUq%7{KTXLM8I680XIr*eZa2!gvULZY z3S2rsI%og~nF#1RZUjh|8{9^sEYE$2DM=(LBp)I93m|f}8kPVnH;Gw4K=MN*e~IK4 zlD|Uo*GPVZWC{sE5_lxi0tt9hu*8}&mIz#rB!&^eppOyv6DYF)U;gh`2NvFj$v5+L z-=PbFQ~rOV3!`~IdT{Umv!aVUv*sbww8S(Kuwt73==MhsoA)d=?-85#NzMD<`cz_i z1o9y$B>kr7XqO!ApS^YWg0TI#xbcLv@r2+wA(&3Ayye4<9zn{}Wgwqaa6N(hHr3sn zR(*gEoXQeYc308f(do(|^8$;}S55{#BW!2uU&Fs3`4{-|13)IJM~r=jNw3Xx&#i~s4~blW zJXg|V@0v`@ESm|=Mke2U(&W1x`McQ9W4VCXbXaOSEST2i9j$_?b>-4PEJl{bKfnYN zyh?-DH4R?ZaBWMd@fY-<5%$mDfiHK3MF9RVs-Kvt_DJCWqEsej{3GR6zPc65YdvyG za+4JLxn-0-# z4}RiwwuF?`MOGS-X8tKjqd(0#_`wx+OXUJ$BzB5&DsZ4sRjRH4Q+_vOa|paX!R<3R zz$?ZCd|9}kB0*C_7o+^*$ilJwcocrL#Is;#h=H|({+>+xW%{^GpOfj6JRXXa7o)iU z?kK++VfgJ^5DKBQDPFZkqii$)Mw1vDh{OWL-z0qsB7P5F{37rb zX2@UmUHlX3x1m*fEg^7xcUH~3mcE#|k)v;aD0#Ms&aIMjD=6+4On*Z7`!e+OxtViQ z^b=?G$A&pedJNt<=bTe?ZjhWCJ~Q04{j%y8Rk<|khiV?^R7NF6a%A9b}b?!P6{M8fjWhFj29E*=b zNYE}qFbvTw9F7|CN+z`tt@JH(Xe1s6XXztc7IN_|;;Vqdk6CCsPjw2~PoCN)Xg_(X zMbLioRIO0_9REeCSJLr>OdQs80XohZs#Ze^nPP%rPT1m)w F{SUvTw{rjh diff --git a/basic_function/__pycache__/operation.cpython-313.pyc b/basic_function/__pycache__/operation.cpython-313.pyc deleted file mode 100644 index 9cc28ad459db11bc76eb9ddea996da626a2fe4dd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2860 zcmc&$O>7fK6rS1jAO1=FlPpaj)Ia&bZG#g8g#-dMl(dQokhrGg2h7Gy_53E(ozPaDcPg6t+%L4)A^h0Lf4YDw-fg^yQhnyB#MVq}uU%J64{d`gZ_e z0#0nrYhMsW_JX76s-i@(Us+R*qRd_vMfZXzw3>^)D#|R5B5TS~l-bLo=zTPA7$GiY zs>Yx^Q_oM(*stP?F&qYf-&SQ?ky&-t3X&DJ2eGZO%C_RX)ww)l%V1kmh#JIA`81vT zDrEA}`$-#Na)SQXl&bZ2t9(@9Pa+55$Tnp+?f^~6vf`va;ArL$>6!1aWI1<|fHf#z zjgptqQLk-5^vGs#Q^{ReQKrIf0K3YPzUWWpGWu zS8MsB!tp(^SNc-qptgG!0QMwC;ko)U;-#GAY}EQx-Tg9(rJV}GPiKz-$2!h96v76agG-ZL`$xYK;CCb`^^diZqV%8Is=fO(qkZZzjbFzUa8%o0<{0@bTc83Slli zndIfQF*X&8grhTjBpplgzT^}quteH3H76NUGk7ZJDYW2N1gg|&JkEplI%dAVC=X`G zjsENX^C#8=gN48-3XF>9FWd`+?>jEyag|+a%=m9LWu{SsH`|w+Ug^)f&|8O5!{Iz$ zdgrG)KMI7A<6`j@T5oD8#~7RsJTib9$Ib2~cCja8Lk>@7I0u){=Q>f#{yao2U2B%^ z1^B?m+`N);EndwGpzUqh#+-kpDL0M0U3ow9_N>{CEchSPH!gK&*jqhW8`|+^b~q1L z&gVN(d(SFF?f$j;V>e;Z+ElPMBWrVp&H6J#YgXSD4VdE)kM^SvX;5?Mk&zsx^Yh43 z%V(F?EUpFkbDeA9-J-K)*-&tHAZLeYZd&*3UOrgx>_eV?xdG%kkZ;aU-)=)5|LP@i zVDL^t{2+vS&x%74ahyX#QE_4t4aHDzOpGT`Zvwgb;wyK6!FNwV(Y$NZ04$9KvlE$} z-|Q^7_TF>t6%U;(b{t#nL>(u@-gDy6FzOuL8{+kyNWs{^;&??TkwhuZq?;-yol z{&bOXW^9Xzj2|(r6;(wHDJ;{2Lor9!U$RF_F-NZ(aU#s?3&!8=y9)N!d-m4cj@4oD z!*OKi<^$^{M$|Bm|2PEdSom0STqIx0hA-CF&OWHQQ}684>Ap2M`%H#Y_<^gbzMk@9 zC11rAu!U`a(2NJka5#}1pNVrv1TUVF--nb351TZE@E6dw-rSm*&b*!7n`N`DId^U# R$B6EOe*zEOF6_qp{sG{umi_<$ diff --git a/basic_function/__pycache__/operation.cpython-38.pyc b/basic_function/__pycache__/operation.cpython-38.pyc deleted file mode 100644 index 62d500d9fe786633c5ff73d4ecfc0e8e7275946f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14170 zcmd^GTWlQHd7j(OUbwu7q9sw5m5FRe+?W?pa+6k7Z7WoixQS&qQfy~rbv@h}E|;9$ zS)Q4txWurR$cg0w22vC)P}Bv;Jk>zar=sXn3lv44iarE=nuoM#A5y;+C{PqeWA*$0 zGc&tOQg#~jtxL?znVB=^`d`2Q=eNhlD++$!{ooJRpM6PD{*`Wqetx%o zr}&-WpJ-3pC%IhkXWCELPqa_jr&Pt8+Boe^Y@C@@>{+jfzq7Ln%1@$n4%er=bKZH^ z{6KSyZs{Gv{+#!eSH`>9Sw)`Lv{}V1<9)>)!+T4%fLeK5ZuJA*9d{@2_VezdJL#49 z-396g-%h$ycz?p3ey3nR?ViNn3|f1_J%#$uc(dNDdm8s=+*#Z|%Y9>nvKQ(btzj(g zS=4#bJ%>79_|!U2x#v;mBHDclEq@L@TkhxaWX`K#raWhEmEM=Bcb?yc4-EHd_Zify zc~5$0-DlC;7u<`u|03>r&G^}zTf?*G+$GHNEY|Hgv_J1IU}p2Kc3Y`k+C8XTw;F-( zdyS|SbgW>_TI+SV99nCEu)3bWrQPZ@t%m4_k>mS8Q#jrAR>N{SuGJ3kY0r0r6`-m^ zjfC@+u-90(oX`sU?Y0+*zU6egeyhQ+=dG}ZstwQgt)|yWKbq&(nqENlxoyYqhplk2 z@)42et(uwW*E{{hxY24viSf-=7$v1wyVRfKCxtt`uJ0vs3$rNP<{TRatX%@j_?jnyAzQT zD-Tpvc?0x)TQQVK-K>ZUszQI?Q99b2%I-|0MB2U{D+lV4nzyo=wEAxE& z-D_-yR%^|2tPsGD?&H9u6u30yvYacLd9 zzSIjnv9#ujrJF5m=^ZbOmb(3DJ?PX|Zojs)>V&OEJ)_N~oY)q-{iNIq>vCPr(N4|a zq?9$aq?+o4YHFH>GfLtlo@2L*pvrpC4Mi3ABO6YmBJs!@{w-{OQ`u8-Xt8ob`GI=x z>Yg6!`^tf`&*Jbv+t*zEK;1JUJvJVwuCb|$8Dw7ue9DJt!k-Ev$y(Dur48PET6uM$`GD+dZ)!{6E3adF)=#h>Bn zm^*f$ihZ}*(cJO4;7-KFHSL;$I+I-IM|qvU&Ff6%bxv@d|IF+BTV7{6s}oh@$^-QW zs<2{HwqAf>jK-shee^(N{bfa0s7K88WL&{qXXISBcDCM*)%~fsaGSlHJb0oDUO`(R{}I=Q6CcQz$Lutg^dthPg9!Gk=I^tRr@3pCJ?ECPzL z);%9`C6`|ozH+vMmJ9jjdX@_bkxC$ZVs(3~xXmkP?H=MDNhnvX>p#ml=e)({;=FbD z;?BI)pSO0OTVA?+@m|K4NBHf!6_L=w=sJ

JZaM_XF-|h$|O{`o6Mr#S%fpQglgp zYnjl_*2bE91-%j5({K38WqcV4rxW`8v9%3xABfPJTgiv#fR~5*9kB3K>t&WpnJ7QJ zL9`8R$0se}xw89@RVz_GR_8y`!A!d|w~u~$ad-S8ec0-3JATVu{7@5TP?t|N?IV)B zT;83YyZhY2y%+E9^mi{U{Lfkf&^u=4OwPg+&zQ7JI9dRV(mmy@PqrFNQ#hszSDqC2|+heqWfM)Oi;sGr<)XZ zJrRUS*@M_21?DB1@RH(MtK&jJ;u{caH>p5fiYUC_iq={7h;e)`Ch0UqEt$}YUdOGC z2@+ZMiEMA{*@h-3hc>OAZkk<@@~e(c>}p!>Qb#bqh)Z_r?Yz`T)K+59Tic9+O|KK$ z1-`QAm8WSM%h@_Agq6*2 zRBC%V#IX^n`w+=2QuSDi^#>Y6<~-h;dquo6;zDe06b`hwxNjaPrmRrN9<3DJ>ZT5RGb+ecyjbVna7|)EdSEf2~8Kp|caI1S2 zv{VA)SK^{O{yrJe6V7j9Sy$kEM55PVdv)Mttcjov*6MXykwx2dEVHH?IczVd+iV4^ z8_<^nb{uT=FemcfOQ*?m;Qm0>MzB~?T{~~Fb&zXk($LxO?%jpud29I|Y==XIdvLlu zh!?Vyp#7o6^H~jslQlzEsOnom$D6lYZ_VlXSO{1_pVI#01p`>g8eIS@cL&&3AvYm& zxJ7EVogQTNX55$VCO#95tH&L}P<_3vt1B$+jCgcUmT zGOTuO@lv9S!<=2YM$)7aw7W0-9<^@ZWGb4nd!ENu=k31c2M}rD3yXuw$BvKugK#bz zp9OuQzNfsafc@a@y;A}aVyv@p1sOVx?bWB=i7A8f{_sE3f=8|HsLoR)%ri-tkc488k(CkAAq5TjV<{~{U zNSRE^N4!QSa_|!4I_ij5=>Ao@FQ%jpTZ|XzB|NmpNS&l?EEcGmx}K<;iR#;Bdd2yxK>mZm1KFN=8*Z17pj;8x7R2JrZx=>)j_twAL#Usv%IM zl)^Ttu7c>ye`J{!ajf77|CDrp-{d=JnegI9%n`Z)SU-fP-N@whK$ z24mv-G@3COj-OGN$Dh_MAF7+SGdl9LZUuF9n5m_>+|+S4;|i{2(r8Vc^z7bPq-~fm zTH)(s?W=p^@%Tn*Zvx+yp^Yat2rHPxN>|`nPabISwy__zqVqj$*{h8g~0+SH5fh7MsYB|?|Ps;(-vsHkGIHCw|35!}F z2~ZnLrkY+?zlF_+iacEJK z$wN!{=NNf-0Vid=3g6ezX4I;FS*5>u`qM5LQ>v*?;GOt&)FNz;_VD zd}qegvJo46SLC}wdIuK+U?l;Va5zpvB%{{?jO;?oc~u_4J^fk@mUQBejodJ)egZs}PZ<^8gF}H5qxxEbAOs zlR)DRTweAWuG~oNckAu#w{rmr;3`9+rHpj&S+`4O@saNcg|xPMFyaLTC+4mD2%SJ~ zuwd=5$bp!3Nf-eMlrU zf6?XupMY;4MqZmC(-8cci@Q<8Vw)VAD+{U2v2OL+uOW6rK0VEo2bqR!JRajRwvOLaPdhT!N#lO zVi&hB)v2CcMqybNO}OH|w`|MnWn0@q3tN30yEyJ2X7m#z*m%jkREYc&S`GNC zGe85T^$FNq_=A=8b^80M@bDHaKE}p205qq4OjLm^Dl(nfqWUis;?$ao)wlFLeb0aw z4pxuB^tcCv1*=1W&uzwltjHF?)AC*g7J&krf#L@Us>WqkI|(+6E9rGiUR_ubV4@0+ zF=8LG8xWUJ#I9So4L?0L@CJ;9wv2e3MF4_Y`vn&`J?sfRo^*@XU<=XLD3?$!aBGup znV(nrd4->oy@GmYASmMMlT?2U&nMDeP?Mik@oaL#1nZ4EH^6!;j!zLo@@mPp7=TL) zQLwNAJ>Z2c;QA7$0jew^YYASWAMq z;nc?inHgI zEYlo%PUHi!o`*hPL)k8-3zc$Qn5a&ot|n^3 zF7TS$2CcbW;5D}k;Nd#%=&~v=O}hZ@t~c$8?ZAhKs$=SHFtcsSdom9Q-agE_m9@)_ zlxI~;JtUJwsg~d)O=+jpit)uqnhuCsI|VR1cB8JvfkAL^$pl^6a?PTHVg!UujS!g)O!;Oo1hHm`xw+An$M}6NPfTo z$5Y0z+L3Albhr(&fF&x!wQ^moOfp5Oq8ra?1ItHMX403-<9PD&I^IX=We z`he_o9Y;vW3@9xRaWLeBv>PD4&(rT1AGCYIi>97?5f;+yIK^$WMOlH#`MoF{y(jZ;rvtty82Jmx-Dn@vn zR^G&Ml6HWEnFO>_oCErpK)VeMK_sj?1(AlGGec4kc{tM7`7%n_c9I$58l8THPA}7G z)I&oCNa~?!=A+&7up0e-{q}cWdF43g40sWR&jbbE22L1=m}4hxdiqOrx=N>CrW1F@ z-AYht(Edx*G5H)6-l%A$Pt1bo;0|dUrPSU79GC5TD7f`1Dl}-NvLCHjZp+ZHh;D%i zAutD8MZSZiQc$Wv0lNLd0pTh@LV#Al?2mQM0nnU(_;l)>hc`{>IAl@7KeMs{vd;Z1 z+RZ9Ur$4i0L3pcyKNr?u6$@mwbb}5;cZj(MY=w+8nppw)YW&P>@PAMIhpm}GZVv_c z53$DRWQT$gw1}CLI~e0JE&SZfC5;J_X=;EP*hbXyJvP<_^N1+|P>tx)iAjve7J<7S z3rhPWh21Gl!*dB6Ak^tdH1G6)sOyBmWEEd03McigoPeU&&T3){Wyo_vhmZQ;MAjsN zS8y=u4>MThs50R(K(i@h@qNUcBl5Qm0GTk*A4uU#apOZOU0OexMkcfk8BA$V*=;#Z z%2T7aN8?ZgFlF`gn(zfXKrEz*3=pCT({t|9JpO7K$p&5CQI8Em5H*lFYi@ZSf3<8F zpAqt+kR^o(B&Q%t!NW*K!d8qAl9f&t>EnDno&a0PS-(f7Gi*3Z`jZB#Y8%c@E3BCV z@h0xlc8$}CHMG%z50GN~c9m^?8gsqVwI{RXN~6Vx_4VfgMoL$so{p%Ve-enEjN0M=7FLyFat%H9|fsmvJgFG8+XOmt^D?~c+P zX9p5u+Hr*CA`s?ukRm1XHbtxLAn`H{YYvh&sdnz7Fr+}yvtH(njkZENKev48Qf(f@ zw@eqx`JB6S2_&}Wv|F4w0iqshbkvf_H|*1hEO*c{0{S#^8ze1AiX9C3u3Lvw=hort z0ltK{Cp|bkP4Sj|k#dI*KFNL0Lp~TU3;SG1hV?TZTU}_#S90cIUe&w%K^ami9AV2aZGb%K+f5m34|p-k;Yu-In?1o zrJH`NNuf*kaNzUl(*HwGigf&Rz6`2mo?xMYoF^n)IIAIJyUd|!?n1=qWFNsfGbIWJuXE=#Vzluv=MOIfRjmMK~K zUoOgT(UKDQu2c~^Gb%_9B?G>sSM_t0uXG=8esYC>|M4Qk4ut~|rUdqnKw{&V^2 zkUj~2j8m>h$aK^WkQppZNL7h|%aT?>%E!4vk;xP=vP^(OKz3aqpvZWt#05Mn#iflgV70~fRmhm?{+0R3%JY<|r77m)i6 z028F109GLKBlBU%l1K7I4?qcJ0|K@nNR!kUphMVFZSQ^{Hm&WJW4-RaVF6TVwQ}b+ z4?4ji3;gXY7jb#f`sz*>QYXtf$ogXc=C$Q4X{moW8R7Dx^_5=B=kJarPY`<}fJGxk z(6xo|6D4!;P$_|(@fJd3TP_fwG^z2*qP3C%v@{q8Sy#w|Pi2gZ0TE8qLAEJcfSd$~ zMd{|22Ms5e3i7liGTD<+G|48iW1u4)D+D=XwE)qkqOc{bZ8U^|SZ_GUV#t<_Y}Ds5 zBuX8p;b0?3S|bym4UfUXlXwTLgzz0Ypel0M=#=YX9fDqSJsT7j1~IY0+-XQ?-bYFa zRtomQi{ZPB)u3T0OO^JDSA_?k z6AmrWV-1qgRQ8l6G2>$i@-Zpx;yQR)TD7Wp3FRXaU&q5MI9Qw@qXm&XavEoq*TM5! znhSMMQ*TLu^&QkhlE?VAdmmY1Abj8mUHOVVwemGWZ+5!kg}=&VzCn2`Uqt+qtR1Ad z^P1PG0a)`xvOW7yN^e!HpsSlm0nsJ! zjo8DUa9p>}Jb?)y=i=9>S%T=~8)6We8fAn~Dvw>@y7o*?!1bJfLr&m;o*|hm`ZR=U zDh_}Zw!8L;>@H7C68EU5au#7Em4x-9Xu@>Hes!`U!u7y+;jF*{whKH50F^xcO3G37 zMhXX!5hznOPCsUEen>lb0o9agLL5u_d9A9M>RE&efj{DyMhwAJ4Sl*OzK>cU5+vsA zP-!74*6VK2sMp1B;9e#RQ%VvNC;Jp&5|F)$Aj%VD7YMRC1R1D;k~Zv82lfnv?2dRq zw*~lB-9FnJ$Wx*6qHM<2JKe}050MDaUqp5!!vDS?-${HIr(|OE-z}I7IB~`<-nwy} z;z4$i|JUP7&)e1X$xZr?6t>&#u@!g=ucx|ShgPtsZ>JE|t9e>+Vyp&%FB78~<)B1M z!s8-Ie-u9D0E>j|e7YsTg1tI6)LB>a3Jw_sjDRppm98PLpyzup(2|AZOBg1rdey=n zFcPd!>4XNC^rXoVIrkU?4@&0!3Yb%u@{%eJYARwLnqyspvz{r+G+=_969KfdEBeG*-X= zKQpttBxR>T-@3%ioS8XuuK)G>e>&si6$QWVzW+ySFI-iWf1{h>pNX5RIKqEX6~$F- z#Z_I+S8bKAnyv9ww{=`~-)I-?Lff>>cF``jOLnPUw#)5`T~XzCV|JD6kJ}S`oopNS z6u&e4W9@1CIF}3lO#3POsrCu`gsON`>nFX5^;5HoJ?jaz|7`d?WR(@u(MyeW;Fue z_Zm?v=vcw3wc6`&IkZ*-VRb!$OS{!+S`E<;BgglHrf|Azt%l`vT&o@6)1L1LD?n9; z8VTnsVXv`fIiVHy+iforeaq=~{Z@lt&s$*+RU4k~TTQQ%el*XmHNAl9bK8#J4_o14 z{5S@pA>HOy1thbzTLrnO-+i8uV23e?f6|tiQ4Ix65l@ zVZgMc9=Z;_PBiFw-g?dPL(d9O zP26vV-XiyP>$Mv-BPsBxlFDtz?|H8Z5s0LmjW0|}uVq73#TZud14SIe<)c>{m+yW% z^h9{~nhPSsu7ZXFare3xZbU)%Znqz;1)cix%{T9^IAN<%cmPHv0ceo}6Qb-6mv z&|b~pq?9$aq?&3~H8o9};F36w=h*Ke=&~MkLs7;3$d=RCNI>$Ie+vZARCZMyTC7}G zexTmFysO9hp0cm(u|V9{_B2=DS9gs_kBtYaYi#J^dStkTSpQ?iHUC6^png}43VWto zM6CjsN+_98k?LPp?)~Vlwp-jSMWwy6TaHWn+HN^6cXAJmvft3;U)czegC#fd)E~ zQ6L7aHP44+$t9SDubi!*@w5=dGO=m(E{2cQ514LtJ;wib!l>be+fn@q_iF`vG?}#G4C4eP7zXWQibR zIeK1rtC`r&*2bE91HBRR({K38C433q>V!UjY;8f{2O_lQmh<5`;OL=#2aJ5#dYz?H zCd?0R5N$)-@kvj3uI&CJ)k>6))cFr}Fw@S=&BLFb+ZjJcAGSJMj^A<@KhVS})a6r6 z`;cTWmv^S;?!35g@0B~-{hbR7Us^nW@5Nf8ZhxdMG*Jb=|M;7qq=i=meU+$vahk6? zA8NSnOusI>AaU7jL!w#dc21%1&JiTHSUZ*!{NTPPk|HFa?=+xQLeNc==)Ttx6V$NQ z=_Z98PXu97_8@jhg?Wi4yrj6=>bOvm_y)w|!z*cCoJYrU-i7ff%N}%6?@O3OHB{-$ns? z7eWW3Qrp!bj*Up&gGgqPs>fQaKhPjD=kea$E#jRK7h-e0u&>3%J#$|(cPaJh7L!YGi2XYN$o6VZcCBJWZ zo4pp8lQ`z+q08~r%Dlzm1?*E&aD;IBiM|8s*WDmW)LT%un?271!^~3+P$7wYe6@sI zagm-?LHGljA7`4+s6Ck(jHBKS!WP?@c8PX??h_p!Ck6h-E_S^8^za~YKR>*HTy@nG zbx5qbT<$vWCx#Ldhbz^pW`MY-LFiMOco|Q2W`=FX#bs#=gRG0GS)5eJqOC*3Ha%%q z3==fj6=X`1I1Jwe&Bd@EdWtyX!hUFK!{?^H3p>Isyn{N|VQjvuJycg!*Tj=jq&(8d zYIvaTYG4CBQqfKk?goyuq~lIMsW3y_B}Tz>dM0b|+o2l244Z^%l=2!>8g64|*M!BR z#mrJo*g9aIKi&NPri=cIkJNP?jQ8@pN=(<6-qjx}54DH-s_vF~ys8U3g2!6|D{7Ik zUU;Oen{mkOVB`&(-?~y4z;rtesbs6qQBzg_DS_gi{stDR(tzM@US+q?@GHbe#!&Y;;%~r6o z4t+Ub$H7()b0Y7(belW}4iHpr1dAorweuER2f2184W0e&+*w$fx0dd~b~sSD2e-?E zcp*y(+8+wjVRXM3V!d}iQPY=)Y167Ln*YO!7r9m<_ zF)&fsO}U6aFzFwnZZ;ya2gzIWe2>gFZY(MGI^@WAYP$F;-X|tO2+cke4_XhQRnF1V zf|SFgaKx|CiJZE`xQ0672Hlf8pA=K_hW*8ha}^KmF;XEZ>xu=crmiLGMxy$5nO<=} zc9n{mdbh_S809|GTWDA4l^hjMu4ag@qZ!%7H>mUmooEwaAQA)dhrfgS@C!I8lO^?p zHUZJxUzi8_IQ&K4YuZQgs4r#) zW8(TWnlTuPpHi2{pVloOsGGJkI`Xt`1$A|prKPys)NwZB3a(|+U`?I$>h4&it(!1X z;oD>FtGnaz_}4%luc7gF zfOoD3Z`cok6MCLRv(}=hd-?qNFaZ4JZ8?jO_UB)^_=TfDuWb5Ki)RdwDsUP+fHgT- za#dLwAM(p-?4i@9MPP6fU$FkS$dBU{Wf=f%IB1Df@6mdXbiz$ytrIx0^qkmLBW0cC zrPL%+OlX4K9(y7ccd&;Eo=k3Oc9V?0Fj<+qKCb9V*2oT{HASqt>&5_ zp(TEaJ~#%bkz1tJ>MKrTqc$bJgEvVL9tFLD-HrmiM$mvK2(Q;8P-4r=t$W*ey}rFo z=YF5gJ3IV@%iCP;@7Sf+rN^*sk6{!Z`S?9Aw5#$)!W;HPPWyGDeS4f_bM8bOSkPqh zz!LriMjoEUNg1!g&o#6ewW?oK>2IF?w6n&PYU&esC+esLIje3R1{G;rb5L;|iyr|# zj36}0>T%TvFocOB=$V^7g)JnXL)3y+bOVuSCBd{;>C;7$OTB(M_B#c_yZ^jd&HJt!B@uQ@|cODGj*pyyGhrzodS@&1(3 z)YIo9wb356UKna|xOLPWY6q>J9JDjiKKFX4_7Sb;^N=DH;?Z;-V8Op4;|!S#417y}k8zE*t?;WvH{1j}AWTcF6=j^c^9N)@Ba|yP#0SymcRe637e| zsvQuq2Bn+-~1nvgP%nt!<)(%|4DD9QO}0`Y{r0yyRXgME)5qhXflclQZfm zc>dG+1Z*w*!Mge?{rya=;w@Nwj16pnWlsBu2m)DCWH7Te^2*wAU04ucq6&^N zVjr>@5Q9*}j$611pFB4328@NajChR}lb==bY;xTM>y10t!FtP%Pw_$WV9Bl+07(l` zu&@k0;Ds$<_!5%=f-K=_31g;YHcGWat55WM2tvPe$CHoG{K~BEI2&LM3%pEy(U78S=``AD}uP; z$eGPk_>TN^in`H^ET$*K`hdyHz9lR!d573Ws&0o$SXSkRn8g`hB;81jo|cJe7z=w#cW9n@%vu(2^HUv-@Ka7WFk)%Ot$frjm4h-NI4oZx0I<35k z<2daAi82Xdr3eSKF+p_e8p1_bbqW^^yJd!?Ao6ddr}IUWus>^QLA#-Z=xC(eKx9e&>xhj&jX_5>cQ`P?&Atf`N#MyhMjiU!v1xI{gZrxHIln zLP>-6uTsb4aZn(mqLn^23nIWB(lkn`wFd|;+xJkm&bbh@QrVAIEVpE6SVXr#gAjrP zts=8QQYk3apitaiVV@8cU?9LLVD!g2NBcGBA3vFT=ix_F>J8b`@XoAkfUI#pjdruj z(%sK2SrFbz;Ln9M*u(-^EZv}kARVIX0ox!Wjb=7Lz8XI_8~ooB|3O=3kjq1%`~$2p zI@zJH11)0a;ts}mL<_$#b4gReWSSaa2DTBke2-0a!8~G$E~rLq(uoO}$QB{G9t%qQ zI0e`#A;WVC86dprP~7h1fT-()zGM||6NQuVR*pfC*PeG&13A9&%V@_}FbZP0C55 zw}&H61TAIt^P2DkJHRWXfeZ+u3Bz;l!aV+J8Oa7+-cgUuK@c@CIcsid9)GoL7@rc{ zqTnP21tg~+AHl;&Mgmlf43dpb_UMy5Je~mC$XUOKrZa3dOZt=MscIX}b}Ouz1Mw#A z(sprD1Z!xc0}miY_U$TL`ZVTxr)y7U%az8359;gB0)~{1Mm-%-J^w=PS3opFLQSbw zpuk0h(ukTM2}tjmO14ozh?6M;U^mBoDCp|yKP1ST!kxF^s==K@fE3};S%gDz1tNo7 z1GsazmYK=`_(C88{>7b6aPK|<^_0qxO7*m|JBEZQGY0I7(CV3(sLpU6j#3@x0uoZ% zafIa};N^6X3MDf%MXT)~#WD?F4$?EJcJ7`qoIp{sUgn97wnFJ2L zoV#!V1h(q5Tbv*Pk{)Sv)Q-rv>yt<ta^_)R1szGfBQ&CcEJ!Kd#Bl)8 zg+zgqFl;(#14@TzVuRsre0P9!IoS6e_FLk(%L78r=1br!`G+*#I*-2&*D2lbBaI0i zOoSGt(!MMx)#r`4(~WbJY`x)6wr zgj~_U6NzpTFOfofyk26FJ)68>{8b1fkjHr|1X#2 zw=q%PAf<|6nNdOtC>if1y{bP$SxLm-Ke@!epT2^%pa3Akl0f_sI&2(~!jVY_(k0=K zamsZFnT^^$a)71zs45W=-%&Bh{f}D1j8F5{yrRg3R^eyGm4IH-M5pN)duR zGGzh=mI+u0n63-#6M0RQxPWJ+xU@b7l(zVu3JFu)8<#1fdlS)QgtW^C&|4%)kluML znvToa-TLwHFWnLXcqM!@LrI*`lv~Dl#_-h`LVOimwe_bSpuY`@z>hiq3Sz$wNP<)o zU<$;0WY!B=@kp`g0V1KCKL8d4W|9g6Xb3y1?cEQ=hPBmltlREe7T|N{Q->v=9#4a)JG%sf(8ut>p}wrNKDJtwO$gDq&>ghY*?$a!Szxq$J=gN-MWK zXg9f3kftq>NuG?NNivZQ1O4b&A;=l41#mVMge_rhp&<;!ddoq+LbhyVpT2}4QJOdn z2OB}M8X5O&bPN`rggXEw1ntlPRguF+r(7575cHaB*`Tm6h=dL1PD4WTJ`zW;QuvPm z;Be?I-_4_;I7tsu-Y=)bk;cks-(UvychW3znn%uE4BTa`2K_?0sI*tSDm?g{aA=91 zXpD@et*11J86QcIPe@-E*TBotqE*FJln=>!9S<+zU~z)H6~ynzVVqTNgXcFj7s{Zf z-jD*zMm?l+jBmO3kqZXG2Zqpf~Um@gXyDMJ)YX~SQ@js7S|Rn!t8vS7FE!dHkj>&y?B05UATMa>d0C*Kfr(9|gVgOYXZ0@t-?a^kJ$#2a$D26PN3WbDuos;MXdLfG!w z$FjRT^+@~%^;FK{ilmOPeiTEP&RDNb_CvT9_%56jSiW|F=Ky$;$6ragsoqFopx>kk z@P3_q!ruIl)XP~^Q>F=LEa_*os%EOE5g-KWh+`U&15-8h>7sauS|AXl;Pk1qkQD26 zH)z!B;MG+^v)PW@f0lO{Uq1yudscxSw z4P>QIc~LfF>z!_7kB12T_ZQ)uAmHx{@|nZ~oRW#rf2d$G;4~S#c;ostiUQe1{-2I7 zyku9?C$G_ep0K@ck1fMXxSi^K9U8%&zL`Q(-^>$-6JsR^e3|;p=mw=y5)Kzh`m57u zR_R3OECXuNsU&ZX4RqGiyn;gpK@PJdr7OrJ==t8uv}7Trz!)a0dDX%mF#fAg>4X87 z^r=&tI{k6=i@=IMuAWvt)>Ojrw4YTkO}#kv+_Y+bT%F_dERvZ&uAam@{8j$DIy-f2 H=Is9fuQal# diff --git a/basic_function/__pycache__/operation_new.cpython-310.pyc b/basic_function/__pycache__/operation_new.cpython-310.pyc deleted file mode 100644 index ba8434ce6cdfe9eb37fe62a693c22be7c68969bf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5811 zcmd^D&2JmW72nwpa!DN|r2-@G9B`H#NF1>ci`Fiu_ z?d;6^y?Jkx(Auy; zvqERX(PeHU$8GL>W^Cknj^|Oj+~oz71zzNHD2v?VC6sf#%;!;he1R{bEb%42jIzwF zdq%bL1aCH~OfA*7M7vcF!rG%wv=<8A6rkp!WJ}2SZk-X^=%H_Nk!~Tye?(7e92rC7 zfjM9UGhwND$T)js%1UA$SxtP&=d8(w?4ChwZgKlFBW1UZL(m3RYGSlKG)@}>8~qM< zKy{K_YNhs}dE0pSGsw*i@~Ojfs|L?G#;TF#(mZEq<(CcA@>y-gs2JZhEaRF1Ug~8N zUK;uJyn6gIct>xTkxPgDWzDSA;W4piMd@kDpH%?znvdY|6FRP?sk4?4hmRTFJ)mwM%9$WtI8r2 zG__hgh+`2~O=YfKR%~D9(5tgA-(BCr?ALp-kn0UD*SkF_YPasaH&r=EwxW)vZ3c0x zUTgF^^`sSb)}yYF0o7_9vA5PeP{nIuD^BoH-RotVBTeFz(La~1MQyweYuCaE$cV3_ zGUk75KCx(0CCX?yp3vZGNi8%*Qp0!Df)20M1DPl*6dh%Dy2=Tp3=Y&1an$jdLe`>& z{#NDmg^Xfl^;=!V?kk(Of~Lxatxga&l`UG$EoJd2QH$MPCY>bLHiP=3YEil4iK=`M zhV*jfQboT#m#yW-q7)%Mepl4GQQV?c*l^#$%M;n!S2i&zR}0-g=qo$e7qKc~nI_!v z@g_*J0BZikD&;|pUqI%Wp1Hs(rpJnG$-Kzu_Y!{R*%@=mEHjVUOcLXavXBaqTy0hQ zSW(VW*p5|L25v1IiE+fZ4aG5mC4)PJr6H>r18DyVK3JG%@WIZIv}ecyT1T?18J=l5 z$2^dI1k1~FsmXK87>W78T20k4s}!>&Dr$DcSfM&m(WY@3^|aB%C{3T?hCGtarX)^F zr_q@llBlImJ(UH8S+!9M5Z-GedJV4c^NXUjZ~Louw7IndZ|x;VX4hSC@sf zpz&rAooLivF$Qz#96Sy~ZgHG7IP*qY52%sCxUJLQ(6Z4g zp4ENmrSp?yfrn@#bU1hQ1|M(dlfPm;t2w==@>{6OmneG?nX=;@903Kb=`g(FaY+^G z0>+v2NHP336|3WD2)kQ>vNuH%D5n{;+krCcls1u?%8Z+eeI(D5+{s~>+2La2N`0gV z-$8dw@AhD1$G-~8eixYw#W;nuz@7&U%JBwRpHhyAYq#-#r5sQ!^PhTWwaS1W<~LKLG=HLJq(sC!QKQ^RHIkQTcj+8_wl2vv z>L;Ab%aonegZdS6CdlHxi|&|^Kh^;%pVkA9oxzcOhLxH56|~_x2Gq%m89(!=dLQU+J4lfP~!1($_b-AAyrsa^h~-N;0ch zfe2L)onD(1`{bQ&C!Z$;6b#ZoJJT0u$&DUXzv~F>06_!D*Rx- zF1qBys$9_RiVjx}T#2TT3y>nuP)0aW`FN)%L|;_rGeI@1K&5j8_m?aoZ=7Ni}tYxrC>H>zZ=_JrFV)ov-P#o<)%MUtzEVI+hGtDoGOCG#% zAN@sBZbC?%R(kx98#00T-!h16cmv`N61-InbDYG_fVn$?c|q3*=0(8Voxpr<;EiDJ z0p_K&q+wo8%Npk2LrE}y6VSuG1)w(vn135EFOOjUW592!2mA!{4^a}#Kh@(1(0UxG z1oM_2GfgF!@1U1pzOTm+pvPkf=7)OBgsNd)9>M%~BbcKNm=~w!NiZ*uVE&>8cR|D4 z0Z@w?q{R`;?Gen+4d$^@8O-&q2;e=hVeU`~m@jB^u`q^tnmeKqFbYF9rtSP$Q*uM5 zOBmb+!f1H|dns=tt1fQX2m)hypC}(t_7loJr0gSP8}rk;4_5gxb!|}g31tr`dq~-* zl+`HvDP;j=&r;WwV|A_VwURA@28Hs!#DJJ4mFe#Ps=a0Ce8rS?2%2$N2sC8sZXv~Q zA=BQa_FSQhY}A6Uik3NQ!COVkBHwNWz980I$BXblZP0O`Ap){b9?n%eGi3VsJ&hMB zyMw@`nj5)u6a>jDpvzY%dllKnJVo&lLik!1B5kaMFq$GxYLgB`QO#uQB=*07+Xh^~ zz);q{X2f40KPFRO2IOG?5OO1^Zh$beJmkyK66xjn6)J! zB&Y0sWZ)GjR8ESx7beP$1&)Ufjw<$^V&7NnO~vjicKhUZtRHU&$Nf)$5p;1w9())I zxbk7p5w%vQR+k6p4*!mq=y=SAbG4*@p7}iPL$F0%n#a5<&t6UG>#ytxba4K453-MX z2h8#=WiL=h9-a2twX3F$xkiNEcP)8p=CtiufZ$$hM;zCA*JQSSj0pN1gO#0=%R4!{pnYG&J2ef z(bU@BYzq{~1$pK^&OP_s*E#3jbDugKRtmz+>0gD;xhd*jFrfxrF7k8)A|Fv+%1fW8 z&e4Q+V_XMu9nr~g{dqkxK%V}*kr?H#iI}{GI*ORRMu3dh1kmC&1GIV>fHtoMpxtW) z=X8@WFvw1_~=xG8yV#w*A$J7aYVQnq!BF_82IYp|I}na zE`8LXwls&-o%>EgpC&4X5MK^HU3U)KOCR!66uXt8hJcO|22EiTOb%0{`N94o!I+$NZ-JU!D;xr*UG-JP3t;PVTLlE-V>2Xl%1cnwE&M|BwU zPkU)O$Jhb=@G)LvA=Lrp9gyq*9tG(2Zb$>3R1C?c z=a&0S>fEbedCizKbYYDahe6mI#i2k+8NE^*^A-)=x?&2*um9~*_0*NfTVYRcf3}{# zPQ`G;u5W=oqU@6peRkf{?+HKr|%h zO2%1^LbcRMq=fou80w=w1Y9ceSu_!O+4YNWx7zj=B zL2SA+I>8a(1`45#y%RH1nKE4gwKrrZw1%@F{2JzFj{2_Fy;j>L)^=rTUlk3mxib&V z*7=^F|0HoWW3F82GR0Yc)M$ zP0#)EZ)y&Ibw#Z1&p0m5^<{x<^D(jcSf=@PvG(<}!Ih(xG-b*f=FUE>+nx+PsOyP4 z=M7nAOQJk(OZ2|)Otvg{Ew!zdi%t7h{rA~>Ltpla-KWK-(_&qp$n?GM{J+3GwcNXO zSghZ>T6@3o-ZruRNV@K*$Q<3+_zyw)3+ML^YTLobTc|CycUn`eVp+?l)Xj32ylEPgDxCu3_~>=kXh)@*x3+n%{o>s8ga`fm26 zx3z!%*2-I%s{M2QS%ZDeutPNLNKSoY=+4?I=Pv!%lV+;sA}ydg{;F*InPbcb3jp;e zZmFkD0Krb7eW0O1j8iDpyn577Ly%I!U{1fOY5amVG-xpL`oDl9OL|Jtav0CcO^EaK zb6PBEMf{)^b!rcPue~$LV^-Wd{ibdGkbaeJOFIjB`=~GrFi%FW$*1$0a87X^k z^v~qb@{H*aLJhr~GRQT69ONkFaz0Yjl*5|o1bR7)QWS7iAs_TJtv<_#LRoD`X*JJ> zrX3YS%ZbA2<`g#PqiBRW^__WZuZ8lM*(#8n*2`mI!L7<&H^dc{$+|&X?YZVTdh}KI zH7*#764aUF?0cM>b(3gJ7S51H^{^L+a4Hs33ljkpD`9yd6eZ)_h&#+;^lLV#m3l}rVMd|QC^6V$zbf+<`kBdrrlHFNW{%YV{TzG6yk_GM55#FGvV=J zHsX#%gYQ`*Tqx!q110R9?23d5i?iTT0S{nIY*W!je816$o0|9-~BGBlTEF&q(yFL{AC0 zv1I*7S}?yHRlPivTu#rDy^!Aof8lRrrS`ajvO4Ckf>w#e>ymnr+4k}9ozU&jqW80X zpLefxFAc4Btq$DVFYbITv!h33dggkw7W+$9U@Y^EaYI5Mw|zK#D|9oI^nSYUZue4m zYG}D@d0=I~xVobecO3+N&-I?H#X3I#95q50-ip^H*d(3oN*NYHqNP5` zF4C#sVq=P1WS8ls;Bw;1!Z;BBJmRPDx_HMpf|8ef_w^IV#6+I&eqEgG67wGy6SUy`QPtMgCE8Uyybe?A{e zsWGqKXU4dXk)cJ-$C+#e)HMXvBx^RU4xHOrzpV5 zslq3>>+eB)&?mIN6=VL_{BiuEe#EcOioM5j!1yYs@2oP0oWDUo7%$4=uORFm{WUb~ zG075-cpy&pAlMB+G6+}UgkzCqKAa!86^CRAa&XK*=_=9i8zLmE7I?rwxQtc#Z+=QHJYjyt9C6{h*fQArd=+XznY{MPA5++oD<6$(u^A< z+E$H3Bb_k98zdU8Sy`g}>yHvGe+J-xlxVOlC~ENcBGJ4$u)IMgfrV}C)cY*0l)p}u zQy`IL=_Z%coqB&WmiT963CHARogK>+iRZ7BF=SZ?(pr^5&H9E0{8V>Bz$Piyt$rC&YRO&@_DfQK7>UJri*-teEH&;BXsY@^Tk+dL41 zC_jq=Xl0#xmpKS4x;SjOtIyjvrT3fHpjr@My*&@tn8#&?V9zL_aHk7CfM4f)0Jn$i zYWZMke}c(pRr`Pz5Y+u=v|xO8uNgfnBfty98QIGM`_OCgnY~u*tj~gmL_d07@_q=P zUFFPc%h#|M)qs5l6dfAs&O9zsw`i!Jm2u1RG zq@^wYDfA?Nb3a(}hF;UZstXhg129yhj_(loe3UR)^0vHCR~;kpC1^(y{0Bq%`*N=f zq#BS@aXA|ON&Ww+_!xZp!3MRp!q;b3vE7utwYUf+?T(5 zP}c{n4Iiuxhh}}^xwt5RAEU5-cNG;vvI$O91uvZcD9(`Dg1;YA*l*tH)0OOja}X`r zqM-WF3cc&UlIz>^c*%PbTd0-lixTaY-~ynXpwRH>6k8e(gtIk!TaWJ1z-dO$>4pvzhA0C-vr@8;E;*seG8Hua}_Na z=--e!3DjoB1JGT`mWwDpjyyJ#i=CA8;Sq3UOhw6vRIb4d$Ri@z)hy^6oI?JX;L;V! z0H_C$+&6a)fi6a70?0HO1Edke{trcQZ`+FCewzNlgAVO?{&<-q-VvE_88t}k8H_q35} z-1ohas(a<}9?IfcW46JCSn}*Q%+4%RrKZ-Ho!@s;RkgqAp`4Y8t`CFB_DuEe2aY}S z^m=XG9n)>o$K=lR?dj#M%eG9z{?*96Xr{LBAM5|AL9F%9JF-mu8q+K?&5QkOEeFMx zgPE4Y511p+s?+luhO#^HPAgpn9t=;~;8BbNtt!tO2Y;y=2eM%R#(@$3aK)TA5X?UF z|Du7wXoM#6R20zFp$a!}%oc=NY21Q=VAU81dEu`y5WaW;1HrC}XWrp+#Aeqvu+s4VgteXTMUHK>ffSLc2i*PmNFr% z1&rq!2>BJMt`Arn^PsxRH0$$sGY~Y-bo`e#7zly9fv{y0BWTNu83;a8ss3v?-^6rL z$LGab>A(JhW=Q^N$8$Nzn`QfTEfhHof`Xg_;3ZW(Qnp=Crlp7e8HV=oTBgGYzjj zV0ykcqdvCg^l{=+#@6!b;9dVG{`B4xnVl!owv%bY$!9MvK_@F0m;VltkF+ij?ExRk zFyO&LamA>6l&3Wp@tf6<-7@(ax)-P+J8tqd^x!Cf*1=f~kk`7xk^7=BXfDctq%SW= z@e%?Yy@ukeD~~}j|M3J*TmU2k4}MyYDR*a&{$6qb3tvOfgBWG_=qL&=jsl7|SgMJ@ zQG*j=fqVftj7MW*^79O3Q!6{U5#DdY<0?2`1U9Kc2(<85@E1_T%u(x%^#*gDNo-Aa zE!1b^YwZ`v`L4QVflIOr;kio>cew9_Q{g*&ieC<7b{tI`wqKa)R$arP)mM4v7wex(^iQE1HPjC+fLpzlwUrYS|v?De^6D*>O$4jfP0+cFz6`G zfrXh@N_aENn|<^kp00|5_ST{Qi`@MLjqepMdh|seDagRK3%QViwyY)>N3M|PVG1i> zG6vyQ9XzXJhXv(uid%tz)kYCG?Nfjog^y$+oG=-QNsPe3v6Y7-H2tPT_e%5$iG~CI zDFFve)ndw$?<_3&k+2YBp%OmkLtj`V!husP!tz`o%m;#G28tu(CqUxw;4l0F06b=M zl+E=?uKL?oKDm;q>VUH#96ZimKbv@anNIzQ=xWQ@+u=O(!kG_qm;>)$oTne!%RbcI zGTk&Ku6|(8*tdVGyK7mp+;uKFm#=2rU822fp3XAP8;-am!9HMWzb~g;d*LO5y<)Ck z-al|`cW!*HCqHFW?h1@>Q7|#1jPfPu47rS8CxX2QdJ&vNFopoNxxD@5Ay}d029!Z$ zj*Le~;8Er=@@IgAdr3GB;IW>j>0ePDzoHtyr7FIqEZpjn d0PJg_Tc2E@chRTlCxbLYH$Sbi(-m?9{{!1__tXFY diff --git a/basic_function/__pycache__/operation_new.cpython-38.pyc b/basic_function/__pycache__/operation_new.cpython-38.pyc deleted file mode 100644 index 6b878d7e801c09b61695e984efab052285e58f68..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5759 zcmd^D&5t8T74PZ~yX~<(U(U{c_^Asg?SUXt(NqP}}W<`+?w10cs|UcZ7`Y)ETjjfAmf+(oLl37wAch zp)oSO+!;GZhvef?3VEmw1JhF7;le^lg7YC zzr!6+oj8+NiT%*LWo-Qmk~4#B;_%EVgJ&G$l#yhTEN5tCR}Iv%X>H9Y8{aZ4Sa@28u{hCdi+y($1}{yB}sE~%d`S;mW_p@+4X$05;Yg*nd8#(r(9Y_`Zh;gdmkiD z;@r=!k?ved3G_2!(3TUM_jAKMwQ2tQ*#Y%0BzEH7&yLJN4tgylEXj>vMTvpkcKfgO zqE@Hr-45HL6MJF9d!r*{-15CR^g`JpB6i$shk>Z~0&&&biR13ojg2VW7E(O$*Xv<> zLq8)*|MI)SW-$+t=a7900jf`$IX+_N{0M@Gez#H&go5VV|H z+m9jpt&tX+F-)WXT&jWEKAQd7h=yh+XP@LJuMv9bcuQD&#B9A8TRP^}P09WN?mEo|tw zRZd^XFj7{()m7}SvU$sIs!Y)8_(42bV_M3B7(I{7FbigZt(aw2V0m+eUDVH6c9xwpPorG|T@vG*xi}z@T=-w zBx>nXIdk&$tY6@MiuCDwz&l&yohw=YGBzOZ;dghZ6?qd3d8nt3!}nm+=N?pEdg%r4 zfvCr!B+t?IcZKJBGK@9Q&?f6z9}5x_PLkM?5nPP)J5eK)ZNa@(ha~vk+dUt^=)gWN zk_+mLCyaOec&3kd>tQE~Wv?E8amu5mr%UkmTS4G;!q|&?jfRk3LxyedwN`uE4?KW- zR}Tn8BldQHTd&svP-G<17sl#v`U%Z*wd+g2?LD}fwxo!AvXeHNuwMNBenb0%aqF`i zcGi>@IrM9^$9MB#~x#Z zc${&acrt|0#OC%HBSjxh@2sXdb2L}e5Vg$kDfrtWQff{P(bvuhWv~0BUXLKKW z>C_}y;33)=9nKxS!AINq_&3mJC8PILejRmrma?xQQ+Bk6BOs?O9gbH#Dym#vz&Vp0 zDTe>HVs#u1L3hVj_O^(98k%%#sPC4}*p9tl1gM6GB z;@Xnr5u^gbpsPV4De&yB8O11FTpm8+IfTM_g!;K_#@>gqn-p-ZSU}Iucwx2&{6&oW zCCU`sJ42wj0$P#yEttI|UQdqW&&SJ&^^oaH z&WnR0G`E7ZnizChg0m`~GDarCd6aN=rNqSO(y%;20K~8dd)u1Zq@(-Mi94S&PZ^)H zi?V>NLu{>29rF79we|Bi&b@!(($@L+5#(LI^zC~a_tv*ARP=e%H!mEdyx3@?Z(cfh z{BQhZMYk>-f^sO&r;WZo#X#xb?Jw=i5I2eFu)tW{SPVBHOv8Z2MJN@816Wep8|Jx3iq6@5!~~DyE}#Z!k{pQ zdjW7SCPfYRQc}`zKR`)v{}CXFFBgE|0^rU8_tF^dLx6Cm2mA#0U!f$p|6Y$HVC!+9 z65Ri&$IMa*?teiq!Tle448eLbCWSk`m|&(UUBkUJh5MB;+(7}{^D}EDxR=Io=X>AR zAkS&II{<86!@WF)yFG^c>w`t;D}_7LaJQ314R?o1z+PD@*@-9){q3j2g{gASEkyRIGF#uNi9(7eId!Mp< zlx}h7Zc4W4-{Z_m~5TWb+k1-&61=+-g|2Hcxv6H5(L&{tLL(n0I zcM~ak6`77UbtDUmWaAcMSG3G=3(+fD7R7pNhz+50eSbtmYJ-mB4Q|yweppxR)Sc<$ z_#`G$b_W-kN@g71(WOXU0bRa8*^9`ki*&0Gaiy=NSElM(0M{v^xHj#;t*ROC9QOcU z!Pf_ThJm}ReaVPlLw-aK-vHe4=qtFStH+#rD7GUNY4!XR6+l~`+Z@3yM?g%U=%+6$c6wsVQ_R#^nJNC{t4$Q0UBZ0CvOzXxaE!XN~Yt5Z41| zm`Px)XAFGKV?($U4Gu5C2`0(MRoo-(DyNQ5BltGrZ%3&^Cv{6f+1n{^1BJ>+5%q#t zxsiaY@8H;BZz=X2#a>bDj$*ftD|O>YsT^N*e2k#aF%*{vfk1#B_#IJeb!v5ai0<&synm=HM*r9;Pq#2AJg?%ATi;;xrwJ>i|r< xW^LDc_q61hmDA>>VZVE=9ddk{yH3Y-L@J?x-=)WN&$z4Z3HOwH)~&et{{WBwBZ&Y2 diff --git a/basic_function/__pycache__/operation_new.cpython-39.pyc b/basic_function/__pycache__/operation_new.cpython-39.pyc deleted file mode 100644 index 46e609252f9f8023be8402c94270be444a8ccad8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5839 zcmd^D&5s*N74PZ~yY2CKJ~xxe?4f}rj5o<-$u1zGY?9q%KOtcuyRhrb;O%jhJZ)_FY_7XB|gjNkWce@ zzJR>Utvg0#@exKgDooASw?w;D_k-GYC)^7JZwgRzVZ0?|bi2-oE&I_oxp;5jjed@n z#5gnt#=6;OeKTf>`G9e@ZOYZyJhYnll8;!E4cHxn>fGY?hepC~8V^A0TZxI@_P{u8 z^lh{|+yT{zbBUGM56qj!#=jsp*Uu*o&n+7~=NQXIl1uWOp_X4TP|ByJC1cU}nqe7N z4DeDb8}U-l&*s(rPvFf)n2<}7Cgd*C48XZ)OdXA`#~YWZxiHQImzF=`(lXMoIpW$` zPn^WLmtUf_bLma6J_Cl;a$@sd;jl<`8vkCtPwi8Qow)b%1G8Vix|R}_6b8_t#K2y= zz1Mc5R;TIR3frO+dtt+Sqa$S8^1V3pLfIlB_Sn}+4PddnSGNCA>-ZtJlc%xHDnwZE&xzA$b9!oIkpO`!L9pj-< zF(vV;vIqoCt=9IVNJJG=nQPA}wl8yN)yb!K>zCKxiJ+6~FY$J((~4s0$Dv%mDWdH- z?5=lrrKsJw^VV45Al?c)nzrdjt$MAo+o{K`u(KX^1v;Zt>xjLz?tv;^2_O(3*}YoE zi2uNA^uCa;Ol`O-Ygd90IEk*JFyww{KC)(&n&08Ix-Vm8 z1)`(OPFFd;l>UL5CyqKkTF6@1&|j{co{(XrtX`|D*xSnHEx)O9L962jO=XK#b4yt~ zjMZFsHGwsg8%8Z!gR1)m29ITH zU)jW@TrG6}pr>qqUqosei!|bnj2y)=D-8}djx8v0a+1WB zj9_}C--#NbYzywSIwZmOzP;-M7#-;6c`{9%Y`}QSkH^}Gw;pz23wG=AlS3X&J)MHL z*9ro!6UJV&+h_>sHDuWKUTd{C{lEjbw{?d=G-7WHxb=2B0E&!6`oWkT&P1VkE_Z$D zx4rw9)0!0VuI!|hCbSp7zg^SzAgg_HEh`w;C0Rx}gQVifd1T6haaLxxrL1}w$pzh_ z*nL@`Ms}cBPqDq@xSJ#Mp_f=>@clg4BNAv7R@o|h29M2Nr{}S}0O|zJril%&#u;rh zbe_&7ynR9vm4TB!;f2Z;%0pZbkDXzJc$}}Ur^KtXFd46k@NtqOGk<4k&_82+1z)p0Zg-7R0)nw=x9vsaItZzo+-jt&>YdQ22AYHv&hotJaeHKr|{0Qr`h6Tc@>;bD96ar ztNvdp2h@$Nkm+@rR#KQEHPYJ-x>gdjF|?T4sOPmBxu;?NlhTS-8SumWrfQVrkF{h< zM7$$4${klDxrXh~KGn&(Brj4sSqb@NN{;J6{SrA7WO2WN=IA_<|E32ecE&uz&am=h z^Rc`J{yKRv!ylQdJYH=oK7uwhVHWVh#fKSyL;qv8aUDj$pVGvnsKJTy5D5tsuD z$E2-|0F}tecVj2XJwU&Wf5OEt^ow!s5W!4Rg!2jrgRc6eq{Q>vW)vfL5mJ1>3viK( z2muOLjGgymHz^_5m_p0Xcxk)^{L|?73*=?A{zmsJ4EupP1M0`R&p34!eLg|!99sXX z`^=B^S=csz%w$gY8qs5N!b@<%7k9A&axI>UONY}$dB75PY_4VgR8m?(*z*C~d4K2i z#5$Zw@&g2WiH(v=rJdP$mKWo>#CpJV@bf}{8Y?%CcOfw-h=Q@2UN#0M+#Qgj2qE|e2>gI1i{$EP4d6+~yZO|H05rW7jQXmvO|Rvd`2 zl$PdbMK$ms3Q%CBNH8;jv(<`D%LlV9X2L7HT|M>}EIU=iVz=+dR zWT(s`(pi{$7b*GHWu3cOUex(%a~6Chc;6nIH)S0X>a@c3B+u3-+SJ9y!> zW^l)ud;;9v5!?&9L~t(x?(PWgQ~go~_Y&YfolI-Emy@!F`#y4l`wsv?+-CrSQ-C`M z+{+o<4*|ll7Vs0?e~Fyn{yW`|fUWz1N^t*!?lVp$xc?cg1oyw|J_PGwpA_!6N5M!V zx`ums1oulB+(7}{i(_*oxR*1y^PO*MkQX%E9RRkd;l7x`-Ok|tdVdD%mBL-$nE>W9 z8tx9|fcvaA8?!^WC%HrF0pl=WLmEs#L(_5trgIqJhIz-3*RiYe6(p6ps*SKPlJ5}Z zyOey7l6NV24@q@q90OpLRcg9R$vPz)lzg9(8YS;j;#0Co$w_LvbfmVmy;i(M5TQ{2 zhv*Q!jAW?8|C<(<*=bV}9wr)Xb~~?-ezRe7z<3 zhFEhQKf)umLC5iih{!&ESXb=Soay6u83QT1gFvQ|%iKE(isbX4%NHmi_QGJMz zzLthe)ujNYQ$%rX)PSg}8E+l80H4F{1FmOaE^D9B<5!R$m5>;KJ02Z^OA0+EmWO;h zS|YtXKSBY}X1K6pxUfsF?cnf##2OaB84MQ)fv&xbOE_Ajbjd(L0In$%4kK)!U^t^V zpwMaM0Os%lP_uD;XAN*kh`_)(%q39M6$2N1>=4F8!{SV$!_M0k>zHc|0(l`vR1j!? z$Eb8%9nmW8(Y`3Bj!PBXs`#5xdKAz`NkPeXkbt*=Oy#7Ac7s^Ck-#C*!GXozQtTzg zURLb3VmFU($Hk-V;J6R+(St5>$fFMe0oOk8JEGR=)avp8&B1RmNE)A+Hk_?_dayxH z!&b{3n#YVPPhM2%%dzwzbZ`oG57Li%6U_2kl+g8wj_tG$uU$87)HNdXzH7;2GpB7& l0|xg>JLI_RyGkn;okl{Tf=iF-uDA>CDR

sH+2e*mq#F~R@< diff --git a/basic_function/__pycache__/others.cpython-310.pyc b/basic_function/__pycache__/others.cpython-310.pyc deleted file mode 100644 index 2bf4f34a432a4270531ccd0b0eb2b3ce7c09169d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1562 zcmZ7%O>f&al%zf^+p_b~G#~557*?Qanz+{u!8R;y4#OG@NU>pR0zolpN2@F;6g}G- zb!m3qZ;+3@^uE8s%MLr`wsS7squONNDDb_<_r>=fAE~s{X(1TzzOYAsHW2ztFO~;^ z#n&+NCjcm-I7a>ljD1qC4d0-I8owdmq$afhTNIxl&%Oo=bUeTseH@y0U>?HEzXMP- z1d`|qWAr2VK0ziDc)p5}{(eFUo}!{9kRX*I(E|TZ@8PW;B`0W#JmdJ~k!LDv7^k67 ztsla8!k%-Ua@8F2a3G>I$(3{Tk@T>_`CG&xYjpFhf02$^f08raA5z}WCY%L_C*Lm- zb8(R-HSavkqd_p7B(=+aDlQn$_p`ZbMtP7-#^;P%I>BuK1UuNqE;ccFPmFF&+d%Us z2nV{lw_^ne9ZOxCcTlMjDhW&jrYZ3=boREiq*b8>sxZZrEJy`)U*OUfhP1CRHD8nb z({B;ADpT6JT5qthDw_gbHfpK?lp`8i8~pupO3G&05>C~mj%+PR*_LhKN2M!W+5~S$ zfLpH#e?i-_LERlBn+NFfK(wR_`*lF#)9$ha^jE;{9PlbgquQ>0YX{sl;ELL2nY#`) z5~TM)dR5zOEZN)#?(enD9apsL2sS5(SHcyYsw-Wcv2>QAT%lXi-Gc6VgD;;-qgs`a zKSt67%F)!NBi5)3)ibW}6@knjlpUzkI?Q_#jqn0XLw0u1f=E;2EmU?T0qbtH0jF?( z0X+#`{*cgr-uCeg?kRMGpIjSojKxkqANPJOdxLZm(_WH_-Z|?D&O+f4-T{poc;s4N zikdFA_8))s=+*At+2dEUdGU1b^V9z6{@Jdl@T?%aGv3vjTD>59v+MUT6rO_w=H6W4 z0?6W8_Y|y)m&NLor^%?N<47~nONZX3vf^~gxN2~g#o>S{E6?IcC?jTxvWF=j!_6T{ z=352hX|9?K`kC=iFhw|1jbW6~Fpiax3tqP{*V}PZ1Ksng`Fu8DnZ9bu4zr9UR9QSs zMobwr6}}a69?tznemP;RV9FM09Oc4WQI5W%0i^A(WcD!aQ*726b)z zbp{cR6#iMc37g) zCjJDs;9G^K!U5O?+^w;*gk8LiT|)kDt-E-g^za5Y34a7`Ji9&u--6T1=6nZO{DB7g jm8mtB!qkrpuVZx%#wne|>`UEKx&9G2$btPG;&%Q4Np+O~ diff --git a/basic_function/__pycache__/others.cpython-311.pyc b/basic_function/__pycache__/others.cpython-311.pyc deleted file mode 100644 index 167d18603fad440640727a0719ce514ffa283430..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3270 zcmcImO>7&-6`tiT$)!k&lE|ej$+p(!niBQr+9?`G3Y-Y~soTm*0@(rV&>(2;Ql>(2 z+1X_zF-ycSa$o=j5D*v9p+R^ETo*e!zA+3`!`_xV#SJ_jzKw)JY@5w1OHwous9 z3J3wU5dpRl0Hi5!3uKrLuu0zn8BMyOrXoQlnv7?}RCHca z-dE*pG`pqVlQXfomA9jJ<4P(KOXf0lo~W$e!|7ILMOSSGiASAvi%5*eVPVR0YS zY>eGoWxd7tC+s#`WGg;@Q9Z1>xf-DC-j=Aao+4ADaogeHVa>tVIc)T4%IxM#y+(J) z><$^MV6uYF3Kohw&ylG30>xZ0M@TqKqMX42Jn;F!XI(t~CB@TUott%nZ=7AT9Q{pz z!ZOjeXn@7&`m%q^kY!&51si^8{4yFf4|b?K|18)wdkMoj%Ol+Mq}g0*rbn|0DTcPrIU2fn z!c)Qd5}eZ-QGr!`4GsWK`kn`Oeu(0T3$8+!+S|gdxf;62YF*ZOD!4T_wnGJv<`G@T z<8@%Y>qv(l(XBZ}&p?9$Xs*|>zCEaN%_g9EtZ|CorgitvTEB)p_}aYnH1WinyTjW% z7~k8(5YA3a0%Q8xIQ>mL;S6+eeu(j06GJ#VF`ugazc9*R4#07>MGVxT7J99_p`V-` z>s)E<&&=+3>TtsvXLrBm$>XZk?g6!ZsL`dK7OtgfC-Q~@8uJn)Yh|VIZruJ(4YjyL zk7bXX*g@88!kGjMUP~_bHJB`|{zj_?Rw-))8m)O-Q)=bcnD#z0Cxvd!_L7lxmjVqh z?f=#tT58MU7RrjT#z%o|MZ?!Zmp)1AV*kT&-YJ_y)pXU_~f0j_cpikznXk4dNDe6XFP(a&Ae@V6OzS+ zlJH$4!MG|R8{u5tdtRv9mX@O5~4*}MfNy^@h#+3{+N@byo4O8oK7if z#D`dNXU6cbkx$9u$YQHd&Shgba1_8EV^B(_rGy&G%1WwE?;))?T$tL7T}SP;Ipo1( zCZ;h$j@*r!l$H?NG-+sah)#){h}n>#hvZFd=*2hR*2BsfwSwzjP!{jnu%P zJ^WuHik?dF{C;q>92_-*V`gw{KX|zuy!`Z)XEq}^V+LpTgR|w}>|V$SE}Fr`qPyzr z+nF(Z!=`VzbgAqU^rI&{yC1$*4qy9r$Ozvw!#53|VAgx8e(vFw-9^J6GX0?wR)7Fwc{!$I}cnOZk9Z$tS zTzaMK59|K$Bn|%RKH<3lE+CT8{ zgD>v>`QBr;lsC>^HqTyuy0q8(_qUAxd9#0hPt$MQ{KrQ||GPM^*wg!LsLX~+OZ&o= zvT(%^rcGgbUzjZmvxYEd3UdZKZ?f}x{rR_hSZ9Wx$INTe+BRg(aENG!M?sOi@CIP70%!)Ir@}kEbx~kMO8qL-gqjYPj5UG;JRP`)fhL0 RKR}yeP@wo4=oEE-KLLZ7@#Fvi diff --git a/basic_function/__pycache__/others.cpython-313.pyc b/basic_function/__pycache__/others.cpython-313.pyc deleted file mode 100644 index 4ade1f669eccc13606e596eb17f247fa38c9ce6f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2865 zcmbtWO>7&-6`tji)KXlLwyCB59BUR%H#2+K*VhY3`}6r=@$WqVd_xDU`P$53f-tv$3RLb@fQ~4KeD0V> zyy_bU$glFG4-~nif1jFN&T(yPN&@8L$#?iM>285jAk&d*;9Q%RL6I8+8MZ3u0Xngf z!6281npFhyo68g6F_k?CN_McW* zIs{yx?`uIk;0gWiAm|TBMB%+sBO*VD$h*g#$i!EI;!&2qj`mY!={swLfc_>N0l(`V z;cs&z{3oLq!56&?BL8KWBYm#35kyZQpSyueBw9pjk*MHR1EU%n+MZ7!V!HRWsDOSG+%W!q2j+|n7<-a14 z3_K;0_$tszs2LLJQOgW7E`5@5T}rNjP8L`;G~uzF{gt4NG7EJ;6KBbQ=|V0{uAsGV z0H9Rk(_>7tOSMh~5DouRH3Wtyle4n+;SpM9f?=qlr?xv4&nkgbL5@^GpQjQPX4$3V zz$2gR&NGf)h$)WwarI@o7i@*AYji~8HX@qusS!ldNru#=>_nxfyK{_VGPWYgB5Y^T zr-0<(Wz5MKU~XogUJJD$YimP(=a+KMtd{e#VcPOFDBB3?cG8FFv2welZ=Yj)ZQ#u5 zi;2wi>|3WZn_IQ#W`FQz>diB6O(!w8S@TVAB6`>my<2L&+0Fg`l7_i0A|c=G7UpV% zt?jqvwJWt_*O6%yW$K4nv1PN6jN?GrybclWL9kNRa}WotO1WfXz6=c(Obdq~xw#?IZHRpc;+{gu$m``Y<}Lf^Hb?Z$LiD*p^fJStOPe`Zp>GNf>XiyK z@;HEWqX;pdH*Fly5z@DCkM)Zxgf)nRwplJ&b~1_u`W9tL*f?rch-cX%GOLv=aUA!P z6|)Ey$}neVE2dR);YB)2+(o6$>>Bot7O_ZTrlwg6j>BvO%Mc446UHrKzLej@{_7^n z``tyV8p{B=6;Sv177d=VbQIWA8&}>+F0?%VUr|S1&FdF$62dn6w5HH+vWy( ziKx=6(+ZP~y8)Ub&+PdJn=$Fes@t^&r?$mre7q4))Z+;!e!3B#tHuPu#rVL`QeeI#Fd$(Do)58j~ya$(6s2IFp(~CZb9!CVepf{`}6e z6Pwr;o)`^{?_PK~l=|&YwijFS_@8<|?cJNcbl-O5)jjcVGkK0k< z6-P?`-zF{T)Xuw(G;`PgKzgAyHS_t;8nYK4%wD`NI0)(W7Fj5yrc2>R;o^1;KY}62*{q`IfO8k2sL}EA3{vo$B zTcN7p`CzrB7(*sVFI=I*ZkHfQ4{<=Q!2H8g8WJ8#! z3lqDm4Q0Nr%sa~YhO$ss793^qp|C`jk0cKSAj*6G(L+)bAQ^Oz77mb8SlL4K2M3*^ z9ewWI()kRadCI#LC;V{3%va0s6?B2j(R;-jBpvxZ;g5wbY>hoj^m_hx7%)49HJEmyxle&0j*$ByA>BIq!w7btt>Y)Nkbjm8Bf&2 zj4xZxXnz2Q zTNR<-^ z2G|40jsfLOfh4-Z7<~!8k5B~(eBtm%80q;HipdLK^uqy^hoGJU@4eMcrvmX$L+oS9K&UrFqopHu^ zr$>1w9dj1!9euXMOvQN;m$I`k8+3!-I4)gw5^>IWwwq2>ZIA`=czDLRsV}+?2w{S4 zY~d!h2+Z5`H7yuQC_t9&7>KCRoXK0RQm|{xiWCq!v zRb^1J*~?sb+?eLJw>Bj)TImOHDKo_+FafN{S~nL1!xO)e=Y6m zw{{?11?eB9&9Zb2Z2p4Xdl%$2=1ys~zGQX3ls_u%?&JcZgeW#HY_R!6G(>Z@B3*rN zX)ndvqMOp)gi3pbNBQVX8naaidB8|kfU`9>YKb-KLKclHd_^Dw?-dQm*FAuBiTZes zr6C(zXij8B(@j(~B?0T^>^{84#vJMs!wZe1gj)0-9A4wLLf81o!hk2ux3cN5{Y}yC zCgX^<<3zO2SX*!w3XkvxXq3n!3*96opKtCy{^*^T+dHR^Urwg^yF2fnbWV0pw>^a? zIoY1@rq*IMvBi#$C_I%Q1srIR?FilxZmC3`n&y+zE;hQ1n;nc5YqcLMSQ(PZ|4ih?#*oV5p NGA**Y}Tbl&WT*Y?_R(lluh2Mck)O%rm2s)C}V5voK~DFKA*XtnW-U6b|hdUmYX z%lOi8M*RUC>?0SB9QXnJ0#0+`z$u(LMM&^w+XTam-p9OoZ|2RL_g+}5RS1qBAG-bD zN`(Aw7R!U;;xR(@01+b$BGP_Mp-qdm)wURA*0-c>Gn+X`9R|mw5U$v$7Bz^!dtDQ zlf#y+onDlLQdhnRqcMLXL?U##C&I2ABypy_!#8XTG-N*zfiEZ5C*6mg&oeHv&L?a* zhzFSzp-e>QfM@42NjvFS@L>P=*%C9AFOs;BorT$;8}!C;aYZMQFYutGgGe)jYW)#6fuoWDTS8&>%Im(%n z1(^c_Mi+FBKe2$BD=p<-0<+I)cKDPqXKpLk#P20YXYMlOXQhHGA?L}G(LNxje|G4s zJgZ1=US^)EEaUVsQ9j;R!=0a4eR&7< zSJ3Vkqz&BtwXmz+*r9Y4rGFGQ%hENp`3raNoYPkzDuvbhlGWWp{!U?cE0@*6i;XeH zMtmxLS)13DZ{DrEC0|$Art&v2>t4brADt;{-cT3^Aj(G0Gu)^lSDBAdv@YS2Vgx>z z)i7Rb2zM0e!vd70YFlJMm2K!Inbj3V>-u~hf8*`~vlY;bky4nOt$Rn;(A4A_9$Z=Y zrTJDi9X7w2HM_|;V$C>_%`@JVf`_t2MGZ9y)S_3WUlV7;bPx8ivNaB9e_>rWt zneZkS5u{9N+eh6x|hrN$}Z!R{L!q z-)u|+b9lRYv*HnNd$%Sx+`hI3Y@+RpM6mYil5SAM=HJ86!@h>EwIB0~TR)1v@*0`0 zr{7)n9=l7N^|%NyvOzY8`wp3e&P%Zt(%y!M{qUVhcs$ZG%j|DKwUxH0?$do_ddl2^&ZJBa=~? L5qWsCN4?rV6;h;e diff --git a/basic_function/__pycache__/packaged_function.cpython-310.pyc b/basic_function/__pycache__/packaged_function.cpython-310.pyc deleted file mode 100644 index 457e7cec1b07deaa9daa78a128184375542910d6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3446 zcmcguOK;rP8Rct84(B0Rc5J70TP7{iPNPV55ICu78wLCb3Z$wLISo3tLD0N+BuX4| zbT6l}D|n6kNEqaWC5BITvwRh$yG(Y(1gH&K}xU2j4Ax@-Hx@v27|#naq4) zRgKbO7PFrl%DQbFnuliPs>Zgh4CcHrE3b668x!})Fb(A~V{D2~N^jyT@5rqDvZ?Gp z8myrlSXY_a8n9j&DC=;Vcaiu1)Mm{?>$7gvRLx0C`A2Zur*NEu_t1vhf(_%UA=md% zPVcYqQ5kDt^X>JQX5EsbIN{NzD8{42#3p8sg@lby()SIR-x|tz{;BvJ#qrMd#95^dUI;hRP`*(Y7L7T1@ zZ%v2#aE5@|mg8aVA0vVzyDw{3maG_;bz_i?<(43q)!jXjmOR3M(ICxu?Gz)P3zw<_ zU6rLdm%Ts`bZSd(*3FrlWZlAOqcOp-+^YSUu_%v+m@JOt{^CiqC?5|Wb5XZ)zBk)* z?a_?17`KeYCNH(er-r;$FPw2w6NUO*yXKPFHUrbg_a3GG=~#pfo06!UTLc!b`Z!{V57dn*qgyM&;c$9{r|j zGP??tTeXN4+pcnmBcE+o?bN6`Dp3B8A%0(as&i;62k8eYnC31Sy6%B$Plz=w)l@*I zY8<&$7rBn=UN=;WHQqNc_P^~ff1>9NR7d|?%za@!w5o+_Q7vG79gNkTw2urSRiMX0 zOVzz$?528c4e>kGK|7kWJLwRocdOi)u3Nq{&&S*c8~-rR#@*L7gS_>F&eLiRn_qtT zX)N+I-&_vMEiQSO$}r9ZkJ)}0=i%2N4hKcXxCql64bh08vmDBubQBJESSkb^9pK*@ z#!sTXLhNuM#dmdCza=<`kVku4JnwaDtB~Mqyu;BA;GSNIbz>ybysU$y$dj=U9BFSb z)&%OGzc*#xy`sqOJOS5nS8A;HT{(U{OiK-++S!ZKvcGWmXCH6ek8b_q=l4FkbN|lm z+64#484#PqdBU^aIYIY~_$Ex<&B4~_WSjNk+&vX#DJs5CZLd(<0HhaXMEvfz=i14B zJ5CZlDx;(r=cRa^8oSdXO~1a+XN!IZb$xfbui9Y+F8(e}K;WKZ(ZxAg;DY2L&N808 zhF0<>jAJ^cWnC~`O(@DOf=Ir31!X?{|8d!+KVHkXiE@_de{May@|E?>`MbD=A?}_M znKnv@aGGF$MdH6CmIFZEfe3+6phSb&kR+&O<*`NunNb#K)K@OZ<|~^3PZGks9ZOtO zz6QC!`=*wpCL}2+o6ILFs|G~ICef){sx=8zK-5O6MX-mI5!Jle2-A7wTf46@4>A!P zyb3}*Yvi_Jcygbug>gtn8a|j6U0pt@eRnxrSy_2B-*N7eq^`gM(9hB&E;*4!Qy?q` z;k4Z`LFNZIjjc8GF6XP~Af%zTcD%e(ljugF*3FQ{3(F#uJQm589*XGv@C3e6yoI9^ zq*$7dQvzC34LHwuE_t>OG$lLHkV6yM`oX1B{U1K9qNf$flvuzZkf|TgSUoa7B=a^I zn+#$sLoAUYJrWnmTq1KBrfwuzAvyY=P)%H=%J<2a* zp+hW$R64|Rl(TICy$G`?-~!4W;P*1*r9tw-oDDZUgXdw1XIIl1naW zsZGB+T;83Xx%0X+bI#oPwaaBk5dONGoBo>@p}&(4Etsp3M?Z$hb;Kf;N}wclj-p81 z0CB@PgBmv`jOUDyW=xpQnOW1j=$wT$vzBR!wX!s8nKHAscd2t$NVAjFDI@EEZ<=+n zPI%i`;{*!3ACl(6l;WC7<0Kcv`XqP0nv7F=~b(fBLGnnV*wvq-5DfwTx7osOZg235-f z()yW`r5ne&A7yQ#ZH2CvuE#{Wu^x3wVTv1Tpis)ugZ3icG;g?MJd5UG8JtBF{8-xr zV(l=6j>s)aaqx3YjEhd=bY?DAvl_H!g8x4FKROKJIF!4NM0BAeOX+7zr0Vn4*sIg% zI<&ad$s%Zpii8b{F~Ox2Gpvau4_}7A+W=X9m*(){ zh3+|=j&VF6jp56@z$E$y66qL|;1BoLbjXh4`#}5yev3#7j>5Bg;Inpzl#SmRsF!jgcQwlYIY9wsK#F|RWP^A;c zw9TOy`AkxwtDA+xWSuJRc^ntGC=`xP#S@%jO3!gAOcteLSF;lF6vv0{xP>Gc_-Vyf zv!mijXN0*7*^R=CLNhEIO)*JMF)|FJwA9q1sZ8>14l9lnH?LVnv62>{xMC3)Jk1F_ zF-)iyFj?-3r^%kU2U5v0<=bKLETVP$R@uIL)pid{_OXI}Otz0L4z0HbE4I6P8Qz89 zs0^}wRBj)YO#YIu?Rwt_eJjKHzJf0-`@)OIOIDv`-Ma4ECj0ttUMctv$-YB5>$=OE zKQH-`@Z6lKyzly|Z?^KjAA1n~BK4P41snzO6u z>{^~KICsd-9g=fL*$8>6f-1fts^1b(PunQ=|lBHH^D5+IJ z-XhhwFB`uAB?3LLw*_6xY!glq+(;QK7$RDSQJ%UmkOdOR+C@v&A?Prwj+QGHeIK;A zXu4x+Y#(sW8An-{Xcy^ORSiFGz)V-8zG$jy&4~8@q_;puuGM%_bZ(5K0-2X;5FMgT z1Z#=r8_@ULwf<+_kmC~F1E425Sn~pG0~oa)*Ej29;Jf#6_2edYROw#pqxKeE8*5*5 zXc~7cjchOpYa{E)wuqi#bYUxu$-Uy#boK8>CdFpaULCQW*x)s^v`w8Ul@)PNbU|$a zwKg3W*6VAm4cMSPWvS2f>SH*9{13d9r|C9lZtVPuorKEe&O$dP~0gXp=fQGN{f!uqlq|V>?FEfMAr^1h=nMjEybzR;}s_p zi*a*8G?vb!1RNyFmP%5EX=qOC^jsd_HS{bqqaBhzOj?Ezdaxr1fYX%2ovM9xt1`f`?%r!!~z+SM+*wynAL z62o zvwHeZv3tFLNc=HT*gy7Qvbg`OwEt|W>-o>4pG1Efz02M^^1vodyd_1a3tcmE*Gzt- zv}0%f*z(A=*GvA6<-u#C`QgV_)E@i})joEjEzirrzO~@qVsP(j@BIb&*jf2lTng?j z1TV_Li`T4qQ=Tbp?Jgse_XSlzd+lAIF4n5aQ%5X%toNSpt$TQ`8&(KsI^nGN7)Q^QjMjG zZ;08xC8D0TQA}mZXSM!Sw!!qM6UCDS{mb5d^gHh$W&C2-(BY%6nD4)0I(o=<|1bq| zq`{FOJTPCV<1=gNw4U1A#;{I2^Fdb?%IXHD$ri%F-f`dzIQg5-pq1 zCoLMAft|PstaX;)7Qz;huEI6F6xLJJ#TrSAwnjqD+6A)xs)409J4A?Pa67=#w~4jY zz}lL|;t(Ax8`G7&u}unybfrdBT7Bf*Oc)~AMkQQpePne@t$wt0m$ia>VPE*CoxsgO z0xQW-?U;Ko#Doa^h0be9d;4m+FZG4``};5G3+m3!q$uDfFa;ZrF#<=3MWu^SdMZ@W ztCESnf<_!Z2&DymwhnGpSUygYHBC9Z$p@QD6ESKWB#ASfq*laY(fkVYc7TO!iPhQDJuVjhag2|!?J27#K z@Gc_yiP%j9ab>WdhyW2}LEv^GI*8~5p_pTdG|xf#8b{#+By)&}SBW6>g0~X!8VH4o z;}H4SPDD2do*Yj!3L)X7`;t@&Rp`d_?x zyx>0|`wtZ8S7rKDNqzn+cWf;qx4BgnIm5cUTXyeT9ewbUZI0J6`exmxqg+O+;XDciJR@dOxt(Yq7<#`~Hv zMrKwE#33M4`{=Zata_)v?()Nd+3?I24P1k4f+9xN@C*W)MYw^g;+MC<4KX*P25Xu? zn_#OMYHm96W-a;)8rHg^b8)mm688wMnT$=y)^r&Ij+b>~2ST(o_YpL8=wENA8QI8K z!nTE5>Nd{sm2EC?fK>!&YPV;AF&L0&I{s4*&avdCriu)3Jp!n}v3hvXK#s3~Llq76 zpr+T>fl-4#skNG*hZ+@DPKB#zVj^9QECf!fXX4Nl0~Zyn)@=?>&_@1pa={7c(gZ2? zX82Gcr3ph+tRX;xhe3pC9gL7kQ~|959wG8s2Mw~CHPG1b+zCh3DCs&L1IaKXzh@!9qTk4J9iZ5UYYKd zwCC&Uy~SNY@(mY!!?JG}E-$=o`Se<#uNdgNdH(K}d#4J4F*z{yV5s07m%ZZvBD8yr z?kv)9&MeT~GTklF-H`Ua^F8NP_mVrO{xESnVGf8(QbBjS;%>YVtKc!&$trj&zPc;Y zU-d-r9+KNb#2^s_K2?LNPxm8y0U$8F-KEgawK_!8F)Q>98LFcK}Rktp|2 zs$&OfjSd)`VMS#OL7ib(f?*FX&<8)dT?eE$Adga%Od8`B-PV`93~rqvTb}kGN|fcy zZT7Gu=;yif-2U(X4>#?020{31rx^WPEkb`M2`w1Qk${fGe@IfEBr6KCX1ofK#0XwG!l$XUGT4GNnfmm&Fw^_*4Bp*cHeQ)8B> z-siYSUg@J0`&1Gqgj8roz_N%T!`;VCgrcG(Vj-2pA&C^n$HUqXNizrh_3+;uq{wiC zBu@>KIxV*KY(;?Qj;$(UzTwW8knE_-M{5MBIv_ZQoh(~qkWiPB56|B_cshxH}8{1pxijhTQ zPatF^Z${7Q;l0$jPp9bPqNEsMzY{W~bR7p2I+2Wv;dES-$>dp9IoUa#Oo*LnI3%4@ z5!^YG#$pJLN>yYoHJy~yymNva3x}rCQaBY$N}VvN^FmaNRB}3I<`l~@a>y>s;1`wF z8JrA@vK$KIIXNZ7JC4PZVIeNR*iorNHd=;}qkCupWxd0n%&i{zZP)UJYj0nDyRfSx zyQ|}qxqr)Knci*P3$s!TvM5I8ZJlxor^6{2tlX(kAw?IOpEq=bV^az>duq^U!K8^& z(^q=&kap}9y_`-cbXg5VOisJvn8mS_7=p&3saRZ849OW$!WL4*s73LZB+5Pp*N`lo z991lpRaC6WbZRC|^ei>4&_W~2*CYTw!{=~}k z%4@4{uDR9>YrSicwbM5aeHp&#`$EbN@LB$J_SIU%NqpNbXJ;=JZLWfC zch0tZIeOdn5}=|6Wd$652e2xQzNhx{Iy zkb=V1h#-rh!ss7xQb8^BfC-&M?|CeiIBK#2Bd+AN@*Yq&3#XSD z-XvLn8bXs5>}dTW#2YxnQ3@bpJAfp6V5P@}kN_C20+3yMCvQ+oPeUtar}Af@Tb@?y z^h3`~VD`$KLMhdgpg5aZLaQ_cufR!JEaXe;C z>dV8adhqtgwDD9`E=M`jQC%9|C^@AX-U$$FhPZ=wOg8;@UCqE~@D*Ukd8RyTCh&~~ zjkT0;2>bCV><0jqw?m(Gg*0){BvN4$z;#d?S zbF3h-TQFp&lJSU$*_Z?sp%TH8pOw$YX4u5}NDO0`3ULpt3E@I$Hi^%RSjPLxc~}IX zN}<_lQSvzyT~Y>fA)FVXnXH&f{)A$j!7(YNFyW*WPGc-W+&q<5L9EnGlz2apOvVQ; zfSCbDqr_2@oIaO`rBrw+hFKw&Qk;XsliaCL-^sCw(ScI~{fY@#o{j?z!-5nRR_c0#dL)sq#Q5DtqosZcnXmQomw z1X4_;tP)8W)F(uExBp{X0+7a{I62*MoV_fz{KW zhHte0F`nx_ar11x`%H27zH6bYp^rjqk(-umD4O3rof#}Xxqo?Z>15H}xZJxmoaw)B zMh)JbySUSR=DrQ>+E?)Ub6$VJdobrcxZ3{Zd|~KJZs<&5D3%+FWxWTp-Zz%a8AC?6 zyQdZUZ`l38rKL+BytDL9v2oA!maEof)3=Syzj1zVLCtL&4&?Uc>wI@xeaj{&zh=K` z|DE%DJ!;$kt-EQX9o046akW1%BF7kY%Rc&p1=i$W8%9c9Hhb3bYJm|aMymfLX_A^?~ zELft+hEls~snHwCVaBSZGH;E(G=G82e*yk~D7&@hpCat|TA7^8{5MPw(Y*v)alU$3 zfhCM!gId->f2H)0pY7=AcwK8$<;{_102V-LBo-D@A|bEJQ#Lupmg<#ht2N#ZsUB#V z(u%4$r1Gk#Jfv-%tCd*wmcWv}*i@30MFEGWnVu4xdl+IE&Q?GuV;?LmCLVxF_7ci< zC7jhoL`fFobHJkT`A|XxH$E~x->|Ly*^8G#7cUXV8Iwccvtd8!i})e9gNRNd3`Bs# zfH1M!s@IIkzT!qAh<#R!;doLOp?w9~_y{TNBjP9#gtr)cEreeHp-?f*l0*v;aL|$J zA8Je`&k|Rw@2$X9xTnOx@(n@*(!cyQhy}D!Q^nG(dvC%0bk6{U*UaZ9nI zhR0zI`p~J6GHu<}IrC)2A=RdUP{76OD|o5g%E?!Oo!npEv^5H-Q9;*Me_MEo8hn=d z_ZYa`s=AaC0RlveE1y9?fib9gH1?JVUKF{cm(@>@BA|L=MAKi59CDAAQ%wpjuS&}+ z^++#}(DYZ;)l?3;0JgSqbo|2RQ%iq6La9>swq160 zRr^*oj-6X8XB*9i1rkDPZ$=8AfCZm`73a>+weG6q)YWExZ z-|0Kt=8YGu?{bH`TL$-dle>ck_j&7u(f3f=rt*QsJG_gMCcnfN_##q1w;r;UrPuf- zTQTa+AeY0a3`dbHghavet%vu+Es+TsmAMQvs!lgzn(XYMWp(i0!dv_arer&YGL^xN zmuA%{O>T1Q1ykm2c4QnGm7^LvmSWs~X;f}$?=&XPF*BHQIooZ@KbG#qQ|_@*d1X^s ze`dU)Y*(T^p$xu%{$0@e`)jPk@?wT)l|(%OL@oGwolkk*l*Ya_9^53L&jTJ zW_#n6QMZaw6pL_6=HpRpV7jx%M8d*b(EA3AVow$;qh!k1!Mq<~j(a9&7{}OHF>5Oo zS>3>{3=8DU-j8cp!jkW|oE$NL8>gEkPf@dRnvqh7$VQgly!|#k_9*d#~V|h@N zQQEta=5dr3H+!cYaE5jv1V3!A6{Q@<5DWOr;Oy5DD~b?4eeIFaAA|0C@hv?n9((sjJDw$9^UK5vSAzYEtJ3JWQdK|9!$+* z7k{(o5)yR0ns*1SlvTOPth~eJl1)JU1neKky1u-~%G{1mL;W>5Z6 zHMv#!%BfnUf*nWMq>0b|TD22ab(F8XT_(RMUDY`u?gU)*z=+P+hapAM5!BT?J?QBw97?*8Jz!_Dy4uYYy_mv=Vr+^!v{X`F(+ILcy? zu3VCDVqWqaFm*SB?!wb$){ApDPnagC_&&A0Ms0n_SC%0uwcnm=SM=La9E(vI#`!oa z<=3gPGcD3W>U&}~=^vr4?@ZTK+dRjm-K7r@taEy}GIym5NmqT8rXu|sS{3iXFv!3* zO~}E4Ks=pWg!Rm8D1%h2H(>EIJ&XQ*L%xq*ch5P6HcCh|S|oo%;wy;525v9lI3NmO z-QX772s~Wn@xDmf>a*rgb8gL92 zIf|;KS`%OSBvGVVC+>&j;+~DL*W)fOd;jpYAK+Q1jT#20chPzf1$2nP?`B2suAbJu zw;J?%y(jY>=Po$viqi-0Jc*-HkUTW>K|Tnk?M@^yKfrmK*3rA1ubzWO!?k`gy>lPY zjRNia0vazU^PmusjJNerB;d!Vz!vg*I5aO;6vS5a`}Xl~s}lepzJ7;Q!$;)>Kqh+spXYKbsIb9EZL z0VxUawvs;xT)q?poq!N!91sXTPt%O);{T5M9LJ~M{wfTg=#&;Tr;(|bSPyg{LK?$P zfJ?n(P59=@3D`F6+{WRR2imU$g9xz<4`$&HybuDQ;b9w~U!dE+NV2B;0?rzHVp@I( zv*OJWO)<6B)|BJ_&F_Uv9Cv+4^ZpoP6r?5A>}q0tv#WWPAE69Od>emL9k=T4`OtH! zJc}BTO#4e`-m-V%EafWIUMBMv8M^ni3iR-E8q;di(U*5)m~+I$H_3x52s9gsL+%`M$REMEPe^d_87YUHc&}z>+`9*)Tc5g!~IL^~Z+HFX5F#FpMyok`WzHifIGV#=y{NGc^Y$l$ohDu$jf| z=Vah8hq>F7d91Sy&m|=Cg5#uBm4=$$-?;UvMXJ87lOO>H0qc4ooE(~Tc z^Epw*U2;T^Xz3_>*HnaAFKFp1myrn-e=V$;t=wZ;dWECRKNDsv3#_B4wgjwOxXL)1 z=N-s%94f`{a zkTd$%DsPmhJf5Wdb0TnizJ5Pm-}(@CRBru>jgl;ph2GDtyIk%S`FLx5SnT9k2%Ek2 zB$7!S4kuY$Bzd+4Q{0QTIa`+W$A?wp79PpdvAB(+{|8<&{lQq|F_$uo#i1;sw0|qj z<0zH4`^&~~o|fdCObaG zl#@~Ao$$Z~ZdBDd5J|y9=o}7{lvh?h=9zG?%GXtCl5yGdMF)!vxm`6D+as%HJ}Jf% z1g_Yryoj+di$=UMqbM45m(9X#GJ3*A)y(+8Vttj1UxepH53r3>|>|wK(4Jiw}@&o`yEm?_y3?r z_f=c}H<|Oo*fPqMva42L7MGw$XVy9<;vdS_J)ouP{G9B6sb^q|N2(3&TDo2P7>WP) z{#W~t>iXKhd7XSpWvi~R{hL}B?#8tJtgKP?^?M&hB1^LED?zctB@YrAM5*8rI}D;M zc!gsy%u~iikYvyh8X*{00=burgV7#Kgunv=kej3EX?T!}JrI@nzAhUy1qYr>mW*s!OH#Le6aa2yz{&FAN=;-!+Uot2h=o4fyX$?VxIQS3A{VRx4~4M z40IPRmsxddcTbq&eegqUdkNe6AYW00sMJBL)=myuQ5^Ge5ytr>E5t=??97X_xCS0y zO!_BKH*n_bsw|en1-^nGK!j@=xlp^_xxi~bN>iSmhgR}!Fa*S583tXW4v5Cnxh`1G zxC~_=AoDe_@H0FU|9vWMK(G6!97GFBP)b@lUqIp;NX7yPpMw(tsQ~OYGr^@e%qm@G zm!1LvfQY;YGQx{tD{nloSd zC|pQ25p{4%C?faO2>X5J!gcRYU-|(&>$p*)V0k~S2T_2B82o-wbp6V*_U$V{zu$jc z?^wI^uq&KC@XnGrDmcnRLm%YBVBYRT67>O2)3gq~i+Xhp8V%R_$@ETrL^lew_X}{m zpvVKsBN6ZDp-{kEOJJpV0}f7DXi$N3$r*py1mLEVD18SSN}M8bU*X?D0RZMG1&r?QO+PBzKxfM#7MBEBQVeMnit+!z5HL> C3}V0l diff --git a/basic_function/__pycache__/unit_cell_parser.cpython-310.pyc b/basic_function/__pycache__/unit_cell_parser.cpython-310.pyc deleted file mode 100644 index d82f489d384a12c0fd3469c7ff8f2468dffb8cd1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6600 zcmdT|OOF)S6|QGLsvq-UfWbJp5-B!KhCx7{gy7(Dh;br>0`Z~^X>M2D>8_!=s(q_^ zU`ExOfRt>qPZULBHj&uo4@64-MM+s@5s{T8@4N{4&aLY1nx-+f;z&ui>ORhWpL@=C z&OKqHVJZ0i`rvncW-H3ysZ#t=>ksio@2QI7Ds9D8UCmbqT3Z|FZGB+04V7rRZ?;P^ zFAvmqMOAoZr^-t^)(NFu<2u)P{e+UG+Zt#MZlETqpw}d=DQVO%t1ndz_chn#4Y$OX z-7;@(m)we5eQdT@K(}}^r#C>axpmN2<%kAfDzwTqwkvMaT|$o|(;l;VCq0(k74%q> zJysXaal~CikE7EbFV1n)J%%2~+~YV)13Si9vGe-2;jX(U9vkiBu6|2to!tMMbyI}9 zp35V9dlWda7X~PW!gj)7my6g7dUlt`4>%8O*W2Faf(NlJ_)zc&C6)~cw{2!QVlRrB z?}t6XhJDYmM}ZgH4)=Zgv`mLgumO*`h-?Lise{&q9)3LI@S_x zM|0KtTD1JW;%bl7UxM!HCzP>{+(2$5`W=02j!WY*M(K%hprT$GSH~9WMp8+tiPh63 zwKlGg8=#trg}Ixqnbh3U5oKLT>Ujrk&x%`@_+Kczup=wlldaNYLhRFI}0`StkxVw8}9o5uHY2 zEL-ZTTG1-1r7f$+G)p~WYR{KV3zWM2udB;w7t~pN9myd)TmoT=hkrx`NHf7gAmai& z)PQ&T*ch9!vZDj>fOX?iUIrQxE=szIas?<=PAWZpTy=HVc&v=A#7Ycc7^shoaSb&c zwUVsWWzC>5Xv>gO;uqoRlhT1QZUS9l{eV;^TPY&dob|DRR8l`?;PrAKA$Agwkj_aV zJnMV+xgB=tqF&UyxuyO)`v1j2E!&4^;0ixIr15mT1sZ=*YP@5h|Kvw9(ymIx>0pI+g-}P9suFpqC{9pM0*}Jc zGu1Di?Z@a4y^h4P45+@P)iqLdD7@Y@n`#prfx3SO9nvcKTOIPO95MxyA0Z_%PdY(2 zQJ5;dE=wv@*GM#RDlu|*=;+iP(m~PTAe{omAO}eb135z4*i6hPhIAxM_LHJ|7JR#+ z>jy2Jxp~Eg5k#SP$igjD-|6Y3`kv)JACNzr8BUQYGg;8Z%NI14JV#>j1O^5y7T!Z? zd=I;|g)!t}F0l3mOP{$4ue9|OJAZL8GYX0=TDYsHsx-g*Z{=ov%7mhlU$k>(M4K&h zPT<8E%#w+BNqm>eWEC>=S)b0<35McZpyx8=z#X3#%~P|^P&?|;>F#Bwi-1Bi)&}6Yx)^ zgSx=S{9JMeX(P0NBPaa_x;ucolp4A+*_g@C9BzGva1dvHg~6ZD#n&56Ra~K|E+IK2 z7iAEpFm(r&{|u(2S3({NJQuvl;CBGBN=b=43+YiDf~>{B1jJelCU!x}NciytemP{3 zs=PwP7A{|qi~kYCTFi`j$P(YftN)hWR zWjP{B7L^c#s^xh-%UvU}&@Yce9mg2Fz-YhK*mAX`L2)SBS&SGK?e6!8TKz#M2e4IP z*U2VNJoGH$kv)#dg`!}^GfCUYeu9a4&nA)@6A!>zL}nFiy)i-n}YgP-V%YZ+|nb=ag&2GVw0vXg9ZT z{+ob>EcBh$sQy7_@Y}VmSzH@&r7-j}ETXrVbbB*LhoDJN=jh;DBZr&ox!cKF#{q`g!fm0tGE4P_@?=&{h}FPQFMxweoMF CCi4FP diff --git a/basic_function/__pycache__/unit_cell_parser.cpython-311.pyc b/basic_function/__pycache__/unit_cell_parser.cpython-311.pyc deleted file mode 100644 index fcff2bcbad0ad0c1a7eac96c8cfbb9ef049a3c52..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8999 zcmd5>ZA@F&8NS!w#x^#F5RyQEq#-d$C?w4XN!2I`p>aY8C-)FCD;H`tV7T$Vz>)>sGw;tX`cpKntg10eXexJE+3a~wl%kl(Oc6a*osxfJU zpB8@K#dKJW!1Asbi-PPkQiu!8rx3qeNXIpTW~87gz>d;*wCoZS(0mGo@8(r7zhDI> zpxseO%j0paN~SB~;yRT~FO%s(Cc`Rp8dNeFE^buG*L2GLk*$G#8j-o9}I^ZCODqsDFBTRkE8eGJ_IgPMiZIH=sBL}rZf8JFgFp*XeL4# zt$#e4(MCg&j4nFK$1;ZK7BaFL;K(8q-*hPg~xHZc&4ae2WIH#Wv)%sdy!VN4jM2%4Fb( zPRZVo>Rhuoi`oA}=ZU>pqVdP}nqPh3jNCF>67ygq3mj=d>Aakvuhb$i9^)Z7fQH>RM@{fUd45RaWZlV_8YVnt)x zx$};9;mwqr)U`-;EyUTfWcuRB%JBy-a=J@8-9=ivRyncaY94oy)E$=U4io2L%;6Vq zrb48?RjP0O{K);|%LmBO^U~4tUk!XYxO$Cr_DP+6<$y>7-|?;7)o)Ajei`tH@I|-w49Mz&XDFaU+wv_X|;x&^GfHuq}lthw@>Q5A@&DJ z@1WE>NSX&p)8NCNcDcHdjGdtc7%2Plfm#?hTZR4KX~|F`{&;KZ!TXX z$1X_6E|7zr(!ox^?SGwJcVGT`$HSgpsppz_-A{S~Qcr-mgOWRl@>=6#I^)n_Q7fQLTmH?H{z%o!_Wc3v>K$JjRTus8{b?57o z*CCI~>r>OE)x%*8N-_~ zz>ROl!0~>{)dJ(d(!!w#7aq#!BIEoR4_#wi4!#z#6&N_NfC(!Z=aHc0VnEPxF+UGm zXoj!CU>gQ#Ecl%m?1BJzA|-A*O33t-jPVW-L@|)FFpTqcSR_lNP(KnW%0pWqA7#mu3$-b3p?=Y|boSn%rOA6QFSkGNth~C)uAC8TE)#o~WbXnR^a9ao z$BtyHSfD37b4yf$vZ+$jF26p#2%LH;Vesi{CqGK}x*zm>Fuvqc((!S-YxYL#@pmBUXEK+~L^g#m6 zH>GixVC4dh)9{@NGg)u|0;_TkT0#4)Cq1!(Q9-hx)DzIk&WLndoDs>~t-varxlBS& z0N^yHNCf~ZOUcVog9_aSYG730WvQ+3^yEoar>Zrw+Me1P6*fmQWc7n?SxVlMeCy*# z-AcDcR{E)}QDL9jnl11^?10P&2H3wF1tWC~d?So8zXt^wc)!9>d=NG$TtAx z@f-jRJAWKX9D!f-9DD$|ov~K`>W|qNuiKq-?H}Hry`8F9v+ohJ|1X^;_B|4fKekoR zP0scw-C*sz7AqF65Zl4U&@y{3ynOA!a=+6`qzT^&X719lHM5+ zN5j%PW8}&hv5h4<(?2ym72A{5scO-20Je{&?Dlc9-9d{$?aoDre5^Ph6PTFHLJ6#( zxyLH07|j99-2x6yjI}+?XPRyoONoI?L#L3QaKM^sgG#7d&$5PNIJ0Q^a8W#UK~N9h z4(M}O><|PAX5G=UIQ{U{E;Qo!X|dgw}NKA34)AnA{4oqF^qAsv8iyz5E~DNqA{15$JM2#sKziCi-r8b zj3pQGVOcA+Le%=ek!3c8Ce9a`pi4H2q!`yIedG(3mI2E}(F)6ji}ZE)MgIiBi?Lka zlsjgJ=cY+{qg39QczxYjHShYk<<~8#?ZoMloUX*>^~!DYCzEbc*(6mqCAu)#p6p*d zO)5`Fl_wHiX-DlN$KEx^UgFp%Irf$QR&X^@GtmVRm6_PWFu>VO!6cxA7qnZL3Z0?w)i{)@zfVQGBI@*b& zU2?Po$3*L~>{^994{>-ThetUAh3^5%7Yrp{Jt;>(-q#`(QxmL|Y(zj`?ooafDoU2S zw&UR%F3v{Ylz)RStWHgq2J(&=F8bV~?hCZW9>Yyw?g^;&*>I`IDlPzJ=Q(gSsg%jc z6SxnhOh(Xv_9OTr69vR~xI)q?^M_Z9GU%2HD zPT&;|J?h}CSjL*WpbA9>$0>=)uv~^6$gqCC8sz+T3~De?ooyJht%;%*YQAzo){!_i zwfYL+Zon_PKuPbqvFxKGvnS`=#Ijqm?0#fvTC+3}%K^!9fEZnp(Y5GVGq#AvmULC! z2j&mV-M0 zRxI}~17oidYc7&zrrSo;aH>oseG~{co?7Y~^zmANJ09KD?GZ z`+LnNhSW4^Xp$9)fr+%crFQqKhqQN-+U|#*9?8=uzQqyGfaDnv z2ZxAfh|~@NTVJ=9&vwqW!#-ZJ)=_=8NJFdC&`PY?yCxSo+94h7Al44JjR~`zct=y) z$)+#80qI@cm@L1-COS6^jMbTVgVJdi+vzfA^pUBtiRtX7iTW~JM_KNRl$E+R9vu>o zv-=rCi_!(~z3J7l@xW9#_!{p85&E3y?;zOFvMifsoMQfe!%)p`+hCNz5hvy;gJv7% fDT4+L<|%_R3+5?XdnIq|6d2Y diff --git a/basic_function/__pycache__/unit_cell_parser.cpython-313.pyc b/basic_function/__pycache__/unit_cell_parser.cpython-313.pyc deleted file mode 100644 index ba4c6646677fb1a3416be0b56394ecc4ea066a6e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8352 zcmd5>eQZU1 zE>%4>LGvm|sYyx=DGfQw6Fb6Y5e+Dl6er(#r2uu%HB=og15YE(~A>nWFNZ@$<^ zGn7lM1x})zTIMyF%V}g@OL#5tq$Dq|lX*#AFY}VTLD6zuX`3jD(RF}Xm#)JEbr~zj z>J_q86=V$xSremn!C#k}(I~VGD_aBOXex{&>KDnAl~SxzDo^UrR;;6{sE*3=#)ACn zEmxE*>SqV^O`%b#r!ZG!bd)VO7mo~bEwEBtN}2K_F82;?Dq2xqx@hbwN*DJM+N98^ zF4tm3dHMR~))wZZWbD43Rk>Hs)T-QjrByjyJld<6^DfvyoS89ir5yTxhw2eJqDSbl z&aE5d!`xU?fQ!UqOk8S~R7;-8ulGCz0&n9x(=r|kk_8bA{B1 zobu*OEH@VPdDviNoXu59%!n_>77E7LP>9VLd3H1}W56VlAes4c(u6ZUZ!C&lx=H!v z`n$%%AzxQK>f^ioqkLB+&ig!vU9T$=iP(6UBY7iiG~o64yUH$$a3gi$pNX_ESv-n>*>cX|BMcq$>*I(}`+TI(0BO@g&4-6LAt zGls*N9fw7r*7k?ix>U{d`P8B5x0A=RPiO2KQ-@}Eq^4%OlER%WZdj^F)gF@rr1NYAs-Rylgn;DrD-9ce7pS676`f2Mzo%pg-Xm>tf1{WFkea4+R?GqWl z(C!yn{SN|@ivj+AfX7kyU-GAh#oBF8)bs@H_>!WxIhM9}e9--V_bun`lOLU&KPK)y zDsJ!jQcZ7rgpMFBojbk(zSu;*AO#~Mw zBKIdE+1f@hK1nx!saCCNeyr8lTOMn5HCvLskE<#3x{PM+6IcsW5qN-OT>=lIu&w`p z;6dRop8`vi;6jEH1S-fJ0SaB}f@9*;tFWqoATr=6Se433ttyZHr?r*nNG%4d5?8P) zaivxf4k_T1OWOg*hMbpg3FZNYTh7nhD&zzF0SRsyN=X(Awj?VnZRK-d!lA7twh)F% zwG`Wu@>bf)x2UYG3J|VEd8n2>G>4Av&5yWZX3HEnPIRs%AEPO(i*KQ;E$AY&K!{o3 zD?!aw^l9&R=##wGpJ|^eK;l*O>G>b1Px2hEm?smX3<}DCE1Xpq2@C+5C`aJyWi5hy=GoPRH_If`dlGfPjJljt!3GG+dYu@z71mVdEQ6wg!nE6%1HOm`CuF zPXYMJr@TBwv=qM<$vPw(kkljD2m~+~0Yw@DcC-ZO@HQxjB2gk7O!J#iMS>=1AE61s zc%6`sWzkkeo4VOG-+gc3&PxmQor9UW zFva&KoV9I8HP4z;1GCojF2S}notUe6@BG}M_uigw7PcJ_YzO88(30T7t~=q( zfK&L1Q?NNRgQo=RDL_;|m@6TldDr z7fw8I{&dml7M$)y=bM7_%?uk5ofEbmxdRxoc$s|E9x z&Z`QOHfDzvaOVNul zN#R4-Om0=)hd=~?k*Y0U!X7S)2FpVNT0St;jQX40sy z(D@^qa3sQpn-O#|P(uDo<-|#1E;zn?DfwSqN^Wk7?GhVh?aC#2m{M%FsFXv=UP;1LWOYO|m^yySDxXQj8HCIjumhCqJ z^YqQ&{P4Z0JEs?#?|6jHKEcwrkjM<5`pfyu8)KRAfbhmdW-=(e5fV;>1WPE{ll`&T zsac<{&Dh!?x-m%6&B{%T<4-nmJ0#wv+<2!zK8AA3n|;XcqPyYp6Q~-lG1XhCo2q^X zO-@UPip&TJq+yBd0%da6Nk=Je#V9%ooa2h~OOCy6kbW7}@a#)cOSaA%i3cwpU+5KH z^kya_f^BLQF6sP_)+H&hQ6KCy zuFY5A7kwAVGx0xPRNF3&Ul^ZBh}F%>SC;H+FFR%(mpf-W)9XdMBYAvj&AQ9`FYQnD zifdYveaLlR>Q0>&*SwJI%i04TX^Rt|kJ~>r&QH|GT^O1Ph?b^BOPgS66D{qc zar=!a(YSMI?Q?KdzINxd@hcr=t;^IM6RUeOhTeyUnwjQl^UT1sHF+rebVhmY|9w0N zSgNNPXvA>p&>vNQ()~e~PKZsr<_3hOeL};&PwaPVKd!xJ5W9O9288Z@p`ri5fNODJ z@czJHW{4FBMl${}VPH&X7z1#-WUjv0bD?LZTQqOV8#^mD?G(&~cLImF=ZIiF0$1)q zx(9FD8+z#MtFHmSu^-v$6LhlYv5qp^ldlo*>7aWY#+;Umhaw3no*+jHE))qzBK#vK z4UgvokJERPBC6618(s>9N8>@?Pk1L3;gJyC4&3Xl?yv))y_D1b3ZHU^^zWLdlJ|n-YyJTw_$N?HoH!s~P|i6d{J!d*nV#_&CnAwxcDo*5y{o?O zSHJq+ZZsSXzuz4He&1!9_7BPweqFSZzxund#f@ zl8Va%yb~@MlFstcoVL zKQ?)Zmp`)GYy7NeLe}6F$g13dY@MGID@v;Rvv z=jYn2p=6z(K*=-w`MEY5D0zmTM9E3auOXT=S5fDuJ~sGie&!>qeTtiJX)Sj2cju-I z4?Hd+wl@lTu@?qNg_8Bc;6TXO3;L`p;tz!g829$}gcL!{q!>yOA?0Qb!ae3XJ$V?# zuJ4C^=??e39vcN-%zDE2*##92UFi-)EM&yofHU8XW3MOJf#}7dj9N;Fp`BQ|L4?GB z26y9dfWCTRD7hE7F(x{h&z^MEgk4`Xz2-`kdTziXOhgQ#E$0!bMaxXBpL2YEZGB-O;} zn@U<6*T)S=ONoQNp+!lJmriJ>wWOZMjV$Knq_zPmzQ%@jN*k}BRXMJ3Gg-NBpzSL1 zRbD}AEwPg|UcIU%C8~?GqtZ>4t|47k{jBl2O4pJ))jO@Bym^-DH_=BrAq_V@=__oP#iR(N;<4)mq-F2E z>u$2{ChNVw$=-GS;XY+WjIi$x2JZXsZ%#X+&hp=f`!EOQi=ZFxW7TN1?#)<|(QGgDe3b*aoSw>PXaUvf5`hsGmP`bQ>$qt=^h@Y~Z} zyUhk8biMB$heN9#o^;HHy{LqjD*D7cks=5l0u92LY zPSPYmn%Ym~DqWP5pjAtaV3^vjlfKcDVFsit zsYSYy+Dcb|Bw^GcKta^aqCV#61)YA{JP7^KKy(0tgHH58yH?2OXhg^JP?IlWCZAp3 z-VX<2djzG}-s5unIuOvmC8BtHco^@8LFd|AH@CZPDnak zcH*!@Bb-F{(FFwBs--*nx>3<9hNG_hCn{OFSU^Hx;JB2hfs%xmN^T=v0eY2_O5Yq; zxyh}Mw6T*oi3Mzf^u9H&A!j02Qn|XyS=0t)8KNp&Bb7o3%hqv@VGrL4E4?6!dC~oY#y?KBm8U*321ps zkYGn4LC-?{$&d7a6yz;@O3yM~Nc3aemx;NdoyENfnSswtEMjQKI2oLVc_W98Pya0I zXx*}86H#jNFiuOZ?2$EFJ>MI)?9@1T8QeRN%7?>1#2`y$tN@I zrxr~rEdzos6Y!r#S!-oxKxut8R4{DY7_;MMQ8$ZvZ5!xALW&@ys3*_RNS6_`s|tQP z7@=Ju2vVu4f+0%EXQ?J3Nn!mN78DowQ&fn~Aup@_Q(6 zS!tCBS%;|CY$X{K(8V$N5suQ9D#~;di76DADy74qt;CR56DucA7pLUOgv11$x(bO! zbP6p%p+K9Av7OitY(+b4_h&_;EpWjiRfO%%XkoEn1O@5O-G@t!db(I;)R%-W2E>1} zYgT3^PWSEd;S2UmJd^f(0uBQ=mfk&eRnODlLK`BaOSnOS7i7fZsmx%)3U-!nM#0?W zMV!?`U0v2kFXkj*x(P+o-(mS?L|d)WGBc1bp{q>kOY-}ao?ad8`n1bVFqE&LFxRUd zIQ*i*o}X2QMN*AU_b|KRXh(gtjtC*NhNC+M^jBZE>$?0QB)|?e^7L+hTvJI)r~~jPfa@y@$+y2Wg7e5XS-c!D|o91@Kl%O2kbR zCz*q{_B<3p$j;9Cbc{xpo6hX4sXX(^BPCt$=s5ra8R(9Y&6bY!vX$r2AUF z2H8ZBzZt|^ss}`n3p-y>y)T1!w9c0*s%PblIjX+q;r+3j_7qe`-z&2BR&ktR=b=+>G&=@`! zxOO5*ep`8hf@dfomq+T%7Sw9WU5M~nr0a{8K`5;rSL)q#mKxEsNIplwc?zDVfKl)* z3cgLjcM!Czxg!S-rUf`(ojo_RI)6p|h#rFga6Gjx;;X;~1J{f7#utXUVan?$Yw6b) z|4*moPKSrRP6uwQ!y%ZKB#)gmDIjWs_wQ&hJj~w9b(P-ADyC*IjMMUsAMYxUSy}Da zE8k7)x#T+SpnL;;w42wk9XEj*nX5jnQU0BbM6_$yW`1~hy~5DXs2M#D6+csKO5H^A zb&}Ci3Ol0qMm@tTR|g><`QlYtLi7#-tGr>t3M3MbS+8Up0xXu3#Ys>gvMl1{FxfEf?W&z^yQh0n z-LvtodM-E;;=mn5kh~WpU-J(j!9Rgo;=}>@f^yCw;rCVd%=AviI1z~iv)lFf>Rt7H zzxvhZwpxyc-*1k7zvnVd`v+ypzX~$1;ET`cn#Q%R#&vE4`oQQK1G8%mtgfY#%na;q zMa9*D-mU4HsNJoL%3bG-)@=w=7@~Pb%fnp*vX-!rQ&PwqO4e2~s+Z@N__LxVmPMP} zADO(us~_6k6@FHyOBjx4|qFv@Xo;-*X zHwdD>bccJs$3~%_Fi!*lyP(3ME8T%egp8RRau&Er;(LPa3onUe+)+Xd?Ih9-VqtkwHm;8yEf% zDrsZf9Je5~QwM!Ri_!+KoX}2dX|sr1dCaS6V+~S#tu^hGHeN!jN>b%!y7a(6+hycy zyoS_DYNso_eoad&R2OMSrQ0gqK)R~>S>a8UZlq1BcUnVv`z+OODf#j=-$s7rVf{#h z4lTKV!Is>})g{i)yL8e?G=I<=j>GV~y(i{M*d6x{^PwDp-wy?6ANa{0OHl3pCHKA` zZ>T1@wy--pcUa&<8T?T7rm7uW8U}7CiUGD{Ki*V@G~D#0H`x|TND)ZI6W0$(%iew0 zU1vM%%zJO0z3T?UJ<5s%Vc#7L-1pvFpLRr@6~7PmU=A!0VL#c!w2}5hwiWi{L`H*L z$=$FYV3tK07kw!N6m!ypQo2p{#xA=pM`E41aSWr2)vEn)I7(Q&=MK>?Ha3zHmQpP4 z7fMY_>)T>i?_b51#0i5nQXJ3Bl$J<4q|C6eH>|T*az33$#$!_Y$0mZq&Kqj*JJVge z!v-UCz31+WVoDdTth1|~4T;0o9#im7eE5CxvaBQOs|o zgt(GKv^SkqNfs^Xs^~UU`(6$J)#5laCqKF#dd+k<(|0o6!-%9CnMJyh*-AHn98ugO z070~qM}5r64}1Nry&namf#?DB2EF+GZljbh(1(uWktScjOg??n+uVK|8W3-9aiAY2 zDuI7;`&AL&OQPZS@F3ZX!rqOyZf)ZgT5GZS15@ zY5~z8ePE3n$eGAhRIaIV7PUcHj-v|A2wfjmja4d>*oeBmDH3{nPRm*#A{!|DJ{V zlOO2;Dac#+l%D0fkm^UcB~x=vJB!;8G6Ns(LE>dcI2oLVd9whFPyQ^MXx*`78&PKR zD9I|W^vI?yFYt#QJ2Un#x&A% zq$#So>ubj0ca^EoLdjI<%r~wD5t_u;H|AS*s%mPI6kje==am4=B*&YZ0X{ZEt!y*;7`{ev4}q@?I50HjP2BZ zXepjyyFV}6Yl-}oxgpGU#skX*Bj`r|*af)Io~Mgd_IybMVnBp9zgJ~$-E_e&9=>GC zL@a5~C%7=3JaC;z}M$>_57?d?2u~o zb`J7;jW#nt>zHssYdN}OKy~$1yQ#}fNPrk>WWu1*RaI6}QqbvD^(-<8sC)|Y{##>g zCb+b4l~9hJMuvO$tEyW`#h!9&Ojr4JXc>*Qskrgnh4pk+`()i43P~@pnC*j&bM?$) ziXai=#>aA)8sH5c^gSy{guehrHEe}r|Vy(;mLLP38_XIMbp;`4km)GbsQ+L({g=?v;#Rznnz zOg{vg6RLUi+@AdCCZ(sfB5nUdVhVG5s^>IKUQdX?5oxzm(DgG^r-Bq5TnORbIx_I5 zKLCkWZR~r7{9CGT8)Mw5DLwV@(4*(w!zw-Wkgt;K3;7y7;nI5H_;SW@ec?%9ozPO7 zLO7c$c?_49Ln%1H8b#d9V{WAm>J^T#Q)q)nAI-O(INV5EK{#_ zn7jtyt-Ei~9Y1*(7dH>zp`(O0lcnku$!i`Z|v6U$|s%W2(m|u3F)$_C`B!pH7`plUv&L zV^Q)Rz2gQW@tTy8RQ!`xo;#kf4Y?DDR;j#mVU)1k_k;Pf3;8QhtwTK>a|e&|>?qew z4Qh3G9=k*B6mNi&7~bC!sbz;N%H>=dXyaz5ImaBCQ9q_(0N{9FUBp*|&jr2~>nEQZ=9(#ALs>_Eeg6M+R_*n8 zg+99N8`zFpK#klZpEW4|PR<{?jT None: - """Prints all attributes of the atom to the console.""" - for key, value in self.__dict__.items(): - print(f"{key}: {value}") - - def check(self) -> None: - """ - Performs basic sanity checks on the atom's attributes. - - Raises: - AssertionError: If element is not defined, or if neither cartesian - nor fractional coordinates are provided. - """ - assert self.element != "unknown", "Atom must have an element type." - has_cart = isinstance(self.cart_xyz, (np.ndarray, list)) - has_frac = isinstance(self.frac_xyz, (np.ndarray, list)) - assert has_cart or has_frac, "Atom needs either cart_xyz or frac_xyz." - - -class Crystal: - """ - Represents a periodic crystal structure. - - Contains lattice information (cell parameters or vectors) and a list of atoms - that constitute the structure within the unit cell. - """ - - def __init__(self, **kwargs: Any): - """ - Initializes a Crystal object. - - The constructor requires lattice and atom information. It will automatically - calculate derived properties like volume and density. - - Args: - **kwargs: Keyword arguments to set crystal attributes. - Required: 'atoms' and one of 'cell_vect' or 'cell_para'. - - cell_vect (list): 3x3 list or array of lattice vectors. - - cell_para (list): [[a, b, c], [alpha, beta, gamma]]. - - atoms (List[Atom]): A list of Atom objects. - Optional: 'energy', 'comment', 'system_name', 'space_group', 'SYMM', etc. - """ - self.cell_vect: Union[str, np.ndarray] = kwargs.get("cell_vect", "unknown") - self.cell_para: Union[str, list] = kwargs.get("cell_para", "unknown") - self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown") - self.energy: Union[str, float] = kwargs.get("energy", "unknown") - self.comment: Any = kwargs.get("comment", "unknown") - self.descriptor: Any = kwargs.get("descriptor", "unknown") - self.molecule_number: Union[str, int] = kwargs.get("molecule_number", "unknown") - self.system_name: str = kwargs.get("system_name", "unknown") - self.virial: Any = kwargs.get("virial", "unknown") - self.SYMM: list = kwargs.get("SYMM", ["x,y,z"]) - self.space_group: int = kwargs.get("space_group", 1) - self.other_properties: dict = {} - - # This method completes the initialization. - self.lattice_and_atom_complete() - - def lattice_and_atom_complete(self) -> None: - """ - Completes the initialization by ensuring consistency between cell representations, - atom coordinates, and calculating derived properties. - """ - # --- 1. Finalize Lattice Representation --- - has_vect = isinstance(self.cell_vect, (np.ndarray, list)) - has_para = isinstance(self.cell_para, (np.ndarray, list)) - - if has_vect and has_para: - # If both are provided, check for consistency - derived_para = np.array(unit_cell_parser.cell_vect_to_para(self.cell_vect)).flatten() - provided_para = np.array(self.cell_para).flatten() - assert np.allclose(derived_para, provided_para, atol=1e-3), \ - "Provided cell_para and cell_vect are inconsistent." - elif has_vect and not has_para: - self.cell_para = unit_cell_parser.cell_vect_to_para(self.cell_vect) - elif not has_vect and has_para: - self.cell_vect = unit_cell_parser.cell_para_to_vect(self.cell_para) - else: - raise ValueError("Crystal lattice is not defined. Provide 'cell_vect' or 'cell_para'.") - - # --- 2. Finalize Atom Coordinates --- - if self.atoms == "unknown" or not self.atoms: - print("Warning: Crystal initialized with no atoms.") - self.atoms = [] - else: - for atom in self.atoms: - has_cart = isinstance(atom.cart_xyz, (np.ndarray, list)) - has_frac = isinstance(atom.frac_xyz, (np.ndarray, list)) - if not (has_cart and has_frac): - if has_frac: - atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) - elif has_cart: - atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect) - else: - raise ValueError(f"Atom {atom.element} {atom.atom_id} has no coordinate information.") - - # --- 3. Calculate Derived Properties --- - self.volume = unit_cell_parser.calculate_volume(self.cell_para) - if self.atoms: - total_mass = sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms) - # Density in g/cm^3 - self.density = total_mass / (self.volume * 1e-24) / (6.022140857e23) - else: - self.density = 0.0 - - def update_cart_by_frac(self) -> None: - """Updates all atom cartesian coordinates from their fractional coordinates.""" - for atom in self.atoms: - atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) - - def update_frac_by_cart(self) -> None: - """Updates all atom fractional coordinates from their cartesian coordinates.""" - for atom in self.atoms: - atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect) - - def check(self) -> None: - """Performs consistency checks on the crystal structure.""" - print("Performing consistency checks...") - # Check lattice consistency - self.lattice_and_atom_complete() - - # Check atom coordinate consistency - for atom in self.atoms: - derived_cart = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) - assert np.allclose(atom.cart_xyz, derived_cart, atol=1e-3), \ - f"Atom {atom.atom_id} cartesian and fractional coordinates do not match." - - # Check atom IDs - if all(atom.atom_id != "unknown" for atom in self.atoms): - print("All atoms have IDs.") - else: - print("Warning: Not all atoms have IDs. Use .give_atom_id_forced() to assign them.") - print("Checks passed.") - - def give_atom_id_forced(self) -> None: - """Assigns or resets atom IDs from 0 to N-1 and clears bonding info.""" - print("Warning: Resetting all atom IDs and bonding information!") - for i, atom in enumerate(self.atoms): - atom.atom_id = i - atom.bonded_atom = [] - - def move_atom_into_cell(self) -> None: - """ - Moves all atoms into the primary unit cell [0, 1) in fractional coordinates. - """ - for atom in self.atoms: - # Use modulo for a more direct and efficient way to wrap coordinates - atom.frac_xyz = np.mod(atom.frac_xyz, 1.0) - self.update_cart_by_frac() - - def find_molecule(self, tolerance: float = 1.15) -> None: - """ - Identifies molecules within the crystal based on bonding distances. - - This method performs a graph search (BFS) on the atoms, connecting them - based on scaled covalent radii. It populates the `atom.molecule` and - `self.molecule_number` attributes. - - Args: - tolerance: A scaling factor for covalent radii to determine bonding. - A bond is formed if dist(A, B) < (radius(A) + radius(B)) * tolerance. - """ - self.move_atom_into_cell() - atoms_to_visit = list(range(len(self.atoms))) - molecule_id = 0 - - while atoms_to_visit: - molecule_id += 1 - # Start a Breadth-First Search (BFS) from the first unvisited atom - q = [atoms_to_visit[0]] - visited_in_molecule = {atoms_to_visit[0]} - - head = 0 - while head < len(q): - current_atom_idx = q[head] - head += 1 - self.atoms[current_atom_idx].molecule = molecule_id - - # Check for bonds with all other atoms - for other_atom_idx in range(len(self.atoms)): - if current_atom_idx == other_atom_idx: - continue - - # is_bonding_crystal handles periodic boundaries - is_bonded, _ = operation.is_bonding_crystal( - self.atoms[current_atom_idx], - self.atoms[other_atom_idx], - self.cell_vect, - tolerance=tolerance, - update_atom2=False # Do not modify coordinates during search - ) - - if is_bonded and other_atom_idx not in visited_in_molecule: - visited_in_molecule.add(other_atom_idx) - q.append(other_atom_idx) - - # Remove all atoms found in the new molecule from the list to visit - atoms_to_visit = [idx for idx in atoms_to_visit if idx not in visited_in_molecule] - - self.molecule_number = molecule_id - - def get_element(self) -> List[str]: - """Returns a sorted list of unique element symbols in the crystal.""" - return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms)) - - def get_element_amount(self) -> List[int]: - """Returns the count of each element, sorted by atomic number.""" - all_elements = [atom.element for atom in self.atoms] - return [all_elements.count(element) for element in self.get_element()] - - - def make_p1(self) -> None: - """ - Expands the asymmetric unit to the full P1 cell using symmetry operations. - - The crystal's space group is set to 1 (P1) and SYMM is reset. This - implementation is robustly designed to ensure the final coordinate array - is always 2-dimensional, preventing downstream errors. - """ - all_ele, all_frac = self.get_ele_and_frac() - all_reflect_position = [] - all_matrix_M = [] - all_matrix_C = [] - for sym_opt in self.SYMM: - sym_opt_ele = sym_opt.lower().replace(" ", "").split(",") - # assert len(sym_opt_ele) == 3, "sym {} could not be treat".format(sym_opt_ele) - matrix_M = np.zeros((3, 3)) - matrix_C = np.zeros((1, 3)) - for idx, word in enumerate(sym_opt_ele): - sym_opt_ele_split = re.findall(r".*?([+-]*[xyz0-9\/\.]+)", word) - for sym_opt_frag in sym_opt_ele_split: - if sym_opt_frag == 'x' or sym_opt_frag == '+x': - matrix_M[0][idx] = 1 - elif str(sym_opt_frag) == '-x': - matrix_M[0][idx] = -1 - elif sym_opt_frag == 'y' or sym_opt_frag == '+y': - matrix_M[1][idx] = 1 - elif sym_opt_frag == '-y': - matrix_M[1][idx] = -1 - elif sym_opt_frag == 'z' or sym_opt_frag == '+z': - matrix_M[2][idx] = 1 - elif sym_opt_frag == '-z': - matrix_M[2][idx] = -1 - elif operation.is_number(sym_opt_frag) is True: - matrix_C[0][idx] = float(fractions.Fraction(sym_opt_frag)) - else: - raise Exception("wrong sym opt of" + sym_opt_frag) - - all_matrix_M.append(matrix_M) - all_matrix_C.append(matrix_C) - - for j in range(0, len(all_matrix_M)): - new_positions = np.dot(np.array([all_frac]), all_matrix_M[j]) + all_matrix_C[j] - all_reflect_position.append(new_positions.squeeze()) - all_ele = all_ele*len(self.SYMM) - - new_atoms = [] - idx=0 - for element, frac_xyz in zip(all_ele, np.array(all_reflect_position).reshape(-1,3)): - new_atoms.append(Atom(element=element, - frac_xyz=frac_xyz, - atom_id=idx)) - idx+=1 - - self.SYMM = "[x,y,z]" - self.space_group = 1 - self.atoms = new_atoms - self.update_cart_by_frac() - - def sort_by_element(self) -> None: - """Sorts the atoms list based on atomic number.""" - self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element]) - - def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]: - """Returns all element symbols and their cartesian coordinates.""" - if not self.atoms: - return [], np.empty((0, 3)) - all_ele = [atom.element for atom in self.atoms] - all_carts = np.array([atom.cart_xyz for atom in self.atoms]) - return all_ele, all_carts - - def get_ele_and_frac(self) -> Tuple[List[str], np.ndarray]: - """Returns all element symbols and their fractional coordinates.""" - if not self.atoms: - return [], np.empty((0, 3)) - all_ele = [atom.element for atom in self.atoms] - all_fracs = np.array([atom.frac_xyz for atom in self.atoms]) - return all_ele, all_fracs - - def info(self, all_info: bool = False) -> None: - """ - Prints a formatted summary of the crystal structure. - - Args: - all_info: If True, prints an extended table including fractional - coordinates, forces, and other properties. - """ - print("--- Crystal System ---") - print(f"Name: {self.system_name}") - print("Lattice Vectors (Angstrom):") - for vec in self.cell_vect: - print(f"{vec[0]:16.8f} {vec[1]:16.8f} {vec[2]:16.8f}") - print("Lattice Parameters:") - print(f"a, b, c (A): {self.cell_para[0][0]:.4f}, {self.cell_para[0][1]:.4f}, {self.cell_para[0][2]:.4f}") - print(f"alpha, beta, gamma (deg): {self.cell_para[1][0]:.4f}, {self.cell_para[1][1]:.4f}, {self.cell_para[1][2]:.4f}") - print(f"Volume (A^3): {self.volume:.4f} | Density (g/cm^3): {self.density:.4f}") - print(f"\n--- Atomic Coordinates (Total: {len(self.atoms)}) ---") - - if not all_info: - print(f"{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}") - print("-" * 58) - for atom in self.atoms: - print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}") - else: - header = ( - f"{'ID':<5} {'Elem':<6} " - f"{'Frac X':>10} {'Frac Y':>10} {'Frac Z':>10} | " - f"{'Cart X':>12} {'Cart Y':>12} {'Cart Z':>12}" - ) - print(header) - print("-" * len(header)) - for atom in self.atoms: - aid = str(atom.atom_id) if atom.atom_id != 'unknown' else '-' - print( - f"{aid:<5} {atom.element:<6} " - f"{atom.frac_xyz[0]:10.6f} {atom.frac_xyz[1]:10.6f} {atom.frac_xyz[2]:10.6f} | " - f"{atom.cart_xyz[0]:12.6f} {atom.cart_xyz[1]:12.6f} {atom.cart_xyz[2]:12.6f}" - ) - - print("\n--- Other Properties ---") - print(f"Energy: {self.energy}") - print(f"Comment: {self.comment}") - print(f"Virial: {self.virial}") - - -class Molecule: - """Represents a non-periodic molecule (a collection of atoms).""" - - def __init__(self, **kwargs: Any): - """ - Initializes a Molecule object. - - Args: - **kwargs: Keyword arguments to set molecule attributes. - Required: 'atoms' (List[Atom]). - Optional: 'energy', 'comment', 'name', 'system_name'. - """ - self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown") - self.energy: Union[str, float] = kwargs.get("energy", "unknown") - self.comment: Any = kwargs.get("comment", "unknown") - self.descriptor: Any = kwargs.get("descriptor", "unknown") - self.name: str = kwargs.get("name", "unknown") - self.system_name: str = kwargs.get("system_name", "unknown") - - if self.atoms == "unknown": - print("Warning: Molecule initialized with no atoms.") - self.atoms = [] - - def give_atom_id_forced(self) -> None: - """Assigns or resets atom IDs from 0 to N-1 and clears bonding info.""" - print("Warning: Resetting all atom IDs and bonding information!") - for i, atom in enumerate(self.atoms): - atom.atom_id = i - atom.bonded_atom = [] - - def get_element(self) -> List[str]: - """Returns a sorted list of unique element symbols in the molecule.""" - if not self.atoms: return [] - return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms)) - - def get_element_amount(self) -> List[int]: - """Returns the count of each element, sorted by atomic number.""" - if not self.atoms: return [] - all_elements = [atom.element for atom in self.atoms] - return [all_elements.count(element) for element in self.get_element()] - - def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]: - """Returns all element symbols and their cartesian coordinates.""" - if not self.atoms: - return [], np.empty((0, 3)) - all_ele = [atom.element for atom in self.atoms] - all_carts = np.array([atom.cart_xyz for atom in self.atoms]) - return all_ele, all_carts - - def put_ele_cart_back(self, all_ele: List[str], all_carts: np.ndarray) -> None: - """Updates the molecule's atoms from lists of elements and coordinates.""" - for i, atom in enumerate(self.atoms): - atom.element = all_ele[i] - atom.cart_xyz = all_carts[i] - - def build_molecules_by_ele_cart(self, all_ele: List[str], all_carts: np.ndarray) -> None: - """Rebuilds the molecule's atoms list from elements and coordinates.""" - assert len(all_ele) == len(all_carts), "Element and coordinate lists must have the same length." - self.atoms = [ - Atom(element=ele, cart_xyz=cart, atom_id=i) - for i, (ele, cart) in enumerate(zip(all_ele, all_carts)) - ] - - def get_mass(self) -> float: - """Calculates the total mass of the molecule.""" - if not self.atoms: return 0.0 - return sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms) - - def get_center_of_mass(self) -> np.ndarray: - """Calculates the center of mass of the molecule.""" - if not self.atoms: return np.zeros(3) - - all_ele, all_carts = self.get_ele_and_cart() - masses = np.array([chemical_knowledge.element_masses[x] for x in all_ele]) - total_mass = np.sum(masses) - - if total_mass == 0: return np.zeros(3) - return np.sum(all_carts * masses[:, np.newaxis], axis=0) / total_mass - - def sort_by_element(self) -> None: - """Sorts the atoms list based on atomic number.""" - self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element]) - - def sort_by_id(self) -> None: - """Sorts the atoms list based on their atom_id.""" - self.atoms.sort(key=lambda atom: atom.atom_id) - - def info(self) -> None: - """Prints a formatted summary of the molecule.""" - print(f"--- Molecule ---") - print(f"Name: {self.name} | System: {self.system_name}") - print(f"Number of atoms: {len(self.atoms)}") - print(f"Total Mass (amu): {self.get_mass():.4f}") - print(f"Energy: {self.energy}") - print(f"Comment: {self.comment}") - print(f"\n{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}") - print("-" * 58) - if self.atoms: - for atom in self.atoms: - print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}") - - def find_fragment(self, tolerance: float = 1.15) -> Dict[int, List[int]]: - """ - Identifies covalently bonded fragments within the molecule. - - This is useful for molecules that are actually composed of several - disconnected components (e.g., salts, solvent shells). - - Args: - tolerance: Scaling factor for covalent radii to determine bonding. - - Returns: - A dictionary mapping a fragment ID (starting from 1) to a list of - atom indices belonging to that fragment. - """ - if not self.atoms: return {} - - num_atoms = len(self.atoms) - cart_matrix = np.array([atom.cart_xyz for atom in self.atoms]) - radii = np.array([chemical_knowledge.element_covalent_radii[atom.element] for atom in self.atoms]) - - # Create a matrix of bond thresholds (r_i + r_j) - bond_threshold_matrix = (radii[:, np.newaxis] + radii) * tolerance - - # True where distance is less than the bond threshold - dist_matrix = cdist(cart_matrix, cart_matrix) - adj_matrix = dist_matrix < bond_threshold_matrix - np.fill_diagonal(adj_matrix, False) - - # Graph traversal (DFS) to find connected components - visited = [False] * num_atoms - groups = {} - group_id = 0 - for i in range(num_atoms): - if not visited[i]: - group_id += 1 - groups[group_id] = [] - stack = [i] - while stack: - atom_idx = stack.pop() - if not visited[atom_idx]: - visited[atom_idx] = True - groups[group_id].append(atom_idx) - # Find neighbors and add to stack - neighbors = np.where(adj_matrix[atom_idx])[0] - stack.extend(neighbors) - return groups - - def give_molecule_id(self, tolerance: float = 1.15) -> None: - """Assigns a molecule ID to each atom based on fragment analysis.""" - fragments = self.find_fragment(tolerance=tolerance) - for group_id, atom_indices in fragments.items(): - for atom_idx in atom_indices: - self.atoms[atom_idx].molecule = group_id - - def take_out_fragment(self, tolerance: float = 1.15) -> List['Molecule']: - """ - Splits the current molecule into a list of new Molecule objects, - one for each disconnected fragment. - """ - if not self.atoms: return [] - - self.give_atom_id_forced() # Ensure IDs are set for lookup - fragments = self.find_fragment(tolerance=tolerance) - new_molecules = [] - - for i, atom_indices in fragments.items(): - fragment_atoms = [self.atoms[j] for j in atom_indices] - new_mol = Molecule( - atoms=copy.deepcopy(fragment_atoms), - name=f"{self.name}_frag{i}", - system_name=f"{self.system_name}_frag{i}" - ) - new_molecules.append(new_mol) - return new_molecules - - def calculate_frac_xyz_by_cell_para(self, cell_para: list) -> None: - """Calculates fractional coordinates for all atoms given cell parameters.""" - for atom in self.atoms: - atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_para(atom.cart_xyz, cell_para) - - def molecule_volume(self, num_samples: int = 100000) -> float: - """ - Calculates the van der Waals volume using a Monte Carlo integration method. - - This method samples points in a bounding box around the molecule and - determines the ratio of points that fall within any atom's vdW sphere. - - Args: - num_samples: The number of random points to sample. More points - yield a more accurate volume at the cost of performance. - - Returns: - The estimated van der Waals volume in cubic Angstroms. - """ - if not self.atoms: return 0.0 - - elements, coords = self.get_ele_and_cart() - radii = np.array([chemical_knowledge.element_vdw_radii[el] for el in elements]) - - # Determine bounding box for sampling - min_bounds = np.min(coords, axis=0) - np.max(radii) - max_bounds = np.max(coords, axis=0) + np.max(radii) - bounding_box_volume = np.prod(max_bounds - min_bounds) - - # Generate random sample points within the bounding box - random_points = np.random.uniform(min_bounds, max_bounds, (num_samples, 3)) - - # Check for each point if it's inside ANY sphere - count_inside = 0 - for rp in tqdm(random_points, desc="Monte Carlo Volume", leave=False): - # Calculate squared distances from the point to all atom centers - dist_sq = np.sum((coords - rp)**2, axis=1) - # If any distance is within the radius, the point is inside - if np.any(dist_sq <= radii**2): - count_inside += 1 - +""" +This module defines the core data structures for representing atomic structures: +Atom, Crystal, and Molecule. + +These classes store information about atomic coordinates, lattice parameters, +and other physical properties, providing a foundational toolkit for geometric +and structural analysis in materials science simulations. +""" +import copy +from typing import List, Tuple, Union, Any, Optional, Dict + +import numpy as np +import fractions +import re +from scipy.spatial.distance import cdist +from tqdm import tqdm + +from basic_function import unit_cell_parser +from basic_function import chemical_knowledge +from basic_function import operation + + +class Atom: + """ + Represents a single atom in a chemical structure. + + Attributes: + element (str): The chemical symbol of the atom (e.g., 'H', 'C', 'O'). + cart_xyz (np.ndarray): Cartesian coordinates [x, y, z] in Angstroms. + frac_xyz (np.ndarray): Fractional coordinates [u, v, w] with respect to a lattice. + atom_id (int): A unique identifier for the atom within a larger structure. + force (np.ndarray): Force vector [fx, fy, fz] acting on the atom. + atom_charge (float): Partial charge of the atom. + atom_energy (float): Site potential energy of the atom. + molecule (int): Identifier for the molecule this atom belongs to. + bonded_atom (list): A list of IDs of atoms bonded to this one. + descriptor (any): A placeholder for feature vectors or other descriptors. + comment (dict): A dictionary for storing arbitrary metadata. + """ + + def __init__(self, **kwargs: Any): + """ + Initializes an Atom object. + + Args: + **kwargs: Keyword arguments to set atom attributes. + Required: 'element' and one of 'cart_xyz' or 'frac_xyz'. + Optional: 'atom_id', 'force_xyz', 'atom_charge', 'atom_energy', etc. + """ + self.element: str = kwargs.get("element", "unknown") + self.cart_xyz: Union[str, np.ndarray] = kwargs.get('cart_xyz', "unknown") + self.frac_xyz: Union[str, np.ndarray] = kwargs.get('frac_xyz', "unknown") + self.atom_id: Union[str, int] = kwargs.get("atom_id", "unknown") + self.force: Union[str, np.ndarray] = kwargs.get('force_xyz', 'unknown') + self.atom_charge: Union[str, float] = kwargs.get('atom_charge', 'unknown') + self.atom_energy: Union[str, float] = kwargs.get('atom_energy', 'unknown') + self.molecule: Union[str, int] = kwargs.get('molecule', 'unknown') + self.bonded_atom: list = kwargs.get('bonded_atom', []) + self.descriptor: Any = kwargs.get("descriptor", "unknown") + self.comment: dict = kwargs.get("comment", {}) + + def info(self) -> None: + """Prints all attributes of the atom to the console.""" + for key, value in self.__dict__.items(): + print(f"{key}: {value}") + + def check(self) -> None: + """ + Performs basic sanity checks on the atom's attributes. + + Raises: + AssertionError: If element is not defined, or if neither cartesian + nor fractional coordinates are provided. + """ + assert self.element != "unknown", "Atom must have an element type." + has_cart = isinstance(self.cart_xyz, (np.ndarray, list)) + has_frac = isinstance(self.frac_xyz, (np.ndarray, list)) + assert has_cart or has_frac, "Atom needs either cart_xyz or frac_xyz." + + +class Crystal: + """ + Represents a periodic crystal structure. + + Contains lattice information (cell parameters or vectors) and a list of atoms + that constitute the structure within the unit cell. + """ + + def __init__(self, **kwargs: Any): + """ + Initializes a Crystal object. + + The constructor requires lattice and atom information. It will automatically + calculate derived properties like volume and density. + + Args: + **kwargs: Keyword arguments to set crystal attributes. + Required: 'atoms' and one of 'cell_vect' or 'cell_para'. + - cell_vect (list): 3x3 list or array of lattice vectors. + - cell_para (list): [[a, b, c], [alpha, beta, gamma]]. + - atoms (List[Atom]): A list of Atom objects. + Optional: 'energy', 'comment', 'system_name', 'space_group', 'SYMM', etc. + """ + self.cell_vect: Union[str, np.ndarray] = kwargs.get("cell_vect", "unknown") + self.cell_para: Union[str, list] = kwargs.get("cell_para", "unknown") + self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown") + self.energy: Union[str, float] = kwargs.get("energy", "unknown") + self.comment: Any = kwargs.get("comment", "unknown") + self.descriptor: Any = kwargs.get("descriptor", "unknown") + self.molecule_number: Union[str, int] = kwargs.get("molecule_number", "unknown") + self.system_name: str = kwargs.get("system_name", "unknown") + self.virial: Any = kwargs.get("virial", "unknown") + self.SYMM: list = kwargs.get("SYMM", ["x,y,z"]) + self.space_group: int = kwargs.get("space_group", 1) + self.other_properties: dict = {} + + # This method completes the initialization. + self.lattice_and_atom_complete() + + def lattice_and_atom_complete(self) -> None: + """ + Completes the initialization by ensuring consistency between cell representations, + atom coordinates, and calculating derived properties. + """ + # --- 1. Finalize Lattice Representation --- + has_vect = isinstance(self.cell_vect, (np.ndarray, list)) + has_para = isinstance(self.cell_para, (np.ndarray, list)) + + if has_vect and has_para: + # If both are provided, check for consistency + derived_para = np.array(unit_cell_parser.cell_vect_to_para(self.cell_vect)).flatten() + provided_para = np.array(self.cell_para).flatten() + assert np.allclose(derived_para, provided_para, atol=1e-3), \ + "Provided cell_para and cell_vect are inconsistent." + elif has_vect and not has_para: + self.cell_para = unit_cell_parser.cell_vect_to_para(self.cell_vect) + elif not has_vect and has_para: + self.cell_vect = unit_cell_parser.cell_para_to_vect(self.cell_para) + else: + raise ValueError("Crystal lattice is not defined. Provide 'cell_vect' or 'cell_para'.") + + # --- 2. Finalize Atom Coordinates --- + if self.atoms == "unknown" or not self.atoms: + print("Warning: Crystal initialized with no atoms.") + self.atoms = [] + else: + for atom in self.atoms: + has_cart = isinstance(atom.cart_xyz, (np.ndarray, list)) + has_frac = isinstance(atom.frac_xyz, (np.ndarray, list)) + if not (has_cart and has_frac): + if has_frac: + atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) + elif has_cart: + atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect) + else: + raise ValueError(f"Atom {atom.element} {atom.atom_id} has no coordinate information.") + + # --- 3. Calculate Derived Properties --- + self.volume = unit_cell_parser.calculate_volume(self.cell_para) + if self.atoms: + total_mass = sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms) + # Density in g/cm^3 + self.density = total_mass / (self.volume * 1e-24) / (6.022140857e23) + else: + self.density = 0.0 + + def update_cart_by_frac(self) -> None: + """Updates all atom cartesian coordinates from their fractional coordinates.""" + for atom in self.atoms: + atom.cart_xyz = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) + + def update_frac_by_cart(self) -> None: + """Updates all atom fractional coordinates from their cartesian coordinates.""" + for atom in self.atoms: + atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_vect(atom.cart_xyz, self.cell_vect) + + def check(self) -> None: + """Performs consistency checks on the crystal structure.""" + print("Performing consistency checks...") + # Check lattice consistency + self.lattice_and_atom_complete() + + # Check atom coordinate consistency + for atom in self.atoms: + derived_cart = unit_cell_parser.atom_frac_to_cart_by_cell_vect(atom.frac_xyz, self.cell_vect) + assert np.allclose(atom.cart_xyz, derived_cart, atol=1e-3), \ + f"Atom {atom.atom_id} cartesian and fractional coordinates do not match." + + # Check atom IDs + if all(atom.atom_id != "unknown" for atom in self.atoms): + print("All atoms have IDs.") + else: + print("Warning: Not all atoms have IDs. Use .give_atom_id_forced() to assign them.") + print("Checks passed.") + + def give_atom_id_forced(self) -> None: + """Assigns or resets atom IDs from 0 to N-1 and clears bonding info.""" + print("Warning: Resetting all atom IDs and bonding information!") + for i, atom in enumerate(self.atoms): + atom.atom_id = i + atom.bonded_atom = [] + + def move_atom_into_cell(self) -> None: + """ + Moves all atoms into the primary unit cell [0, 1) in fractional coordinates. + """ + for atom in self.atoms: + # Use modulo for a more direct and efficient way to wrap coordinates + atom.frac_xyz = np.mod(atom.frac_xyz, 1.0) + self.update_cart_by_frac() + + def find_molecule(self, tolerance: float = 1.15) -> None: + """ + Identifies molecules within the crystal based on bonding distances. + + This method performs a graph search (BFS) on the atoms, connecting them + based on scaled covalent radii. It populates the `atom.molecule` and + `self.molecule_number` attributes. + + Args: + tolerance: A scaling factor for covalent radii to determine bonding. + A bond is formed if dist(A, B) < (radius(A) + radius(B)) * tolerance. + """ + self.move_atom_into_cell() + atoms_to_visit = list(range(len(self.atoms))) + molecule_id = 0 + + while atoms_to_visit: + molecule_id += 1 + # Start a Breadth-First Search (BFS) from the first unvisited atom + q = [atoms_to_visit[0]] + visited_in_molecule = {atoms_to_visit[0]} + + head = 0 + while head < len(q): + current_atom_idx = q[head] + head += 1 + self.atoms[current_atom_idx].molecule = molecule_id + + # Check for bonds with all other atoms + for other_atom_idx in range(len(self.atoms)): + if current_atom_idx == other_atom_idx: + continue + + # is_bonding_crystal handles periodic boundaries + is_bonded, _ = operation.is_bonding_crystal( + self.atoms[current_atom_idx], + self.atoms[other_atom_idx], + self.cell_vect, + tolerance=tolerance, + update_atom2=False # Do not modify coordinates during search + ) + + if is_bonded and other_atom_idx not in visited_in_molecule: + visited_in_molecule.add(other_atom_idx) + q.append(other_atom_idx) + + # Remove all atoms found in the new molecule from the list to visit + atoms_to_visit = [idx for idx in atoms_to_visit if idx not in visited_in_molecule] + + self.molecule_number = molecule_id + + def get_element(self) -> List[str]: + """Returns a sorted list of unique element symbols in the crystal.""" + return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms)) + + def get_element_amount(self) -> List[int]: + """Returns the count of each element, sorted by atomic number.""" + all_elements = [atom.element for atom in self.atoms] + return [all_elements.count(element) for element in self.get_element()] + + + def make_p1(self) -> None: + """ + Expands the asymmetric unit to the full P1 cell using symmetry operations. + + The crystal's space group is set to 1 (P1) and SYMM is reset. This + implementation is robustly designed to ensure the final coordinate array + is always 2-dimensional, preventing downstream errors. + """ + all_ele, all_frac = self.get_ele_and_frac() + all_reflect_position = [] + all_matrix_M = [] + all_matrix_C = [] + for sym_opt in self.SYMM: + sym_opt_ele = sym_opt.lower().replace(" ", "").split(",") + # assert len(sym_opt_ele) == 3, "sym {} could not be treat".format(sym_opt_ele) + matrix_M = np.zeros((3, 3)) + matrix_C = np.zeros((1, 3)) + for idx, word in enumerate(sym_opt_ele): + sym_opt_ele_split = re.findall(r".*?([+-]*[xyz0-9\/\.]+)", word) + for sym_opt_frag in sym_opt_ele_split: + if sym_opt_frag == 'x' or sym_opt_frag == '+x': + matrix_M[0][idx] = 1 + elif str(sym_opt_frag) == '-x': + matrix_M[0][idx] = -1 + elif sym_opt_frag == 'y' or sym_opt_frag == '+y': + matrix_M[1][idx] = 1 + elif sym_opt_frag == '-y': + matrix_M[1][idx] = -1 + elif sym_opt_frag == 'z' or sym_opt_frag == '+z': + matrix_M[2][idx] = 1 + elif sym_opt_frag == '-z': + matrix_M[2][idx] = -1 + elif operation.is_number(sym_opt_frag) is True: + matrix_C[0][idx] = float(fractions.Fraction(sym_opt_frag)) + else: + raise Exception("wrong sym opt of" + sym_opt_frag) + + all_matrix_M.append(matrix_M) + all_matrix_C.append(matrix_C) + + for j in range(0, len(all_matrix_M)): + new_positions = np.dot(np.array([all_frac]), all_matrix_M[j]) + all_matrix_C[j] + all_reflect_position.append(new_positions.squeeze()) + all_ele = all_ele*len(self.SYMM) + + new_atoms = [] + idx=0 + for element, frac_xyz in zip(all_ele, np.array(all_reflect_position).reshape(-1,3)): + new_atoms.append(Atom(element=element, + frac_xyz=frac_xyz, + atom_id=idx)) + idx+=1 + + self.SYMM = "[x,y,z]" + self.space_group = 1 + self.atoms = new_atoms + self.update_cart_by_frac() + + def sort_by_element(self) -> None: + """Sorts the atoms list based on atomic number.""" + self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element]) + + def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]: + """Returns all element symbols and their cartesian coordinates.""" + if not self.atoms: + return [], np.empty((0, 3)) + all_ele = [atom.element for atom in self.atoms] + all_carts = np.array([atom.cart_xyz for atom in self.atoms]) + return all_ele, all_carts + + def get_ele_and_frac(self) -> Tuple[List[str], np.ndarray]: + """Returns all element symbols and their fractional coordinates.""" + if not self.atoms: + return [], np.empty((0, 3)) + all_ele = [atom.element for atom in self.atoms] + all_fracs = np.array([atom.frac_xyz for atom in self.atoms]) + return all_ele, all_fracs + + def info(self, all_info: bool = False) -> None: + """ + Prints a formatted summary of the crystal structure. + + Args: + all_info: If True, prints an extended table including fractional + coordinates, forces, and other properties. + """ + print("--- Crystal System ---") + print(f"Name: {self.system_name}") + print("Lattice Vectors (Angstrom):") + for vec in self.cell_vect: + print(f"{vec[0]:16.8f} {vec[1]:16.8f} {vec[2]:16.8f}") + print("Lattice Parameters:") + print(f"a, b, c (A): {self.cell_para[0][0]:.4f}, {self.cell_para[0][1]:.4f}, {self.cell_para[0][2]:.4f}") + print(f"alpha, beta, gamma (deg): {self.cell_para[1][0]:.4f}, {self.cell_para[1][1]:.4f}, {self.cell_para[1][2]:.4f}") + print(f"Volume (A^3): {self.volume:.4f} | Density (g/cm^3): {self.density:.4f}") + print(f"\n--- Atomic Coordinates (Total: {len(self.atoms)}) ---") + + if not all_info: + print(f"{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}") + print("-" * 58) + for atom in self.atoms: + print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}") + else: + header = ( + f"{'ID':<5} {'Elem':<6} " + f"{'Frac X':>10} {'Frac Y':>10} {'Frac Z':>10} | " + f"{'Cart X':>12} {'Cart Y':>12} {'Cart Z':>12}" + ) + print(header) + print("-" * len(header)) + for atom in self.atoms: + aid = str(atom.atom_id) if atom.atom_id != 'unknown' else '-' + print( + f"{aid:<5} {atom.element:<6} " + f"{atom.frac_xyz[0]:10.6f} {atom.frac_xyz[1]:10.6f} {atom.frac_xyz[2]:10.6f} | " + f"{atom.cart_xyz[0]:12.6f} {atom.cart_xyz[1]:12.6f} {atom.cart_xyz[2]:12.6f}" + ) + + print("\n--- Other Properties ---") + print(f"Energy: {self.energy}") + print(f"Comment: {self.comment}") + print(f"Virial: {self.virial}") + + +class Molecule: + """Represents a non-periodic molecule (a collection of atoms).""" + + def __init__(self, **kwargs: Any): + """ + Initializes a Molecule object. + + Args: + **kwargs: Keyword arguments to set molecule attributes. + Required: 'atoms' (List[Atom]). + Optional: 'energy', 'comment', 'name', 'system_name'. + """ + self.atoms: Union[str, List[Atom]] = kwargs.get("atoms", "unknown") + self.energy: Union[str, float] = kwargs.get("energy", "unknown") + self.comment: Any = kwargs.get("comment", "unknown") + self.descriptor: Any = kwargs.get("descriptor", "unknown") + self.name: str = kwargs.get("name", "unknown") + self.system_name: str = kwargs.get("system_name", "unknown") + + if self.atoms == "unknown": + print("Warning: Molecule initialized with no atoms.") + self.atoms = [] + + def give_atom_id_forced(self) -> None: + """Assigns or resets atom IDs from 0 to N-1 and clears bonding info.""" + print("Warning: Resetting all atom IDs and bonding information!") + for i, atom in enumerate(self.atoms): + atom.atom_id = i + atom.bonded_atom = [] + + def get_element(self) -> List[str]: + """Returns a sorted list of unique element symbols in the molecule.""" + if not self.atoms: return [] + return chemical_knowledge.sort_by_atomic_number(set(atom.element for atom in self.atoms)) + + def get_element_amount(self) -> List[int]: + """Returns the count of each element, sorted by atomic number.""" + if not self.atoms: return [] + all_elements = [atom.element for atom in self.atoms] + return [all_elements.count(element) for element in self.get_element()] + + def get_ele_and_cart(self) -> Tuple[List[str], np.ndarray]: + """Returns all element symbols and their cartesian coordinates.""" + if not self.atoms: + return [], np.empty((0, 3)) + all_ele = [atom.element for atom in self.atoms] + all_carts = np.array([atom.cart_xyz for atom in self.atoms]) + return all_ele, all_carts + + def put_ele_cart_back(self, all_ele: List[str], all_carts: np.ndarray) -> None: + """Updates the molecule's atoms from lists of elements and coordinates.""" + for i, atom in enumerate(self.atoms): + atom.element = all_ele[i] + atom.cart_xyz = all_carts[i] + + def build_molecules_by_ele_cart(self, all_ele: List[str], all_carts: np.ndarray) -> None: + """Rebuilds the molecule's atoms list from elements and coordinates.""" + assert len(all_ele) == len(all_carts), "Element and coordinate lists must have the same length." + self.atoms = [ + Atom(element=ele, cart_xyz=cart, atom_id=i) + for i, (ele, cart) in enumerate(zip(all_ele, all_carts)) + ] + + def get_mass(self) -> float: + """Calculates the total mass of the molecule.""" + if not self.atoms: return 0.0 + return sum(chemical_knowledge.element_masses[atom.element] for atom in self.atoms) + + def get_center_of_mass(self) -> np.ndarray: + """Calculates the center of mass of the molecule.""" + if not self.atoms: return np.zeros(3) + + all_ele, all_carts = self.get_ele_and_cart() + masses = np.array([chemical_knowledge.element_masses[x] for x in all_ele]) + total_mass = np.sum(masses) + + if total_mass == 0: return np.zeros(3) + return np.sum(all_carts * masses[:, np.newaxis], axis=0) / total_mass + + def sort_by_element(self) -> None: + """Sorts the atoms list based on atomic number.""" + self.atoms.sort(key=lambda atom: chemical_knowledge.periodic_table_list[atom.element]) + + def sort_by_id(self) -> None: + """Sorts the atoms list based on their atom_id.""" + self.atoms.sort(key=lambda atom: atom.atom_id) + + def info(self) -> None: + """Prints a formatted summary of the molecule.""" + print(f"--- Molecule ---") + print(f"Name: {self.name} | System: {self.system_name}") + print(f"Number of atoms: {len(self.atoms)}") + print(f"Total Mass (amu): {self.get_mass():.4f}") + print(f"Energy: {self.energy}") + print(f"Comment: {self.comment}") + print(f"\n{'Element':<10} {'Cartesian X':>16} {'Cartesian Y':>16} {'Cartesian Z':>16}") + print("-" * 58) + if self.atoms: + for atom in self.atoms: + print(f"{atom.element:<10} {atom.cart_xyz[0]:16.8f} {atom.cart_xyz[1]:16.8f} {atom.cart_xyz[2]:16.8f}") + + def find_fragment(self, tolerance: float = 1.15) -> Dict[int, List[int]]: + """ + Identifies covalently bonded fragments within the molecule. + + This is useful for molecules that are actually composed of several + disconnected components (e.g., salts, solvent shells). + + Args: + tolerance: Scaling factor for covalent radii to determine bonding. + + Returns: + A dictionary mapping a fragment ID (starting from 1) to a list of + atom indices belonging to that fragment. + """ + if not self.atoms: return {} + + num_atoms = len(self.atoms) + cart_matrix = np.array([atom.cart_xyz for atom in self.atoms]) + radii = np.array([chemical_knowledge.element_covalent_radii[atom.element] for atom in self.atoms]) + + # Create a matrix of bond thresholds (r_i + r_j) + bond_threshold_matrix = (radii[:, np.newaxis] + radii) * tolerance + + # True where distance is less than the bond threshold + dist_matrix = cdist(cart_matrix, cart_matrix) + adj_matrix = dist_matrix < bond_threshold_matrix + np.fill_diagonal(adj_matrix, False) + + # Graph traversal (DFS) to find connected components + visited = [False] * num_atoms + groups = {} + group_id = 0 + for i in range(num_atoms): + if not visited[i]: + group_id += 1 + groups[group_id] = [] + stack = [i] + while stack: + atom_idx = stack.pop() + if not visited[atom_idx]: + visited[atom_idx] = True + groups[group_id].append(atom_idx) + # Find neighbors and add to stack + neighbors = np.where(adj_matrix[atom_idx])[0] + stack.extend(neighbors) + return groups + + def give_molecule_id(self, tolerance: float = 1.15) -> None: + """Assigns a molecule ID to each atom based on fragment analysis.""" + fragments = self.find_fragment(tolerance=tolerance) + for group_id, atom_indices in fragments.items(): + for atom_idx in atom_indices: + self.atoms[atom_idx].molecule = group_id + + def take_out_fragment(self, tolerance: float = 1.15) -> List['Molecule']: + """ + Splits the current molecule into a list of new Molecule objects, + one for each disconnected fragment. + """ + if not self.atoms: return [] + + self.give_atom_id_forced() # Ensure IDs are set for lookup + fragments = self.find_fragment(tolerance=tolerance) + new_molecules = [] + + for i, atom_indices in fragments.items(): + fragment_atoms = [self.atoms[j] for j in atom_indices] + new_mol = Molecule( + atoms=copy.deepcopy(fragment_atoms), + name=f"{self.name}_frag{i}", + system_name=f"{self.system_name}_frag{i}" + ) + new_molecules.append(new_mol) + return new_molecules + + def calculate_frac_xyz_by_cell_para(self, cell_para: list) -> None: + """Calculates fractional coordinates for all atoms given cell parameters.""" + for atom in self.atoms: + atom.frac_xyz = unit_cell_parser.atom_cart_to_frac_by_cell_para(atom.cart_xyz, cell_para) + + def molecule_volume(self, num_samples: int = 100000) -> float: + """ + Calculates the van der Waals volume using a Monte Carlo integration method. + + This method samples points in a bounding box around the molecule and + determines the ratio of points that fall within any atom's vdW sphere. + + Args: + num_samples: The number of random points to sample. More points + yield a more accurate volume at the cost of performance. + + Returns: + The estimated van der Waals volume in cubic Angstroms. + """ + if not self.atoms: return 0.0 + + elements, coords = self.get_ele_and_cart() + radii = np.array([chemical_knowledge.element_vdw_radii[el] for el in elements]) + + # Determine bounding box for sampling + min_bounds = np.min(coords, axis=0) - np.max(radii) + max_bounds = np.max(coords, axis=0) + np.max(radii) + bounding_box_volume = np.prod(max_bounds - min_bounds) + + # Generate random sample points within the bounding box + random_points = np.random.uniform(min_bounds, max_bounds, (num_samples, 3)) + + # Check for each point if it's inside ANY sphere + count_inside = 0 + for rp in tqdm(random_points, desc="Monte Carlo Volume", leave=False): + # Calculate squared distances from the point to all atom centers + dist_sq = np.sum((coords - rp)**2, axis=1) + # If any distance is within the radius, the point is inside + if np.any(dist_sq <= radii**2): + count_inside += 1 + return (count_inside / num_samples) * bounding_box_volume \ No newline at end of file diff --git a/basic_function/format_parser.py b/basic_function/format_parser.py index c17226e..6a34377 100644 --- a/basic_function/format_parser.py +++ b/basic_function/format_parser.py @@ -1,333 +1,333 @@ -import re -from basic_function import operation -from basic_function import data_classes -from basic_function import chemical_knowledge -import copy - - -def read_xyz_file(file_path): - input_file = open(file_path, 'r') - lines = input_file.readlines() - number_of_atoms = int(lines[0]) - name = str(lines[1][:-1]) - atoms = [] - for index,line in enumerate(lines): - split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) - if len(split_line)==4 and operation.is_number(split_line[1]) and \ - operation.is_number(split_line[2]) and operation.is_number(split_line[3]): - atoms.append(data_classes.Atom(element=split_line[0], - cart_xyz=[float(split_line[1]), float(split_line[2]), float(split_line[3])], - atom_id=index-2)) - - if number_of_atoms!=len(atoms): - print("Warning! The length of atoms don't match the number of atoms given") - - molecule = data_classes.Molecule(atoms=atoms, name=name, system_name=name) - - return molecule - -def write_cif_file(crystal, sym=False, name="zcx"): - """ - Accept crystal class, give the cif file out - :param crystal: crystal class - :param coordinates: frac or cart - :param sym: False:give all atoms out; True:with symmetry - :param name: file name - :return: cif_out - cif file in list format should be print using the following function: - target=open("D:\\zcx.cif",'w') - target.writelines(cif_out) - target.close() - """ - - if crystal.system_name!="unknown": - name = crystal.system_name - cif_file = [] - - cif_file.append("data_"+str(name)+"\n") - if sym==False: - if crystal.space_group==1: - crystal_temp = crystal - else: - crystal_temp = copy.deepcopy(crystal) - crystal_temp.make_p1() - cif_file.append("_symmetry_space_group_name_H-M \'P1\'"+"\n") - cif_file.append("_symmetry_Int_Tables_number 1"+"\n") - - cif_file.append("loop_"+"\n") - cif_file.append("_symmetry_equiv_pos_site_id"+"\n") - cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") - cif_file.append("1 x,y,z"+"\n") - cif_file.append("_cell_length_a "+str(crystal_temp.cell_para[0][0])+"\n") - cif_file.append("_cell_length_b "+str(crystal_temp.cell_para[0][1])+"\n") - cif_file.append("_cell_length_c "+str(crystal_temp.cell_para[0][2])+"\n") - cif_file.append("_cell_angle_alpha "+str(crystal_temp.cell_para[1][0])+"\n") - cif_file.append("_cell_angle_beta "+str(crystal_temp.cell_para[1][1])+"\n") - cif_file.append("_cell_angle_gamma "+str(crystal_temp.cell_para[1][2])+"\n") - - cif_file.append("loop_"+"\n") - cif_file.append("_atom_site_label"+"\n") - cif_file.append("_atom_site_type_symbol"+"\n") - cif_file.append("_atom_site_fract_x"+"\n") - cif_file.append("_atom_site_fract_y"+"\n") - cif_file.append("_atom_site_fract_z"+"\n") - for i in range(0,len(crystal_temp.atoms)): - cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" - .format(i+1,crystal_temp.atoms[i].element,crystal_temp.atoms[i].frac_xyz[0], - crystal_temp.atoms[i].frac_xyz[1],crystal_temp.atoms[i].frac_xyz[2])) - - return cif_file - - elif sym==True: - cif_file.append("_symmetry_space_group_name_H-M \'{}\'".format(chemical_knowledge.space_group[crystal.space_group][1])+"\n") - cif_file.append("_symmetry_Int_Tables_number {}".format(crystal.space_group)+"\n") - - cif_file.append("loop_"+"\n") - cif_file.append("_symmetry_equiv_pos_site_id"+"\n") - cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") - for idx, SYMM in enumerate(crystal.SYMM): - cif_file.append("{} {}".format(idx+1,SYMM)+"\n") - cif_file.append("_cell_length_a "+str(crystal.cell_para[0][0])+"\n") - cif_file.append("_cell_length_b "+str(crystal.cell_para[0][1])+"\n") - cif_file.append("_cell_length_c "+str(crystal.cell_para[0][2])+"\n") - cif_file.append("_cell_angle_alpha "+str(crystal.cell_para[1][0])+"\n") - cif_file.append("_cell_angle_beta "+str(crystal.cell_para[1][1])+"\n") - cif_file.append("_cell_angle_gamma "+str(crystal.cell_para[1][2])+"\n") - - cif_file.append("loop_"+"\n") - cif_file.append("_atom_site_label"+"\n") - cif_file.append("_atom_site_type_symbol"+"\n") - cif_file.append("_atom_site_fract_x"+"\n") - cif_file.append("_atom_site_fract_y"+"\n") - cif_file.append("_atom_site_fract_z"+"\n") - for i in range(0,len(crystal.atoms)): - cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" - .format(i+1,crystal.atoms[i].element,crystal.atoms[i].frac_xyz[0], - crystal.atoms[i].frac_xyz[1],crystal.atoms[i].frac_xyz[2])) - - return cif_file - - -def write_cifs_file(crystals, sym=False, name="zcx"): - cifs_file = [] - for crystal in crystals: - single_cif = write_cif_file(crystal,sym=sym, name=name) - cifs_file.extend(single_cif) - return cifs_file - - -def read_cif_file(file_path,on_sym_check=False,shut_up=False,system_name="unknown",comment_name="unknown"): - input_file = open(file_path, 'r') - lines = input_file.readlines() - step_pickle = [] - crystal_all = [] - if system_name=="unknown": - no_name = True - else: - no_name = False - # first time scan - for index,line in enumerate(lines): - # find out all the step pickle - if line.startswith("data_"): - step_pickle.append(index) - step_pickle.append(len(lines)) - - # treat every step and return a crystal - for m in range(0,len(step_pickle)-1): - atoms = [] - atoms_P1 = [] - SYMM = [] - cell_para = [["unknown","unknown","unknown"],["unknown","unknown","unknown"]] - - for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): - split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) - if line.startswith("#"): - pass - elif len(split_line)==0: - pass - # read the loop of symmetry - # elif split_line[0]=="loop_" and lines[step_pickle[m]+index+1]=="_symmetry_equiv_pos_as_xyz\n": - elif split_line[0] == "loop_" and lines[step_pickle[m] + index + 1] == "_symmetry_equiv_pos_as_xyz\n": - temp_number = 1 - while "_" not in lines[step_pickle[m]+index+1+temp_number]: - split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) - temp_number+=1 - if not operation.is_number(split_line_temp[0]): - SYMM.append(split_line_temp[0]) - else: - SYMM.append(split_line_temp[1]) - elif split_line[0] == "loop_" and lines[step_pickle[m]+index+2].strip(" ")=="_symmetry_equiv_pos_as_xyz\n": - temp_number = 1 - while "_" not in lines[step_pickle[m]+index+2+temp_number]: - split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+2+temp_number]))) - temp_number+=1 - SYMM.append("".join(split_line_temp[1:])) - elif split_line[0] == "loop_" and "_space_group_symop_operation_xyz\n" in lines[step_pickle[m] + index + 1]: - # ase format - temp_number = 1 - while "_" not in lines[step_pickle[m]+index+1+temp_number]: - if lines[step_pickle[m]+index+1+temp_number]=="\n": - temp_number += 1 - continue - split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) - temp_number+=1 - if not operation.is_number(split_line_temp[0]): - SYMM.append("".join(split_line_temp)) - - # read the loop of atoms: - elif (split_line[0] == "loop_" and lines[step_pickle[m] + index + 1].strip(" ") == "_atom_site_label\n") or \ - (split_line[0] == "loop_" and lines[step_pickle[m] + index + 2].strip(" ") == "_atom_site_label\n"): - temp_number = 0 - while "_" in lines[step_pickle[m]+index+1+temp_number]: - if lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_type_symbol\n": - ele_pos = temp_number - elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_x\n": - x_pos = temp_number - elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_y\n": - y_pos = temp_number - elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_z\n": - z_pos = temp_number - temp_number+=1 - how_long = temp_number - - while len(list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))) == how_long: - split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) - atoms.append(data_classes.Atom(element=split_line_temp[ele_pos], - frac_xyz=[float(split_line_temp[x_pos]),float(split_line_temp[y_pos]), - float(split_line_temp[z_pos])])) - temp_number += 1 - if step_pickle[m]+index+1+temp_number==len(lines): - break - - elif split_line[0] == "_cell_length_a": - cell_para[0][0] = float(split_line[1]) - elif split_line[0] == "_cell_length_b": - cell_para[0][1] = float(split_line[1]) - elif split_line[0] == "_cell_length_c": - cell_para[0][2] = float(split_line[1]) - elif split_line[0] == "_cell_angle_alpha": - cell_para[1][0] = float(split_line[1]) - elif split_line[0] == "_cell_angle_beta": - cell_para[1][1] = float(split_line[1]) - elif split_line[0] == "_cell_angle_gamma": - cell_para[1][2] = float(split_line[1]) - elif "data_" in line: - if no_name == True: - system_name = line[5:] - system_name = system_name.replace(" ","_") - system_name = system_name.replace("\\", "_") - system_name = system_name.replace("\n", "") - for atom in atoms: - all_reflect_position = operation.space_group_transfer_for_single_atom(atom.frac_xyz, SYMM) - for new_position in all_reflect_position: - atoms_P1.append(data_classes.Atom(element=atom.element, - frac_xyz=[new_position[0], new_position[1], new_position[2]])) - crystal_all.append(data_classes.Crystal(cell_para=cell_para, atoms=atoms_P1, comment=comment_name, system_name=system_name)) - if on_sym_check == True: - raise Exception("Not finished part, TODO in code") - if shut_up==False: - if m%100 == 0: - print("{} structures have been treated".format(m)) - - return crystal_all - - -def write_poscar_file(crystal, coordinates = 'frac', name = "parser_zcx_create"): - - vasp_file = [] - - vasp_file.append('{}\n'.format(name)) - vasp_file.append('1.0\n') - cell_vect = crystal.cell_vect - for vect in cell_vect: - vasp_file.append("{:16.8f} {:16.8f} {:16.8f}\n".format(vect[0],vect[1],vect[2])) - crystal.sort_by_element() - vasp_file.append("".join("{:>6s}".format(x) for x in crystal.get_element()) + "\n") - vasp_file.append("".join("{:>6.0f}".format(x) for x in crystal.get_element_amount()) + "\n") - if coordinates == 'frac': - vasp_file.append('Direct\n') - for ELEMENT in crystal.get_element(): - for ATOM in crystal.atoms: - if ATOM.element == ELEMENT: - vasp_file.append( - "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.frac_xyz[0], ATOM.frac_xyz[1], ATOM.frac_xyz[2])) - elif coordinates == 'cart': - vasp_file.append('Cartesian\n') - for ELEMENT in crystal.get_element(): - for ATOM in crystal.atoms: - if ATOM.element == ELEMENT: - vasp_file.append( - "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.cart_xyz[0], ATOM.cart_xyz[1], ATOM.cart_xyz[2])) - else: - raise Exception("Wrong coordinates type: {}".format(coordinates)) - - return vasp_file - - -def read_ase_pbc_file(file_path,shut_up=False): - input_file = open(file_path, 'r') - lines = input_file.readlines()[2:] - step_pickle = [] - crystal_all = [] - - # first time scan - for index,line in enumerate(lines): - # find out all the step pickle - if line.startswith("Step "): - step_pickle.append(index) - step_pickle.append(len(lines)) - - # treat every step and return a crystal - for m in range(0,len(step_pickle)-1): - atoms_P1 = [] - force_matrix = [] - position_matrix = [] - in_forces = False - in_positions = False - - - for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): - # split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) - line = line.strip() - # check Forces part - if line.startswith("Forces:"): - in_forces = True - in_positions = False - continue - - # check Positions part - if line.startswith("Positions:"): - in_positions = True - in_forces = False - continue - - if in_forces and line.startswith("[") and line.endswith("]"): - line = line.replace("[", "").replace("]", "") - force_matrix.append([float(x) for x in line.split()]) - - # analyse Positions part - if in_positions and line.startswith("[") and line.endswith("]"): - line = line.replace("[", "").replace("]", "") - position_matrix.append([float(x) for x in line.split()]) - - if line.startswith("Elements:"): - elements_string = line.strip().split(":", 1)[-1].strip() - elements_string = elements_string[1:-1] - - elements = [elem.strip().strip("'") for elem in elements_string.split(",")] - - if line.startswith("cell:"): - matrix_string = line[len("cell: Cell("):-1] - rows = matrix_string.split("], [") - cell_vect = [ - [float(value) for value in row.replace('[', '').replace(']', '').replace(')', '').split(", ")] - for row in rows - ] - - for i in range(0,len(elements)): - atoms_P1.append(data_classes.Atom(element=elements[i], cart_xyz=[position_matrix[i][0], position_matrix[i][1], position_matrix[i][2]])) - - crystal_all.append(data_classes.Crystal(cell_vect=cell_vect, atoms=atoms_P1)) - - return crystal_all +import re +from basic_function import operation +from basic_function import data_classes +from basic_function import chemical_knowledge +import copy + + +def read_xyz_file(file_path): + input_file = open(file_path, 'r') + lines = input_file.readlines() + number_of_atoms = int(lines[0]) + name = str(lines[1][:-1]) + atoms = [] + for index,line in enumerate(lines): + split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) + if len(split_line)==4 and operation.is_number(split_line[1]) and \ + operation.is_number(split_line[2]) and operation.is_number(split_line[3]): + atoms.append(data_classes.Atom(element=split_line[0], + cart_xyz=[float(split_line[1]), float(split_line[2]), float(split_line[3])], + atom_id=index-2)) + + if number_of_atoms!=len(atoms): + print("Warning! The length of atoms don't match the number of atoms given") + + molecule = data_classes.Molecule(atoms=atoms, name=name, system_name=name) + + return molecule + +def write_cif_file(crystal, sym=False, name="zcx"): + """ + Accept crystal class, give the cif file out + :param crystal: crystal class + :param coordinates: frac or cart + :param sym: False:give all atoms out; True:with symmetry + :param name: file name + :return: cif_out + cif file in list format should be print using the following function: + target=open("D:\\zcx.cif",'w') + target.writelines(cif_out) + target.close() + """ + + if crystal.system_name!="unknown": + name = crystal.system_name + cif_file = [] + + cif_file.append("data_"+str(name)+"\n") + if sym==False: + if crystal.space_group==1: + crystal_temp = crystal + else: + crystal_temp = copy.deepcopy(crystal) + crystal_temp.make_p1() + cif_file.append("_symmetry_space_group_name_H-M \'P1\'"+"\n") + cif_file.append("_symmetry_Int_Tables_number 1"+"\n") + + cif_file.append("loop_"+"\n") + cif_file.append("_symmetry_equiv_pos_site_id"+"\n") + cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") + cif_file.append("1 x,y,z"+"\n") + cif_file.append("_cell_length_a "+str(crystal_temp.cell_para[0][0])+"\n") + cif_file.append("_cell_length_b "+str(crystal_temp.cell_para[0][1])+"\n") + cif_file.append("_cell_length_c "+str(crystal_temp.cell_para[0][2])+"\n") + cif_file.append("_cell_angle_alpha "+str(crystal_temp.cell_para[1][0])+"\n") + cif_file.append("_cell_angle_beta "+str(crystal_temp.cell_para[1][1])+"\n") + cif_file.append("_cell_angle_gamma "+str(crystal_temp.cell_para[1][2])+"\n") + + cif_file.append("loop_"+"\n") + cif_file.append("_atom_site_label"+"\n") + cif_file.append("_atom_site_type_symbol"+"\n") + cif_file.append("_atom_site_fract_x"+"\n") + cif_file.append("_atom_site_fract_y"+"\n") + cif_file.append("_atom_site_fract_z"+"\n") + for i in range(0,len(crystal_temp.atoms)): + cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" + .format(i+1,crystal_temp.atoms[i].element,crystal_temp.atoms[i].frac_xyz[0], + crystal_temp.atoms[i].frac_xyz[1],crystal_temp.atoms[i].frac_xyz[2])) + + return cif_file + + elif sym==True: + cif_file.append("_symmetry_space_group_name_H-M \'{}\'".format(chemical_knowledge.space_group[crystal.space_group][1])+"\n") + cif_file.append("_symmetry_Int_Tables_number {}".format(crystal.space_group)+"\n") + + cif_file.append("loop_"+"\n") + cif_file.append("_symmetry_equiv_pos_site_id"+"\n") + cif_file.append("_symmetry_equiv_pos_as_xyz"+"\n") + for idx, SYMM in enumerate(crystal.SYMM): + cif_file.append("{} {}".format(idx+1,SYMM)+"\n") + cif_file.append("_cell_length_a "+str(crystal.cell_para[0][0])+"\n") + cif_file.append("_cell_length_b "+str(crystal.cell_para[0][1])+"\n") + cif_file.append("_cell_length_c "+str(crystal.cell_para[0][2])+"\n") + cif_file.append("_cell_angle_alpha "+str(crystal.cell_para[1][0])+"\n") + cif_file.append("_cell_angle_beta "+str(crystal.cell_para[1][1])+"\n") + cif_file.append("_cell_angle_gamma "+str(crystal.cell_para[1][2])+"\n") + + cif_file.append("loop_"+"\n") + cif_file.append("_atom_site_label"+"\n") + cif_file.append("_atom_site_type_symbol"+"\n") + cif_file.append("_atom_site_fract_x"+"\n") + cif_file.append("_atom_site_fract_y"+"\n") + cif_file.append("_atom_site_fract_z"+"\n") + for i in range(0,len(crystal.atoms)): + cif_file.append("{:6} {:4} {:16.8f} {:16.8f} {:16.8f}\n" + .format(i+1,crystal.atoms[i].element,crystal.atoms[i].frac_xyz[0], + crystal.atoms[i].frac_xyz[1],crystal.atoms[i].frac_xyz[2])) + + return cif_file + + +def write_cifs_file(crystals, sym=False, name="zcx"): + cifs_file = [] + for crystal in crystals: + single_cif = write_cif_file(crystal,sym=sym, name=name) + cifs_file.extend(single_cif) + return cifs_file + + +def read_cif_file(file_path,on_sym_check=False,shut_up=False,system_name="unknown",comment_name="unknown"): + input_file = open(file_path, 'r') + lines = input_file.readlines() + step_pickle = [] + crystal_all = [] + if system_name=="unknown": + no_name = True + else: + no_name = False + # first time scan + for index,line in enumerate(lines): + # find out all the step pickle + if line.startswith("data_"): + step_pickle.append(index) + step_pickle.append(len(lines)) + + # treat every step and return a crystal + for m in range(0,len(step_pickle)-1): + atoms = [] + atoms_P1 = [] + SYMM = [] + cell_para = [["unknown","unknown","unknown"],["unknown","unknown","unknown"]] + + for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): + split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) + if line.startswith("#"): + pass + elif len(split_line)==0: + pass + # read the loop of symmetry + # elif split_line[0]=="loop_" and lines[step_pickle[m]+index+1]=="_symmetry_equiv_pos_as_xyz\n": + elif split_line[0] == "loop_" and lines[step_pickle[m] + index + 1] == "_symmetry_equiv_pos_as_xyz\n": + temp_number = 1 + while "_" not in lines[step_pickle[m]+index+1+temp_number]: + split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) + temp_number+=1 + if not operation.is_number(split_line_temp[0]): + SYMM.append(split_line_temp[0]) + else: + SYMM.append(split_line_temp[1]) + elif split_line[0] == "loop_" and lines[step_pickle[m]+index+2].strip(" ")=="_symmetry_equiv_pos_as_xyz\n": + temp_number = 1 + while "_" not in lines[step_pickle[m]+index+2+temp_number]: + split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+2+temp_number]))) + temp_number+=1 + SYMM.append("".join(split_line_temp[1:])) + elif split_line[0] == "loop_" and "_space_group_symop_operation_xyz\n" in lines[step_pickle[m] + index + 1]: + # ase format + temp_number = 1 + while "_" not in lines[step_pickle[m]+index+1+temp_number]: + if lines[step_pickle[m]+index+1+temp_number]=="\n": + temp_number += 1 + continue + split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) + temp_number+=1 + if not operation.is_number(split_line_temp[0]): + SYMM.append("".join(split_line_temp)) + + # read the loop of atoms: + elif (split_line[0] == "loop_" and lines[step_pickle[m] + index + 1].strip(" ") == "_atom_site_label\n") or \ + (split_line[0] == "loop_" and lines[step_pickle[m] + index + 2].strip(" ") == "_atom_site_label\n"): + temp_number = 0 + while "_" in lines[step_pickle[m]+index+1+temp_number]: + if lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_type_symbol\n": + ele_pos = temp_number + elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_x\n": + x_pos = temp_number + elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_y\n": + y_pos = temp_number + elif lines[step_pickle[m] + index + 1 + temp_number].strip(" ") == "_atom_site_fract_z\n": + z_pos = temp_number + temp_number+=1 + how_long = temp_number + + while len(list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number])))) == how_long: + split_line_temp = list(filter(lambda x: x != '', re.split("\\s+", lines[step_pickle[m]+index+1+temp_number]))) + atoms.append(data_classes.Atom(element=split_line_temp[ele_pos], + frac_xyz=[float(split_line_temp[x_pos]),float(split_line_temp[y_pos]), + float(split_line_temp[z_pos])])) + temp_number += 1 + if step_pickle[m]+index+1+temp_number==len(lines): + break + + elif split_line[0] == "_cell_length_a": + cell_para[0][0] = float(split_line[1]) + elif split_line[0] == "_cell_length_b": + cell_para[0][1] = float(split_line[1]) + elif split_line[0] == "_cell_length_c": + cell_para[0][2] = float(split_line[1]) + elif split_line[0] == "_cell_angle_alpha": + cell_para[1][0] = float(split_line[1]) + elif split_line[0] == "_cell_angle_beta": + cell_para[1][1] = float(split_line[1]) + elif split_line[0] == "_cell_angle_gamma": + cell_para[1][2] = float(split_line[1]) + elif "data_" in line: + if no_name == True: + system_name = line[5:] + system_name = system_name.replace(" ","_") + system_name = system_name.replace("\\", "_") + system_name = system_name.replace("\n", "") + for atom in atoms: + all_reflect_position = operation.space_group_transfer_for_single_atom(atom.frac_xyz, SYMM) + for new_position in all_reflect_position: + atoms_P1.append(data_classes.Atom(element=atom.element, + frac_xyz=[new_position[0], new_position[1], new_position[2]])) + crystal_all.append(data_classes.Crystal(cell_para=cell_para, atoms=atoms_P1, comment=comment_name, system_name=system_name)) + if on_sym_check == True: + raise Exception("Not finished part, TODO in code") + if shut_up==False: + if m%100 == 0: + print("{} structures have been treated".format(m)) + + return crystal_all + + +def write_poscar_file(crystal, coordinates = 'frac', name = "parser_zcx_create"): + + vasp_file = [] + + vasp_file.append('{}\n'.format(name)) + vasp_file.append('1.0\n') + cell_vect = crystal.cell_vect + for vect in cell_vect: + vasp_file.append("{:16.8f} {:16.8f} {:16.8f}\n".format(vect[0],vect[1],vect[2])) + crystal.sort_by_element() + vasp_file.append("".join("{:>6s}".format(x) for x in crystal.get_element()) + "\n") + vasp_file.append("".join("{:>6.0f}".format(x) for x in crystal.get_element_amount()) + "\n") + if coordinates == 'frac': + vasp_file.append('Direct\n') + for ELEMENT in crystal.get_element(): + for ATOM in crystal.atoms: + if ATOM.element == ELEMENT: + vasp_file.append( + "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.frac_xyz[0], ATOM.frac_xyz[1], ATOM.frac_xyz[2])) + elif coordinates == 'cart': + vasp_file.append('Cartesian\n') + for ELEMENT in crystal.get_element(): + for ATOM in crystal.atoms: + if ATOM.element == ELEMENT: + vasp_file.append( + "{:16.8f} {:16.8f} {:16.8f}\n".format(ATOM.cart_xyz[0], ATOM.cart_xyz[1], ATOM.cart_xyz[2])) + else: + raise Exception("Wrong coordinates type: {}".format(coordinates)) + + return vasp_file + + +def read_ase_pbc_file(file_path,shut_up=False): + input_file = open(file_path, 'r') + lines = input_file.readlines()[2:] + step_pickle = [] + crystal_all = [] + + # first time scan + for index,line in enumerate(lines): + # find out all the step pickle + if line.startswith("Step "): + step_pickle.append(index) + step_pickle.append(len(lines)) + + # treat every step and return a crystal + for m in range(0,len(step_pickle)-1): + atoms_P1 = [] + force_matrix = [] + position_matrix = [] + in_forces = False + in_positions = False + + + for index, line in enumerate(lines[step_pickle[m]:step_pickle[m+1]]): + # split_line = list(filter(lambda x: x != '', re.split("\\s+", line))) + line = line.strip() + # check Forces part + if line.startswith("Forces:"): + in_forces = True + in_positions = False + continue + + # check Positions part + if line.startswith("Positions:"): + in_positions = True + in_forces = False + continue + + if in_forces and line.startswith("[") and line.endswith("]"): + line = line.replace("[", "").replace("]", "") + force_matrix.append([float(x) for x in line.split()]) + + # analyse Positions part + if in_positions and line.startswith("[") and line.endswith("]"): + line = line.replace("[", "").replace("]", "") + position_matrix.append([float(x) for x in line.split()]) + + if line.startswith("Elements:"): + elements_string = line.strip().split(":", 1)[-1].strip() + elements_string = elements_string[1:-1] + + elements = [elem.strip().strip("'") for elem in elements_string.split(",")] + + if line.startswith("cell:"): + matrix_string = line[len("cell: Cell("):-1] + rows = matrix_string.split("], [") + cell_vect = [ + [float(value) for value in row.replace('[', '').replace(']', '').replace(')', '').split(", ")] + for row in rows + ] + + for i in range(0,len(elements)): + atoms_P1.append(data_classes.Atom(element=elements[i], cart_xyz=[position_matrix[i][0], position_matrix[i][1], position_matrix[i][2]])) + + crystal_all.append(data_classes.Crystal(cell_vect=cell_vect, atoms=atoms_P1)) + + return crystal_all \ No newline at end of file diff --git a/basic_function/operation.py b/basic_function/operation.py index 1e157fb..a2ccb8f 100644 --- a/basic_function/operation.py +++ b/basic_function/operation.py @@ -1,469 +1,469 @@ -# -*- coding: utf-8 -*- -""" -A collection of functions for performing crystallographic and molecular operations, -such as symmetry application, supercell generation, and geometric analysis. -""" - -# --- Standard Library Imports --- -import copy -import fractions -import re -from typing import Any, Dict, List, Optional, Tuple, Union - -# --- Third-Party Imports --- -import networkx as nx -import numpy as np -import numpy.typing as npt -from scipy.spatial import cKDTree as KDTree - -# --- Local Application Imports --- -from basic_function import chemical_knowledge, data_classes - -# Type aliases for clarity -NDArrayFloat = npt.NDArray[np.float64] -CellVectors = List[List[float]] -SymmetryOperations = List[str] - - -def is_number(s: str) -> bool: - """Checks if a string can be interpreted as a number (float or fraction). - - Args: - s: The input string. - - Returns: - True if the string represents a number, False otherwise. - """ - try: - float(s) - return True - except ValueError: - pass - - try: - # Check for fractional representations like "1/2" - float(fractions.Fraction(s)) - return True - except ValueError: - return False - - -def _parse_symmetry_operations( - sym_ops: SymmetryOperations, -) -> Tuple[List[NDArrayFloat], List[NDArrayFloat]]: - """Parses a list of symmetry operation strings into matrices. - - This is an internal helper function to avoid code duplication in public functions. - - Args: - sym_ops: A list of symmetry operation strings (e.g., ['x, y, z+1/2']). - - Returns: - A tuple containing two lists: - - A list of 3x3 rotation/reflection matrices (M). - - A list of 1x3 translation vectors (C). - - Raises: - ValueError: If a symmetry operation string is malformed. - """ - rotation_matrices = [] - translation_vectors = [] - - for sym_op_str in sym_ops: - sym_op_parts = sym_op_str.lower().replace(" ", "").split(",") - if len(sym_op_parts) != 3: - raise ValueError(f"Symmetry operation '{sym_op_str}' is invalid.") - - matrix_m = np.zeros((3, 3)) - matrix_c = np.zeros((1, 3)) - - for i, part in enumerate(sym_op_parts): - # Regex to find elements like '+x', '-y', 'z', '1/2', '-0.5' - tokens = re.findall(r"([+-]?[xyz0-9./]+)", part) - for token in tokens: - token = token.strip() - if not token: - continue - - if "x" in token: - matrix_m[0, i] = -1.0 if token.startswith("-") else 1.0 - elif "y" in token: - matrix_m[1, i] = -1.0 if token.startswith("-") else 1.0 - elif "z" in token: - matrix_m[2, i] = -1.0 if token.startswith("-") else 1.0 - elif is_number(token): - matrix_c[0, i] += float(fractions.Fraction(token)) - else: - raise ValueError(f"Invalid fragment '{token}' in symmetry operation.") - - rotation_matrices.append(matrix_m) - translation_vectors.append(matrix_c) - - return rotation_matrices, translation_vectors - - -def space_group_transfer_for_single_atom( - frac_xyz: List[float], space_group_ops: SymmetryOperations -) -> List[List[float]]: - """Applies space group symmetry operations to a single atomic coordinate. - - Args: - frac_xyz: The fractional coordinates [x, y, z] of a single atom. - space_group_ops: A list of space group symmetry operation strings. - - Returns: - A list of all symmetrically equivalent fractional coordinates. - """ - rot_matrices, trans_vectors = _parse_symmetry_operations(space_group_ops) - - equivalent_positions = [] - atom_pos = np.array(frac_xyz) - - for rot, trans in zip(rot_matrices, trans_vectors): - new_pos = np.dot(atom_pos, rot.T) + trans.squeeze() - equivalent_positions.append(new_pos.tolist()) - - return equivalent_positions - - -def super_cell( - crystal: "data_classes.Crystal", - cell_range: Optional[List[List[int]]] = None, -) -> "data_classes.Crystal": - """Constructs a supercell from a unit cell. - - Args: - crystal: The input Crystal object. - cell_range: A list of ranges for each lattice vector, e.g., - [[-1, 1], [-1, 1], [-1, 1]] creates a 3x3x3 supercell. - If None, defaults to [[-1, 1], [-1, 1], [-1, 1]]. - - Returns: - A new Crystal object representing the supercell. - """ - if cell_range is None: - cell_range = [[-1, 1], [-1, 1], [-1, 1]] - - dims = [r[1] - r[0] + 1 for r in cell_range] - - new_lattice = [ - [dim * val for val in crystal.cell_vect[i]] for i, dim in enumerate(dims) - ] - - translation_vectors = [] - for h in range(cell_range[0][0], cell_range[0][1] + 1): - for k in range(cell_range[1][0], cell_range[1][1] + 1): - for l in range(cell_range[2][0], cell_range[2][1] + 1): - translation_vectors.append([h, k, l]) - - new_atoms = [] - for atom in crystal.atoms: - for trans_vec in translation_vectors: - new_frac_xyz = [ - (atom.frac_xyz[i] + trans_vec[i]) / dims[i] for i in range(3) - ] - new_atoms.append( - data_classes.Atom(element=atom.element, frac_xyz=new_frac_xyz) - ) - - if crystal.energy != "unknown": - total_cells = dims[0] * dims[1] * dims[2] - new_energy = crystal.energy * total_cells - else: - new_energy = "unknown" - - return data_classes.Crystal( - cell_vect=new_lattice, energy=new_energy, atoms=new_atoms - ) - - -def orient_molecule(molecule: "data_classes.Molecule") -> "data_classes.Molecule": - """Orients a molecule along its principal axes of inertia. - - The method uses the Moment of Inertia tensor to define a canonical orientation. - The molecule's coordinates are modified in-place. For more details, see: - http://sobereva.com/426 - - Args: - molecule: The Molecule object to be oriented. - - Returns: - The same Molecule object with its atoms reoriented. - """ - all_ele, all_cart = molecule.get_ele_and_cart() - - if len(all_cart) <= 1: - return molecule # No orientation needed for single atoms or empty molecules. - - masses = np.array([chemical_knowledge.element_masses[el] for el in all_ele]) - relative_position = all_cart - molecule.get_center_of_mass() - - # Calculate the moment of inertia tensor - I_xx = np.sum(masses * (relative_position[:, 1] ** 2 + relative_position[:, 2] ** 2)) - I_yy = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 2] ** 2)) - I_zz = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 1] ** 2)) - I_xy = -np.sum(masses * relative_position[:, 0] * relative_position[:, 1]) - I_xz = -np.sum(masses * relative_position[:, 0] * relative_position[:, 2]) - I_yz = -np.sum(masses * relative_position[:, 1] * relative_position[:, 2]) - - I_matrix = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]]) - - # Eigenvectors of the inertia tensor are the principal axes. - # np.linalg.eigh is used for symmetric matrices. - eigenvalues, eigenvectors = np.linalg.eigh(I_matrix) - principal_axes = eigenvectors.T - - # Project the relative positions onto the new axes system. - new_positions = np.dot(relative_position, principal_axes.T) - - molecule.put_ele_cart_back(all_ele, new_positions) - return molecule - - -def get_rotate_matrix(v: NDArrayFloat) -> NDArrayFloat: - """Generates a 3x3 rotation matrix from a 3D vector `v`. - - This function uses a mapping from a 3D vector to a quaternion, which is then - used to construct the rotation matrix. This method avoids gimbal lock. A - left-handed coordinate system is assumed. - - Args: - v: A 3-element NumPy array used to generate the quaternion. - - Returns: - A 3x3 rotation matrix. - """ - # Ensure v elements are within valid ranges if necessary, though the - # formulas handle most inputs gracefully. - v0_sqrt = np.sqrt(max(v[0], 0)) - v0_1_sqrt = np.sqrt(max(1.0 - v[0], 0)) - - angle1 = 2.0 * np.pi * v[1] - angle2 = 2.0 * np.pi * v[2] - - # Quaternion components (x, y, z, w) - qx = v0_1_sqrt * np.sin(angle1) - qy = v0_1_sqrt * np.cos(angle1) - qz = v0_sqrt * np.sin(angle2) - qw = v0_sqrt * np.cos(angle2) - - return np.array([ - [1 - 2*qy**2 - 2*qz**2, 2*qx*qy + 2*qw*qz, 2*qx*qz - 2*qw*qy], - [2*qx*qy - 2*qw*qz, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz + 2*qw*qx], - [2*qx*qz + 2*qw*qy, 2*qy*qz - 2*qw*qx, 1 - 2*qx**2 - 2*qy**2] - ]) - - -def f2c_matrix( - cell_params: Tuple[List[float], List[float]] -) -> Optional[NDArrayFloat]: - """Calculates the fractional-to-Cartesian transformation matrix. - - Args: - cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], - where lengths are in Angstroms and angles are in degrees. - - Returns: - The 3x3 transformation matrix, or None if cell parameters are invalid. - """ - lengths, angles = cell_params - a, b, c = lengths - alpha, beta, gamma = np.deg2rad(angles) - - cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) - sin_g = np.sin(gamma) - - # Volume calculation term - volume_term_sq = ( - 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g - ) - if volume_term_sq < 0: - return None - - volume = a * b * c * np.sqrt(volume_term_sq) - - matrix = np.zeros((3, 3)) - matrix[0, 0] = a - matrix[0, 1] = b * cos_g - matrix[0, 2] = c * cos_b - matrix[1, 1] = b * sin_g - matrix[1, 2] = c * (cos_a - cos_b * cos_g) / sin_g - matrix[2, 2] = volume / (a * b * sin_g) - - return matrix.T - - -def c2f_matrix( - cell_params: Tuple[List[float], List[float]] -) -> Optional[NDArrayFloat]: - """Calculates the Cartesian-to-fractional transformation matrix. - - This is the inverse of the matrix generated by `f2c_matrix`. - - Args: - cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], - where lengths are in Angstroms and angles are in degrees. - - Returns: - The 3x3 transformation matrix, or None if cell parameters are invalid. - """ - f2c = f2c_matrix(cell_params) - if f2c is None: - return None - - try: - return np.linalg.inv(f2c) - except np.linalg.LinAlgError: - return None - - -def apply_SYMM( - frac_xyz: NDArrayFloat, symm_ops: SymmetryOperations -) -> NDArrayFloat: - """Applies symmetry operations to a single set of fractional coordinates. - - Args: - frac_xyz: A NumPy array of fractional coordinates [x, y, z]. - symm_ops: A list of symmetry operation strings. - - Returns: - A NumPy array of all symmetrically equivalent fractional coordinates. - """ - rot_matrices, trans_vectors = _parse_symmetry_operations(symm_ops) - - equivalent_positions = [ - np.dot(frac_xyz, rot.T) + trans.squeeze() - for rot, trans in zip(rot_matrices, trans_vectors) - ] - - return np.array(equivalent_positions) - - -def apply_SYMM_with_element( - elements: Union[str, List[str]], - frac_xyzs: NDArrayFloat, - symm_ops: SymmetryOperations, -) -> Tuple[NDArrayFloat, NDArrayFloat]: - """Applies symmetry operations, returning new elements and coordinates. - - Args: - elements: The element symbol(s) corresponding to the coordinates. - frac_xyzs: A NumPy array of fractional coordinates. - symm_ops: A list of symmetry operation strings. - - Returns: - A tuple containing: - - A NumPy array of element symbols for each new position. - - A NumPy array of all symmetrically equivalent fractional coordinates. - """ - equivalent_positions = apply_SYMM(frac_xyzs, symm_ops) - num_ops = len(equivalent_positions) - - replicated_elements = np.tile(np.array(elements).squeeze(), (num_ops, 1)) - - return replicated_elements, equivalent_positions - - -def calculate_longest_diagonal_length(cell_vect: CellVectors) -> float: - """Calculates the length of the longest space diagonal of a unit cell. - - The longest diagonal connects the origin (0,0,0) to the opposite - corner (1,1,1) of the unit cell. - - Args: - cell_vect: The three lattice vectors of the cell. - - Returns: - The length of the longest diagonal in Angstroms. - """ - cell_vect_np = np.array(cell_vect) - diagonal_vector = np.sum(cell_vect_np, axis=0) - return float(np.linalg.norm(diagonal_vector)) - - -def calculate_distance_of_parallel_plane_in_crystal(cell_vect: CellVectors) -> List[float]: - """Calculates inter-planar distances for primary crystallographic planes. - - This computes the distances for the (100), (010), and (001) families of planes. - - Args: - cell_vect: The three lattice vectors [a, b, c] of the cell. - - Returns: - A list of three distances [d_a, d_b, d_c], where d_a is the distance - between planes parallel to the b-c plane, and so on. - """ - distances = [] - vectors = [np.array(v) for v in cell_vect] - - # Permutations to calculate distance for each primary plane - # (a to b-c plane, b to a-c plane, c to a-b plane) - indices = [(0, 1, 2), (1, 0, 2), (2, 0, 1)] - - for i, j, k in indices: - point_p = vectors[i] - plane_v1 = vectors[j] - plane_v2 = vectors[k] - - # Normal vector to the plane defined by plane_v1 and plane_v2 - normal_vector = np.cross(plane_v1, plane_v2) - - # Distance from point P to the plane is |N · P| / ||N|| - distance = abs(np.dot(normal_vector, point_p)) / np.linalg.norm(normal_vector) - distances.append(distance) - - return distances - - -def detect_is_frame_vdw_new(crystal: "data_classes.Crystal", tolerance: float = 1.2) -> bool: - """Detects if a crystal structure forms a connected framework via VdW radii. - - The method involves: - 1. Expanding the crystal to a P1 symmetry supercell. - 2. Building a 3x3x3 supercell to ensure periodic connections are considered. - 3. Constructing a graph where atoms are nodes and an edge exists if their - distance is within a scaled sum of their van der Waals radii. - 4. Checking if the largest connected component in the graph is large enough - to be considered a single, percolating framework. - - Args: - crystal: The Crystal object to analyze. - tolerance: A tolerance factor to scale the VdW radii sum. - - Returns: - True if the structure is a connected framework, False otherwise. - """ - crystal_temp = copy.deepcopy(crystal) - crystal_temp.make_p1() - crystal_temp.move_atom_into_cell() - - # Create a 3x3x3 supercell to check for connectivity across boundaries - crystal_supercell = super_cell(crystal_temp, cell_range=[[-1, 1], [-1, 1], [-1, 1]]) - - all_ele, all_carts = crystal_supercell.get_ele_and_cart() - - vdw_radii_map = chemical_knowledge.element_vdw_radii - vdw_max = max(vdw_radii_map[el] for el in set(all_ele)) - distance_threshold = vdw_max * tolerance * 2 - - # KDTree for efficient nearest-neighbor search - tree = KDTree(all_carts) - pairs = tree.query_pairs(r=distance_threshold) - - # Build a graph to find connected components - graph = nx.Graph() - graph.add_nodes_from(range(len(all_carts))) - graph.add_edges_from(list(pairs)) - - if not graph.nodes: - return False - - # Find the largest connected component - largest_cc = max(nx.connected_components(graph), key=len) - - # A heuristic to check for a percolating framework. A connected framework - # should connect most atoms. The threshold '9' is empirical but robustly - # distinguishes between isolated molecules and a fully connected lattice. - # In a 3x3x3 supercell (27 unit cells), a connected framework should involve - # significantly more atoms than in a few unit cells. +# -*- coding: utf-8 -*- +""" +A collection of functions for performing crystallographic and molecular operations, +such as symmetry application, supercell generation, and geometric analysis. +""" + +# --- Standard Library Imports --- +import copy +import fractions +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +# --- Third-Party Imports --- +import networkx as nx +import numpy as np +import numpy.typing as npt +from scipy.spatial import cKDTree as KDTree + +# --- Local Application Imports --- +from basic_function import chemical_knowledge, data_classes + +# Type aliases for clarity +NDArrayFloat = npt.NDArray[np.float64] +CellVectors = List[List[float]] +SymmetryOperations = List[str] + + +def is_number(s: str) -> bool: + """Checks if a string can be interpreted as a number (float or fraction). + + Args: + s: The input string. + + Returns: + True if the string represents a number, False otherwise. + """ + try: + float(s) + return True + except ValueError: + pass + + try: + # Check for fractional representations like "1/2" + float(fractions.Fraction(s)) + return True + except ValueError: + return False + + +def _parse_symmetry_operations( + sym_ops: SymmetryOperations, +) -> Tuple[List[NDArrayFloat], List[NDArrayFloat]]: + """Parses a list of symmetry operation strings into matrices. + + This is an internal helper function to avoid code duplication in public functions. + + Args: + sym_ops: A list of symmetry operation strings (e.g., ['x, y, z+1/2']). + + Returns: + A tuple containing two lists: + - A list of 3x3 rotation/reflection matrices (M). + - A list of 1x3 translation vectors (C). + + Raises: + ValueError: If a symmetry operation string is malformed. + """ + rotation_matrices = [] + translation_vectors = [] + + for sym_op_str in sym_ops: + sym_op_parts = sym_op_str.lower().replace(" ", "").split(",") + if len(sym_op_parts) != 3: + raise ValueError(f"Symmetry operation '{sym_op_str}' is invalid.") + + matrix_m = np.zeros((3, 3)) + matrix_c = np.zeros((1, 3)) + + for i, part in enumerate(sym_op_parts): + # Regex to find elements like '+x', '-y', 'z', '1/2', '-0.5' + tokens = re.findall(r"([+-]?[xyz0-9./]+)", part) + for token in tokens: + token = token.strip() + if not token: + continue + + if "x" in token: + matrix_m[0, i] = -1.0 if token.startswith("-") else 1.0 + elif "y" in token: + matrix_m[1, i] = -1.0 if token.startswith("-") else 1.0 + elif "z" in token: + matrix_m[2, i] = -1.0 if token.startswith("-") else 1.0 + elif is_number(token): + matrix_c[0, i] += float(fractions.Fraction(token)) + else: + raise ValueError(f"Invalid fragment '{token}' in symmetry operation.") + + rotation_matrices.append(matrix_m) + translation_vectors.append(matrix_c) + + return rotation_matrices, translation_vectors + + +def space_group_transfer_for_single_atom( + frac_xyz: List[float], space_group_ops: SymmetryOperations +) -> List[List[float]]: + """Applies space group symmetry operations to a single atomic coordinate. + + Args: + frac_xyz: The fractional coordinates [x, y, z] of a single atom. + space_group_ops: A list of space group symmetry operation strings. + + Returns: + A list of all symmetrically equivalent fractional coordinates. + """ + rot_matrices, trans_vectors = _parse_symmetry_operations(space_group_ops) + + equivalent_positions = [] + atom_pos = np.array(frac_xyz) + + for rot, trans in zip(rot_matrices, trans_vectors): + new_pos = np.dot(atom_pos, rot.T) + trans.squeeze() + equivalent_positions.append(new_pos.tolist()) + + return equivalent_positions + + +def super_cell( + crystal: "data_classes.Crystal", + cell_range: Optional[List[List[int]]] = None, +) -> "data_classes.Crystal": + """Constructs a supercell from a unit cell. + + Args: + crystal: The input Crystal object. + cell_range: A list of ranges for each lattice vector, e.g., + [[-1, 1], [-1, 1], [-1, 1]] creates a 3x3x3 supercell. + If None, defaults to [[-1, 1], [-1, 1], [-1, 1]]. + + Returns: + A new Crystal object representing the supercell. + """ + if cell_range is None: + cell_range = [[-1, 1], [-1, 1], [-1, 1]] + + dims = [r[1] - r[0] + 1 for r in cell_range] + + new_lattice = [ + [dim * val for val in crystal.cell_vect[i]] for i, dim in enumerate(dims) + ] + + translation_vectors = [] + for h in range(cell_range[0][0], cell_range[0][1] + 1): + for k in range(cell_range[1][0], cell_range[1][1] + 1): + for l in range(cell_range[2][0], cell_range[2][1] + 1): + translation_vectors.append([h, k, l]) + + new_atoms = [] + for atom in crystal.atoms: + for trans_vec in translation_vectors: + new_frac_xyz = [ + (atom.frac_xyz[i] + trans_vec[i]) / dims[i] for i in range(3) + ] + new_atoms.append( + data_classes.Atom(element=atom.element, frac_xyz=new_frac_xyz) + ) + + if crystal.energy != "unknown": + total_cells = dims[0] * dims[1] * dims[2] + new_energy = crystal.energy * total_cells + else: + new_energy = "unknown" + + return data_classes.Crystal( + cell_vect=new_lattice, energy=new_energy, atoms=new_atoms + ) + + +def orient_molecule(molecule: "data_classes.Molecule") -> "data_classes.Molecule": + """Orients a molecule along its principal axes of inertia. + + The method uses the Moment of Inertia tensor to define a canonical orientation. + The molecule's coordinates are modified in-place. For more details, see: + http://sobereva.com/426 + + Args: + molecule: The Molecule object to be oriented. + + Returns: + The same Molecule object with its atoms reoriented. + """ + all_ele, all_cart = molecule.get_ele_and_cart() + + if len(all_cart) <= 1: + return molecule # No orientation needed for single atoms or empty molecules. + + masses = np.array([chemical_knowledge.element_masses[el] for el in all_ele]) + relative_position = all_cart - molecule.get_center_of_mass() + + # Calculate the moment of inertia tensor + I_xx = np.sum(masses * (relative_position[:, 1] ** 2 + relative_position[:, 2] ** 2)) + I_yy = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 2] ** 2)) + I_zz = np.sum(masses * (relative_position[:, 0] ** 2 + relative_position[:, 1] ** 2)) + I_xy = -np.sum(masses * relative_position[:, 0] * relative_position[:, 1]) + I_xz = -np.sum(masses * relative_position[:, 0] * relative_position[:, 2]) + I_yz = -np.sum(masses * relative_position[:, 1] * relative_position[:, 2]) + + I_matrix = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]]) + + # Eigenvectors of the inertia tensor are the principal axes. + # np.linalg.eigh is used for symmetric matrices. + eigenvalues, eigenvectors = np.linalg.eigh(I_matrix) + principal_axes = eigenvectors.T + + # Project the relative positions onto the new axes system. + new_positions = np.dot(relative_position, principal_axes.T) + + molecule.put_ele_cart_back(all_ele, new_positions) + return molecule + + +def get_rotate_matrix(v: NDArrayFloat) -> NDArrayFloat: + """Generates a 3x3 rotation matrix from a 3D vector `v`. + + This function uses a mapping from a 3D vector to a quaternion, which is then + used to construct the rotation matrix. This method avoids gimbal lock. A + left-handed coordinate system is assumed. + + Args: + v: A 3-element NumPy array used to generate the quaternion. + + Returns: + A 3x3 rotation matrix. + """ + # Ensure v elements are within valid ranges if necessary, though the + # formulas handle most inputs gracefully. + v0_sqrt = np.sqrt(max(v[0], 0)) + v0_1_sqrt = np.sqrt(max(1.0 - v[0], 0)) + + angle1 = 2.0 * np.pi * v[1] + angle2 = 2.0 * np.pi * v[2] + + # Quaternion components (x, y, z, w) + qx = v0_1_sqrt * np.sin(angle1) + qy = v0_1_sqrt * np.cos(angle1) + qz = v0_sqrt * np.sin(angle2) + qw = v0_sqrt * np.cos(angle2) + + return np.array([ + [1 - 2*qy**2 - 2*qz**2, 2*qx*qy + 2*qw*qz, 2*qx*qz - 2*qw*qy], + [2*qx*qy - 2*qw*qz, 1 - 2*qx**2 - 2*qz**2, 2*qy*qz + 2*qw*qx], + [2*qx*qz + 2*qw*qy, 2*qy*qz - 2*qw*qx, 1 - 2*qx**2 - 2*qy**2] + ]) + + +def f2c_matrix( + cell_params: Tuple[List[float], List[float]] +) -> Optional[NDArrayFloat]: + """Calculates the fractional-to-Cartesian transformation matrix. + + Args: + cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], + where lengths are in Angstroms and angles are in degrees. + + Returns: + The 3x3 transformation matrix, or None if cell parameters are invalid. + """ + lengths, angles = cell_params + a, b, c = lengths + alpha, beta, gamma = np.deg2rad(angles) + + cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) + sin_g = np.sin(gamma) + + # Volume calculation term + volume_term_sq = ( + 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g + ) + if volume_term_sq < 0: + return None + + volume = a * b * c * np.sqrt(volume_term_sq) + + matrix = np.zeros((3, 3)) + matrix[0, 0] = a + matrix[0, 1] = b * cos_g + matrix[0, 2] = c * cos_b + matrix[1, 1] = b * sin_g + matrix[1, 2] = c * (cos_a - cos_b * cos_g) / sin_g + matrix[2, 2] = volume / (a * b * sin_g) + + return matrix.T + + +def c2f_matrix( + cell_params: Tuple[List[float], List[float]] +) -> Optional[NDArrayFloat]: + """Calculates the Cartesian-to-fractional transformation matrix. + + This is the inverse of the matrix generated by `f2c_matrix`. + + Args: + cell_params: A tuple containing [[a, b, c], [alpha, beta, gamma]], + where lengths are in Angstroms and angles are in degrees. + + Returns: + The 3x3 transformation matrix, or None if cell parameters are invalid. + """ + f2c = f2c_matrix(cell_params) + if f2c is None: + return None + + try: + return np.linalg.inv(f2c) + except np.linalg.LinAlgError: + return None + + +def apply_SYMM( + frac_xyz: NDArrayFloat, symm_ops: SymmetryOperations +) -> NDArrayFloat: + """Applies symmetry operations to a single set of fractional coordinates. + + Args: + frac_xyz: A NumPy array of fractional coordinates [x, y, z]. + symm_ops: A list of symmetry operation strings. + + Returns: + A NumPy array of all symmetrically equivalent fractional coordinates. + """ + rot_matrices, trans_vectors = _parse_symmetry_operations(symm_ops) + + equivalent_positions = [ + np.dot(frac_xyz, rot.T) + trans.squeeze() + for rot, trans in zip(rot_matrices, trans_vectors) + ] + + return np.array(equivalent_positions) + + +def apply_SYMM_with_element( + elements: Union[str, List[str]], + frac_xyzs: NDArrayFloat, + symm_ops: SymmetryOperations, +) -> Tuple[NDArrayFloat, NDArrayFloat]: + """Applies symmetry operations, returning new elements and coordinates. + + Args: + elements: The element symbol(s) corresponding to the coordinates. + frac_xyzs: A NumPy array of fractional coordinates. + symm_ops: A list of symmetry operation strings. + + Returns: + A tuple containing: + - A NumPy array of element symbols for each new position. + - A NumPy array of all symmetrically equivalent fractional coordinates. + """ + equivalent_positions = apply_SYMM(frac_xyzs, symm_ops) + num_ops = len(equivalent_positions) + + replicated_elements = np.tile(np.array(elements).squeeze(), (num_ops, 1)) + + return replicated_elements, equivalent_positions + + +def calculate_longest_diagonal_length(cell_vect: CellVectors) -> float: + """Calculates the length of the longest space diagonal of a unit cell. + + The longest diagonal connects the origin (0,0,0) to the opposite + corner (1,1,1) of the unit cell. + + Args: + cell_vect: The three lattice vectors of the cell. + + Returns: + The length of the longest diagonal in Angstroms. + """ + cell_vect_np = np.array(cell_vect) + diagonal_vector = np.sum(cell_vect_np, axis=0) + return float(np.linalg.norm(diagonal_vector)) + + +def calculate_distance_of_parallel_plane_in_crystal(cell_vect: CellVectors) -> List[float]: + """Calculates inter-planar distances for primary crystallographic planes. + + This computes the distances for the (100), (010), and (001) families of planes. + + Args: + cell_vect: The three lattice vectors [a, b, c] of the cell. + + Returns: + A list of three distances [d_a, d_b, d_c], where d_a is the distance + between planes parallel to the b-c plane, and so on. + """ + distances = [] + vectors = [np.array(v) for v in cell_vect] + + # Permutations to calculate distance for each primary plane + # (a to b-c plane, b to a-c plane, c to a-b plane) + indices = [(0, 1, 2), (1, 0, 2), (2, 0, 1)] + + for i, j, k in indices: + point_p = vectors[i] + plane_v1 = vectors[j] + plane_v2 = vectors[k] + + # Normal vector to the plane defined by plane_v1 and plane_v2 + normal_vector = np.cross(plane_v1, plane_v2) + + # Distance from point P to the plane is |N · P| / ||N|| + distance = abs(np.dot(normal_vector, point_p)) / np.linalg.norm(normal_vector) + distances.append(distance) + + return distances + + +def detect_is_frame_vdw_new(crystal: "data_classes.Crystal", tolerance: float = 1.2) -> bool: + """Detects if a crystal structure forms a connected framework via VdW radii. + + The method involves: + 1. Expanding the crystal to a P1 symmetry supercell. + 2. Building a 3x3x3 supercell to ensure periodic connections are considered. + 3. Constructing a graph where atoms are nodes and an edge exists if their + distance is within a scaled sum of their van der Waals radii. + 4. Checking if the largest connected component in the graph is large enough + to be considered a single, percolating framework. + + Args: + crystal: The Crystal object to analyze. + tolerance: A tolerance factor to scale the VdW radii sum. + + Returns: + True if the structure is a connected framework, False otherwise. + """ + crystal_temp = copy.deepcopy(crystal) + crystal_temp.make_p1() + crystal_temp.move_atom_into_cell() + + # Create a 3x3x3 supercell to check for connectivity across boundaries + crystal_supercell = super_cell(crystal_temp, cell_range=[[-1, 1], [-1, 1], [-1, 1]]) + + all_ele, all_carts = crystal_supercell.get_ele_and_cart() + + vdw_radii_map = chemical_knowledge.element_vdw_radii + vdw_max = max(vdw_radii_map[el] for el in set(all_ele)) + distance_threshold = vdw_max * tolerance * 2 + + # KDTree for efficient nearest-neighbor search + tree = KDTree(all_carts) + pairs = tree.query_pairs(r=distance_threshold) + + # Build a graph to find connected components + graph = nx.Graph() + graph.add_nodes_from(range(len(all_carts))) + graph.add_edges_from(list(pairs)) + + if not graph.nodes: + return False + + # Find the largest connected component + largest_cc = max(nx.connected_components(graph), key=len) + + # A heuristic to check for a percolating framework. A connected framework + # should connect most atoms. The threshold '9' is empirical but robustly + # distinguishes between isolated molecules and a fully connected lattice. + # In a 3x3x3 supercell (27 unit cells), a connected framework should involve + # significantly more atoms than in a few unit cells. return len(largest_cc) > 9 * len(crystal_temp.atoms) \ No newline at end of file diff --git a/basic_function/packaged_function.py b/basic_function/packaged_function.py index b131115..11632b7 100644 --- a/basic_function/packaged_function.py +++ b/basic_function/packaged_function.py @@ -1,102 +1,102 @@ -from basic_function import format_parser -from basic_function import CSP_generator_normal -import os -import concurrent.futures -import sys - - - -def process_crystal(seed, sg, molecules,output_path,add_name): - aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg) - molecules_number = sum(len(molecule.atoms) for molecule in molecules) - new_crystal = aaa.generate(seed=seed) - sys.stdout.flush() - if new_crystal is not None: - cif_out = format_parser.write_cif_file(new_crystal) - with open(f"{output_path}/structures/{add_name}_{sg}_{seed}_z{len(molecules)}_{molecules_number}.cif", 'w') as target: - target.writelines(cif_out) - return True - return False - -def CSP_generater_parallel(molecules,output_path,need_structure = 100, space_group_list=[1],max_workers=8,add_name='',start_seed=1): - space_groups = space_group_list - accept_count = need_structure - - try: - os.makedirs("{}/structures".format(output_path)) - except: - print("Warning, these is already an structures folder in this path, skip mkdir") - for sg in space_groups: - accept = 0 - seed = start_seed - - with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: - futures = {} - while accept < accept_count: - # submit new task - while len(futures) < max_workers and accept + len(futures) < accept_count: - future = executor.submit(process_crystal, seed, sg, molecules, output_path,add_name) - futures[future] = seed - seed += 1 - - # check the finished task - done, _ = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) - for future in done: - if future.result(): - accept += 1 - # remove it from list, no matter what result it is - del futures[future] - - # cancel all task if the number need is arrived. - if accept >= accept_count: - for future in futures: - future.cancel() - break - - -def CSP_generater_serial(molecules,output_path,need_structure = 100, densely_pack_method=False, space_group_list=[1]): - """ - :param molecules: a list [molecule1, molecule2, ...] - :param output_path: a str indicate the path of output folder - :param need_structure: int - :param space_group_list:a list indicate the space group need to search - """ - try: - os.makedirs("{}\\structures".format(output_path)) - except: - print("Warning, these is already an structures folder in this path, skip mkdir") - for sg in space_group_list: - aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg) - accept=0 - i=1 - while accept= accept_count: + for future in futures: + future.cancel() + break + + +def CSP_generater_serial(molecules,output_path,need_structure = 100, densely_pack_method=False, space_group_list=[1]): + """ + :param molecules: a list [molecule1, molecule2, ...] + :param output_path: a str indicate the path of output folder + :param need_structure: int + :param space_group_list:a list indicate the space group need to search + """ + try: + os.makedirs("{}\\structures".format(output_path)) + except: + print("Warning, these is already an structures folder in this path, skip mkdir") + for sg in space_group_list: + aaa = CSP_generator_normal.CrystalGenerator(molecules, space_group=sg) + accept=0 + i=1 + while accept CellVectors: - """Converts cell parameters to lattice vectors. - - The lattice vector `a` is aligned with the x-axis. The vector `b` lies in - the xy-plane. - - Args: - cell_para: A tuple containing [[a, b, c], [alpha, beta, gamma]], - where lengths are in Angstroms and angles are in degrees. - check: If True, asserts the input shape is correct. - - Returns: - A 3x3 list of lists representing the cell vectors [a, b, c]. - """ - if check: - shape_check = np.array(cell_para) - assert shape_check.shape == (2, 3), "Input `cell_para` must have shape (2, 3)." - - lengths = cell_para[0] - angles_deg = cell_para[1] - - a, b, c = lengths - alpha, beta, gamma = np.deg2rad(angles_deg) - - cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) - sin_g = np.sin(gamma) - - # This term is related to the square of the cell volume. - # It ensures the cell parameters are physically valid. - volume_term_sq = ( - 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g - ) - - # Ensure the argument for sqrt is non-negative - volume_term = np.sqrt(max(0, volume_term_sq)) - - cell_vect = np.zeros((3, 3)) - cell_vect[0, 0] = a - cell_vect[1, 0] = b * cos_g - cell_vect[1, 1] = b * sin_g - cell_vect[2, 0] = c * cos_b - cell_vect[2, 1] = c * (cos_a - cos_b * cos_g) / sin_g - cell_vect[2, 2] = c * volume_term / sin_g - - return cell_vect.tolist() - - -def cell_vect_to_para(cell_vect: CellVectors, check: bool = False) -> CellParameters: - """Converts lattice vectors to cell parameters. - - Args: - cell_vect: A 3x3 array-like object representing the lattice vectors. - check: If True, asserts the input shape is correct. - - Returns: - A tuple containing [[a, b, c], [alpha, beta, gamma]]. - """ - cell_vect_np = np.array(cell_vect) - if check: - assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." - - vec_a, vec_b, vec_c = cell_vect_np - - len_a = np.linalg.norm(vec_a) - len_b = np.linalg.norm(vec_b) - len_c = np.linalg.norm(vec_c) - - lengths = [len_a, len_b, len_c] - - # Calculate angles using the dot product formula; handle potential floating point inaccuracies. - def _calculate_angle(v1, v2, norm1, norm2): - cosine_angle = np.dot(v1, v2) / (norm1 * norm2) - # Clip to handle values slightly outside [-1, 1] due to precision issues - return np.arccos(np.clip(cosine_angle, -1.0, 1.0)) - - alpha_rad = _calculate_angle(vec_b, vec_c, len_b, len_c) - beta_rad = _calculate_angle(vec_a, vec_c, len_a, len_c) - gamma_rad = _calculate_angle(vec_a, vec_b, len_a, len_b) - - angles_deg = np.rad2deg([alpha_rad, beta_rad, gamma_rad]).tolist() - - return (lengths, angles_deg) - - -def atom_frac_to_cart_by_cell_vect( - atom_frac: Coordinates, cell_vect: CellVectors, check: bool = False -) -> List[float]: - """Converts fractional coordinates to Cartesian coordinates using cell vectors. - - Args: - atom_frac: A 3-element list or array of fractional coordinates. - cell_vect: A 3x3 matrix of lattice vectors. - check: If True, asserts input shapes are correct. - - Returns: - A list of 3 Cartesian coordinates. - """ - atom_frac_np = np.array(atom_frac) - cell_vect_np = np.array(cell_vect) - - if check: - assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." - assert atom_frac_np.shape == (3,), "Input `atom_frac` must have 3 elements." - - # The transformation is a linear combination of the basis vectors. - # atom_cart = frac_x * vec_a + frac_y * vec_b + frac_z * vec_c - # This is equivalent to a dot product: [fx, fy, fz] @ [[ax,ay,az],[bx,by,bz],[cx,cy,cz]] - atom_cart = np.dot(atom_frac_np, cell_vect_np) - return atom_cart.tolist() - - -def atom_frac_to_cart_by_cell_para( - atom_frac: Coordinates, cell_para: CellParameters, check: bool = False -) -> List[float]: - """Converts fractional coordinates to Cartesian using cell parameters. - - Args: - atom_frac: A 3-element list or array of fractional coordinates. - cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. - check: If True, performs validation checks in underlying functions. - - Returns: - A list of 3 Cartesian coordinates. - """ - cell_vect = cell_para_to_vect(cell_para, check=check) - return atom_frac_to_cart_by_cell_vect(atom_frac, cell_vect, check=check) - - -def atom_cart_to_frac_by_cell_vect( - atom_cart: Coordinates, cell_vect: CellVectors, check: bool = False -) -> List[float]: - """Converts Cartesian coordinates to fractional coordinates using cell vectors. - - Args: - atom_cart: A 3-element list or array of Cartesian coordinates. - cell_vect: A 3x3 matrix of lattice vectors. - check: If True, asserts input shapes are correct. - - Returns: - A list of 3 fractional coordinates. - """ - atom_cart_np = np.array(atom_cart) - cell_vect_np = np.array(cell_vect) - - if check: - assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." - assert atom_cart_np.shape == (3,), "Input `atom_cart` must have 3 elements." - - # The transformation is atom_frac = atom_cart @ inverse(cell_vect) - inv_cell_vect = np.linalg.inv(cell_vect_np) - atom_frac = np.dot(atom_cart_np, inv_cell_vect) - return atom_frac.tolist() - - -def atom_cart_to_frac_by_cell_para( - atom_cart: Coordinates, cell_para: CellParameters, check: bool = False -) -> List[float]: - """Converts Cartesian coordinates to fractional using cell parameters. - - Args: - atom_cart: A 3-element list or array of Cartesian coordinates. - cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. - check: If True, performs validation checks in underlying functions. - - Returns: - A list of 3 fractional coordinates. - """ - cell_vect = cell_para_to_vect(cell_para, check=check) - return atom_cart_to_frac_by_cell_vect(atom_cart, cell_vect, check=check) - - -def calculate_volume(cell_info: Union[CellParameters, CellVectors]) -> float: - """Calculates the volume of the unit cell. - - Args: - cell_info: Can be either cell parameters [[a,b,c], [al,be,ga]] or - a 3x3 matrix of cell vectors. - - Returns: - The volume of the cell in cubic Angstroms. - - Raises: - ValueError: If the shape of `cell_info` is not (2, 3) or (3, 3). - """ - cell_info_np = np.array(cell_info) - - if cell_info_np.shape == (3, 3): - # Input is cell vectors, calculate volume using the scalar triple product. - return float(np.abs(np.dot(cell_info_np[0], np.cross(cell_info_np[1], cell_info_np[2])))) - - elif cell_info_np.shape == (2, 3): - # Input is cell parameters. - lengths, angles_deg = cell_info_np - a, b, c = lengths - alpha, beta, gamma = np.deg2rad(angles_deg) - - cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) - - # Standard formula for volume from cell parameters - volume_sq = ( - a**2 * b**2 * c**2 * (1 - cos_a**2 - cos_b**2 - cos_g**2 + 2 * cos_a * cos_b * cos_g) - ) - return float(np.sqrt(max(0, volume_sq))) - - else: +# -*- coding: utf-8 -*- +""" +Provides functions for converting between different representations of a +crystallographic unit cell (cell parameters and lattice vectors) and for +transforming atomic coordinates between fractional and Cartesian systems. +""" + +# --- Standard Library Imports --- +from typing import List, Tuple, Union + +# --- Third-Party Imports --- +import numpy as np +import numpy.typing as npt + +# --- Type Aliases for Clarity --- +NDArrayFloat = npt.NDArray[np.float64] +CellParameters = Tuple[List[float], List[float]] +CellVectors = Union[List[List[float]], NDArrayFloat] +Coordinates = Union[List[float], NDArrayFloat] + + +def cell_para_to_vect( + cell_para: CellParameters, check: bool = False +) -> CellVectors: + """Converts cell parameters to lattice vectors. + + The lattice vector `a` is aligned with the x-axis. The vector `b` lies in + the xy-plane. + + Args: + cell_para: A tuple containing [[a, b, c], [alpha, beta, gamma]], + where lengths are in Angstroms and angles are in degrees. + check: If True, asserts the input shape is correct. + + Returns: + A 3x3 list of lists representing the cell vectors [a, b, c]. + """ + if check: + shape_check = np.array(cell_para) + assert shape_check.shape == (2, 3), "Input `cell_para` must have shape (2, 3)." + + lengths = cell_para[0] + angles_deg = cell_para[1] + + a, b, c = lengths + alpha, beta, gamma = np.deg2rad(angles_deg) + + cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) + sin_g = np.sin(gamma) + + # This term is related to the square of the cell volume. + # It ensures the cell parameters are physically valid. + volume_term_sq = ( + 1.0 - cos_a**2 - cos_b**2 - cos_g**2 + 2.0 * cos_a * cos_b * cos_g + ) + + # Ensure the argument for sqrt is non-negative + volume_term = np.sqrt(max(0, volume_term_sq)) + + cell_vect = np.zeros((3, 3)) + cell_vect[0, 0] = a + cell_vect[1, 0] = b * cos_g + cell_vect[1, 1] = b * sin_g + cell_vect[2, 0] = c * cos_b + cell_vect[2, 1] = c * (cos_a - cos_b * cos_g) / sin_g + cell_vect[2, 2] = c * volume_term / sin_g + + return cell_vect.tolist() + + +def cell_vect_to_para(cell_vect: CellVectors, check: bool = False) -> CellParameters: + """Converts lattice vectors to cell parameters. + + Args: + cell_vect: A 3x3 array-like object representing the lattice vectors. + check: If True, asserts the input shape is correct. + + Returns: + A tuple containing [[a, b, c], [alpha, beta, gamma]]. + """ + cell_vect_np = np.array(cell_vect) + if check: + assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." + + vec_a, vec_b, vec_c = cell_vect_np + + len_a = np.linalg.norm(vec_a) + len_b = np.linalg.norm(vec_b) + len_c = np.linalg.norm(vec_c) + + lengths = [len_a, len_b, len_c] + + # Calculate angles using the dot product formula; handle potential floating point inaccuracies. + def _calculate_angle(v1, v2, norm1, norm2): + cosine_angle = np.dot(v1, v2) / (norm1 * norm2) + # Clip to handle values slightly outside [-1, 1] due to precision issues + return np.arccos(np.clip(cosine_angle, -1.0, 1.0)) + + alpha_rad = _calculate_angle(vec_b, vec_c, len_b, len_c) + beta_rad = _calculate_angle(vec_a, vec_c, len_a, len_c) + gamma_rad = _calculate_angle(vec_a, vec_b, len_a, len_b) + + angles_deg = np.rad2deg([alpha_rad, beta_rad, gamma_rad]).tolist() + + return (lengths, angles_deg) + + +def atom_frac_to_cart_by_cell_vect( + atom_frac: Coordinates, cell_vect: CellVectors, check: bool = False +) -> List[float]: + """Converts fractional coordinates to Cartesian coordinates using cell vectors. + + Args: + atom_frac: A 3-element list or array of fractional coordinates. + cell_vect: A 3x3 matrix of lattice vectors. + check: If True, asserts input shapes are correct. + + Returns: + A list of 3 Cartesian coordinates. + """ + atom_frac_np = np.array(atom_frac) + cell_vect_np = np.array(cell_vect) + + if check: + assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." + assert atom_frac_np.shape == (3,), "Input `atom_frac` must have 3 elements." + + # The transformation is a linear combination of the basis vectors. + # atom_cart = frac_x * vec_a + frac_y * vec_b + frac_z * vec_c + # This is equivalent to a dot product: [fx, fy, fz] @ [[ax,ay,az],[bx,by,bz],[cx,cy,cz]] + atom_cart = np.dot(atom_frac_np, cell_vect_np) + return atom_cart.tolist() + + +def atom_frac_to_cart_by_cell_para( + atom_frac: Coordinates, cell_para: CellParameters, check: bool = False +) -> List[float]: + """Converts fractional coordinates to Cartesian using cell parameters. + + Args: + atom_frac: A 3-element list or array of fractional coordinates. + cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. + check: If True, performs validation checks in underlying functions. + + Returns: + A list of 3 Cartesian coordinates. + """ + cell_vect = cell_para_to_vect(cell_para, check=check) + return atom_frac_to_cart_by_cell_vect(atom_frac, cell_vect, check=check) + + +def atom_cart_to_frac_by_cell_vect( + atom_cart: Coordinates, cell_vect: CellVectors, check: bool = False +) -> List[float]: + """Converts Cartesian coordinates to fractional coordinates using cell vectors. + + Args: + atom_cart: A 3-element list or array of Cartesian coordinates. + cell_vect: A 3x3 matrix of lattice vectors. + check: If True, asserts input shapes are correct. + + Returns: + A list of 3 fractional coordinates. + """ + atom_cart_np = np.array(atom_cart) + cell_vect_np = np.array(cell_vect) + + if check: + assert cell_vect_np.shape == (3, 3), "Input `cell_vect` must have shape (3, 3)." + assert atom_cart_np.shape == (3,), "Input `atom_cart` must have 3 elements." + + # The transformation is atom_frac = atom_cart @ inverse(cell_vect) + inv_cell_vect = np.linalg.inv(cell_vect_np) + atom_frac = np.dot(atom_cart_np, inv_cell_vect) + return atom_frac.tolist() + + +def atom_cart_to_frac_by_cell_para( + atom_cart: Coordinates, cell_para: CellParameters, check: bool = False +) -> List[float]: + """Converts Cartesian coordinates to fractional using cell parameters. + + Args: + atom_cart: A 3-element list or array of Cartesian coordinates. + cell_para: The cell parameters [[a, b, c], [alpha, beta, gamma]]. + check: If True, performs validation checks in underlying functions. + + Returns: + A list of 3 fractional coordinates. + """ + cell_vect = cell_para_to_vect(cell_para, check=check) + return atom_cart_to_frac_by_cell_vect(atom_cart, cell_vect, check=check) + + +def calculate_volume(cell_info: Union[CellParameters, CellVectors]) -> float: + """Calculates the volume of the unit cell. + + Args: + cell_info: Can be either cell parameters [[a,b,c], [al,be,ga]] or + a 3x3 matrix of cell vectors. + + Returns: + The volume of the cell in cubic Angstroms. + + Raises: + ValueError: If the shape of `cell_info` is not (2, 3) or (3, 3). + """ + cell_info_np = np.array(cell_info) + + if cell_info_np.shape == (3, 3): + # Input is cell vectors, calculate volume using the scalar triple product. + return float(np.abs(np.dot(cell_info_np[0], np.cross(cell_info_np[1], cell_info_np[2])))) + + elif cell_info_np.shape == (2, 3): + # Input is cell parameters. + lengths, angles_deg = cell_info_np + a, b, c = lengths + alpha, beta, gamma = np.deg2rad(angles_deg) + + cos_a, cos_b, cos_g = np.cos([alpha, beta, gamma]) + + # Standard formula for volume from cell parameters + volume_sq = ( + a**2 * b**2 * c**2 * (1 - cos_a**2 - cos_b**2 - cos_g**2 + 2 * cos_a * cos_b * cos_g) + ) + return float(np.sqrt(max(0, volume_sq))) + + else: raise ValueError(f"Cannot understand input shape {cell_info_np.shape} for `cell_info`.") \ No newline at end of file diff --git a/csp.sh b/csp.sh index 42b7820..e1ca131 100644 --- a/csp.sh +++ b/csp.sh @@ -1,36 +1,37 @@ -#!/bin/bash -TOP_DIR=$(pwd) -TAR_DIR="${TOP_DIR}/test" - -mkdir -p "${TAR_DIR}" -cd ${TAR_DIR} - -# generate structures -python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ - --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ - --num_generation 100 --generate_conformers 20 --use_conformers 4 > generate.log 2>&1 - -# opt structures using mace, --batch_size 0 means auto batch size only for mace -mkdir -p "${TAR_DIR}/mace_opt" -cd "${TAR_DIR}/mace_opt" -python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ - --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 80 --batch_size 0 \ - --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ - --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 1 --cueq true \ - --use_ordered_files true --model mace > opt.log 2>&1 - -# opt structures using 7net -# mkdir -p "${TAR_DIR}/7net_opt" -# cd "${TAR_DIR}/7net_opt" -# python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ -# --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 48 --batch_size 2 \ -# --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ -# --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true \ -# --use_ordered_files true --model sevennet > opt.log 2>&1 - -# Postprocess the opt structures -python "${TOP_DIR}/post_process/clean_table.py" -## Make sure you have installed csd-python-api in current env before execuing following commands -# conda activate ccdc -# python "${TOP_DIR}/post_process/check_match.py" --workers 80 --timeout 20 --ref_path "${TAR_DIR}/refs" -# python "${TOP_DIR}/post_process/duplicate_remove.py" --workers 80 + +TOP_DIR=$(pwd) +TAR_DIR="${TOP_DIR}/test" + +mkdir -p "${TAR_DIR}" +cd ${TAR_DIR} + +# conformer search and structure generation +# change --mode to conformer_only or structure_only to seperate the process. +python "${TOP_DIR}/main.py" --path ${TAR_DIR} --smiles "OC(=O)c1cc(O)c(O)c(O)c1.O" \ + --molecule_num_in_cell 1,1 --space_group_list 13,14 --add_name KONTIQ --max_workers 16\ + --num_generation 100 --generate_conformers 20 --use_conformers 4 --mode all > generate.log 2>&1 + +# opt structures using mace, --batch_size 0 means auto batch size only for mace +mkdir -p "${TAR_DIR}/mace_opt" +cd "${TAR_DIR}/mace_opt" +python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ + --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 80 --batch_size 0 \ + --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ + --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 1 --cueq true \ + --use_ordered_files true --model mace > opt.log 2>&1 + +# opt structures using 7net +# mkdir -p "${TAR_DIR}/7net_opt" +# cd "${TAR_DIR}/7net_opt" +# python "${TOP_DIR}/mace-bench/scripts/mace_opt_batch.py" --target_folder "${TAR_DIR}/structures" \ +# --molecule_single 21 --gpu_offset 0 --n_gpus 8 --num_workers 48 --batch_size 2 \ +# --max_steps 3000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ +# --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true \ +# --use_ordered_files true --model sevennet > opt.log 2>&1 + +# Postprocess the opt structures +python "${TOP_DIR}/post_process/clean_table.py" +## Make sure you have installed csd-python-api in current env before execuing following commands +# conda activate ccdc +# python "${TOP_DIR}/post_process/check_match.py" --workers 80 --timeout 20 --ref_path "${TAR_DIR}/refs" +# python "${TOP_DIR}/post_process/duplicate_remove.py" --workers 80 diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__init__.py b/mace-bench/3rdparty/SevenNet/sevenn/__init__.py index 85a1516..6145e44 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/__init__.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/__init__.py @@ -1,13 +1,13 @@ -from importlib.metadata import version - -from packaging.version import Version - -__version__ = version('sevenn') - -from e3nn import __version__ as e3nn_ver - -if Version(e3nn_ver) < Version('0.5.0'): - raise ValueError( - 'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient ' - 'convention.' - ) +from importlib.metadata import version + +from packaging.version import Version + +__version__ = version('sevenn') + +from e3nn import __version__ as e3nn_ver + +if Version(e3nn_ver) < Version('0.5.0'): + raise ValueError( + 'The e3nn version MUST be 0.5.0 or later due to changes in CG coefficient ' + 'convention.' + ) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 16f58640b1ac6fead56c3a802b5fe32431897b71..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 482 zcmYjOy-ve05VoDPji5-t6Ig5=h>;K*LP&_81%#B+A&ce4*Ty2TBgaK5BYheMMqVZ> z6R*HhY}$d7{OS9#e5bS7+uK3DK0e>e2}bBEi2oINap_Nc=RqQgD^%kgL`ye z0xY8y&SUSzRj-cozV{;NuP{Wix5CfJ9>G9*{f6ZrL!;!|zlKJ*8@zyRmC=DtHni)v z@B%tmyPR)x&R=^c=~;U6Iv$rm;Z$q7jp+STHl|ace>pXlR>Fa$askvCT9iW1p`}XG zt6N$an9Wq7Ks(R1@8f*;^kI~AhpKLjb(NZ?H8>%K6PEZX^Mghd&tk6hJl(dmVJo&5 z!Z?hQCjrB(-OfZ*3%It{*mq=;a4*15#!6EITeM)=Oj>rMw5Z0=I@T;*X*9pe9$764 zIG%zoN_J{xBdlApEG**z9NTP#ajvv-oTtsQOU`Q}7ZqIipF{1Thano0gbc|(A(-HJ H9j|`@iu#CH diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_const.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/_const.cpython-310.pyc deleted file mode 100644 index aec61878f46aec9a92dea1495af02a0a6455b5b0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7317 zcmahtYjYddb-NdfB?y8~Q4}dTwxX9sJqYSy*;Zs(OJG650v6Ejg0eP|*AVU!TnR7j zE@@Md6FXAcrjMD<^i$h-#?DMTqYs_-OFpzepnst=H`9FSOdGaMlh$dSxT)Q9&n`fK z5^Z2`&bjBFd+&MOd(TCCbTlcz=dXVJWW9D=5dMib`ajP@;~M<@zl#e36{y${ERIP- zY!ZuzoQ^eOP05mw78`LZj&gEy$VwnjY794%RubsAHA3aOXr<^7P1K1sN{49@@E9Ei z?l>KxDQHd91v>g8(V7H226zhaIN*JB0;Hy$_}%??XYL+|@GRgt_#JeT01rjn=jh}Q zhpew5k4w+esUH&S`K~P82dx*J=babmG*E{f;qKQzI)d=1bL?aB&jq^wL-Mg`9e0ic zJ@X;4P9Svvs27o%1?nVHb3n}_br7giNF4&|G=0uGrZ zISZgodX3Bp*5a~2U!aE{3-qu<9>uJ4^y~Dvehxx!33^dYbzJTgpIdiJqZn z>S0fYo<)glq=Opa96b*6E+bq)m`C^q!s`eNbb&6^3CI;`8lXmp0hSP!=}Dk(xPqJd z3(?Zi?u(9rYh8-)%MpGh!mm0e_J{T@(sPf60UOWvaV|I&l)W5{&Cmp zeHo&B3uG+%3czpDR{?$tt-p;vdkdkB@NIe-daIP%dxxF?SVON-dIe|)W!4bZ=?HMv z>1*z|SH5(+o=G*>KFj;Af)taq<5XZ9#QM`T0}WGY7Jq{ zF=3^ZH z;LA|^-*XmFi}#~^`yj&KbIym_Mg8B2@b{zs#~fQfpf{)v(fEBT8p6D>%};f0USyKm z+Gui|b!!3Oabw+G3mj^+*Mf3m!{2@f1etut@qMRZB|7cKmR;KjtoWPdLP@vemUDl@ z?J&9Fww$VGO?Mp6t_JOdt;+q|MWAqrvHvZQtLhHliL`)l_}?7K`UfdPap@ zHg(l7o(SF%SRcV>H`43vrjy?A9WT8`y>!lPRU4+`2kFjMu-Uv4H&0A&FU{ihkqH5~-j4@A`_%0JScYN#}g=yK0<9Tfl4z^m`qug244cvAs4}JbY z04a3StSE~q{1L?D3yFAVp}P)0_)J~fJ`5=5LDem6e*qtJn>!KD4?I1#uR)WIRyXT5 zoY;2DaqH{bydL+)0SrWdo69}GBk&0i)%_{|a zNip)q6VW>j(z`88TxnFBcWCwM5BPk1;)LM60%*@gy^DOBmjJ+Whx(VxRLoK?XXh$f z#>|&A?`1A{qzAZ1aq2(h#H#?FT7!Fq8|L%Vl&4l8J|dmOLr5BK;ylsdxIGVj)|7wW z3+w@hK@gL!??OUX;SPe`2d6}^rxJh6m*M_P3PdEL^u^xF90RG)%bm$U1d2Q&?a3fU zA*UZ>#y^7m2K8dvz$Dz8+#;X<_y+p`gl`YMCq9Jju_>$yqHv#3u&e9M!~Nbg)I!eA zlr&96Z({O-iDY4s5x=&}uWa~A9ZsEhJKoj59 zOWtcB=Y4~N7XiS1*J$kNu^)kiF9S%3Niro4nA)F}0}1k6ze#z>N?qA8XV0acaE+eI z5bjTI7-mQTh(LgYM!ZWN3UE-2gI9r`!>qYWu0(haa%V&Ejy)7Vib20bg^%GW8@1s1 z8n!UEa&ST)&8j)2QZQ|(5!4%%{Hmho70tAZrL0=u;mGBeS?s!ci;d~3k}XwCTQTfh zMq{x~yTfL>OyJnu(p80T1$MBXzmm^pRn6AaieA!;CxppxvaGpvmY~jBb)ymZyQ4dX zC5qkc{%f$^=ZBvJk(dz2@C#=S_X%7#ZmZsKAeO!p9YApE@VK#f-K%!iJ)ZS@Mi2MI z;D#&XSYG%=itV+re*%I}hmGe(cVmCTw?X&L2*?`Q6gnfYqhfq}huc8%PVi~2aDYiU z*`GqjNIu%Ush?gX^q0kfps=isSU6GXMf5+fBSmF&3%_m8+=+#HQ_1LdkGD2d#@cf+LPr^n7$@|y|9U1FEn|@ z-zUJ!dTYdnr$u+&xy*{1xn3^8Wn^bdH#OW|g}ip%fR|q2G1+ule%rGq6*%yg`GJ{* z9mNZKEQhhMV{%yQusIFx1v{rIW<^)+B3vB#av`6|o443#Uf0#KVHXvxq7;}^C@OEV zVGeCIyR1eeAJ2v%XUlNWz^HMa0g%tDr9uU7j#PgU$0qw*wx$$SHomLuV$-`Aq1a?! zsKQBDe?*Opq8AjM^+$ABF;#Y~Pnl@@*qUY!-eGKlOWAO%m5O;yF-tm|3^!rKGL;<` z9^p1DSWwwi#Zc{~JY1o9&4wVTpu#NYK?+0Vir!U=OKLU?`A~s<>j@ZX>ZLvU z;C_a_nVoUKD8rq-R48Sxdn(WV6%Lm97J!U=y5wRLyQ>y9272j5IFK#XHp(k(f2Oii zDrD5|am33M7K1lElkT+JjlLUAqVRu=HPc^H@Il!;md$vJ$Fuyg!bb2qfIV1LOg7Qi zuoZpTU@1dcRc+3Os{yi}x0Y_%1!YMsuzlTT7zw;Vyf?X0$tA_itk?z^z^2QPrFtm? zZ%jMPJ$P=isZ6N|eggYUxM>U%-bgGJt%H>)o8Iy1%{&AQqxv)#qZ|)A;}~;q6;!bH zqb*=G)Kh)7^V^Ay=!%vt6>URRvn;NbOPLkeJ|)9oNe=l|QHpH1R5tTPFhOS{NOt)K zZ18NQ0Q~VD8TF;`P2If(Y`)lVS%m`**GPam@LH^>EV*+=uyODQ98x>0MU39upf5^< zlN{lG>3NOc+ZENYOHj)9jvaR5nKJyO&6K0eFKLBaY$#)_+J(|G8{v_;S<MS zCq#KTAtj(Nk@zDF{&SFjLy`bLACr>eL6U@vxBHby657(|WATYtLMG#~NWPE>`CN{X z|C1&1TRBetSC+|V@(}rroFKoJhPed!R7xTKl{`-VGc`eenVKZONKTRekoS{+lV`}k z$_L1{JSW2gKKv#@yWaOFn(w`!i`6@uzZ%{7N1rpCre~f63G2-v{QNB|nqr z$iK)3$v?}7x-)}SLrDKa;;cN{hsmdM zlKhW6LjGGCSdV;0KI!=((p5UxU6K42v_6cjpOO<_wu+SKj^P+){k06+ z&jX*KGZ|h0&eKAM3&2InXLu2~M1>450WVWA!z;k6w3y*F;ImZ9@H+50TFUTw;1AJq zhA#kLq?HU`0)Bv2Gkh8N!?c#+2Z0}=vl%V}KTPWxz5@IRoy+i}z>m@S3||F)oIaG{ zCxEZfg$zFl{1jcx@YBH0(4`Dt2i~9uGW;y?b96bw&jY_eAI|WLz#pLpGyD?pN9myq zzYP2_Drfi=;E&V88NLDhDqYF&YrwD5BN=`JSffWXya`;P$1+?6*6C`7w}7|l@eJ<( z*XW51e**X>U6XZsGUM%nXV6m_ZU8sw=?uRK+@fbPYyw+!UAAdM?$NXI7Ck55qUYsp zdO;GsC_hObk)NWMQdQm|UG`~94rp6?v?G10$sv70zC$-GcS*|ckt4rPUHJn#kUylJ{1LhG$MjkG6S^b+fco+eX(0cIJo(4u%Rix^{8M^I z{uu@G&nc9DL6Q7R8p*$+L;2TqB>#r)%D<&!`FHfL{CoPG{3(52{*1mLe@+|ml>Khm^@gLUjOjee4I{^k zo-UN}a}>Cd1L@8?BO48LOE+?TuWWm=ES;`B>PIX)Tb^Y@`UkD=`O$qcj{Q-2FbX4O zBtTI1kGxc+L%Z)vTS@Hr5(@iDt)U%+j!;^&N&m~q_S#fDx8-_vKVGsUf8ch6HyZ3a zK^QM-R%5qT5k`A=Q*W8^%+L?xnP$U`*E>$XFZy;AxgAFwI-SVJ_}PkHuTy@w-w|QY z?M88d8N%GDZCPny%HA{ukUyVQ7RWN1RTsQ1ah|mn*6pSqPu-0djYd@$Te@boTRNS&D-L>o z7{zNR*OjW(seFmBZp3)Oc)!ut_C;t7JP}#O2;ps9=3dX#8S4R@=RP% z>nN|#VrpU1vzS_z^et!eQr8vL#KeDbDV>@2jXjt4=5ZWS|L1WMl*T3%n5$wbj0?4wAy0)p;sfg2QizCN9=tXg{s#%)2rPsE1 zEGnrU6?JZO9G6s|3cQ+iJD%IOqfRfL+tjSePCOIY2UH53ewPzTB%;AEUN`mnmZ&rg zL$6r1h9N9{w;31IJQ2F@I&ncw7G{m~IIpC^J$^h}GqxJEVEAKom8QaI7&@&F17{fc z9fTdH2S`H%(rRcQI)YaSok$Grs0U@3wSSzffzV806xgmOR9Z@H$36}0h*p#W2@rO0 z5YiI%VbtMtE&9l*wA%M=d9Rc;#Zqf#k`tbiknxhKkrPH4>wGH8A4SOBp#T(6pVm75 z01QWj5%Ok;#1_)x*fe#@ivf~1vP|eWl2(=4(miyMGXl=ERK!+`R8~Ca41Mg=Y~K$< zDlmnA0(%hW{bA$|unhqfRddV_J3U8^`rt2T&7_Z3p(OHvS|te?T1$fUBrl$q&LQ$~ zJSSOSDkVM-d!uf*@6fUvij3s>5 z?Zicnz9T=l!xz*ncH~YB@8ZN^6S$48zEojO^DtPVUMhl!eYyR-rWaalreICMHLPD6TAsF=ON_TSUZXClq3 zFah@~t)&+&duzPLxf8rp9bP)GWQ&14q(bU3ZkJr-;b92AH$u)i)9Kls=k#eYu%!#l z_H7g;v;zMjMI?yVr{%;3(H(gmt`TTu==YBWxMG-WAp&<`-=#cL!FN~M>j2|lO81&CjjvQ z8u%a3hrEHl^OJc=A`$zD!aKASIJN`_DMfTU9$gw^fYTm2f3gp*a1T|%MkV|3KqXm^ zZwCaL?;Q%b7PXm&2nyo4q`I0@*A`u7RnqtoM=k>{zLb;Uq=}-FzWce@5QkRB31c!n zWu1^tE6PW#?|5&{CiCu9`&?&);ukgruI(Xt!HTYXKug#GNv)=pTuW`kbXyqcpkftf+?6rr5fNXKrWs^=8KnwK&Yc;Vl6@;f_d&=V3KE2WKv?X z%w(0xStjS0oM&=@$welYm|SM^Ad`ofl$l&%@(7bhnLNhiDwD^VJi+7|lP8%x#pD?# z8%&;M@*I=rnY_T{MJ6wS#50z*9lXr&6(%>Byb3~G88ust3hFwpyK#@B#G^%1+tUSC z0;Y}%L7~mtX(@4b46T+{GsK=&uT?22DNz0FY8I`i8X^HgYi*mfgyN=M;hI>~YoBw3IH}Xj{#;Wzt%uv5ST-OiQz{bgcsQQ_WRXPQ{jH zR8jSrx?ZKZ7DhFwh|&m!idr(B)0>UT4$amZC~gZ(MN?~OyOeJV%8NlHFFMcS6eql10~yJ$*b=>29}6Q zQzVa&hEcywXDa5Ns5kI_&|59^wsgb`Hh&YI6uY|Bs#U1SVRWm}y2($tS@`~D44p2Bg9tQP(?KrnvMEx0j_8wPzyDn|AG|I!J2ws-)D58g-ys9WqihW@i zVy9NE>V`1%cB^5Sw9JN`7)fgyB<}GkF$5kpt-6LNKf&A9ES*+QFcMSf{&xU}`07bR zg`w@Dwm)TC!d>J2-K=S*Ku)RG3>{)GWP6TR0iM;9_v@uek9vOKTp(&dTMc^uKI9AZ zouAB05{TGDY~G}$mabL75s<%CG3e451Dy8Q`ICJUKTt{A+g3T$;D;6B5AGLmHNI&W zh^$sTr@lnVzyz>6X@L(TegB z>&tI6-kY<@ynEF?S83~Swh;iWnr2jV4vnqaHttb_b9)`h5K8aR+Bp62xxq`kUD4mm z1#k1c!sP#&8$EyEY@iwlHo7v{P+xp36xJKVCcXKRXLv(v$HOXKHrgzLjRx aNFJU1O=fV5S0B>^Pw063SUZeenEHR-9?NC` diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/atom_graph_data.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/atom_graph_data.cpython-310.pyc deleted file mode 100644 index 68284e4d537453a5df64c1a00e65d12b10aaa87c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3003 zcma)8O>Y~=8J?N_;F2O`)wP`jZHh$#1g5KyfnDSv2u4(icG5bOA*lhvR2{E&hUC)A zU20}1i^A$80Xemo{s9{3#fKhy@9&svPx%XV8~2@ES|SCd@DlU6`+m&xzV9S9)6-g&K4Vr5o2M;(tXu_}yZUST!1e2&g=XX$%n zqfQk(T+y`+l@1?^VAS6VR1gFfLKPpL4bhWp`OK`2{kwZSk*Qeq)7eq#YkmNtQRpY* z;Q<%Ye-Wv^FZ;oW`)ldJVC~s=@P;NS<8t*vlhIbl(V5)-vxa|{@oq4V)!Iz5ohJNQ z{X&Sd9;Z>lIQa1-O?ub(k{`hj(R@uLLRsbEnOD1M>KIM7YSFi_Vxe1FxP zh8lbFUHuA_64%1h0xf=*2Wl+15B5ZKJZEyjp*g&A+m~+8nSyt5XF_?xMQ^g^%P@!o z;eYIplSuh#*XQ5(cfZ-#ed>$g#U-v23P^-}&7X^u|6qSXuDNl&q*Q`S38J6x`ylY$ z@$srZnXLLJC#(K3j1w47j<4$a(}U}@jZ%3nf^EnhA6`bLdfpwvBIWW(vS2h&d?bdkPK%evonj~ zKie3=3Mp$~ucOSFIxo={+V}oQPTkW|R?1N3uA{@-3_;P8b{z|&vhHd}jz?SwsCuM5 z3N;ofO11U-*0;Lc-rj1_jlI1c?L>+XW!=-Zlt>^IOwAhLf4z{aN#6ooQ! zh7}%=F%ow&GPPc!sSOQ+l0eA*N)F+65I+Y&h=8poBd#3joaHpl02rx!dFHi7~0dworR^`lA=H?~RKfIm6{rQ}T`g>|Lx(Sy=5KOl}C%)t7}yC%wF}xBc~3o3#ChuYcRx*@fG( zShw1(oyXs5>*@Bcb{=i-Y_@de=}v2l7W1;XL`<$qZJW}`!PJ#VMu}8G5^`aZO}uAJ zl;q)<^AoPi<776iJ2~u~8kA*c8p*@Lx=Qk#?Gg@9W> zXR*2^3>*A3NH@9yIOH<@x8UoXJ64F{j+Hh0anVQcJZ1Plo0-c9!Ol}bfF2VL|kB}cdi~QHK$$va&`L$Q`JaY3UEVfBqGyWnX~#f>5-oqQ-FwlsTVLL)u}kF=euu+2Zq-0s{V b^5f*h<*gUcJv4r&jzuN2jH8%I?AVr^#Lge5;*|d+l^-cvl}=Jo;!14hk{mga zMM_4V=k1;gSlX;|CHdo->bJlC?*5MV{<@vbrZoJ1{#V}G_(@OGeuodEKNAm6;c|}Y znr3O1-q1?HT`%eKY?KVSM@kXgjfUBbmZHsADb|da;>|=Up$GYsrIh4Lm(p_2g!io6 zBaK{hqBPOWm-3R|Y)m$%N>dV!Hug2AOVeRJ`4nr z+uHWc^R=q$C$4O|wRWr0@T1SQ5GWdc>RhE!?KCQGTlwiqqfxGIR9fq{Q`G(JwN7)G ziimu5!*159l}6dwZmzZ)lIg75Za%Fln;YjVZUve44I-}D*E@F0tyLO`p6j$4l~uc8 zTV=G>bjohKyv|}(_H*=?3m=>w0=C9E#{BA~NYY3~l4dfRj3uMiq7}DIJ8H-3aXYaV znb)j@9kr5(OV(3%bj?6a%1+x++|zajBbv6e^V(YGbK072Wv$$+*-}o%R4YvcBU;K^ zQ7dXsu0>=#_0p8J-P)!wubw4F1l0)9u=lx>&omaW#ej~*`^9fsE0s%krob6jOR z5|UJuJt}2KiWV@ym4;IYiZ6xrFlWKReA-TdLtU%(Go++~%bkW>JMr|@=L(f-6%(Q= zToZ+wQ)so_f)n)GwhFJ*+>OH7YZnTb@Nk7JNd+&qI|XN>-Dy~bj$;=toxObd>a_xT zUT`<;u}Xvd@5DaXvfUG>{1nzt_2tcWt>yY=rQ^1jFZ$7D+p-&ea;?^|%kK83?Z+&8 zt5&ro1uHdew<@bht9I<`{)EG^Ms4M4yR}wZuf|7KP9EU|t~4&^_Y6&Qfh=z7y5{OP zwYp*5ThqMw&+1=5jAeLQIpUh0?ndjen}%9&;~w(v=ct$U=tL+$`;U%_o zb;3kq`YKr%0o>%X)oQ|XT{kAFIk^<_t$6UHPkd4KF!I~JfHa1tsS(GLkr&+wD1kj zR6pY`+|)j!V@r(w^KZm=GU)wc{eY)?8BeQct;yFTXftC?&1=1b);{zgYfW4GU)9`0 z?p-(a^V&`Q$28eyGl)&NhrDdBu!~rxuwU+Bl%DOEmbm%WL(-eOZyI0JRn*h#vzF1( z)PKgf&&i1E^?7R^z5W(r78rxEyT@9@*y-haJ$*0&91oP1_eUAr=UXPq9FSVHdW7|& zjae^-kRF6ME?ye&K3N$@&=;+KpBwb1!f^fQPR`3A_f_|pm+KYzK06bh`9{ zvYi=qV(=S8$-X=TM4}2FiU1dQg*{OF@q%TqRe(~SJR0os086n_{+7{Fp}htSXuGOX zYXPrF#ZKXvWN)sWXg4=2#|!e*ZdTkd*f?GLKo~q#EVR|2U#xYL-ErH3Z-zL9UD<53 zxBG3Nl;n^G3pLlV8*77V2KYu^apWzA^^LztyCpAjbR-=kauj&By5Uq+ZPPuFQlh#P5}IB{5nZ(3Uuth`?yL{DE^}+)M=WN5i!Rag<u*J5*2$1q(d$=1I z@FATYj79gJBS(%D1{<5a#GPAvxfoFspm~0JWHI_VnX|!4>`o3=ysW5Nt57u59Ezz~ zfbLVnr(TvuRhs(Uelb>j;Ms-vaK!$oWxAo=vyb zw86_}%JvOlnKEC+@iWcJX1UR>${?#lX!f0G$6Y(|Xcf%a;0L@3Ru0@6E@uV+OoH~Q z_$Q4WU5LN_Q^s|m*taxw1c}9nZ!Dei^_Cw23oKCr-g)liM!RXB#7wJ`YnD2Bu?8-C z#dh42o7?V2yH!4S?dr*9rD~s8#X8$Kd0JVU73FT9>`&*(V8NXnE`p`aZ9nw{mt7So z@!FO&rA6gog-Sdt2vBCuadE;d8#;<7}IA z>){Jxi5*2&U$3bm?r-U(6b68EhDi#t-H&+n0Kg4gP7ENX<@DGsbr*u&4_+pp2c+V- zg^kLVEtr|JSH?7Pb@lg&u${wE_+1p1U-#~yhG21E>DkSIW94fHWY1dZ9(=B4-@uY_ zvEyp3P42;Va}|@po#u4!!YZmX3Kg~9X+kt9G;5B-O?6DL>6yYida;}9v|ettUulIn zvYRX%BV&vW(~U{-C%g9%Ki1Y;;IXYj1*lTqZ3CfJJ1EAtDn|YE2)FLX0zTc31x&sl z3s`(V(P!{w9E&IX=%ylz90f0KD?ibnS;gh)r)ubNKwbPe31OqQ>L;#tluaUyb?GWU z+GxLmlBwq_jgEamDadw-S1PK7S9N@oAwMd(LU9qOdDfe$IpqOO@y(7~Yxon5cEy53 z*yq|y>D5lHVU^{5{Yh;9GD@vh-0DWzg7}B6D{Q$R`RIigONo^iuU;r$K6}+q^e4I$ zJNMkmm5Uc;P5AqOEuU#_Hf+v?Z3P{T)N7Cq*f1r5Dn`~%+E`?iQ*4!FSEJ4*h?XB+ zLoQd5-S=ZINJU$fvi26l(|&VmlCL+-3`v1Gghkh6UeN@UdE`dJEbXz&=`A5aFFx^ZB#e?_-dv4vfZ*u85uj^ zrHW-+ekwp%+~lPUF=Ah`R##DVDeN?+IWuNV87ZWt^pu`6a)49F7t;;HK<@rO_Ptm# zf&H6{Od&@KC1)Zm!*ugjDw5DMh&9m;(u^4+rpJsJYC{>M#L!kspTSdJpF%5WN$)Q9 zvE3ect4iF~|3wj?Xm<{ma|ytO5CLG2;TRK~GX-@L zT>#OFx*e=m>@nN*r`ECh1NoXb5M>)6XHl!S4;jugttT0JqIVrgFFvsAokpEam%f$W>?r6gy19zytN zEs$4?H_;Ac#TX=NOo|uvVv;vAPnnVm9P^DBCMUY=Co!I78H%@Jip2VD0bC$&lUBdY zA0&8~;30z306|r@MTJ&QU1ZGU)n_l9FRxr#IeV#mVfn(dPrpb! z<5sF0ip0=2Yp&faWzVi$x%|wzevT-51acf-gQ5(W#!pEl7q2{f?t*W$HZi+ORkr=K zL|+Gw&g5iX+v9T@cg_C+!o zlqKDTdIt=;^UJ8fxffstn{Girf(^+{Nof24izd{OL$MQkI=BQKlNi}GA)-P*2SGjW8dSDGERXa|7Xtii8Wk&;a1)`px|`55-9$b4 zgyxyO6rR%*YTXP%SvLno$0yy1dR`O>`FbRXXDz609~R}b z6&3NA&-+K8do!|lO)K^!lrxsrnB9qb(cYYex~a;MQhZK&rj?r4#^Tb5!{Y46y{17` zV{!o_H19#-^O{y(5PXtderI|a0tNM&!YG5DbU>}f*eaG)@V zJ$+Z3*M35nO_}w#i~yaqkrm;ixT#f!by-Hi${-B@lSqOlOyY0j`8#ZBLF=B{W3Kyr zl|8~G4ho}{!ze2BZN*QLpo6Cbv%=lijqw?1@c{3ifa0)G=4ODVuDk_)tF4?*ohkU! zdnI)D4a;n`8=&RPnwG4=b8%5|k@?9%2L*X5S_b@yHPvns$8K(yNhwu^y(CvDq&?Wy z+4E;t&XzAx`US9`u1%bp~?W`+u zzBOTo9T&l9?iD1G#o(L4G z*yEtYF31Izt|@!mKq#sf2tC7vP9D1x3RFBNWsl2lcQe@ID{i))8`F;Mf zrHt;OKYkRdTr>UO1OjB(WgQ&@TJm%!He5p>|He(j$$9#^DBj6{fUCQ>V}>+@lp~-a zQH0{^2%#4ZXo$DpQ|8zPzj{iN~7aeRNInd2~%hdefdz zl zU?)}6pbSsp?G9nbEMS!G+SRFw_tI7l}1xp zoveG0m#N=t#pXe!F?)Ha$+Nv_^x;033vxsDa2Ey)UAzAK=*1B)ixHW8-RK>)reJH^ zw;Kt2HI25@*uDFgyYm;Co9=c%RLi7M&>^?Cps%$4)O)%Jf)?!QrT1RVf6RU+J%mX! zajx9}3A9x=Cn7MUzK}-VJNo=yv{HQr)pQr{%sL>E>Lzjs_6V3oVGGv~2Wc>3)tYbL z!cV=-Os4^g#{v{16z^k*y!SG`gRqmKKZol()G$yY1aXIC1}KU@ zN@D9k#Nm@EDUb;z4rxfi6CgJHdIYhwtj1qAD3U`>nphTYQPokosu8Rc$Qs>1sHh89 z_Euzsm~oLILTDD(aX{i^I0f9#xHkYQ9LhWbrao;<>(i8QqOo_(MECF>12Dc+ASBYU zK$Oy=sb0Rg{Fc64oTYxP1uIp#?5E1qY;|C4E<=}Grg1$?iI>Y(y9&(<*&_8g!D9qx z2&jMxUPrbIr6=E{Z9CFL1lIZU%9Fh6TW-2W|!vYn@g#@P?83_fuQe zEATR)=>c*`D-2dyUH?bHgU_6C=SKm+%mkW@xRH%u=@{{q-$!P+22gk*@3#z+dh z4sL3_DE1Baj=mlUp|t$RC=r(9Zo*3nDV+lG2E)@!?%j%z(L74Z8r^DjpMzdtkozF+h2|1%5kC^op=2_Qr*!;2o7d_{eRQR$K40&Z4P) z1(;L)1A>1@@Q(lxE1c*QBZQ0+WiJaAET;U-C?YGxt7wv;yKs_tI zVg`h>2rd&>6oO#PHR$065$!>kdgBmANrB=J>EPD}VKM2!c;d!-aR`*6sV9AbqLy)> zNaP10JSI^JK4+2^B^B$X!*uY5j2Egj1tm=e{%@ugg^(APfxnghgoYHTQNxrRh38#^ zo>-ES@uH(CNQ)pX0-?~b^e9*#)`1?wlkS7iDoDH$##1;QjgN%!=+|gGsAZ!2_|-sf zBSz@JDqEIoDg}`_vCT!C81A$YQB0UsjJl@&D#1?^kkc*Z)vqxuyxk>+ z{xN~@dt87+-q%kJb(Lp;zQTkN>oO0RM%^ z0^po6r_br8G1M!L?$M>N{;}m6L4|$sL;Wn<`db8FCHUI}vZ02%io@_P*^`jxj|$HZ zen0kIGu^#s?|m{RAQqFivyJK(+0ZW$P_wDNM(}lj;-q?u;lD%hcM0Anudkrk($q46 zpdFcYnZaLR-0+=-eEi=>LX$5>`5~c!XCEMf5FPyQy?A*eZhQ|0oj){^x$h<}<`)&~ z@lBEjKL!W8j%tnRMmW}Vh!ym32<=%A>l$czM{h!4Yese=u)mn~XfGxZtS2|{j@g}{ zDNo{H1@T}5@pREyfmwY47E}{fo9*K|v|qaRA`ssz*q|O9XQ5>?cA{gYfg<)S{n63V zu1VVp+RN=4>Uq!HiM3+xgaw-|P*8LyX0d;HSb1BBjWZVa1n*XoDOL=q04#LB$5zTp z??%??%$tFmWDNe^5i9pPG;|~H6&vpv5W_?uyu!&Re(9MHmoI#L#h)(I%o}(;STF+6 z(6OyPfGgFtZq6lK(J5OCsrqYWn*dd@(EtSQ+&@yej>mcnCF=_}ATKoR;KWA(KazTQi(#h>4pcZx&Dsq*puzbFXGIEQCvJvpIP#Yt zEIhe~@8_|h-z;k_KZIPPba*5^AQ0nC9Q%Zn0y+pvf2tU`iLZ?Nqe3izntKiIi%xJX zgaRZtrn~feB0KA@#>>W zac83@UB-?h^~_pf8;4BbDcSC@WiV&k1)F|s&hY}q6oXg*>aDG9OKa@SXeG21aC-_* z$Hbr)p|`CZfr9&9tMC%gW@i(AmM;bD9o*6EGlk*#7z$dA2D}z8pC2IF!SNWB7q4CV zbBz{1x3*S8gFGD*x-qrq7i*Qmo~V;9~XZZvS=kaH+8 z68Z$0FVbjjvsK-7K~jyi(#2^jsuPyy$KgaFXYMAC(Bd3_a>!= zef=~yhf^wYAusN`x}@al`*-zqYR0nlner_Bi-OS9o#j|w%A&+*sJd zEOpiDEWr;0_@;6<;G9}*;igE@=wyP{gajw4?2uq!#UjD-O~^8<>Ypa@xz$0K}>pBqF~+cs3h`H|-){JCsDsE_<&Ae4ng7$Jh$NQOX2@e@+V zTCD|E=y|;4Ug`%?IqaoxI81W8hEh`r(tYX_=%&eV%E%iq9z_B_FvPu^iU&_LJO%z? z2z@t|?#JRu@*wq(QYPI)r%=O`F@-NhT>-5!K;Q= z>^)yo7Ws%3n`13bVq-l-CGz6?*u!Z6aZ8KQEzN_fWyB#Z=vnBA_6NubD|8(G)Wiuv z|2QLxPBUAa1=c*Rb(h{(=LSw_Py}*msQ*FmErLq`h?V{EZASb)fG<|wz`;6b{C_b? z)?*d9O#{ zvI0FCt|(muAwuKs0Y}4g){r|qUTIn7Gjfx*#@2Xp&s+x9fdtUNc{~V|Cb@~C0dzpt zJXjDL;~_rYnJ(D}WsvS5n;&cN7|ObDTvxun0!Mb7S>UFttoNN!O?q{`&9+^r!tM}K7I+jF%$N&&ODsxR@S&jmHitfsj?9QVI6aKS zD^2en9iM{ttYK_ExS2SSG9Mp7fUc<6LVpVDHD*fe)^rz>(WClq#BN>ZBS##*ksMLw z61L%Yeha4mJ>04&JyEy)~V;ZVE-jLjyyy3$L?EQv#=Yd@wVOrv2yLwtqD`QQ? zmqE-23j$%mD+CD3YLYb+#)o0aKN^-o@b4YClU+dt;%9OjGJ;b^;EnN6@M5oX6o^T1 zq=1J4!aq#Wx6IV-V=-28xk?n=|IxvdY^w~$*HNhq2Cx#g)SjNlU{yo6cFVv);cH-L z!JI;ar1xThA8&aV7OnKz>cUDavoC#PdNroqf2&*KB7RcYGmXXFu9s+|@oA*_qjy z-I>Ff9N3xn=Hcb7JHPP}4Ts<7dj}DF2>xAJFXv6{Mll96y}R53w$qF^8-`}RxiB<` z5{IP(QoT8E(wp-3dDGthT{y=e?QTg6#~- z4E#`KKXyMg(9n%y1PFqGgK?o!AjVxv4YHlZ=}<7u?R9>%LlO%Ong$^nw>#@1VXu?f zgcq8%MEX&=QFGRR`b#H|M8EaTAL16q@7q5VSS2DWD=TUV5hHY^#z(<&_aOuxWQ!5s zVEAEi$qrmqox1~`2gdxjN5GUlrEj4B2BQ-k!IyaKz~fURjf$@o~s5rH5AK|uW{M*l9szmYU56`;HMw~YBcg0~6&3&DQ{fNoCM202w%!I3yQ z*g;=zp5X}a|9@jXkZq_YTCSt8{6iO>6gw2S@2M{$6SPXgw!DO}_@b$wWYOesz~UG% zjvqOx8xlw{5W1mcI80>>$-2AYC@9-Cg|)9KvO#_fhZ=F<0C(vy4kwAeGIpcPQ!0Lr zsWnJc!Pz7~y4t?cXv=M*<{LM*ap54+cGpkhz@)0R!5{ibRM8K12f>I8M&Qm~zTjuV z(8cm5&KhP1Pc)GFiT>jtbnP~JAuH->nRzQ=((7#w&p4HaYYOHL zo->=q39=bDeN7pMBtOr##UP--wH=}g1)6D|ax*FLgd)znxkTU=PnH8-cehgTj*?Sr zC^rYfWZc3(7N|%w-|xO zITBKI#Rc~llA+L(O)l7p8;w#kYmnKD58EOHwqwvmc~It1s@%de97k6#cr8qS@X7!~%8W>m49R zLMWzt_znqBHV8}@+7I#+#F0D|qy$pE8E?YN2joad(QJ=oio{LEANRZQjgf#9-5Z3Y zg;_W^Gat-VAVsr6m_S4%OtM6RGKYAQC5GpNQQKL7!*#KKzc=qKgrw*JrU)ro#GD@Y z@I8~AMUbil#(NPpoCs?;AgKrHOGw8X_fB$7gd`mhQiS|MiXIeFbijSc!&g%R=q+-M zdItzenn>y(hr<}A6;I=2etUB}(750;S(>7|s-Sk2Vkk|JArSUA%g_Lys~jIE2=K`) zZHUIzBp=@=8dp=yJdVQEKE_QGjN@;$pKt2MAy_)zk@sC4!R#rwAS-c!=QF*~4)?hxiVBn>h2wkr)V4|G#i%EEcl#--l>3 zH-t_7@l$c6sOz@T*A3a&}K4K)>F^ zqybYO|K__e=80D?V&Ei{dpEJ1fCoe(1&`ir;`#iGVz?@v2pFc{WD}1OdBtmfV>IOe?w>aQ57H0IegFUf diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/checkpoint.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/checkpoint.cpython-310.pyc deleted file mode 100644 index 54ee9d8573bd5121535b081ada352ebc5ccef1d5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14780 zcmbVzdvF`cnIC2_7z`c+KoBHB>amctS_!h2+Ff0ocyF^_iiCD)QIJlgw6$}#ClGo- zg2GEZLurW`#Osx8udA|`y|eGGD&D;UQ1J?dw zZL0Fe`O1mUk5ztO_YeR{tnD08)6@O+*WF+DbbtN5*q)e(hw%6NUw?OV^_N1S?@(d% zFO0$qc$}mX3b7DV8X;4@6;lyq)l|hhY=-ex8(K4BMijycH}s}%>LRZ-qD{jznlUrh zjGOUh!b}Lb(ny*Uz>PFg&9s>oFujp!PMVV*h~q=cHs%;N%LG;+rK0;`QgRZj;}4(tFCWc>$vrHtJ3fzH(SU6GkVMBj?C)ocFSpV0JTQD!thQs+stlQ zx3}vJD&{wA*W$L*Zrrn(LzaT&Sinm4zS1 zdYY?vp$7_5<$-cXb!I;ndR_4}FVt6eL!LtE1JyY>fQAJWDWFddpqhYEsxmDX^1_rp z(46P@Ft#Er!t^&W${@ENbB!=r2+JNVbL_;>zWLDHXg*VToGg+zLR%pZ^wE5n@}`D% zYPqQT8h6_bU#lR+${CeLquOXY^sSVNuiWMtwB?gX-c@|H)pEQj$aO!*x1M*o)rYXrI@sp$UQJq6zL&>VS@QgB}oK#dLj-QTS^2hvfYNK=8-R7+- zeXhMf%G0h}!sC1liHq;#Di2hi!X8u~qC|UY8GCL`)`d|w&#+^tZUn9YehS!;UF-^v zc1EPUZzvjuHu@S3MQdFL+9zTkS9?lN^%M-OK2W}(-c{G+Z1Ks{bcij-#Syj`j%_jQ zs|!#0;XC#YmcP>2w&l8r#S!Zge~$2~?Plls^E3{9&iPRH)PatMUv!ASk63C)yI2sL zLuP}B)q6xHPTsE$$>dL+(*Pg4D_(c%Ei6H+ zYWo^Z&X3R_4pYBX-V-0L)d4E_86cjc;ZsQBN*0?Vg~w2|u&CQNed;xe zM{El)aPjH;3bCbKd?i{2SKSIDhcD-;cQk&?)%p>J(wA`0cq&ruoye}D;48a&Kk9{c zm4Bu5!miPeQN9;)L+5FvXqNO6`KZ^(mJYBZaX$B0b%Uc({?C3wb);I^N-a zhBoCVj-x)>Gd%QvFT~gHUU4V;Sr%o+u8PeLio7T&iFwACLhpp#sa;hGc?K9!^tsU8 zD;P=kl9s)=InQvXv5kA4(Vqcxlz1(%nd4Q13$ z-?w?e-K?~bpzO4_ap2E(pC~RA)(Oz8>=ddM@UX&m%YLVOsam@^p>Za9hFiQ^0 z6Er}k-m2r`8(x1rDaO50dc`-^U%6hgt}b5pV^^;&FRoISKSkAQ8y$l*yLHb`+yt9_ zSrWtPbjS@Iu6DMIF+mpQXzJAt5pNo_oRp(0Ud$LhiB;nKA5h~DQ~$~xU%7W+`}^5| za0^}W_O@GZ*@WyogCrx)L6TVJqf{7dQ&p9`Vmva!hLTYYz>{F9apdF3sjq1YUq!Qh zR~~U6y7I({`9rV|T&>DnsHbf}2K$#vFR~Hp>3E`GAld^3d-oTyZzFw$_9hFzqtaO! zl(nF&_Ek5$rR^##a!2JC2lcwFZ)Ln8MidxJUNq=2!0_~-Y{(XOGc2~D^kOW|5*unS z?!~nzusQ9>={?q|Ij=WE`S z#9-5efw_Fm%YybPu`ap3u^R&CKzS~-6-N)qJr|-L(Dsi$25Hcn!Wjo|dR@8eU>yBk z=q-ioUg`rz|M!E@{|r6+2=phsX?%%US|{hP2h`*PYJT+dQuCJLM8?n)74&?}n*=?R zhgSo%9t$W^-%{>Qcr)ndqXR1DBo)B==>G>DE1+Y3-@5$F6~)?-EPbT=QoMk_W~H?a zp$y!Lq@@DxU8hiKv4WI73iTRzThqp=Y;BCPvdv1T0|{Dm;ZH+6qQiHu-N3Pfgx09+ z*c?~+6z;6ewrja{tUO-cGb4(tt{Q<=BlH z@Lii%s$zhr#{r-5^#%@og_|?71xRXv^Abg9%-g=+g1lbu_{mXfyK}PDLT$nAEYz&p zs|(<;@s&TADQ;J4nBOSKIl7m)%L~;%2sWGRoqH3 z+&wqi67)7Z^#(MNW~GkxwV{ph`o^ZaVBNpC)y#&uY$7UUEP!4#e?6qA00S6Ul3M0YyJWfKetv|y=*OA zE0;@4>nqpF7TsL(s){(BM&0$bE$9iHt_nX2aUI*+rc0V$Byn-qxnFlTi+SIGoQDgM zS528?jtdN44K5!)MiaKG+xA^E8XN?3O3vSs-`HwZnzm!+N50Sgn&XGoVqd+l)Xg~z zV6|I~9V_U%)^0S~kVzbK-oDsseL&r0eWTUJ^l%)4hUV;`C!z<6%U(bEVLu0I(j&c4 zl{r1=1C$SYm>Kr44?Z!{zE(Ap+Z}9qY;EE_oHP`Eg}wtR*&S#ED}RAR4t<6X3vOUHc0h|0qIx`tBUqMa3Z>@sv2YhP0c`Z z)=&p^3vW$P-=BfhkVh^jB#Aw!7EUSZqpTWNsg5KF0~pjodw)J?0cdz$(IBpZlq#Nm zf21$HfX8_WGPoTg3lFS23M@OSt=W++-Hz7Oxe)9mq>yQ*!6F1NCl!MbErmGYcS5FsL9&6tw9ahZfNwJfD(OV z#UaS_07BLrg3Pe9Y>v%?`dL zBOOvrq!E$oB8`gF5NS-LaYpzFVRcc`uG7 z=}!{JfHiBVEc(+{8wRmvy=(LRU9xhp{rWy;^&_#h?vsyX{)dk5-{swrtqV8qLb9ehexROk8z$ z$5gLfG{bAxF7m5r>gR-D1RlI6=D0f@o5QFK#SLGzS%-RQInX}qwo^=WGh*>uHnd2o z=bGBfn~=P;m+-AkeMt;%MwjH^rnZW~yQvqUOmzxpB>}~5c4a7s6@;MYbaQLi6r`>O zbQFk=ZhW=kxCOBxO{3JXn=q6)eiGU&?am6W-JQ!_^?bpP%Q9q}j%i#Cs)}j;A|}F* zQ9@U_uZg{=$&D$xUm|j#?8?&*w-Qkw?O`tkt1%ky5EIzIhqo+#(N+g~*O4#z>Pfi>c zN~NKEz;>6t?c3>Or6VpPb~@P~$x62cJJcEsqhv=7gJ8aeigy)3@&O{niTpOIh3yFk z?c|t)c7SahA|MeW+&I#ODdYnw=5JB&Bb4x8 zC4ewDG$|+OAOqPrI{eH*ZuuR;{52%qQ)7rZi2e5>tM#+V%8ldL|7$?nnK7h|Yw`CY zul;k7H}2;6o1aSFxE6=V8=M2O%I@EWD3S44eCV*)a`HnrXw-2#?9-_D0Un1eH?+fa zrJEnKZi1a0qGWI&$hy9tl7k0=j>8^hkrB$o7QBn*{8uO$N8_hZ@jjlBQvfPw$52TO z@B@kDzd-|tr8z`EhpIl1fY(q9|6z##CNg6OCZ0Rq)-M>Ak9M>?9Ci^ntH8K$N#n{S z2LlVg8OHUg_HjYeWl5rtC`E!&M3nTPq>EBCC`CmHB9drlh!QMYq7)OQL{N&mur+T% zzGO*=1#51yKfxx*RxPb!k)Xw2%$EO5MPg=`G+y7~g}_!lYpbxLUear(gg9wo0)LcG2^huac{Phm{nyZ@!& z**$fTzl=dzFJ|N>BT3AULxqO%R`6cBDou!pqI;_iD*^e>xZr|<=ZKRn4vg5AVrz&? z_E{qC7by{}n+%}Eu()i@C~T@i_Ef8#gG`(_ExD?llgN-iDy61y503nxP8mv4P2r83 zY25J>r2Em^mFgX6?)zkGGDFk#AT}C2+vGV_8=)o~6ELSL$xsI$jXYQ^d@s}scjNfL zVOQIVK&jBkJo|E}mC)d#*m+#RMFt0Sq^G?CqvU_;l|;Di1ZWWU_c)l{Hz+eIJT0W z4zFwUO%*29STF8QcyVk$1I`uKP4&}oW5j90x*2j~pf)CIRp-+%^TCk2Xa4jekRlW1 z_h=)*B{VO5)>i4FMb@o*5ELsdn5ReqD5_={%DJy0x$B1^dcm*B|A}m;lMGaK<$lbs z!ORz!|F|tY5Wr@z*`%)49XO~qL?!GjlfpB=e+!-X5n*(YJO2d&OfTIm{nE{qTZ=bV z7RyT|vbkJddC@nn*gGZ8+nhfSEOQon3Hx=Zomll6xgNMUJ)|e|i_}}1cJv6E6b6U; zv;?GLLdrvLQ!66MjFN3tHIh zz#jJ#a*0I`*hV3TB=;nIu`1O4BzCv@C>bVs5c~g;uB8B>y!J?ifk1zh)qtzaqrQ85 z3|k-R+}CJ~eFtCyL#6PD1MoI3N8+pZ6rOP5caZ#wY~n2JDqAWZdLxKn3u?@kjwg!8 zz!Sq0$CJR5#4~{>wUGjE4sE63x663ca&iu;Ob;8qpG9ekl&W4t)a2&iy~x2c2RPIm zCZW^9SC=Jw1Qeke_~jG^4*eO}EIy8-L%utA*3+?(6F6pd)gK8q1}0q}PP9M+2fuA!e3pztd!2gPt&_Jx_A#7t*U0wzkB zEt?&bPN9?^l;Ck>bAytBbx}b1Jm~x`prD}tIC`K}KaCZCqOR1TAc7smdP#iIBX{5x z;O8;>Gr%~@j_#uDoM*>mTY{ngqbOmHC>2qnI$1imuk@jPrSm~a&frOus25pUK?S6P9bco_1M6dEw-iS|JZZ0Th84Ph`L-@*+M zjZh%{OBh>(#eZR^(rn0^Zm5q0uFgSAQsErvsdVQ^e+2Ogc*E;e=>Y{0T#DPxHo|k@ zD-}T)6z$X48FW$%n=w+g$i{8^25v<71i=VWdpW6gc2hT7cUtZHtpXjrXA5wcLMi6& z;K+6J=gtuc+Hx9ay0R0D;* zn>#?FAa4{3k8zzz=-G4UY87%LfD$ZP4L)q)+C7g~d%Hm%>cR(OlCSy@CjT9x>UWV` zUPqwWj$G5HhuLTL39KaBNAb{-IV4q!sEbq&78iz_T?7}X@8Fi*07H9Ql~EW77Sd`U zyBix2YGCQ)=G{Ea{?58pux0_n~k*0y`T-Xc`#&yKZ)uu5+@g`XGm3N9_f#| zPmSTu^sPi}%1B7bFnr?~KYokCROIdU0($im%gfiTON&caO66rgzIts5uJId-GPU)t~vehM)D392w4x(PQ4 ze~yZ$mP?lxZ?3MBU^0{m)-5^dKrooyLlR@GlwaJF7^e0rXv%xQO!?+j>t=c7mQ`9U zT`iT@u~09SuCCAuuU#)Kt(4Y$NwQc1S}9$67yLL7m#$pDwo+b~)=QF;{p8IxL@-eB z!piz9etuB0%J?!?>FTA@@-m3>=Z7`Im4seOH!P$#u6;n=wVPBoGi(RI^_%O~CG3YQ zer^bVc?BfMkADbCdqJ4D_$6B9*O3$_q`31g0T?BpM`FehnkGZp{P_JgM-+zxOD|-e z%2u1>Rfpe)e+8I+7-2I*FN;ZD78nVIPwJr56B}&(5}lNPisnWH4v(;Ka6rNt^ZTAY z3CUk2H%S3b5rk7f>OV?>7tjdgLje|2;Bi%L>z6iimvL4 zs%GBTBHg1u?R@&W)OO%@4zwQ07^DX&fn~Wv{F~MJzPOGfm=>`l;@{4nBHdFkL|m)d z({M+F88yO9tw&L|3ey^qrrv`tf_psrd4xXp4Kn(&$gY~C_>UN5tQcgh7;gI*WGoGF zAWAg2^dQ$cFS%L7%AtgNehT)sRX4{BlJEp{8o0Mn zJ0oyE8Q{(YxVM1|r2)86eple5Z|CkCWRxb^Nc4-FI?7+$>&ZQH3pyjVX3wnZmQuavmeGD^#B>=g+FwSF@k;l5!RJ^f{NU^FlwWflu zWk{Ww-lR7P>1&2o*8?Wa0g{ife7+GP2*E+?`(70A3;Xe$Q)0HFWk2nyvW1lP?!JEj z_amdY5|j1?dX7*SqKLyOVAZA&zj6ZaoO{yC0Y2^JK_oL;^+=jB;ONSHluj!g!+?AGc26Hrp8J$!YjNoLw*9FS6XOf_3|rfS(ia-xu)d5%`A${2Ky3GXg&^;NJ!fv$A*h zNpEUv!NcBT`3Gum9B#=PVf)5HJO4$hmw-Iw6Pb6tRLCBqD${O412F4HkDKcy(035d0K|c?(S>6Qy7= z{vjk^CeBGx&rGO$aSZnwrD;UK7{s`CpIL8<06@W~3-{}Y1q@_1cs#(m;F7tMJ` zM1L;`S(>72aJ!L^2)R~B(Eh|e7!5?x;)J;u{x5cnc{-?Ys=VHDodswOaGM2&m4V2~ z{|AlzA|)hI@E0ig50nV$DnSxUAe!)h3$TCUpu}Q9jDv^=DaHcwrP2-l@2I&?$)8ZN zK*_IA@(oJ<3nkw~VkTs)iTEs1LU@|$NKBBHk_*Ve_Ja^>KhlA}%k@)4PayC>TH?P= zxECooLkS6Z(y1nbEh!#GL|pR!NV$JQ$$z4RJWc#RBk`jM^Q^co;x%bW;6{dft{rz4nzY3#+vmTcuNa71rIbvpDR|LR8VO!xn94%Hk}gUW zH8Y4qp(s~^q>xL&RF!;R*K(-SVam#@Is8zIxL5=4I7GJ`L@;^~+oB>)CanF$(6t{L zQSI-Hxb}aHg!X+SrF}1+&V&&@13`}{Qx)S!+G}w`Q@fA-45qh_0b`~^V=wJ{NLcic zPeB-H@z&f(Uq`HHltF@`X!Bjfi!$}?DENN3`?-Pozh`~eMt~PA^md_Llfl{}ia7;! zQ80Ib92_FK_O_oCVk=C+jZtg zDE$RLylKBW&cEbO(SM{2CoOc@Be;^B47$|xJt3ZUi~e`rC&qlQF&#j)K8@yO>2NCG z_zN)`k>_!Xwk<_&SqQYW?jpRDjuuK$%VO=SW$`G@B}$f+M*X(0T}MVTI2RTCpHZy= zyA5I)+#RW{(XGriO2h(vfpWtIB~Pi49zQcO3y0*mrUT#9D|c zi*?jH2zaYF`~k4d$qrm76}Wsh7J?WL`Hun-;UN;C_dHJB{tHToCHoV>Jx8nK1My diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/model_build.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/model_build.cpython-310.pyc deleted file mode 100644 index f40f0573b12f208c6d90acf01ab4d630bc456f6d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12071 zcmai4Z*UvOb-z0t4u`)4|NoI>L6&R@vP9aF|0Is%h9*Io2?a7FWk)wjFCg|n9(6cC zdk0x$51lr08>fvj>Eu&uXNKzJLzK2@r=7H&cG7nGKkc;B51D?rX}`2HZT+=rnmCTd z{@xw{f|Q+-bBo=#Z};|g-+TMsZ?}=nrX>9R*0=6&egBFi{W}%<|ETt3c)ah*lEfq? zJCYIo%7%=x;wbf)5tE4)bJTj=h}Shkt0#;^T{m<=SDj=%Wu!!2cGC5Xk?G{K^_(%# z$>-~X#vt->ra43PVPjb26V6C|)EKRg8DsTvW4u0LOw=chNl~XeQ}tuUF_BL?$LlAI z6C$5-9;u%+PSy)Xp?=CZC1`2q(fVoQw8&?iGxf)e$Lfz8kJrx{X9X?moU0d&V*R{v zzW#*qgrMb|3-yb}MOm`s?P)f^@@(*~f|?;V{JLb!un{(j-zV7^8^`ZcY=TYV_Y#|8 z$MAcZ9k*gNg`HrJysjEgcS23|xBw27IQn{c9<3ABMJwV_1m8vK$@nmn z*QDai+qB;0q8z4|IkULM=IyFqR6}*i_WUrhWcwDcIAQ#nYd73bZ``uD)2Oifa;PnG zZZ*BPAsQ0e4@SkghI^~ww0v+E<6-(G%kwPfrHW^JVYb9Jtx|o%V$60o!~9Ca*>M|n zyW-5X{6?)74z4w~EN)jTPRXq{1XKD`D@)6(Q9WKVxzONMYt85AcDdy@TmAxX)Tv)o zPL!>it;H47c4^F28iKjuG^#g>u`o63H|nbvYgMgKUzjVK)@9cXGtrorY}cxA^f0jO zT3ZdjH%2tfUv038^Q!GxtCok6Sz)4^8Cq*vRon9VD$!tg&2nm(1=kWISaf~f2qzB{ zqZdY1NBHqoyn&$%(}1t=N^@%sUuC&IrYY3yO}D|V-G%intFVr#*jlS{yXhDFou*ar zc0Avz7g)pc+_QecYc-nR$-G*PG$?@=-L>)6(*|e%Ll7r4) z0>3JrIXvFSkubR(^QAz!BhO2p$$eUBt9asVErNFX}PUBR?D~Wajs`^znCxv>kZa&EYq#jEwfU? zY8nGKtR1sMZdPm_YEEUx;$9fD+3ira!x$Js*}O0DbLja;ugq*U>efsPJbDzWjk!T zxf5m~T{OyQSVfwk5hSWSBI~k>f8~K5Q{_Bz{5&eEq#fNq4Vm7&zK%j5`O>yLDYa#h zgL<|VKNiTodWU9J-joAn1htz|8&u5w*|r*}VunvKsU3GGQBpA9=l!^^)vzXR9+Vg; zm;FSClJIqgR&PpQ#rk~?E3%gi;`d}(x;YWVU;^UhP<0zF7Qlt1*-cEX(1+bK*W4R! z<2APc4X&Bho=BTQ+G>rK%bqV3Q(^2ArR(9yT4`y)oLerJOLOas%Vl%@`bsGrC||ou z#8;M=t`Rj%*_yaX*MmF6}3>q%WL+gG5LNY zF&$VcLx&vE7;hOFyJ3=uMD&DL72`D$9kP&*D52#GM{L*j%_AK@gBtG$CC#Pg+iJt|_p(W0M5XU1_2Qm-U=`OM@w`Hb$I@VU&v9{_fdofX?PD=OWc07ph z#i1rx7UVRDMFLbU(8OY%#}cM()O=0;k_-{)woQt*i6GHyn-FbNXsZWUb_l|vpB8Ns zqHWAggB}Nc-Ou#X#aOZsz@*6Ku(N14RBm-k9s0ne#2uD!Wu~`NqArg%X@4+C2B{#; zl6UFbh5{vsK@d}SNx(8?BU=>~-gGQ?)88t_%H?vXZHT=g)SK7`OfZCTVeC1n3@=b3 zjD#E}Vf5-1$KJJyF(b10Fr>U!fs4H{fT3LvDIC6v# z@n-2O2B?9$banPqVUkj_G=HTOCKuPtmD$zB_3L5IEHAHKon2ZqO6J(}aEsR<2EvQg)%SfPTvMO|dGGsK3G%78p(G=uNw?I{4(y3Q& zn-;drP(ydNzk~PpCfG!WM>S+vXb811t?urWnImlvg!LbXG9 za&k`LWw2KjB;7x44etUuY+H`tAZNyV=PX5RELKITHdC5?QWbU-cp2;VK4Gi+i6jknUo;Js>qWH z#HoL#`#AlB$!<{K?`#ocGqfgZ1`b{+@f1z_ZMT}>okz0Ej&5AvpB%%gCcdwt(Ioz zm#?jxvumProp@vDn*S8>(7qRD%FFX5bD=c5er>fB4qaO-nJ+DZf3a*XEtX5OtJoOk zW|vCy=Ir|N)x|lpR4%PvxgKg{tKnd#W}!Z>%!<8XtouIl&-&<7{q&a?=jTghvs}8i zx?Eo4!x%(S6Iv4LozJ+GsBZR?66EBM7G|6@f{=r=08tG=V|tA+f3r-&lBhz~7f^+(D+>N1>qJG4vCOp>q& z!9N3ghYh~BPY4B&NC6<|QbnC(N7165X8RlB6^4mv6uHUfhrt=u|SR-J2)3mtGz7Vo@3fQ zg>U%-+j%I1CPp;a8PO1aqcIL6r?AA`1llE_S|dznNsIzc%N?j9-}OgDYlZfN4$m0F ze$mJCC4ap04k$Rw+*Lb#6D&)POEHH%!*9@Te`qpev`lzNfjnAUY?RVx^7s&JmGGr# zU=lt)PyXmVc&UN%73C)Bugq&;htegsDqh8hHzLZhP?>%*ly8Lct)sG%dVwK4sy6D) z7k`yrN6K(+3SA13;VG0>T8HoG7ziX@*JcT@^=U z2Yp_#UDG@0=+Kwa+sL_~Fa6#al1=*3M4P@5*6xn-E8@GvpkZ*H0!J9HAzbrep|+?b zdk=q+m{9tnZ&0`NjL407L!J|xZGBRF+>#MUcLTfKCF07lv1cf@e zq~_Jxl`t;I_#xB8H|bTh^3jZ`w8UW?fp6X8Es(<)+!p?{$ZcAFm?0~7b$K2e_WE^B z>tJMR2t#QJn-r6b^cc)n7r`67?I9|?JWPs?#HeBw$Km7RH5GS(MC>v zz*0UNSu3KaB#O`b2n&nDX@dsRUr&vr%7;CXjw@NKyvXnKzZ+sC(p zI7`BZ-U_t6p^ojmMB^`}cK;<+z~4&4^P9X;g^OEw&b9oDB91KhjY7TU__k2_f}l+c z{->r3O`s_VcZdpK*WW21Cc6kUqR?}?3L@qe{FlIA@vB>1Zom+v3;lRT#O(?8tKB{; zuT?mZ0x~cig|fZER>ftG1*O^nSU|zhf&IBq09sH0Dr1R$7QzvX3WG8cc9Rnlm5e-o zO(J&(U%n6UM}%Zf1qcv4kDjx0J90X*J%U%cWrf4NVg>LJ!4tqchC46Ai@N*s?QlrM zYW*P3Q}08&#IHfh;Q6=lBTNGSS;~<@^Pi!FyjMPoB+OBi+EM7m(ps2`bW%jHo&4Pw~h`Ni^;P{%rYui5@qICa1y0_4T=dTABj(faE02h_pg zLfzOwI}ldjguS%1JogC$AMcwqubqj;x)&2d-69_-R;_J@2{QsmjAW0K+w`FtS^-!& zx(DQhD0REs`F*f@i_7ul#++VK%JuO!!qpaG37m7 zRer1|ly`Jp`JtXt{#Q?>6y@g-F29LR519or2^t<@7S4%qnMsaRmla8t>Pj0Cv>)43 z1DX8&?Sw7egZB=%D7L4Qae&!4MrsE`0%ym2+)wSLnQ&z~)C`5x-bGQ8(E};|(=3kg zRfAiUV+p3;je_riAVw~bEg|TZ_VU7fEWmsu+B$+;-5>IYYYMywD@X{{f92*vplpu> zdZ5;nUTB^~7_0?}7o?jDqSYu`?0|3V|KS@4-xt9*@xgo;)j$Aq1%N;pnzD#F!WbZ~ zh287bRteb6EyQyK^%bDhoM1x$g6u+JqXmQzh+(r*y#cI`YDudF5aU4crwa?%dNxo? z2CBjBO+w2GfMBiZP+7hl#@3K%uw151Fezbg6tcX6l`BpfBN0?K?Yh&|7R-R6J( z;H#;FGJgTg{G*igG%_5WyH@(-HTeIli?ijq5+Owki&yx^z-8nSYV=q)GXB_KrR9mcaz(!(hIedZMv`dKyR)og)R99e4MkW&b2<;L;lrb8SYAuH# zi3sts7$-3`2BSK{n7ROW>Ok}OuMcH?f2#PPmdhR;UoEjmH| zH07V6-6M)b3Ac@iDzDjRVU|CYY88n1g zxgd3?3;YexL6e`y6w5W`9<1L@qr-!vAEEmig9E~bJ}Lzo+>PNf|2!rrjNN|H;^f60 znmZi?!ka~8M34g`@Q@a;i{r;I4i)2o%LI33Bmg+s!SQjthnWF({|gj|64oM0EGj3+ z%GD4<@ssElalk41(Zh5{kO^suMBof6_+ykjPRUs$hF-I|N8m=NajTAq=ctV2J;w$H z(!!~SWBOaR_ZW%E5Zd+w?fDqNOuHAz4kE|pG9_O`VoZw4j&noZ_q>n$D!&d6;MfslLS~R=R=_vL$deI}e8@R5rlQ{7ua19` z+LC;S!|&5HGV!^?&>Jrhd!|MwVFdC3Cn^p`UqB1ae}|F-H)ooPPf~K865;R+^`kcC zm6w;-)`d&MKLeI9&MWSwB?6wv<$07EmXQ?GeR@rE7UsIs8O;Z*b~G<8mciV$0OM1f z9Ht|OL_qT4XcwyQn9;~v;oqdUkfel~Sha`D;LyVhjpJX*V`xJVBS1)4v6KQ=B!y4~ z4oeiW;P=OHSCZ=Ym25=s!*YHiJvT6de#Zb_3gh%gc+o-!Yw8Ldi7H-h;~#ekk^C1?45!3VlUcfK(qT-rp$oU#ayHPVh|g7(|0#Ia zKk%LW+thRKef&=zev#%y{cP#h*uDUpM*ZL1;&T&f)x zIl_}PkrU=tbJ1KHFx4(2v$4m4+mL)1o7c+kj0Ukh9ft;9bY?&+MGWoMlui#Y>#hGwDb6#+@5L=!s$H5 z^k5(Yg#F`m#Kw|hGa3pqOaXZPHBjSSsv8594*Q0wrGG*I%a7nK^qj;~05%r)PX$B4 zFi~+%lfJ96Oaxeq@jQywBf%iXI5IA^M}vGY8esb%r!kOO_N_#F4D&S(7d3ZF`m)T2 z;m%IrnG~&0PfB}dgkzh9*`I13!<;`BOaZw&hG!y}g4_CEg7fj8$T=xE-wVd}j-k~g z8|Xml36{U1;H}5}v+Z#=;h$rJL85!?=obS`Y&)NA9~Ui$gX2LqIF2``DP==%sbYic z_Oi`}+b3A8{RsH5q7ztA^!;&cVM#X9J`v3@DA)b-+fM|;n2Tg^0&)51-MAYMGQn|@ z0?gM*|NLH5e{%Z*I3BqdZx@b^PK>A!oMdC&;{klzI6DEyevHn|#K=yzhrcY}d|a&F z#W&<&_)h020joJpGUCr*-JZnr6dS**u!*QI{}K*0P6em-E(521R7lNaaH?~R5EJwL z^dzKl?-{gwmSNiuWPI1RSqx#yREPc?1v{XAN7Vg&R2MwTj=iOQ0UJ7M{>}eTpkNeA z`}F6f&%=FS$KmYg!D-SBAS)mCpO5NmvMs$K(O1ENcx3lqMPMK8BLD^}O&|&9UYltY zHV9uRrbPA1DiHkj<#imAMF9Tug>%IVqUsQ?53*aq#Rcy|;i=+#L=;HPk+P0c6n#8(o3GfdUPxYB^4U{O`JTfe;gB0#7861qg0o zv%)thzPUwb$Pw@%cF^*|Ll7kyW#JZ$AR}N@ErL<+$bBFxC)f(k>=u4-(}43Mj*37n zze0NuS^N=cclv{^_kmCHLJvR{JOAz}+r@}&R_N@V^@`g<_$5w&3Y#{rA@JWt_xv{~ zd4rPQMuK6H>nFUvVj|3T8rN(G_a(yoYnHva1*pL7U;*J+ge_C}V&Xsv_ikvVxHu9r zd4=JAiILmDVHJ*nuus#CmVO~=uSxrsAVQMo(Sy*JS3 z28kOgxGkXaD%yv#ALfsagyP7<_YA~E50!$kP^H@#2tD&P`?e_{v?g3cpkzSryPpmj zHUYYk|AWa9UKxc-;27=Vp(LAjwR?0n)Hzf0aUsSb;^q$jJ+xRCR}RF9YXq^@Y0Oy* zt0e9c2&_IQif}hO*!=;XwTan$NZ=^O8OGz0GbIErO9F=&Mkb{_P|7~>!&e*!w+MkG5Ug%Cn{76R#!JE%#DGP1Aj3&|~h zmC_vHUXzF1!mHPo))!Zn7Uu}%CI>Ri3-pRk<~sK~gulq=sV+L%rVhA?GXE?Q2Ms@pl{KK=hs=Ts99MONK*zAC5zY)uq{00Xl{QeHHVIbL)o)S7z5s ze2_+QnUZhO`&3+|a|p4oSsdsK-MAAcyDA;yi<89&ai`!Y%pNkv;ow7No9Dr1WN?aM zl9bXRHGhrXB`}$vr<_1Z$@JUWFYz9(IEbk5@8K+{h@#;iz3iuS;wHctF<44{r`r=a-c9Ms57abK zqKHg;HwW%+exm6Jsd6Lb%Bx^bD41qibpBs034 zAPiZU>C&JpUKl5p@uDNfbf<aB%8##hN5AeYbZV2L$2%bPKalwaYDVL*!ZV4F! z;O!6WBYhnO^^h#mrB+W^3Rp6Cx`SYn2@n?|qT`7{)QB*=cYY*-f83OFr^V@@IJy)E zj^eD*7@~^9TX1weZ;CoTM&rv0mhOF2VIjg?lb5G`Ga+-1lcqwE_+1KzWO6CFG=jfI-Dkq|PqUDs_gZ~4U)LuCN diff --git a/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/util.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/__pycache__/util.cpython-310.pyc deleted file mode 100644 index d944acb195fee14da3596908c9500e1487c23c51..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 9907 zcmb7KS!^6fdhUCAdJZ0^4eO%5=qHgiwpIs~B>siqD z)UpBGC-HTf6{uXc*fDk-eV$8?rJqWVv2U`c{bBYDo)hT#OtKnw64agA`CfKdkv)rE zAQ^2B~(b!pb4)8PV`5xvA>^vZ6R}^;P zcU7$PMXW0AA5D75KHp$3VZL)rT~Lazd_reszGw(zrc#cCalH~o!dz@Ns=hGa2rBiU zsEMH~&1!Yg55hXX$?L2MMC%&oej_ZZ!ia9OT3MxWIzKv)CZ1_L;Xj~=lr=S0?kH`P zTiUkL)?$L`W6G9+FFi81ED1@CDY3q-wGBXQ2}zGB+iKhFVK^NOi>d8Ytn%|cOjTm0 zqD-u9sx!(v&wf{H+wF90#5Ve*F<gt1Y1^9Y ztzJ}azZ7fOG`;nkP65BD&daObieIn!5wDaBOQop1>azkXMI|&ALcdYsrN}Q-v8l30 zomW;WL8*#XkFY#yIaMqyZ4_>9TrE_>LQszieo$|&tQNvbQ1)fFN>pHVKMYPrh1F8P zs(zsoEc;*pD@2kjp+uYK#D!u`=&#SdEiyN6&CYlWi??PM7DRUUgd#ov#+-L)@y6VO zu%bHAAWXR!VTO%rB@#Jmyf*F4&d=Vu`nIr|!BQn)Vn`0Tcw_O>^=^|RL04|vnw}MD zKWNr`TCqqsgYb6K_gj9^aR*CLz2>d(Qe)MVd*r4op%+Hn4?~ge?Sn{Rj!M}N-C^ob zDSLk4^OX&`N$#QUlD()Nm8yMh`+9B5V2Ph}80i6%nIu%G6=Bef4;6k0oAvV>ldE9g zWHa>nwoUE0~{`is~lvgLuaR#PF8bzZ9p2VzdRwmoJ!j{3?#yg8Q#0@=in`)F_8(14euRDZAy8fWj!MC@DVx46w z`2Ich`|7r;Dz^{q>NSjBAH;;E?Zt}K*xn?KmYx4ZN`WDAvz*uYEx%mn%;zFCcS!<- zeQRzZsbyFE$b%et(MH2B>LPuGj?eWvj^)rzpI^c$oRvpVf}90;VrAV*;KGzWgq1W2 zBPi8;k>dWcx9)EUy8#iUnrnJ;QaW)hiV()beq^Czh&P;m@}>qNX%2_~u_ zh94*BQ&c=nMUjeUQHXSzS0bN7OY>9oZO}q^o~l|9V95)!3Av?4!-oJ{*U!**anQ}v zz}|8*C`afWK8}Pv(~4B0=0Z2u*H@_C4n_PqqAj6isRW^-6(_}EqUw16%DrU{&CoYo zcoul!^C(8u5!KMLszv`aOEuK2>hvB%vmr$`$&Y#f&j@74R{0Aw%RoKwz14c;brg!Q zXutau2_AjK$LW5Jb+Cgdo`nMsvoNFEAOi?x#Y4)fs)59)ls^t z7=-SJ(t>IVH)>1uYS<^jG>?I20Z(`U1rE9LP6{Ha!WwF^`a|tDY|GAm;W%m&Q$k%A z>YW`Zm%Hd(gr1Db^;+YVMGW;*TF3#VbsWp4RYL<2EfW{3b&wOjGO>f+M-z1;JG~E2 zIErG0XtLYt2g(O(40}4J{Kg)t9%}p>Kosgd)Lul*&33mI$GJmjhelF;lEyb^e41?Z zt7z<@W92rE&IGAxIfpQD$ZKIym^28 z-5oWv)T~sQSF1C>Dy(unSgx!Tbv}!+gjQ|{Q?~5o=6(rHL}6#JZxa=?1Y5UNN8|G} zh~b|jv0f^b*CEb*hnWr+tw2&w7NkQx4%Ze|04GY1^;k>j=U0KiUqeweBzgQM0j^Uq zN5x*TeJqLI-AY6jT@q75w}f89A$o}g*n?(NUtShAjK5TbWg@Zrqt*5NM4G!Re{H>E z%E+|Ati}440Y47XxT%ph2rmxq@0Q67G7*urm15=!mKd9ny+(eW2}5Ab^P?CO9%_yh z`Bid1W^kkc7;OgC(vn|=awbz9lEg^&WHIre9Cu)s376-PN-0F z>}AH|sILT#X0(ppA^8efXi3tK#-qGG486&}1K<_mNGz!v$U&>z@p)LP_MIZ~GKfX= zNY}l~k`EQ~G!nH~C2s>Pf~_}VHM|^|5#0Ap6|$m4U=d?C>e%R~F3EdtW(x84C*3V`qfKxo=yTF(1R@*ck(^JmH`}vswCqlC<|` z>5No#I_g>zdD>w3m|)xz_CYF|?twW>enD5S!JsCx88-ZK#qU!Yq%L||eFPs;gIZE$ zDBHWWCr+ISKeW=|uK;ZiOZg?hzK2I@$q}VBN%}jnC34cxgD278bvVc2F16lMnU5t( zs`cGOeEz zC(im9o~bv`yku}mJBNf z5jq$MC1C8}AUUe0W+C`_)kry?+7dQON}x)kzKlX-mZ82PkI@kqY1qT0+NYdJ1JKc) z#uNS_3bH2=?BAxwPetmM9xH1GzRhioA0QPRBTz@s4mG?0yk2`6?I(7%LxC?K=L8%Y z&suXZ-bET~ErXat>5d?45?A77f%O`&b1(uLQ$KELK8h&v@qykprytF|LVRK2M@jBKhY`HUZW~3AC?C5r(MRkf$BYy?Kgq37N z#1I16ZWk|cEp+K$ByqVgWg~n5+IL=~BA{ZAea%S{JVYvp38EhWBczxhYm#?M|BOdj z=yNix4x&W|(W0U8Hc*7IjNFW{D>5D9-4ppvtq#3~T!F9{uQ$A^$RKR3L>tmM5rfN> zdtR+sjVcKD%K(?j4tIYURC=ucEqov-Q|~|%d`r2n(zcTwUXz}2+eCP2#Zc$(oQE$< z{w(EOkU5Fd%W#ztZfZdqFB2rT_bKbxFF*}G zP1&2?NEQy2BZI=>$RPDEtvojOUC;>M2yr;%yRmwOKLFs9NBHqyMXhLZQexnCvPt!3 zBx=>Cp(b=}p-}O^f!P7r z%ho}eg6MX^w10qM63(UuW$d2?Z#DTKevuJQ7BLO~2zbJ1l%iFkH+jE_Chky%O$z3W zN`eUokFsRQsxkE-@@cI~q^;?I!)Mw?s(|Sq!E{IFnibO*Lims1SH;FwY8#PGWXHx@ zddp%KOKl@});5_96YIn#O9LAm{8;%h0`8BDHf7gz24Vcc5hN81*5Bd#CDu%XjGC(g;x{pE7Y1T~| zOL+?-Z79*mMl0!epT72QY^zEeiRso6;OCgmRQZq4IuPqSEn8y* zSC#vqtL?n&+*R)?#liWOJwchB3erE&9kzcD=EAG5e?o$Y9H>QEq*6_!qDsvND_eQ^ z_mJ{v9J+fYT9uw-1%x3uy4=dn)b9r5BSMuFkX=Q34L+kVpyMOqQ0aw;&3q)I{ID#v zyGyNu)4f0f*}kN=&t59DhPq=FmP-}*tN=YRUG+-=jooY%I#l;My;LY35_&ys2%AE1 znUc`zp|ESEb)QvuC{idT*d|vM$*xf7aCzK8q#Q!uLlOB_)JeMqKY>-&QF#CX8DB4uB0#>Tv+sl^S#} zQV&TcCOvzv>^CSC0SS&c|2_?5Ngjk0tWjskCxJPsM<`agfHVTbh9?CDYY7`jW%$BL zGK68Jh58bL;B`;dv+$r+DCvlpg+p;-Gx-<|rv5^nVqcd7yTZG_0%|ye!hwit>WDUo zHzo5lO~bc_FYB{xN~ZEu?TaIiY$FS0ZNV;UP}swI7A09{`#tSv{7*5EOqBF_RWjg1 zixlo}L%3lF;YhY+XtaW40gkMO2ZGW~P>_!c<$N3v1Vogw*aflfD6qb;+HEbE1x9K& z20w=u9W5`S#qMi)6)gr@PNM}%d1q8UM>ZcVPoTx=YjM$Hp=AgynZA}4w4~4i4;{)w zg%MGqaG+RZR7w$Q29Ru&+sd=twz{ohrBF{$EE%N!b5JEF1o_uuQ-i`-P;P$@N=QL2 z@Q2YLCHwEIEjt=Q4l+-d1Dje$T?}vhBvBW#Z6Im0pmY`#De$lXGIfiCZ(IocsE-l5 zJ^P z(K-#UUWSwiMCmwF97<>TYrFU|X#6t_1v;`_I(lUAKc|`cT@6XWe|9M-YAvU`3FA{O zi?(5W>Pt6*@0#lD*;;95#a}wre|4w&KL<6(tLG{&B;;s5>r7_GKEKK0e3{m<7P z%^t^SNVk=~dcdE5?Ey`UR*JrQw9mSuwT{i>R;X0CQ>s>2Vdu;gWLBm?C#N;^?&L(d zglM62AAxfP?V%(Q?LJ{v>vzd*C5;_M?#ROIo3r!tvx^>`-ETFkbY}XXXV1~(WJ zT;k8J&$%n!lsnjK#V761E(BVlaNaCco4(ZY?id2A>DO=GxHiA&&D?lv{`!qeGv4)U z^RF-PEf7nI^F|eSEh0tnE__r@-luT+m z&D!xnTF)lR;JJc)6Yo3Kjwp1#h{-MMDw0V-p>?fyg-PC1SOC}Ljb@2X?RZDk$pbq< z=iSawe=wDD5|GB>xE&e@RQ}WHSlRwZrDh3OP6Yuc^JC z&-@x-_-llYI}R)DfTwm^`&nMKt?Z{{&Eb$gG9H0jgp1G+hmOuoAK!>)#5>8iAE82b z-COX)x78501Qe#?Tk@@q*a>y%lY^U72S2t|#7ywU8KuhLrN$;616jlKPT>Sxhi|T3P`-YT@LZfiw4^5-Pqc!~iQYsA{*G$DO~uD3U>WHe*uRHA2v|eD zTuNQh=6~PA`eUk{B2=BQ^fH_}o5G{St)v=l)t*fk1Mqy|aC+UXe1;fs1j8U-i6jO^ zLv$mGyHISue1C2j{2zf(w2=OPh9)%VCJi4a@WKCynlyfLS6w3Mxa{lV{PQjxs*8c{ z#hVw=-#PSO)DS~ccqBs`s6Y^v#oNqYHu9 z-w_C?!QHoqrwQd43SpG%jSYE~=Vs<+7jGe7HFIrxQHHo+-@g!|#C{z$H-mA!&Yd3O zoz69=$Rr@iNjKN|Sif81E1~OjKFif`FKdT8cd5H5J5SuD)L=G~q^I(@smtPVRQ|63 z7gc^kzOEGYD>8_`QhZ(>N|CCEh{kY$q{lCe`GM?Wi!0%9zC#Ws|!eJr^+!WoSOw`ZGtbh#tWn|;#I{TX#N&XkqE8B(P7B>_o zUFIfPe0fVRuN!5MB(InfMKnSEtU3Ih4C0GF^lgqnZ!zqPiZ AQ~&?~ diff --git a/mace-bench/3rdparty/SevenNet/sevenn/_const.py b/mace-bench/3rdparty/SevenNet/sevenn/_const.py index 528a77b..284b64e 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/_const.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/_const.py @@ -1,310 +1,310 @@ -import os -from enum import Enum -from typing import Dict - -import torch - -import sevenn._keys as KEY -from sevenn.nn.activation import ShiftedSoftPlus - -NUM_UNIV_ELEMENT = 119 # Z = 0 ~ 118 - -IMPLEMENTED_RADIAL_BASIS = ['bessel'] -IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR'] -# TODO: support None. This became difficult because of parallel model -IMPLEMENTED_SELF_CONNECTION_TYPE = ['nequip', 'linear'] -IMPLEMENTED_INTERACTION_TYPE = ['nequip'] - -IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies'] -IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms'] - -SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss'] -SUPPORTING_ERROR_TYPES = [ - 'TotalEnergy', - 'Energy', - 'Force', - 'Stress', - 'Stress_GPa', - 'TotalLoss', -] - -IMPLEMENTED_MODEL = ['E3_equivariant_model'] - -# string input to real torch function -ACTIVATION = { - 'relu': torch.nn.functional.relu, - 'silu': torch.nn.functional.silu, - 'tanh': torch.tanh, - 'abs': torch.abs, - 'ssp': ShiftedSoftPlus, - 'sigmoid': torch.sigmoid, - 'elu': torch.nn.functional.elu, -} -ACTIVATION_FOR_EVEN = { - 'ssp': ShiftedSoftPlus, - 'silu': torch.nn.functional.silu, -} -ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs} -ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD} - -_prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials') -SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth' -SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth' -SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth' -SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth' -SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth' -SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth' - -_git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download' -CHECKPOINT_DOWNLOAD_LINKS = { - SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth', - SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth', -} -# to avoid torch script to compile torch_geometry.data -AtomGraphDataType = Dict[str, torch.Tensor] - - -class LossType(Enum): # only used for train_v1, do not use it afterwards - ENERGY = 'energy' # eV or eV/atom - FORCE = 'force' # eV/A - STRESS = 'stress' # kB - - -def error_record_condition(x): - if type(x) is not list: - return False - for v in x: - if type(v) is not list or len(v) != 2: - return False - if v[0] not in SUPPORTING_ERROR_TYPES: - return False - if v[0] == 'TotalLoss': - continue - if v[1] not in SUPPORTING_METRICS: - return False - return True - - -DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = { - KEY.CUTOFF: 4.5, - KEY.NODE_FEATURE_MULTIPLICITY: 32, - KEY.IRREPS_MANUAL: False, - KEY.LMAX: 1, - KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax - KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax - KEY.IS_PARITY: True, - KEY.NUM_CONVOLUTION: 3, - KEY.RADIAL_BASIS: { - KEY.RADIAL_BASIS_NAME: 'bessel', - }, - KEY.CUTOFF_FUNCTION: { - KEY.CUTOFF_FUNCTION_NAME: 'poly_cut', - }, - KEY.ACTIVATION_RADIAL: 'silu', - KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'}, - KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'}, - KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64], - # KEY.AVG_NUM_NEIGH: True, # deprecated - # KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated - KEY.CONV_DENOMINATOR: 'avg_num_neigh', - KEY.TRAIN_DENOMINTAOR: False, - KEY.TRAIN_SHIFT_SCALE: False, - # KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True - KEY.USE_BIAS_IN_LINEAR: False, - KEY.USE_MODAL_NODE_EMBEDDING: False, - KEY.USE_MODAL_SELF_INTER_INTRO: False, - KEY.USE_MODAL_SELF_INTER_OUTRO: False, - KEY.USE_MODAL_OUTPUT_BLOCK: False, - KEY.READOUT_AS_FCN: False, - # Applied af readout as fcn is True - KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30], - KEY.READOUT_FCN_ACTIVATION: 'relu', - KEY.SELF_CONNECTION_TYPE: 'nequip', - KEY.INTERACTION_TYPE: 'nequip', - KEY._NORMALIZE_SPH: True, - KEY.CUEQUIVARIANCE_CONFIG: {}, -} - - -# Basically, "If provided, it should be type of ..." -MODEL_CONFIG_CONDITION = { - KEY.NODE_FEATURE_MULTIPLICITY: int, - KEY.LMAX: int, - KEY.LMAX_EDGE: int, - KEY.LMAX_NODE: int, - KEY.IS_PARITY: bool, - KEY.RADIAL_BASIS: { - KEY.RADIAL_BASIS_NAME: lambda x: x in IMPLEMENTED_RADIAL_BASIS, - }, - KEY.CUTOFF_FUNCTION: { - KEY.CUTOFF_FUNCTION_NAME: lambda x: x in IMPLEMENTED_CUTOFF_FUNCTION, - }, - KEY.CUTOFF: float, - KEY.NUM_CONVOLUTION: int, - KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) - or x - in [ - 'avg_num_neigh', - 'sqrt_avg_num_neigh', - ], - KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list, - KEY.TRAIN_SHIFT_SCALE: bool, - KEY.TRAIN_DENOMINTAOR: bool, - KEY.USE_BIAS_IN_LINEAR: bool, - KEY.USE_MODAL_NODE_EMBEDDING: bool, - KEY.USE_MODAL_SELF_INTER_INTRO: bool, - KEY.USE_MODAL_SELF_INTER_OUTRO: bool, - KEY.USE_MODAL_OUTPUT_BLOCK: bool, - KEY.READOUT_AS_FCN: bool, - KEY.READOUT_FCN_HIDDEN_NEURONS: list, - KEY.READOUT_FCN_ACTIVATION: str, - KEY.ACTIVATION_RADIAL: str, - KEY.SELF_CONNECTION_TYPE: lambda x: ( - x in IMPLEMENTED_SELF_CONNECTION_TYPE - or ( - isinstance(x, list) - and all(sc in IMPLEMENTED_SELF_CONNECTION_TYPE for sc in x) - ) - ), - KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE, - KEY._NORMALIZE_SPH: bool, - KEY.CUEQUIVARIANCE_CONFIG: dict, -} - - -def model_defaults(config): - defaults = DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG - - if KEY.READOUT_AS_FCN not in config: - config[KEY.READOUT_AS_FCN] = defaults[KEY.READOUT_AS_FCN] - if config[KEY.READOUT_AS_FCN] is False: - defaults.pop(KEY.READOUT_FCN_ACTIVATION, None) - defaults.pop(KEY.READOUT_FCN_HIDDEN_NEURONS, None) - - return defaults - - -DEFAULT_DATA_CONFIG = { - KEY.DTYPE: 'single', - KEY.DATA_FORMAT: 'ase', - KEY.DATA_FORMAT_ARGS: {}, - KEY.SAVE_DATASET: False, - KEY.SAVE_BY_LABEL: False, - KEY.SAVE_BY_TRAIN_VALID: False, - KEY.RATIO: 0.0, - KEY.BATCH_SIZE: 6, - KEY.PREPROCESS_NUM_CORES: 1, - KEY.COMPUTE_STATISTICS: True, - KEY.DATASET_TYPE: 'graph', - # KEY.USE_SPECIES_WISE_SHIFT_SCALE: False, - KEY.USE_MODAL_WISE_SHIFT: False, - KEY.USE_MODAL_WISE_SCALE: False, - KEY.SHIFT: 'per_atom_energy_mean', - KEY.SCALE: 'force_rms', - # KEY.DATA_SHUFFLE: True, - # KEY.DATA_WEIGHT: False, - # KEY.DATA_MODALITY: False, -} - -DATA_CONFIG_CONDITION = { - KEY.DTYPE: str, - KEY.DATA_FORMAT: str, - KEY.DATA_FORMAT_ARGS: dict, - KEY.SAVE_DATASET: str, - KEY.SAVE_BY_LABEL: bool, - KEY.SAVE_BY_TRAIN_VALID: bool, - KEY.RATIO: float, - KEY.BATCH_SIZE: int, - KEY.PREPROCESS_NUM_CORES: int, - KEY.DATASET_TYPE: lambda x: x in ['graph', 'atoms'], - # KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool, - KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT, - KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE, - KEY.USE_MODAL_WISE_SHIFT: bool, - KEY.USE_MODAL_WISE_SCALE: bool, - # KEY.DATA_SHUFFLE: bool, - KEY.COMPUTE_STATISTICS: bool, - # KEY.DATA_WEIGHT: bool, - # KEY.DATA_MODALITY: bool, -} - - -def data_defaults(config): - defaults = DEFAULT_DATA_CONFIG - if KEY.LOAD_VALIDSET in config: - defaults.pop(KEY.RATIO, None) - return defaults - - -DEFAULT_TRAINING_CONFIG = { - KEY.RANDOM_SEED: 1, - KEY.EPOCH: 300, - KEY.LOSS: 'mse', - KEY.LOSS_PARAM: {}, - KEY.OPTIMIZER: 'adam', - KEY.OPTIM_PARAM: {}, - KEY.SCHEDULER: 'exponentiallr', - KEY.SCHEDULER_PARAM: {}, - KEY.FORCE_WEIGHT: 0.1, - KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default - KEY.PER_EPOCH: 5, - # KEY.USE_TESTSET: False, - KEY.CONTINUE: { - KEY.CHECKPOINT: False, - KEY.RESET_OPTIMIZER: False, - KEY.RESET_SCHEDULER: False, - KEY.RESET_EPOCH: False, - KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True, - KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True, - }, - # KEY.DEFAULT_MODAL: 'common', - KEY.CSV_LOG: 'log.csv', - KEY.NUM_WORKERS: 0, - KEY.IS_TRAIN_STRESS: True, - KEY.TRAIN_SHUFFLE: True, - KEY.ERROR_RECORD: [ - ['Energy', 'RMSE'], - ['Force', 'RMSE'], - ['Stress', 'RMSE'], - ['TotalLoss', 'None'], - ], - KEY.BEST_METRIC: 'TotalLoss', - KEY.USE_WEIGHT: False, - KEY.USE_MODALITY: False, -} - - -TRAINING_CONFIG_CONDITION = { - KEY.RANDOM_SEED: int, - KEY.EPOCH: int, - KEY.FORCE_WEIGHT: float, - KEY.STRESS_WEIGHT: float, - KEY.USE_TESTSET: None, # Not used - KEY.NUM_WORKERS: int, - KEY.PER_EPOCH: int, - KEY.CONTINUE: { - KEY.CHECKPOINT: str, - KEY.RESET_OPTIMIZER: bool, - KEY.RESET_SCHEDULER: bool, - KEY.RESET_EPOCH: bool, - KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool, - KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool, - }, - KEY.DEFAULT_MODAL: str, - KEY.IS_TRAIN_STRESS: bool, - KEY.TRAIN_SHUFFLE: bool, - KEY.ERROR_RECORD: error_record_condition, - KEY.BEST_METRIC: str, - KEY.CSV_LOG: str, - KEY.USE_MODALITY: bool, - KEY.USE_WEIGHT: bool, -} - - -def train_defaults(config): - defaults = DEFAULT_TRAINING_CONFIG - if KEY.IS_TRAIN_STRESS not in config: - config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS] - if not config[KEY.IS_TRAIN_STRESS]: - defaults.pop(KEY.STRESS_WEIGHT, None) - return defaults +import os +from enum import Enum +from typing import Dict + +import torch + +import sevenn._keys as KEY +from sevenn.nn.activation import ShiftedSoftPlus + +NUM_UNIV_ELEMENT = 119 # Z = 0 ~ 118 + +IMPLEMENTED_RADIAL_BASIS = ['bessel'] +IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR'] +# TODO: support None. This became difficult because of parallel model +IMPLEMENTED_SELF_CONNECTION_TYPE = ['nequip', 'linear'] +IMPLEMENTED_INTERACTION_TYPE = ['nequip'] + +IMPLEMENTED_SHIFT = ['per_atom_energy_mean', 'elemwise_reference_energies'] +IMPLEMENTED_SCALE = ['force_rms', 'per_atom_energy_std', 'elemwise_force_rms'] + +SUPPORTING_METRICS = ['RMSE', 'ComponentRMSE', 'MAE', 'Loss'] +SUPPORTING_ERROR_TYPES = [ + 'TotalEnergy', + 'Energy', + 'Force', + 'Stress', + 'Stress_GPa', + 'TotalLoss', +] + +IMPLEMENTED_MODEL = ['E3_equivariant_model'] + +# string input to real torch function +ACTIVATION = { + 'relu': torch.nn.functional.relu, + 'silu': torch.nn.functional.silu, + 'tanh': torch.tanh, + 'abs': torch.abs, + 'ssp': ShiftedSoftPlus, + 'sigmoid': torch.sigmoid, + 'elu': torch.nn.functional.elu, +} +ACTIVATION_FOR_EVEN = { + 'ssp': ShiftedSoftPlus, + 'silu': torch.nn.functional.silu, +} +ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs} +ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD} + +_prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials') +SEVENNET_0_11Jul2024 = f'{_prefix}/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth' +SEVENNET_0_22May2024 = f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth' +SEVENNET_l3i5 = f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth' +SEVENNET_MF_0 = f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth' +SEVENNET_MF_ompa = f'{_prefix}/SevenNet_MF_ompa/checkpoint_sevennet_mf_ompa.pth' +SEVENNET_omat = f'{_prefix}/SevenNet_omat/checkpoint_sevennet_omat.pth' + +_git_prefix = 'https://github.com/MDIL-SNU/SevenNet/releases/download' +CHECKPOINT_DOWNLOAD_LINKS = { + SEVENNET_MF_ompa: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_mf_ompa.pth', + SEVENNET_omat: f'{_git_prefix}/v0.11.0.cp/checkpoint_sevennet_omat.pth', +} +# to avoid torch script to compile torch_geometry.data +AtomGraphDataType = Dict[str, torch.Tensor] + + +class LossType(Enum): # only used for train_v1, do not use it afterwards + ENERGY = 'energy' # eV or eV/atom + FORCE = 'force' # eV/A + STRESS = 'stress' # kB + + +def error_record_condition(x): + if type(x) is not list: + return False + for v in x: + if type(v) is not list or len(v) != 2: + return False + if v[0] not in SUPPORTING_ERROR_TYPES: + return False + if v[0] == 'TotalLoss': + continue + if v[1] not in SUPPORTING_METRICS: + return False + return True + + +DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = { + KEY.CUTOFF: 4.5, + KEY.NODE_FEATURE_MULTIPLICITY: 32, + KEY.IRREPS_MANUAL: False, + KEY.LMAX: 1, + KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax + KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax + KEY.IS_PARITY: True, + KEY.NUM_CONVOLUTION: 3, + KEY.RADIAL_BASIS: { + KEY.RADIAL_BASIS_NAME: 'bessel', + }, + KEY.CUTOFF_FUNCTION: { + KEY.CUTOFF_FUNCTION_NAME: 'poly_cut', + }, + KEY.ACTIVATION_RADIAL: 'silu', + KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'}, + KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'}, + KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64], + # KEY.AVG_NUM_NEIGH: True, # deprecated + # KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated + KEY.CONV_DENOMINATOR: 'avg_num_neigh', + KEY.TRAIN_DENOMINTAOR: False, + KEY.TRAIN_SHIFT_SCALE: False, + # KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True + KEY.USE_BIAS_IN_LINEAR: False, + KEY.USE_MODAL_NODE_EMBEDDING: False, + KEY.USE_MODAL_SELF_INTER_INTRO: False, + KEY.USE_MODAL_SELF_INTER_OUTRO: False, + KEY.USE_MODAL_OUTPUT_BLOCK: False, + KEY.READOUT_AS_FCN: False, + # Applied af readout as fcn is True + KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30], + KEY.READOUT_FCN_ACTIVATION: 'relu', + KEY.SELF_CONNECTION_TYPE: 'nequip', + KEY.INTERACTION_TYPE: 'nequip', + KEY._NORMALIZE_SPH: True, + KEY.CUEQUIVARIANCE_CONFIG: {}, +} + + +# Basically, "If provided, it should be type of ..." +MODEL_CONFIG_CONDITION = { + KEY.NODE_FEATURE_MULTIPLICITY: int, + KEY.LMAX: int, + KEY.LMAX_EDGE: int, + KEY.LMAX_NODE: int, + KEY.IS_PARITY: bool, + KEY.RADIAL_BASIS: { + KEY.RADIAL_BASIS_NAME: lambda x: x in IMPLEMENTED_RADIAL_BASIS, + }, + KEY.CUTOFF_FUNCTION: { + KEY.CUTOFF_FUNCTION_NAME: lambda x: x in IMPLEMENTED_CUTOFF_FUNCTION, + }, + KEY.CUTOFF: float, + KEY.NUM_CONVOLUTION: int, + KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) + or x + in [ + 'avg_num_neigh', + 'sqrt_avg_num_neigh', + ], + KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list, + KEY.TRAIN_SHIFT_SCALE: bool, + KEY.TRAIN_DENOMINTAOR: bool, + KEY.USE_BIAS_IN_LINEAR: bool, + KEY.USE_MODAL_NODE_EMBEDDING: bool, + KEY.USE_MODAL_SELF_INTER_INTRO: bool, + KEY.USE_MODAL_SELF_INTER_OUTRO: bool, + KEY.USE_MODAL_OUTPUT_BLOCK: bool, + KEY.READOUT_AS_FCN: bool, + KEY.READOUT_FCN_HIDDEN_NEURONS: list, + KEY.READOUT_FCN_ACTIVATION: str, + KEY.ACTIVATION_RADIAL: str, + KEY.SELF_CONNECTION_TYPE: lambda x: ( + x in IMPLEMENTED_SELF_CONNECTION_TYPE + or ( + isinstance(x, list) + and all(sc in IMPLEMENTED_SELF_CONNECTION_TYPE for sc in x) + ) + ), + KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE, + KEY._NORMALIZE_SPH: bool, + KEY.CUEQUIVARIANCE_CONFIG: dict, +} + + +def model_defaults(config): + defaults = DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG + + if KEY.READOUT_AS_FCN not in config: + config[KEY.READOUT_AS_FCN] = defaults[KEY.READOUT_AS_FCN] + if config[KEY.READOUT_AS_FCN] is False: + defaults.pop(KEY.READOUT_FCN_ACTIVATION, None) + defaults.pop(KEY.READOUT_FCN_HIDDEN_NEURONS, None) + + return defaults + + +DEFAULT_DATA_CONFIG = { + KEY.DTYPE: 'single', + KEY.DATA_FORMAT: 'ase', + KEY.DATA_FORMAT_ARGS: {}, + KEY.SAVE_DATASET: False, + KEY.SAVE_BY_LABEL: False, + KEY.SAVE_BY_TRAIN_VALID: False, + KEY.RATIO: 0.0, + KEY.BATCH_SIZE: 6, + KEY.PREPROCESS_NUM_CORES: 1, + KEY.COMPUTE_STATISTICS: True, + KEY.DATASET_TYPE: 'graph', + # KEY.USE_SPECIES_WISE_SHIFT_SCALE: False, + KEY.USE_MODAL_WISE_SHIFT: False, + KEY.USE_MODAL_WISE_SCALE: False, + KEY.SHIFT: 'per_atom_energy_mean', + KEY.SCALE: 'force_rms', + # KEY.DATA_SHUFFLE: True, + # KEY.DATA_WEIGHT: False, + # KEY.DATA_MODALITY: False, +} + +DATA_CONFIG_CONDITION = { + KEY.DTYPE: str, + KEY.DATA_FORMAT: str, + KEY.DATA_FORMAT_ARGS: dict, + KEY.SAVE_DATASET: str, + KEY.SAVE_BY_LABEL: bool, + KEY.SAVE_BY_TRAIN_VALID: bool, + KEY.RATIO: float, + KEY.BATCH_SIZE: int, + KEY.PREPROCESS_NUM_CORES: int, + KEY.DATASET_TYPE: lambda x: x in ['graph', 'atoms'], + # KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool, + KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT, + KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE, + KEY.USE_MODAL_WISE_SHIFT: bool, + KEY.USE_MODAL_WISE_SCALE: bool, + # KEY.DATA_SHUFFLE: bool, + KEY.COMPUTE_STATISTICS: bool, + # KEY.DATA_WEIGHT: bool, + # KEY.DATA_MODALITY: bool, +} + + +def data_defaults(config): + defaults = DEFAULT_DATA_CONFIG + if KEY.LOAD_VALIDSET in config: + defaults.pop(KEY.RATIO, None) + return defaults + + +DEFAULT_TRAINING_CONFIG = { + KEY.RANDOM_SEED: 1, + KEY.EPOCH: 300, + KEY.LOSS: 'mse', + KEY.LOSS_PARAM: {}, + KEY.OPTIMIZER: 'adam', + KEY.OPTIM_PARAM: {}, + KEY.SCHEDULER: 'exponentiallr', + KEY.SCHEDULER_PARAM: {}, + KEY.FORCE_WEIGHT: 0.1, + KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default + KEY.PER_EPOCH: 5, + # KEY.USE_TESTSET: False, + KEY.CONTINUE: { + KEY.CHECKPOINT: False, + KEY.RESET_OPTIMIZER: False, + KEY.RESET_SCHEDULER: False, + KEY.RESET_EPOCH: False, + KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True, + KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True, + }, + # KEY.DEFAULT_MODAL: 'common', + KEY.CSV_LOG: 'log.csv', + KEY.NUM_WORKERS: 0, + KEY.IS_TRAIN_STRESS: True, + KEY.TRAIN_SHUFFLE: True, + KEY.ERROR_RECORD: [ + ['Energy', 'RMSE'], + ['Force', 'RMSE'], + ['Stress', 'RMSE'], + ['TotalLoss', 'None'], + ], + KEY.BEST_METRIC: 'TotalLoss', + KEY.USE_WEIGHT: False, + KEY.USE_MODALITY: False, +} + + +TRAINING_CONFIG_CONDITION = { + KEY.RANDOM_SEED: int, + KEY.EPOCH: int, + KEY.FORCE_WEIGHT: float, + KEY.STRESS_WEIGHT: float, + KEY.USE_TESTSET: None, # Not used + KEY.NUM_WORKERS: int, + KEY.PER_EPOCH: int, + KEY.CONTINUE: { + KEY.CHECKPOINT: str, + KEY.RESET_OPTIMIZER: bool, + KEY.RESET_SCHEDULER: bool, + KEY.RESET_EPOCH: bool, + KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool, + KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool, + }, + KEY.DEFAULT_MODAL: str, + KEY.IS_TRAIN_STRESS: bool, + KEY.TRAIN_SHUFFLE: bool, + KEY.ERROR_RECORD: error_record_condition, + KEY.BEST_METRIC: str, + KEY.CSV_LOG: str, + KEY.USE_MODALITY: bool, + KEY.USE_WEIGHT: bool, +} + + +def train_defaults(config): + defaults = DEFAULT_TRAINING_CONFIG + if KEY.IS_TRAIN_STRESS not in config: + config[KEY.IS_TRAIN_STRESS] = defaults[KEY.IS_TRAIN_STRESS] + if not config[KEY.IS_TRAIN_STRESS]: + defaults.pop(KEY.STRESS_WEIGHT, None) + return defaults diff --git a/mace-bench/3rdparty/SevenNet/sevenn/_keys.py b/mace-bench/3rdparty/SevenNet/sevenn/_keys.py index f91b6de..1ff5a61 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/_keys.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/_keys.py @@ -1,226 +1,226 @@ -""" -How to add new feature? - -1. Add new key to this file. -2. Add new key to _const.py -2.1. if the type of input is consistent, - write adequate condition and default to _const.py. -2.2. if the type of input is not consistent, - you must add your own input validation code to - parse_input.py -""" - -from typing import Final - -# see -# https://github.com/pytorch/pytorch/issues/52312 -# for FYI - -# ~~ keys ~~ # -# PyG : primitive key of torch_geometric.data.Data type - -# ==================================================# -# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ # -# ==================================================# -# some raw properties of graph -ATOMIC_NUMBERS: Final[str] = 'atomic_numbers' # (N) -POS: Final[str] = 'pos' # (N, 3) PyG -CELL: Final[str] = 'cell_lattice_vectors' # (3, 3) -CELL_SHIFT: Final[str] = 'pbc_shift' # (N, 3) -CELL_VOLUME: Final[str] = 'cell_volume' - -EDGE_VEC: Final[str] = 'edge_vec' # (N_edge, 3) -EDGE_LENGTH: Final[str] = 'edge_length' # (N_edge, 1) - -# some primary data of graph -EDGE_IDX: Final[str] = 'edge_index' # (2, N_edge) PyG -ATOM_TYPE: Final[str] = 'atom_type' # (N) one-hot index of nodes -NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG -NODE_FEATURE_GHOST: Final[str] = 'x_ghost' -NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot -MODAL_ATTR: Final[str] = ( - 'modal_attr' # (1, N_modalities) for handling multi-modal -) -MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal -EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics) -EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding) - -# inputs of loss function -ENERGY: Final[str] = 'total_energy' # (1) -FORCE: Final[str] = 'force_of_atoms' # (N, 3) -STRESS: Final[str] = 'stress' # (6) - -# This is for training, per atom scale. -SCALED_ENERGY: Final[str] = 'scaled_total_energy' - -# general outputs of models -SCALED_ATOMIC_ENERGY: Final[str] = 'scaled_atomic_energy' -ATOMIC_ENERGY: Final[str] = 'atomic_energy' -PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy' - -PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy' -PER_ATOM_ENERGY: Final[str] = 'per_atom_energy' - -PRED_FORCE: Final[str] = 'inferred_force' -SCALED_FORCE: Final[str] = 'scaled_force' - -PRED_STRESS: Final[str] = 'inferred_stress' -SCALED_STRESS: Final[str] = 'scaled_stress' - -# very general data property for AtomGraphData -NUM_ATOMS: Final[str] = 'num_atoms' # int -NUM_GHOSTS: Final[str] = 'num_ghosts' -NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu -USER_LABEL: Final[str] = 'user_label' -DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data -DATA_MODALITY: Final[str] = ( - 'data_modality' # modality of given data. e.g. PBE and SCAN -) -BATCH: Final[str] = 'batch' - -TAG = 'tag' # replace USER_LABEL - -# etc -SELF_CONNECTION_TEMP: Final[str] = 'self_cont_tmp' -BATCH_SIZE: Final[str] = 'batch_size' -INFO: Final[str] = 'data_info' - -# something special -LABEL_NONE: Final[str] = 'No_label' - -# ==================================================# -# ~~~~~~ KEY for train/data configuration ~~~~~~~~ # -# ==================================================# -PREPROCESS_NUM_CORES = 'preprocess_num_cores' -SAVE_DATASET = 'save_dataset_path' -SAVE_BY_LABEL = 'save_by_label' -SAVE_BY_TRAIN_VALID = 'save_by_train_valid' -DATA_FORMAT = 'data_format' -DATA_FORMAT_ARGS = 'data_format_args' -STRUCTURE_LIST = 'structure_list' -LOAD_DATASET = 'load_dataset_path' # not used in v2 -LOAD_TRAINSET = 'load_trainset_path' -LOAD_VALIDSET = 'load_validset_path' -LOAD_TESTSET = 'load_testset_path' -FORMAT_OUTPUTS = 'format_outputs_for_ase' -COMPUTE_STATISTICS = 'compute_statistics' -DATASET_TYPE = 'dataset_type' - -RANDOM_SEED = 'random_seed' -RATIO = 'data_divide_ratio' -USE_TESTSET = 'use_testset' -EPOCH = 'epoch' -LOSS = 'loss' -LOSS_PARAM = 'loss_param' -OPTIMIZER = 'optimizer' -OPTIM_PARAM = 'optim_param' -SCHEDULER = 'scheduler' -SCHEDULER_PARAM = 'scheduler_param' -FORCE_WEIGHT = 'force_loss_weight' -STRESS_WEIGHT = 'stress_loss_weight' -DEVICE = 'device' -DTYPE = 'dtype' - -TRAIN_SHUFFLE = 'train_shuffle' - -IS_TRAIN_STRESS = 'is_train_stress' - -CONTINUE = 'continue' -CHECKPOINT = 'checkpoint' -RESET_OPTIMIZER = 'reset_optimizer' -RESET_SCHEDULER = 'reset_scheduler' -RESET_EPOCH = 'reset_epoch' -USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint' -USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = ( - 'use_statistic_values_for_cp_modal_only' -) - -CSV_LOG = 'csv_log' - -ERROR_RECORD = 'error_record' -BEST_METRIC = 'best_metric' - -NUM_WORKERS = 'num_workers' # not work - -RANK = 'rank' -LOCAL_RANK = 'local_rank' -WORLD_SIZE = 'world_size' -IS_DDP = 'is_ddp' -DDP_BACKEND = 'ddp_backend' -PER_EPOCH = 'per_epoch' - -USE_WEIGHT = 'use_weight' -USE_MODALITY = 'use_modality' -DEFAULT_MODAL = 'default_modal' - - -# ==================================================# -# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ # -# ==================================================# -# ~~ global model configuration ~~ # -# note that these names are directly used for input.yaml for user input -MODEL_TYPE = '_model_type' -CUTOFF = 'cutoff' -CHEMICAL_SPECIES = 'chemical_species' -MODAL_LIST = 'modal_list' -CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number' -NUM_SPECIES = '_number_of_species' -NUM_MODALITIES = '_number_of_modalities' -TYPE_MAP = '_type_map' -MODAL_MAP = '_modal_map' - -# ~~ E3 equivariant model build configuration keys ~~ # -# see model_build default_config for type -IRREPS_MANUAL = 'irreps_manual' -NODE_FEATURE_MULTIPLICITY = 'channel' - -RADIAL_BASIS = 'radial_basis' -BESSEL_BASIS_NUM = 'bessel_basis_num' - -CUTOFF_FUNCTION = 'cutoff_function' -POLY_CUT_P = 'poly_cut_p_value' - -LMAX = 'lmax' -LMAX_EDGE = 'lmax_edge' -LMAX_NODE = 'lmax_node' -IS_PARITY = 'is_parity' -CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS = 'weight_nn_hidden_neurons' -NUM_CONVOLUTION = 'num_convolution_layer' -ACTIVATION_SCARLAR = 'act_scalar' -ACTIVATION_GATE = 'act_gate' -ACTIVATION_RADIAL = 'act_radial' - -SELF_CONNECTION_TYPE = 'self_connection_type' - -RADIAL_BASIS_NAME = 'radial_basis_name' -CUTOFF_FUNCTION_NAME = 'cutoff_function_name' - -USE_BIAS_IN_LINEAR = 'use_bias_in_linear' - -USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding' -USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro' -USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro' -USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block' - -READOUT_AS_FCN = 'readout_as_fcn' -READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons' -READOUT_FCN_ACTIVATION = 'readout_fcn_activation' - -AVG_NUM_NEIGH = 'avg_num_neigh' -CONV_DENOMINATOR = 'conv_denominator' -SHIFT = 'shift' -SCALE = 'scale' - -USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale' -USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift' -USE_MODAL_WISE_SCALE = 'use_modal_wise_scale' - -TRAIN_SHIFT_SCALE = 'train_shift_scale' -TRAIN_DENOMINTAOR = 'train_denominator' -INTERACTION_TYPE = 'interaction_type' -TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated - -CUEQUIVARIANCE_CONFIG = 'cuequivariance_config' - -_NORMALIZE_SPH = '_normalize_sph' -OPTIMIZE_BY_REDUCE = 'optimize_by_reduce' +""" +How to add new feature? + +1. Add new key to this file. +2. Add new key to _const.py +2.1. if the type of input is consistent, + write adequate condition and default to _const.py. +2.2. if the type of input is not consistent, + you must add your own input validation code to + parse_input.py +""" + +from typing import Final + +# see +# https://github.com/pytorch/pytorch/issues/52312 +# for FYI + +# ~~ keys ~~ # +# PyG : primitive key of torch_geometric.data.Data type + +# ==================================================# +# ~~~~~~~~~~~~~~~~~ KEY for data ~~~~~~~~~~~~~~~~~~ # +# ==================================================# +# some raw properties of graph +ATOMIC_NUMBERS: Final[str] = 'atomic_numbers' # (N) +POS: Final[str] = 'pos' # (N, 3) PyG +CELL: Final[str] = 'cell_lattice_vectors' # (3, 3) +CELL_SHIFT: Final[str] = 'pbc_shift' # (N, 3) +CELL_VOLUME: Final[str] = 'cell_volume' + +EDGE_VEC: Final[str] = 'edge_vec' # (N_edge, 3) +EDGE_LENGTH: Final[str] = 'edge_length' # (N_edge, 1) + +# some primary data of graph +EDGE_IDX: Final[str] = 'edge_index' # (2, N_edge) PyG +ATOM_TYPE: Final[str] = 'atom_type' # (N) one-hot index of nodes +NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG +NODE_FEATURE_GHOST: Final[str] = 'x_ghost' +NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot +MODAL_ATTR: Final[str] = ( + 'modal_attr' # (1, N_modalities) for handling multi-modal +) +MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal +EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics) +EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding) + +# inputs of loss function +ENERGY: Final[str] = 'total_energy' # (1) +FORCE: Final[str] = 'force_of_atoms' # (N, 3) +STRESS: Final[str] = 'stress' # (6) + +# This is for training, per atom scale. +SCALED_ENERGY: Final[str] = 'scaled_total_energy' + +# general outputs of models +SCALED_ATOMIC_ENERGY: Final[str] = 'scaled_atomic_energy' +ATOMIC_ENERGY: Final[str] = 'atomic_energy' +PRED_TOTAL_ENERGY: Final[str] = 'inferred_total_energy' + +PRED_PER_ATOM_ENERGY: Final[str] = 'inferred_per_atom_energy' +PER_ATOM_ENERGY: Final[str] = 'per_atom_energy' + +PRED_FORCE: Final[str] = 'inferred_force' +SCALED_FORCE: Final[str] = 'scaled_force' + +PRED_STRESS: Final[str] = 'inferred_stress' +SCALED_STRESS: Final[str] = 'scaled_stress' + +# very general data property for AtomGraphData +NUM_ATOMS: Final[str] = 'num_atoms' # int +NUM_GHOSTS: Final[str] = 'num_ghosts' +NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu +USER_LABEL: Final[str] = 'user_label' +DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data +DATA_MODALITY: Final[str] = ( + 'data_modality' # modality of given data. e.g. PBE and SCAN +) +BATCH: Final[str] = 'batch' + +TAG = 'tag' # replace USER_LABEL + +# etc +SELF_CONNECTION_TEMP: Final[str] = 'self_cont_tmp' +BATCH_SIZE: Final[str] = 'batch_size' +INFO: Final[str] = 'data_info' + +# something special +LABEL_NONE: Final[str] = 'No_label' + +# ==================================================# +# ~~~~~~ KEY for train/data configuration ~~~~~~~~ # +# ==================================================# +PREPROCESS_NUM_CORES = 'preprocess_num_cores' +SAVE_DATASET = 'save_dataset_path' +SAVE_BY_LABEL = 'save_by_label' +SAVE_BY_TRAIN_VALID = 'save_by_train_valid' +DATA_FORMAT = 'data_format' +DATA_FORMAT_ARGS = 'data_format_args' +STRUCTURE_LIST = 'structure_list' +LOAD_DATASET = 'load_dataset_path' # not used in v2 +LOAD_TRAINSET = 'load_trainset_path' +LOAD_VALIDSET = 'load_validset_path' +LOAD_TESTSET = 'load_testset_path' +FORMAT_OUTPUTS = 'format_outputs_for_ase' +COMPUTE_STATISTICS = 'compute_statistics' +DATASET_TYPE = 'dataset_type' + +RANDOM_SEED = 'random_seed' +RATIO = 'data_divide_ratio' +USE_TESTSET = 'use_testset' +EPOCH = 'epoch' +LOSS = 'loss' +LOSS_PARAM = 'loss_param' +OPTIMIZER = 'optimizer' +OPTIM_PARAM = 'optim_param' +SCHEDULER = 'scheduler' +SCHEDULER_PARAM = 'scheduler_param' +FORCE_WEIGHT = 'force_loss_weight' +STRESS_WEIGHT = 'stress_loss_weight' +DEVICE = 'device' +DTYPE = 'dtype' + +TRAIN_SHUFFLE = 'train_shuffle' + +IS_TRAIN_STRESS = 'is_train_stress' + +CONTINUE = 'continue' +CHECKPOINT = 'checkpoint' +RESET_OPTIMIZER = 'reset_optimizer' +RESET_SCHEDULER = 'reset_scheduler' +RESET_EPOCH = 'reset_epoch' +USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint' +USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = ( + 'use_statistic_values_for_cp_modal_only' +) + +CSV_LOG = 'csv_log' + +ERROR_RECORD = 'error_record' +BEST_METRIC = 'best_metric' + +NUM_WORKERS = 'num_workers' # not work + +RANK = 'rank' +LOCAL_RANK = 'local_rank' +WORLD_SIZE = 'world_size' +IS_DDP = 'is_ddp' +DDP_BACKEND = 'ddp_backend' +PER_EPOCH = 'per_epoch' + +USE_WEIGHT = 'use_weight' +USE_MODALITY = 'use_modality' +DEFAULT_MODAL = 'default_modal' + + +# ==================================================# +# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ # +# ==================================================# +# ~~ global model configuration ~~ # +# note that these names are directly used for input.yaml for user input +MODEL_TYPE = '_model_type' +CUTOFF = 'cutoff' +CHEMICAL_SPECIES = 'chemical_species' +MODAL_LIST = 'modal_list' +CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number' +NUM_SPECIES = '_number_of_species' +NUM_MODALITIES = '_number_of_modalities' +TYPE_MAP = '_type_map' +MODAL_MAP = '_modal_map' + +# ~~ E3 equivariant model build configuration keys ~~ # +# see model_build default_config for type +IRREPS_MANUAL = 'irreps_manual' +NODE_FEATURE_MULTIPLICITY = 'channel' + +RADIAL_BASIS = 'radial_basis' +BESSEL_BASIS_NUM = 'bessel_basis_num' + +CUTOFF_FUNCTION = 'cutoff_function' +POLY_CUT_P = 'poly_cut_p_value' + +LMAX = 'lmax' +LMAX_EDGE = 'lmax_edge' +LMAX_NODE = 'lmax_node' +IS_PARITY = 'is_parity' +CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS = 'weight_nn_hidden_neurons' +NUM_CONVOLUTION = 'num_convolution_layer' +ACTIVATION_SCARLAR = 'act_scalar' +ACTIVATION_GATE = 'act_gate' +ACTIVATION_RADIAL = 'act_radial' + +SELF_CONNECTION_TYPE = 'self_connection_type' + +RADIAL_BASIS_NAME = 'radial_basis_name' +CUTOFF_FUNCTION_NAME = 'cutoff_function_name' + +USE_BIAS_IN_LINEAR = 'use_bias_in_linear' + +USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding' +USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro' +USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro' +USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block' + +READOUT_AS_FCN = 'readout_as_fcn' +READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons' +READOUT_FCN_ACTIVATION = 'readout_fcn_activation' + +AVG_NUM_NEIGH = 'avg_num_neigh' +CONV_DENOMINATOR = 'conv_denominator' +SHIFT = 'shift' +SCALE = 'scale' + +USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale' +USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift' +USE_MODAL_WISE_SCALE = 'use_modal_wise_scale' + +TRAIN_SHIFT_SCALE = 'train_shift_scale' +TRAIN_DENOMINTAOR = 'train_denominator' +INTERACTION_TYPE = 'interaction_type' +TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated + +CUEQUIVARIANCE_CONFIG = 'cuequivariance_config' + +_NORMALIZE_SPH = '_normalize_sph' +OPTIMIZE_BY_REDUCE = 'optimize_by_reduce' diff --git a/mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py b/mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py index ee5b7bb..e0de629 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/atom_graph_data.py @@ -1,75 +1,75 @@ -from typing import Optional - -import torch -import torch_geometric.data - -import sevenn._keys as KEY -import sevenn.util - - -class AtomGraphData(torch_geometric.data.Data): - """ - Args: - x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes, - atomic_numbers]`. (default: :obj:`None`) - edge_index (LongTensor, optional): Graph connectivity in coordinate - format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) - edge_attr (Tensor, optional): Edge feature matrix with shape - :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) - y_energy: scalar # unit of eV (VASP raw) - y_force: [num_nodes, 3] # unit of eV/A (VASP raw) - y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw) - pos (Tensor, optional): Node position matrix with shape - :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) - **kwargs (optional): Additional attributes. - - x, y_force, pos should be aligned with each other. - """ - - def __init__( - self, - x: Optional[torch.Tensor] = None, - edge_index: Optional[torch.Tensor] = None, - pos: Optional[torch.Tensor] = None, - edge_attr: Optional[torch.Tensor] = None, - **kwargs - ): - super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos) - self[KEY.NODE_ATTR] = x # ? - for k, v in kwargs.items(): - self[k] = v - - def to_numpy_dict(self): - # This is not debugged yet! - dct = { - k: v.detach().cpu().numpy() if type(v) is torch.Tensor else v - for k, v in self.items() - } - return dct - - def fit_dimension(self): - per_atom_keys = [ - KEY.ATOMIC_NUMBERS, - KEY.ATOMIC_ENERGY, - KEY.POS, - KEY.FORCE, - KEY.PRED_FORCE, - ] - natoms = self.num_atoms.item() - for k, v in self.items(): - if not isinstance(v, torch.Tensor): - continue - if natoms == 1 and k in per_atom_keys: - self[k] = v.squeeze().unsqueeze(0) - else: - self[k] = v.squeeze() - return self - - @staticmethod - def from_numpy_dict(dct): - for k, v in dct.items(): - if k == KEY.CELL_SHIFT: - dct[k] = torch.Tensor(v) # this is special - else: - dct[k] = sevenn.util.dtype_correct(v) - return AtomGraphData(**dct) +from typing import Optional + +import torch +import torch_geometric.data + +import sevenn._keys as KEY +import sevenn.util + + +class AtomGraphData(torch_geometric.data.Data): + """ + Args: + x (Tensor, optional): atomic numbers with shape :obj:`[num_nodes, + atomic_numbers]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Graph connectivity in coordinate + format with shape :obj:`[2, num_edges]`. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + y_energy: scalar # unit of eV (VASP raw) + y_force: [num_nodes, 3] # unit of eV/A (VASP raw) + y_stress: [6] # [xx, yy, zz, xy, yz, zx] # unit of eV/A^3 (VASP raw) + pos (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + **kwargs (optional): Additional attributes. + + x, y_force, pos should be aligned with each other. + """ + + def __init__( + self, + x: Optional[torch.Tensor] = None, + edge_index: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + edge_attr: Optional[torch.Tensor] = None, + **kwargs + ): + super(AtomGraphData, self).__init__(x, edge_index, edge_attr, pos=pos) + self[KEY.NODE_ATTR] = x # ? + for k, v in kwargs.items(): + self[k] = v + + def to_numpy_dict(self): + # This is not debugged yet! + dct = { + k: v.detach().cpu().numpy() if type(v) is torch.Tensor else v + for k, v in self.items() + } + return dct + + def fit_dimension(self): + per_atom_keys = [ + KEY.ATOMIC_NUMBERS, + KEY.ATOMIC_ENERGY, + KEY.POS, + KEY.FORCE, + KEY.PRED_FORCE, + ] + natoms = self.num_atoms.item() + for k, v in self.items(): + if not isinstance(v, torch.Tensor): + continue + if natoms == 1 and k in per_atom_keys: + self[k] = v.squeeze().unsqueeze(0) + else: + self[k] = v.squeeze() + return self + + @staticmethod + def from_numpy_dict(dct): + for k, v in dct.items(): + if k == KEY.CELL_SHIFT: + dct[k] = torch.Tensor(v) # this is special + else: + dct[k] = sevenn.util.dtype_correct(v) + return AtomGraphData(**dct) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/calculator.py b/mace-bench/3rdparty/SevenNet/sevenn/calculator.py index 7a7f386..e237886 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/calculator.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/calculator.py @@ -1,846 +1,846 @@ -import ctypes -import os -import pathlib -import warnings -from typing import Any, Dict, Optional, Union - -import numpy as np -import torch -import torch.jit -import torch.jit._script -from ase.calculators.calculator import Calculator, all_changes -from ase.calculators.mixing import SumCalculator -from ase.data import chemical_symbols - -import sevenn._keys as KEY -import sevenn.util as util -from sevenn.atom_graph_data import AtomGraphData -from sevenn.nn.sequential import AtomGraphSequential -from sevenn.train.dataload import unlabeled_atoms_to_graph -import logging - -torch_script_type = torch.jit._script.RecursiveScriptModule - - -class SevenNetCalculator(Calculator): - """Supporting properties: - 'free_energy', 'energy', 'forces', 'stress', 'energies' - free_energy equals energy. 'energies' stores atomic energy. - - Multi-GPU acceleration is not supported with ASE calculator. - You should use LAMMPS for the acceleration. - """ - - def __init__( - self, - model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', - file_type: str = 'checkpoint', - device: Union[torch.device, str] = 'auto', - modal: Optional[str] = None, - enable_cueq: bool = False, - sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info - **kwargs, - ): - """Initialize SevenNetCalculator. - - Parameters - ---------- - model: str | Path | AtomGraphSequential, default='7net-0' - Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or - path to the checkpoint, deployed model or the model itself - file_type: str, default='checkpoint' - one of 'checkpoint' | 'torchscript' | 'model_instance' - device: str | torch.device, default='auto' - if not given, use CUDA if available - modal: str | None, default=None - modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, - it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) - case insensitive - enable_cueq: bool, default=False - if True, use cuEquivariant to accelerate inference. - sevennet_config: dict | None, default=None - Not used, but can be used to carry meta information of this calculator - """ - print("&&& Initializing SevenNetCalculator") - super().__init__(**kwargs) - self.sevennet_config = None - - if isinstance(model, pathlib.PurePath): - model = str(model) - - allowed_file_types = ['checkpoint', 'torchscript', 'model_instance'] - file_type = file_type.lower() - if file_type not in allowed_file_types: - raise ValueError(f'file_type not in {allowed_file_types}') - - if enable_cueq and file_type in ['model_instance', 'torchscript']: - warnings.warn( - 'file_type should be checkpoint to enable cueq. cueq set to False' - ) - enable_cueq = False - - if isinstance(device, str): # TODO: do we really need this? - if device == 'auto': - self.device = torch.device( - 'cuda' if torch.cuda.is_available() else 'cpu' - ) - else: - self.device = torch.device(device) - else: - self.device = device - - if file_type == 'checkpoint' and isinstance(model, str): - cp = util.load_checkpoint(model) - - backend = 'e3nn' if not enable_cueq else 'cueq' - model_loaded = cp.build_model(backend) - model_loaded.set_is_batch_data(False) - - self.type_map = cp.config[KEY.TYPE_MAP] - self.cutoff = cp.config[KEY.CUTOFF] - self.sevennet_config = cp.config - - elif file_type == 'torchscript' and isinstance(model, str): - if modal: - raise NotImplementedError() - extra_dict = { - 'chemical_symbols_to_index': b'', - 'cutoff': b'', - 'num_species': b'', - 'model_type': b'', - 'version': b'', - 'dtype': b'', - 'time': b'', - } - model_loaded = torch.jit.load( - model, _extra_files=extra_dict, map_location=self.device - ) - chem_symbols = extra_dict['chemical_symbols_to_index'].decode('utf-8') - sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)} - self.type_map = { - sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) - } - self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) - - elif isinstance(model, AtomGraphSequential): - if model.type_map is None: - raise ValueError( - 'Model must have the type_map to be used with calculator' - ) - if model.cutoff == 0.0: - raise ValueError('Model cutoff seems not initialized') - model.eval_type_map = torch.tensor(True) # ? - model.set_is_batch_data(False) - model_loaded = model - self.type_map = model.type_map - self.cutoff = model.cutoff - else: - raise ValueError('Unexpected input combinations') - - if self.sevennet_config is None and sevennet_config is not None: - self.sevennet_config = sevennet_config - - self.model = model_loaded - - self.modal = None - if isinstance(self.model, AtomGraphSequential): - modal_map = self.model.modal_map - if modal_map: - modal_ava = list(modal_map.keys()) - if not modal: - raise ValueError(f'modal argument missing (avail: {modal_ava})') - elif modal not in modal_ava: - raise ValueError(f'unknown modal {modal} (not in {modal_ava})') - self.modal = modal - elif not self.model.modal_map and modal: - warnings.warn(f'modal={modal} is ignored as model has no modal_map') - - self.model.to(self.device) - self.model.eval() - self.implemented_properties = [ - 'free_energy', - 'energy', - 'forces', - 'stress', - 'energies', - ] - - def set_atoms(self, atoms): - # called by ase, when atoms.calc = calc - zs = tuple(set(atoms.get_atomic_numbers())) - for z in zs: - if z not in self.type_map: - sp = list(self.type_map.keys()) - raise ValueError( - f'Model do not know atomic number: {z}, (knows: {sp})' - ) - - def output_to_results(self, output): - energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item() - num_atoms = output['num_atoms'].item() - atomic_energies = output[KEY.ATOMIC_ENERGY].detach().cpu().numpy().flatten() - forces = output[KEY.PRED_FORCE].detach().cpu().numpy()[:num_atoms, :] - stress = np.array( - (-output[KEY.PRED_STRESS]) - .detach() - .cpu() - .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation - ) - # Store results - return { - 'free_energy': energy, - 'energy': energy, - 'energies': atomic_energies, - 'forces': forces, - 'stress': stress, - 'num_edges': output[KEY.EDGE_IDX].shape[1], - } - - def calculate(self, atoms=None, properties=None, system_changes=all_changes): - # call parent class to set necessary atom attributes - Calculator.calculate(self, atoms, properties, system_changes) - if atoms is None: - raise ValueError('No atoms to evaluate') - data = AtomGraphData.from_numpy_dict( - unlabeled_atoms_to_graph(atoms, self.cutoff) - ) - if self.modal: - data[KEY.DATA_MODALITY] = self.modal - - data.to(self.device) # type: ignore - - if isinstance(self.model, torch_script_type): - data[KEY.NODE_FEATURE] = torch.tensor( - [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], - dtype=torch.int64, - device=self.device, - ) - data[KEY.POS].requires_grad_(True) # backward compatibility - data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility - data = data.to_dict() - del data['data_info'] - - import logging - logging.debug(f"data: {data}") - # logging.debug(f"data[pos]: {data['pos']}") - # logging.debug(f"data[x]: {data['x']}") - logging.debug(f"data[cell_lattice_vectors]: {data['cell_lattice_vectors']}") - logging.debug(f"data[cell_volume]: {data['cell_volume']}") - output = self.model(data) - # logging.info(f"input: {data}") - # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") - # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") - # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") - self.results = self.output_to_results(output) - # logging.debug(f"results['energy'] = {self.results['energy']}") - # logging.debug(f"results['forces'] = {self.results['forces']}") - # logging.debug(f"results['stress'] = {self.results['stress']}") - - def predict_one(self, atoms): - if atoms is None: - raise ValueError('No atoms to evaluate') - data = AtomGraphData.from_numpy_dict( - unlabeled_atoms_to_graph(atoms, self.cutoff) - ) - if self.modal: - data[KEY.DATA_MODALITY] = self.modal - - data.to(self.device) # type: ignore - - if isinstance(self.model, torch_script_type): - data[KEY.NODE_FEATURE] = torch.tensor( - [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], - dtype=torch.int64, - device=self.device, - ) - data[KEY.POS].requires_grad_(True) # backward compatibility - data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility - data = data.to_dict() - del data['data_info'] - - return self.model(data) - - - - def predict(self, atoms_list, properties=None): - - # if len(atoms_list) == 1: - # output = self.predict_one(atoms_list[0]) - # predictions = {} - # predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).unsqueeze(0) - # predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).unsqueeze(0) - # voigt = (-output[KEY.PRED_STRESS])[[0, 1, 2, 4, 5, 3]].to(torch.float64).unsqueeze(0) - # stress_list = [] - # for i in range(voigt.shape[0]): - # stress_list.append(self._stress2tensor(voigt[i,:])) - # predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3) - # return predictions - - - if not atoms_list: - raise ValueError("Empty atoms_list provided") - - if not isinstance(atoms_list, list): - atoms_list = [atoms_list] - - # Convert atoms to graph data - graph_list = [] - for atoms in atoms_list: - data = AtomGraphData.from_numpy_dict( - unlabeled_atoms_to_graph(atoms, self.cutoff) - ) - if self.modal: - data[KEY.DATA_MODALITY] = self.modal - - if isinstance(self.model, torch_script_type): - data[KEY.NODE_FEATURE] = torch.tensor( - [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], - dtype=torch.int64, - device=self.device, - ) - data[KEY.POS].requires_grad_(True) # backward compatibility - data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility - - graph_list.append(data) - - # Process graphs based on model type - # was_batch_mode = True - if isinstance(self.model, AtomGraphSequential): - # was_batch_mode = self.model.is_batch_data - self.model.set_is_batch_data(True) - self.model.eval() - - # Batch the data if there are multiple atoms - from torch_geometric.loader.dataloader import Collater - batched_data = Collater(graph_list)(graph_list) - batched_data = batched_data.to(self.device) - - import logging - logging.debug(f"batched_data: {batched_data}") - # logging.debug(f"batched_data[pos]: {batched_data['pos']}") - # logging.debug(f"batched_data[x]: {batched_data['x']}") - logging.debug(f"batched_data[cell_lattice_vectors]: {batched_data['cell_lattice_vectors']}") - logging.debug(f"batched_data[cell_volume]: {batched_data['cell_volume']}") - # Run model on batched data - if isinstance(self.model, torch_script_type): - batched_dict = batched_data.to_dict() - if 'data_info' in batched_dict: - del batched_dict['data_info'] - output = self.model(batched_dict) - else: - output = self.model(batched_data) - - # Convert to list of individual outputs using util.to_atom_graph_list - # logging.info(f"input: {batched_data}") - # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") - # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") - # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") - - predictions = {} - predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).detach() - predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).detach() - voigt = (-output[KEY.PRED_STRESS])[:, [0, 1, 2, 4, 5, 3]].to(torch.float64).detach() - stress_list = [] - for i in range(voigt.shape[0]): - stress_list.append(self._stress2tensor(voigt[i,:])) - predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3).detach() - - # logging.debug(f"predictions['energy'] = {predictions['energy']}") - # logging.debug(f"predictions['forces'] = {predictions['forces']}") - # logging.debug(f"predictions['stress'] = {predictions['stress']}") - return predictions - - def _stress2tensor(self, stress): - tensor = torch.tensor( - [ - [stress[0], stress[5], stress[4]], - [stress[5], stress[1], stress[3]], - [stress[4], stress[3], stress[2]], - ], - device=self.device - ) - return tensor - - -class SevenNetD3Calculator(SumCalculator): - def __init__( - self, - model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', - file_type: str = 'checkpoint', - device: Union[torch.device, str] = 'auto', - sevennet_config: Optional[Any] = None, # hold meta information - damping_type: str = 'damp_bj', - functional_name: str = 'pbe', - vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au - cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au - batch_size=10, - **kwargs, - ): - """Initialize SevenNetD3Calculator. CUDA required. - - Parameters - ---------- - model: str | Path | AtomGraphSequential - Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or - path to the checkpoint, deployed model or the model itself - file_type: str, default='checkpoint' - one of 'checkpoint' | 'torchscript' | 'model_instance' - device: str | torch.device, default='auto' - if not given, use CUDA if available - modal: str | None, default=None - modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, - it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) - enable_cueq: bool, default=False - if True, use cuEquivariant to accelerate inference. - damping_type: str, default='damp_bj' - Damping type of D3, one of 'damp_bj' | 'damp_zero' - functional_name: str, default='pbe' - Target functional name of D3 parameters. - vdw_cutoff: float, default=9000 - vdw cutoff of D3 calculator in au - cn_cutoff: float, default=1600 - cn cutoff of D3 calculator in au - """ - self.d3_calc = D3Calculator( - damping_type=damping_type, - functional_name=functional_name, - vdw_cutoff=vdw_cutoff, - cn_cutoff=cn_cutoff, - **kwargs, - ) - - self.sevennet_calc = SevenNetCalculator( - model=model, - file_type=file_type, - device=device, - sevennet_config=sevennet_config, - **kwargs, - ) - - super().__init__([self.sevennet_calc, self.d3_calc]) - - self.device = device - self.d3_calcs = [] - for _ in range(batch_size): - self.d3_calcs.append( - D3Calculator( - damping_type=damping_type, - functional_name=functional_name, - vdw_cutoff=vdw_cutoff, - cn_cutoff=cn_cutoff, - **kwargs, - ) - ) - - - def predict(self, atoms_list): - """Predict the energy and forces for a list of atoms. - """ - # Call the predict method of the first calculator (SevenNetCalculator) - predictions = self.sevennet_calc.predict(atoms_list) - - energy_list = [] - forces_list = [] - stress_list = [] - predictions3d = {} - for i, atoms in enumerate(atoms_list): - prediction = self.d3_calcs[i].predict_one(atoms) - energy_list.append(torch.tensor(prediction['energy'])) - forces_list.append(torch.from_numpy(prediction['forces']).to(self.device)) - stress_list.append(self._stress2tensor(torch.from_numpy(prediction['stress']))) - - # Convert lists to tensors - predictions3d['energy'] = torch.stack(energy_list, dim=0).to(self.device) - predictions3d['forces'] = torch.cat(forces_list, dim=0).view(-1, 3) - predictions3d['stress'] = torch.stack(stress_list, dim=0).view(-1, 3, 3) - - predictions['energy'] += predictions3d['energy'].detach() - predictions['forces'] += predictions3d['forces'].detach() - predictions['stress'] += predictions3d['stress'].detach() - - return predictions - - def _stress2tensor(self, stress): - tensor = torch.tensor( - [ - # [stress[0], stress[3], stress[4]], - # [stress[3], stress[1], stress[5]], - # [stress[4], stress[5], stress[2]], - [stress[0], stress[5], stress[4]], - [stress[5], stress[1], stress[3]], - [stress[4], stress[3], stress[2]], - ], - device=self.device - ) - return tensor - - - -def _load(name: str) -> ctypes.CDLL: - from torch.utils.cpp_extension import LIB_EXT, _get_build_directory, load - - # Load the library from the candidate locations - - package_dir = os.path.dirname(os.path.abspath(__file__)) - try: - return ctypes.CDLL(os.path.join(package_dir, f'{name}{LIB_EXT}')) - except OSError: - pass - - cache_dir = _get_build_directory(name, verbose=False) - try: - return ctypes.CDLL(os.path.join(cache_dir, f'{name}{LIB_EXT}')) - except OSError: - pass - - # Compile the library if it is not found - - if os.access(package_dir, os.W_OK): - compile_dir = package_dir - else: - print('Warning: package directory is not writable. Using cache directory.') - compile_dir = cache_dir - - if 'TORCH_CUDA_ARCH_LIST' not in os.environ: - print('Warning: TORCH_CUDA_ARCH_LIST is not set.') - print('Warning: Use default CUDA architectures: 61, 70, 75, 80, 86, 89, 90') - os.environ['TORCH_CUDA_ARCH_LIST'] = '6.1;7.0;7.5;8.0;8.6;8.9;9.0' - - load( - name=name, - sources=[os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu')], - extra_cuda_cflags=['-O3', '--expt-relaxed-constexpr', '-fmad=false'], - build_directory=compile_dir, - verbose=True, - is_python_module=False, - ) - - return ctypes.CDLL(os.path.join(compile_dir, f'{name}{LIB_EXT}')) - - -class PairD3(ctypes.Structure): - pass # Opaque structure; only used as a pointer - - -class D3Calculator(Calculator): - """ASE calculator for accelerated D3 van der Waals (vdW) correction. - - Example: - from ase.calculators.mixing import SumCalculator - calc_1 = SevenNetCalculator() - calc_2 = D3Calculator() - return SumCalculator([calc_1, calc_2]) - - This calculator interfaces with the `libpaird3.so` library, - which is compiled by nvcc during the package installation. - If you encounter any errors, please verify - the installation process and the compilation options in `setup.py`. - Note: Multi-GPU parallel MD is not supported in this mode. - Note: Cffi could be used, but it was avoided to reduce dependencies. - """ - - # Here, free_energy = energy - implemented_properties = ['free_energy', 'energy', 'forces', 'stress'] - - def __init__( - self, - damping_type: str = 'damp_bj', # damp_bj, damp_zero - functional_name: str = 'pbe', # check the source code - vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au - cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au - **kwargs, - ): - super().__init__(**kwargs) - - if not torch.cuda.is_available(): - raise NotImplementedError('CPU + D3 is not implemented yet') - - self.rthr = vdw_cutoff - self.cnthr = cn_cutoff - self.damp_name = damping_type.lower() - self.func_name = functional_name.lower() - - if self.damp_name not in ['damp_bj', 'damp_zero']: - raise ValueError('Error: Invalid damping type.') - - self._lib = _load('pair_d3') - - self._lib.pair_init.restype = ctypes.POINTER(PairD3) - self.pair = self._lib.pair_init() - - self._lib.pair_set_atom.argtypes = [ - ctypes.POINTER(PairD3), # PairD3* pair - ctypes.c_int, # int natoms - ctypes.c_int, # int ntypes - ctypes.POINTER(ctypes.c_int), # int* types - ctypes.POINTER(ctypes.c_double), # double* x - ] - self._lib.pair_set_atom.restype = None - - self._lib.pair_set_domain.argtypes = [ - ctypes.POINTER(PairD3), # PairD3* pair - ctypes.c_int, # int xperiodic - ctypes.c_int, # int yperiodic - ctypes.c_int, # int zperiodic - ctypes.POINTER(ctypes.c_double), # double* boxlo - ctypes.POINTER(ctypes.c_double), # double* boxhi - ctypes.c_double, # double xy - ctypes.c_double, # double xz - ctypes.c_double, # double yz - ] - self._lib.pair_set_domain.restype = None - - self._lib.pair_run_settings.argtypes = [ - ctypes.POINTER(PairD3), # PairD3* pair - ctypes.c_double, # double rthr - ctypes.c_double, # double cnthr - ctypes.c_char_p, # const char* damp_name - ctypes.c_char_p, # const char* func_name - ] - self._lib.pair_run_settings.restype = None - - self._lib.pair_run_coeff.argtypes = [ - ctypes.POINTER(PairD3), # PairD3* pair - ctypes.POINTER(ctypes.c_int), # int* atomic_numbers - ] - self._lib.pair_run_coeff.restype = None - - self._lib.pair_run_compute.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_run_compute.restype = None - - self._lib.pair_get_energy.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_get_energy.restype = ctypes.c_double - - self._lib.pair_get_force.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_get_force.restype = ctypes.POINTER(ctypes.c_double) - - self._lib.pair_get_stress.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_get_stress.restype = ctypes.POINTER(ctypes.c_double * 6) - - self._lib.pair_fin.argtypes = [ctypes.POINTER(PairD3)] - self._lib.pair_fin.restype = None - - def _idx_to_numbers(self, Z_of_atoms): - unique_numbers = list(dict.fromkeys(Z_of_atoms)) - return unique_numbers - - def _idx_to_types(self, Z_of_atoms): - unique_numbers = list(dict.fromkeys(Z_of_atoms)) - mapping = {num: idx + 1 for idx, num in enumerate(unique_numbers)} - atom_types = [mapping[num] for num in Z_of_atoms] - return atom_types - - def _convert_domain_ase2lammps(self, cell): - qtrans, ltrans = np.linalg.qr(cell.T, mode='complete') - lammps_cell = ltrans.T - signs = np.sign(np.diag(lammps_cell)) - lammps_cell = lammps_cell * signs - qtrans = qtrans * signs - lammps_cell = lammps_cell[(0, 1, 2, 1, 2, 2), (0, 1, 2, 0, 0, 1)] - rotator = qtrans.T - return lammps_cell, rotator - - def _stress2tensor(self, stress): - tensor = np.array( - [ - [stress[0], stress[3], stress[4]], - [stress[3], stress[1], stress[5]], - [stress[4], stress[5], stress[2]], - ] - ) - return tensor - - def _tensor2stress(self, tensor): - stress = -np.array( - [ - tensor[0, 0], - tensor[1, 1], - tensor[2, 2], - tensor[1, 2], - tensor[0, 2], - tensor[0, 1], - ] - ) - return stress - - def calculate(self, atoms=None, properties=None, system_changes=all_changes): - Calculator.calculate(self, atoms, properties, system_changes) - if atoms is None: - raise ValueError('No atoms to evaluate') - - if atoms.get_cell().sum() == 0: - print( - 'Warning: D3Calculator requires a cell.\n' - 'Warning: An orthogonal cell large enough is generated.' - ) - positions = atoms.get_positions() - min_pos = positions.min(axis=0) - max_pos = positions.max(axis=0) - max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 - - cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin - cell = np.eye(3) * cell_lengths - - atoms.set_cell(cell) - atoms.set_pbc([True, True, True]) # for minus positions - - cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) - - Z_of_atoms = atoms.get_atomic_numbers() - natoms = len(atoms) - ntypes = len(set(Z_of_atoms)) - types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) - - positions = atoms.get_positions() @ rotator.T - x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) - - atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) - - boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) - boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) - xy = cell[3] - xz = cell[4] - yz = cell[5] - xperiodic, yperiodic, zperiodic = atoms.get_pbc() - - lib = self._lib - assert lib is not None - lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) - - xperiodic = xperiodic.astype(int) - yperiodic = yperiodic.astype(int) - zperiodic = zperiodic.astype(int) - lib.pair_set_domain( - self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz - ) - - lib.pair_run_settings( - self.pair, - self.rthr, - self.cnthr, - self.damp_name.encode('utf-8'), - self.func_name.encode('utf-8'), - ) - - lib.pair_run_coeff(self.pair, atomic_numbers) - lib.pair_run_compute(self.pair) - - result_E = lib.pair_get_energy(self.pair) - - result_F_ptr = lib.pair_get_force(self.pair) - result_F_size = natoms * 3 - result_F = np.ctypeslib.as_array( - result_F_ptr, shape=(result_F_size,) - ).reshape((natoms, 3)) - result_F = np.array(result_F) - result_F = result_F @ rotator - - result_S = lib.pair_get_stress(self.pair) - result_S = np.array(result_S.contents) - result_S = ( - self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) - / atoms.get_volume() - ) - - self.results = { - 'free_energy': result_E, - 'energy': result_E, - 'forces': result_F, - 'stress': result_S, - } - - def predict_one(self, atoms): - atoms = atoms.copy() - if atoms is None: - raise ValueError('No atoms to evaluate') - - if atoms.get_cell().sum() == 0: - print( - 'Warning: D3Calculator requires a cell.\n' - 'Warning: An orthogonal cell large enough is generated.' - ) - positions = atoms.get_positions() - min_pos = positions.min(axis=0) - max_pos = positions.max(axis=0) - max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 - - cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin - cell = np.eye(3) * cell_lengths - - atoms.set_cell(cell) - atoms.set_pbc([True, True, True]) # for minus positions - - cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) - - Z_of_atoms = atoms.get_atomic_numbers() - natoms = len(atoms) - ntypes = len(set(Z_of_atoms)) - types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) - - positions = atoms.get_positions() @ rotator.T - x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) - - atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) - - boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) - boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) - xy = cell[3] - xz = cell[4] - yz = cell[5] - xperiodic, yperiodic, zperiodic = atoms.get_pbc() - - lib = self._lib - assert lib is not None - lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) - - xperiodic = xperiodic.astype(int) - yperiodic = yperiodic.astype(int) - zperiodic = zperiodic.astype(int) - lib.pair_set_domain( - self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz - ) - - lib.pair_run_settings( - self.pair, - self.rthr, - self.cnthr, - self.damp_name.encode('utf-8'), - self.func_name.encode('utf-8'), - ) - - lib.pair_run_coeff(self.pair, atomic_numbers) - lib.pair_run_compute(self.pair) - - result_E = lib.pair_get_energy(self.pair) - - result_F_ptr = lib.pair_get_force(self.pair) - result_F_size = natoms * 3 - result_F = np.ctypeslib.as_array( - result_F_ptr, shape=(result_F_size,) - ).reshape((natoms, 3)) - result_F = np.array(result_F) - result_F = result_F @ rotator - - result_S = lib.pair_get_stress(self.pair) - result_S = np.array(result_S.contents) - result_S = ( - self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) - / atoms.get_volume() - ) - - prediction = { - 'free_energy': float(result_E), - 'energy': float(result_E), - 'forces': result_F.copy(), - 'stress': result_S.copy(), - } - - return prediction - - - def __del__(self): - if self._lib is not None: - self._lib.pair_fin(self.pair) - self._lib = None - self.pair = None - +import ctypes +import os +import pathlib +import warnings +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.jit +import torch.jit._script +from ase.calculators.calculator import Calculator, all_changes +from ase.calculators.mixing import SumCalculator +from ase.data import chemical_symbols + +import sevenn._keys as KEY +import sevenn.util as util +from sevenn.atom_graph_data import AtomGraphData +from sevenn.nn.sequential import AtomGraphSequential +from sevenn.train.dataload import unlabeled_atoms_to_graph +import logging + +torch_script_type = torch.jit._script.RecursiveScriptModule + + +class SevenNetCalculator(Calculator): + """Supporting properties: + 'free_energy', 'energy', 'forces', 'stress', 'energies' + free_energy equals energy. 'energies' stores atomic energy. + + Multi-GPU acceleration is not supported with ASE calculator. + You should use LAMMPS for the acceleration. + """ + + def __init__( + self, + model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', + file_type: str = 'checkpoint', + device: Union[torch.device, str] = 'auto', + modal: Optional[str] = None, + enable_cueq: bool = False, + sevennet_config: Optional[Dict] = None, # Not used in logic, just meta info + **kwargs, + ): + """Initialize SevenNetCalculator. + + Parameters + ---------- + model: str | Path | AtomGraphSequential, default='7net-0' + Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or + path to the checkpoint, deployed model or the model itself + file_type: str, default='checkpoint' + one of 'checkpoint' | 'torchscript' | 'model_instance' + device: str | torch.device, default='auto' + if not given, use CUDA if available + modal: str | None, default=None + modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, + it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) + case insensitive + enable_cueq: bool, default=False + if True, use cuEquivariant to accelerate inference. + sevennet_config: dict | None, default=None + Not used, but can be used to carry meta information of this calculator + """ + print("&&& Initializing SevenNetCalculator") + super().__init__(**kwargs) + self.sevennet_config = None + + if isinstance(model, pathlib.PurePath): + model = str(model) + + allowed_file_types = ['checkpoint', 'torchscript', 'model_instance'] + file_type = file_type.lower() + if file_type not in allowed_file_types: + raise ValueError(f'file_type not in {allowed_file_types}') + + if enable_cueq and file_type in ['model_instance', 'torchscript']: + warnings.warn( + 'file_type should be checkpoint to enable cueq. cueq set to False' + ) + enable_cueq = False + + if isinstance(device, str): # TODO: do we really need this? + if device == 'auto': + self.device = torch.device( + 'cuda' if torch.cuda.is_available() else 'cpu' + ) + else: + self.device = torch.device(device) + else: + self.device = device + + if file_type == 'checkpoint' and isinstance(model, str): + cp = util.load_checkpoint(model) + + backend = 'e3nn' if not enable_cueq else 'cueq' + model_loaded = cp.build_model(backend) + model_loaded.set_is_batch_data(False) + + self.type_map = cp.config[KEY.TYPE_MAP] + self.cutoff = cp.config[KEY.CUTOFF] + self.sevennet_config = cp.config + + elif file_type == 'torchscript' and isinstance(model, str): + if modal: + raise NotImplementedError() + extra_dict = { + 'chemical_symbols_to_index': b'', + 'cutoff': b'', + 'num_species': b'', + 'model_type': b'', + 'version': b'', + 'dtype': b'', + 'time': b'', + } + model_loaded = torch.jit.load( + model, _extra_files=extra_dict, map_location=self.device + ) + chem_symbols = extra_dict['chemical_symbols_to_index'].decode('utf-8') + sym_to_num = {sym: n for n, sym in enumerate(chemical_symbols)} + self.type_map = { + sym_to_num[sym]: i for i, sym in enumerate(chem_symbols.split()) + } + self.cutoff = float(extra_dict['cutoff'].decode('utf-8')) + + elif isinstance(model, AtomGraphSequential): + if model.type_map is None: + raise ValueError( + 'Model must have the type_map to be used with calculator' + ) + if model.cutoff == 0.0: + raise ValueError('Model cutoff seems not initialized') + model.eval_type_map = torch.tensor(True) # ? + model.set_is_batch_data(False) + model_loaded = model + self.type_map = model.type_map + self.cutoff = model.cutoff + else: + raise ValueError('Unexpected input combinations') + + if self.sevennet_config is None and sevennet_config is not None: + self.sevennet_config = sevennet_config + + self.model = model_loaded + + self.modal = None + if isinstance(self.model, AtomGraphSequential): + modal_map = self.model.modal_map + if modal_map: + modal_ava = list(modal_map.keys()) + if not modal: + raise ValueError(f'modal argument missing (avail: {modal_ava})') + elif modal not in modal_ava: + raise ValueError(f'unknown modal {modal} (not in {modal_ava})') + self.modal = modal + elif not self.model.modal_map and modal: + warnings.warn(f'modal={modal} is ignored as model has no modal_map') + + self.model.to(self.device) + self.model.eval() + self.implemented_properties = [ + 'free_energy', + 'energy', + 'forces', + 'stress', + 'energies', + ] + + def set_atoms(self, atoms): + # called by ase, when atoms.calc = calc + zs = tuple(set(atoms.get_atomic_numbers())) + for z in zs: + if z not in self.type_map: + sp = list(self.type_map.keys()) + raise ValueError( + f'Model do not know atomic number: {z}, (knows: {sp})' + ) + + def output_to_results(self, output): + energy = output[KEY.PRED_TOTAL_ENERGY].detach().cpu().item() + num_atoms = output['num_atoms'].item() + atomic_energies = output[KEY.ATOMIC_ENERGY].detach().cpu().numpy().flatten() + forces = output[KEY.PRED_FORCE].detach().cpu().numpy()[:num_atoms, :] + stress = np.array( + (-output[KEY.PRED_STRESS]) + .detach() + .cpu() + .numpy()[[0, 1, 2, 4, 5, 3]] # as voigt notation + ) + # Store results + return { + 'free_energy': energy, + 'energy': energy, + 'energies': atomic_energies, + 'forces': forces, + 'stress': stress, + 'num_edges': output[KEY.EDGE_IDX].shape[1], + } + + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + # call parent class to set necessary atom attributes + Calculator.calculate(self, atoms, properties, system_changes) + if atoms is None: + raise ValueError('No atoms to evaluate') + data = AtomGraphData.from_numpy_dict( + unlabeled_atoms_to_graph(atoms, self.cutoff) + ) + if self.modal: + data[KEY.DATA_MODALITY] = self.modal + + data.to(self.device) # type: ignore + + if isinstance(self.model, torch_script_type): + data[KEY.NODE_FEATURE] = torch.tensor( + [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], + dtype=torch.int64, + device=self.device, + ) + data[KEY.POS].requires_grad_(True) # backward compatibility + data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility + data = data.to_dict() + del data['data_info'] + + import logging + logging.debug(f"data: {data}") + # logging.debug(f"data[pos]: {data['pos']}") + # logging.debug(f"data[x]: {data['x']}") + logging.debug(f"data[cell_lattice_vectors]: {data['cell_lattice_vectors']}") + logging.debug(f"data[cell_volume]: {data['cell_volume']}") + output = self.model(data) + # logging.info(f"input: {data}") + # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") + # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") + # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") + self.results = self.output_to_results(output) + # logging.debug(f"results['energy'] = {self.results['energy']}") + # logging.debug(f"results['forces'] = {self.results['forces']}") + # logging.debug(f"results['stress'] = {self.results['stress']}") + + def predict_one(self, atoms): + if atoms is None: + raise ValueError('No atoms to evaluate') + data = AtomGraphData.from_numpy_dict( + unlabeled_atoms_to_graph(atoms, self.cutoff) + ) + if self.modal: + data[KEY.DATA_MODALITY] = self.modal + + data.to(self.device) # type: ignore + + if isinstance(self.model, torch_script_type): + data[KEY.NODE_FEATURE] = torch.tensor( + [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], + dtype=torch.int64, + device=self.device, + ) + data[KEY.POS].requires_grad_(True) # backward compatibility + data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility + data = data.to_dict() + del data['data_info'] + + return self.model(data) + + + + def predict(self, atoms_list, properties=None): + + # if len(atoms_list) == 1: + # output = self.predict_one(atoms_list[0]) + # predictions = {} + # predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).unsqueeze(0) + # predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).unsqueeze(0) + # voigt = (-output[KEY.PRED_STRESS])[[0, 1, 2, 4, 5, 3]].to(torch.float64).unsqueeze(0) + # stress_list = [] + # for i in range(voigt.shape[0]): + # stress_list.append(self._stress2tensor(voigt[i,:])) + # predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3) + # return predictions + + + if not atoms_list: + raise ValueError("Empty atoms_list provided") + + if not isinstance(atoms_list, list): + atoms_list = [atoms_list] + + # Convert atoms to graph data + graph_list = [] + for atoms in atoms_list: + data = AtomGraphData.from_numpy_dict( + unlabeled_atoms_to_graph(atoms, self.cutoff) + ) + if self.modal: + data[KEY.DATA_MODALITY] = self.modal + + if isinstance(self.model, torch_script_type): + data[KEY.NODE_FEATURE] = torch.tensor( + [self.type_map[z.item()] for z in data[KEY.NODE_FEATURE]], + dtype=torch.int64, + device=self.device, + ) + data[KEY.POS].requires_grad_(True) # backward compatibility + data[KEY.EDGE_VEC].requires_grad_(True) # backward compatibility + + graph_list.append(data) + + # Process graphs based on model type + # was_batch_mode = True + if isinstance(self.model, AtomGraphSequential): + # was_batch_mode = self.model.is_batch_data + self.model.set_is_batch_data(True) + self.model.eval() + + # Batch the data if there are multiple atoms + from torch_geometric.loader.dataloader import Collater + batched_data = Collater(graph_list)(graph_list) + batched_data = batched_data.to(self.device) + + import logging + logging.debug(f"batched_data: {batched_data}") + # logging.debug(f"batched_data[pos]: {batched_data['pos']}") + # logging.debug(f"batched_data[x]: {batched_data['x']}") + logging.debug(f"batched_data[cell_lattice_vectors]: {batched_data['cell_lattice_vectors']}") + logging.debug(f"batched_data[cell_volume]: {batched_data['cell_volume']}") + # Run model on batched data + if isinstance(self.model, torch_script_type): + batched_dict = batched_data.to_dict() + if 'data_info' in batched_dict: + del batched_dict['data_info'] + output = self.model(batched_dict) + else: + output = self.model(batched_data) + + # Convert to list of individual outputs using util.to_atom_graph_list + # logging.info(f"input: {batched_data}") + # logging.info(f"output[{KEY.PRED_TOTAL_ENERGY}] = {output[KEY.PRED_TOTAL_ENERGY]}") + # logging.info(f"output[{KEY.PRED_FORCE}] = {output[KEY.PRED_FORCE]}") + # logging.info(f"output[{KEY.PRED_STRESS}] = {output[KEY.PRED_STRESS]}") + + predictions = {} + predictions['energy'] = output[KEY.PRED_TOTAL_ENERGY].to(torch.float64).detach() + predictions['forces'] = output[KEY.PRED_FORCE].to(torch.float64).detach() + voigt = (-output[KEY.PRED_STRESS])[:, [0, 1, 2, 4, 5, 3]].to(torch.float64).detach() + stress_list = [] + for i in range(voigt.shape[0]): + stress_list.append(self._stress2tensor(voigt[i,:])) + predictions['stress'] = torch.stack(stress_list, dim=0).view(-1,3,3).detach() + + # logging.debug(f"predictions['energy'] = {predictions['energy']}") + # logging.debug(f"predictions['forces'] = {predictions['forces']}") + # logging.debug(f"predictions['stress'] = {predictions['stress']}") + return predictions + + def _stress2tensor(self, stress): + tensor = torch.tensor( + [ + [stress[0], stress[5], stress[4]], + [stress[5], stress[1], stress[3]], + [stress[4], stress[3], stress[2]], + ], + device=self.device + ) + return tensor + + +class SevenNetD3Calculator(SumCalculator): + def __init__( + self, + model: Union[str, pathlib.PurePath, AtomGraphSequential] = '7net-0', + file_type: str = 'checkpoint', + device: Union[torch.device, str] = 'auto', + sevennet_config: Optional[Any] = None, # hold meta information + damping_type: str = 'damp_bj', + functional_name: str = 'pbe', + vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au + cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au + batch_size=10, + **kwargs, + ): + """Initialize SevenNetD3Calculator. CUDA required. + + Parameters + ---------- + model: str | Path | AtomGraphSequential + Name of pretrained models (7net-mf-ompa, 7net-omat, 7net-l3i5, 7net-0) or + path to the checkpoint, deployed model or the model itself + file_type: str, default='checkpoint' + one of 'checkpoint' | 'torchscript' | 'model_instance' + device: str | torch.device, default='auto' + if not given, use CUDA if available + modal: str | None, default=None + modal (fidelity) if given model is multi-modal model. for 7net-mf-ompa, + it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' (OMat24) + enable_cueq: bool, default=False + if True, use cuEquivariant to accelerate inference. + damping_type: str, default='damp_bj' + Damping type of D3, one of 'damp_bj' | 'damp_zero' + functional_name: str, default='pbe' + Target functional name of D3 parameters. + vdw_cutoff: float, default=9000 + vdw cutoff of D3 calculator in au + cn_cutoff: float, default=1600 + cn cutoff of D3 calculator in au + """ + self.d3_calc = D3Calculator( + damping_type=damping_type, + functional_name=functional_name, + vdw_cutoff=vdw_cutoff, + cn_cutoff=cn_cutoff, + **kwargs, + ) + + self.sevennet_calc = SevenNetCalculator( + model=model, + file_type=file_type, + device=device, + sevennet_config=sevennet_config, + **kwargs, + ) + + super().__init__([self.sevennet_calc, self.d3_calc]) + + self.device = device + self.d3_calcs = [] + for _ in range(batch_size): + self.d3_calcs.append( + D3Calculator( + damping_type=damping_type, + functional_name=functional_name, + vdw_cutoff=vdw_cutoff, + cn_cutoff=cn_cutoff, + **kwargs, + ) + ) + + + def predict(self, atoms_list): + """Predict the energy and forces for a list of atoms. + """ + # Call the predict method of the first calculator (SevenNetCalculator) + predictions = self.sevennet_calc.predict(atoms_list) + + energy_list = [] + forces_list = [] + stress_list = [] + predictions3d = {} + for i, atoms in enumerate(atoms_list): + prediction = self.d3_calcs[i].predict_one(atoms) + energy_list.append(torch.tensor(prediction['energy'])) + forces_list.append(torch.from_numpy(prediction['forces']).to(self.device)) + stress_list.append(self._stress2tensor(torch.from_numpy(prediction['stress']))) + + # Convert lists to tensors + predictions3d['energy'] = torch.stack(energy_list, dim=0).to(self.device) + predictions3d['forces'] = torch.cat(forces_list, dim=0).view(-1, 3) + predictions3d['stress'] = torch.stack(stress_list, dim=0).view(-1, 3, 3) + + predictions['energy'] += predictions3d['energy'].detach() + predictions['forces'] += predictions3d['forces'].detach() + predictions['stress'] += predictions3d['stress'].detach() + + return predictions + + def _stress2tensor(self, stress): + tensor = torch.tensor( + [ + # [stress[0], stress[3], stress[4]], + # [stress[3], stress[1], stress[5]], + # [stress[4], stress[5], stress[2]], + [stress[0], stress[5], stress[4]], + [stress[5], stress[1], stress[3]], + [stress[4], stress[3], stress[2]], + ], + device=self.device + ) + return tensor + + + +def _load(name: str) -> ctypes.CDLL: + from torch.utils.cpp_extension import LIB_EXT, _get_build_directory, load + + # Load the library from the candidate locations + + package_dir = os.path.dirname(os.path.abspath(__file__)) + try: + return ctypes.CDLL(os.path.join(package_dir, f'{name}{LIB_EXT}')) + except OSError: + pass + + cache_dir = _get_build_directory(name, verbose=False) + try: + return ctypes.CDLL(os.path.join(cache_dir, f'{name}{LIB_EXT}')) + except OSError: + pass + + # Compile the library if it is not found + + if os.access(package_dir, os.W_OK): + compile_dir = package_dir + else: + print('Warning: package directory is not writable. Using cache directory.') + compile_dir = cache_dir + + if 'TORCH_CUDA_ARCH_LIST' not in os.environ: + print('Warning: TORCH_CUDA_ARCH_LIST is not set.') + print('Warning: Use default CUDA architectures: 61, 70, 75, 80, 86, 89, 90') + os.environ['TORCH_CUDA_ARCH_LIST'] = '6.1;7.0;7.5;8.0;8.6;8.9;9.0' + + load( + name=name, + sources=[os.path.join(package_dir, 'pair_e3gnn', 'pair_d3_for_ase.cu')], + extra_cuda_cflags=['-O3', '--expt-relaxed-constexpr', '-fmad=false'], + build_directory=compile_dir, + verbose=True, + is_python_module=False, + ) + + return ctypes.CDLL(os.path.join(compile_dir, f'{name}{LIB_EXT}')) + + +class PairD3(ctypes.Structure): + pass # Opaque structure; only used as a pointer + + +class D3Calculator(Calculator): + """ASE calculator for accelerated D3 van der Waals (vdW) correction. + + Example: + from ase.calculators.mixing import SumCalculator + calc_1 = SevenNetCalculator() + calc_2 = D3Calculator() + return SumCalculator([calc_1, calc_2]) + + This calculator interfaces with the `libpaird3.so` library, + which is compiled by nvcc during the package installation. + If you encounter any errors, please verify + the installation process and the compilation options in `setup.py`. + Note: Multi-GPU parallel MD is not supported in this mode. + Note: Cffi could be used, but it was avoided to reduce dependencies. + """ + + # Here, free_energy = energy + implemented_properties = ['free_energy', 'energy', 'forces', 'stress'] + + def __init__( + self, + damping_type: str = 'damp_bj', # damp_bj, damp_zero + functional_name: str = 'pbe', # check the source code + vdw_cutoff: float = 9000, # au^2, 0.52917726 angstrom = 1 au + cn_cutoff: float = 1600, # au^2, 0.52917726 angstrom = 1 au + **kwargs, + ): + super().__init__(**kwargs) + + if not torch.cuda.is_available(): + raise NotImplementedError('CPU + D3 is not implemented yet') + + self.rthr = vdw_cutoff + self.cnthr = cn_cutoff + self.damp_name = damping_type.lower() + self.func_name = functional_name.lower() + + if self.damp_name not in ['damp_bj', 'damp_zero']: + raise ValueError('Error: Invalid damping type.') + + self._lib = _load('pair_d3') + + self._lib.pair_init.restype = ctypes.POINTER(PairD3) + self.pair = self._lib.pair_init() + + self._lib.pair_set_atom.argtypes = [ + ctypes.POINTER(PairD3), # PairD3* pair + ctypes.c_int, # int natoms + ctypes.c_int, # int ntypes + ctypes.POINTER(ctypes.c_int), # int* types + ctypes.POINTER(ctypes.c_double), # double* x + ] + self._lib.pair_set_atom.restype = None + + self._lib.pair_set_domain.argtypes = [ + ctypes.POINTER(PairD3), # PairD3* pair + ctypes.c_int, # int xperiodic + ctypes.c_int, # int yperiodic + ctypes.c_int, # int zperiodic + ctypes.POINTER(ctypes.c_double), # double* boxlo + ctypes.POINTER(ctypes.c_double), # double* boxhi + ctypes.c_double, # double xy + ctypes.c_double, # double xz + ctypes.c_double, # double yz + ] + self._lib.pair_set_domain.restype = None + + self._lib.pair_run_settings.argtypes = [ + ctypes.POINTER(PairD3), # PairD3* pair + ctypes.c_double, # double rthr + ctypes.c_double, # double cnthr + ctypes.c_char_p, # const char* damp_name + ctypes.c_char_p, # const char* func_name + ] + self._lib.pair_run_settings.restype = None + + self._lib.pair_run_coeff.argtypes = [ + ctypes.POINTER(PairD3), # PairD3* pair + ctypes.POINTER(ctypes.c_int), # int* atomic_numbers + ] + self._lib.pair_run_coeff.restype = None + + self._lib.pair_run_compute.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_run_compute.restype = None + + self._lib.pair_get_energy.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_get_energy.restype = ctypes.c_double + + self._lib.pair_get_force.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_get_force.restype = ctypes.POINTER(ctypes.c_double) + + self._lib.pair_get_stress.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_get_stress.restype = ctypes.POINTER(ctypes.c_double * 6) + + self._lib.pair_fin.argtypes = [ctypes.POINTER(PairD3)] + self._lib.pair_fin.restype = None + + def _idx_to_numbers(self, Z_of_atoms): + unique_numbers = list(dict.fromkeys(Z_of_atoms)) + return unique_numbers + + def _idx_to_types(self, Z_of_atoms): + unique_numbers = list(dict.fromkeys(Z_of_atoms)) + mapping = {num: idx + 1 for idx, num in enumerate(unique_numbers)} + atom_types = [mapping[num] for num in Z_of_atoms] + return atom_types + + def _convert_domain_ase2lammps(self, cell): + qtrans, ltrans = np.linalg.qr(cell.T, mode='complete') + lammps_cell = ltrans.T + signs = np.sign(np.diag(lammps_cell)) + lammps_cell = lammps_cell * signs + qtrans = qtrans * signs + lammps_cell = lammps_cell[(0, 1, 2, 1, 2, 2), (0, 1, 2, 0, 0, 1)] + rotator = qtrans.T + return lammps_cell, rotator + + def _stress2tensor(self, stress): + tensor = np.array( + [ + [stress[0], stress[3], stress[4]], + [stress[3], stress[1], stress[5]], + [stress[4], stress[5], stress[2]], + ] + ) + return tensor + + def _tensor2stress(self, tensor): + stress = -np.array( + [ + tensor[0, 0], + tensor[1, 1], + tensor[2, 2], + tensor[1, 2], + tensor[0, 2], + tensor[0, 1], + ] + ) + return stress + + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + Calculator.calculate(self, atoms, properties, system_changes) + if atoms is None: + raise ValueError('No atoms to evaluate') + + if atoms.get_cell().sum() == 0: + print( + 'Warning: D3Calculator requires a cell.\n' + 'Warning: An orthogonal cell large enough is generated.' + ) + positions = atoms.get_positions() + min_pos = positions.min(axis=0) + max_pos = positions.max(axis=0) + max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 + + cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin + cell = np.eye(3) * cell_lengths + + atoms.set_cell(cell) + atoms.set_pbc([True, True, True]) # for minus positions + + cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) + + Z_of_atoms = atoms.get_atomic_numbers() + natoms = len(atoms) + ntypes = len(set(Z_of_atoms)) + types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) + + positions = atoms.get_positions() @ rotator.T + x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) + + atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) + + boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) + boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) + xy = cell[3] + xz = cell[4] + yz = cell[5] + xperiodic, yperiodic, zperiodic = atoms.get_pbc() + + lib = self._lib + assert lib is not None + lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) + + xperiodic = xperiodic.astype(int) + yperiodic = yperiodic.astype(int) + zperiodic = zperiodic.astype(int) + lib.pair_set_domain( + self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz + ) + + lib.pair_run_settings( + self.pair, + self.rthr, + self.cnthr, + self.damp_name.encode('utf-8'), + self.func_name.encode('utf-8'), + ) + + lib.pair_run_coeff(self.pair, atomic_numbers) + lib.pair_run_compute(self.pair) + + result_E = lib.pair_get_energy(self.pair) + + result_F_ptr = lib.pair_get_force(self.pair) + result_F_size = natoms * 3 + result_F = np.ctypeslib.as_array( + result_F_ptr, shape=(result_F_size,) + ).reshape((natoms, 3)) + result_F = np.array(result_F) + result_F = result_F @ rotator + + result_S = lib.pair_get_stress(self.pair) + result_S = np.array(result_S.contents) + result_S = ( + self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) + / atoms.get_volume() + ) + + self.results = { + 'free_energy': result_E, + 'energy': result_E, + 'forces': result_F, + 'stress': result_S, + } + + def predict_one(self, atoms): + atoms = atoms.copy() + if atoms is None: + raise ValueError('No atoms to evaluate') + + if atoms.get_cell().sum() == 0: + print( + 'Warning: D3Calculator requires a cell.\n' + 'Warning: An orthogonal cell large enough is generated.' + ) + positions = atoms.get_positions() + min_pos = positions.min(axis=0) + max_pos = positions.max(axis=0) + max_cutoff = np.sqrt(max(self.rthr, self.cnthr)) * 0.52917726 + + cell_lengths = max_pos - min_pos + max_cutoff + 1.0 # extra margin + cell = np.eye(3) * cell_lengths + + atoms.set_cell(cell) + atoms.set_pbc([True, True, True]) # for minus positions + + cell, rotator = self._convert_domain_ase2lammps(atoms.get_cell()) + + Z_of_atoms = atoms.get_atomic_numbers() + natoms = len(atoms) + ntypes = len(set(Z_of_atoms)) + types = (ctypes.c_int * natoms)(*self._idx_to_types(Z_of_atoms)) + + positions = atoms.get_positions() @ rotator.T + x_flat = (ctypes.c_double * (natoms * 3))(*positions.flatten()) + + atomic_numbers = (ctypes.c_int * ntypes)(*self._idx_to_numbers(Z_of_atoms)) + + boxlo = (ctypes.c_double * 3)(0.0, 0.0, 0.0) + boxhi = (ctypes.c_double * 3)(cell[0], cell[1], cell[2]) + xy = cell[3] + xz = cell[4] + yz = cell[5] + xperiodic, yperiodic, zperiodic = atoms.get_pbc() + + lib = self._lib + assert lib is not None + lib.pair_set_atom(self.pair, natoms, ntypes, types, x_flat) + + xperiodic = xperiodic.astype(int) + yperiodic = yperiodic.astype(int) + zperiodic = zperiodic.astype(int) + lib.pair_set_domain( + self.pair, xperiodic, yperiodic, zperiodic, boxlo, boxhi, xy, xz, yz + ) + + lib.pair_run_settings( + self.pair, + self.rthr, + self.cnthr, + self.damp_name.encode('utf-8'), + self.func_name.encode('utf-8'), + ) + + lib.pair_run_coeff(self.pair, atomic_numbers) + lib.pair_run_compute(self.pair) + + result_E = lib.pair_get_energy(self.pair) + + result_F_ptr = lib.pair_get_force(self.pair) + result_F_size = natoms * 3 + result_F = np.ctypeslib.as_array( + result_F_ptr, shape=(result_F_size,) + ).reshape((natoms, 3)) + result_F = np.array(result_F) + result_F = result_F @ rotator + + result_S = lib.pair_get_stress(self.pair) + result_S = np.array(result_S.contents) + result_S = ( + self._tensor2stress(rotator.T @ self._stress2tensor(result_S) @ rotator) + / atoms.get_volume() + ) + + prediction = { + 'free_energy': float(result_E), + 'energy': float(result_E), + 'forces': result_F.copy(), + 'stress': result_S.copy(), + } + + return prediction + + + def __del__(self): + if self._lib is not None: + self._lib.pair_fin(self.pair) + self._lib = None + self.pair = None + diff --git a/mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py b/mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py index 859cde2..9d8e284 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/checkpoint.py @@ -1,552 +1,552 @@ -import os -import pathlib -import uuid -import warnings -from copy import deepcopy -from datetime import datetime -from typing import Any, Dict, Optional, Union - -import pandas as pd -from packaging.version import Version -from torch import Tensor -from torch import load as torch_load - -import sevenn -import sevenn._const as consts -import sevenn._keys as KEY -import sevenn.scripts.backward_compatibility as compat -from sevenn import model_build -from sevenn.nn.scale import get_resolved_shift_scale -from sevenn.nn.sequential import AtomGraphSequential - - -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): - import numpy as np - - def acl(a, b, rtol=rtol, atol=atol): - return np.allclose(a, b, rtol=rtol, atol=atol) - - assert len(atoms1) == len(atoms2) - assert acl(atoms1.get_cell(), atoms2.get_cell()) - assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) - assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) - assert acl( - atoms1.get_stress(voigt=False), - atoms2.get_stress(voigt=False), - rtol * 10, - atol * 10, - ) - # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) - - -def copy_state_dict(state_dict) -> dict: - if isinstance(state_dict, dict): - return {key: copy_state_dict(value) for key, value in state_dict.items()} - elif isinstance(state_dict, list): - return [copy_state_dict(item) for item in state_dict] # type: ignore - elif isinstance(state_dict, Tensor): - return state_dict.clone() # type: ignore - else: - # For non-tensor values (e.g., scalars, None), return as-is - return state_dict - - -def _config_cp_routine(config): - cp_ver = Version(config.get('version', None)) - this_ver = Version(sevenn.__version__) - if cp_ver > this_ver: - warnings.warn(f'The checkpoint version ({cp_ver}) is newer than this source' - f'({this_ver}). This may cause unexpected behaviors') - - defaults = {**consts.model_defaults(config)} - config = compat.patch_old_config(config) # type: ignore - - scaler = model_build.init_shift_scale(config) - shift, scale = get_resolved_shift_scale( - scaler, config.get(KEY.TYPE_MAP), config.get(KEY.MODAL_MAP, None) - ) - config['shift'] = shift - config['scale'] = scale - - for k, v in defaults.items(): - if k in config: - continue - if os.getenv('SEVENN_DEBUG', False): - warnings.warn(f'{k} not in config, use default value {v}', UserWarning) - config[k] = v - - for k, v in config.items(): - if isinstance(v, Tensor): - config[k] = v.cpu() - return config - - -def _convert_e3nn_and_cueq(stct_src, stct_dst, src_config, from_cueq): - """ - manually check keys and assert if something unexpected happens - """ - n_layer = src_config['num_convolution_layer'] - - linear_module_names = [ - 'onehot_to_feature_x', - 'reduce_input_to_hidden', - 'reduce_hidden_to_energy', - ] - convolution_module_names = [] - fc_tensor_product_module_names = [] - for i in range(n_layer): - linear_module_names.append(f'{i}_self_interaction_1') - linear_module_names.append(f'{i}_self_interaction_2') - if src_config.get(KEY.SELF_CONNECTION_TYPE) == 'linear': - linear_module_names.append(f'{i}_self_connection_intro') - elif src_config.get(KEY.SELF_CONNECTION_TYPE) == 'nequip': - fc_tensor_product_module_names.append(f'{i}_self_connection_intro') - convolution_module_names.append(f'{i}_convolution') - - # Rule: those keys can be safely ignored before state dict load, - # except for linear.bias. This should be aborted in advance to - # this function. Others are not parameters but constants. - cue_only_linear_followers = ['linear.f.tp.f_fx.module.c'] - e3nn_only_linear_followers = ['linear.bias', 'linear.output_mask'] - ignores_in_linear = cue_only_linear_followers + e3nn_only_linear_followers - - cue_only_conv_followers = [ - 'convolution.f.tp.f_fx.module.c', - 'convolution.f.tp.module.module.f.module.module._f.data', - ] - e3nn_only_conv_followers = [ - 'convolution._compiled_main_left_right._w3j', - 'convolution.weight', - 'convolution.output_mask', - ] - ignores_in_conv = cue_only_conv_followers + e3nn_only_conv_followers - - cue_only_fc_followers = ['fc_tensor_product.f.tp.f_fx.module.c'] - e3nn_only_fc_followers = [ - 'fc_tensor_product.output_mask', - ] - ignores_in_fc = cue_only_fc_followers + e3nn_only_fc_followers - - updated_keys = [] - for k, v in stct_src.items(): - module_name = k.split('.')[0] - flag = False - if module_name in linear_module_names: - for ignore in ignores_in_linear: - if '.'.join([module_name, ignore]) in k: - flag = True - break - if not flag and k == '.'.join([module_name, 'linear.weight']): - updated_keys.append(k) - stct_dst[k] = v.clone().reshape(stct_dst[k].shape) - flag = True - assert flag, f'Unexpected key from linear: {k}' - elif module_name in convolution_module_names: - for ignore in ignores_in_conv: - if '.'.join([module_name, ignore]) in k: - flag = True - break - if not flag and ( - k.startswith(f'{module_name}.weight_nn') - or k == '.'.join([module_name, 'denominator']) - ): - updated_keys.append(k) - stct_dst[k] = v.clone().reshape(stct_dst[k].shape) - flag = True - assert flag, f'Unexpected key from linear: {k}' - elif module_name in fc_tensor_product_module_names: - for ignore in ignores_in_fc: - if '.'.join([module_name, ignore]) in k: - flag = True - break - if not flag and k == '.'.join([module_name, 'fc_tensor_product.weight']): - updated_keys.append(k) - stct_dst[k] = v.clone().reshape(stct_dst[k].shape) - flag = True - assert flag, f'Unexpected key from fc tensor product: {k}' - else: - # assert k in stct_dst - updated_keys.append(k) - stct_dst[k] = v.clone().reshape(stct_dst[k].shape) - - return stct_dst - - -class SevenNetCheckpoint: - """ - Tool box for checkpoint processed from SevenNet. - """ - - def __init__(self, checkpoint_path: Union[pathlib.Path, str]): - self._checkpoint_path = os.path.abspath(checkpoint_path) - self._config = None - self._epoch = None - self._model_state_dict = None - self._optimizer_state_dict = None - self._scheduler_state_dict = None - self._hash = None - self._time = None - - self._loaded = False - - def __repr__(self) -> str: - cfg = self.config # just alias - if len(cfg) == 0: - return '' - dct = { - 'Sevennet version': cfg.get('version', 'Not found'), - 'When': self.time, - 'Hash': self.hash, - 'Cutoff': cfg.get('cutoff'), - 'Channel': cfg.get('channel'), - 'Lmax': cfg.get('lmax'), - 'Group (parity)': 'O3' if cfg.get('is_parity') else 'SO3', - 'Interaction layers': cfg.get('num_convolution_layer'), - 'Self connection type': cfg.get('self_connection_type', 'nequip'), - 'Last epoch': self.epoch, - 'Elements': len(cfg.get('chemical_species', [])), - } - if cfg.get('use_modality', False): - dct['Modality'] = ', '.join(list(cfg.get('_modal_map', {}).keys())) - - df = pd.DataFrame.from_dict([dct]).T # type: ignore - df.columns = [''] - return df.to_string() - - @property - def checkpoint_path(self) -> str: - return str(self._checkpoint_path) - - @property - def config(self) -> Dict[str, Any]: - if not self._loaded: - self._load() - assert isinstance(self._config, dict) - return deepcopy(self._config) - - @property - def model_state_dict(self) -> Dict[str, Any]: - if not self._loaded: - self._load() - assert isinstance(self._model_state_dict, dict) - return copy_state_dict(self._model_state_dict) - - @property - def optimizer_state_dict(self) -> Dict[str, Any]: - if not self._loaded: - self._load() - assert isinstance(self._optimizer_state_dict, dict) - return copy_state_dict(self._optimizer_state_dict) - - @property - def scheduler_state_dict(self) -> Dict[str, Any]: - if not self._loaded: - self._load() - assert isinstance(self._scheduler_state_dict, dict) - return copy_state_dict(self._scheduler_state_dict) - - @property - def epoch(self) -> Optional[int]: - if not self._loaded: - self._load() - return self._epoch - - @property - def time(self) -> str: - if not self._loaded: - self._load() - assert isinstance(self._time, str) - return self._time - - @property - def hash(self) -> str: - if not self._loaded: - self._load() - assert isinstance(self._hash, str) - return self._hash - - def _load(self) -> None: - assert not self._loaded - cp_path = self.checkpoint_path # just alias - - cp = torch_load(cp_path, weights_only=False, map_location='cpu') - self._config_original = cp.get('config', {}) - self._model_state_dict = cp.get('model_state_dict', {}) - self._optimizer_state_dict = cp.get('optimizer_state_dict', {}) - self._scheduler_state_dict = cp.get('scheduler_state_dict', {}) - self._epoch = cp.get('epoch', None) - self._time = cp.get('time', 'Not found') - self._hash = cp.get('hash', 'Not found') - - if len(self._config_original) == 0: - warnings.warn(f'config is not found from {cp_path}') - self._config = {} - else: - self._config = _config_cp_routine(self._config_original) - - if len(self._model_state_dict) == 0: - warnings.warn(f'model_state_dict is not found from {cp_path}') - - self._loaded = True - - def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential: - from .model_build import build_E3_equivariant_model - - use_cue = not backend or backend.lower() in ['cue', 'cueq'] - try: - cp_using_cue = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use'] - except KeyError: - cp_using_cue = False - - if (not backend) or (use_cue == cp_using_cue): - # backend not given, or checkpoint backend is same as requested - model = build_E3_equivariant_model(self.config) - state_dict = compat.patch_state_dict_if_old( - self.model_state_dict, self.config, model - ) - else: - cfg_new = self.config - cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': use_cue} - model = build_E3_equivariant_model(cfg_new) - stct_src = compat.patch_state_dict_if_old( - self.model_state_dict, self.config, model - ) - state_dict = _convert_e3nn_and_cueq( - stct_src, model.state_dict(), self.config, from_cueq=cp_using_cue - ) - - missing, not_used = model.load_state_dict(state_dict, strict=False) - if len(not_used) > 0: - warnings.warn(f'Some keys are not used: {not_used}', UserWarning) - - assert len(missing) == 0, f'Missing keys: {missing}' - return model - - def yaml_dict(self, mode: str) -> dict: - """ - Return dict for input.yaml from checkpoint config - Dataset paths and statistic values are removed intentionally - """ - if mode not in ['reproduce', 'continue', 'continue_modal']: - raise ValueError(f'Unknown mode: {mode}') - - ignore = [ - 'when', - KEY.DDP_BACKEND, - KEY.LOCAL_RANK, - KEY.IS_DDP, - KEY.DEVICE, - KEY.MODEL_TYPE, - KEY.SHIFT, - KEY.SCALE, - KEY.CONV_DENOMINATOR, - KEY.SAVE_DATASET, - KEY.SAVE_BY_LABEL, - KEY.SAVE_BY_TRAIN_VALID, - KEY.CONTINUE, - KEY.LOAD_DATASET, # old - ] - - cfg = self.config - len_atoms = len(cfg[KEY.TYPE_MAP]) - - world_size = cfg.pop(KEY.WORLD_SIZE, 1) - cfg[KEY.BATCH_SIZE] = cfg[KEY.BATCH_SIZE] * world_size - cfg[KEY.LOAD_TRAINSET] = '**path_to_training_set**' - - major, minor, _ = cfg.pop('version', '0.0.0').split('.')[:3] - if int(major) == 0 and int(minor) <= 9: - warnings.warn('checkpoint version too old, yaml may wrong') - - ret = {'model': {}, 'train': {}, 'data': {}} - for k, v in cfg.items(): - if k.startswith('_') or k in ignore or k.endswith('set_path'): - continue - if k in consts.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG: - ret['model'][k] = v - elif k in consts.DEFAULT_TRAINING_CONFIG: - ret['train'][k] = v - elif k in consts.DEFAULT_DATA_CONFIG: - ret['data'][k] = v - - ret['model'][KEY.CHEMICAL_SPECIES] = ( - 'univ' if len_atoms == consts.NUM_UNIV_ELEMENT else 'auto' - ) - ret['data'][KEY.LOAD_TRAINSET] = '**path_to_trainset**' - ret['data'][KEY.LOAD_VALIDSET] = '**path_to_validset**' - - # TODO - ret['data'][KEY.SHIFT] = '**failed to infer shift, should be set**' - ret['data'][KEY.SCALE] = '**failed to infer scale, should be set**' - - if mode.startswith('continue'): - ret['train'].update( - {KEY.CONTINUE: {KEY.CHECKPOINT: self.checkpoint_path}} - ) - modal_names = None - if mode == 'continue_modal' and not cfg.get(KEY.USE_MODALITY, False): - ret['train'][KEY.USE_MODALITY] = True - - # suggest defaults - ret['model'][KEY.USE_MODAL_NODE_EMBEDDING] = False - ret['model'][KEY.USE_MODAL_SELF_INTER_INTRO] = True - ret['model'][KEY.USE_MODAL_SELF_INTER_OUTRO] = True - ret['model'][KEY.USE_MODAL_OUTPUT_BLOCK] = True - - ret['data'][KEY.USE_MODAL_WISE_SHIFT] = True - ret['data'][KEY.USE_MODAL_WISE_SCALE] = False - - modal_names = ['my_modal1', 'my_modal2'] - elif cfg.get(KEY.USE_MODALITY, False): - modal_names = list(cfg[KEY.MODAL_MAP].keys()) - - if modal_names: - ret['data'][KEY.LOAD_TRAINSET] = [ - {'data_modality': mm, 'file_list': [{'file': f'**path_to_{mm}**'}]} - for mm in modal_names - ] - - return ret - - def append_modal( - self, - dst_config, - original_modal_name: str = 'origin', - working_dir: str = os.getcwd(), - ): - """ """ - import sevenn.train.modal_dataset as modal_dataset - from sevenn.model_build import init_shift_scale - from sevenn.scripts.convert_model_modality import _append_modal_weight - - src_config = self.config - src_has_no_modal = not src_config.get(KEY.USE_MODALITY, False) - - # inherit element things first - chem_keys = [ - KEY.TYPE_MAP, - KEY.NUM_SPECIES, - KEY.CHEMICAL_SPECIES, - KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, - ] - dst_config.update({k: src_config[k] for k in chem_keys}) - - if dst_config[KEY.USE_MODAL_WISE_SHIFT] and ( - KEY.SHIFT not in dst_config or not isinstance(dst_config[KEY.SHIFT], str) - ): - raise ValueError('To use modal wise shift, keyword shift is required') - if dst_config[KEY.USE_MODAL_WISE_SCALE] and ( - KEY.SCALE not in dst_config or not isinstance(dst_config[KEY.SCALE], str) - ): - raise ValueError('To use modal wise scale, keyword scale is required') - - if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SHIFT]: - dst_config[KEY.SHIFT] = src_config[KEY.SHIFT] - if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SCALE]: - dst_config[KEY.SCALE] = src_config[KEY.SCALE] - - # get statistics of given datasets of yaml - # dst_config updated - _ = modal_dataset.from_config(dst_config, working_dir=working_dir) - dst_modal_map = dst_config[KEY.MODAL_MAP] - - found_modal_names = list(dst_modal_map.keys()) - if len(found_modal_names) == 0: - raise ValueError('No modality is found from config') - - # Check difference btw given modals and new modal map - orig_modal_map = src_config.get(KEY.MODAL_MAP, {original_modal_name: 0}) - assert isinstance(orig_modal_map, dict) - new_modal_map = orig_modal_map.copy() - for modal_name in found_modal_names: - if modal_name in orig_modal_map: # duplicate, skipping - continue - new_modal_map[modal_name] = len(new_modal_map) # assign new - print(f'New modals: {list(new_modal_map.keys())}') - - if src_has_no_modal: - append_num = len(new_modal_map) - else: - append_num = len(new_modal_map) - len(orig_modal_map) - if append_num == 0: - raise ValueError('Nothing to append from checkpoint') - - dst_config[KEY.NUM_MODALITIES] = len(new_modal_map) - dst_config[KEY.MODAL_MAP] = new_modal_map - - # update dst_config's shift scales based on src_config - for ss_key, use_mw in ( - (KEY.SHIFT, dst_config[KEY.USE_MODAL_WISE_SHIFT]), - (KEY.SCALE, dst_config[KEY.USE_MODAL_WISE_SCALE]), - ): - if not use_mw: # not using mw ss, just assign - assert not isinstance(dst_config[ss_key], dict) - dst_config[ss_key] = src_config[ss_key] - elif src_has_no_modal: - assert isinstance(dst_config[ss_key], dict) - # mw ss, update by dict but use original_modal_name - dst_config[ss_key].update({original_modal_name: src_config[ss_key]}) - else: - assert isinstance(dst_config[ss_key], dict) - # mw ss, update by dict - dst_config[ss_key].update(src_config[ss_key]) - scaler = init_shift_scale(dst_config) - - # finally, prepare updated continuable state dict using above - orig_model = self.build_model() - orig_state_dict = orig_model.state_dict() - - new_state_dict = copy_state_dict(orig_state_dict) - for stct_key in orig_state_dict: - sp = stct_key.split('.') - k, follower = sp[0], '.'.join(sp[1:]) - if k == 'rescale_atomic_energy' and follower == 'shift': - new_state_dict[stct_key] = scaler.shift.clone() - elif k == 'rescale_atomic_energy' and follower == 'scale': - new_state_dict[stct_key] = scaler.scale.clone() - elif follower == 'linear.weight' and ( # append linear layer - ( - dst_config[KEY.USE_MODAL_NODE_EMBEDDING] - and k.endswith('onehot_to_feature_x') - ) - or ( - dst_config[KEY.USE_MODAL_SELF_INTER_INTRO] - and k.endswith('self_interaction_1') - ) - or ( - dst_config[KEY.USE_MODAL_SELF_INTER_OUTRO] - and k.endswith('self_interaction_2') - ) - or ( - dst_config[KEY.USE_MODAL_OUTPUT_BLOCK] - and k == 'reduce_input_to_hidden' - ) - ): - orig_linear = getattr(orig_model._modules[k], 'linear') - # assert normalization element - new_state_dict[stct_key] = _append_modal_weight( - orig_state_dict, - k, - orig_linear.irreps_in, - orig_linear.irreps_out, - append_num, - ) - - dst_config['version'] = sevenn.__version__ - - return new_state_dict - - def get_checkpoint_dict(self) -> dict: - """ - Return duplicate of this checkpoint with new hash and time. - Convenient for creating variant of the checkpoint - """ - return { - 'config': self.config, - 'epoch': self.epoch, - 'model_state_dict': self.model_state_dict, - 'optimizer_state_dict': self.optimizer_state_dict, - 'scheduler_state_dict': self.scheduler_state_dict, - 'time': datetime.now().strftime('%Y-%m-%d %H:%M'), - 'hash': uuid.uuid4().hex, - } +import os +import pathlib +import uuid +import warnings +from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, Optional, Union + +import pandas as pd +from packaging.version import Version +from torch import Tensor +from torch import load as torch_load + +import sevenn +import sevenn._const as consts +import sevenn._keys as KEY +import sevenn.scripts.backward_compatibility as compat +from sevenn import model_build +from sevenn.nn.scale import get_resolved_shift_scale +from sevenn.nn.sequential import AtomGraphSequential + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + import numpy as np + + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +def copy_state_dict(state_dict) -> dict: + if isinstance(state_dict, dict): + return {key: copy_state_dict(value) for key, value in state_dict.items()} + elif isinstance(state_dict, list): + return [copy_state_dict(item) for item in state_dict] # type: ignore + elif isinstance(state_dict, Tensor): + return state_dict.clone() # type: ignore + else: + # For non-tensor values (e.g., scalars, None), return as-is + return state_dict + + +def _config_cp_routine(config): + cp_ver = Version(config.get('version', None)) + this_ver = Version(sevenn.__version__) + if cp_ver > this_ver: + warnings.warn(f'The checkpoint version ({cp_ver}) is newer than this source' + f'({this_ver}). This may cause unexpected behaviors') + + defaults = {**consts.model_defaults(config)} + config = compat.patch_old_config(config) # type: ignore + + scaler = model_build.init_shift_scale(config) + shift, scale = get_resolved_shift_scale( + scaler, config.get(KEY.TYPE_MAP), config.get(KEY.MODAL_MAP, None) + ) + config['shift'] = shift + config['scale'] = scale + + for k, v in defaults.items(): + if k in config: + continue + if os.getenv('SEVENN_DEBUG', False): + warnings.warn(f'{k} not in config, use default value {v}', UserWarning) + config[k] = v + + for k, v in config.items(): + if isinstance(v, Tensor): + config[k] = v.cpu() + return config + + +def _convert_e3nn_and_cueq(stct_src, stct_dst, src_config, from_cueq): + """ + manually check keys and assert if something unexpected happens + """ + n_layer = src_config['num_convolution_layer'] + + linear_module_names = [ + 'onehot_to_feature_x', + 'reduce_input_to_hidden', + 'reduce_hidden_to_energy', + ] + convolution_module_names = [] + fc_tensor_product_module_names = [] + for i in range(n_layer): + linear_module_names.append(f'{i}_self_interaction_1') + linear_module_names.append(f'{i}_self_interaction_2') + if src_config.get(KEY.SELF_CONNECTION_TYPE) == 'linear': + linear_module_names.append(f'{i}_self_connection_intro') + elif src_config.get(KEY.SELF_CONNECTION_TYPE) == 'nequip': + fc_tensor_product_module_names.append(f'{i}_self_connection_intro') + convolution_module_names.append(f'{i}_convolution') + + # Rule: those keys can be safely ignored before state dict load, + # except for linear.bias. This should be aborted in advance to + # this function. Others are not parameters but constants. + cue_only_linear_followers = ['linear.f.tp.f_fx.module.c'] + e3nn_only_linear_followers = ['linear.bias', 'linear.output_mask'] + ignores_in_linear = cue_only_linear_followers + e3nn_only_linear_followers + + cue_only_conv_followers = [ + 'convolution.f.tp.f_fx.module.c', + 'convolution.f.tp.module.module.f.module.module._f.data', + ] + e3nn_only_conv_followers = [ + 'convolution._compiled_main_left_right._w3j', + 'convolution.weight', + 'convolution.output_mask', + ] + ignores_in_conv = cue_only_conv_followers + e3nn_only_conv_followers + + cue_only_fc_followers = ['fc_tensor_product.f.tp.f_fx.module.c'] + e3nn_only_fc_followers = [ + 'fc_tensor_product.output_mask', + ] + ignores_in_fc = cue_only_fc_followers + e3nn_only_fc_followers + + updated_keys = [] + for k, v in stct_src.items(): + module_name = k.split('.')[0] + flag = False + if module_name in linear_module_names: + for ignore in ignores_in_linear: + if '.'.join([module_name, ignore]) in k: + flag = True + break + if not flag and k == '.'.join([module_name, 'linear.weight']): + updated_keys.append(k) + stct_dst[k] = v.clone().reshape(stct_dst[k].shape) + flag = True + assert flag, f'Unexpected key from linear: {k}' + elif module_name in convolution_module_names: + for ignore in ignores_in_conv: + if '.'.join([module_name, ignore]) in k: + flag = True + break + if not flag and ( + k.startswith(f'{module_name}.weight_nn') + or k == '.'.join([module_name, 'denominator']) + ): + updated_keys.append(k) + stct_dst[k] = v.clone().reshape(stct_dst[k].shape) + flag = True + assert flag, f'Unexpected key from linear: {k}' + elif module_name in fc_tensor_product_module_names: + for ignore in ignores_in_fc: + if '.'.join([module_name, ignore]) in k: + flag = True + break + if not flag and k == '.'.join([module_name, 'fc_tensor_product.weight']): + updated_keys.append(k) + stct_dst[k] = v.clone().reshape(stct_dst[k].shape) + flag = True + assert flag, f'Unexpected key from fc tensor product: {k}' + else: + # assert k in stct_dst + updated_keys.append(k) + stct_dst[k] = v.clone().reshape(stct_dst[k].shape) + + return stct_dst + + +class SevenNetCheckpoint: + """ + Tool box for checkpoint processed from SevenNet. + """ + + def __init__(self, checkpoint_path: Union[pathlib.Path, str]): + self._checkpoint_path = os.path.abspath(checkpoint_path) + self._config = None + self._epoch = None + self._model_state_dict = None + self._optimizer_state_dict = None + self._scheduler_state_dict = None + self._hash = None + self._time = None + + self._loaded = False + + def __repr__(self) -> str: + cfg = self.config # just alias + if len(cfg) == 0: + return '' + dct = { + 'Sevennet version': cfg.get('version', 'Not found'), + 'When': self.time, + 'Hash': self.hash, + 'Cutoff': cfg.get('cutoff'), + 'Channel': cfg.get('channel'), + 'Lmax': cfg.get('lmax'), + 'Group (parity)': 'O3' if cfg.get('is_parity') else 'SO3', + 'Interaction layers': cfg.get('num_convolution_layer'), + 'Self connection type': cfg.get('self_connection_type', 'nequip'), + 'Last epoch': self.epoch, + 'Elements': len(cfg.get('chemical_species', [])), + } + if cfg.get('use_modality', False): + dct['Modality'] = ', '.join(list(cfg.get('_modal_map', {}).keys())) + + df = pd.DataFrame.from_dict([dct]).T # type: ignore + df.columns = [''] + return df.to_string() + + @property + def checkpoint_path(self) -> str: + return str(self._checkpoint_path) + + @property + def config(self) -> Dict[str, Any]: + if not self._loaded: + self._load() + assert isinstance(self._config, dict) + return deepcopy(self._config) + + @property + def model_state_dict(self) -> Dict[str, Any]: + if not self._loaded: + self._load() + assert isinstance(self._model_state_dict, dict) + return copy_state_dict(self._model_state_dict) + + @property + def optimizer_state_dict(self) -> Dict[str, Any]: + if not self._loaded: + self._load() + assert isinstance(self._optimizer_state_dict, dict) + return copy_state_dict(self._optimizer_state_dict) + + @property + def scheduler_state_dict(self) -> Dict[str, Any]: + if not self._loaded: + self._load() + assert isinstance(self._scheduler_state_dict, dict) + return copy_state_dict(self._scheduler_state_dict) + + @property + def epoch(self) -> Optional[int]: + if not self._loaded: + self._load() + return self._epoch + + @property + def time(self) -> str: + if not self._loaded: + self._load() + assert isinstance(self._time, str) + return self._time + + @property + def hash(self) -> str: + if not self._loaded: + self._load() + assert isinstance(self._hash, str) + return self._hash + + def _load(self) -> None: + assert not self._loaded + cp_path = self.checkpoint_path # just alias + + cp = torch_load(cp_path, weights_only=False, map_location='cpu') + self._config_original = cp.get('config', {}) + self._model_state_dict = cp.get('model_state_dict', {}) + self._optimizer_state_dict = cp.get('optimizer_state_dict', {}) + self._scheduler_state_dict = cp.get('scheduler_state_dict', {}) + self._epoch = cp.get('epoch', None) + self._time = cp.get('time', 'Not found') + self._hash = cp.get('hash', 'Not found') + + if len(self._config_original) == 0: + warnings.warn(f'config is not found from {cp_path}') + self._config = {} + else: + self._config = _config_cp_routine(self._config_original) + + if len(self._model_state_dict) == 0: + warnings.warn(f'model_state_dict is not found from {cp_path}') + + self._loaded = True + + def build_model(self, backend: Optional[str] = None) -> AtomGraphSequential: + from .model_build import build_E3_equivariant_model + + use_cue = not backend or backend.lower() in ['cue', 'cueq'] + try: + cp_using_cue = self.config[KEY.CUEQUIVARIANCE_CONFIG]['use'] + except KeyError: + cp_using_cue = False + + if (not backend) or (use_cue == cp_using_cue): + # backend not given, or checkpoint backend is same as requested + model = build_E3_equivariant_model(self.config) + state_dict = compat.patch_state_dict_if_old( + self.model_state_dict, self.config, model + ) + else: + cfg_new = self.config + cfg_new[KEY.CUEQUIVARIANCE_CONFIG] = {'use': use_cue} + model = build_E3_equivariant_model(cfg_new) + stct_src = compat.patch_state_dict_if_old( + self.model_state_dict, self.config, model + ) + state_dict = _convert_e3nn_and_cueq( + stct_src, model.state_dict(), self.config, from_cueq=cp_using_cue + ) + + missing, not_used = model.load_state_dict(state_dict, strict=False) + if len(not_used) > 0: + warnings.warn(f'Some keys are not used: {not_used}', UserWarning) + + assert len(missing) == 0, f'Missing keys: {missing}' + return model + + def yaml_dict(self, mode: str) -> dict: + """ + Return dict for input.yaml from checkpoint config + Dataset paths and statistic values are removed intentionally + """ + if mode not in ['reproduce', 'continue', 'continue_modal']: + raise ValueError(f'Unknown mode: {mode}') + + ignore = [ + 'when', + KEY.DDP_BACKEND, + KEY.LOCAL_RANK, + KEY.IS_DDP, + KEY.DEVICE, + KEY.MODEL_TYPE, + KEY.SHIFT, + KEY.SCALE, + KEY.CONV_DENOMINATOR, + KEY.SAVE_DATASET, + KEY.SAVE_BY_LABEL, + KEY.SAVE_BY_TRAIN_VALID, + KEY.CONTINUE, + KEY.LOAD_DATASET, # old + ] + + cfg = self.config + len_atoms = len(cfg[KEY.TYPE_MAP]) + + world_size = cfg.pop(KEY.WORLD_SIZE, 1) + cfg[KEY.BATCH_SIZE] = cfg[KEY.BATCH_SIZE] * world_size + cfg[KEY.LOAD_TRAINSET] = '**path_to_training_set**' + + major, minor, _ = cfg.pop('version', '0.0.0').split('.')[:3] + if int(major) == 0 and int(minor) <= 9: + warnings.warn('checkpoint version too old, yaml may wrong') + + ret = {'model': {}, 'train': {}, 'data': {}} + for k, v in cfg.items(): + if k.startswith('_') or k in ignore or k.endswith('set_path'): + continue + if k in consts.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG: + ret['model'][k] = v + elif k in consts.DEFAULT_TRAINING_CONFIG: + ret['train'][k] = v + elif k in consts.DEFAULT_DATA_CONFIG: + ret['data'][k] = v + + ret['model'][KEY.CHEMICAL_SPECIES] = ( + 'univ' if len_atoms == consts.NUM_UNIV_ELEMENT else 'auto' + ) + ret['data'][KEY.LOAD_TRAINSET] = '**path_to_trainset**' + ret['data'][KEY.LOAD_VALIDSET] = '**path_to_validset**' + + # TODO + ret['data'][KEY.SHIFT] = '**failed to infer shift, should be set**' + ret['data'][KEY.SCALE] = '**failed to infer scale, should be set**' + + if mode.startswith('continue'): + ret['train'].update( + {KEY.CONTINUE: {KEY.CHECKPOINT: self.checkpoint_path}} + ) + modal_names = None + if mode == 'continue_modal' and not cfg.get(KEY.USE_MODALITY, False): + ret['train'][KEY.USE_MODALITY] = True + + # suggest defaults + ret['model'][KEY.USE_MODAL_NODE_EMBEDDING] = False + ret['model'][KEY.USE_MODAL_SELF_INTER_INTRO] = True + ret['model'][KEY.USE_MODAL_SELF_INTER_OUTRO] = True + ret['model'][KEY.USE_MODAL_OUTPUT_BLOCK] = True + + ret['data'][KEY.USE_MODAL_WISE_SHIFT] = True + ret['data'][KEY.USE_MODAL_WISE_SCALE] = False + + modal_names = ['my_modal1', 'my_modal2'] + elif cfg.get(KEY.USE_MODALITY, False): + modal_names = list(cfg[KEY.MODAL_MAP].keys()) + + if modal_names: + ret['data'][KEY.LOAD_TRAINSET] = [ + {'data_modality': mm, 'file_list': [{'file': f'**path_to_{mm}**'}]} + for mm in modal_names + ] + + return ret + + def append_modal( + self, + dst_config, + original_modal_name: str = 'origin', + working_dir: str = os.getcwd(), + ): + """ """ + import sevenn.train.modal_dataset as modal_dataset + from sevenn.model_build import init_shift_scale + from sevenn.scripts.convert_model_modality import _append_modal_weight + + src_config = self.config + src_has_no_modal = not src_config.get(KEY.USE_MODALITY, False) + + # inherit element things first + chem_keys = [ + KEY.TYPE_MAP, + KEY.NUM_SPECIES, + KEY.CHEMICAL_SPECIES, + KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, + ] + dst_config.update({k: src_config[k] for k in chem_keys}) + + if dst_config[KEY.USE_MODAL_WISE_SHIFT] and ( + KEY.SHIFT not in dst_config or not isinstance(dst_config[KEY.SHIFT], str) + ): + raise ValueError('To use modal wise shift, keyword shift is required') + if dst_config[KEY.USE_MODAL_WISE_SCALE] and ( + KEY.SCALE not in dst_config or not isinstance(dst_config[KEY.SCALE], str) + ): + raise ValueError('To use modal wise scale, keyword scale is required') + + if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SHIFT]: + dst_config[KEY.SHIFT] = src_config[KEY.SHIFT] + if src_has_no_modal and not dst_config[KEY.USE_MODAL_WISE_SCALE]: + dst_config[KEY.SCALE] = src_config[KEY.SCALE] + + # get statistics of given datasets of yaml + # dst_config updated + _ = modal_dataset.from_config(dst_config, working_dir=working_dir) + dst_modal_map = dst_config[KEY.MODAL_MAP] + + found_modal_names = list(dst_modal_map.keys()) + if len(found_modal_names) == 0: + raise ValueError('No modality is found from config') + + # Check difference btw given modals and new modal map + orig_modal_map = src_config.get(KEY.MODAL_MAP, {original_modal_name: 0}) + assert isinstance(orig_modal_map, dict) + new_modal_map = orig_modal_map.copy() + for modal_name in found_modal_names: + if modal_name in orig_modal_map: # duplicate, skipping + continue + new_modal_map[modal_name] = len(new_modal_map) # assign new + print(f'New modals: {list(new_modal_map.keys())}') + + if src_has_no_modal: + append_num = len(new_modal_map) + else: + append_num = len(new_modal_map) - len(orig_modal_map) + if append_num == 0: + raise ValueError('Nothing to append from checkpoint') + + dst_config[KEY.NUM_MODALITIES] = len(new_modal_map) + dst_config[KEY.MODAL_MAP] = new_modal_map + + # update dst_config's shift scales based on src_config + for ss_key, use_mw in ( + (KEY.SHIFT, dst_config[KEY.USE_MODAL_WISE_SHIFT]), + (KEY.SCALE, dst_config[KEY.USE_MODAL_WISE_SCALE]), + ): + if not use_mw: # not using mw ss, just assign + assert not isinstance(dst_config[ss_key], dict) + dst_config[ss_key] = src_config[ss_key] + elif src_has_no_modal: + assert isinstance(dst_config[ss_key], dict) + # mw ss, update by dict but use original_modal_name + dst_config[ss_key].update({original_modal_name: src_config[ss_key]}) + else: + assert isinstance(dst_config[ss_key], dict) + # mw ss, update by dict + dst_config[ss_key].update(src_config[ss_key]) + scaler = init_shift_scale(dst_config) + + # finally, prepare updated continuable state dict using above + orig_model = self.build_model() + orig_state_dict = orig_model.state_dict() + + new_state_dict = copy_state_dict(orig_state_dict) + for stct_key in orig_state_dict: + sp = stct_key.split('.') + k, follower = sp[0], '.'.join(sp[1:]) + if k == 'rescale_atomic_energy' and follower == 'shift': + new_state_dict[stct_key] = scaler.shift.clone() + elif k == 'rescale_atomic_energy' and follower == 'scale': + new_state_dict[stct_key] = scaler.scale.clone() + elif follower == 'linear.weight' and ( # append linear layer + ( + dst_config[KEY.USE_MODAL_NODE_EMBEDDING] + and k.endswith('onehot_to_feature_x') + ) + or ( + dst_config[KEY.USE_MODAL_SELF_INTER_INTRO] + and k.endswith('self_interaction_1') + ) + or ( + dst_config[KEY.USE_MODAL_SELF_INTER_OUTRO] + and k.endswith('self_interaction_2') + ) + or ( + dst_config[KEY.USE_MODAL_OUTPUT_BLOCK] + and k == 'reduce_input_to_hidden' + ) + ): + orig_linear = getattr(orig_model._modules[k], 'linear') + # assert normalization element + new_state_dict[stct_key] = _append_modal_weight( + orig_state_dict, + k, + orig_linear.irreps_in, + orig_linear.irreps_out, + append_num, + ) + + dst_config['version'] = sevenn.__version__ + + return new_state_dict + + def get_checkpoint_dict(self) -> dict: + """ + Return duplicate of this checkpoint with new hash and time. + Convenient for creating variant of the checkpoint + """ + return { + 'config': self.config, + 'epoch': self.epoch, + 'model_state_dict': self.model_state_dict, + 'optimizer_state_dict': self.optimizer_state_dict, + 'scheduler_state_dict': self.scheduler_state_dict, + 'time': datetime.now().strftime('%Y-%m-%d %H:%M'), + 'hash': uuid.uuid4().hex, + } diff --git a/mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py b/mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py index 48f4a20..5999aea 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/error_recorder.py @@ -1,430 +1,430 @@ -from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -import torch.distributed as dist - -import sevenn._keys as KEY -from sevenn.train.loss import LossDefinition - -from .atom_graph_data import AtomGraphData -from .train.optim import loss_dict - -_ERROR_TYPES = { - 'TotalEnergy': { - 'name': 'Energy', - 'ref_key': KEY.ENERGY, - 'pred_key': KEY.PRED_TOTAL_ENERGY, - 'unit': 'eV', - 'vdim': 1, - }, - 'Energy': { # by default per-atom for energy - 'name': 'Energy', - 'ref_key': KEY.ENERGY, - 'pred_key': KEY.PRED_TOTAL_ENERGY, - 'unit': 'eV/atom', - 'per_atom': True, - 'vdim': 1, - }, - 'Force': { - 'name': 'Force', - 'ref_key': KEY.FORCE, - 'pred_key': KEY.PRED_FORCE, - 'unit': 'eV/Å', - 'vdim': 3, - }, - 'Stress': { - 'name': 'Stress', - 'ref_key': KEY.STRESS, - 'pred_key': KEY.PRED_STRESS, - 'unit': 'kbar', - 'coeff': 1602.1766208, - 'vdim': 6, - }, - 'Stress_GPa': { - 'name': 'Stress', - 'ref_key': KEY.STRESS, - 'pred_key': KEY.PRED_STRESS, - 'unit': 'GPa', - 'coeff': 160.21766208, - 'vdim': 6, - }, - 'TotalLoss': { - 'name': 'TotalLoss', - 'unit': None, - }, -} - - -def get_err_type(name: str) -> Dict[str, Any]: - return deepcopy(_ERROR_TYPES[name]) - - -def _get_loss_function_from_name(loss_functions, name): - for loss_def, w in loss_functions: - if loss_def.name.lower() == name.lower(): - return loss_def, w - return None, None - - -class AverageNumber: - def __init__(self): - self._sum = 0.0 - self._count = 0 - - def update(self, values: torch.Tensor): - self._sum += values.sum().item() - self._count += values.numel() - - def _ddp_reduce(self, device): - _sum = torch.tensor(self._sum, device=device) - _count = torch.tensor(self._count, device=device) - dist.all_reduce(_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(_count, op=dist.ReduceOp.SUM) - self._sum = _sum.item() - self._count = _count.item() - - def get(self): - if self._count == 0: - return torch.nan - return self._sum / self._count - - -class ErrorMetric: - """ - Base class for error metrics We always average error by # of structures, - and designed to collect errors in the middle of iteration (by AverageNumber) - """ - - def __init__( - self, - name: str, - ref_key: str, - pred_key: str, - coeff: float = 1.0, - unit: Optional[str] = None, - per_atom: bool = False, - ignore_unlabeled: bool = True, - **kwargs, - ): - self.name = name - self.unit = unit - self.coeff = coeff - self.ref_key = ref_key - self.pred_key = pred_key - self.per_atom = per_atom - self.ignore_unlabeled = ignore_unlabeled - self.value = AverageNumber() - - def update(self, output: AtomGraphData): - raise NotImplementedError - - def _retrieve(self, output: AtomGraphData): - y_ref = output[self.ref_key] * self.coeff - y_pred = output[self.pred_key] * self.coeff - if self.per_atom: - assert y_ref.dim() == 1 and y_pred.dim() == 1 - natoms = output[KEY.NUM_ATOMS] - y_ref = y_ref / natoms - y_pred = y_pred / natoms - if self.ignore_unlabeled: - unlabelled_idx = torch.isnan(y_ref) - y_ref = y_ref[~unlabelled_idx] - y_pred = y_pred[~unlabelled_idx] - return y_ref, y_pred - - def ddp_reduce(self, device): - self.value._ddp_reduce(device) - - def reset(self): - self.value = AverageNumber() - - def get(self): - return self.value.get() - - def key_str(self, with_unit=True): - if self.unit is None or not with_unit: - return self.name - else: - return f'{self.name} ({self.unit})' - - def __str__(self): - return f'{self.key_str()}: {self.value.get():.6f}' - - -class RMSError(ErrorMetric): - """ - Vector squared error - """ - - def __init__(self, vdim: int = 1, **kwargs): - super().__init__(**kwargs) - self.vdim = vdim - self._se = torch.nn.MSELoss(reduction='none') - - def _square_error(self, y_ref, y_pred, vdim: int): - return self._se(y_ref.view(-1, vdim), y_pred.view(-1, vdim)).sum(dim=1) - - def update(self, output: AtomGraphData): - y_ref, y_pred = self._retrieve(output) - se = self._square_error(y_ref, y_pred, self.vdim) - self.value.update(se) - - def get(self): - return self.value.get() ** 0.5 - - -class ComponentRMSError(ErrorMetric): - """ - Ignore vector dim and just average over components - Results smaller error - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._se = torch.nn.MSELoss(reduction='none') - - def _square_error(self, y_ref, y_pred): - return self._se(y_ref, y_pred) - - def update(self, output: AtomGraphData): - y_ref, y_pred = self._retrieve(output) - y_ref = y_ref.view(-1) - y_pred = y_pred.view(-1) - se = self._square_error(y_ref, y_pred) - self.value.update(se) - - def get(self): - return self.value.get() ** 0.5 - - -class MAError(ErrorMetric): - """ - Average over all component - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _square_error(self, y_ref, y_pred): - return torch.abs(y_ref - y_pred) - - def update(self, output: AtomGraphData): - y_ref, y_pred = self._retrieve(output) - y_ref = y_ref.reshape((-1,)) - y_pred = y_pred.reshape((-1,)) - se = self._square_error(y_ref, y_pred) - self.value.update(se) - - -class CustomError(ErrorMetric): - """ - Custom error metric - Args: - func: a function that takes y_ref and y_pred - and returns a list of errors - """ - - def __init__(self, func: Callable, **kwargs): - super().__init__(**kwargs) - self.func = func - - def update(self, output: AtomGraphData): - y_ref, y_pred = self._retrieve(output) - se = self.func(y_ref, y_pred) if len(y_ref) > 0 else torch.tensor([]) - self.value.update(se) - - -class LossError(ErrorMetric): - """ - Error metric that record loss - """ - - def __init__( - self, - name: str, - loss_def: LossDefinition, - **kwargs, - ): - super().__init__( - name, - ignore_unlabeld=loss_def.ignore_unlabeled, - **kwargs, - ) - self.loss_def = loss_def - - def update(self, output: AtomGraphData): - loss = self.loss_def.get_loss(output) # type: ignore - self.value.update(loss) # type: ignore - - -class CombinedError(ErrorMetric): - """ - Combine multiple error metrics with weights - corresponds to a weighted sum of errors (normally used in loss) - """ - - def __init__(self, metrics: List[Tuple[ErrorMetric, float]], **kwargs): - super().__init__(**kwargs) - self.metrics = metrics - assert kwargs['unit'] is None - - def update(self, output: AtomGraphData): - for metric, _ in self.metrics: - metric.update(output) - - def reset(self): - for metric, _ in self.metrics: - metric.reset() - - def ddp_reduce(self, device): # override - for metric, _ in self.metrics: - metric.value._ddp_reduce(device) - - def get(self): - val = 0.0 - for metric, weight in self.metrics: - val += metric.get() * weight - return val - - -class ErrorRecorder: - """ - record errors of a model - """ - - METRIC_DICT = { - 'RMSE': RMSError, - 'ComponentRMSE': ComponentRMSError, - 'MAE': MAError, - 'Loss': LossError, - } - - def __init__(self, metrics: List[ErrorMetric]): - self.history = [] - self.metrics = metrics - - def _update(self, output: AtomGraphData): - for metric in self.metrics: - metric.update(output) - - def update(self, output: AtomGraphData, no_grad=True): - if no_grad: - with torch.no_grad(): - self._update(output) - else: - self._update(output) - - def get_metric_dict(self, with_unit=True): - return {metric.key_str(with_unit): metric.get() for metric in self.metrics} - - def get_current(self): - dct = {} - for metric in self.metrics: - dct[metric.name] = { - 'value': metric.get(), - 'unit': metric.unit, - 'ref_key': metric.ref_key, - 'pred_key': metric.pred_key, - } - return dct - - def get_dct(self, prefix=''): - dct = {} - if prefix.endswith('_') is False and prefix != '': - prefix = prefix + '_' - for metric in self.metrics: - dct[f'{prefix}{metric.name}'] = f'{metric.get():6f}' - return dct - - def get_key_str(self, name: str): - for metric in self.metrics: - if name == metric.name: - return metric.key_str() - return None - - def epoch_forward(self): - self.history.append(self.get_current()) - pretty = self.get_metric_dict(with_unit=True) - for metric in self.metrics: - metric.reset() - return pretty # for print - - @staticmethod - def init_total_loss_metric( - config, - criteria: Optional[Callable] = None, - loss_functions: Optional[List[Tuple[LossDefinition, float]]] = None, - ): - if criteria is None and loss_functions is None: - raise ValueError('both criteria and loss functions not given') - - is_stress = config[KEY.IS_TRAIN_STRESS] - metrics = [] - if criteria is not None: - energy_metric = CustomError(criteria, **get_err_type('Energy')) - metrics.append((energy_metric, 1)) - force_metric = CustomError(criteria, **get_err_type('Force')) - metrics.append((force_metric, config[KEY.FORCE_WEIGHT])) - if is_stress: - stress_metric = CustomError(criteria, **get_err_type('Stress')) - metrics.append((stress_metric, config[KEY.STRESS_WEIGHT])) - else: # TODO: this is hard-coded - for efs in ['Energy', 'Force', 'Stress']: - if efs == 'Stress' and not is_stress: - continue - lf, w = _get_loss_function_from_name(loss_functions, efs) - if lf is None: - raise ValueError(f'{efs} not found from loss_functions') - metric = LossError(loss_def=lf, **get_err_type(efs)) - metrics.append((metric, w)) - - total_loss_metric = CombinedError( - metrics, name='TotalLoss', unit=None, ref_key=None, pred_key=None - ) - return total_loss_metric - - @staticmethod - def from_config(config: dict, loss_functions=None): - loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()] - loss_param = config.get(KEY.LOSS_PARAM, {}) - criteria = loss_cls(**loss_param) if loss_functions is None else None - - err_config = config.get(KEY.ERROR_RECORD, False) - if not err_config: - raise ValueError( - 'No error_record config found. Consider util.get_error_recorder' - ) - err_config_n = [] - if not config.get(KEY.IS_TRAIN_STRESS, True): - for err_type, metric_name in err_config: - if 'Stress' in err_type: - continue - err_config_n.append((err_type, metric_name)) - err_config = err_config_n - - err_metrics = [] - for err_type, metric_name in err_config: - metric_kwargs = get_err_type(err_type) - if err_type == 'TotalLoss': # special case - err_metrics.append( - ErrorRecorder.init_total_loss_metric( - config, criteria, loss_functions - ) - ) - continue - metric_cls = ErrorRecorder.METRIC_DICT[metric_name] - assert isinstance(metric_kwargs['name'], str) - if metric_name == 'Loss': - if loss_functions is not None: - metric_cls = LossError - metric_kwargs['loss_def'], _ = _get_loss_function_from_name( - loss_functions, metric_kwargs['name'] - ) - else: - metric_cls = CustomError - metric_kwargs['func'] = criteria - metric_kwargs.pop('unit', None) - metric_kwargs['name'] += f'_{metric_name}' - err_metrics.append(metric_cls(**metric_kwargs)) - return ErrorRecorder(err_metrics) +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist + +import sevenn._keys as KEY +from sevenn.train.loss import LossDefinition + +from .atom_graph_data import AtomGraphData +from .train.optim import loss_dict + +_ERROR_TYPES = { + 'TotalEnergy': { + 'name': 'Energy', + 'ref_key': KEY.ENERGY, + 'pred_key': KEY.PRED_TOTAL_ENERGY, + 'unit': 'eV', + 'vdim': 1, + }, + 'Energy': { # by default per-atom for energy + 'name': 'Energy', + 'ref_key': KEY.ENERGY, + 'pred_key': KEY.PRED_TOTAL_ENERGY, + 'unit': 'eV/atom', + 'per_atom': True, + 'vdim': 1, + }, + 'Force': { + 'name': 'Force', + 'ref_key': KEY.FORCE, + 'pred_key': KEY.PRED_FORCE, + 'unit': 'eV/Å', + 'vdim': 3, + }, + 'Stress': { + 'name': 'Stress', + 'ref_key': KEY.STRESS, + 'pred_key': KEY.PRED_STRESS, + 'unit': 'kbar', + 'coeff': 1602.1766208, + 'vdim': 6, + }, + 'Stress_GPa': { + 'name': 'Stress', + 'ref_key': KEY.STRESS, + 'pred_key': KEY.PRED_STRESS, + 'unit': 'GPa', + 'coeff': 160.21766208, + 'vdim': 6, + }, + 'TotalLoss': { + 'name': 'TotalLoss', + 'unit': None, + }, +} + + +def get_err_type(name: str) -> Dict[str, Any]: + return deepcopy(_ERROR_TYPES[name]) + + +def _get_loss_function_from_name(loss_functions, name): + for loss_def, w in loss_functions: + if loss_def.name.lower() == name.lower(): + return loss_def, w + return None, None + + +class AverageNumber: + def __init__(self): + self._sum = 0.0 + self._count = 0 + + def update(self, values: torch.Tensor): + self._sum += values.sum().item() + self._count += values.numel() + + def _ddp_reduce(self, device): + _sum = torch.tensor(self._sum, device=device) + _count = torch.tensor(self._count, device=device) + dist.all_reduce(_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(_count, op=dist.ReduceOp.SUM) + self._sum = _sum.item() + self._count = _count.item() + + def get(self): + if self._count == 0: + return torch.nan + return self._sum / self._count + + +class ErrorMetric: + """ + Base class for error metrics We always average error by # of structures, + and designed to collect errors in the middle of iteration (by AverageNumber) + """ + + def __init__( + self, + name: str, + ref_key: str, + pred_key: str, + coeff: float = 1.0, + unit: Optional[str] = None, + per_atom: bool = False, + ignore_unlabeled: bool = True, + **kwargs, + ): + self.name = name + self.unit = unit + self.coeff = coeff + self.ref_key = ref_key + self.pred_key = pred_key + self.per_atom = per_atom + self.ignore_unlabeled = ignore_unlabeled + self.value = AverageNumber() + + def update(self, output: AtomGraphData): + raise NotImplementedError + + def _retrieve(self, output: AtomGraphData): + y_ref = output[self.ref_key] * self.coeff + y_pred = output[self.pred_key] * self.coeff + if self.per_atom: + assert y_ref.dim() == 1 and y_pred.dim() == 1 + natoms = output[KEY.NUM_ATOMS] + y_ref = y_ref / natoms + y_pred = y_pred / natoms + if self.ignore_unlabeled: + unlabelled_idx = torch.isnan(y_ref) + y_ref = y_ref[~unlabelled_idx] + y_pred = y_pred[~unlabelled_idx] + return y_ref, y_pred + + def ddp_reduce(self, device): + self.value._ddp_reduce(device) + + def reset(self): + self.value = AverageNumber() + + def get(self): + return self.value.get() + + def key_str(self, with_unit=True): + if self.unit is None or not with_unit: + return self.name + else: + return f'{self.name} ({self.unit})' + + def __str__(self): + return f'{self.key_str()}: {self.value.get():.6f}' + + +class RMSError(ErrorMetric): + """ + Vector squared error + """ + + def __init__(self, vdim: int = 1, **kwargs): + super().__init__(**kwargs) + self.vdim = vdim + self._se = torch.nn.MSELoss(reduction='none') + + def _square_error(self, y_ref, y_pred, vdim: int): + return self._se(y_ref.view(-1, vdim), y_pred.view(-1, vdim)).sum(dim=1) + + def update(self, output: AtomGraphData): + y_ref, y_pred = self._retrieve(output) + se = self._square_error(y_ref, y_pred, self.vdim) + self.value.update(se) + + def get(self): + return self.value.get() ** 0.5 + + +class ComponentRMSError(ErrorMetric): + """ + Ignore vector dim and just average over components + Results smaller error + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._se = torch.nn.MSELoss(reduction='none') + + def _square_error(self, y_ref, y_pred): + return self._se(y_ref, y_pred) + + def update(self, output: AtomGraphData): + y_ref, y_pred = self._retrieve(output) + y_ref = y_ref.view(-1) + y_pred = y_pred.view(-1) + se = self._square_error(y_ref, y_pred) + self.value.update(se) + + def get(self): + return self.value.get() ** 0.5 + + +class MAError(ErrorMetric): + """ + Average over all component + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _square_error(self, y_ref, y_pred): + return torch.abs(y_ref - y_pred) + + def update(self, output: AtomGraphData): + y_ref, y_pred = self._retrieve(output) + y_ref = y_ref.reshape((-1,)) + y_pred = y_pred.reshape((-1,)) + se = self._square_error(y_ref, y_pred) + self.value.update(se) + + +class CustomError(ErrorMetric): + """ + Custom error metric + Args: + func: a function that takes y_ref and y_pred + and returns a list of errors + """ + + def __init__(self, func: Callable, **kwargs): + super().__init__(**kwargs) + self.func = func + + def update(self, output: AtomGraphData): + y_ref, y_pred = self._retrieve(output) + se = self.func(y_ref, y_pred) if len(y_ref) > 0 else torch.tensor([]) + self.value.update(se) + + +class LossError(ErrorMetric): + """ + Error metric that record loss + """ + + def __init__( + self, + name: str, + loss_def: LossDefinition, + **kwargs, + ): + super().__init__( + name, + ignore_unlabeld=loss_def.ignore_unlabeled, + **kwargs, + ) + self.loss_def = loss_def + + def update(self, output: AtomGraphData): + loss = self.loss_def.get_loss(output) # type: ignore + self.value.update(loss) # type: ignore + + +class CombinedError(ErrorMetric): + """ + Combine multiple error metrics with weights + corresponds to a weighted sum of errors (normally used in loss) + """ + + def __init__(self, metrics: List[Tuple[ErrorMetric, float]], **kwargs): + super().__init__(**kwargs) + self.metrics = metrics + assert kwargs['unit'] is None + + def update(self, output: AtomGraphData): + for metric, _ in self.metrics: + metric.update(output) + + def reset(self): + for metric, _ in self.metrics: + metric.reset() + + def ddp_reduce(self, device): # override + for metric, _ in self.metrics: + metric.value._ddp_reduce(device) + + def get(self): + val = 0.0 + for metric, weight in self.metrics: + val += metric.get() * weight + return val + + +class ErrorRecorder: + """ + record errors of a model + """ + + METRIC_DICT = { + 'RMSE': RMSError, + 'ComponentRMSE': ComponentRMSError, + 'MAE': MAError, + 'Loss': LossError, + } + + def __init__(self, metrics: List[ErrorMetric]): + self.history = [] + self.metrics = metrics + + def _update(self, output: AtomGraphData): + for metric in self.metrics: + metric.update(output) + + def update(self, output: AtomGraphData, no_grad=True): + if no_grad: + with torch.no_grad(): + self._update(output) + else: + self._update(output) + + def get_metric_dict(self, with_unit=True): + return {metric.key_str(with_unit): metric.get() for metric in self.metrics} + + def get_current(self): + dct = {} + for metric in self.metrics: + dct[metric.name] = { + 'value': metric.get(), + 'unit': metric.unit, + 'ref_key': metric.ref_key, + 'pred_key': metric.pred_key, + } + return dct + + def get_dct(self, prefix=''): + dct = {} + if prefix.endswith('_') is False and prefix != '': + prefix = prefix + '_' + for metric in self.metrics: + dct[f'{prefix}{metric.name}'] = f'{metric.get():6f}' + return dct + + def get_key_str(self, name: str): + for metric in self.metrics: + if name == metric.name: + return metric.key_str() + return None + + def epoch_forward(self): + self.history.append(self.get_current()) + pretty = self.get_metric_dict(with_unit=True) + for metric in self.metrics: + metric.reset() + return pretty # for print + + @staticmethod + def init_total_loss_metric( + config, + criteria: Optional[Callable] = None, + loss_functions: Optional[List[Tuple[LossDefinition, float]]] = None, + ): + if criteria is None and loss_functions is None: + raise ValueError('both criteria and loss functions not given') + + is_stress = config[KEY.IS_TRAIN_STRESS] + metrics = [] + if criteria is not None: + energy_metric = CustomError(criteria, **get_err_type('Energy')) + metrics.append((energy_metric, 1)) + force_metric = CustomError(criteria, **get_err_type('Force')) + metrics.append((force_metric, config[KEY.FORCE_WEIGHT])) + if is_stress: + stress_metric = CustomError(criteria, **get_err_type('Stress')) + metrics.append((stress_metric, config[KEY.STRESS_WEIGHT])) + else: # TODO: this is hard-coded + for efs in ['Energy', 'Force', 'Stress']: + if efs == 'Stress' and not is_stress: + continue + lf, w = _get_loss_function_from_name(loss_functions, efs) + if lf is None: + raise ValueError(f'{efs} not found from loss_functions') + metric = LossError(loss_def=lf, **get_err_type(efs)) + metrics.append((metric, w)) + + total_loss_metric = CombinedError( + metrics, name='TotalLoss', unit=None, ref_key=None, pred_key=None + ) + return total_loss_metric + + @staticmethod + def from_config(config: dict, loss_functions=None): + loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()] + loss_param = config.get(KEY.LOSS_PARAM, {}) + criteria = loss_cls(**loss_param) if loss_functions is None else None + + err_config = config.get(KEY.ERROR_RECORD, False) + if not err_config: + raise ValueError( + 'No error_record config found. Consider util.get_error_recorder' + ) + err_config_n = [] + if not config.get(KEY.IS_TRAIN_STRESS, True): + for err_type, metric_name in err_config: + if 'Stress' in err_type: + continue + err_config_n.append((err_type, metric_name)) + err_config = err_config_n + + err_metrics = [] + for err_type, metric_name in err_config: + metric_kwargs = get_err_type(err_type) + if err_type == 'TotalLoss': # special case + err_metrics.append( + ErrorRecorder.init_total_loss_metric( + config, criteria, loss_functions + ) + ) + continue + metric_cls = ErrorRecorder.METRIC_DICT[metric_name] + assert isinstance(metric_kwargs['name'], str) + if metric_name == 'Loss': + if loss_functions is not None: + metric_cls = LossError + metric_kwargs['loss_def'], _ = _get_loss_function_from_name( + loss_functions, metric_kwargs['name'] + ) + else: + metric_cls = CustomError + metric_kwargs['func'] = criteria + metric_kwargs.pop('unit', None) + metric_kwargs['name'] += f'_{metric_name}' + err_metrics.append(metric_cls(**metric_kwargs)) + return ErrorRecorder(err_metrics) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/logger.py b/mace-bench/3rdparty/SevenNet/sevenn/logger.py index 1db11c0..1c66af8 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/logger.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/logger.py @@ -1,336 +1,336 @@ -import os -import time -import traceback -from datetime import datetime -from typing import Any, Dict, List, Optional - -from ase.data import atomic_numbers - -import sevenn._keys as KEY -from sevenn import __version__ - -CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()} - - -class Singleton(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] - - -class Logger(metaclass=Singleton): - SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output - - def __init__( - self, filename: Optional[str] = None, screen: bool = False, rank: int = 0 - ): - self.rank = rank - self._filename = filename - if rank == 0: - # if filename is not None: - # self.logfile = open(filename, 'a', buffering=1) - self.logfile = None - self.files = {} - self.screen = screen - else: - self.logfile = None - self.screen = False - self.timer_dct = {} - self.active = True - - def __enter__(self): - if self.rank != 0: - return self - if self.logfile is None and self._filename is not None: - try: - self.logfile = open( - self._filename, 'a', buffering=1, encoding='utf-8' - ) - except IOError as e: - print(f'Failed to re-open log file {self._filename}: {e}') - self.logfile = None - self.files = {} - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self.rank != 0: - return self - try: - if self.logfile is not None: - self.logfile.close() - self.logfile = None - for f in self.files.values(): - f.close() - except IOError as e: - print(f'Failed to close log files: {e}') - finally: - self.logfile = None - self.files = {} - - def switch_file(self, new_filename: str): - if self.rank != 0: - return self - if self.logfile is not None: - raise ValueError('Current logfile is not yet closed') - self._filename = new_filename - return self - - def write(self, content: str): - if self.rank != 0: - return - # no newline! - if self.logfile is not None and self.active: - self.logfile.write(content) - if self.screen and self.active: - print(content, end='') - - def writeline(self, content: str): - content = content + '\n' - self.write(content) - - def init_csv(self, filename: str, header: list): - """ - Deprecated - """ - if self.rank == 0: - self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8') - self.files[filename].write(','.join(header) + '\n') - else: - pass - - def append_csv(self, filename: str, content: list, decimal: int = 6): - """ - Deprecated - """ - if self.rank == 0: - if filename not in self.files: - self.files[filename] = open(filename, 'a', buffering=1) - str_content = [] - for c in content: - if isinstance(c, float): - str_content.append(f'{c:.{decimal}f}') - else: - str_content.append(str(c)) - self.files[filename].write(','.join(str_content) + '\n') - else: - pass - - def natoms_write(self, natoms: Dict[str, Dict]): - content = '' - total_natom = {} - for label, natom in natoms.items(): - content += self.format_k_v(label, natom) - for specie, num in natom.items(): - try: - total_natom[specie] += num - except KeyError: - total_natom[specie] = num - content += self.format_k_v('Total, label wise', total_natom) - content += self.format_k_v('Total', sum(total_natom.values())) - self.write(content) - - def statistic_write(self, statistic: Dict[str, Dict]): - content = '' - for label, dct in statistic.items(): - if label.startswith('_'): - continue - if not isinstance(dct, dict): - continue - dct_new = {} - for k, v in dct.items(): - if k.startswith('_'): - continue - if isinstance(v, int): - dct_new[k] = v - else: - dct_new[k] = f'{v:.3f}' - content += self.format_k_v(label, dct_new) - self.write(content) - - # TODO : refactoring!!!, this is not loss, rmse - def epoch_write_specie_wise_loss(self, train_loss, valid_loss): - lb_pad = 21 - fs = 6 - pad = 21 - fs - ln = '-' * fs - total_atom_type = train_loss.keys() - content = '' - - for at in total_atom_type: - t_F = train_loss[at] - v_F = valid_loss[at] - at_sym = CHEM_SYMBOLS[at] - content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format( - label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs - ) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format( - t_F=t_F, v_F=v_F, pad=pad, fs=fs - ) - content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format( - t_S=ln, v_S=ln, pad=pad, fs=fs - ) - content += '\n' - self.write(content) - - def write_full_table( - self, - dict_list: List[Dict], - row_labels: List[str], - decimal_places: int = 6, - pad: int = 2, - ): - """ - Assume data_list is list of dict with same keys - """ - assert len(dict_list) == len(row_labels) - label_len = max(map(len, row_labels)) - # Extract the column names and create a 2D array of values - col_names = list(dict_list[0].keys()) - - values = [list(d.values()) for d in dict_list] - - # Format the numbers with the given decimal places - formatted_values = [ - [f'{value:.{decimal_places}f}' for value in row] for row in values - ] - - # Calculate padding lengths for each column (with extra padding) - max_col_lengths = [ - max(len(str(value)) for value in col) + pad - for col in zip(col_names, *formatted_values) - ] - - # Create header row and separator - header = ' ' * (label_len + pad) + ' '.join( - col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths) - ) - separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * ( - label_len + pad - ) - - # Print header and separator - self.writeline(header) - self.writeline(separator) - - # Print the data rows with row labels - for row_label, row in zip(row_labels, formatted_values): - data_row = ' '.join( - value.rjust(pad) for value, pad in zip(row, max_col_lengths) - ) - self.writeline(f'{row_label.ljust(label_len)}{data_row}') - - def format_k_v(self, key: Any, val: Any, write: bool = False): - """ - key and val should be str convertible - """ - MAX_KEY_SIZE = 20 - SEPARATOR = ', ' - EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3) - NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5 - key = str(key) - val = str(val) - content = f'{key:<{MAX_KEY_SIZE}}: {val}' - if len(content) > NEW_LINE_LEN: - content = f'{key:<{MAX_KEY_SIZE}}: ' - # septate val by separator - val_list = val.split(SEPARATOR) - current_len = len(content) - for val_compo in val_list: - current_len += len(val_compo) - if current_len > NEW_LINE_LEN: - newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}' - content += f'\\\n{newline_content}' - current_len = len(newline_content) - else: - content += f'{val_compo}{SEPARATOR}' - - if content.endswith(f'{SEPARATOR}'): - content = content[: -len(SEPARATOR)] - content += '\n' - - if write is False: - return content - else: - self.write(content) - return '' - - def greeting(self): - LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii' - with open(LOGO_ASCII_FILE, 'r') as logo_f: - logo_ascii = logo_f.read() - content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n' - content += f'version {__version__}, {time.ctime()}\n' - self.write(content) - self.write(logo_ascii) - - def bar(self): - content = '-' * Logger.SCREEN_WIDTH + '\n' - self.write(content) - - def print_config( - self, - model_config: Dict[str, Any], - data_config: Dict[str, Any], - train_config: Dict[str, Any], - ): - """ - print some important information from config - """ - content = 'successfully read yaml config!\n\n' + 'from model configuration\n' - for k, v in model_config.items(): - content += self.format_k_v(k, str(v)) - content += '\nfrom train configuration\n' - for k, v in train_config.items(): - content += self.format_k_v(k, str(v)) - content += '\nfrom data configuration\n' - for k, v in data_config.items(): - content += self.format_k_v(k, str(v)) - self.write(content) - - # TODO: This is not good make own exception - def error(self, e: Exception): - content = '' - if type(e) is ValueError: - content += 'Error occurred!\n' - content += str(e) + '\n' - else: - content += 'Unknown error occurred!\n' - content += traceback.format_exc() - self.write(content) - - def timer_start(self, name: str): - self.timer_dct[name] = datetime.now() - - def timer_end(self, name: str, message: str, remove: bool = True): - """ - print f"{message}: {elapsed}" - """ - elapsed = str(datetime.now() - self.timer_dct[name]) - # elapsed = elapsed.strftime('%H-%M-%S') - if remove: - del self.timer_dct[name] - self.write(f'{message}: {elapsed[:-4]}\n') - - # TODO: print it without config - # TODO: refactoring, readout part name :( - def print_model_info(self, model, config): - from functools import partial - - kv_write = partial(self.format_k_v, write=True) - self.writeline('Irreps of features') - kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out')) - for i in range(config[KEY.NUM_CONVOLUTION]): - kv_write( - f'{i}th node', - model.get_irreps_in(f'{i}_self_interaction_1'), - ) - i = config[KEY.NUM_CONVOLUTION] - 1 - kv_write( - 'readout irreps', - model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'), - ) - - num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad) - self.writeline(f'# learnable parameters: {num_weights}\n') +import os +import time +import traceback +from datetime import datetime +from typing import Any, Dict, List, Optional + +from ase.data import atomic_numbers + +import sevenn._keys as KEY +from sevenn import __version__ + +CHEM_SYMBOLS = {v: k for k, v in atomic_numbers.items()} + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Logger(metaclass=Singleton): + SCREEN_WIDTH = 120 # half size of my screen / changed due to stress output + + def __init__( + self, filename: Optional[str] = None, screen: bool = False, rank: int = 0 + ): + self.rank = rank + self._filename = filename + if rank == 0: + # if filename is not None: + # self.logfile = open(filename, 'a', buffering=1) + self.logfile = None + self.files = {} + self.screen = screen + else: + self.logfile = None + self.screen = False + self.timer_dct = {} + self.active = True + + def __enter__(self): + if self.rank != 0: + return self + if self.logfile is None and self._filename is not None: + try: + self.logfile = open( + self._filename, 'a', buffering=1, encoding='utf-8' + ) + except IOError as e: + print(f'Failed to re-open log file {self._filename}: {e}') + self.logfile = None + self.files = {} + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.rank != 0: + return self + try: + if self.logfile is not None: + self.logfile.close() + self.logfile = None + for f in self.files.values(): + f.close() + except IOError as e: + print(f'Failed to close log files: {e}') + finally: + self.logfile = None + self.files = {} + + def switch_file(self, new_filename: str): + if self.rank != 0: + return self + if self.logfile is not None: + raise ValueError('Current logfile is not yet closed') + self._filename = new_filename + return self + + def write(self, content: str): + if self.rank != 0: + return + # no newline! + if self.logfile is not None and self.active: + self.logfile.write(content) + if self.screen and self.active: + print(content, end='') + + def writeline(self, content: str): + content = content + '\n' + self.write(content) + + def init_csv(self, filename: str, header: list): + """ + Deprecated + """ + if self.rank == 0: + self.files[filename] = open(filename, 'w', buffering=1, encoding='utf-8') + self.files[filename].write(','.join(header) + '\n') + else: + pass + + def append_csv(self, filename: str, content: list, decimal: int = 6): + """ + Deprecated + """ + if self.rank == 0: + if filename not in self.files: + self.files[filename] = open(filename, 'a', buffering=1) + str_content = [] + for c in content: + if isinstance(c, float): + str_content.append(f'{c:.{decimal}f}') + else: + str_content.append(str(c)) + self.files[filename].write(','.join(str_content) + '\n') + else: + pass + + def natoms_write(self, natoms: Dict[str, Dict]): + content = '' + total_natom = {} + for label, natom in natoms.items(): + content += self.format_k_v(label, natom) + for specie, num in natom.items(): + try: + total_natom[specie] += num + except KeyError: + total_natom[specie] = num + content += self.format_k_v('Total, label wise', total_natom) + content += self.format_k_v('Total', sum(total_natom.values())) + self.write(content) + + def statistic_write(self, statistic: Dict[str, Dict]): + content = '' + for label, dct in statistic.items(): + if label.startswith('_'): + continue + if not isinstance(dct, dict): + continue + dct_new = {} + for k, v in dct.items(): + if k.startswith('_'): + continue + if isinstance(v, int): + dct_new[k] = v + else: + dct_new[k] = f'{v:.3f}' + content += self.format_k_v(label, dct_new) + self.write(content) + + # TODO : refactoring!!!, this is not loss, rmse + def epoch_write_specie_wise_loss(self, train_loss, valid_loss): + lb_pad = 21 + fs = 6 + pad = 21 - fs + ln = '-' * fs + total_atom_type = train_loss.keys() + content = '' + + for at in total_atom_type: + t_F = train_loss[at] + v_F = valid_loss[at] + at_sym = CHEM_SYMBOLS[at] + content += '{label:{lb_pad}}{t_E:<{pad}.{fs}s}{v_E:<{pad}.{fs}s}'.format( + label=at_sym, t_E=ln, v_E=ln, lb_pad=lb_pad, pad=pad, fs=fs + ) + '{t_F:<{pad}.{fs}f}{v_F:<{pad}.{fs}f}'.format( + t_F=t_F, v_F=v_F, pad=pad, fs=fs + ) + content += '{t_S:<{pad}.{fs}s}{v_S:<{pad}.{fs}s}'.format( + t_S=ln, v_S=ln, pad=pad, fs=fs + ) + content += '\n' + self.write(content) + + def write_full_table( + self, + dict_list: List[Dict], + row_labels: List[str], + decimal_places: int = 6, + pad: int = 2, + ): + """ + Assume data_list is list of dict with same keys + """ + assert len(dict_list) == len(row_labels) + label_len = max(map(len, row_labels)) + # Extract the column names and create a 2D array of values + col_names = list(dict_list[0].keys()) + + values = [list(d.values()) for d in dict_list] + + # Format the numbers with the given decimal places + formatted_values = [ + [f'{value:.{decimal_places}f}' for value in row] for row in values + ] + + # Calculate padding lengths for each column (with extra padding) + max_col_lengths = [ + max(len(str(value)) for value in col) + pad + for col in zip(col_names, *formatted_values) + ] + + # Create header row and separator + header = ' ' * (label_len + pad) + ' '.join( + col_name.ljust(pad) for col_name, pad in zip(col_names, max_col_lengths) + ) + separator = '-'.join('-' * pad for pad in max_col_lengths) + '-' * ( + label_len + pad + ) + + # Print header and separator + self.writeline(header) + self.writeline(separator) + + # Print the data rows with row labels + for row_label, row in zip(row_labels, formatted_values): + data_row = ' '.join( + value.rjust(pad) for value, pad in zip(row, max_col_lengths) + ) + self.writeline(f'{row_label.ljust(label_len)}{data_row}') + + def format_k_v(self, key: Any, val: Any, write: bool = False): + """ + key and val should be str convertible + """ + MAX_KEY_SIZE = 20 + SEPARATOR = ', ' + EMPTY_PADDING = ' ' * (MAX_KEY_SIZE + 3) + NEW_LINE_LEN = Logger.SCREEN_WIDTH - 5 + key = str(key) + val = str(val) + content = f'{key:<{MAX_KEY_SIZE}}: {val}' + if len(content) > NEW_LINE_LEN: + content = f'{key:<{MAX_KEY_SIZE}}: ' + # septate val by separator + val_list = val.split(SEPARATOR) + current_len = len(content) + for val_compo in val_list: + current_len += len(val_compo) + if current_len > NEW_LINE_LEN: + newline_content = f'{EMPTY_PADDING}{val_compo}{SEPARATOR}' + content += f'\\\n{newline_content}' + current_len = len(newline_content) + else: + content += f'{val_compo}{SEPARATOR}' + + if content.endswith(f'{SEPARATOR}'): + content = content[: -len(SEPARATOR)] + content += '\n' + + if write is False: + return content + else: + self.write(content) + return '' + + def greeting(self): + LOGO_ASCII_FILE = f'{os.path.dirname(__file__)}/logo_ascii' + with open(LOGO_ASCII_FILE, 'r') as logo_f: + logo_ascii = logo_f.read() + content = 'SevenNet: Scalable EquiVariance-Enabled Neural Network\n' + content += f'version {__version__}, {time.ctime()}\n' + self.write(content) + self.write(logo_ascii) + + def bar(self): + content = '-' * Logger.SCREEN_WIDTH + '\n' + self.write(content) + + def print_config( + self, + model_config: Dict[str, Any], + data_config: Dict[str, Any], + train_config: Dict[str, Any], + ): + """ + print some important information from config + """ + content = 'successfully read yaml config!\n\n' + 'from model configuration\n' + for k, v in model_config.items(): + content += self.format_k_v(k, str(v)) + content += '\nfrom train configuration\n' + for k, v in train_config.items(): + content += self.format_k_v(k, str(v)) + content += '\nfrom data configuration\n' + for k, v in data_config.items(): + content += self.format_k_v(k, str(v)) + self.write(content) + + # TODO: This is not good make own exception + def error(self, e: Exception): + content = '' + if type(e) is ValueError: + content += 'Error occurred!\n' + content += str(e) + '\n' + else: + content += 'Unknown error occurred!\n' + content += traceback.format_exc() + self.write(content) + + def timer_start(self, name: str): + self.timer_dct[name] = datetime.now() + + def timer_end(self, name: str, message: str, remove: bool = True): + """ + print f"{message}: {elapsed}" + """ + elapsed = str(datetime.now() - self.timer_dct[name]) + # elapsed = elapsed.strftime('%H-%M-%S') + if remove: + del self.timer_dct[name] + self.write(f'{message}: {elapsed[:-4]}\n') + + # TODO: print it without config + # TODO: refactoring, readout part name :( + def print_model_info(self, model, config): + from functools import partial + + kv_write = partial(self.format_k_v, write=True) + self.writeline('Irreps of features') + kv_write('edge_feature', model.get_irreps_in('edge_embedding', 'irreps_out')) + for i in range(config[KEY.NUM_CONVOLUTION]): + kv_write( + f'{i}th node', + model.get_irreps_in(f'{i}_self_interaction_1'), + ) + i = config[KEY.NUM_CONVOLUTION] - 1 + kv_write( + 'readout irreps', + model.get_irreps_in(f'{i}_equivariant_gate', 'irreps_out'), + ) + + num_weights = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.writeline(f'# learnable parameters: {num_weights}\n') diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn.py index a944add..d96a1bf 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn.py @@ -1,248 +1,248 @@ -import argparse -import os -import sys -import time - -from sevenn import __version__ - -description = 'train a model given the input.yaml' - -input_yaml_help = 'input.yaml for training' -mode_help = 'main training script to run. Default is train.' -working_dir_help = 'path to write output. Default is cwd.' -screen_help = 'print log to stdout' -distributed_help = 'set this flag if it is distributed training' -distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi' - -# Metainfo will be saved to checkpoint -global_config = { - 'version': __version__, - 'when': time.ctime(), - '_model_type': 'E3_equivariant_model', -} - - -def run(args): - """ - main function of sevenn - """ - import random - import sys - - import torch - import torch.distributed as dist - - import sevenn._keys as KEY - from sevenn.logger import Logger - from sevenn.parse_input import read_config_yaml - from sevenn.scripts.train import train, train_v2 - from sevenn.util import unique_filepath - - input_yaml = args.input_yaml - mode = args.mode - working_dir = args.working_dir - log = args.log - screen = args.screen - distributed = args.distributed - distributed_backend = args.distributed_backend - use_cue = args.enable_cueq - - if use_cue: - import sevenn.nn.cue_helper - - if not sevenn.nn.cue_helper.is_cue_available(): - raise ImportError('cuEquivariance not installed.') - - if working_dir is None: - working_dir = os.getcwd() - elif not os.path.isdir(working_dir): - os.makedirs(working_dir, exist_ok=True) - - world_size = 1 - if distributed: - if distributed_backend == 'nccl': - local_rank = int(os.environ['LOCAL_RANK']) - rank = int(os.environ['RANK']) - world_size = int(os.environ['WORLD_SIZE']) - elif distributed_backend == 'mpi': - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) - else: - raise ValueError(f'Unknown distributed backend: {distributed_backend}') - - dist.init_process_group( - backend=distributed_backend, world_size=world_size, rank=rank - ) - else: - local_rank, rank, world_size = 0, 0, 1 - - log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}') - with Logger(filename=log_fname, screen=screen, rank=rank) as logger: - logger.greeting() - - if distributed: - logger.writeline( - f'Distributed training enabled, total world size is {world_size}' - ) - - try: - model_config, train_config, data_config = read_config_yaml( - input_yaml, return_separately=True - ) - except Exception as e: - logger.writeline('Failed to parsing input.yaml') - logger.error(e) - sys.exit(1) - - train_config[KEY.IS_DDP] = distributed - train_config[KEY.DDP_BACKEND] = distributed_backend - train_config[KEY.LOCAL_RANK] = local_rank - train_config[KEY.RANK] = rank - train_config[KEY.WORLD_SIZE] = world_size - - if distributed: - torch.cuda.set_device(torch.device('cuda', local_rank)) - - if use_cue: - if KEY.CUEQUIVARIANCE_CONFIG not in model_config: - model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True} - else: - model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True}) - - logger.print_config(model_config, data_config, train_config) - # don't have to distinguish configs inside program - global_config.update(model_config) - global_config.update(train_config) - global_config.update(data_config) - - # Not implemented - if global_config[KEY.DTYPE] == 'double': - raise Exception('double precision is not implemented yet') - # torch.set_default_dtype(torch.double) - - seed = global_config[KEY.RANDOM_SEED] - random.seed(seed) - torch.manual_seed(seed) - - # run train - if mode == 'train_v1': - train(global_config, working_dir) - elif mode == 'train_v2': - train_v2(global_config, working_dir) - - -def cmd_parser_train(parser): - ag = parser - ag.add_argument('input_yaml', help=input_yaml_help, type=str) - ag.add_argument( - '-m', - '--mode', - choices=['train_v1', 'train_v2'], - default='train_v2', - help=mode_help, - type=str, - ) - ag.add_argument( - '-cueq', - '--enable_cueq', - help='(Not stable!) use cuEquivariance for training', - action='store_true', - ) - ag.add_argument( - '-w', - '--working_dir', - nargs='?', - const=os.getcwd(), - help=working_dir_help, - type=str, - ) - ag.add_argument( - '-l', - '--log', - default='log.sevenn', - help='name of logfile, default is log.sevenn', - type=str, - ) - ag.add_argument('-s', '--screen', help=screen_help, action='store_true') - ag.add_argument( - '-d', '--distributed', help=distributed_help, action='store_true' - ) - ag.add_argument( - '--distributed_backend', - help=distributed_backend_help, - type=str, - default='nccl', - choices=['nccl', 'mpi'], - ) - - -def add_parser(subparsers): - ag = subparsers.add_parser('train', help=description) - cmd_parser_train(ag) - - -def set_default_subparser(self, name, args=None, positional_args=0): - """default subparser selection. Call after setup, just before parse_args() - name: is the name of the subparser to call by default - args: if set is the argument list handed to parse_args() - - Hack copied from stack overflow - """ - subparser_found = False - for arg in sys.argv[1:]: - if arg in ['-h', '--help']: # global help if no subparser - break - else: - for x in self._subparsers._actions: - if not isinstance(x, argparse._SubParsersAction): - continue - for sp_name in x._name_parser_map.keys(): - if sp_name in sys.argv[1:]: - subparser_found = True - if not subparser_found: - # insert default in last position before global positional - # arguments, this implies no global options are specified after - # first positional argument - if args is None: - sys.argv.insert(len(sys.argv) - positional_args, name) - else: - args.insert(len(args) - positional_args, name) - - -argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore - - -def main(): - import sevenn.main.sevenn_cp as checkpoint_cmd - import sevenn.main.sevenn_get_model as get_model_cmd - import sevenn.main.sevenn_graph_build as graph_build_cmd - import sevenn.main.sevenn_inference as inference_cmd - import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd - import sevenn.main.sevenn_preset as preset_cmd - - ag = argparse.ArgumentParser(f'SevenNet version={__version__}') - - subparsers = ag.add_subparsers(dest='command', help='Sub-commands') - add_parser(subparsers) # add 'train' - checkpoint_cmd.add_parser(subparsers) - inference_cmd.add_parser(subparsers) - graph_build_cmd.add_parser(subparsers) - preset_cmd.add_parser(subparsers) - get_model_cmd.add_parser(subparsers) - patch_lammps_cmd.add_parser(subparsers) - - ag.set_default_subparser('train') # type: ignore - args = ag.parse_args() - - if args.command is None: # backward compatibility - args.command = 'train' - - if args.command == 'train': - run(args) - elif args.command == 'preset': - preset_cmd.run(args) - - -if __name__ == '__main__': - main() +import argparse +import os +import sys +import time + +from sevenn import __version__ + +description = 'train a model given the input.yaml' + +input_yaml_help = 'input.yaml for training' +mode_help = 'main training script to run. Default is train.' +working_dir_help = 'path to write output. Default is cwd.' +screen_help = 'print log to stdout' +distributed_help = 'set this flag if it is distributed training' +distributed_backend_help = 'backend for distributed training. Supported: nccl, mpi' + +# Metainfo will be saved to checkpoint +global_config = { + 'version': __version__, + 'when': time.ctime(), + '_model_type': 'E3_equivariant_model', +} + + +def run(args): + """ + main function of sevenn + """ + import random + import sys + + import torch + import torch.distributed as dist + + import sevenn._keys as KEY + from sevenn.logger import Logger + from sevenn.parse_input import read_config_yaml + from sevenn.scripts.train import train, train_v2 + from sevenn.util import unique_filepath + + input_yaml = args.input_yaml + mode = args.mode + working_dir = args.working_dir + log = args.log + screen = args.screen + distributed = args.distributed + distributed_backend = args.distributed_backend + use_cue = args.enable_cueq + + if use_cue: + import sevenn.nn.cue_helper + + if not sevenn.nn.cue_helper.is_cue_available(): + raise ImportError('cuEquivariance not installed.') + + if working_dir is None: + working_dir = os.getcwd() + elif not os.path.isdir(working_dir): + os.makedirs(working_dir, exist_ok=True) + + world_size = 1 + if distributed: + if distributed_backend == 'nccl': + local_rank = int(os.environ['LOCAL_RANK']) + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + elif distributed_backend == 'mpi': + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + else: + raise ValueError(f'Unknown distributed backend: {distributed_backend}') + + dist.init_process_group( + backend=distributed_backend, world_size=world_size, rank=rank + ) + else: + local_rank, rank, world_size = 0, 0, 1 + + log_fname = unique_filepath(f'{os.path.abspath(working_dir)}/{log}') + with Logger(filename=log_fname, screen=screen, rank=rank) as logger: + logger.greeting() + + if distributed: + logger.writeline( + f'Distributed training enabled, total world size is {world_size}' + ) + + try: + model_config, train_config, data_config = read_config_yaml( + input_yaml, return_separately=True + ) + except Exception as e: + logger.writeline('Failed to parsing input.yaml') + logger.error(e) + sys.exit(1) + + train_config[KEY.IS_DDP] = distributed + train_config[KEY.DDP_BACKEND] = distributed_backend + train_config[KEY.LOCAL_RANK] = local_rank + train_config[KEY.RANK] = rank + train_config[KEY.WORLD_SIZE] = world_size + + if distributed: + torch.cuda.set_device(torch.device('cuda', local_rank)) + + if use_cue: + if KEY.CUEQUIVARIANCE_CONFIG not in model_config: + model_config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': True} + else: + model_config[KEY.CUEQUIVARIANCE_CONFIG].update({'use': True}) + + logger.print_config(model_config, data_config, train_config) + # don't have to distinguish configs inside program + global_config.update(model_config) + global_config.update(train_config) + global_config.update(data_config) + + # Not implemented + if global_config[KEY.DTYPE] == 'double': + raise Exception('double precision is not implemented yet') + # torch.set_default_dtype(torch.double) + + seed = global_config[KEY.RANDOM_SEED] + random.seed(seed) + torch.manual_seed(seed) + + # run train + if mode == 'train_v1': + train(global_config, working_dir) + elif mode == 'train_v2': + train_v2(global_config, working_dir) + + +def cmd_parser_train(parser): + ag = parser + ag.add_argument('input_yaml', help=input_yaml_help, type=str) + ag.add_argument( + '-m', + '--mode', + choices=['train_v1', 'train_v2'], + default='train_v2', + help=mode_help, + type=str, + ) + ag.add_argument( + '-cueq', + '--enable_cueq', + help='(Not stable!) use cuEquivariance for training', + action='store_true', + ) + ag.add_argument( + '-w', + '--working_dir', + nargs='?', + const=os.getcwd(), + help=working_dir_help, + type=str, + ) + ag.add_argument( + '-l', + '--log', + default='log.sevenn', + help='name of logfile, default is log.sevenn', + type=str, + ) + ag.add_argument('-s', '--screen', help=screen_help, action='store_true') + ag.add_argument( + '-d', '--distributed', help=distributed_help, action='store_true' + ) + ag.add_argument( + '--distributed_backend', + help=distributed_backend_help, + type=str, + default='nccl', + choices=['nccl', 'mpi'], + ) + + +def add_parser(subparsers): + ag = subparsers.add_parser('train', help=description) + cmd_parser_train(ag) + + +def set_default_subparser(self, name, args=None, positional_args=0): + """default subparser selection. Call after setup, just before parse_args() + name: is the name of the subparser to call by default + args: if set is the argument list handed to parse_args() + + Hack copied from stack overflow + """ + subparser_found = False + for arg in sys.argv[1:]: + if arg in ['-h', '--help']: # global help if no subparser + break + else: + for x in self._subparsers._actions: + if not isinstance(x, argparse._SubParsersAction): + continue + for sp_name in x._name_parser_map.keys(): + if sp_name in sys.argv[1:]: + subparser_found = True + if not subparser_found: + # insert default in last position before global positional + # arguments, this implies no global options are specified after + # first positional argument + if args is None: + sys.argv.insert(len(sys.argv) - positional_args, name) + else: + args.insert(len(args) - positional_args, name) + + +argparse.ArgumentParser.set_default_subparser = set_default_subparser # type: ignore + + +def main(): + import sevenn.main.sevenn_cp as checkpoint_cmd + import sevenn.main.sevenn_get_model as get_model_cmd + import sevenn.main.sevenn_graph_build as graph_build_cmd + import sevenn.main.sevenn_inference as inference_cmd + import sevenn.main.sevenn_patch_lammps as patch_lammps_cmd + import sevenn.main.sevenn_preset as preset_cmd + + ag = argparse.ArgumentParser(f'SevenNet version={__version__}') + + subparsers = ag.add_subparsers(dest='command', help='Sub-commands') + add_parser(subparsers) # add 'train' + checkpoint_cmd.add_parser(subparsers) + inference_cmd.add_parser(subparsers) + graph_build_cmd.add_parser(subparsers) + preset_cmd.add_parser(subparsers) + get_model_cmd.add_parser(subparsers) + patch_lammps_cmd.add_parser(subparsers) + + ag.set_default_subparser('train') # type: ignore + args = ag.parse_args() + + if args.command is None: # backward compatibility + args.command = 'train' + + if args.command == 'train': + run(args) + elif args.command == 'preset': + preset_cmd.run(args) + + +if __name__ == '__main__': + main() diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_cp.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_cp.py index 29edf5b..319cb1a 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_cp.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_cp.py @@ -1,92 +1,92 @@ -import argparse -import os.path as osp - -from sevenn import __version__ - -description = ( - 'tool box for sevennet checkpoints' -) - - -def add_parser(subparsers): - ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp']) - add_args(ag) - - -def add_args(parser): - ag = parser - - ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str) - - group = ag.add_mutually_exclusive_group(required=False) - group.add_argument( - '--get_yaml', - choices=['reproduce', 'continue', 'continue_modal'], - help='create input.yaml based on the given checkpoint', - type=str, - ) - - group.add_argument( - '--append_modal_yaml', - help='append modality with given yaml.', - type=str, - ) - ag.add_argument( - '--original_modal_name', - help=( - 'when the append_modal is used and checkpoint is not multi-modal, ' - + 'used to name previously trained modality. defaults to "origin"' - ), - default='origin', - type=str, - ) - - -def run(args): - import torch - import yaml - - from sevenn.parse_input import read_config_yaml - from sevenn.util import load_checkpoint - - checkpoint = load_checkpoint(args.checkpoint) - if args.get_yaml: - mode = args.get_yaml - cfg = checkpoint.yaml_dict(mode) - print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False)) - elif args.append_modal_yaml: - dst_yaml = args.append_modal_yaml - if not osp.exists(dst_yaml): - raise FileNotFoundError(f'No yaml file {dst_yaml}') - - dst_config = read_config_yaml(dst_yaml, return_separately=False) - model_state_dict = checkpoint.append_modal( - dst_config, args.original_modal_name - ) - - to_save = checkpoint.get_checkpoint_dict() - to_save.update({'config': dst_config, 'model_state_dict': model_state_dict}) - - torch.save(to_save, 'checkpoint_modal_appended.pth') - print('checkpoint_modal_appended.pth is successfully saved.') - print(f'update continue of {dst_yaml} as blow (recommend) to continue') - cont_dct = { - 'continue': { - 'checkpoint': 'checkpoint_modal_appended.pth', - 'reset_epoch': True, - 'reset_optimizer': True, - 'reset_scheduler': True, - } - } - print( - yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False) - ) - - else: - print(checkpoint) - - -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) +import argparse +import os.path as osp + +from sevenn import __version__ + +description = ( + 'tool box for sevennet checkpoints' +) + + +def add_parser(subparsers): + ag = subparsers.add_parser('checkpoint', help=description, aliases=['cp']) + add_args(ag) + + +def add_args(parser): + ag = parser + + ag.add_argument('checkpoint', help='checkpoint or pretrained', type=str) + + group = ag.add_mutually_exclusive_group(required=False) + group.add_argument( + '--get_yaml', + choices=['reproduce', 'continue', 'continue_modal'], + help='create input.yaml based on the given checkpoint', + type=str, + ) + + group.add_argument( + '--append_modal_yaml', + help='append modality with given yaml.', + type=str, + ) + ag.add_argument( + '--original_modal_name', + help=( + 'when the append_modal is used and checkpoint is not multi-modal, ' + + 'used to name previously trained modality. defaults to "origin"' + ), + default='origin', + type=str, + ) + + +def run(args): + import torch + import yaml + + from sevenn.parse_input import read_config_yaml + from sevenn.util import load_checkpoint + + checkpoint = load_checkpoint(args.checkpoint) + if args.get_yaml: + mode = args.get_yaml + cfg = checkpoint.yaml_dict(mode) + print(yaml.dump(cfg, indent=4, sort_keys=False, default_flow_style=False)) + elif args.append_modal_yaml: + dst_yaml = args.append_modal_yaml + if not osp.exists(dst_yaml): + raise FileNotFoundError(f'No yaml file {dst_yaml}') + + dst_config = read_config_yaml(dst_yaml, return_separately=False) + model_state_dict = checkpoint.append_modal( + dst_config, args.original_modal_name + ) + + to_save = checkpoint.get_checkpoint_dict() + to_save.update({'config': dst_config, 'model_state_dict': model_state_dict}) + + torch.save(to_save, 'checkpoint_modal_appended.pth') + print('checkpoint_modal_appended.pth is successfully saved.') + print(f'update continue of {dst_yaml} as blow (recommend) to continue') + cont_dct = { + 'continue': { + 'checkpoint': 'checkpoint_modal_appended.pth', + 'reset_epoch': True, + 'reset_optimizer': True, + 'reset_scheduler': True, + } + } + print( + yaml.dump(cont_dct, indent=4, sort_keys=False, default_flow_style=False) + ) + + else: + print(checkpoint) + + +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_get_model.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_get_model.py index f2b7b0a..f8b78e8 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_get_model.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_get_model.py @@ -1,70 +1,70 @@ -import argparse -import os - -from sevenn import __version__ - -description_get_model = ( - 'deploy LAMMPS model from the checkpoint' -) -checkpoint_help = ( - 'path to the checkpoint | SevenNet-0 | 7net-0 |' - ' {SevenNet-0|7net-0}_{11July2024|22May2024}' -) -output_name_help = 'filename prefix' -get_parallel_help = 'deploy parallel model' - - -def add_parser(subparsers): - ag = subparsers.add_parser( - 'get_model', help=description_get_model, aliases=['deploy'] - ) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument('checkpoint', help=checkpoint_help, type=str) - ag.add_argument( - '-o', '--output_prefix', nargs='?', help=output_name_help, type=str - ) - ag.add_argument( - '-p', '--get_parallel', help=get_parallel_help, action='store_true' - ) - ag.add_argument( - '-m', - '--modal', - help='Modality of multi-modal model', - type=str, - ) - - -def run(args): - import sevenn.util - from sevenn.scripts.deploy import deploy, deploy_parallel - - checkpoint = args.checkpoint - output_prefix = args.output_prefix - get_parallel = args.get_parallel - get_serial = not get_parallel - modal = args.modal - - if output_prefix is None: - output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' - - checkpoint_path = None - if os.path.isfile(checkpoint): - checkpoint_path = checkpoint - else: - checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) - - if get_serial: - deploy(checkpoint_path, output_prefix, modal) - else: - deploy_parallel(checkpoint_path, output_prefix, modal) - - -# legacy way -def main(): - ag = argparse.ArgumentParser(description=description_get_model) - add_args(ag) - run(ag.parse_args()) +import argparse +import os + +from sevenn import __version__ + +description_get_model = ( + 'deploy LAMMPS model from the checkpoint' +) +checkpoint_help = ( + 'path to the checkpoint | SevenNet-0 | 7net-0 |' + ' {SevenNet-0|7net-0}_{11July2024|22May2024}' +) +output_name_help = 'filename prefix' +get_parallel_help = 'deploy parallel model' + + +def add_parser(subparsers): + ag = subparsers.add_parser( + 'get_model', help=description_get_model, aliases=['deploy'] + ) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument('checkpoint', help=checkpoint_help, type=str) + ag.add_argument( + '-o', '--output_prefix', nargs='?', help=output_name_help, type=str + ) + ag.add_argument( + '-p', '--get_parallel', help=get_parallel_help, action='store_true' + ) + ag.add_argument( + '-m', + '--modal', + help='Modality of multi-modal model', + type=str, + ) + + +def run(args): + import sevenn.util + from sevenn.scripts.deploy import deploy, deploy_parallel + + checkpoint = args.checkpoint + output_prefix = args.output_prefix + get_parallel = args.get_parallel + get_serial = not get_parallel + modal = args.modal + + if output_prefix is None: + output_prefix = 'deployed_parallel' if not get_serial else 'deployed_serial' + + checkpoint_path = None + if os.path.isfile(checkpoint): + checkpoint_path = checkpoint + else: + checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint) + + if get_serial: + deploy(checkpoint_path, output_prefix, modal) + else: + deploy_parallel(checkpoint_path, output_prefix, modal) + + +# legacy way +def main(): + ag = argparse.ArgumentParser(description=description_get_model) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_graph_build.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_graph_build.py index da533c9..fd9eaee 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_graph_build.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_graph_build.py @@ -1,130 +1,130 @@ -import argparse -import glob -import os -import sys -from datetime import datetime - -from sevenn import __version__ - -description = 'create `sevenn_data/dataset.pt` from ase readable' - -source_help = 'source data to build graph, knows *' -cutoff_help = 'cutoff radius of edges in Angstrom' -filename_help = ( - 'Name of the dataset, default is graph.pt. ' - + 'The dataset will be written under "sevenn_data", ' - + 'for example, {out}/sevenn_data/graph.pt.' -) -legacy_help = 'build legacy .sevenn_data' - - -def add_parser(subparsers): - ag = subparsers.add_parser('graph_build', help=description) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument('source', help=source_help, type=str) - ag.add_argument('cutoff', help=cutoff_help, type=float) - ag.add_argument( - '-n', - '--num_cores', - help='number of cores to build graph in parallel', - default=1, - type=int, - ) - ag.add_argument( - '-o', - '--out', - help='Existing path to write outputs.', - type=str, - default='./', - ) - ag.add_argument( - '-f', - '--filename', - help=filename_help, - type=str, - default='graph.pt', - ) - ag.add_argument( - '--legacy', - help=legacy_help, - action='store_true', - ) - ag.add_argument( - '-s', - '--screen', - help='print log to the screen', - action='store_true', - ) - ag.add_argument( - '--kwargs', - nargs=argparse.REMAINDER, - help='will be passed to ase.io.read, or can be used to specify EFS key', - ) - - -def run(args): - import sevenn.scripts.graph_build as graph_build - from sevenn.logger import Logger - - source = glob.glob(args.source) - cutoff = args.cutoff - num_cores = args.num_cores - filename = args.filename - out = args.out - legacy = args.legacy - fmt_kwargs = {} - if args.kwargs: - for kwarg in args.kwargs: - k, v = kwarg.split('=') - fmt_kwargs[k] = v - - if len(source) == 0: - print('Source has zero len, nothing to read') - sys.exit(0) - - if not os.path.isdir(out): - raise NotADirectoryError(f'No such directory: {out}') - - to_be_written = os.path.join(out, 'sevenn_data', filename) - if os.path.isfile(to_be_written): - raise FileExistsError(f'File already exist: {to_be_written}') - - metadata = { - 'sevenn_version': __version__, - 'when': datetime.now().strftime('%Y-%m-%d'), - 'cutoff': cutoff, - } - - with Logger(filename=None, screen=args.screen) as logger: - logger.writeline(description) - - if not legacy: - graph_build.build_sevennet_graph_dataset( - source, - cutoff, - num_cores, - out, - filename, - metadata, - **fmt_kwargs, - ) - else: - out = os.path.join(out, filename.split('.')[0]) - graph_build.build_script( # build .sevenn_data - source, - cutoff, - num_cores, - out, - metadata, - **fmt_kwargs, - ) - - -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) +import argparse +import glob +import os +import sys +from datetime import datetime + +from sevenn import __version__ + +description = 'create `sevenn_data/dataset.pt` from ase readable' + +source_help = 'source data to build graph, knows *' +cutoff_help = 'cutoff radius of edges in Angstrom' +filename_help = ( + 'Name of the dataset, default is graph.pt. ' + + 'The dataset will be written under "sevenn_data", ' + + 'for example, {out}/sevenn_data/graph.pt.' +) +legacy_help = 'build legacy .sevenn_data' + + +def add_parser(subparsers): + ag = subparsers.add_parser('graph_build', help=description) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument('source', help=source_help, type=str) + ag.add_argument('cutoff', help=cutoff_help, type=float) + ag.add_argument( + '-n', + '--num_cores', + help='number of cores to build graph in parallel', + default=1, + type=int, + ) + ag.add_argument( + '-o', + '--out', + help='Existing path to write outputs.', + type=str, + default='./', + ) + ag.add_argument( + '-f', + '--filename', + help=filename_help, + type=str, + default='graph.pt', + ) + ag.add_argument( + '--legacy', + help=legacy_help, + action='store_true', + ) + ag.add_argument( + '-s', + '--screen', + help='print log to the screen', + action='store_true', + ) + ag.add_argument( + '--kwargs', + nargs=argparse.REMAINDER, + help='will be passed to ase.io.read, or can be used to specify EFS key', + ) + + +def run(args): + import sevenn.scripts.graph_build as graph_build + from sevenn.logger import Logger + + source = glob.glob(args.source) + cutoff = args.cutoff + num_cores = args.num_cores + filename = args.filename + out = args.out + legacy = args.legacy + fmt_kwargs = {} + if args.kwargs: + for kwarg in args.kwargs: + k, v = kwarg.split('=') + fmt_kwargs[k] = v + + if len(source) == 0: + print('Source has zero len, nothing to read') + sys.exit(0) + + if not os.path.isdir(out): + raise NotADirectoryError(f'No such directory: {out}') + + to_be_written = os.path.join(out, 'sevenn_data', filename) + if os.path.isfile(to_be_written): + raise FileExistsError(f'File already exist: {to_be_written}') + + metadata = { + 'sevenn_version': __version__, + 'when': datetime.now().strftime('%Y-%m-%d'), + 'cutoff': cutoff, + } + + with Logger(filename=None, screen=args.screen) as logger: + logger.writeline(description) + + if not legacy: + graph_build.build_sevennet_graph_dataset( + source, + cutoff, + num_cores, + out, + filename, + metadata, + **fmt_kwargs, + ) + else: + out = os.path.join(out, filename.split('.')[0]) + graph_build.build_script( # build .sevenn_data + source, + cutoff, + num_cores, + out, + metadata, + **fmt_kwargs, + ) + + +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_inference.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_inference.py index 5a9cd90..bfac237 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_inference.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_inference.py @@ -1,129 +1,129 @@ -import argparse -import glob -import os -import sys - -description = ( - 'evaluate sevenn_data/ase readable with a model (checkpoint).' -) -checkpoint_help = 'Checkpoint or pre-trained model name' -target_help = 'Target files to evaluate' - - -def add_parser(subparsers): - ag = subparsers.add_parser('inference', help=description, aliases=['inf']) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument('checkpoint', type=str, help=checkpoint_help) - ag.add_argument('targets', type=str, nargs='+', help=target_help) - ag.add_argument( - '-d', - '--device', - type=str, - default='auto', - help='cpu/cuda/cuda:x', - ) - ag.add_argument( - '-nw', - '--nworkers', - type=int, - default=1, - help='Number of cores to build graph, defaults to 1', - ) - ag.add_argument( - '-o', - '--output', - type=str, - default='./inference_results', - help='A directory name to write outputs', - ) - ag.add_argument( - '-b', - '--batch', - type=int, - default='4', - help='batch size, useful for GPU' - ) - ag.add_argument( - '-s', - '--save_graph', - action='store_true', - help='Additionally, save preprocessed graph as sevenn_data' - ) - ag.add_argument( - '-au', - '--allow_unlabeled', - action='store_true', - help='Allow energy or force unlabeled data' - ) - ag.add_argument( - '-m', - '--modal', - type=str, - default=None, - help='modality for multi-modal inference', - ) - ag.add_argument( - '--kwargs', - nargs=argparse.REMAINDER, - help='will be passed to reader, or can be used to specify EFS key', - ) - - -def run(args): - import torch - - from sevenn.scripts.inference import inference - from sevenn.util import pretrained_name_to_path - - out = args.output - - if os.path.exists(out): - raise FileExistsError(f'Directory {out} already exists') - - device = args.device - if device == 'auto': - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - targets = [] - for target in args.targets: - targets.extend(glob.glob(target)) - - if len(targets) == 0: - print('No targets (data to inference) are found') - sys.exit(0) - - cp = args.checkpoint - if not os.path.isfile(cp): - cp = pretrained_name_to_path(cp) # raises value error - - fmt_kwargs = {} - if args.kwargs: - for kwarg in args.kwargs: - k, v = kwarg.split('=') - fmt_kwargs[k] = v - - if args.save_graph and args.allow_unlabeled: - raise ValueError('save_graph and allow_unlabeled are mutually exclusive') - - inference( - cp, - targets, - out, - args.nworkers, - device, - args.batch, - args.save_graph, - args.allow_unlabeled, - args.modal, - **fmt_kwargs, - ) - - -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) +import argparse +import glob +import os +import sys + +description = ( + 'evaluate sevenn_data/ase readable with a model (checkpoint).' +) +checkpoint_help = 'Checkpoint or pre-trained model name' +target_help = 'Target files to evaluate' + + +def add_parser(subparsers): + ag = subparsers.add_parser('inference', help=description, aliases=['inf']) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument('checkpoint', type=str, help=checkpoint_help) + ag.add_argument('targets', type=str, nargs='+', help=target_help) + ag.add_argument( + '-d', + '--device', + type=str, + default='auto', + help='cpu/cuda/cuda:x', + ) + ag.add_argument( + '-nw', + '--nworkers', + type=int, + default=1, + help='Number of cores to build graph, defaults to 1', + ) + ag.add_argument( + '-o', + '--output', + type=str, + default='./inference_results', + help='A directory name to write outputs', + ) + ag.add_argument( + '-b', + '--batch', + type=int, + default='4', + help='batch size, useful for GPU' + ) + ag.add_argument( + '-s', + '--save_graph', + action='store_true', + help='Additionally, save preprocessed graph as sevenn_data' + ) + ag.add_argument( + '-au', + '--allow_unlabeled', + action='store_true', + help='Allow energy or force unlabeled data' + ) + ag.add_argument( + '-m', + '--modal', + type=str, + default=None, + help='modality for multi-modal inference', + ) + ag.add_argument( + '--kwargs', + nargs=argparse.REMAINDER, + help='will be passed to reader, or can be used to specify EFS key', + ) + + +def run(args): + import torch + + from sevenn.scripts.inference import inference + from sevenn.util import pretrained_name_to_path + + out = args.output + + if os.path.exists(out): + raise FileExistsError(f'Directory {out} already exists') + + device = args.device + if device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + targets = [] + for target in args.targets: + targets.extend(glob.glob(target)) + + if len(targets) == 0: + print('No targets (data to inference) are found') + sys.exit(0) + + cp = args.checkpoint + if not os.path.isfile(cp): + cp = pretrained_name_to_path(cp) # raises value error + + fmt_kwargs = {} + if args.kwargs: + for kwarg in args.kwargs: + k, v = kwarg.split('=') + fmt_kwargs[k] = v + + if args.save_graph and args.allow_unlabeled: + raise ValueError('save_graph and allow_unlabeled are mutually exclusive') + + inference( + cp, + targets, + out, + args.nworkers, + device, + args.batch, + args.save_graph, + args.allow_unlabeled, + args.modal, + **fmt_kwargs, + ) + + +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_patch_lammps.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_patch_lammps.py index 9ede5a6..38f7299 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_patch_lammps.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_patch_lammps.py @@ -1,55 +1,55 @@ -import argparse -import os -import subprocess - -from sevenn import __version__ - -# python wrapper of patch_lammps.sh script -# importlib.resources is correct way to do these things -# but it changes so frequently to use -pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn') - -description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile' - - -def add_parser(subparsers): - ag = subparsers.add_parser('patch_lammps', help=description) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str) - ag.add_argument('--d3', help='Enable D3 support', action='store_true') - # cxx_standard is detected automatically - - -def run(args): - lammps_dir = os.path.abspath(args.lammps_dir) - - print('Patching LAMMPS with the following settings:') - print(' - LAMMPS source directory:', lammps_dir) - - cxx_standard = '17' # always 17 - - if args.d3: - d3_support = '1' - print(' - D3 support enabled') - else: - d3_support = '0' - print(' - D3 support disabled') - - script = f'{pair_e3gnn_dir}/patch_lammps.sh' - cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}' - res = subprocess.run(cmd.split()) - return res.returncode # is it meaningless? - - -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) - - -if __name__ == '__main__': - main() +import argparse +import os +import subprocess + +from sevenn import __version__ + +# python wrapper of patch_lammps.sh script +# importlib.resources is correct way to do these things +# but it changes so frequently to use +pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn') + +description = 'patch LAMMPS with e3gnn(7net) pair-styles before compile' + + +def add_parser(subparsers): + ag = subparsers.add_parser('patch_lammps', help=description) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str) + ag.add_argument('--d3', help='Enable D3 support', action='store_true') + # cxx_standard is detected automatically + + +def run(args): + lammps_dir = os.path.abspath(args.lammps_dir) + + print('Patching LAMMPS with the following settings:') + print(' - LAMMPS source directory:', lammps_dir) + + cxx_standard = '17' # always 17 + + if args.d3: + d3_support = '1' + print(' - D3 support enabled') + else: + d3_support = '0' + print(' - D3 support disabled') + + script = f'{pair_e3gnn_dir}/patch_lammps.sh' + cmd = f'{script} {lammps_dir} {cxx_standard} {d3_support}' + res = subprocess.run(cmd.split()) + return res.returncode # is it meaningless? + + +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) + + +if __name__ == '__main__': + main() diff --git a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_preset.py b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_preset.py index fdacea5..f8587d1 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_preset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/main/sevenn_preset.py @@ -1,45 +1,45 @@ -import argparse -import os - -from sevenn import __version__ - -description = ( - 'print the selected preset for training. ' - + 'ex) sevennet_preset fine_tune > my_input.yaml' -) - -preset_help = 'Name of preset' - - -def add_parser(subparsers): - ag = subparsers.add_parser('preset', help=description) - add_args(ag) - - -def add_args(parser): - ag = parser - ag.add_argument( - 'preset', choices=[ - 'fine_tune', - 'fine_tune_le', - 'sevennet-0', - 'sevennet-l3i5', - 'base', - 'multi_modal' - ], - help=preset_help - ) - - -def run(args): - preset = args.preset - prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets') - with open(f'{prefix}/{preset}.yaml', 'r') as f: - print(f.read()) - - -# When executed as sevenn_preset (legacy way) -def main(args=None): - ag = argparse.ArgumentParser(description=description) - add_args(ag) - run(ag.parse_args()) +import argparse +import os + +from sevenn import __version__ + +description = ( + 'print the selected preset for training. ' + + 'ex) sevennet_preset fine_tune > my_input.yaml' +) + +preset_help = 'Name of preset' + + +def add_parser(subparsers): + ag = subparsers.add_parser('preset', help=description) + add_args(ag) + + +def add_args(parser): + ag = parser + ag.add_argument( + 'preset', choices=[ + 'fine_tune', + 'fine_tune_le', + 'sevennet-0', + 'sevennet-l3i5', + 'base', + 'multi_modal' + ], + help=preset_help + ) + + +def run(args): + preset = args.preset + prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets') + with open(f'{prefix}/{preset}.yaml', 'r') as f: + print(f.read()) + + +# When executed as sevenn_preset (legacy way) +def main(args=None): + ag = argparse.ArgumentParser(description=description) + add_args(ag) + run(ag.parse_args()) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/model_build.py b/mace-bench/3rdparty/SevenNet/sevenn/model_build.py index 9113672..2b8701d 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/model_build.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/model_build.py @@ -1,556 +1,556 @@ -import copy -import warnings -from collections import OrderedDict -from typing import List, Literal, Union, overload - -from e3nn.o3 import Irreps - -import sevenn._const as _const -import sevenn._keys as KEY -import sevenn.util as util - -from .nn.convolution import IrrepsConvolution -from .nn.edge_embedding import ( - BesselBasis, - EdgeEmbedding, - PolynomialCutoff, - SphericalEncoding, - XPLORCutoff, -) -from .nn.force_output import ForceStressOutputFromEdge -from .nn.interaction_blocks import NequIP_interaction_block -from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear -from .nn.node_embedding import OnehotEmbedding -from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale -from .nn.self_connection import ( - SelfConnectionIntro, - SelfConnectionLinearIntro, - SelfConnectionOutro, -) -from .nn.sequential import AtomGraphSequential - -# warning from PyTorch, about e3nn type annotations -warnings.filterwarnings( - 'ignore', - message=( - "The TorchScript type system doesn't " 'support instance-level annotations' - ), -) - - -def _insert_after(module_name_after, key_module_pair, layers): - idx = -1 - for i, (key, _) in enumerate(layers): - if key == module_name_after: - idx = i - break - if idx == -1: - return layers # do nothing if not found - layers.insert(idx + 1, key_module_pair) - return layers - - -def init_self_connection(config): - self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE] - num_conv = config[KEY.NUM_CONVOLUTION] - if isinstance(self_connection_type_list, str): - self_connection_type_list = [self_connection_type_list] * num_conv - - io_pair_list = [] - for sc_type in self_connection_type_list: - if sc_type == 'none': - io_pair = None - elif sc_type == 'nequip': - io_pair = SelfConnectionIntro, SelfConnectionOutro - elif sc_type == 'linear': - io_pair = SelfConnectionLinearIntro, SelfConnectionOutro - else: - raise ValueError(f'Unknown self_connection_type found: {sc_type}') - io_pair_list.append(io_pair) - return io_pair_list - - -def init_edge_embedding(config): - _cutoff_param = {'cutoff_length': config[KEY.CUTOFF]} - rbf, env, sph = None, None, None - - rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS]) - rbf_dct.update(_cutoff_param) - rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME) - if rbf_name == 'bessel': - rbf = BesselBasis(**rbf_dct) - - envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION]) - envelop_dct.update(_cutoff_param) - envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME) - if envelop_name == 'poly_cut': - env = PolynomialCutoff(**envelop_dct) - elif envelop_name == 'XPLOR': - env = XPLORCutoff(**envelop_dct) - - lmax_edge = config[KEY.LMAX] - if config[KEY.LMAX_EDGE] > 0: - lmax_edge = config[KEY.LMAX_EDGE] - parity = -1 if config[KEY.IS_PARITY] else 1 - _normalize_sph = config[KEY._NORMALIZE_SPH] - sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph) - - return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph) - - -def init_feature_reduce(config, irreps_x): - # features per node to scalar per node - layers = OrderedDict() - if config[KEY.READOUT_AS_FCN] is False: - hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))]) - layers.update( - { - 'reduce_input_to_hidden': IrrepsLinear( - irreps_x, - hidden_irreps, - data_key_in=KEY.NODE_FEATURE, - biases=config[KEY.USE_BIAS_IN_LINEAR], - ), - 'reduce_hidden_to_energy': IrrepsLinear( - hidden_irreps, - Irreps([(1, (0, 1))]), - data_key_in=KEY.NODE_FEATURE, - data_key_out=KEY.SCALED_ATOMIC_ENERGY, - biases=config[KEY.USE_BIAS_IN_LINEAR], - ), - } - ) - else: - act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]] - hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS] - layers.update( - { - 'readout_FCN': FCN_e3nn( - dim_out=1, - hidden_neurons=hidden_neurons, - activation=act, - data_key_in=KEY.NODE_FEATURE, - data_key_out=KEY.SCALED_ATOMIC_ENERGY, - irreps_in=irreps_x, - ) - } - ) - return layers - - -def init_shift_scale(config): - # for mm, ex, shift: modal_idx -> shifts - shift_scale = [] - train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE] - type_map = config[KEY.TYPE_MAP] - - # in case of modal, shift or scale has more dims [][] - # correct typing (I really want static python) - for s in (config[KEY.SHIFT], config[KEY.SCALE]): - if hasattr(s, 'tolist'): # numpy or torch - s = s.tolist() - if isinstance(s, dict): - s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()} - if isinstance(s, list) and len(s) == 1: - s = s[0] - shift_scale.append(s) - shift, scale = shift_scale - - rescale_module = None - if config.get(KEY.USE_MODALITY, False): - rescale_module = ModalWiseRescale.from_mappers( # type: ignore - shift, - scale, - config[KEY.USE_MODAL_WISE_SHIFT], - config[KEY.USE_MODAL_WISE_SCALE], - type_map=type_map, - modal_map=config[KEY.MODAL_MAP], - train_shift_scale=train_shift_scale, - ) - elif all([isinstance(s, float) for s in shift_scale]): - rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale) - elif any([isinstance(s, list) for s in shift_scale]): - rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore - shift, scale, type_map=type_map, train_shift_scale=train_shift_scale - ) - else: - raise ValueError('shift, scale should be list of float or float') - - return rescale_module - - -def patch_modality(layers: OrderedDict, config): - """ - Postprocess 7net-model to multimodal model. - 1. prepend modality one-hot embedding layer - 2. patch modalities of IrrepsLinear layers - Modality aware shift scale is handled by init_shift_scale, not here - """ - cfg = config - if not cfg.get(KEY.USE_MODALITY, False): - return layers - - _layers = list(layers.items()) - _layers = _insert_after( - 'onehot_idx_to_onehot', - ( - 'one_hot_modality', - OnehotEmbedding( - num_classes=config[KEY.NUM_MODALITIES], - data_key_x=KEY.MODAL_TYPE, - data_key_out=KEY.MODAL_ATTR, - data_key_save=None, - data_key_additional=None, - ), - ), - _layers, - ) - layers = OrderedDict(_layers) - - num_modal = config[KEY.NUM_MODALITIES] - for k, module in layers.items(): - if not isinstance(module, IrrepsLinear): - continue - if ( - (cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x')) - or ( - cfg[KEY.USE_MODAL_SELF_INTER_INTRO] - and k.endswith('self_interaction_1') - ) - or ( - cfg[KEY.USE_MODAL_SELF_INTER_OUTRO] - and k.endswith('self_interaction_2') - ) - or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden') - ): - module.set_num_modalities(num_modal) - return layers - - -def patch_cue(layers: OrderedDict, config): - import sevenn.nn.cue_helper as cue_helper - - cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {})) - - if not cue_cfg.pop('use', False): - return layers - - if not cue_helper.is_cue_available(): - warnings.warn( - ( - 'cuEquivariance is requested, but the package is not installed. ' - + 'Fallback to original code.' - ) - ) - return layers - - if not cue_helper.is_cue_cuda_available_model(config): - return layers - - group = 'O3' if config[KEY.IS_PARITY] else 'SO3' - cueq_module_params = dict(layout='mul_ir') - cueq_module_params.update(cue_cfg) - updates = {} - for k, module in layers.items(): - if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)): - if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape - continue - module_patched = cue_helper.patch_linear( - module, group, **cueq_module_params - ) - updates[k] = module_patched - elif isinstance(module, SelfConnectionIntro): - module_patched = cue_helper.patch_fully_connected( - module, group, **cueq_module_params - ) - updates[k] = module_patched - elif isinstance(module, IrrepsConvolution): - module_patched = cue_helper.patch_convolution( - module, group, **cueq_module_params - ) - updates[k] = module_patched - - layers.update(updates) - return layers - - -def patch_modules(layers: OrderedDict, config): - layers = patch_modality(layers, config) - layers = patch_cue(layers, config) - return layers - - -def _to_parallel_model(layers: OrderedDict, config): - num_classes = layers['onehot_idx_to_onehot'].num_classes - one_hot_irreps = Irreps(f'{num_classes}x0e') - irreps_node_zero = layers['onehot_to_feature_x'].irreps_out - - _layers = list(layers.items()) - layers_list = [] - - num_convolution_layer = config[KEY.NUM_CONVOLUTION] - - def slice_until_this(module_name, layers): - idx = -1 - for i, (key, _) in enumerate(layers): - if key == module_name: - idx = i - break - first_to = layers[: idx + 1] - remain = layers[idx + 1 :] - return first_to, remain - - _layers = _insert_after( - 'onehot_to_feature_x', - ( - 'one_hot_ghost', - OnehotEmbedding( - data_key_x=KEY.NODE_FEATURE_GHOST, - num_classes=num_classes, - data_key_save=None, - data_key_additional=None, - ), - ), - _layers, - ) - _layers = _insert_after( - 'one_hot_ghost', - ( - 'ghost_onehot_to_feature_x', - IrrepsLinear( - irreps_in=one_hot_irreps, - irreps_out=irreps_node_zero, - data_key_in=KEY.NODE_FEATURE_GHOST, - biases=config[KEY.USE_BIAS_IN_LINEAR], - ), - ), - _layers, - ) - _layers = _insert_after( - '0_self_interaction_1', - ( - 'ghost_0_self_interaction_1', - IrrepsLinear( - irreps_node_zero, - irreps_node_zero, - data_key_in=KEY.NODE_FEATURE_GHOST, - biases=config[KEY.USE_BIAS_IN_LINEAR], - ), - ), - _layers, - ) - # assign modules (before first communications) - # initialize edge related to retain position gradients - for i in range(1, num_convolution_layer): - sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers) - layers_list.append(OrderedDict(sliced)) - _layers.insert(0, ('edge_embedding', init_edge_embedding(config))) - - layers_list.append(OrderedDict(_layers)) - del layers_list[-1]['force_output'] # done in LAMMPS - return layers_list - - -@overload -def build_E3_equivariant_model( - config: dict, parallel: Literal[False] = False -) -> AtomGraphSequential: # noqa - ... - - -@overload -def build_E3_equivariant_model( - config: dict, parallel: Literal[True] -) -> List[AtomGraphSequential]: # noqa - ... - - -def build_E3_equivariant_model( - config: dict, parallel: bool = False -) -> Union[AtomGraphSequential, List[AtomGraphSequential]]: - """ - output shapes (w/o batch) - - PRED_TOTAL_ENERGY: (), - ATOMIC_ENERGY: (natoms, 1), # intended - PRED_FORCE: (natoms, 3), - PRED_STRESS: (6,), - - for data w/o cell volume, pred_stress has garbage values - """ - layers = OrderedDict() - - cutoff = config[KEY.CUTOFF] - num_species = config[KEY.NUM_SPECIES] - feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY] - num_convolution_layer = config[KEY.NUM_CONVOLUTION] - interaction_type = config[KEY.INTERACTION_TYPE] - use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR] - - lmax_node = config[KEY.LMAX] # ignore second (lmax_edge) - # if config[KEY.LMAX_EDGE] > 0: # not yet used - # _ = config[KEY.LMAX_EDGE] - if config[KEY.LMAX_NODE] > 0: - lmax_node = config[KEY.LMAX_NODE] - - act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]] - self_connection_pair_list = init_self_connection(config) - - irreps_manual = None - if config[KEY.IRREPS_MANUAL] is not False: - irreps_manual = config[KEY.IRREPS_MANUAL] - try: - irreps_manual = [Irreps(irr) for irr in irreps_manual] - assert len(irreps_manual) == num_convolution_layer + 1 - except Exception: - raise RuntimeError('invalid irreps_manual input given') - - conv_denominator = config[KEY.CONV_DENOMINATOR] - if not isinstance(conv_denominator, list): - conv_denominator = [conv_denominator] * num_convolution_layer - train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR] - - edge_embedding = init_edge_embedding(config) - irreps_filter = edge_embedding.spherical.irreps_out - radial_basis_num = edge_embedding.basis_function.num_basis - layers.update({'edge_embedding': edge_embedding}) - - one_hot_irreps = Irreps(f'{num_species}x0e') - irreps_x = ( - Irreps(f'{feature_multiplicity}x0e') - if irreps_manual is None - else irreps_manual[0] - ) - - layers.update( - { - 'onehot_idx_to_onehot': OnehotEmbedding( - num_classes=num_species, - data_key_x=KEY.NODE_FEATURE, - data_key_out=KEY.NODE_FEATURE, - data_key_save=KEY.ATOM_TYPE, # atomic numbers - data_key_additional=KEY.NODE_ATTR, # one-hot embeddings - ), - 'onehot_to_feature_x': IrrepsLinear( - irreps_in=one_hot_irreps, - irreps_out=irreps_x, - data_key_in=KEY.NODE_FEATURE, - biases=use_bias_in_linear, - ), - } - ) - - weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS] - weight_nn_layers = [radial_basis_num] + weight_nn_hidden - - param_interaction_block = { - 'irreps_filter': irreps_filter, - 'weight_nn_layers': weight_nn_layers, - 'train_conv_denominator': train_conv_denominator, - 'act_radial': act_radial, - 'bias_in_linear': use_bias_in_linear, - 'num_species': num_species, - 'parallel': parallel, - } - - interaction_builder = None - - if interaction_type in ['nequip']: - act_scalar = {} - act_gate = {} - for k, v in config[KEY.ACTIVATION_SCARLAR].items(): - act_scalar[k] = _const.ACTIVATION_DICT[k][v] - for k, v in config[KEY.ACTIVATION_GATE].items(): - act_gate[k] = _const.ACTIVATION_DICT[k][v] - param_interaction_block.update( - { - 'act_scalar': act_scalar, - 'act_gate': act_gate, - } - ) - - if interaction_type == 'nequip': - interaction_builder = NequIP_interaction_block - else: - raise ValueError(f'Unknown interaction type: {interaction_type}') - - for t in range(num_convolution_layer): - param_interaction_block.update( - { - 'irreps_x': irreps_x, - 't': t, - 'conv_denominator': conv_denominator[t], - 'self_connection_pair': self_connection_pair_list[t], - } - ) - if interaction_type == 'nequip': - parity_mode = 'full' - fix_multiplicity = False - if t == num_convolution_layer - 1: - lmax_node = 0 - parity_mode = 'even' - # TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out - irreps_out = ( - util.infer_irreps_out( - irreps_x, # type: ignore - irreps_filter, - lmax_node, # type: ignore - parity_mode, - fix_multiplicity=feature_multiplicity, - ) - if irreps_manual is None - else irreps_manual[t + 1] - ) - irreps_out_tp = util.infer_irreps_out( - irreps_x, # type: ignore - irreps_filter, - irreps_out.lmax, # type: ignore - parity_mode, - fix_multiplicity, - ) - else: - raise ValueError(f'Unknown interaction type: {interaction_type}') - param_interaction_block.update( - { - 'irreps_out_tp': irreps_out_tp, - 'irreps_out': irreps_out, - } - ) - layers.update(interaction_builder(**param_interaction_block)) - irreps_x = irreps_out - - layers.update(init_feature_reduce(config, irreps_x)) - - layers.update( - { - 'rescale_atomic_energy': init_shift_scale(config), - 'reduce_total_enegy': AtomReduce( - data_key_in=KEY.ATOMIC_ENERGY, - data_key_out=KEY.PRED_TOTAL_ENERGY, - ), - } - ) - - gradient_module = ForceStressOutputFromEdge() - grad_key = gradient_module.get_grad_key() - layers.update({'force_output': gradient_module}) - - common_args = { - 'cutoff': cutoff, - 'type_map': config[KEY.TYPE_MAP], - 'modal_map': config.get(KEY.MODAL_MAP, None), - 'eval_type_map': False if parallel else True, - 'eval_modal_map': False - if not config.get(KEY.USE_MODALITY, False) or parallel - else True, - 'data_key_grad': grad_key, - } - - if parallel: - layers_list = _to_parallel_model(layers, config) - return [ - AtomGraphSequential(patch_modules(layers, config), **common_args) - for layers in layers_list - ] - else: - return AtomGraphSequential(patch_modules(layers, config), **common_args) +import copy +import warnings +from collections import OrderedDict +from typing import List, Literal, Union, overload + +from e3nn.o3 import Irreps + +import sevenn._const as _const +import sevenn._keys as KEY +import sevenn.util as util + +from .nn.convolution import IrrepsConvolution +from .nn.edge_embedding import ( + BesselBasis, + EdgeEmbedding, + PolynomialCutoff, + SphericalEncoding, + XPLORCutoff, +) +from .nn.force_output import ForceStressOutputFromEdge +from .nn.interaction_blocks import NequIP_interaction_block +from .nn.linear import AtomReduce, FCN_e3nn, IrrepsLinear +from .nn.node_embedding import OnehotEmbedding +from .nn.scale import ModalWiseRescale, Rescale, SpeciesWiseRescale +from .nn.self_connection import ( + SelfConnectionIntro, + SelfConnectionLinearIntro, + SelfConnectionOutro, +) +from .nn.sequential import AtomGraphSequential + +# warning from PyTorch, about e3nn type annotations +warnings.filterwarnings( + 'ignore', + message=( + "The TorchScript type system doesn't " 'support instance-level annotations' + ), +) + + +def _insert_after(module_name_after, key_module_pair, layers): + idx = -1 + for i, (key, _) in enumerate(layers): + if key == module_name_after: + idx = i + break + if idx == -1: + return layers # do nothing if not found + layers.insert(idx + 1, key_module_pair) + return layers + + +def init_self_connection(config): + self_connection_type_list = config[KEY.SELF_CONNECTION_TYPE] + num_conv = config[KEY.NUM_CONVOLUTION] + if isinstance(self_connection_type_list, str): + self_connection_type_list = [self_connection_type_list] * num_conv + + io_pair_list = [] + for sc_type in self_connection_type_list: + if sc_type == 'none': + io_pair = None + elif sc_type == 'nequip': + io_pair = SelfConnectionIntro, SelfConnectionOutro + elif sc_type == 'linear': + io_pair = SelfConnectionLinearIntro, SelfConnectionOutro + else: + raise ValueError(f'Unknown self_connection_type found: {sc_type}') + io_pair_list.append(io_pair) + return io_pair_list + + +def init_edge_embedding(config): + _cutoff_param = {'cutoff_length': config[KEY.CUTOFF]} + rbf, env, sph = None, None, None + + rbf_dct = copy.deepcopy(config[KEY.RADIAL_BASIS]) + rbf_dct.update(_cutoff_param) + rbf_name = rbf_dct.pop(KEY.RADIAL_BASIS_NAME) + if rbf_name == 'bessel': + rbf = BesselBasis(**rbf_dct) + + envelop_dct = copy.deepcopy(config[KEY.CUTOFF_FUNCTION]) + envelop_dct.update(_cutoff_param) + envelop_name = envelop_dct.pop(KEY.CUTOFF_FUNCTION_NAME) + if envelop_name == 'poly_cut': + env = PolynomialCutoff(**envelop_dct) + elif envelop_name == 'XPLOR': + env = XPLORCutoff(**envelop_dct) + + lmax_edge = config[KEY.LMAX] + if config[KEY.LMAX_EDGE] > 0: + lmax_edge = config[KEY.LMAX_EDGE] + parity = -1 if config[KEY.IS_PARITY] else 1 + _normalize_sph = config[KEY._NORMALIZE_SPH] + sph = SphericalEncoding(lmax_edge, parity, normalize=_normalize_sph) + + return EdgeEmbedding(basis_module=rbf, cutoff_module=env, spherical_module=sph) + + +def init_feature_reduce(config, irreps_x): + # features per node to scalar per node + layers = OrderedDict() + if config[KEY.READOUT_AS_FCN] is False: + hidden_irreps = Irreps([(irreps_x.dim // 2, (0, 1))]) + layers.update( + { + 'reduce_input_to_hidden': IrrepsLinear( + irreps_x, + hidden_irreps, + data_key_in=KEY.NODE_FEATURE, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ), + 'reduce_hidden_to_energy': IrrepsLinear( + hidden_irreps, + Irreps([(1, (0, 1))]), + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.SCALED_ATOMIC_ENERGY, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ), + } + ) + else: + act = _const.ACTIVATION[config[KEY.READOUT_FCN_ACTIVATION]] + hidden_neurons = config[KEY.READOUT_FCN_HIDDEN_NEURONS] + layers.update( + { + 'readout_FCN': FCN_e3nn( + dim_out=1, + hidden_neurons=hidden_neurons, + activation=act, + data_key_in=KEY.NODE_FEATURE, + data_key_out=KEY.SCALED_ATOMIC_ENERGY, + irreps_in=irreps_x, + ) + } + ) + return layers + + +def init_shift_scale(config): + # for mm, ex, shift: modal_idx -> shifts + shift_scale = [] + train_shift_scale = config[KEY.TRAIN_SHIFT_SCALE] + type_map = config[KEY.TYPE_MAP] + + # in case of modal, shift or scale has more dims [][] + # correct typing (I really want static python) + for s in (config[KEY.SHIFT], config[KEY.SCALE]): + if hasattr(s, 'tolist'): # numpy or torch + s = s.tolist() + if isinstance(s, dict): + s = {k: v.tolist() if hasattr(v, 'tolist') else v for k, v in s.items()} + if isinstance(s, list) and len(s) == 1: + s = s[0] + shift_scale.append(s) + shift, scale = shift_scale + + rescale_module = None + if config.get(KEY.USE_MODALITY, False): + rescale_module = ModalWiseRescale.from_mappers( # type: ignore + shift, + scale, + config[KEY.USE_MODAL_WISE_SHIFT], + config[KEY.USE_MODAL_WISE_SCALE], + type_map=type_map, + modal_map=config[KEY.MODAL_MAP], + train_shift_scale=train_shift_scale, + ) + elif all([isinstance(s, float) for s in shift_scale]): + rescale_module = Rescale(shift, scale, train_shift_scale=train_shift_scale) + elif any([isinstance(s, list) for s in shift_scale]): + rescale_module = SpeciesWiseRescale.from_mappers( # type: ignore + shift, scale, type_map=type_map, train_shift_scale=train_shift_scale + ) + else: + raise ValueError('shift, scale should be list of float or float') + + return rescale_module + + +def patch_modality(layers: OrderedDict, config): + """ + Postprocess 7net-model to multimodal model. + 1. prepend modality one-hot embedding layer + 2. patch modalities of IrrepsLinear layers + Modality aware shift scale is handled by init_shift_scale, not here + """ + cfg = config + if not cfg.get(KEY.USE_MODALITY, False): + return layers + + _layers = list(layers.items()) + _layers = _insert_after( + 'onehot_idx_to_onehot', + ( + 'one_hot_modality', + OnehotEmbedding( + num_classes=config[KEY.NUM_MODALITIES], + data_key_x=KEY.MODAL_TYPE, + data_key_out=KEY.MODAL_ATTR, + data_key_save=None, + data_key_additional=None, + ), + ), + _layers, + ) + layers = OrderedDict(_layers) + + num_modal = config[KEY.NUM_MODALITIES] + for k, module in layers.items(): + if not isinstance(module, IrrepsLinear): + continue + if ( + (cfg[KEY.USE_MODAL_NODE_EMBEDDING] and k.endswith('onehot_to_feature_x')) + or ( + cfg[KEY.USE_MODAL_SELF_INTER_INTRO] + and k.endswith('self_interaction_1') + ) + or ( + cfg[KEY.USE_MODAL_SELF_INTER_OUTRO] + and k.endswith('self_interaction_2') + ) + or (cfg[KEY.USE_MODAL_OUTPUT_BLOCK] and k == 'reduce_input_to_hidden') + ): + module.set_num_modalities(num_modal) + return layers + + +def patch_cue(layers: OrderedDict, config): + import sevenn.nn.cue_helper as cue_helper + + cue_cfg = copy.deepcopy(config.get(KEY.CUEQUIVARIANCE_CONFIG, {})) + + if not cue_cfg.pop('use', False): + return layers + + if not cue_helper.is_cue_available(): + warnings.warn( + ( + 'cuEquivariance is requested, but the package is not installed. ' + + 'Fallback to original code.' + ) + ) + return layers + + if not cue_helper.is_cue_cuda_available_model(config): + return layers + + group = 'O3' if config[KEY.IS_PARITY] else 'SO3' + cueq_module_params = dict(layout='mul_ir') + cueq_module_params.update(cue_cfg) + updates = {} + for k, module in layers.items(): + if isinstance(module, (IrrepsLinear, SelfConnectionLinearIntro)): + if k == 'reduce_hidden_to_energy': # TODO: has bug with 0 shape + continue + module_patched = cue_helper.patch_linear( + module, group, **cueq_module_params + ) + updates[k] = module_patched + elif isinstance(module, SelfConnectionIntro): + module_patched = cue_helper.patch_fully_connected( + module, group, **cueq_module_params + ) + updates[k] = module_patched + elif isinstance(module, IrrepsConvolution): + module_patched = cue_helper.patch_convolution( + module, group, **cueq_module_params + ) + updates[k] = module_patched + + layers.update(updates) + return layers + + +def patch_modules(layers: OrderedDict, config): + layers = patch_modality(layers, config) + layers = patch_cue(layers, config) + return layers + + +def _to_parallel_model(layers: OrderedDict, config): + num_classes = layers['onehot_idx_to_onehot'].num_classes + one_hot_irreps = Irreps(f'{num_classes}x0e') + irreps_node_zero = layers['onehot_to_feature_x'].irreps_out + + _layers = list(layers.items()) + layers_list = [] + + num_convolution_layer = config[KEY.NUM_CONVOLUTION] + + def slice_until_this(module_name, layers): + idx = -1 + for i, (key, _) in enumerate(layers): + if key == module_name: + idx = i + break + first_to = layers[: idx + 1] + remain = layers[idx + 1 :] + return first_to, remain + + _layers = _insert_after( + 'onehot_to_feature_x', + ( + 'one_hot_ghost', + OnehotEmbedding( + data_key_x=KEY.NODE_FEATURE_GHOST, + num_classes=num_classes, + data_key_save=None, + data_key_additional=None, + ), + ), + _layers, + ) + _layers = _insert_after( + 'one_hot_ghost', + ( + 'ghost_onehot_to_feature_x', + IrrepsLinear( + irreps_in=one_hot_irreps, + irreps_out=irreps_node_zero, + data_key_in=KEY.NODE_FEATURE_GHOST, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ), + ), + _layers, + ) + _layers = _insert_after( + '0_self_interaction_1', + ( + 'ghost_0_self_interaction_1', + IrrepsLinear( + irreps_node_zero, + irreps_node_zero, + data_key_in=KEY.NODE_FEATURE_GHOST, + biases=config[KEY.USE_BIAS_IN_LINEAR], + ), + ), + _layers, + ) + # assign modules (before first communications) + # initialize edge related to retain position gradients + for i in range(1, num_convolution_layer): + sliced, _layers = slice_until_this(f'{i}_self_interaction_1', _layers) + layers_list.append(OrderedDict(sliced)) + _layers.insert(0, ('edge_embedding', init_edge_embedding(config))) + + layers_list.append(OrderedDict(_layers)) + del layers_list[-1]['force_output'] # done in LAMMPS + return layers_list + + +@overload +def build_E3_equivariant_model( + config: dict, parallel: Literal[False] = False +) -> AtomGraphSequential: # noqa + ... + + +@overload +def build_E3_equivariant_model( + config: dict, parallel: Literal[True] +) -> List[AtomGraphSequential]: # noqa + ... + + +def build_E3_equivariant_model( + config: dict, parallel: bool = False +) -> Union[AtomGraphSequential, List[AtomGraphSequential]]: + """ + output shapes (w/o batch) + + PRED_TOTAL_ENERGY: (), + ATOMIC_ENERGY: (natoms, 1), # intended + PRED_FORCE: (natoms, 3), + PRED_STRESS: (6,), + + for data w/o cell volume, pred_stress has garbage values + """ + layers = OrderedDict() + + cutoff = config[KEY.CUTOFF] + num_species = config[KEY.NUM_SPECIES] + feature_multiplicity = config[KEY.NODE_FEATURE_MULTIPLICITY] + num_convolution_layer = config[KEY.NUM_CONVOLUTION] + interaction_type = config[KEY.INTERACTION_TYPE] + use_bias_in_linear = config[KEY.USE_BIAS_IN_LINEAR] + + lmax_node = config[KEY.LMAX] # ignore second (lmax_edge) + # if config[KEY.LMAX_EDGE] > 0: # not yet used + # _ = config[KEY.LMAX_EDGE] + if config[KEY.LMAX_NODE] > 0: + lmax_node = config[KEY.LMAX_NODE] + + act_radial = _const.ACTIVATION[config[KEY.ACTIVATION_RADIAL]] + self_connection_pair_list = init_self_connection(config) + + irreps_manual = None + if config[KEY.IRREPS_MANUAL] is not False: + irreps_manual = config[KEY.IRREPS_MANUAL] + try: + irreps_manual = [Irreps(irr) for irr in irreps_manual] + assert len(irreps_manual) == num_convolution_layer + 1 + except Exception: + raise RuntimeError('invalid irreps_manual input given') + + conv_denominator = config[KEY.CONV_DENOMINATOR] + if not isinstance(conv_denominator, list): + conv_denominator = [conv_denominator] * num_convolution_layer + train_conv_denominator = config[KEY.TRAIN_DENOMINTAOR] + + edge_embedding = init_edge_embedding(config) + irreps_filter = edge_embedding.spherical.irreps_out + radial_basis_num = edge_embedding.basis_function.num_basis + layers.update({'edge_embedding': edge_embedding}) + + one_hot_irreps = Irreps(f'{num_species}x0e') + irreps_x = ( + Irreps(f'{feature_multiplicity}x0e') + if irreps_manual is None + else irreps_manual[0] + ) + + layers.update( + { + 'onehot_idx_to_onehot': OnehotEmbedding( + num_classes=num_species, + data_key_x=KEY.NODE_FEATURE, + data_key_out=KEY.NODE_FEATURE, + data_key_save=KEY.ATOM_TYPE, # atomic numbers + data_key_additional=KEY.NODE_ATTR, # one-hot embeddings + ), + 'onehot_to_feature_x': IrrepsLinear( + irreps_in=one_hot_irreps, + irreps_out=irreps_x, + data_key_in=KEY.NODE_FEATURE, + biases=use_bias_in_linear, + ), + } + ) + + weight_nn_hidden = config[KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS] + weight_nn_layers = [radial_basis_num] + weight_nn_hidden + + param_interaction_block = { + 'irreps_filter': irreps_filter, + 'weight_nn_layers': weight_nn_layers, + 'train_conv_denominator': train_conv_denominator, + 'act_radial': act_radial, + 'bias_in_linear': use_bias_in_linear, + 'num_species': num_species, + 'parallel': parallel, + } + + interaction_builder = None + + if interaction_type in ['nequip']: + act_scalar = {} + act_gate = {} + for k, v in config[KEY.ACTIVATION_SCARLAR].items(): + act_scalar[k] = _const.ACTIVATION_DICT[k][v] + for k, v in config[KEY.ACTIVATION_GATE].items(): + act_gate[k] = _const.ACTIVATION_DICT[k][v] + param_interaction_block.update( + { + 'act_scalar': act_scalar, + 'act_gate': act_gate, + } + ) + + if interaction_type == 'nequip': + interaction_builder = NequIP_interaction_block + else: + raise ValueError(f'Unknown interaction type: {interaction_type}') + + for t in range(num_convolution_layer): + param_interaction_block.update( + { + 'irreps_x': irreps_x, + 't': t, + 'conv_denominator': conv_denominator[t], + 'self_connection_pair': self_connection_pair_list[t], + } + ) + if interaction_type == 'nequip': + parity_mode = 'full' + fix_multiplicity = False + if t == num_convolution_layer - 1: + lmax_node = 0 + parity_mode = 'even' + # TODO: irreps_manual is applicable to both irreps_out_tp and irreps_out + irreps_out = ( + util.infer_irreps_out( + irreps_x, # type: ignore + irreps_filter, + lmax_node, # type: ignore + parity_mode, + fix_multiplicity=feature_multiplicity, + ) + if irreps_manual is None + else irreps_manual[t + 1] + ) + irreps_out_tp = util.infer_irreps_out( + irreps_x, # type: ignore + irreps_filter, + irreps_out.lmax, # type: ignore + parity_mode, + fix_multiplicity, + ) + else: + raise ValueError(f'Unknown interaction type: {interaction_type}') + param_interaction_block.update( + { + 'irreps_out_tp': irreps_out_tp, + 'irreps_out': irreps_out, + } + ) + layers.update(interaction_builder(**param_interaction_block)) + irreps_x = irreps_out + + layers.update(init_feature_reduce(config, irreps_x)) + + layers.update( + { + 'rescale_atomic_energy': init_shift_scale(config), + 'reduce_total_enegy': AtomReduce( + data_key_in=KEY.ATOMIC_ENERGY, + data_key_out=KEY.PRED_TOTAL_ENERGY, + ), + } + ) + + gradient_module = ForceStressOutputFromEdge() + grad_key = gradient_module.get_grad_key() + layers.update({'force_output': gradient_module}) + + common_args = { + 'cutoff': cutoff, + 'type_map': config[KEY.TYPE_MAP], + 'modal_map': config.get(KEY.MODAL_MAP, None), + 'eval_type_map': False if parallel else True, + 'eval_modal_map': False + if not config.get(KEY.USE_MODALITY, False) or parallel + else True, + 'data_key_grad': grad_key, + } + + if parallel: + layers_list = _to_parallel_model(layers, config) + return [ + AtomGraphSequential(patch_modules(layers, config), **common_args) + for layers in layers_list + ] + else: + return AtomGraphSequential(patch_modules(layers, config), **common_args) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 3dd949272dba255d4cb94e44d5c594209a25eaf5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 179 zcmd1j<>g`kg0)*$W`O9&AOaaM0yz#qT+9L_QW%06G#UL?G8BP?5yY=B{fzwFRQ=N8 z)FS<~lp=k%%)G>$kksN5{esGpjQqTK=imVS+{ENm-K5mKQpb$HCqQs7Dr=k)Yya(;bz{tyZ zW#Sc>I2VR~^4)i5->>y}JVJnv&v*F-{1u#I#NbRp>=h`EIBwDEkMRl%+>j>X1eo$D z=!$a6<5&8ES24#q%93x$LvxRxd}4$fj4E-s`A4uR$PC0zKsnSv(OvY6_Sn%K6xrhm zff7wtaMJY+5bNmY(z-{;d diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/convolution.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/convolution.cpython-310.pyc deleted file mode 100644 index 56d12b0ab1a2b560af5ce37e9dd76a50a1e3e737..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4258 zcma)9TW{RP73K^t%e}8wk}V~6yy-PcU8#)&qz|rJ*ODElw&F(0ZIOVbc4k*pxa4}s zbu57k6t#=AKwY3fe}D}1B`Eq*pns!3U|#!_w<5i#jot5%yOL6;K)E}IXAb9bcrM== z#`Su|!1deTf4KXrwqg8{Iwvm&o$sP0zXK75V4*Q&UB-y#A@B11ZVs)kmCMY~9y(no z=dG|bbi3}b+%4yQJFE<=-Rg91ZCLL%a^4A>!&bKiz9igmZrJX&bG{tT4;Q)%;46~v zEsAPRYef5wsEfvZt9wdTMe`oJW{8%UyKjg&Y3-f9XWnDoCBe6hmG(!_+gM@Rx)CJl zCYZUa<1pO65=W8jrBZClbcJi@x>9nK==!#dl33kRu^9JKjIQhbb681sa-)!UIR&ujAdM~d*;MUS!Ud01D5iP@0kJ~_z7QzP1I5{8f~Wkq44%}_?#3SmiYC%Znz+i9$NyV^dGDo#{`7HrbUX43Q1R4T7nz*})t zl^)iJ)M3>(g#ML zi8;}}U+=cWyjZwzbmzpPY>QK}-M7SPv4m8Z|G=<}?gCu9tmle&MwHLq{NO^RgX@GX zI&r`ATt66eURd2(Igcm9;c6#{IyWxgym@OYA3Z2u@lW5mx~<(HPZ95~u1~ps5W>T{ zGVQ{9`po+>80@BA=rBSng()}fs_1NM+*CBYu>%rI9iTC3WDRIMm$tKpa77sqT6w@HjdY#>Cu>4l#jo}NfZ{Dc zYGthfpeD0LJ-5W7F!1D5PEQN6!BV;m`c(Qf=$WjR{3bIF&yp=Ow)YHV=Y)Y6Z4egM zmApL;0Epf|`QjLlv^l;r-tG}T%`Sq&652dk@*Kzj`@Y+H$DXhWpO_O1KnIO~YGmfJ z0WD#RU-H{10Xd%JWW~{Z^&0>l1m-u3+Y5A~z_ke9fo2qwA2Q8G#k4cnNLMZr;!z>E zv`M=YVmd~Rw;nr>4_2SZv{eu52dU?+UJT=&A10T;%og~}JJI`bw5{2y84Bz&S%+_a zL=OtJq0Zgow@|rG?#0j_?g;ok|fy2ioz|4)}+2ALX-yDlr}bBHWotG>;-(zJ(+l zO3YJ>v_ti382f2adDWLlViy=tq1!b5Tf~1R^J^rTCB6C@%`3}jJd_H_t_ab!feY_pnkCpWDM-EXbv?YRGs@o@^(tGcHUk&iKQw5@ECaInnrOQW; zb$f<@NAM$*O6>5vGJO3Pw3=00{z zhZR4IQ{`R24|4o!tO1_?i2sW}cHIW6FwUJvj?==OqfXC$?kBwLqUSzx?BlgCpC3H+ zzZ2q#M5GuY4U;!OaC*WCcg$2dWjJ-QALGomhm#g0ZAbv#xx^G!Kx3_6vtM9@N4!0i zIa3*q0(er&LC@Yi=v+Cj$(`Iw9Y0jk7yBJ~7buY&Jo9EzL!;<(pKlkS;3`55U3tR~ z$8tldSgCCY6lt%1K;%s#ba+&R5i3k}iKo+5L2{qGh3*I~p?X+>)u}D;1>6tLK54y^ zacNQIr4u5Kv{8>BXm$}!Q?Tde`~#SoIg|1$K)9*#a9$d-qz%c`oQDhu4l^$s4}s}9 z4=kU+e0!EUpW25G4zVmP?YYaw0M!pqXnYwbC>m!S^(^#kR2|fU416xH;LabK6DO@g zih52*>C8N=VTF3W7VYvw?9!O-(u5v3GUpl;JDC@;7SI`PJa^KKLDGrh6lbTSq!kBe zkG=RgkMmP^Uhh^~xg!hbU5BmIH6nRR=IChRAg#^sZ@jJBo3E~Kcvm+rZ@>Q9hIj4d zSGTs+D>V3JBKcVrSaGL-&{~>*$#}^q3{vfE-gx!Oupru{>S#yxU_xoIWXeo``#ZkTNz?15&eg(geD^v3_mCyS%;qnr`Im#?9|- ztgm0+yr$i}e|`O}9QzeTr#4ZmwB4s-qph7d4)K+fsvMRS<$CVO%E)EmoL$wC!sP>GhBpA4Iua>>Fe*oD(b&UW3 diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/cue_helper.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/cue_helper.cpython-310.pyc deleted file mode 100644 index ebbfa3761ce71d34b3b6e186c42b690ae59a8b1f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5559 zcmaJ_TW=f372cU$E|(Wk7hfXTj>9BP6BdotSU{V)uAA7gQniL`Kz5R{NjEFbP+Dra zOU*9rh@grjGSI%bXr6-r5>OQILm&IvA2P3f@-HY5Aa=hqOG=`gcFCPRGjnEV=3KsW zW|OI@vW4p(|9rf8<*a4>n<~dI50!Utr*AXM5|&_*)y$u)$x!DJ@7hhf>olEip;=&g zpW7`qizC`=mV_;wsNAhID`>MtA*yz#np1{$qgr>mIo+LU&M<3m#>_C=oHNf;&C{aT zJoCOKJW+aRiISY_oZYjVFMMdl6~~hEJ9n5xKT&R;`JENB2bQRa>O-r!Af`kO&vRm0 z%;0%m%!)ZYFUWJ^lq_~$6sM&l&Ma7MN1PQeJaqP0^CJ4@WkF8KnKoarj&UrAbB5z3 z!%;cbb6#99J$2JFeXQq2anba=EZCaWc5U}5pG5r(J!q`vjkm&B2CBjJsWln3!4k_Baje9dO7y~HW6o^4?zU+9 z<5vllcW|c)g|tYOO)iW=A$nWd!h&#I;fTURs8$vZ`&`rSc8YCJ6oqFt=^ggkvV>73 zQO2k;MvdRZjbcd56)^>Y)O7Lok}sFySUcG6Tc*6MT_v-=id!738ec>gaNoe4&ZEF& z{GpXu9k$Ph){q_W_pCjVvc1|OA&xH^=;B_(onAn(VGXTM%DXJXpbo>$h4qMk#kV;I zv&USsNV8tUR^$?Dnu-||kT<{d6U{^Q*ixs^@a;#7n@LwL_EV`A+d?fahj9?ym1(xv z+sQVQ*uSxMXR#Z!uGD#-$;}Q;`3lhNuaMdbNnK6Ah&sYaCR)YN#^IQ z*~UO4x>rDqMYjv#!6>Z`Z*wCom}i@BK>&{#fJ4!gM>eSzs5+@l?fQO{`TpzZN=Zd! zd!7wy<3;72-|HbQ8gU)P7jUb6CZNgptWW0QDOif&0(v?HI}y$!{sG|DwAv?Iw}A}M zN+H*ZBA0#g440~bf?<{j?y@9_Qe9{zeLxnuS%a%r@yK^}Vn3b%NMPlg;7Ey^e_rEf z**tTe4ra#t8Z)gns@n7YIOs~>*JW6LBKi^0Ro~z42T|Um2o{Xt>q;vM(zGk#vO>Lt zW?hU$pj5Do>!akM-`84sdWQ7!!`3eVyUzh`v2y!@0GkLuLs@aj`u z8FO&raOatk!!qe$;v(BNRzj-GAX@br2n2(gf8*Xw|J@JXUAgt{^;KB!FnWrhOsGa-C^o79?#i@OIU;T3!AO9n?BoW^46VOo!qtbVvA=$@9MD=XbzqOzMxB^2D8Ky9Sj z-FiUJ$4ooze%#7;Ng~(92P)_RFR#+%1|V6Dd3a+3 z+m@|fG+@WSM(i5eDET*lFlKE##%1i)0l!ioyi5WDa3xBLdY04?Uh1v>&Fy}8KTu&1 zw$ac_=vQ)Dz&F(jg_x+6!Vdt2^ zn5{4Ptxeh5^1H!~5B>NQeS$26(cHxD49v~2D)ZP`#y@wyQv?(f=0lcB@-h;+2!b$; z%;$issSHtCAQT`=Lzoqgfe1Z^&H3wLkosZlM<%UV9$dMOhI$y+^M-mWh>gLlOOl1C zKM1o;NJMV`xZVz;Xgz3ck(?-txJ!|5LuSS()Gx6HIEvd!;&!-ka=+(CzR~K7;K)G; zE@X5IBU8$GtE>!Iceq+cMZ2L%%5-5vCH)?3)Ly&2q$om9m>i{BKqfstKy-xjEA@8U5%oW8U``Y@38@KWX zixF!Gpxqdnwct(Cld-MEFf$7Ee&gW5;Ay+QN@*&$p zj_UJ_B0IvUsX^f#XV=aONTumVAr?vZBX`%y+@U>m_BeuU5y_6_myqz3g#%ROgsKcP z7!O+3MH$JijI)LdgoGRELfcRJnXZO0gwZz#1HdRbnN0N@TzV7kt-3)J4&dHNtY`hC zmxbMMAbp8J{cdW6)2QgVDA+;%L~jb>EDSO!av_v*jnQ=F5lM4S8B_om6BMJ4Aw(&+sV* zr*#x%x5jJC<+vy`kACjA<&wu|;NfZmOcOdK2d0GC==5J8j&ur&A(ge^mUaeY-^Z~C z$u6`ys!{4YuDOJIR72Wi@DKbhT%1A6#i1kM%nxz0eZ+SQ8P2yjGFe#(@uIM2&NBds zUl}{XQOx{svc;*kV(8kV(yp78VJoh&l}e`yTbY8ba4~(tR%TWYRTai>9iuO+KVWxA zEjRH{AqqX)4kKh~BflhbQslT3kU{W4Ku1ntpmGQcULdRFt52vSHzKClXD3zUU5$e7 zx(KfQmBtdUXANG>^#-gSA9#=9!H6GVZg0t*iFFWc5uliL&_Tu=?>xj{95{xEz+0oE z-O(zrR|8fo&}{vv*<=9f3|{dSZnLg=Yw+6tUDx=;wXE)8%-u#w7a?6Ak(!9hIRW_@ z;*@cfN$*oljMH-zY(r$Jk!)u^g7+qzY221qXZ^G%TOqlv6bQ~js^4R1F6Ek{Fhj37 zl-|e*l^~{1sUQid_o$%MM!PAJEjaWlRqR37GoGpfRGafRHUoTBh&~Kcc~{2ZyrYmt zw=z9_day$F4xF}T ze$IE!v$8kCF0e^=&fVJmPq{jB>OW%qgj18D&~2Q$1LDZ34K)7`pf-KZ2)3OOU^BR^ z4WPC@BN%%>(P{*Lh` zPsnk#Hd2K9lowp)JoB76!qhKmmRHmpc$kDtyV*_;iE4iKBF9yv41_SGZkLiJMPFys zMJmiO%vjRs>z8tg)6HRB5jgAV6kor_YeFDX6oM6f(rVTyJv>U7O_I>0!^g2kA1BJz zrm=VR7LBAlF8}7UL=>G#OxmCO#WGcIQ(^pw;)NpnGP^ux%BQH!q!ixOJXN?xP?_FA z;aMI#&8pnPT}4j8L3zk7&)U_xhl8Q#;8*sF+&jnUfLJyMNE{X~7}LgK67(BmdgkeW E0T=Zs$ksq()#C;rv05BrY{o@=W!>0h43|>b+mf+&+1G)^R8ak8P)NQ z(KYL4*Q#6US?}b!cHLHG*2#CBdO?+qPO)36mzWmVzWIn<)qKm(J=XkOptsA9^hc~d zgF4&Kqb{G-)n`%T_yyDyMm2{}Q}j!yDUE98CgYX;8Pv~A#+#qC%=(AWa%j@B;IrFW zb?zy4r>#}Fw5|ye^b$FDyVnXt*zh{LfVI!$BS!~4JUf2oTZtMqmK6D|CyH|zR zYhCtIZ+){DNGoZGu$MMy!0F2vJtSX4A<#&9b?&pe?(@KC@=vuU^L5{NT&$bE8Cbp* zSWUyv`SxR@p8HfYw7QLj=4E-+zZ2ZV8b#a)l4S4`r-G|DsR*Mxm0p}w(zp_!XXSp- zNMn&yyvWZgIzeq(U>-5L(RqJA$4EFw)on5 z)sSY=?*&5It{X;S>bk1PngWRZlmKRiR(Vm{{2a< z6?cOgHX~|HU(~LIk=I!dlC;*_Oj~i}Ub=m=*7X{}iH#s?v}$LB-}6MeS-VZ!UJKIN zX#Y?FY0nM18-edbrz^cpK}eaW-B=I~v$(?sNA`b?mGRo8L5r1^2-3cY_9>=_w`pgW zXsACkO&e-$O_YZ$Wo^E!ZREgt2kYN?KIaaCh1Yg@W1nJ5{mWT9kYtep6j z8u4Cg56$iTj<)UKF6?NJU|~nJEh{ayOGAt1XzD|wUB*4*8=vW0xnXX5);D(;K2dJS zeCsSgiBfL+5Y}Yx@FUuSwq*~Ew75MtG`8na$A*|?S=-^;3-pBXjUiX6&l ze*XWS>+sMRjg^gkcsLnb?R;?=&H=m6v_(%ko*cY|TFCsewxn$plve&UEMRp<{K8Pj zC?>UG^`w_VR2;PxsSW#ac39RjEU0;w&T81na{K780DY7`%f7|2voL5Fae3|OF0Q9! z@K5nsVb;?0Q{@!=;C|Q$*6TS!iU)f_Gp@hA`}74~sdBM^r%ceB11AdKj>7hdb5VPc zM>#9P+;u0y*gz3V-cdZX8uMwkX&X~H)&TGAW@B2L62E#e3bT^7N;eh5svcZBD=GLKb93H(YR z^^I;<=KE1{uO9@1K<35;j(U6b^3_%M+U1X=6@*dJAJ=@idP!OxLgO82M6u|~g6esF zb?xf4e`wb2MVow`HPOZXC;VkRB%eT7Y>~|Y)A(_VS!{_J%*Jgn>$zjv zYyq`p#&<1lyqMDs&iA#=vRPKQ?$^fT=k+_9^t z#U#(n*nA>#AS-l%(Z`Dc)HFv2+!k- z_>fkiScEw74rcp)ev84qXl=Hj(fq&Q_xSD1Qi<%kUfPnp7oyBpPD&O%Vd4T_oib(T zY$r@oicQYBXi@xRX=Pa7;H@d0PsnNIY$rx!lbl0g{GvFFH>-}w)6m3pas)Me2bo8F z(ex;g5N}db;Yug=()A!pVj)eqhC2v+l)@9qqe$j&dIC0`2BKNH5t!tW4Qz9&wJU z_ZF}R^LNqkZQRKc1Ypj14qoaR=lliN`EzcRpS?aKs8<8`zKfJbS8^F*KkNZz9iTS? zvuP+$0Q>+vg7zTSwtbr<01xN`@OcGz2Y@fi*_&}^Gm5(i7%wT{4z7Ly>Q7}bZ>4E3 zd8bzM#KZ9ZN-XZwdVaHZ`sB$oD<{vKI(=pZkC{K95KO^9&h{{g3&-I0-1}aqA2i4Y z#ut4KdZs&BhnU7Zg;0ch&^<6eMBsHqgH&OT56eKUfwoZ%o=o+GUJJ-}RJ*0{u8;Bl zuXhFMubSJ#QMO9flePkg+dDZaotl(R@59N*7(#rHg!;4-&QkW!Lw!<)a7-{0ev81$ zA{#8ftn3Ng1sUx@sE{)vcX$bRLS0B1Z8)GUMxpr#EL1Xf0GB0-Q%waRRYz)~dZcCK6nG2RyL%vCVX z=RnDeFz`h(@ZG`E0X2-V#{~KOYT;GT6W>8o2D}O81bL({0o-@drHxyyMgh(?o@a0? zKYzU>sJ(%E-$gb_S8^U=8fugZ64XX81J9~-5UF502e3f`W|S8jbz8wK4;4GI@X^id zH*Sp(c1_`945U^Zcl{F^K*fp(JU?v(h$E7fFs=nrt>dMX8f7uZ3H0lEp;&39-7K-o z3^zZb+&J2YE5g$xyCbJLjVpyI7zr??waM@5ceP~xeeI`6dT(n_*cz;Tk7o?w#t~SA zFg$j%%0@OY?5hy)bJVLq#xcqM7uaB7N&{bJ_T=a$=@*$H%pOdj_<>BBQXE>hf&j|K z#I&`$U=pMgd(b~@7Gxpw z57SLqqIAUTgaa=PVTbwgW6&5=BCVCm#5rX5$xE>#S^AhmxXu9~ zVh%)<7Ay@ZJf3n$%@H+{wcy`G@@=3tKZ+f*6&t4>o{A;O0I29>k!WqeQ zv?j<4sSZM!8nJ|&PsWE>ZfkssC7sdxaX%Hb8>yq$*!t8ccPB%sSzp0Hzd?_LvPCCH z2s}!85iE}@S$sA)df+R387jxVG|UsD?UH@ly2l-!MT7@BSKh@*0G#aZ@Sd^S#iEKjUWk z^VjDE^`tHByU4caO0GbRF)=>N4TbPw6{=QJq}*E(QK;nUhuZ&Qd*rQZCUy4 z%1bGmEM^rMy-ZTK(c|o%X=XJCGPSWqBo%06id{0a5sKh-Aq_%JW+UbltCX|ZoBgPv za6L2Pu}MXKvVw8!bsrOn2P7Vnc!k~n5%n4KAV;ArCClq48s4*VI?6a0auM^>t*o3P zohi~~8xG5m=di8o1{ByozwX%d?x1|*I{;Gmg~a#laMxO(~WwY96_ zM^u?tl^536Z;h2k8uui7h8!U zp=QhlXQF5&K9l)Y9o=L}RrS-bv(gSzL0?rCv(w`h_inI>6D#$nlrBr-mPQ=GznC=9 zk4aDft`ag8Dk{EI4ni$yN|}=!K-y_1zPfQ+TP8K zH#CuRN~E+$#I=Xym;*=tnYnW6xtu5v-WzW=KdMShGUNB&JU_qp%{MdKZ#F#w?YH0e zm!DM$`2&^Xtpdu|(A63YoN$_ul=?JfKBK6!#7b@7HnNpeQpb0UY$w&!^<5)Zl3MEd z-pIe6HvFcMon#`N^d~719~ih2NoAU@^9XTn;sdz4dA;Y>K9D%>7BPq=kVwneSy?hx>*aA&*eLpj{t zCVm}CX>d<8dN!}~#5%?fmor zp})8QLre;CzXpZTP4=y%Yo0=fzlWj}Bi2<#2m}nZn<~=*w5=ea&l&F{4VLE_G5~~uEk3fHcSzRRm z3y<-OucSU^#>v#j?10^0rhbIgewny@7n7m`sHqOpQKGU1Y6ng4!|>WK%1h{iu$99H zAVI9g*kIP61!@;f7^_{=6*UABYWHx_b;TSEx)SFiRhlN+L7HY<;-W|_PumD+dqUev zBt3~&K$)@mR1;MH%rWG7AQF-Khl^1jOv)JLI~dL>o%O`%rICkj91n6rQ@Vb1#wlGp zJl_=Tu38X8NvKp1wC%EZgXVjfB~o7+7(N6E&pUM$2CwQ+k9pLA?QqVrnN1s@P0=al zQWr)BvVAaZPg4eV?_gu*bzA1d;vl`T-@xy(=*AtdHEOu!Tj@~cE!b;G zD_@RPt2fM|Jnm<$yx&?9d21m1ut=>Ek5-(udVPug=TP!-pf;08V_h4;7H6>fSHNBd{9GaY%gCy~!tn2I3KUVwM` z5U%t)VANbz|}S`Gw%-!u5MU+*v3&jt9K9;w+a~jJ8xRCGtgniUa)4 zA@XyaXpWc`i6jYvV=2#~i=i3*Ccs^xP;-f0wO=yE{`D;Ok7u*LS!Lp-d-;MR(XTzN zoqT-|XG`Y0(xl<$eHkrxvP?7RRWX}opp@IN)cUg~Yuz++!#qwpD{(IIfY*)kQt04i zuXM#c6(0Cgq`?~0{K3<)X6S!l C{b~dN diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/force_output.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/force_output.cpython-310.pyc deleted file mode 100644 index c4652dd5a76086849b92fefb3d68c50e0c24cd3e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5367 zcmb7IOK%(36`uEtPsy?@*?ll+o0MrHHF2A^XwcS`Y3E_fhNYy*pum9Uy_86kL*~v< zj)l?6DY9q{w41C7R00YVT6Ece&|lEy-FD^figXjjZTg)%Bt=FrQ&+e?mMBY z;%3k5`=O7x#q3_tFN6hM;}&!7>sK}AGVgtjdEDq0@0<7aurw|4S%8vYT2h{t6j%`@ z#c4@}=^I+T^bljy*6X?~wg&xSr^loIfN`YfE~SJ1RnZ)7Uv8$&&3i-sP=_i?{M%yC zWUXeBN+)TF&M<9JBeRb?P9$$3!nIImT4*pmG`ZEa5gY`!Z8C#<%tWxb*LFV8+B&nD z^L{Dx^C^Y_^LYW&DYPv{6R^VK2b!gYB^tUcTp5TKzcEUOqjYyQsNuOg5OL-G!NjUzyCNMdmSiC{AGw;Ufnk)4Es(Q9$%i7v=m#EsKSzIE$Dp zZt*nVF02ENmG%*>w4?7X&2~DgBrV!1=?I>VMBF0v$wNnwdZ%F|w;*VHy4td~rSBOk zt}ko*#sQ_x)a=4Kv1!=6uETU@Jg}I#tnFDDHgTkjXH#?E&8!1m)%0FLYp_FWee=-3 z@e#b4avJ$?TNazW-r#ODim@q7sB`11Ns9^wA z-jZg_w?q~Cbz7*FxI2j1( zi^3zrv5q{yiq!6z*`*!VtCz`8<7S`3P=knIVFYxDWHE}~9W{I75;uz2poMJnjmA6D zym@0o&fUD#xEyWX*t~QtYOFVIU42IeD(lLPTWbwWJ{8KY#4-(=&XD*nge;AXy2X3F zD4LniJgJaz2*MRIp@a;>(=5YsEKC5CF^vBirtx3HDnIfH=S`*Ag3^hHs?w9RA)W+F zfF=*YZ`**U0Pqw5o&shgI6PTaaKp=1nJzfusi2E-kuyoLoqQLP!X8Wz9BYi%*;6Ml8l&LB=!l=7Rfq*^Dg;golp zy_zzfqy_M;5n%TcrSuq5%%ED{hg_JVnp)+ZL2uOOEjl8ThvHLgdkKlp!?w%mzQ_XJW=4-^zq{SJsRL8{NiDV+howsO2OnEAj`SY-jTuqg_&o>!P< z&tMiD3Y0sleIGHGIeQ)uFA$e9FLUwp?qCUHKaYvm(Ly2f4)i@Yb5xsuSK5EdoXp)X zf~jq0AD9=kJzw?ZLz7~q%!h-c{l!(=_x?q&?=HNMnR2()OuxrQdeYQ;-Gw zWpa`Zay+l(s7zx!HKHm;fEhyuN*0tXDp^vpysS-HRqHcutzKD|ZUl=3GH;5n(*_cz z3-SZD9v%MsAJ=c+{Q1LA@jRrXQ#bNrh5LrmQUdv8no3M`s0qaBbt0b>>vPAPA!?-Y z6p5!vD9Had#lAy=XhQs$#7`ikGbRXYr^D|`XD3OUtvk}k1mXm^!o`nB%U0R-!ZEvu zXQ+v1Nl*vUZWAYkcA}`s$>Acjc6yyYd_M(SVQHsBh(yZ+6np$^5VGGyNZy7B#uOAN z3UVHKzK0yoptPfxbq~&kXE^#i-WE!Zf6mdoXO+GGVC#^$ZY#)d4cHLhKY-nwz^&Fc;EBIyvaJVsf9 z@Ns6w$3%4z1ur8cH2)c*BCs;wpf+3e1HU4euE~w#R&sSCIv*@IHKL5LcO%cp2nXW)+ZEMUht(GW^N*d=+Y|>*949&XdUP&s3<$?U^U{ z_a#V<7?k!yk(`5FNpWk+^wj8r9suS!gcE(43VuOUbgP1MZmG0G#FlDZx9h?4Ft^B= z_j5-bgytw`9Rx-s^?P0D&M0&g`v5R^-m8(&n#SSlJXUFcxGWJu=>O#Is=f_v5%K8keJ~Uw=bwk=L#YYF} zFuv1i@th%RRHV~YkBwTqZmi^YE4*_aUC9adr@5NRI{k;_&(<&G%t_T1^hAol&DL-vtxn4OIdv=6CZMRFJ0{}LO$TWdkNTl0I$g!c(jy2| z0He6sY4I zm!++|Zt2AN+e+{mKsvcEK0|WK%_ds=3qta1h$@hXf@9F#l5Urvb}uj%jYUP_^B{DP zyP4Cu1uakgi2a;T7Tqe!9XJ6?s9V;ZPmAV)UN)9Y?^Da%ed@`GGokpoCsMrh8kes& zqPH4ra;jg0FZxIsmA|7LS-TL@6#~Vk6TmINP1D>b4a7@UX zlb*WCt;8|*F_tBMF^*S8X{Wc+?WBT;OBVCbGb{A1Dv`E&mHc9W5{C@v4@>vt&0S_Mx@G~el*bwJAgv}-W|B}GLa{vGU diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/interaction_blocks.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/interaction_blocks.cpython-310.pyc deleted file mode 100644 index 5bc82ade64600ebf995b460a7ca00dad1dad9e8d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1851 zcmZux&2!sC6qi<#Ez7bke+U2ko)!n+P5cA^(h=a|EhAYIt7DyGu2B{H5ni~qPlNN4}Hg8>BsX9$&)GeQLN}tQ{!*%lDvvrLS>w+YdwkC?s6J7Nm?D zhx1uT`YJE&Q%Yv(=$wQ z%gsQR@ow>2VNHUN5qj<>bLO!bw~8#d+__1D<07-oCS%d0H}eC^Eq~0sbnJ&d<)*NYq1il2eH{29uO9_|>OsX&HqNW@ zJo3^R83!cgD&rRRneUyD1x(UCw~I!Dcm{2oGoJ#OBpf#dxW%;&|2D`~kn{&ofwF8d zC9}T(_S`{d>av}cKSayUvdh#=%gn5DrY(C}?<@4RvRunLnUb}$HH>a+Su3l-Q`#ML zMPVAmm8_msv3^@GGKS3PxC$9HLs=u!Gb=N)Qs&^&mzAvfjry5dtSuK+P_}lViYYQ# z_e?2PR;1r6JWF~l;E+Ay+O1beAMyH=&6^SG$C%b&Oy2V@Zbc9?n!xhvIetORO9FVn zxScPDHBckpx>x%ZDvl$q3MGot4N>`W0lOpz>=}UXvBaAO7#MQSo;(Ys;C(B^Y39UP z4ON2_Xp-+8w$zU<D!?6PF*!s8<{FXNENs97?t$G(#v^lRqKBs%Xxt6 zPh!p^r(9QcL|=k({T?r|#Vm-ADFT;jLg}D*nQX=JX0jzWmfLyfo8sF5_~w7w)!_$R z0$&hG?y&1z%*PBp*MWzt>Mnj{%DVDFh4gY zVd68cA0XqCZwYwx`vAbd0NOuD zLywV;9+eFAH!7irsEq!yO!OzJAb<)zKvncRJp3D~p z>>Z62iA5q0A<9d5i6ZSvd9jqdM9NEkO1)-F~{^68i{F55|5Jj+B}3cexA4pZ(=q!fcfxi^wbL=w{)ao!&!sR;Wy z7xdbtGVi}3qrvXgsEoGu212<-ER#VQ)2{Q+!lY4rgo-e_CO5i_n_Y`DVfP$vaeLe5 zj&Qk);t7xY!WZ6I(t^hWQNx*o?bav8w#n(F&2F8y_`=6Vx4{?r5`LR}S+w|yXkndI ze(GbpyD-G2IcVd#jk64dY{zL-6zzPwEiPnPyPDE2 z^LCEGs1)ryYacthetAQhINO%0C0cJ`lBr-)&4;C`bLc3%CH82bF>BD!%B%&&Mrjz8 zrBtnK*e69tX;LPlP-|&4+6&WYPe`m#lu=eD_?Fn8Q*#Y<3#E=y{05>l_KgGc&@hd$ zH8w&vhQAICxu{!KW{qw1T4M*&9-EMkmhPCL?u~s&pBs=hPNTutg{*TLZIn&O*4X9N zEnB`kHVziX?xBgDyy7xii?k~@Hg)OP8`rizC;fDmlv50$N9Bh&M^c8Ns#RJUKotR1 zp%bdNs6MBWm2jrn@OEB}$OhJwO$gOYwV#H!ZbxzlZd*0g@VhJ0NvlAwL$GKq5#_e@m_L+x8( zc+dE9YkfEGi}hh4(>&%aZ42C`e3iz&9m^z<~!^CC>GD(6j{8xenIjD%m>zRiAx+uqVN0#tKUDEJRRuO&@h(1+S)knD3v1Odu?$C)KqMBJoy;bSHXKxIux%Q=$P_J zs()uDo~BwG;zuY2*&%=iY$^{R9|YFu%(2tXnKWjYK79yDvjcw*4IIai-uz)pYi#}z zNBa}oH%}XPnMRMT(c?^y`u!V(+1VNAcH_^r<0xxqdD)J4qijdC=kRXbc|mkseMb2V zHkW5fxFob&k>hu)>hQCUlWKSZlP;kY6z#|rR^b4FIeOxl0v_S#HbFwUQj|lP9n%U$ zMk?A3{(_3u%K5Qbnmu-QqJy4w);O?p|$E|V>){u&W1JZv3+7q``DWHqt*itef|KCB!nP>Sv=`? zn(|v{DeI@NzptFjm$t6FDbG-cCq=OvfvPlJ*d^kdT%!SpQ(ZY_F5_Kg#ZjqTac2-^ zTv@5eI!*M z3ct31WXNT+@xZPJ_giQ!TejJvwl#Y4n7JRd<}ni!2xgA|8X9C`LV8iVj#OxZbdxg> z{|cxcCWA>~s)L!q4oz(?yD}po{%i$dfD+b>~Ht!jB1wR`MHpC=XM zZRZkv&wqt^+R2x-PS%)hj#iJIdNQ%IOo1nh`T~I`2q}sTkVHLdf_SH)_!b@9Q?`i}pt0<);6)7vpN@d^7^HejhCLT$!lN5}lJ_(VfYEGy> z7HNugkND~-nx{#N!ZwAng2=`@G%;LDfLw3o4)-=v~2wu^3&~DNuxtByNjL$eq3R^RFOhDx-AL)-&R~N?qFU zgZ?lrlR=sYPPr8kj-CSL4*UPo-)#K*)uWI~Y(#r0f4WFaJT!f%PL4~gU69r0av!)R zavusnRmNlM{KwHbO%;;Z21T7!yjGc!aIT$CUZdLUB#0%Z*ggvMdPq;whSYX_nx<1A z&QJcG<>nAq1EtsE(C7HTE2h3r5agCd=ZIA62vZ&4vA%9k{_Ozr42B+YyX;pB_-6hDVJj&Q!f zE#TS~4&AK;`~x`mfO8-BtD5l78ruzx_Zkm?cib*Toi{`scyIC+@QzX7y$QT8D*xJ* zjgW3cqZc($AAMVT=qjn_Zs^=id5E`^-mBJb!nw%8ObjI`K?PAfp!CMUjybkY8)LQu?2N1Q)te-*xC3n>ec$NRPErJ6Tk{02eMJaAO(QA4 zOPz#VWo^e(92Ak*RZ|AyE*cktZv_*0?b8FWs{)ldae8C9+xz~b6UAb^pb3++EwpQhp1@E*XiH+EK>}{Bm1FI z*}ZDz7pi=n+*mdBgkhPa=X*&hed=pe7r68E-vWhl^iP#>G^Zmv*6CbAxo_!QMTdFK q1o|uI)IHO=@-J3-%`0?i)3mZ`__*mcKzER+xfZSh@1s0>#{4g9*B2+B;&gSOCk#Z$EH(y&e#F zzWVz`?*~;v{(*z*j|YRFLM#3TL=Zt!(x)Nqvyk~9tO@n?ltWM@@{d(B2d?}sj&xW&<$g1#8>4!v=h5wugUwZpDPMuR4&WVa-a(TuV z$+I}!Xi_s%9QAkew78cI`@2#VAgXop{vb(Z)X#;4@r_41?|-P`L2oVA@z&8m8n5W6 zWS~2!{rXdz42s_YBS}aF37MdwBi;RyVEBTCj`V~p%c3NGSqA&PNnalK-QWw-rNS4L z=k>560tx4p0qm=ZI-IQf1#wAOgVQz4%;QYXUjC8h`GHi2D$$Y)Kp}b}kD_N*Rh$W4 z#0RqANj4a2uJh>%lyh8=5B2MFF<3nIgH`hpf0E^gT=x=CRPaKpINOta7c?gDD|i=< zf}V+WvYVudKH`Uo?(u<=2T49GIM`Q!Td!;aJBoqqBofvY5+F#rnirio1wpQ!*md$* zlW*VUEAZak;qN@;+uJL5klnruP)$24+~K#c0mFa7?~TFPFx6{v6HM_jwJRO~F!ieu z;$r5nhEpgym~EF2@(h{-tyl%3$tgLb=Y*0IdO{+0!hkv_jvzy#=IzMMU<^|C#06S9 zDFO8a*(4WqqgghkVmOe>_)(N(iH;(Lxo=9=n?z|TZ_QjpZffCVLDMsCA=9o^ zprCFLg~6}tNE$#csHsHbD2k#B^5VDcUf!4OVIft!D^&Y~1Tt_-7P>t+(!D&3-rxLH zyB~MtJG(OL^xF587{p2+wKwH~%r>NM7szC7{6S(y@-paKgClhV{_#y{#VnA(_ozpG zcs=IPfF8f`TAH*jk8fP`f$@~oLzQ(9?90am2HX44iVz4wzE76Ph-tdd_MI~q0pBm3 zvk^dk=A9Gl?W*!Bff*m%Le+OCGT+K7R`xdRL~Rw2(M_UeKz%b%bI_TR&Q+&pSeBWf zJIp%c6fxyIlM$r)#9`eukF}bI4Qdg|0%~y8P2}GKVjM_D;{rxPEu+v9M97pOkq%=e z?!chHzYORcKowlQEXzR02*U>O-97u14=vv3AGIHMsV(OES>BuyN=$x9qgiFwVRRMwl zR^SGoxXi6Him&^EG4>zkuz!0l`L~$t<>)mS zXmH*s9g&2LsGuWo@*3G$966evz)d75gaBq2Js@i7X<%LHf+L}}zYD`V~DZIyH)C`s4gvz~uzPwMDW0hECVjf=+n z-UICgMv-4~sSEK?gT`!JxK#ZJ|GP2Oj5noP0pTktgY^pQiLF-$z-aO%BFT;$nw~Lx zt5Lz8`-tc%J*Ok**CbnViFS^f6i}t)xT@iPJi*)2T_%4aqV!k5{7+PPBli#PAw49J zeBOTf%!jmMka|$z&MKfiI47VFMu2_!SXea2_HsZ1?d8@LaUoj>!>ebp%P)}`FZ@4P z-5%HdSE2|@2SyR5q{VR$qG-n((Jvvzm&u>mQ}`a4It1LyT6c{Dwp~z*U%+$AF+615E+1{4@W-$K8@@QSDT$$kWI33C{@5XAcP~V2d>N`j{Zeuz(HxMl$ zyYVoE9)WCn>c_YQgT&UXux9a{IKxn>zK2sb!!h`5a+{mqe8ykEg5nJzei`zd&Qkwn z!?{WQKeCsztU+C-eg>1xCF8-Lz;}>B$P~umxXRe``L?pQvdo?!oWC5lu9ev;#>eE7 z&?cU#%llc@8tNo%?I&7cjGNl{Dr(_ZP+>~8Q%#(yU(R&$3~C$}Hsk%$R!Lii6joiE vPU>ys(ClecT#EsJbzEB?qUQ<@WP$h$ld^^{aGtwB=bCKJqwHls8!!J0AK?+M diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/scale.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/scale.cpython-310.pyc deleted file mode 100644 index f8329937e1cb613c90e43560555f44c3fe4ee1a1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11977 zcmd5?+ix7#d7n8mdz-!RB8sGFIoY_j<4tVJj+&-*9VN1Eu_Mv0B`5GE>SnbwOD?tC zrOyl{Yq4zVP=)D-1WpSy4@m(E$cum$Xdm(yeAOch`RM8O!<)CdMC@ejaZ)Yg?9Y`F7jd zu&efls;UjA>TI}GcOz9z*(_7-bTuu%nT>2UD|t>kx8YU2L3w_oP%TQ_ZI?Ets#6=& z)oG+tffLLev;1_Rg4tlUsSwWuGk(_OH?w}u_inq@BZ24VckBz6U+{~!Ex#CK*XMSe z9lLroDJl6=D49x1=97|Xe+DHpNy$P^<&;0~&jzXBXfPivG##0PJDLU3M*^l7&8&`D z&9pz~AHAKb9`n^@tFrK8Qmc}RoztCLv3s`Fh+_9*D~#go3!71^+o`wX)XN=2Rw{N} zXmmF=TkW8>(e;ChilDw5QT@uT%^+4` z7-u8IY8&;jQ@bf@8P0)vO~;TtX^5m~$Tjvd+8?a+H3ySHP1!+yhh z!+t|G9bfs*?QEnXTHQrHs@_$xT7ELNdv_gOKw9&N z={y4#zjOJ-T6ZHju@wgTMAO$N&bL6rD?u2Y*t`|3bvv~)mtQ=wQEvoKTnjplwG$_` zzggGOtrM4n8$o9!h)#rzbWY%3R01WsJfG9X3b?65-*kPT;Pm@TPrkn-gSw7jG)af?={?lw8w$ z!)n?-*yMJx>iJ%fX(~VO7r-+4L8|DNkSc&NipVK8U7zgWR!eVKu2r4F=Fi5Nmjaa3RQ9ltE6f{CCL=+#^0$Cz2)!eOmPE!1n@DmS~j5tz!64#Kb^ot%$T z(iT5FQgPnLxvpIgZq-_yxR4~eTTwg{>3XYElj=38-AD^POpw8ki&uzx4Ux#|<7n?F z?IFfvtirN&@D>uW%~B2eaG&neEXZKZ`HC%zM9C-^&OOYeHUTg zw?GrMa@UTFI(T)fg(cNib=|KNVy_jpI$>1rG=ezQYrYwJ&?evofFgh2kpJMPS1}7Q( zJ_8wW|EN5ZQtuB?W#5=`UR@s12In+<7^jdWLBcYEzD<%GBVp`0hHAm&InUr1q_I~p zuL^3K$;rNd0?CRp^!P&-?VDZ(?1atrKaH}G$NsqOs@~L)D1*|=>dhYUx~t5L6P zLA})CLap}dR=qta$<%6ow}EWui|1a6=PsW)eevAc+UYAVTzc+I?cB<_moB^#7blYX zAq*KiVWeaCTDRMlv!pp04Ku;gIR;O&PH}MZ8bP}~Jbkj3Cs>6WtdAi$iZ^@|LDq7W zn{{nvyYDEa-cycxSGnp}Dy4p@(xrDwb0ufVkv0beBo8N_@TUk46Oa&+83;)iLNXiV z0`Hg=cw$Ys`_ic(58;`I@XW)?C^QQqKxt(Zt3`joF9fAvDwqzY*&0?r(e#F%Xyg^%9HJrzIx-c3Amc^`p3J2VtN50 zX)&I?ycslFLHOlX7!1xs?_Yj*68q&4QamiTBC7vZv%K1ZIxUlbA1?>5ZBp<@unCZ6 zX}1%UyUlX67HqJ-NJE*+bI81D>V@UH4hH?OIrQmjw8rY4Zp5bL5GVNYGQ>Z|4z%uS zqdN|heDQRmG+G_M)d)h7g$=5Q;X`Fjo*|otFhrp$pb{Xnpb}u0L=IHNjKo~VjM9J# zVg7^2R}u3c9OT662J#8Q1)RjIz(v9CKwr;TMb>`{N#)Y#5+N=d5=ratmq)JIDy zDl-@TG8I%mGY4y_Z*MbS=A!F;D5ycc+E!6!Jsairygrp#+IK{m<&#`BmRk^2RuolM z>ZhR4rbMB^!dtq?tsCq?xYpfj`{iq93wjSd+r`Eh6HwGXW#Dk5>aywjqabQ!TK^t? zj8ym-V^VFxV&ClT;6H<8hl?_O@vUbM50x|fB_nRc?uQ=ZMgtXfoykgbS8 zGcrs&Y2^1FILKy+np`mkt0>EH-M|CKgU4~-=8E?#%WZ7Ao{8K(esO3VqzNNb?+0?? zX?DFNcKu)zT88t;V!~R>W%AVk3~*zFJHa}pQBtB|1M4vGlU`*{A1~irYc;@m9Y1)j z466plZn+L4uGOf4lGg&gTz+!-N%jj{qP5xye42Hpis9B~cdIk%^*^yhc`7Nhqx%#u z&0UzwqV;LFnM#wYPPEE(6uwQiolVh5aqZsmL&!fe1@1?5{Y8#=7&*;J{+*GN-F*;p zKJ{V9xr)i-`Jc~$oQK&l2goUv_K}nNc?++;gy8;Ug7T35gDCT#9q+Sn9{0xx&1>+UcfIS_~+i`)U!w1G+t%uMYnfKJ$el{YxMargGuM^A1XwBMeRnUj3n6G*uX<5-fPDaF z`!0LkvBR%Xbobq_s%;D6G-qbtvii>V?Qaq2=6ZjfaK?FPI={66n;a-TEaSKXXKxbg z==#k88wwgJ-=Ih%B!SO`aHM zZ1xu0v$Hv?ck&RvYwGR>y$^sm;b9CFf?SIz6>?>9LO`w^5be4eBN^sL zO?Fr5Y0mov#^J*0)VP;5?(Kdueg{>AjEh#qlg3(f;{~=SCG-~%0RDG%6!=DErb*R~ zA3MO%`W%Xfl2lk*2=q7`<;P&L&=6wvPuV5sWe7=v$A$5|`Z{t#DoQt>g;bp<{Iegn zJ;gY4mJrU`y^sDH8fzky2@pg@M}Z*<_E|)RAczW%KuF*al-n_2L?1%njuL1D#W%_s zz@*dwCOIoVzK2J@g6=A5{TzNqs3~@Dw1S&Pe+YIOXcf*Z&I~QVeWH*6<-2%8=6hJF zYrFTRmG_?O_LdG?v%y?)P%Df2kI_m$!{DTKjUdhuHPx=X^5Qw2W+u5& zHdd1;n&dG?5{acrwb*HOq7jmc3veXBg#tIhTG!V;t0mw{pJj5iD6w#aYk5*Qy^zxr zNabN|1XaH%!N4g^fGYKl%9MWb@M90d=23RZR`*!Pa|-V+Js@YUflekKj=?AFfr;)1 zv;elUknZ?vW&m_@2IdIJ8O0}oR^HDA1^Ax|ejcD}RwOrFxd2_cQm*7b0RPk^@S2u% zJ^^Ag!HgLPz$rfhR)lE-00rRnJL$`~d5;8hSli4Hw8^MPWi&dA-TRgd*Eux;c>c7s z7D(AF!125wU~k@E3>M&)UI5rT<}U&4EehB>20e8=p1#!e>+Nx<_s^oaE<=C+*3P*L z0__g!nxF+G20h53pSVd-Chk7SJ(C*XGorV#I@h)$u@t2a3{}yP_4dvBE#SxwykMKK zQeja7TvCLZYOYJoHzTMEW(e?BuAdfXBC%J# zpFvpYXAu^Cx*$t}#+(&FQ&%ceV{qybjQ&B;(lpW11O;Z8GeIleWoMfR;!>>@)~?l~ z##)W^gM#`b%MQa=8!YHBNSX`94ntUfi)tbF%qzG=Q)3t^ZJR$}sbiP+o#K5+D$bc{ z!^SWkNbbY%VSKnZZ#R+H531gT8LH=z>xT1DY8NR&E9AoHV6J|uuk=%>^VA!OLAq;+ z0cwc|5~H-Q_R@C-xolZU_ufv7#1-ITvfB`GJv~HRLr+Okb7=G>RmsHRAB3D{@~#pi#$3_F@H$1j9hT%k@qQA=E3Iow#qPJ*m{wuR;2ztv0jzxm-eyy3wQ$eLF zv*Ia+d2e0K!Q{7TgTUy7O86_ec^YX@+&R&hi&H`ZsAB)MiaiEO*%>a2AZN^F(Z3-r z7TwYZoov#yk}>`6L&U;SBu|`bMZrdh83`YXr57RJ7=9nMqAnRM~a%AxJki-EMEHGlCZRGxh75@yu{g`Oq*yHC?y4!ay zsUr;A5;W*PXTTA0Mw9py=EN22 zq0M)aHtUS}u=-@Yb<;yXhaT3^ZhcIVcvLNDdn;-C`b68;OOG9SC#}y&>n!U32zft}HKn*F(6iXbjIMrSyM#KWh<*MLeU=96 z{~_v@Wc{Vitkj@y=b!o|wXa~=#qF}{{BE&{w-VQ&`VLq z`8i`kGmiRmj9T0Qb+4=N9dWMbg=)CVhVQ0Gs_ocs*-KbV`Or6*Dl z^~>Vz$zoPTnbmwhJMjDzjOPd5R>#kNPxaG7|4;oCfK4KZ{$II(;NUerAo}UQB3|FC zVt==yaOk$jt$Sk)uBd?*AKu|w1J0{xdH)9`$r2x@M|W!j&U`V;gLVLqS7gk2dhL^n zFgyZH*boZx7idJE5WjQe#Te!cIMbv2p8R}f1krw^Olv)n@Jc2rAHP^CEL z)#E9b3$56t8DhRPNL+ofON}cIKV4_dYYf`#aB2*6C0!dAo*01i_vzaV8VIV{j=^-H zxccI>z}%bstfr2y!yH)k`Y%{n&VrCx4y%v6|Hd2sC4%g%>t^BW_ zbzJo8zBlU>RMy5tKU?#%-*$1;pZ;a;m0a2iy#to{y4oiRF zcZ6!B-M{U0fY(t2>1{`1E@Hst>nZ&(!ZeV&OE*r3dFeh&fz;PC`Xs{aL@Dq*O4HIy z4*jS1yu_oK8pa2{%)Ies>Hrh06m*CmaT&&`XHQ=_^98M$@hSt^c|qS|dWn%G--xWT zR9u`H0}aH*nG)z;EpCW|W8m|s7RK@a*s+qcAc35Pr_tpVa@(Vb@N#n+dx+10oM~Q% zzUyYpZvl62?md*yKe~8ue~fn&U;3<6mOhAc)VuV;+0z$|ceKS-QFy$+cT}&l3N^{d zMXIkeN#XN>-J|5Q@8b<=K#$y`bdd7T{i!Do(D5%%=p-|d2AILePaIdWHEuudhZuyeM!W|?d zh_)YGG$(EWBS%Nq{#2iX`zN!4aJhUQS5gDp!8jcnbyy109SjiMIpced!N(!PE9(%K z&H=vv0M?d^Ww{9}53ifqp`7Wl?gqXU>6KEHpV8`DyaP8=GsS}Bh zyP3Bb7tka!eUc&LMfy30gdhLnu}v|Wb-3K=$R`+aF*w=jEN?}v_VRiw(v+w$8RTo(Wqy(w#wqzR z600~Y5zErtU&+mcC_dp+@nfVfnd8l`yyP~)(BWAoWy+(qClU!)eb^l^17X!n8G-=; zb{BL^4*ODRCU1T0Qy=V^zUU#3{WJU8r~ZW4t<`fDBt_lYLfV79VJ~OD^L@*#UiSz* zU;h25zqdxnKd`X+l%Vi)Na`xQL)xVj^DJRQt7{GIu01SuOQz0BoME|JHhG%3 z!%EjPc{`~NYuy?pqRdMN^kc#uUOpteEa;$qU>(qIW7*>J3ba&~Eo;jbk5{3kx@=kJ zbe}Y9|Ae)XcGJ=~cgIOG+0N2bM7iJ{k*Z8SkQpCGxpwxX6eFdbdvPj4*`&G_Wy4XN zh+vp;0p%-q^KAIB3`hNs!aVFuMnXF(lJO{yaE8^VJ};=>Lq?D;<)q6v?OMVflpr~r z^(=0|%57XdBuBWgTbH|{(yM(zdX!hVcUbSjo_Ot$bgR5BYP=z8J)5ub^+UT`|AN@0 z+khQj(XZ@_q&MGKoTYndF0%lOea z7p6zkd~CXfkA)xS{;fRoqb!ZWT%?fv@d&;Z90$w&)AQdS!~Dev9bJiyV}m%=jd?a3 z2?;}MZ&4oS+QXFupNUEEM6WONi=Nk$a5@Q+a3UlOr1CJ$W0*ZEuXx>sybY-eNj(CS zlLPXUJ|mRmY({5nW*t~F5?GKYfemF6lxA4t%xtjbSqZGm30R9`p9*|O0&i9ZTb((o zH6u@JI36eaCN)`mw^`REHHIB&H-NRpc@Rhjv+FCpC;(pp#)8otWQz^Q>GzhKkJI;L z4f~gh%~$cY*!R_QTju)JbK?e4qV$dP)t`MH%1@Q9u6V5JYFhzVBnGxbSDDErB33rg zK%5ip6rUbZT?t_HP^lm|BF}%<>H}z6VE+xrh%!!Q!J zABi;Tx9&(j3S~ZN?Tg1EZHv66Fq5|6A6F3opfDH#vsG zr4HsX!Jo}))IOfRej0QJ^jnJ~h$sxPcrcGtwGB!A0?dpI;QYo^R&vVe0sRx4_5jX$ z0B9Hh76v5Q%Xs=tyFGpToX350Tz;5H5%P&Ip2R9wohH)_ z)eoiM!RI3Ww4W=z9;dmG;Jigm)0J0YVcPpmn2g1al$n$_phCWa<{BCViQGW*8ki$0 zZ({z?oAkeiS_`^6%<}KwQ;8_Im zehZgaYv2E9C)nM&+xhjw9bGY{yPeKMX<@6Sa;fb{S(X@>n}d=Lw8(d{xjqlGNF)gi zc(RYLVHY0_@>>Ox!r5G6GuyK%9O+k#v6sv;2mCKqVlSA(p0je}EBB4fy7}syKH)I< zs1BI3`UKjJT1ldgHWz z3t2TmI=yYgrKs^$K1qbHfDKSy=4>!m7t^m(aFv0-p($jj2obn=(2eM1!~72+bHezUzt%W%DI_`PKK-N}*J-C5~rg|laHZ7~Yq&204!Fn#qD(_8IbLzm5(M2ad7(U7QF z7$Q|WiMczN_>?F`+AC@dL0hA2r0esFf?o{^noa@Wlfjm|l&|Ad1_A~VJ{FNB_aG_c z2hS}yvEW7zxN^>p-?-HMmvba8*>L2KA+w0;i#f8#V(d{m#f*+&Myox&IX~Nr>98z+ z2;<9dqWKY;lbyehMWi?REi@-#;A0UvErg`-1PXEkf;+aUd+b@$w=UoD3b#nZ6K}%+ z=SXe2{qI2QV*3}9+A|7d#s*eFX7U~QplKJ|GMt4WDw}A&jpocgFydoK3W0l-UI$W| ze*N;*FG#5&abfZ{8svD{LSv@-F6J-<tBDK@MWmXw&7O`3ZE%aCW>n8qk#Omy9)DZk8ps!7mdx&t%l!N>j}s zmBpPj-OBEm3vz*1bg-_Q`f(m7TZ1^4h$USsu8Ldun?UK3c@;uiU$lT;fU4p`weyLQ tXd`z9pk@IpP+#(=c0VZO{6oAJD!g6P$TeD{HMmwf)-~$959!?-^gmZWbnXBE diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/sequential.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/sequential.cpython-310.pyc deleted file mode 100644 index bba755543c76b1bb1846a1f558fdc2850a48b64a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6182 zcmbVQ&2JmW72lm*E|(NZOO#|=vJ*4Tmzmg-)1)b2IBpb6PTkhB5i@a2G}x{#LVY#cl zNh0eL2SxzpED-C47~rIZM~0lhzTbVixY22g#cnL*Vw20oD**_qi8xu@+)6e&k$3s# zjm5U#5a-rJ)Yw>jTk=g`CR>X)#a$7tiexdSLbQl~Ebep#;S?@xZmF4rI1*oI`yvrc z`cYs7(>2+z?Z=YZloK~(u$eSa9((LTEq(_DVbo1-)EPI0)ns28O>l`@yM?;VZIR&_ zk!jjm>UQ1v%CL=k7CduR*}K-^et0$7>?Sv)0RBSdfRG=00l(iMVGJKC(`(_4X8;mo z$JjGT1bwq_c#OpG$Xo@-GjVrQNacEcFO{R zX@FyRVk8DwD(_NwlQojc8QaW=G!hC4UipF;Q)s^mEO^_jaZ?e+a#U36K|9 z8u@MEc`EM>6bI#k=iTZ0;h@I?UYe28q1{gdsN#Ux4H1SzhHH#wX^3Y15tImLt%I(? zEXGQY=d+U9pV2)UNKbqLw!5!;b&IqPKN}inx9D}a)pWTdvceU)R=&yR44xGe+!Y11 z6-AL+(3aD6cT~@d5}&{e&%=v&OXzp!jAmKu_7Oh0TdW`9Q=-&lyv(O}?fN7?&u206 zsBod{$N2HxOnr*ap=Dn0Z1WTRB*vA;+FsxnMYO!AE$|c+ z?kzPva)&oZKEnSNT+LVT+=8EkDOEbn%FwG+V4fn(>+U7F9$(NU8a!|$7b*u_GR=&l z1z0)xn9VI z@Y??7g^CYX9yGkD+g=kguDo}t(uu?d)(j#p?pMylLSXmB2pD?;V>l0@b$Gz)Xe|Nf z?CmcITA}Ri!F=qoi3z}v;BZuhi>BWVlb8^z$?iU4VSUentEr)7PcXFI0kNWj#OyI14@k+W@bGi~|YI(g?kGWlex$QQ1Tc)m-^U zYFUV`wJS1*igHqEDR;<$p^FB8RB?tH& zjL4(t0Pr0O37lqP?yxvVae2qwHv5bMLEi$j`WZl*Q)tNa8E2m}c`9-CvfN7Chsb5% zj4f_-4Y{CO?X;EaS<%TPx0g?HtqJY`7I^jnduVRwY0h3zX9Mmp*!IMBp})0*|!0w`dPqo-vylR=KyDtS;E~klA{P`$NJ_& zDie0cHEmyjHk5aWj^h;KbheROt;s4_(3?+XLz;Ac;e4edwMl8Q_hyIkqF%Bz-F)Ti z`XA|*oL0&RJgRwzGfacYv<#uUq?7Sbu5)W|YTr)ZCmHt<{3Vj43^=MRuYduZCy zEf=YJp1@lKE&?dCr<@d1spHuw*6JPJ#gLdlzTle9;O8RA%IieyC@Y(Jv=-Q)m(dvN zFPM|2{iJvD5D$GI1!;!<`dI2g_CYjUrhN(1K4q_y+pL#CqQpGomT`yKMq+N2O(ajI zaU1n}2a+f`g-KPbUh>9;Sz1E&8o)abMDUItmNIP_Wl7fhbha@C-!r}r1@pKB;Mxq7 z7m$0fCwbO;^)Nc0nXPhxM4}iYwNn$2927#QB+8}Sn!Im=7%+S&2OBZt15{v4Aa%~z zw)z%jk6_L0w%yAotZ(mGeVF-HU#+R}uBbTsR>L6_(B-Nzn z4KjSqBcwSi7*Wln4EhO;z6wy!CF=u~Q)c^Vb``ONH=;1Y_?VJcViPlG2F$Wq@1?_- zH8xjS;Dr{+7(<9Tl-z3Vpv06IvIqa!XMy%TRBy5Loby1R`E+7>1xF_v`EpXVH08?a zZor^R?Ri*@4<9o>nnq0445l2LPAA0ThsIq) z7OC7f-_x4(G=9p4$l_t{u9^50o-;_jL^4FvTiYUq{}ncgDKeHwKc4i?9LC4PCaRJ@ zmOzcsK}s?@!Tdne(Sp-xA&|!IPYr}Mvt?}~;I`~N1XzSe3%yR?LdeD8?}4S&M$*}u zj${EH@s!1bHkFZ`RkJz*=_nviVU#|}qz;uO9aYeyJRI2sPWE=HShvQ;JuSt#rxLII z2=XQ7>9fuxyeoMe_}C}CbB8hNxohad)EJ9M%)}!4#%ln-fp8J~%{WU8Ld}e)sSSAw zYgu&8*k(~dpCG`ZU3b{NY#@7K2~I41f(}@lZMocfKn#X0Vs`Fg8u{(jXRa$qQWVPK zqX6r$g3q$sw7K?kzhie8tbp_gF`t4N4#I5VykbG{$V|DqZwL8{VGt*DV{+;Dw4w{( z7>v13z^)%1%C7MQV&P)gY4~A$35DTGa-Me0>&+zv4OszD#pR{ilBdtCS8KOZJs^8Y zm0^xbbhec*0l?{yfv4wLa)_zue@P`eJ*p$#I3B-6V~jg4qr@Bvu*_x7V;ccSi~jpk zruXvyK#imo#zaoM9$7be60&Dm);3O21@^E9k&!hZ4JFIBaJGky$a(Z208@I0I#5ybnNLD^2HhV&e^sL?8OKrCg?= zPq~l@k6fzkG(;T#5&dx)z(F9UpThtd(hk-;btr%K=c^|EuDLL0kuC;olKAj#VwD-`6pq=tPx*;nux9Z#h8pg8$)xc$!7pDFiB>GD&|ALGNAmQ>IS znqiI53g`zi4o7PTXx(OavVF2oopPB{VC+zgv}&sGhvZfITIDYop-nW8+Z)2Ycj7Ri z4Xa`+O9186((5g)@D7b>5Fn<=bpodeyiQD91W4gke*fA=BY1!_`&#c_gc;!i<`PBY YnB1{z(CH#QW#|iE%gN4|PWia`UuJCN#Q*>R diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/util.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/nn/__pycache__/util.cpython-310.pyc deleted file mode 100644 index b1cf45e398ed9ed6fdd64dfe1593c33611c3b866..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 535 zcmY+AJx}965Qb+zoIpqrLINsUq!f2s&?0maC&~yBDkKy+-T?=ab=X~>F373ihd@a| z$)DNwD*nPre9Sr&F!K9)X6IRJ#zm*o0*a^S$9%{D{>9>JG^qGRayw)iG#kSyXpV5n zw5VMoKX2$nH}2%A(5w&L<_odV3&PnfB!$n>2BAljOvLlgko1l^Ah}DjifeWSwpFov zc%W(v=DdQKYktM;wr84E{FeWQv*uh-hxi9JhwBZ`7Tn_IH$v3J4)sM<(-qKwvS(>m zc&Y83qAP8jB_uXA7Z{dHCd<{X)zj4and;*oG(Ggn#l$FjNzU$^S^u(UTR ZbUDWTxW8MA9|kTY torch.Tensor: - return torch.nn.functional.softplus(x) - math.log(2.0) +import math + +import torch + + +@torch.jit.script +def ShiftedSoftPlus(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.softplus(x) - math.log(2.0) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/convolution.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/convolution.py index e62d6ad..d5f1cd7 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/convolution.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/convolution.py @@ -1,141 +1,141 @@ -from typing import List - -import torch -import torch.nn as nn -from e3nn.nn import FullyConnectedNet -from e3nn.o3 import Irreps, TensorProduct -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - -from .activation import ShiftedSoftPlus -from .util import broadcast - - -def message_gather( - node_features: torch.Tensor, - edge_dst: torch.Tensor, - message: torch.Tensor -): - index = broadcast(edge_dst, message, 0) - out_shape = [len(node_features)] + list(message.shape[1:]) - out = torch.zeros( - out_shape, - dtype=node_features.dtype, - device=node_features.device - ) - out.scatter_reduce_(0, index, message, reduce='sum') - return out - - -@compile_mode('script') -class IrrepsConvolution(nn.Module): - """ - convolution of (fig 2.b), comm. in LAMMPS - """ - - def __init__( - self, - irreps_x: Irreps, - irreps_filter: Irreps, - irreps_out: Irreps, - weight_layer_input_to_hidden: List[int], - weight_layer_act=ShiftedSoftPlus, - denominator: float = 1.0, - train_denominator: bool = False, - data_key_x: str = KEY.NODE_FEATURE, - data_key_filter: str = KEY.EDGE_ATTR, - data_key_weight_input: str = KEY.EDGE_EMBEDDING, - data_key_edge_idx: str = KEY.EDGE_IDX, - lazy_layer_instantiate: bool = True, - is_parallel: bool = False, - ): - super().__init__() - self.denominator = nn.Parameter( - torch.FloatTensor([denominator]), requires_grad=train_denominator - ) - self.key_x = data_key_x - self.key_filter = data_key_filter - self.key_weight_input = data_key_weight_input - self.key_edge_idx = data_key_edge_idx - self.is_parallel = is_parallel - - instructions = [] - irreps_mid = [] - weight_numel = 0 - for i, (mul_x, ir_x) in enumerate(irreps_x): - for j, (_, ir_filter) in enumerate(irreps_filter): - for ir_out in ir_x * ir_filter: - if ir_out in irreps_out: # here we drop l > lmax - k = len(irreps_mid) - weight_numel += mul_x * 1 # path shape - irreps_mid.append((mul_x, ir_out)) - instructions.append((i, j, k, 'uvu', True)) - - irreps_mid = Irreps(irreps_mid) - irreps_mid, p, _ = irreps_mid.sort() # type: ignore - instructions = [ - (i_in1, i_in2, p[i_out], mode, train) - for i_in1, i_in2, i_out, mode, train in instructions - ] - - # From v0.11.x, to compatible with cuEquivariance - self._instructions_before_sort = instructions - instructions = sorted(instructions, key=lambda x: x[2]) - - self.convolution_kwargs = dict( - irreps_in1=irreps_x, - irreps_in2=irreps_filter, - irreps_out=irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - ) - - self.weight_nn_kwargs = dict( - hs=weight_layer_input_to_hidden + [weight_numel], - act=weight_layer_act - ) - - self.convolution = None - self.weight_nn = None - self.layer_instantiated = False - self.convolution_cls = TensorProduct - self.weight_nn_cls = FullyConnectedNet - - if not lazy_layer_instantiate: - self.instantiate() - - self._comm_size = irreps_x.dim # used in parallel - - def instantiate(self): - if self.convolution is not None: - raise ValueError('Convolution layer already exists') - if self.weight_nn is not None: - raise ValueError('Weight_nn layer already exists') - - self.convolution = self.convolution_cls(**self.convolution_kwargs) - self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs) - self.layer_instantiated = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - assert self.convolution is not None, 'Convolution is not instantiated' - assert self.weight_nn is not None, 'Weight_nn is not instantiated' - weight = self.weight_nn(data[self.key_weight_input]) - x = data[self.key_x] - if self.is_parallel: - x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]]) - - # note that 1 -> src 0 -> dst - edge_src = data[self.key_edge_idx][1] - edge_dst = data[self.key_edge_idx][0] - - message = self.convolution(x[edge_src], data[self.key_filter], weight) - - x = message_gather(x, edge_dst, message) - x = x.div(self.denominator) - if self.is_parallel: - x = torch.tensor_split(x, data[KEY.NLOCAL])[0] - data[self.key_x] = x - return data +from typing import List + +import torch +import torch.nn as nn +from e3nn.nn import FullyConnectedNet +from e3nn.o3 import Irreps, TensorProduct +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + +from .activation import ShiftedSoftPlus +from .util import broadcast + + +def message_gather( + node_features: torch.Tensor, + edge_dst: torch.Tensor, + message: torch.Tensor +): + index = broadcast(edge_dst, message, 0) + out_shape = [len(node_features)] + list(message.shape[1:]) + out = torch.zeros( + out_shape, + dtype=node_features.dtype, + device=node_features.device + ) + out.scatter_reduce_(0, index, message, reduce='sum') + return out + + +@compile_mode('script') +class IrrepsConvolution(nn.Module): + """ + convolution of (fig 2.b), comm. in LAMMPS + """ + + def __init__( + self, + irreps_x: Irreps, + irreps_filter: Irreps, + irreps_out: Irreps, + weight_layer_input_to_hidden: List[int], + weight_layer_act=ShiftedSoftPlus, + denominator: float = 1.0, + train_denominator: bool = False, + data_key_x: str = KEY.NODE_FEATURE, + data_key_filter: str = KEY.EDGE_ATTR, + data_key_weight_input: str = KEY.EDGE_EMBEDDING, + data_key_edge_idx: str = KEY.EDGE_IDX, + lazy_layer_instantiate: bool = True, + is_parallel: bool = False, + ): + super().__init__() + self.denominator = nn.Parameter( + torch.FloatTensor([denominator]), requires_grad=train_denominator + ) + self.key_x = data_key_x + self.key_filter = data_key_filter + self.key_weight_input = data_key_weight_input + self.key_edge_idx = data_key_edge_idx + self.is_parallel = is_parallel + + instructions = [] + irreps_mid = [] + weight_numel = 0 + for i, (mul_x, ir_x) in enumerate(irreps_x): + for j, (_, ir_filter) in enumerate(irreps_filter): + for ir_out in ir_x * ir_filter: + if ir_out in irreps_out: # here we drop l > lmax + k = len(irreps_mid) + weight_numel += mul_x * 1 # path shape + irreps_mid.append((mul_x, ir_out)) + instructions.append((i, j, k, 'uvu', True)) + + irreps_mid = Irreps(irreps_mid) + irreps_mid, p, _ = irreps_mid.sort() # type: ignore + instructions = [ + (i_in1, i_in2, p[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + # From v0.11.x, to compatible with cuEquivariance + self._instructions_before_sort = instructions + instructions = sorted(instructions, key=lambda x: x[2]) + + self.convolution_kwargs = dict( + irreps_in1=irreps_x, + irreps_in2=irreps_filter, + irreps_out=irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + ) + + self.weight_nn_kwargs = dict( + hs=weight_layer_input_to_hidden + [weight_numel], + act=weight_layer_act + ) + + self.convolution = None + self.weight_nn = None + self.layer_instantiated = False + self.convolution_cls = TensorProduct + self.weight_nn_cls = FullyConnectedNet + + if not lazy_layer_instantiate: + self.instantiate() + + self._comm_size = irreps_x.dim # used in parallel + + def instantiate(self): + if self.convolution is not None: + raise ValueError('Convolution layer already exists') + if self.weight_nn is not None: + raise ValueError('Weight_nn layer already exists') + + self.convolution = self.convolution_cls(**self.convolution_kwargs) + self.weight_nn = self.weight_nn_cls(**self.weight_nn_kwargs) + self.layer_instantiated = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + assert self.convolution is not None, 'Convolution is not instantiated' + assert self.weight_nn is not None, 'Weight_nn is not instantiated' + weight = self.weight_nn(data[self.key_weight_input]) + x = data[self.key_x] + if self.is_parallel: + x = torch.cat([x, data[KEY.NODE_FEATURE_GHOST]]) + + # note that 1 -> src 0 -> dst + edge_src = data[self.key_edge_idx][1] + edge_dst = data[self.key_edge_idx][0] + + message = self.convolution(x[edge_src], data[self.key_filter], weight) + + x = message_gather(x, edge_dst, message) + x = x.div(self.denominator) + if self.is_parallel: + x = torch.tensor_split(x, data[KEY.NLOCAL])[0] + data[self.key_x] = x + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/cue_helper.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/cue_helper.py index 1d0d0d6..c798f40 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/cue_helper.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/cue_helper.py @@ -1,189 +1,189 @@ -import itertools -import warnings -from typing import Iterator, Literal, Union - -import e3nn.o3 as o3 -import numpy as np - -from .convolution import IrrepsConvolution -from .linear import IrrepsLinear -from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro - -try: - import cuequivariance as cue - import cuequivariance_torch as cuet - - _CUE_AVAILABLE = True - - # Obatained from MACE - class O3_e3nn(cue.O3): - def __mul__( # type: ignore - rep1: 'O3_e3nn', rep2: 'O3_e3nn' - ) -> Iterator['O3_e3nn']: - return [ # type: ignore - O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2) - ] - - @classmethod - def clebsch_gordan( # type: ignore - cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn' - ) -> np.ndarray: - rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) - - if rep1.p * rep2.p == rep3.p: - return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( - rep3.dim - ) - return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) - - def __lt__( # type: ignore - rep1: 'O3_e3nn', rep2: 'O3_e3nn' - ) -> bool: - rep2 = rep1._from(rep2) # type: ignore - return (rep1.l, rep1.p) < (rep2.l, rep2.p) - - @classmethod - def iterator(cls) -> Iterator['O3_e3nn']: - for l in itertools.count(0): - yield O3_e3nn(l=l, p=1 * (-1) ** l) - yield O3_e3nn(l=l, p=-1 * (-1) ** l) - -except ImportError: - _CUE_AVAILABLE = False - - -def is_cue_available(): - return _CUE_AVAILABLE - - -def cue_needed(func): - def wrapper(*args, **kwargs): - if is_cue_available(): - return func(*args, **kwargs) - else: - raise ImportError('cue is not available') - - return wrapper - - -def _check_may_not_compatible(orig_kwargs, defaults): - for k, v in defaults.items(): - v_given = orig_kwargs.pop(k, v) - if v_given != v: - warnings.warn(f'{k}: {v} is ignored to use cuEquivariance') - - -def is_cue_cuda_available_model(config): - if config.get('use_bias_in_linear', False): - warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn') - return False - else: - return True - - -@cue_needed -def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']): - """Convert e3nn irreps to given group's cue irreps""" - if group == 'SO3': - assert all(irrep.ir.p == 1 for irrep in irreps) - return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore - elif group == 'O3': - return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore - else: - raise ValueError(f'Unknown group: {group}') - - -@cue_needed -def patch_linear( - module: Union[IrrepsLinear, SelfConnectionLinearIntro], - group: Literal['SO3', 'O3'], - **cue_kwargs, -): - assert not module.layer_instantiated - - module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore - module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore - - orig_kwargs = module.linear_kwargs - - may_not_compatible_default = dict( - f_in=None, - f_out=None, - instructions=None, - biases=False, - path_normalization='element', - _optimize_einsums=None, - ) - # pop may_not_compatible_defaults - _check_may_not_compatible(orig_kwargs, may_not_compatible_default) - - module.linear_cls = cuet.Linear # type: ignore - orig_kwargs.update(**cue_kwargs) - return module - - -@cue_needed -def patch_convolution( - module: IrrepsConvolution, - group: Literal['SO3', 'O3'], - **cue_kwargs, -): - assert not module.layer_instantiated - - # conv_kwargs will be patched in place - conv_kwargs = module.convolution_kwargs - conv_kwargs.update( - dict( - irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group), - irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group), - filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group), - ) - ) - - inst_orig = conv_kwargs.pop('instructions') - inst_sorted = sorted(inst_orig, key=lambda x: x[2]) - assert all([a == b for a, b in zip(inst_orig, inst_sorted)]) - - may_not_compatible_default = dict( - in1_var=None, - in2_var=None, - out_var=None, - irrep_normalization=False, - path_normalization='element', - compile_left_right=True, - compile_right=False, - _specialized_code=None, - _optimize_einsums=None, - ) - # pop may_not_compatible_defaults - _check_may_not_compatible(conv_kwargs, may_not_compatible_default) - - module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore - conv_kwargs.update(**cue_kwargs) - return module - - -@cue_needed -def patch_fully_connected( - module: SelfConnectionIntro, - group: Literal['SO3', 'O3'], - **cue_kwargs, -): - assert not module.layer_instantiated - - module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore - module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore - module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore - - may_not_compatible_default = dict( - irrep_normalization=None, - path_normalization=None, - ) - # pop may_not_compatible_defaults - _check_may_not_compatible( - module.fc_tensor_product_kwargs, may_not_compatible_default - ) - - module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore - module.fc_tensor_product_kwargs.update(**cue_kwargs) - return module +import itertools +import warnings +from typing import Iterator, Literal, Union + +import e3nn.o3 as o3 +import numpy as np + +from .convolution import IrrepsConvolution +from .linear import IrrepsLinear +from .self_connection import SelfConnectionIntro, SelfConnectionLinearIntro + +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + + _CUE_AVAILABLE = True + + # Obatained from MACE + class O3_e3nn(cue.O3): + def __mul__( # type: ignore + rep1: 'O3_e3nn', rep2: 'O3_e3nn' + ) -> Iterator['O3_e3nn']: + return [ # type: ignore + O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2) + ] + + @classmethod + def clebsch_gordan( # type: ignore + cls, rep1: 'O3_e3nn', rep2: 'O3_e3nn', rep3: 'O3_e3nn' + ) -> np.ndarray: + rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) + + if rep1.p * rep2.p == rep3.p: + return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( + rep3.dim + ) + return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) + + def __lt__( # type: ignore + rep1: 'O3_e3nn', rep2: 'O3_e3nn' + ) -> bool: + rep2 = rep1._from(rep2) # type: ignore + return (rep1.l, rep1.p) < (rep2.l, rep2.p) + + @classmethod + def iterator(cls) -> Iterator['O3_e3nn']: + for l in itertools.count(0): + yield O3_e3nn(l=l, p=1 * (-1) ** l) + yield O3_e3nn(l=l, p=-1 * (-1) ** l) + +except ImportError: + _CUE_AVAILABLE = False + + +def is_cue_available(): + return _CUE_AVAILABLE + + +def cue_needed(func): + def wrapper(*args, **kwargs): + if is_cue_available(): + return func(*args, **kwargs) + else: + raise ImportError('cue is not available') + + return wrapper + + +def _check_may_not_compatible(orig_kwargs, defaults): + for k, v in defaults.items(): + v_given = orig_kwargs.pop(k, v) + if v_given != v: + warnings.warn(f'{k}: {v} is ignored to use cuEquivariance') + + +def is_cue_cuda_available_model(config): + if config.get('use_bias_in_linear', False): + warnings.warn('Bias in linear can not be used with cueq, fallback to e3nn') + return False + else: + return True + + +@cue_needed +def as_cue_irreps(irreps: o3.Irreps, group: Literal['SO3', 'O3']): + """Convert e3nn irreps to given group's cue irreps""" + if group == 'SO3': + assert all(irrep.ir.p == 1 for irrep in irreps) + return cue.Irreps('SO3', str(irreps).replace('e', '')) # type: ignore + elif group == 'O3': + return cue.Irreps(O3_e3nn, str(irreps)) # type: ignore + else: + raise ValueError(f'Unknown group: {group}') + + +@cue_needed +def patch_linear( + module: Union[IrrepsLinear, SelfConnectionLinearIntro], + group: Literal['SO3', 'O3'], + **cue_kwargs, +): + assert not module.layer_instantiated + + module.irreps_in = as_cue_irreps(module.irreps_in, group) # type: ignore + module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore + + orig_kwargs = module.linear_kwargs + + may_not_compatible_default = dict( + f_in=None, + f_out=None, + instructions=None, + biases=False, + path_normalization='element', + _optimize_einsums=None, + ) + # pop may_not_compatible_defaults + _check_may_not_compatible(orig_kwargs, may_not_compatible_default) + + module.linear_cls = cuet.Linear # type: ignore + orig_kwargs.update(**cue_kwargs) + return module + + +@cue_needed +def patch_convolution( + module: IrrepsConvolution, + group: Literal['SO3', 'O3'], + **cue_kwargs, +): + assert not module.layer_instantiated + + # conv_kwargs will be patched in place + conv_kwargs = module.convolution_kwargs + conv_kwargs.update( + dict( + irreps_in1=as_cue_irreps(conv_kwargs.get('irreps_in1'), group), + irreps_in2=as_cue_irreps(conv_kwargs.get('irreps_in2'), group), + filter_irreps_out=as_cue_irreps(conv_kwargs.pop('irreps_out'), group), + ) + ) + + inst_orig = conv_kwargs.pop('instructions') + inst_sorted = sorted(inst_orig, key=lambda x: x[2]) + assert all([a == b for a, b in zip(inst_orig, inst_sorted)]) + + may_not_compatible_default = dict( + in1_var=None, + in2_var=None, + out_var=None, + irrep_normalization=False, + path_normalization='element', + compile_left_right=True, + compile_right=False, + _specialized_code=None, + _optimize_einsums=None, + ) + # pop may_not_compatible_defaults + _check_may_not_compatible(conv_kwargs, may_not_compatible_default) + + module.convolution_cls = cuet.ChannelWiseTensorProduct # type: ignore + conv_kwargs.update(**cue_kwargs) + return module + + +@cue_needed +def patch_fully_connected( + module: SelfConnectionIntro, + group: Literal['SO3', 'O3'], + **cue_kwargs, +): + assert not module.layer_instantiated + + module.irreps_in1 = as_cue_irreps(module.irreps_in1, group) # type: ignore + module.irreps_in2 = as_cue_irreps(module.irreps_in2, group) # type: ignore + module.irreps_out = as_cue_irreps(module.irreps_out, group) # type: ignore + + may_not_compatible_default = dict( + irrep_normalization=None, + path_normalization=None, + ) + # pop may_not_compatible_defaults + _check_may_not_compatible( + module.fc_tensor_product_kwargs, may_not_compatible_default + ) + + module.fc_tensor_product_cls = cuet.FullyConnectedTensorProduct # type: ignore + module.fc_tensor_product_kwargs.update(**cue_kwargs) + return module diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/edge_embedding.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/edge_embedding.py index 8c53339..e738ef5 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/edge_embedding.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/edge_embedding.py @@ -1,217 +1,217 @@ -import math - -import torch -import torch.nn as nn -from e3nn.o3 import Irreps, SphericalHarmonics -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -@compile_mode('script') -class EdgePreprocess(nn.Module): - """ - preprocessing pos to edge vectors and edge lengths - currently used in sevenn/scripts/deploy for lammps serial model - """ - - def __init__(self, is_stress: bool): - super().__init__() - # controlled by 'AtomGraphSequential' - self.is_stress = is_stress - self._is_batch_data = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self._is_batch_data: - cell = data[KEY.CELL].view(-1, 3, 3) - else: - cell = data[KEY.CELL].view(3, 3) - cell_shift = data[KEY.CELL_SHIFT] - pos = data[KEY.POS] - - batch = data[KEY.BATCH] # for deploy, must be defined first - if self.is_stress: - if self._is_batch_data: - num_batch = int(batch.max().cpu().item()) + 1 - strain = torch.zeros( - (num_batch, 3, 3), - dtype=pos.dtype, - device=pos.device, - ) - strain.requires_grad_(True) - data['_strain'] = strain - - sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) - pos = pos + torch.bmm( - pos.unsqueeze(-2), sym_strain[batch] - ).squeeze(-2) - cell = cell + torch.bmm(cell, sym_strain) - else: - strain = torch.zeros( - (3, 3), - dtype=pos.dtype, - device=pos.device, - ) - strain.requires_grad_(True) - data['_strain'] = strain - - sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) - pos = pos + torch.mm(pos, sym_strain) - cell = cell + torch.mm(cell, sym_strain) - - idx_src = data[KEY.EDGE_IDX][0] - idx_dst = data[KEY.EDGE_IDX][1] - - edge_vec = pos[idx_dst] - pos[idx_src] - - if self._is_batch_data: - edge_vec = edge_vec + torch.einsum( - 'ni,nij->nj', cell_shift, cell[batch[idx_src]] - ) - else: - edge_vec = edge_vec + torch.einsum( - 'ni,ij->nj', cell_shift, cell.squeeze(0) - ) - data[KEY.EDGE_VEC] = edge_vec - data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1) - return data - - -class BesselBasis(nn.Module): - """ - f : (*, 1) -> (*, bessel_basis_num) - """ - - def __init__( - self, - cutoff_length: float, - bessel_basis_num: int = 8, - trainable_coeff: bool = True, - ): - super().__init__() - self.num_basis = bessel_basis_num - self.prefactor = 2.0 / cutoff_length - self.coeffs = torch.FloatTensor([ - n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1) - ]) - if trainable_coeff: - self.coeffs = nn.Parameter(self.coeffs) - - def forward(self, r: torch.Tensor) -> torch.Tensor: - ur = r.unsqueeze(-1) # to fit dimension - return self.prefactor * torch.sin(self.coeffs * ur) / ur - - -class PolynomialCutoff(nn.Module): - """ - f : (*, 1) -> (*, 1) - https://arxiv.org/pdf/2003.03123.pdf - """ - - def __init__( - self, - cutoff_length: float, - poly_cut_p_value: int = 6, - ): - super().__init__() - p = poly_cut_p_value - self.cutoff_length = cutoff_length - self.p = p - self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0 - self.coeff_p1 = p * (p + 2.0) - self.coeff_p2 = p * (p + 1.0) / 2.0 - - def forward(self, r: torch.Tensor) -> torch.Tensor: - r = r / self.cutoff_length - return ( - 1 - - self.coeff_p0 * torch.pow(r, self.p) - + self.coeff_p1 * torch.pow(r, self.p + 1.0) - - self.coeff_p2 * torch.pow(r, self.p + 2.0) - ) - - -class XPLORCutoff(nn.Module): - """ - https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html - """ - - def __init__( - self, - cutoff_length: float, - cutoff_on: float, - ): - super().__init__() - self.r_on = cutoff_on - self.r_cut = cutoff_length - assert self.r_on < self.r_cut - - def forward(self, r: torch.Tensor) -> torch.Tensor: - r_sq = r * r - r_on_sq = self.r_on * self.r_on - r_cut_sq = self.r_cut * self.r_cut - return torch.where( - r < self.r_on, - 1.0, - (r_cut_sq - r_sq) ** 2 - * (r_cut_sq + 2 * r_sq - 3 * r_on_sq) - / (r_cut_sq - r_on_sq) ** 3, - ) - - -@compile_mode('script') -class SphericalEncoding(nn.Module): - def __init__( - self, - lmax: int, - parity: int = -1, - normalization: str = 'component', - normalize: bool = True, - ): - super().__init__() - self.lmax = lmax - self.normalization = normalization - self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e') - self.irreps_out = Irreps.spherical_harmonics(lmax, parity) - self.sph = SphericalHarmonics( - self.irreps_out, - normalize=normalize, - normalization=normalization, - irreps_in=self.irreps_in, - ) - - def forward(self, r: torch.Tensor) -> torch.Tensor: - return self.sph(r) - - -@compile_mode('script') -class EdgeEmbedding(nn.Module): - """ - embedding layer of |r| by - RadialBasis(|r|)*CutOff(|r|) - f : (N_edge) -> (N_edge, basis_num) - """ - - def __init__( - self, - basis_module: nn.Module, - cutoff_module: nn.Module, - spherical_module: nn.Module, - ): - super().__init__() - self.basis_function = basis_module - self.cutoff_function = cutoff_module - self.spherical = spherical_module - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - rvec = data[KEY.EDGE_VEC] - r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1) - data[KEY.EDGE_LENGTH] = r - - data[KEY.EDGE_EMBEDDING] = self.basis_function( - r - ) * self.cutoff_function(r).unsqueeze(-1) - data[KEY.EDGE_ATTR] = self.spherical(rvec) - - return data +import math + +import torch +import torch.nn as nn +from e3nn.o3 import Irreps, SphericalHarmonics +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +@compile_mode('script') +class EdgePreprocess(nn.Module): + """ + preprocessing pos to edge vectors and edge lengths + currently used in sevenn/scripts/deploy for lammps serial model + """ + + def __init__(self, is_stress: bool): + super().__init__() + # controlled by 'AtomGraphSequential' + self.is_stress = is_stress + self._is_batch_data = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self._is_batch_data: + cell = data[KEY.CELL].view(-1, 3, 3) + else: + cell = data[KEY.CELL].view(3, 3) + cell_shift = data[KEY.CELL_SHIFT] + pos = data[KEY.POS] + + batch = data[KEY.BATCH] # for deploy, must be defined first + if self.is_stress: + if self._is_batch_data: + num_batch = int(batch.max().cpu().item()) + 1 + strain = torch.zeros( + (num_batch, 3, 3), + dtype=pos.dtype, + device=pos.device, + ) + strain.requires_grad_(True) + data['_strain'] = strain + + sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) + pos = pos + torch.bmm( + pos.unsqueeze(-2), sym_strain[batch] + ).squeeze(-2) + cell = cell + torch.bmm(cell, sym_strain) + else: + strain = torch.zeros( + (3, 3), + dtype=pos.dtype, + device=pos.device, + ) + strain.requires_grad_(True) + data['_strain'] = strain + + sym_strain = 0.5 * (strain + strain.transpose(-1, -2)) + pos = pos + torch.mm(pos, sym_strain) + cell = cell + torch.mm(cell, sym_strain) + + idx_src = data[KEY.EDGE_IDX][0] + idx_dst = data[KEY.EDGE_IDX][1] + + edge_vec = pos[idx_dst] - pos[idx_src] + + if self._is_batch_data: + edge_vec = edge_vec + torch.einsum( + 'ni,nij->nj', cell_shift, cell[batch[idx_src]] + ) + else: + edge_vec = edge_vec + torch.einsum( + 'ni,ij->nj', cell_shift, cell.squeeze(0) + ) + data[KEY.EDGE_VEC] = edge_vec + data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1) + return data + + +class BesselBasis(nn.Module): + """ + f : (*, 1) -> (*, bessel_basis_num) + """ + + def __init__( + self, + cutoff_length: float, + bessel_basis_num: int = 8, + trainable_coeff: bool = True, + ): + super().__init__() + self.num_basis = bessel_basis_num + self.prefactor = 2.0 / cutoff_length + self.coeffs = torch.FloatTensor([ + n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1) + ]) + if trainable_coeff: + self.coeffs = nn.Parameter(self.coeffs) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + ur = r.unsqueeze(-1) # to fit dimension + return self.prefactor * torch.sin(self.coeffs * ur) / ur + + +class PolynomialCutoff(nn.Module): + """ + f : (*, 1) -> (*, 1) + https://arxiv.org/pdf/2003.03123.pdf + """ + + def __init__( + self, + cutoff_length: float, + poly_cut_p_value: int = 6, + ): + super().__init__() + p = poly_cut_p_value + self.cutoff_length = cutoff_length + self.p = p + self.coeff_p0 = (p + 1.0) * (p + 2.0) / 2.0 + self.coeff_p1 = p * (p + 2.0) + self.coeff_p2 = p * (p + 1.0) / 2.0 + + def forward(self, r: torch.Tensor) -> torch.Tensor: + r = r / self.cutoff_length + return ( + 1 + - self.coeff_p0 * torch.pow(r, self.p) + + self.coeff_p1 * torch.pow(r, self.p + 1.0) + - self.coeff_p2 * torch.pow(r, self.p + 2.0) + ) + + +class XPLORCutoff(nn.Module): + """ + https://hoomd-blue.readthedocs.io/en/latest/module-md-pair.html + """ + + def __init__( + self, + cutoff_length: float, + cutoff_on: float, + ): + super().__init__() + self.r_on = cutoff_on + self.r_cut = cutoff_length + assert self.r_on < self.r_cut + + def forward(self, r: torch.Tensor) -> torch.Tensor: + r_sq = r * r + r_on_sq = self.r_on * self.r_on + r_cut_sq = self.r_cut * self.r_cut + return torch.where( + r < self.r_on, + 1.0, + (r_cut_sq - r_sq) ** 2 + * (r_cut_sq + 2 * r_sq - 3 * r_on_sq) + / (r_cut_sq - r_on_sq) ** 3, + ) + + +@compile_mode('script') +class SphericalEncoding(nn.Module): + def __init__( + self, + lmax: int, + parity: int = -1, + normalization: str = 'component', + normalize: bool = True, + ): + super().__init__() + self.lmax = lmax + self.normalization = normalization + self.irreps_in = Irreps('1x1o') if parity == -1 else Irreps('1x1e') + self.irreps_out = Irreps.spherical_harmonics(lmax, parity) + self.sph = SphericalHarmonics( + self.irreps_out, + normalize=normalize, + normalization=normalization, + irreps_in=self.irreps_in, + ) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + return self.sph(r) + + +@compile_mode('script') +class EdgeEmbedding(nn.Module): + """ + embedding layer of |r| by + RadialBasis(|r|)*CutOff(|r|) + f : (N_edge) -> (N_edge, basis_num) + """ + + def __init__( + self, + basis_module: nn.Module, + cutoff_module: nn.Module, + spherical_module: nn.Module, + ): + super().__init__() + self.basis_function = basis_module + self.cutoff_function = cutoff_module + self.spherical = spherical_module + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + rvec = data[KEY.EDGE_VEC] + r = torch.linalg.norm(data[KEY.EDGE_VEC], dim=-1) + data[KEY.EDGE_LENGTH] = r + + data[KEY.EDGE_EMBEDDING] = self.basis_function( + r + ) * self.cutoff_function(r).unsqueeze(-1) + data[KEY.EDGE_ATTR] = self.spherical(rvec) + + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/equivariant_gate.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/equivariant_gate.py index 5c36fe2..8422405 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/equivariant_gate.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/equivariant_gate.py @@ -1,61 +1,61 @@ -from typing import Callable, Dict - -import torch.nn as nn -from e3nn.nn import Gate -from e3nn.o3 import Irreps -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -@compile_mode('script') -class EquivariantGate(nn.Module): - def __init__( - self, - irreps_x: Irreps, - act_scalar_dict: Dict[int, Callable], - act_gate_dict: Dict[int, Callable], - data_key_x: str = KEY.NODE_FEATURE, - ): - super().__init__() - self.key_x = data_key_x - - parity_mapper = {'e': 1, 'o': -1} - act_scalar_dict = { - parity_mapper[k]: v for k, v in act_scalar_dict.items() - } - act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()} - - irreps_gated_elem = [] - irreps_scalars_elem = [] - # non scalar irreps > gated / scalar irreps > scalars - for mul, irreps in irreps_x: - if irreps.l > 0: - irreps_gated_elem.append((mul, irreps)) - else: - irreps_scalars_elem.append((mul, irreps)) - irreps_scalars = Irreps(irreps_scalars_elem) - irreps_gated = Irreps(irreps_gated_elem) - - irreps_gates_parity = 1 if '0e' in irreps_scalars else -1 - irreps_gates = Irreps( - [(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated] - ) - - act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars] - act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates] - - self.gate = Gate( - irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated - ) - - def get_gate_irreps_in(self): - """ - user must call this function to get proper irreps in for forward - """ - return self.gate.irreps_in - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[self.key_x] = self.gate(data[self.key_x]) - return data +from typing import Callable, Dict + +import torch.nn as nn +from e3nn.nn import Gate +from e3nn.o3 import Irreps +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +@compile_mode('script') +class EquivariantGate(nn.Module): + def __init__( + self, + irreps_x: Irreps, + act_scalar_dict: Dict[int, Callable], + act_gate_dict: Dict[int, Callable], + data_key_x: str = KEY.NODE_FEATURE, + ): + super().__init__() + self.key_x = data_key_x + + parity_mapper = {'e': 1, 'o': -1} + act_scalar_dict = { + parity_mapper[k]: v for k, v in act_scalar_dict.items() + } + act_gate_dict = {parity_mapper[k]: v for k, v in act_gate_dict.items()} + + irreps_gated_elem = [] + irreps_scalars_elem = [] + # non scalar irreps > gated / scalar irreps > scalars + for mul, irreps in irreps_x: + if irreps.l > 0: + irreps_gated_elem.append((mul, irreps)) + else: + irreps_scalars_elem.append((mul, irreps)) + irreps_scalars = Irreps(irreps_scalars_elem) + irreps_gated = Irreps(irreps_gated_elem) + + irreps_gates_parity = 1 if '0e' in irreps_scalars else -1 + irreps_gates = Irreps( + [(mul, (0, irreps_gates_parity)) for mul, _ in irreps_gated] + ) + + act_scalars = [act_scalar_dict[p] for _, (_, p) in irreps_scalars] + act_gates = [act_gate_dict[p] for _, (_, p) in irreps_gates] + + self.gate = Gate( + irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated + ) + + def get_gate_irreps_in(self): + """ + user must call this function to get proper irreps in for forward + """ + return self.gate.irreps_in + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[self.key_x] = self.gate(data[self.key_x]) + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/force_output.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/force_output.py index d1360c7..f6b90e5 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/force_output.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/force_output.py @@ -1,224 +1,224 @@ -import torch -import torch.nn as nn -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - -from .util import broadcast - - -@compile_mode('script') -class ForceOutput(nn.Module): - """ - works when pos.requires_grad_ is True - """ - - def __init__( - self, - data_key_pos: str = KEY.POS, - data_key_energy: str = KEY.PRED_TOTAL_ENERGY, - data_key_force: str = KEY.PRED_FORCE, - ): - super().__init__() - self.key_pos = data_key_pos - self.key_energy = data_key_energy - self.key_force = data_key_force - - def get_grad_key(self): - return self.key_pos - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - pos_tensor = [data[self.key_pos]] - energy = [(data[self.key_energy]).sum()] - - # `materialize_grads` not supported in low version of pytorch - # Also can not be deployed when using it. - # But not using it makes problem in - # force/stress inference in sparse systems - # TODO: use it only in sevennet_calculator? - grad = torch.autograd.grad( - energy, - pos_tensor, - create_graph=self.training, - allow_unused=True, - # materialize_grads=True, - )[0] - - # For torchscript - if grad is not None: - data[self.key_force] = torch.neg(grad) - return data - - -@compile_mode('script') -class ForceStressOutput(nn.Module): - """ - Compute stress and force from positions. - Used in serial torchscipt models - """ - def __init__( - self, - data_key_pos: str = KEY.POS, - data_key_energy: str = KEY.PRED_TOTAL_ENERGY, - data_key_force: str = KEY.PRED_FORCE, - data_key_stress: str = KEY.PRED_STRESS, - data_key_cell_volume: str = KEY.CELL_VOLUME, - ): - - super().__init__() - self.key_pos = data_key_pos - self.key_energy = data_key_energy - self.key_force = data_key_force - self.key_stress = data_key_stress - self.key_cell_volume = data_key_cell_volume - self._is_batch_data = True - - def get_grad_key(self): - return self.key_pos - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - pos_tensor = data[self.key_pos] - energy = [(data[self.key_energy]).sum()] - - # `materialize_grads` not supported in low version of pytorch - # Also can not be deployed when using it. - # But not using it makes problem in - # force/stress inference in sparse systems - # TODO: use it only in sevennet_calculator? - grad = torch.autograd.grad( - energy, - [pos_tensor, data['_strain']], - create_graph=self.training, - allow_unused=True, - # materialize_grads=True, - ) - - # make grad is not Optional[Tensor] - fgrad = grad[0] - if fgrad is not None: - data[self.key_force] = torch.neg(fgrad) - - sgrad = grad[1] - volume = data[self.key_cell_volume] - vlim = 1e-3 # for cell volume = 0 for non PBC structures - if self._is_batch_data: - volume[volume < vlim] = vlim - elif volume < vlim: - volume = torch.tensor(vlim) - - if sgrad is not None: - if self._is_batch_data: - stress = sgrad / volume.view(-1, 1, 1) - stress = torch.neg(stress) - virial_stress = torch.vstack(( - stress[:, 0, 0], - stress[:, 1, 1], - stress[:, 2, 2], - stress[:, 0, 1], - stress[:, 1, 2], - stress[:, 0, 2], - )) - data[self.key_stress] = virial_stress.transpose(0, 1) - else: - stress = sgrad / volume - stress = torch.neg(stress) - virial_stress = torch.stack(( - stress[0, 0], - stress[1, 1], - stress[2, 2], - stress[0, 1], - stress[1, 2], - stress[0, 2], - )) - data[self.key_stress] = virial_stress - - return data - - -@compile_mode('script') -class ForceStressOutputFromEdge(nn.Module): - """ - Compute stress and force from edge. - Used in parallel torchscipt models, and training - """ - def __init__( - self, - data_key_edge: str = KEY.EDGE_VEC, - data_key_edge_idx: str = KEY.EDGE_IDX, - data_key_energy: str = KEY.PRED_TOTAL_ENERGY, - data_key_force: str = KEY.PRED_FORCE, - data_key_stress: str = KEY.PRED_STRESS, - data_key_cell_volume: str = KEY.CELL_VOLUME, - ): - - super().__init__() - self.key_edge = data_key_edge - self.key_edge_idx = data_key_edge_idx - self.key_energy = data_key_energy - self.key_force = data_key_force - self.key_stress = data_key_stress - self.key_cell_volume = data_key_cell_volume - self._is_batch_data = True - - def get_grad_key(self): - return self.key_edge - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item? - rij = data[self.key_edge] - energy = [(data[self.key_energy]).sum()] - edge_idx = data[self.key_edge_idx] - - grad = torch.autograd.grad( - energy, - [rij], - create_graph=self.training, - allow_unused=True - ) - - # make grad is not Optional[Tensor] - fij = grad[0] - - if fij is not None: - # compute force - pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) - nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) - _edge_src = broadcast(edge_idx[0], fij, 0) - _edge_dst = broadcast(edge_idx[1], fij, 0) - pf.scatter_reduce_(0, _edge_src, fij, reduce='sum') - nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum') - data[self.key_force] = pf - nf - - # compute virial - diag = rij * fij - s12 = rij[..., 0] * fij[..., 1] - s23 = rij[..., 1] * fij[..., 2] - s31 = rij[..., 2] * fij[..., 0] - # cat last dimension - _virial = torch.cat([ - diag, - s12.unsqueeze(-1), - s23.unsqueeze(-1), - s31.unsqueeze(-1) - ], dim=-1) - - _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) - _edge_dst6 = broadcast(edge_idx[1], _virial, 0) - _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') - - if self._is_batch_data: - batch = data[KEY.BATCH] # for deploy, must be defined first - nbatch = int(batch.max().cpu().item()) + 1 - sout = torch.zeros( - (nbatch, 6), dtype=_virial.dtype, device=_virial.device - ) - _batch = broadcast(batch, _s, 0) - sout.scatter_reduce_(0, _batch, _s, reduce='sum') - else: - sout = torch.sum(_s, dim=0) - - data[self.key_stress] =\ - torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1) - - return data +import torch +import torch.nn as nn +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + +from .util import broadcast + + +@compile_mode('script') +class ForceOutput(nn.Module): + """ + works when pos.requires_grad_ is True + """ + + def __init__( + self, + data_key_pos: str = KEY.POS, + data_key_energy: str = KEY.PRED_TOTAL_ENERGY, + data_key_force: str = KEY.PRED_FORCE, + ): + super().__init__() + self.key_pos = data_key_pos + self.key_energy = data_key_energy + self.key_force = data_key_force + + def get_grad_key(self): + return self.key_pos + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + pos_tensor = [data[self.key_pos]] + energy = [(data[self.key_energy]).sum()] + + # `materialize_grads` not supported in low version of pytorch + # Also can not be deployed when using it. + # But not using it makes problem in + # force/stress inference in sparse systems + # TODO: use it only in sevennet_calculator? + grad = torch.autograd.grad( + energy, + pos_tensor, + create_graph=self.training, + allow_unused=True, + # materialize_grads=True, + )[0] + + # For torchscript + if grad is not None: + data[self.key_force] = torch.neg(grad) + return data + + +@compile_mode('script') +class ForceStressOutput(nn.Module): + """ + Compute stress and force from positions. + Used in serial torchscipt models + """ + def __init__( + self, + data_key_pos: str = KEY.POS, + data_key_energy: str = KEY.PRED_TOTAL_ENERGY, + data_key_force: str = KEY.PRED_FORCE, + data_key_stress: str = KEY.PRED_STRESS, + data_key_cell_volume: str = KEY.CELL_VOLUME, + ): + + super().__init__() + self.key_pos = data_key_pos + self.key_energy = data_key_energy + self.key_force = data_key_force + self.key_stress = data_key_stress + self.key_cell_volume = data_key_cell_volume + self._is_batch_data = True + + def get_grad_key(self): + return self.key_pos + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + pos_tensor = data[self.key_pos] + energy = [(data[self.key_energy]).sum()] + + # `materialize_grads` not supported in low version of pytorch + # Also can not be deployed when using it. + # But not using it makes problem in + # force/stress inference in sparse systems + # TODO: use it only in sevennet_calculator? + grad = torch.autograd.grad( + energy, + [pos_tensor, data['_strain']], + create_graph=self.training, + allow_unused=True, + # materialize_grads=True, + ) + + # make grad is not Optional[Tensor] + fgrad = grad[0] + if fgrad is not None: + data[self.key_force] = torch.neg(fgrad) + + sgrad = grad[1] + volume = data[self.key_cell_volume] + vlim = 1e-3 # for cell volume = 0 for non PBC structures + if self._is_batch_data: + volume[volume < vlim] = vlim + elif volume < vlim: + volume = torch.tensor(vlim) + + if sgrad is not None: + if self._is_batch_data: + stress = sgrad / volume.view(-1, 1, 1) + stress = torch.neg(stress) + virial_stress = torch.vstack(( + stress[:, 0, 0], + stress[:, 1, 1], + stress[:, 2, 2], + stress[:, 0, 1], + stress[:, 1, 2], + stress[:, 0, 2], + )) + data[self.key_stress] = virial_stress.transpose(0, 1) + else: + stress = sgrad / volume + stress = torch.neg(stress) + virial_stress = torch.stack(( + stress[0, 0], + stress[1, 1], + stress[2, 2], + stress[0, 1], + stress[1, 2], + stress[0, 2], + )) + data[self.key_stress] = virial_stress + + return data + + +@compile_mode('script') +class ForceStressOutputFromEdge(nn.Module): + """ + Compute stress and force from edge. + Used in parallel torchscipt models, and training + """ + def __init__( + self, + data_key_edge: str = KEY.EDGE_VEC, + data_key_edge_idx: str = KEY.EDGE_IDX, + data_key_energy: str = KEY.PRED_TOTAL_ENERGY, + data_key_force: str = KEY.PRED_FORCE, + data_key_stress: str = KEY.PRED_STRESS, + data_key_cell_volume: str = KEY.CELL_VOLUME, + ): + + super().__init__() + self.key_edge = data_key_edge + self.key_edge_idx = data_key_edge_idx + self.key_energy = data_key_energy + self.key_force = data_key_force + self.key_stress = data_key_stress + self.key_cell_volume = data_key_cell_volume + self._is_batch_data = True + + def get_grad_key(self): + return self.key_edge + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + tot_num = torch.sum(data[KEY.NUM_ATOMS]) # ? item? + rij = data[self.key_edge] + energy = [(data[self.key_energy]).sum()] + edge_idx = data[self.key_edge_idx] + + grad = torch.autograd.grad( + energy, + [rij], + create_graph=self.training, + allow_unused=True + ) + + # make grad is not Optional[Tensor] + fij = grad[0] + + if fij is not None: + # compute force + pf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) + nf = torch.zeros(tot_num, 3, dtype=fij.dtype, device=fij.device) + _edge_src = broadcast(edge_idx[0], fij, 0) + _edge_dst = broadcast(edge_idx[1], fij, 0) + pf.scatter_reduce_(0, _edge_src, fij, reduce='sum') + nf.scatter_reduce_(0, _edge_dst, fij, reduce='sum') + data[self.key_force] = pf - nf + + # compute virial + diag = rij * fij + s12 = rij[..., 0] * fij[..., 1] + s23 = rij[..., 1] * fij[..., 2] + s31 = rij[..., 2] * fij[..., 0] + # cat last dimension + _virial = torch.cat([ + diag, + s12.unsqueeze(-1), + s23.unsqueeze(-1), + s31.unsqueeze(-1) + ], dim=-1) + + _s = torch.zeros(tot_num, 6, dtype=fij.dtype, device=fij.device) + _edge_dst6 = broadcast(edge_idx[1], _virial, 0) + _s.scatter_reduce_(0, _edge_dst6, _virial, reduce='sum') + + if self._is_batch_data: + batch = data[KEY.BATCH] # for deploy, must be defined first + nbatch = int(batch.max().cpu().item()) + 1 + sout = torch.zeros( + (nbatch, 6), dtype=_virial.dtype, device=_virial.device + ) + _batch = broadcast(batch, _s, 0) + sout.scatter_reduce_(0, _batch, _s, reduce='sum') + else: + sout = torch.sum(_s, dim=0) + + data[self.key_stress] =\ + torch.neg(sout) / data[self.key_cell_volume].unsqueeze(-1) + + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/interaction_blocks.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/interaction_blocks.py index 24caa82..3f93768 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/interaction_blocks.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/interaction_blocks.py @@ -1,76 +1,76 @@ -from typing import Callable, List, Tuple - -from e3nn.o3 import Irreps - -import sevenn._keys as KEY - -from .convolution import IrrepsConvolution -from .equivariant_gate import EquivariantGate -from .linear import IrrepsLinear - - -def NequIP_interaction_block( - irreps_x: Irreps, - irreps_filter: Irreps, - irreps_out_tp: Irreps, - irreps_out: Irreps, - weight_nn_layers: List[int], - conv_denominator: float, - train_conv_denominator: bool, - self_connection_pair: Tuple[Callable, Callable], - act_scalar: Callable, - act_gate: Callable, - act_radial: Callable, - bias_in_linear: bool, - num_species: int, - t: int, # interaction layer index - data_key_x: str = KEY.NODE_FEATURE, - data_key_weight_input: str = KEY.EDGE_EMBEDDING, - parallel: bool = False, - **conv_kwargs, -): - block = {} - irreps_node_attr = Irreps(f'{num_species}x0e') - sc_intro, sc_outro = self_connection_pair - - gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate) - irreps_for_gate_in = gate_layer.get_gate_irreps_in() - - block[f'{t}_self_connection_intro'] = sc_intro( - irreps_x, - irreps_operand=irreps_node_attr, - irreps_out=irreps_for_gate_in, - ) - - block[f'{t}_self_interaction_1'] = IrrepsLinear( - irreps_x, irreps_x, - data_key_in=data_key_x, - biases=bias_in_linear, - ) - - # convolution part, l>lmax is dropped as defined in irreps_out - block[f'{t}_convolution'] = IrrepsConvolution( - irreps_x=irreps_x, - irreps_filter=irreps_filter, - irreps_out=irreps_out_tp, - data_key_weight_input=data_key_weight_input, - weight_layer_input_to_hidden=weight_nn_layers, - weight_layer_act=act_radial, - denominator=conv_denominator, - train_denominator=train_conv_denominator, - is_parallel=parallel, - **conv_kwargs, - ) - - # irreps of x increase to gate_irreps_in - block[f'{t}_self_interaction_2'] = IrrepsLinear( - irreps_out_tp, - irreps_for_gate_in, - data_key_in=data_key_x, - biases=bias_in_linear, - ) - - block[f'{t}_self_connection_outro'] = sc_outro() - block[f'{t}_equivariant_gate'] = gate_layer - - return block +from typing import Callable, List, Tuple + +from e3nn.o3 import Irreps + +import sevenn._keys as KEY + +from .convolution import IrrepsConvolution +from .equivariant_gate import EquivariantGate +from .linear import IrrepsLinear + + +def NequIP_interaction_block( + irreps_x: Irreps, + irreps_filter: Irreps, + irreps_out_tp: Irreps, + irreps_out: Irreps, + weight_nn_layers: List[int], + conv_denominator: float, + train_conv_denominator: bool, + self_connection_pair: Tuple[Callable, Callable], + act_scalar: Callable, + act_gate: Callable, + act_radial: Callable, + bias_in_linear: bool, + num_species: int, + t: int, # interaction layer index + data_key_x: str = KEY.NODE_FEATURE, + data_key_weight_input: str = KEY.EDGE_EMBEDDING, + parallel: bool = False, + **conv_kwargs, +): + block = {} + irreps_node_attr = Irreps(f'{num_species}x0e') + sc_intro, sc_outro = self_connection_pair + + gate_layer = EquivariantGate(irreps_out, act_scalar, act_gate) + irreps_for_gate_in = gate_layer.get_gate_irreps_in() + + block[f'{t}_self_connection_intro'] = sc_intro( + irreps_x, + irreps_operand=irreps_node_attr, + irreps_out=irreps_for_gate_in, + ) + + block[f'{t}_self_interaction_1'] = IrrepsLinear( + irreps_x, irreps_x, + data_key_in=data_key_x, + biases=bias_in_linear, + ) + + # convolution part, l>lmax is dropped as defined in irreps_out + block[f'{t}_convolution'] = IrrepsConvolution( + irreps_x=irreps_x, + irreps_filter=irreps_filter, + irreps_out=irreps_out_tp, + data_key_weight_input=data_key_weight_input, + weight_layer_input_to_hidden=weight_nn_layers, + weight_layer_act=act_radial, + denominator=conv_denominator, + train_denominator=train_conv_denominator, + is_parallel=parallel, + **conv_kwargs, + ) + + # irreps of x increase to gate_irreps_in + block[f'{t}_self_interaction_2'] = IrrepsLinear( + irreps_out_tp, + irreps_for_gate_in, + data_key_in=data_key_x, + biases=bias_in_linear, + ) + + block[f'{t}_self_connection_outro'] = sc_outro() + block[f'{t}_equivariant_gate'] = gate_layer + + return block diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/linear.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/linear.py index 43faa29..b5b87d2 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/linear.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/linear.py @@ -1,180 +1,180 @@ -from typing import Callable, List, Optional - -import torch -import torch.nn as nn -from e3nn.nn import FullyConnectedNet -from e3nn.o3 import Irreps, Linear -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -@compile_mode('script') -class IrrepsLinear(nn.Module): - """ - wrapper class of e3nn Linear to operate on AtomGraphData - """ - - def __init__( - self, - irreps_in: Irreps, - irreps_out: Irreps, - data_key_in: str, - data_key_out: Optional[str] = None, - data_key_modal_attr: str = KEY.MODAL_ATTR, - num_modalities: int = 0, - lazy_layer_instantiate: bool = True, - **linear_kwargs, - ): - super().__init__() - self.key_input = data_key_in - if data_key_out is None: - self.key_output = data_key_in - else: - self.key_output = data_key_out - self.key_modal_attr = data_key_modal_attr - - self._irreps_in_wo_modal = irreps_in - self.irreps_in = irreps_in - self.irreps_out = irreps_out - self.linear_kwargs = linear_kwargs - - self.linear = None - self.layer_instantiated = False - self.num_modalities = num_modalities - self._is_batch_data = True - - # use getter setter - self.linear_cls = Linear - - if num_modalities > 1: # in case of multi-modal - self.set_num_modalities(num_modalities) - - if not lazy_layer_instantiate: - self.instantiate() - - def instantiate(self): - if self.linear is not None: - raise ValueError('Linear layer already exists') - self.linear = self.linear_cls( - self.irreps_in, self.irreps_out, **self.linear_kwargs - ) - self.layer_instantiated = True - - def set_num_modalities(self, num_modalities): - if self.layer_instantiated: - raise ValueError('Layer already instantiated, can not change modalities') - irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e') - self.num_modalities = num_modalities - self.irreps_in = irreps_in - - def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self._is_batch_data: - batch = data[KEY.BATCH] - batch_modality_onehot = data[self.key_modal_attr].reshape( - -1, self.num_modalities - ) - batch_modality_onehot = batch_modality_onehot.type( - data[self.key_input].dtype - ) - data[self.key_input] = torch.cat( - [data[self.key_input], batch_modality_onehot[batch]], dim=1 - ) - else: - modality_onehot = data[self.key_modal_attr].expand( - len(data[self.key_input]), -1 - ) - modality_onehot = modality_onehot.type(data[self.key_input].dtype) - data[self.key_input] = torch.cat( - [data[self.key_input], modality_onehot], dim=1 - ) - return data - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - assert self.linear is not None, 'Layer is not instantiated' - if self.num_modalities > 1: - data = self._patch_modal_to_data(data) - - data[self.key_output] = self.linear(data[self.key_input]) - return data - - -@compile_mode('script') -class AtomReduce(nn.Module): - """ - atomic energy -> total energy - constant is multiplied to data - """ - - def __init__( - self, - data_key_in: str, - data_key_out: str, - reduce: str = 'sum', - constant: float = 1.0, - ): - super().__init__() - - self.key_input = data_key_in - self.key_output = data_key_out - self.constant = constant - self.reduce = reduce - - # controlled by the upper most wrapper 'AtomGraphSequential' - self._is_batch_data = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self._is_batch_data: - src = data[self.key_input].squeeze(1) - size = int(data[KEY.BATCH].max()) + 1 - output = torch.zeros( - (size), - dtype=src.dtype, - device=src.device, - ) - output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum') - data[self.key_output] = output * self.constant - else: - data[self.key_output] = torch.sum(data[self.key_input]) * self.constant - - return data - - -@compile_mode('script') -class FCN_e3nn(nn.Module): - """ - wrapper class of e3nn FullyConnectedNet - """ - - def __init__( - self, - irreps_in: Irreps, # confirm it is scalar & input size - dim_out: int, - hidden_neurons: List[int], - activation: Callable, - data_key_in: str, - data_key_out: Optional[str] = None, - **e3nn_kwargs, - ): - super().__init__() - self.key_input = data_key_in - self.irreps_in = irreps_in - if data_key_out is None: - self.key_output = data_key_in - else: - self.key_output = data_key_out - - for _, irrep in irreps_in: - assert irrep.is_scalar() - inp_dim = irreps_in.dim - - self.fcn = FullyConnectedNet( - [inp_dim] + hidden_neurons + [dim_out], - activation, - **e3nn_kwargs, - ) - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[self.key_output] = self.fcn(data[self.key_input]) - return data +from typing import Callable, List, Optional + +import torch +import torch.nn as nn +from e3nn.nn import FullyConnectedNet +from e3nn.o3 import Irreps, Linear +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +@compile_mode('script') +class IrrepsLinear(nn.Module): + """ + wrapper class of e3nn Linear to operate on AtomGraphData + """ + + def __init__( + self, + irreps_in: Irreps, + irreps_out: Irreps, + data_key_in: str, + data_key_out: Optional[str] = None, + data_key_modal_attr: str = KEY.MODAL_ATTR, + num_modalities: int = 0, + lazy_layer_instantiate: bool = True, + **linear_kwargs, + ): + super().__init__() + self.key_input = data_key_in + if data_key_out is None: + self.key_output = data_key_in + else: + self.key_output = data_key_out + self.key_modal_attr = data_key_modal_attr + + self._irreps_in_wo_modal = irreps_in + self.irreps_in = irreps_in + self.irreps_out = irreps_out + self.linear_kwargs = linear_kwargs + + self.linear = None + self.layer_instantiated = False + self.num_modalities = num_modalities + self._is_batch_data = True + + # use getter setter + self.linear_cls = Linear + + if num_modalities > 1: # in case of multi-modal + self.set_num_modalities(num_modalities) + + if not lazy_layer_instantiate: + self.instantiate() + + def instantiate(self): + if self.linear is not None: + raise ValueError('Linear layer already exists') + self.linear = self.linear_cls( + self.irreps_in, self.irreps_out, **self.linear_kwargs + ) + self.layer_instantiated = True + + def set_num_modalities(self, num_modalities): + if self.layer_instantiated: + raise ValueError('Layer already instantiated, can not change modalities') + irreps_in = self._irreps_in_wo_modal + Irreps(f'{num_modalities}x0e') + self.num_modalities = num_modalities + self.irreps_in = irreps_in + + def _patch_modal_to_data(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self._is_batch_data: + batch = data[KEY.BATCH] + batch_modality_onehot = data[self.key_modal_attr].reshape( + -1, self.num_modalities + ) + batch_modality_onehot = batch_modality_onehot.type( + data[self.key_input].dtype + ) + data[self.key_input] = torch.cat( + [data[self.key_input], batch_modality_onehot[batch]], dim=1 + ) + else: + modality_onehot = data[self.key_modal_attr].expand( + len(data[self.key_input]), -1 + ) + modality_onehot = modality_onehot.type(data[self.key_input].dtype) + data[self.key_input] = torch.cat( + [data[self.key_input], modality_onehot], dim=1 + ) + return data + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + assert self.linear is not None, 'Layer is not instantiated' + if self.num_modalities > 1: + data = self._patch_modal_to_data(data) + + data[self.key_output] = self.linear(data[self.key_input]) + return data + + +@compile_mode('script') +class AtomReduce(nn.Module): + """ + atomic energy -> total energy + constant is multiplied to data + """ + + def __init__( + self, + data_key_in: str, + data_key_out: str, + reduce: str = 'sum', + constant: float = 1.0, + ): + super().__init__() + + self.key_input = data_key_in + self.key_output = data_key_out + self.constant = constant + self.reduce = reduce + + # controlled by the upper most wrapper 'AtomGraphSequential' + self._is_batch_data = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self._is_batch_data: + src = data[self.key_input].squeeze(1) + size = int(data[KEY.BATCH].max()) + 1 + output = torch.zeros( + (size), + dtype=src.dtype, + device=src.device, + ) + output.scatter_reduce_(0, data[KEY.BATCH], src, reduce='sum') + data[self.key_output] = output * self.constant + else: + data[self.key_output] = torch.sum(data[self.key_input]) * self.constant + + return data + + +@compile_mode('script') +class FCN_e3nn(nn.Module): + """ + wrapper class of e3nn FullyConnectedNet + """ + + def __init__( + self, + irreps_in: Irreps, # confirm it is scalar & input size + dim_out: int, + hidden_neurons: List[int], + activation: Callable, + data_key_in: str, + data_key_out: Optional[str] = None, + **e3nn_kwargs, + ): + super().__init__() + self.key_input = data_key_in + self.irreps_in = irreps_in + if data_key_out is None: + self.key_output = data_key_in + else: + self.key_output = data_key_out + + for _, irrep in irreps_in: + assert irrep.is_scalar() + inp_dim = irreps_in.dim + + self.fcn = FullyConnectedNet( + [inp_dim] + hidden_neurons + [dim_out], + activation, + **e3nn_kwargs, + ) + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[self.key_output] = self.fcn(data[self.key_input]) + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/node_embedding.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/node_embedding.py index 747d097..a5f272a 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/node_embedding.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/node_embedding.py @@ -1,91 +1,91 @@ -from typing import Dict, List, Optional - -import torch -import torch.nn as nn -import torch.nn.functional -from ase.symbols import symbols2numbers -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -# TODO: put this to model_build and do not preprocess data by onehot -@compile_mode('script') -class OnehotEmbedding(nn.Module): - """ - x : tensor of shape (N, 1) - x_after : tensor of shape (N, num_classes) - It overwrite data_key_x - and saves input to data_key_save and output to data_key_additional - I know this is strange but it is for compatibility with previous version - and to specie wise shift scale work - ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2) - """ - - def __init__( - self, - num_classes: int, - data_key_x: str = KEY.NODE_FEATURE, - data_key_out: Optional[str] = None, - data_key_save: Optional[str] = None, - data_key_additional: Optional[str] = None, # additional output - ): - super().__init__() - self.num_classes = num_classes - self.key_x = data_key_x - if data_key_out is None: - self.key_output = data_key_x - else: - self.key_output = data_key_out - self.key_save = data_key_save - self.key_additional_output = data_key_additional - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - inp = data[self.key_x] - embd = torch.nn.functional.one_hot(inp, self.num_classes) - embd = embd.float() - data[self.key_output] = embd - if self.key_additional_output is not None: - data[self.key_additional_output] = embd # for self-connection - if self.key_save is not None: - data[self.key_save] = inp # for elemwise shift scale - return data - - -def get_type_mapper_from_specie(specie_list: List[str]): - """ - from ['Hf', 'O'] - return {72: 0, 8: 1} - """ - specie_list = sorted(specie_list) - type_map = {} - unique_counter = 0 - for specie in specie_list: - atomic_num = symbols2numbers(specie)[0] - if atomic_num in type_map: - continue - type_map[atomic_num] = unique_counter - unique_counter += 1 - return type_map - - -# deprecated -def one_hot_atom_embedding( - atomic_numbers: List[int], type_map: Dict[int, int] -): - """ - atomic numbers from ase.get_atomic_numbers - type_map from get_type_mapper_from_specie() - """ - num_classes = len(type_map) - try: - type_numbers = torch.LongTensor( - [type_map[num] for num in atomic_numbers] - ) - except KeyError as e: - raise ValueError(f'Atomic number {e.args[0]} is not expected') - embd = torch.nn.functional.one_hot(type_numbers, num_classes) - embd = embd.to(torch.get_default_dtype()) - - return embd +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional +from ase.symbols import symbols2numbers +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +# TODO: put this to model_build and do not preprocess data by onehot +@compile_mode('script') +class OnehotEmbedding(nn.Module): + """ + x : tensor of shape (N, 1) + x_after : tensor of shape (N, num_classes) + It overwrite data_key_x + and saves input to data_key_save and output to data_key_additional + I know this is strange but it is for compatibility with previous version + and to specie wise shift scale work + ex) [0 1 1 0] -> [[1, 0] [0, 1] [0, 1] [1, 0]] (num_classes = 2) + """ + + def __init__( + self, + num_classes: int, + data_key_x: str = KEY.NODE_FEATURE, + data_key_out: Optional[str] = None, + data_key_save: Optional[str] = None, + data_key_additional: Optional[str] = None, # additional output + ): + super().__init__() + self.num_classes = num_classes + self.key_x = data_key_x + if data_key_out is None: + self.key_output = data_key_x + else: + self.key_output = data_key_out + self.key_save = data_key_save + self.key_additional_output = data_key_additional + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + inp = data[self.key_x] + embd = torch.nn.functional.one_hot(inp, self.num_classes) + embd = embd.float() + data[self.key_output] = embd + if self.key_additional_output is not None: + data[self.key_additional_output] = embd # for self-connection + if self.key_save is not None: + data[self.key_save] = inp # for elemwise shift scale + return data + + +def get_type_mapper_from_specie(specie_list: List[str]): + """ + from ['Hf', 'O'] + return {72: 0, 8: 1} + """ + specie_list = sorted(specie_list) + type_map = {} + unique_counter = 0 + for specie in specie_list: + atomic_num = symbols2numbers(specie)[0] + if atomic_num in type_map: + continue + type_map[atomic_num] = unique_counter + unique_counter += 1 + return type_map + + +# deprecated +def one_hot_atom_embedding( + atomic_numbers: List[int], type_map: Dict[int, int] +): + """ + atomic numbers from ase.get_atomic_numbers + type_map from get_type_mapper_from_specie() + """ + num_classes = len(type_map) + try: + type_numbers = torch.LongTensor( + [type_map[num] for num in atomic_numbers] + ) + except KeyError as e: + raise ValueError(f'Atomic number {e.args[0]} is not expected') + embd = torch.nn.functional.one_hot(type_numbers, num_classes) + embd = embd.to(torch.get_default_dtype()) + + return embd diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/scale.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/scale.py index f593770..73da835 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/scale.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/scale.py @@ -1,387 +1,387 @@ -from typing import Any, Dict, List, Optional, Union - -import torch -import torch.nn as nn -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType - - -def _as_univ( - ss: List[float], type_map: Dict[int, int], default: float -) -> List[float]: - assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long' - return [ - ss[type_map[z]] if z in type_map else default - for z in range(NUM_UNIV_ELEMENT) - ] - - -@compile_mode('script') -class Rescale(nn.Module): - """ - Scaling and shifting energy (and automatically force and stress) - """ - - def __init__( - self, - shift: float, - scale: float, - data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, - data_key_out: str = KEY.ATOMIC_ENERGY, - train_shift_scale: bool = False, - **kwargs, - ): - assert isinstance(shift, float) and isinstance(scale, float) - super().__init__() - self.shift = nn.Parameter( - torch.FloatTensor([shift]), requires_grad=train_shift_scale - ) - self.scale = nn.Parameter( - torch.FloatTensor([scale]), requires_grad=train_shift_scale - ) - self.key_input = data_key_in - self.key_output = data_key_out - - def get_shift(self) -> float: - return self.shift.detach().cpu().tolist()[0] - - def get_scale(self) -> float: - return self.scale.detach().cpu().tolist()[0] - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[self.key_output] = data[self.key_input] * self.scale + self.shift - - return data - - -@compile_mode('script') -class SpeciesWiseRescale(nn.Module): - """ - Scaling and shifting energy (and automatically force and stress) - Use as it is if given list, expand to list if one of them is float - If two lists are given and length is not the same, raise error - """ - - def __init__( - self, - shift: Union[List[float], float], - scale: Union[List[float], float], - data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, - data_key_out: str = KEY.ATOMIC_ENERGY, - data_key_indices: str = KEY.ATOM_TYPE, - train_shift_scale: bool = False, - ): - super().__init__() - assert isinstance(shift, float) or isinstance(shift, list) - assert isinstance(scale, float) or isinstance(scale, list) - - if ( - isinstance(shift, list) - and isinstance(scale, list) - and len(shift) != len(scale) - ): - raise ValueError('List length should be same') - - if isinstance(shift, list): - num_species = len(shift) - elif isinstance(scale, list): - num_species = len(scale) - else: - raise ValueError('Both shift and scale is not a list') - - shift = [shift] * num_species if isinstance(shift, float) else shift - scale = [scale] * num_species if isinstance(scale, float) else scale - - self.shift = nn.Parameter( - torch.FloatTensor(shift), requires_grad=train_shift_scale - ) - self.scale = nn.Parameter( - torch.FloatTensor(scale), requires_grad=train_shift_scale - ) - self.key_input = data_key_in - self.key_output = data_key_out - self.key_indices = data_key_indices - - def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: - """ - Return shift in list of float. If type_map is given, return type_map reversed - shift, which index equals atomic_number. 0.0 is assigned for atomis not found - """ - shift = self.shift.detach().cpu().tolist() - if type_map: - shift = _as_univ(shift, type_map, 0.0) - return shift - - def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: - """ - Return scale in list of float. If type_map is given, return type_map reversed - scale, which index equals atomic_number. 1.0 is assigned for atomis not found - """ - scale = self.scale.detach().cpu().tolist() - if type_map: - scale = _as_univ(scale, type_map, 1.0) - return scale - - @staticmethod - def from_mappers( - shift: Union[float, List[float]], - scale: Union[float, List[float]], - type_map: Dict[int, int], - **kwargs, - ): - """ - Fit dimensions or mapping raw shift scale values to that is valid under - the given type_map: (atomic_numbers -> type_indices) - """ - shift_scale = [] - n_atom_types = len(type_map) - for s in (shift, scale): - if isinstance(s, list) and len(s) > n_atom_types: - if len(s) != NUM_UNIV_ELEMENT: - raise ValueError('given shift or scale is strange') - s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])] - # s = [s[z] for z in sorted(type_map, key=type_map.get)] - elif isinstance(s, float): - s = [s] * n_atom_types - elif isinstance(s, list) and len(s) == 1: - s = s * n_atom_types - shift_scale.append(s) - assert all([len(s) == n_atom_types for s in shift_scale]) - shift, scale = shift_scale - return SpeciesWiseRescale(shift, scale, **kwargs) - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - indices = data[self.key_indices] - data[self.key_output] = data[self.key_input] * self.scale[indices].view( - -1, 1 - ) + self.shift[indices].view(-1, 1) - - return data - - -@compile_mode('script') -class ModalWiseRescale(nn.Module): - """ - Scaling and shifting energy (and automatically force and stress) - Given shift or scale is either modal-wise and atom-wise or - not modal-wise but atom-wise. It is always interpreted as atom-wise. - """ - - def __init__( - self, - shift: List[List[float]], - scale: List[List[float]], - data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, - data_key_out: str = KEY.ATOMIC_ENERGY, - data_key_modal_indices: str = KEY.MODAL_TYPE, - data_key_atom_indices: str = KEY.ATOM_TYPE, - use_modal_wise_shift: bool = False, - use_modal_wise_scale: bool = False, - train_shift_scale: bool = False, - ): - super().__init__() - self.shift = nn.Parameter( - torch.FloatTensor(shift), requires_grad=train_shift_scale - ) - self.scale = nn.Parameter( - torch.FloatTensor(scale), requires_grad=train_shift_scale - ) - self.key_input = data_key_in - self.key_output = data_key_out - self.key_atom_indices = data_key_atom_indices - self.key_modal_indices = data_key_modal_indices - self.use_modal_wise_shift = use_modal_wise_shift - self.use_modal_wise_scale = use_modal_wise_scale - self._is_batch_data = True - - def get_shift( - self, - type_map: Optional[Dict[int, int]] = None, - modal_map: Optional[Dict[str, int]] = None, - ) -> Union[List[float], Dict[str, List[float]]]: - """ - Nothing is given: return as it is - type_map is given but not modal wise shift: return univ shift - both type_map and modal_map is given and modal wise shift: return fully - resolved modalwise univ shift - """ - shift = self.shift.detach().cpu().tolist() - if type_map and not self.use_modal_wise_shift: - shift = _as_univ(shift, type_map, 0.0) - elif self.use_modal_wise_shift and modal_map and type_map: - shift = [_as_univ(s, type_map, 0.0) for s in shift] - shift = {modal: shift[idx] for modal, idx in modal_map.items()} - - return shift - - def get_scale( - self, - type_map: Optional[Dict[int, int]] = None, - modal_map: Optional[Dict[str, int]] = None, - ) -> Union[List[float], Dict[str, List[float]]]: - """ - Nothing is given: return as it is - type_map is given but not modal wise scale: return univ scale - both type_map and modal_map is given and modal wise scale: return fully - resolved modalwise univ scale - """ - scale = self.scale.detach().cpu().tolist() - if type_map and not self.use_modal_wise_scale: - scale = _as_univ(scale, type_map, 0.0) - elif self.use_modal_wise_scale and modal_map and type_map: - scale = [_as_univ(s, type_map, 0.0) for s in scale] - scale = {modal: scale[idx] for modal, idx in modal_map.items()} - return scale - - @staticmethod - def from_mappers( - shift: Union[float, List[float], Dict[str, Any]], - scale: Union[float, List[float], Dict[str, Any]], - use_modal_wise_shift: bool, - use_modal_wise_scale: bool, - type_map: Dict[int, int], - modal_map: Dict[str, int], - **kwargs, - ): - """ - Fit dimensions or mapping raw shift scale values to that is valid under - the given type_map: (atomic_numbers -> type_indices) - If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT - , assume it is element-wise list - otherwise, it is modal-wise list - """ - - def solve_mapper(arr, map): - # value is attr index and never overlap, key is either 'z' or modal str - return [arr[z] for z in sorted(map, key=lambda x: map[x])] - - shift_scale = [] - n_atom_types = len(type_map) - n_modals = len(modal_map) - - for s, use_mw in ( - (shift, use_modal_wise_shift), - (scale, use_modal_wise_scale), - ): - # solve elemewise, or broadcast - if isinstance(s, float): - # given, modal-wise: no, elem-wise: no => broadcast - shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,) - res = torch.full(shape, s).tolist() # TODO: w/o torch - elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT: - # given, modal-wise: no, elem-wise: yes(univ) => solve elem map - s = solve_mapper(s, type_map) - res = [s] * n_modals if use_mw else s - elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise - isinstance(s, list) - and isinstance(s[0], float) - and len(s) == n_modals - and use_mw - ): - res = [[v] * n_atom_types for v in s] - elif ( # given, modal-wise: no, elem-wise: yes => as it is - isinstance(s, list) - and isinstance(s[0], float) - and len(s) == n_atom_types - and not use_mw - ): - res = s - elif ( # given, modal-wise: yes, elem-wise: yes => as it is - isinstance(s, list) - and isinstance(s[0], list) - and len(s) == n_modals - and len(s[0]) == n_atom_types - and use_mw - ): - res = s - elif isinstance(s, dict) and use_mw: - # solve modal dict, modal-wise: yes - s = solve_mapper(s, modal_map) - res = [] - for v in s: - if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT: - # elem-wise: yes(univ) => solve elem map - v = solve_mapper(v, type_map) - elif isinstance(v, float): - # elem-wise: no => broadcast to elemwise - v = [v] * n_atom_types - else: - raise ValueError(f'Invalid shift or scale {s}') - res.append(v) - else: - raise ValueError(f'Invalid shift or scale {s}') - - if use_mw: - assert ( - isinstance(res, list) - and isinstance(res[0], list) - and len(res) == n_modals - ) - assert all([len(r) == n_atom_types for r in res]) # type: ignore - else: - assert ( - isinstance(res, list) - and isinstance(res[0], float) - and len(res) == n_atom_types - ) - shift_scale.append(res) - shift, scale = shift_scale - - return ModalWiseRescale( - shift, - scale, - use_modal_wise_shift=use_modal_wise_shift, - use_modal_wise_scale=use_modal_wise_scale, - **kwargs, - ) - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self._is_batch_data: - batch = data[KEY.BATCH] - modal_indices = data[self.key_modal_indices][batch] - else: - modal_indices = data[self.key_modal_indices] - atom_indices = data[self.key_atom_indices] - shift = ( - self.shift[modal_indices, atom_indices] - if self.use_modal_wise_shift - else self.shift[atom_indices] - ) - scale = ( - self.scale[modal_indices, atom_indices] - if self.use_modal_wise_scale - else self.scale[atom_indices] - ) - data[self.key_output] = data[self.key_input] * scale.view( - -1, 1 - ) + shift.view(-1, 1) - - return data - - -def get_resolved_shift_scale( - module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale], - type_map: Optional[Dict[int, int]] = None, - modal_map: Optional[Dict[str, int]] = None, -): - """ - Return resolved shift and scale from scale modules. For element wise case, - convert to list of floats where idx is atomic number. For modal wise case, return - dictionary of shift scale where key is modal name given in modal_map - - Return: - Tuple of solved shift and scale - """ - - if isinstance(module, Rescale): - return (module.get_shift(), module.get_scale()) - elif isinstance(module, SpeciesWiseRescale): - return (module.get_shift(type_map), module.get_scale(type_map)) - elif isinstance(module, ModalWiseRescale): - return ( - module.get_shift(type_map, modal_map), - module.get_scale(type_map, modal_map), - ) - raise ValueError('Not scale module') +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType + + +def _as_univ( + ss: List[float], type_map: Dict[int, int], default: float +) -> List[float]: + assert len(ss) <= NUM_UNIV_ELEMENT, 'shift scale is too long' + return [ + ss[type_map[z]] if z in type_map else default + for z in range(NUM_UNIV_ELEMENT) + ] + + +@compile_mode('script') +class Rescale(nn.Module): + """ + Scaling and shifting energy (and automatically force and stress) + """ + + def __init__( + self, + shift: float, + scale: float, + data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, + data_key_out: str = KEY.ATOMIC_ENERGY, + train_shift_scale: bool = False, + **kwargs, + ): + assert isinstance(shift, float) and isinstance(scale, float) + super().__init__() + self.shift = nn.Parameter( + torch.FloatTensor([shift]), requires_grad=train_shift_scale + ) + self.scale = nn.Parameter( + torch.FloatTensor([scale]), requires_grad=train_shift_scale + ) + self.key_input = data_key_in + self.key_output = data_key_out + + def get_shift(self) -> float: + return self.shift.detach().cpu().tolist()[0] + + def get_scale(self) -> float: + return self.scale.detach().cpu().tolist()[0] + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[self.key_output] = data[self.key_input] * self.scale + self.shift + + return data + + +@compile_mode('script') +class SpeciesWiseRescale(nn.Module): + """ + Scaling and shifting energy (and automatically force and stress) + Use as it is if given list, expand to list if one of them is float + If two lists are given and length is not the same, raise error + """ + + def __init__( + self, + shift: Union[List[float], float], + scale: Union[List[float], float], + data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, + data_key_out: str = KEY.ATOMIC_ENERGY, + data_key_indices: str = KEY.ATOM_TYPE, + train_shift_scale: bool = False, + ): + super().__init__() + assert isinstance(shift, float) or isinstance(shift, list) + assert isinstance(scale, float) or isinstance(scale, list) + + if ( + isinstance(shift, list) + and isinstance(scale, list) + and len(shift) != len(scale) + ): + raise ValueError('List length should be same') + + if isinstance(shift, list): + num_species = len(shift) + elif isinstance(scale, list): + num_species = len(scale) + else: + raise ValueError('Both shift and scale is not a list') + + shift = [shift] * num_species if isinstance(shift, float) else shift + scale = [scale] * num_species if isinstance(scale, float) else scale + + self.shift = nn.Parameter( + torch.FloatTensor(shift), requires_grad=train_shift_scale + ) + self.scale = nn.Parameter( + torch.FloatTensor(scale), requires_grad=train_shift_scale + ) + self.key_input = data_key_in + self.key_output = data_key_out + self.key_indices = data_key_indices + + def get_shift(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: + """ + Return shift in list of float. If type_map is given, return type_map reversed + shift, which index equals atomic_number. 0.0 is assigned for atomis not found + """ + shift = self.shift.detach().cpu().tolist() + if type_map: + shift = _as_univ(shift, type_map, 0.0) + return shift + + def get_scale(self, type_map: Optional[Dict[int, int]] = None) -> List[float]: + """ + Return scale in list of float. If type_map is given, return type_map reversed + scale, which index equals atomic_number. 1.0 is assigned for atomis not found + """ + scale = self.scale.detach().cpu().tolist() + if type_map: + scale = _as_univ(scale, type_map, 1.0) + return scale + + @staticmethod + def from_mappers( + shift: Union[float, List[float]], + scale: Union[float, List[float]], + type_map: Dict[int, int], + **kwargs, + ): + """ + Fit dimensions or mapping raw shift scale values to that is valid under + the given type_map: (atomic_numbers -> type_indices) + """ + shift_scale = [] + n_atom_types = len(type_map) + for s in (shift, scale): + if isinstance(s, list) and len(s) > n_atom_types: + if len(s) != NUM_UNIV_ELEMENT: + raise ValueError('given shift or scale is strange') + s = [s[z] for z in sorted(type_map, key=lambda x: type_map[x])] + # s = [s[z] for z in sorted(type_map, key=type_map.get)] + elif isinstance(s, float): + s = [s] * n_atom_types + elif isinstance(s, list) and len(s) == 1: + s = s * n_atom_types + shift_scale.append(s) + assert all([len(s) == n_atom_types for s in shift_scale]) + shift, scale = shift_scale + return SpeciesWiseRescale(shift, scale, **kwargs) + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + indices = data[self.key_indices] + data[self.key_output] = data[self.key_input] * self.scale[indices].view( + -1, 1 + ) + self.shift[indices].view(-1, 1) + + return data + + +@compile_mode('script') +class ModalWiseRescale(nn.Module): + """ + Scaling and shifting energy (and automatically force and stress) + Given shift or scale is either modal-wise and atom-wise or + not modal-wise but atom-wise. It is always interpreted as atom-wise. + """ + + def __init__( + self, + shift: List[List[float]], + scale: List[List[float]], + data_key_in: str = KEY.SCALED_ATOMIC_ENERGY, + data_key_out: str = KEY.ATOMIC_ENERGY, + data_key_modal_indices: str = KEY.MODAL_TYPE, + data_key_atom_indices: str = KEY.ATOM_TYPE, + use_modal_wise_shift: bool = False, + use_modal_wise_scale: bool = False, + train_shift_scale: bool = False, + ): + super().__init__() + self.shift = nn.Parameter( + torch.FloatTensor(shift), requires_grad=train_shift_scale + ) + self.scale = nn.Parameter( + torch.FloatTensor(scale), requires_grad=train_shift_scale + ) + self.key_input = data_key_in + self.key_output = data_key_out + self.key_atom_indices = data_key_atom_indices + self.key_modal_indices = data_key_modal_indices + self.use_modal_wise_shift = use_modal_wise_shift + self.use_modal_wise_scale = use_modal_wise_scale + self._is_batch_data = True + + def get_shift( + self, + type_map: Optional[Dict[int, int]] = None, + modal_map: Optional[Dict[str, int]] = None, + ) -> Union[List[float], Dict[str, List[float]]]: + """ + Nothing is given: return as it is + type_map is given but not modal wise shift: return univ shift + both type_map and modal_map is given and modal wise shift: return fully + resolved modalwise univ shift + """ + shift = self.shift.detach().cpu().tolist() + if type_map and not self.use_modal_wise_shift: + shift = _as_univ(shift, type_map, 0.0) + elif self.use_modal_wise_shift and modal_map and type_map: + shift = [_as_univ(s, type_map, 0.0) for s in shift] + shift = {modal: shift[idx] for modal, idx in modal_map.items()} + + return shift + + def get_scale( + self, + type_map: Optional[Dict[int, int]] = None, + modal_map: Optional[Dict[str, int]] = None, + ) -> Union[List[float], Dict[str, List[float]]]: + """ + Nothing is given: return as it is + type_map is given but not modal wise scale: return univ scale + both type_map and modal_map is given and modal wise scale: return fully + resolved modalwise univ scale + """ + scale = self.scale.detach().cpu().tolist() + if type_map and not self.use_modal_wise_scale: + scale = _as_univ(scale, type_map, 0.0) + elif self.use_modal_wise_scale and modal_map and type_map: + scale = [_as_univ(s, type_map, 0.0) for s in scale] + scale = {modal: scale[idx] for modal, idx in modal_map.items()} + return scale + + @staticmethod + def from_mappers( + shift: Union[float, List[float], Dict[str, Any]], + scale: Union[float, List[float], Dict[str, Any]], + use_modal_wise_shift: bool, + use_modal_wise_scale: bool, + type_map: Dict[int, int], + modal_map: Dict[str, int], + **kwargs, + ): + """ + Fit dimensions or mapping raw shift scale values to that is valid under + the given type_map: (atomic_numbers -> type_indices) + If given List[float] and its length matches length of _const.NUM_UNIV_ELEMENT + , assume it is element-wise list + otherwise, it is modal-wise list + """ + + def solve_mapper(arr, map): + # value is attr index and never overlap, key is either 'z' or modal str + return [arr[z] for z in sorted(map, key=lambda x: map[x])] + + shift_scale = [] + n_atom_types = len(type_map) + n_modals = len(modal_map) + + for s, use_mw in ( + (shift, use_modal_wise_shift), + (scale, use_modal_wise_scale), + ): + # solve elemewise, or broadcast + if isinstance(s, float): + # given, modal-wise: no, elem-wise: no => broadcast + shape = (n_modals, n_atom_types) if use_mw else (n_atom_types,) + res = torch.full(shape, s).tolist() # TODO: w/o torch + elif isinstance(s, list) and len(s) == NUM_UNIV_ELEMENT: + # given, modal-wise: no, elem-wise: yes(univ) => solve elem map + s = solve_mapper(s, type_map) + res = [s] * n_modals if use_mw else s + elif ( # given, modal-wise: yes, elem-wise: no => broadcast to elemwise + isinstance(s, list) + and isinstance(s[0], float) + and len(s) == n_modals + and use_mw + ): + res = [[v] * n_atom_types for v in s] + elif ( # given, modal-wise: no, elem-wise: yes => as it is + isinstance(s, list) + and isinstance(s[0], float) + and len(s) == n_atom_types + and not use_mw + ): + res = s + elif ( # given, modal-wise: yes, elem-wise: yes => as it is + isinstance(s, list) + and isinstance(s[0], list) + and len(s) == n_modals + and len(s[0]) == n_atom_types + and use_mw + ): + res = s + elif isinstance(s, dict) and use_mw: + # solve modal dict, modal-wise: yes + s = solve_mapper(s, modal_map) + res = [] + for v in s: + if isinstance(v, list) and len(v) == NUM_UNIV_ELEMENT: + # elem-wise: yes(univ) => solve elem map + v = solve_mapper(v, type_map) + elif isinstance(v, float): + # elem-wise: no => broadcast to elemwise + v = [v] * n_atom_types + else: + raise ValueError(f'Invalid shift or scale {s}') + res.append(v) + else: + raise ValueError(f'Invalid shift or scale {s}') + + if use_mw: + assert ( + isinstance(res, list) + and isinstance(res[0], list) + and len(res) == n_modals + ) + assert all([len(r) == n_atom_types for r in res]) # type: ignore + else: + assert ( + isinstance(res, list) + and isinstance(res[0], float) + and len(res) == n_atom_types + ) + shift_scale.append(res) + shift, scale = shift_scale + + return ModalWiseRescale( + shift, + scale, + use_modal_wise_shift=use_modal_wise_shift, + use_modal_wise_scale=use_modal_wise_scale, + **kwargs, + ) + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self._is_batch_data: + batch = data[KEY.BATCH] + modal_indices = data[self.key_modal_indices][batch] + else: + modal_indices = data[self.key_modal_indices] + atom_indices = data[self.key_atom_indices] + shift = ( + self.shift[modal_indices, atom_indices] + if self.use_modal_wise_shift + else self.shift[atom_indices] + ) + scale = ( + self.scale[modal_indices, atom_indices] + if self.use_modal_wise_scale + else self.scale[atom_indices] + ) + data[self.key_output] = data[self.key_input] * scale.view( + -1, 1 + ) + shift.view(-1, 1) + + return data + + +def get_resolved_shift_scale( + module: Union[Rescale, SpeciesWiseRescale, ModalWiseRescale], + type_map: Optional[Dict[int, int]] = None, + modal_map: Optional[Dict[str, int]] = None, +): + """ + Return resolved shift and scale from scale modules. For element wise case, + convert to list of floats where idx is atomic number. For modal wise case, return + dictionary of shift scale where key is modal name given in modal_map + + Return: + Tuple of solved shift and scale + """ + + if isinstance(module, Rescale): + return (module.get_shift(), module.get_scale()) + elif isinstance(module, SpeciesWiseRescale): + return (module.get_shift(type_map), module.get_scale(type_map)) + elif isinstance(module, ModalWiseRescale): + return ( + module.get_shift(type_map, modal_map), + module.get_scale(type_map, modal_map), + ) + raise ValueError('Not scale module') diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/self_connection.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/self_connection.py index 40b55a2..ce731b5 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/self_connection.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/self_connection.py @@ -1,128 +1,128 @@ -import torch.nn as nn -from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -@compile_mode('script') -class SelfConnectionIntro(nn.Module): - """ - do TensorProduct of x and some data(here attribute of x) - and save it (to concatenate updated x at SelfConnectionOutro) - """ - - def __init__( - self, - irreps_in: Irreps, - irreps_operand: Irreps, - irreps_out: Irreps, - data_key_x: str = KEY.NODE_FEATURE, - data_key_operand: str = KEY.NODE_ATTR, - lazy_layer_instantiate: bool = True, - **kwargs, # for compatibility - ): - super().__init__() - - self.fc_tensor_product = FullyConnectedTensorProduct( - irreps_in, irreps_operand, irreps_out - ) - self.irreps_in1 = irreps_in - self.irreps_in2 = irreps_operand - self.irreps_out = irreps_out - - self.key_x = data_key_x - self.key_operand = data_key_operand - - self.fc_tensor_product = None - self.layer_instantiated = False - self.fc_tensor_product_cls = FullyConnectedTensorProduct - self.fc_tensor_product_kwargs = kwargs - - if not lazy_layer_instantiate: - self.instantiate() - - def instantiate(self): - if self.fc_tensor_product is not None: - raise ValueError('fc_tensor_product layer already exists') - self.fc_tensor_product = self.fc_tensor_product_cls( - self.irreps_in1, - self.irreps_in2, - self.irreps_out, - shared_weights=True, - internal_weights=None, # same as True - **self.fc_tensor_product_kwargs, - ) - self.layer_instantiated = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - assert self.fc_tensor_product is not None, 'Layer is not instantiated' - data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product( - data[self.key_x], data[self.key_operand] - ) - return data - - -@compile_mode('script') -class SelfConnectionLinearIntro(nn.Module): - """ - Linear style self connection update - """ - - def __init__( - self, - irreps_in: Irreps, - irreps_out: Irreps, - data_key_x: str = KEY.NODE_FEATURE, - lazy_layer_instantiate: bool = True, - **kwargs, - ): - super().__init__() - self.irreps_in = irreps_in - self.irreps_out = irreps_out - self.key_x = data_key_x - - self.linear = None - self.layer_instantiated = False - self.linear_cls = Linear - - # TODO: better to have SelfConnectionIntro super class - kwargs.pop('irreps_operand') - self.linear_kwargs = kwargs - - if not lazy_layer_instantiate: - self.instantiate() - - def instantiate(self): - if self.linear is not None: - raise ValueError('Linear layer already exists') - self.linear = self.linear_cls( - self.irreps_in, self.irreps_out, **self.linear_kwargs - ) - self.layer_instantiated = True - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - assert self.linear is not None, 'Layer is not instantiated' - data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x]) - return data - - -@compile_mode('script') -class SelfConnectionOutro(nn.Module): - """ - do TensorProduct of x and some data(here attribute of x) - and save it (to concatenate updated x at SelfConnectionOutro) - """ - - def __init__( - self, - data_key_x: str = KEY.NODE_FEATURE, - ): - super().__init__() - self.key_x = data_key_x - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP] - del data[KEY.SELF_CONNECTION_TEMP] - return data +import torch.nn as nn +from e3nn.o3 import FullyConnectedTensorProduct, Irreps, Linear +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +@compile_mode('script') +class SelfConnectionIntro(nn.Module): + """ + do TensorProduct of x and some data(here attribute of x) + and save it (to concatenate updated x at SelfConnectionOutro) + """ + + def __init__( + self, + irreps_in: Irreps, + irreps_operand: Irreps, + irreps_out: Irreps, + data_key_x: str = KEY.NODE_FEATURE, + data_key_operand: str = KEY.NODE_ATTR, + lazy_layer_instantiate: bool = True, + **kwargs, # for compatibility + ): + super().__init__() + + self.fc_tensor_product = FullyConnectedTensorProduct( + irreps_in, irreps_operand, irreps_out + ) + self.irreps_in1 = irreps_in + self.irreps_in2 = irreps_operand + self.irreps_out = irreps_out + + self.key_x = data_key_x + self.key_operand = data_key_operand + + self.fc_tensor_product = None + self.layer_instantiated = False + self.fc_tensor_product_cls = FullyConnectedTensorProduct + self.fc_tensor_product_kwargs = kwargs + + if not lazy_layer_instantiate: + self.instantiate() + + def instantiate(self): + if self.fc_tensor_product is not None: + raise ValueError('fc_tensor_product layer already exists') + self.fc_tensor_product = self.fc_tensor_product_cls( + self.irreps_in1, + self.irreps_in2, + self.irreps_out, + shared_weights=True, + internal_weights=None, # same as True + **self.fc_tensor_product_kwargs, + ) + self.layer_instantiated = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + assert self.fc_tensor_product is not None, 'Layer is not instantiated' + data[KEY.SELF_CONNECTION_TEMP] = self.fc_tensor_product( + data[self.key_x], data[self.key_operand] + ) + return data + + +@compile_mode('script') +class SelfConnectionLinearIntro(nn.Module): + """ + Linear style self connection update + """ + + def __init__( + self, + irreps_in: Irreps, + irreps_out: Irreps, + data_key_x: str = KEY.NODE_FEATURE, + lazy_layer_instantiate: bool = True, + **kwargs, + ): + super().__init__() + self.irreps_in = irreps_in + self.irreps_out = irreps_out + self.key_x = data_key_x + + self.linear = None + self.layer_instantiated = False + self.linear_cls = Linear + + # TODO: better to have SelfConnectionIntro super class + kwargs.pop('irreps_operand') + self.linear_kwargs = kwargs + + if not lazy_layer_instantiate: + self.instantiate() + + def instantiate(self): + if self.linear is not None: + raise ValueError('Linear layer already exists') + self.linear = self.linear_cls( + self.irreps_in, self.irreps_out, **self.linear_kwargs + ) + self.layer_instantiated = True + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + assert self.linear is not None, 'Layer is not instantiated' + data[KEY.SELF_CONNECTION_TEMP] = self.linear(data[self.key_x]) + return data + + +@compile_mode('script') +class SelfConnectionOutro(nn.Module): + """ + do TensorProduct of x and some data(here attribute of x) + and save it (to concatenate updated x at SelfConnectionOutro) + """ + + def __init__( + self, + data_key_x: str = KEY.NODE_FEATURE, + ): + super().__init__() + self.key_x = data_key_x + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[self.key_x] = data[self.key_x] + data[KEY.SELF_CONNECTION_TEMP] + del data[KEY.SELF_CONNECTION_TEMP] + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/sequential.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/sequential.py index 1a91ae6..c300814 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/sequential.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/sequential.py @@ -1,183 +1,183 @@ -import warnings -from collections import OrderedDict -from typing import Dict, Optional - -import torch -import torch.nn as nn -from e3nn.util.jit import compile_mode - -import sevenn._keys as KEY -from sevenn._const import AtomGraphDataType - - -def _instantiate_modules(modules): - # see IrrepsLinear of linear.py - for module in modules.values(): - if not getattr(module, 'layer_instantiated', True): - module.instantiate() - - -@compile_mode('script') -class _ModalInputPrepare(nn.Module): - - def __init__( - self, - modal_idx: int - ): - super().__init__() - self.modal_idx = modal_idx - - def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: - data[KEY.MODAL_TYPE] = torch.tensor( - self.modal_idx, - dtype=torch.int64, - device=data['x'].device, - ) - return data - - -@compile_mode('script') -class AtomGraphSequential(nn.Sequential): - """ - Wrapper of SevenNet model - - Args: - modules: OrderedDict of nn.Modules - cutoff: not used internally, but makes sense to have - type_map: atomic_numbers => onehot index (see nn/node_embedding.py) - eval_type_map: perform index mapping using type_map defaults to True - data_key_atomic_numbers: used when eval_type_map is True - data_key_node_feature: used when eval_type_map is True - data_key_grad: if given, sets its requires grad True before pred - """ - - def __init__( - self, - modules: Dict[str, nn.Module], - cutoff: float = 0.0, - type_map: Optional[Dict[int, int]] = None, - modal_map: Optional[Dict[str, int]] = None, - eval_type_map: bool = True, - eval_modal_map: bool = False, - data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS, - data_key_node_feature: str = KEY.NODE_FEATURE, - data_key_grad: Optional[str] = None, - ): - if not isinstance(modules, OrderedDict): # backward compat - modules = OrderedDict(modules) - self.cutoff = cutoff - self.type_map = type_map - self.eval_type_map = eval_type_map - self.is_batch_data = True - - if cutoff == 0.0: - warnings.warn('cutoff is 0.0 or not given', UserWarning) - - if self.type_map is None: - warnings.warn('type_map is not given', UserWarning) - self.eval_type_map = False - else: - z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long)) - for z, onehot in self.type_map.items(): - z_to_onehot_tensor[z] = onehot - self.z_to_onehot_tensor = z_to_onehot_tensor - - if eval_modal_map and modal_map is None: - raise ValueError('eval_modal_map is True but modal_map is None') - self.eval_modal_map = eval_modal_map - self.modal_map = modal_map - - self.key_atomic_numbers = data_key_atomic_numbers - self.key_node_feature = data_key_node_feature - self.key_grad = data_key_grad - - _instantiate_modules(modules) - super().__init__(modules) - if not isinstance(self._modules, OrderedDict): # backward compat - self._modules = OrderedDict(self._modules) - - def set_is_batch_data(self, flag: bool): - # whether given data is batched or not some module have to change - # its behavior. checking whether data is batched or not inside - # forward function make problem harder when make it into torchscript - for module in self: - try: # Easier to ask for forgiveness than permission. - module._is_batch_data = flag # type: ignore - except AttributeError: - pass - self.is_batch_data = flag - - def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'): - tg_module = self._modules[modlue_name] - for m in tg_module.modules(): - try: - return repr(m.__getattribute__(attr_key)) - except AttributeError: - pass - return None - - def prepand_module(self, key: str, module: nn.Module): - self._modules.update({key: module}) - self._modules.move_to_end(key, last=False) # type: ignore - - def replace_module(self, key: str, module: nn.Module): - self._modules.update({key: module}) - - def delete_module_by_key(self, key: str): - if key in self._modules.keys(): - del self._modules[key] - - @torch.jit.unused - def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor): - assert atomic_numbers.dtype == torch.int64 - device = atomic_numbers.device - z_to_onehot_tensor = self.z_to_onehot_tensor.to(device) - return torch.index_select( - input=z_to_onehot_tensor, dim=0, index=atomic_numbers - ) - - @torch.jit.unused - def _eval_modal_map(self, data: AtomGraphDataType): - assert self.modal_map is not None - # modal_map: dict[str, int] - if not self.is_batch_data: - modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore - else: - modal_idx = [ - self.modal_map[ii] # type: ignore - for ii in data[KEY.DATA_MODALITY] - ] - modal_idx = torch.tensor( - modal_idx, - dtype=torch.int64, - device=data.x.device, # type: ignore - ) - data[KEY.MODAL_TYPE] = modal_idx - - def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType: - if self.eval_type_map: - atomic_numbers = data[self.key_atomic_numbers] - onehot = self._atomic_numbers_to_onehot(atomic_numbers) - data[self.key_node_feature] = onehot - - if self.eval_modal_map: - self._eval_modal_map(data) - - if self.key_grad is not None: - data[self.key_grad].requires_grad_(True) - - return data - - def prepare_modal_deploy(self, modal: str): - if self.modal_map is None: - return - self.eval_modal_map = False - self.set_is_batch_data(False) - modal_idx = self.modal_map[modal] # type: ignore - self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx)) - - def forward(self, input: AtomGraphDataType) -> AtomGraphDataType: - data = self._preprocess(input) - for module in self: - data = module(data) - return data +import warnings +from collections import OrderedDict +from typing import Dict, Optional + +import torch +import torch.nn as nn +from e3nn.util.jit import compile_mode + +import sevenn._keys as KEY +from sevenn._const import AtomGraphDataType + + +def _instantiate_modules(modules): + # see IrrepsLinear of linear.py + for module in modules.values(): + if not getattr(module, 'layer_instantiated', True): + module.instantiate() + + +@compile_mode('script') +class _ModalInputPrepare(nn.Module): + + def __init__( + self, + modal_idx: int + ): + super().__init__() + self.modal_idx = modal_idx + + def forward(self, data: AtomGraphDataType) -> AtomGraphDataType: + data[KEY.MODAL_TYPE] = torch.tensor( + self.modal_idx, + dtype=torch.int64, + device=data['x'].device, + ) + return data + + +@compile_mode('script') +class AtomGraphSequential(nn.Sequential): + """ + Wrapper of SevenNet model + + Args: + modules: OrderedDict of nn.Modules + cutoff: not used internally, but makes sense to have + type_map: atomic_numbers => onehot index (see nn/node_embedding.py) + eval_type_map: perform index mapping using type_map defaults to True + data_key_atomic_numbers: used when eval_type_map is True + data_key_node_feature: used when eval_type_map is True + data_key_grad: if given, sets its requires grad True before pred + """ + + def __init__( + self, + modules: Dict[str, nn.Module], + cutoff: float = 0.0, + type_map: Optional[Dict[int, int]] = None, + modal_map: Optional[Dict[str, int]] = None, + eval_type_map: bool = True, + eval_modal_map: bool = False, + data_key_atomic_numbers: str = KEY.ATOMIC_NUMBERS, + data_key_node_feature: str = KEY.NODE_FEATURE, + data_key_grad: Optional[str] = None, + ): + if not isinstance(modules, OrderedDict): # backward compat + modules = OrderedDict(modules) + self.cutoff = cutoff + self.type_map = type_map + self.eval_type_map = eval_type_map + self.is_batch_data = True + + if cutoff == 0.0: + warnings.warn('cutoff is 0.0 or not given', UserWarning) + + if self.type_map is None: + warnings.warn('type_map is not given', UserWarning) + self.eval_type_map = False + else: + z_to_onehot_tensor = torch.neg(torch.ones(120, dtype=torch.long)) + for z, onehot in self.type_map.items(): + z_to_onehot_tensor[z] = onehot + self.z_to_onehot_tensor = z_to_onehot_tensor + + if eval_modal_map and modal_map is None: + raise ValueError('eval_modal_map is True but modal_map is None') + self.eval_modal_map = eval_modal_map + self.modal_map = modal_map + + self.key_atomic_numbers = data_key_atomic_numbers + self.key_node_feature = data_key_node_feature + self.key_grad = data_key_grad + + _instantiate_modules(modules) + super().__init__(modules) + if not isinstance(self._modules, OrderedDict): # backward compat + self._modules = OrderedDict(self._modules) + + def set_is_batch_data(self, flag: bool): + # whether given data is batched or not some module have to change + # its behavior. checking whether data is batched or not inside + # forward function make problem harder when make it into torchscript + for module in self: + try: # Easier to ask for forgiveness than permission. + module._is_batch_data = flag # type: ignore + except AttributeError: + pass + self.is_batch_data = flag + + def get_irreps_in(self, modlue_name: str, attr_key: str = 'irreps_in'): + tg_module = self._modules[modlue_name] + for m in tg_module.modules(): + try: + return repr(m.__getattribute__(attr_key)) + except AttributeError: + pass + return None + + def prepand_module(self, key: str, module: nn.Module): + self._modules.update({key: module}) + self._modules.move_to_end(key, last=False) # type: ignore + + def replace_module(self, key: str, module: nn.Module): + self._modules.update({key: module}) + + def delete_module_by_key(self, key: str): + if key in self._modules.keys(): + del self._modules[key] + + @torch.jit.unused + def _atomic_numbers_to_onehot(self, atomic_numbers: torch.Tensor): + assert atomic_numbers.dtype == torch.int64 + device = atomic_numbers.device + z_to_onehot_tensor = self.z_to_onehot_tensor.to(device) + return torch.index_select( + input=z_to_onehot_tensor, dim=0, index=atomic_numbers + ) + + @torch.jit.unused + def _eval_modal_map(self, data: AtomGraphDataType): + assert self.modal_map is not None + # modal_map: dict[str, int] + if not self.is_batch_data: + modal_idx = self.modal_map[data[KEY.DATA_MODALITY]] # type: ignore + else: + modal_idx = [ + self.modal_map[ii] # type: ignore + for ii in data[KEY.DATA_MODALITY] + ] + modal_idx = torch.tensor( + modal_idx, + dtype=torch.int64, + device=data.x.device, # type: ignore + ) + data[KEY.MODAL_TYPE] = modal_idx + + def _preprocess(self, data: AtomGraphDataType) -> AtomGraphDataType: + if self.eval_type_map: + atomic_numbers = data[self.key_atomic_numbers] + onehot = self._atomic_numbers_to_onehot(atomic_numbers) + data[self.key_node_feature] = onehot + + if self.eval_modal_map: + self._eval_modal_map(data) + + if self.key_grad is not None: + data[self.key_grad].requires_grad_(True) + + return data + + def prepare_modal_deploy(self, modal: str): + if self.modal_map is None: + return + self.eval_modal_map = False + self.set_is_batch_data(False) + modal_idx = self.modal_map[modal] # type: ignore + self.prepand_module('modal_input_prepare', _ModalInputPrepare(modal_idx)) + + def forward(self, input: AtomGraphDataType) -> AtomGraphDataType: + data = self._preprocess(input) + for module in self: + data = module(data) + return data diff --git a/mace-bench/3rdparty/SevenNet/sevenn/nn/util.py b/mace-bench/3rdparty/SevenNet/sevenn/nn/util.py index 411b6dc..cf29c96 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/nn/util.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/nn/util.py @@ -1,17 +1,17 @@ -import torch - - -def broadcast( - src: torch.Tensor, - other: torch.Tensor, - dim: int -): - if dim < 0: - dim = other.dim() + dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand_as(other) - return src +import torch + + +def broadcast( + src: torch.Tensor, + other: torch.Tensor, + dim: int +): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src diff --git a/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/patch_lammps.sh b/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/patch_lammps.sh index b111d0e..e6bc90d 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/patch_lammps.sh +++ b/mace-bench/3rdparty/SevenNet/sevenn/pair_e3gnn/patch_lammps.sh @@ -1,154 +1,154 @@ -#!/bin/bash - -lammps_root=$1 -cxx_standard=$2 # 14, 17 -d3_support=$3 # 1, 0 -SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") - -########################################### -# Check if the given arguments are valid # -########################################### - -# Check the number of arguments -if [ "$#" -ne 3 ]; then - echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}" - echo " {lammps_root}: Root directory of LAMMPS source" - echo " {cxx_standard}: C++ standard (14, 17)" - echo " {d3_support}: Support for pair_d3 (1, 0)" - exit 1 -fi - -# Check if the lammps_root directory exists -if [ ! -d "$lammps_root" ]; then - echo "Error: No such directory: $lammps_root" - exit 1 -fi - -# Check if the given directory is the root of LAMMPS source -if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then - echo "Error: Given $lammps_root is not a root of LAMMPS source" - exit 1 -fi - -# Check if the script is being run from the root of SevenNet -if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then - echo "Error: Script executed in a wrong directory" - exit 1 -fi - -# Check if the patch is already applied -if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then - echo "----------------------------------------------------------" - echo "Seems like given LAMMPS is already patched." - echo "Try again after removing src/pair_e3gnn.cpp to force patch" - echo "----------------------------------------------------------" - echo "Example build commands, under LAMMPS root" - echo " mkdir build; cd build" - echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" - echo " make -j 4" - exit 0 -fi - -# Check if OpenMPI exists and if it is CUDA-aware -if command -v ompi_info &> /dev/null; then - cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value) - if [[ -z "$cuda_support" ]]; then - echo "OpenMPI not found, parallel performance is not optimal" - elif [[ "$cuda_support" == *"true" ]]; then - echo "OpenMPI is CUDA aware" - else - echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal" - fi -else - echo "OpenMPI not found, parallel performance is not optimal" -fi - -# Extract LAMMPS version and update -lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"') - -# Combine version and update -detected_version="$lammps_version" -required_version="2 Aug 2023" # Example required version - -# Check if the detected version is compatible -if [[ "$detected_version" != "$required_version" ]]; then - echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version" -fi - -########################################### -# Backup original LAMMPS source code # -########################################### - -# Create a backup directory if it doesn't exist -backup_dir="$lammps_root/_backups" -mkdir -p $backup_dir - -# Copy comm_* from original LAMMPS source as backup -cp $lammps_root/src/comm_brick.cpp $backup_dir/ -cp $lammps_root/src/comm_brick.h $backup_dir/ - -# Copy cmake/CMakeLists.txt from original source as backup -cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt - -########################################### -# Patch LAMMPS source code: e3gnn # -########################################### - -# 1. Copy pair_e3gnn files to LAMMPS source -cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/ -cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/ - -# 2. Patch cmake/CMakeLists.txt -sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt -cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" - -find_package(Torch REQUIRED) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") -target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}") -EOF - -########################################### -# Patch LAMMPS source code: d3 # -########################################### - -if [ "$d3_support" -ne 0 ]; then - -# 1. Copy pair_d3 files to LAMMPS source -cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/ -cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/ -cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/ - -# 2. Patch cmake/CMakeLists.txt -sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt -sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt -cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" - -find_package(CUDA) -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3") -string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") -target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda) -EOF - -fi - -########################################### -# Print changes and backup file locations # -########################################### - -# Print changes and backup file locations -echo "Changes made:" -echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups" -echo " - Copied contents of pair_e3gnn to $lammps_root/src/" -echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard" -if [ "$d3_support" -ne 0 ]; then - echo " - Copied contents of pair_d3 to $lammps_root/src/" - echo " - Patched CMakeLists.txt: include CUDA" -fi - -# Provide example cmake command to the user -echo "Example build commands, under LAMMPS root" -echo " mkdir build; cd build" -echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" -echo " make -j 4" - -exit 0 +#!/bin/bash + +lammps_root=$1 +cxx_standard=$2 # 14, 17 +d3_support=$3 # 1, 0 +SCRIPT_DIR=$(dirname "${BASH_SOURCE[0]}") + +########################################### +# Check if the given arguments are valid # +########################################### + +# Check the number of arguments +if [ "$#" -ne 3 ]; then + echo "Usage: sh patch_lammps.sh {lammps_root} {cxx_standard} {d3_support}" + echo " {lammps_root}: Root directory of LAMMPS source" + echo " {cxx_standard}: C++ standard (14, 17)" + echo " {d3_support}: Support for pair_d3 (1, 0)" + exit 1 +fi + +# Check if the lammps_root directory exists +if [ ! -d "$lammps_root" ]; then + echo "Error: No such directory: $lammps_root" + exit 1 +fi + +# Check if the given directory is the root of LAMMPS source +if [ ! -d "$lammps_root/cmake" ] && [ ! -d "$lammps_root/potentials" ]; then + echo "Error: Given $lammps_root is not a root of LAMMPS source" + exit 1 +fi + +# Check if the script is being run from the root of SevenNet +if [ ! -f "${SCRIPT_DIR}/pair_e3gnn.cpp" ]; then + echo "Error: Script executed in a wrong directory" + exit 1 +fi + +# Check if the patch is already applied +if [ -f "$lammps_root/src/pair_e3gnn.cpp" ]; then + echo "----------------------------------------------------------" + echo "Seems like given LAMMPS is already patched." + echo "Try again after removing src/pair_e3gnn.cpp to force patch" + echo "----------------------------------------------------------" + echo "Example build commands, under LAMMPS root" + echo " mkdir build; cd build" + echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" + echo " make -j 4" + exit 0 +fi + +# Check if OpenMPI exists and if it is CUDA-aware +if command -v ompi_info &> /dev/null; then + cuda_support=$(ompi_info --parsable --all | grep mpi_built_with_cuda_support:value) + if [[ -z "$cuda_support" ]]; then + echo "OpenMPI not found, parallel performance is not optimal" + elif [[ "$cuda_support" == *"true" ]]; then + echo "OpenMPI is CUDA aware" + else + echo "This system's OpenMPI is not 'CUDA aware', parallel performance is not optimal" + fi +else + echo "OpenMPI not found, parallel performance is not optimal" +fi + +# Extract LAMMPS version and update +lammps_version=$(grep "#define LAMMPS_VERSION" $lammps_root/src/version.h | awk '{print $3, $4, $5}' | tr -d '"') + +# Combine version and update +detected_version="$lammps_version" +required_version="2 Aug 2023" # Example required version + +# Check if the detected version is compatible +if [[ "$detected_version" != "$required_version" ]]; then + echo "Warning: Detected LAMMPS version ($detected_version) may not be compatible. Required version: $required_version" +fi + +########################################### +# Backup original LAMMPS source code # +########################################### + +# Create a backup directory if it doesn't exist +backup_dir="$lammps_root/_backups" +mkdir -p $backup_dir + +# Copy comm_* from original LAMMPS source as backup +cp $lammps_root/src/comm_brick.cpp $backup_dir/ +cp $lammps_root/src/comm_brick.h $backup_dir/ + +# Copy cmake/CMakeLists.txt from original source as backup +cp $lammps_root/cmake/CMakeLists.txt $backup_dir/CMakeLists.txt + +########################################### +# Patch LAMMPS source code: e3gnn # +########################################### + +# 1. Copy pair_e3gnn files to LAMMPS source +cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.cpp $lammps_root/src/ +cp $SCRIPT_DIR/{pair_e3gnn,pair_e3gnn_parallel,comm_brick}.h $lammps_root/src/ + +# 2. Patch cmake/CMakeLists.txt +sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD $cxx_standard)/" $lammps_root/cmake/CMakeLists.txt +cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") +target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}") +EOF + +########################################### +# Patch LAMMPS source code: d3 # +########################################### + +if [ "$d3_support" -ne 0 ]; then + +# 1. Copy pair_d3 files to LAMMPS source +cp $SCRIPT_DIR/pair_d3.cu $lammps_root/src/ +cp $SCRIPT_DIR/pair_d3.h $lammps_root/src/ +cp $SCRIPT_DIR/pair_d3_pars.h $lammps_root/src/ + +# 2. Patch cmake/CMakeLists.txt +sed -i "s/project(lammps CXX)/project(lammps CXX CUDA)/" $lammps_root/cmake/CMakeLists.txt +sed -i "s/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp/\${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cpp \${LAMMPS_SOURCE_DIR}\/\[\^.\]\*\.cu/" $lammps_root/cmake/CMakeLists.txt +cat >> $lammps_root/cmake/CMakeLists.txt << "EOF" + +find_package(CUDA) +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -fmad=false -O3") +string(REPLACE "-gencode arch=compute_50,code=sm_50" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") +target_link_libraries(lammps PUBLIC ${CUDA_LIBRARIES} cuda) +EOF + +fi + +########################################### +# Print changes and backup file locations # +########################################### + +# Print changes and backup file locations +echo "Changes made:" +echo " - Original LAMMPS files (src/comm_brick.*, cmake/CMakeList.txt) are in {lammps_root}/_backups" +echo " - Copied contents of pair_e3gnn to $lammps_root/src/" +echo " - Patched CMakeLists.txt: include LibTorch, CXX_STANDARD $cxx_standard" +if [ "$d3_support" -ne 0 ]; then + echo " - Copied contents of pair_d3 to $lammps_root/src/" + echo " - Patched CMakeLists.txt: include CUDA" +fi + +# Provide example cmake command to the user +echo "Example build commands, under LAMMPS root" +echo " mkdir build; cd build" +echo " cmake ../cmake -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" +echo " make -j 4" + +exit 0 diff --git a/mace-bench/3rdparty/SevenNet/sevenn/parse_input.py b/mace-bench/3rdparty/SevenNet/sevenn/parse_input.py index 62f3167..f0406d8 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/parse_input.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/parse_input.py @@ -1,246 +1,246 @@ -import glob -import os -import warnings -from typing import Any, Callable, Dict - -import torch -import yaml - -import sevenn._const as _const -import sevenn._keys as KEY -import sevenn.util as util - - -def config_initialize( - key: str, - config: Dict, - default: Any, - conditions: Dict, -): - # default value exist & no user input -> return default - if key not in config.keys(): - return default - - # No validation method exist => accept user input - user_input = config[key] - if key in conditions: - condition = conditions[key] - else: - return user_input - - if type(default) is dict and isinstance(condition, dict): - for i_key, val in default.items(): - user_input[i_key] = config_initialize( - i_key, user_input, val, condition - ) - return user_input - elif isinstance(condition, type): - if isinstance(user_input, condition): - return user_input - else: - try: - return condition(user_input) # try type casting - except ValueError: - raise ValueError( - f"Expect '{user_input}' for '{key}' is {condition}" - ) - elif isinstance(condition, Callable) and condition(user_input): - return user_input - else: - raise ValueError( - f"Given input '{user_input}' for '{key}' is not valid" - ) - - -def init_model_config(config: Dict): - # defaults = _const.model_defaults(config) - model_meta = {} - - # init complicated ones - if KEY.CHEMICAL_SPECIES not in config.keys(): - raise ValueError('required key chemical_species not exist') - input_chem = config[KEY.CHEMICAL_SPECIES] - if isinstance(input_chem, str) and input_chem.lower() == 'auto': - model_meta[KEY.CHEMICAL_SPECIES] = 'auto' - model_meta[KEY.NUM_SPECIES] = 'auto' - model_meta[KEY.TYPE_MAP] = 'auto' - elif isinstance(input_chem, str) and 'univ' in input_chem.lower(): - model_meta.update(util.chemical_species_preprocess([], universal=True)) - else: - if isinstance(input_chem, list) and all( - isinstance(x, str) for x in input_chem - ): - pass - elif isinstance(input_chem, str): - input_chem = ( - input_chem.replace('-', ',').replace(' ', ',').split(',') - ) - input_chem = [chem for chem in input_chem if len(chem) != 0] - else: - raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange') - model_meta.update(util.chemical_species_preprocess(input_chem)) - - # deprecation warnings - if KEY.AVG_NUM_NEIGH in config: - warnings.warn( - "key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'." - ' We use the default, the average number of neighbors in the' - ' dataset, if not provided.', - UserWarning, - ) - config.pop(KEY.AVG_NUM_NEIGH) - if KEY.TRAIN_AVG_NUM_NEIGH in config: - warnings.warn( - "key 'train_avg_num_neigh' is deprecated. Please use" - " 'train_denominator'. We overwrite train_denominator as given" - ' train_avg_num_neigh', - UserWarning, - ) - config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH] - config.pop(KEY.TRAIN_AVG_NUM_NEIGH) - if KEY.OPTIMIZE_BY_REDUCE in config: - warnings.warn( - "key 'optimize_by_reduce' is deprecated. Always true", - UserWarning, - ) - config.pop(KEY.OPTIMIZE_BY_REDUCE) - - # init simpler ones - for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items(): - model_meta[key] = config_initialize( - key, config, default, _const.MODEL_CONFIG_CONDITION - ) - - unknown_keys = [ - key for key in config.keys() if key not in model_meta.keys() - ] - if len(unknown_keys) != 0: - warnings.warn( - f'Unexpected model keys: {unknown_keys} will be ignored', - UserWarning, - ) - - return model_meta - - -def init_train_config(config: Dict): - train_meta = {} - # defaults = _const.train_defaults(config) - - try: - device_input = config[KEY.DEVICE] - train_meta[KEY.DEVICE] = torch.device(device_input) - except KeyError: - train_meta[KEY.DEVICE] = ( - torch.device('cuda') - if torch.cuda.is_available() - else torch.device('cpu') - ) - train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE]) - - # init simpler ones - for key, default in _const.DEFAULT_TRAINING_CONFIG.items(): - train_meta[key] = config_initialize( - key, config, default, _const.TRAINING_CONFIG_CONDITION - ) - - if KEY.CONTINUE in config.keys(): - cnt_dct = config[KEY.CONTINUE] - if KEY.CHECKPOINT not in cnt_dct.keys(): - raise ValueError('no checkpoint is given in continue') - checkpoint = cnt_dct[KEY.CHECKPOINT] - if os.path.isfile(checkpoint): - checkpoint_file = checkpoint - else: - checkpoint_file = util.pretrained_name_to_path(checkpoint) - train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file}) - - unknown_keys = [ - key for key in config.keys() if key not in train_meta.keys() - ] - if len(unknown_keys) != 0: - warnings.warn( - f'Unexpected train keys: {unknown_keys} will be ignored', - UserWarning, - ) - return train_meta - - -def init_data_config(config: Dict): - data_meta = {} - # defaults = _const.data_defaults(config) - - load_data_keys = [] - for k in config: - if k.startswith('load_') and k.endswith('_path'): - load_data_keys.append(k) - - for load_data_key in load_data_keys: - if load_data_key in config.keys(): - inp = config[load_data_key] - extended = [] - if type(inp) not in [str, list]: - raise ValueError(f'unexpected input {inp} for sturcture_list') - if type(inp) is str: - extended = glob.glob(inp) - elif type(inp) is list: - for i in inp: - if isinstance(i, str): - extended.extend(glob.glob(i)) - elif isinstance(i, dict): - extended.append(i) - if len(extended) == 0: - raise ValueError( - f'Cannot find {inp} for {load_data_key}' - + ' or path is not given' - ) - data_meta[load_data_key] = extended - else: - data_meta[load_data_key] = False - - for key, default in _const.DEFAULT_DATA_CONFIG.items(): - data_meta[key] = config_initialize( - key, config, default, _const.DATA_CONFIG_CONDITION - ) - - unknown_keys = [ - key for key in config.keys() if key not in data_meta.keys() - ] - if len(unknown_keys) != 0: - warnings.warn( - f'Unexpected data keys: {unknown_keys} will be ignored', - UserWarning, - ) - return data_meta - - -def read_config_yaml(filename: str, return_separately: bool = False): - with open(filename, 'r') as fstream: - inputs = yaml.safe_load(fstream) - - model_meta, train_meta, data_meta = {}, {}, {} - for key, config in inputs.items(): - if key == 'model': - model_meta = init_model_config(config) - elif key == 'train': - train_meta = init_train_config(config) - elif key == 'data': - data_meta = init_data_config(config) - else: - raise ValueError(f'Unexpected input {key} given') - - if return_separately: - return model_meta, train_meta, data_meta - else: - model_meta.update(train_meta) - model_meta.update(data_meta) - return model_meta - - -def main(): - filename = './input.yaml' - read_config_yaml(filename) - - -if __name__ == '__main__': - main() +import glob +import os +import warnings +from typing import Any, Callable, Dict + +import torch +import yaml + +import sevenn._const as _const +import sevenn._keys as KEY +import sevenn.util as util + + +def config_initialize( + key: str, + config: Dict, + default: Any, + conditions: Dict, +): + # default value exist & no user input -> return default + if key not in config.keys(): + return default + + # No validation method exist => accept user input + user_input = config[key] + if key in conditions: + condition = conditions[key] + else: + return user_input + + if type(default) is dict and isinstance(condition, dict): + for i_key, val in default.items(): + user_input[i_key] = config_initialize( + i_key, user_input, val, condition + ) + return user_input + elif isinstance(condition, type): + if isinstance(user_input, condition): + return user_input + else: + try: + return condition(user_input) # try type casting + except ValueError: + raise ValueError( + f"Expect '{user_input}' for '{key}' is {condition}" + ) + elif isinstance(condition, Callable) and condition(user_input): + return user_input + else: + raise ValueError( + f"Given input '{user_input}' for '{key}' is not valid" + ) + + +def init_model_config(config: Dict): + # defaults = _const.model_defaults(config) + model_meta = {} + + # init complicated ones + if KEY.CHEMICAL_SPECIES not in config.keys(): + raise ValueError('required key chemical_species not exist') + input_chem = config[KEY.CHEMICAL_SPECIES] + if isinstance(input_chem, str) and input_chem.lower() == 'auto': + model_meta[KEY.CHEMICAL_SPECIES] = 'auto' + model_meta[KEY.NUM_SPECIES] = 'auto' + model_meta[KEY.TYPE_MAP] = 'auto' + elif isinstance(input_chem, str) and 'univ' in input_chem.lower(): + model_meta.update(util.chemical_species_preprocess([], universal=True)) + else: + if isinstance(input_chem, list) and all( + isinstance(x, str) for x in input_chem + ): + pass + elif isinstance(input_chem, str): + input_chem = ( + input_chem.replace('-', ',').replace(' ', ',').split(',') + ) + input_chem = [chem for chem in input_chem if len(chem) != 0] + else: + raise ValueError(f'given {KEY.CHEMICAL_SPECIES} input is strange') + model_meta.update(util.chemical_species_preprocess(input_chem)) + + # deprecation warnings + if KEY.AVG_NUM_NEIGH in config: + warnings.warn( + "key 'avg_num_neigh' is deprecated. Please use 'conv_denominator'." + ' We use the default, the average number of neighbors in the' + ' dataset, if not provided.', + UserWarning, + ) + config.pop(KEY.AVG_NUM_NEIGH) + if KEY.TRAIN_AVG_NUM_NEIGH in config: + warnings.warn( + "key 'train_avg_num_neigh' is deprecated. Please use" + " 'train_denominator'. We overwrite train_denominator as given" + ' train_avg_num_neigh', + UserWarning, + ) + config[KEY.TRAIN_DENOMINTAOR] = config[KEY.TRAIN_AVG_NUM_NEIGH] + config.pop(KEY.TRAIN_AVG_NUM_NEIGH) + if KEY.OPTIMIZE_BY_REDUCE in config: + warnings.warn( + "key 'optimize_by_reduce' is deprecated. Always true", + UserWarning, + ) + config.pop(KEY.OPTIMIZE_BY_REDUCE) + + # init simpler ones + for key, default in _const.DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG.items(): + model_meta[key] = config_initialize( + key, config, default, _const.MODEL_CONFIG_CONDITION + ) + + unknown_keys = [ + key for key in config.keys() if key not in model_meta.keys() + ] + if len(unknown_keys) != 0: + warnings.warn( + f'Unexpected model keys: {unknown_keys} will be ignored', + UserWarning, + ) + + return model_meta + + +def init_train_config(config: Dict): + train_meta = {} + # defaults = _const.train_defaults(config) + + try: + device_input = config[KEY.DEVICE] + train_meta[KEY.DEVICE] = torch.device(device_input) + except KeyError: + train_meta[KEY.DEVICE] = ( + torch.device('cuda') + if torch.cuda.is_available() + else torch.device('cpu') + ) + train_meta[KEY.DEVICE] = str(train_meta[KEY.DEVICE]) + + # init simpler ones + for key, default in _const.DEFAULT_TRAINING_CONFIG.items(): + train_meta[key] = config_initialize( + key, config, default, _const.TRAINING_CONFIG_CONDITION + ) + + if KEY.CONTINUE in config.keys(): + cnt_dct = config[KEY.CONTINUE] + if KEY.CHECKPOINT not in cnt_dct.keys(): + raise ValueError('no checkpoint is given in continue') + checkpoint = cnt_dct[KEY.CHECKPOINT] + if os.path.isfile(checkpoint): + checkpoint_file = checkpoint + else: + checkpoint_file = util.pretrained_name_to_path(checkpoint) + train_meta[KEY.CONTINUE].update({KEY.CHECKPOINT: checkpoint_file}) + + unknown_keys = [ + key for key in config.keys() if key not in train_meta.keys() + ] + if len(unknown_keys) != 0: + warnings.warn( + f'Unexpected train keys: {unknown_keys} will be ignored', + UserWarning, + ) + return train_meta + + +def init_data_config(config: Dict): + data_meta = {} + # defaults = _const.data_defaults(config) + + load_data_keys = [] + for k in config: + if k.startswith('load_') and k.endswith('_path'): + load_data_keys.append(k) + + for load_data_key in load_data_keys: + if load_data_key in config.keys(): + inp = config[load_data_key] + extended = [] + if type(inp) not in [str, list]: + raise ValueError(f'unexpected input {inp} for sturcture_list') + if type(inp) is str: + extended = glob.glob(inp) + elif type(inp) is list: + for i in inp: + if isinstance(i, str): + extended.extend(glob.glob(i)) + elif isinstance(i, dict): + extended.append(i) + if len(extended) == 0: + raise ValueError( + f'Cannot find {inp} for {load_data_key}' + + ' or path is not given' + ) + data_meta[load_data_key] = extended + else: + data_meta[load_data_key] = False + + for key, default in _const.DEFAULT_DATA_CONFIG.items(): + data_meta[key] = config_initialize( + key, config, default, _const.DATA_CONFIG_CONDITION + ) + + unknown_keys = [ + key for key in config.keys() if key not in data_meta.keys() + ] + if len(unknown_keys) != 0: + warnings.warn( + f'Unexpected data keys: {unknown_keys} will be ignored', + UserWarning, + ) + return data_meta + + +def read_config_yaml(filename: str, return_separately: bool = False): + with open(filename, 'r') as fstream: + inputs = yaml.safe_load(fstream) + + model_meta, train_meta, data_meta = {}, {}, {} + for key, config in inputs.items(): + if key == 'model': + model_meta = init_model_config(config) + elif key == 'train': + train_meta = init_train_config(config) + elif key == 'data': + data_meta = init_data_config(config) + else: + raise ValueError(f'Unexpected input {key} given') + + if return_separately: + return model_meta, train_meta, data_meta + else: + model_meta.update(train_meta) + model_meta.update(data_meta) + return model_meta + + +def main(): + filename = './input.yaml' + read_config_yaml(filename) + + +if __name__ == '__main__': + main() diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 24894cc923674b16f0810c194bd4278fe4d4d34d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 184 zcmd1j<>g`kg0)*$W`O9&AOaaM0yz#qT+9L_QW%06G#UL?G8BP?5yY=({fzwFRQ=N8 z)FS<~lp=k%%)G>$kksN5{esGpjQqTK=imVS+{ENm-K5mKPO2Tq(qbkc!NLFlW#=!T diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/backward_compatibility.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/scripts/__pycache__/backward_compatibility.cpython-310.pyc deleted file mode 100644 index 141a08b4fe52ae09d43701ef6a9546b040c4a434..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5093 zcmb7I&2JmW72lct;D;!Rl4w~dLm!OBjJ@*teb$@SGl;k)-5fb}m=FRMz zc{A_*-kS|4CyN?>fBws(&EGF*+CQjq^2gEmC7$SQ5TV(cV1oP1W~$^iSEb%FY?En1 z7se-=Z3$CYD03nw@+kA7Ac`mpVnUQq76rel)hCBb#W&scxVYuI{aWaYT3@=a#IofD zuBf#)-S$>L^ny654TAP&E7)*FqxcxB(3)*V#o>?Y7M{y^qIr;5i&@MQ?LZgoA=~47 zY@{U`X5`J3HR?uc#)H1^rh2c{Pj$?Yno-~P;yO$D_PIxz%tOeR-r{E1a~B7ZD;GOL zF0Ob%%U^S&c(K15Z-#+$`Q|%|y;j?OW!=U47SBu3Z^?Lf@us`&22D3!jHnVUMs4Z! z<7jcc)!y1^N#V4^UcVK4>mKIZZS;53$!%9g9z?2jO_Cl4NR=5((&;>%kv=kFJwZRHzQ;HxUc!aB#igI< z!rEdewbnG0!V7!e&_Ag8)>B<`nyhh5IcQmyX6nubXFOe)7su}f#mGyKAc}E6&Qh&C6?7SDR_|YYeA( z>H3va?}z<#dhN!gt4-(4E6vsGSDR~>NYTry&0Ay6r8RKMPIL9f^-FJGwXZlg-}zPD zO0C2FnxIL|UaK2Q6kdRWIjP}5KV|+IQQKI7v9>oII4CTaHhTu|MlXYuV4!(6&u3VL zSJ@0J;wgc{@i(3n^&&V$W^nmE41UUpN$`=+{tFF45ke1P{?Qn*SQFMBdsDmf!x2yT z=P6YuVXr@`d41NKnE3>V(#BLhZiz*(l? zER_XqWQnN*{fILyX2^XWvQbXVCb?}*-br!?hL{yqn0GF*5(Zv?^^vnooz{}vX>F8G z%tTM%3U`-P%}_NRCVgu7S+R!SN+@ByC!nCVDJD)VJvPt zK2uwlEr3m}=Ssj~ZRypsuin00SK_X>G7#;B!u8KYHIxJ2twEF<)#9)gxI4sqrcb?` zmSL%`Tmhu%y`tBO)(r#QOl+{WIrZ^Oy``)^*dys z0rrWo6u2(K;arXQgc65*%n5#4qM_TNKcKP05)Bx)(%DQF2D-A930oS@W`a&IYfJXD z;-Wa(C4WX}%Nsez@aGyF}y1{I3JTBm|EqbNwGfr>FtKE%JNLtpAj(JTE<| z>O|*sgI8FY%>x$*8!Che9D(fuUTG6-kAH*-goT9S(O*GSW})@ieVu6q?Vfgz-Q)N4 zj-j%Q{D><*Mvzl}{QV<8)`bZ_X5`0-g&e>_D+kZaBaHnF-l`*HEqNNe(*lx!E@%fv zVkD-AY=LQa=O8yn9-gRnq94w1>V1Bn!LPNhxew|0wSBE?-3LA;`8~X+-pDHnZzVbS zd7-He9a%+T$J>CHIOn@hf#50VNU6pwDg1BPMvO?luXi!+NTpYwXbC_N=?mM}Kfu#P zv_`DfX>ybXR7C8Cil(Aa6o6#JSa_CriNfy?fA47>CXjM{0!UKQ%-GyFy8u`*F^^5! zHvp!k!~&gsU_f=K0?CN-8!R>xw#Ta4s36LtBCY1mAMW!wmzWBkKf|U^+-Gr~j-~?u z6_ro&aY0N|?;hTaw(j2EXJQ)iN(m$i?Ngo<@h_c5tpivcP3~*^T$C5I1+81`P9!>7 zGe<20t%-3<*SS6_kETYIq%6u?hWsL#IxH($34ks`k7_cdXytS5fi;>|8lFm~(b7jV z&=2xQv-?(3PNtH|o<6F^rGv?2_Mn_py9fdafIa%+X`Yx%s#WAO(ci9VA2YNju`9@5 zk{J^#7jv(}ZedQ)>@^KLHJKEW2|)hsyEo!ew~`dZj5v+cC?ti0>4(T|lH4+O17loST$e-dB zlG_{j@(op&7g1zd5a`LDgSdDin#R(&UVCEwS1CQLG`13 zlx_y+NmLEww`4jKKo1zt0^o}17QdwdN-!T7_zycp3WG#mfu1#NiY+wJyPh@`uYvfI>jcFuPl`CTxEXBsEMaYNmgJl{pk z!`0C7-A?RC+LOlPC-~9U3a?k6NsV^c-?bNDg|PU63fz&CWl#>S&z?CU)}g!DMkHr3 zVQNvgD^lG@io-Xfv@nj)xFxoHU%o=4xQ8U!$L%qc?aed~e%lWtH|24dTEgA-+8Et# z;W`O^;MPys2AnLm^RysxE4x4oaMZlNAkBXXVo#{r4K1^9RH7SGe)mWTR%KdYAmhhanbF`5#V!Nbf8@Cjj7NIIZxbw)M9SgSYyc&l$ z5U9q+CSSwS>Uw6WE40?Ph^PyJses@;w4+-fB}8|FSv-&EPB(!f=eqS(*{Yz89`Gyd zdE^CoKF`X$3|{f8d2^1HIDeAYON_3$7>hc+QDpp|<>xAF_!8;;f50e9t;l9FBbp#h z4wJfiF3q{;;WpuUrHmxk%`z%$&=<&?(mbU^NRlJDjy_~og)FPW=A^~r)JhVr*hRHV zlxjtCiA0f`$yY(rGG$B0l5PSfj9EvXCxa^W?624A-MVH97lw0$C30E zO|F5YrR?)igIqNN{81q(CCs#-EAz1ylH_D&L^0w7BIlPg8@@;ZG@`4yq2*0v-aN0P Ld@lb|>E{0cm!m&t diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/backward_compatibility.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/backward_compatibility.py index b01f9b1..b8e81b1 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/backward_compatibility.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/backward_compatibility.py @@ -1,184 +1,184 @@ -""" -Debt -keep old pre-trained checkpoints unchanged. -""" - -import copy - -import torch - -import sevenn._keys as KEY - - -def version_tuple(v1): - v1 = tuple(map(int, v1.split('.'))) - return v1 - - -def patch_old_config(config): - version = config.get('version', None) - if not version: - raise ValueError('No version found in config') - - major, minor, _ = version.split('.')[:3] - major, minor = int(major), int(minor) - - if major == 0 and minor <= 9: - if config[KEY.CUTOFF_FUNCTION][KEY.CUTOFF_FUNCTION_NAME] == 'XPLOR': - config[KEY.CUTOFF_FUNCTION].pop('poly_cut_p_value', None) - if KEY.TRAIN_DENOMINTAOR not in config: - config[KEY.TRAIN_DENOMINTAOR] = config.pop('train_avg_num_neigh', False) - _opt = config.pop('optimize_by_reduce', None) - if _opt is False: - raise ValueError( - 'This checkpoint(optimize_by_reduce: False) is no longer supported' - ) - if KEY.CONV_DENOMINATOR not in config: - config[KEY.CONV_DENOMINATOR] = 0.0 - if KEY._NORMALIZE_SPH not in config: - config[KEY._NORMALIZE_SPH] = False - - return config - - -def map_old_model(old_model_state_dict): - """ - For compatibility with old namings (before 'correct' branch merged 2404XX) - Map old model's module names to new model's module names - """ - _old_module_name_mapping = { - 'EdgeEmbedding': 'edge_embedding', - 'reducing nn input to hidden': 'reduce_input_to_hidden', - 'reducing nn hidden to energy': 'reduce_hidden_to_energy', - 'rescale atomic energy': 'rescale_atomic_energy', - } - for i in range(10): - _old_module_name_mapping[f'{i} self connection intro'] = ( - f'{i}_self_connection_intro' - ) - _old_module_name_mapping[f'{i} convolution'] = f'{i}_convolution' - _old_module_name_mapping[f'{i} self interaction 2'] = ( - f'{i}_self_interaction_2' - ) - _old_module_name_mapping[f'{i} equivariant gate'] = f'{i}_equivariant_gate' - - new_model_state_dict = {} - for k, v in old_model_state_dict.items(): - key_name = k.split('.')[0] - follower = '.'.join(k.split('.')[1:]) - if 'denumerator' in follower: - follower = follower.replace('denumerator', 'denominator') - if key_name in _old_module_name_mapping: - new_key_name = _old_module_name_mapping[key_name] + '.' + follower - new_model_state_dict[new_key_name] = v - else: - new_model_state_dict[k] = v - return new_model_state_dict - - -def sort_old_convolution(model_now, state_dict): - from e3nn.o3 import wigner_3j - - """ - Reason1: we have to sort instructions of convolution to be compatible with - cuEquivariance. (therefore, sort weight) - Reason2: some of old convolution module's w3j coeff has flipped sign. This also - has to be fixed to be compatible with cuEquivarinace. - """ - - def patch(stct): - inst_old = copy.copy(conv._instructions_before_sort) - inst_old = [(inst[0], inst[1], inst[2]) for inst in inst_old] - del conv._instructions_before_sort - - conv_args = conv.convolution_kwargs - irreps_in1 = conv_args['irreps_in1'] - irreps_in2 = conv_args['irreps_in2'] - irreps_out = conv_args.get('irreps_out', conv_args.get('filter_irreps_out')) - - inst_sorted = sorted(inst_old, key=lambda x: x[2]) - - inst_sorted = [ - # in1, in2, out, weights - (inst[0], inst[1], inst[2], irreps_in1[inst[0]].mul) - for inst in inst_sorted - ] - - n = len(weight_nn.hs) - 2 - ww_key = f'{conv_key}.weight_nn.layer{n}.weight' - ww = stct[ww_key] - ww_sorted = [None] * len(inst_old) - - _prev_idx = 0 - for ist_src in inst_old: - for j, ist_dst in enumerate(inst_sorted): - if not all(ist_src[ii] == ist_dst[ii] for ii in range(3)): - continue - - numel = ist_dst[3] # weight num - ww_src = ww[:, _prev_idx : _prev_idx + numel] - l1, l2, l3 = ( - irreps_in1[ist_src[0]].ir.l, - irreps_in2[ist_src[1]].ir.l, - irreps_out[ist_src[2]].ir.l, - ) - if l1 > 0 and l2 > 0 and l3 > 0: - w3j_key = f'_w3j_{l1}_{l2}_{l3}' - conv_w3j_key = ( - f'{conv_key}.convolution._compiled_main_left_right.{w3j_key}' - ) - w3j_old = stct[conv_w3j_key] - w3j_now = wigner_3j(l1, l2, l3) - if not torch.allclose(w3j_old.to(w3j_now.device), w3j_now): - assert torch.allclose( - w3j_old.to(w3j_now.device), -1 * w3j_now - ) - ww_src = -1 * ww_src - stct[conv_w3j_key] *= -1 # stct updated - _prev_idx += numel - ww_sorted[j] = ww_src - ww_sorted = torch.cat(ww_sorted, dim=1) # type: ignore - stct[ww_key] = ww_sorted.clone() # stct updated - - conv_dicts = {} - for k, v in state_dict.items(): - key_name = k.split('.')[0] - if key_name.split('_')[1] == 'convolution': - if key_name not in conv_dicts: - conv_dicts[key_name] = {} - conv_dicts[key_name].update({k: v}) - - new_state_dict = {} - new_state_dict.update(state_dict) - for conv_key, conv_state_dict in conv_dicts.items(): - conv = model_now._modules[conv_key] - weight_nn = conv.weight_nn - patch(conv_state_dict) - new_state_dict.update(conv_state_dict) - - return new_state_dict - - -def patch_state_dict_if_old(state_dict, config_cp, now_model): - version = config_cp.get('version', None) - if not version: - raise ValueError('No version found in config') - vs = version.split('.') - vsuffix = '' - if len(vs) == 4: - vsuffix = vs[-1] - vs = version_tuple('.'.join(vs[:3])) - else: - vs = version_tuple('.'.join(vs)) - - if vs < version_tuple('0.10.0'): - state_dict = map_old_model(state_dict) - - # TODO: change version criteria before release!!! - # it causes problem if model is sorted but this function is called - # ... more robust way? idk - if vs < version_tuple('0.11.0') or ( - vs == version_tuple('0.11.0') and vsuffix == 'dev0' - ): - state_dict = sort_old_convolution(now_model, state_dict) - return state_dict +""" +Debt +keep old pre-trained checkpoints unchanged. +""" + +import copy + +import torch + +import sevenn._keys as KEY + + +def version_tuple(v1): + v1 = tuple(map(int, v1.split('.'))) + return v1 + + +def patch_old_config(config): + version = config.get('version', None) + if not version: + raise ValueError('No version found in config') + + major, minor, _ = version.split('.')[:3] + major, minor = int(major), int(minor) + + if major == 0 and minor <= 9: + if config[KEY.CUTOFF_FUNCTION][KEY.CUTOFF_FUNCTION_NAME] == 'XPLOR': + config[KEY.CUTOFF_FUNCTION].pop('poly_cut_p_value', None) + if KEY.TRAIN_DENOMINTAOR not in config: + config[KEY.TRAIN_DENOMINTAOR] = config.pop('train_avg_num_neigh', False) + _opt = config.pop('optimize_by_reduce', None) + if _opt is False: + raise ValueError( + 'This checkpoint(optimize_by_reduce: False) is no longer supported' + ) + if KEY.CONV_DENOMINATOR not in config: + config[KEY.CONV_DENOMINATOR] = 0.0 + if KEY._NORMALIZE_SPH not in config: + config[KEY._NORMALIZE_SPH] = False + + return config + + +def map_old_model(old_model_state_dict): + """ + For compatibility with old namings (before 'correct' branch merged 2404XX) + Map old model's module names to new model's module names + """ + _old_module_name_mapping = { + 'EdgeEmbedding': 'edge_embedding', + 'reducing nn input to hidden': 'reduce_input_to_hidden', + 'reducing nn hidden to energy': 'reduce_hidden_to_energy', + 'rescale atomic energy': 'rescale_atomic_energy', + } + for i in range(10): + _old_module_name_mapping[f'{i} self connection intro'] = ( + f'{i}_self_connection_intro' + ) + _old_module_name_mapping[f'{i} convolution'] = f'{i}_convolution' + _old_module_name_mapping[f'{i} self interaction 2'] = ( + f'{i}_self_interaction_2' + ) + _old_module_name_mapping[f'{i} equivariant gate'] = f'{i}_equivariant_gate' + + new_model_state_dict = {} + for k, v in old_model_state_dict.items(): + key_name = k.split('.')[0] + follower = '.'.join(k.split('.')[1:]) + if 'denumerator' in follower: + follower = follower.replace('denumerator', 'denominator') + if key_name in _old_module_name_mapping: + new_key_name = _old_module_name_mapping[key_name] + '.' + follower + new_model_state_dict[new_key_name] = v + else: + new_model_state_dict[k] = v + return new_model_state_dict + + +def sort_old_convolution(model_now, state_dict): + from e3nn.o3 import wigner_3j + + """ + Reason1: we have to sort instructions of convolution to be compatible with + cuEquivariance. (therefore, sort weight) + Reason2: some of old convolution module's w3j coeff has flipped sign. This also + has to be fixed to be compatible with cuEquivarinace. + """ + + def patch(stct): + inst_old = copy.copy(conv._instructions_before_sort) + inst_old = [(inst[0], inst[1], inst[2]) for inst in inst_old] + del conv._instructions_before_sort + + conv_args = conv.convolution_kwargs + irreps_in1 = conv_args['irreps_in1'] + irreps_in2 = conv_args['irreps_in2'] + irreps_out = conv_args.get('irreps_out', conv_args.get('filter_irreps_out')) + + inst_sorted = sorted(inst_old, key=lambda x: x[2]) + + inst_sorted = [ + # in1, in2, out, weights + (inst[0], inst[1], inst[2], irreps_in1[inst[0]].mul) + for inst in inst_sorted + ] + + n = len(weight_nn.hs) - 2 + ww_key = f'{conv_key}.weight_nn.layer{n}.weight' + ww = stct[ww_key] + ww_sorted = [None] * len(inst_old) + + _prev_idx = 0 + for ist_src in inst_old: + for j, ist_dst in enumerate(inst_sorted): + if not all(ist_src[ii] == ist_dst[ii] for ii in range(3)): + continue + + numel = ist_dst[3] # weight num + ww_src = ww[:, _prev_idx : _prev_idx + numel] + l1, l2, l3 = ( + irreps_in1[ist_src[0]].ir.l, + irreps_in2[ist_src[1]].ir.l, + irreps_out[ist_src[2]].ir.l, + ) + if l1 > 0 and l2 > 0 and l3 > 0: + w3j_key = f'_w3j_{l1}_{l2}_{l3}' + conv_w3j_key = ( + f'{conv_key}.convolution._compiled_main_left_right.{w3j_key}' + ) + w3j_old = stct[conv_w3j_key] + w3j_now = wigner_3j(l1, l2, l3) + if not torch.allclose(w3j_old.to(w3j_now.device), w3j_now): + assert torch.allclose( + w3j_old.to(w3j_now.device), -1 * w3j_now + ) + ww_src = -1 * ww_src + stct[conv_w3j_key] *= -1 # stct updated + _prev_idx += numel + ww_sorted[j] = ww_src + ww_sorted = torch.cat(ww_sorted, dim=1) # type: ignore + stct[ww_key] = ww_sorted.clone() # stct updated + + conv_dicts = {} + for k, v in state_dict.items(): + key_name = k.split('.')[0] + if key_name.split('_')[1] == 'convolution': + if key_name not in conv_dicts: + conv_dicts[key_name] = {} + conv_dicts[key_name].update({k: v}) + + new_state_dict = {} + new_state_dict.update(state_dict) + for conv_key, conv_state_dict in conv_dicts.items(): + conv = model_now._modules[conv_key] + weight_nn = conv.weight_nn + patch(conv_state_dict) + new_state_dict.update(conv_state_dict) + + return new_state_dict + + +def patch_state_dict_if_old(state_dict, config_cp, now_model): + version = config_cp.get('version', None) + if not version: + raise ValueError('No version found in config') + vs = version.split('.') + vsuffix = '' + if len(vs) == 4: + vsuffix = vs[-1] + vs = version_tuple('.'.join(vs[:3])) + else: + vs = version_tuple('.'.join(vs)) + + if vs < version_tuple('0.10.0'): + state_dict = map_old_model(state_dict) + + # TODO: change version criteria before release!!! + # it causes problem if model is sorted but this function is called + # ... more robust way? idk + if vs < version_tuple('0.11.0') or ( + vs == version_tuple('0.11.0') and vsuffix == 'dev0' + ): + state_dict = sort_old_convolution(now_model, state_dict) + return state_dict diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py index 99882a6..5581e19 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/convert_model_modality.py @@ -1,301 +1,301 @@ -import math -from typing import List - -import torch -import torch.nn as nn -from e3nn.o3 import Irreps, Linear - -import sevenn._keys as KEY -from sevenn.model_build import build_E3_equivariant_model - -modal_module_dict = { - KEY.USE_MODAL_NODE_EMBEDDING: 'onehot_to_feature_x', - KEY.USE_MODAL_SELF_INTER_INTRO: 'self_interaction_1', - KEY.USE_MODAL_SELF_INTER_OUTRO: 'self_interaction_2', - KEY.USE_MODAL_OUTPUT_BLOCK: 'reduce_input_to_hidden', -} - - -def _get_scalar_index(irreps: Irreps): - scalar_indices = [] - for idx, (_, (l, p)) in enumerate(irreps): # noqa - if ( - l == 0 and p == 1 - ): # get index of parameter for scalar (0e), which is used for modality - scalar_indices.append(idx) - - return scalar_indices - - -def _reshape_weight_of_linear( - irreps_in: Irreps, irreps_out: Irreps, weight: torch.Tensor -) -> List[torch.Tensor]: - linear = Linear(irreps_in, irreps_out) - linear.weight = nn.Parameter(weight) - return list(linear.weight_views()) - - -def _erase_linear_modal_params( - model_state_dct: dict, - erase_modal_indices: List[int], - key: str, - irreps_in: Irreps, - irreps_out: Irreps, -): - orig_input_dim = irreps_in.count('0e') - new_input_dim = orig_input_dim - len(erase_modal_indices) - - orig_weight = model_state_dct[key + '.linear.weight'] - scalar_idx = _get_scalar_index(irreps_in) - linear_weight_list = _reshape_weight_of_linear( - irreps_in, irreps_out, orig_weight - ) - - new_weight_list = [] - - for idx, l_p_weight in enumerate(linear_weight_list[:-1]): - new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze() - if idx in scalar_idx: - new_weight = new_weight * math.sqrt(new_input_dim / orig_input_dim) - - new_weight_list.append(new_weight) - - """ - Following works for normalization = `path`, which is not used in SEVENNet - for l_p_weight in linear_weight_list[:-1]: - new_weight_list.append(torch.reshape(l_p_weight, (1, -1)).squeeze()) - """ - - flattened_weight = torch.cat(new_weight_list) - - return flattened_weight - - -def _get_modal_weight_as_bias( - model_state_dct: dict, - key: str, - ref_index: int, - irreps_in: Irreps, - irreps_out: Irreps, -): - assert ref_index != -1 - input_dim = irreps_in.count('0e') - output_dim = irreps_out.count('0e') - orig_weight = model_state_dct[key + '.linear.weight'] - orig_bias = model_state_dct[key + '.linear.bias'] - if len(orig_bias) == 0: - orig_bias = torch.zeros(output_dim, dtype=orig_weight.dtype) - - modal_weight = _reshape_weight_of_linear( - irreps_in, irreps_out, orig_weight - )[-1] - - new_bias = orig_bias + modal_weight[ref_index] / math.sqrt(input_dim) - - return new_bias - - -def _append_modal_weight( - model_state_dct: dict, # state dict to be targeted - key: str, # linear weight modune name - irreps_in: Irreps, # irreps_in before modality append - irreps_out: Irreps, - append_number: int, -): - # This works for normalization = `element`, default in SEVENNet. - # (normalization = `path` is curruently deprecated in SEVENNet.) - input_dim = irreps_in.count('0e') - output_dim = irreps_out.count('0e') - new_input_dim = input_dim + append_number - orig_weight = model_state_dct[key + '.linear.weight'] - scalar_idx = _get_scalar_index(irreps_in) - linear_weight_list = _reshape_weight_of_linear( - irreps_in, irreps_out, orig_weight - ) - - new_weight_list = [] - - # TODO: combine following as function with _erase_linear_modal_params - - for idx, l_p_weight in enumerate(linear_weight_list): - new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze() - if idx in scalar_idx: - new_weight = new_weight * math.sqrt(new_input_dim / input_dim) - - new_weight_list.append(new_weight) - - flattened_weight_list = [] - for l_p_weight in new_weight_list: - flattened_weight_list.append( - torch.reshape(l_p_weight, (1, -1)).squeeze() - ) - flattened_weight = torch.cat(flattened_weight_list) - - append_weight = torch.cat([ - flattened_weight, - torch.zeros(append_number * output_dim, dtype=flattened_weight.dtype), - ]) # zeros: starting from common model - - return append_weight - - -def get_single_modal_model_dct( - model_state_dct: dict, - config: dict, - ref_modal: str, - from_processing_cp: bool = False, - is_deploy: bool = False, -): - """ - Convert multimodal model state dictionary to single modal model. - Modal is selected by `ref_modal` - - `model_state_dct`: model state dictionary from multimodal checkpoint file - `config`: dictionary containing configuration of the checkpoint model - `ref_modal`: modal that are going to be converted - `from_processing_cp`: if True, use modal_map of the checkpoint file - `is_deploy`: if True, model is build with single-modal shift and scale - """ - if ( - not from_processing_cp and not config[KEY.USE_MODALITY] - ): # model is already single modal - return model_state_dct - - config[KEY.USE_BIAS_IN_LINEAR] = True - config['_deploy'] = is_deploy - - model = build_E3_equivariant_model(config) - del config['_deploy'] - key_add = '_cp' if from_processing_cp else '' - modal_type_dict = config[KEY.MODAL_MAP + key_add] - erase_modal_indices = range(len(modal_type_dict.keys())) # starts with 0 - - if ref_modal != 'common': - try: - ref_modal_index = modal_type_dict[ref_modal] - except: - raise KeyError( - f'{ref_modal} not in modal type. Use one of' - f' {modal_type_dict.keys()}.' - ) - - for module_key in model._modules.keys(): - for ( - use_modal_module_key, - modal_module_name, - ) in modal_module_dict.items(): - irreps_out = Irreps(model.get_irreps_in(module_key, 'irreps_out')) - # TODO: directly using "irreps_in" might not be compatible - # when changing `nn/linear.py` - output_dim = irreps_out.count('0e') - if ( - config[use_modal_module_key] - and modal_module_name in module_key - ): # this module is used for giving modality - - irreps_in = Irreps( - model.get_irreps_in(module_key, 'irreps_in') - ) - - new_bias = ( - torch.zeros(output_dim) - if ref_modal == 'common' - else _get_modal_weight_as_bias( - model_state_dct, - module_key, - ref_modal_index, - irreps_in, # type: ignore - irreps_out, # type: ignore - ) - ) - erased_modal_weight = _erase_linear_modal_params( - model_state_dct, - erase_modal_indices, - module_key, - irreps_in, # type: ignore - irreps_out, # type: ignore - ) - - model_state_dct[module_key + '.linear.weight'] = ( - erased_modal_weight - ) - model_state_dct[module_key + '.linear.bias'] = new_bias - elif modal_module_name in module_key: - model_state_dct[module_key + '.linear.bias'] = torch.zeros( - output_dim, - dtype=model_state_dct[module_key + '.linear.weight'].dtype, - ) - - final_block_key = 'reduce_hidden_to_energy' - model_state_dct[final_block_key + '.linear.bias'] = torch.tensor( - [0], dtype=model_state_dct[final_block_key + '.linear.weight'].dtype - ) - - if config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SHIFT]: - rescaler_names = [] - if config[KEY.USE_MODAL_WISE_SHIFT]: - rescaler_names.append('shift') - if config[KEY.USE_MODAL_WISE_SCALE]: - rescaler_names.append('scale') - config[KEY.USE_MODAL_WISE_SHIFT] = False - config[KEY.USE_MODAL_WISE_SCALE] = False - for rescaler_name in rescaler_names: - rescaler_key = 'rescale_atomic_energy.' + rescaler_name - rescaler = model_state_dct[rescaler_key][ref_modal_index] - model_state_dct.update({rescaler_key: rescaler}) - config.update({rescaler_name: rescaler}) - - config[KEY.USE_MODALITY] = False - - return model_state_dct - - -def append_modality_to_model_dct( - model_state_dct: dict, - config: dict, - orig_num_modal: int, - append_modal_length: int, -): - """ - Append modal-wise parameters to the original linear layers. - This enables expanding modal to single/multi modal model checkpoint. - - `model_state_dct`: model state dictionary from multimodal checkpoint file - `config`: dictionary containing configuration of the checkpoint model - + modality appended - `orig_num_modal`: Number of modality used in original checkpoint - `append_modal_length`: Number of modality to be appended in new checkpoint. - """ - config_num_modal = config[KEY.NUM_MODALITIES] - config.update({KEY.NUM_MODALITIES: orig_num_modal, KEY.USE_MODALITY: True}) - - model = build_E3_equivariant_model(config) - - for module_key in model._modules.keys(): - for ( - use_modal_module_key, - modal_module_name, - ) in modal_module_dict.items(): - if ( - config[use_modal_module_key] - and modal_module_name in module_key - ): # this module is used for giving modality - irreps_in = model.get_irreps_in( - module_key, 'irreps_in' - ) - # TODO: directly using "irreps_in" might not be compatible - # when changing `nn/linear.py` - irreps_out = model.get_irreps_in(module_key, 'irreps_out') - irreps_in, irreps_out = Irreps(irreps_in), Irreps(irreps_out) - - append_weight = _append_modal_weight( - model_state_dct, - module_key, - irreps_in, # type: ignore - irreps_out, # type: ignore - append_modal_length, - ) - model_state_dct[module_key + '.linear.weight'] = append_weight - config[KEY.NUM_MODALITIES] = config_num_modal - - return model_state_dct +import math +from typing import List + +import torch +import torch.nn as nn +from e3nn.o3 import Irreps, Linear + +import sevenn._keys as KEY +from sevenn.model_build import build_E3_equivariant_model + +modal_module_dict = { + KEY.USE_MODAL_NODE_EMBEDDING: 'onehot_to_feature_x', + KEY.USE_MODAL_SELF_INTER_INTRO: 'self_interaction_1', + KEY.USE_MODAL_SELF_INTER_OUTRO: 'self_interaction_2', + KEY.USE_MODAL_OUTPUT_BLOCK: 'reduce_input_to_hidden', +} + + +def _get_scalar_index(irreps: Irreps): + scalar_indices = [] + for idx, (_, (l, p)) in enumerate(irreps): # noqa + if ( + l == 0 and p == 1 + ): # get index of parameter for scalar (0e), which is used for modality + scalar_indices.append(idx) + + return scalar_indices + + +def _reshape_weight_of_linear( + irreps_in: Irreps, irreps_out: Irreps, weight: torch.Tensor +) -> List[torch.Tensor]: + linear = Linear(irreps_in, irreps_out) + linear.weight = nn.Parameter(weight) + return list(linear.weight_views()) + + +def _erase_linear_modal_params( + model_state_dct: dict, + erase_modal_indices: List[int], + key: str, + irreps_in: Irreps, + irreps_out: Irreps, +): + orig_input_dim = irreps_in.count('0e') + new_input_dim = orig_input_dim - len(erase_modal_indices) + + orig_weight = model_state_dct[key + '.linear.weight'] + scalar_idx = _get_scalar_index(irreps_in) + linear_weight_list = _reshape_weight_of_linear( + irreps_in, irreps_out, orig_weight + ) + + new_weight_list = [] + + for idx, l_p_weight in enumerate(linear_weight_list[:-1]): + new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze() + if idx in scalar_idx: + new_weight = new_weight * math.sqrt(new_input_dim / orig_input_dim) + + new_weight_list.append(new_weight) + + """ + Following works for normalization = `path`, which is not used in SEVENNet + for l_p_weight in linear_weight_list[:-1]: + new_weight_list.append(torch.reshape(l_p_weight, (1, -1)).squeeze()) + """ + + flattened_weight = torch.cat(new_weight_list) + + return flattened_weight + + +def _get_modal_weight_as_bias( + model_state_dct: dict, + key: str, + ref_index: int, + irreps_in: Irreps, + irreps_out: Irreps, +): + assert ref_index != -1 + input_dim = irreps_in.count('0e') + output_dim = irreps_out.count('0e') + orig_weight = model_state_dct[key + '.linear.weight'] + orig_bias = model_state_dct[key + '.linear.bias'] + if len(orig_bias) == 0: + orig_bias = torch.zeros(output_dim, dtype=orig_weight.dtype) + + modal_weight = _reshape_weight_of_linear( + irreps_in, irreps_out, orig_weight + )[-1] + + new_bias = orig_bias + modal_weight[ref_index] / math.sqrt(input_dim) + + return new_bias + + +def _append_modal_weight( + model_state_dct: dict, # state dict to be targeted + key: str, # linear weight modune name + irreps_in: Irreps, # irreps_in before modality append + irreps_out: Irreps, + append_number: int, +): + # This works for normalization = `element`, default in SEVENNet. + # (normalization = `path` is curruently deprecated in SEVENNet.) + input_dim = irreps_in.count('0e') + output_dim = irreps_out.count('0e') + new_input_dim = input_dim + append_number + orig_weight = model_state_dct[key + '.linear.weight'] + scalar_idx = _get_scalar_index(irreps_in) + linear_weight_list = _reshape_weight_of_linear( + irreps_in, irreps_out, orig_weight + ) + + new_weight_list = [] + + # TODO: combine following as function with _erase_linear_modal_params + + for idx, l_p_weight in enumerate(linear_weight_list): + new_weight = torch.reshape(l_p_weight, (1, -1)).squeeze() + if idx in scalar_idx: + new_weight = new_weight * math.sqrt(new_input_dim / input_dim) + + new_weight_list.append(new_weight) + + flattened_weight_list = [] + for l_p_weight in new_weight_list: + flattened_weight_list.append( + torch.reshape(l_p_weight, (1, -1)).squeeze() + ) + flattened_weight = torch.cat(flattened_weight_list) + + append_weight = torch.cat([ + flattened_weight, + torch.zeros(append_number * output_dim, dtype=flattened_weight.dtype), + ]) # zeros: starting from common model + + return append_weight + + +def get_single_modal_model_dct( + model_state_dct: dict, + config: dict, + ref_modal: str, + from_processing_cp: bool = False, + is_deploy: bool = False, +): + """ + Convert multimodal model state dictionary to single modal model. + Modal is selected by `ref_modal` + + `model_state_dct`: model state dictionary from multimodal checkpoint file + `config`: dictionary containing configuration of the checkpoint model + `ref_modal`: modal that are going to be converted + `from_processing_cp`: if True, use modal_map of the checkpoint file + `is_deploy`: if True, model is build with single-modal shift and scale + """ + if ( + not from_processing_cp and not config[KEY.USE_MODALITY] + ): # model is already single modal + return model_state_dct + + config[KEY.USE_BIAS_IN_LINEAR] = True + config['_deploy'] = is_deploy + + model = build_E3_equivariant_model(config) + del config['_deploy'] + key_add = '_cp' if from_processing_cp else '' + modal_type_dict = config[KEY.MODAL_MAP + key_add] + erase_modal_indices = range(len(modal_type_dict.keys())) # starts with 0 + + if ref_modal != 'common': + try: + ref_modal_index = modal_type_dict[ref_modal] + except: + raise KeyError( + f'{ref_modal} not in modal type. Use one of' + f' {modal_type_dict.keys()}.' + ) + + for module_key in model._modules.keys(): + for ( + use_modal_module_key, + modal_module_name, + ) in modal_module_dict.items(): + irreps_out = Irreps(model.get_irreps_in(module_key, 'irreps_out')) + # TODO: directly using "irreps_in" might not be compatible + # when changing `nn/linear.py` + output_dim = irreps_out.count('0e') + if ( + config[use_modal_module_key] + and modal_module_name in module_key + ): # this module is used for giving modality + + irreps_in = Irreps( + model.get_irreps_in(module_key, 'irreps_in') + ) + + new_bias = ( + torch.zeros(output_dim) + if ref_modal == 'common' + else _get_modal_weight_as_bias( + model_state_dct, + module_key, + ref_modal_index, + irreps_in, # type: ignore + irreps_out, # type: ignore + ) + ) + erased_modal_weight = _erase_linear_modal_params( + model_state_dct, + erase_modal_indices, + module_key, + irreps_in, # type: ignore + irreps_out, # type: ignore + ) + + model_state_dct[module_key + '.linear.weight'] = ( + erased_modal_weight + ) + model_state_dct[module_key + '.linear.bias'] = new_bias + elif modal_module_name in module_key: + model_state_dct[module_key + '.linear.bias'] = torch.zeros( + output_dim, + dtype=model_state_dct[module_key + '.linear.weight'].dtype, + ) + + final_block_key = 'reduce_hidden_to_energy' + model_state_dct[final_block_key + '.linear.bias'] = torch.tensor( + [0], dtype=model_state_dct[final_block_key + '.linear.weight'].dtype + ) + + if config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SHIFT]: + rescaler_names = [] + if config[KEY.USE_MODAL_WISE_SHIFT]: + rescaler_names.append('shift') + if config[KEY.USE_MODAL_WISE_SCALE]: + rescaler_names.append('scale') + config[KEY.USE_MODAL_WISE_SHIFT] = False + config[KEY.USE_MODAL_WISE_SCALE] = False + for rescaler_name in rescaler_names: + rescaler_key = 'rescale_atomic_energy.' + rescaler_name + rescaler = model_state_dct[rescaler_key][ref_modal_index] + model_state_dct.update({rescaler_key: rescaler}) + config.update({rescaler_name: rescaler}) + + config[KEY.USE_MODALITY] = False + + return model_state_dct + + +def append_modality_to_model_dct( + model_state_dct: dict, + config: dict, + orig_num_modal: int, + append_modal_length: int, +): + """ + Append modal-wise parameters to the original linear layers. + This enables expanding modal to single/multi modal model checkpoint. + + `model_state_dct`: model state dictionary from multimodal checkpoint file + `config`: dictionary containing configuration of the checkpoint model + + modality appended + `orig_num_modal`: Number of modality used in original checkpoint + `append_modal_length`: Number of modality to be appended in new checkpoint. + """ + config_num_modal = config[KEY.NUM_MODALITIES] + config.update({KEY.NUM_MODALITIES: orig_num_modal, KEY.USE_MODALITY: True}) + + model = build_E3_equivariant_model(config) + + for module_key in model._modules.keys(): + for ( + use_modal_module_key, + modal_module_name, + ) in modal_module_dict.items(): + if ( + config[use_modal_module_key] + and modal_module_name in module_key + ): # this module is used for giving modality + irreps_in = model.get_irreps_in( + module_key, 'irreps_in' + ) + # TODO: directly using "irreps_in" might not be compatible + # when changing `nn/linear.py` + irreps_out = model.get_irreps_in(module_key, 'irreps_out') + irreps_in, irreps_out = Irreps(irreps_in), Irreps(irreps_out) + + append_weight = _append_modal_weight( + model_state_dct, + module_key, + irreps_in, # type: ignore + irreps_out, # type: ignore + append_modal_length, + ) + model_state_dct[module_key + '.linear.weight'] = append_weight + config[KEY.NUM_MODALITIES] = config_num_modal + + return model_state_dct diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py index c069579..51ded15 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/deploy.py @@ -1,148 +1,148 @@ -import os -from datetime import datetime -from typing import Optional - -import e3nn.util.jit -import torch -import torch.nn -from ase.data import chemical_symbols - -import sevenn._keys as KEY -from sevenn import __version__ -from sevenn.model_build import build_E3_equivariant_model -from sevenn.util import load_checkpoint - - -def deploy(checkpoint, fname='deployed_serial.pt', modal: Optional[str] = None): - """ - This method is messy to avoid changes in pair_e3gnn.cpp, while - refactoring python part. - If changes the behavior, and accordingly pair_e3gnn.cpp, - we have to recompile LAMMPS (which I always want to procrastinate) - """ - from sevenn.nn.edge_embedding import EdgePreprocess - from sevenn.nn.force_output import ForceStressOutput - - cp = load_checkpoint(checkpoint) - model, config = cp.build_model('e3nn'), cp.config - - model.prepand_module('edge_preprocess', EdgePreprocess(True)) - grad_module = ForceStressOutput() - model.replace_module('force_output', grad_module) - new_grad_key = grad_module.get_grad_key() - model.key_grad = new_grad_key - if hasattr(model, 'eval_type_map'): - setattr(model, 'eval_type_map', False) - - if modal: - model.prepare_modal_deploy(modal) - elif model.modal_map is not None and len(model.modal_map) >= 1: - raise ValueError( - f'Modal is not given. It has: {list(model.modal_map.keys())}' - ) - - model.set_is_batch_data(False) - model.eval() - - model = e3nn.util.jit.script(model) - model = torch.jit.freeze(model) - - # make some config need for md - md_configs = {} - type_map = config[KEY.TYPE_MAP] - chem_list = '' - for Z in type_map.keys(): - chem_list += chemical_symbols[Z] + ' ' - chem_list.strip() - md_configs.update({'chemical_symbols_to_index': chem_list}) - md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) - md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) - md_configs.update( - {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} - ) - md_configs.update({'version': __version__}) - md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) - md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) - - if fname.endswith('.pt') is False: - fname += '.pt' - torch.jit.save(model, fname, _extra_files=md_configs) - - -# TODO: build model only once -def deploy_parallel( - checkpoint, fname='deployed_parallel', modal: Optional[str] = None -): - # Additional layer for ghost atom (and copy parameters from original) - GHOST_LAYERS_KEYS = ['onehot_to_feature_x', '0_self_interaction_1'] - - cp = load_checkpoint(checkpoint) - model, config = cp.build_model('e3nn'), cp.config - config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': False} - model_state_dct = model.state_dict() - - model_list = build_E3_equivariant_model(config, parallel=True) - dct_temp = {} - copy_counter = {gk: 0 for gk in GHOST_LAYERS_KEYS} - for ghost_layer_key in GHOST_LAYERS_KEYS: - for key, val in model_state_dct.items(): - if not key.startswith(ghost_layer_key): - continue - dct_temp.update({f'ghost_{key}': val}) - copy_counter[ghost_layer_key] += 1 - # Ensure reference weights are copied from state dict - assert all(x > 0 for x in copy_counter.values()) - - model_state_dct.update(dct_temp) - - for model_part in model_list: - missing, _ = model_part.load_state_dict(model_state_dct, strict=False) - if hasattr(model_part, 'eval_type_map'): - setattr(model_part, 'eval_type_map', False) - # Ensure all values are inserted - assert len(missing) == 0, missing - - if modal: - model_list[0].prepare_modal_deploy(modal) - elif model_list[0].modal_map is not None: - raise ValueError( - f'Modal is not given. It has: {list(model_list[0].modal_map.keys())}' - ) - - # prepare some extra information for MD - md_configs = {} - type_map = config[KEY.TYPE_MAP] - - chem_list = '' - for Z in type_map.keys(): - chem_list += chemical_symbols[Z] + ' ' - chem_list.strip() - - comm_size = max( - [ - seg._modules[f'{t}_convolution']._comm_size # type: ignore - for t, seg in enumerate(model_list) - ] - ) - - md_configs.update({'chemical_symbols_to_index': chem_list}) - md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) - md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) - md_configs.update({'comm_size': str(comm_size)}) - md_configs.update( - {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} - ) - md_configs.update({'version': __version__}) - md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) - md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) - - os.makedirs(fname) - for idx, model in enumerate(model_list): - fname_full = f'{fname}/deployed_parallel_{idx}.pt' - model.set_is_batch_data(False) - model.eval() - - model = e3nn.util.jit.script(model) - model = torch.jit.freeze(model) - - torch.jit.save(model, fname_full, _extra_files=md_configs) +import os +from datetime import datetime +from typing import Optional + +import e3nn.util.jit +import torch +import torch.nn +from ase.data import chemical_symbols + +import sevenn._keys as KEY +from sevenn import __version__ +from sevenn.model_build import build_E3_equivariant_model +from sevenn.util import load_checkpoint + + +def deploy(checkpoint, fname='deployed_serial.pt', modal: Optional[str] = None): + """ + This method is messy to avoid changes in pair_e3gnn.cpp, while + refactoring python part. + If changes the behavior, and accordingly pair_e3gnn.cpp, + we have to recompile LAMMPS (which I always want to procrastinate) + """ + from sevenn.nn.edge_embedding import EdgePreprocess + from sevenn.nn.force_output import ForceStressOutput + + cp = load_checkpoint(checkpoint) + model, config = cp.build_model('e3nn'), cp.config + + model.prepand_module('edge_preprocess', EdgePreprocess(True)) + grad_module = ForceStressOutput() + model.replace_module('force_output', grad_module) + new_grad_key = grad_module.get_grad_key() + model.key_grad = new_grad_key + if hasattr(model, 'eval_type_map'): + setattr(model, 'eval_type_map', False) + + if modal: + model.prepare_modal_deploy(modal) + elif model.modal_map is not None and len(model.modal_map) >= 1: + raise ValueError( + f'Modal is not given. It has: {list(model.modal_map.keys())}' + ) + + model.set_is_batch_data(False) + model.eval() + + model = e3nn.util.jit.script(model) + model = torch.jit.freeze(model) + + # make some config need for md + md_configs = {} + type_map = config[KEY.TYPE_MAP] + chem_list = '' + for Z in type_map.keys(): + chem_list += chemical_symbols[Z] + ' ' + chem_list.strip() + md_configs.update({'chemical_symbols_to_index': chem_list}) + md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) + md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) + md_configs.update( + {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} + ) + md_configs.update({'version': __version__}) + md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) + md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) + + if fname.endswith('.pt') is False: + fname += '.pt' + torch.jit.save(model, fname, _extra_files=md_configs) + + +# TODO: build model only once +def deploy_parallel( + checkpoint, fname='deployed_parallel', modal: Optional[str] = None +): + # Additional layer for ghost atom (and copy parameters from original) + GHOST_LAYERS_KEYS = ['onehot_to_feature_x', '0_self_interaction_1'] + + cp = load_checkpoint(checkpoint) + model, config = cp.build_model('e3nn'), cp.config + config[KEY.CUEQUIVARIANCE_CONFIG] = {'use': False} + model_state_dct = model.state_dict() + + model_list = build_E3_equivariant_model(config, parallel=True) + dct_temp = {} + copy_counter = {gk: 0 for gk in GHOST_LAYERS_KEYS} + for ghost_layer_key in GHOST_LAYERS_KEYS: + for key, val in model_state_dct.items(): + if not key.startswith(ghost_layer_key): + continue + dct_temp.update({f'ghost_{key}': val}) + copy_counter[ghost_layer_key] += 1 + # Ensure reference weights are copied from state dict + assert all(x > 0 for x in copy_counter.values()) + + model_state_dct.update(dct_temp) + + for model_part in model_list: + missing, _ = model_part.load_state_dict(model_state_dct, strict=False) + if hasattr(model_part, 'eval_type_map'): + setattr(model_part, 'eval_type_map', False) + # Ensure all values are inserted + assert len(missing) == 0, missing + + if modal: + model_list[0].prepare_modal_deploy(modal) + elif model_list[0].modal_map is not None: + raise ValueError( + f'Modal is not given. It has: {list(model_list[0].modal_map.keys())}' + ) + + # prepare some extra information for MD + md_configs = {} + type_map = config[KEY.TYPE_MAP] + + chem_list = '' + for Z in type_map.keys(): + chem_list += chemical_symbols[Z] + ' ' + chem_list.strip() + + comm_size = max( + [ + seg._modules[f'{t}_convolution']._comm_size # type: ignore + for t, seg in enumerate(model_list) + ] + ) + + md_configs.update({'chemical_symbols_to_index': chem_list}) + md_configs.update({'cutoff': str(config[KEY.CUTOFF])}) + md_configs.update({'num_species': str(config[KEY.NUM_SPECIES])}) + md_configs.update({'comm_size': str(comm_size)}) + md_configs.update( + {'model_type': config.pop(KEY.MODEL_TYPE, 'E3_equivariant_model')} + ) + md_configs.update({'version': __version__}) + md_configs.update({'dtype': config.pop(KEY.DTYPE, 'single')}) + md_configs.update({'time': datetime.now().strftime('%Y-%m-%d')}) + + os.makedirs(fname) + for idx, model in enumerate(model_list): + fname_full = f'{fname}/deployed_parallel_{idx}.pt' + model.set_is_batch_data(False) + model.eval() + + model = e3nn.util.jit.script(model) + model = torch.jit.freeze(model) + + torch.jit.save(model, fname_full, _extra_files=md_configs) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py index f036475..af1b162 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/graph_build.py @@ -1,119 +1,119 @@ -import os -from typing import List, Optional - -from sevenn.logger import Logger -from sevenn.train.dataset import AtomGraphDataset -from sevenn.util import unique_filepath - - -def build_sevennet_graph_dataset( - source: List[str], - cutoff: float, - num_cores: int, - out: str, - filename: str, - metadata: Optional[dict] = None, - **fmt_kwargs, -): - from sevenn.train.graph_dataset import SevenNetGraphDataset - - log = Logger() - if metadata is None: - metadata = {} - - log.timer_start('graph_build') - db = SevenNetGraphDataset( - cutoff=cutoff, - root=out, - files=source, - processed_name=filename, - process_num_cores=num_cores, - **fmt_kwargs, - ) - log.timer_end('graph_build', 'graph build time') - log.writeline(f'Graph saved: {db.processed_paths[0]}') - - log.bar() - for k, v in metadata.items(): - log.format_k_v(k, v, write=True) - log.bar() - - log.writeline('Distribution:') - log.statistic_write(db.statistics) - log.format_k_v('# atoms (node)', db.natoms, write=True) - log.format_k_v('# structures (graph)', len(db), write=True) - - -def dataset_finalize(dataset, metadata, out): - """ - Deprecated - """ - natoms = dataset.get_natoms() - species = dataset.get_species() - metadata = { - **metadata, - 'natoms': natoms, - 'species': species, - } - dataset.meta = metadata - - if os.path.isdir(out): - out = os.path.join(out, 'graph_built.sevenn_data') - elif out.endswith('.sevenn_data') is False: - out = out + '.sevenn_data' - out = unique_filepath(out) - - log = Logger() - log.writeline('The metadata of the dataset is...') - for k, v in metadata.items(): - log.format_k_v(k, v, write=True) - dataset.save(out) - log.writeline(f'dataset is saved to {out}') - - return dataset - - -def build_script( - source: List[str], - cutoff: float, - num_cores: int, - out: str, - metadata: Optional[dict] = None, - **fmt_kwargs, -): - """ - Deprecated - """ - from sevenn.train.dataload import file_to_dataset, match_reader - - if metadata is None: - metadata = {} - log = Logger() - - dataset = AtomGraphDataset({}, cutoff) - common_args = { - 'cutoff': cutoff, - 'cores': num_cores, - 'label': 'graph_build', - } - log.timer_start('graph_build') - for path in source: - if os.path.isdir(path): - continue - log.writeline(f'Read: {path}') - basename = os.path.basename(path) - if 'structure_list' in basename: - fmt = 'structure_list' - else: - fmt = 'ase' - reader, rmeta = match_reader(fmt, **fmt_kwargs) - metadata.update(**rmeta) - dataset.augment( - file_to_dataset( - file=path, - reader=reader, - **common_args, - ) - ) - log.timer_end('graph_build', 'graph build time') - dataset_finalize(dataset, metadata, out) +import os +from typing import List, Optional + +from sevenn.logger import Logger +from sevenn.train.dataset import AtomGraphDataset +from sevenn.util import unique_filepath + + +def build_sevennet_graph_dataset( + source: List[str], + cutoff: float, + num_cores: int, + out: str, + filename: str, + metadata: Optional[dict] = None, + **fmt_kwargs, +): + from sevenn.train.graph_dataset import SevenNetGraphDataset + + log = Logger() + if metadata is None: + metadata = {} + + log.timer_start('graph_build') + db = SevenNetGraphDataset( + cutoff=cutoff, + root=out, + files=source, + processed_name=filename, + process_num_cores=num_cores, + **fmt_kwargs, + ) + log.timer_end('graph_build', 'graph build time') + log.writeline(f'Graph saved: {db.processed_paths[0]}') + + log.bar() + for k, v in metadata.items(): + log.format_k_v(k, v, write=True) + log.bar() + + log.writeline('Distribution:') + log.statistic_write(db.statistics) + log.format_k_v('# atoms (node)', db.natoms, write=True) + log.format_k_v('# structures (graph)', len(db), write=True) + + +def dataset_finalize(dataset, metadata, out): + """ + Deprecated + """ + natoms = dataset.get_natoms() + species = dataset.get_species() + metadata = { + **metadata, + 'natoms': natoms, + 'species': species, + } + dataset.meta = metadata + + if os.path.isdir(out): + out = os.path.join(out, 'graph_built.sevenn_data') + elif out.endswith('.sevenn_data') is False: + out = out + '.sevenn_data' + out = unique_filepath(out) + + log = Logger() + log.writeline('The metadata of the dataset is...') + for k, v in metadata.items(): + log.format_k_v(k, v, write=True) + dataset.save(out) + log.writeline(f'dataset is saved to {out}') + + return dataset + + +def build_script( + source: List[str], + cutoff: float, + num_cores: int, + out: str, + metadata: Optional[dict] = None, + **fmt_kwargs, +): + """ + Deprecated + """ + from sevenn.train.dataload import file_to_dataset, match_reader + + if metadata is None: + metadata = {} + log = Logger() + + dataset = AtomGraphDataset({}, cutoff) + common_args = { + 'cutoff': cutoff, + 'cores': num_cores, + 'label': 'graph_build', + } + log.timer_start('graph_build') + for path in source: + if os.path.isdir(path): + continue + log.writeline(f'Read: {path}') + basename = os.path.basename(path) + if 'structure_list' in basename: + fmt = 'structure_list' + else: + fmt = 'ase' + reader, rmeta = match_reader(fmt, **fmt_kwargs) + metadata.update(**rmeta) + dataset.augment( + file_to_dataset( + file=path, + reader=reader, + **common_args, + ) + ) + log.timer_end('graph_build', 'graph build time') + dataset_finalize(dataset, metadata, out) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py index 355cc25..93c998f 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/inference.py @@ -1,227 +1,227 @@ -import csv -import os -from typing import Iterable, List, Optional, Union - -import numpy as np -from torch_geometric.loader import DataLoader -from tqdm import tqdm - -import sevenn._keys as KEY -import sevenn.util as util -from sevenn.atom_graph_data import AtomGraphData -from sevenn.train.graph_dataset import SevenNetGraphDataset -from sevenn.train.modal_dataset import SevenNetMultiModalDataset - - -def write_inference_csv(output_list, out): - for i, output in enumerate(output_list): - output = output.fit_dimension() - output[KEY.STRESS] = output[KEY.STRESS] * 1602.1766208 - output[KEY.PRED_STRESS] = output[KEY.PRED_STRESS] * 1602.1766208 - output_list[i] = output.to_numpy_dict() - - per_graph_keys = [ - KEY.NUM_ATOMS, - KEY.USER_LABEL, - KEY.ENERGY, - KEY.PRED_TOTAL_ENERGY, - KEY.STRESS, - KEY.PRED_STRESS, - ] - - per_atom_keys = [ - KEY.ATOMIC_NUMBERS, - KEY.ATOMIC_ENERGY, - KEY.POS, - KEY.FORCE, - KEY.PRED_FORCE, - ] - - def unfold_dct_val(dct, keys, suffix_list=None): - res = {} - if suffix_list is None: - suffix_list = range(100) - for k in keys: - if k not in dct: - res[k] = '-' - elif isinstance(dct[k], np.ndarray) and dct[k].ndim != 0: - res.update( - {f'{k}_{suffix_list[i]}': v for i, v in enumerate(dct[k])} - ) - else: - res[k] = dct[k] - return res - - def per_atom_dct_list(dct, keys): - sfx_list = ['x', 'y', 'z'] - res = [] - natoms = dct[KEY.NUM_ATOMS] - extracted = {k: dct[k] for k in keys} - for i in range(natoms): - raw = {} - raw.update({k: v[i] for k, v in extracted.items()}) - per_atom_dct = unfold_dct_val(raw, keys, suffix_list=sfx_list) - res.append(per_atom_dct) - return res - - try: - with open(f'{out}/info.csv', 'w', newline='') as f: - header = output_list[0][KEY.INFO].keys() - writer = csv.DictWriter(f, fieldnames=header) - writer.writeheader() - for output in output_list: - writer.writerow(output[KEY.INFO]) - except (KeyError, TypeError, AttributeError, csv.Error) as e: - print(e) - print('failed to write meta data, info.csv is not written') - - with open(f'{out}/per_graph.csv', 'w', newline='') as f: - sfx_list = ['xx', 'yy', 'zz', 'xy', 'yz', 'zx'] # for stress - writer = None - for output in output_list: - cell_dct = {KEY.CELL: output[KEY.CELL]} - cell_dct = unfold_dct_val(cell_dct, [KEY.CELL], ['a', 'b', 'c']) - data = { - **unfold_dct_val(output, per_graph_keys, sfx_list), - **cell_dct, - } - if writer is None: - writer = csv.DictWriter(f, fieldnames=data.keys()) - writer.writeheader() - writer.writerow(data) - - with open(f'{out}/per_atom.csv', 'w', newline='') as f: - writer = None - for i, output in enumerate(output_list): - list_of_dct = per_atom_dct_list(output, per_atom_keys) - for j, dct in enumerate(list_of_dct): - idx_dct = {'stct_id': i, 'atom_id': j} - data = {**idx_dct, **dct} - if writer is None: - writer = csv.DictWriter(f, fieldnames=data.keys()) - writer.writeheader() - writer.writerow(data) - - -def _patch_data_info( - graph_list: Iterable[AtomGraphData], full_file_list: List[str] -) -> None: - keys = set() - for graph, path in zip(graph_list, full_file_list): - if KEY.INFO not in graph: - graph[KEY.INFO] = {} - graph[KEY.INFO].update({'file': os.path.abspath(path)}) - keys.update(graph[KEY.INFO].keys()) - - # save only safe subset of info (for batching) - for graph in graph_list: - info_dict = graph[KEY.INFO] - info_dict.update({k: '' for k in keys if k not in info_dict}) - - -def inference( - checkpoint: str, - targets: Union[str, List[str]], - output_dir: str, - num_workers: int = 1, - device: str = 'cpu', - batch_size: int = 4, - save_graph: bool = False, - allow_unlabeled: bool = False, - modal: Optional[str] = None, - **data_kwargs, -) -> None: - """ - Inference model on the target dataset, writes - per_graph, per_atom inference results in csv format - to the output_dir - If a given target doesn't have EFS key, it puts dummy - values. - - Args: - checkpoint: model checkpoint path, - target: path, or list of path to evaluate. Supports - ASE readable, sevenn_data/*.pt, .sevenn_data, and - structure_list - output_dir: directory to write results - num_workers: number of workers to build graph - device: device to evaluate, defaults to 'auto' - batch_size: batch size for inference - save_grpah: if True, save preprocessed graph to output dir - data_kwargs: keyword arguments used when reading targets, - for example, given index='-1', only the last snapshot - will be evaluated if it was ASE readable. - While this function can handle different types of targets - at once, it will not work smoothly with data_kwargs - - """ - model, _ = util.model_from_checkpoint(checkpoint) - cutoff = model.cutoff - - if modal: - if model.modal_map is None: - raise ValueError('Modality given, but model has no modal_map') - if modal not in model.modal_map: - _modals = list(model.modal_map.keys()) - raise ValueError(f'Unknown modal {modal} (not in {_modals})') - - if isinstance(targets, str): - targets = [targets] - - full_file_list = [] - if save_graph: - dataset = SevenNetGraphDataset( - cutoff=cutoff, - root=output_dir, - files=targets, - process_num_cores=num_workers, - processed_name='saved_graph.pt', - **data_kwargs, - ) - full_file_list = dataset.full_file_list # TODO: not used currently - else: - dataset = [] - for file in targets: - tmplist = SevenNetGraphDataset.file_to_graph_list( - file, - cutoff=cutoff, - num_cores=num_workers, - allow_unlabeled=allow_unlabeled, - **data_kwargs, - ) - dataset.extend(tmplist) - full_file_list.extend([os.path.abspath(file)] * len(tmplist)) - if ( - full_file_list is not None - and len(full_file_list) == len(dataset) - and not isinstance(dataset, SevenNetGraphDataset) - ): - _patch_data_info(dataset, full_file_list) # type: ignore - - if modal: - dataset = SevenNetMultiModalDataset({modal: dataset}) # type: ignore - - loader = DataLoader(dataset, batch_size, shuffle=False) # type: ignore - - model.to(device) - model.set_is_batch_data(True) - model.eval() - - rec = util.get_error_recorder() - output_list = [] - - for batch in tqdm(loader): - batch = batch.to(device) - output = model(batch).detach().cpu() - rec.update(output) - output_list.extend(util.to_atom_graph_list(output)) - - errors = rec.epoch_forward() - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - with open(os.path.join(output_dir, 'errors.txt'), 'w', encoding='utf-8') as f: - for key, val in errors.items(): - f.write(f'{key}: {val}\n') - - write_inference_csv(output_list, output_dir) +import csv +import os +from typing import Iterable, List, Optional, Union + +import numpy as np +from torch_geometric.loader import DataLoader +from tqdm import tqdm + +import sevenn._keys as KEY +import sevenn.util as util +from sevenn.atom_graph_data import AtomGraphData +from sevenn.train.graph_dataset import SevenNetGraphDataset +from sevenn.train.modal_dataset import SevenNetMultiModalDataset + + +def write_inference_csv(output_list, out): + for i, output in enumerate(output_list): + output = output.fit_dimension() + output[KEY.STRESS] = output[KEY.STRESS] * 1602.1766208 + output[KEY.PRED_STRESS] = output[KEY.PRED_STRESS] * 1602.1766208 + output_list[i] = output.to_numpy_dict() + + per_graph_keys = [ + KEY.NUM_ATOMS, + KEY.USER_LABEL, + KEY.ENERGY, + KEY.PRED_TOTAL_ENERGY, + KEY.STRESS, + KEY.PRED_STRESS, + ] + + per_atom_keys = [ + KEY.ATOMIC_NUMBERS, + KEY.ATOMIC_ENERGY, + KEY.POS, + KEY.FORCE, + KEY.PRED_FORCE, + ] + + def unfold_dct_val(dct, keys, suffix_list=None): + res = {} + if suffix_list is None: + suffix_list = range(100) + for k in keys: + if k not in dct: + res[k] = '-' + elif isinstance(dct[k], np.ndarray) and dct[k].ndim != 0: + res.update( + {f'{k}_{suffix_list[i]}': v for i, v in enumerate(dct[k])} + ) + else: + res[k] = dct[k] + return res + + def per_atom_dct_list(dct, keys): + sfx_list = ['x', 'y', 'z'] + res = [] + natoms = dct[KEY.NUM_ATOMS] + extracted = {k: dct[k] for k in keys} + for i in range(natoms): + raw = {} + raw.update({k: v[i] for k, v in extracted.items()}) + per_atom_dct = unfold_dct_val(raw, keys, suffix_list=sfx_list) + res.append(per_atom_dct) + return res + + try: + with open(f'{out}/info.csv', 'w', newline='') as f: + header = output_list[0][KEY.INFO].keys() + writer = csv.DictWriter(f, fieldnames=header) + writer.writeheader() + for output in output_list: + writer.writerow(output[KEY.INFO]) + except (KeyError, TypeError, AttributeError, csv.Error) as e: + print(e) + print('failed to write meta data, info.csv is not written') + + with open(f'{out}/per_graph.csv', 'w', newline='') as f: + sfx_list = ['xx', 'yy', 'zz', 'xy', 'yz', 'zx'] # for stress + writer = None + for output in output_list: + cell_dct = {KEY.CELL: output[KEY.CELL]} + cell_dct = unfold_dct_val(cell_dct, [KEY.CELL], ['a', 'b', 'c']) + data = { + **unfold_dct_val(output, per_graph_keys, sfx_list), + **cell_dct, + } + if writer is None: + writer = csv.DictWriter(f, fieldnames=data.keys()) + writer.writeheader() + writer.writerow(data) + + with open(f'{out}/per_atom.csv', 'w', newline='') as f: + writer = None + for i, output in enumerate(output_list): + list_of_dct = per_atom_dct_list(output, per_atom_keys) + for j, dct in enumerate(list_of_dct): + idx_dct = {'stct_id': i, 'atom_id': j} + data = {**idx_dct, **dct} + if writer is None: + writer = csv.DictWriter(f, fieldnames=data.keys()) + writer.writeheader() + writer.writerow(data) + + +def _patch_data_info( + graph_list: Iterable[AtomGraphData], full_file_list: List[str] +) -> None: + keys = set() + for graph, path in zip(graph_list, full_file_list): + if KEY.INFO not in graph: + graph[KEY.INFO] = {} + graph[KEY.INFO].update({'file': os.path.abspath(path)}) + keys.update(graph[KEY.INFO].keys()) + + # save only safe subset of info (for batching) + for graph in graph_list: + info_dict = graph[KEY.INFO] + info_dict.update({k: '' for k in keys if k not in info_dict}) + + +def inference( + checkpoint: str, + targets: Union[str, List[str]], + output_dir: str, + num_workers: int = 1, + device: str = 'cpu', + batch_size: int = 4, + save_graph: bool = False, + allow_unlabeled: bool = False, + modal: Optional[str] = None, + **data_kwargs, +) -> None: + """ + Inference model on the target dataset, writes + per_graph, per_atom inference results in csv format + to the output_dir + If a given target doesn't have EFS key, it puts dummy + values. + + Args: + checkpoint: model checkpoint path, + target: path, or list of path to evaluate. Supports + ASE readable, sevenn_data/*.pt, .sevenn_data, and + structure_list + output_dir: directory to write results + num_workers: number of workers to build graph + device: device to evaluate, defaults to 'auto' + batch_size: batch size for inference + save_grpah: if True, save preprocessed graph to output dir + data_kwargs: keyword arguments used when reading targets, + for example, given index='-1', only the last snapshot + will be evaluated if it was ASE readable. + While this function can handle different types of targets + at once, it will not work smoothly with data_kwargs + + """ + model, _ = util.model_from_checkpoint(checkpoint) + cutoff = model.cutoff + + if modal: + if model.modal_map is None: + raise ValueError('Modality given, but model has no modal_map') + if modal not in model.modal_map: + _modals = list(model.modal_map.keys()) + raise ValueError(f'Unknown modal {modal} (not in {_modals})') + + if isinstance(targets, str): + targets = [targets] + + full_file_list = [] + if save_graph: + dataset = SevenNetGraphDataset( + cutoff=cutoff, + root=output_dir, + files=targets, + process_num_cores=num_workers, + processed_name='saved_graph.pt', + **data_kwargs, + ) + full_file_list = dataset.full_file_list # TODO: not used currently + else: + dataset = [] + for file in targets: + tmplist = SevenNetGraphDataset.file_to_graph_list( + file, + cutoff=cutoff, + num_cores=num_workers, + allow_unlabeled=allow_unlabeled, + **data_kwargs, + ) + dataset.extend(tmplist) + full_file_list.extend([os.path.abspath(file)] * len(tmplist)) + if ( + full_file_list is not None + and len(full_file_list) == len(dataset) + and not isinstance(dataset, SevenNetGraphDataset) + ): + _patch_data_info(dataset, full_file_list) # type: ignore + + if modal: + dataset = SevenNetMultiModalDataset({modal: dataset}) # type: ignore + + loader = DataLoader(dataset, batch_size, shuffle=False) # type: ignore + + model.to(device) + model.set_is_batch_data(True) + model.eval() + + rec = util.get_error_recorder() + output_list = [] + + for batch in tqdm(loader): + batch = batch.to(device) + output = model(batch).detach().cpu() + rec.update(output) + output_list.extend(util.to_atom_graph_list(output)) + + errors = rec.epoch_forward() + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, 'errors.txt'), 'w', encoding='utf-8') as f: + for key, val in errors.items(): + f.write(f'{key}: {val}\n') + + write_inference_csv(output_list, output_dir) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py index 8ad9db1..2537684 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_continue.py @@ -1,273 +1,273 @@ -import os -import warnings - -import torch - -import sevenn._keys as KEY -import sevenn.util as util -from sevenn.logger import Logger -from sevenn.scripts.convert_model_modality import ( - append_modality_to_model_dct, - get_single_modal_model_dct, -) - - -def processing_continue_v2(config): # simpler - """ - Replacement of processing_continue, - Skips model compatibility - """ - log = Logger() - continue_dct = config[KEY.CONTINUE] - log.write('\nContinue found, loading checkpoint\n') - - checkpoint = util.load_checkpoint(continue_dct[KEY.CHECKPOINT]) - model_cp = checkpoint.build_model() - config_cp = checkpoint.config - model_state_dict_cp = model_cp.state_dict() - - optimizer_state_dict_cp = ( - checkpoint.optimizer_state_dict - if not continue_dct[KEY.RESET_OPTIMIZER] - else None - ) - scheduler_state_dict_cp = ( - checkpoint.scheduler_state_dict - if not continue_dct[KEY.RESET_SCHEDULER] - else None - ) - - # use_statistic_value_of_checkpoint always True - # Overwrite config from model state dict, so graph_dataset.from_config - # will not put statistic values to shift, scale, and conv_denominator - config[KEY.SHIFT] = model_state_dict_cp['rescale_atomic_energy.shift'].tolist() - config[KEY.SCALE] = model_state_dict_cp['rescale_atomic_energy.scale'].tolist() - conv_denom = [] - for i in range(config_cp[KEY.NUM_CONVOLUTION]): - conv_denom.append(model_state_dict_cp[f'{i}_convolution.denominator'].item()) - config[KEY.CONV_DENOMINATOR] = conv_denom - log.writeline( - f'{KEY.SHIFT}, {KEY.SCALE}, and {KEY.CONV_DENOMINATOR} are ' - + 'overwritten by model_state_dict of checkpoint' - ) - - chem_keys = [ - KEY.TYPE_MAP, - KEY.NUM_SPECIES, - KEY.CHEMICAL_SPECIES, - KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, - ] - config.update({k: config_cp[k] for k in chem_keys}) - log.writeline( - 'chemical_species are overwritten by checkpoint. ' - + f'This model knows {config[KEY.NUM_SPECIES]} species' - ) - - if config_cp.get(KEY.USE_MODALITY, False) != config.get(KEY.USE_MODALITY): - raise ValueError('use_modality is not same. Check sevenn_cp') - - modal_map = config_cp.get(KEY.MODAL_MAP, None) # dict | None - if modal_map and len(modal_map) > 0: - modalities = list(modal_map.keys()) - log.writeline(f'Multimodal model found: {modalities}') - log.writeline('use_modality: True') - config[KEY.USE_MODALITY] = True - - from_epoch = checkpoint.epoch or 0 - log.writeline(f'Checkpoint previous epoch was: {from_epoch}') - epoch = 1 if continue_dct[KEY.RESET_EPOCH] else from_epoch + 1 - log.writeline(f'epoch start from {epoch}') - - log.writeline('checkpoint loading successful') - - state_dicts = [ - model_state_dict_cp, - optimizer_state_dict_cp, - scheduler_state_dict_cp, - ] - return state_dicts, epoch - - -def check_config_compatible(config, config_cp): - # TODO: check more - SHOULD_BE_SAME = [ - KEY.NODE_FEATURE_MULTIPLICITY, - KEY.LMAX, - KEY.IS_PARITY, - KEY.RADIAL_BASIS, - KEY.CUTOFF_FUNCTION, - KEY.CUTOFF, - KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS, - KEY.NUM_CONVOLUTION, - KEY.USE_BIAS_IN_LINEAR, - KEY.SELF_CONNECTION_TYPE, - ] - for sbs in SHOULD_BE_SAME: - if config[sbs] == config_cp[sbs]: - continue - if sbs == KEY.SELF_CONNECTION_TYPE and config_cp[sbs] == 'MACE': - warnings.warn( - 'We do not support this version of checkpoints to continue ' - "Please use self_connection_type='linear' in input.yaml " - 'and train from scratch', - UserWarning, - ) - raise ValueError( - f'Value of {sbs} should be same. {config[sbs]} != {config_cp[sbs]}' - ) - - try: - cntdct = config[KEY.CONTINUE] - except KeyError: - return - - TRAINABLE_CONFIGS = [KEY.TRAIN_DENOMINTAOR, KEY.TRAIN_SHIFT_SCALE] - if ( - any((not cntdct[KEY.RESET_SCHEDULER], not cntdct[KEY.RESET_OPTIMIZER])) - and all(config[k] == config_cp[k] for k in TRAINABLE_CONFIGS) is False - ): - raise ValueError( - 'reset optimizer and scheduler if you want to change ' - + 'trainable configs' - ) - - # TODO add conition for changed optim/scheduler but not reset - - -def processing_continue(config): - log = Logger() - continue_dct = config[KEY.CONTINUE] - log.write('\nContinue found, loading checkpoint\n') - - checkpoint = torch.load( - continue_dct[KEY.CHECKPOINT], map_location='cpu', weights_only=False - ) - config_cp = checkpoint['config'] - - model_cp, config_cp = util.model_from_checkpoint(checkpoint) - model_state_dict_cp = model_cp.state_dict() - - # it will raise error if not compatible - check_config_compatible(config, config_cp) - log.write('Checkpoint config is compatible\n') - - # for backward compat. - config.update({KEY._NORMALIZE_SPH: config_cp[KEY._NORMALIZE_SPH]}) - - from_epoch = checkpoint['epoch'] - optimizer_state_dict_cp = ( - checkpoint['optimizer_state_dict'] - if not continue_dct[KEY.RESET_OPTIMIZER] - else None - ) - scheduler_state_dict_cp = ( - checkpoint['scheduler_state_dict'] - if not continue_dct[KEY.RESET_SCHEDULER] - else None - ) - - # These could be changed based on given continue_input.yaml - # ex) adapt to statistics of fine-tuning dataset - shift_cp = model_state_dict_cp['rescale_atomic_energy.shift'].numpy() - del model_state_dict_cp['rescale_atomic_energy.shift'] - scale_cp = model_state_dict_cp['rescale_atomic_energy.scale'].numpy() - del model_state_dict_cp['rescale_atomic_energy.scale'] - conv_denominators = [] - for i in range(config_cp[KEY.NUM_CONVOLUTION]): - conv_denominators.append( - (model_state_dict_cp[f'{i}_convolution.denominator']).item() - ) - del model_state_dict_cp[f'{i}_convolution.denominator'] - - # Further handled by processing_dataset.py - config.update({ - KEY.SHIFT + '_cp': shift_cp, - KEY.SCALE + '_cp': scale_cp, - KEY.CONV_DENOMINATOR + '_cp': conv_denominators, - }) - - chem_keys = [ - KEY.TYPE_MAP, - KEY.NUM_SPECIES, - KEY.CHEMICAL_SPECIES, - KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, - ] - config.update({k: config_cp[k] for k in chem_keys}) - - if ( - KEY.USE_MODALITY in config_cp.keys() and config_cp[KEY.USE_MODALITY] - ): # checkpoint model is multimodal - config.update({ - KEY.MODAL_MAP + '_cp': config_cp[KEY.MODAL_MAP], - KEY.USE_MODALITY + '_cp': True, - KEY.NUM_MODALITIES + '_cp': len(config_cp[KEY.MODAL_MAP]), - }) - else: - config.update({ - KEY.MODAL_MAP + '_cp': {}, - KEY.USE_MODALITY + '_cp': False, - KEY.NUM_MODALITIES + '_cp': 0, - }) - - log.write(f'checkpoint previous epoch was: {from_epoch}\n') - - # decide start epoch - reset_epoch = continue_dct[KEY.RESET_EPOCH] - if reset_epoch: - start_epoch = 1 - log.write('epoch reset to 1\n') - else: - start_epoch = from_epoch + 1 - log.write(f'epoch start from {start_epoch}\n') - - # decide csv file to continue - init_csv = True - csv_fname = config_cp[KEY.CSV_LOG] - if os.path.isfile(csv_fname): - # I hope python compare dict well - if config_cp[KEY.ERROR_RECORD] == config[KEY.ERROR_RECORD]: - log.writeline('Same metric, csv file will be appended') - init_csv = False - else: - log.writeline(f'{csv_fname} file not found, new csv file will be created') - log.writeline('checkpoint loading was successful') - - state_dicts = [ - model_state_dict_cp, - optimizer_state_dict_cp, - scheduler_state_dict_cp, - ] - return state_dicts, start_epoch, init_csv - - -def convert_modality_of_checkpoint_state_dct(config, state_dicts): - # TODO: this requires updating model state dict after seeing dataset - model_state_dict_cp, optimizer_state_dict_cp, scheduler_state_dict_cp = ( - state_dicts - ) - - if config[KEY.USE_MODALITY]: # current model is multimodal - num_modalities_cp = len(config[KEY.MODAL_MAP + '_cp']) - append_modal_length = config[KEY.NUM_MODALITIES] - num_modalities_cp - - model_state_dict_cp = append_modality_to_model_dct( - model_state_dict_cp, config, num_modalities_cp, append_modal_length - ) - - else: # current model is single modal - if config[KEY.USE_MODALITY + '_cp']: # checkpoint model is multimodal - # change model state dict to single modal, default = "common" - model_state_dict_cp = get_single_modal_model_dct( - model_state_dict_cp, - config, - config[KEY.DEFAULT_MODAL], - from_processing_cp=True, - ) - - state_dicts = ( - model_state_dict_cp, - optimizer_state_dict_cp, - scheduler_state_dict_cp, - ) - - return state_dicts +import os +import warnings + +import torch + +import sevenn._keys as KEY +import sevenn.util as util +from sevenn.logger import Logger +from sevenn.scripts.convert_model_modality import ( + append_modality_to_model_dct, + get_single_modal_model_dct, +) + + +def processing_continue_v2(config): # simpler + """ + Replacement of processing_continue, + Skips model compatibility + """ + log = Logger() + continue_dct = config[KEY.CONTINUE] + log.write('\nContinue found, loading checkpoint\n') + + checkpoint = util.load_checkpoint(continue_dct[KEY.CHECKPOINT]) + model_cp = checkpoint.build_model() + config_cp = checkpoint.config + model_state_dict_cp = model_cp.state_dict() + + optimizer_state_dict_cp = ( + checkpoint.optimizer_state_dict + if not continue_dct[KEY.RESET_OPTIMIZER] + else None + ) + scheduler_state_dict_cp = ( + checkpoint.scheduler_state_dict + if not continue_dct[KEY.RESET_SCHEDULER] + else None + ) + + # use_statistic_value_of_checkpoint always True + # Overwrite config from model state dict, so graph_dataset.from_config + # will not put statistic values to shift, scale, and conv_denominator + config[KEY.SHIFT] = model_state_dict_cp['rescale_atomic_energy.shift'].tolist() + config[KEY.SCALE] = model_state_dict_cp['rescale_atomic_energy.scale'].tolist() + conv_denom = [] + for i in range(config_cp[KEY.NUM_CONVOLUTION]): + conv_denom.append(model_state_dict_cp[f'{i}_convolution.denominator'].item()) + config[KEY.CONV_DENOMINATOR] = conv_denom + log.writeline( + f'{KEY.SHIFT}, {KEY.SCALE}, and {KEY.CONV_DENOMINATOR} are ' + + 'overwritten by model_state_dict of checkpoint' + ) + + chem_keys = [ + KEY.TYPE_MAP, + KEY.NUM_SPECIES, + KEY.CHEMICAL_SPECIES, + KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, + ] + config.update({k: config_cp[k] for k in chem_keys}) + log.writeline( + 'chemical_species are overwritten by checkpoint. ' + + f'This model knows {config[KEY.NUM_SPECIES]} species' + ) + + if config_cp.get(KEY.USE_MODALITY, False) != config.get(KEY.USE_MODALITY): + raise ValueError('use_modality is not same. Check sevenn_cp') + + modal_map = config_cp.get(KEY.MODAL_MAP, None) # dict | None + if modal_map and len(modal_map) > 0: + modalities = list(modal_map.keys()) + log.writeline(f'Multimodal model found: {modalities}') + log.writeline('use_modality: True') + config[KEY.USE_MODALITY] = True + + from_epoch = checkpoint.epoch or 0 + log.writeline(f'Checkpoint previous epoch was: {from_epoch}') + epoch = 1 if continue_dct[KEY.RESET_EPOCH] else from_epoch + 1 + log.writeline(f'epoch start from {epoch}') + + log.writeline('checkpoint loading successful') + + state_dicts = [ + model_state_dict_cp, + optimizer_state_dict_cp, + scheduler_state_dict_cp, + ] + return state_dicts, epoch + + +def check_config_compatible(config, config_cp): + # TODO: check more + SHOULD_BE_SAME = [ + KEY.NODE_FEATURE_MULTIPLICITY, + KEY.LMAX, + KEY.IS_PARITY, + KEY.RADIAL_BASIS, + KEY.CUTOFF_FUNCTION, + KEY.CUTOFF, + KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS, + KEY.NUM_CONVOLUTION, + KEY.USE_BIAS_IN_LINEAR, + KEY.SELF_CONNECTION_TYPE, + ] + for sbs in SHOULD_BE_SAME: + if config[sbs] == config_cp[sbs]: + continue + if sbs == KEY.SELF_CONNECTION_TYPE and config_cp[sbs] == 'MACE': + warnings.warn( + 'We do not support this version of checkpoints to continue ' + "Please use self_connection_type='linear' in input.yaml " + 'and train from scratch', + UserWarning, + ) + raise ValueError( + f'Value of {sbs} should be same. {config[sbs]} != {config_cp[sbs]}' + ) + + try: + cntdct = config[KEY.CONTINUE] + except KeyError: + return + + TRAINABLE_CONFIGS = [KEY.TRAIN_DENOMINTAOR, KEY.TRAIN_SHIFT_SCALE] + if ( + any((not cntdct[KEY.RESET_SCHEDULER], not cntdct[KEY.RESET_OPTIMIZER])) + and all(config[k] == config_cp[k] for k in TRAINABLE_CONFIGS) is False + ): + raise ValueError( + 'reset optimizer and scheduler if you want to change ' + + 'trainable configs' + ) + + # TODO add conition for changed optim/scheduler but not reset + + +def processing_continue(config): + log = Logger() + continue_dct = config[KEY.CONTINUE] + log.write('\nContinue found, loading checkpoint\n') + + checkpoint = torch.load( + continue_dct[KEY.CHECKPOINT], map_location='cpu', weights_only=False + ) + config_cp = checkpoint['config'] + + model_cp, config_cp = util.model_from_checkpoint(checkpoint) + model_state_dict_cp = model_cp.state_dict() + + # it will raise error if not compatible + check_config_compatible(config, config_cp) + log.write('Checkpoint config is compatible\n') + + # for backward compat. + config.update({KEY._NORMALIZE_SPH: config_cp[KEY._NORMALIZE_SPH]}) + + from_epoch = checkpoint['epoch'] + optimizer_state_dict_cp = ( + checkpoint['optimizer_state_dict'] + if not continue_dct[KEY.RESET_OPTIMIZER] + else None + ) + scheduler_state_dict_cp = ( + checkpoint['scheduler_state_dict'] + if not continue_dct[KEY.RESET_SCHEDULER] + else None + ) + + # These could be changed based on given continue_input.yaml + # ex) adapt to statistics of fine-tuning dataset + shift_cp = model_state_dict_cp['rescale_atomic_energy.shift'].numpy() + del model_state_dict_cp['rescale_atomic_energy.shift'] + scale_cp = model_state_dict_cp['rescale_atomic_energy.scale'].numpy() + del model_state_dict_cp['rescale_atomic_energy.scale'] + conv_denominators = [] + for i in range(config_cp[KEY.NUM_CONVOLUTION]): + conv_denominators.append( + (model_state_dict_cp[f'{i}_convolution.denominator']).item() + ) + del model_state_dict_cp[f'{i}_convolution.denominator'] + + # Further handled by processing_dataset.py + config.update({ + KEY.SHIFT + '_cp': shift_cp, + KEY.SCALE + '_cp': scale_cp, + KEY.CONV_DENOMINATOR + '_cp': conv_denominators, + }) + + chem_keys = [ + KEY.TYPE_MAP, + KEY.NUM_SPECIES, + KEY.CHEMICAL_SPECIES, + KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER, + ] + config.update({k: config_cp[k] for k in chem_keys}) + + if ( + KEY.USE_MODALITY in config_cp.keys() and config_cp[KEY.USE_MODALITY] + ): # checkpoint model is multimodal + config.update({ + KEY.MODAL_MAP + '_cp': config_cp[KEY.MODAL_MAP], + KEY.USE_MODALITY + '_cp': True, + KEY.NUM_MODALITIES + '_cp': len(config_cp[KEY.MODAL_MAP]), + }) + else: + config.update({ + KEY.MODAL_MAP + '_cp': {}, + KEY.USE_MODALITY + '_cp': False, + KEY.NUM_MODALITIES + '_cp': 0, + }) + + log.write(f'checkpoint previous epoch was: {from_epoch}\n') + + # decide start epoch + reset_epoch = continue_dct[KEY.RESET_EPOCH] + if reset_epoch: + start_epoch = 1 + log.write('epoch reset to 1\n') + else: + start_epoch = from_epoch + 1 + log.write(f'epoch start from {start_epoch}\n') + + # decide csv file to continue + init_csv = True + csv_fname = config_cp[KEY.CSV_LOG] + if os.path.isfile(csv_fname): + # I hope python compare dict well + if config_cp[KEY.ERROR_RECORD] == config[KEY.ERROR_RECORD]: + log.writeline('Same metric, csv file will be appended') + init_csv = False + else: + log.writeline(f'{csv_fname} file not found, new csv file will be created') + log.writeline('checkpoint loading was successful') + + state_dicts = [ + model_state_dict_cp, + optimizer_state_dict_cp, + scheduler_state_dict_cp, + ] + return state_dicts, start_epoch, init_csv + + +def convert_modality_of_checkpoint_state_dct(config, state_dicts): + # TODO: this requires updating model state dict after seeing dataset + model_state_dict_cp, optimizer_state_dict_cp, scheduler_state_dict_cp = ( + state_dicts + ) + + if config[KEY.USE_MODALITY]: # current model is multimodal + num_modalities_cp = len(config[KEY.MODAL_MAP + '_cp']) + append_modal_length = config[KEY.NUM_MODALITIES] - num_modalities_cp + + model_state_dict_cp = append_modality_to_model_dct( + model_state_dict_cp, config, num_modalities_cp, append_modal_length + ) + + else: # current model is single modal + if config[KEY.USE_MODALITY + '_cp']: # checkpoint model is multimodal + # change model state dict to single modal, default = "common" + model_state_dict_cp = get_single_modal_model_dct( + model_state_dict_cp, + config, + config[KEY.DEFAULT_MODAL], + from_processing_cp=True, + ) + + state_dicts = ( + model_state_dict_cp, + optimizer_state_dict_cp, + scheduler_state_dict_cp, + ) + + return state_dicts diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py index fe1e1e4..64e7961 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_dataset.py @@ -1,481 +1,481 @@ -import os - -import torch -import torch.distributed as dist - -import sevenn._const as CONST -import sevenn._keys as KEY -from sevenn.logger import Logger -from sevenn.train.dataload import file_to_dataset, match_reader -from sevenn.train.dataset import AtomGraphDataset -from sevenn.util import chemical_species_preprocess, onehot_to_chem - - -def dataset_load(file: str, config): - """ - Wrapping of dataload.file_to_dataset to suppert - graph prebuilt sevenn_data - """ - log = Logger() - log.write(f'Loading {file}\n') - log.timer_start('loading dataset') - - if file.endswith('.sevenn_data'): - dataset = torch.load(file, map_location='cpu', weights_only=False) - else: - reader, _ = match_reader( - config[KEY.DATA_FORMAT], **config[KEY.DATA_FORMAT_ARGS] - ) - dataset = file_to_dataset( - file, - config[KEY.CUTOFF], - config[KEY.PREPROCESS_NUM_CORES], - reader=reader, - use_modality=config[KEY.USE_MODALITY], - use_weight=config[KEY.USE_WEIGHT], - ) - log.format_k_v('loaded dataset size is', dataset.len(), write=True) - log.timer_end('loading dataset', 'data set loading time') - return dataset - - -def calculate_shift_or_scale_from_key( - train_set: AtomGraphDataset, key_given, n_chem -): - _expand = True - use_species_wise_shift_scale = False - if key_given == 'per_atom_energy_mean': - shift_or_scale = train_set.get_per_atom_energy_mean() - elif key_given == 'elemwise_reference_energies': - shift_or_scale = train_set.get_species_ref_energy_by_linear_comb(n_chem) - _expand = False - use_species_wise_shift_scale = True - - elif key_given == 'force_rms': - shift_or_scale = train_set.get_force_rms() - elif key_given == 'per_atom_energy_std': - shift_or_scale = train_set.get_statistics(KEY.PER_ATOM_ENERGY)['Total'][ - 'std' - ] - elif key_given == 'elemwise_force_rms': - shift_or_scale = train_set.get_species_wise_force_rms(n_chem) - _expand = False - use_species_wise_shift_scale = True - - return shift_or_scale, _expand, use_species_wise_shift_scale - - -def handle_shift_scale(config, train_set: AtomGraphDataset, checkpoint_given): - """ - Priority (first comes later to overwrite): - 1. Float given in yaml - 2. Use statistic values of checkpoint == True - 3. Plain options (provided as string) - """ - log = Logger() - shift, scale, conv_denominator = None, None, None - type_map = config[KEY.TYPE_MAP] - n_chem = len(type_map) - chem_strs = onehot_to_chem(list(range(n_chem)), type_map) - - log.writeline('\nCalculating statistic values from dataset') - - shift_given = config[KEY.SHIFT] - scale_given = config[KEY.SCALE] - _expand_shift = True - _expand_scale = True - use_species_wise_shift = False - use_species_wise_scale = False - - use_modal_wise_shift = config[KEY.USE_MODAL_WISE_SHIFT] - use_modal_wise_scale = config[KEY.USE_MODAL_WISE_SCALE] - - if shift_given in CONST.IMPLEMENTED_SHIFT: - shift, _expand_shift, use_species_wise_shift = ( - calculate_shift_or_scale_from_key(train_set, shift_given, n_chem) - ) - - if scale_given in CONST.IMPLEMENTED_SCALE: - scale, _expand_scale, use_species_wise_scale = ( - calculate_shift_or_scale_from_key(train_set, scale_given, n_chem) - ) - - if use_modal_wise_shift or use_modal_wise_scale: - atomdata_dict_sort_by_modal = train_set.get_dict_sort_by_modality() - modal_map = config[KEY.MODAL_MAP] - n_modal = len(modal_map) - cutoff = config[KEY.CUTOFF] - - if use_modal_wise_shift: - shift = torch.zeros((n_modal, n_chem)) - - if use_modal_wise_scale: - scale = torch.zeros((n_modal, n_chem)) - - for modal_key, data_list in atomdata_dict_sort_by_modal.items(): - modal_set = AtomGraphDataset(data_list, cutoff, x_is_one_hot_idx=True) - - if use_modal_wise_shift: - if shift_given == 'elemwise_reference_energies': - modal_shift, _expand_shift, use_species_wise_shift = ( - calculate_shift_or_scale_from_key( - modal_set, shift_given, n_chem - ) - ) - shift[modal_map[modal_key]] = torch.tensor( - modal_shift - ) # this is np.array - elif shift_given in CONST.IMPLEMENTED_SHIFT: - raise NotImplementedError( - 'Currently, modal-wise shift implemented for' - 'species-dependent case only.' - ) - - if use_modal_wise_scale: - if scale_given == 'elemwise_force_rms': - modal_scale, _expand_scale, use_species_wise_scale = ( - calculate_shift_or_scale_from_key( - modal_set, scale_given, n_chem - ) - ) - scale[modal_map[modal_key]] = modal_scale - elif scale_given in CONST.IMPLEMENTED_SCALE: - raise NotImplementedError( - 'Currently, modal-wise scale implemented for' - 'species-dependent case only.' - ) - - avg_num_neigh = train_set.get_avg_num_neigh() - log.format_k_v('Average # of neighbors', f'{avg_num_neigh:.6f}', write=True) - - if config[KEY.CONV_DENOMINATOR] == 'avg_num_neigh': - conv_denominator = avg_num_neigh - elif config[KEY.CONV_DENOMINATOR] == 'sqrt_avg_num_neigh': - conv_denominator = avg_num_neigh ** (0.5) - - if ( - checkpoint_given - and config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT] - ): - log.writeline( - 'Overwrite shift, scale, conv_denominator from model checkpoint' - ) - # TODO: This needs refactoring - conv_denominator = config[KEY.CONV_DENOMINATOR + '_cp'] - if not (use_modal_wise_shift or use_modal_wise_scale): - # Values extracted from checkpoint in processing_continue.py - if len(list(shift)) > 1: - use_species_wise_shift = True - use_species_wise_scale = True - _expand_shift = _expand_scale = False - else: - shift = shift.item() - scale = scale.item() - else: - # Case of modal wise shift scale - shift_cp = config[KEY.SHIFT + '_cp'] - scale_cp = config[KEY.SCALE + '_cp'] - if not use_modal_wise_shift: - shift = shift_cp - if not use_modal_wise_scale: - scale = scale_cp - modal_map = config[KEY.MODAL_MAP] - modal_map_cp = config[KEY.MODAL_MAP + '_cp'] - - # Extracting shift, scale for modal in checkpoint model. - if config[KEY.USE_MODALITY + '_cp']: # cp model is multimodal - for modal_key_cp, modal_idx_cp in modal_map_cp.items(): - modal_idx = modal_map[modal_key_cp] - if use_modal_wise_shift: - shift[modal_idx] = torch.tensor(shift_cp[modal_idx_cp]) - if use_modal_wise_scale: - scale[modal_idx] = torch.tensor(scale_cp[modal_idx_cp]) - - else: # cp model is single modal - try: - modal_idx = modal_map[config[KEY.DEFAULT_MODAL]] - except: - raise KeyError( - f'{config[KEY.DEFAULT_MODAL]} should be one of' - f' {modal_map.keys()}' - ) - if use_modal_wise_shift: - shift[modal_idx] = torch.tensor(shift_cp) - if use_modal_wise_scale: - scale[modal_idx] = torch.tensor(scale_cp) - - if not config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY]: - # Also overwrite values of new modal to reference value - # For multimodal, set reference modal with KEY.DEFAULT_MODAL - shift_ref = shift_cp - scale_ref = scale_cp - if config[KEY.USE_MODALITY + '_cp']: - try: - modal_idx_cp = modal_map_cp[config[KEY.DEFAULT_MODAL]] - except: - raise KeyError( - f'{config[KEY.DEFAULT_MODAL]} should be one of' - f' {modal_map_cp.keys()}' - ) - shift_ref = shift_cp[modal_idx_cp] - scale_ref = scale_cp[modal_idx_cp] - - for modal_key, modal_idx in modal_map.items(): - if modal_key not in modal_map_cp.keys(): - if use_modal_wise_shift: - shift[modal_idx] = shift_ref - if use_modal_wise_scale: - scale[modal_idx] = scale_ref - - # overwrite shift scale anyway if defined in yaml. - if type(shift_given) in [list, float]: - log.writeline('Overwrite shift to value(s) given in yaml') - _expand_shift = isinstance(shift_given, float) - shift = shift_given - if type(scale_given) in [list, float]: - log.writeline('Overwrite scale to value(s) given in yaml') - _expand_scale = isinstance(scale_given, float) - scale = scale_given - - if isinstance(config[KEY.CONV_DENOMINATOR], float): - log.writeline('Overwrite conv_denominator to value given in yaml') - conv_denominator = config[KEY.CONV_DENOMINATOR] - - if isinstance(conv_denominator, float): - conv_denominator = [conv_denominator] * config[KEY.NUM_CONVOLUTION] - - use_species_wise_shift_scale = use_species_wise_shift or use_species_wise_scale - if use_species_wise_shift_scale: - chem_strs = onehot_to_chem(list(range(n_chem)), type_map) - if _expand_shift: - if use_modal_wise_shift: - shift = torch.full((n_modal, n_chem), shift) - else: - shift = [shift] * n_chem - if _expand_scale: - if use_modal_wise_scale: - scale = torch.full((n_modal, n_chem), scale) - else: - scale = [scale] * n_chem - - Logger().write('Use element-wise shift, scale\n') - if use_modal_wise_shift or use_modal_wise_scale: - for modal_key, modal_idx in modal_map.items(): - Logger().writeline(f'For modal = {modal_key}') - print_shift = shift[modal_idx] if use_modal_wise_shift else shift - print_scale = scale[modal_idx] if use_modal_wise_scale else scale - for cstr, sh, sc in zip(chem_strs, print_shift, print_scale): - Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True) - else: - for cstr, sh, sc in zip(chem_strs, shift, scale): - Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True) - else: - log.write('Use global shift, scale\n') - log.format_k_v('shift, scale', f'{shift:.6f}, {scale:.6f}', write=True) - - assert isinstance(conv_denominator, list) and all( - isinstance(deno, float) for deno in conv_denominator - ) - log.format_k_v( - '(1st) conv_denominator is', f'{conv_denominator[0]:.6f}', write=True - ) - - config[KEY.USE_SPECIES_WISE_SHIFT_SCALE] = use_species_wise_shift_scale - return shift, scale, conv_denominator - - -# TODO: This is too long -def processing_dataset(config, working_dir): - log = Logger() - prefix = f'{os.path.abspath(working_dir)}/' - is_stress = config[KEY.IS_TRAIN_STRESS] - checkpoint_given = config[KEY.CONTINUE][KEY.CHECKPOINT] is not False - cutoff = config[KEY.CUTOFF] - - log.write('\nInitializing dataset...\n') - - dataset = AtomGraphDataset({}, cutoff) - load_dataset = config[KEY.LOAD_DATASET] - if type(load_dataset) is str: - load_dataset = [load_dataset] - for file in load_dataset: - dataset.augment(dataset_load(file, config)) - - dataset.group_by_key() # apply labels inside original datapoint - dataset.unify_dtypes() # unify dtypes of all data points - - # TODO: I think manual chemical species input is redundant - chem_in_db = dataset.get_species() - if config[KEY.CHEMICAL_SPECIES] == 'auto' and not checkpoint_given: - log.writeline('Auto detect chemical species from dataset') - config.update(chemical_species_preprocess(chem_in_db)) - elif config[KEY.CHEMICAL_SPECIES] == 'auto' and checkpoint_given: - pass # copied from checkpoint in processing_continue.py - elif config[KEY.CHEMICAL_SPECIES] != 'auto' and not checkpoint_given: - pass # processed in parse_input.py - else: # config[KEY.CHEMICAL_SPECIES] != "auto" and checkpoint_given - log.writeline('Ignore chemical species in yaml, use checkpoint') - # already processed in processing_continue.py - - # basic dataset compatibility check with previous model - if checkpoint_given: - chem_from_cp = config[KEY.CHEMICAL_SPECIES] - if not all(chem in chem_from_cp for chem in chem_in_db): - raise ValueError('Chemical species in checkpoint is not compatible') - - # check what modalities are used in dataset - if config[KEY.USE_MODALITY]: - modalities = dataset.get_modalities() - num_modalities = len(modalities) - if num_modalities < 2: - Logger().writeline('Only one modal is given, ignore modality') - config.uptate({KEY.USE_MODALITY: False}) - - else: - modal_map_cp = config[KEY.MODAL_MAP + '_cp'] if checkpoint_given else {} - modal_map = modal_map_cp.copy() - current_idx = len(modal_map_cp) - for modal_key in modalities: - if modal_key not in modal_map.keys(): - modal_map[modal_key] = current_idx - current_idx += 1 - - if config[KEY.IS_DDP]: - # Synchronize modal_map - torch.cuda.set_device(config[KEY.LOCAL_RANK]) - modal_map_bcast = [modal_map] - dist.broadcast_object_list(modal_map_bcast, src=0) - modal_map = modal_map_bcast[0] - - config.update( - { - KEY.NUM_MODALITIES: len(modal_map), - KEY.MODAL_MAP: modal_map, - KEY.MODAL_LIST: list(modal_map.keys()), - } - ) - - dataset.write_modal_attr( - modal_map, - config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE], - ) - - # --------------- save dataset regardless of train/valid--------------# - save_dataset = config[KEY.SAVE_DATASET] - save_by_label = config[KEY.SAVE_BY_LABEL] - if save_dataset: - if save_dataset.endswith('.sevenn_data') is False: - save_dataset += '.sevenn_data' - if (save_dataset.startswith('.') or save_dataset.startswith('/')) is False: - save_dataset = prefix + save_dataset # save_data set is plain file name - dataset.save(save_dataset) - log.format_k_v('Dataset saved to', save_dataset, write=True) - # log.write(f"Loaded full dataset saved to : {save_dataset}\n") - if save_by_label: - dataset.save(prefix, by_label=True) - log.format_k_v('Dataset saved by label', prefix, write=True) - # --------------------------------------------------------------------# - - # TODO: testset is not used - ignore_test = not config.get(KEY.USE_TESTSET, False) - if KEY.LOAD_VALIDSET in config and config[KEY.LOAD_VALIDSET]: - train_set = dataset - test_set = AtomGraphDataset([], config[KEY.CUTOFF]) - - log.write('Loading validset from load_validset\n') - valid_set = AtomGraphDataset({}, cutoff) - for file in config[KEY.LOAD_VALIDSET]: - valid_set.augment(dataset_load(file, config)) - valid_set.group_by_key() - valid_set.unify_dtypes() - - # condition: validset labels should be subset of trainset labels - valid_labels = valid_set.user_labels - train_labels = train_set.user_labels - if set(valid_labels).issubset(set(train_labels)) is False: - valid_set = AtomGraphDataset(valid_set.to_list(), cutoff) - valid_set.rewrite_labels_to_data() - train_set = AtomGraphDataset(train_set.to_list(), cutoff) - train_set.rewrite_labels_to_data() - Logger().write('WARNING! validset labels is not subset of trainset\n') - Logger().write('We overwrite all the train, valid labels to default.\n') - Logger().write('Please create validset by sevenn_graph_build with -l\n') - - Logger().write('the validset loaded, load_dataset is now train_set\n') - Logger().write('the ratio will be ignored\n') - - # condition: validset modalities should be subset of trainset modalities - if config[KEY.USE_MODALITY]: - config_modality = config[KEY.MODAL_LIST] - valid_modality = valid_set.get_modalities() - - if set(valid_modality).issubset(set(config_modality)) is False: - raise ValueError('validset modality is not subset of trainset') - - valid_set.write_modal_attr( - config[KEY.MODAL_MAP], - config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE], - ) - else: - train_set, valid_set, test_set = dataset.divide_dataset( - config[KEY.RATIO], ignore_test=ignore_test - ) - log.write(f'The dataset divided into train, valid by {KEY.RATIO}\n') - - log.format_k_v('\nloaded trainset size is', train_set.len(), write=True) - log.format_k_v('\nloaded validset size is', valid_set.len(), write=True) - - log.write('Dataset initialization was successful\n') - - log.write('\nNumber of atoms in the train_set:\n') - log.natoms_write(train_set.get_natoms(config[KEY.TYPE_MAP])) - - log.bar() - log.write('Per atom energy(eV/atom) distribution:\n') - log.statistic_write(train_set.get_statistics(KEY.PER_ATOM_ENERGY)) - log.bar() - log.write('Force(eV/Angstrom) distribution:\n') - log.statistic_write(train_set.get_statistics(KEY.FORCE)) - log.bar() - log.write('Stress(eV/Angstrom^3) distribution:\n') - try: - log.statistic_write(train_set.get_statistics(KEY.STRESS)) - except KeyError: - log.write('\n Stress is not included in the train_set\n') - if is_stress: - is_stress = False - log.write('Turn off stress training\n') - log.bar() - - # saved data must have atomic numbers as X not one hot idx - if config[KEY.SAVE_BY_TRAIN_VALID]: - train_set.save(prefix + 'train') - valid_set.save(prefix + 'valid') - log.format_k_v('Dataset saved by train, valid', prefix, write=True) - - # inconsistent .info dict give error when collate - _, _ = train_set.separate_info() - _, _ = valid_set.separate_info() - - if train_set.x_is_one_hot_idx is False: - train_set.x_to_one_hot_idx(config[KEY.TYPE_MAP]) - if valid_set.x_is_one_hot_idx is False: - valid_set.x_to_one_hot_idx(config[KEY.TYPE_MAP]) - - log.format_k_v('training_set size', train_set.len(), write=True) - log.format_k_v('validation_set size', valid_set.len(), write=True) - - shift, scale, conv_denominator = handle_shift_scale( - config, train_set, checkpoint_given - ) - config.update( - { - KEY.SHIFT: shift, - KEY.SCALE: scale, - KEY.CONV_DENOMINATOR: conv_denominator, - } - ) - - data_lists = (train_set.to_list(), valid_set.to_list(), test_set.to_list()) - - return data_lists +import os + +import torch +import torch.distributed as dist + +import sevenn._const as CONST +import sevenn._keys as KEY +from sevenn.logger import Logger +from sevenn.train.dataload import file_to_dataset, match_reader +from sevenn.train.dataset import AtomGraphDataset +from sevenn.util import chemical_species_preprocess, onehot_to_chem + + +def dataset_load(file: str, config): + """ + Wrapping of dataload.file_to_dataset to suppert + graph prebuilt sevenn_data + """ + log = Logger() + log.write(f'Loading {file}\n') + log.timer_start('loading dataset') + + if file.endswith('.sevenn_data'): + dataset = torch.load(file, map_location='cpu', weights_only=False) + else: + reader, _ = match_reader( + config[KEY.DATA_FORMAT], **config[KEY.DATA_FORMAT_ARGS] + ) + dataset = file_to_dataset( + file, + config[KEY.CUTOFF], + config[KEY.PREPROCESS_NUM_CORES], + reader=reader, + use_modality=config[KEY.USE_MODALITY], + use_weight=config[KEY.USE_WEIGHT], + ) + log.format_k_v('loaded dataset size is', dataset.len(), write=True) + log.timer_end('loading dataset', 'data set loading time') + return dataset + + +def calculate_shift_or_scale_from_key( + train_set: AtomGraphDataset, key_given, n_chem +): + _expand = True + use_species_wise_shift_scale = False + if key_given == 'per_atom_energy_mean': + shift_or_scale = train_set.get_per_atom_energy_mean() + elif key_given == 'elemwise_reference_energies': + shift_or_scale = train_set.get_species_ref_energy_by_linear_comb(n_chem) + _expand = False + use_species_wise_shift_scale = True + + elif key_given == 'force_rms': + shift_or_scale = train_set.get_force_rms() + elif key_given == 'per_atom_energy_std': + shift_or_scale = train_set.get_statistics(KEY.PER_ATOM_ENERGY)['Total'][ + 'std' + ] + elif key_given == 'elemwise_force_rms': + shift_or_scale = train_set.get_species_wise_force_rms(n_chem) + _expand = False + use_species_wise_shift_scale = True + + return shift_or_scale, _expand, use_species_wise_shift_scale + + +def handle_shift_scale(config, train_set: AtomGraphDataset, checkpoint_given): + """ + Priority (first comes later to overwrite): + 1. Float given in yaml + 2. Use statistic values of checkpoint == True + 3. Plain options (provided as string) + """ + log = Logger() + shift, scale, conv_denominator = None, None, None + type_map = config[KEY.TYPE_MAP] + n_chem = len(type_map) + chem_strs = onehot_to_chem(list(range(n_chem)), type_map) + + log.writeline('\nCalculating statistic values from dataset') + + shift_given = config[KEY.SHIFT] + scale_given = config[KEY.SCALE] + _expand_shift = True + _expand_scale = True + use_species_wise_shift = False + use_species_wise_scale = False + + use_modal_wise_shift = config[KEY.USE_MODAL_WISE_SHIFT] + use_modal_wise_scale = config[KEY.USE_MODAL_WISE_SCALE] + + if shift_given in CONST.IMPLEMENTED_SHIFT: + shift, _expand_shift, use_species_wise_shift = ( + calculate_shift_or_scale_from_key(train_set, shift_given, n_chem) + ) + + if scale_given in CONST.IMPLEMENTED_SCALE: + scale, _expand_scale, use_species_wise_scale = ( + calculate_shift_or_scale_from_key(train_set, scale_given, n_chem) + ) + + if use_modal_wise_shift or use_modal_wise_scale: + atomdata_dict_sort_by_modal = train_set.get_dict_sort_by_modality() + modal_map = config[KEY.MODAL_MAP] + n_modal = len(modal_map) + cutoff = config[KEY.CUTOFF] + + if use_modal_wise_shift: + shift = torch.zeros((n_modal, n_chem)) + + if use_modal_wise_scale: + scale = torch.zeros((n_modal, n_chem)) + + for modal_key, data_list in atomdata_dict_sort_by_modal.items(): + modal_set = AtomGraphDataset(data_list, cutoff, x_is_one_hot_idx=True) + + if use_modal_wise_shift: + if shift_given == 'elemwise_reference_energies': + modal_shift, _expand_shift, use_species_wise_shift = ( + calculate_shift_or_scale_from_key( + modal_set, shift_given, n_chem + ) + ) + shift[modal_map[modal_key]] = torch.tensor( + modal_shift + ) # this is np.array + elif shift_given in CONST.IMPLEMENTED_SHIFT: + raise NotImplementedError( + 'Currently, modal-wise shift implemented for' + 'species-dependent case only.' + ) + + if use_modal_wise_scale: + if scale_given == 'elemwise_force_rms': + modal_scale, _expand_scale, use_species_wise_scale = ( + calculate_shift_or_scale_from_key( + modal_set, scale_given, n_chem + ) + ) + scale[modal_map[modal_key]] = modal_scale + elif scale_given in CONST.IMPLEMENTED_SCALE: + raise NotImplementedError( + 'Currently, modal-wise scale implemented for' + 'species-dependent case only.' + ) + + avg_num_neigh = train_set.get_avg_num_neigh() + log.format_k_v('Average # of neighbors', f'{avg_num_neigh:.6f}', write=True) + + if config[KEY.CONV_DENOMINATOR] == 'avg_num_neigh': + conv_denominator = avg_num_neigh + elif config[KEY.CONV_DENOMINATOR] == 'sqrt_avg_num_neigh': + conv_denominator = avg_num_neigh ** (0.5) + + if ( + checkpoint_given + and config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT] + ): + log.writeline( + 'Overwrite shift, scale, conv_denominator from model checkpoint' + ) + # TODO: This needs refactoring + conv_denominator = config[KEY.CONV_DENOMINATOR + '_cp'] + if not (use_modal_wise_shift or use_modal_wise_scale): + # Values extracted from checkpoint in processing_continue.py + if len(list(shift)) > 1: + use_species_wise_shift = True + use_species_wise_scale = True + _expand_shift = _expand_scale = False + else: + shift = shift.item() + scale = scale.item() + else: + # Case of modal wise shift scale + shift_cp = config[KEY.SHIFT + '_cp'] + scale_cp = config[KEY.SCALE + '_cp'] + if not use_modal_wise_shift: + shift = shift_cp + if not use_modal_wise_scale: + scale = scale_cp + modal_map = config[KEY.MODAL_MAP] + modal_map_cp = config[KEY.MODAL_MAP + '_cp'] + + # Extracting shift, scale for modal in checkpoint model. + if config[KEY.USE_MODALITY + '_cp']: # cp model is multimodal + for modal_key_cp, modal_idx_cp in modal_map_cp.items(): + modal_idx = modal_map[modal_key_cp] + if use_modal_wise_shift: + shift[modal_idx] = torch.tensor(shift_cp[modal_idx_cp]) + if use_modal_wise_scale: + scale[modal_idx] = torch.tensor(scale_cp[modal_idx_cp]) + + else: # cp model is single modal + try: + modal_idx = modal_map[config[KEY.DEFAULT_MODAL]] + except: + raise KeyError( + f'{config[KEY.DEFAULT_MODAL]} should be one of' + f' {modal_map.keys()}' + ) + if use_modal_wise_shift: + shift[modal_idx] = torch.tensor(shift_cp) + if use_modal_wise_scale: + scale[modal_idx] = torch.tensor(scale_cp) + + if not config[KEY.CONTINUE][KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY]: + # Also overwrite values of new modal to reference value + # For multimodal, set reference modal with KEY.DEFAULT_MODAL + shift_ref = shift_cp + scale_ref = scale_cp + if config[KEY.USE_MODALITY + '_cp']: + try: + modal_idx_cp = modal_map_cp[config[KEY.DEFAULT_MODAL]] + except: + raise KeyError( + f'{config[KEY.DEFAULT_MODAL]} should be one of' + f' {modal_map_cp.keys()}' + ) + shift_ref = shift_cp[modal_idx_cp] + scale_ref = scale_cp[modal_idx_cp] + + for modal_key, modal_idx in modal_map.items(): + if modal_key not in modal_map_cp.keys(): + if use_modal_wise_shift: + shift[modal_idx] = shift_ref + if use_modal_wise_scale: + scale[modal_idx] = scale_ref + + # overwrite shift scale anyway if defined in yaml. + if type(shift_given) in [list, float]: + log.writeline('Overwrite shift to value(s) given in yaml') + _expand_shift = isinstance(shift_given, float) + shift = shift_given + if type(scale_given) in [list, float]: + log.writeline('Overwrite scale to value(s) given in yaml') + _expand_scale = isinstance(scale_given, float) + scale = scale_given + + if isinstance(config[KEY.CONV_DENOMINATOR], float): + log.writeline('Overwrite conv_denominator to value given in yaml') + conv_denominator = config[KEY.CONV_DENOMINATOR] + + if isinstance(conv_denominator, float): + conv_denominator = [conv_denominator] * config[KEY.NUM_CONVOLUTION] + + use_species_wise_shift_scale = use_species_wise_shift or use_species_wise_scale + if use_species_wise_shift_scale: + chem_strs = onehot_to_chem(list(range(n_chem)), type_map) + if _expand_shift: + if use_modal_wise_shift: + shift = torch.full((n_modal, n_chem), shift) + else: + shift = [shift] * n_chem + if _expand_scale: + if use_modal_wise_scale: + scale = torch.full((n_modal, n_chem), scale) + else: + scale = [scale] * n_chem + + Logger().write('Use element-wise shift, scale\n') + if use_modal_wise_shift or use_modal_wise_scale: + for modal_key, modal_idx in modal_map.items(): + Logger().writeline(f'For modal = {modal_key}') + print_shift = shift[modal_idx] if use_modal_wise_shift else shift + print_scale = scale[modal_idx] if use_modal_wise_scale else scale + for cstr, sh, sc in zip(chem_strs, print_shift, print_scale): + Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True) + else: + for cstr, sh, sc in zip(chem_strs, shift, scale): + Logger().format_k_v(f'{cstr}', f'{sh:.6f}, {sc:.6f}', write=True) + else: + log.write('Use global shift, scale\n') + log.format_k_v('shift, scale', f'{shift:.6f}, {scale:.6f}', write=True) + + assert isinstance(conv_denominator, list) and all( + isinstance(deno, float) for deno in conv_denominator + ) + log.format_k_v( + '(1st) conv_denominator is', f'{conv_denominator[0]:.6f}', write=True + ) + + config[KEY.USE_SPECIES_WISE_SHIFT_SCALE] = use_species_wise_shift_scale + return shift, scale, conv_denominator + + +# TODO: This is too long +def processing_dataset(config, working_dir): + log = Logger() + prefix = f'{os.path.abspath(working_dir)}/' + is_stress = config[KEY.IS_TRAIN_STRESS] + checkpoint_given = config[KEY.CONTINUE][KEY.CHECKPOINT] is not False + cutoff = config[KEY.CUTOFF] + + log.write('\nInitializing dataset...\n') + + dataset = AtomGraphDataset({}, cutoff) + load_dataset = config[KEY.LOAD_DATASET] + if type(load_dataset) is str: + load_dataset = [load_dataset] + for file in load_dataset: + dataset.augment(dataset_load(file, config)) + + dataset.group_by_key() # apply labels inside original datapoint + dataset.unify_dtypes() # unify dtypes of all data points + + # TODO: I think manual chemical species input is redundant + chem_in_db = dataset.get_species() + if config[KEY.CHEMICAL_SPECIES] == 'auto' and not checkpoint_given: + log.writeline('Auto detect chemical species from dataset') + config.update(chemical_species_preprocess(chem_in_db)) + elif config[KEY.CHEMICAL_SPECIES] == 'auto' and checkpoint_given: + pass # copied from checkpoint in processing_continue.py + elif config[KEY.CHEMICAL_SPECIES] != 'auto' and not checkpoint_given: + pass # processed in parse_input.py + else: # config[KEY.CHEMICAL_SPECIES] != "auto" and checkpoint_given + log.writeline('Ignore chemical species in yaml, use checkpoint') + # already processed in processing_continue.py + + # basic dataset compatibility check with previous model + if checkpoint_given: + chem_from_cp = config[KEY.CHEMICAL_SPECIES] + if not all(chem in chem_from_cp for chem in chem_in_db): + raise ValueError('Chemical species in checkpoint is not compatible') + + # check what modalities are used in dataset + if config[KEY.USE_MODALITY]: + modalities = dataset.get_modalities() + num_modalities = len(modalities) + if num_modalities < 2: + Logger().writeline('Only one modal is given, ignore modality') + config.uptate({KEY.USE_MODALITY: False}) + + else: + modal_map_cp = config[KEY.MODAL_MAP + '_cp'] if checkpoint_given else {} + modal_map = modal_map_cp.copy() + current_idx = len(modal_map_cp) + for modal_key in modalities: + if modal_key not in modal_map.keys(): + modal_map[modal_key] = current_idx + current_idx += 1 + + if config[KEY.IS_DDP]: + # Synchronize modal_map + torch.cuda.set_device(config[KEY.LOCAL_RANK]) + modal_map_bcast = [modal_map] + dist.broadcast_object_list(modal_map_bcast, src=0) + modal_map = modal_map_bcast[0] + + config.update( + { + KEY.NUM_MODALITIES: len(modal_map), + KEY.MODAL_MAP: modal_map, + KEY.MODAL_LIST: list(modal_map.keys()), + } + ) + + dataset.write_modal_attr( + modal_map, + config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE], + ) + + # --------------- save dataset regardless of train/valid--------------# + save_dataset = config[KEY.SAVE_DATASET] + save_by_label = config[KEY.SAVE_BY_LABEL] + if save_dataset: + if save_dataset.endswith('.sevenn_data') is False: + save_dataset += '.sevenn_data' + if (save_dataset.startswith('.') or save_dataset.startswith('/')) is False: + save_dataset = prefix + save_dataset # save_data set is plain file name + dataset.save(save_dataset) + log.format_k_v('Dataset saved to', save_dataset, write=True) + # log.write(f"Loaded full dataset saved to : {save_dataset}\n") + if save_by_label: + dataset.save(prefix, by_label=True) + log.format_k_v('Dataset saved by label', prefix, write=True) + # --------------------------------------------------------------------# + + # TODO: testset is not used + ignore_test = not config.get(KEY.USE_TESTSET, False) + if KEY.LOAD_VALIDSET in config and config[KEY.LOAD_VALIDSET]: + train_set = dataset + test_set = AtomGraphDataset([], config[KEY.CUTOFF]) + + log.write('Loading validset from load_validset\n') + valid_set = AtomGraphDataset({}, cutoff) + for file in config[KEY.LOAD_VALIDSET]: + valid_set.augment(dataset_load(file, config)) + valid_set.group_by_key() + valid_set.unify_dtypes() + + # condition: validset labels should be subset of trainset labels + valid_labels = valid_set.user_labels + train_labels = train_set.user_labels + if set(valid_labels).issubset(set(train_labels)) is False: + valid_set = AtomGraphDataset(valid_set.to_list(), cutoff) + valid_set.rewrite_labels_to_data() + train_set = AtomGraphDataset(train_set.to_list(), cutoff) + train_set.rewrite_labels_to_data() + Logger().write('WARNING! validset labels is not subset of trainset\n') + Logger().write('We overwrite all the train, valid labels to default.\n') + Logger().write('Please create validset by sevenn_graph_build with -l\n') + + Logger().write('the validset loaded, load_dataset is now train_set\n') + Logger().write('the ratio will be ignored\n') + + # condition: validset modalities should be subset of trainset modalities + if config[KEY.USE_MODALITY]: + config_modality = config[KEY.MODAL_LIST] + valid_modality = valid_set.get_modalities() + + if set(valid_modality).issubset(set(config_modality)) is False: + raise ValueError('validset modality is not subset of trainset') + + valid_set.write_modal_attr( + config[KEY.MODAL_MAP], + config[KEY.USE_MODAL_WISE_SHIFT] or config[KEY.USE_MODAL_WISE_SCALE], + ) + else: + train_set, valid_set, test_set = dataset.divide_dataset( + config[KEY.RATIO], ignore_test=ignore_test + ) + log.write(f'The dataset divided into train, valid by {KEY.RATIO}\n') + + log.format_k_v('\nloaded trainset size is', train_set.len(), write=True) + log.format_k_v('\nloaded validset size is', valid_set.len(), write=True) + + log.write('Dataset initialization was successful\n') + + log.write('\nNumber of atoms in the train_set:\n') + log.natoms_write(train_set.get_natoms(config[KEY.TYPE_MAP])) + + log.bar() + log.write('Per atom energy(eV/atom) distribution:\n') + log.statistic_write(train_set.get_statistics(KEY.PER_ATOM_ENERGY)) + log.bar() + log.write('Force(eV/Angstrom) distribution:\n') + log.statistic_write(train_set.get_statistics(KEY.FORCE)) + log.bar() + log.write('Stress(eV/Angstrom^3) distribution:\n') + try: + log.statistic_write(train_set.get_statistics(KEY.STRESS)) + except KeyError: + log.write('\n Stress is not included in the train_set\n') + if is_stress: + is_stress = False + log.write('Turn off stress training\n') + log.bar() + + # saved data must have atomic numbers as X not one hot idx + if config[KEY.SAVE_BY_TRAIN_VALID]: + train_set.save(prefix + 'train') + valid_set.save(prefix + 'valid') + log.format_k_v('Dataset saved by train, valid', prefix, write=True) + + # inconsistent .info dict give error when collate + _, _ = train_set.separate_info() + _, _ = valid_set.separate_info() + + if train_set.x_is_one_hot_idx is False: + train_set.x_to_one_hot_idx(config[KEY.TYPE_MAP]) + if valid_set.x_is_one_hot_idx is False: + valid_set.x_to_one_hot_idx(config[KEY.TYPE_MAP]) + + log.format_k_v('training_set size', train_set.len(), write=True) + log.format_k_v('validation_set size', valid_set.len(), write=True) + + shift, scale, conv_denominator = handle_shift_scale( + config, train_set, checkpoint_given + ) + config.update( + { + KEY.SHIFT: shift, + KEY.SCALE: scale, + KEY.CONV_DENOMINATOR: conv_denominator, + } + ) + + data_lists = (train_set.to_list(), valid_set.to_list(), test_set.to_list()) + + return data_lists diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py index 3a88669..ec2d701 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/processing_epoch.py @@ -1,182 +1,182 @@ -import os -from copy import deepcopy -from typing import Optional - -import torch -from torch.utils.data.distributed import DistributedSampler - -import sevenn._keys as KEY -from sevenn.error_recorder import ErrorRecorder -from sevenn.logger import Logger -from sevenn.train.trainer import Trainer - - -def processing_epoch_v2( - config: dict, - trainer: Trainer, - loaders: dict, # dict[str, Dataset] - start_epoch: int = 1, - train_loader_key: str = 'trainset', - error_recorder: Optional[ErrorRecorder] = None, - total_epoch: Optional[int] = None, - per_epoch: Optional[int] = None, - best_metric_loader_key: str = 'validset', - best_metric: Optional[str] = None, - write_csv: bool = True, - working_dir: Optional[str] = None, -): - from sevenn.util import unique_filepath - - log = Logger() - write_csv = write_csv and log.rank == 0 - working_dir = working_dir or os.getcwd() - prefix = f'{os.path.abspath(working_dir)}/' - - total_epoch = total_epoch or config[KEY.EPOCH] - per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10) - best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss') - recorder = error_recorder or ErrorRecorder.from_config( - config, trainer.loss_functions - ) - recorders = {k: deepcopy(recorder) for k in loaders} - - best_val = float('inf') - best_key = None - if best_metric_loader_key in recorders: - best_key = recorders[best_metric_loader_key].get_key_str(best_metric) - if best_key is None: - log.writeline( - f'Failed to get error recorder key: {best_metric} or ' - + f'{best_metric_loader_key} is missing. There will be no best ' - + 'checkpoint.' - ) - - csv_path = unique_filepath(f'{prefix}/lc.csv') - if write_csv: - head = ['epoch', 'lr'] - for k, rec in recorders.items(): - head.extend(list(rec.get_dct(prefix=k))) - with open(csv_path, 'w') as f: - f.write(','.join(head) + '\n') - - if start_epoch == 1: - path = f'{prefix}/checkpoint_0.pth' # save first epoch - trainer.write_checkpoint(path, config=config, epoch=0) - - for epoch in range(start_epoch, total_epoch + 1): # one indexing - log.timer_start('epoch') - lr = trainer.get_lr() - log.bar() - log.write(f'Epoch {epoch}/{total_epoch} lr: {lr:8f}\n') - log.bar() - - csv_dct = {'epoch': str(epoch), 'lr': f'{lr:8f}'} - errors = {} - for k, loader in loaders.items(): - is_train = k == train_loader_key - if ( - trainer.distributed - and isinstance(loader.sampler, DistributedSampler) - and is_train - and config.get('train_shuffle', True) - ): - loader.sampler.set_epoch(epoch) - - rec = recorders[k] - trainer.run_one_epoch(loader, is_train, rec) - csv_dct.update(rec.get_dct(prefix=k)) - errors[k] = rec.epoch_forward() - log.write_full_table(list(errors.values()), list(errors)) - trainer.scheduler_step(best_val) - - if write_csv: - with open(csv_path, 'a') as f: - f.write(','.join(list(csv_dct.values())) + '\n') - - if best_key and errors[best_metric_loader_key][best_key] < best_val: - path = f'{prefix}/checkpoint_best.pth' - trainer.write_checkpoint(path, config=config, epoch=epoch) - best_val = errors[best_metric_loader_key][best_key] - log.writeline('Best checkpoint written') - - if epoch % per_epoch == 0: - path = f'{prefix}/checkpoint_{epoch}.pth' - trainer.write_checkpoint(path, config=config, epoch=epoch) - - log.timer_end('epoch', message=f'Epoch {epoch} elapsed') - return trainer - - -def processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir): - log = Logger() - prefix = f'{os.path.abspath(working_dir)}/' - train_loader, valid_loader = loaders - - is_distributed = config[KEY.IS_DDP] - rank = config[KEY.RANK] - total_epoch = config[KEY.EPOCH] - per_epoch = config[KEY.PER_EPOCH] - train_recorder = ErrorRecorder.from_config(config) - valid_recorder = ErrorRecorder.from_config(config) - best_metric = config[KEY.BEST_METRIC] - csv_fname = f'{prefix}{config[KEY.CSV_LOG]}' - current_best = float('inf') - - if init_csv: - csv_header = ['Epoch', 'Learning_rate'] - # Assume train valid have the same metrics - for metric in train_recorder.get_metric_dict().keys(): - csv_header.append(f'Train_{metric}') - csv_header.append(f'Valid_{metric}') - log.init_csv(csv_fname, csv_header) - - def write_checkpoint(epoch, is_best=False): - if is_distributed and rank != 0: - return - suffix = '_best' if is_best else f'_{epoch}' - checkpoint = trainer.get_checkpoint_dict() - checkpoint.update({'config': config, 'epoch': epoch}) - torch.save(checkpoint, f'{prefix}/checkpoint{suffix}.pth') - - fin_epoch = total_epoch + start_epoch - for epoch in range(start_epoch, fin_epoch): - lr = trainer.get_lr() - log.timer_start('epoch') - log.bar() - log.write(f'Epoch {epoch}/{fin_epoch - 1} lr: {lr:8f}\n') - log.bar() - - trainer.run_one_epoch( - train_loader, is_train=True, error_recorder=train_recorder - ) - train_err = train_recorder.epoch_forward() - - trainer.run_one_epoch(valid_loader, error_recorder=valid_recorder) - valid_err = valid_recorder.epoch_forward() - - csv_values = [epoch, lr] - for metric in train_err: - csv_values.append(train_err[metric]) - csv_values.append(valid_err[metric]) - log.append_csv(csv_fname, csv_values) - - log.write_full_table([train_err, valid_err], ['Train', 'Valid']) - - val = None - for metric in valid_err: - # loose string comparison, - # e.g. "Energy" in "TotalEnergy" or "Energy_Loss" - if best_metric in metric: - val = valid_err[metric] - break - assert val is not None, f'Metric {best_metric} not found in {valid_err}' - trainer.scheduler_step(val) - - log.timer_end('epoch', message=f'Epoch {epoch} elapsed') - - if val < current_best: - current_best = val - write_checkpoint(epoch, is_best=True) - log.writeline('Best checkpoint written') - if epoch % per_epoch == 0: - write_checkpoint(epoch) +import os +from copy import deepcopy +from typing import Optional + +import torch +from torch.utils.data.distributed import DistributedSampler + +import sevenn._keys as KEY +from sevenn.error_recorder import ErrorRecorder +from sevenn.logger import Logger +from sevenn.train.trainer import Trainer + + +def processing_epoch_v2( + config: dict, + trainer: Trainer, + loaders: dict, # dict[str, Dataset] + start_epoch: int = 1, + train_loader_key: str = 'trainset', + error_recorder: Optional[ErrorRecorder] = None, + total_epoch: Optional[int] = None, + per_epoch: Optional[int] = None, + best_metric_loader_key: str = 'validset', + best_metric: Optional[str] = None, + write_csv: bool = True, + working_dir: Optional[str] = None, +): + from sevenn.util import unique_filepath + + log = Logger() + write_csv = write_csv and log.rank == 0 + working_dir = working_dir or os.getcwd() + prefix = f'{os.path.abspath(working_dir)}/' + + total_epoch = total_epoch or config[KEY.EPOCH] + per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10) + best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss') + recorder = error_recorder or ErrorRecorder.from_config( + config, trainer.loss_functions + ) + recorders = {k: deepcopy(recorder) for k in loaders} + + best_val = float('inf') + best_key = None + if best_metric_loader_key in recorders: + best_key = recorders[best_metric_loader_key].get_key_str(best_metric) + if best_key is None: + log.writeline( + f'Failed to get error recorder key: {best_metric} or ' + + f'{best_metric_loader_key} is missing. There will be no best ' + + 'checkpoint.' + ) + + csv_path = unique_filepath(f'{prefix}/lc.csv') + if write_csv: + head = ['epoch', 'lr'] + for k, rec in recorders.items(): + head.extend(list(rec.get_dct(prefix=k))) + with open(csv_path, 'w') as f: + f.write(','.join(head) + '\n') + + if start_epoch == 1: + path = f'{prefix}/checkpoint_0.pth' # save first epoch + trainer.write_checkpoint(path, config=config, epoch=0) + + for epoch in range(start_epoch, total_epoch + 1): # one indexing + log.timer_start('epoch') + lr = trainer.get_lr() + log.bar() + log.write(f'Epoch {epoch}/{total_epoch} lr: {lr:8f}\n') + log.bar() + + csv_dct = {'epoch': str(epoch), 'lr': f'{lr:8f}'} + errors = {} + for k, loader in loaders.items(): + is_train = k == train_loader_key + if ( + trainer.distributed + and isinstance(loader.sampler, DistributedSampler) + and is_train + and config.get('train_shuffle', True) + ): + loader.sampler.set_epoch(epoch) + + rec = recorders[k] + trainer.run_one_epoch(loader, is_train, rec) + csv_dct.update(rec.get_dct(prefix=k)) + errors[k] = rec.epoch_forward() + log.write_full_table(list(errors.values()), list(errors)) + trainer.scheduler_step(best_val) + + if write_csv: + with open(csv_path, 'a') as f: + f.write(','.join(list(csv_dct.values())) + '\n') + + if best_key and errors[best_metric_loader_key][best_key] < best_val: + path = f'{prefix}/checkpoint_best.pth' + trainer.write_checkpoint(path, config=config, epoch=epoch) + best_val = errors[best_metric_loader_key][best_key] + log.writeline('Best checkpoint written') + + if epoch % per_epoch == 0: + path = f'{prefix}/checkpoint_{epoch}.pth' + trainer.write_checkpoint(path, config=config, epoch=epoch) + + log.timer_end('epoch', message=f'Epoch {epoch} elapsed') + return trainer + + +def processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir): + log = Logger() + prefix = f'{os.path.abspath(working_dir)}/' + train_loader, valid_loader = loaders + + is_distributed = config[KEY.IS_DDP] + rank = config[KEY.RANK] + total_epoch = config[KEY.EPOCH] + per_epoch = config[KEY.PER_EPOCH] + train_recorder = ErrorRecorder.from_config(config) + valid_recorder = ErrorRecorder.from_config(config) + best_metric = config[KEY.BEST_METRIC] + csv_fname = f'{prefix}{config[KEY.CSV_LOG]}' + current_best = float('inf') + + if init_csv: + csv_header = ['Epoch', 'Learning_rate'] + # Assume train valid have the same metrics + for metric in train_recorder.get_metric_dict().keys(): + csv_header.append(f'Train_{metric}') + csv_header.append(f'Valid_{metric}') + log.init_csv(csv_fname, csv_header) + + def write_checkpoint(epoch, is_best=False): + if is_distributed and rank != 0: + return + suffix = '_best' if is_best else f'_{epoch}' + checkpoint = trainer.get_checkpoint_dict() + checkpoint.update({'config': config, 'epoch': epoch}) + torch.save(checkpoint, f'{prefix}/checkpoint{suffix}.pth') + + fin_epoch = total_epoch + start_epoch + for epoch in range(start_epoch, fin_epoch): + lr = trainer.get_lr() + log.timer_start('epoch') + log.bar() + log.write(f'Epoch {epoch}/{fin_epoch - 1} lr: {lr:8f}\n') + log.bar() + + trainer.run_one_epoch( + train_loader, is_train=True, error_recorder=train_recorder + ) + train_err = train_recorder.epoch_forward() + + trainer.run_one_epoch(valid_loader, error_recorder=valid_recorder) + valid_err = valid_recorder.epoch_forward() + + csv_values = [epoch, lr] + for metric in train_err: + csv_values.append(train_err[metric]) + csv_values.append(valid_err[metric]) + log.append_csv(csv_fname, csv_values) + + log.write_full_table([train_err, valid_err], ['Train', 'Valid']) + + val = None + for metric in valid_err: + # loose string comparison, + # e.g. "Energy" in "TotalEnergy" or "Energy_Loss" + if best_metric in metric: + val = valid_err[metric] + break + assert val is not None, f'Metric {best_metric} not found in {valid_err}' + trainer.scheduler_step(val) + + log.timer_end('epoch', message=f'Epoch {epoch} elapsed') + + if val < current_best: + current_best = val + write_checkpoint(epoch, is_best=True) + log.writeline('Best checkpoint written') + if epoch % per_epoch == 0: + write_checkpoint(epoch) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py b/mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py index 888469d..b0dabab 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/scripts/train.py @@ -1,139 +1,139 @@ -from typing import List, Optional - -import torch.distributed as dist -from torch.utils.data.distributed import DistributedSampler -from torch_geometric.loader import DataLoader - -import sevenn._keys as KEY -from sevenn.logger import Logger -from sevenn.model_build import build_E3_equivariant_model -from sevenn.scripts.processing_continue import ( - convert_modality_of_checkpoint_state_dct, -) -from sevenn.train.trainer import Trainer - - -def loader_from_config(config, dataset, is_train=False): - batch_size = config[KEY.BATCH_SIZE] - shuffle = is_train and config[KEY.TRAIN_SHUFFLE] - sampler = None - loader_args = { - 'dataset': dataset, - 'batch_size': batch_size, - 'shuffle': shuffle - } - if KEY.NUM_WORKERS in config and config[KEY.NUM_WORKERS] > 0: - loader_args.update({'num_workers': config[KEY.NUM_WORKERS]}) - - if config[KEY.IS_DDP]: - dist.barrier() - sampler = DistributedSampler( - dataset, dist.get_world_size(), dist.get_rank(), shuffle=shuffle - ) - loader_args.update({'sampler': sampler}) - loader_args.pop('shuffle') # sampler is mutually exclusive with shuffle - return DataLoader(**loader_args) - - -def train_v2(config, working_dir: str): - """ - Main program flow, since v0.9.6 - """ - import sevenn.train.atoms_dataset as atoms_dataset - import sevenn.train.graph_dataset as graph_dataset - import sevenn.train.modal_dataset as modal_dataset - - from .processing_continue import processing_continue_v2 - from .processing_epoch import processing_epoch_v2 - - log = Logger() - log.timer_start('total') - - if KEY.LOAD_TRAINSET not in config and KEY.LOAD_DATASET in config: - log.writeline('***************************************************') - log.writeline('For train_v2, please use load_trainset_path instead') - log.writeline('I will assign load_trainset as load_dataset') - log.writeline('***************************************************') - config[KEY.LOAD_TRAINSET] = config.pop(KEY.LOAD_DATASET) - - # config updated - start_epoch = 1 - state_dicts: Optional[List[dict]] = None - if config[KEY.CONTINUE][KEY.CHECKPOINT]: - state_dicts, start_epoch = processing_continue_v2(config) - - if config.get(KEY.USE_MODALITY, False): - datasets = modal_dataset.from_config(config, working_dir) - elif config[KEY.DATASET_TYPE] == 'graph': - datasets = graph_dataset.from_config(config, working_dir) - elif config[KEY.DATASET_TYPE] == 'atoms': - datasets = atoms_dataset.from_config(config, working_dir) - else: - raise ValueError(f'Unknown dataset type: {config[KEY.DATASET_TYPE]}') - loaders = { - k: loader_from_config(config, v, is_train=(k == 'trainset')) - for k, v in datasets.items() - } - - log.write('\nModel building...\n') - model = build_E3_equivariant_model(config) - log.print_model_info(model, config) - - trainer = Trainer.from_config(model, config) - if state_dicts: - trainer.load_state_dicts(*state_dicts, strict=False) - - processing_epoch_v2( - config, trainer, loaders, start_epoch, working_dir=working_dir - ) - log.timer_end('total', message='Total wall time') - - -def train(config, working_dir: str): - """ - Main program flow, until v0.9.5 - """ - from .processing_continue import processing_continue - from .processing_dataset import processing_dataset - from .processing_epoch import processing_epoch - - log = Logger() - log.timer_start('total') - - # config updated - state_dicts: Optional[List[dict]] = None - if config[KEY.CONTINUE][KEY.CHECKPOINT]: - state_dicts, start_epoch, init_csv = processing_continue(config) - else: - start_epoch, init_csv = 1, True - - # config updated - train, valid, _ = processing_dataset(config, working_dir) - datasets = {'dataset': train, 'validset': valid} - loaders = { - k: loader_from_config(config, v, is_train=(k == 'dataset')) - for k, v in datasets.items() - } - loaders = list(loaders.values()) - - log.write('\nModel building...\n') - model = build_E3_equivariant_model(config) - - log.write('Model building was successful\n') - - trainer = Trainer.from_config(model, config) - if state_dicts: - state_dicts = convert_modality_of_checkpoint_state_dct( - config, state_dicts - ) - trainer.load_state_dicts(*state_dicts, strict=False) - - log.print_model_info(model, config) - - Logger().write('Trainer initialized, ready to training\n') - Logger().bar() - log.write('Trainer initialized, ready to training\n') - log.bar() - - processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir) - log.timer_end('total', message='Total wall time') +from typing import List, Optional + +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch_geometric.loader import DataLoader + +import sevenn._keys as KEY +from sevenn.logger import Logger +from sevenn.model_build import build_E3_equivariant_model +from sevenn.scripts.processing_continue import ( + convert_modality_of_checkpoint_state_dct, +) +from sevenn.train.trainer import Trainer + + +def loader_from_config(config, dataset, is_train=False): + batch_size = config[KEY.BATCH_SIZE] + shuffle = is_train and config[KEY.TRAIN_SHUFFLE] + sampler = None + loader_args = { + 'dataset': dataset, + 'batch_size': batch_size, + 'shuffle': shuffle + } + if KEY.NUM_WORKERS in config and config[KEY.NUM_WORKERS] > 0: + loader_args.update({'num_workers': config[KEY.NUM_WORKERS]}) + + if config[KEY.IS_DDP]: + dist.barrier() + sampler = DistributedSampler( + dataset, dist.get_world_size(), dist.get_rank(), shuffle=shuffle + ) + loader_args.update({'sampler': sampler}) + loader_args.pop('shuffle') # sampler is mutually exclusive with shuffle + return DataLoader(**loader_args) + + +def train_v2(config, working_dir: str): + """ + Main program flow, since v0.9.6 + """ + import sevenn.train.atoms_dataset as atoms_dataset + import sevenn.train.graph_dataset as graph_dataset + import sevenn.train.modal_dataset as modal_dataset + + from .processing_continue import processing_continue_v2 + from .processing_epoch import processing_epoch_v2 + + log = Logger() + log.timer_start('total') + + if KEY.LOAD_TRAINSET not in config and KEY.LOAD_DATASET in config: + log.writeline('***************************************************') + log.writeline('For train_v2, please use load_trainset_path instead') + log.writeline('I will assign load_trainset as load_dataset') + log.writeline('***************************************************') + config[KEY.LOAD_TRAINSET] = config.pop(KEY.LOAD_DATASET) + + # config updated + start_epoch = 1 + state_dicts: Optional[List[dict]] = None + if config[KEY.CONTINUE][KEY.CHECKPOINT]: + state_dicts, start_epoch = processing_continue_v2(config) + + if config.get(KEY.USE_MODALITY, False): + datasets = modal_dataset.from_config(config, working_dir) + elif config[KEY.DATASET_TYPE] == 'graph': + datasets = graph_dataset.from_config(config, working_dir) + elif config[KEY.DATASET_TYPE] == 'atoms': + datasets = atoms_dataset.from_config(config, working_dir) + else: + raise ValueError(f'Unknown dataset type: {config[KEY.DATASET_TYPE]}') + loaders = { + k: loader_from_config(config, v, is_train=(k == 'trainset')) + for k, v in datasets.items() + } + + log.write('\nModel building...\n') + model = build_E3_equivariant_model(config) + log.print_model_info(model, config) + + trainer = Trainer.from_config(model, config) + if state_dicts: + trainer.load_state_dicts(*state_dicts, strict=False) + + processing_epoch_v2( + config, trainer, loaders, start_epoch, working_dir=working_dir + ) + log.timer_end('total', message='Total wall time') + + +def train(config, working_dir: str): + """ + Main program flow, until v0.9.5 + """ + from .processing_continue import processing_continue + from .processing_dataset import processing_dataset + from .processing_epoch import processing_epoch + + log = Logger() + log.timer_start('total') + + # config updated + state_dicts: Optional[List[dict]] = None + if config[KEY.CONTINUE][KEY.CHECKPOINT]: + state_dicts, start_epoch, init_csv = processing_continue(config) + else: + start_epoch, init_csv = 1, True + + # config updated + train, valid, _ = processing_dataset(config, working_dir) + datasets = {'dataset': train, 'validset': valid} + loaders = { + k: loader_from_config(config, v, is_train=(k == 'dataset')) + for k, v in datasets.items() + } + loaders = list(loaders.values()) + + log.write('\nModel building...\n') + model = build_E3_equivariant_model(config) + + log.write('Model building was successful\n') + + trainer = Trainer.from_config(model, config) + if state_dicts: + state_dicts = convert_modality_of_checkpoint_state_dct( + config, state_dicts + ) + trainer.load_state_dicts(*state_dicts, strict=False) + + log.print_model_info(model, config) + + Logger().write('Trainer initialized, ready to training\n') + Logger().bar() + log.write('Trainer initialized, ready to training\n') + log.bar() + + processing_epoch(trainer, config, loaders, start_epoch, init_csv, working_dir) + log.timer_end('total', message='Total wall time') diff --git a/mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py b/mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py index e3ee490..7efea4a 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/sevenn_logger.py @@ -1,6 +1,6 @@ -import warnings - -from .logger import * # noqa: F403 - -warnings.warn('Please use sevenn.logger instead of sevenn.sevenn_logger', - DeprecationWarning, stacklevel=2) +import warnings + +from .logger import * # noqa: F403 + +warnings.warn('Please use sevenn.logger instead of sevenn.sevenn_logger', + DeprecationWarning, stacklevel=2) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py b/mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py index f7e6725..ba3c787 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/sevennet_calculator.py @@ -1,6 +1,6 @@ -import warnings - -from .calculator import * # noqa: F403 - -warnings.warn('Please use sevenn.calculator instead of sevenn.sevennet_calculator', - DeprecationWarning, stacklevel=2) +import warnings + +from .calculator import * # noqa: F403 + +warnings.warn('Please use sevenn.calculator instead of sevenn.sevennet_calculator', + DeprecationWarning, stacklevel=2) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 06affaeb2937478a01a3398a013f17cb8fcbecf9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 182 zcmd1j<>g`kg0)*$W`O9&AOaaM0yz#qT+9L_QW%06G#UL?G8BP?5yY=Z{fzwFRQ=N8 z)FS<~lp=k%%)G>$kksN5{esGpjQqTK=imVS+{ENm-K5mKmR`NOy1KgAe82Q8$zz*sQ=~l}duBEAgCmh1&B&Bjr0sRjczc_DtI4MN zrCZgKNK^+&Xv0nbO(xl2$;@P_-JNW-*jePYL6H0jf&|GXKL*$Y$fg1$yV$(+Bftbf zu)EUO%6HDKZZ>JzlLQDGb?e?+=bpNC?z!iA*%}&hHT?bPckXSCzoBV=M2*3}Q8Zq| z75tW=X+jfvT`Q`)UNq1)>Sn_#S~~Ta^;E+y+KqHE&26jhG&03Z!!5dvY%$x&6?5FL z*N2Kjj4Pkqhl?ZJ9&L;j#~CJ7pJ+@LCmU16sm641x-nCnY0MU98^?;r8pn&r8OE-k zXq+sbzVqQ#+l+7uIu%)jVFp<;JRCXvhh^$DXwShPdDa@ zbExOUP`%K2ruYo%dH>m}A%?}s`&My2X^o08w9fhGiVM3gW;|YBY&=(dPS^bT;=(&x zGh=D~^E-c4*XZ}I*3@rlVnR&5uNA*2ru?Dp7Z{Hvro{~Kd?_=VIhHvlzIO0T9QR)o zCvctgU#z;~lsNr${AZ3_z_WwV zIk707!{`fQg-1QF@K-JIMezdqUlw2TUCi)B@e*eEiukgB9<8s4pF`_ak@GJIUCfGO z;#k!d=Yg5N`>Xiw3w}ntoXnNz^$M<6L9Y?N6<)tje7suF<8-?$!&t5njsMRGPajdZtldRS36;)EH7?#n%5OjJoaWQs1crEsVP=kvf>A^ zy%tJ82;y9e!7BtFV2;ymS!)3JreA9Mwau-KmMqmV6+FrXp{zAGzgTMu|2Ep$4Oy=E z{_S?TDbTdS8=`@`bEOpo>pN{9^`WJ()p%W&+gq2)VfnreLZEl#P;cOev0dqet!lMm z5vh{DA>7lrF5(J)iXzg5KC5*t+zs4QxSP1!p|-6HD>Ovvj?qnr+P;a)ip;zET|iE2 zT_>`F7a~X46I!ICe*?VxlB);t2+Z(q>D zlt@QB_C-Av^6>u!Wrl#Oy#3tW_0l|(%`h2lO$ zpmyIx0b-^^2wLAyp%q)T?MrKgRMEZBDK|r`A{N?gZ&YHdLcEh{_xM20&|oJvn(f%B z37>#>;#9TXDu=PzDBq5q@v>gRNRM4DM-i0b?J$ zySUYA_=_EI{$f?g#g!Vi#=0Mbi|w6otJN%BTzhk|0j4~+;WsN=i!Vq)TW)7@&A;h4 zSN(7?ph|Nwl-M+j0z8Q?77Oj2c(SxfY+Bmr)as(tD2G9%*52{(*5DZwj!wV4F`-+j zIo;CLWf+Fu`{f_zoSdFB71kBb>uq7p)V^;y9X<{cZ%V0BYn8|tJVLgQ3B^QR=z@)^MWCYud2ZbwhpMjM9-itASjy581Ut5dKH?OY6&hn+#mrIu~{X$ZEd-)=ag5@h$VwcL&+Sf0ytjAfdzkT(}TW>7K z86YSTE}*FU&DgrUx^fjF`!y_|+RD6vNp33%{+O+th-2*IIDh!#ViTVfTO>x0GTTh2 z32Vl$`+^A)l){!GRu!)bNK8PQg72n`yx|xw_O_#s>9%R>gTK6P{jr<1bonj@tgpmF zklW3m>Pv{PYAep}l&Z4TC}C$+;$euX)~!2jT$c09Mo_nPgNg1 zRekgnBQ|YtaJo4erS~21e!4jlLY{!%9as_>o{b}vKzBqIZS`&_grEc|m1F5-1vfAc zH%i^rg9g`tbw7JolQ%;64oZUU4*`FM=d11Kd3Qoy?GK|j8;)RiW?^4scz5QyL*0CL zm|=6#5bw?rz>WfTlzUyw+6~7*pYiPp%yujq6E^7a+w>IIWH^Q0ISr_>aAtcp8bx`G z%j4LWGutPk5tJvRJjzqRliJf!J0005J<#HeNQ?Yk1Ff^z(=*#oM8hb*5D~T~!?EzG zXq5IYT2Iqxv>_vogKo8fzjef!dhQL^QGg3@p6iyN0fD9o|&G zZuaWYIB-7;c`&}FK{^by3M1tgOign={H=iBJg6m)<>S03{jejO-qN6%i{~C26W#_O zh1Y6&%DnYtS=2fKBU-*)@!O#XLE!mpxX)9)0>n+>!O(_n?(q%^yhbMoy$#=kNP$Jd zdch+p&v@{zR=DL$uYz|4cnUb^n-E5kFBTZ&8qYECo^G}mn&sxx-mO}=XSn#TMvs~|h*8D<}J)Jr}Sk(rmYy@smRtS&3=HKYlz?MW=L1xVsI2H2U!*re|Zh;^7 zH1i&~(N~-oJ3A%C-EkTuCH|%^#oqVy`(&UM40#MURrFmPIg7s7P6nscP+JY$ee!bE z+i7*Y!1qDHS}^am>ppEtSoO3uF}tAS?>_fdPX~F+9`@M^SiI8|^B4z%?69`r<(Fl= zbhR)KKZ?|g&07un8SAdhGOOBwEo@>F#VOu2Y#PX~&|Gt+0j7a-6cChIZtleH+r%%+ zQo{UNqEG!A6{L#fB`TJwAghb*1i3=B*Qt1dimy{~nTl^v@l7hOP=P3!CRb6w^WzR3aDqb7$7 z(<7glA9;+h{^!^TIdr%PyXAgj8S*w>8M|t`QCK9-U^kOtO(7Mf3@XIcy+0R~3@+#o zE%=KlDA)jR&i+vQQ2)^Q&_oal=EnO@q+NnCBX2Hv9^nfF3ZS}R62SB|LkmW2ii`wd zMFzVYCVWE^h6%hvGNGZ^oc%Q2n^`DQ^L>rR2pAN1491%gF7-wk)rvAOdb8y$EFs?9 z@RG41{JM7&HY;3050ZYsA7&W#j1$FnAZs}|#F8UXg+vt^zBH&38Np+;gc?0Zu`bB= z{?N5$CCGUgiS5tT82!oi!TW3C&B09qW^}+BCqP~*60;DuefXcDYG^A!ELU6qzjRDr zyS=V73hi0Z&A@q@~zzJ=n8B>4hz zWcgd{u8zon{vB=LD-chheGbhO;f5e>v)u_xMp7Y=dW#rkP;lBWgJ~7ekWs5)ZRfWz(TVrbBOX`sU;2m8W^p;!SasuhEXJ zH(@^^oB&s?QEv11Nm^B<2vU`LD?pH8yH#rzT)9IJc+GaHcCc_H?-7*qB@C2=El4R<5QaBsf09o@*nS6K zhX_Yfv>!?PJv1ov3!*_rvUJp|2KB&=ASW)V3-kLX7?wgkT1S&7B7|;s;RtP~goThX z2rldk8j>c%mKNGOFRG*ooS)QP3OVDMN}%uA^e)t?T=j>48wJ>&?Cd7Ui6@rBQ`Q~pBU8sMAGQk zRYqj#y1xV4gltf*x@uBXKNPJcaYq{bBCg;!QAE(p@4bnDmKIsi(3E!2JGz{TtZfQg zVY^aI7#Fl=I!aMWqSq!}56EZHHUSBZPBmdAkST`bc4RW-HwB`&j1|E!aE~2Y`xGog zfFa=?dXsZOyJ6r>2o~_0R2$Ny;FkX*wg{^=q=}Fg^0r#Hyk77dzDrpI-qPCg0vj-1 zYhxSXKz97d7f93%4~K<)K(@p+B~s0kA$>qO@&h>334d(MC*=x5sNF)Wvbo84*o7si z)&}RPgi@jGvU%(RC8h#F&}4bpNZ5S~ zj>iJ&*x)@F8W68w{sT9ZT#cQQa^(3d$eKBzEWaBx?PdASd!R8{Cxk?SEqF=BkG<5@ zRx|Of2vMJ#KNk-}eIA;D^(0z+)_fG$;}G$eeag>0LY)K4qfapvJ-9|*+d%waP(ZNv zO|G5dSPE|hNCDjtgNjS0Jh6FH=EIm=fnU3qa*g-K5?s6)fT@S5QbR zMgM*3{Q(t!or>3}xJ3nHI^u`rN?_?Ern|F9l9_afbBfY26G)Ibkjqq#Af+&3K#Bhf z9zW2}BTrF71iDNCa84i}iqDeJC5wo>+u$V>q{fj5)2a}*kyX>9)UWH`(IF5`L=2IK z2__&GsZA*^&Z_QFmeknwn56F@mj~t`gP)kK)> zY_au1GX`z$;wF#KhQ`MB^;_8Gfg;`!^0>fSB80x&AmWg@G zsmZI9N!-)FZFCVHrf0JnLe2;eSE0#tz=c_XAS{w)_I{3hPrNS=KQmj?yTkBD5q`#d z{~5-QbVoNegp5bg{s#!xjw_sDK0>IvI}r{e9F4&Co{s0g(VfJXO!AjZg`?3_Wa68q zcXJ5)PDc|+?MxuCGqF9nXLV;_wPj&NWp}=>tSHQg1LU_Tggg90T?~a&+td3q(d1pd zrvFgy&W5wm48La+b2!#L{xk6Yi~no9FC}8$msGGa;Wk6$G_? zrQou3w591SZvZfDGuq)K~&3g2*$zyQeqjlF`d1k12c#G3G7gg z%&;SN0S(ygl+vU-1P>h_fyXc;;NGGKahjxOCWO_dk}Q;^{{X2Y&@#RHK>yat1O2_- zGq+?5vpM|f$#m8BXIER2(8k%H6aZ`O7#1lGD+CmsN6xJ&xF(x`xre-bolf~2kv3WW zC}m!JK;k9`U`ni6oLtN4yU!e+Ln4|Eoaq4spG}#s_bIcl6LmreX2hZ>Ig5nQJ|q!w zW7AFg&FiK-5kfBYC>YWA~(Sc!?_+9}yFrbHOpD6vr=RNEm#=kXc8nLq+9(jXHmwiXWmVoKm{~9(9lh zSTtLew){KP{QJ}#KSCK=#W=;^Qq0NnElO~3kOrNU1ebrCMrFcATW!^%OCy33PJl?9 z{h;G?TrUpy5eGcJ=up(LByb)u#lkM;k&w(FTM_aChdEeS7BhsA`x>XGqg`l z<6|R*Q&x|1hzpHablwVK{6{(S(@#b)_85!4Jly}B=>oSppXGjR{Jwi`Oqaih>BVZ( zGS`!Q>HQJIk|j#pEch=dI37niWF6)UHW*ZZE;NMQO*rfnOqzru;FCYXd$~?CiBdK0 zx$#~&MB)M>Ik$?yh~l3h+@0!IdqRuTW}Bb~{(kbR{7bmSSta+_mX*AJ{vo;#Nh^{g z6iFG!6`Vl9VNO>CT&sEwL?EMnFxaWHjogoIR`2OxTbdcRfD-T&P~_0-|LTdV6WZ=+ zTFOu9wu3-%oZ<-Nu0>L}kXJ&WLzr%&h@EfxJ1TNzciQ+apYv^T_R`Y&Qt6%L%ddZZ z{XhmGDR~>tn~=mMN??&VQGzxj|268x9218p1;Oz>di-ywf1LYo(Ye~N1OtT*@+Op2 zoj4ps30OKK1OYLy_5Lm4P{GuJZ~a4}0c}l$P&uE4O=S|p+aEhA+E_W{u^jyrE^Hcj zWjXk0jxPTIL*lG*Q%WQp4y_x>SGw34`!*^9%n#e14z7~UKTci=tzV?GQ--5k6m1o- zu$k4+CQT10^D$87VW@(&8s`s7lh|&_+Gee(c)^myN4#%kOf~YNa6s`7ss0-%DDOu8 zR_rwVuuK#_Y6?IHg_W&@KXjlANSP=hG(zIS8u^4WgOc{VBtFLZY)a|L|3n@CiaJ~< z_foP*IjT2W0!jLCr{dso?~h{Ur28)73TE&oI_--c?Lajt7f)m}zXd)bs&~qI@2skVgv-qC{oR=PtCbuWxx_lQkn%(3D zoPg>MzXSKEi5v(0!YMHnjY9h6BkOzG57_9uA)?{n$NZi#%st(m9KiKY4sbyL{RCl|4NX3p`a3D|Bd?YQ}G&# zVusg)6jm{#@?NC7nAUgbxyvRYA7GTfPaRqKsUi_A$}&wWITIAKs=_u_+-vaX|CnEe zq#ce#Kw%Vz7)ehi#fI28Dg=jFkK!06Nl3)-hDZ*QsqKv9V0_syxyUK9QBt_dLF^9M zm4ToAkjyxoi@5xYfXVd39%_ ze>|1+5w|ACm|1mlvN@gb2yXzjbDpCqzDUI-6vc6 z3Jf7F_z4OJ@sAw%afaOrs((SxWfG^K&+_DqyqWi`F$2%YBBIQpV@+g@v&6U7L(U|^ zb%7=IF*dAHvjI!*Gt$PN8czOC{x@kK{?q^f diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc b/mace-bench/3rdparty/SevenNet/sevenn/train/__pycache__/dataset.cpython-310.pyc deleted file mode 100644 index 55d4fb5c33d33ea6f7e3b2dbb20068e7847bc6a8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 17074 zcmbt*Yit}>nq9rR`a!ZuNwh@T^0;h`EYhY(zdfVX$fFS{k7s1cE77xdd$m2yzD2UB zevnnwlGtu$cUE4EvkNb>k8H5B8-PN9AP%rdHbF4E10;(~uz4Z>f&f7!zy?Wv#J>_? zf*{!$k8;j;tE;=&^vnby(YLFs?!9&Edz|let2HrEF!1+B|LEhbd&`FLU-)D2&&H20 z;}iZm8f7T6Zj|)5Su*k4s#^`aWScx^*PTYDlxesnw~;Mn8@W=hkuT-tnprQD#$^0> zX&m29eWFn;73Er{KG~QmO(p$~HIA1~B*)W@M@o;F#_NW1RrUizW$#<1lghkl%;!GA zdW_Y1GtORVwVRP2%xB`0(EA=?@b`u95oAFe2 z%Wu@Gm3lecYpl2GA?C^4sHsiA%4rAx^7!#(d_v1I4By~Gmn>zLY~R_=lw7pgQVwmt zR6sjc8qZsKTisBOKj9bsNq=g?nl*H*GX62;`p5m^EFQ%_v5`|bmH!}9npOoh_JL7) z1T&o6Fx9v}xnZdZRm8PZI5(4?n^aRcH>-}RH<%<*CZ}l{hLFz9*4>4EhjWeE=2mo06LzCjuea_}o_L!+3BcsVD{rq| zzk1c1+o-oH5uQ=beRbu{H?Ca1R=)Yx%9S@(Zt9ij)vaEQv&pSvSFwB%z$1@oBCZf2-$B#0x<=XPI%v(Vt<1I&Oh(qW-L-1wzTLHP-ib0@ z^S&9mxb{TX*v@v%nu+tdrrnkA$i>;*LH>bppT7$N_*UhcJuTCkwoz%rw1yVDiv7P4OPP zE+mTL1r=3mhkT|)1qbwCP*8Yl{vBqMH@Dtu)gfU(h_#^YFL?JT=L3vf@ceq{d*@2$ zv_^53bN9}9bIlg?ZNslbc+UBrQjiAKN>hv9i$E8kT)5y>fQ7Ycx!G>4`$4#nKH1QH z$eTof?m&TC!D9c8!xM0cn|{*|Ds>Oa%{vl$R1L5i)%W@sRBP_E``qo#4)Np+T<1EG z1AVxD0$BqOrRNO%-8SY6y>rptj$dw6cFv{CUY#EcCLqSaB%0WU>cpwdGkwe(`3b}B)%w{+77@~`^QQpkhmQWR8b&80*- z7I*gIT)A9p)}nIx8Qdp4j>dh&HFIXcbj@+IXu8(4*_j=(L*VmbZ&cO9h~A%r1FE~I z`T`E1q{e~yz*shJ6@Ss%H}|csh2t#L9I8*Hzi8}^?Pj$4cK&xtf=_-J*%9wgRwm2Nr|BXdw@ef3u}gJ*qV* ztgwoZj!Hf7D{4;<7pdZHBEOm5_A3yZL`r1pz}E|bD$({q+HkoN1eNAy-(IipQO;@* z^}9{)Ze=f>ZFSz!(viW#G{RsCw2KS}Cp~#4C9%Z+a$!ZeA zRE}C@cKbT6h4kf*XLI1K_0W0PIXA-DqsPWMxhI6T3XDzv$Zc}H5AMc+@xa)(YT!NK zo!p)gP&a3deTSgGZQe11*-grZ*>x@^-?PSn)wLg(rm>rkU;?(y11ACZL^>*uL)21= zZ3eCOPS22PV+-nPKL?BS;1tQA`HVS<`$7J0^W;P8- z3MJ-AND5@bf!}b)2&TnFIMuXFqiOEuy7mFCbs;Lis08<;Tyqm+5y3dX@O zcwO7p19RVOLZu1u`(|X{hvIidalynqnR$=_yW9uDs@+pjHp+Fa`|#{&2*4Jr^B+^9 zyxVGvYEq9T-|jnD2!MOlmOsiHASv83sQ5oM!6M-@JZfZDo0M0#%ApxQV0D9~}G_tVzIHz@YH@9h6hH2b1AQlh`3dCw4s*w+tFunGOp2Lyg z8qK#hjT6h4*DjaeyuN(-+8b-{YFgP9HAJXV0sb6f$Oudolsi&;ihS07$KSxcLV8~= z)Q0ZQanLGfcAhwrM*VT|F*z>x_`Tx?@sxN3dUCl?wJ~w;FNcE^1?%_$^`c^E+o!4d z3@s*9>FjN+c`CuN$12oSFq{W$zAGEYj61<8YaX|=QUVHbgmEQ5? zwY3}bIc-)haF(A#Q!2veA0V#gf+tf%jh!<5Dad!6*QA4Eepr5oC`jG;>lho-6Ux~T zV%VM{{&NuG0{xl)1Apep5k?Jifpd&VbQbf8|ME>7)U-VJyw<1pt7jYeI5OBg!ZM+ONW!g>(bx;M+V6-v+-yceK!iY!^TYzXn09P+law^Z42p+ygpQ z44Z%rh~*i9P-6EMK}$vp0cAgUg$-3VA^$g+kyr@=__nm^u?uK~l}`Yo^a%B&Oyr=| z?oBDgFD0}-!p_f95sn=Bw{Q#kaBbdo|sq$82UTbF6@)YlA#;&ZWiJ2vkemeayT5{?eRr9a`U z^pGDlI&&kWfstRlSZ^UQ7G7HH51P*osGhiK;;@Zjn&1&Or_p?DYV{g6|1$88aXBQQ z9Jgl7DQnv5OdoaY*rtv>jebK&M1pX|!U^69@@^pUk_A6#(CMdx`t-PNVil|NmR=ME zbl%pBx`tz!`7mj0NW=GVF?7-7AUZ|!h{d7P5RWJv>5>?+a4GaMz>_!e^J6$6_!7DV z3~LVERxbP*&I>CFX6NJxJtWW?x+&wBG?sY@KMrBR1VqHzN5V>3kT#og1`W4>s?cnB zn^>gi_t2z#3Qj80My@jOx#hxWnhcbFYe>$h=%DWF_&U;96)w^#6*>EJ^!WiNpEf!R zBM--b4-z>bYw5t?I&Y%mNt9cF6bS|epW|&A4jd*Xm-T&oBzAPf-*@Ra%zV~;#jP3g z`AGueEM+bHMKok6{rWKbH3R6`_aJ8JLBQrs!+U_u`?i7fzoqOC#G|L@K)OF5-WuB( z(0*!n{6oxEb`=wfhRT2kZs)tH57x8D48}+k`giYF22)ZO(##XoL#NC={8cu@cPC2k z1g)wcB0In=(I8|ipWziUF33_hz22?Tdn2he;$AX^gV?_r>GXtTu7h;M)e+$#RGlEz zX!b;Sa{DHpSj1;`Sb$=OQVr#x1jS|tkgv(+w0hLoA08c_=%w3~!Ee=gs?uCWK@ZeoGhX^Uu6hfYkTAQ17zdSraYeUri-(#p~bJ3+4mrT(45mER_ zGpxlK#L%@?JS*vS$a7hKm)E1aN!lqMt8F$BdBM*EXb}N=e>7~g2Z9E_jRTceRs=g9 znJBh7D)T{R-`Rw;ZyFo+eg@v`!iNYq8B+){xRKM%kRx69zIC5yLGzv6&vmm1ICH`w zZ}(&~%RV0(^vmvB@S(GUCae?Zx*6gK{Ox=5*>lP4_Zu!52ryXgnp&EC95Y0N}(~hgZd*;j(=+b5m`iONHtj3Cfmez z{aqX|{=e-6Ij~Bc-UwEVAxi|-bOFk6hY_T7VUutcDx++;YfqrJBQCNPyo3}F0>Z8{ zV<3<*mECoJ4X(vL4XCROInFgtcWnjsP8&z{LNMc;N_uh>e|2pcnL)?{#%|a&0jrx> zF#i72Yn|B}!k%Q~@b5fy943Mo%Z^)& z`KNHTvoxaYdY+y$&_{iJ{z$X*B5&%U`4%MO))+ECFiTKWeyscUH>gZ-B@>?wxF!y* z)J|Au=y^JvVZezK^vYjkiRa!4ABN`{Vh&wr zDy`TvI-HRancMW@;Y~d-eiN$MV@O8wO2|6G(TPwA4C4&a;PLodCTOA_~$IHRKzk4I2d#Q+1V;K#6i>kJvSMAXF#kMc6-i$rh&V4?Uxdi?+s3y|>>6Tdm_t5X z(sv9~7{uVZr0ygwVRyc8nA6A=%5AVc07Yq7qJiBySVr1?gwtHb34G*PB{V27uKbWf zZGH%TKQFM@k)`pQl2ztmIi(?;i&BdOF|a+ym$lKq)L0s>bD1G><$ zg}GTHuQYqHK)M3#k85rY7`XSjD#6TSigxFzpT5W;;$Z?|pI{N4 zZC0Rw9$*+;2eb3Vl;*I}$jou4&B!w5=x430uH1P2-M;q@oAy`ul#^-jqk9|$r18wq zoZQl|iF_T%^MxJ*#+}amPky$6QSlT7mI~F&`45HQG`^|a{u}&&l@}pLK&m8lrv?01hDL)kh?pBu{m+EnUDjR>QEh? zp3R*)2&P>5cB8`f7=66kmdYsiBe?GPl~vhav%1+y>?edrn2QS9r`*p+?~PeyqM%<3zUPHot#+Nx zJ5vnj0eR_UOt><3kUCz_ws_TvnfNhk>}v0;TG z&if!3_WwkFoLh&7&dt4ILxL_ZXK)o8Tdd|zY-6);!2DrsWA;*}+QOQ9Ufs7iA>-%R ztOnQkZ?*=qg}7#8Bt3tjmu=omBXIiS*rF=CDbMA z`BS)QF`qM!Tga!=bALEv(}zE9({xQ+*2990Jp0i={8Mi+Bquak4C(~;Skgr}hTIzy zkx=hRcwQz7yY#=Hf(OWSJ}~JsN<8gGnFEAuEcv30n>onglgFn(2zUUXcz{q4ifFHX z@(_PP8PkK(!oL+9$Hf#9uU@}#WhKr+&$cQ+_6S!si^;w~TZwQth=P5Nvv*o|lQc8v zk0QIT?L>#2MRrpGjweWV-6v>-zeMs<9US=841yRCA=hEMKP+TVn@3c`hlUTKCzF@_ zkc8nMqWNs}97tLA5KtmyIggll02$#I93b7HfN5}GooyH7CB7maR+$Y8XL9)DRrVtr zDC%w(6p**;9E?3M5=_h>=C%&Tacx44sqv4r37hEjcc|Xb#cQF82daoQlZ-$jJSoCl z#x~u`CL&|(q>&DSnoX3r!{H;q7>bD1l#@{K!V66k;~QE zACCis>TTthWgFid1_G;XRcp)_51|;Vu1)Ozj*nr-H?~h|z=KHv>3;{~5K4+p2-tmx zEBFJRoeVJs$d~|885;w3+m*l%ewSUSCP>IisQN9QAgGlx6}YbV5c5)H;*1!ck_`bk zj8%^gV%0E_Z&QO9p%ydPOM>FkwB?|NlwiQO#f22dcXQkT^LR08oB z2JPpB4z7UUFd3l~2z!-Q#tS+N@as6jHVoO5uD2WHR+e91DZjD&PMjs9pd<6Hkelh6 zrjJe6v02usSdW7zDt9&|dXQ^}K}tN@p&q|VGBYK@aE0aUM@<)LfC9Gge1>m3G<*c8 z7@`PJDIDl}6?f#x-hy6b`-FDp2HZ8thyrG?hXBr>cIn3CdPdy2G_2|OQ7+ssrrWZe z|B%8+n)lSj7H$JZTvw~=oyzhi0<)TP+ILtW9#ygT7&~X+c=hpxec*yg%?5!rGorHl zQPQCZ9qHp?D9R{2>fa?@hfb#sOd(U@cfYSUM|Yd|R@_FfYW&ft0PEmr^_MavvP3$H z>`i>a-$cXsFDVKp#VI;oI16+7z?_8p1(T_xKrxbXjQ#C2m~t#;mHsx96p}MNA7OWU z6R8T!lGi;4u3q9*<@_OpTYA6*b1r=b;y$umy8W`;Vc+n;P|z9e8v?0+3%&3~7(y8H zI=&E-OD*S+1WOg^_b9WBosc4MoKI<3ktTTSXdp;MppV+vDF+X!cj2?kl~i5iTD2Jft|V?-~v|A)QX#~F#%dPOfm;WOyfvuCpOKf#IE20Z925G!NBciG4@$@H8#md#wL zOobnjM^m_v%Q}xD>thwMc}2Fojb>Dbq%n;1Np>F4L?G?`(OglP#9T3Rl7K6AIw~6} z%tF^@`MQg^_`VgMNGl(3tHs5pIYR`EvJ_lwY>AuacNeA$v2I4p*D0Psz6b@b8nhZ2 z^iy~jhUZ)qvAE43s|dK6%6oV^1p9a|ifMEKN#I8Ko%}-jO93 zP#}@(V+pl-HKSJWeTp^%Thw{5m(xP2;0Q%dxY`HZk$F}-OZ{vk%3j!N3xcvpM&F03#7amP=F6$2@<7KH`Mj9W2HhC`f20 z{`vQDDfnk>#L}GQ(O22L!A9^*8sYmmoY0Lu{TSiDBUwoB9DMO<{7cn-bl?5pm?6)Q zK*otB|h;^mcCi3ts=1Y#3Gr9-CpV2mc!9 z{wt|S9x)vI**K)ZQdUnsIShOR}<)2TbsYY7rdI-I#4dhg)(=8 zwy`Os9FLdFh#>39m2A1JTGet{hdbk|*OynyS642ty?tXPDb6Hm>|e%`;*2D4g5TzG zLEg10uU=nW(YpevVFB?H+uEZ**IKuDjxrHDG!MaJ{QV@G$Jv~dD}0wFc$U9kX7fch zZ8ntG;3^wt|AKWk+%T#kh?~=bci6ngrp4wA8&-ni{Oa3pmM^befAgl~<%3`0mA}U3 z_t^YxHh;wCAG4u$1pkzca5O7`^5P2&o}*V@3ckZDf6nHgvmu`bc$ccJ$N-C9;fH8) zV4*V&=h$}XDY_`h|M0kF{n)atPhH3Qk(DX@F!$v2*z}3qvD`6h%y~#x+Crf_SIiY> z^3ZnsWY(JOowX(~!;h?F2Foe@DEH!Nd*)OwkG(uzaUyrBI5z!^kM9xj^{(S9!$vSQ99Roe}bK zUcjtkJe?%R7n58tgw0RVuj5*G7KsojGs~A2wz1?23xffHQ>##4CsMk`xr List[Atoms]: - if isinstance(file, dict): - atoms_list = dataload.dict_reader(file) - elif 'structure_list' in file: - atoms_dct = dataload.structure_list_reader(file) - atoms_list = [] - for lst in atoms_dct.values(): - atoms_list.extend(lst) - else: - atoms_list = dataload.ase_reader(file, **kwargs) - return atoms_list - - def save(self, path): - # Save atoms list as extxyz - write(path, self._atoms_list, format='extxyz') - - def _graph_build(self, atoms): - return dataload.atoms_to_graph( - atoms, self.cutoff, transfer_info=False, y_from_calc=False - ) - - def __len__(self): - return len(self._atoms_list) - - def __getitem__(self, index): - atoms = self._atoms_list[index] - if self.atoms_transform is not None: - atoms = self.atoms_transform(atoms) - - graph = self._graph_build(atoms) - if self.transform is not None: - graph = self.transform(graph) - - if self.use_data_weight: - weight = graph[KEY.INFO].pop( - KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0} - ) - graph[KEY.DATA_WEIGHT] = weight - - return AtomGraphData.from_numpy_dict(graph) - - @property - def species(self): - self.run_stat() - return [z for z in self.statistics['_natoms'].keys() if z != 'total'] - - @property - def natoms(self): - self.run_stat() - return self.statistics['_natoms'] - - @property - def per_atom_energy_mean(self): - self.run_stat() - return self.statistics[KEY.PER_ATOM_ENERGY]['mean'] - - @property - def elemwise_reference_energies(self): - from sklearn.linear_model import Ridge - - c = self.statistics['_composition'] - y = self.statistics[KEY.ENERGY]['_array'] - zero_indices = np.all(c == 0, axis=0) - c_reduced = c[:, ~zero_indices] - # will not 100% reproduce, as it is sorted by Z - # train/dataset.py was sorted by alphabets of chemical species - coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ - full_coeff = np.zeros(NUM_UNIV_ELEMENT) - full_coeff[~zero_indices] = coef_reduced - return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy - - @property - def force_rms(self): - self.run_stat() - mean = self.statistics[KEY.FORCE]['mean'] - std = self.statistics[KEY.FORCE]['std'] - return float((mean**2 + std**2) ** (0.5)) - - @property - def per_atom_energy_std(self): - self.run_stat() - return self.statistics['per_atom_energy']['std'] - - @property - def avg_num_neigh(self, n_sample=10000): - if self._avg_num_neigh_approx is None: - if len(self) > n_sample: - warnings.warn(_warn_avg_num_neigh) - n_sample = min(len(self), n_sample) - indices = random.sample(range(len(self)), n_sample) - n_neigh = [] - for i in indices: - graph = self[i] - _, nn = np.unique(graph[KEY.EDGE_IDX][0], return_counts=True) - n_neigh.append(nn) - n_neigh = np.concatenate(n_neigh) - self._avg_num_neigh_approx = np.mean(n_neigh) - return self._avg_num_neigh_approx - - @property - def sqrt_avg_num_neigh(self): - self.run_stat() - return self.avg_num_neigh**0.5 - - def run_stat(self): - """ - Loop over dataset and init any statistics might need - Unlink SevenNetGraphDataset, neighbors count is not computed as - it requires to build graph - """ - if self._scanned is True: - return # statistics already computed - y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS] - natoms_counter = Counter() - composition = np.zeros((len(self), NUM_UNIV_ELEMENT)) - stats: Dict[str, Dict[str, Any]] = {y: {'_array': []} for y in y_keys} - - for i, atoms in tqdm( - enumerate(self._atoms_list), desc='run_stat', total=len(self) - ): - z = atoms.get_atomic_numbers() - natoms_counter.update(z.tolist()) - composition[i] = np.bincount(z, minlength=NUM_UNIV_ELEMENT) - for y, dct in stats.items(): - if y == KEY.ENERGY: - dct['_array'].append(atoms.info['y_energy']) - elif y == KEY.PER_ATOM_ENERGY: - dct['_array'].append(atoms.info['y_energy'] / len(atoms)) - elif y == KEY.FORCE: - dct['_array'].append(atoms.arrays['y_force'].reshape(-1)) - elif y == KEY.STRESS: - dct['_array'].append(atoms.info['y_stress'].reshape(-1)) - - for y, dct in stats.items(): - if y == KEY.FORCE: - array = np.concatenate(dct['_array']) - else: - array = np.array(dct['_array']).reshape(-1) - dct.update( - { - 'mean': float(np.mean(array)), - 'std': float(np.std(array)), - 'median': float(np.quantile(array, q=0.5)), - 'max': float(np.max(array)), - 'min': float(np.min(array)), - '_array': array, - } - ) - - natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()} - natoms['total'] = sum(list(natoms.values())) - self.statistics.update( - { - '_composition': composition, - '_natoms': natoms, - **stats, - } - ) - self._scanned = True - - -# script, return dict of SevenNetAtomsDataset -def from_config( - config: Dict[str, Any], - working_dir: str = os.getcwd(), - dataset_keys: Optional[List[str]] = None, -): - from sevenn.logger import Logger - - log = Logger() - if dataset_keys is None: - dataset_keys = [] - for k in config: - if k.startswith('load_') and k.endswith('_path'): - dataset_keys.append(k) - - if KEY.LOAD_TRAINSET not in dataset_keys: - raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') - - # initialize arguments for loading dataset - dataset_args = { - 'cutoff': config[KEY.CUTOFF], - 'use_data_weight': config.get(KEY.USE_WEIGHT, False), - **config[KEY.DATA_FORMAT_ARGS], - } - - datasets = {} - for dk in dataset_keys: - if not (paths := config[dk]): - continue - if isinstance(paths, str): - paths = [paths] - name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) - dataset_args.update({'files': paths}) - datasets[name] = SevenNetAtomsDataset(**dataset_args) - - if not config[KEY.COMPUTE_STATISTICS]: - log.writeline( - ( - 'Computing statistics is skipped, note that if any of other' - 'configurations requires statistics (shift, scale, avg_num_neigh,' - 'chemical_species as auto), SevenNet eventually raise an error!' - ) - ) - return datasets - - train_set = datasets['trainset'] - - chem_species = set(train_set.species) - # print statistics of each dataset - for name, dataset in datasets.items(): - dataset.run_stat() - log.bar() - log.writeline(f'{name} distribution:') - log.statistic_write(dataset.statistics) - log.format_k_v('# atoms (node)', dataset.natoms, write=True) - log.format_k_v('# structures (graph)', len(dataset), write=True) - - chem_species.update(dataset.species) - log.bar() - - # initialize known species from dataset if 'auto' - # sorted to alphabetical order (which is same as before) - chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] - if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py - log.writeline('Known species are obtained from the dataset') - config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) - - # retrieve shift, scale, conv_denominaotrs from user input (keyword) - init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] - for k in init_from_stats: - input = config[k] # statistic key or numbers - # If it is not 'str', 1: It is 'continue' training - # 2: User manually inserted numbers - if isinstance(input, str) and hasattr(train_set, input): - var = getattr(train_set, input) - config.update({k: var}) - log.writeline(f'{k} is obtained from statistics') - elif isinstance(input, str) and not hasattr(train_set, input): - raise NotImplementedError(input) - - return datasets +import os +import random +import warnings +from collections import Counter +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch.utils.data +from ase.atoms import Atoms +from ase.data import chemical_symbols +from ase.io import write +from tqdm import tqdm + +import sevenn._keys as KEY +import sevenn.train.dataload as dataload +import sevenn.util as util +from sevenn._const import NUM_UNIV_ELEMENT +from sevenn.atom_graph_data import AtomGraphData + +_warn_avg_num_neigh = """SevenNetAtomsDataset does not provide correct avg_num_neigh +as it does not build graph. We will compute only random 10000 structures graph to +approximate this value. If you want more precise avg_num_neigh, +use SevenNetGraphDataset. If it is not viable due to memory limit, you +need online algorithm to do this , which is not yet implemented in the SevenNet""" + + +class SevenNetAtomsDataset(torch.utils.data.Dataset): + """ + Args: + cutoff: edge cutoff of given AtomGraphData + files: list of filenames or dict describing how to parse the file + ASE readable (with proper extension), structure_list, .sevenn_data, + dict containing file_list (see dict_reader of train/dataload.py) + info_dict_copy_keys: patch these keys from KEY.INFO to graph when accessing. + default is KEY.DATA_WEIGHT and KEY.DATA_MODALITY, which may accessed + while training. + **process_kwargs: keyword arguments that will be passed into ase.io.read + """ + + def __init__( + self, + cutoff: float, + files: Union[str, List[str]], + atoms_filter: Optional[Callable] = None, + atoms_transform: Optional[Callable] = None, + transform: Optional[Callable] = None, + use_data_weight: bool = False, + **process_kwargs, + ): + self.cutoff = cutoff + if isinstance(files, str): + files = [files] # user convenience + files = [os.path.abspath(file) for file in files] + self._files = files + self.atoms_filter = atoms_filter + self.atoms_transform = atoms_transform + self.transform = transform + self.use_data_weight = use_data_weight + self._scanned = False + self._avg_num_neigh_approx = None + self.statistics = {} + + atoms_list = [] + for file in files: + atoms_list.extend( + SevenNetAtomsDataset.file_to_atoms_list(file, **process_kwargs) + ) + self._atoms_list = atoms_list + + super().__init__() + + @staticmethod + def file_to_atoms_list(file: Union[str, dict], **kwargs) -> List[Atoms]: + if isinstance(file, dict): + atoms_list = dataload.dict_reader(file) + elif 'structure_list' in file: + atoms_dct = dataload.structure_list_reader(file) + atoms_list = [] + for lst in atoms_dct.values(): + atoms_list.extend(lst) + else: + atoms_list = dataload.ase_reader(file, **kwargs) + return atoms_list + + def save(self, path): + # Save atoms list as extxyz + write(path, self._atoms_list, format='extxyz') + + def _graph_build(self, atoms): + return dataload.atoms_to_graph( + atoms, self.cutoff, transfer_info=False, y_from_calc=False + ) + + def __len__(self): + return len(self._atoms_list) + + def __getitem__(self, index): + atoms = self._atoms_list[index] + if self.atoms_transform is not None: + atoms = self.atoms_transform(atoms) + + graph = self._graph_build(atoms) + if self.transform is not None: + graph = self.transform(graph) + + if self.use_data_weight: + weight = graph[KEY.INFO].pop( + KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0} + ) + graph[KEY.DATA_WEIGHT] = weight + + return AtomGraphData.from_numpy_dict(graph) + + @property + def species(self): + self.run_stat() + return [z for z in self.statistics['_natoms'].keys() if z != 'total'] + + @property + def natoms(self): + self.run_stat() + return self.statistics['_natoms'] + + @property + def per_atom_energy_mean(self): + self.run_stat() + return self.statistics[KEY.PER_ATOM_ENERGY]['mean'] + + @property + def elemwise_reference_energies(self): + from sklearn.linear_model import Ridge + + c = self.statistics['_composition'] + y = self.statistics[KEY.ENERGY]['_array'] + zero_indices = np.all(c == 0, axis=0) + c_reduced = c[:, ~zero_indices] + # will not 100% reproduce, as it is sorted by Z + # train/dataset.py was sorted by alphabets of chemical species + coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ + full_coeff = np.zeros(NUM_UNIV_ELEMENT) + full_coeff[~zero_indices] = coef_reduced + return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy + + @property + def force_rms(self): + self.run_stat() + mean = self.statistics[KEY.FORCE]['mean'] + std = self.statistics[KEY.FORCE]['std'] + return float((mean**2 + std**2) ** (0.5)) + + @property + def per_atom_energy_std(self): + self.run_stat() + return self.statistics['per_atom_energy']['std'] + + @property + def avg_num_neigh(self, n_sample=10000): + if self._avg_num_neigh_approx is None: + if len(self) > n_sample: + warnings.warn(_warn_avg_num_neigh) + n_sample = min(len(self), n_sample) + indices = random.sample(range(len(self)), n_sample) + n_neigh = [] + for i in indices: + graph = self[i] + _, nn = np.unique(graph[KEY.EDGE_IDX][0], return_counts=True) + n_neigh.append(nn) + n_neigh = np.concatenate(n_neigh) + self._avg_num_neigh_approx = np.mean(n_neigh) + return self._avg_num_neigh_approx + + @property + def sqrt_avg_num_neigh(self): + self.run_stat() + return self.avg_num_neigh**0.5 + + def run_stat(self): + """ + Loop over dataset and init any statistics might need + Unlink SevenNetGraphDataset, neighbors count is not computed as + it requires to build graph + """ + if self._scanned is True: + return # statistics already computed + y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS] + natoms_counter = Counter() + composition = np.zeros((len(self), NUM_UNIV_ELEMENT)) + stats: Dict[str, Dict[str, Any]] = {y: {'_array': []} for y in y_keys} + + for i, atoms in tqdm( + enumerate(self._atoms_list), desc='run_stat', total=len(self) + ): + z = atoms.get_atomic_numbers() + natoms_counter.update(z.tolist()) + composition[i] = np.bincount(z, minlength=NUM_UNIV_ELEMENT) + for y, dct in stats.items(): + if y == KEY.ENERGY: + dct['_array'].append(atoms.info['y_energy']) + elif y == KEY.PER_ATOM_ENERGY: + dct['_array'].append(atoms.info['y_energy'] / len(atoms)) + elif y == KEY.FORCE: + dct['_array'].append(atoms.arrays['y_force'].reshape(-1)) + elif y == KEY.STRESS: + dct['_array'].append(atoms.info['y_stress'].reshape(-1)) + + for y, dct in stats.items(): + if y == KEY.FORCE: + array = np.concatenate(dct['_array']) + else: + array = np.array(dct['_array']).reshape(-1) + dct.update( + { + 'mean': float(np.mean(array)), + 'std': float(np.std(array)), + 'median': float(np.quantile(array, q=0.5)), + 'max': float(np.max(array)), + 'min': float(np.min(array)), + '_array': array, + } + ) + + natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()} + natoms['total'] = sum(list(natoms.values())) + self.statistics.update( + { + '_composition': composition, + '_natoms': natoms, + **stats, + } + ) + self._scanned = True + + +# script, return dict of SevenNetAtomsDataset +def from_config( + config: Dict[str, Any], + working_dir: str = os.getcwd(), + dataset_keys: Optional[List[str]] = None, +): + from sevenn.logger import Logger + + log = Logger() + if dataset_keys is None: + dataset_keys = [] + for k in config: + if k.startswith('load_') and k.endswith('_path'): + dataset_keys.append(k) + + if KEY.LOAD_TRAINSET not in dataset_keys: + raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') + + # initialize arguments for loading dataset + dataset_args = { + 'cutoff': config[KEY.CUTOFF], + 'use_data_weight': config.get(KEY.USE_WEIGHT, False), + **config[KEY.DATA_FORMAT_ARGS], + } + + datasets = {} + for dk in dataset_keys: + if not (paths := config[dk]): + continue + if isinstance(paths, str): + paths = [paths] + name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) + dataset_args.update({'files': paths}) + datasets[name] = SevenNetAtomsDataset(**dataset_args) + + if not config[KEY.COMPUTE_STATISTICS]: + log.writeline( + ( + 'Computing statistics is skipped, note that if any of other' + 'configurations requires statistics (shift, scale, avg_num_neigh,' + 'chemical_species as auto), SevenNet eventually raise an error!' + ) + ) + return datasets + + train_set = datasets['trainset'] + + chem_species = set(train_set.species) + # print statistics of each dataset + for name, dataset in datasets.items(): + dataset.run_stat() + log.bar() + log.writeline(f'{name} distribution:') + log.statistic_write(dataset.statistics) + log.format_k_v('# atoms (node)', dataset.natoms, write=True) + log.format_k_v('# structures (graph)', len(dataset), write=True) + + chem_species.update(dataset.species) + log.bar() + + # initialize known species from dataset if 'auto' + # sorted to alphabetical order (which is same as before) + chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] + if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py + log.writeline('Known species are obtained from the dataset') + config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) + + # retrieve shift, scale, conv_denominaotrs from user input (keyword) + init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] + for k in init_from_stats: + input = config[k] # statistic key or numbers + # If it is not 'str', 1: It is 'continue' training + # 2: User manually inserted numbers + if isinstance(input, str) and hasattr(train_set, input): + var = getattr(train_set, input) + config.update({k: var}) + log.writeline(f'{k} is obtained from statistics') + elif isinstance(input, str) and not hasattr(train_set, input): + raise NotImplementedError(input) + + return datasets diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/collate.py b/mace-bench/3rdparty/SevenNet/sevenn/train/collate.py index e3c902a..3e1ede9 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/collate.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/collate.py @@ -1,41 +1,41 @@ -from typing import Any, List, Optional, Sequence - -from ase.atoms import Atoms -from torch_geometric.loader.dataloader import Collater - -from sevenn.atom_graph_data import AtomGraphData - -from .dataload import atoms_to_graph - - -class AtomsToGraphCollater(Collater): - - def __init__( - self, - dataset: Sequence[Atoms], - cutoff: float, - transfer_info: bool = False, - follow_batch: Optional[List[str]] = None, - exclude_keys: Optional[List[str]] = None, - y_from_calc: bool = True, - ): - # quite original collator's type mismatch with [] - super().__init__([], follow_batch, exclude_keys) - self.dataset = dataset - self.cutoff = cutoff - self.transfer_info = transfer_info - self.y_from_calc = y_from_calc - - def __call__(self, batch: List[Any]) -> Any: - # build list of graph - graph_list = [] - for stct in batch: - graph = atoms_to_graph( - stct, - self.cutoff, - transfer_info=self.transfer_info, - y_from_calc=self.y_from_calc, - ) - graph = AtomGraphData.from_numpy_dict(graph) - graph_list.append(graph) - return super().__call__(graph_list) +from typing import Any, List, Optional, Sequence + +from ase.atoms import Atoms +from torch_geometric.loader.dataloader import Collater + +from sevenn.atom_graph_data import AtomGraphData + +from .dataload import atoms_to_graph + + +class AtomsToGraphCollater(Collater): + + def __init__( + self, + dataset: Sequence[Atoms], + cutoff: float, + transfer_info: bool = False, + follow_batch: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + y_from_calc: bool = True, + ): + # quite original collator's type mismatch with [] + super().__init__([], follow_batch, exclude_keys) + self.dataset = dataset + self.cutoff = cutoff + self.transfer_info = transfer_info + self.y_from_calc = y_from_calc + + def __call__(self, batch: List[Any]) -> Any: + # build list of graph + graph_list = [] + for stct in batch: + graph = atoms_to_graph( + stct, + self.cutoff, + transfer_info=self.transfer_info, + y_from_calc=self.y_from_calc, + ) + graph = AtomGraphData.from_numpy_dict(graph) + graph_list.append(graph) + return super().__call__(graph_list) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py b/mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py index cffa221..48cdc8b 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/dataload.py @@ -1,609 +1,609 @@ -import copy -import os.path -from functools import partial -from itertools import chain, islice -from typing import Callable, Dict, List, Optional - -import ase -import ase.io -import numpy as np -import torch.multiprocessing as mp -from ase.io.vasp_parsers.vasp_outcar_parsers import ( - Cell, - DefaultParsersContainer, - Energy, - OutcarChunkParser, - PositionsAndForces, - Stress, - outcarchunks, -) -from ase.neighborlist import primitive_neighbor_list -from ase.utils import string2index -from braceexpand import braceexpand -from tqdm import tqdm - -import sevenn._keys as KEY -from sevenn._const import LossType -from sevenn.atom_graph_data import AtomGraphData - -from .dataset import AtomGraphDataset - - -def _graph_build_matscipy(cutoff: float, pbc, cell, pos): - pbc_x = pbc[0] - pbc_y = pbc[1] - pbc_z = pbc[2] - - identity = np.identity(3, dtype=float) - max_positions = np.max(np.absolute(pos)) + 1 - - # Extend cell in non-periodic directions - # For models with more than 5 layers, - # the multiplicative constant needs to be increased. - if not pbc_x: - cell[0, :] = max_positions * 5 * cutoff * identity[0, :] - if not pbc_y: - cell[1, :] = max_positions * 5 * cutoff * identity[1, :] - if not pbc_z: - cell[2, :] = max_positions * 5 * cutoff * identity[2, :] - # it does not have self-interaction - edge_src, edge_dst, edge_vec, shifts = neighbour_list( - quantities='ijDS', - pbc=pbc, - cell=cell, - positions=pos, - cutoff=cutoff, - ) - # dtype issue - edge_src = edge_src.astype(np.int64) - edge_dst = edge_dst.astype(np.int64) - - return edge_src, edge_dst, edge_vec, shifts - - -def _graph_build_ase(cutoff: float, pbc, cell, pos): - # building neighbor list - edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list( - 'ijDS', pbc, cell, pos, cutoff, self_interaction=True - ) - - is_zero_idx = np.all(edge_vec == 0, axis=1) - is_self_idx = edge_src == edge_dst - non_trivials = ~(is_zero_idx & is_self_idx) - shifts = np.array(shifts[non_trivials]) - - edge_vec = edge_vec[non_trivials] - edge_src = edge_src[non_trivials] - edge_dst = edge_dst[non_trivials] - - return edge_src, edge_dst, edge_vec, shifts - - -_graph_build_f = _graph_build_ase -try: - from matscipy.neighbours import neighbour_list - - _graph_build_f = _graph_build_matscipy -except ImportError: - pass - - -def _correct_scalar(v): - if isinstance(v, np.ndarray): - v = v.squeeze() - assert v.ndim == 0, f'given {v} is not a scalar' - return v - elif isinstance(v, (int, float, np.integer, np.floating)): - return np.array(v) - else: - assert False, f'{type(v)} is not expected' - - -def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float): - pos = atoms.get_positions() - cell = np.array(atoms.get_cell()) - pbc = atoms.get_pbc() - - edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos) - - edge_idx = np.array([edge_src, edge_dst]) - - atomic_numbers = atoms.get_atomic_numbers() - - cell = np.array(cell) - vol = _correct_scalar(atoms.cell.volume) - if vol == 0: - vol = np.array(np.finfo(float).eps) - - data = { - KEY.NODE_FEATURE: atomic_numbers, - KEY.ATOMIC_NUMBERS: atomic_numbers, - KEY.POS: pos, - KEY.EDGE_IDX: edge_idx, - KEY.EDGE_VEC: edge_vec, - KEY.CELL: cell, - KEY.CELL_SHIFT: shifts, - KEY.CELL_VOLUME: vol, - KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), - } - data[KEY.INFO] = {} - return data - - -def atoms_to_graph( - atoms: ase.Atoms, - cutoff: float, - transfer_info: bool = True, - y_from_calc: bool = False, - allow_unlabeled: bool = False, -): - """ - From ase atoms, return AtomGraphData as graph based on cutoff radius - Except for energy, force and stress labels must be numpy array type - as other cases are not tested. - Returns 'np.nan' with consistent shape for unlabeled data - (ex. stress of non-pbc system) - - Args: - atoms (Atoms): ase atoms - cutoff (float): cutoff radius - transfer_info (bool): if True, transfer ".info" from atoms to graph, - defaults to True - y_from_calc: if True, get ref values from calculator, defaults to False - Returns: - numpy dict that can be used to initialize AtomGraphData - by AtomGraphData(**atoms_to_graph(atoms, cutoff)) - , for scalar, its shape is (), and types are np.ndarray - Requires grad is handled by 'dataset' not here. - """ - if not y_from_calc: - y_energy = atoms.info['y_energy'] - y_force = atoms.arrays['y_force'] - y_stress = atoms.info.get('y_stress', np.full((6,), np.nan)) - if y_stress.shape == (3, 3): - y_stress = np.array( - [ - y_stress[0][0], - y_stress[1][1], - y_stress[2][2], - y_stress[0][1], - y_stress[1][2], - y_stress[2][0], - ] - ) - else: - y_stress = y_stress.squeeze() - else: - from_calc = _y_from_calc(atoms) - y_energy = from_calc['energy'] - y_force = from_calc['force'] - y_stress = from_calc['stress'] - assert y_stress.shape == (6,), 'If you see this, please raise a issue' - - if not allow_unlabeled and (np.isnan(y_energy) or np.isnan(y_force).any()): - raise ValueError('Unlabeled E or F found, set allow_unlabeled True') - - pos = atoms.get_positions() - cell = np.array(atoms.get_cell()) - pbc = atoms.get_pbc() - - edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos) - - edge_idx = np.array([edge_src, edge_dst]) - atomic_numbers = atoms.get_atomic_numbers() - - cell = np.array(cell) - vol = _correct_scalar(atoms.cell.volume) - if vol == 0: - vol = np.array(np.finfo(float).eps) - - data = { - KEY.NODE_FEATURE: atomic_numbers, - KEY.ATOMIC_NUMBERS: atomic_numbers, - KEY.POS: pos, - KEY.EDGE_IDX: edge_idx, - KEY.EDGE_VEC: edge_vec, - KEY.ENERGY: _correct_scalar(y_energy), - KEY.FORCE: y_force, - KEY.STRESS: y_stress.reshape(1, 6), # to make batch have (n_node, 6) - KEY.CELL: cell, - KEY.CELL_SHIFT: shifts, - KEY.CELL_VOLUME: vol, - KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), - KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)), - } - - if transfer_info and atoms.info is not None: - info = copy.deepcopy(atoms.info) - # save only metadata - info.pop('y_energy', None) - info.pop('y_force', None) - info.pop('y_stress', None) - data[KEY.INFO] = info - else: - data[KEY.INFO] = {} - - return data - - -def graph_build( - atoms_list: List, - cutoff: float, - num_cores: int = 1, - transfer_info: bool = True, - y_from_calc: bool = False, - allow_unlabeled: bool = False, -) -> List[AtomGraphData]: - """ - parallel version of graph_build - build graph from atoms_list and return list of AtomGraphData - Args: - atoms_list (List): list of ASE atoms - cutoff (float): cutoff radius of graph - num_cores (int): number of cores to use - transfer_info (bool): if True, copy info from atoms to graph, - defaults to True - y_from_calc (bool): Get reference y labels from calculator, defaults to False - Returns: - List[AtomGraphData]: list of AtomGraphData - """ - serial = num_cores == 1 - inputs = [ - (atoms, cutoff, transfer_info, y_from_calc, allow_unlabeled) - for atoms in atoms_list - ] - - if not serial: - pool = mp.Pool(num_cores) - graph_list = pool.starmap( - atoms_to_graph, - tqdm(inputs, total=len(atoms_list), desc=f'graph_build ({num_cores})'), - ) - pool.close() - pool.join() - else: - graph_list = [ - atoms_to_graph(*input_) - for input_ in tqdm(inputs, desc='graph_build (1)') - ] - - graph_list = [AtomGraphData.from_numpy_dict(g) for g in graph_list] - - return graph_list - - -def _y_from_calc(atoms: ase.Atoms): - ret = { - 'energy': np.nan, - 'force': np.full((len(atoms), 3), np.nan), - 'stress': np.full((6,), np.nan), - } - - if atoms.calc is None: - return ret - - try: - ret['energy'] = atoms.get_potential_energy(force_consistent=True) - except NotImplementedError: - ret['energy'] = atoms.get_potential_energy() - - try: - ret['force'] = atoms.get_forces(apply_constraint=False) - except NotImplementedError: - pass - - try: - y_stress = -1 * atoms.get_stress() # it ensures correct shape - ret['stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) - except RuntimeError: - pass - return ret - - -def _set_atoms_y( - atoms_list: List[ase.Atoms], - energy_key: Optional[str] = None, - force_key: Optional[str] = None, - stress_key: Optional[str] = None, -) -> List[ase.Atoms]: - """ - Define how SevenNet reads ASE.atoms object for its y label - If energy_key, force_key, or stress_key is given, the corresponding - label is obtained from .info dict of Atoms object. These values should - have eV, eV/Angstrom, and eV/Angstrom^3 for energy, force, and stress, - respectively. (stress in Voigt notation) - - Args: - atoms_list (list[ase.Atoms]): target atoms to set y_labels - energy_key (str, optional): key to get energy. Defaults to None. - force_key (str, optional): key to get force. Defaults to None. - stress_key (str, optional): key to get stress. Defaults to None. - - Returns: - list[ase.Atoms]: list of ase.Atoms - - Raises: - RuntimeError: if ase atoms are somewhat imperfect - - Use free_energy: atoms.get_potential_energy(force_consistent=True) - If it is not available, use atoms.get_potential_energy() - If stress is available, initialize stress tensor - Ignore constraints like selective dynamics - """ - for atoms in atoms_list: - from_calc = _y_from_calc(atoms) - if energy_key is not None: - atoms.info['y_energy'] = atoms.info.pop(energy_key) - else: - atoms.info['y_energy'] = from_calc['energy'] - - if force_key is not None: - atoms.arrays['y_force'] = atoms.arrays.pop(force_key) - else: - atoms.arrays['y_force'] = from_calc['force'] - - if stress_key is not None: - y_stress = -1 * atoms.info.pop(stress_key) - atoms.info['y_stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) - else: - atoms.info['y_stress'] = from_calc['stress'] - - return atoms_list - - -def ase_reader( - filename: str, - energy_key: Optional[str] = None, - force_key: Optional[str] = None, - stress_key: Optional[str] = None, - index: str = ':', - **kwargs, -) -> List[ase.Atoms]: - """ - Wrapper of ase.io.read - """ - atoms_list = ase.io.read(filename, index=index, **kwargs) - if not isinstance(atoms_list, list): - atoms_list = [atoms_list] - - return _set_atoms_y(atoms_list, energy_key, force_key, stress_key) - - -# Reader -def structure_list_reader(filename: str, format_outputs: Optional[str] = None): - """ - Read from structure_list using braceexpand and ASE - - Args: - fname : filename of structure_list - - Returns: - dictionary of lists of ASE structures. - key is title of training data (user-define) - """ - parsers = DefaultParsersContainer( - PositionsAndForces, Stress, Energy, Cell - ).make_parsers() - ocp = OutcarChunkParser(parsers=parsers) - - def parse_label(line): - line = line.strip() - if line.startswith('[') is False: - return False - elif line.endswith(']') is False: - raise ValueError('wrong structure_list title format') - return line[1:-1] - - def parse_fileline(line): - line = line.strip().split() - if len(line) == 1: - line.append(':') - elif len(line) != 2: - raise ValueError('wrong structure_list format') - return line[0], line[1] - - structure_list_file = open(filename, 'r') - lines = structure_list_file.readlines() - - raw_str_dict = {} - label = 'Default' - for line in lines: - if line.strip() == '': - continue - tmp_label = parse_label(line) - if tmp_label: - label = tmp_label - raw_str_dict[label] = [] - continue - elif label in raw_str_dict: - files_expr, index_expr = parse_fileline(line) - raw_str_dict[label].append((files_expr, index_expr)) - else: - raise ValueError('wrong structure_list format') - structure_list_file.close() - - structures_dict = {} - info_dct = {'data_from': 'user_OUTCAR'} - for title, file_lines in raw_str_dict.items(): - stct_lists = [] - for file_line in file_lines: - files_expr, index_expr = file_line - index = string2index(index_expr) - for expanded_filename in list(braceexpand(files_expr)): - f_stream = open(expanded_filename, 'r') - # generator of all outcar ionic steps - gen_all = outcarchunks(f_stream, ocp) - try: # TODO: index may not slice, it can be integer - it_atoms = islice(gen_all, index.start, index.stop, index.step) - except ValueError: - # TODO: support - # negative index - raise ValueError('Negative index is not supported yet') - - info_dct_f = { - **info_dct, - 'file': os.path.abspath(expanded_filename), - } - for idx, o in enumerate(it_atoms): - try: - it_atoms = islice( - gen_all, index.start, index.stop, index.step - ) - except ValueError: - # TODO: support - # negative index - raise ValueError('Negative index is not supported yet') - - info_dct_f = { - **info_dct, - 'file': os.path.abspath(expanded_filename), - } - for idx, o in enumerate(it_atoms): - try: - istep = index.start + idx * index.step # type: ignore - atoms = o.build() - atoms.info = {**info_dct_f, 'ionic_step': istep}.copy() - except TypeError: # it is not slice of ionic steps - atoms = o.build() - atoms.info = info_dct_f.copy() - stct_lists.append(atoms) - f_stream.close() - else: - stct_lists += ase.io.read( - expanded_filename, - index=index_expr, - parallel=False, - ) - structures_dict[title] = stct_lists - return {k: _set_atoms_y(v) for k, v in structures_dict.items()} - - -def dict_reader(data_dict: Dict): - data_dict_cp = copy.deepcopy(data_dict) - - ret = [] - file_list = data_dict_cp.pop('file_list', None) - if file_list is None: - raise KeyError('file_list is not found') - - data_weight_default = { - 'energy': 1.0, - 'force': 1.0, - 'stress': 1.0, - } - data_weight = data_weight_default.copy() - data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {})) - - for file_dct in file_list: - ftype = file_dct.pop('data_format', 'ase') - files = list(braceexpand(file_dct.pop('file'))) - if ftype == 'ase': - ret.extend(chain(*[ase_reader(f, **file_dct) for f in files])) - elif ftype == 'graph': - continue - else: - raise ValueError(f'{ftype} yet') - - for atoms in ret: - atoms.info.update(data_dict_cp) - atoms.info.update({KEY.DATA_WEIGHT: data_weight}) - return _set_atoms_y(ret) - - -def match_reader(reader_name: str, **kwargs): - reader = None - metadata = {} - if reader_name == 'structure_list': - reader = partial(structure_list_reader, **kwargs) - metadata.update({'origin': 'structure_list'}) - else: - reader = partial(ase_reader, **kwargs) - metadata.update({'origin': 'ase_reader'}) - return reader, metadata - - -def file_to_dataset( - file: str, - cutoff: float, - cores: int = 1, - reader: Callable = ase_reader, - label: Optional[str] = None, - transfer_info: bool = True, - use_weight: bool = False, - use_modality: bool = False, -): - """ - Deprecated - Read file by reader > get list of atoms or dict of atoms - """ - - # expect label: atoms_list dct or atoms or list of atoms - atoms = reader(file) - - if type(atoms) is list: - if label is None: - label = KEY.LABEL_NONE - atoms_dct = {label: atoms} - elif isinstance(atoms, ase.Atoms): - if label is None: - label = KEY.LABEL_NONE - atoms_dct = {label: [atoms]} - elif isinstance(atoms, dict): - atoms_dct = atoms - else: - raise TypeError('The return of reader is not list or dict') - - graph_dct = {} - for label, atoms_list in atoms_dct.items(): - graph_list = graph_build( - atoms_list=atoms_list, - cutoff=cutoff, - num_cores=cores, - transfer_info=transfer_info, - y_from_calc=False, - ) - - label_info = label.split(':') - for graph in graph_list: - graph[KEY.USER_LABEL] = label_info[0].strip() - if use_weight: - find_weight = False - for info in label_info[1:]: - if 'w=' in info.lower(): - weights = info.split('=')[1] - try: - if ',' in weights: - weight_list = list(map(float, weights.split(','))) - else: - weight_list = [float(weights)] * 3 - weight_dict = {} - for idx, loss_type in enumerate(LossType): - weight_dict[loss_type.value] = ( - weight_list[idx] if idx < len(weight_list) else 1 - ) - graph[KEY.DATA_WEIGHT] = weight_dict - find_weight = True - break - except: - raise ValueError( - 'Weight must be a real number, but' - f' {weights} is given for {label}' - ) - if not find_weight: - weight_dict = {} - for loss_type in LossType: - weight_dict[loss_type.value] = 1 - graph[KEY.DATA_WEIGHT] = weight_dict - if use_modality: - find_modality = False - for info in label_info[1:]: - if 'm=' in info.lower(): - graph[KEY.DATA_MODALITY] = (info.split('=')[1]).strip() - find_modality = True - break - if not find_modality: - raise ValueError(f'Modality not given for {label}') - - graph_dct[label_info[0].strip()] = graph_list - db = AtomGraphDataset(graph_dct, cutoff) - return db +import copy +import os.path +from functools import partial +from itertools import chain, islice +from typing import Callable, Dict, List, Optional + +import ase +import ase.io +import numpy as np +import torch.multiprocessing as mp +from ase.io.vasp_parsers.vasp_outcar_parsers import ( + Cell, + DefaultParsersContainer, + Energy, + OutcarChunkParser, + PositionsAndForces, + Stress, + outcarchunks, +) +from ase.neighborlist import primitive_neighbor_list +from ase.utils import string2index +from braceexpand import braceexpand +from tqdm import tqdm + +import sevenn._keys as KEY +from sevenn._const import LossType +from sevenn.atom_graph_data import AtomGraphData + +from .dataset import AtomGraphDataset + + +def _graph_build_matscipy(cutoff: float, pbc, cell, pos): + pbc_x = pbc[0] + pbc_y = pbc[1] + pbc_z = pbc[2] + + identity = np.identity(3, dtype=float) + max_positions = np.max(np.absolute(pos)) + 1 + + # Extend cell in non-periodic directions + # For models with more than 5 layers, + # the multiplicative constant needs to be increased. + if not pbc_x: + cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + if not pbc_y: + cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + if not pbc_z: + cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + # it does not have self-interaction + edge_src, edge_dst, edge_vec, shifts = neighbour_list( + quantities='ijDS', + pbc=pbc, + cell=cell, + positions=pos, + cutoff=cutoff, + ) + # dtype issue + edge_src = edge_src.astype(np.int64) + edge_dst = edge_dst.astype(np.int64) + + return edge_src, edge_dst, edge_vec, shifts + + +def _graph_build_ase(cutoff: float, pbc, cell, pos): + # building neighbor list + edge_src, edge_dst, edge_vec, shifts = primitive_neighbor_list( + 'ijDS', pbc, cell, pos, cutoff, self_interaction=True + ) + + is_zero_idx = np.all(edge_vec == 0, axis=1) + is_self_idx = edge_src == edge_dst + non_trivials = ~(is_zero_idx & is_self_idx) + shifts = np.array(shifts[non_trivials]) + + edge_vec = edge_vec[non_trivials] + edge_src = edge_src[non_trivials] + edge_dst = edge_dst[non_trivials] + + return edge_src, edge_dst, edge_vec, shifts + + +_graph_build_f = _graph_build_ase +try: + from matscipy.neighbours import neighbour_list + + _graph_build_f = _graph_build_matscipy +except ImportError: + pass + + +def _correct_scalar(v): + if isinstance(v, np.ndarray): + v = v.squeeze() + assert v.ndim == 0, f'given {v} is not a scalar' + return v + elif isinstance(v, (int, float, np.integer, np.floating)): + return np.array(v) + else: + assert False, f'{type(v)} is not expected' + + +def unlabeled_atoms_to_graph(atoms: ase.Atoms, cutoff: float): + pos = atoms.get_positions() + cell = np.array(atoms.get_cell()) + pbc = atoms.get_pbc() + + edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos) + + edge_idx = np.array([edge_src, edge_dst]) + + atomic_numbers = atoms.get_atomic_numbers() + + cell = np.array(cell) + vol = _correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx, + KEY.EDGE_VEC: edge_vec, + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), + } + data[KEY.INFO] = {} + return data + + +def atoms_to_graph( + atoms: ase.Atoms, + cutoff: float, + transfer_info: bool = True, + y_from_calc: bool = False, + allow_unlabeled: bool = False, +): + """ + From ase atoms, return AtomGraphData as graph based on cutoff radius + Except for energy, force and stress labels must be numpy array type + as other cases are not tested. + Returns 'np.nan' with consistent shape for unlabeled data + (ex. stress of non-pbc system) + + Args: + atoms (Atoms): ase atoms + cutoff (float): cutoff radius + transfer_info (bool): if True, transfer ".info" from atoms to graph, + defaults to True + y_from_calc: if True, get ref values from calculator, defaults to False + Returns: + numpy dict that can be used to initialize AtomGraphData + by AtomGraphData(**atoms_to_graph(atoms, cutoff)) + , for scalar, its shape is (), and types are np.ndarray + Requires grad is handled by 'dataset' not here. + """ + if not y_from_calc: + y_energy = atoms.info['y_energy'] + y_force = atoms.arrays['y_force'] + y_stress = atoms.info.get('y_stress', np.full((6,), np.nan)) + if y_stress.shape == (3, 3): + y_stress = np.array( + [ + y_stress[0][0], + y_stress[1][1], + y_stress[2][2], + y_stress[0][1], + y_stress[1][2], + y_stress[2][0], + ] + ) + else: + y_stress = y_stress.squeeze() + else: + from_calc = _y_from_calc(atoms) + y_energy = from_calc['energy'] + y_force = from_calc['force'] + y_stress = from_calc['stress'] + assert y_stress.shape == (6,), 'If you see this, please raise a issue' + + if not allow_unlabeled and (np.isnan(y_energy) or np.isnan(y_force).any()): + raise ValueError('Unlabeled E or F found, set allow_unlabeled True') + + pos = atoms.get_positions() + cell = np.array(atoms.get_cell()) + pbc = atoms.get_pbc() + + edge_src, edge_dst, edge_vec, shifts = _graph_build_f(cutoff, pbc, cell, pos) + + edge_idx = np.array([edge_src, edge_dst]) + atomic_numbers = atoms.get_atomic_numbers() + + cell = np.array(cell) + vol = _correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx, + KEY.EDGE_VEC: edge_vec, + KEY.ENERGY: _correct_scalar(y_energy), + KEY.FORCE: y_force, + KEY.STRESS: y_stress.reshape(1, 6), # to make batch have (n_node, 6) + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: _correct_scalar(len(atomic_numbers)), + KEY.PER_ATOM_ENERGY: _correct_scalar(y_energy / len(pos)), + } + + if transfer_info and atoms.info is not None: + info = copy.deepcopy(atoms.info) + # save only metadata + info.pop('y_energy', None) + info.pop('y_force', None) + info.pop('y_stress', None) + data[KEY.INFO] = info + else: + data[KEY.INFO] = {} + + return data + + +def graph_build( + atoms_list: List, + cutoff: float, + num_cores: int = 1, + transfer_info: bool = True, + y_from_calc: bool = False, + allow_unlabeled: bool = False, +) -> List[AtomGraphData]: + """ + parallel version of graph_build + build graph from atoms_list and return list of AtomGraphData + Args: + atoms_list (List): list of ASE atoms + cutoff (float): cutoff radius of graph + num_cores (int): number of cores to use + transfer_info (bool): if True, copy info from atoms to graph, + defaults to True + y_from_calc (bool): Get reference y labels from calculator, defaults to False + Returns: + List[AtomGraphData]: list of AtomGraphData + """ + serial = num_cores == 1 + inputs = [ + (atoms, cutoff, transfer_info, y_from_calc, allow_unlabeled) + for atoms in atoms_list + ] + + if not serial: + pool = mp.Pool(num_cores) + graph_list = pool.starmap( + atoms_to_graph, + tqdm(inputs, total=len(atoms_list), desc=f'graph_build ({num_cores})'), + ) + pool.close() + pool.join() + else: + graph_list = [ + atoms_to_graph(*input_) + for input_ in tqdm(inputs, desc='graph_build (1)') + ] + + graph_list = [AtomGraphData.from_numpy_dict(g) for g in graph_list] + + return graph_list + + +def _y_from_calc(atoms: ase.Atoms): + ret = { + 'energy': np.nan, + 'force': np.full((len(atoms), 3), np.nan), + 'stress': np.full((6,), np.nan), + } + + if atoms.calc is None: + return ret + + try: + ret['energy'] = atoms.get_potential_energy(force_consistent=True) + except NotImplementedError: + ret['energy'] = atoms.get_potential_energy() + + try: + ret['force'] = atoms.get_forces(apply_constraint=False) + except NotImplementedError: + pass + + try: + y_stress = -1 * atoms.get_stress() # it ensures correct shape + ret['stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) + except RuntimeError: + pass + return ret + + +def _set_atoms_y( + atoms_list: List[ase.Atoms], + energy_key: Optional[str] = None, + force_key: Optional[str] = None, + stress_key: Optional[str] = None, +) -> List[ase.Atoms]: + """ + Define how SevenNet reads ASE.atoms object for its y label + If energy_key, force_key, or stress_key is given, the corresponding + label is obtained from .info dict of Atoms object. These values should + have eV, eV/Angstrom, and eV/Angstrom^3 for energy, force, and stress, + respectively. (stress in Voigt notation) + + Args: + atoms_list (list[ase.Atoms]): target atoms to set y_labels + energy_key (str, optional): key to get energy. Defaults to None. + force_key (str, optional): key to get force. Defaults to None. + stress_key (str, optional): key to get stress. Defaults to None. + + Returns: + list[ase.Atoms]: list of ase.Atoms + + Raises: + RuntimeError: if ase atoms are somewhat imperfect + + Use free_energy: atoms.get_potential_energy(force_consistent=True) + If it is not available, use atoms.get_potential_energy() + If stress is available, initialize stress tensor + Ignore constraints like selective dynamics + """ + for atoms in atoms_list: + from_calc = _y_from_calc(atoms) + if energy_key is not None: + atoms.info['y_energy'] = atoms.info.pop(energy_key) + else: + atoms.info['y_energy'] = from_calc['energy'] + + if force_key is not None: + atoms.arrays['y_force'] = atoms.arrays.pop(force_key) + else: + atoms.arrays['y_force'] = from_calc['force'] + + if stress_key is not None: + y_stress = -1 * atoms.info.pop(stress_key) + atoms.info['y_stress'] = np.array(y_stress[[0, 1, 2, 5, 3, 4]]) + else: + atoms.info['y_stress'] = from_calc['stress'] + + return atoms_list + + +def ase_reader( + filename: str, + energy_key: Optional[str] = None, + force_key: Optional[str] = None, + stress_key: Optional[str] = None, + index: str = ':', + **kwargs, +) -> List[ase.Atoms]: + """ + Wrapper of ase.io.read + """ + atoms_list = ase.io.read(filename, index=index, **kwargs) + if not isinstance(atoms_list, list): + atoms_list = [atoms_list] + + return _set_atoms_y(atoms_list, energy_key, force_key, stress_key) + + +# Reader +def structure_list_reader(filename: str, format_outputs: Optional[str] = None): + """ + Read from structure_list using braceexpand and ASE + + Args: + fname : filename of structure_list + + Returns: + dictionary of lists of ASE structures. + key is title of training data (user-define) + """ + parsers = DefaultParsersContainer( + PositionsAndForces, Stress, Energy, Cell + ).make_parsers() + ocp = OutcarChunkParser(parsers=parsers) + + def parse_label(line): + line = line.strip() + if line.startswith('[') is False: + return False + elif line.endswith(']') is False: + raise ValueError('wrong structure_list title format') + return line[1:-1] + + def parse_fileline(line): + line = line.strip().split() + if len(line) == 1: + line.append(':') + elif len(line) != 2: + raise ValueError('wrong structure_list format') + return line[0], line[1] + + structure_list_file = open(filename, 'r') + lines = structure_list_file.readlines() + + raw_str_dict = {} + label = 'Default' + for line in lines: + if line.strip() == '': + continue + tmp_label = parse_label(line) + if tmp_label: + label = tmp_label + raw_str_dict[label] = [] + continue + elif label in raw_str_dict: + files_expr, index_expr = parse_fileline(line) + raw_str_dict[label].append((files_expr, index_expr)) + else: + raise ValueError('wrong structure_list format') + structure_list_file.close() + + structures_dict = {} + info_dct = {'data_from': 'user_OUTCAR'} + for title, file_lines in raw_str_dict.items(): + stct_lists = [] + for file_line in file_lines: + files_expr, index_expr = file_line + index = string2index(index_expr) + for expanded_filename in list(braceexpand(files_expr)): + f_stream = open(expanded_filename, 'r') + # generator of all outcar ionic steps + gen_all = outcarchunks(f_stream, ocp) + try: # TODO: index may not slice, it can be integer + it_atoms = islice(gen_all, index.start, index.stop, index.step) + except ValueError: + # TODO: support + # negative index + raise ValueError('Negative index is not supported yet') + + info_dct_f = { + **info_dct, + 'file': os.path.abspath(expanded_filename), + } + for idx, o in enumerate(it_atoms): + try: + it_atoms = islice( + gen_all, index.start, index.stop, index.step + ) + except ValueError: + # TODO: support + # negative index + raise ValueError('Negative index is not supported yet') + + info_dct_f = { + **info_dct, + 'file': os.path.abspath(expanded_filename), + } + for idx, o in enumerate(it_atoms): + try: + istep = index.start + idx * index.step # type: ignore + atoms = o.build() + atoms.info = {**info_dct_f, 'ionic_step': istep}.copy() + except TypeError: # it is not slice of ionic steps + atoms = o.build() + atoms.info = info_dct_f.copy() + stct_lists.append(atoms) + f_stream.close() + else: + stct_lists += ase.io.read( + expanded_filename, + index=index_expr, + parallel=False, + ) + structures_dict[title] = stct_lists + return {k: _set_atoms_y(v) for k, v in structures_dict.items()} + + +def dict_reader(data_dict: Dict): + data_dict_cp = copy.deepcopy(data_dict) + + ret = [] + file_list = data_dict_cp.pop('file_list', None) + if file_list is None: + raise KeyError('file_list is not found') + + data_weight_default = { + 'energy': 1.0, + 'force': 1.0, + 'stress': 1.0, + } + data_weight = data_weight_default.copy() + data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {})) + + for file_dct in file_list: + ftype = file_dct.pop('data_format', 'ase') + files = list(braceexpand(file_dct.pop('file'))) + if ftype == 'ase': + ret.extend(chain(*[ase_reader(f, **file_dct) for f in files])) + elif ftype == 'graph': + continue + else: + raise ValueError(f'{ftype} yet') + + for atoms in ret: + atoms.info.update(data_dict_cp) + atoms.info.update({KEY.DATA_WEIGHT: data_weight}) + return _set_atoms_y(ret) + + +def match_reader(reader_name: str, **kwargs): + reader = None + metadata = {} + if reader_name == 'structure_list': + reader = partial(structure_list_reader, **kwargs) + metadata.update({'origin': 'structure_list'}) + else: + reader = partial(ase_reader, **kwargs) + metadata.update({'origin': 'ase_reader'}) + return reader, metadata + + +def file_to_dataset( + file: str, + cutoff: float, + cores: int = 1, + reader: Callable = ase_reader, + label: Optional[str] = None, + transfer_info: bool = True, + use_weight: bool = False, + use_modality: bool = False, +): + """ + Deprecated + Read file by reader > get list of atoms or dict of atoms + """ + + # expect label: atoms_list dct or atoms or list of atoms + atoms = reader(file) + + if type(atoms) is list: + if label is None: + label = KEY.LABEL_NONE + atoms_dct = {label: atoms} + elif isinstance(atoms, ase.Atoms): + if label is None: + label = KEY.LABEL_NONE + atoms_dct = {label: [atoms]} + elif isinstance(atoms, dict): + atoms_dct = atoms + else: + raise TypeError('The return of reader is not list or dict') + + graph_dct = {} + for label, atoms_list in atoms_dct.items(): + graph_list = graph_build( + atoms_list=atoms_list, + cutoff=cutoff, + num_cores=cores, + transfer_info=transfer_info, + y_from_calc=False, + ) + + label_info = label.split(':') + for graph in graph_list: + graph[KEY.USER_LABEL] = label_info[0].strip() + if use_weight: + find_weight = False + for info in label_info[1:]: + if 'w=' in info.lower(): + weights = info.split('=')[1] + try: + if ',' in weights: + weight_list = list(map(float, weights.split(','))) + else: + weight_list = [float(weights)] * 3 + weight_dict = {} + for idx, loss_type in enumerate(LossType): + weight_dict[loss_type.value] = ( + weight_list[idx] if idx < len(weight_list) else 1 + ) + graph[KEY.DATA_WEIGHT] = weight_dict + find_weight = True + break + except: + raise ValueError( + 'Weight must be a real number, but' + f' {weights} is given for {label}' + ) + if not find_weight: + weight_dict = {} + for loss_type in LossType: + weight_dict[loss_type.value] = 1 + graph[KEY.DATA_WEIGHT] = weight_dict + if use_modality: + find_modality = False + for info in label_info[1:]: + if 'm=' in info.lower(): + graph[KEY.DATA_MODALITY] = (info.split('=')[1]).strip() + find_modality = True + break + if not find_modality: + raise ValueError(f'Modality not given for {label}') + + graph_dct[label_info[0].strip()] = graph_list + db = AtomGraphDataset(graph_dct, cutoff) + return db diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py index 3c31c55..ccf04df 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/dataset.py @@ -1,496 +1,496 @@ -import itertools -import random -from collections import Counter -from typing import Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from ase.data import chemical_symbols -from sklearn.linear_model import Ridge - -import sevenn._keys as KEY -import sevenn.util as util - - -class AtomGraphDataset: - """ - Deprecated - - class representing dataset of AtomGraphData - the dataset is handled as dict, {label: data} - if given data is List, it stores data as {KEY_DEFAULT: data} - - cutoff is for metadata of the graphs not used for some calc - Every data expected to have one unique cutoff - No validity or check of the condition is done inside the object - - attribute: - dataset (Dict[str, List]): key is data label(str), value is list of data - user_labels (List[str]): list of user labels same as dataset.keys() - meta (Dict, Optional): metadata of dataset - for now, metadata 'might' have following keys: - KEY.CUTOFF (float), KEY.CHEMICAL_SPECIES (Dict) - """ - - DATA_KEY_X = ( - KEY.NODE_FEATURE - ) # atomic_number > one_hot_idx > one_hot_vector - DATA_KEY_ENERGY = KEY.ENERGY - DATA_KEY_FORCE = KEY.FORCE - KEY_DEFAULT = KEY.LABEL_NONE - - def __init__( - self, - dataset: Union[Dict[str, List], List], - cutoff: float, - metadata: Optional[Dict] = None, - x_is_one_hot_idx: bool = False, - ): - """ - Default constructor of AtomGraphDataset - Args: - dataset (Union[Dict[str, List], List]: dataset as dict or pure list - metadata (Dict, Optional): metadata of data - cutoff (float): cutoff radius of graphs inside the dataset - x_is_one_hot_idx (bool): if True, x is one_hot_idx, else 'Z' - - 'x' (node feature) of dataset can have 3 states, atomic_numbers, - one_hot_idx, or one_hot_vector. - - atomic_numbers is general but cannot directly used for input - one_hot_idx is can be input of the model but requires 'type_map' - """ - self.cutoff = cutoff - self.x_is_one_hot_idx = x_is_one_hot_idx - if metadata is None: - metadata = {KEY.CUTOFF: cutoff} - self.meta = metadata - if type(dataset) is list: - self.dataset = {self.KEY_DEFAULT: dataset} - else: - self.dataset = dataset - self.user_labels = list(self.dataset.keys()) - # group_by_key here? or not? - - def rewrite_labels_to_data(self): - """ - Based on self.dataset dict's keys - write data[KEY.USER_LABEL] to correspond to dict's keys - Most of times, it is already correctly written - But required to rewrite if someone rearrange dataset by their own way - """ - for label, data_list in self.dataset.items(): - for data in data_list: - data[KEY.USER_LABEL] = label - - def group_by_key(self, data_key: str = KEY.USER_LABEL): - """ - group dataset list by given key and save it as dict - and change in-place - Args: - data_key (str): data key to group by - - original use is USER_LABEL, but it can be used for other keys - if someone established it from data[KEY.INFO] - """ - data_list = self.to_list() - self.dataset = {} - for datum in data_list: - key = datum[data_key] - if key not in self.dataset: - self.dataset[key] = [] - self.dataset[key].append(datum) - self.user_labels = list(self.dataset.keys()) - - def separate_info(self, data_key: str = KEY.INFO): - """ - Separate info from data and save it as list of dict - to make it compatible with torch_geometric and later training - """ - data_list = self.to_list() - info_list = [] - for datum in data_list: - if data_key in datum is False: - continue - info_list.append(datum[data_key]) - del datum[data_key] # It does change the self.dataset - datum[data_key] = len(info_list) - 1 - self.info_list = info_list - - return (data_list, info_list) - - def get_species(self): - """ - You can also use get_natoms and extract keys from there instead of this - (And it is more efficient) - get chemical species of dataset - return list of SORTED chemical species (as str) - """ - if hasattr(self, 'type_map'): - natoms = self.get_natoms(self.type_map) - else: - natoms = self.get_natoms() - species = set() - for natom_dct in natoms.values(): - species.update(natom_dct.keys()) - species = sorted(list(species)) - return species - - def get_modalities(self): - modalities = set() - for data_list in self.dataset.values(): - datum = data_list[0].to_dict() - if KEY.DATA_MODALITY in datum.keys(): - modalities.add(datum[KEY.DATA_MODALITY]) - else: - return [] - return list(modalities) - - def write_modal_attr( - self, modal_type_mapper: dict, write_modal_type: bool = False - ): - num_modalities = len(modal_type_mapper) - for data_list in self.dataset.values(): - for data in data_list: - tmp_tensor = torch.zeros(num_modalities) - if data[KEY.DATA_MODALITY] != 'common': - modal_idx = modal_type_mapper[data[KEY.DATA_MODALITY]] - tmp_tensor[modal_idx] = 1.0 - if write_modal_type: - data[KEY.MODAL_TYPE] = modal_idx - data[KEY.MODAL_ATTR] = tmp_tensor - - def get_dict_sort_by_modality(self): - dict_sort_by_modality = {} - for data_list in self.dataset.values(): - try: - modal_key = data_list[0].to_dict()[KEY.DATA_MODALITY] - except: # Dataset is not modal - raise ValueError('This dataset has no modality.') - - if modal_key not in dict_sort_by_modality.keys(): - dict_sort_by_modality[modal_key] = [] - dict_sort_by_modality[modal_key].extend(data_list) - - return dict_sort_by_modality - - def len(self): - if ( - len(self.dataset.keys()) == 1 - and list(self.dataset.keys())[0] == AtomGraphDataset.KEY_DEFAULT - ): - return len(self.dataset[AtomGraphDataset.KEY_DEFAULT]) - else: - return {k: len(v) for k, v in self.dataset.items()} - - def get(self, idx: int, key: Optional[str] = None): - if key is None: - key = self.KEY_DEFAULT - return self.dataset[key][idx] - - def items(self): - return self.dataset.items() - - def to_dict(self): - dct_dataset = {} - for label, data_list in self.dataset.items(): - dct_dataset[label] = [datum.to_dict() for datum in data_list] - self.dataset = dct_dataset - return self - - def x_to_one_hot_idx(self, type_map: Dict[int, int]): - """ - type_map is dict of {atomic_number: one_hot_idx} - after this process, the dataset has dependency on type_map - or chemical species user want to consider - """ - assert self.x_is_one_hot_idx is False - for data_list in self.dataset.values(): - for datum in data_list: - datum[self.DATA_KEY_X] = torch.LongTensor( - [type_map[z.item()] for z in datum[self.DATA_KEY_X]] - ) - self.type_map = type_map - self.x_is_one_hot_idx = True - - def toggle_requires_grad_of_data( - self, key: str, requires_grad_value: bool - ): - """ - set requires_grad of specific key of data(pos, edge_vec, ...) - """ - for data_list in self.dataset.values(): - for datum in data_list: - datum[key].requires_grad_(requires_grad_value) - - def divide_dataset( - self, - ratio: float, - constant_ratio_btw_labels: bool = True, - ignore_test: bool = True - ): - """ - divide dataset into 1-2*ratio : ratio : ratio - return divided AtomGraphDataset - returned value lost its dict key and became {KEY_DEFAULT: datalist} - but KEY.USER_LABEL of each data is preserved - """ - - def divide(ratio: float, data_list: List, ignore_test=True): - if ratio > 0.5: - raise ValueError('Ratio must not exceed 0.5') - data_len = len(data_list) - random.shuffle(data_list) - n_validation = int(data_len * ratio) - if n_validation == 0: - raise ValueError( - '# of validation set is 0, increase your dataset' - ) - - if ignore_test: - test_list = [] - n_train = data_len - n_validation - train_list = data_list[0:n_train] - valid_list = data_list[n_train:] - else: - n_train = data_len - 2 * n_validation - train_list = data_list[0:n_train] - valid_list = data_list[n_train : n_train + n_validation] - test_list = data_list[n_train + n_validation : data_len] - return train_list, valid_list, test_list - - lists = ([], [], []) # train, valid, test - if constant_ratio_btw_labels: - for data_list in self.dataset.values(): - for store, divided in zip(lists, divide(ratio, data_list)): - store.extend(divided) - else: - lists = divide(ratio, self.to_list()) - - dbs = tuple( - AtomGraphDataset(data, self.cutoff, self.meta) for data in lists - ) - for db in dbs: - db.group_by_key() - return dbs - - def to_list(self): - return list(itertools.chain(*self.dataset.values())) - - def get_natoms(self, type_map: Optional[Dict[int, int]] = None): - """ - if x_is_one_hot_idx, type_map is required - type_map: Z->one_hot_index(node_feature) - return Dict{label: {symbol, natom}]} - """ - assert not (self.x_is_one_hot_idx is True and type_map is None) - natoms = {} - for label, data in self.dataset.items(): - natoms[label] = Counter() - for datum in data: - if self.x_is_one_hot_idx and type_map is not None: - Zs = util.onehot_to_chem(datum[self.DATA_KEY_X], type_map) - else: - Zs = [ - chemical_symbols[z] - for z in datum[self.DATA_KEY_X].tolist() - ] - cnt = Counter(Zs) - natoms[label] += cnt - natoms[label] = dict(natoms[label]) - return natoms - - def get_per_atom_mean(self, key: str, key_num_atoms: str = KEY.NUM_ATOMS): - """ - return per_atom mean of given data key - """ - eng_list = torch.Tensor( - [x[key] / x[key_num_atoms] for x in self.to_list()] - ) - return float(torch.mean(eng_list)) - - def get_per_atom_energy_mean(self): - """ - alias for get_per_atom_mean(KEY.ENERGY) - """ - return self.get_per_atom_mean(self.DATA_KEY_ENERGY) - - def get_species_ref_energy_by_linear_comb(self, num_chem_species: int): - """ - Total energy as y, composition as c_i, - solve linear regression of y = c_i*X - sklearn LinearRegression as solver - - x should be one-hot-indexed - give num_chem_species if possible - """ - assert self.x_is_one_hot_idx is True - data_list = self.to_list() - - c = torch.zeros((len(data_list), num_chem_species)) - for idx, datum in enumerate(data_list): - c[idx] = torch.bincount( - datum[self.DATA_KEY_X], minlength=num_chem_species - ) - y = torch.Tensor([x[self.DATA_KEY_ENERGY] for x in data_list]) - c = c.numpy() - y = y.numpy() - - # tweak to fine tune training from many-element to small element - zero_indices = np.all(c == 0, axis=0) - c_reduced = c[:, ~zero_indices] - full_coeff = np.zeros(num_chem_species) - coef_reduced = ( - Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ - ) - full_coeff[~zero_indices] = coef_reduced - - return full_coeff - - def get_force_rms(self): - force_list = [] - for x in self.to_list(): - force_list.extend( - x[self.DATA_KEY_FORCE] - .reshape( - -1, - ) - .tolist() - ) - force_list = torch.Tensor(force_list) - return float(torch.sqrt(torch.mean(torch.pow(force_list, 2)))) - - def get_species_wise_force_rms(self, num_chem_species: int): - """ - Return force rms for each species - Averaged by each components (x, y, z) - """ - assert self.x_is_one_hot_idx is True - data_list = self.to_list() - - atomx = torch.concat([d[self.DATA_KEY_X] for d in data_list]) - force = torch.concat([d[self.DATA_KEY_FORCE] for d in data_list]) - - index = atomx.repeat_interleave(3, 0).reshape(force.shape) - rms = torch.zeros( - (num_chem_species, 3), - dtype=force.dtype, - device=force.device - ) - rms.scatter_reduce_( - 0, index, force.square(), - reduce='mean', include_self=False - ) - return torch.sqrt(rms.mean(dim=1)) - - def get_avg_num_neigh(self): - n_neigh = [] - for _, data_list in self.dataset.items(): - for data in data_list: - n_neigh.extend( - np.unique(data[KEY.EDGE_IDX][0], return_counts=True)[1] - ) - - avg_num_neigh = np.average(n_neigh) - return avg_num_neigh - - def get_statistics(self, key: str): - """ - return dict of statistics of given key (energy, force, stress) - key of dict is its label and _total for total statistics - value of dict is dict of statistics (mean, std, median, max, min) - """ - - def _get_statistic_dict(tensor_list): - data_list = torch.cat( - [ - tensor.reshape( - -1, - ) - for tensor in tensor_list - ] - ) - data_list = data_list[~torch.isnan(data_list)] - return { - 'mean': float(torch.mean(data_list)), - 'std': float(torch.std(data_list)), - 'median': float(torch.median(data_list)), - 'max': ( - torch.nan - if data_list.numel() == 0 - else float(torch.max(data_list)) - ), - 'min': ( - torch.nan - if data_list.numel() == 0 - else float(torch.min(data_list)) - ), - } - - res = {} - for label, values in self.dataset.items(): - # flatten list of torch.Tensor (values) - tensor_list = [x[key] for x in values] - res[label] = _get_statistic_dict(tensor_list) - tensor_list = [x[key] for x in self.to_list()] - res['Total'] = _get_statistic_dict(tensor_list) - return res - - def augment(self, dataset, validator: Optional[Callable] = None): - """check meta compatibility here - dataset(AtomGraphDataset): data to augment - validator(Callable, Optional): function(self, dataset) -> bool - - if validator is None, by default it checks - whether cutoff & chemical_species are same before augment - - check consistent data type, float, double, long integer etc - """ - - def default_validator(db1, db2): - cut_consis = db1.cutoff == db2.cutoff - # compare unordered lists - x_is_not_onehot = (not db1.x_is_one_hot_idx) and ( - not db2.x_is_one_hot_idx - ) - return cut_consis and x_is_not_onehot - - if validator is None: - validator = default_validator - if not validator(self, dataset): - raise ValueError('given datasets are not compatible check cutoffs') - for key, val in dataset.items(): - if key in self.dataset: - self.dataset[key].extend(val) - else: - self.dataset.update({key: val}) - self.user_labels = list(self.dataset.keys()) - - def unify_dtypes( - self, - float_dtype: torch.dtype = torch.float32, - int_dtype: torch.dtype = torch.int64 - ): - data_list = self.to_list() - for datum in data_list: - for k, v in list(datum.items()): - datum[k] = util.dtype_correct(v, float_dtype, int_dtype) - - def delete_data_key(self, key: str): - for data in self.to_list(): - del data[key] - - # TODO: this by_label is not straightforward - def save(self, path: str, by_label: bool = False): - if by_label: - for label, data in self.dataset.items(): - torch.save( - AtomGraphDataset( - {label: data}, self.cutoff, metadata=self.meta - ), - f'{path}/{label}.sevenn_data', - ) - else: - if path.endswith('.sevenn_data') is False: - path += '.sevenn_data' - torch.save(self, path) +import itertools +import random +from collections import Counter +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from ase.data import chemical_symbols +from sklearn.linear_model import Ridge + +import sevenn._keys as KEY +import sevenn.util as util + + +class AtomGraphDataset: + """ + Deprecated + + class representing dataset of AtomGraphData + the dataset is handled as dict, {label: data} + if given data is List, it stores data as {KEY_DEFAULT: data} + + cutoff is for metadata of the graphs not used for some calc + Every data expected to have one unique cutoff + No validity or check of the condition is done inside the object + + attribute: + dataset (Dict[str, List]): key is data label(str), value is list of data + user_labels (List[str]): list of user labels same as dataset.keys() + meta (Dict, Optional): metadata of dataset + for now, metadata 'might' have following keys: + KEY.CUTOFF (float), KEY.CHEMICAL_SPECIES (Dict) + """ + + DATA_KEY_X = ( + KEY.NODE_FEATURE + ) # atomic_number > one_hot_idx > one_hot_vector + DATA_KEY_ENERGY = KEY.ENERGY + DATA_KEY_FORCE = KEY.FORCE + KEY_DEFAULT = KEY.LABEL_NONE + + def __init__( + self, + dataset: Union[Dict[str, List], List], + cutoff: float, + metadata: Optional[Dict] = None, + x_is_one_hot_idx: bool = False, + ): + """ + Default constructor of AtomGraphDataset + Args: + dataset (Union[Dict[str, List], List]: dataset as dict or pure list + metadata (Dict, Optional): metadata of data + cutoff (float): cutoff radius of graphs inside the dataset + x_is_one_hot_idx (bool): if True, x is one_hot_idx, else 'Z' + + 'x' (node feature) of dataset can have 3 states, atomic_numbers, + one_hot_idx, or one_hot_vector. + + atomic_numbers is general but cannot directly used for input + one_hot_idx is can be input of the model but requires 'type_map' + """ + self.cutoff = cutoff + self.x_is_one_hot_idx = x_is_one_hot_idx + if metadata is None: + metadata = {KEY.CUTOFF: cutoff} + self.meta = metadata + if type(dataset) is list: + self.dataset = {self.KEY_DEFAULT: dataset} + else: + self.dataset = dataset + self.user_labels = list(self.dataset.keys()) + # group_by_key here? or not? + + def rewrite_labels_to_data(self): + """ + Based on self.dataset dict's keys + write data[KEY.USER_LABEL] to correspond to dict's keys + Most of times, it is already correctly written + But required to rewrite if someone rearrange dataset by their own way + """ + for label, data_list in self.dataset.items(): + for data in data_list: + data[KEY.USER_LABEL] = label + + def group_by_key(self, data_key: str = KEY.USER_LABEL): + """ + group dataset list by given key and save it as dict + and change in-place + Args: + data_key (str): data key to group by + + original use is USER_LABEL, but it can be used for other keys + if someone established it from data[KEY.INFO] + """ + data_list = self.to_list() + self.dataset = {} + for datum in data_list: + key = datum[data_key] + if key not in self.dataset: + self.dataset[key] = [] + self.dataset[key].append(datum) + self.user_labels = list(self.dataset.keys()) + + def separate_info(self, data_key: str = KEY.INFO): + """ + Separate info from data and save it as list of dict + to make it compatible with torch_geometric and later training + """ + data_list = self.to_list() + info_list = [] + for datum in data_list: + if data_key in datum is False: + continue + info_list.append(datum[data_key]) + del datum[data_key] # It does change the self.dataset + datum[data_key] = len(info_list) - 1 + self.info_list = info_list + + return (data_list, info_list) + + def get_species(self): + """ + You can also use get_natoms and extract keys from there instead of this + (And it is more efficient) + get chemical species of dataset + return list of SORTED chemical species (as str) + """ + if hasattr(self, 'type_map'): + natoms = self.get_natoms(self.type_map) + else: + natoms = self.get_natoms() + species = set() + for natom_dct in natoms.values(): + species.update(natom_dct.keys()) + species = sorted(list(species)) + return species + + def get_modalities(self): + modalities = set() + for data_list in self.dataset.values(): + datum = data_list[0].to_dict() + if KEY.DATA_MODALITY in datum.keys(): + modalities.add(datum[KEY.DATA_MODALITY]) + else: + return [] + return list(modalities) + + def write_modal_attr( + self, modal_type_mapper: dict, write_modal_type: bool = False + ): + num_modalities = len(modal_type_mapper) + for data_list in self.dataset.values(): + for data in data_list: + tmp_tensor = torch.zeros(num_modalities) + if data[KEY.DATA_MODALITY] != 'common': + modal_idx = modal_type_mapper[data[KEY.DATA_MODALITY]] + tmp_tensor[modal_idx] = 1.0 + if write_modal_type: + data[KEY.MODAL_TYPE] = modal_idx + data[KEY.MODAL_ATTR] = tmp_tensor + + def get_dict_sort_by_modality(self): + dict_sort_by_modality = {} + for data_list in self.dataset.values(): + try: + modal_key = data_list[0].to_dict()[KEY.DATA_MODALITY] + except: # Dataset is not modal + raise ValueError('This dataset has no modality.') + + if modal_key not in dict_sort_by_modality.keys(): + dict_sort_by_modality[modal_key] = [] + dict_sort_by_modality[modal_key].extend(data_list) + + return dict_sort_by_modality + + def len(self): + if ( + len(self.dataset.keys()) == 1 + and list(self.dataset.keys())[0] == AtomGraphDataset.KEY_DEFAULT + ): + return len(self.dataset[AtomGraphDataset.KEY_DEFAULT]) + else: + return {k: len(v) for k, v in self.dataset.items()} + + def get(self, idx: int, key: Optional[str] = None): + if key is None: + key = self.KEY_DEFAULT + return self.dataset[key][idx] + + def items(self): + return self.dataset.items() + + def to_dict(self): + dct_dataset = {} + for label, data_list in self.dataset.items(): + dct_dataset[label] = [datum.to_dict() for datum in data_list] + self.dataset = dct_dataset + return self + + def x_to_one_hot_idx(self, type_map: Dict[int, int]): + """ + type_map is dict of {atomic_number: one_hot_idx} + after this process, the dataset has dependency on type_map + or chemical species user want to consider + """ + assert self.x_is_one_hot_idx is False + for data_list in self.dataset.values(): + for datum in data_list: + datum[self.DATA_KEY_X] = torch.LongTensor( + [type_map[z.item()] for z in datum[self.DATA_KEY_X]] + ) + self.type_map = type_map + self.x_is_one_hot_idx = True + + def toggle_requires_grad_of_data( + self, key: str, requires_grad_value: bool + ): + """ + set requires_grad of specific key of data(pos, edge_vec, ...) + """ + for data_list in self.dataset.values(): + for datum in data_list: + datum[key].requires_grad_(requires_grad_value) + + def divide_dataset( + self, + ratio: float, + constant_ratio_btw_labels: bool = True, + ignore_test: bool = True + ): + """ + divide dataset into 1-2*ratio : ratio : ratio + return divided AtomGraphDataset + returned value lost its dict key and became {KEY_DEFAULT: datalist} + but KEY.USER_LABEL of each data is preserved + """ + + def divide(ratio: float, data_list: List, ignore_test=True): + if ratio > 0.5: + raise ValueError('Ratio must not exceed 0.5') + data_len = len(data_list) + random.shuffle(data_list) + n_validation = int(data_len * ratio) + if n_validation == 0: + raise ValueError( + '# of validation set is 0, increase your dataset' + ) + + if ignore_test: + test_list = [] + n_train = data_len - n_validation + train_list = data_list[0:n_train] + valid_list = data_list[n_train:] + else: + n_train = data_len - 2 * n_validation + train_list = data_list[0:n_train] + valid_list = data_list[n_train : n_train + n_validation] + test_list = data_list[n_train + n_validation : data_len] + return train_list, valid_list, test_list + + lists = ([], [], []) # train, valid, test + if constant_ratio_btw_labels: + for data_list in self.dataset.values(): + for store, divided in zip(lists, divide(ratio, data_list)): + store.extend(divided) + else: + lists = divide(ratio, self.to_list()) + + dbs = tuple( + AtomGraphDataset(data, self.cutoff, self.meta) for data in lists + ) + for db in dbs: + db.group_by_key() + return dbs + + def to_list(self): + return list(itertools.chain(*self.dataset.values())) + + def get_natoms(self, type_map: Optional[Dict[int, int]] = None): + """ + if x_is_one_hot_idx, type_map is required + type_map: Z->one_hot_index(node_feature) + return Dict{label: {symbol, natom}]} + """ + assert not (self.x_is_one_hot_idx is True and type_map is None) + natoms = {} + for label, data in self.dataset.items(): + natoms[label] = Counter() + for datum in data: + if self.x_is_one_hot_idx and type_map is not None: + Zs = util.onehot_to_chem(datum[self.DATA_KEY_X], type_map) + else: + Zs = [ + chemical_symbols[z] + for z in datum[self.DATA_KEY_X].tolist() + ] + cnt = Counter(Zs) + natoms[label] += cnt + natoms[label] = dict(natoms[label]) + return natoms + + def get_per_atom_mean(self, key: str, key_num_atoms: str = KEY.NUM_ATOMS): + """ + return per_atom mean of given data key + """ + eng_list = torch.Tensor( + [x[key] / x[key_num_atoms] for x in self.to_list()] + ) + return float(torch.mean(eng_list)) + + def get_per_atom_energy_mean(self): + """ + alias for get_per_atom_mean(KEY.ENERGY) + """ + return self.get_per_atom_mean(self.DATA_KEY_ENERGY) + + def get_species_ref_energy_by_linear_comb(self, num_chem_species: int): + """ + Total energy as y, composition as c_i, + solve linear regression of y = c_i*X + sklearn LinearRegression as solver + + x should be one-hot-indexed + give num_chem_species if possible + """ + assert self.x_is_one_hot_idx is True + data_list = self.to_list() + + c = torch.zeros((len(data_list), num_chem_species)) + for idx, datum in enumerate(data_list): + c[idx] = torch.bincount( + datum[self.DATA_KEY_X], minlength=num_chem_species + ) + y = torch.Tensor([x[self.DATA_KEY_ENERGY] for x in data_list]) + c = c.numpy() + y = y.numpy() + + # tweak to fine tune training from many-element to small element + zero_indices = np.all(c == 0, axis=0) + c_reduced = c[:, ~zero_indices] + full_coeff = np.zeros(num_chem_species) + coef_reduced = ( + Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ + ) + full_coeff[~zero_indices] = coef_reduced + + return full_coeff + + def get_force_rms(self): + force_list = [] + for x in self.to_list(): + force_list.extend( + x[self.DATA_KEY_FORCE] + .reshape( + -1, + ) + .tolist() + ) + force_list = torch.Tensor(force_list) + return float(torch.sqrt(torch.mean(torch.pow(force_list, 2)))) + + def get_species_wise_force_rms(self, num_chem_species: int): + """ + Return force rms for each species + Averaged by each components (x, y, z) + """ + assert self.x_is_one_hot_idx is True + data_list = self.to_list() + + atomx = torch.concat([d[self.DATA_KEY_X] for d in data_list]) + force = torch.concat([d[self.DATA_KEY_FORCE] for d in data_list]) + + index = atomx.repeat_interleave(3, 0).reshape(force.shape) + rms = torch.zeros( + (num_chem_species, 3), + dtype=force.dtype, + device=force.device + ) + rms.scatter_reduce_( + 0, index, force.square(), + reduce='mean', include_self=False + ) + return torch.sqrt(rms.mean(dim=1)) + + def get_avg_num_neigh(self): + n_neigh = [] + for _, data_list in self.dataset.items(): + for data in data_list: + n_neigh.extend( + np.unique(data[KEY.EDGE_IDX][0], return_counts=True)[1] + ) + + avg_num_neigh = np.average(n_neigh) + return avg_num_neigh + + def get_statistics(self, key: str): + """ + return dict of statistics of given key (energy, force, stress) + key of dict is its label and _total for total statistics + value of dict is dict of statistics (mean, std, median, max, min) + """ + + def _get_statistic_dict(tensor_list): + data_list = torch.cat( + [ + tensor.reshape( + -1, + ) + for tensor in tensor_list + ] + ) + data_list = data_list[~torch.isnan(data_list)] + return { + 'mean': float(torch.mean(data_list)), + 'std': float(torch.std(data_list)), + 'median': float(torch.median(data_list)), + 'max': ( + torch.nan + if data_list.numel() == 0 + else float(torch.max(data_list)) + ), + 'min': ( + torch.nan + if data_list.numel() == 0 + else float(torch.min(data_list)) + ), + } + + res = {} + for label, values in self.dataset.items(): + # flatten list of torch.Tensor (values) + tensor_list = [x[key] for x in values] + res[label] = _get_statistic_dict(tensor_list) + tensor_list = [x[key] for x in self.to_list()] + res['Total'] = _get_statistic_dict(tensor_list) + return res + + def augment(self, dataset, validator: Optional[Callable] = None): + """check meta compatibility here + dataset(AtomGraphDataset): data to augment + validator(Callable, Optional): function(self, dataset) -> bool + + if validator is None, by default it checks + whether cutoff & chemical_species are same before augment + + check consistent data type, float, double, long integer etc + """ + + def default_validator(db1, db2): + cut_consis = db1.cutoff == db2.cutoff + # compare unordered lists + x_is_not_onehot = (not db1.x_is_one_hot_idx) and ( + not db2.x_is_one_hot_idx + ) + return cut_consis and x_is_not_onehot + + if validator is None: + validator = default_validator + if not validator(self, dataset): + raise ValueError('given datasets are not compatible check cutoffs') + for key, val in dataset.items(): + if key in self.dataset: + self.dataset[key].extend(val) + else: + self.dataset.update({key: val}) + self.user_labels = list(self.dataset.keys()) + + def unify_dtypes( + self, + float_dtype: torch.dtype = torch.float32, + int_dtype: torch.dtype = torch.int64 + ): + data_list = self.to_list() + for datum in data_list: + for k, v in list(datum.items()): + datum[k] = util.dtype_correct(v, float_dtype, int_dtype) + + def delete_data_key(self, key: str): + for data in self.to_list(): + del data[key] + + # TODO: this by_label is not straightforward + def save(self, path: str, by_label: bool = False): + if by_label: + for label, data in self.dataset.items(): + torch.save( + AtomGraphDataset( + {label: data}, self.cutoff, metadata=self.meta + ), + f'{path}/{label}.sevenn_data', + ) + else: + if path.endswith('.sevenn_data') is False: + path += '.sevenn_data' + torch.save(self, path) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py index fd8d395..fc32d14 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/graph_dataset.py @@ -1,707 +1,707 @@ -import os -import warnings -from collections import Counter -from copy import deepcopy -from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.serialization -import torch.utils.data -import yaml -from ase.data import chemical_symbols -from torch_geometric.data import Data -from torch_geometric.data.in_memory_dataset import InMemoryDataset -from tqdm import tqdm - -import sevenn._keys as KEY -import sevenn.train.dataload as dataload -import sevenn.util as util -from sevenn import __version__ -from sevenn._const import NUM_UNIV_ELEMENT -from sevenn.atom_graph_data import AtomGraphData -from sevenn.logger import Logger - -if torch.__version__.split()[0] >= '2.4.0': - # load graph without error - torch.serialization.add_safe_globals([AtomGraphData]) - -# warning from PyG, for later torch versions -warnings.filterwarnings( - 'ignore', - message='You are using `torch.load` with `weights_only=False`', -) - - -def _tag_graphs(graph_list: List[AtomGraphData], tag: str): - """ - WIP: To be used - """ - for g in graph_list: - g[KEY.TAG] = tag - return graph_list - - -def pt_to_args(pt_filename: str): - """ - Return arg dict of root and processed_name from path to .pt - Usage: - dataset = SevenNetGraphDataset( - **pt_to_args({path}/sevenn_data/dataset.pt) - ) - """ - processed_dir, basename = os.path.split(pt_filename) - return { - 'root': os.path.dirname(processed_dir), - 'processed_name': os.path.basename(basename), - } - - -def _run_stat( - graph_list, - y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS], -) -> Dict[str, Any]: - """ - Loop over dataset and init any statistics might need - """ - n_neigh = [] - natoms_counter = Counter() - composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT)) - stats: Dict[str, Any] = {y: {'_array': []} for y in y_keys} - - for i, graph in tqdm( - enumerate(graph_list), desc='run_stat', total=len(graph_list) - ): - z_tensor = graph[KEY.ATOMIC_NUMBERS] - natoms_counter.update(z_tensor.tolist()) - composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT) - n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1]) - for y, dct in stats.items(): - dct['_array'].append( - graph[y].reshape( - -1, - ) - ) - - stats.update({'num_neighbor': {'_array': n_neigh}}) - for y, dct in stats.items(): - array = torch.cat(dct['_array']) - if array.dtype == torch.int64: # because of n_neigh - array = array.to(torch.float) - try: - median = torch.quantile(array, q=0.5) - except RuntimeError: - warnings.warn(f'skip median due to too large tensor size: {y}') - median = torch.nan - dct.update( - { - 'mean': float(torch.mean(array)), - 'std': float(torch.std(array, correction=0)), - 'median': float(median), - 'max': float(torch.max(array)), - 'min': float(torch.min(array)), - 'count': array.numel(), - '_array': array, - } - ) - - natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()} - natoms['total'] = sum(list(natoms.values())) - stats.update({'_composition': composition, 'natoms': natoms}) - return stats - - -def _elemwise_reference_energies(composition: np.ndarray, energies: np.ndarray): - from sklearn.linear_model import Ridge - - c = composition - y = energies - zero_indices = np.all(c == 0, axis=0) - c_reduced = c[:, ~zero_indices] - # will not 100% reproduce, as it is sorted by Z - # train/dataset.py was sorted by alphabets of chemical species - coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ - full_coeff = np.zeros(NUM_UNIV_ELEMENT) - full_coeff[~zero_indices] = coef_reduced - return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy - - -class SevenNetGraphDataset(InMemoryDataset): - """ - Replacement of AtomGraphDataset. (and .sevenn_data) - Extends InMemoryDataset of PyG. From given 'files', and 'cutoff', - build graphs for training SevenNet model. Preprocessed graphs are saved to - f'{root}/sevenn_data/{processed_name}.pt - - TODO: Save meta info (cutoff) by overriding .save and .load - TODO: 'tag' is not used yet, but initialized - 'tag' is replacement for 'label', and each datapoint has it as integer - 'tag' is usually parsed from if the structure_list of load_dataset - - Args: - root: path to save/load processed PyG dataset - cutoff: edge cutoff of given AtomGraphData - files: list of filenames or dict describing how to parse the file - ASE readable (with proper extension), structure_list, .sevenn_data, - dict containing file_list (see dict_reader of train/dataload.py) - process_num_cores: # of cpu cores to build graph - processed_name: save as {root}/sevenn_data/{processed_name}.pt - pre_transfrom: optional transform for each graph: def (graph) -> graph - pre_filter: optional filtering function for each graph: def (graph) -> graph - force_reload: if True, reload dataset from files even if there exist - {root}/sevenn_data/{processed_name} - **process_kwargs: keyword arguments that will be passed into ase.io.read - """ - - def __init__( - self, - cutoff: float, - root: Optional[str] = None, - files: Optional[Union[str, List[Any]]] = None, - process_num_cores: int = 1, - processed_name: str = 'graph.pt', - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, - use_data_weight: bool = False, - log: bool = True, - force_reload: bool = False, - drop_info: bool = True, - **process_kwargs, - ): - self.cutoff = cutoff - if files is None: - files = [] - elif isinstance(files, str): - files = [files] # user convenience - - _files = [] - for f in files: - if isinstance(f, str): - f = os.path.abspath(f) - _files.append(f) - self._files = _files - - self._full_file_list = [] - if not processed_name.endswith('.pt'): - processed_name += '.pt' - self._processed_names = [ - processed_name, # {root}/sevenn_data/{name}.pt - processed_name.replace('.pt', '.yaml'), - ] - - root = root or './' - _pdir = os.path.join(root, 'sevenn_data') - _pt = os.path.join(_pdir, self._processed_names[0]) - if not os.path.exists(_pt) and len(self._files) == 0: - raise ValueError( - ( - f'{_pt} not found and no files to process. ' - + 'If you copied only .pt file, please copy ' - + 'whole sevenn_data dir without changing its name.' - + ' They all work together.' - ) - ) - - _yam = os.path.join(_pdir, self._processed_names[1]) - if not os.path.exists(_yam) and len(self._files) == 0: - raise ValueError(f'{_yam} not found and no files to process') - - self.process_num_cores = process_num_cores - self.process_kwargs = process_kwargs - self.use_data_weight = use_data_weight - self.drop_info = drop_info - - self.tag_map = {} - self.statistics = {} - self.finalized = False - - super().__init__( - root, - transform, - pre_transform, - pre_filter, - log=log, - force_reload=force_reload, - ) # Internally calls 'process' - self.load(self.processed_paths[0]) # load pt, saved after process - - def load(self, path: str, data_cls=Data) -> None: - super().load(path, data_cls) - - if len(self) == 0: - warnings.warn(f'No graphs found {self.processed_paths[0]}') - if len(self.statistics) == 0: - # dataset is loaded from existing pt file. - self._load_meta() - - def _load_meta(self) -> None: - with open(self.processed_paths[1], 'r') as f: - meta = yaml.safe_load(f) - - if meta['sevennet_version'] == '0.10.0': - self._save_meta(list(self)) - with open(self.processed_paths[1], 'r') as f: - meta = yaml.safe_load(f) - - cutoff = float(meta['cutoff']) - if float(meta['cutoff']) != self.cutoff: - warnings.warn( - ( - 'Loaded dataset is built with different cutoff length: ' - + f'{cutoff} != {self.cutoff}, dataset cutoff will be' - + f' overwritten to {cutoff}' - ) - ) - self.cutoff = cutoff - self._files = meta['files'] - self.statistics = meta['statistics'] - - def __getitem__(self, idx): - graph = super().__getitem__(idx) - if self.drop_info: - graph.pop(KEY.INFO, None) # type: ignore - return graph - - @property - def raw_file_names(self) -> List[Any]: - return self._files - - @property - def processed_file_names(self) -> List[str]: - return self._processed_names - - @property - def processed_dir(self) -> str: - return os.path.join(self.root, 'sevenn_data') - - @property - def full_file_list(self) -> Union[List[str], None]: - return self._full_file_list - - def process(self): - graph_list: List[AtomGraphData] = [] - for file in self.raw_file_names: - tmplist = SevenNetGraphDataset.file_to_graph_list( - file=file, - cutoff=self.cutoff, - num_cores=self.process_num_cores, - **self.process_kwargs, - ) - if isinstance(file, str) and self._full_file_list is not None: - self._full_file_list.extend([os.path.abspath(file)] * len(tmplist)) - else: - self._full_file_list = None - graph_list.extend(tmplist) - - processed_graph_list = [] - for data in graph_list: - if self.pre_filter is not None and not self.pre_filter(data): - continue - if self.pre_transform is not None: - data = self.pre_transform(data) - if self.use_data_weight: - # pop data weight from info, and assign to graph - weight = data[KEY.INFO].pop( - KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0} - ) - data[KEY.DATA_WEIGHT] = weight - processed_graph_list.append(data) - - if len(processed_graph_list) == 0: - # Can not save at all if there is no graph (error in PyG), raise an error - raise ValueError('Zero graph found after filtering') - - # save graphs, handled by torch_geometrics - self.save(processed_graph_list, self.processed_paths[0]) - self._save_meta(processed_graph_list) - if self.log: - Logger().writeline(f'Dataset is saved: {self.processed_paths[0]}') - - def _save_meta(self, graph_list) -> None: - stats = _run_stat(graph_list) - stats['elemwise_reference_energies'] = _elemwise_reference_energies( - stats['_composition'].numpy(), stats[KEY.ENERGY]['_array'].numpy() - ) - self.statistics = stats - - stats_save = {} - for label, dct in self.statistics.items(): - if label.startswith('_'): - continue - stats_save[label] = {} - if not isinstance(dct, dict): - stats_save[label] = dct - else: - for k, v in dct.items(): - if k.startswith('_'): - continue - stats_save[label][k] = v - - meta = { - 'sevennet_version': __version__, - 'cutoff': self.cutoff, - 'when': datetime.now().strftime('%Y-%m-%d %H:%M'), - 'files': self._files, - 'statistics': stats_save, - 'species': self.species, - 'num_graphs': self.statistics[KEY.ENERGY]['count'], - 'per_atom_energy_mean': self.per_atom_energy_mean, - 'force_rms': self.force_rms, - 'per_atom_energy_std': self.per_atom_energy_std, - 'avg_num_neigh': self.avg_num_neigh, - 'sqrt_avg_num_neigh': self.sqrt_avg_num_neigh, - } - - with open(self.processed_paths[1], 'w') as f: - yaml.dump(meta, f, default_flow_style=False) - - @property - def species(self): - return [z for z in self.statistics['natoms'].keys() if z != 'total'] - - @property - def natoms(self): - return self.statistics['natoms'] - - @property - def per_atom_energy_mean(self): - return self.statistics[KEY.PER_ATOM_ENERGY]['mean'] - - @property - def elemwise_reference_energies(self): - return self.statistics['elemwise_reference_energies'] - - @property - def force_rms(self): - mean = self.statistics[KEY.FORCE]['mean'] - std = self.statistics[KEY.FORCE]['std'] - return float((mean**2 + std**2) ** (0.5)) - - @property - def per_atom_energy_std(self): - return self.statistics['per_atom_energy']['std'] - - @property - def avg_num_neigh(self): - return self.statistics['num_neighbor']['mean'] - - @property - def sqrt_avg_num_neigh(self): - return self.avg_num_neigh**0.5 - - @staticmethod - def _read_sevenn_data(filename: str) -> Tuple[List[AtomGraphData], float]: - # backward compatibility - from sevenn.train.dataset import AtomGraphDataset - - dataset = torch.load(filename, map_location='cpu', weights_only=False) - if isinstance(dataset, AtomGraphDataset): - graph_list = [] - for _, graphs in dataset.dataset.items(): # type: ignore - # TODO: transfer label to tag (who gonna need this?) - graph_list.extend(graphs) - return graph_list, dataset.cutoff - else: - raise ValueError(f'Not sevenn_data type: {type(dataset)}') - - @staticmethod - def _read_structure_list( - filename: str, cutoff: float, num_cores: int = 1 - ) -> List[AtomGraphData]: - datadct = dataload.structure_list_reader(filename) - graph_list = [] - for tag, atoms_list in datadct.items(): - tmp = dataload.graph_build(atoms_list, cutoff, num_cores) - graph_list.extend(_tag_graphs(tmp, tag)) - return graph_list - - @staticmethod - def _read_ase_readable( - filename: str, - cutoff: float, - num_cores: int = 1, - tag: str = '', - transfer_info: bool = True, - allow_unlabeled: bool = False, - **ase_kwargs, - ) -> List[AtomGraphData]: - pbc_override = ase_kwargs.pop('pbc', None) - atoms_list = dataload.ase_reader(filename, **ase_kwargs) - for atoms in atoms_list: - if pbc_override is not None: - atoms.pbc = pbc_override - graph_list = dataload.graph_build( - atoms_list, - cutoff, - num_cores, - transfer_info=transfer_info, - allow_unlabeled=allow_unlabeled, - ) - if tag != '': - graph_list = _tag_graphs(graph_list, tag) - return graph_list - - @staticmethod - def _read_graph_dataset( - filename: str, cutoff: float, **kwargs - ) -> List[AtomGraphData]: - meta_f = filename.replace('.pt', '.yaml') - orig_cutoff = cutoff - if not os.path.exists(filename): - raise FileNotFoundError(f'No such file: {filename}') - if not os.path.exists(meta_f): - warnings.warn('No meta info found, beware of cutoff...') - else: - with open(meta_f, 'r') as f: - meta = yaml.safe_load(f) - orig_cutoff = float(meta['cutoff']) - if orig_cutoff != cutoff: - warnings.warn( - f'{filename} has different cutoff length: ' - + f'{cutoff} != {orig_cutoff}' - ) - ds_args: dict[str, Any] = dict({'cutoff': orig_cutoff}) - ds_args.update(pt_to_args(filename)) - ds_args.update(kwargs) - dataset = SevenNetGraphDataset(**ds_args) - # TODO: hard coded. consult with inference.py - glist = [g.fit_dimension() for g in dataset] # type: ignore - for g in glist: - if KEY.STRESS in g: - # (1, 6) is what we want - g[KEY.STRESS] = g[KEY.STRESS].unsqueeze(0) - return glist - - @staticmethod - def _read_dict( - data_dict: dict, - cutoff: float, - num_cores: int = 1, - ): - # logic same as the dataload dict_reader, but handles graphs - data_dict_cp = deepcopy(data_dict) - file_list = data_dict_cp.get('file_list', None) - if file_list is None: - raise KeyError('file_list is not found') - - data_weight_default = { - 'energy': 1.0, - 'force': 1.0, - 'stress': 1.0, - } - data_weight = data_weight_default.copy() - data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {})) - - graph_list = [] - for file_dct in file_list: - ftype = file_dct.pop('data_format', 'ase') - if ftype != 'graph': - continue - graph_list.extend( - SevenNetGraphDataset._read_graph_dataset( - file_dct.get('file'), cutoff=cutoff - ) - ) - for graph in graph_list: - if KEY.INFO not in graph: - graph[KEY.INFO] = {} - graph[KEY.INFO].update(data_dict_cp) - graph[KEY.INFO].update({KEY.DATA_WEIGHT: data_weight}) - - atoms_list = dataload.dict_reader(data_dict) - graph_list.extend(dataload.graph_build(atoms_list, cutoff, num_cores)) - return graph_list - - @staticmethod - def file_to_graph_list( - file: Union[str, dict], cutoff: float, num_cores: int = 1, **kwargs - ) -> List[AtomGraphData]: - """ - kwargs: if file is ase readable, passed to ase.io.read - """ - if isinstance(file, str) and not os.path.isfile(file): - raise ValueError(f'No such file: {file}') - graph_list: List[AtomGraphData] - if isinstance(file, dict): - graph_list = SevenNetGraphDataset._read_dict( - file, cutoff, num_cores, **kwargs - ) - elif file.endswith('.pt'): - graph_list = SevenNetGraphDataset._read_graph_dataset(file, cutoff) - elif file.endswith('.sevenn_data'): - graph_list, cutoff_other = SevenNetGraphDataset._read_sevenn_data(file) - if cutoff_other != cutoff: - warnings.warn(f'Given {file} has different {cutoff_other}!') - cutoff = cutoff_other - elif 'structure_list' in file: - graph_list = SevenNetGraphDataset._read_structure_list( - file, cutoff, num_cores - ) - else: - graph_list = SevenNetGraphDataset._read_ase_readable( - file, cutoff, num_cores, **kwargs - ) - return graph_list - - -def from_single_path( - path: Union[str, List], override_data_weight: bool = True, **dataset_kwargs -) -> Union[SevenNetGraphDataset, None]: - """ - Convenient routine for loading a single .pt dataset. - If given dict and it has data_weight, apply it using transform - """ - data_weight = {'energy': 1.0, 'force': 1.0, 'stress': 1.0} - spath = _extract_single_path(path) - if spath is None: - return None - - if isinstance(spath, str): - if not spath.endswith('.pt'): - return None - dataset_kwargs.update(pt_to_args(spath)) - elif isinstance(spath, dict): - file = _extract_file_from_dict(spath) - if file is None or not file.endswith('.pt'): - return None - dataset_kwargs.update(pt_to_args(file)) - data_weight_user = spath.get(KEY.DATA_WEIGHT, None) - if data_weight_user is not None: - data_weight.update(data_weight_user) - else: - return None - - if override_data_weight: - dataset_kwargs['transform'] = _chain_data_weight_override( - dataset_kwargs.get('transform'), data_weight - ) - - return SevenNetGraphDataset(**dataset_kwargs) - - -def _extract_single_path(path: Union[str, List]) -> Union[str, dict, None]: - """Extracts a single path from the input, - ensuring it's either a single string or list with one item.""" - if isinstance(path, list): - return path[0] if len(path) == 1 else None - return path if isinstance(path, (str, dict)) else None - - -def _extract_file_from_dict(path_dict: dict) -> Union[str, None]: - """Extracts a single file path from the dictionary, ensuring it's valid.""" - file_list = path_dict.get('file_list', None) - if file_list and len(file_list) == 1: - file = file_list[0].get('file', None) - return file if isinstance(file, str) else None - return None - - -def _chain_data_weight_override(transform_func, data_weight): - """Creates a transform function that overrides the data weight.""" - - def chained_transform(graph): - graph = transform_func(graph) if transform_func is not None else graph - graph[KEY.INFO].pop(KEY.DATA_WEIGHT, None) - graph[KEY.DATA_WEIGHT] = data_weight - return graph - - return chained_transform - - -# script, return dict of SevenNetGraphDataset -def from_config( - config: Dict[str, Any], - working_dir: str = os.getcwd(), - dataset_keys: Optional[List[str]] = None, -): - log = Logger() - if dataset_keys is None: - dataset_keys = [] - for k in config: - if k.startswith('load_') and k.endswith('_path'): - dataset_keys.append(k) - - if KEY.LOAD_TRAINSET not in dataset_keys: - raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') - - # initialize arguments for loading dataset - dataset_args = { - 'cutoff': config[KEY.CUTOFF], - 'root': working_dir, - 'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1), - 'use_data_weight': config.get(KEY.USE_WEIGHT, False), - **config.get(KEY.DATA_FORMAT_ARGS, {}), - } - - datasets = {} - for dk in dataset_keys: - if not (paths := config[dk]): - continue - if isinstance(paths, str): - paths = [paths] - name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) - if (dataset := from_single_path(paths, **dataset_args)) is not None: - datasets[name] = dataset - else: - dataset_args.update({'files': paths, 'processed_name': name}) - dataset_path = os.path.join(working_dir, 'sevenn_data', f'{name}.pt') - if os.path.exists(dataset_path) and 'force_reload' not in dataset_args: - log.writeline( - f'Dataset will be loaded from {dataset_path}, without update. ' - + 'If you have changed your files to read, put force_reload=True' - + ' under the data_format_args key' - ) - datasets[name] = SevenNetGraphDataset(**dataset_args) - - train_set = datasets['trainset'] - - chem_species = set(train_set.species) - # print statistics of each dataset - for name, dataset in datasets.items(): - log.bar() - log.writeline(f'{name} distribution:') - log.statistic_write(dataset.statistics) - log.format_k_v('# structures (graph)', len(dataset), write=True) - - chem_species.update(dataset.species) - log.bar() - - # initialize known species from dataset if 'auto' - # sorted to alphabetical order (which is same as before) - chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] - if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py - log.writeline('Known species are obtained from the dataset') - config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) - - # retrieve shift, scale, conv_denominaotrs from user input (keyword) - init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] - for k in init_from_stats: - input = config[k] # statistic key or numbers - # If it is not 'str', 1: It is 'continue' training - # 2: User manually inserted numbers - if isinstance(input, str) and hasattr(train_set, input): - var = getattr(train_set, input) - config.update({k: var}) - log.writeline(f'{k} is obtained from statistics') - elif isinstance(input, str) and not hasattr(train_set, input): - raise NotImplementedError(input) - - if 'validset' not in datasets and config.get(KEY.RATIO, 0.0) > 0.0: - log.writeline('Use validation set as random split from the training set') - log.writeline( - 'Note that statistics, shift, scale, and conv_denominator are ' - + 'computed before random split.\n If you want these after random ' - + 'split, please preprocess dataset and set it as load_trainset_path ' - + 'and load_validset_path explicitly.' - ) - - ratio = float(config[KEY.RATIO]) - train, valid = torch.utils.data.random_split( - datasets['trainset'], (1.0 - ratio, ratio) - ) - datasets['trainset'] = train - datasets['validset'] = valid - - return datasets +import os +import warnings +from collections import Counter +from copy import deepcopy +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.serialization +import torch.utils.data +import yaml +from ase.data import chemical_symbols +from torch_geometric.data import Data +from torch_geometric.data.in_memory_dataset import InMemoryDataset +from tqdm import tqdm + +import sevenn._keys as KEY +import sevenn.train.dataload as dataload +import sevenn.util as util +from sevenn import __version__ +from sevenn._const import NUM_UNIV_ELEMENT +from sevenn.atom_graph_data import AtomGraphData +from sevenn.logger import Logger + +if torch.__version__.split()[0] >= '2.4.0': + # load graph without error + torch.serialization.add_safe_globals([AtomGraphData]) + +# warning from PyG, for later torch versions +warnings.filterwarnings( + 'ignore', + message='You are using `torch.load` with `weights_only=False`', +) + + +def _tag_graphs(graph_list: List[AtomGraphData], tag: str): + """ + WIP: To be used + """ + for g in graph_list: + g[KEY.TAG] = tag + return graph_list + + +def pt_to_args(pt_filename: str): + """ + Return arg dict of root and processed_name from path to .pt + Usage: + dataset = SevenNetGraphDataset( + **pt_to_args({path}/sevenn_data/dataset.pt) + ) + """ + processed_dir, basename = os.path.split(pt_filename) + return { + 'root': os.path.dirname(processed_dir), + 'processed_name': os.path.basename(basename), + } + + +def _run_stat( + graph_list, + y_keys: List[str] = [KEY.ENERGY, KEY.PER_ATOM_ENERGY, KEY.FORCE, KEY.STRESS], +) -> Dict[str, Any]: + """ + Loop over dataset and init any statistics might need + """ + n_neigh = [] + natoms_counter = Counter() + composition = torch.zeros((len(graph_list), NUM_UNIV_ELEMENT)) + stats: Dict[str, Any] = {y: {'_array': []} for y in y_keys} + + for i, graph in tqdm( + enumerate(graph_list), desc='run_stat', total=len(graph_list) + ): + z_tensor = graph[KEY.ATOMIC_NUMBERS] + natoms_counter.update(z_tensor.tolist()) + composition[i] = torch.bincount(z_tensor, minlength=NUM_UNIV_ELEMENT) + n_neigh.append(torch.unique(graph[KEY.EDGE_IDX][0], return_counts=True)[1]) + for y, dct in stats.items(): + dct['_array'].append( + graph[y].reshape( + -1, + ) + ) + + stats.update({'num_neighbor': {'_array': n_neigh}}) + for y, dct in stats.items(): + array = torch.cat(dct['_array']) + if array.dtype == torch.int64: # because of n_neigh + array = array.to(torch.float) + try: + median = torch.quantile(array, q=0.5) + except RuntimeError: + warnings.warn(f'skip median due to too large tensor size: {y}') + median = torch.nan + dct.update( + { + 'mean': float(torch.mean(array)), + 'std': float(torch.std(array, correction=0)), + 'median': float(median), + 'max': float(torch.max(array)), + 'min': float(torch.min(array)), + 'count': array.numel(), + '_array': array, + } + ) + + natoms = {chemical_symbols[int(z)]: cnt for z, cnt in natoms_counter.items()} + natoms['total'] = sum(list(natoms.values())) + stats.update({'_composition': composition, 'natoms': natoms}) + return stats + + +def _elemwise_reference_energies(composition: np.ndarray, energies: np.ndarray): + from sklearn.linear_model import Ridge + + c = composition + y = energies + zero_indices = np.all(c == 0, axis=0) + c_reduced = c[:, ~zero_indices] + # will not 100% reproduce, as it is sorted by Z + # train/dataset.py was sorted by alphabets of chemical species + coef_reduced = Ridge(alpha=0.1, fit_intercept=False).fit(c_reduced, y).coef_ + full_coeff = np.zeros(NUM_UNIV_ELEMENT) + full_coeff[~zero_indices] = coef_reduced + return full_coeff.tolist() # ex: full_coeff[1] = H_reference_energy + + +class SevenNetGraphDataset(InMemoryDataset): + """ + Replacement of AtomGraphDataset. (and .sevenn_data) + Extends InMemoryDataset of PyG. From given 'files', and 'cutoff', + build graphs for training SevenNet model. Preprocessed graphs are saved to + f'{root}/sevenn_data/{processed_name}.pt + + TODO: Save meta info (cutoff) by overriding .save and .load + TODO: 'tag' is not used yet, but initialized + 'tag' is replacement for 'label', and each datapoint has it as integer + 'tag' is usually parsed from if the structure_list of load_dataset + + Args: + root: path to save/load processed PyG dataset + cutoff: edge cutoff of given AtomGraphData + files: list of filenames or dict describing how to parse the file + ASE readable (with proper extension), structure_list, .sevenn_data, + dict containing file_list (see dict_reader of train/dataload.py) + process_num_cores: # of cpu cores to build graph + processed_name: save as {root}/sevenn_data/{processed_name}.pt + pre_transfrom: optional transform for each graph: def (graph) -> graph + pre_filter: optional filtering function for each graph: def (graph) -> graph + force_reload: if True, reload dataset from files even if there exist + {root}/sevenn_data/{processed_name} + **process_kwargs: keyword arguments that will be passed into ase.io.read + """ + + def __init__( + self, + cutoff: float, + root: Optional[str] = None, + files: Optional[Union[str, List[Any]]] = None, + process_num_cores: int = 1, + processed_name: str = 'graph.pt', + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + use_data_weight: bool = False, + log: bool = True, + force_reload: bool = False, + drop_info: bool = True, + **process_kwargs, + ): + self.cutoff = cutoff + if files is None: + files = [] + elif isinstance(files, str): + files = [files] # user convenience + + _files = [] + for f in files: + if isinstance(f, str): + f = os.path.abspath(f) + _files.append(f) + self._files = _files + + self._full_file_list = [] + if not processed_name.endswith('.pt'): + processed_name += '.pt' + self._processed_names = [ + processed_name, # {root}/sevenn_data/{name}.pt + processed_name.replace('.pt', '.yaml'), + ] + + root = root or './' + _pdir = os.path.join(root, 'sevenn_data') + _pt = os.path.join(_pdir, self._processed_names[0]) + if not os.path.exists(_pt) and len(self._files) == 0: + raise ValueError( + ( + f'{_pt} not found and no files to process. ' + + 'If you copied only .pt file, please copy ' + + 'whole sevenn_data dir without changing its name.' + + ' They all work together.' + ) + ) + + _yam = os.path.join(_pdir, self._processed_names[1]) + if not os.path.exists(_yam) and len(self._files) == 0: + raise ValueError(f'{_yam} not found and no files to process') + + self.process_num_cores = process_num_cores + self.process_kwargs = process_kwargs + self.use_data_weight = use_data_weight + self.drop_info = drop_info + + self.tag_map = {} + self.statistics = {} + self.finalized = False + + super().__init__( + root, + transform, + pre_transform, + pre_filter, + log=log, + force_reload=force_reload, + ) # Internally calls 'process' + self.load(self.processed_paths[0]) # load pt, saved after process + + def load(self, path: str, data_cls=Data) -> None: + super().load(path, data_cls) + + if len(self) == 0: + warnings.warn(f'No graphs found {self.processed_paths[0]}') + if len(self.statistics) == 0: + # dataset is loaded from existing pt file. + self._load_meta() + + def _load_meta(self) -> None: + with open(self.processed_paths[1], 'r') as f: + meta = yaml.safe_load(f) + + if meta['sevennet_version'] == '0.10.0': + self._save_meta(list(self)) + with open(self.processed_paths[1], 'r') as f: + meta = yaml.safe_load(f) + + cutoff = float(meta['cutoff']) + if float(meta['cutoff']) != self.cutoff: + warnings.warn( + ( + 'Loaded dataset is built with different cutoff length: ' + + f'{cutoff} != {self.cutoff}, dataset cutoff will be' + + f' overwritten to {cutoff}' + ) + ) + self.cutoff = cutoff + self._files = meta['files'] + self.statistics = meta['statistics'] + + def __getitem__(self, idx): + graph = super().__getitem__(idx) + if self.drop_info: + graph.pop(KEY.INFO, None) # type: ignore + return graph + + @property + def raw_file_names(self) -> List[Any]: + return self._files + + @property + def processed_file_names(self) -> List[str]: + return self._processed_names + + @property + def processed_dir(self) -> str: + return os.path.join(self.root, 'sevenn_data') + + @property + def full_file_list(self) -> Union[List[str], None]: + return self._full_file_list + + def process(self): + graph_list: List[AtomGraphData] = [] + for file in self.raw_file_names: + tmplist = SevenNetGraphDataset.file_to_graph_list( + file=file, + cutoff=self.cutoff, + num_cores=self.process_num_cores, + **self.process_kwargs, + ) + if isinstance(file, str) and self._full_file_list is not None: + self._full_file_list.extend([os.path.abspath(file)] * len(tmplist)) + else: + self._full_file_list = None + graph_list.extend(tmplist) + + processed_graph_list = [] + for data in graph_list: + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + if self.use_data_weight: + # pop data weight from info, and assign to graph + weight = data[KEY.INFO].pop( + KEY.DATA_WEIGHT, {'energy': 1.0, 'force': 1.0, 'stress': 1.0} + ) + data[KEY.DATA_WEIGHT] = weight + processed_graph_list.append(data) + + if len(processed_graph_list) == 0: + # Can not save at all if there is no graph (error in PyG), raise an error + raise ValueError('Zero graph found after filtering') + + # save graphs, handled by torch_geometrics + self.save(processed_graph_list, self.processed_paths[0]) + self._save_meta(processed_graph_list) + if self.log: + Logger().writeline(f'Dataset is saved: {self.processed_paths[0]}') + + def _save_meta(self, graph_list) -> None: + stats = _run_stat(graph_list) + stats['elemwise_reference_energies'] = _elemwise_reference_energies( + stats['_composition'].numpy(), stats[KEY.ENERGY]['_array'].numpy() + ) + self.statistics = stats + + stats_save = {} + for label, dct in self.statistics.items(): + if label.startswith('_'): + continue + stats_save[label] = {} + if not isinstance(dct, dict): + stats_save[label] = dct + else: + for k, v in dct.items(): + if k.startswith('_'): + continue + stats_save[label][k] = v + + meta = { + 'sevennet_version': __version__, + 'cutoff': self.cutoff, + 'when': datetime.now().strftime('%Y-%m-%d %H:%M'), + 'files': self._files, + 'statistics': stats_save, + 'species': self.species, + 'num_graphs': self.statistics[KEY.ENERGY]['count'], + 'per_atom_energy_mean': self.per_atom_energy_mean, + 'force_rms': self.force_rms, + 'per_atom_energy_std': self.per_atom_energy_std, + 'avg_num_neigh': self.avg_num_neigh, + 'sqrt_avg_num_neigh': self.sqrt_avg_num_neigh, + } + + with open(self.processed_paths[1], 'w') as f: + yaml.dump(meta, f, default_flow_style=False) + + @property + def species(self): + return [z for z in self.statistics['natoms'].keys() if z != 'total'] + + @property + def natoms(self): + return self.statistics['natoms'] + + @property + def per_atom_energy_mean(self): + return self.statistics[KEY.PER_ATOM_ENERGY]['mean'] + + @property + def elemwise_reference_energies(self): + return self.statistics['elemwise_reference_energies'] + + @property + def force_rms(self): + mean = self.statistics[KEY.FORCE]['mean'] + std = self.statistics[KEY.FORCE]['std'] + return float((mean**2 + std**2) ** (0.5)) + + @property + def per_atom_energy_std(self): + return self.statistics['per_atom_energy']['std'] + + @property + def avg_num_neigh(self): + return self.statistics['num_neighbor']['mean'] + + @property + def sqrt_avg_num_neigh(self): + return self.avg_num_neigh**0.5 + + @staticmethod + def _read_sevenn_data(filename: str) -> Tuple[List[AtomGraphData], float]: + # backward compatibility + from sevenn.train.dataset import AtomGraphDataset + + dataset = torch.load(filename, map_location='cpu', weights_only=False) + if isinstance(dataset, AtomGraphDataset): + graph_list = [] + for _, graphs in dataset.dataset.items(): # type: ignore + # TODO: transfer label to tag (who gonna need this?) + graph_list.extend(graphs) + return graph_list, dataset.cutoff + else: + raise ValueError(f'Not sevenn_data type: {type(dataset)}') + + @staticmethod + def _read_structure_list( + filename: str, cutoff: float, num_cores: int = 1 + ) -> List[AtomGraphData]: + datadct = dataload.structure_list_reader(filename) + graph_list = [] + for tag, atoms_list in datadct.items(): + tmp = dataload.graph_build(atoms_list, cutoff, num_cores) + graph_list.extend(_tag_graphs(tmp, tag)) + return graph_list + + @staticmethod + def _read_ase_readable( + filename: str, + cutoff: float, + num_cores: int = 1, + tag: str = '', + transfer_info: bool = True, + allow_unlabeled: bool = False, + **ase_kwargs, + ) -> List[AtomGraphData]: + pbc_override = ase_kwargs.pop('pbc', None) + atoms_list = dataload.ase_reader(filename, **ase_kwargs) + for atoms in atoms_list: + if pbc_override is not None: + atoms.pbc = pbc_override + graph_list = dataload.graph_build( + atoms_list, + cutoff, + num_cores, + transfer_info=transfer_info, + allow_unlabeled=allow_unlabeled, + ) + if tag != '': + graph_list = _tag_graphs(graph_list, tag) + return graph_list + + @staticmethod + def _read_graph_dataset( + filename: str, cutoff: float, **kwargs + ) -> List[AtomGraphData]: + meta_f = filename.replace('.pt', '.yaml') + orig_cutoff = cutoff + if not os.path.exists(filename): + raise FileNotFoundError(f'No such file: {filename}') + if not os.path.exists(meta_f): + warnings.warn('No meta info found, beware of cutoff...') + else: + with open(meta_f, 'r') as f: + meta = yaml.safe_load(f) + orig_cutoff = float(meta['cutoff']) + if orig_cutoff != cutoff: + warnings.warn( + f'{filename} has different cutoff length: ' + + f'{cutoff} != {orig_cutoff}' + ) + ds_args: dict[str, Any] = dict({'cutoff': orig_cutoff}) + ds_args.update(pt_to_args(filename)) + ds_args.update(kwargs) + dataset = SevenNetGraphDataset(**ds_args) + # TODO: hard coded. consult with inference.py + glist = [g.fit_dimension() for g in dataset] # type: ignore + for g in glist: + if KEY.STRESS in g: + # (1, 6) is what we want + g[KEY.STRESS] = g[KEY.STRESS].unsqueeze(0) + return glist + + @staticmethod + def _read_dict( + data_dict: dict, + cutoff: float, + num_cores: int = 1, + ): + # logic same as the dataload dict_reader, but handles graphs + data_dict_cp = deepcopy(data_dict) + file_list = data_dict_cp.get('file_list', None) + if file_list is None: + raise KeyError('file_list is not found') + + data_weight_default = { + 'energy': 1.0, + 'force': 1.0, + 'stress': 1.0, + } + data_weight = data_weight_default.copy() + data_weight.update(data_dict_cp.pop(KEY.DATA_WEIGHT, {})) + + graph_list = [] + for file_dct in file_list: + ftype = file_dct.pop('data_format', 'ase') + if ftype != 'graph': + continue + graph_list.extend( + SevenNetGraphDataset._read_graph_dataset( + file_dct.get('file'), cutoff=cutoff + ) + ) + for graph in graph_list: + if KEY.INFO not in graph: + graph[KEY.INFO] = {} + graph[KEY.INFO].update(data_dict_cp) + graph[KEY.INFO].update({KEY.DATA_WEIGHT: data_weight}) + + atoms_list = dataload.dict_reader(data_dict) + graph_list.extend(dataload.graph_build(atoms_list, cutoff, num_cores)) + return graph_list + + @staticmethod + def file_to_graph_list( + file: Union[str, dict], cutoff: float, num_cores: int = 1, **kwargs + ) -> List[AtomGraphData]: + """ + kwargs: if file is ase readable, passed to ase.io.read + """ + if isinstance(file, str) and not os.path.isfile(file): + raise ValueError(f'No such file: {file}') + graph_list: List[AtomGraphData] + if isinstance(file, dict): + graph_list = SevenNetGraphDataset._read_dict( + file, cutoff, num_cores, **kwargs + ) + elif file.endswith('.pt'): + graph_list = SevenNetGraphDataset._read_graph_dataset(file, cutoff) + elif file.endswith('.sevenn_data'): + graph_list, cutoff_other = SevenNetGraphDataset._read_sevenn_data(file) + if cutoff_other != cutoff: + warnings.warn(f'Given {file} has different {cutoff_other}!') + cutoff = cutoff_other + elif 'structure_list' in file: + graph_list = SevenNetGraphDataset._read_structure_list( + file, cutoff, num_cores + ) + else: + graph_list = SevenNetGraphDataset._read_ase_readable( + file, cutoff, num_cores, **kwargs + ) + return graph_list + + +def from_single_path( + path: Union[str, List], override_data_weight: bool = True, **dataset_kwargs +) -> Union[SevenNetGraphDataset, None]: + """ + Convenient routine for loading a single .pt dataset. + If given dict and it has data_weight, apply it using transform + """ + data_weight = {'energy': 1.0, 'force': 1.0, 'stress': 1.0} + spath = _extract_single_path(path) + if spath is None: + return None + + if isinstance(spath, str): + if not spath.endswith('.pt'): + return None + dataset_kwargs.update(pt_to_args(spath)) + elif isinstance(spath, dict): + file = _extract_file_from_dict(spath) + if file is None or not file.endswith('.pt'): + return None + dataset_kwargs.update(pt_to_args(file)) + data_weight_user = spath.get(KEY.DATA_WEIGHT, None) + if data_weight_user is not None: + data_weight.update(data_weight_user) + else: + return None + + if override_data_weight: + dataset_kwargs['transform'] = _chain_data_weight_override( + dataset_kwargs.get('transform'), data_weight + ) + + return SevenNetGraphDataset(**dataset_kwargs) + + +def _extract_single_path(path: Union[str, List]) -> Union[str, dict, None]: + """Extracts a single path from the input, + ensuring it's either a single string or list with one item.""" + if isinstance(path, list): + return path[0] if len(path) == 1 else None + return path if isinstance(path, (str, dict)) else None + + +def _extract_file_from_dict(path_dict: dict) -> Union[str, None]: + """Extracts a single file path from the dictionary, ensuring it's valid.""" + file_list = path_dict.get('file_list', None) + if file_list and len(file_list) == 1: + file = file_list[0].get('file', None) + return file if isinstance(file, str) else None + return None + + +def _chain_data_weight_override(transform_func, data_weight): + """Creates a transform function that overrides the data weight.""" + + def chained_transform(graph): + graph = transform_func(graph) if transform_func is not None else graph + graph[KEY.INFO].pop(KEY.DATA_WEIGHT, None) + graph[KEY.DATA_WEIGHT] = data_weight + return graph + + return chained_transform + + +# script, return dict of SevenNetGraphDataset +def from_config( + config: Dict[str, Any], + working_dir: str = os.getcwd(), + dataset_keys: Optional[List[str]] = None, +): + log = Logger() + if dataset_keys is None: + dataset_keys = [] + for k in config: + if k.startswith('load_') and k.endswith('_path'): + dataset_keys.append(k) + + if KEY.LOAD_TRAINSET not in dataset_keys: + raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') + + # initialize arguments for loading dataset + dataset_args = { + 'cutoff': config[KEY.CUTOFF], + 'root': working_dir, + 'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1), + 'use_data_weight': config.get(KEY.USE_WEIGHT, False), + **config.get(KEY.DATA_FORMAT_ARGS, {}), + } + + datasets = {} + for dk in dataset_keys: + if not (paths := config[dk]): + continue + if isinstance(paths, str): + paths = [paths] + name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) + if (dataset := from_single_path(paths, **dataset_args)) is not None: + datasets[name] = dataset + else: + dataset_args.update({'files': paths, 'processed_name': name}) + dataset_path = os.path.join(working_dir, 'sevenn_data', f'{name}.pt') + if os.path.exists(dataset_path) and 'force_reload' not in dataset_args: + log.writeline( + f'Dataset will be loaded from {dataset_path}, without update. ' + + 'If you have changed your files to read, put force_reload=True' + + ' under the data_format_args key' + ) + datasets[name] = SevenNetGraphDataset(**dataset_args) + + train_set = datasets['trainset'] + + chem_species = set(train_set.species) + # print statistics of each dataset + for name, dataset in datasets.items(): + log.bar() + log.writeline(f'{name} distribution:') + log.statistic_write(dataset.statistics) + log.format_k_v('# structures (graph)', len(dataset), write=True) + + chem_species.update(dataset.species) + log.bar() + + # initialize known species from dataset if 'auto' + # sorted to alphabetical order (which is same as before) + chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] + if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py + log.writeline('Known species are obtained from the dataset') + config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) + + # retrieve shift, scale, conv_denominaotrs from user input (keyword) + init_from_stats = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] + for k in init_from_stats: + input = config[k] # statistic key or numbers + # If it is not 'str', 1: It is 'continue' training + # 2: User manually inserted numbers + if isinstance(input, str) and hasattr(train_set, input): + var = getattr(train_set, input) + config.update({k: var}) + log.writeline(f'{k} is obtained from statistics') + elif isinstance(input, str) and not hasattr(train_set, input): + raise NotImplementedError(input) + + if 'validset' not in datasets and config.get(KEY.RATIO, 0.0) > 0.0: + log.writeline('Use validation set as random split from the training set') + log.writeline( + 'Note that statistics, shift, scale, and conv_denominator are ' + + 'computed before random split.\n If you want these after random ' + + 'split, please preprocess dataset and set it as load_trainset_path ' + + 'and load_validset_path explicitly.' + ) + + ratio = float(config[KEY.RATIO]) + train, valid = torch.utils.data.random_split( + datasets['trainset'], (1.0 - ratio, ratio) + ) + datasets['trainset'] = train + datasets['validset'] = valid + + return datasets diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/loss.py b/mace-bench/3rdparty/SevenNet/sevenn/train/loss.py index a223c26..7aae162 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/loss.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/loss.py @@ -1,223 +1,223 @@ -from typing import Any, Callable, Dict, Optional, Tuple - -import torch - -import sevenn._keys as KEY - - -class LossDefinition: - """ - Base class for loss definition - weights are defined in outside of the class - """ - - def __init__( - self, - name: str, - unit: Optional[str] = None, - criterion: Optional[Callable] = None, - ref_key: Optional[str] = None, - pred_key: Optional[str] = None, - use_weight: bool = False, - ignore_unlabeled: bool = True, - ): - self.name = name - self.unit = unit - self.criterion = criterion - self.ref_key = ref_key - self.pred_key = pred_key - self.use_weight = use_weight - self.ignore_unlabeled = ignore_unlabeled - - def __repr__(self): - return self.name - - def assign_criteria(self, criterion: Callable): - if self.criterion is not None: - raise ValueError('Loss uses its own criterion.') - self.criterion = criterion - - def _preprocess( - self, batch_data: Dict[str, Any], model: Optional[Callable] = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - if self.pred_key is None or self.ref_key is None: - raise NotImplementedError('LossDefinition is not implemented.') - pred = torch.reshape(batch_data[self.pred_key], (-1,)) - ref = torch.reshape(batch_data[self.ref_key], (-1,)) - return pred, ref, None - - def _ignore_unlabeled(self, pred, ref, data_weights=None): - unlabeled = torch.isnan(ref) - pred = pred[~unlabeled] - ref = ref[~unlabeled] - if data_weights is not None: - data_weights = data_weights[~unlabeled] - return pred, ref, data_weights - - def get_loss(self, batch_data: Dict[str, Any], model: Optional[Callable] = None): - """ - Function that return scalar - """ - if self.criterion is None: - raise NotImplementedError('LossDefinition has no criterion.') - pred, ref, w_tensor = self._preprocess(batch_data, model) - - if self.ignore_unlabeled: - pred, ref, w_tensor = self._ignore_unlabeled(pred, ref, w_tensor) - - if len(pred) == 0: - assert self.ref_key is not None - return torch.zeros(1, device=batch_data[self.ref_key].device) - - loss = self.criterion(pred, ref) - if self.use_weight: - loss = torch.mean(loss * w_tensor) - return loss - - -class PerAtomEnergyLoss(LossDefinition): - """ - Loss for per atom energy - """ - - def __init__( - self, - name: str = 'Energy', - unit: str = 'eV/atom', - criterion: Optional[Callable] = None, - ref_key: str = KEY.ENERGY, - pred_key: str = KEY.PRED_TOTAL_ENERGY, - **kwargs, - ): - super().__init__( - name=name, - unit=unit, - criterion=criterion, - ref_key=ref_key, - pred_key=pred_key, - **kwargs, - ) - - def _preprocess( - self, batch_data: Dict[str, Any], model: Optional[Callable] = None - ): - num_atoms = batch_data[KEY.NUM_ATOMS] - assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) - pred = batch_data[self.pred_key] / num_atoms - ref = batch_data[self.ref_key] / num_atoms - w_tensor = None - - if self.use_weight: - loss_type = self.name.lower() - weight = batch_data[KEY.DATA_WEIGHT][loss_type] - w_tensor = torch.repeat_interleave(weight, 1) - - return pred, ref, w_tensor - - -class ForceLoss(LossDefinition): - """ - Loss for force - """ - - def __init__( - self, - name: str = 'Force', - unit: str = 'eV/A', - criterion: Optional[Callable] = None, - ref_key: str = KEY.FORCE, - pred_key: str = KEY.PRED_FORCE, - **kwargs, - ): - super().__init__( - name=name, - unit=unit, - criterion=criterion, - ref_key=ref_key, - pred_key=pred_key, - **kwargs, - ) - - def _preprocess( - self, batch_data: Dict[str, Any], model: Optional[Callable] = None - ): - assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) - pred = torch.reshape(batch_data[self.pred_key], (-1,)) - ref = torch.reshape(batch_data[self.ref_key], (-1,)) - w_tensor = None - - if self.use_weight: - loss_type = self.name.lower() - weight = batch_data[KEY.DATA_WEIGHT][loss_type] - w_tensor = weight[batch_data[KEY.BATCH]] - w_tensor = torch.repeat_interleave(w_tensor, 3) - - return pred, ref, w_tensor - - -class StressLoss(LossDefinition): - """ - Loss for stress this is kbar - """ - - def __init__( - self, - name: str = 'Stress', - unit: str = 'kbar', - criterion: Optional[Callable] = None, - ref_key: str = KEY.STRESS, - pred_key: str = KEY.PRED_STRESS, - **kwargs, - ): - super().__init__( - name=name, - unit=unit, - criterion=criterion, - ref_key=ref_key, - pred_key=pred_key, - **kwargs, - ) - self.TO_KB = 1602.1766208 # eV/A^3 to kbar - - def _preprocess( - self, batch_data: Dict[str, Any], model: Optional[Callable] = None - ): - assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) - - pred = torch.reshape(batch_data[self.pred_key] * self.TO_KB, (-1,)) - ref = torch.reshape(batch_data[self.ref_key] * self.TO_KB, (-1,)) - w_tensor = None - - if self.use_weight: - loss_type = self.name.lower() - weight = batch_data[KEY.DATA_WEIGHT][loss_type] - w_tensor = torch.repeat_interleave(weight, 6) - - return pred, ref, w_tensor - - -def get_loss_functions_from_config(config: Dict[str, Any]): - from sevenn.train.optim import loss_dict - - loss_functions = [] # list of tuples (loss_definition, weight) - - loss = loss_dict[config[KEY.LOSS].lower()] - loss_param = config.get(KEY.LOSS_PARAM, {}) - - use_weight = config.get(KEY.USE_WEIGHT, False) - if use_weight: - loss_param['reduction'] = 'none' - criterion = loss(**loss_param) - - commons = {'use_weight': use_weight} - - loss_functions.append((PerAtomEnergyLoss(**commons), 1.0)) - loss_functions.append((ForceLoss(**commons), config[KEY.FORCE_WEIGHT])) - if config[KEY.IS_TRAIN_STRESS]: - loss_functions.append((StressLoss(**commons), config[KEY.STRESS_WEIGHT])) - - for loss_function, _ in loss_functions: # why do these? - if loss_function.criterion is None: - loss_function.assign_criteria(criterion) - - return loss_functions +from typing import Any, Callable, Dict, Optional, Tuple + +import torch + +import sevenn._keys as KEY + + +class LossDefinition: + """ + Base class for loss definition + weights are defined in outside of the class + """ + + def __init__( + self, + name: str, + unit: Optional[str] = None, + criterion: Optional[Callable] = None, + ref_key: Optional[str] = None, + pred_key: Optional[str] = None, + use_weight: bool = False, + ignore_unlabeled: bool = True, + ): + self.name = name + self.unit = unit + self.criterion = criterion + self.ref_key = ref_key + self.pred_key = pred_key + self.use_weight = use_weight + self.ignore_unlabeled = ignore_unlabeled + + def __repr__(self): + return self.name + + def assign_criteria(self, criterion: Callable): + if self.criterion is not None: + raise ValueError('Loss uses its own criterion.') + self.criterion = criterion + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if self.pred_key is None or self.ref_key is None: + raise NotImplementedError('LossDefinition is not implemented.') + pred = torch.reshape(batch_data[self.pred_key], (-1,)) + ref = torch.reshape(batch_data[self.ref_key], (-1,)) + return pred, ref, None + + def _ignore_unlabeled(self, pred, ref, data_weights=None): + unlabeled = torch.isnan(ref) + pred = pred[~unlabeled] + ref = ref[~unlabeled] + if data_weights is not None: + data_weights = data_weights[~unlabeled] + return pred, ref, data_weights + + def get_loss(self, batch_data: Dict[str, Any], model: Optional[Callable] = None): + """ + Function that return scalar + """ + if self.criterion is None: + raise NotImplementedError('LossDefinition has no criterion.') + pred, ref, w_tensor = self._preprocess(batch_data, model) + + if self.ignore_unlabeled: + pred, ref, w_tensor = self._ignore_unlabeled(pred, ref, w_tensor) + + if len(pred) == 0: + assert self.ref_key is not None + return torch.zeros(1, device=batch_data[self.ref_key].device) + + loss = self.criterion(pred, ref) + if self.use_weight: + loss = torch.mean(loss * w_tensor) + return loss + + +class PerAtomEnergyLoss(LossDefinition): + """ + Loss for per atom energy + """ + + def __init__( + self, + name: str = 'Energy', + unit: str = 'eV/atom', + criterion: Optional[Callable] = None, + ref_key: str = KEY.ENERGY, + pred_key: str = KEY.PRED_TOTAL_ENERGY, + **kwargs, + ): + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ): + num_atoms = batch_data[KEY.NUM_ATOMS] + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + pred = batch_data[self.pred_key] / num_atoms + ref = batch_data[self.ref_key] / num_atoms + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = torch.repeat_interleave(weight, 1) + + return pred, ref, w_tensor + + +class ForceLoss(LossDefinition): + """ + Loss for force + """ + + def __init__( + self, + name: str = 'Force', + unit: str = 'eV/A', + criterion: Optional[Callable] = None, + ref_key: str = KEY.FORCE, + pred_key: str = KEY.PRED_FORCE, + **kwargs, + ): + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ): + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + pred = torch.reshape(batch_data[self.pred_key], (-1,)) + ref = torch.reshape(batch_data[self.ref_key], (-1,)) + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = weight[batch_data[KEY.BATCH]] + w_tensor = torch.repeat_interleave(w_tensor, 3) + + return pred, ref, w_tensor + + +class StressLoss(LossDefinition): + """ + Loss for stress this is kbar + """ + + def __init__( + self, + name: str = 'Stress', + unit: str = 'kbar', + criterion: Optional[Callable] = None, + ref_key: str = KEY.STRESS, + pred_key: str = KEY.PRED_STRESS, + **kwargs, + ): + super().__init__( + name=name, + unit=unit, + criterion=criterion, + ref_key=ref_key, + pred_key=pred_key, + **kwargs, + ) + self.TO_KB = 1602.1766208 # eV/A^3 to kbar + + def _preprocess( + self, batch_data: Dict[str, Any], model: Optional[Callable] = None + ): + assert isinstance(self.pred_key, str) and isinstance(self.ref_key, str) + + pred = torch.reshape(batch_data[self.pred_key] * self.TO_KB, (-1,)) + ref = torch.reshape(batch_data[self.ref_key] * self.TO_KB, (-1,)) + w_tensor = None + + if self.use_weight: + loss_type = self.name.lower() + weight = batch_data[KEY.DATA_WEIGHT][loss_type] + w_tensor = torch.repeat_interleave(weight, 6) + + return pred, ref, w_tensor + + +def get_loss_functions_from_config(config: Dict[str, Any]): + from sevenn.train.optim import loss_dict + + loss_functions = [] # list of tuples (loss_definition, weight) + + loss = loss_dict[config[KEY.LOSS].lower()] + loss_param = config.get(KEY.LOSS_PARAM, {}) + + use_weight = config.get(KEY.USE_WEIGHT, False) + if use_weight: + loss_param['reduction'] = 'none' + criterion = loss(**loss_param) + + commons = {'use_weight': use_weight} + + loss_functions.append((PerAtomEnergyLoss(**commons), 1.0)) + loss_functions.append((ForceLoss(**commons), config[KEY.FORCE_WEIGHT])) + if config[KEY.IS_TRAIN_STRESS]: + loss_functions.append((StressLoss(**commons), config[KEY.STRESS_WEIGHT])) + + for loss_function, _ in loss_functions: # why do these? + if loss_function.criterion is None: + loss_function.assign_criteria(criterion) + + return loss_functions diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py b/mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py index 8b749d2..c5bcd91 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/modal_dataset.py @@ -1,365 +1,365 @@ -import bisect -import os -from copy import deepcopy -from typing import Any, Dict, List, Optional - -import numpy as np -from torch.utils.data import ConcatDataset, Dataset - -import sevenn._keys as KEY -import sevenn.util as util -from sevenn.logger import Logger - - -def _arrange_paths_by_modality(paths: List[dict]): - modal_dct = {} - for path in paths: - if isinstance(path, dict): - if KEY.DATA_MODALITY not in path: - raise ValueError(f'{KEY.DATA_MODALITY} is missing') - modal = path.pop(KEY.DATA_MODALITY) - else: - raise TypeError(f'{path} is not dict or str') - if modal not in modal_dct: - modal_dct[modal] = [] - modal_dct[modal].append(path) - return modal_dct - - -def combined_variance( - means: np.ndarray, stds: np.ndarray, sample_sizes: np.ndarray, ddof: int = 0 -) -> float: - """ - Calculate the combined variance for multiple datasets. - """ - assert len(means) == len(stds) and len(stds) == len(sample_sizes) - # Total number of samples - total_samples = np.sum(sample_sizes) - - # Combined mean - combined_mean = np.sum(sample_sizes * means) / total_samples - - # Combined variance calculation - variance_terms = (sample_sizes - ddof) * (stds**2) - mean_diff_terms = sample_sizes * ((means - combined_mean) ** 2) - combined_variance = (np.sum(variance_terms) + np.sum(mean_diff_terms)) / ( - total_samples - ddof - ) - - return combined_variance - - -def combined_std( - means: List[float], stds: List[float], sample_sizes: List[int] -) -> float: - """ - Calculate the combined std for multiple datasets. - """ - assert len(means) == len(stds) and len(stds) == len(sample_sizes) - means_arr = np.array(means) - stds_arr = np.array(stds) - sample_sizes_arr = np.array(sample_sizes) - - cv = combined_variance(means_arr, stds_arr, sample_sizes_arr) - return np.sqrt(cv) - - -def combined_mean(means: List[float], sample_sizes: List[int]) -> float: - """ - Calculate the combined mean for multiple datasets. - """ - assert len(means) == len(sample_sizes) - means_arr = np.array(means) - sample_sizes_arr = np.array(sample_sizes) - - return np.sum(sample_sizes_arr * means_arr) / np.sum(sample_sizes_arr) - - -def combined_rms( - means: List[float], stds: List[float], sample_sizes: List[int] -) -> float: - """ - Calculate the combined RMS for multiple datasets. - """ - assert len(means) == len(stds) and len(stds) == len(sample_sizes) - means_arr = np.array(means) - stds_arr = np.array(stds) - sample_sizes_arr = np.array(sample_sizes) - - cm = combined_mean(means, sample_sizes) - cv = combined_variance(means_arr, stds_arr, sample_sizes_arr) - - # Combined RMS calculation - return np.sqrt(cm**2 + cv) - - -class SevenNetMultiModalDataset(ConcatDataset): - def __init__( - self, - modal_dataset_dict: Dict[str, Dataset], - ): - datasets = [] - modals = [] - for modal, dataset in modal_dataset_dict.items(): - modals.append(modal) - datasets.append(dataset) - self.modals = modals - super().__init__(datasets) - - def __getitem__(self, idx): - graph = super().__getitem__(idx) - dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) - modality = self.modals[dataset_idx] - graph[KEY.DATA_MODALITY] = modality - return graph - - def _modal_wise_property(self, attribute_name: str): - dct = {} - for modal, dataset in zip(self.modals, self.datasets): - try: - if hasattr(dataset, attribute_name): - dct[modal] = getattr(dataset, attribute_name) - except AttributeError: - dct[modal] = None - return dct - - @property - def dataset_dict(self): - arr = {} - for idx, modality in enumerate(self.modals): - arr[modality] = self.datasets[idx] - return arr - - @property - def species(self): - dct = self._modal_wise_property('species') - tot = set() - for sp in dct.values(): - tot.update(sp) - dct['total'] = list(tot) - return dct - - @property - def natoms(self): - return self._modal_wise_property('natoms') - - @property - def per_atom_energy_mean(self): - dct = self._modal_wise_property('per_atom_energy_mean') - try: - means = [] - sample_sizes = [] - for modality, mean in dct.items(): - means.append(mean) - sample_sizes.append( - self.statistics[modality][KEY.PER_ATOM_ENERGY]['count'] - ) - cm = combined_mean(means, sample_sizes) - dct['total'] = cm - except KeyError: - pass - return dct - - @property - def elemwise_reference_energies(self): - # total is not supported (it is expensive and complex, but useless) - return self._modal_wise_property('elemwise_reference_energies') - - @property - def force_rms(self): - dct = self._modal_wise_property('force_rms') - try: - means = [] - sample_sizes = [] - stds = [] - for modality in dct: - means.append(self.statistics[modality][KEY.FORCE]['mean']) - sample_sizes.append(self.statistics[modality][KEY.FORCE]['count']) - stds.append(self.statistics[modality][KEY.FORCE]['std']) - cm = combined_rms(means, stds, sample_sizes) - dct['total'] = cm - except KeyError: - pass - return dct - - @property - def per_atom_energy_std(self): - dct = self._modal_wise_property('per_atom_energy_std') - try: - means = [] - sample_sizes = [] - stds = [] - for modality in dct: - means.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['mean']) - sample_sizes.append( - self.statistics[modality][KEY.PER_ATOM_ENERGY]['count'] - ) - stds.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['std']) - cm = combined_std(means, stds, sample_sizes) - dct['total'] = cm - except KeyError: - pass - return dct - - @property - def avg_num_neigh(self): - dct = self._modal_wise_property('avg_num_neigh') - try: - means = [] - sample_sizes = [] - for modality, mean in dct.items(): - means.append(mean) - sample_sizes.append( - self.statistics[modality]['num_neighbor']['count'] - ) - cm = combined_mean(means, sample_sizes) - dct['total'] = cm - except KeyError: - pass - return dct - - @property - def sqrt_avg_num_neigh(self): - avg_nn = self.avg_num_neigh - return {k: v**0.5 for k, v in avg_nn.items()} - - @property - def statistics(self): - return self._modal_wise_property('statistics') - - @staticmethod - def as_graph_dataset( - paths: List[dict], - **graph_dataset_kwargs, - ): - import sevenn.train.graph_dataset as gd - - modal_paths = _arrange_paths_by_modality(paths) - dataset_dct = {} - for modality, paths in modal_paths.items(): - kwargs = deepcopy(graph_dataset_kwargs) - if (dataset := gd.from_single_path(paths, **kwargs)) is None: - pname = kwargs.pop('processed_name', 'graph').replace('.pt', '') - dataset = gd.SevenNetGraphDataset( - files=paths, - processed_name=f'{pname}_{modality}.pt', - **kwargs, - ) - dataset_dct[modality] = dataset - return SevenNetMultiModalDataset(dataset_dct) - - -def from_config( - config: Dict[str, Any], - working_dir: str = os.getcwd(), - dataset_keys: Optional[List[str]] = None, -): - log = Logger() - if dataset_keys is None: - dataset_keys = [ - k for k in config if (k.startswith('load_') and k.endswith('_path')) - ] - - if KEY.LOAD_TRAINSET not in dataset_keys: - raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') - - dataset_args = { - 'cutoff': config[KEY.CUTOFF], - 'root': working_dir, - 'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1), - 'use_data_weight': config.get(KEY.USE_WEIGHT, False), - **config[KEY.DATA_FORMAT_ARGS], - } - - datasets = {} - for dk in dataset_keys: - if not (paths := config[dk]): - continue - if isinstance(paths, str): - paths = [paths] - name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) - dataset_args.update({'processed_name': name}) - datasets[name] = SevenNetMultiModalDataset.as_graph_dataset( - paths, # type: ignore - **dataset_args, - ) - - train_set = datasets['trainset'] - - modals_dataset = set() - chem_species = set() - # print statistics of each dataset - for name, dataset in datasets.items(): - for idx, modality in enumerate(dataset.modals): - log.bar() - log.writeline(f'{name} - {modality} distribution:') - log.statistic_write(dataset.statistics[modality]) - log.format_k_v( - '# structures (graph)', len(dataset.datasets[idx]), write=True - ) - modals_dataset.update([modality]) - chem_species.update(dataset.species['total']) - log.bar() - - if (modal_map := config.get(KEY.MODAL_MAP, None)) is None: - modals = sorted(list(modals_dataset)) - modal_map = {modal: i for i, modal in enumerate(modals)} - config[KEY.MODAL_MAP] = modal_map - - modals = list(modal_map.keys()) - if not modals_dataset.issubset(modal_map): - raise ValueError( - f'Found modalities in datasets: {modals_dataset} are not subset of' - + f' {modals}. Use sevenn_cp tool to append/assign modality' - ) - - log.writeline(f'Modalities of this model: {modals}') - - config[KEY.NUM_MODALITIES] = len(modal_map) - - # initialize known species from dataset if 'auto' - # sorted to alphabetical order (which is same as before) - chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] - if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py - log.writeline('Known species are obtained from the dataset') - config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) - - # retrieve shift, scale, conv_denominaotrs from user input (keyword) - init_from_stats_candid = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] - init_from_stats = [ - k for k in init_from_stats_candid if isinstance(config[k], str) - ] - - for k in init_from_stats: - input = config[k] - if not hasattr(train_set, input): - raise NotImplementedError(input) - modal_stat = getattr(train_set, input) - try: - if k == KEY.CONV_DENOMINATOR and 'total' in modal_stat: - # conv_denominator is not modal-wise - var = modal_stat['total'] - elif k == KEY.SHIFT and config[KEY.USE_MODAL_WISE_SHIFT]: - modal_stat.pop('total', None) - var = modal_stat - elif k == KEY.SHIFT and not config[KEY.USE_MODAL_WISE_SHIFT]: - var = modal_stat['total'] - elif k == KEY.SCALE and config[KEY.USE_MODAL_WISE_SCALE]: - modal_stat.pop('total', None) - var = modal_stat - elif k == KEY.SCALE and not config[KEY.USE_MODAL_WISE_SCALE]: - var = modal_stat['total'] - else: - raise NotImplementedError(f'Failed to init {k} from statistics') - except KeyError as e: - if e.args[0] == 'total': - raise NotImplementedError( - f'{k}: {input} does not support total statistics. ' - + f'Set use_modal_wise_{k} True or specify numbers manually' - ) - else: - raise e - config.update({k: var}) - log.writeline(f'{k} is obtained from statistics') - - return datasets +import bisect +import os +from copy import deepcopy +from typing import Any, Dict, List, Optional + +import numpy as np +from torch.utils.data import ConcatDataset, Dataset + +import sevenn._keys as KEY +import sevenn.util as util +from sevenn.logger import Logger + + +def _arrange_paths_by_modality(paths: List[dict]): + modal_dct = {} + for path in paths: + if isinstance(path, dict): + if KEY.DATA_MODALITY not in path: + raise ValueError(f'{KEY.DATA_MODALITY} is missing') + modal = path.pop(KEY.DATA_MODALITY) + else: + raise TypeError(f'{path} is not dict or str') + if modal not in modal_dct: + modal_dct[modal] = [] + modal_dct[modal].append(path) + return modal_dct + + +def combined_variance( + means: np.ndarray, stds: np.ndarray, sample_sizes: np.ndarray, ddof: int = 0 +) -> float: + """ + Calculate the combined variance for multiple datasets. + """ + assert len(means) == len(stds) and len(stds) == len(sample_sizes) + # Total number of samples + total_samples = np.sum(sample_sizes) + + # Combined mean + combined_mean = np.sum(sample_sizes * means) / total_samples + + # Combined variance calculation + variance_terms = (sample_sizes - ddof) * (stds**2) + mean_diff_terms = sample_sizes * ((means - combined_mean) ** 2) + combined_variance = (np.sum(variance_terms) + np.sum(mean_diff_terms)) / ( + total_samples - ddof + ) + + return combined_variance + + +def combined_std( + means: List[float], stds: List[float], sample_sizes: List[int] +) -> float: + """ + Calculate the combined std for multiple datasets. + """ + assert len(means) == len(stds) and len(stds) == len(sample_sizes) + means_arr = np.array(means) + stds_arr = np.array(stds) + sample_sizes_arr = np.array(sample_sizes) + + cv = combined_variance(means_arr, stds_arr, sample_sizes_arr) + return np.sqrt(cv) + + +def combined_mean(means: List[float], sample_sizes: List[int]) -> float: + """ + Calculate the combined mean for multiple datasets. + """ + assert len(means) == len(sample_sizes) + means_arr = np.array(means) + sample_sizes_arr = np.array(sample_sizes) + + return np.sum(sample_sizes_arr * means_arr) / np.sum(sample_sizes_arr) + + +def combined_rms( + means: List[float], stds: List[float], sample_sizes: List[int] +) -> float: + """ + Calculate the combined RMS for multiple datasets. + """ + assert len(means) == len(stds) and len(stds) == len(sample_sizes) + means_arr = np.array(means) + stds_arr = np.array(stds) + sample_sizes_arr = np.array(sample_sizes) + + cm = combined_mean(means, sample_sizes) + cv = combined_variance(means_arr, stds_arr, sample_sizes_arr) + + # Combined RMS calculation + return np.sqrt(cm**2 + cv) + + +class SevenNetMultiModalDataset(ConcatDataset): + def __init__( + self, + modal_dataset_dict: Dict[str, Dataset], + ): + datasets = [] + modals = [] + for modal, dataset in modal_dataset_dict.items(): + modals.append(modal) + datasets.append(dataset) + self.modals = modals + super().__init__(datasets) + + def __getitem__(self, idx): + graph = super().__getitem__(idx) + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + modality = self.modals[dataset_idx] + graph[KEY.DATA_MODALITY] = modality + return graph + + def _modal_wise_property(self, attribute_name: str): + dct = {} + for modal, dataset in zip(self.modals, self.datasets): + try: + if hasattr(dataset, attribute_name): + dct[modal] = getattr(dataset, attribute_name) + except AttributeError: + dct[modal] = None + return dct + + @property + def dataset_dict(self): + arr = {} + for idx, modality in enumerate(self.modals): + arr[modality] = self.datasets[idx] + return arr + + @property + def species(self): + dct = self._modal_wise_property('species') + tot = set() + for sp in dct.values(): + tot.update(sp) + dct['total'] = list(tot) + return dct + + @property + def natoms(self): + return self._modal_wise_property('natoms') + + @property + def per_atom_energy_mean(self): + dct = self._modal_wise_property('per_atom_energy_mean') + try: + means = [] + sample_sizes = [] + for modality, mean in dct.items(): + means.append(mean) + sample_sizes.append( + self.statistics[modality][KEY.PER_ATOM_ENERGY]['count'] + ) + cm = combined_mean(means, sample_sizes) + dct['total'] = cm + except KeyError: + pass + return dct + + @property + def elemwise_reference_energies(self): + # total is not supported (it is expensive and complex, but useless) + return self._modal_wise_property('elemwise_reference_energies') + + @property + def force_rms(self): + dct = self._modal_wise_property('force_rms') + try: + means = [] + sample_sizes = [] + stds = [] + for modality in dct: + means.append(self.statistics[modality][KEY.FORCE]['mean']) + sample_sizes.append(self.statistics[modality][KEY.FORCE]['count']) + stds.append(self.statistics[modality][KEY.FORCE]['std']) + cm = combined_rms(means, stds, sample_sizes) + dct['total'] = cm + except KeyError: + pass + return dct + + @property + def per_atom_energy_std(self): + dct = self._modal_wise_property('per_atom_energy_std') + try: + means = [] + sample_sizes = [] + stds = [] + for modality in dct: + means.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['mean']) + sample_sizes.append( + self.statistics[modality][KEY.PER_ATOM_ENERGY]['count'] + ) + stds.append(self.statistics[modality][KEY.PER_ATOM_ENERGY]['std']) + cm = combined_std(means, stds, sample_sizes) + dct['total'] = cm + except KeyError: + pass + return dct + + @property + def avg_num_neigh(self): + dct = self._modal_wise_property('avg_num_neigh') + try: + means = [] + sample_sizes = [] + for modality, mean in dct.items(): + means.append(mean) + sample_sizes.append( + self.statistics[modality]['num_neighbor']['count'] + ) + cm = combined_mean(means, sample_sizes) + dct['total'] = cm + except KeyError: + pass + return dct + + @property + def sqrt_avg_num_neigh(self): + avg_nn = self.avg_num_neigh + return {k: v**0.5 for k, v in avg_nn.items()} + + @property + def statistics(self): + return self._modal_wise_property('statistics') + + @staticmethod + def as_graph_dataset( + paths: List[dict], + **graph_dataset_kwargs, + ): + import sevenn.train.graph_dataset as gd + + modal_paths = _arrange_paths_by_modality(paths) + dataset_dct = {} + for modality, paths in modal_paths.items(): + kwargs = deepcopy(graph_dataset_kwargs) + if (dataset := gd.from_single_path(paths, **kwargs)) is None: + pname = kwargs.pop('processed_name', 'graph').replace('.pt', '') + dataset = gd.SevenNetGraphDataset( + files=paths, + processed_name=f'{pname}_{modality}.pt', + **kwargs, + ) + dataset_dct[modality] = dataset + return SevenNetMultiModalDataset(dataset_dct) + + +def from_config( + config: Dict[str, Any], + working_dir: str = os.getcwd(), + dataset_keys: Optional[List[str]] = None, +): + log = Logger() + if dataset_keys is None: + dataset_keys = [ + k for k in config if (k.startswith('load_') and k.endswith('_path')) + ] + + if KEY.LOAD_TRAINSET not in dataset_keys: + raise ValueError(f'{KEY.LOAD_TRAINSET} must be present in config') + + dataset_args = { + 'cutoff': config[KEY.CUTOFF], + 'root': working_dir, + 'process_num_cores': config.get(KEY.PREPROCESS_NUM_CORES, 1), + 'use_data_weight': config.get(KEY.USE_WEIGHT, False), + **config[KEY.DATA_FORMAT_ARGS], + } + + datasets = {} + for dk in dataset_keys: + if not (paths := config[dk]): + continue + if isinstance(paths, str): + paths = [paths] + name = '_'.join([nn.strip() for nn in dk.split('_')[1:-1]]) + dataset_args.update({'processed_name': name}) + datasets[name] = SevenNetMultiModalDataset.as_graph_dataset( + paths, # type: ignore + **dataset_args, + ) + + train_set = datasets['trainset'] + + modals_dataset = set() + chem_species = set() + # print statistics of each dataset + for name, dataset in datasets.items(): + for idx, modality in enumerate(dataset.modals): + log.bar() + log.writeline(f'{name} - {modality} distribution:') + log.statistic_write(dataset.statistics[modality]) + log.format_k_v( + '# structures (graph)', len(dataset.datasets[idx]), write=True + ) + modals_dataset.update([modality]) + chem_species.update(dataset.species['total']) + log.bar() + + if (modal_map := config.get(KEY.MODAL_MAP, None)) is None: + modals = sorted(list(modals_dataset)) + modal_map = {modal: i for i, modal in enumerate(modals)} + config[KEY.MODAL_MAP] = modal_map + + modals = list(modal_map.keys()) + if not modals_dataset.issubset(modal_map): + raise ValueError( + f'Found modalities in datasets: {modals_dataset} are not subset of' + + f' {modals}. Use sevenn_cp tool to append/assign modality' + ) + + log.writeline(f'Modalities of this model: {modals}') + + config[KEY.NUM_MODALITIES] = len(modal_map) + + # initialize known species from dataset if 'auto' + # sorted to alphabetical order (which is same as before) + chem_keys = [KEY.CHEMICAL_SPECIES, KEY.NUM_SPECIES, KEY.TYPE_MAP] + if all([config[ck] == 'auto' for ck in chem_keys]): # see parse_input.py + log.writeline('Known species are obtained from the dataset') + config.update(util.chemical_species_preprocess(sorted(list(chem_species)))) + + # retrieve shift, scale, conv_denominaotrs from user input (keyword) + init_from_stats_candid = [KEY.SHIFT, KEY.SCALE, KEY.CONV_DENOMINATOR] + init_from_stats = [ + k for k in init_from_stats_candid if isinstance(config[k], str) + ] + + for k in init_from_stats: + input = config[k] + if not hasattr(train_set, input): + raise NotImplementedError(input) + modal_stat = getattr(train_set, input) + try: + if k == KEY.CONV_DENOMINATOR and 'total' in modal_stat: + # conv_denominator is not modal-wise + var = modal_stat['total'] + elif k == KEY.SHIFT and config[KEY.USE_MODAL_WISE_SHIFT]: + modal_stat.pop('total', None) + var = modal_stat + elif k == KEY.SHIFT and not config[KEY.USE_MODAL_WISE_SHIFT]: + var = modal_stat['total'] + elif k == KEY.SCALE and config[KEY.USE_MODAL_WISE_SCALE]: + modal_stat.pop('total', None) + var = modal_stat + elif k == KEY.SCALE and not config[KEY.USE_MODAL_WISE_SCALE]: + var = modal_stat['total'] + else: + raise NotImplementedError(f'Failed to init {k} from statistics') + except KeyError as e: + if e.args[0] == 'total': + raise NotImplementedError( + f'{k}: {input} does not support total statistics. ' + + f'Set use_modal_wise_{k} True or specify numbers manually' + ) + else: + raise e + config.update({k: var}) + log.writeline(f'{k} is obtained from statistics') + + return datasets diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/optim.py b/mace-bench/3rdparty/SevenNet/sevenn/train/optim.py index 7f98943..10e7579 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/optim.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/optim.py @@ -1,23 +1,23 @@ -import torch.nn as nn -import torch.optim.lr_scheduler as scheduler -from torch.optim import adagrad, adam, adamw, radam, sgd - -optim_dict = { - 'sgd': sgd.SGD, - 'adagrad': adagrad.Adagrad, - 'adam': adam.Adam, - 'adamw': adamw.AdamW, - 'radam': radam.RAdam, -} - - -scheduler_dict = { - 'steplr': scheduler.StepLR, - 'multisteplr': scheduler.MultiStepLR, - 'exponentiallr': scheduler.ExponentialLR, - 'cosineannealinglr': scheduler.CosineAnnealingLR, - 'reducelronplateau': scheduler.ReduceLROnPlateau, - 'linearlr': scheduler.LinearLR, -} - -loss_dict = {'mse': nn.MSELoss, 'huber': nn.HuberLoss} +import torch.nn as nn +import torch.optim.lr_scheduler as scheduler +from torch.optim import adagrad, adam, adamw, radam, sgd + +optim_dict = { + 'sgd': sgd.SGD, + 'adagrad': adagrad.Adagrad, + 'adam': adam.Adam, + 'adamw': adamw.AdamW, + 'radam': radam.RAdam, +} + + +scheduler_dict = { + 'steplr': scheduler.StepLR, + 'multisteplr': scheduler.MultiStepLR, + 'exponentiallr': scheduler.ExponentialLR, + 'cosineannealinglr': scheduler.CosineAnnealingLR, + 'reducelronplateau': scheduler.ReduceLROnPlateau, + 'linearlr': scheduler.LinearLR, +} + +loss_dict = {'mse': nn.MSELoss, 'huber': nn.HuberLoss} diff --git a/mace-bench/3rdparty/SevenNet/sevenn/train/trainer.py b/mace-bench/3rdparty/SevenNet/sevenn/train/trainer.py index 962598e..cb2eb91 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/train/trainer.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/train/trainer.py @@ -1,230 +1,230 @@ -import os -import uuid -from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn -from torch.nn.parallel import DistributedDataParallel as DDP -from tqdm import tqdm - -import sevenn._keys as KEY -from sevenn.error_recorder import ErrorRecorder -from sevenn.train.loss import LossDefinition - -from .loss import get_loss_functions_from_config -from .optim import optim_dict, scheduler_dict - - -class Trainer: - """ - Training routine specialized for this package. Depends on 'sevenn.train.loss' - - Args: - model: model to train - loss_functions: List of tuples of [LossDefinition, float]. 'float' is for - loss weight for each Loss function - optimizer_cls: torch optimizer class to initialize - optimizer_args: optimizer keyword argument except 'param' - scheduler_cls: torch scheduler class to initialize, can be None - optimizer_args: optimizer keyword argument except 'optimizer' - device: device to train model, defaults to 'auto' - distributed: whether this is distributed training - distributed_backend: torch DDP backend. Should be one of 'nccl', 'mpi' - """ - - def __init__( - self, - model: torch.nn.Module, - loss_functions: List[Tuple[LossDefinition, float]], - optimizer_cls, - optimizer_args: Optional[dict] = None, - scheduler_cls=None, - scheduler_args: Optional[dict] = None, - device: Union[torch.device, str] = 'auto', - distributed: bool = False, - distributed_backend: str = 'nccl', - ): - if device == 'auto': - device = 'cuda' if torch.cuda.is_available() else 'cpu' - if distributed_backend == 'mpi': - device = 'cpu' - - if distributed: - local_rank = int(os.environ['LOCAL_RANK']) - self.rank = local_rank - if distributed_backend == 'nccl': - device = torch.device('cuda', local_rank) - self.model = DDP(model.to(device), device_ids=[device]) - elif distributed_backend == 'mpi': - self.model = DDP(model.to(device)) - else: - raise ValueError(f'Unknown DDP backend: {distributed_backend}') - dist.barrier() - self.model.module.set_is_batch_data(True) - else: - self.model = model.to(device) - self.model.set_is_batch_data(True) - self.rank = 0 - - self.device = device - self.distributed = distributed - - param = [p for p in self.model.parameters() if p.requires_grad] - self.optimizer = optimizer_cls(param, **optimizer_args) - if scheduler_cls is not None: - self.scheduler = scheduler_cls(self.optimizer, **scheduler_args) - else: - self.scheduler = None - self.loss_functions = loss_functions - - @staticmethod - def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer': - trainer = Trainer( - model, - loss_functions=get_loss_functions_from_config(config), - optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()], - optimizer_args=config.get(KEY.OPTIM_PARAM, {}), - scheduler_cls=scheduler_dict[ - config.get(KEY.SCHEDULER, 'exponentiallr').lower() - ], - scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}), - device=config.get(KEY.DEVICE, 'auto'), - distributed=config.get(KEY.IS_DDP, False), - distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'), - ) - return trainer - - @staticmethod - def args_from_checkpoint(checkpoint: str) -> Tuple[Dict, Dict, Dict]: - """ - Usage: - trainer_args, optim_stct, scheduler_stct = args_from_checkpoint('7net-0') - # Do what you want to do here - trainer = Trainer(**trainer_args) - trainer.load_state_dict( - optimizer_state_dict=optim_stct, - scheduler_state_dict=scheduler_stct, - """ - from sevenn.util import load_checkpoint - - cp = load_checkpoint(checkpoint) - - model = cp.build_model() - config = cp.config - optimizer_cls = optim_dict[config[KEY.OPTIMIZER].lower()] - scheduler_cls = scheduler_dict[config[KEY.SCHEDULER].lower()] - loss_functions = get_loss_functions_from_config(config) - - return ( - { - 'model': model, - 'loss_functions': loss_functions, - 'optimizer_cls': optimizer_cls, - 'optimizer_args': config[KEY.OPTIM_PARAM], - 'scheduler_cls': scheduler_cls, - 'scheduler_args': config[KEY.SCHEDULER_PARAM], - }, - cp.optimizer_state_dict, - cp.scheduler_state_dict, - ) - - def run_one_epoch( - self, - loader: Iterable, - is_train: bool = False, - error_recorder: Optional[ErrorRecorder] = None, - wrap_tqdm: Union[bool, int] = False, - ) -> None: - """ - Run single epoch with given dataloader - Args: - loader: iterable yieds AtomGraphData - is_train: if true, do backward() and optimizer step - error_recorder: ErrorRecorder instance to compute errors (RMSEm MAE, ..) - wrap_tqdm: wrap given dataloader with tqdm for progress bar - """ - if is_train: - self.model.train() - else: - self.model.eval() - - if wrap_tqdm: - total_len = wrap_tqdm if isinstance(wrap_tqdm, int) else None - loader = tqdm(loader, total=total_len) - for _, batch in enumerate(loader): - if is_train: - self.optimizer.zero_grad() - batch = batch.to(self.device, non_blocking=True) - output = self.model(batch) - if error_recorder is not None: - error_recorder.update(output) - if is_train: - total_loss = torch.tensor([0.0], device=self.device) - for loss_def, w in self.loss_functions: - indv_loss = loss_def.get_loss(output, self.model) - if indv_loss is not None: - total_loss += (indv_loss * w) - total_loss.backward() - self.optimizer.step() - - if self.distributed and error_recorder is not None: - self.recorder_all_reduce(error_recorder) - - def scheduler_step(self, metric: Optional[float] = None) -> None: - if self.scheduler is None: - return - if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - assert isinstance(metric, float) - self.scheduler.step(metric) - else: - self.scheduler.step() - - def get_lr(self) -> float: - return float(self.optimizer.param_groups[0]['lr']) - - def recorder_all_reduce(self, recorder: ErrorRecorder) -> None: - for metric in recorder.metrics: - # metric.value._ddp_reduce(self.device) - metric.ddp_reduce(self.device) - - def get_checkpoint_dict(self) -> dict: - if self.distributed: - model_state_dct = self.model.module.state_dict() - else: - model_state_dct = self.model.state_dict() - return { - 'model_state_dict': model_state_dct, - 'optimizer_state_dict': self.optimizer.state_dict(), - 'scheduler_state_dict': self.scheduler.state_dict() - if self.scheduler is not None - else None, - 'time': datetime.now().strftime('%Y-%m-%d %H:%M'), - 'hash': uuid.uuid4().hex, - } - - def write_checkpoint(self, path: str, **extra) -> None: - if self.distributed and self.rank != 0: - return - cp = self.get_checkpoint_dict() - cp.update(**extra) - torch.save(cp, path) - - def load_state_dicts( - self, - model_state_dict: Optional[Dict] = None, - optimizer_state_dict: Optional[Dict] = None, - scheduler_state_dict: Optional[Dict] = None, - strict: bool = True, - ) -> None: - if model_state_dict is not None: - if self.distributed: - self.model.module.load_state_dict(model_state_dict, strict=strict) - else: - self.model.load_state_dict(model_state_dict, strict=strict) - - if optimizer_state_dict is not None: - self.optimizer.load_state_dict(optimizer_state_dict) - if scheduler_state_dict is not None and self.scheduler is not None: - self.scheduler.load_state_dict(scheduler_state_dict) +import os +import uuid +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm + +import sevenn._keys as KEY +from sevenn.error_recorder import ErrorRecorder +from sevenn.train.loss import LossDefinition + +from .loss import get_loss_functions_from_config +from .optim import optim_dict, scheduler_dict + + +class Trainer: + """ + Training routine specialized for this package. Depends on 'sevenn.train.loss' + + Args: + model: model to train + loss_functions: List of tuples of [LossDefinition, float]. 'float' is for + loss weight for each Loss function + optimizer_cls: torch optimizer class to initialize + optimizer_args: optimizer keyword argument except 'param' + scheduler_cls: torch scheduler class to initialize, can be None + optimizer_args: optimizer keyword argument except 'optimizer' + device: device to train model, defaults to 'auto' + distributed: whether this is distributed training + distributed_backend: torch DDP backend. Should be one of 'nccl', 'mpi' + """ + + def __init__( + self, + model: torch.nn.Module, + loss_functions: List[Tuple[LossDefinition, float]], + optimizer_cls, + optimizer_args: Optional[dict] = None, + scheduler_cls=None, + scheduler_args: Optional[dict] = None, + device: Union[torch.device, str] = 'auto', + distributed: bool = False, + distributed_backend: str = 'nccl', + ): + if device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if distributed_backend == 'mpi': + device = 'cpu' + + if distributed: + local_rank = int(os.environ['LOCAL_RANK']) + self.rank = local_rank + if distributed_backend == 'nccl': + device = torch.device('cuda', local_rank) + self.model = DDP(model.to(device), device_ids=[device]) + elif distributed_backend == 'mpi': + self.model = DDP(model.to(device)) + else: + raise ValueError(f'Unknown DDP backend: {distributed_backend}') + dist.barrier() + self.model.module.set_is_batch_data(True) + else: + self.model = model.to(device) + self.model.set_is_batch_data(True) + self.rank = 0 + + self.device = device + self.distributed = distributed + + param = [p for p in self.model.parameters() if p.requires_grad] + self.optimizer = optimizer_cls(param, **optimizer_args) + if scheduler_cls is not None: + self.scheduler = scheduler_cls(self.optimizer, **scheduler_args) + else: + self.scheduler = None + self.loss_functions = loss_functions + + @staticmethod + def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer': + trainer = Trainer( + model, + loss_functions=get_loss_functions_from_config(config), + optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()], + optimizer_args=config.get(KEY.OPTIM_PARAM, {}), + scheduler_cls=scheduler_dict[ + config.get(KEY.SCHEDULER, 'exponentiallr').lower() + ], + scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}), + device=config.get(KEY.DEVICE, 'auto'), + distributed=config.get(KEY.IS_DDP, False), + distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'), + ) + return trainer + + @staticmethod + def args_from_checkpoint(checkpoint: str) -> Tuple[Dict, Dict, Dict]: + """ + Usage: + trainer_args, optim_stct, scheduler_stct = args_from_checkpoint('7net-0') + # Do what you want to do here + trainer = Trainer(**trainer_args) + trainer.load_state_dict( + optimizer_state_dict=optim_stct, + scheduler_state_dict=scheduler_stct, + """ + from sevenn.util import load_checkpoint + + cp = load_checkpoint(checkpoint) + + model = cp.build_model() + config = cp.config + optimizer_cls = optim_dict[config[KEY.OPTIMIZER].lower()] + scheduler_cls = scheduler_dict[config[KEY.SCHEDULER].lower()] + loss_functions = get_loss_functions_from_config(config) + + return ( + { + 'model': model, + 'loss_functions': loss_functions, + 'optimizer_cls': optimizer_cls, + 'optimizer_args': config[KEY.OPTIM_PARAM], + 'scheduler_cls': scheduler_cls, + 'scheduler_args': config[KEY.SCHEDULER_PARAM], + }, + cp.optimizer_state_dict, + cp.scheduler_state_dict, + ) + + def run_one_epoch( + self, + loader: Iterable, + is_train: bool = False, + error_recorder: Optional[ErrorRecorder] = None, + wrap_tqdm: Union[bool, int] = False, + ) -> None: + """ + Run single epoch with given dataloader + Args: + loader: iterable yieds AtomGraphData + is_train: if true, do backward() and optimizer step + error_recorder: ErrorRecorder instance to compute errors (RMSEm MAE, ..) + wrap_tqdm: wrap given dataloader with tqdm for progress bar + """ + if is_train: + self.model.train() + else: + self.model.eval() + + if wrap_tqdm: + total_len = wrap_tqdm if isinstance(wrap_tqdm, int) else None + loader = tqdm(loader, total=total_len) + for _, batch in enumerate(loader): + if is_train: + self.optimizer.zero_grad() + batch = batch.to(self.device, non_blocking=True) + output = self.model(batch) + if error_recorder is not None: + error_recorder.update(output) + if is_train: + total_loss = torch.tensor([0.0], device=self.device) + for loss_def, w in self.loss_functions: + indv_loss = loss_def.get_loss(output, self.model) + if indv_loss is not None: + total_loss += (indv_loss * w) + total_loss.backward() + self.optimizer.step() + + if self.distributed and error_recorder is not None: + self.recorder_all_reduce(error_recorder) + + def scheduler_step(self, metric: Optional[float] = None) -> None: + if self.scheduler is None: + return + if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + assert isinstance(metric, float) + self.scheduler.step(metric) + else: + self.scheduler.step() + + def get_lr(self) -> float: + return float(self.optimizer.param_groups[0]['lr']) + + def recorder_all_reduce(self, recorder: ErrorRecorder) -> None: + for metric in recorder.metrics: + # metric.value._ddp_reduce(self.device) + metric.ddp_reduce(self.device) + + def get_checkpoint_dict(self) -> dict: + if self.distributed: + model_state_dct = self.model.module.state_dict() + else: + model_state_dct = self.model.state_dict() + return { + 'model_state_dict': model_state_dct, + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict() + if self.scheduler is not None + else None, + 'time': datetime.now().strftime('%Y-%m-%d %H:%M'), + 'hash': uuid.uuid4().hex, + } + + def write_checkpoint(self, path: str, **extra) -> None: + if self.distributed and self.rank != 0: + return + cp = self.get_checkpoint_dict() + cp.update(**extra) + torch.save(cp, path) + + def load_state_dicts( + self, + model_state_dict: Optional[Dict] = None, + optimizer_state_dict: Optional[Dict] = None, + scheduler_state_dict: Optional[Dict] = None, + strict: bool = True, + ) -> None: + if model_state_dict is not None: + if self.distributed: + self.model.module.load_state_dict(model_state_dict, strict=strict) + else: + self.model.load_state_dict(model_state_dict, strict=strict) + + if optimizer_state_dict is not None: + self.optimizer.load_state_dict(optimizer_state_dict) + if scheduler_state_dict is not None and self.scheduler is not None: + self.scheduler.load_state_dict(scheduler_state_dict) diff --git a/mace-bench/3rdparty/SevenNet/sevenn/util.py b/mace-bench/3rdparty/SevenNet/sevenn/util.py index 7afeb4e..632a1d2 100644 --- a/mace-bench/3rdparty/SevenNet/sevenn/util.py +++ b/mace-bench/3rdparty/SevenNet/sevenn/util.py @@ -1,330 +1,330 @@ -import os -import os.path as osp -import pathlib -import shutil -from typing import Dict, List, Tuple, Union - -import numpy as np -import requests -import torch -import torch.nn -from e3nn.o3 import FullTensorProduct, Irreps -from tqdm import tqdm - -import sevenn._const as _const -import sevenn._keys as KEY - - -def to_atom_graph_list(atom_graph_batch): - """ - torch_geometric batched data to separate list - original to_data_list() by PyG is not enough since - it doesn't handle inferred tensors - """ - is_stress = KEY.PRED_STRESS in atom_graph_batch - - data_list = atom_graph_batch.to_data_list() - - indices = atom_graph_batch[KEY.NUM_ATOMS].tolist() - - atomic_energy_list = torch.split(atom_graph_batch[KEY.ATOMIC_ENERGY], indices) - inferred_total_energy_list = torch.unbind( - atom_graph_batch[KEY.PRED_TOTAL_ENERGY] - ) - inferred_force_list = torch.split(atom_graph_batch[KEY.PRED_FORCE], indices) - - inferred_stress_list = None - if is_stress: - inferred_stress_list = torch.unbind(atom_graph_batch[KEY.PRED_STRESS]) - - for i, data in enumerate(data_list): - data[KEY.ATOMIC_ENERGY] = atomic_energy_list[i] - data[KEY.PRED_TOTAL_ENERGY] = inferred_total_energy_list[i] - data[KEY.PRED_FORCE] = inferred_force_list[i] - # To fit with KEY.STRESS (ref) format - if is_stress and inferred_stress_list is not None: - data[KEY.PRED_STRESS] = torch.unsqueeze(inferred_stress_list[i], 0) - return data_list - - -def error_recorder_from_loss_functions(loss_functions): - from .error_recorder import ErrorRecorder, MAError, RMSError, get_err_type - from .train.loss import ForceLoss, PerAtomEnergyLoss, StressLoss - - metrics = [] - for loss_function, _ in loss_functions: - ref_key = loss_function.ref_key - pred_key = loss_function.pred_key - # unit = loss_function.unit - criterion = loss_function.criterion - name = loss_function.name - base = None - if type(loss_function) is PerAtomEnergyLoss: - base = get_err_type('Energy') - elif type(loss_function) is ForceLoss: - base = get_err_type('Force') - elif type(loss_function) is StressLoss: - base = get_err_type('Stress') - else: - base = {} - base['name'] = name - base['ref_key'] = ref_key - base['pred_key'] = pred_key - if type(criterion) is torch.nn.MSELoss: - base['name'] = base['name'] + '_RMSE' - metrics.append(RMSError(**base)) - elif type(criterion) is torch.nn.L1Loss: - metrics.append(MAError(**base)) - return ErrorRecorder(metrics) - - -def onehot_to_chem(one_hot_indices: List[int], type_map: Dict[int, int]): - from ase.data import chemical_symbols - - type_map_rev = {v: k for k, v in type_map.items()} - return [chemical_symbols[type_map_rev[x]] for x in one_hot_indices] - - -def model_from_checkpoint( - checkpoint: str, -) -> Tuple[torch.nn.Module, Dict]: - cp = load_checkpoint(checkpoint) - model = cp.build_model() - - return model, cp.config - - -def model_from_checkpoint_with_backend( - checkpoint: str, - backend: str = 'e3nn', -) -> Tuple[torch.nn.Module, Dict]: - cp = load_checkpoint(checkpoint) - model = cp.build_model(backend) - - return model, cp.config - - -def unlabeled_atoms_to_input(atoms, cutoff: float, grad_key: str = KEY.EDGE_VEC): - from .atom_graph_data import AtomGraphData - from .train.dataload import unlabeled_atoms_to_graph - - atom_graph = AtomGraphData.from_numpy_dict( - unlabeled_atoms_to_graph(atoms, cutoff) - ) - atom_graph[grad_key].requires_grad_(True) - atom_graph[KEY.BATCH] = torch.zeros([0]) - return atom_graph - - -def chemical_species_preprocess(input_chem: List[str], universal: bool = False): - from ase.data import atomic_numbers, chemical_symbols - - from .nn.node_embedding import get_type_mapper_from_specie - - config = {} - if not universal: - input_chem = list(set(input_chem)) - chemical_specie = sorted([x.strip() for x in input_chem]) - config[KEY.CHEMICAL_SPECIES] = chemical_specie - config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = [ - atomic_numbers[x] for x in chemical_specie - ] - config[KEY.NUM_SPECIES] = len(chemical_specie) - config[KEY.TYPE_MAP] = get_type_mapper_from_specie(chemical_specie) - else: - config[KEY.CHEMICAL_SPECIES] = chemical_symbols - len_univ = len(chemical_symbols) - config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = list(range(len_univ)) - config[KEY.NUM_SPECIES] = len_univ - config[KEY.TYPE_MAP] = {z: z for z in range(len_univ)} - return config - - -def dtype_correct( - v: Union[np.ndarray, torch.Tensor, int, float], - float_dtype: torch.dtype = torch.float32, - int_dtype: torch.dtype = torch.int64, -): - if isinstance(v, np.ndarray): - if np.issubdtype(v.dtype, np.floating): - return torch.from_numpy(v).to(float_dtype) - elif np.issubdtype(v.dtype, np.integer): - return torch.from_numpy(v).to(int_dtype) - elif isinstance(v, torch.Tensor): - if v.dtype.is_floating_point: - return v.to(float_dtype) # convert to specified float dtype - else: # assuming non-floating point tensors are integers - return v.to(int_dtype) # convert to specified int dtype - else: # scalar values - if isinstance(v, int): - return torch.tensor(v, dtype=int_dtype) - elif isinstance(v, float): - return torch.tensor(v, dtype=float_dtype) - else: # Not numeric - return v - - -def infer_irreps_out( - irreps_x: Irreps, - irreps_operand: Irreps, - drop_l: Union[bool, int] = False, - parity_mode: str = 'full', - fix_multiplicity: Union[bool, int] = False, -): - assert parity_mode in ['full', 'even', 'sph'] - # (mul, (ir, p)) - irreps_out = FullTensorProduct(irreps_x, irreps_operand).irreps_out.simplify() - new_irreps_elem = [] - for mul, (l, p) in irreps_out: # noqa - elem = (mul, (l, p)) - if drop_l is not False and l > drop_l: - continue - if parity_mode == 'even' and p == -1: - continue - elif parity_mode == 'sph' and p != (-1) ** l: - continue - if fix_multiplicity: - elem = (fix_multiplicity, (l, p)) - new_irreps_elem.append(elem) - return Irreps(new_irreps_elem) - - -def download_checkpoint(path: str, url: str): - fname = osp.basename(path) - temp_path = path + '.partial' - try: - # raises permission error if fails - os.makedirs(osp.dirname(path), exist_ok=True) - response = requests.get(url, stream=True, timeout=30) - response.raise_for_status() # Raise exception for bad status codes - - total_size = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 KB chunks - - progress_bar = tqdm( - total=total_size, - unit='B', - unit_scale=True, - desc=f'Downloading {fname}', - ) - - with open(temp_path, 'wb') as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - - shutil.move(temp_path, path) - print(f'Checkpoint downloaded: {path}') - return path - except PermissionError: - raise - except Exception as e: - # Clean up partial downloads on failure - # May not work as errors handled internally by tqdm etc. - print(f'Download failed: {str(e)}') - if os.path.exists(temp_path): - print(f'Cleaning up partial download: {temp_path}') - os.remove(temp_path) - raise - - -def pretrained_name_to_path(name: str) -> str: - name = name.lower() - heads = ['sevennet', '7net'] - checkpoint_path = None - url = None - - if ( # TODO: regex - name in [f'{n}-0_11july2024' for n in heads] - or name in [f'{n}-0_11jul2024' for n in heads] - or name in ['sevennet-0', '7net-0'] - ): - checkpoint_path = _const.SEVENNET_0_11Jul2024 - elif name in [f'{n}-0_22may2024' for n in heads]: - checkpoint_path = _const.SEVENNET_0_22May2024 - elif name in [f'{n}-l3i5' for n in heads]: - checkpoint_path = _const.SEVENNET_l3i5 - elif name in [f'{n}-mf-0' for n in heads]: - checkpoint_path = _const.SEVENNET_MF_0 - elif name in [f'{n}-mf-ompa' for n in heads]: - checkpoint_path = _const.SEVENNET_MF_ompa - elif name in [f'{n}-omat' for n in heads]: - checkpoint_path = _const.SEVENNET_omat - else: - raise ValueError('Not a valid pretrained model name') - url = _const.CHECKPOINT_DOWNLOAD_LINKS.get(checkpoint_path) - - paths = [ - checkpoint_path, - checkpoint_path.replace(_const._prefix, osp.expanduser('~/.cache/sevennet')), - ] - - for path in paths: - if osp.exists(path): - return path - - # File not found check url and try download - if url is None: - raise FileNotFoundError(checkpoint_path) - - try: - return download_checkpoint(paths[0], url) # 7net package path - except PermissionError: - return download_checkpoint(paths[1], url) # ~/.cache - - -def load_checkpoint(checkpoint: Union[pathlib.Path, str]): - from sevenn.checkpoint import SevenNetCheckpoint - suggests = ['7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat'] - if osp.isfile(checkpoint): - checkpoint_path = checkpoint - else: - try: - checkpoint_path = pretrained_name_to_path(str(checkpoint)) - except ValueError: - raise ValueError( - f'Given {checkpoint} is not exists and not a pre-trained name.\n' - f'Valid pretrained model names: {suggests}' - ) - return SevenNetCheckpoint(checkpoint_path) - - -def unique_filepath(filepath: str) -> str: - if not os.path.isfile(filepath): - return filepath - else: - dirname = os.path.dirname(filepath) - fname = os.path.basename(filepath) - name, ext = os.path.splitext(fname) - cnt = 0 - new_name = f'{name}{cnt}{ext}' - new_path = os.path.join(dirname, new_name) - while os.path.exists(new_path): - cnt += 1 - new_name = f'{name}{cnt}{ext}' - new_path = os.path.join(dirname, new_name) - return new_path - - -def get_error_recorder( - recorder_tuples: List[Tuple[str, str]] = [ - ('Energy', 'RMSE'), - ('Force', 'RMSE'), - ('Stress', 'RMSE'), - ('Energy', 'MAE'), - ('Force', 'MAE'), - ('Stress', 'MAE'), - ], -): - # TODO add criterion argument and loss recorder selections - import sevenn.error_recorder as error_recorder - - config = recorder_tuples - err_metrics = [] - for err_type, metric_name in config: - metric_kwargs = error_recorder.get_err_type(err_type).copy() - metric_kwargs['name'] += f'_{metric_name}' - metric_cls = error_recorder.ErrorRecorder.METRIC_DICT[metric_name] - err_metrics.append(metric_cls(**metric_kwargs)) - return error_recorder.ErrorRecorder(err_metrics) +import os +import os.path as osp +import pathlib +import shutil +from typing import Dict, List, Tuple, Union + +import numpy as np +import requests +import torch +import torch.nn +from e3nn.o3 import FullTensorProduct, Irreps +from tqdm import tqdm + +import sevenn._const as _const +import sevenn._keys as KEY + + +def to_atom_graph_list(atom_graph_batch): + """ + torch_geometric batched data to separate list + original to_data_list() by PyG is not enough since + it doesn't handle inferred tensors + """ + is_stress = KEY.PRED_STRESS in atom_graph_batch + + data_list = atom_graph_batch.to_data_list() + + indices = atom_graph_batch[KEY.NUM_ATOMS].tolist() + + atomic_energy_list = torch.split(atom_graph_batch[KEY.ATOMIC_ENERGY], indices) + inferred_total_energy_list = torch.unbind( + atom_graph_batch[KEY.PRED_TOTAL_ENERGY] + ) + inferred_force_list = torch.split(atom_graph_batch[KEY.PRED_FORCE], indices) + + inferred_stress_list = None + if is_stress: + inferred_stress_list = torch.unbind(atom_graph_batch[KEY.PRED_STRESS]) + + for i, data in enumerate(data_list): + data[KEY.ATOMIC_ENERGY] = atomic_energy_list[i] + data[KEY.PRED_TOTAL_ENERGY] = inferred_total_energy_list[i] + data[KEY.PRED_FORCE] = inferred_force_list[i] + # To fit with KEY.STRESS (ref) format + if is_stress and inferred_stress_list is not None: + data[KEY.PRED_STRESS] = torch.unsqueeze(inferred_stress_list[i], 0) + return data_list + + +def error_recorder_from_loss_functions(loss_functions): + from .error_recorder import ErrorRecorder, MAError, RMSError, get_err_type + from .train.loss import ForceLoss, PerAtomEnergyLoss, StressLoss + + metrics = [] + for loss_function, _ in loss_functions: + ref_key = loss_function.ref_key + pred_key = loss_function.pred_key + # unit = loss_function.unit + criterion = loss_function.criterion + name = loss_function.name + base = None + if type(loss_function) is PerAtomEnergyLoss: + base = get_err_type('Energy') + elif type(loss_function) is ForceLoss: + base = get_err_type('Force') + elif type(loss_function) is StressLoss: + base = get_err_type('Stress') + else: + base = {} + base['name'] = name + base['ref_key'] = ref_key + base['pred_key'] = pred_key + if type(criterion) is torch.nn.MSELoss: + base['name'] = base['name'] + '_RMSE' + metrics.append(RMSError(**base)) + elif type(criterion) is torch.nn.L1Loss: + metrics.append(MAError(**base)) + return ErrorRecorder(metrics) + + +def onehot_to_chem(one_hot_indices: List[int], type_map: Dict[int, int]): + from ase.data import chemical_symbols + + type_map_rev = {v: k for k, v in type_map.items()} + return [chemical_symbols[type_map_rev[x]] for x in one_hot_indices] + + +def model_from_checkpoint( + checkpoint: str, +) -> Tuple[torch.nn.Module, Dict]: + cp = load_checkpoint(checkpoint) + model = cp.build_model() + + return model, cp.config + + +def model_from_checkpoint_with_backend( + checkpoint: str, + backend: str = 'e3nn', +) -> Tuple[torch.nn.Module, Dict]: + cp = load_checkpoint(checkpoint) + model = cp.build_model(backend) + + return model, cp.config + + +def unlabeled_atoms_to_input(atoms, cutoff: float, grad_key: str = KEY.EDGE_VEC): + from .atom_graph_data import AtomGraphData + from .train.dataload import unlabeled_atoms_to_graph + + atom_graph = AtomGraphData.from_numpy_dict( + unlabeled_atoms_to_graph(atoms, cutoff) + ) + atom_graph[grad_key].requires_grad_(True) + atom_graph[KEY.BATCH] = torch.zeros([0]) + return atom_graph + + +def chemical_species_preprocess(input_chem: List[str], universal: bool = False): + from ase.data import atomic_numbers, chemical_symbols + + from .nn.node_embedding import get_type_mapper_from_specie + + config = {} + if not universal: + input_chem = list(set(input_chem)) + chemical_specie = sorted([x.strip() for x in input_chem]) + config[KEY.CHEMICAL_SPECIES] = chemical_specie + config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = [ + atomic_numbers[x] for x in chemical_specie + ] + config[KEY.NUM_SPECIES] = len(chemical_specie) + config[KEY.TYPE_MAP] = get_type_mapper_from_specie(chemical_specie) + else: + config[KEY.CHEMICAL_SPECIES] = chemical_symbols + len_univ = len(chemical_symbols) + config[KEY.CHEMICAL_SPECIES_BY_ATOMIC_NUMBER] = list(range(len_univ)) + config[KEY.NUM_SPECIES] = len_univ + config[KEY.TYPE_MAP] = {z: z for z in range(len_univ)} + return config + + +def dtype_correct( + v: Union[np.ndarray, torch.Tensor, int, float], + float_dtype: torch.dtype = torch.float32, + int_dtype: torch.dtype = torch.int64, +): + if isinstance(v, np.ndarray): + if np.issubdtype(v.dtype, np.floating): + return torch.from_numpy(v).to(float_dtype) + elif np.issubdtype(v.dtype, np.integer): + return torch.from_numpy(v).to(int_dtype) + elif isinstance(v, torch.Tensor): + if v.dtype.is_floating_point: + return v.to(float_dtype) # convert to specified float dtype + else: # assuming non-floating point tensors are integers + return v.to(int_dtype) # convert to specified int dtype + else: # scalar values + if isinstance(v, int): + return torch.tensor(v, dtype=int_dtype) + elif isinstance(v, float): + return torch.tensor(v, dtype=float_dtype) + else: # Not numeric + return v + + +def infer_irreps_out( + irreps_x: Irreps, + irreps_operand: Irreps, + drop_l: Union[bool, int] = False, + parity_mode: str = 'full', + fix_multiplicity: Union[bool, int] = False, +): + assert parity_mode in ['full', 'even', 'sph'] + # (mul, (ir, p)) + irreps_out = FullTensorProduct(irreps_x, irreps_operand).irreps_out.simplify() + new_irreps_elem = [] + for mul, (l, p) in irreps_out: # noqa + elem = (mul, (l, p)) + if drop_l is not False and l > drop_l: + continue + if parity_mode == 'even' and p == -1: + continue + elif parity_mode == 'sph' and p != (-1) ** l: + continue + if fix_multiplicity: + elem = (fix_multiplicity, (l, p)) + new_irreps_elem.append(elem) + return Irreps(new_irreps_elem) + + +def download_checkpoint(path: str, url: str): + fname = osp.basename(path) + temp_path = path + '.partial' + try: + # raises permission error if fails + os.makedirs(osp.dirname(path), exist_ok=True) + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() # Raise exception for bad status codes + + total_size = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1 KB chunks + + progress_bar = tqdm( + total=total_size, + unit='B', + unit_scale=True, + desc=f'Downloading {fname}', + ) + + with open(temp_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + + shutil.move(temp_path, path) + print(f'Checkpoint downloaded: {path}') + return path + except PermissionError: + raise + except Exception as e: + # Clean up partial downloads on failure + # May not work as errors handled internally by tqdm etc. + print(f'Download failed: {str(e)}') + if os.path.exists(temp_path): + print(f'Cleaning up partial download: {temp_path}') + os.remove(temp_path) + raise + + +def pretrained_name_to_path(name: str) -> str: + name = name.lower() + heads = ['sevennet', '7net'] + checkpoint_path = None + url = None + + if ( # TODO: regex + name in [f'{n}-0_11july2024' for n in heads] + or name in [f'{n}-0_11jul2024' for n in heads] + or name in ['sevennet-0', '7net-0'] + ): + checkpoint_path = _const.SEVENNET_0_11Jul2024 + elif name in [f'{n}-0_22may2024' for n in heads]: + checkpoint_path = _const.SEVENNET_0_22May2024 + elif name in [f'{n}-l3i5' for n in heads]: + checkpoint_path = _const.SEVENNET_l3i5 + elif name in [f'{n}-mf-0' for n in heads]: + checkpoint_path = _const.SEVENNET_MF_0 + elif name in [f'{n}-mf-ompa' for n in heads]: + checkpoint_path = _const.SEVENNET_MF_ompa + elif name in [f'{n}-omat' for n in heads]: + checkpoint_path = _const.SEVENNET_omat + else: + raise ValueError('Not a valid pretrained model name') + url = _const.CHECKPOINT_DOWNLOAD_LINKS.get(checkpoint_path) + + paths = [ + checkpoint_path, + checkpoint_path.replace(_const._prefix, osp.expanduser('~/.cache/sevennet')), + ] + + for path in paths: + if osp.exists(path): + return path + + # File not found check url and try download + if url is None: + raise FileNotFoundError(checkpoint_path) + + try: + return download_checkpoint(paths[0], url) # 7net package path + except PermissionError: + return download_checkpoint(paths[1], url) # ~/.cache + + +def load_checkpoint(checkpoint: Union[pathlib.Path, str]): + from sevenn.checkpoint import SevenNetCheckpoint + suggests = ['7net-0, 7net-l3i5, 7net-mf-ompa, 7net-omat'] + if osp.isfile(checkpoint): + checkpoint_path = checkpoint + else: + try: + checkpoint_path = pretrained_name_to_path(str(checkpoint)) + except ValueError: + raise ValueError( + f'Given {checkpoint} is not exists and not a pre-trained name.\n' + f'Valid pretrained model names: {suggests}' + ) + return SevenNetCheckpoint(checkpoint_path) + + +def unique_filepath(filepath: str) -> str: + if not os.path.isfile(filepath): + return filepath + else: + dirname = os.path.dirname(filepath) + fname = os.path.basename(filepath) + name, ext = os.path.splitext(fname) + cnt = 0 + new_name = f'{name}{cnt}{ext}' + new_path = os.path.join(dirname, new_name) + while os.path.exists(new_path): + cnt += 1 + new_name = f'{name}{cnt}{ext}' + new_path = os.path.join(dirname, new_name) + return new_path + + +def get_error_recorder( + recorder_tuples: List[Tuple[str, str]] = [ + ('Energy', 'RMSE'), + ('Force', 'RMSE'), + ('Stress', 'RMSE'), + ('Energy', 'MAE'), + ('Force', 'MAE'), + ('Stress', 'MAE'), + ], +): + # TODO add criterion argument and loss recorder selections + import sevenn.error_recorder as error_recorder + + config = recorder_tuples + err_metrics = [] + for err_type, metric_name in config: + metric_kwargs = error_recorder.get_err_type(err_type).copy() + metric_kwargs['name'] += f'_{metric_name}' + metric_cls = error_recorder.ErrorRecorder.METRIC_DICT[metric_name] + err_metrics.append(metric_cls(**metric_kwargs)) + return error_recorder.ErrorRecorder(err_metrics) diff --git a/mace-bench/3rdparty/SevenNet/tests/data/inferences/snet0_on_hfo2/errors.txt b/mace-bench/3rdparty/SevenNet/tests/data/inferences/snet0_on_hfo2/errors.txt index 81bd34c..a7123ab 100644 --- a/mace-bench/3rdparty/SevenNet/tests/data/inferences/snet0_on_hfo2/errors.txt +++ b/mace-bench/3rdparty/SevenNet/tests/data/inferences/snet0_on_hfo2/errors.txt @@ -1,6 +1,6 @@ -Energy_RMSE (eV/atom): 18.84848889028682 -Force_RMSE (eV/Å): 0.2622841142173583 -Stress_RMSE (kbar): 163.7362768581691 -Energy_MAE (eV/atom): 18.848487854003906 -Force_MAE (eV/Å): 0.116698424021403 -Stress_MAE (kbar): 47.33086649576823 +Energy_RMSE (eV/atom): 18.84848889028682 +Force_RMSE (eV/Å): 0.2622841142173583 +Stress_RMSE (kbar): 163.7362768581691 +Energy_MAE (eV/atom): 18.848487854003906 +Force_MAE (eV/Å): 0.116698424021403 +Stress_MAE (kbar): 47.33086649576823 diff --git a/mace-bench/3rdparty/SevenNet/tests/lammps_tests/conftest.py b/mace-bench/3rdparty/SevenNet/tests/lammps_tests/conftest.py index d4efed3..8b7257d 100644 --- a/mace-bench/3rdparty/SevenNet/tests/lammps_tests/conftest.py +++ b/mace-bench/3rdparty/SevenNet/tests/lammps_tests/conftest.py @@ -1,24 +1,24 @@ -import pytest - - -def pytest_addoption(parser): - parser.addoption('--lammps_cmd', default=None, help='Lammps binary to test') - parser.addoption( - '--mpirun_cmd', default=None, help='mpirun binary to test parallel' - ) - - -@pytest.fixture -def lammps_cmd(request): - bin = request.config.getoption('lammps_cmd') - if bin is None: - pytest.skip('No LAMMPS binary given, skipping test') - return bin - - -@pytest.fixture -def mpirun_cmd(request): - bin = request.config.getoption('mpirun_cmd') - if bin is None: - pytest.skip('No mpirun cmd given, skipping test') - return bin +import pytest + + +def pytest_addoption(parser): + parser.addoption('--lammps_cmd', default=None, help='Lammps binary to test') + parser.addoption( + '--mpirun_cmd', default=None, help='mpirun binary to test parallel' + ) + + +@pytest.fixture +def lammps_cmd(request): + bin = request.config.getoption('lammps_cmd') + if bin is None: + pytest.skip('No LAMMPS binary given, skipping test') + return bin + + +@pytest.fixture +def mpirun_cmd(request): + bin = request.config.getoption('mpirun_cmd') + if bin is None: + pytest.skip('No mpirun cmd given, skipping test') + return bin diff --git a/mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py b/mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py index 25b0614..54f53a5 100644 --- a/mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py +++ b/mace-bench/3rdparty/SevenNet/tests/lammps_tests/test_lammps.py @@ -1,467 +1,467 @@ -import copy -import logging -import pathlib -import subprocess - -import ase.calculators.lammps -import ase.io.lammpsdata -import numpy as np -import pytest -import torch -from ase.build import bulk, surface -from ase.calculators.singlepoint import SinglePointCalculator - -import sevenn -from sevenn.calculator import SevenNetCalculator -from sevenn.model_build import build_E3_equivariant_model -from sevenn.nn.cue_helper import is_cue_available -from sevenn.scripts.deploy import deploy, deploy_parallel -from sevenn.util import chemical_species_preprocess, pretrained_name_to_path - -logger = logging.getLogger('test_lammps') - -cutoff = 4.0 - -lmp_script_path = str( - (pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve() -) - -data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() -cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O -cp_mf_path = pretrained_name_to_path('7net-mf-0') - - -@pytest.fixture(scope='module') -def serial_potential_path(tmp_path_factory): - tmp = tmp_path_factory.mktemp('serial_potential') - pot_path = str(tmp / 'deployed_serial.pt') - deploy(cp_0_path, pot_path) - return pot_path - - -@pytest.fixture(scope='module') -def parallel_potential_path(tmp_path_factory): - tmp = tmp_path_factory.mktemp('paralllel_potential') - pot_path = str(tmp / 'deployed_parallel') - deploy_parallel(cp_0_path, pot_path) - return ' '.join(['3', pot_path]) - - -@pytest.fixture(scope='module') -def serial_modal_potential_path(tmp_path_factory): - tmp = tmp_path_factory.mktemp('serial_modal_potential') - pot_path = str(tmp / 'deployed_serial.pt') - deploy(cp_mf_path, pot_path, 'PBE') - return pot_path - - -@pytest.fixture(scope='module') -def parallel_modal_potential_path(tmp_path_factory): - tmp = tmp_path_factory.mktemp('paralllel_modal_potential') - pot_path = str(tmp / 'deployed_parallel') - deploy_parallel(cp_mf_path, pot_path, 'PBE') - return ' '.join(['5', pot_path]) - - -@pytest.fixture(scope='module') -def ref_calculator(): - return SevenNetCalculator(cp_0_path) - - -@pytest.fixture(scope='module') -def ref_modal_calculator(): - return SevenNetCalculator(cp_mf_path, modal='PBE') - - -def get_model_config(): - config = { - 'cutoff': cutoff, - 'channel': 8, - 'lmax': 2, - 'is_parity': True, - 'num_convolution_layer': 3, - 'self_connection_type': 'linear', # not NequIp - 'interaction_type': 'nequip', - 'radial_basis': { - 'radial_basis_name': 'bessel', - }, - 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, - 'weight_nn_hidden_neurons': [64, 64], - 'act_radial': 'silu', - 'act_scalar': {'e': 'silu', 'o': 'tanh'}, - 'act_gate': {'e': 'silu', 'o': 'tanh'}, - 'conv_denominator': 30.0, - 'train_denominator': False, - 'shift': -10.0, - 'scale': 10.0, - 'train_shift_scale': False, - 'irreps_manual': False, - 'lmax_edge': -1, - 'lmax_node': -1, - 'readout_as_fcn': False, - 'use_bias_in_linear': False, - '_normalize_sph': True, - } - config.update(chemical_species_preprocess(['Hf', 'O'])) - return config - - -def get_model(config_overwrite=None, use_cueq=False, cueq_config=None): - cf = get_model_config() - if config_overwrite is not None: - cf.update(config_overwrite) - - cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}} - cf.update(cueq_config) - - model = build_E3_equivariant_model(cf, parallel=False) - assert not isinstance(model, list) - return model - - -def hfo2_bulk(replicate=(2, 2, 2), a=4.0): - atoms = bulk('HfO', 'rocksalt', a, orthorhombic=True) - atoms = atoms * replicate - atoms.rattle(stdev=0.10) - return atoms - - -def hf_surface(replicate=(3, 3, 1), layers=4, vacuum=0.5): - atoms = surface('Al', (1, 0, 0), layers=layers, vacuum=vacuum) - atoms.set_atomic_numbers([72] * len(atoms)) # Hf - atoms = atoms * replicate - atoms.rattle(stdev=0.10) - return atoms - - -def get_system(system_name, **kwargs): - if system_name == 'bulk': - return hfo2_bulk(**kwargs) - elif system_name == 'surface': - return hf_surface(**kwargs) - else: - raise ValueError() - - -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): - def acl(a, b, rtol=rtol, atol=atol): - return np.allclose(a, b, rtol=rtol, atol=atol) - - assert len(atoms1) == len(atoms2) - assert acl(atoms1.get_cell(), atoms2.get_cell()) - assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) - assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) - assert acl( - atoms1.get_stress(voigt=False), - atoms2.get_stress(voigt=False), - rtol * 10, - atol * 10, - ) - # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) - - -def _lammps_results_to_atoms(lammps_log, force_dump): - with open(lammps_log, 'r') as f: - lines = f.readlines() - lmp_log = None - for i, line in enumerate(lines): - if not line.startswith('Per MPI rank memory allocation'): - continue - lmp_log = { - k: eval(v) for k, v in zip(lines[i + 1].split(), lines[i + 2].split()) - } - break - - assert lmp_log is not None and 'PotEng' in lmp_log - - latoms_list = ase.io.read(force_dump, format='lammps-dump-text', index=':') - assert isinstance(latoms_list, list) - latoms = latoms_list[0] - assert latoms.calc is not None - latoms.calc.results['energy'] = lmp_log['PotEng'] - latoms.calc.results['free_energy'] = lmp_log['PotEng'] - latoms.info = { - 'data_from': 'lammps', - 'lmp_log': lmp_log, - 'lmp_dump': force_dump, - } - # atomic energy read - latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0] - stress = np.array( - [ - [lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']], - [lmp_log['Pxy'], lmp_log['Pyy'], lmp_log['Pyz']], - [lmp_log['Pxz'], lmp_log['Pyz'], lmp_log['Pzz']], - ] - ) - stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3 - latoms.calc.results['stress'] = stress - - return latoms - - -def _run_lammps(atoms, pair_style, potential, wd, command, test_name): - wd = wd.resolve() - pbc = atoms.get_pbc() - pbc_str = ' '.join(['p' if x else 'f' for x in pbc]) - chem = list(set(atoms.get_chemical_symbols())) - # Way to ase handle lammps structure - - prism = ase.calculators.lammps.coordinatetransform.Prism( - atoms.get_cell(), pbc=pbc - ) - lmp_stct = wd / 'lammps_structure' - ase.io.lammpsdata.write_lammps_data( - lmp_stct, atoms, prismobj=prism, specorder=chem - ) - - with open(lmp_script_path, 'r') as f: - cont = f.read() - - lammps_log = str(wd / 'log.lammps') - force_dump = str(wd / 'force.dump') - - var_dct = {} - var_dct['__ELEMENT__'] = ' '.join(chem) - var_dct['__LMP_STCT__'] = str(lmp_stct.resolve()) - var_dct['__PAIR_STYLE__'] = pair_style - var_dct['__POTENTIALS__'] = potential - var_dct['__BOUNDARY__'] = pbc_str - var_dct['__FORCE_DUMP_PATH__'] = force_dump - for key, val in var_dct.items(): - cont = cont.replace(key, val) - - input_script_path = str(wd / 'in.lmp') - with open(input_script_path, 'w') as f: - f.write(cont) - - command = f'{command} -in {input_script_path} -log {lammps_log}' - subprocess_routine(command.split(), test_name) - - lmp_atoms = _lammps_results_to_atoms(lammps_log, force_dump) - assert lmp_atoms.calc is not None - - rot_mat = prism.rot_mat - results = copy.deepcopy(lmp_atoms.calc.results) - r_force = np.dot(results['forces'], rot_mat.T) - results['forces'] = r_force - if 'stress' in results: - # see ase.calculators.lammpsrun.py - stress_tensor = results['stress'] - stress_atoms = np.dot(np.dot(rot_mat, stress_tensor), rot_mat.T) - results['stress'] = stress_atoms - r_cell = lmp_atoms.get_cell() @ rot_mat.T - lmp_atoms.set_cell(r_cell, scale_atoms=True) - lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms() - - return lmp_atoms - - -def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd): - command = lammps_cmd - return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name) - - -def parallel_lammps_run( - atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd -): - command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}' - return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name) - - -def subprocess_routine(cmd, name): - res = subprocess.run(cmd, capture_output=True, timeout=30) - if res.returncode != 0: - logger.error(f'Subprocess {name} failed return code: {res.returncode}') - logger.error(res.stderr.decode('utf-8')) - raise RuntimeError(f'{name} failed') - - logger.info(f'stdout of {name}:') - logger.info(res.stdout.decode('utf-8')) - - -@pytest.mark.parametrize( - 'system', - ['bulk', 'surface'], -) -def test_serial(system, serial_potential_path, ref_calculator, lammps_cmd, tmp_path): - atoms = get_system(system) - atoms_lammps = serial_lammps_run( - atoms=atoms, - potential=serial_potential_path, - wd=tmp_path, - test_name='serial lmp test', - lammps_cmd=lammps_cmd, - ) - atoms.calc = ref_calculator - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.parametrize( - 'system,ncores', - [ - ('bulk', 1), - ('bulk', 2), - ('bulk', 4), - ('surface', 1), - ('surface', 2), - ('surface', 3), - ('surface', 4), - ], -) -def test_parallel( - system, - ncores, - parallel_potential_path, - ref_calculator, - lammps_cmd, - mpirun_cmd, - tmp_path, -): - if system == 'bulk': - rep = (6, 6, 3) - elif system == 'surface': - rep = (4, 4, 1) - else: - assert False - atoms = get_system(system, replicate=rep) - atoms_lammps = parallel_lammps_run( - atoms=atoms, - potential=parallel_potential_path, - wd=tmp_path, - test_name='parallel lmp test', - lammps_cmd=lammps_cmd, - mpirun_cmd=mpirun_cmd, - ncores=ncores, - ) - atoms.calc = ref_calculator - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.parametrize( - 'system', - ['bulk', 'surface'], -) -def test_modal_serial( - system, serial_modal_potential_path, ref_modal_calculator, lammps_cmd, tmp_path -): - atoms = get_system(system) - atoms_lammps = serial_lammps_run( - atoms=atoms, - potential=serial_modal_potential_path, - wd=tmp_path, - test_name='serial lmp test', - lammps_cmd=lammps_cmd, - ) - atoms.calc = ref_modal_calculator - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.parametrize( - 'system,ncores', - [ - ('bulk', 2), - ('surface', 2), - ], -) -def test_modal_parallel( - system, - ncores, - parallel_modal_potential_path, - ref_modal_calculator, - lammps_cmd, - mpirun_cmd, - tmp_path, -): - if system == 'bulk': - rep = (6, 6, 3) - elif system == 'surface': - rep = (4, 4, 1) - else: - assert False - atoms = get_system(system, replicate=rep) - atoms_lammps = parallel_lammps_run( - atoms=atoms, - potential=parallel_modal_potential_path, - wd=tmp_path, - test_name='parallel lmp test', - lammps_cmd=lammps_cmd, - mpirun_cmd=mpirun_cmd, - ncores=ncores, - ) - atoms.calc = ref_modal_calculator - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') -def test_cueq_serial(lammps_cmd, tmp_path): - """ - TODO: Use already saved cueq enabled checkpoint after cueq becomes stable - """ - cueq = True - model = get_model(use_cueq=cueq) - ref_calc = SevenNetCalculator(model, file_type='model_instance') - atoms = get_system('bulk') - - cfg = get_model_config() - cfg.update( - {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} - ) - - cp_path = str(tmp_path / 'cp.pth') - torch.save( - {'model_state_dict': model.state_dict(), 'config': cfg}, - cp_path, - ) - - pot_path = str(tmp_path / 'deployed_from_cueq_serial.pt') - deploy(cp_path, pot_path) - - atoms_lammps = serial_lammps_run( - atoms=atoms, - potential=pot_path, - wd=tmp_path, - test_name='cueq checkpoint serial lmp run test', - lammps_cmd=lammps_cmd, - ) - atoms.calc = ref_calc - assert_atoms(atoms, atoms_lammps) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') -def test_cueq_parallel(lammps_cmd, mpirun_cmd, tmp_path): - """ - TODO: Use already saved cueq enabled checkpoint after cueq becomes stable - """ - cueq = True - model = get_model(use_cueq=cueq) - ref_calc = SevenNetCalculator(model, file_type='model_instance') - atoms = get_system('surface', replicate=(4, 4, 1)) - - cfg = get_model_config() - cfg.update( - {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} - ) - - cp_path = str(tmp_path / 'cp.pth') - torch.save( - {'model_state_dict': model.state_dict(), 'config': cfg}, - cp_path, - ) - - pot_path = str(tmp_path / 'deployed_from_cueq_parallel') - deploy_parallel(cp_path, pot_path) - - atoms_lammps = parallel_lammps_run( - atoms=atoms, - potential=' '.join([str(cfg['num_convolution_layer']), pot_path]), - wd=tmp_path, - test_name='cueq checkpoint parallel lmp run test', - lammps_cmd=lammps_cmd, - mpirun_cmd=mpirun_cmd, - ncores=2, - ) - atoms.calc = ref_calc - assert_atoms(atoms, atoms_lammps) +import copy +import logging +import pathlib +import subprocess + +import ase.calculators.lammps +import ase.io.lammpsdata +import numpy as np +import pytest +import torch +from ase.build import bulk, surface +from ase.calculators.singlepoint import SinglePointCalculator + +import sevenn +from sevenn.calculator import SevenNetCalculator +from sevenn.model_build import build_E3_equivariant_model +from sevenn.nn.cue_helper import is_cue_available +from sevenn.scripts.deploy import deploy, deploy_parallel +from sevenn.util import chemical_species_preprocess, pretrained_name_to_path + +logger = logging.getLogger('test_lammps') + +cutoff = 4.0 + +lmp_script_path = str( + (pathlib.Path(__file__).parent / 'scripts' / 'skel.lmp').resolve() +) + +data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() +cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') # knows Hf, O +cp_mf_path = pretrained_name_to_path('7net-mf-0') + + +@pytest.fixture(scope='module') +def serial_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_potential') + pot_path = str(tmp / 'deployed_serial.pt') + deploy(cp_0_path, pot_path) + return pot_path + + +@pytest.fixture(scope='module') +def parallel_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('paralllel_potential') + pot_path = str(tmp / 'deployed_parallel') + deploy_parallel(cp_0_path, pot_path) + return ' '.join(['3', pot_path]) + + +@pytest.fixture(scope='module') +def serial_modal_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('serial_modal_potential') + pot_path = str(tmp / 'deployed_serial.pt') + deploy(cp_mf_path, pot_path, 'PBE') + return pot_path + + +@pytest.fixture(scope='module') +def parallel_modal_potential_path(tmp_path_factory): + tmp = tmp_path_factory.mktemp('paralllel_modal_potential') + pot_path = str(tmp / 'deployed_parallel') + deploy_parallel(cp_mf_path, pot_path, 'PBE') + return ' '.join(['5', pot_path]) + + +@pytest.fixture(scope='module') +def ref_calculator(): + return SevenNetCalculator(cp_0_path) + + +@pytest.fixture(scope='module') +def ref_modal_calculator(): + return SevenNetCalculator(cp_mf_path, modal='PBE') + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 8, + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'self_connection_type': 'linear', # not NequIp + 'interaction_type': 'nequip', + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': 30.0, + 'train_denominator': False, + 'shift': -10.0, + 'scale': 10.0, + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + config.update(chemical_species_preprocess(['Hf', 'O'])) + return config + + +def get_model(config_overwrite=None, use_cueq=False, cueq_config=None): + cf = get_model_config() + if config_overwrite is not None: + cf.update(config_overwrite) + + cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}} + cf.update(cueq_config) + + model = build_E3_equivariant_model(cf, parallel=False) + assert not isinstance(model, list) + return model + + +def hfo2_bulk(replicate=(2, 2, 2), a=4.0): + atoms = bulk('HfO', 'rocksalt', a, orthorhombic=True) + atoms = atoms * replicate + atoms.rattle(stdev=0.10) + return atoms + + +def hf_surface(replicate=(3, 3, 1), layers=4, vacuum=0.5): + atoms = surface('Al', (1, 0, 0), layers=layers, vacuum=vacuum) + atoms.set_atomic_numbers([72] * len(atoms)) # Hf + atoms = atoms * replicate + atoms.rattle(stdev=0.10) + return atoms + + +def get_system(system_name, **kwargs): + if system_name == 'bulk': + return hfo2_bulk(**kwargs) + elif system_name == 'surface': + return hf_surface(**kwargs) + else: + raise ValueError() + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +def _lammps_results_to_atoms(lammps_log, force_dump): + with open(lammps_log, 'r') as f: + lines = f.readlines() + lmp_log = None + for i, line in enumerate(lines): + if not line.startswith('Per MPI rank memory allocation'): + continue + lmp_log = { + k: eval(v) for k, v in zip(lines[i + 1].split(), lines[i + 2].split()) + } + break + + assert lmp_log is not None and 'PotEng' in lmp_log + + latoms_list = ase.io.read(force_dump, format='lammps-dump-text', index=':') + assert isinstance(latoms_list, list) + latoms = latoms_list[0] + assert latoms.calc is not None + latoms.calc.results['energy'] = lmp_log['PotEng'] + latoms.calc.results['free_energy'] = lmp_log['PotEng'] + latoms.info = { + 'data_from': 'lammps', + 'lmp_log': lmp_log, + 'lmp_dump': force_dump, + } + # atomic energy read + latoms.calc.results['energies'] = latoms.arrays['c_pa'][:, 0] + stress = np.array( + [ + [lmp_log['Pxx'], lmp_log['Pxy'], lmp_log['Pxz']], + [lmp_log['Pxy'], lmp_log['Pyy'], lmp_log['Pyz']], + [lmp_log['Pxz'], lmp_log['Pyz'], lmp_log['Pzz']], + ] + ) + stress = -1 * stress / 1602.1766208 / 1000 # convert bars to eV/A^3 + latoms.calc.results['stress'] = stress + + return latoms + + +def _run_lammps(atoms, pair_style, potential, wd, command, test_name): + wd = wd.resolve() + pbc = atoms.get_pbc() + pbc_str = ' '.join(['p' if x else 'f' for x in pbc]) + chem = list(set(atoms.get_chemical_symbols())) + # Way to ase handle lammps structure + + prism = ase.calculators.lammps.coordinatetransform.Prism( + atoms.get_cell(), pbc=pbc + ) + lmp_stct = wd / 'lammps_structure' + ase.io.lammpsdata.write_lammps_data( + lmp_stct, atoms, prismobj=prism, specorder=chem + ) + + with open(lmp_script_path, 'r') as f: + cont = f.read() + + lammps_log = str(wd / 'log.lammps') + force_dump = str(wd / 'force.dump') + + var_dct = {} + var_dct['__ELEMENT__'] = ' '.join(chem) + var_dct['__LMP_STCT__'] = str(lmp_stct.resolve()) + var_dct['__PAIR_STYLE__'] = pair_style + var_dct['__POTENTIALS__'] = potential + var_dct['__BOUNDARY__'] = pbc_str + var_dct['__FORCE_DUMP_PATH__'] = force_dump + for key, val in var_dct.items(): + cont = cont.replace(key, val) + + input_script_path = str(wd / 'in.lmp') + with open(input_script_path, 'w') as f: + f.write(cont) + + command = f'{command} -in {input_script_path} -log {lammps_log}' + subprocess_routine(command.split(), test_name) + + lmp_atoms = _lammps_results_to_atoms(lammps_log, force_dump) + assert lmp_atoms.calc is not None + + rot_mat = prism.rot_mat + results = copy.deepcopy(lmp_atoms.calc.results) + r_force = np.dot(results['forces'], rot_mat.T) + results['forces'] = r_force + if 'stress' in results: + # see ase.calculators.lammpsrun.py + stress_tensor = results['stress'] + stress_atoms = np.dot(np.dot(rot_mat, stress_tensor), rot_mat.T) + results['stress'] = stress_atoms + r_cell = lmp_atoms.get_cell() @ rot_mat.T + lmp_atoms.set_cell(r_cell, scale_atoms=True) + lmp_atoms = SinglePointCalculator(lmp_atoms, **results).get_atoms() + + return lmp_atoms + + +def serial_lammps_run(atoms, potential, wd, test_name, lammps_cmd): + command = lammps_cmd + return _run_lammps(atoms, 'e3gnn', potential, wd, command, test_name) + + +def parallel_lammps_run( + atoms, potential, wd, test_name, ncores, lammps_cmd, mpirun_cmd +): + command = f'{mpirun_cmd} -np {ncores} {lammps_cmd}' + return _run_lammps(atoms, 'e3gnn/parallel', potential, wd, command, test_name) + + +def subprocess_routine(cmd, name): + res = subprocess.run(cmd, capture_output=True, timeout=30) + if res.returncode != 0: + logger.error(f'Subprocess {name} failed return code: {res.returncode}') + logger.error(res.stderr.decode('utf-8')) + raise RuntimeError(f'{name} failed') + + logger.info(f'stdout of {name}:') + logger.info(res.stdout.decode('utf-8')) + + +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_serial(system, serial_potential_path, ref_calculator, lammps_cmd, tmp_path): + atoms = get_system(system) + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=serial_potential_path, + wd=tmp_path, + test_name='serial lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_calculator + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.parametrize( + 'system,ncores', + [ + ('bulk', 1), + ('bulk', 2), + ('bulk', 4), + ('surface', 1), + ('surface', 2), + ('surface', 3), + ('surface', 4), + ], +) +def test_parallel( + system, + ncores, + parallel_potential_path, + ref_calculator, + lammps_cmd, + mpirun_cmd, + tmp_path, +): + if system == 'bulk': + rep = (6, 6, 3) + elif system == 'surface': + rep = (4, 4, 1) + else: + assert False + atoms = get_system(system, replicate=rep) + atoms_lammps = parallel_lammps_run( + atoms=atoms, + potential=parallel_potential_path, + wd=tmp_path, + test_name='parallel lmp test', + lammps_cmd=lammps_cmd, + mpirun_cmd=mpirun_cmd, + ncores=ncores, + ) + atoms.calc = ref_calculator + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.parametrize( + 'system', + ['bulk', 'surface'], +) +def test_modal_serial( + system, serial_modal_potential_path, ref_modal_calculator, lammps_cmd, tmp_path +): + atoms = get_system(system) + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=serial_modal_potential_path, + wd=tmp_path, + test_name='serial lmp test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_modal_calculator + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.parametrize( + 'system,ncores', + [ + ('bulk', 2), + ('surface', 2), + ], +) +def test_modal_parallel( + system, + ncores, + parallel_modal_potential_path, + ref_modal_calculator, + lammps_cmd, + mpirun_cmd, + tmp_path, +): + if system == 'bulk': + rep = (6, 6, 3) + elif system == 'surface': + rep = (4, 4, 1) + else: + assert False + atoms = get_system(system, replicate=rep) + atoms_lammps = parallel_lammps_run( + atoms=atoms, + potential=parallel_modal_potential_path, + wd=tmp_path, + test_name='parallel lmp test', + lammps_cmd=lammps_cmd, + mpirun_cmd=mpirun_cmd, + ncores=ncores, + ) + atoms.calc = ref_modal_calculator + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') +def test_cueq_serial(lammps_cmd, tmp_path): + """ + TODO: Use already saved cueq enabled checkpoint after cueq becomes stable + """ + cueq = True + model = get_model(use_cueq=cueq) + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = get_system('bulk') + + cfg = get_model_config() + cfg.update( + {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} + ) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + pot_path = str(tmp_path / 'deployed_from_cueq_serial.pt') + deploy(cp_path, pot_path) + + atoms_lammps = serial_lammps_run( + atoms=atoms, + potential=pot_path, + wd=tmp_path, + test_name='cueq checkpoint serial lmp run test', + lammps_cmd=lammps_cmd, + ) + atoms.calc = ref_calc + assert_atoms(atoms, atoms_lammps) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') +def test_cueq_parallel(lammps_cmd, mpirun_cmd, tmp_path): + """ + TODO: Use already saved cueq enabled checkpoint after cueq becomes stable + """ + cueq = True + model = get_model(use_cueq=cueq) + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = get_system('surface', replicate=(4, 4, 1)) + + cfg = get_model_config() + cfg.update( + {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} + ) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + pot_path = str(tmp_path / 'deployed_from_cueq_parallel') + deploy_parallel(cp_path, pot_path) + + atoms_lammps = parallel_lammps_run( + atoms=atoms, + potential=' '.join([str(cfg['num_convolution_layer']), pot_path]), + wd=tmp_path, + test_name='cueq checkpoint parallel lmp run test', + lammps_cmd=lammps_cmd, + mpirun_cmd=mpirun_cmd, + ncores=2, + ) + atoms.calc = ref_calc + assert_atoms(atoms, atoms_lammps) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py index f2e6f8e..c249492 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_calculator.py @@ -1,217 +1,217 @@ -import copy - -import numpy as np -import pytest -from ase.build import bulk, molecule - -from sevenn.calculator import D3Calculator, SevenNetCalculator -from sevenn.nn.cue_helper import is_cue_available -from sevenn.scripts.deploy import deploy -from sevenn.util import ( - model_from_checkpoint, - model_from_checkpoint_with_backend, - pretrained_name_to_path, -) - - -@pytest.fixture -def atoms_pbc(): - atoms1 = bulk('NaCl', 'rocksalt', a=5.63) - atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) - atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) - return atoms1 - - -@pytest.fixture -def atoms_mol(): - atoms2 = molecule('H2O') - atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]]) - return atoms2 - - -@pytest.fixture(scope='module') -def sevennet_0_cal(): - return SevenNetCalculator('7net-0_11July2024') - - -@pytest.fixture(scope='module') -def sevennet_0_cueq_cal(): - cpp = pretrained_name_to_path('7net-0_11July2024') - model, _ = model_from_checkpoint_with_backend(cpp, 'cueq') - return SevenNetCalculator(model) - - -@pytest.fixture(scope='module') -def d3_cal(): - try: - return D3Calculator() - except NotImplementedError as e: - pytest.skip(f'{e}') - - -def test_sevennet_0_cal_pbc(atoms_pbc, sevennet_0_cal): - atoms1_ref = { - 'energy': -3.779199, - 'energies': [-1.8493923, -1.9298072], - 'force': [ - [12.666697, 0.04726403, 0.04775861], - [-12.666697, -0.04726403, -0.04775861], - ], - 'stress': [ - [ - -0.6439122, - -0.03643947, - -0.03643981, - 0.00599139, - 0.04544507, - 0.04543639, - ] - ], - } - - atoms_pbc.calc = sevennet_0_cal - assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) - assert np.allclose( - atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] - ) - assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) - assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) - assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies']) - - -def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): - atoms2_ref = { - 'energy': -12.782808303833008, - 'energies': [-6.2493525, -3.141562, -3.3918958], - 'force': [ - [0.0, -1.3619621e01, 7.5937047e00], - [0.0, 9.3918495e00, -1.0172190e01], - [0.0, 4.2277718e00, 2.5784855e00], - ], - } - atoms_mol.calc = sevennet_0_cal - assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy']) - assert np.allclose( - atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy'] - ) - assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force']) - assert np.allclose(atoms_mol.get_potential_energies(), atoms2_ref['energies']) - - -def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): - fname = str(tmp_path / '7net_0.pt') - deploy(pretrained_name_to_path('7net-0_11July2024'), fname) - - calc_script = SevenNetCalculator(fname, file_type='torchscript') - calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) - - atoms_pbc.calc = calc_cp - atoms_pbc.get_potential_energy() - res_cp = copy.copy(atoms_pbc.calc.results) - - atoms_pbc.calc = calc_script - atoms_pbc.get_potential_energy() - res_script = copy.copy(atoms_pbc.calc.results) - - for k in res_cp: - assert np.allclose(res_cp[k], res_script[k]) - - -def test_sevennet_0_cal_as_instance_consistency(atoms_pbc): - model, _ = model_from_checkpoint( - pretrained_name_to_path('7net-0_11July2024') - ) - - calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) - calc_instance = SevenNetCalculator(model, file_type='model_instance') - - atoms_pbc.calc = calc_cp - atoms_pbc.get_potential_energy() - res_cp = copy.copy(atoms_pbc.calc.results) - - atoms_pbc.calc = calc_instance - atoms_pbc.get_potential_energy() - res_script = copy.copy(atoms_pbc.calc.results) - - for k in res_cp: - assert np.allclose(res_cp[k], res_script[k]) - - -@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') -def test_sevennet_0_cal_cueq(atoms_pbc, sevennet_0_cueq_cal): - atoms1_ref = { - 'energy': -3.779199, - 'energies': [-1.8493923, -1.9298072], - 'force': [ - [12.666697, 0.04726403, 0.04775861], - [-12.666697, -0.04726403, -0.04775861], - ], - 'stress': [ - [ - -0.6439122, - -0.03643947, - -0.03643981, - 0.00599139, - 0.04544507, - 0.04543639, - ] - ], - } - - atoms_pbc.calc = sevennet_0_cueq_cal - - assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) - assert np.allclose( - atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] - ) - assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) - assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) - assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies']) - - -def test_d3_cal_pbc(atoms_pbc, d3_cal): - atoms1_ref = { - 'energy': -0.531393751583389, - 'force': [ - [-0.00570205, 0.00107457, 0.00107459], - [0.00570205, -0.00107457, -0.00107459], - ], - 'stress': [ - [ - 1.52403705e-02, - 1.50417333e-02, - 1.50417321e-02, - -3.22684163e-05, - -5.05532863e-05, - -5.05586994e-05, - ] - ], - } - - atoms_pbc.calc = d3_cal - - assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) - assert np.allclose( - atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] - ) - assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) - assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) - - -def test_d3_cal_mol(atoms_mol, d3_cal): - atoms2_ref = { - 'energy': -0.009889134535170716, - 'force': [ - [0.0, 2.04263840e-03, 1.27477674e-03], - [0.0, -9.90038901e-05, 1.18046682e-06], - [0.0, -1.94363451e-03, -1.27595721e-03], - ], - } - - atoms_mol.calc = d3_cal - - assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy']) - assert np.allclose( - atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy'] - ) - assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force']) +import copy + +import numpy as np +import pytest +from ase.build import bulk, molecule + +from sevenn.calculator import D3Calculator, SevenNetCalculator +from sevenn.nn.cue_helper import is_cue_available +from sevenn.scripts.deploy import deploy +from sevenn.util import ( + model_from_checkpoint, + model_from_checkpoint_with_backend, + pretrained_name_to_path, +) + + +@pytest.fixture +def atoms_pbc(): + atoms1 = bulk('NaCl', 'rocksalt', a=5.63) + atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms1 + + +@pytest.fixture +def atoms_mol(): + atoms2 = molecule('H2O') + atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]]) + return atoms2 + + +@pytest.fixture(scope='module') +def sevennet_0_cal(): + return SevenNetCalculator('7net-0_11July2024') + + +@pytest.fixture(scope='module') +def sevennet_0_cueq_cal(): + cpp = pretrained_name_to_path('7net-0_11July2024') + model, _ = model_from_checkpoint_with_backend(cpp, 'cueq') + return SevenNetCalculator(model) + + +@pytest.fixture(scope='module') +def d3_cal(): + try: + return D3Calculator() + except NotImplementedError as e: + pytest.skip(f'{e}') + + +def test_sevennet_0_cal_pbc(atoms_pbc, sevennet_0_cal): + atoms1_ref = { + 'energy': -3.779199, + 'energies': [-1.8493923, -1.9298072], + 'force': [ + [12.666697, 0.04726403, 0.04775861], + [-12.666697, -0.04726403, -0.04775861], + ], + 'stress': [ + [ + -0.6439122, + -0.03643947, + -0.03643981, + 0.00599139, + 0.04544507, + 0.04543639, + ] + ], + } + + atoms_pbc.calc = sevennet_0_cal + assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) + assert np.allclose( + atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] + ) + assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) + assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) + assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies']) + + +def test_sevennet_0_cal_mol(atoms_mol, sevennet_0_cal): + atoms2_ref = { + 'energy': -12.782808303833008, + 'energies': [-6.2493525, -3.141562, -3.3918958], + 'force': [ + [0.0, -1.3619621e01, 7.5937047e00], + [0.0, 9.3918495e00, -1.0172190e01], + [0.0, 4.2277718e00, 2.5784855e00], + ], + } + atoms_mol.calc = sevennet_0_cal + assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy']) + assert np.allclose( + atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy'] + ) + assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force']) + assert np.allclose(atoms_mol.get_potential_energies(), atoms2_ref['energies']) + + +def test_sevennet_0_cal_deployed_consistency(tmp_path, atoms_pbc): + fname = str(tmp_path / '7net_0.pt') + deploy(pretrained_name_to_path('7net-0_11July2024'), fname) + + calc_script = SevenNetCalculator(fname, file_type='torchscript') + calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) + + atoms_pbc.calc = calc_cp + atoms_pbc.get_potential_energy() + res_cp = copy.copy(atoms_pbc.calc.results) + + atoms_pbc.calc = calc_script + atoms_pbc.get_potential_energy() + res_script = copy.copy(atoms_pbc.calc.results) + + for k in res_cp: + assert np.allclose(res_cp[k], res_script[k]) + + +def test_sevennet_0_cal_as_instance_consistency(atoms_pbc): + model, _ = model_from_checkpoint( + pretrained_name_to_path('7net-0_11July2024') + ) + + calc_cp = SevenNetCalculator(pretrained_name_to_path('7net-0_11July2024')) + calc_instance = SevenNetCalculator(model, file_type='model_instance') + + atoms_pbc.calc = calc_cp + atoms_pbc.get_potential_energy() + res_cp = copy.copy(atoms_pbc.calc.results) + + atoms_pbc.calc = calc_instance + atoms_pbc.get_potential_energy() + res_script = copy.copy(atoms_pbc.calc.results) + + for k in res_cp: + assert np.allclose(res_cp[k], res_script[k]) + + +@pytest.mark.skipif(not is_cue_available(), reason='cueq not available') +def test_sevennet_0_cal_cueq(atoms_pbc, sevennet_0_cueq_cal): + atoms1_ref = { + 'energy': -3.779199, + 'energies': [-1.8493923, -1.9298072], + 'force': [ + [12.666697, 0.04726403, 0.04775861], + [-12.666697, -0.04726403, -0.04775861], + ], + 'stress': [ + [ + -0.6439122, + -0.03643947, + -0.03643981, + 0.00599139, + 0.04544507, + 0.04543639, + ] + ], + } + + atoms_pbc.calc = sevennet_0_cueq_cal + + assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) + assert np.allclose( + atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] + ) + assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) + assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) + assert np.allclose(atoms_pbc.get_potential_energies(), atoms1_ref['energies']) + + +def test_d3_cal_pbc(atoms_pbc, d3_cal): + atoms1_ref = { + 'energy': -0.531393751583389, + 'force': [ + [-0.00570205, 0.00107457, 0.00107459], + [0.00570205, -0.00107457, -0.00107459], + ], + 'stress': [ + [ + 1.52403705e-02, + 1.50417333e-02, + 1.50417321e-02, + -3.22684163e-05, + -5.05532863e-05, + -5.05586994e-05, + ] + ], + } + + atoms_pbc.calc = d3_cal + + assert np.allclose(atoms_pbc.get_potential_energy(), atoms1_ref['energy']) + assert np.allclose( + atoms_pbc.get_potential_energy(force_consistent=True), atoms1_ref['energy'] + ) + assert np.allclose(atoms_pbc.get_forces(), atoms1_ref['force']) + assert np.allclose(atoms_pbc.get_stress(), atoms1_ref['stress']) + + +def test_d3_cal_mol(atoms_mol, d3_cal): + atoms2_ref = { + 'energy': -0.009889134535170716, + 'force': [ + [0.0, 2.04263840e-03, 1.27477674e-03], + [0.0, -9.90038901e-05, 1.18046682e-06], + [0.0, -1.94363451e-03, -1.27595721e-03], + ], + } + + atoms_mol.calc = d3_cal + + assert np.allclose(atoms_mol.get_potential_energy(), atoms2_ref['energy']) + assert np.allclose( + atoms_mol.get_potential_energy(force_consistent=True), atoms2_ref['energy'] + ) + assert np.allclose(atoms_mol.get_forces(), atoms2_ref['force']) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py index ad0e80c..0bdeeab 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cli.py @@ -1,233 +1,233 @@ -import csv -import os -import pathlib -from unittest import mock - -import ase.io -import numpy as np -import pytest -import yaml -from ase.build import bulk - -from sevenn.calculator import SevenNetCalculator -from sevenn.logger import Logger -from sevenn.main.sevenn import main as sevenn_main -from sevenn.main.sevenn_get_model import main as get_model_main -from sevenn.main.sevenn_graph_build import main as graph_build_main -from sevenn.main.sevenn_inference import main as inference_main -from sevenn.util import pretrained_name_to_path - -main = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/main/') -preset = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/presets/') -file_path = pathlib.Path(__file__).parent.resolve() - -data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() -hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') -hfo2_7net_0_inference_path = data_root / 'inferences' / 'snet0_on_hfo2' -cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') - -Logger() # init - - -@pytest.fixture -def atoms_hfo(): - atoms1 = bulk('HfO', 'rocksalt', a=5.63) - atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) - atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) - return atoms1 - - -@pytest.fixture(scope='module') -def sevennet_0_cal(): - return SevenNetCalculator('7net-0_11July2024') - - -def test_get_model_serial(tmp_path, capsys): - output_file = tmp_path / 'mypot.pt' - cp = pretrained_name_to_path('7net-0') - cli_args = ['-o', str(output_file), cp] - with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args): - get_model_main() - _ = capsys.readouterr() # not used - assert output_file.is_file(), '.pt file is not written' - - -def test_get_model_parallel(tmp_path, capsys): - output_dir = tmp_path / 'my_parallel' - cp = pretrained_name_to_path('7net-0') - expected_file_cnt = 5 # 5 interaction layers - cli_args = ['-o', str(output_dir), '-p', cp] - with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args): - # with pytest.raises(SystemExit): - get_model_main() - _ = capsys.readouterr() # not used - assert output_dir.is_dir(), 'parallel model directory not exist' - for i in range(expected_file_cnt): - assert (output_dir / f'deployed_parallel_{i}.pt').is_file() - - -@pytest.mark.parametrize('source', [(hfo2_path)]) -def test_graph_build(source, tmp_path): - output_dir = tmp_path / 'sevenn_data' - output_f = output_dir / 'my_graph.pt' - output_yml = output_dir / 'my_graph.yaml' - cli_args = ['-o', str(tmp_path), '-f', 'my_graph.pt', source, '4.0'] - with mock.patch('sys.argv', [f'{main}/sevenn_graph_build.py'] + cli_args): - graph_build_main() - - assert output_dir.is_dir() - assert output_f.is_file() - assert output_yml.is_file() - - -@pytest.mark.parametrize( - 'batch,device,save_graph', - [ - (1, 'cpu', False), - (2, 'cpu', False), - (1, 'cpu', True), - ], -) -def test_inference(batch, device, save_graph, tmp_path): - checkpoint = '7net-0' - target = hfo2_path - ref_path = hfo2_7net_0_inference_path - - output_dir = tmp_path / 'inference_results' - files = ['info.csv', 'per_graph.csv', 'per_atom.csv', 'errors.txt'] - cli_args = [ - '--output', - str(output_dir), - '--device', - device, - '--batch', - str(batch), - checkpoint, - target, - ] - if save_graph: - cli_args.append('--save_graph') - with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): - inference_main() - - assert output_dir.is_dir() - for f in files: - assert (output_dir / f).is_file() - with open(output_dir / 'errors.txt', 'r', encoding='utf-8') as f: - errors = [float(ll.split(':')[-1].strip()) for ll in f.readlines()] - with open(ref_path / 'errors.txt', 'r', encoding='utf-8') as f: - errors_ref = [float(ll.split(':')[-1].strip()) for ll in f.readlines()] - assert np.allclose(np.array(errors), np.array(errors_ref)) - - """ - # TODO: commented out as currently SevenNetGraphDataset can't do this - with open(output_dir / 'info.csv', 'r') as f: - reader = csv.DictReader(f) - for dct in reader: - assert dct['file'] == hfo2_path - assert reader.line_num == 3 - """ - - if save_graph: - assert (output_dir / 'sevenn_data').is_dir() - assert (output_dir / 'sevenn_data' / 'saved_graph.pt').is_file() - assert (output_dir / 'sevenn_data' / 'saved_graph.yaml').is_file() - - -def test_inference_unlabeled(atoms_hfo, tmp_path): - labeled = str(hfo2_path) - unlabeled = str(tmp_path / 'unlabeled.xyz') - ase.io.write(unlabeled, atoms_hfo) - - output_dir = tmp_path / 'inference_results' - cli_args = [ - '--output', - str(output_dir), - '--allow_unlabeled', - cp_0_path, - labeled, - unlabeled, - ] - with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): - inference_main() - - with open(output_dir / 'info.csv', 'r') as f: - reader = csv.DictReader(f) - for dct in reader: - assert dct['file'] in [labeled, unlabeled] - assert reader.line_num == 4 - - -def test_inference_labeled_w_kwargs(atoms_hfo, tmp_path): - atoms_hfo.info['my_energy'] = 1.0 - atoms_hfo.arrays['my_force'] = np.full((len(atoms_hfo), 3), 7.7) - # this should be considered as Voigt, xx, yy, zz, yz, zx, xy - atoms_hfo.info['my_stress'] = np.array([1, 2, 3, 4, 5, 6]) - - unlabeled = str(tmp_path / 'unlabeled.xyz') - ase.io.write(unlabeled, atoms_hfo) - - output_dir = tmp_path / 'inference_results' - cli_args = [ - '--output', - str(output_dir), - cp_0_path, - unlabeled, - '--kwargs', - 'energy_key=my_energy', - 'force_key=my_force', - 'stress_key=my_stress', - ] - with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): - inference_main() - - per_graph = None - with open(output_dir / 'per_graph.csv', 'r') as f: - reader = csv.DictReader(f) - for dct in reader: - per_graph = dct - assert reader.line_num == 2 - assert per_graph is not None - - stress_coeff = -1602.1766208 - assert np.allclose(float(per_graph['stress_yy']), 2 * stress_coeff) - assert np.allclose(float(per_graph['stress_yz']), 4 * stress_coeff) - assert np.allclose(float(per_graph['stress_zx']), 5 * stress_coeff) - assert np.allclose(float(per_graph['stress_xy']), 6 * stress_coeff) - - -@pytest.mark.parametrize( - 'preset_name,mode,data_path', - [ - ('fine_tune', 'train_v2', hfo2_path), - ('base', 'train_v2', hfo2_path), - ('sevennet-0', 'train_v1', hfo2_path), - ], -) -def test_sevenn_preset(preset_name, mode, data_path, tmp_path): - preset_path = os.path.join(preset, preset_name + '.yaml') - with open(preset_path, 'r') as f: - cfg = yaml.safe_load(f) - - cfg['train']['epoch'] = 1 - if mode == 'train_v2': - cfg['data']['load_trainset_path'] = data_path - cfg['data'].pop('load_testset_path', None) - elif mode == 'train_v1': - cfg['data']['load_dataset_path'] = data_path - else: - assert False - cfg['data']['load_validset_path'] = data_path - - input_yam = str(tmp_path / 'input.yaml') - with open(input_yam, 'w') as f: - yaml.dump(cfg, f) - - Logger().switch_file(str(tmp_path / 'log.sevenn')) - cli_args = ['train', '-w', str(tmp_path), '-m', mode, input_yam] - with mock.patch('sys.argv', [f'{main}/sevenn.py'] + cli_args): - sevenn_main() - - assert (tmp_path / 'lc.csv').is_file() or (tmp_path / 'log.csv').is_file() - assert (tmp_path / 'log.sevenn').is_file() - assert (tmp_path / 'checkpoint_best.pth').is_file() +import csv +import os +import pathlib +from unittest import mock + +import ase.io +import numpy as np +import pytest +import yaml +from ase.build import bulk + +from sevenn.calculator import SevenNetCalculator +from sevenn.logger import Logger +from sevenn.main.sevenn import main as sevenn_main +from sevenn.main.sevenn_get_model import main as get_model_main +from sevenn.main.sevenn_graph_build import main as graph_build_main +from sevenn.main.sevenn_inference import main as inference_main +from sevenn.util import pretrained_name_to_path + +main = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/main/') +preset = os.path.abspath(f'{os.path.dirname(__file__)}/../../sevenn/presets/') +file_path = pathlib.Path(__file__).parent.resolve() + +data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() +hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') +hfo2_7net_0_inference_path = data_root / 'inferences' / 'snet0_on_hfo2' +cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') + +Logger() # init + + +@pytest.fixture +def atoms_hfo(): + atoms1 = bulk('HfO', 'rocksalt', a=5.63) + atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms1 + + +@pytest.fixture(scope='module') +def sevennet_0_cal(): + return SevenNetCalculator('7net-0_11July2024') + + +def test_get_model_serial(tmp_path, capsys): + output_file = tmp_path / 'mypot.pt' + cp = pretrained_name_to_path('7net-0') + cli_args = ['-o', str(output_file), cp] + with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args): + get_model_main() + _ = capsys.readouterr() # not used + assert output_file.is_file(), '.pt file is not written' + + +def test_get_model_parallel(tmp_path, capsys): + output_dir = tmp_path / 'my_parallel' + cp = pretrained_name_to_path('7net-0') + expected_file_cnt = 5 # 5 interaction layers + cli_args = ['-o', str(output_dir), '-p', cp] + with mock.patch('sys.argv', [f'{main}/sevenn_get_model.py'] + cli_args): + # with pytest.raises(SystemExit): + get_model_main() + _ = capsys.readouterr() # not used + assert output_dir.is_dir(), 'parallel model directory not exist' + for i in range(expected_file_cnt): + assert (output_dir / f'deployed_parallel_{i}.pt').is_file() + + +@pytest.mark.parametrize('source', [(hfo2_path)]) +def test_graph_build(source, tmp_path): + output_dir = tmp_path / 'sevenn_data' + output_f = output_dir / 'my_graph.pt' + output_yml = output_dir / 'my_graph.yaml' + cli_args = ['-o', str(tmp_path), '-f', 'my_graph.pt', source, '4.0'] + with mock.patch('sys.argv', [f'{main}/sevenn_graph_build.py'] + cli_args): + graph_build_main() + + assert output_dir.is_dir() + assert output_f.is_file() + assert output_yml.is_file() + + +@pytest.mark.parametrize( + 'batch,device,save_graph', + [ + (1, 'cpu', False), + (2, 'cpu', False), + (1, 'cpu', True), + ], +) +def test_inference(batch, device, save_graph, tmp_path): + checkpoint = '7net-0' + target = hfo2_path + ref_path = hfo2_7net_0_inference_path + + output_dir = tmp_path / 'inference_results' + files = ['info.csv', 'per_graph.csv', 'per_atom.csv', 'errors.txt'] + cli_args = [ + '--output', + str(output_dir), + '--device', + device, + '--batch', + str(batch), + checkpoint, + target, + ] + if save_graph: + cli_args.append('--save_graph') + with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): + inference_main() + + assert output_dir.is_dir() + for f in files: + assert (output_dir / f).is_file() + with open(output_dir / 'errors.txt', 'r', encoding='utf-8') as f: + errors = [float(ll.split(':')[-1].strip()) for ll in f.readlines()] + with open(ref_path / 'errors.txt', 'r', encoding='utf-8') as f: + errors_ref = [float(ll.split(':')[-1].strip()) for ll in f.readlines()] + assert np.allclose(np.array(errors), np.array(errors_ref)) + + """ + # TODO: commented out as currently SevenNetGraphDataset can't do this + with open(output_dir / 'info.csv', 'r') as f: + reader = csv.DictReader(f) + for dct in reader: + assert dct['file'] == hfo2_path + assert reader.line_num == 3 + """ + + if save_graph: + assert (output_dir / 'sevenn_data').is_dir() + assert (output_dir / 'sevenn_data' / 'saved_graph.pt').is_file() + assert (output_dir / 'sevenn_data' / 'saved_graph.yaml').is_file() + + +def test_inference_unlabeled(atoms_hfo, tmp_path): + labeled = str(hfo2_path) + unlabeled = str(tmp_path / 'unlabeled.xyz') + ase.io.write(unlabeled, atoms_hfo) + + output_dir = tmp_path / 'inference_results' + cli_args = [ + '--output', + str(output_dir), + '--allow_unlabeled', + cp_0_path, + labeled, + unlabeled, + ] + with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): + inference_main() + + with open(output_dir / 'info.csv', 'r') as f: + reader = csv.DictReader(f) + for dct in reader: + assert dct['file'] in [labeled, unlabeled] + assert reader.line_num == 4 + + +def test_inference_labeled_w_kwargs(atoms_hfo, tmp_path): + atoms_hfo.info['my_energy'] = 1.0 + atoms_hfo.arrays['my_force'] = np.full((len(atoms_hfo), 3), 7.7) + # this should be considered as Voigt, xx, yy, zz, yz, zx, xy + atoms_hfo.info['my_stress'] = np.array([1, 2, 3, 4, 5, 6]) + + unlabeled = str(tmp_path / 'unlabeled.xyz') + ase.io.write(unlabeled, atoms_hfo) + + output_dir = tmp_path / 'inference_results' + cli_args = [ + '--output', + str(output_dir), + cp_0_path, + unlabeled, + '--kwargs', + 'energy_key=my_energy', + 'force_key=my_force', + 'stress_key=my_stress', + ] + with mock.patch('sys.argv', [f'{main}/sevenn_inference.py'] + cli_args): + inference_main() + + per_graph = None + with open(output_dir / 'per_graph.csv', 'r') as f: + reader = csv.DictReader(f) + for dct in reader: + per_graph = dct + assert reader.line_num == 2 + assert per_graph is not None + + stress_coeff = -1602.1766208 + assert np.allclose(float(per_graph['stress_yy']), 2 * stress_coeff) + assert np.allclose(float(per_graph['stress_yz']), 4 * stress_coeff) + assert np.allclose(float(per_graph['stress_zx']), 5 * stress_coeff) + assert np.allclose(float(per_graph['stress_xy']), 6 * stress_coeff) + + +@pytest.mark.parametrize( + 'preset_name,mode,data_path', + [ + ('fine_tune', 'train_v2', hfo2_path), + ('base', 'train_v2', hfo2_path), + ('sevennet-0', 'train_v1', hfo2_path), + ], +) +def test_sevenn_preset(preset_name, mode, data_path, tmp_path): + preset_path = os.path.join(preset, preset_name + '.yaml') + with open(preset_path, 'r') as f: + cfg = yaml.safe_load(f) + + cfg['train']['epoch'] = 1 + if mode == 'train_v2': + cfg['data']['load_trainset_path'] = data_path + cfg['data'].pop('load_testset_path', None) + elif mode == 'train_v1': + cfg['data']['load_dataset_path'] = data_path + else: + assert False + cfg['data']['load_validset_path'] = data_path + + input_yam = str(tmp_path / 'input.yaml') + with open(input_yam, 'w') as f: + yaml.dump(cfg, f) + + Logger().switch_file(str(tmp_path / 'log.sevenn')) + cli_args = ['train', '-w', str(tmp_path), '-m', mode, input_yam] + with mock.patch('sys.argv', [f'{main}/sevenn.py'] + cli_args): + sevenn_main() + + assert (tmp_path / 'lc.csv').is_file() or (tmp_path / 'log.csv').is_file() + assert (tmp_path / 'log.sevenn').is_file() + assert (tmp_path / 'checkpoint_best.pth').is_file() diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py index a9c3b15..4f1d5bd 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_cueq.py @@ -1,282 +1,282 @@ -# TODO: add gradient test from total loss after double precision. -# so far, it is empirically checked by seeing learning curves -import copy - -import numpy as np -import pytest -import torch -from ase.build import bulk -from torch_geometric.loader.dataloader import Collater - -import sevenn -import sevenn.train.dataload as dl -from sevenn.atom_graph_data import AtomGraphData -from sevenn.calculator import SevenNetCalculator -from sevenn.model_build import build_E3_equivariant_model -from sevenn.nn.cue_helper import is_cue_available -from sevenn.nn.sequential import AtomGraphSequential -from sevenn.util import ( - chemical_species_preprocess, - model_from_checkpoint_with_backend, -) - -cutoff = 4.0 - -_atoms = bulk('NaCl', 'rocksalt', a=4.00) * (2, 2, 2) -_avg_num_neigh = 30.0 -_atoms.rattle() - -_graph = AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(_atoms, cutoff)) - - -def get_graphs(batched): - # batch size 2 - cloned = [_graph.clone().to('cuda'), _graph.clone().to('cuda')] - if not batched: - return cloned - else: - return Collater(cloned)(cloned) - - -def get_model_config(): - config = { - 'cutoff': cutoff, - 'channel': 32, - 'lmax': 2, - 'is_parity': True, - 'num_convolution_layer': 3, - 'self_connection_type': 'nequip', # not NequIp - 'interaction_type': 'nequip', - 'radial_basis': { - 'radial_basis_name': 'bessel', - }, - 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, - 'weight_nn_hidden_neurons': [64, 64], - 'act_radial': 'silu', - 'act_scalar': {'e': 'silu', 'o': 'tanh'}, - 'act_gate': {'e': 'silu', 'o': 'tanh'}, - 'conv_denominator': _avg_num_neigh, - 'train_denominator': False, - 'shift': -10.0, - 'scale': 10.0, - 'train_shift_scale': False, - 'irreps_manual': False, - 'lmax_edge': -1, - 'lmax_node': -1, - 'readout_as_fcn': False, - 'use_bias_in_linear': False, - '_normalize_sph': True, - } - chems = set() - chems.update(_atoms.get_chemical_symbols()) - config.update(**chemical_species_preprocess(list(chems))) - return config - - -def get_model(config_overwrite=None, use_cueq=False, cueq_config=None): - cf = get_model_config() - if config_overwrite is not None: - cf.update(config_overwrite) - - cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}} - cf.update(cueq_config) - - model = build_E3_equivariant_model(cf, parallel=False) - assert isinstance(model, AtomGraphSequential) - model.to('cuda') - return model - - -@pytest.mark.skipif( - not is_cue_available() or not torch.cuda.is_available(), - reason='cueq or gpu is not available', -) -@pytest.mark.parametrize( - 'cf', - [ - ({}), - ({'self_connection_type': 'linear'}), - ({'is_parity': False}), - ({'channel': 8}), - ({'lmax': 3}), - ({'num_interaction_layer': 2}), - ({'num_interaction_layer': 4}), - ], -) -def test_model_output(cf): - torch.manual_seed(777) - model_e3nn = get_model(cf) - torch.manual_seed(777) - model_cueq = get_model(cf, use_cueq=True) - - model_e3nn.set_is_batch_data(True) - model_cueq.set_is_batch_data(True) - - e3nn_out = model_e3nn._preprocess(get_graphs(batched=True)) - cueq_out = model_cueq._preprocess(get_graphs(batched=True)) - - for k, e3nn_f in model_e3nn._modules.items(): - cueq_f = model_cueq._modules[k] - e3nn_out = e3nn_f(e3nn_out) # type: ignore - cueq_out = cueq_f(cueq_out) # type: ignore - assert torch.allclose(e3nn_out.x, cueq_out.x, atol=1e-6), ( - f'{k} \n\n {e3nn_f} \n\n {cueq_f}' - ) - - assert torch.allclose( - e3nn_out.inferred_total_energy, cueq_out.inferred_total_energy - ) - assert torch.allclose(e3nn_out.atomic_energy, cueq_out.atomic_energy) - assert torch.allclose( - e3nn_out.inferred_force, cueq_out.inferred_force, atol=1e-5 - ) - assert torch.allclose( - e3nn_out.inferred_stress, cueq_out.inferred_stress, atol=1e-5 - ) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif( - not is_cue_available() or not torch.cuda.is_available(), - reason='cueq or gpu is not available', -) -@pytest.mark.parametrize( - 'start_from_cueq', - [ - (True), - (False), - ], -) -def test_checkpoint_convert(tmp_path, start_from_cueq): - torch.manual_seed(123) - model_from = get_model(use_cueq=start_from_cueq) - - cfg = get_model_config() - cfg.update( - { - 'cuequivariance_config': {'use': start_from_cueq}, - 'version': sevenn.__version__, - } - ) - torch.save( - {'model_state_dict': model_from.state_dict(), 'config': cfg}, - tmp_path / 'cp_from.pth', - ) - - backend = 'e3nn' if start_from_cueq else 'cueq' - model_to, _ = model_from_checkpoint_with_backend( - str(tmp_path / 'cp_from.pth'), backend - ) - model_to.to('cuda') - - model_from.set_is_batch_data(True) - model_to.set_is_batch_data(True) - - from_out = model_from(get_graphs(batched=True)) - to_out = model_to(get_graphs(batched=True)) - - assert torch.allclose( - from_out.inferred_total_energy, to_out.inferred_total_energy - ) - assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) - assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) - assert torch.allclose( - from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 - ) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif( - not is_cue_available() or not torch.cuda.is_available(), - reason='cueq or gpu is not available', -) -@pytest.mark.parametrize( - 'start_from_cueq', - [ - (True), - (False), - ], -) -def test_checkpoint_convert_no_batch(tmp_path, start_from_cueq): - torch.manual_seed(123) - model_from = get_model(use_cueq=start_from_cueq) - - cfg = get_model_config() - cfg.update( - { - 'cuequivariance_config': {'use': start_from_cueq}, - 'version': sevenn.__version__, - } - ) - torch.save( - {'model_state_dict': model_from.state_dict(), 'config': cfg}, - tmp_path / 'cp_from.pth', - ) - - backend = 'e3nn' if start_from_cueq else 'cueq' - model_to, _ = model_from_checkpoint_with_backend( - str(tmp_path / 'cp_from.pth'), backend - ) - model_to.to('cuda') - - model_from.set_is_batch_data(False) - model_to.set_is_batch_data(False) - - from_out = model_from(get_graphs(batched=False)[0]) - to_out = model_to(get_graphs(batched=False)[0]) - - assert torch.allclose( - from_out.inferred_total_energy, to_out.inferred_total_energy - ) - assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) - assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) - assert torch.allclose( - from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 - ) - - -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): - def acl(a, b, rtol=rtol, atol=atol): - return np.allclose(a, b, rtol=rtol, atol=atol) - - assert len(atoms1) == len(atoms2) - assert acl(atoms1.get_cell(), atoms2.get_cell()) - assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) - assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) - assert acl( - atoms1.get_stress(voigt=False), - atoms2.get_stress(voigt=False), - rtol * 10, - atol * 10, - ) - # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) - - -@pytest.mark.filterwarnings('ignore:.*is not found from.*') -@pytest.mark.skipif( - not is_cue_available() or not torch.cuda.is_available(), - reason='cueq or gpu is not available', -) -def test_calculator(tmp_path): - cueq = True - model = get_model(use_cueq=cueq) - ref_calc = SevenNetCalculator(model, file_type='model_instance') - atoms = copy.deepcopy(_atoms) - atoms.calc = ref_calc - - cfg = get_model_config() - cfg.update( - {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} - ) - - cp_path = str(tmp_path / 'cp.pth') - torch.save( - {'model_state_dict': model.state_dict(), 'config': cfg}, - cp_path, - ) - - calc2 = SevenNetCalculator(cp_path, enable_cueq=False) - atoms2 = copy.deepcopy(_atoms) - atoms2.calc = calc2 - - assert_atoms(atoms, atoms2) +# TODO: add gradient test from total loss after double precision. +# so far, it is empirically checked by seeing learning curves +import copy + +import numpy as np +import pytest +import torch +from ase.build import bulk +from torch_geometric.loader.dataloader import Collater + +import sevenn +import sevenn.train.dataload as dl +from sevenn.atom_graph_data import AtomGraphData +from sevenn.calculator import SevenNetCalculator +from sevenn.model_build import build_E3_equivariant_model +from sevenn.nn.cue_helper import is_cue_available +from sevenn.nn.sequential import AtomGraphSequential +from sevenn.util import ( + chemical_species_preprocess, + model_from_checkpoint_with_backend, +) + +cutoff = 4.0 + +_atoms = bulk('NaCl', 'rocksalt', a=4.00) * (2, 2, 2) +_avg_num_neigh = 30.0 +_atoms.rattle() + +_graph = AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(_atoms, cutoff)) + + +def get_graphs(batched): + # batch size 2 + cloned = [_graph.clone().to('cuda'), _graph.clone().to('cuda')] + if not batched: + return cloned + else: + return Collater(cloned)(cloned) + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 32, + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'self_connection_type': 'nequip', # not NequIp + 'interaction_type': 'nequip', + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': _avg_num_neigh, + 'train_denominator': False, + 'shift': -10.0, + 'scale': 10.0, + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + chems = set() + chems.update(_atoms.get_chemical_symbols()) + config.update(**chemical_species_preprocess(list(chems))) + return config + + +def get_model(config_overwrite=None, use_cueq=False, cueq_config=None): + cf = get_model_config() + if config_overwrite is not None: + cf.update(config_overwrite) + + cueq_config = cueq_config or {'cuequivariance_config': {'use': use_cueq}} + cf.update(cueq_config) + + model = build_E3_equivariant_model(cf, parallel=False) + assert isinstance(model, AtomGraphSequential) + model.to('cuda') + return model + + +@pytest.mark.skipif( + not is_cue_available() or not torch.cuda.is_available(), + reason='cueq or gpu is not available', +) +@pytest.mark.parametrize( + 'cf', + [ + ({}), + ({'self_connection_type': 'linear'}), + ({'is_parity': False}), + ({'channel': 8}), + ({'lmax': 3}), + ({'num_interaction_layer': 2}), + ({'num_interaction_layer': 4}), + ], +) +def test_model_output(cf): + torch.manual_seed(777) + model_e3nn = get_model(cf) + torch.manual_seed(777) + model_cueq = get_model(cf, use_cueq=True) + + model_e3nn.set_is_batch_data(True) + model_cueq.set_is_batch_data(True) + + e3nn_out = model_e3nn._preprocess(get_graphs(batched=True)) + cueq_out = model_cueq._preprocess(get_graphs(batched=True)) + + for k, e3nn_f in model_e3nn._modules.items(): + cueq_f = model_cueq._modules[k] + e3nn_out = e3nn_f(e3nn_out) # type: ignore + cueq_out = cueq_f(cueq_out) # type: ignore + assert torch.allclose(e3nn_out.x, cueq_out.x, atol=1e-6), ( + f'{k} \n\n {e3nn_f} \n\n {cueq_f}' + ) + + assert torch.allclose( + e3nn_out.inferred_total_energy, cueq_out.inferred_total_energy + ) + assert torch.allclose(e3nn_out.atomic_energy, cueq_out.atomic_energy) + assert torch.allclose( + e3nn_out.inferred_force, cueq_out.inferred_force, atol=1e-5 + ) + assert torch.allclose( + e3nn_out.inferred_stress, cueq_out.inferred_stress, atol=1e-5 + ) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif( + not is_cue_available() or not torch.cuda.is_available(), + reason='cueq or gpu is not available', +) +@pytest.mark.parametrize( + 'start_from_cueq', + [ + (True), + (False), + ], +) +def test_checkpoint_convert(tmp_path, start_from_cueq): + torch.manual_seed(123) + model_from = get_model(use_cueq=start_from_cueq) + + cfg = get_model_config() + cfg.update( + { + 'cuequivariance_config': {'use': start_from_cueq}, + 'version': sevenn.__version__, + } + ) + torch.save( + {'model_state_dict': model_from.state_dict(), 'config': cfg}, + tmp_path / 'cp_from.pth', + ) + + backend = 'e3nn' if start_from_cueq else 'cueq' + model_to, _ = model_from_checkpoint_with_backend( + str(tmp_path / 'cp_from.pth'), backend + ) + model_to.to('cuda') + + model_from.set_is_batch_data(True) + model_to.set_is_batch_data(True) + + from_out = model_from(get_graphs(batched=True)) + to_out = model_to(get_graphs(batched=True)) + + assert torch.allclose( + from_out.inferred_total_energy, to_out.inferred_total_energy + ) + assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) + assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) + assert torch.allclose( + from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 + ) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif( + not is_cue_available() or not torch.cuda.is_available(), + reason='cueq or gpu is not available', +) +@pytest.mark.parametrize( + 'start_from_cueq', + [ + (True), + (False), + ], +) +def test_checkpoint_convert_no_batch(tmp_path, start_from_cueq): + torch.manual_seed(123) + model_from = get_model(use_cueq=start_from_cueq) + + cfg = get_model_config() + cfg.update( + { + 'cuequivariance_config': {'use': start_from_cueq}, + 'version': sevenn.__version__, + } + ) + torch.save( + {'model_state_dict': model_from.state_dict(), 'config': cfg}, + tmp_path / 'cp_from.pth', + ) + + backend = 'e3nn' if start_from_cueq else 'cueq' + model_to, _ = model_from_checkpoint_with_backend( + str(tmp_path / 'cp_from.pth'), backend + ) + model_to.to('cuda') + + model_from.set_is_batch_data(False) + model_to.set_is_batch_data(False) + + from_out = model_from(get_graphs(batched=False)[0]) + to_out = model_to(get_graphs(batched=False)[0]) + + assert torch.allclose( + from_out.inferred_total_energy, to_out.inferred_total_energy + ) + assert torch.allclose(from_out.atomic_energy, to_out.atomic_energy) + assert torch.allclose(from_out.inferred_force, to_out.inferred_force, atol=1e-5) + assert torch.allclose( + from_out.inferred_stress, to_out.inferred_stress, atol=1e-5 + ) + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +@pytest.mark.filterwarnings('ignore:.*is not found from.*') +@pytest.mark.skipif( + not is_cue_available() or not torch.cuda.is_available(), + reason='cueq or gpu is not available', +) +def test_calculator(tmp_path): + cueq = True + model = get_model(use_cueq=cueq) + ref_calc = SevenNetCalculator(model, file_type='model_instance') + atoms = copy.deepcopy(_atoms) + atoms.calc = ref_calc + + cfg = get_model_config() + cfg.update( + {'cuequivariance_config': {'use': cueq}, 'version': sevenn.__version__} + ) + + cp_path = str(tmp_path / 'cp.pth') + torch.save( + {'model_state_dict': model.state_dict(), 'config': cfg}, + cp_path, + ) + + calc2 = SevenNetCalculator(cp_path, enable_cueq=False) + atoms2 = copy.deepcopy(_atoms) + atoms2.calc = calc2 + + assert_atoms(atoms, atoms2) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py index b966b2a..fda1eff 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_data.py @@ -1,521 +1,521 @@ -import logging -import os -import os.path as osp -import uuid -from collections import Counter -from copy import deepcopy -from typing import Literal - -import ase.calculators.singlepoint as singlepoint -import ase.io -import numpy as np -import pytest -import torch -from ase import Atoms -from ase.build import bulk, molecule -from torch_geometric.loader import DataLoader - -import sevenn._keys as KEY -import sevenn.train.dataload as dl -import sevenn.train.graph_dataset as ds -import sevenn.train.modal_dataset as modal_dataset -from sevenn._const import NUM_UNIV_ELEMENT -from sevenn.atom_graph_data import AtomGraphData -from sevenn.util import model_from_checkpoint, pretrained_name_to_path - -cutoff = 4.0 -lattice_constant = 3.35 - -_samples = { - 'bulk': bulk('NaCl', 'rocksalt', a=5.63), - 'mol': molecule('H2O'), - 'isolated': molecule('H'), - 'small_bulk': Atoms( - symbols='Cu', - positions=[ - (0, 0, 0), # Atom at the corner of the cube - ], - cell=[ - [lattice_constant, 0, 0], - [0, lattice_constant, 0], - [0, 0, lattice_constant], - ], - pbc=True, # Periodic boundary conditions - ), -} - - -_nedges_c4 = {'bulk': 36, 'mol': 6, 'isolated': 0, 'small_bulk': 18} - - -def get_atoms( - atoms_type: Literal['bulk', 'mol', 'isolated', 'small_bulk'], - init_y_as: Literal['calc', 'info', 'none'], -): - """ - Return atoms w, w/o reference values with its - # of edges for 4.0 cutoff length - """ - assert atoms_type in _samples - atoms = deepcopy(_samples[atoms_type]) - natoms = len(atoms) - if init_y_as == 'calc': - results = { - 'energy': np.random.rand(1), - 'forces': np.random.rand(natoms, 3), - 'stress': np.random.rand(6), - } - if not atoms.pbc.all(): - del results['stress'] - calc = singlepoint.SinglePointCalculator(atoms, **results) - atoms = calc.get_atoms() - elif init_y_as == 'info': - atoms.info['y_energy'] = np.random.rand(1) - atoms.arrays['y_force'] = np.random.rand(natoms, 3) - atoms.info['y_stress'] = np.random.rand(6) - if not atoms.pbc.all(): - del atoms.info['y_stress'] - return atoms, _nedges_c4[atoms_type] - - -@pytest.mark.parametrize('init_y_as', ['calc', 'info']) -@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) -def test_atoms_to_graph(atoms_type, init_y_as): - atoms, nedges = get_atoms(atoms_type, init_y_as) - is_stress = atoms.pbc.all() - y_from_calc = init_y_as == 'calc' - - graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc) - - essential = { - 'atomic_numbers': ((len(atoms),), int), - 'pos': ((len(atoms), 3), float), - 'edge_index': ((2, nedges), int), - 'edge_vec': ((nedges, 3), float), - 'total_energy': ((), float), - 'force_of_atoms': ((len(atoms), 3), float), - 'cell_volume': ((), float), - 'num_atoms': ((), int), - 'per_atom_energy': ((), float), - 'stress': ((1, 6), float), - } - - for k, (shape, dtype) in essential.items(): - assert k in graph, f'{k} missing in graph' - assert isinstance( - graph[k], np.ndarray - ), f'{k}: {type(graph[k])} is not np.ndarray' - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - if not is_stress and k == 'stress': - assert np.isnan(graph[k]).all() - else: - assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}' - - assert graph['per_atom_energy'] == (graph['total_energy'] / len(atoms)) - assert graph['num_atoms'] == len(atoms) - if not is_stress: - assert graph['cell_volume'] == np.finfo(float).eps - - -@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) -def test_unlabeled_atoms_to_graph(atoms_type): - atoms, nedges = get_atoms(atoms_type, 'none') - - graph = dl.unlabeled_atoms_to_graph(atoms, cutoff=cutoff) - - essential = { - 'atomic_numbers': ((len(atoms),), int), - 'pos': ((len(atoms), 3), float), - 'edge_index': ((2, nedges), int), - 'edge_vec': ((nedges, 3), float), - 'cell_volume': ((), float), - 'num_atoms': ((), int), - } - - for k, (shape, dtype) in essential.items(): - assert k in graph, f'{k} missing in graph' - assert isinstance( - graph[k], np.ndarray - ), f'{k}: {type(graph[k])} is not np.ndarray' - assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}' - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - - assert graph['num_atoms'] == len(atoms) - if not atoms.pbc.all(): - assert graph['cell_volume'] == np.finfo(float).eps - - -@pytest.mark.parametrize('init_y_as', ['calc', 'info']) -@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) -def test_atom_graph_data(atoms_type, init_y_as): - atoms, nedges = get_atoms(atoms_type, init_y_as) - y_from_calc = init_y_as == 'calc' - is_stress = atoms.pbc.all() - np_graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc) - graph = AtomGraphData.from_numpy_dict(np_graph) - - essential = { - 'atomic_numbers': ((len(atoms),), int), - 'edge_index': ((2, nedges), int), - 'edge_vec': ((nedges, 3), float), - } - auxilaray = { - 'x': ((len(atoms),), int), - 'pos': ((len(atoms), 3), float), - 'num_atoms': ((), int), - 'cell_volume': ((), float), - 'total_energy': ((), float), - 'per_atom_energy': ((), float), - 'force_of_atoms': ((len(atoms), 3), float), - 'stress': ((1, 6), float), - } - - for k, (shape, dtype) in essential.items(): - assert k in graph, f'{k} missing in graph' - assert isinstance( - graph[k], torch.Tensor - ), f'{k}: {type(graph[k])} is not an tensor' - assert graph[k].is_floating_point() == (dtype is float) - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - - for k, (shape, dtype) in auxilaray.items(): - if k not in graph: - continue - assert isinstance( - graph[k], torch.Tensor - ), f'{k}: {type(graph[k])} is not an tensor' - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - if not is_stress and k == 'stress': - assert torch.isnan(graph[k]).all() - else: - assert graph[k].is_floating_point() == (dtype is float) - - -def test_graph_build(): - """ - Compare parallel implementation, should preserve order - """ - atoms_list = [ - get_atoms(t, 'calc')[0] # type: ignore - for t in list(_samples.keys()) - ] - one_core = dl.graph_build(atoms_list, cutoff, num_cores=1, y_from_calc=True) - two_core = dl.graph_build(atoms_list, cutoff, num_cores=2, y_from_calc=True) - - assert len(one_core) == len(two_core) - for g1, g2 in zip(one_core, two_core): - assert set(g1.keys()) == set(g2.keys()) - for k in g1.keys(): - if not isinstance(g1[k], torch.Tensor): - continue - if k == 'stress': # TODO: robust way to test it - assert torch.allclose(g1[k], g2[k]) or ( - torch.isnan(g1[k]).all() == torch.isnan(g2[k]).all() - ) - else: - assert torch.allclose(g1[k], g2[k]) - - -@pytest.fixture(scope='module') -def graph_dataset_tuple(): - tmpdir = os.getenv('TMPDIR', '/tmp') - randstr = uuid.uuid4().hex - assert os.access(tmpdir, os.W_OK), f'{tmpdir} is not writable' - - root = tmpdir - files = f'{root}/{randstr}.extxyz' - atoms_list = [ - get_atoms(atype, 'calc')[0] # type: ignore - for atype in ['bulk', 'mol', 'isolated'] - ] - ase.io.write(files, atoms_list, 'extxyz') - - dataset = ds.SevenNetGraphDataset( - cutoff=cutoff, - root=root, - files=files, - processed_name=f'{randstr}.pt', - ) - assert os.path.isfile(f'{root}/sevenn_data/{randstr}.pt'), 'dataset not written' - return dataset, atoms_list - - -def test_sevenn_graph_dataset_properties(graph_dataset_tuple): - dataset, atoms_list = graph_dataset_tuple - - species = set() - natoms = Counter() - elist = [] - e_per_list = [] - flist = [] - slist = [] - for at in atoms_list: - chems = at.get_chemical_symbols() - species.update(chems) - natoms.update(chems) - elist.append(at.get_potential_energy()) - e_per_list.append(at.get_potential_energy() / len(at)) - flist.extend(at.get_forces()) - try: - slist.append(at.get_stress()) - except NotImplementedError: - slist.append(np.full(6, np.nan)) - - elist = np.array(elist) - e_per_list = np.array(e_per_list) - flist = np.array(flist) - slist = np.array(slist) - - natoms['total'] = sum([cnt for cnt in list(natoms.values())]) - - assert set(dataset.species) == species - assert dataset.natoms == natoms - assert np.allclose(dataset.per_atom_energy_mean, e_per_list.mean()) - assert np.allclose(dataset.force_rms, np.sqrt((flist**2).mean())) - - -def test_sevenn_graph_dataset_elemwise_energies(graph_dataset_tuple): - logger = logging.getLogger(__name__) - - dataset, atoms_list = graph_dataset_tuple - - ref_e = dataset.elemwise_reference_energies - assert len(ref_e) == NUM_UNIV_ELEMENT - z_set = set() - for atoms in atoms_list: - inferred_e = 0 - atomic_numbers = atoms.get_atomic_numbers() - z_set.update(atomic_numbers) - for z in atomic_numbers: - inferred_e += ref_e[z] - # it never be same, but should be similar - logger.info('elemwise energy should be similar:') - logger.info(f'{inferred_e:4f} {atoms.get_potential_energy()[0]:4f}') - - for z in range(NUM_UNIV_ELEMENT): - if z not in z_set: - assert ref_e[z] == 0 - - -def test_sevenn_graph_dataset_statistics(graph_dataset_tuple): - dataset, atoms_list = graph_dataset_tuple - - elist = [] - e_per_list = [] - flist = [] - slist = [] - for at in atoms_list: - elist.append(at.get_potential_energy()) - e_per_list.append(at.get_potential_energy() / len(at)) - flist.extend(at.get_forces()) - try: - slist.append(at.get_stress()) - except NotImplementedError: - slist.append(np.full(6, np.nan)) - - dct = { - 'total_energy': np.array(elist), - 'per_atom_energy': np.array(e_per_list), - 'force_of_atoms': np.array(flist).flatten(), - # 'stress': np.array(slist), # TODO: it may have nan - } - - for key in dct: - assert np.allclose(dataset.statistics[key]['mean'], dct[key].mean()), key - assert np.allclose(dataset.statistics[key]['std'], dct[key].std(ddof=0)), key - assert np.allclose( - dataset.statistics[key]['median'], np.median(dct[key]) - ), key - assert np.allclose(dataset.statistics[key]['max'], dct[key].max()), key - assert np.allclose(dataset.statistics[key]['min'], dct[key].min()), key - - -def test_sevenn_mm_dataset_statistics(tmp_path): - - files = osp.join(tmp_path, 'gd_one.extxyz') - atoms_list1 = [ - get_atoms(atype, 'calc')[0] # type: ignore - for atype in ['bulk', 'bulk', 'bulk', 'bulk'] - ] - ase.io.write(files, atoms_list1, 'extxyz') - - gd1 = ds.SevenNetGraphDataset( - cutoff=cutoff, - root=tmp_path, - files=files, - processed_name='gd_one.pt', - ) - - files = osp.join(tmp_path, 'gd_two.extxyz') - atoms_list2 = [ - get_atoms(atype, 'calc')[0] # type: ignore - for atype in ['mol', 'mol', 'bulk'] - ] - ase.io.write(files, atoms_list2, 'extxyz') - - gd2 = ds.SevenNetGraphDataset( - cutoff=cutoff, - root=tmp_path, - files=files, - processed_name='gd_two.pt', - ) - - ref = ds.SevenNetGraphDataset( - cutoff=cutoff, - root=tmp_path, - files=[gd1.processed_paths[0], gd2.processed_paths[0]], - processed_name='combined.pt', - ) - - mm = modal_dataset.SevenNetMultiModalDataset( - {'modal1': gd1, 'modal2': gd2} - ) - - assert np.allclose(ref.per_atom_energy_mean, mm.per_atom_energy_mean['total']) - assert np.allclose(ref.avg_num_neigh, mm.avg_num_neigh['total']) - assert np.allclose(ref.force_rms, mm.force_rms['total']) - assert set(ref.species) == set(mm.species['total']) - - -@pytest.mark.parametrize( - 'a_types,init_ys', [(['bulk', 'mol', 'isolated'], ['calc', 'calc', 'calc'])] -) -def test_7net_graph_dataset_batch_shape(a_types, init_ys, tmp_path): - assert len(a_types) == len(init_ys) - n_graph = len(a_types) - atoms_list = [] - tot_edges = 0 - tot_atoms = 0 - for a_type, init_y in zip(a_types, init_ys): - atoms, n_edge = get_atoms(a_type, init_y) - tot_edges += n_edge - tot_atoms += len(atoms) - atoms_list.append(atoms) - ase.io.write(tmp_path / 'tmp', atoms_list, format='extxyz') - dataset = ds.SevenNetGraphDataset(cutoff, tmp_path, str(tmp_path / 'tmp')) - loader = DataLoader(dataset, batch_size=n_graph) - graph = next(iter(loader)) - - essential = { - 'x': ((tot_atoms,), int), - 'atomic_numbers': ((tot_atoms,), int), - 'pos': ((tot_atoms, 3), float), - 'edge_index': ((2, tot_edges), int), - 'edge_vec': ((tot_edges, 3), float), - 'total_energy': ((n_graph,), float), - 'force_of_atoms': ((tot_atoms, 3), float), - 'cell_volume': ((n_graph,), float), - 'num_atoms': ((n_graph,), int), - 'per_atom_energy': ((n_graph,), float), - 'stress': ((n_graph, 6), float), - 'batch': ((tot_atoms,), int), # from PyG - } - - for k, (shape, dtype) in essential.items(): - assert k in graph, f'{k} missing in graph' - assert isinstance( - graph[k], torch.Tensor - ), f'{k}: {type(graph[k])} is not an tensor' - assert graph[k].is_floating_point() == (dtype is float) - assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' - - -@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated', 'small_bulk']) -def test_graph_build_ase_and_matscipy(atoms_type): - atoms, _ = get_atoms(atoms_type, 'calc') - atoms.rattle() - pos = atoms.get_positions() - cell = np.array(atoms.get_cell()) - pbc = atoms.get_pbc() - - # graph build check - # ase graph build - edge_src_ase, edge_dst_ase, edge_vec_ase, shifts_ase = dl._graph_build_ase( - cutoff, pbc, cell, pos - ) - # matscipy graph build - edge_src_matsci, edge_dst_matsci, edge_vec_matsci, shifts_matsci = ( - dl._graph_build_matscipy(cutoff, pbc, cell, pos) - ) - - # sort the graph - sorted_indices_ase = np.lexsort( - (edge_vec_ase[:, 2], edge_vec_ase[:, 1], edge_vec_ase[:, 0]) - ) - sorted_indices_matsci = np.lexsort( - (edge_vec_matsci[:, 2], edge_vec_matsci[:, 1], edge_vec_matsci[:, 0]) - ) - sorted_vec_ase = edge_vec_ase[sorted_indices_ase] - sorted_vec_matsci = edge_vec_matsci[sorted_indices_matsci] - sorted_src_ase = edge_src_ase[sorted_indices_ase] - sorted_dst_ase = edge_dst_ase[sorted_indices_ase] - sorted_src_matsci = edge_src_matsci[sorted_indices_matsci] - sorted_dst_matsci = edge_dst_matsci[sorted_indices_matsci] - sorted_shift_ase = shifts_ase[sorted_indices_ase] - sorted_shift_matsci = shifts_matsci[sorted_indices_matsci] - - # compare the result - assert np.allclose(sorted_vec_ase, sorted_vec_matsci) - assert np.array_equal(sorted_src_ase, sorted_src_matsci) - assert np.array_equal(sorted_dst_ase, sorted_dst_matsci) - assert np.array_equal(sorted_shift_ase, sorted_shift_matsci) - - # energy test - model, _ = model_from_checkpoint(pretrained_name_to_path('7net-0_11July2024')) - model.eval() - model.set_is_batch_data(False) - - # for ase energy - edge_idx_ase = np.array([edge_src_ase, edge_dst_ase]) - atomic_numbers = atoms.get_atomic_numbers() - cell = np.array(cell) - vol = dl._correct_scalar(atoms.cell.volume) - if vol == 0: - vol = np.array(np.finfo(float).eps) - - data_ase = { - KEY.NODE_FEATURE: atomic_numbers, - KEY.ATOMIC_NUMBERS: atomic_numbers, - KEY.POS: pos, - KEY.EDGE_IDX: edge_idx_ase, - KEY.EDGE_VEC: edge_vec_ase, - KEY.CELL: cell, - KEY.CELL_SHIFT: shifts_ase, - KEY.CELL_VOLUME: vol, - KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), - } - data_ase[KEY.INFO] = {} - atom_graph_data_ase = AtomGraphData.from_numpy_dict(data_ase) - output_ase = model(atom_graph_data_ase) - ase_pred_energy = output_ase[KEY.PRED_TOTAL_ENERGY] - ase_pred_force = output_ase[KEY.PRED_FORCE] - ase_pred_stress = output_ase[KEY.PRED_STRESS] - - # for matsci energy - edge_idx_matsci = np.array([edge_src_matsci, edge_dst_matsci]) - atomic_numbers = atoms.get_atomic_numbers() - cell = np.array(cell) - vol = dl._correct_scalar(atoms.cell.volume) - if vol == 0: - vol = np.array(np.finfo(float).eps) - - data_matsci = { - KEY.NODE_FEATURE: atomic_numbers, - KEY.ATOMIC_NUMBERS: atomic_numbers, - KEY.POS: pos, - KEY.EDGE_IDX: edge_idx_matsci, - KEY.EDGE_VEC: edge_vec_matsci, - KEY.CELL: cell, - KEY.CELL_SHIFT: shifts_matsci, - KEY.CELL_VOLUME: vol, - KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), - } - data_matsci[KEY.INFO] = {} - atom_graph_data_matsci = AtomGraphData.from_numpy_dict(data_matsci) - output_matsci = model(atom_graph_data_matsci) - matsci_pred_energy = output_matsci[KEY.PRED_TOTAL_ENERGY] - matsci_pred_force = output_matsci[KEY.PRED_FORCE] - matsci_pred_stress = output_matsci[KEY.PRED_STRESS] - assert torch.equal(ase_pred_energy, matsci_pred_energy) - assert torch.allclose(ase_pred_force, matsci_pred_force, atol=1e-06) - assert torch.allclose(ase_pred_stress, matsci_pred_stress) +import logging +import os +import os.path as osp +import uuid +from collections import Counter +from copy import deepcopy +from typing import Literal + +import ase.calculators.singlepoint as singlepoint +import ase.io +import numpy as np +import pytest +import torch +from ase import Atoms +from ase.build import bulk, molecule +from torch_geometric.loader import DataLoader + +import sevenn._keys as KEY +import sevenn.train.dataload as dl +import sevenn.train.graph_dataset as ds +import sevenn.train.modal_dataset as modal_dataset +from sevenn._const import NUM_UNIV_ELEMENT +from sevenn.atom_graph_data import AtomGraphData +from sevenn.util import model_from_checkpoint, pretrained_name_to_path + +cutoff = 4.0 +lattice_constant = 3.35 + +_samples = { + 'bulk': bulk('NaCl', 'rocksalt', a=5.63), + 'mol': molecule('H2O'), + 'isolated': molecule('H'), + 'small_bulk': Atoms( + symbols='Cu', + positions=[ + (0, 0, 0), # Atom at the corner of the cube + ], + cell=[ + [lattice_constant, 0, 0], + [0, lattice_constant, 0], + [0, 0, lattice_constant], + ], + pbc=True, # Periodic boundary conditions + ), +} + + +_nedges_c4 = {'bulk': 36, 'mol': 6, 'isolated': 0, 'small_bulk': 18} + + +def get_atoms( + atoms_type: Literal['bulk', 'mol', 'isolated', 'small_bulk'], + init_y_as: Literal['calc', 'info', 'none'], +): + """ + Return atoms w, w/o reference values with its + # of edges for 4.0 cutoff length + """ + assert atoms_type in _samples + atoms = deepcopy(_samples[atoms_type]) + natoms = len(atoms) + if init_y_as == 'calc': + results = { + 'energy': np.random.rand(1), + 'forces': np.random.rand(natoms, 3), + 'stress': np.random.rand(6), + } + if not atoms.pbc.all(): + del results['stress'] + calc = singlepoint.SinglePointCalculator(atoms, **results) + atoms = calc.get_atoms() + elif init_y_as == 'info': + atoms.info['y_energy'] = np.random.rand(1) + atoms.arrays['y_force'] = np.random.rand(natoms, 3) + atoms.info['y_stress'] = np.random.rand(6) + if not atoms.pbc.all(): + del atoms.info['y_stress'] + return atoms, _nedges_c4[atoms_type] + + +@pytest.mark.parametrize('init_y_as', ['calc', 'info']) +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) +def test_atoms_to_graph(atoms_type, init_y_as): + atoms, nedges = get_atoms(atoms_type, init_y_as) + is_stress = atoms.pbc.all() + y_from_calc = init_y_as == 'calc' + + graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc) + + essential = { + 'atomic_numbers': ((len(atoms),), int), + 'pos': ((len(atoms), 3), float), + 'edge_index': ((2, nedges), int), + 'edge_vec': ((nedges, 3), float), + 'total_energy': ((), float), + 'force_of_atoms': ((len(atoms), 3), float), + 'cell_volume': ((), float), + 'num_atoms': ((), int), + 'per_atom_energy': ((), float), + 'stress': ((1, 6), float), + } + + for k, (shape, dtype) in essential.items(): + assert k in graph, f'{k} missing in graph' + assert isinstance( + graph[k], np.ndarray + ), f'{k}: {type(graph[k])} is not np.ndarray' + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + if not is_stress and k == 'stress': + assert np.isnan(graph[k]).all() + else: + assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}' + + assert graph['per_atom_energy'] == (graph['total_energy'] / len(atoms)) + assert graph['num_atoms'] == len(atoms) + if not is_stress: + assert graph['cell_volume'] == np.finfo(float).eps + + +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) +def test_unlabeled_atoms_to_graph(atoms_type): + atoms, nedges = get_atoms(atoms_type, 'none') + + graph = dl.unlabeled_atoms_to_graph(atoms, cutoff=cutoff) + + essential = { + 'atomic_numbers': ((len(atoms),), int), + 'pos': ((len(atoms), 3), float), + 'edge_index': ((2, nedges), int), + 'edge_vec': ((nedges, 3), float), + 'cell_volume': ((), float), + 'num_atoms': ((), int), + } + + for k, (shape, dtype) in essential.items(): + assert k in graph, f'{k} missing in graph' + assert isinstance( + graph[k], np.ndarray + ), f'{k}: {type(graph[k])} is not np.ndarray' + assert graph[k].dtype == dtype, f'{k} dtype {graph[k].dtype} != {dtype}' + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + + assert graph['num_atoms'] == len(atoms) + if not atoms.pbc.all(): + assert graph['cell_volume'] == np.finfo(float).eps + + +@pytest.mark.parametrize('init_y_as', ['calc', 'info']) +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated']) +def test_atom_graph_data(atoms_type, init_y_as): + atoms, nedges = get_atoms(atoms_type, init_y_as) + y_from_calc = init_y_as == 'calc' + is_stress = atoms.pbc.all() + np_graph = dl.atoms_to_graph(atoms, cutoff=cutoff, y_from_calc=y_from_calc) + graph = AtomGraphData.from_numpy_dict(np_graph) + + essential = { + 'atomic_numbers': ((len(atoms),), int), + 'edge_index': ((2, nedges), int), + 'edge_vec': ((nedges, 3), float), + } + auxilaray = { + 'x': ((len(atoms),), int), + 'pos': ((len(atoms), 3), float), + 'num_atoms': ((), int), + 'cell_volume': ((), float), + 'total_energy': ((), float), + 'per_atom_energy': ((), float), + 'force_of_atoms': ((len(atoms), 3), float), + 'stress': ((1, 6), float), + } + + for k, (shape, dtype) in essential.items(): + assert k in graph, f'{k} missing in graph' + assert isinstance( + graph[k], torch.Tensor + ), f'{k}: {type(graph[k])} is not an tensor' + assert graph[k].is_floating_point() == (dtype is float) + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + + for k, (shape, dtype) in auxilaray.items(): + if k not in graph: + continue + assert isinstance( + graph[k], torch.Tensor + ), f'{k}: {type(graph[k])} is not an tensor' + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + if not is_stress and k == 'stress': + assert torch.isnan(graph[k]).all() + else: + assert graph[k].is_floating_point() == (dtype is float) + + +def test_graph_build(): + """ + Compare parallel implementation, should preserve order + """ + atoms_list = [ + get_atoms(t, 'calc')[0] # type: ignore + for t in list(_samples.keys()) + ] + one_core = dl.graph_build(atoms_list, cutoff, num_cores=1, y_from_calc=True) + two_core = dl.graph_build(atoms_list, cutoff, num_cores=2, y_from_calc=True) + + assert len(one_core) == len(two_core) + for g1, g2 in zip(one_core, two_core): + assert set(g1.keys()) == set(g2.keys()) + for k in g1.keys(): + if not isinstance(g1[k], torch.Tensor): + continue + if k == 'stress': # TODO: robust way to test it + assert torch.allclose(g1[k], g2[k]) or ( + torch.isnan(g1[k]).all() == torch.isnan(g2[k]).all() + ) + else: + assert torch.allclose(g1[k], g2[k]) + + +@pytest.fixture(scope='module') +def graph_dataset_tuple(): + tmpdir = os.getenv('TMPDIR', '/tmp') + randstr = uuid.uuid4().hex + assert os.access(tmpdir, os.W_OK), f'{tmpdir} is not writable' + + root = tmpdir + files = f'{root}/{randstr}.extxyz' + atoms_list = [ + get_atoms(atype, 'calc')[0] # type: ignore + for atype in ['bulk', 'mol', 'isolated'] + ] + ase.io.write(files, atoms_list, 'extxyz') + + dataset = ds.SevenNetGraphDataset( + cutoff=cutoff, + root=root, + files=files, + processed_name=f'{randstr}.pt', + ) + assert os.path.isfile(f'{root}/sevenn_data/{randstr}.pt'), 'dataset not written' + return dataset, atoms_list + + +def test_sevenn_graph_dataset_properties(graph_dataset_tuple): + dataset, atoms_list = graph_dataset_tuple + + species = set() + natoms = Counter() + elist = [] + e_per_list = [] + flist = [] + slist = [] + for at in atoms_list: + chems = at.get_chemical_symbols() + species.update(chems) + natoms.update(chems) + elist.append(at.get_potential_energy()) + e_per_list.append(at.get_potential_energy() / len(at)) + flist.extend(at.get_forces()) + try: + slist.append(at.get_stress()) + except NotImplementedError: + slist.append(np.full(6, np.nan)) + + elist = np.array(elist) + e_per_list = np.array(e_per_list) + flist = np.array(flist) + slist = np.array(slist) + + natoms['total'] = sum([cnt for cnt in list(natoms.values())]) + + assert set(dataset.species) == species + assert dataset.natoms == natoms + assert np.allclose(dataset.per_atom_energy_mean, e_per_list.mean()) + assert np.allclose(dataset.force_rms, np.sqrt((flist**2).mean())) + + +def test_sevenn_graph_dataset_elemwise_energies(graph_dataset_tuple): + logger = logging.getLogger(__name__) + + dataset, atoms_list = graph_dataset_tuple + + ref_e = dataset.elemwise_reference_energies + assert len(ref_e) == NUM_UNIV_ELEMENT + z_set = set() + for atoms in atoms_list: + inferred_e = 0 + atomic_numbers = atoms.get_atomic_numbers() + z_set.update(atomic_numbers) + for z in atomic_numbers: + inferred_e += ref_e[z] + # it never be same, but should be similar + logger.info('elemwise energy should be similar:') + logger.info(f'{inferred_e:4f} {atoms.get_potential_energy()[0]:4f}') + + for z in range(NUM_UNIV_ELEMENT): + if z not in z_set: + assert ref_e[z] == 0 + + +def test_sevenn_graph_dataset_statistics(graph_dataset_tuple): + dataset, atoms_list = graph_dataset_tuple + + elist = [] + e_per_list = [] + flist = [] + slist = [] + for at in atoms_list: + elist.append(at.get_potential_energy()) + e_per_list.append(at.get_potential_energy() / len(at)) + flist.extend(at.get_forces()) + try: + slist.append(at.get_stress()) + except NotImplementedError: + slist.append(np.full(6, np.nan)) + + dct = { + 'total_energy': np.array(elist), + 'per_atom_energy': np.array(e_per_list), + 'force_of_atoms': np.array(flist).flatten(), + # 'stress': np.array(slist), # TODO: it may have nan + } + + for key in dct: + assert np.allclose(dataset.statistics[key]['mean'], dct[key].mean()), key + assert np.allclose(dataset.statistics[key]['std'], dct[key].std(ddof=0)), key + assert np.allclose( + dataset.statistics[key]['median'], np.median(dct[key]) + ), key + assert np.allclose(dataset.statistics[key]['max'], dct[key].max()), key + assert np.allclose(dataset.statistics[key]['min'], dct[key].min()), key + + +def test_sevenn_mm_dataset_statistics(tmp_path): + + files = osp.join(tmp_path, 'gd_one.extxyz') + atoms_list1 = [ + get_atoms(atype, 'calc')[0] # type: ignore + for atype in ['bulk', 'bulk', 'bulk', 'bulk'] + ] + ase.io.write(files, atoms_list1, 'extxyz') + + gd1 = ds.SevenNetGraphDataset( + cutoff=cutoff, + root=tmp_path, + files=files, + processed_name='gd_one.pt', + ) + + files = osp.join(tmp_path, 'gd_two.extxyz') + atoms_list2 = [ + get_atoms(atype, 'calc')[0] # type: ignore + for atype in ['mol', 'mol', 'bulk'] + ] + ase.io.write(files, atoms_list2, 'extxyz') + + gd2 = ds.SevenNetGraphDataset( + cutoff=cutoff, + root=tmp_path, + files=files, + processed_name='gd_two.pt', + ) + + ref = ds.SevenNetGraphDataset( + cutoff=cutoff, + root=tmp_path, + files=[gd1.processed_paths[0], gd2.processed_paths[0]], + processed_name='combined.pt', + ) + + mm = modal_dataset.SevenNetMultiModalDataset( + {'modal1': gd1, 'modal2': gd2} + ) + + assert np.allclose(ref.per_atom_energy_mean, mm.per_atom_energy_mean['total']) + assert np.allclose(ref.avg_num_neigh, mm.avg_num_neigh['total']) + assert np.allclose(ref.force_rms, mm.force_rms['total']) + assert set(ref.species) == set(mm.species['total']) + + +@pytest.mark.parametrize( + 'a_types,init_ys', [(['bulk', 'mol', 'isolated'], ['calc', 'calc', 'calc'])] +) +def test_7net_graph_dataset_batch_shape(a_types, init_ys, tmp_path): + assert len(a_types) == len(init_ys) + n_graph = len(a_types) + atoms_list = [] + tot_edges = 0 + tot_atoms = 0 + for a_type, init_y in zip(a_types, init_ys): + atoms, n_edge = get_atoms(a_type, init_y) + tot_edges += n_edge + tot_atoms += len(atoms) + atoms_list.append(atoms) + ase.io.write(tmp_path / 'tmp', atoms_list, format='extxyz') + dataset = ds.SevenNetGraphDataset(cutoff, tmp_path, str(tmp_path / 'tmp')) + loader = DataLoader(dataset, batch_size=n_graph) + graph = next(iter(loader)) + + essential = { + 'x': ((tot_atoms,), int), + 'atomic_numbers': ((tot_atoms,), int), + 'pos': ((tot_atoms, 3), float), + 'edge_index': ((2, tot_edges), int), + 'edge_vec': ((tot_edges, 3), float), + 'total_energy': ((n_graph,), float), + 'force_of_atoms': ((tot_atoms, 3), float), + 'cell_volume': ((n_graph,), float), + 'num_atoms': ((n_graph,), int), + 'per_atom_energy': ((n_graph,), float), + 'stress': ((n_graph, 6), float), + 'batch': ((tot_atoms,), int), # from PyG + } + + for k, (shape, dtype) in essential.items(): + assert k in graph, f'{k} missing in graph' + assert isinstance( + graph[k], torch.Tensor + ), f'{k}: {type(graph[k])} is not an tensor' + assert graph[k].is_floating_point() == (dtype is float) + assert graph[k].shape == shape, f'{k} shape {graph[k].shape} != {shape}' + + +@pytest.mark.parametrize('atoms_type', ['bulk', 'mol', 'isolated', 'small_bulk']) +def test_graph_build_ase_and_matscipy(atoms_type): + atoms, _ = get_atoms(atoms_type, 'calc') + atoms.rattle() + pos = atoms.get_positions() + cell = np.array(atoms.get_cell()) + pbc = atoms.get_pbc() + + # graph build check + # ase graph build + edge_src_ase, edge_dst_ase, edge_vec_ase, shifts_ase = dl._graph_build_ase( + cutoff, pbc, cell, pos + ) + # matscipy graph build + edge_src_matsci, edge_dst_matsci, edge_vec_matsci, shifts_matsci = ( + dl._graph_build_matscipy(cutoff, pbc, cell, pos) + ) + + # sort the graph + sorted_indices_ase = np.lexsort( + (edge_vec_ase[:, 2], edge_vec_ase[:, 1], edge_vec_ase[:, 0]) + ) + sorted_indices_matsci = np.lexsort( + (edge_vec_matsci[:, 2], edge_vec_matsci[:, 1], edge_vec_matsci[:, 0]) + ) + sorted_vec_ase = edge_vec_ase[sorted_indices_ase] + sorted_vec_matsci = edge_vec_matsci[sorted_indices_matsci] + sorted_src_ase = edge_src_ase[sorted_indices_ase] + sorted_dst_ase = edge_dst_ase[sorted_indices_ase] + sorted_src_matsci = edge_src_matsci[sorted_indices_matsci] + sorted_dst_matsci = edge_dst_matsci[sorted_indices_matsci] + sorted_shift_ase = shifts_ase[sorted_indices_ase] + sorted_shift_matsci = shifts_matsci[sorted_indices_matsci] + + # compare the result + assert np.allclose(sorted_vec_ase, sorted_vec_matsci) + assert np.array_equal(sorted_src_ase, sorted_src_matsci) + assert np.array_equal(sorted_dst_ase, sorted_dst_matsci) + assert np.array_equal(sorted_shift_ase, sorted_shift_matsci) + + # energy test + model, _ = model_from_checkpoint(pretrained_name_to_path('7net-0_11July2024')) + model.eval() + model.set_is_batch_data(False) + + # for ase energy + edge_idx_ase = np.array([edge_src_ase, edge_dst_ase]) + atomic_numbers = atoms.get_atomic_numbers() + cell = np.array(cell) + vol = dl._correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data_ase = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx_ase, + KEY.EDGE_VEC: edge_vec_ase, + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts_ase, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), + } + data_ase[KEY.INFO] = {} + atom_graph_data_ase = AtomGraphData.from_numpy_dict(data_ase) + output_ase = model(atom_graph_data_ase) + ase_pred_energy = output_ase[KEY.PRED_TOTAL_ENERGY] + ase_pred_force = output_ase[KEY.PRED_FORCE] + ase_pred_stress = output_ase[KEY.PRED_STRESS] + + # for matsci energy + edge_idx_matsci = np.array([edge_src_matsci, edge_dst_matsci]) + atomic_numbers = atoms.get_atomic_numbers() + cell = np.array(cell) + vol = dl._correct_scalar(atoms.cell.volume) + if vol == 0: + vol = np.array(np.finfo(float).eps) + + data_matsci = { + KEY.NODE_FEATURE: atomic_numbers, + KEY.ATOMIC_NUMBERS: atomic_numbers, + KEY.POS: pos, + KEY.EDGE_IDX: edge_idx_matsci, + KEY.EDGE_VEC: edge_vec_matsci, + KEY.CELL: cell, + KEY.CELL_SHIFT: shifts_matsci, + KEY.CELL_VOLUME: vol, + KEY.NUM_ATOMS: dl._correct_scalar(len(atomic_numbers)), + } + data_matsci[KEY.INFO] = {} + atom_graph_data_matsci = AtomGraphData.from_numpy_dict(data_matsci) + output_matsci = model(atom_graph_data_matsci) + matsci_pred_energy = output_matsci[KEY.PRED_TOTAL_ENERGY] + matsci_pred_force = output_matsci[KEY.PRED_FORCE] + matsci_pred_stress = output_matsci[KEY.PRED_STRESS] + assert torch.equal(ase_pred_energy, matsci_pred_energy) + assert torch.allclose(ase_pred_force, matsci_pred_force, atol=1e-06) + assert torch.allclose(ase_pred_stress, matsci_pred_stress) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py index 455fd76..36024a8 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_errors.py @@ -1,285 +1,285 @@ -# test_errors: error recorder.py, loss.py -from copy import deepcopy - -import numpy as np -import pytest -import torch -import torch.nn -from torch import tensor - -import sevenn.error_recorder as erc -import sevenn.train.loss as loss -from sevenn.atom_graph_data import AtomGraphData -from sevenn.train.optim import loss_dict - -_default_config = { - 'loss': 'mse', - 'loss_param': {}, - 'error_record': [ - ('Energy', 'RMSE'), - ('Force', 'RMSE'), - ('Stress', 'RMSE'), - ('Energy', 'MAE'), - ('Force', 'MAE'), - ('Stress', 'MAE'), - ('TotalLoss', 'None'), - ], - 'is_train_stress': True, - 'force_loss_weight': 1.0, - 'stress_loss_weight': 0.001, -} - -_erc_test_params = [ - ('TotalEnergy', 4, 3), - ('Energy', 4, 3), - ('Force', 4, 3), - ('Stress', 4, 3), - ('Stress_GPa', 4, 3), - ('Energy', 4, 1), - ('Energy', 1, 1), - ('Force', 1, 3), - ('Stress', 1, 3), -] - - -def acl(a, b): - return torch.allclose(a, b, atol=1e-6) - - -def config(**overwrite): # to make it read-only - cf = deepcopy(_default_config) - for k, v in overwrite.items(): - cf[k] = v - return cf - - -def test_per_atom_energy_loss(): - loss_f = loss.PerAtomEnergyLoss(criterion=torch.nn.MSELoss()) - ref = torch.rand(2) - pred = torch.rand(2) - natoms = torch.randint(1, 10, (2,)) - tmp = AtomGraphData( - total_energy=ref, - inferred_total_energy=pred, - num_atoms=natoms, - ).to_dict() - ret = loss_f.get_loss(tmp) - assert loss_f.criterion is not None - assert torch.allclose(loss_f.criterion((ref / natoms), (pred / natoms)), ret) - - -def test_force_loss(): - loss_f = loss.ForceLoss(criterion=torch.nn.MSELoss()) - ref = torch.rand((4, 3)) - pred = torch.rand((4, 3)) - batch = tensor([0, 0, 0, 1]) - tmp = AtomGraphData( - force_of_atoms=ref, - inferred_force=pred, - batch=batch, - ).to_dict() - ret = loss_f.get_loss(tmp) - assert loss_f.criterion is not None - assert torch.allclose(loss_f.criterion(ref.reshape(-1), pred.reshape(-1)), ret) - - -def test_stress_loss(): - loss_f = loss.StressLoss(criterion=torch.nn.MSELoss()) - ref = torch.rand((2, 6)) - pred = torch.rand((2, 6)) - tmp = AtomGraphData( - stress=ref, - inferred_stress=pred, - ).to_dict() - ret = loss_f.get_loss(tmp) - KB = 1602.1766208 - assert loss_f.criterion is not None - assert torch.allclose( - loss_f.criterion(ref.reshape(-1) * KB, pred.reshape(-1) * KB), ret - ) - - -@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)]) -def test_loss_from_config(conf): - loss_functions = loss.get_loss_functions_from_config(conf) - - if conf['is_train_stress']: - assert len(loss_functions) == 3 - else: - assert len(loss_functions) == 2 - - for loss_def, w in loss_functions: - assert isinstance(loss_def, loss.LossDefinition) - if isinstance(loss_def, loss.PerAtomEnergyLoss): - assert w == 1.0 - elif isinstance(loss_def, loss.ForceLoss): - assert w == conf['force_loss_weight'] - elif isinstance(loss_def, loss.StressLoss): - assert w == conf['stress_loss_weight'] - else: - raise ValueError(f'Unexpected loss function: {loss_def}') - - -@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) -def test_rms_error(err_type, ndata, natoms): - err_dct = erc.get_err_type(err_type) - err = erc.RMSError(**err_dct) - ref = torch.rand((ndata, err.vdim)).squeeze(1) - pred = torch.rand((ndata, err.vdim)).squeeze(1) - natoms = torch.tensor([natoms] * ndata) - _data = { - err_dct['ref_key']: ref, - err_dct['pred_key']: pred, - 'num_atoms': natoms, - } - - tmp = AtomGraphData(**_data) - err.update(tmp) - - _ref = ref * err.coeff - _pred = pred * err.coeff - if 'per_atom' in err_dct and err_dct['per_atom']: - # natoms = natoms.unsqueeze(-1) - _ref = _ref / natoms - _pred = _pred / natoms - val = torch.sqrt(((_ref - _pred) ** 2).sum() / ndata) # not ndata*natoms - assert np.allclose(err.get(), val.item()) - err.update(tmp) - assert np.allclose(err.get(), val.item()) - - -@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) -def test_mae_error(err_type, ndata, natoms): - err_dct = erc.get_err_type(err_type) - vdim = err_dct['vdim'] - err = erc.MAError(**err_dct) - ref = torch.rand((ndata, vdim)).squeeze(1) - pred = torch.rand((ndata, vdim)).squeeze(1) - natoms = torch.tensor([natoms] * ndata) - _data = { - err_dct['ref_key']: ref, - err_dct['pred_key']: pred, - 'num_atoms': natoms, - } - - tmp = AtomGraphData(**_data) - err.update(tmp) - - _ref = ref * err.coeff - _pred = pred * err.coeff - if 'per_atom' in err_dct and err_dct['per_atom']: - _ref /= natoms - _pred /= natoms - - val = abs(_ref - _pred).sum() / (ndata * vdim) - assert np.allclose(err.get(), val.item()) - err.update(tmp) - assert np.allclose(err.get(), val.item()) - - -# TODO: test_component_rms_error - - -@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) -def test_custom_error(err_type, ndata, natoms): - def func(a, b): - return a * b - - err_dct = erc.get_err_type(err_type) - vdim = err_dct['vdim'] - err = erc.CustomError(func, **err_dct) - ref = torch.rand((ndata, vdim)).squeeze(1) - pred = torch.rand((ndata, vdim)).squeeze(1) - natoms = torch.tensor([natoms] * ndata) - _data = { - err_dct['ref_key']: ref, - err_dct['pred_key']: pred, - 'num_atoms': natoms, - } - - _ref = ref * err.coeff - _pred = pred * err.coeff - if 'per_atom' in err_dct and err_dct['per_atom']: - _ref /= natoms - _pred /= natoms - - tmp = AtomGraphData(**_data) - err.update(tmp) - val = func(_ref, _pred).mean() - assert np.allclose(err.get(), val.item()) - err.update(tmp) - assert np.allclose(err.get(), val.item()) - - -@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)]) -def test_total_loss_metric_from_config(conf): - def func(a, b): - return a * b - - err = erc.ErrorRecorder.init_total_loss_metric(conf, func) - ndata = 3 - natoms = 4 - - e1, e2 = torch.rand(ndata), torch.rand(ndata) - f1, f2 = torch.rand(ndata * natoms, 3), torch.rand(ndata * natoms, 3) - s1, s2 = torch.rand((ndata, 6)), torch.rand((ndata, 6)) - _data = { - 'total_energy': e1, - 'inferred_total_energy': e2, - 'force_of_atoms': f1, - 'inferred_force': f2, - 'stress': s1, - 'inferred_stress': s2, - 'num_atoms': torch.tensor([natoms] * ndata), - } - - tmp = AtomGraphData(**_data) - err.update(tmp) - - val = (func(e1 / natoms, e2 / natoms)).mean() + conf['force_loss_weight'] * func( - f1, f2 - ).mean() - if conf['is_train_stress']: - KB = 1602.1766208 - val += conf['stress_loss_weight'] * func(s1 * KB, s2 * KB).mean() - - assert np.allclose(err.get(), val.item()) - err.update(tmp) - assert np.allclose(err.get(), val.item()) - - -@pytest.mark.parametrize( - 'conf', [config(), config(is_train_stress=False), config(loss='huber')] -) -def test_error_recorder_from_config(conf): - recorder = erc.ErrorRecorder.from_config(conf) - - total_loss_flag = False - for metric in recorder.metrics: - if conf['is_train_stress'] is False: - assert 'stress' not in metric.name - if metric.name == 'TotalLoss': - total_loss_flag = True - for loss_metric, _ in metric.metrics: # type: ignore - assert isinstance(loss_metric.func, loss_dict[conf['loss']]) - assert total_loss_flag - - -@pytest.mark.parametrize( - 'conf', [config(), config(is_train_stress=False), config(loss='huber')] -) -def test_error_recorder_from_config_and_loss_functions(conf): - loss_functions = loss.get_loss_functions_from_config(conf) - recorder = erc.ErrorRecorder.from_config(conf, loss_functions) - - total_loss_flag = False - for metric in recorder.metrics: - if conf['is_train_stress'] is False: - assert 'stress' not in metric.name - if metric.name == 'TotalLoss': - total_loss_flag = True - for loss_metric, _ in metric.metrics: # type: ignore - assert isinstance( - loss_metric.loss_def.criterion, loss_dict[conf['loss']] - ) - assert total_loss_flag +# test_errors: error recorder.py, loss.py +from copy import deepcopy + +import numpy as np +import pytest +import torch +import torch.nn +from torch import tensor + +import sevenn.error_recorder as erc +import sevenn.train.loss as loss +from sevenn.atom_graph_data import AtomGraphData +from sevenn.train.optim import loss_dict + +_default_config = { + 'loss': 'mse', + 'loss_param': {}, + 'error_record': [ + ('Energy', 'RMSE'), + ('Force', 'RMSE'), + ('Stress', 'RMSE'), + ('Energy', 'MAE'), + ('Force', 'MAE'), + ('Stress', 'MAE'), + ('TotalLoss', 'None'), + ], + 'is_train_stress': True, + 'force_loss_weight': 1.0, + 'stress_loss_weight': 0.001, +} + +_erc_test_params = [ + ('TotalEnergy', 4, 3), + ('Energy', 4, 3), + ('Force', 4, 3), + ('Stress', 4, 3), + ('Stress_GPa', 4, 3), + ('Energy', 4, 1), + ('Energy', 1, 1), + ('Force', 1, 3), + ('Stress', 1, 3), +] + + +def acl(a, b): + return torch.allclose(a, b, atol=1e-6) + + +def config(**overwrite): # to make it read-only + cf = deepcopy(_default_config) + for k, v in overwrite.items(): + cf[k] = v + return cf + + +def test_per_atom_energy_loss(): + loss_f = loss.PerAtomEnergyLoss(criterion=torch.nn.MSELoss()) + ref = torch.rand(2) + pred = torch.rand(2) + natoms = torch.randint(1, 10, (2,)) + tmp = AtomGraphData( + total_energy=ref, + inferred_total_energy=pred, + num_atoms=natoms, + ).to_dict() + ret = loss_f.get_loss(tmp) + assert loss_f.criterion is not None + assert torch.allclose(loss_f.criterion((ref / natoms), (pred / natoms)), ret) + + +def test_force_loss(): + loss_f = loss.ForceLoss(criterion=torch.nn.MSELoss()) + ref = torch.rand((4, 3)) + pred = torch.rand((4, 3)) + batch = tensor([0, 0, 0, 1]) + tmp = AtomGraphData( + force_of_atoms=ref, + inferred_force=pred, + batch=batch, + ).to_dict() + ret = loss_f.get_loss(tmp) + assert loss_f.criterion is not None + assert torch.allclose(loss_f.criterion(ref.reshape(-1), pred.reshape(-1)), ret) + + +def test_stress_loss(): + loss_f = loss.StressLoss(criterion=torch.nn.MSELoss()) + ref = torch.rand((2, 6)) + pred = torch.rand((2, 6)) + tmp = AtomGraphData( + stress=ref, + inferred_stress=pred, + ).to_dict() + ret = loss_f.get_loss(tmp) + KB = 1602.1766208 + assert loss_f.criterion is not None + assert torch.allclose( + loss_f.criterion(ref.reshape(-1) * KB, pred.reshape(-1) * KB), ret + ) + + +@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)]) +def test_loss_from_config(conf): + loss_functions = loss.get_loss_functions_from_config(conf) + + if conf['is_train_stress']: + assert len(loss_functions) == 3 + else: + assert len(loss_functions) == 2 + + for loss_def, w in loss_functions: + assert isinstance(loss_def, loss.LossDefinition) + if isinstance(loss_def, loss.PerAtomEnergyLoss): + assert w == 1.0 + elif isinstance(loss_def, loss.ForceLoss): + assert w == conf['force_loss_weight'] + elif isinstance(loss_def, loss.StressLoss): + assert w == conf['stress_loss_weight'] + else: + raise ValueError(f'Unexpected loss function: {loss_def}') + + +@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) +def test_rms_error(err_type, ndata, natoms): + err_dct = erc.get_err_type(err_type) + err = erc.RMSError(**err_dct) + ref = torch.rand((ndata, err.vdim)).squeeze(1) + pred = torch.rand((ndata, err.vdim)).squeeze(1) + natoms = torch.tensor([natoms] * ndata) + _data = { + err_dct['ref_key']: ref, + err_dct['pred_key']: pred, + 'num_atoms': natoms, + } + + tmp = AtomGraphData(**_data) + err.update(tmp) + + _ref = ref * err.coeff + _pred = pred * err.coeff + if 'per_atom' in err_dct and err_dct['per_atom']: + # natoms = natoms.unsqueeze(-1) + _ref = _ref / natoms + _pred = _pred / natoms + val = torch.sqrt(((_ref - _pred) ** 2).sum() / ndata) # not ndata*natoms + assert np.allclose(err.get(), val.item()) + err.update(tmp) + assert np.allclose(err.get(), val.item()) + + +@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) +def test_mae_error(err_type, ndata, natoms): + err_dct = erc.get_err_type(err_type) + vdim = err_dct['vdim'] + err = erc.MAError(**err_dct) + ref = torch.rand((ndata, vdim)).squeeze(1) + pred = torch.rand((ndata, vdim)).squeeze(1) + natoms = torch.tensor([natoms] * ndata) + _data = { + err_dct['ref_key']: ref, + err_dct['pred_key']: pred, + 'num_atoms': natoms, + } + + tmp = AtomGraphData(**_data) + err.update(tmp) + + _ref = ref * err.coeff + _pred = pred * err.coeff + if 'per_atom' in err_dct and err_dct['per_atom']: + _ref /= natoms + _pred /= natoms + + val = abs(_ref - _pred).sum() / (ndata * vdim) + assert np.allclose(err.get(), val.item()) + err.update(tmp) + assert np.allclose(err.get(), val.item()) + + +# TODO: test_component_rms_error + + +@pytest.mark.parametrize('err_type,ndata,natoms', _erc_test_params) +def test_custom_error(err_type, ndata, natoms): + def func(a, b): + return a * b + + err_dct = erc.get_err_type(err_type) + vdim = err_dct['vdim'] + err = erc.CustomError(func, **err_dct) + ref = torch.rand((ndata, vdim)).squeeze(1) + pred = torch.rand((ndata, vdim)).squeeze(1) + natoms = torch.tensor([natoms] * ndata) + _data = { + err_dct['ref_key']: ref, + err_dct['pred_key']: pred, + 'num_atoms': natoms, + } + + _ref = ref * err.coeff + _pred = pred * err.coeff + if 'per_atom' in err_dct and err_dct['per_atom']: + _ref /= natoms + _pred /= natoms + + tmp = AtomGraphData(**_data) + err.update(tmp) + val = func(_ref, _pred).mean() + assert np.allclose(err.get(), val.item()) + err.update(tmp) + assert np.allclose(err.get(), val.item()) + + +@pytest.mark.parametrize('conf', [config(), config(is_train_stress=False)]) +def test_total_loss_metric_from_config(conf): + def func(a, b): + return a * b + + err = erc.ErrorRecorder.init_total_loss_metric(conf, func) + ndata = 3 + natoms = 4 + + e1, e2 = torch.rand(ndata), torch.rand(ndata) + f1, f2 = torch.rand(ndata * natoms, 3), torch.rand(ndata * natoms, 3) + s1, s2 = torch.rand((ndata, 6)), torch.rand((ndata, 6)) + _data = { + 'total_energy': e1, + 'inferred_total_energy': e2, + 'force_of_atoms': f1, + 'inferred_force': f2, + 'stress': s1, + 'inferred_stress': s2, + 'num_atoms': torch.tensor([natoms] * ndata), + } + + tmp = AtomGraphData(**_data) + err.update(tmp) + + val = (func(e1 / natoms, e2 / natoms)).mean() + conf['force_loss_weight'] * func( + f1, f2 + ).mean() + if conf['is_train_stress']: + KB = 1602.1766208 + val += conf['stress_loss_weight'] * func(s1 * KB, s2 * KB).mean() + + assert np.allclose(err.get(), val.item()) + err.update(tmp) + assert np.allclose(err.get(), val.item()) + + +@pytest.mark.parametrize( + 'conf', [config(), config(is_train_stress=False), config(loss='huber')] +) +def test_error_recorder_from_config(conf): + recorder = erc.ErrorRecorder.from_config(conf) + + total_loss_flag = False + for metric in recorder.metrics: + if conf['is_train_stress'] is False: + assert 'stress' not in metric.name + if metric.name == 'TotalLoss': + total_loss_flag = True + for loss_metric, _ in metric.metrics: # type: ignore + assert isinstance(loss_metric.func, loss_dict[conf['loss']]) + assert total_loss_flag + + +@pytest.mark.parametrize( + 'conf', [config(), config(is_train_stress=False), config(loss='huber')] +) +def test_error_recorder_from_config_and_loss_functions(conf): + loss_functions = loss.get_loss_functions_from_config(conf) + recorder = erc.ErrorRecorder.from_config(conf, loss_functions) + + total_loss_flag = False + for metric in recorder.metrics: + if conf['is_train_stress'] is False: + assert 'stress' not in metric.name + if metric.name == 'TotalLoss': + total_loss_flag = True + for loss_metric, _ in metric.metrics: # type: ignore + assert isinstance( + loss_metric.loss_def.criterion, loss_dict[conf['loss']] + ) + assert total_loss_flag diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py index 3b384bd..000476f 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_modal.py @@ -1,136 +1,136 @@ -# # deploy is test on lammps -# test append modality -# from no modality model to modality yes model -# from modality model to more modality model -# different shift scale settings -# test modality options (check num param) -# calculators with modality - -import copy -# + modal checkpoint continue and test_train -# + sevenn_cp test things in test_cli -import pathlib - -import pytest -from ase.build import bulk - -import sevenn.train.graph_dataset as graph_ds -import sevenn.util as util -from sevenn.calculator import SevenNetCalculator -from sevenn.model_build import build_E3_equivariant_model - -cutoff = 5.0 -data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() -hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') -sevennet_0_path = util.pretrained_name_to_path('7net-0_11July2024') - - -@pytest.fixture(scope='module') -def graph_dataset_path(tmp_path_factory): - gd_path = tmp_path_factory.mktemp('gd') - ds = graph_ds.SevenNetGraphDataset( - cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt' - ) - return ds.processed_paths[0] - - -_modal_cfg = { - 'use_modal_node_embedding': False, - 'use_modal_self_inter_intro': True, - 'use_modal_self_inter_outro': True, - 'use_modal_output_block': True, - 'use_modality': True, - 'use_modal_wise_shift': True, # T/F should be tested - 'use_modal_wise_scale': False, # T/F should be tested - 'load_trainset_path': [ - { - 'data_modality': 'modal_new', - 'file_list': [{'file': hfo2_path}], - } - ], -} - - -@pytest.fixture(scope='module') -def snet_0_cp(): - return util.load_checkpoint(sevennet_0_path) - - -@pytest.fixture(scope='module') -def snet_0_calc(): - return SevenNetCalculator() - - -@pytest.fixture() -def bulk_atoms(): - atoms = bulk('Si') * 3 - atoms.rattle() - return atoms - - -def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): - import numpy as np - - def acl(a, b, rtol=rtol, atol=atol): - return np.allclose(a, b, rtol=rtol, atol=atol) - - assert len(atoms1) == len(atoms2) - assert acl(atoms1.get_cell(), atoms2.get_cell()) - assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) - assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) - assert acl( - atoms1.get_stress(voigt=False), - atoms2.get_stress(voigt=False), - rtol * 10, - atol * 10, - ) - # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) - - -def get_modal_cfg(overwrite=None): - modal_cfg = copy.deepcopy(_modal_cfg).copy() - if overwrite: - modal_cfg.update(overwrite) - return modal_cfg - - -@pytest.mark.parametrize( - 'cfg_overwrite', - [ - ({}), - ({'use_modal_wise_scale': True}), - ({'use_modal_wise_shift': False}), - ({'use_modal_self_inter_intro': False}), - ], -) -def test_append_modal_sevennet_0( - cfg_overwrite, - snet_0_cp, - snet_0_calc, - bulk_atoms, - graph_dataset_path, - tmp_path, -): - modal_cfg = snet_0_cp.config - modal_cfg.pop('load_dataset_path') - modal_cfg.pop('load_validset_path') - modal_cfg.update(get_modal_cfg(cfg_overwrite)) - modal_cfg['shift'] = 'elemwise_reference_energies' - modal_cfg['scale'] = 'per_atom_energy_std' - modal_cfg['load_trainset_path'][0]['file_list'] = [{'file': graph_dataset_path}] - - new_state_dict = snet_0_cp.append_modal( - modal_cfg, original_modal_name='pbe', working_dir=tmp_path - ) - sevennet_0_w_modal = build_E3_equivariant_model(modal_cfg) - sevennet_0_w_modal.load_state_dict(new_state_dict, strict=True) - - atoms1 = bulk_atoms - atoms2 = copy.deepcopy(atoms1) - - atoms1.calc = snet_0_calc - atoms2.calc = SevenNetCalculator( - model=sevennet_0_w_modal, file_type='model_instance', modal='pbe' - ) - - assert_atoms(atoms1, atoms2) +# # deploy is test on lammps +# test append modality +# from no modality model to modality yes model +# from modality model to more modality model +# different shift scale settings +# test modality options (check num param) +# calculators with modality + +import copy +# + modal checkpoint continue and test_train +# + sevenn_cp test things in test_cli +import pathlib + +import pytest +from ase.build import bulk + +import sevenn.train.graph_dataset as graph_ds +import sevenn.util as util +from sevenn.calculator import SevenNetCalculator +from sevenn.model_build import build_E3_equivariant_model + +cutoff = 5.0 +data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() +hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') +sevennet_0_path = util.pretrained_name_to_path('7net-0_11July2024') + + +@pytest.fixture(scope='module') +def graph_dataset_path(tmp_path_factory): + gd_path = tmp_path_factory.mktemp('gd') + ds = graph_ds.SevenNetGraphDataset( + cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt' + ) + return ds.processed_paths[0] + + +_modal_cfg = { + 'use_modal_node_embedding': False, + 'use_modal_self_inter_intro': True, + 'use_modal_self_inter_outro': True, + 'use_modal_output_block': True, + 'use_modality': True, + 'use_modal_wise_shift': True, # T/F should be tested + 'use_modal_wise_scale': False, # T/F should be tested + 'load_trainset_path': [ + { + 'data_modality': 'modal_new', + 'file_list': [{'file': hfo2_path}], + } + ], +} + + +@pytest.fixture(scope='module') +def snet_0_cp(): + return util.load_checkpoint(sevennet_0_path) + + +@pytest.fixture(scope='module') +def snet_0_calc(): + return SevenNetCalculator() + + +@pytest.fixture() +def bulk_atoms(): + atoms = bulk('Si') * 3 + atoms.rattle() + return atoms + + +def assert_atoms(atoms1, atoms2, rtol=1e-5, atol=1e-6): + import numpy as np + + def acl(a, b, rtol=rtol, atol=atol): + return np.allclose(a, b, rtol=rtol, atol=atol) + + assert len(atoms1) == len(atoms2) + assert acl(atoms1.get_cell(), atoms2.get_cell()) + assert acl(atoms1.get_potential_energy(), atoms2.get_potential_energy()) + assert acl(atoms1.get_forces(), atoms2.get_forces(), rtol * 10, atol * 10) + assert acl( + atoms1.get_stress(voigt=False), + atoms2.get_stress(voigt=False), + rtol * 10, + atol * 10, + ) + # assert acl(atoms1.get_potential_energies(), atoms2.get_potential_energies()) + + +def get_modal_cfg(overwrite=None): + modal_cfg = copy.deepcopy(_modal_cfg).copy() + if overwrite: + modal_cfg.update(overwrite) + return modal_cfg + + +@pytest.mark.parametrize( + 'cfg_overwrite', + [ + ({}), + ({'use_modal_wise_scale': True}), + ({'use_modal_wise_shift': False}), + ({'use_modal_self_inter_intro': False}), + ], +) +def test_append_modal_sevennet_0( + cfg_overwrite, + snet_0_cp, + snet_0_calc, + bulk_atoms, + graph_dataset_path, + tmp_path, +): + modal_cfg = snet_0_cp.config + modal_cfg.pop('load_dataset_path') + modal_cfg.pop('load_validset_path') + modal_cfg.update(get_modal_cfg(cfg_overwrite)) + modal_cfg['shift'] = 'elemwise_reference_energies' + modal_cfg['scale'] = 'per_atom_energy_std' + modal_cfg['load_trainset_path'][0]['file_list'] = [{'file': graph_dataset_path}] + + new_state_dict = snet_0_cp.append_modal( + modal_cfg, original_modal_name='pbe', working_dir=tmp_path + ) + sevennet_0_w_modal = build_E3_equivariant_model(modal_cfg) + sevennet_0_w_modal.load_state_dict(new_state_dict, strict=True) + + atoms1 = bulk_atoms + atoms2 = copy.deepcopy(atoms1) + + atoms1.calc = snet_0_calc + atoms2.calc = SevenNetCalculator( + model=sevennet_0_w_modal, file_type='model_instance', modal='pbe' + ) + + assert_atoms(atoms1, atoms2) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py index 843a355..d75976f 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_model.py @@ -1,213 +1,213 @@ -import pytest -import torch -from ase.build import bulk, molecule -from ase.data import chemical_symbols -from torch_geometric.loader.dataloader import Collater - -import sevenn.train.dataload as dl -from sevenn.atom_graph_data import AtomGraphData -from sevenn.model_build import build_E3_equivariant_model -from sevenn.nn.sequential import AtomGraphSequential -from sevenn.util import chemical_species_preprocess - -cutoff = 4.0 - - -_samples = { - 'bulk': bulk('NaCl', 'rocksalt', a=5.63), - 'mol': molecule('H2O'), - 'isolated': molecule('H'), -} -n_samples = len(_samples) -n_atoms_total = sum([len(at) for at in _samples.values()]) - -_graph_list = [ - AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(at, cutoff)) - for at in list(_samples.values()) -] - - -def test_chemical_species_preprocess(): - chems = ['He', 'H', 'Be', 'H'] - cf = chemical_species_preprocess(chems, universal=False) - assert cf['chemical_species'] == ['Be', 'H', 'He'] - assert cf['_number_of_species'] == 3 - assert cf['_type_map'] == {4: 0, 1: 1, 2: 2} - - cf = chemical_species_preprocess(chems, universal=True) - assert cf['chemical_species'] == chemical_symbols - assert cf['_number_of_species'] == len(chemical_symbols) - assert len(cf['_type_map']) == len(chemical_symbols) - for z, node_idx in cf['_type_map'].items(): - assert z == node_idx - - -def get_graphs(batched): - cloned = [g.clone() for g in _graph_list] - if not batched: - return cloned - else: - return Collater(cloned)(cloned) - - -def get_model_config(): - config = { - 'cutoff': cutoff, - 'channel': 4, - 'radial_basis': { - 'radial_basis_name': 'bessel', - }, - 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, - 'interaction_type': 'nequip', - 'lmax': 2, - 'is_parity': True, - 'num_convolution_layer': 3, - 'weight_nn_hidden_neurons': [64, 64], - 'act_radial': 'silu', - 'act_scalar': {'e': 'silu', 'o': 'tanh'}, - 'act_gate': {'e': 'silu', 'o': 'tanh'}, - 'conv_denominator': 30.0, - 'train_denominator': False, - 'self_connection_type': 'nequip', - 'shift': -10.0, - 'scale': 10.0, - 'train_shift_scale': False, - 'irreps_manual': False, - 'lmax_edge': -1, - 'lmax_node': -1, - 'readout_as_fcn': False, - 'use_bias_in_linear': False, - '_normalize_sph': True, - } - chems = set() - for at in list(_samples.values()): - chems.update(at.get_chemical_symbols()) - config.update(**chemical_species_preprocess(list(chems))) - return config - - -def get_model(config_overwrite={}): - cf = get_model_config() - cf.update(**config_overwrite) - model = build_E3_equivariant_model(cf, parallel=False) - assert isinstance(model, AtomGraphSequential) - return model - - -@pytest.mark.parametrize('batched', [False, True]) -@pytest.mark.parametrize('cf', [{}]) -def test_shape(cf, batched): - model = get_model(cf) - model.set_is_batch_data(batched) - - graph = get_graphs(batched) - if not batched: - output_shapes = { - 'inferred_total_energy': (), - 'inferred_stress': (6,), - } - for g in graph: - natoms = g['num_atoms'] - output_shapes.update( - { - 'atomic_energy': (natoms, 1), # intended - 'inferred_force': (natoms, 3), - } - ) - output = model(g) - for k, shape in output_shapes.items(): - assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}' - else: - output_shapes = { - 'inferred_total_energy': (n_samples,), - 'atomic_energy': (n_atoms_total, 1), # intended - 'inferred_force': (n_atoms_total, 3), - 'inferred_stress': (n_samples, 6), - } - output = model(graph) - for k, shape in output_shapes.items(): - assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}' - - -def test_batch(): - model = get_model() - model.set_is_batch_data(False) - - graph_list = get_graphs(batched=False) - output_list = [model(g) for g in graph_list] - - model.set_is_batch_data(True) - graph_batch = get_graphs(batched=True) - output_batched = model(graph_batch) - - e_concat = torch.concat( - [g['inferred_total_energy'].unsqueeze(-1) for g in output_list] - ) - ae_concat = torch.concat([g['atomic_energy'].squeeze(1) for g in output_list]) - f_concat = torch.concat([g['inferred_force'] for g in output_list]) - s_concat = torch.stack([g['inferred_stress'] for g in output_list]) - - assert torch.allclose(e_concat, output_batched['inferred_total_energy']) - assert torch.allclose(ae_concat, output_batched['atomic_energy'].squeeze(1)) - assert torch.allclose( - torch.round(f_concat, decimals=5), - torch.round(output_batched['inferred_force'], decimals=5), - atol=1e-5, - ) - - assert torch.allclose( # TODO, hard-coded, assumes the first structure is bulk - torch.round(s_concat[0], decimals=5), - torch.round(output_batched['inferred_stress'][0], decimals=5), - ) - - -_n_param_tests = [ - ({}, 20642), - ({'train_denominator': True}, 20642 + 3), - ({'train_shift_scale': True}, 20642 + 2), - ({'shift': [1.0] * 4}, 20642), - ({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8), - ({'num_convolution_layer': 4}, 33458), - ({'lmax': 3}, 26866), - ({'channel': 2}, 16883), - ({'is_parity': False}, 20386), - ({'self_connection_type': 'linear'}, 20114), -] - - -@pytest.mark.parametrize('cf,ref', _n_param_tests) -def test_num_params(cf, ref): - model = get_model(cf) - param = sum([p.numel() for p in model.parameters() if p.requires_grad]) - assert param == ref, f'ref: {ref} != given: {param}' - - -_n_modal_param_tests = [ - ({}, 20642), - ({'use_modal_node_embedding': True}, 20642 + 8), - ({'use_modal_self_inter_intro': True}, 20642 + 2 * 4 * 3), - ({'use_modal_self_inter_outro': True}, 20642 + 2 * (12 + 20 + 4)), - ({'use_modal_output_block': True}, 20642 + 2 * 4 / 2), -] - - -@pytest.mark.parametrize('cf,ref', _n_modal_param_tests) -def test_modal_num_params(cf, ref): - modal_cfg = { - 'use_modality': True, - '_number_of_modalities': 2, - '_modal_map': {'x1': 0, 'x2': 1}, - 'use_modal_node_embedding': False, - 'use_modal_self_inter_intro': False, - 'use_modal_self_inter_outro': False, - 'use_modal_output_block': False, - 'use_modal_wise_shift': False, - 'use_modal_wise_scale': False, - } - modal_cfg.update(cf) - model = get_model(modal_cfg) - param = sum([p.numel() for p in model.parameters() if p.requires_grad]) - assert param == ref, f'ref: {ref} != given: {param}' - - -# TODO: test_irreps, test_gard, test_equivariance +import pytest +import torch +from ase.build import bulk, molecule +from ase.data import chemical_symbols +from torch_geometric.loader.dataloader import Collater + +import sevenn.train.dataload as dl +from sevenn.atom_graph_data import AtomGraphData +from sevenn.model_build import build_E3_equivariant_model +from sevenn.nn.sequential import AtomGraphSequential +from sevenn.util import chemical_species_preprocess + +cutoff = 4.0 + + +_samples = { + 'bulk': bulk('NaCl', 'rocksalt', a=5.63), + 'mol': molecule('H2O'), + 'isolated': molecule('H'), +} +n_samples = len(_samples) +n_atoms_total = sum([len(at) for at in _samples.values()]) + +_graph_list = [ + AtomGraphData.from_numpy_dict(dl.unlabeled_atoms_to_graph(at, cutoff)) + for at in list(_samples.values()) +] + + +def test_chemical_species_preprocess(): + chems = ['He', 'H', 'Be', 'H'] + cf = chemical_species_preprocess(chems, universal=False) + assert cf['chemical_species'] == ['Be', 'H', 'He'] + assert cf['_number_of_species'] == 3 + assert cf['_type_map'] == {4: 0, 1: 1, 2: 2} + + cf = chemical_species_preprocess(chems, universal=True) + assert cf['chemical_species'] == chemical_symbols + assert cf['_number_of_species'] == len(chemical_symbols) + assert len(cf['_type_map']) == len(chemical_symbols) + for z, node_idx in cf['_type_map'].items(): + assert z == node_idx + + +def get_graphs(batched): + cloned = [g.clone() for g in _graph_list] + if not batched: + return cloned + else: + return Collater(cloned)(cloned) + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 4, + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'interaction_type': 'nequip', + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': 30.0, + 'train_denominator': False, + 'self_connection_type': 'nequip', + 'shift': -10.0, + 'scale': 10.0, + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + chems = set() + for at in list(_samples.values()): + chems.update(at.get_chemical_symbols()) + config.update(**chemical_species_preprocess(list(chems))) + return config + + +def get_model(config_overwrite={}): + cf = get_model_config() + cf.update(**config_overwrite) + model = build_E3_equivariant_model(cf, parallel=False) + assert isinstance(model, AtomGraphSequential) + return model + + +@pytest.mark.parametrize('batched', [False, True]) +@pytest.mark.parametrize('cf', [{}]) +def test_shape(cf, batched): + model = get_model(cf) + model.set_is_batch_data(batched) + + graph = get_graphs(batched) + if not batched: + output_shapes = { + 'inferred_total_energy': (), + 'inferred_stress': (6,), + } + for g in graph: + natoms = g['num_atoms'] + output_shapes.update( + { + 'atomic_energy': (natoms, 1), # intended + 'inferred_force': (natoms, 3), + } + ) + output = model(g) + for k, shape in output_shapes.items(): + assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}' + else: + output_shapes = { + 'inferred_total_energy': (n_samples,), + 'atomic_energy': (n_atoms_total, 1), # intended + 'inferred_force': (n_atoms_total, 3), + 'inferred_stress': (n_samples, 6), + } + output = model(graph) + for k, shape in output_shapes.items(): + assert output[k].shape == shape, f'{k}: {output[k].shape} != {shape}' + + +def test_batch(): + model = get_model() + model.set_is_batch_data(False) + + graph_list = get_graphs(batched=False) + output_list = [model(g) for g in graph_list] + + model.set_is_batch_data(True) + graph_batch = get_graphs(batched=True) + output_batched = model(graph_batch) + + e_concat = torch.concat( + [g['inferred_total_energy'].unsqueeze(-1) for g in output_list] + ) + ae_concat = torch.concat([g['atomic_energy'].squeeze(1) for g in output_list]) + f_concat = torch.concat([g['inferred_force'] for g in output_list]) + s_concat = torch.stack([g['inferred_stress'] for g in output_list]) + + assert torch.allclose(e_concat, output_batched['inferred_total_energy']) + assert torch.allclose(ae_concat, output_batched['atomic_energy'].squeeze(1)) + assert torch.allclose( + torch.round(f_concat, decimals=5), + torch.round(output_batched['inferred_force'], decimals=5), + atol=1e-5, + ) + + assert torch.allclose( # TODO, hard-coded, assumes the first structure is bulk + torch.round(s_concat[0], decimals=5), + torch.round(output_batched['inferred_stress'][0], decimals=5), + ) + + +_n_param_tests = [ + ({}, 20642), + ({'train_denominator': True}, 20642 + 3), + ({'train_shift_scale': True}, 20642 + 2), + ({'shift': [1.0] * 4}, 20642), + ({'scale': [1.0] * 4, 'train_shift_scale': True}, 20642 + 8), + ({'num_convolution_layer': 4}, 33458), + ({'lmax': 3}, 26866), + ({'channel': 2}, 16883), + ({'is_parity': False}, 20386), + ({'self_connection_type': 'linear'}, 20114), +] + + +@pytest.mark.parametrize('cf,ref', _n_param_tests) +def test_num_params(cf, ref): + model = get_model(cf) + param = sum([p.numel() for p in model.parameters() if p.requires_grad]) + assert param == ref, f'ref: {ref} != given: {param}' + + +_n_modal_param_tests = [ + ({}, 20642), + ({'use_modal_node_embedding': True}, 20642 + 8), + ({'use_modal_self_inter_intro': True}, 20642 + 2 * 4 * 3), + ({'use_modal_self_inter_outro': True}, 20642 + 2 * (12 + 20 + 4)), + ({'use_modal_output_block': True}, 20642 + 2 * 4 / 2), +] + + +@pytest.mark.parametrize('cf,ref', _n_modal_param_tests) +def test_modal_num_params(cf, ref): + modal_cfg = { + 'use_modality': True, + '_number_of_modalities': 2, + '_modal_map': {'x1': 0, 'x2': 1}, + 'use_modal_node_embedding': False, + 'use_modal_self_inter_intro': False, + 'use_modal_self_inter_outro': False, + 'use_modal_output_block': False, + 'use_modal_wise_shift': False, + 'use_modal_wise_scale': False, + } + modal_cfg.update(cf) + model = get_model(modal_cfg) + param = sum([p.numel() for p in model.parameters() if p.requires_grad]) + assert param == ref, f'ref: {ref} != given: {param}' + + +# TODO: test_irreps, test_gard, test_equivariance diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py index d82d7ab..4676ca4 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_pretrained.py @@ -1,344 +1,344 @@ -# test_pretrained: output consistency for pretrained models - -import pytest -import torch -from ase.build import bulk, molecule - -import sevenn._keys as KEY -from sevenn.atom_graph_data import AtomGraphData -from sevenn.train.dataload import unlabeled_atoms_to_graph -from sevenn.util import model_from_checkpoint, pretrained_name_to_path - - -def acl(a, b, atol=1e-6): - return torch.allclose(a, b, atol=atol) - - -@pytest.fixture -def atoms_pbc(): - atoms1 = bulk('NaCl', 'rocksalt', a=5.63) - atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) - atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) - return atoms1 - - -@pytest.fixture -def atoms_mol(): - atoms2 = molecule('H2O') - atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]]) - return atoms2 - - -def test_7net0_22May2024(atoms_pbc, atoms_mol): - """ - Reference from v0.9.3.post1 with SevenNetCalculator - """ - cp_path = pretrained_name_to_path('7net-0_22May2024') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - g1_ref_e = torch.tensor([-3.4140868186950684]) - g1_ref_f = torch.tensor( - [ - [1.2628037e01, 7.5093508e-03, 1.3480943e-02], - [-1.2628037e01, -7.5093508e-03, -1.3480917e-02], - ] - ) - g1_ref_s = -1 * torch.tensor( - [-0.65014917, -0.01990843, -0.02000658, 0.03286226, 0.00589222, 0.03291973] - ) - - g2_ref_e = torch.tensor([-12.808363914489746]) - g2_ref_f = torch.tensor( - [ - [9.31322575e-10, -1.30241165e01, 6.93116236e00], - [-1.39698386e-09, 9.28001022e00, -9.51867390e00], - [5.23868948e-10, 3.74410582e00, 2.58751225e00], - ] - ) - - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net0_11July2024(atoms_pbc, atoms_mol): - """ - Reference from v0.9.3.post1 with SevenNetCalculator - """ - cp_path = pretrained_name_to_path('7net-0_11July2024') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.779199]) - g1_ref_f = torch.tensor( - [ - [12.666697, 0.04726403, 0.04775861], - [-12.666697, -0.04726403, -0.04775861], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6439122, -0.03643947, -0.03643981, 0.04543639, 0.00599139, 0.04544507] - ) - - g2_ref_e = torch.tensor([-12.782808303833008]) - g2_ref_f = torch.tensor( - [ - [0.0, -1.3619621e01, 7.5937047e00], - [0.0, 9.3918495e00, -1.0172190e01], - [0.0, 4.2277718e00, 2.5784855e00], - ] - ) - - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_l3i5(atoms_pbc, atoms_mol): - """ - Reference from v0.9.3.post1 with SevenNetCalculator - """ - cp_path = pretrained_name_to_path('7net-l3i5') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.611131191253662]) - g1_ref_f = torch.tensor( - [ - [13.430887, 0.08655541, 0.08754013], - [-13.430886, -0.08655544, -0.08754011], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6818918, -0.04104544, -0.04107663, 0.04794561, 0.00565416, 0.04793138] - ) - - g2_ref_e = torch.tensor([-12.700481414794922]) - g2_ref_f = torch.tensor( - [ - [0.0, -1.4547814e01, 8.1347866], - [0.0, 1.0308369e01, -1.0880318e01], - [0.0, 4.2394452, 2.7455316], - ] - ) - - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f, 1e-5) - assert acl(g1.inferred_stress, g1_ref_s, 1e-5) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_mf_0(atoms_pbc, atoms_mol): - cp_path = pretrained_name_to_path('7net-mf-0') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - g1[KEY.DATA_MODALITY] = 'R2SCAN' - g2[KEY.DATA_MODALITY] = 'R2SCAN' - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-11.607587814331055]) - g1_ref_f = torch.tensor( - [ - [8.512259, 0.07307914, 0.06676716], - [-8.512257, -0.07307915, -0.06676716], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.4516204, -0.02483013, -0.02485001, 0.03247492, 0.00259375, 0.03250402] - ) - - g2_ref_e = torch.tensor([-14.172412872314453]) - g2_ref_f = torch.tensor( - [ - [4.6566129e-10, -1.3429364e01, 6.9344816e00], - [2.3283064e-09, 8.9132404e00, -9.6807365e00], - [-2.7939677e-09, 4.5161238e00, 2.7462559e00], - ] - ) - - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_mf_ompa_mpa(atoms_pbc, atoms_mol): - cp_path = pretrained_name_to_path('7net-mf-ompa') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - # mpa - g1[KEY.DATA_MODALITY] = 'mpa' - g2[KEY.DATA_MODALITY] = 'mpa' - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.490943193435669]) - g1_ref_f = torch.tensor( - [ - [1.2680445e01, -2.7985498e-04, -2.7979910e-04], - [-1.2680446e01, 2.7984008e-04, 2.7981028e-04], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6481662, -0.02462837, -0.02462837, 0.02693467, 0.00459635, 0.02693467] - ) - - g2_ref_e = torch.tensor([-12.597525596618652]) - g2_ref_f = torch.tensor( - [ - [0.0, -12.245223, 7.26795], - [0.0, 8.816763, -9.423925], - [0.0, 3.4284601, 2.1559749], - ] - ) - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_mf_ompa_omat(atoms_pbc, atoms_mol): - cp_path = pretrained_name_to_path('7net-mf-ompa') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - # mpa - g1[KEY.DATA_MODALITY] = 'omat24' - g2[KEY.DATA_MODALITY] = 'omat24' - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.5094668865203857]) - g1_ref_f = torch.tensor( - [ - [1.2562084e01, -1.4219694e-03, -1.4219843e-03], - [-1.2562084e01, 1.4219508e-03, 1.4219955e-03], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6430905, -0.0254128, -0.02541281, 0.0268343, 0.00460021, 0.0268343] - ) - - g2_ref_e = torch.tensor([-12.6202974319458]) - g2_ref_f = torch.tensor( - [ - [0.0, -12.205926, 7.2050343], - [0.0, 8.790399, -9.368677], - [0.0, 3.4155273, 2.163643], - ] - ) - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) - - -def test_7net_omat(atoms_pbc, atoms_mol): - cp_path = pretrained_name_to_path('7net-omat') - model, config = model_from_checkpoint(cp_path) - cutoff = config['cutoff'] - - g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) - g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) - - model.set_is_batch_data(False) - g1 = model(g1) - g2 = model(g2) - - model.set_is_batch_data(True) - - g1_ref_e = torch.tensor([-3.5033323764801025]) - g1_ref_f = torch.tensor( - [ - [12.533154, 0.02358698, 0.02358694], - [-12.533153, -0.02358699, -0.02358697], - ] - ) - g1_ref_s = -1 * torch.tensor( - # xx, yy, zz, xy, yz, zx - [-0.6420925, -0.02781446, -0.02781446, 0.02575445, 0.00381664, 0.02575445] - ) - - g2_ref_e = torch.tensor([-12.403768539428711]) - g2_ref_f = torch.tensor( - [ - [0, -12.848297, 7.11432], - [0.0, 9.265477, -9.564951], - [0.0, 3.58282, 2.4506311], - ] - ) - assert acl(g1.inferred_total_energy, g1_ref_e) - assert acl(g1.inferred_force, g1_ref_f) - assert acl(g1.inferred_stress, g1_ref_s) - - assert acl(g2.inferred_total_energy, g2_ref_e) - assert acl(g2.inferred_force, g2_ref_f) +# test_pretrained: output consistency for pretrained models + +import pytest +import torch +from ase.build import bulk, molecule + +import sevenn._keys as KEY +from sevenn.atom_graph_data import AtomGraphData +from sevenn.train.dataload import unlabeled_atoms_to_graph +from sevenn.util import model_from_checkpoint, pretrained_name_to_path + + +def acl(a, b, atol=1e-6): + return torch.allclose(a, b, atol=atol) + + +@pytest.fixture +def atoms_pbc(): + atoms1 = bulk('NaCl', 'rocksalt', a=5.63) + atoms1.set_cell([[1.0, 2.815, 2.815], [2.815, 0.0, 2.815], [2.815, 2.815, 0.0]]) + atoms1.set_positions([[0.0, 0.0, 0.0], [2.815, 0.0, 0.0]]) + return atoms1 + + +@pytest.fixture +def atoms_mol(): + atoms2 = molecule('H2O') + atoms2.set_positions([[0.0, 0.2, 0.12], [0.0, 0.76, -0.48], [0.0, -0.76, -0.48]]) + return atoms2 + + +def test_7net0_22May2024(atoms_pbc, atoms_mol): + """ + Reference from v0.9.3.post1 with SevenNetCalculator + """ + cp_path = pretrained_name_to_path('7net-0_22May2024') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + g1_ref_e = torch.tensor([-3.4140868186950684]) + g1_ref_f = torch.tensor( + [ + [1.2628037e01, 7.5093508e-03, 1.3480943e-02], + [-1.2628037e01, -7.5093508e-03, -1.3480917e-02], + ] + ) + g1_ref_s = -1 * torch.tensor( + [-0.65014917, -0.01990843, -0.02000658, 0.03286226, 0.00589222, 0.03291973] + ) + + g2_ref_e = torch.tensor([-12.808363914489746]) + g2_ref_f = torch.tensor( + [ + [9.31322575e-10, -1.30241165e01, 6.93116236e00], + [-1.39698386e-09, 9.28001022e00, -9.51867390e00], + [5.23868948e-10, 3.74410582e00, 2.58751225e00], + ] + ) + + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net0_11July2024(atoms_pbc, atoms_mol): + """ + Reference from v0.9.3.post1 with SevenNetCalculator + """ + cp_path = pretrained_name_to_path('7net-0_11July2024') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.779199]) + g1_ref_f = torch.tensor( + [ + [12.666697, 0.04726403, 0.04775861], + [-12.666697, -0.04726403, -0.04775861], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6439122, -0.03643947, -0.03643981, 0.04543639, 0.00599139, 0.04544507] + ) + + g2_ref_e = torch.tensor([-12.782808303833008]) + g2_ref_f = torch.tensor( + [ + [0.0, -1.3619621e01, 7.5937047e00], + [0.0, 9.3918495e00, -1.0172190e01], + [0.0, 4.2277718e00, 2.5784855e00], + ] + ) + + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_l3i5(atoms_pbc, atoms_mol): + """ + Reference from v0.9.3.post1 with SevenNetCalculator + """ + cp_path = pretrained_name_to_path('7net-l3i5') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.611131191253662]) + g1_ref_f = torch.tensor( + [ + [13.430887, 0.08655541, 0.08754013], + [-13.430886, -0.08655544, -0.08754011], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6818918, -0.04104544, -0.04107663, 0.04794561, 0.00565416, 0.04793138] + ) + + g2_ref_e = torch.tensor([-12.700481414794922]) + g2_ref_f = torch.tensor( + [ + [0.0, -1.4547814e01, 8.1347866], + [0.0, 1.0308369e01, -1.0880318e01], + [0.0, 4.2394452, 2.7455316], + ] + ) + + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f, 1e-5) + assert acl(g1.inferred_stress, g1_ref_s, 1e-5) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_mf_0(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-mf-0') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + g1[KEY.DATA_MODALITY] = 'R2SCAN' + g2[KEY.DATA_MODALITY] = 'R2SCAN' + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-11.607587814331055]) + g1_ref_f = torch.tensor( + [ + [8.512259, 0.07307914, 0.06676716], + [-8.512257, -0.07307915, -0.06676716], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.4516204, -0.02483013, -0.02485001, 0.03247492, 0.00259375, 0.03250402] + ) + + g2_ref_e = torch.tensor([-14.172412872314453]) + g2_ref_f = torch.tensor( + [ + [4.6566129e-10, -1.3429364e01, 6.9344816e00], + [2.3283064e-09, 8.9132404e00, -9.6807365e00], + [-2.7939677e-09, 4.5161238e00, 2.7462559e00], + ] + ) + + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_mf_ompa_mpa(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-mf-ompa') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + # mpa + g1[KEY.DATA_MODALITY] = 'mpa' + g2[KEY.DATA_MODALITY] = 'mpa' + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.490943193435669]) + g1_ref_f = torch.tensor( + [ + [1.2680445e01, -2.7985498e-04, -2.7979910e-04], + [-1.2680446e01, 2.7984008e-04, 2.7981028e-04], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6481662, -0.02462837, -0.02462837, 0.02693467, 0.00459635, 0.02693467] + ) + + g2_ref_e = torch.tensor([-12.597525596618652]) + g2_ref_f = torch.tensor( + [ + [0.0, -12.245223, 7.26795], + [0.0, 8.816763, -9.423925], + [0.0, 3.4284601, 2.1559749], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_mf_ompa_omat(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-mf-ompa') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + # mpa + g1[KEY.DATA_MODALITY] = 'omat24' + g2[KEY.DATA_MODALITY] = 'omat24' + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.5094668865203857]) + g1_ref_f = torch.tensor( + [ + [1.2562084e01, -1.4219694e-03, -1.4219843e-03], + [-1.2562084e01, 1.4219508e-03, 1.4219955e-03], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6430905, -0.0254128, -0.02541281, 0.0268343, 0.00460021, 0.0268343] + ) + + g2_ref_e = torch.tensor([-12.6202974319458]) + g2_ref_f = torch.tensor( + [ + [0.0, -12.205926, 7.2050343], + [0.0, 8.790399, -9.368677], + [0.0, 3.4155273, 2.163643], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) + + +def test_7net_omat(atoms_pbc, atoms_mol): + cp_path = pretrained_name_to_path('7net-omat') + model, config = model_from_checkpoint(cp_path) + cutoff = config['cutoff'] + + g1 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_pbc, cutoff)) + g2 = AtomGraphData.from_numpy_dict(unlabeled_atoms_to_graph(atoms_mol, cutoff)) + + model.set_is_batch_data(False) + g1 = model(g1) + g2 = model(g2) + + model.set_is_batch_data(True) + + g1_ref_e = torch.tensor([-3.5033323764801025]) + g1_ref_f = torch.tensor( + [ + [12.533154, 0.02358698, 0.02358694], + [-12.533153, -0.02358699, -0.02358697], + ] + ) + g1_ref_s = -1 * torch.tensor( + # xx, yy, zz, xy, yz, zx + [-0.6420925, -0.02781446, -0.02781446, 0.02575445, 0.00381664, 0.02575445] + ) + + g2_ref_e = torch.tensor([-12.403768539428711]) + g2_ref_f = torch.tensor( + [ + [0, -12.848297, 7.11432], + [0.0, 9.265477, -9.564951], + [0.0, 3.58282, 2.4506311], + ] + ) + assert acl(g1.inferred_total_energy, g1_ref_e) + assert acl(g1.inferred_force, g1_ref_f) + assert acl(g1.inferred_stress, g1_ref_s) + + assert acl(g2.inferred_total_energy, g2_ref_e) + assert acl(g2.inferred_force, g2_ref_f) diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py index 2fe3245..3d5a4eb 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_shift_scale.py @@ -1,494 +1,494 @@ -import pytest -import torch - -import sevenn._keys as KEY -from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType -from sevenn.nn.scale import ( - ModalWiseRescale, - Rescale, - SpeciesWiseRescale, - get_resolved_shift_scale, -) - -################################################################################ -# Tests for Rescale # -################################################################################ - - -@pytest.mark.parametrize('shift,scale', [(0.0, 1.0), (1.0, 2.0), (-5.0, 10.0)]) -def test_rescale_init(shift, scale): - """ - Test that Rescale can be initialized properly without errors - and that parameters are set correctly. - """ - module = Rescale(shift=shift, scale=scale) - assert module.shift.item() == shift - assert module.scale.item() == scale - assert module.key_input == KEY.SCALED_ATOMIC_ENERGY - assert module.key_output == KEY.ATOMIC_ENERGY - - -def test_rescale_forward(): - """ - Test that Rescale forward pass correctly applies: - output = input * scale + shift - """ - # Setup - shift, scale = 1.0, 2.0 - module = Rescale(shift=shift, scale=scale) - # Make some fake data - input_data = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) - data: AtomGraphDataType = {KEY.SCALED_ATOMIC_ENERGY: input_data.clone()} - - # Forward - out_data = module(data) - - # Check correctness - expected_output = input_data * scale + shift - assert torch.allclose(out_data[KEY.ATOMIC_ENERGY], expected_output) - - -def test_rescale_get_shift_and_scale(): - """ - Test get_shift() and get_scale() methods in Rescale. - """ - module = Rescale(shift=1.5, scale=3.5) - assert module.get_shift() == pytest.approx(1.5) - assert module.get_scale() == pytest.approx(3.5) - - -################################################################################ -# Tests for SpeciesWiseRescale # -################################################################################ - - -def test_specieswise_rescale_init_float(): - """ - Test SpeciesWiseRescale when both shift and scale are floats - (should expand to same length lists). - """ - module = SpeciesWiseRescale(shift=[1.0, -1.0], scale=2.0) - # Expect a parameter of length = 1 in this scenario, but can differ - # if we raise an error for "Both shift and scale is not a list". - # Usually, you'd specify a known number of species or do from_mappers. - # The code as-is throws ValueError if both are float. Let's do from_mappers: - # We'll do direct init if your code allows it. If not, use from_mappers. - assert module.shift.shape == module.scale.shape - # They must be single-parameter (or expanded) if not from mappers. - - -def test_specieswise_rescale_init_list(): - """ - Test initialization with list-based shift/scale of same length. - """ - shift = [1.0, 2.0, 3.0] - scale = [2.0, 3.0, 4.0] - module = SpeciesWiseRescale(shift=shift, scale=scale) - assert len(module.shift) == 3 - assert len(module.scale) == 3 - assert torch.allclose(module.shift, torch.tensor([1.0, 2.0, 3.0])) - assert torch.allclose(module.scale, torch.tensor([2.0, 3.0, 4.0])) - - -def test_specieswise_rescale_forward(): - """ - Test that SpeciesWiseRescale forward pass applies: - output[i] = input[i]*scale[atom_type[i]] + shift[atom_type[i]] - """ - # Suppose we have two species types: - # 0 -> shift=1, scale=2, 1 -> shift=5, scale=10 - # (we'll pass them as lists in the correct order) - shift = [1.0, 5.0] - scale = [2.0, 10.0] - module = SpeciesWiseRescale( - shift=shift, - scale=scale, - data_key_in='in', - data_key_out='out', - data_key_indices='z', - ) - - # Create mock data - # Suppose we have three atoms: species => [0, 1, 0] - # input => [ [1.], [1.], [3.] ] - data: AtomGraphDataType = { - 'z': torch.tensor([0, 1, 0], dtype=torch.long), - 'in': torch.tensor([[1.0], [1.0], [3.0]], dtype=torch.float), - } - - out = module(data) - # Now let's manually compute expected: - # For atom 0: scale=2, shift=1, input=1 => 1*2+1=3 - # For atom 1: scale=10, shift=5, input=1 => 1*10+5=15 - # For atom 2: scale=2, shift=1, input=3 => 3*2+1=7 - expected = torch.tensor([[3.0], [15.0], [7.0]]) - - assert torch.allclose(out['out'], expected) - - -def test_specieswise_rescale_get_shift_scale(): - """ - Test get_shift() and get_scale() with/without type_map. - """ - shift = [1.0, 2.0] - scale = [3.0, 4.0] - module = SpeciesWiseRescale(shift=shift, scale=scale) - - # Without type_map - # Should return the raw parameter values (list form). - s = module.get_shift() - sc = module.get_scale() - assert s == [1.0, 2.0] - assert sc == [3.0, 4.0] - - # With a type_map (example: atomic_number 1 -> 0, 8 -> 1) - type_map = {1: 0, 8: 1} # hydrogen, oxygen - s_univ = module.get_shift(type_map) - sc_univ = module.get_scale(type_map) - # In this small example with NUM_UNIV_ELEMENT = 2, the _as_univ will produce - # a list of length = NUM_UNIV_ELEMENT. If your real NUM_UNIV_ELEMENT is bigger, - # the rest would be padded with default values. - # For demonstration let's assume it returns [1.0, 2.0]. - # Check at least the known mapped portion: - assert len(s_univ) == NUM_UNIV_ELEMENT - assert len(sc_univ) == NUM_UNIV_ELEMENT - assert s_univ[1] == 1.0 # atomic_number=1 -> idx=0 -> shift=1.0 - assert s_univ[8] == 2.0 - - -################################################################################ -# Tests for ModalWiseRescale # -################################################################################ - - -def test_modalwise_rescale_init(): - """ - Basic sanity check for ModalWiseRescale initialization with - certain shapes. - """ - # Suppose we have 2 modals, 3 species => shift, scale is shape [2,3] - shift = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] - scale = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - module = ModalWiseRescale( - shift=shift, - scale=scale, - use_modal_wise_shift=True, - use_modal_wise_scale=True, - ) - # Check shape - assert module.shift.shape == torch.Size([2, 3]) - assert module.scale.shape == torch.Size([2, 3]) - - -def test_modalwise_rescale_forward(): - """ - Test that the forward pass of ModalWiseRescale matches - output[i] = input[i] * scale[modal_i, atom_i] + shift[modal_i, atom_i] - when both use_modal_wise_{shift,scale} are True. - """ - shift = [[0.0, 10.0], [5.0, 15.0]] # shape [2 (modals), 2 (species)] - scale = [[1.0, 2.0], [10.0, 20.0]] - module = ModalWiseRescale( - shift=shift, - scale=scale, - data_key_in='in', - data_key_out='out', - data_key_modal_indices='modal_idx', - data_key_atom_indices='atom_idx', - use_modal_wise_shift=True, - use_modal_wise_scale=True, - ) - - data: AtomGraphDataType = { - 'in': torch.tensor([[1.0], [1.0], [2.0], [2.0]]), - 'modal_idx': torch.tensor([0, 1], dtype=torch.long), - 'atom_idx': torch.tensor([0, 1, 0, 1], dtype=torch.long), - 'batch': torch.tensor([0, 0, 1, 1], dtype=torch.long), - } - - out = module(data) - # i=0 => modal_idx=0, atom_idx=0 => shift=0.0, scale=1.0 => out=1*1+0=1 - # i=1 => modal_idx=0, atom_idx=1 => shift=10.0, scale=2.0 => out=1*2+10=12 - # i=2 => modal_idx=1, atom_idx=0 => shift=5.0, scale=10.0 => out=2*10+5=25 - # i=3 => modal_idx=1, atom_idx=1 => shift=15.0, scale=20.0 => out=2*20+15=55 - expected = torch.tensor([[1.0], [12.0], [25.0], [55.0]]) - assert torch.allclose(out['out'], expected) - - -def test_modalwise_rescale_get_shift_scale(): - """ - Test get_shift() and get_scale() with type_map and modal_map. - """ - # Setup - shift = [[0.0, 10.0], [5.0, 15.0]] - scale = [[1.0, 2.0], [10.0, 20.0]] - mod = ModalWiseRescale( - shift=shift, - scale=scale, - use_modal_wise_shift=True, - use_modal_wise_scale=True, - ) - - # Suppose we have type_map and modal_map - type_map = {1: 0, 8: 1} # Example: H->0, O->1 - modal_map = {'a': 0, 'b': 1} - - # get_shift, get_scale - s = mod.get_shift(type_map=type_map, modal_map=modal_map) - sc = mod.get_scale(type_map=type_map, modal_map=modal_map) - # Expect dict with keys "ambient", "pressure". - # Example: s["ambient"] = [ shift(0,0), shift(0,1) ] mapped to H,O - # s["pressure"] = [ shift(1,0), shift(1,1) ] - assert isinstance(s, dict) and isinstance(sc, dict) - assert set(s.keys()) == {'a', 'b'} - assert set(sc.keys()) == {'a', 'b'} - - -################################################################################ -# Tests for get_resolved_shift_scale function # -################################################################################ - - -def test_get_resolved_shift_scale_rescale(): - """ - Test get_resolved_shift_scale for a Rescale instance. - """ - from_m = Rescale(shift=2.0, scale=5.0) - shift, scale = get_resolved_shift_scale(from_m) - assert shift == 2.0 - assert scale == 5.0 - - -def test_get_resolved_shift_scale_specieswise(): - """ - Test get_resolved_shift_scale for a SpeciesWiseRescale instance. - """ - shift_list = [1.0, 2.0] - scale_list = [3.0, 4.0] - module = SpeciesWiseRescale(shift=shift_list, scale=scale_list) - type_map = {1: 0, 8: 1} - s, sc = get_resolved_shift_scale(module, type_map=type_map) - # The result should be extended to NUM_UNIV_ELEMENT length in real usage, - # but at least the first few should match shift_list, scale_list mapped. - assert isinstance(s, list) - assert isinstance(sc, list) - # Check mapped values - assert s[1] == shift_list[0] - assert s[8] == shift_list[1] - assert sc[1] == scale_list[0] - assert sc[8] == scale_list[1] - - -def test_get_resolved_shift_scale_modalwise(): - """ - Test get_resolved_shift_scale for a ModalWiseRescale instance. - """ - shift = [[0.0, 10.0], [5.0, 15.0]] - scale = [[1.0, 2.0], [10.0, 20.0]] - mmod = ModalWiseRescale( - shift=shift, - scale=scale, - use_modal_wise_shift=True, - use_modal_wise_scale=True, - ) - type_map = {1: 0, 8: 1} - modal_map = {'a': 0, 'b': 1} - s, sc = get_resolved_shift_scale(mmod, type_map=type_map, modal_map=modal_map) - # We expect dictionaries - assert isinstance(s, dict) and isinstance(sc, dict) - # Keys "a", "pressure" - assert 'a' in s - assert 'b' in s - # Check one example - # s["a"] => [0.0, 10.0] - # sc["a"] => [1.0, 2.0] - assert s['a'][1] == 0.0 - assert s['a'][8] == 10.0 - assert sc['a'][1] == 1.0 - assert sc['a'][8] == 2.0 - - -################################################################################ -# Tests for from_mappers function # -################################################################################ - - -@pytest.mark.parametrize( - 'shift, scale, type_map, expected_shift, expected_scale', - [ - # Both shift and scale are floats -> broadcast to each species - ( - 2.0, - 3.0, - {1: 0, 8: 1}, # e.g., H -> index 0, O -> index 1 - [2.0, 2.0], # broadcast - [3.0, 3.0], - ), - # shift, scale are same-length lists => directly used - ( - [0.5, 0.6], - [1.0, 1.1], - {1: 0, 8: 1}, - [0.5, 0.6], - [1.0, 1.1], - ), - # shift, scale are entire "universal" length (NUM_UNIV_ELEMENT=118), - # but we only map out the subset for the actual species in type_map - ( - [0.1] * NUM_UNIV_ELEMENT, - [1.1] * NUM_UNIV_ELEMENT, - {1: 0, 8: 1}, - [0.1, 0.1], - [1.1, 1.1], - ), - # shift is a list, scale is float => shift is used directly, scale broadcast - ( - [1.0, 2.0], - 5.0, - {6: 0, 14: 1}, # C -> 0, Si -> 1 - [1.0, 2.0], - [5.0, 5.0], - ), - ], -) -def test_specieswise_rescale_from_mappers( - shift, scale, type_map, expected_shift, expected_scale -): - """ - Test SpeciesWiseRescale.from_mappers with various combinations of - shift/scale (float, list, universal list) and a given type_map. - """ - module = SpeciesWiseRescale.from_mappers( # type: ignore - shift=shift, - scale=scale, - type_map=type_map, - ) - # Check that the module's internal shift and scale have the correct shape - # The length must match number of species in type_map - assert module.shift.shape[0] == len(type_map) - assert module.scale.shape[0] == len(type_map) - - # Check that the content matches expected - actual_shift = module.shift.detach().cpu().tolist() - actual_scale = module.scale.detach().cpu().tolist() - - assert pytest.approx(actual_shift) == expected_shift - assert pytest.approx(actual_scale) == expected_scale - - -@pytest.mark.parametrize( - 'shift, scale, use_modal_wise_shift, use_modal_wise_scale, ' - 'type_map, modal_map, expected_shift, expected_scale', - [ - # Example 1: single float for shift/scale, - # broadcast over 2 modals and 2 species - ( - 1.0, - 2.0, - True, # shift depends on modal - True, # scale depends on modal - {1: 0, 8: 1}, - {'modA': 0, 'modB': 1}, - # expect 2D => [2 modals x 2 species] - [[1.0, 1.0], [1.0, 1.0]], - [[2.0, 2.0], [2.0, 2.0]], - ), - # Example 2: shift/scale are universal element-lists => use_modal=False => 1D - ( - [0.5] * NUM_UNIV_ELEMENT, - [1.5] * NUM_UNIV_ELEMENT, - False, # shift is not modal-wise - False, # scale is not modal-wise - {6: 0, 14: 1}, # e.g. C->0, Si->1 - {'modA': 0, 'modB': 1}, - # 1D => length = n_atom_types(=2) - [0.5, 0.5], - [1.5, 1.5], - ), - # Example 3: shift is dict of modals -> each is float - # => broadcast for each species - ( - {'modA': 0.0, 'modB': 2.0}, - {'modA': 1.0, 'modB': 3.0}, - True, - True, - {1: 0, 8: 1}, - {'modA': 0, 'modB': 1}, - # shift => shape [2 modals, 2 species] - [[0.0, 0.0], [2.0, 2.0]], - [[1.0, 1.0], [3.0, 3.0]], - ), - # Example 4: already in "modal-wise + species-wise" shape, direct pass - ( - [[0.0, 10.0], [5.0, 15.0]], - [[1.0, 2.0], [10.0, 20.0]], - True, - True, - {1: 0, 8: 1}, - {'modA': 0, 'modB': 1}, - [[0.0, 10.0], [5.0, 15.0]], - [[1.0, 2.0], [10.0, 20.0]], - ), - # Example 5: shift is a list of floats (one per modal), - # but we want modal-wise => broadcast for each species - ( - [0.0, 10.0], # length=2 => same as #modals - [1.0, 2.0], - True, - True, - {1: 0, 8: 1}, - {'modA': 0, 'modB': 1}, - [[0.0, 0.0], [10.0, 10.0]], - [[1.0, 1.0], [2.0, 2.0]], - ), - ], -) -def test_modalwise_rescale_from_mappers( - shift, - scale, - use_modal_wise_shift, - use_modal_wise_scale, - type_map, - modal_map, - expected_shift, - expected_scale, -): - """ - Test ModalWiseRescale.from_mappers for different shapes of shift/scale, - combined with type_map and modal_map. - """ - - module = ModalWiseRescale.from_mappers( # type: ignore - shift=shift, - scale=scale, - use_modal_wise_shift=use_modal_wise_shift, - use_modal_wise_scale=use_modal_wise_scale, - type_map=type_map, - modal_map=modal_map, - ) - # Check shape of the resulting shift, scale - # If modal-wise, we expect a 2D shape: [n_modals, n_species] - # Otherwise, a 1D shape: [n_species] - if use_modal_wise_shift: - assert module.shift.dim() == 2 - assert module.shift.shape[0] == len(modal_map) - assert module.shift.shape[1] == len(type_map) - else: - assert module.shift.dim() == 1 - assert module.shift.shape[0] == len(type_map) - - # Similarly for scale - if use_modal_wise_scale: - assert module.scale.dim() == 2 - assert module.scale.shape[0] == len(modal_map) - assert module.scale.shape[1] == len(type_map) - else: - assert module.scale.dim() == 1 - assert module.scale.shape[0] == len(type_map) - - # Verify the content matches our expectation - actual_shift = module.shift.detach().cpu().tolist() - actual_scale = module.scale.detach().cpu().tolist() - - assert actual_shift == expected_shift - assert actual_scale == expected_scale +import pytest +import torch + +import sevenn._keys as KEY +from sevenn._const import NUM_UNIV_ELEMENT, AtomGraphDataType +from sevenn.nn.scale import ( + ModalWiseRescale, + Rescale, + SpeciesWiseRescale, + get_resolved_shift_scale, +) + +################################################################################ +# Tests for Rescale # +################################################################################ + + +@pytest.mark.parametrize('shift,scale', [(0.0, 1.0), (1.0, 2.0), (-5.0, 10.0)]) +def test_rescale_init(shift, scale): + """ + Test that Rescale can be initialized properly without errors + and that parameters are set correctly. + """ + module = Rescale(shift=shift, scale=scale) + assert module.shift.item() == shift + assert module.scale.item() == scale + assert module.key_input == KEY.SCALED_ATOMIC_ENERGY + assert module.key_output == KEY.ATOMIC_ENERGY + + +def test_rescale_forward(): + """ + Test that Rescale forward pass correctly applies: + output = input * scale + shift + """ + # Setup + shift, scale = 1.0, 2.0 + module = Rescale(shift=shift, scale=scale) + # Make some fake data + input_data = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) + data: AtomGraphDataType = {KEY.SCALED_ATOMIC_ENERGY: input_data.clone()} + + # Forward + out_data = module(data) + + # Check correctness + expected_output = input_data * scale + shift + assert torch.allclose(out_data[KEY.ATOMIC_ENERGY], expected_output) + + +def test_rescale_get_shift_and_scale(): + """ + Test get_shift() and get_scale() methods in Rescale. + """ + module = Rescale(shift=1.5, scale=3.5) + assert module.get_shift() == pytest.approx(1.5) + assert module.get_scale() == pytest.approx(3.5) + + +################################################################################ +# Tests for SpeciesWiseRescale # +################################################################################ + + +def test_specieswise_rescale_init_float(): + """ + Test SpeciesWiseRescale when both shift and scale are floats + (should expand to same length lists). + """ + module = SpeciesWiseRescale(shift=[1.0, -1.0], scale=2.0) + # Expect a parameter of length = 1 in this scenario, but can differ + # if we raise an error for "Both shift and scale is not a list". + # Usually, you'd specify a known number of species or do from_mappers. + # The code as-is throws ValueError if both are float. Let's do from_mappers: + # We'll do direct init if your code allows it. If not, use from_mappers. + assert module.shift.shape == module.scale.shape + # They must be single-parameter (or expanded) if not from mappers. + + +def test_specieswise_rescale_init_list(): + """ + Test initialization with list-based shift/scale of same length. + """ + shift = [1.0, 2.0, 3.0] + scale = [2.0, 3.0, 4.0] + module = SpeciesWiseRescale(shift=shift, scale=scale) + assert len(module.shift) == 3 + assert len(module.scale) == 3 + assert torch.allclose(module.shift, torch.tensor([1.0, 2.0, 3.0])) + assert torch.allclose(module.scale, torch.tensor([2.0, 3.0, 4.0])) + + +def test_specieswise_rescale_forward(): + """ + Test that SpeciesWiseRescale forward pass applies: + output[i] = input[i]*scale[atom_type[i]] + shift[atom_type[i]] + """ + # Suppose we have two species types: + # 0 -> shift=1, scale=2, 1 -> shift=5, scale=10 + # (we'll pass them as lists in the correct order) + shift = [1.0, 5.0] + scale = [2.0, 10.0] + module = SpeciesWiseRescale( + shift=shift, + scale=scale, + data_key_in='in', + data_key_out='out', + data_key_indices='z', + ) + + # Create mock data + # Suppose we have three atoms: species => [0, 1, 0] + # input => [ [1.], [1.], [3.] ] + data: AtomGraphDataType = { + 'z': torch.tensor([0, 1, 0], dtype=torch.long), + 'in': torch.tensor([[1.0], [1.0], [3.0]], dtype=torch.float), + } + + out = module(data) + # Now let's manually compute expected: + # For atom 0: scale=2, shift=1, input=1 => 1*2+1=3 + # For atom 1: scale=10, shift=5, input=1 => 1*10+5=15 + # For atom 2: scale=2, shift=1, input=3 => 3*2+1=7 + expected = torch.tensor([[3.0], [15.0], [7.0]]) + + assert torch.allclose(out['out'], expected) + + +def test_specieswise_rescale_get_shift_scale(): + """ + Test get_shift() and get_scale() with/without type_map. + """ + shift = [1.0, 2.0] + scale = [3.0, 4.0] + module = SpeciesWiseRescale(shift=shift, scale=scale) + + # Without type_map + # Should return the raw parameter values (list form). + s = module.get_shift() + sc = module.get_scale() + assert s == [1.0, 2.0] + assert sc == [3.0, 4.0] + + # With a type_map (example: atomic_number 1 -> 0, 8 -> 1) + type_map = {1: 0, 8: 1} # hydrogen, oxygen + s_univ = module.get_shift(type_map) + sc_univ = module.get_scale(type_map) + # In this small example with NUM_UNIV_ELEMENT = 2, the _as_univ will produce + # a list of length = NUM_UNIV_ELEMENT. If your real NUM_UNIV_ELEMENT is bigger, + # the rest would be padded with default values. + # For demonstration let's assume it returns [1.0, 2.0]. + # Check at least the known mapped portion: + assert len(s_univ) == NUM_UNIV_ELEMENT + assert len(sc_univ) == NUM_UNIV_ELEMENT + assert s_univ[1] == 1.0 # atomic_number=1 -> idx=0 -> shift=1.0 + assert s_univ[8] == 2.0 + + +################################################################################ +# Tests for ModalWiseRescale # +################################################################################ + + +def test_modalwise_rescale_init(): + """ + Basic sanity check for ModalWiseRescale initialization with + certain shapes. + """ + # Suppose we have 2 modals, 3 species => shift, scale is shape [2,3] + shift = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]] + scale = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + module = ModalWiseRescale( + shift=shift, + scale=scale, + use_modal_wise_shift=True, + use_modal_wise_scale=True, + ) + # Check shape + assert module.shift.shape == torch.Size([2, 3]) + assert module.scale.shape == torch.Size([2, 3]) + + +def test_modalwise_rescale_forward(): + """ + Test that the forward pass of ModalWiseRescale matches + output[i] = input[i] * scale[modal_i, atom_i] + shift[modal_i, atom_i] + when both use_modal_wise_{shift,scale} are True. + """ + shift = [[0.0, 10.0], [5.0, 15.0]] # shape [2 (modals), 2 (species)] + scale = [[1.0, 2.0], [10.0, 20.0]] + module = ModalWiseRescale( + shift=shift, + scale=scale, + data_key_in='in', + data_key_out='out', + data_key_modal_indices='modal_idx', + data_key_atom_indices='atom_idx', + use_modal_wise_shift=True, + use_modal_wise_scale=True, + ) + + data: AtomGraphDataType = { + 'in': torch.tensor([[1.0], [1.0], [2.0], [2.0]]), + 'modal_idx': torch.tensor([0, 1], dtype=torch.long), + 'atom_idx': torch.tensor([0, 1, 0, 1], dtype=torch.long), + 'batch': torch.tensor([0, 0, 1, 1], dtype=torch.long), + } + + out = module(data) + # i=0 => modal_idx=0, atom_idx=0 => shift=0.0, scale=1.0 => out=1*1+0=1 + # i=1 => modal_idx=0, atom_idx=1 => shift=10.0, scale=2.0 => out=1*2+10=12 + # i=2 => modal_idx=1, atom_idx=0 => shift=5.0, scale=10.0 => out=2*10+5=25 + # i=3 => modal_idx=1, atom_idx=1 => shift=15.0, scale=20.0 => out=2*20+15=55 + expected = torch.tensor([[1.0], [12.0], [25.0], [55.0]]) + assert torch.allclose(out['out'], expected) + + +def test_modalwise_rescale_get_shift_scale(): + """ + Test get_shift() and get_scale() with type_map and modal_map. + """ + # Setup + shift = [[0.0, 10.0], [5.0, 15.0]] + scale = [[1.0, 2.0], [10.0, 20.0]] + mod = ModalWiseRescale( + shift=shift, + scale=scale, + use_modal_wise_shift=True, + use_modal_wise_scale=True, + ) + + # Suppose we have type_map and modal_map + type_map = {1: 0, 8: 1} # Example: H->0, O->1 + modal_map = {'a': 0, 'b': 1} + + # get_shift, get_scale + s = mod.get_shift(type_map=type_map, modal_map=modal_map) + sc = mod.get_scale(type_map=type_map, modal_map=modal_map) + # Expect dict with keys "ambient", "pressure". + # Example: s["ambient"] = [ shift(0,0), shift(0,1) ] mapped to H,O + # s["pressure"] = [ shift(1,0), shift(1,1) ] + assert isinstance(s, dict) and isinstance(sc, dict) + assert set(s.keys()) == {'a', 'b'} + assert set(sc.keys()) == {'a', 'b'} + + +################################################################################ +# Tests for get_resolved_shift_scale function # +################################################################################ + + +def test_get_resolved_shift_scale_rescale(): + """ + Test get_resolved_shift_scale for a Rescale instance. + """ + from_m = Rescale(shift=2.0, scale=5.0) + shift, scale = get_resolved_shift_scale(from_m) + assert shift == 2.0 + assert scale == 5.0 + + +def test_get_resolved_shift_scale_specieswise(): + """ + Test get_resolved_shift_scale for a SpeciesWiseRescale instance. + """ + shift_list = [1.0, 2.0] + scale_list = [3.0, 4.0] + module = SpeciesWiseRescale(shift=shift_list, scale=scale_list) + type_map = {1: 0, 8: 1} + s, sc = get_resolved_shift_scale(module, type_map=type_map) + # The result should be extended to NUM_UNIV_ELEMENT length in real usage, + # but at least the first few should match shift_list, scale_list mapped. + assert isinstance(s, list) + assert isinstance(sc, list) + # Check mapped values + assert s[1] == shift_list[0] + assert s[8] == shift_list[1] + assert sc[1] == scale_list[0] + assert sc[8] == scale_list[1] + + +def test_get_resolved_shift_scale_modalwise(): + """ + Test get_resolved_shift_scale for a ModalWiseRescale instance. + """ + shift = [[0.0, 10.0], [5.0, 15.0]] + scale = [[1.0, 2.0], [10.0, 20.0]] + mmod = ModalWiseRescale( + shift=shift, + scale=scale, + use_modal_wise_shift=True, + use_modal_wise_scale=True, + ) + type_map = {1: 0, 8: 1} + modal_map = {'a': 0, 'b': 1} + s, sc = get_resolved_shift_scale(mmod, type_map=type_map, modal_map=modal_map) + # We expect dictionaries + assert isinstance(s, dict) and isinstance(sc, dict) + # Keys "a", "pressure" + assert 'a' in s + assert 'b' in s + # Check one example + # s["a"] => [0.0, 10.0] + # sc["a"] => [1.0, 2.0] + assert s['a'][1] == 0.0 + assert s['a'][8] == 10.0 + assert sc['a'][1] == 1.0 + assert sc['a'][8] == 2.0 + + +################################################################################ +# Tests for from_mappers function # +################################################################################ + + +@pytest.mark.parametrize( + 'shift, scale, type_map, expected_shift, expected_scale', + [ + # Both shift and scale are floats -> broadcast to each species + ( + 2.0, + 3.0, + {1: 0, 8: 1}, # e.g., H -> index 0, O -> index 1 + [2.0, 2.0], # broadcast + [3.0, 3.0], + ), + # shift, scale are same-length lists => directly used + ( + [0.5, 0.6], + [1.0, 1.1], + {1: 0, 8: 1}, + [0.5, 0.6], + [1.0, 1.1], + ), + # shift, scale are entire "universal" length (NUM_UNIV_ELEMENT=118), + # but we only map out the subset for the actual species in type_map + ( + [0.1] * NUM_UNIV_ELEMENT, + [1.1] * NUM_UNIV_ELEMENT, + {1: 0, 8: 1}, + [0.1, 0.1], + [1.1, 1.1], + ), + # shift is a list, scale is float => shift is used directly, scale broadcast + ( + [1.0, 2.0], + 5.0, + {6: 0, 14: 1}, # C -> 0, Si -> 1 + [1.0, 2.0], + [5.0, 5.0], + ), + ], +) +def test_specieswise_rescale_from_mappers( + shift, scale, type_map, expected_shift, expected_scale +): + """ + Test SpeciesWiseRescale.from_mappers with various combinations of + shift/scale (float, list, universal list) and a given type_map. + """ + module = SpeciesWiseRescale.from_mappers( # type: ignore + shift=shift, + scale=scale, + type_map=type_map, + ) + # Check that the module's internal shift and scale have the correct shape + # The length must match number of species in type_map + assert module.shift.shape[0] == len(type_map) + assert module.scale.shape[0] == len(type_map) + + # Check that the content matches expected + actual_shift = module.shift.detach().cpu().tolist() + actual_scale = module.scale.detach().cpu().tolist() + + assert pytest.approx(actual_shift) == expected_shift + assert pytest.approx(actual_scale) == expected_scale + + +@pytest.mark.parametrize( + 'shift, scale, use_modal_wise_shift, use_modal_wise_scale, ' + 'type_map, modal_map, expected_shift, expected_scale', + [ + # Example 1: single float for shift/scale, + # broadcast over 2 modals and 2 species + ( + 1.0, + 2.0, + True, # shift depends on modal + True, # scale depends on modal + {1: 0, 8: 1}, + {'modA': 0, 'modB': 1}, + # expect 2D => [2 modals x 2 species] + [[1.0, 1.0], [1.0, 1.0]], + [[2.0, 2.0], [2.0, 2.0]], + ), + # Example 2: shift/scale are universal element-lists => use_modal=False => 1D + ( + [0.5] * NUM_UNIV_ELEMENT, + [1.5] * NUM_UNIV_ELEMENT, + False, # shift is not modal-wise + False, # scale is not modal-wise + {6: 0, 14: 1}, # e.g. C->0, Si->1 + {'modA': 0, 'modB': 1}, + # 1D => length = n_atom_types(=2) + [0.5, 0.5], + [1.5, 1.5], + ), + # Example 3: shift is dict of modals -> each is float + # => broadcast for each species + ( + {'modA': 0.0, 'modB': 2.0}, + {'modA': 1.0, 'modB': 3.0}, + True, + True, + {1: 0, 8: 1}, + {'modA': 0, 'modB': 1}, + # shift => shape [2 modals, 2 species] + [[0.0, 0.0], [2.0, 2.0]], + [[1.0, 1.0], [3.0, 3.0]], + ), + # Example 4: already in "modal-wise + species-wise" shape, direct pass + ( + [[0.0, 10.0], [5.0, 15.0]], + [[1.0, 2.0], [10.0, 20.0]], + True, + True, + {1: 0, 8: 1}, + {'modA': 0, 'modB': 1}, + [[0.0, 10.0], [5.0, 15.0]], + [[1.0, 2.0], [10.0, 20.0]], + ), + # Example 5: shift is a list of floats (one per modal), + # but we want modal-wise => broadcast for each species + ( + [0.0, 10.0], # length=2 => same as #modals + [1.0, 2.0], + True, + True, + {1: 0, 8: 1}, + {'modA': 0, 'modB': 1}, + [[0.0, 0.0], [10.0, 10.0]], + [[1.0, 1.0], [2.0, 2.0]], + ), + ], +) +def test_modalwise_rescale_from_mappers( + shift, + scale, + use_modal_wise_shift, + use_modal_wise_scale, + type_map, + modal_map, + expected_shift, + expected_scale, +): + """ + Test ModalWiseRescale.from_mappers for different shapes of shift/scale, + combined with type_map and modal_map. + """ + + module = ModalWiseRescale.from_mappers( # type: ignore + shift=shift, + scale=scale, + use_modal_wise_shift=use_modal_wise_shift, + use_modal_wise_scale=use_modal_wise_scale, + type_map=type_map, + modal_map=modal_map, + ) + # Check shape of the resulting shift, scale + # If modal-wise, we expect a 2D shape: [n_modals, n_species] + # Otherwise, a 1D shape: [n_species] + if use_modal_wise_shift: + assert module.shift.dim() == 2 + assert module.shift.shape[0] == len(modal_map) + assert module.shift.shape[1] == len(type_map) + else: + assert module.shift.dim() == 1 + assert module.shift.shape[0] == len(type_map) + + # Similarly for scale + if use_modal_wise_scale: + assert module.scale.dim() == 2 + assert module.scale.shape[0] == len(modal_map) + assert module.scale.shape[1] == len(type_map) + else: + assert module.scale.dim() == 1 + assert module.scale.shape[0] == len(type_map) + + # Verify the content matches our expectation + actual_shift = module.shift.detach().cpu().tolist() + actual_scale = module.scale.detach().cpu().tolist() + + assert actual_shift == expected_shift + assert actual_scale == expected_scale diff --git a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py index f758a6c..089fe82 100644 --- a/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py +++ b/mace-bench/3rdparty/SevenNet/tests/unit_tests/test_train.py @@ -1,402 +1,402 @@ -import pathlib - -import ase.io -import numpy as np -import pytest -import torch -from torch_geometric.loader import DataLoader - -import sevenn.train.graph_dataset as graph_ds -from sevenn._const import NUM_UNIV_ELEMENT -from sevenn.error_recorder import ErrorRecorder -from sevenn.logger import Logger -from sevenn.scripts.processing_continue import processing_continue_v2 -from sevenn.scripts.processing_epoch import processing_epoch_v2 -from sevenn.train.dataload import graph_build -from sevenn.train.graph_dataset import from_config as dataset_from_config -from sevenn.train.loss import get_loss_functions_from_config -from sevenn.train.trainer import Trainer -from sevenn.util import ( - chemical_species_preprocess, - get_error_recorder, - pretrained_name_to_path, -) - -cutoff = 4.0 - -data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() - -hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') -cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') -sevennet_0_path = pretrained_name_to_path('7net-0_11July2024') - -known_elements = ['Hf', 'O'] -_elemwise_ref_energy_dct = {72: -17.379337, 8: -34.7499924} - -Logger() # init - - -@pytest.fixture() -def HfO2_atoms(): - atoms = ase.io.read(hfo2_path) - return atoms - - -@pytest.fixture(scope='module') -def HfO2_loader(): - atoms = ase.io.read(hfo2_path, index=':') - assert isinstance(atoms, list) - graphs = graph_build(atoms, cutoff, y_from_calc=True) - return DataLoader(graphs, batch_size=2) - - -@pytest.fixture(scope='module') -def graph_dataset_path(tmp_path_factory): - gd_path = tmp_path_factory.mktemp('gd') - ds = graph_ds.SevenNetGraphDataset( - cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt' - ) - return ds.processed_paths[0] - - -def get_model_config(): - config = { - 'cutoff': cutoff, - 'channel': 4, - 'radial_basis': { - 'radial_basis_name': 'bessel', - }, - 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, - 'interaction_type': 'nequip', - 'lmax': 2, - 'is_parity': True, - 'num_convolution_layer': 3, - 'weight_nn_hidden_neurons': [64, 64], - 'act_radial': 'silu', - 'act_scalar': {'e': 'silu', 'o': 'tanh'}, - 'act_gate': {'e': 'silu', 'o': 'tanh'}, - 'conv_denominator': 'avg_num_neigh', - 'train_denominator': False, - 'self_connection_type': 'nequip', - 'train_shift_scale': False, - 'irreps_manual': False, - 'lmax_edge': -1, - 'lmax_node': -1, - 'readout_as_fcn': False, - 'use_bias_in_linear': False, - '_normalize_sph': True, - } - config.update(**chemical_species_preprocess(known_elements)) - return config - - -def get_train_config(): - config = { - 'random_seed': 1, - 'epoch': 2, - 'loss': 'mse', - 'loss_param': {}, - 'optimizer': 'adam', - 'optim_param': {}, - 'scheduler': 'exponentiallr', - 'scheduler_param': {'gamma': 0.99}, - 'force_loss_weight': 1.0, - 'stress_loss_weight': 0.1, - 'per_epoch': 1, - 'continue': { - 'checkpoint': False, - 'reset_optimizer': False, - 'reset_scheduler': False, - 'reset_epoch': False, - }, - 'is_train_stress': True, - 'train_shuffle': True, - 'best_metric': 'TotalLoss', - 'error_record': [ - ('Energy', 'RMSE'), - ('Force', 'RMSE'), - ('Stress', 'RMSE'), - ('TotalLoss', 'None'), - ], - 'use_modality': False, - 'use_weight': False, - 'device': 'cpu', - 'is_ddp': False, - } - return config - - -def get_data_config(): - config = { - 'batch_size': 2, - 'shift': 'per_atom_energy_mean', - 'scale': 'force_rms', - 'preprocess_num_cores': 1, - 'data_format_args': {}, - 'load_trainset_path': hfo2_path, - } - return config - - -def get_config(overwrite=None): - cf = {} - cf.update(get_model_config()) - cf.update(get_train_config()) - cf.update(get_data_config()) - if overwrite: - cf.update(overwrite) - return cf - - -def test_processing_continue_v2_7net0(tmp_path): - cp = torch.load(sevennet_0_path, weights_only=False, map_location='cpu') - - cfg = get_config( - { - 'continue': { - 'checkpoint': sevennet_0_path, - 'reset_optimizer': False, - 'reset_scheduler': True, - 'reset_epoch': False, - } - } - ) - shift_ref = cp['model_state_dict']['rescale_atomic_energy.shift'].cpu().numpy() - scale_ref = np.array([1.73] * 89) - conv_denominator_ref = np.array([35.989574] * 5) - - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - state_dicts, epoch = processing_continue_v2(cfg) - assert epoch == 601 - assert np.allclose(np.array(cfg['shift']), shift_ref) - assert np.allclose(np.array(cfg['shift'])[0], -5.062768) - assert np.allclose(np.array(cfg['scale']), scale_ref) - assert np.allclose(np.array(cfg['conv_denominator']), conv_denominator_ref) - assert cfg['_number_of_species'] == 89 - assert cfg['_type_map'][89] == 0 # Ac - assert cfg['_type_map'][40] == 88 # Zr - assert state_dicts[2] is None # scheduler reset - - -@pytest.mark.parametrize( - 'cfg_overwrite,ds_names', - [ - ({}, ['trainset']), - ({'load_myset_path': hfo2_path}, ['trainset', 'myset']), - ], -) -def test_dataset_from_config(cfg_overwrite, ds_names, tmp_path): - cfg = get_config(cfg_overwrite) - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - datasets = dataset_from_config(cfg, tmp_path) - - assert set(ds_names) == set(datasets.keys()) - for ds_name in ds_names: - assert (tmp_path / 'sevenn_data' / f'{ds_name}.pt').is_file() - assert (tmp_path / 'sevenn_data' / f'{ds_name}.yaml').is_file() - - -def test_dataset_from_config_as_it_is_load(graph_dataset_path, tmp_path): - cfg = get_config({'load_trainset_path': graph_dataset_path}) - new_wd = tmp_path / 'tmp_wd' - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - _ = dataset_from_config(cfg, str(new_wd)) - print((tmp_path / 'tmp_wd' / 'sevenn_data')) - assert not (tmp_path / 'tmp_wd' / 'sevenn_data').is_dir() - - -@pytest.mark.parametrize( - 'cfg_overwrite,shift,scale,conv', - [ - ( - {}, - -28.978, - 0.113304, - 25.333333, - ), - ( - { - 'shift': -1.2345678, - }, - -1.234567, - 0.113304, - 25.333333, - ), - ( - { - 'conv_denominator': 'sqrt_avg_num_neigh', - }, - -28.978, - 0.113304, - 25.333333**0.5, - ), - ( - { - 'shift': 'force_rms', - }, - 0.113304, - 0.113304, - 25.333333, - ), - ( - { - 'shift': 'elemwise_reference_energies', - }, - [ - 0.0 - if z not in _elemwise_ref_energy_dct - else _elemwise_ref_energy_dct[z] - for z in range(NUM_UNIV_ELEMENT) - ], - 0.113304, - 25.333333, - ), - ], -) -def test_dataset_from_config_statistics_init( - cfg_overwrite, shift, scale, conv, tmp_path -): - cfg = get_config(cfg_overwrite) - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - _ = dataset_from_config(cfg, tmp_path) - - assert np.allclose(cfg['shift'], shift) - assert np.allclose(cfg['scale'], scale) - assert np.allclose(cfg['conv_denominator'], conv) - - -def test_dataset_from_config_chem_auto(tmp_path): - cfg = get_config( - { - 'chemical_species': 'auto', - '_number_of_species': 'auto', - '_type_map': 'auto', - } - ) - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - _ = dataset_from_config(cfg, tmp_path) - assert cfg['chemical_species'] == ['Hf', 'O'] - assert cfg['_number_of_species'] == 2 - assert cfg['_type_map'] == {72: 0, 8: 1} - - -def test_run_one_epoch(HfO2_loader): - trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) - trainer = Trainer(**trainer_args) - erc = get_error_recorder() - - ref1 = { - 'Energy_RMSE': '28.977758', - 'Force_RMSE': '0.214107', - 'Stress_RMSE': '190.014237', - } - - ref2 = { - 'Energy_RMSE': '28.977878', - 'Force_RMSE': '0.213105', - 'Stress_RMSE': '188.772557', - } - - trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc) - ret1 = erc.get_dct() - erc.epoch_forward() - - for k in ref1: - assert np.allclose(float(ret1[k]), float(ref1[k])) - - trainer.run_one_epoch(HfO2_loader, is_train=True, error_recorder=erc) - erc.epoch_forward() - - trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc) - ret2 = erc.get_dct() - erc.epoch_forward() - - for k in ref2: - assert np.allclose(float(ret2[k]), float(ref2[k])) - - -def test_processing_epoch_v2(HfO2_loader, tmp_path): - trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) - trainer = Trainer(**trainer_args) - erc = get_error_recorder() - start_epoch = 10 - total_epoch = 12 - per_epoch = 1 - best_metric = 'Energy_RMSE' - best_metric_loader_key = 'myset' - loaders = {'trainset': HfO2_loader, 'myset': HfO2_loader} - - with Logger().switch_file(str(tmp_path / 'log.sevenn')): - processing_epoch_v2( - config={}, - trainer=trainer, - loaders=loaders, - start_epoch=start_epoch, - error_recorder=erc, - total_epoch=total_epoch, - per_epoch=per_epoch, - best_metric_loader_key=best_metric_loader_key, - best_metric=best_metric, - working_dir=tmp_path, - ) - assert (tmp_path / 'checkpoint_10.pth').is_file() - assert (tmp_path / 'checkpoint_11.pth').is_file() - assert (tmp_path / 'checkpoint_12.pth').is_file() - assert (tmp_path / 'checkpoint_best.pth').is_file() - assert (tmp_path / 'lc.csv').is_file() - with open(tmp_path / 'lc.csv', 'r') as f: - lines = f.readlines() - heads = [ll.strip() for ll in lines[0].split(',')] - assert 'epoch' in heads - assert 'lr' in heads - assert 'trainset_Energy_RMSE' in heads - assert 'myset_Stress_MAE' in heads - lasts = [ll.strip() for ll in lines[-1].split(',')] - assert lasts[0] == '12' - assert lasts[1] == '0.000980' # lr - assert lasts[-2] == '0.087873' # myset Force MAE - - -def test_data_weight(graph_dataset_path, tmp_path): - cfg = get_config( - { - 'load_trainset_path': [{ - 'file_list': [{'file': graph_dataset_path}], - 'data_weight': {'energy': 0.1, 'force': 3.0, 'stress': 1.0}, - }], - 'error_record': [ - ('Energy', 'Loss'), - ('Force', 'Loss'), - ('Stress', 'Loss'), - ('TotalLoss', 'None'), - ], - 'use_weight': True - } - ) - trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) - trainer_args['loss_functions'] = get_loss_functions_from_config(cfg) - trainer = Trainer(**trainer_args) - erc = ErrorRecorder.from_config(cfg, trainer.loss_functions) - - db = graph_ds.from_config(cfg, working_dir=tmp_path)['trainset'] - loader_w_weight = DataLoader(db, batch_size=len(db)) - - trainer.run_one_epoch(loader_w_weight, False, erc) - loss = erc.epoch_forward() - assert np.allclose(loss['Energy_Loss'], 839.7104492 * 0.1) - assert np.allclose(loss['Force_Loss'], 0.0152806 * 3.0) - assert np.allclose(loss['Stress_Loss'], 6017.568847 * 1.0) - - -def _write_empty_checkpoint(): - from sevenn.model_build import build_E3_equivariant_model - - # Function I used to make empty checkpoint, to write the test - cfg = get_config({'shift': 0.0, 'scale': 1.0, 'conv_denominator': 5.0}) - model = build_E3_equivariant_model(cfg) - trainer = Trainer.from_config(model, cfg) # type: ignore - trainer.write_checkpoint('./cp_0.pth', config=cfg, epoch=0) - - -if __name__ == '__main__': - _write_empty_checkpoint() +import pathlib + +import ase.io +import numpy as np +import pytest +import torch +from torch_geometric.loader import DataLoader + +import sevenn.train.graph_dataset as graph_ds +from sevenn._const import NUM_UNIV_ELEMENT +from sevenn.error_recorder import ErrorRecorder +from sevenn.logger import Logger +from sevenn.scripts.processing_continue import processing_continue_v2 +from sevenn.scripts.processing_epoch import processing_epoch_v2 +from sevenn.train.dataload import graph_build +from sevenn.train.graph_dataset import from_config as dataset_from_config +from sevenn.train.loss import get_loss_functions_from_config +from sevenn.train.trainer import Trainer +from sevenn.util import ( + chemical_species_preprocess, + get_error_recorder, + pretrained_name_to_path, +) + +cutoff = 4.0 + +data_root = (pathlib.Path(__file__).parent.parent / 'data').resolve() + +hfo2_path = str(data_root / 'systems' / 'hfo2.extxyz') +cp_0_path = str(data_root / 'checkpoints' / 'cp_0.pth') +sevennet_0_path = pretrained_name_to_path('7net-0_11July2024') + +known_elements = ['Hf', 'O'] +_elemwise_ref_energy_dct = {72: -17.379337, 8: -34.7499924} + +Logger() # init + + +@pytest.fixture() +def HfO2_atoms(): + atoms = ase.io.read(hfo2_path) + return atoms + + +@pytest.fixture(scope='module') +def HfO2_loader(): + atoms = ase.io.read(hfo2_path, index=':') + assert isinstance(atoms, list) + graphs = graph_build(atoms, cutoff, y_from_calc=True) + return DataLoader(graphs, batch_size=2) + + +@pytest.fixture(scope='module') +def graph_dataset_path(tmp_path_factory): + gd_path = tmp_path_factory.mktemp('gd') + ds = graph_ds.SevenNetGraphDataset( + cutoff=cutoff, root=str(gd_path), files=[hfo2_path], processed_name='tmp.pt' + ) + return ds.processed_paths[0] + + +def get_model_config(): + config = { + 'cutoff': cutoff, + 'channel': 4, + 'radial_basis': { + 'radial_basis_name': 'bessel', + }, + 'cutoff_function': {'cutoff_function_name': 'poly_cut'}, + 'interaction_type': 'nequip', + 'lmax': 2, + 'is_parity': True, + 'num_convolution_layer': 3, + 'weight_nn_hidden_neurons': [64, 64], + 'act_radial': 'silu', + 'act_scalar': {'e': 'silu', 'o': 'tanh'}, + 'act_gate': {'e': 'silu', 'o': 'tanh'}, + 'conv_denominator': 'avg_num_neigh', + 'train_denominator': False, + 'self_connection_type': 'nequip', + 'train_shift_scale': False, + 'irreps_manual': False, + 'lmax_edge': -1, + 'lmax_node': -1, + 'readout_as_fcn': False, + 'use_bias_in_linear': False, + '_normalize_sph': True, + } + config.update(**chemical_species_preprocess(known_elements)) + return config + + +def get_train_config(): + config = { + 'random_seed': 1, + 'epoch': 2, + 'loss': 'mse', + 'loss_param': {}, + 'optimizer': 'adam', + 'optim_param': {}, + 'scheduler': 'exponentiallr', + 'scheduler_param': {'gamma': 0.99}, + 'force_loss_weight': 1.0, + 'stress_loss_weight': 0.1, + 'per_epoch': 1, + 'continue': { + 'checkpoint': False, + 'reset_optimizer': False, + 'reset_scheduler': False, + 'reset_epoch': False, + }, + 'is_train_stress': True, + 'train_shuffle': True, + 'best_metric': 'TotalLoss', + 'error_record': [ + ('Energy', 'RMSE'), + ('Force', 'RMSE'), + ('Stress', 'RMSE'), + ('TotalLoss', 'None'), + ], + 'use_modality': False, + 'use_weight': False, + 'device': 'cpu', + 'is_ddp': False, + } + return config + + +def get_data_config(): + config = { + 'batch_size': 2, + 'shift': 'per_atom_energy_mean', + 'scale': 'force_rms', + 'preprocess_num_cores': 1, + 'data_format_args': {}, + 'load_trainset_path': hfo2_path, + } + return config + + +def get_config(overwrite=None): + cf = {} + cf.update(get_model_config()) + cf.update(get_train_config()) + cf.update(get_data_config()) + if overwrite: + cf.update(overwrite) + return cf + + +def test_processing_continue_v2_7net0(tmp_path): + cp = torch.load(sevennet_0_path, weights_only=False, map_location='cpu') + + cfg = get_config( + { + 'continue': { + 'checkpoint': sevennet_0_path, + 'reset_optimizer': False, + 'reset_scheduler': True, + 'reset_epoch': False, + } + } + ) + shift_ref = cp['model_state_dict']['rescale_atomic_energy.shift'].cpu().numpy() + scale_ref = np.array([1.73] * 89) + conv_denominator_ref = np.array([35.989574] * 5) + + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + state_dicts, epoch = processing_continue_v2(cfg) + assert epoch == 601 + assert np.allclose(np.array(cfg['shift']), shift_ref) + assert np.allclose(np.array(cfg['shift'])[0], -5.062768) + assert np.allclose(np.array(cfg['scale']), scale_ref) + assert np.allclose(np.array(cfg['conv_denominator']), conv_denominator_ref) + assert cfg['_number_of_species'] == 89 + assert cfg['_type_map'][89] == 0 # Ac + assert cfg['_type_map'][40] == 88 # Zr + assert state_dicts[2] is None # scheduler reset + + +@pytest.mark.parametrize( + 'cfg_overwrite,ds_names', + [ + ({}, ['trainset']), + ({'load_myset_path': hfo2_path}, ['trainset', 'myset']), + ], +) +def test_dataset_from_config(cfg_overwrite, ds_names, tmp_path): + cfg = get_config(cfg_overwrite) + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + datasets = dataset_from_config(cfg, tmp_path) + + assert set(ds_names) == set(datasets.keys()) + for ds_name in ds_names: + assert (tmp_path / 'sevenn_data' / f'{ds_name}.pt').is_file() + assert (tmp_path / 'sevenn_data' / f'{ds_name}.yaml').is_file() + + +def test_dataset_from_config_as_it_is_load(graph_dataset_path, tmp_path): + cfg = get_config({'load_trainset_path': graph_dataset_path}) + new_wd = tmp_path / 'tmp_wd' + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + _ = dataset_from_config(cfg, str(new_wd)) + print((tmp_path / 'tmp_wd' / 'sevenn_data')) + assert not (tmp_path / 'tmp_wd' / 'sevenn_data').is_dir() + + +@pytest.mark.parametrize( + 'cfg_overwrite,shift,scale,conv', + [ + ( + {}, + -28.978, + 0.113304, + 25.333333, + ), + ( + { + 'shift': -1.2345678, + }, + -1.234567, + 0.113304, + 25.333333, + ), + ( + { + 'conv_denominator': 'sqrt_avg_num_neigh', + }, + -28.978, + 0.113304, + 25.333333**0.5, + ), + ( + { + 'shift': 'force_rms', + }, + 0.113304, + 0.113304, + 25.333333, + ), + ( + { + 'shift': 'elemwise_reference_energies', + }, + [ + 0.0 + if z not in _elemwise_ref_energy_dct + else _elemwise_ref_energy_dct[z] + for z in range(NUM_UNIV_ELEMENT) + ], + 0.113304, + 25.333333, + ), + ], +) +def test_dataset_from_config_statistics_init( + cfg_overwrite, shift, scale, conv, tmp_path +): + cfg = get_config(cfg_overwrite) + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + _ = dataset_from_config(cfg, tmp_path) + + assert np.allclose(cfg['shift'], shift) + assert np.allclose(cfg['scale'], scale) + assert np.allclose(cfg['conv_denominator'], conv) + + +def test_dataset_from_config_chem_auto(tmp_path): + cfg = get_config( + { + 'chemical_species': 'auto', + '_number_of_species': 'auto', + '_type_map': 'auto', + } + ) + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + _ = dataset_from_config(cfg, tmp_path) + assert cfg['chemical_species'] == ['Hf', 'O'] + assert cfg['_number_of_species'] == 2 + assert cfg['_type_map'] == {72: 0, 8: 1} + + +def test_run_one_epoch(HfO2_loader): + trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) + trainer = Trainer(**trainer_args) + erc = get_error_recorder() + + ref1 = { + 'Energy_RMSE': '28.977758', + 'Force_RMSE': '0.214107', + 'Stress_RMSE': '190.014237', + } + + ref2 = { + 'Energy_RMSE': '28.977878', + 'Force_RMSE': '0.213105', + 'Stress_RMSE': '188.772557', + } + + trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc) + ret1 = erc.get_dct() + erc.epoch_forward() + + for k in ref1: + assert np.allclose(float(ret1[k]), float(ref1[k])) + + trainer.run_one_epoch(HfO2_loader, is_train=True, error_recorder=erc) + erc.epoch_forward() + + trainer.run_one_epoch(HfO2_loader, is_train=False, error_recorder=erc) + ret2 = erc.get_dct() + erc.epoch_forward() + + for k in ref2: + assert np.allclose(float(ret2[k]), float(ref2[k])) + + +def test_processing_epoch_v2(HfO2_loader, tmp_path): + trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) + trainer = Trainer(**trainer_args) + erc = get_error_recorder() + start_epoch = 10 + total_epoch = 12 + per_epoch = 1 + best_metric = 'Energy_RMSE' + best_metric_loader_key = 'myset' + loaders = {'trainset': HfO2_loader, 'myset': HfO2_loader} + + with Logger().switch_file(str(tmp_path / 'log.sevenn')): + processing_epoch_v2( + config={}, + trainer=trainer, + loaders=loaders, + start_epoch=start_epoch, + error_recorder=erc, + total_epoch=total_epoch, + per_epoch=per_epoch, + best_metric_loader_key=best_metric_loader_key, + best_metric=best_metric, + working_dir=tmp_path, + ) + assert (tmp_path / 'checkpoint_10.pth').is_file() + assert (tmp_path / 'checkpoint_11.pth').is_file() + assert (tmp_path / 'checkpoint_12.pth').is_file() + assert (tmp_path / 'checkpoint_best.pth').is_file() + assert (tmp_path / 'lc.csv').is_file() + with open(tmp_path / 'lc.csv', 'r') as f: + lines = f.readlines() + heads = [ll.strip() for ll in lines[0].split(',')] + assert 'epoch' in heads + assert 'lr' in heads + assert 'trainset_Energy_RMSE' in heads + assert 'myset_Stress_MAE' in heads + lasts = [ll.strip() for ll in lines[-1].split(',')] + assert lasts[0] == '12' + assert lasts[1] == '0.000980' # lr + assert lasts[-2] == '0.087873' # myset Force MAE + + +def test_data_weight(graph_dataset_path, tmp_path): + cfg = get_config( + { + 'load_trainset_path': [{ + 'file_list': [{'file': graph_dataset_path}], + 'data_weight': {'energy': 0.1, 'force': 3.0, 'stress': 1.0}, + }], + 'error_record': [ + ('Energy', 'Loss'), + ('Force', 'Loss'), + ('Stress', 'Loss'), + ('TotalLoss', 'None'), + ], + 'use_weight': True + } + ) + trainer_args, _, _ = Trainer.args_from_checkpoint(cp_0_path) + trainer_args['loss_functions'] = get_loss_functions_from_config(cfg) + trainer = Trainer(**trainer_args) + erc = ErrorRecorder.from_config(cfg, trainer.loss_functions) + + db = graph_ds.from_config(cfg, working_dir=tmp_path)['trainset'] + loader_w_weight = DataLoader(db, batch_size=len(db)) + + trainer.run_one_epoch(loader_w_weight, False, erc) + loss = erc.epoch_forward() + assert np.allclose(loss['Energy_Loss'], 839.7104492 * 0.1) + assert np.allclose(loss['Force_Loss'], 0.0152806 * 3.0) + assert np.allclose(loss['Stress_Loss'], 6017.568847 * 1.0) + + +def _write_empty_checkpoint(): + from sevenn.model_build import build_E3_equivariant_model + + # Function I used to make empty checkpoint, to write the test + cfg = get_config({'shift': 0.0, 'scale': 1.0, 'conv_denominator': 5.0}) + model = build_E3_equivariant_model(cfg) + trainer = Trainer.from_config(model, cfg) # type: ignore + trainer.write_checkpoint('./cp_0.pth', config=cfg, epoch=0) + + +if __name__ == '__main__': + _write_empty_checkpoint() diff --git a/mace-bench/3rdparty/mace/mace/__init__.py b/mace-bench/3rdparty/mace/mace/__init__.py index 711d144..490c4c0 100644 --- a/mace-bench/3rdparty/mace/mace/__init__.py +++ b/mace-bench/3rdparty/mace/mace/__init__.py @@ -1,5 +1,5 @@ -import os - -from .__version__ import __version__ - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" +import os + +from .__version__ import __version__ + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" diff --git a/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index ad042d7871652a5f00a02c0a2bd7d6b845651629..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 285 zcmYk0!AiqG5Qb+r+7`7wgLun9gLo~X)M|y2s40lK3^BWs4asiU#6o; z@Z_YuI56MLKQr);8IKQu;_G{zzf=9`#sBa~TvIlu1ObATkbvOAD_O-7#z5>zD7e_4 z&@Kjlu%SB$H;m{rcZj%>woyhSI(CtqETiY?9PfxT3?qD*J>1WiF-GCy6&KOXt?xN* z8r$2tq0}p7jb4KtlHTI?DacJFgDU;ZQ&Xsvx=fzc))X>pgSu;T@|wnvyZM{cSsq+k kQKz=;wzC}&l~yg{dEL3uq%z{8lvn-J&44h1BYw>P07%hFSpWb4 diff --git a/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 7895a9a9967da9436428f55384836e3af3f7eb44..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 315 zcmey&%ge<81amD{WTXM<#~=<2Fhd!i&47%l48aV+jNS}I48csn%-)P%j75wJ48bfh z3_&a~4G21zRUODlXVheS2@>>s$p|Dh8E#mC=bG`yt{;veMf5$^^h zUE}@yg`kg0)*$W*7nK#~=Y)eG7|H% zG82KUncyyOghPO2Tq;$jvc!NS1B1OTxHIpzQW diff --git a/mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/__pycache__/__version__.cpython-313.pyc deleted file mode 100644 index b95961fd2994bd9f45690e25ef385f17125df39a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 224 zcmey&%ge<81amD{WEcVI#~=<2FhUuh`GAb648aUV48e^0j75y;Oq$G9YzBJ9dWOcg zxZ~r?Qj3Z+^Yh~4{WO`1n1G6JvB$?J=H$f3uVnZPGUS%0enx(7s(x-_RYqcdR%RlQ zm7J6dIgoYIBatBQ%ZAE?TXleCWG8q dEC?h%Ff%eT-ry5!;C{d@(V^eSUc?F%1pxe6JFNf! diff --git a/mace-bench/3rdparty/mace/mace/__version__.py b/mace-bench/3rdparty/mace/mace/__version__.py index ee8fdbf..343c940 100644 --- a/mace-bench/3rdparty/mace/mace/__version__.py +++ b/mace-bench/3rdparty/mace/mace/__version__.py @@ -1,3 +1,3 @@ -__version__ = "0.3.13" - -__all__ = ["__version__"] +__version__ = "0.3.13" + +__all__ = ["__version__"] diff --git a/mace-bench/3rdparty/mace/mace/calculators/__init__.py b/mace-bench/3rdparty/mace/mace/calculators/__init__.py index 7f5a559..8511eb9 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/__init__.py +++ b/mace-bench/3rdparty/mace/mace/calculators/__init__.py @@ -1,11 +1,11 @@ -from .foundations_models import mace_anicc, mace_mp, mace_off -from .lammps_mace import LAMMPS_MACE -from .mace import MACECalculator - -__all__ = [ - "MACECalculator", - "LAMMPS_MACE", - "mace_mp", - "mace_off", - "mace_anicc", -] +from .foundations_models import mace_anicc, mace_mp, mace_off +from .lammps_mace import LAMMPS_MACE +from .mace import MACECalculator + +__all__ = [ + "MACECalculator", + "LAMMPS_MACE", + "mace_mp", + "mace_off", + "mace_anicc", +] diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index fe16df03b666be8e459a7b528b576eb91ad3b9a9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 412 zcmYjNJx{|h5ViBsq%CDY{D};ZfrSAfs)~(Pm9kkvma$!u%8nx^9m>yP>%hpL@yf)& zA7J9_O2Aou_wJqjbheqzCJ5r~{ZM~F{_w^Bh@iNG-CqL;B(Opn6RaZ=>6paWYmrKH zO47h%HPRW$FnUHJdBr3bX@!c>Cx+`q>?WFVNm<))?y|sj=kh?zW>a7{eV8wo>xwSt zi+dso6`-%(h{`xV-4B*-Ac?-F~P|OBdDps6WY; s*X5NJ9kateXh-0`kLXKEo3T`f|bb?_;KSj4|3;+NC diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index ae361bd38b101f99adba5713cb075aba54828ae2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 418 zcmYL_y-ve05XbF&mR1FciLnDq2FSny5=E*imXsnO4`3PBC8^>#a?+t}eF#Q&cpO$s z6$29k5~44_+2!Mg|NZ`V_W9&>x7$V#k0(#r2jn+ne#7~J^)bK$N>PeOD8ULRj&c%L zxjj>JMqc78-?BRjl2C;`6ryfM5f-+8MnKPzC%c9aP3vxP*uHagD(J!68bn6rrn-}~{?Z6Y}Tk1WKPZwkSiuy08ze4*f)cf*+ Kjjq diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-310.pyc deleted file mode 100644 index bb99d2344cf448d25f3fcc8dad6b1338b701a038..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12018 zcmd5?TW=gkcJA)!xesq5b+cr^Q^Y#4i$?^N~ND9Uy$ z1=t~`SY1jvs77k$rm0g~q#CWoHe<4`SL2&;w3*dJExDPLZP99~mflRG9^>(9rZ%)Wq-(ae zGrT`yj~?ioW3qn29^;7vV{@E8&69S_p0LfbaaQ9gd!lUeG|zk%-8{*4KE#Ll2p{ER z{Dhs{IW?j2k#b_-j&XZ}PvCpfKFv?z|7m>B*k|}zT&L`z@(KPJKZnuICiUb{@_cfH z|LoZNJN;AqasCAQKbCwFcZ~37IL*ND?es|Sr}!j#oa6esmVNdc+R4?d9++#EU%~(A zJ9WpcQ>9sVd=HhO^=56+s+O8n%Xh^$*klc@!z&Am*LuyZ=gV%h&MhA|d&Qc|?W&ir zStWa>*05#@Gc}t#&00?W5Jc6#n6Uk(sF!HO&L@e=FusfUde7pQHXi6LO=Ns+N8dO6 z$bpV}WF3!{N4UwO^q@F4m8WRq)s}xbw5Q{~QQSAx*7wb}*4FvhS?$26qkYe48|wO| zAujk)e&SPIryk>Hshp7I$+KD|vKK+m2u67gy-ratDMgH*KG63}er7+;&+^AUQ>}@% zaiA}0pX&H6+BUz?__=oU3+=1OUQEf}jveUUU!$IEoBTX~{Br~2MfnqL;|t>}Lyn%> zN$+PcXY9cEg|QdkOYkS#3C#aj{t##Ar%}IO>A35uc1&GC@)#cEq&0)hN-=zv+}ri4 zYjKacf>qs;<=YkDiHnXzEN)@|wv*9{*pOu}eolD8v7R0u9Mb>B#`GnY%=Zk6P zHt4BZ)$A2UO2T|s4yU?c4$o=X%xl;sr|e)h^v!k0kpkIet6B9Sgv(aly3KA^oKl4~ zy-+Tt8*0i>m>W_v`s9HigR}E+l{RNnkfCH@ReYK`W!4b(Eyrzo)m^OK;x^CTzhV+A zYpfn-)Z9u*Th)s1H@qwP{I=s)np?S&Tgzk5yr$!KbU82Vs%=5?d`AHJOR?4{7V>k2 zx%rvdxtUq$JE{Kc+>0}B6^eG<7Tdc;yWy59GjkU%t5eXjAJJ@u`M%j^!)azOccyvw zA*Ug&DXy&fVyB4p&E-1dzV~qBZmE5({@x{h@Q_QI@0_5S!WI;^&mFWL zN|P*{y9fB0)GS!c`u$J>f+tKF^A80cyzhQfrMH*&If~%RvmOUdUEc>A{ z{h(F)sZ)tLuVmE+XUh%+QK{cSOg*C@Ds|&9p|rlRb(jp*lB%7?xx-Yqb?z{!)Tyx5 z-#jk}#vKlY{nVx&j;UwYTYKSfqQ5DWv5?KIiZt#Q87hXTAebOVs2HVUjEWOf5FnJ3 ziHT0F^_TB>PJNrH$=S-v@LDye#Ma%i zf7`;0Z#gAc8=FlntiP3IRdqGZJlkf+=;5}t<%-!n`mz+;P>b*@PDhemx5=zh$!_@4 zjr+E!0Vz4Wyv9>*Yu+xd!HvoT?-V!1nxC8wTg z2zSRW`T54KUvcZj#q~ASzGc@-mHfQm4NLgDvP~7GE>`vOeMch%rMbp#kbG$yKKV{V zyu3~;AeAycF|~wl+#A)+d#zl@w-;&g!K zC1sFL>75&IvBQCuwq4~5MR2ZrkfW&=Pih&dhsykJy{tTv!1e~E0{eh_U4V89@S zov`X0sxAU^2e2?mYysiat(qMqYSvAdYT*U3rl?k(tspM!kD9jU2Wd0_zzfH|We1rX z&ARW@>}x`}BAXNmY-wPkhZhX@_O6JA&2i{u-+oe?7bK;w6){sVKG?q~st>LAXoS_AroKHbLRol-E?cUfxs zZn+$cDqd9-)@`+6C3+m44_QpI#Bbv3t${SOrWD8Wz_HrkFZ((NlM!cpgRb_HeuPKb z$qw77kNDtI=oQ65f5wlsH834h@e**WL+;gQ@x$~mBC|s zsgF~)_1oZCjU_+1liE)oXkwAa>pD-sd>A|lPPT#>Q<9SzJdNvjaGjCYaq&ysm1!qo zV1_<7)-~{=3A)0a!(T?)iGS38qwl5L5wF%xuWPLY(eLVTZDcQ_YoTPLyT9&}Z0x9H zm`m|NHTU0OzOlN2^2FXyJA>Kuw%HzPN82$z{$(_rao8Vkr`yBZU}xZBC!29>xTHWh z>Hvn#63CAwPil`bv)dpUXa!yIh2_*iWR$0%!^fFhmq&dC>`geZH*8JWK{oZ`Wd;sH zzQoHaqO}2Xr|yN34+#nnorxrV*Ec#&BgfW{cN#}{4uhv-z~pLNPtCLH`gV^ZUN1B0 zO5mTG(z!|>Q*PEvL;-td%kgcw6cX%~1=88Dpww~#DcsQ&PIt+pRo@LeTb}2Z9Q5RB z+d6bsi8sw)zaXurh-ulaR^8h;D?M3fy1spDmtC*!f_YxGd~j83nypows?%t&t9Bhk zZJK@1wCZfJ;#l?VX|_;x>TD&)u2o?Wdy+|7d-vd$!`WJnz3JY%drU9Cxg)7U%$d;3%QFgT)y-|zEaN5VvK19-*Kf28Y*hM zGKK@$&6jpX)7vd6|BCRBlBHir%a?AKEjeDv#m~DW=>4m{QQ7s_wFUN?>vAa3_wR9c zeA?87YNKMY1@@MW1uoQ^TNUdTCUBrnZwk0qJdFjmhCT8~#RhhDzVKph_QK`4!o$Qf za%}1NXhj)L&{u&3k92oHYD>mh!own!5eiyo9Ur8^FCCPrPO8KLWoO%~0N~+N%J9i9 zo0N!pa_~`3Dl|Ds0i8*SttSbbc7}%$k$2%xp@aNgg!W)?mafdo-P|AES(0w31ANH@ z&_R;|4{pq4smZNLinds{fKse23M0|;qtAd7J8lvg4r0>BbaGT!{Qix)cBfQ<5wd6Gz=*!mS?>q|8$r+K z!HeuYj@~Q^Ah0R9u>*UzD4_!r7Q(b1!zsJ!R+NqY5i=ac8{hwIRBz=HtlOcYVkxk? z3+vYq$?Jrg2Y2zvb9(;yo1jJrSt2Isk;BAOC#c=$V96$oOc7f8ICa{M`4?ShhSsCOM;gKo6dAWhuwr!6WNm$njEu z-DPT4qbQ^4IE7zFHq+8zNjBw)m*|Q9oC?Zifa6Xar+2NLH5Dar0163PA)SJVFfaR3 zdnwj%yupHLI3Bb$%5~eHLEzE0FjBrZUbgk3+Op&1ko)(*W1Wqv<-CHG1W`pq3=oZ& zG9Hp+Wha6NaBYynI)v{Ll2uISdD;UC4+JsEhP_~H84MJ_ahdc`(!(GL!-nVs_EF5q zNu=Cjnd%wGD|THFc$ID||VVR37mEon{A8Hu<0rFPbhF#bUIC884;N4BZ^|Hp{M4l_wn#DsG3B@?yIheBsv`u7RZzrADWpX+xM zY@bJv&{Ao*C)?5e2-0a$9`B~pV!pW(-H#z7s_(^RJx-~veI(@$be=k^?IqfAki>L5 zA=6<=*(ce?uU-lDR69MYI2dbG5i z+|&38BBNg!dnrHNmISgsphcM(jJ?)u^BB0S*C+sn&y^U5)W#Enpl)MR!% zL8MS&C-f%;d8AAa6afs$DtTZ-Imw1X9yC8Ymz%vfH#_^sj@V_kBHTd-^2i*TSA>vO(g1u7~!a;dx zvgUm`ibYu4)(<8$p)L8Xs!J(Rr!Lud5lwURDlsIFbAN{P6$JKvj;NCT<3>jB?;BWg zOO6pzQlQU_(L1q*Nb|8TFTD{5Enp?20oh;6@#HQRu42l3cVqrB0kp%+7 z0x>RRfe?wOtV$~%0{@=6t-QT-?X6?*uY&wyiMH$-idBiG#A|d-fnEiIu2YQwwRnSy z6)Gt6l07SKP@S?NfytecFS2y~3>D8(@h4Qwpg?x8QFVOr95v*qkdQA=wE`8hR1g9b z8&teQ#k*7x)D!PfK|oKuPsImR5WLF{9ff8D!^EFbF;B$>DlSrSiHa9c1QCk+e@K_V zq=HtW3Yh^w>H@hNoTL9hyFy=QM}d%lroxba?7Ku9!F|K{E)zWkT$RCh*iaaZ`cMd; zFd7B;C0gslY9`QIBJA76YmqKiiwI8A$d&#AwGuph_d^RN2u8-i`*#$GX6sGlGCS`$ zA^vXc>AuM$5<1YEPti8sEAlqf`$tsYi=ZBF%Xgq8bTWCYtpjGozmj($o^H0r>Ak4W zv^~V&cMxqyhS{Qb_n^M(Ov~`r8!leLgV})u zixtW{qrQd!u0s)R6=59+8gtjlxnety*z#=Q(tQ46t}uJ?;{4p4+V58O{-kT(syi18 zg_}j`5sHh8;j0brD+5Q{R&txjbFxykDe1dHtM|t?29w?o6h(QD*7%AHcl6>Qmd{$F zh%1FJk0F+x%jH_7AAocD0T-I@QO@HM%nw-WfCV5nN;x?)xZah#!Ph)R}RZ6wrq#1Pk=)R?TyTNG7$UP=9a73Y_Gw)?ga%`g!3LhklSQ7c` zdzr{ELJ1T;FyZ15UO=1p2)AbSRf4gUv6h%&8F-Dn*y?ijyxfFf7{cI1fL!<H50hGJ3FXgu0~F-3*MERi0g1si!jYHai(EEK0Pf1kqJA{2uGQuxj=jc<*pfyj$-XvB^G i#y|dmO3=nCD!;3W#34{-~KQ4ANhO$ diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/foundations_models.cpython-313.pyc deleted file mode 100644 index 2fd1d3854b730e59ae0f75ad30a93cd31a017eda..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 15702 zcmd5@Yj6|Um2SP?vTR9yVW5$DiLoWyvH=@nFn$0A+l(6xp4c-=TWVPXS<-Y%fR$uI z&2H5~_5o~m3i4wsP?fDPRa+JEE9`D+ce3nM=GT@Kvt@0^shX-xwrYN@2qc?z{_Huo zyY<4rytXPAeEat8`#AS`&iT&K6PwLULHOIjd$Io>q^MuvixQl2;>q__6!k8}P>kvn z#Y;~WuZDN^Db1vo*V34$Ii;J_^Lmn2oigwS$kU!OPMUZV$y*gfKWQ+r9DTjA5TN%@We!;b%?N{p{ zlQA#2`n8OuojR{Quj1W|im@^_#?CY_4yJLzG}~3(mXz*>ij2MASl3e1f{Ss&pKGCo zal^A2{#q7VnO1l=FWCBnPbp($I^;52 zmQ-H)6C`rls}i+m*h~zbx?nP%N@60LjAsN$*qH3{D4Pgp6Kp2Me}a<=x%R1%(eX-l zZeKK&O-9%ZR2D+ZsR)-4{L5^Z>t0T?-96pQTqK@d_K^>wPI~BhE|cYxVR@OzblD-+ zz~7Unun6x`LD)iv;Fa}J)I2GG(X6~cz`^QZZ)6qAab?mQ(>94nl{EqJpYflhSVGIj3S8kpyM)%s`!_dTxi5J2VGlut1;i zCk3WSp^1hT87J0-r%Qgi@!6|MYUVmgzqNInD|KHoD>k{ist41e)J|%d*4c)tPjV8s z-b>nTe706lt+eZ;2IS95PEr@t+9{?zsIyZ6O;Z1o)^YH4A*ff{Qr0lCRI?|V>MFpi zhYD($ElkItny#MttwHr+HPBI}7rw%}N@2EHNmZV+ZI!f=n%k|cbQX4F`ZBXEXe28I zJ9vJJQi|Cgtm%nURqX&esIR3WsNYF>cd(6cAk1T_OUXovjRp%*CYBW++iV_cNYCNj`*ik?m-Ir>s89*)sjL8e7I)lzprz$`h> zG7ZpqZI2OoL|la43;l*!7(t|qjz{S<&s~hCvO;17M#@IGi1*pidDYV!R;fSfO35W+ znM_(3^81(KnOJtw7fvnvVQz$MJfqw^ex6HkEKuuLIwxO3%jr;$zptk+(B0eD-3zx9 zUR1q(2fI)Agt#QfFRg^QbSfO{?&}|v?yMa5CH2-5sOqg(?xuH8>E^XB*bVNE(A1d> ze?A1`>+>n?p8I0$E=u#pE+si39uk(>L_(gIxi7QfYA%sm`hqR@NxOg$+a2MOLOiq5 zEsdZXHYX@+Pd_+zPM4#>>K*wpzrO9<{x^QrS(SZ3RcTrwQ<2lz;u z{18qFa{P+GsczsSe~Ld9M=c$*KouPtzOjQ6fH7w+{!vYZuiiLDZ%}&=d7>0R=gNSfpXC#Y z_@Zdwxwo>MkP$7A0qPqc=Pq)hbvB#K#Fx2oo=@>!(BENtL@ksMM0;h{Ly#dlVV_s6 z1m+2%iR|4F^dvfK%ZCUZjgqWbH3S|(0MUrYDYPss3D_eFtR(O>e%zwWJt8-6r2x%SGb!YilN8@A_-+e>8QHc zv{G!07hTImcdF=qyQHRCM%0ifHosMDeXH1(ft+n8)e!HPQ9s^dx9YEqm3C6r#**4* zH9yhVt>%)3GTE*huNiN+Z@*b+?RsSFF42_TS?Z*8&eC4W*nB%#XzN}zbgvt2(2(t% z`l&{3GM9`{?`o)E+*&kR^ESF*+(8nZ1>??=LAynt*S3_*P<_?VCD-Jnnoo^Xqw|*G zreV$DDL6cLj;uNYS56k&JHEH)yL%qvW?y%@ZymUKV9nWCaCY9Ae&h^1(LnYSN-Z_3 zr#g1se&x=q`8`AV-G}n6hx1zZ#}?c51J@2*AHFu6Z+~UY(p9iPzvk_ksMbgL(VlBjeCVmX5r(;|VMw#kWHw8ls>& zWCDE?qI&VXr~$?d>LG@)qoO_n`UTvJHkoyVA{jV*qB#-=y@wY-{D7Fd}ek&X;f7Af>3Y zvv@z(1*WW;Dt9mogPbp`#g*Dg)4V~c%xHq9T7`1PKVOC_Ds^VdAcRtLlRF>P*fMw! zq(FJsRupCq!3`6PP4yL(LSbYN96@TntcK5)Rq9GFv&c#7LOr)(fm1Ms85Ls$*hS5l z%Fsz=oMvV206I}KmQ6}nt4bIY3R$42iauoMb-_c;-DGTo>Yddpyld_z^kVmEg2pk` zepS*E)F>??Hw;>K!YU&Ic*;K6khH=aR4A~npe`*}r8cSC_@cUAclLWQA!guvpmk-#fF`hnZB@WS`VwgH zkPA3KhG*kRKnNr;3^`GRP9+K7n*s=gj|1DL&q(3|z4zcC4GKOAYJ%5iJ`3oGR1c&l zS@Vy`QW`-{PM%eyF(1urE{)Xz2sDV0G$^}LyMoU=oyu?$#&R-B6WI&|zATYeiO?3! zCc}uKRQPs0!x8e4po=V^HE9+u9nvr%s(w^m1btzXD{?_r5K`edl!{2xq+mc2NnJEB zZ=g_+P@QNhkw{&FdrYFSyDTIxuFxkFD*)RZV>1BquwC?-SUeF=r|Dx{5)h9r`t>ZE zq(@_MHo4SAk0jzrddf$SCqNRAC_>gc_VDh-c!WOVqfezSKD-+iy68zZnG$03Y@A!< z1TK;2qPYyv+a>Vq|c!imaKFLv(MCZ{Wb-{saD4)ZYvByJ!%UhRPCmOlm6% zO#xEm53lf9VI?eypMZ2rbST8J86t$zV{svzf=??*;Ikv0jjagu_y~PGm5RV7dGVZ* zws8_i66qLAkI<(%7}Q8IyBK3HLgR7RGpBfvAz|i6=rb@c0&b(4Sq$_X^!4@+_Vs)o zu+R5AI4=qkUPeOP1l*TpN?8It%A22TG}wGeS`n-yJW^4|q?1R4+-Q7Bh=C>qr;Ol= zE3}8Go}S8*c_fv~gXpG*D4-tH;FPA}G9WeJpg;lgotMvcZ4ZQ2iM>CIca9OZuV{vN zl=1KoLTe|(6V66h4?>abMG(8$MZ{Bat!e~@%{n0XC{5&v@{!)G16exp$jz2z3{X>3 zDcCtxDDAlz2V?@Rk}fzFper&69ZvB)Xi9JlgxVoYz#X6vEr~LOTtI!#!kA_0hE}d8 z&~%@@RiLLz^R=ye7SDUg_{IWzk4?f|OD3m03XqhZot*K@;TT*dTsLVHj7NBO6$0SQ z<9a-QaP>V%!WX@lOk~Y`jS{hp=2@^X2{fFZW2u;QzrLjPdO7aJYax}Eu5$L0ycuB# zaBGe*;{azX7P8uzt6G+ zVzZ%_K#kmhgpEAQril)Wt6fXA#LhOF=wu2^7-&6|uA9WlhHq*!U6|ZJ$yCjyMCwh8 zT0Fp_CY;X7iopKkpYf?Cw4iWlPo6u5uAx{RrMI?dvT9ZRseB^^hg%x`8`gJ`RE`l zSM@+PrvhLI>oS`T5pxlow)Zrmg1dM6bOYkfvCxvg(w;YWP@1yM=pgg=+#~GONy8b3 zk%&4;p;Lpxrk5^?_b9JR)PTS#nqd%pMnDmhsQr95PJjlh=?P{dh>a6~lYnG5ft!$k z7@`SoFTmVj5_wcmL>(cTpTe{?E`-W_Le!$kiZ|ki9i&hQXUbPe211p)fqxClSxzpe zQ+$T>M0O581KB)UtGx{bJ*ir%G5mz1sXj?oE-M7?yIh~~-?iy985z(f!)?38rJ zpTp9Ih(w2gXe{2|RY3ww3^Gv0pNA40{0Vm@aO8}tS@#-f>5dX*bDN(~0AxIEr|gY4 z+<*SnHJhhk^V~VOX6q{0y4KtJ?*~`g4qcsGw>IWoCqLZxVMo5{wg0f5``FoXTlbwq zYtCH-=dLy9o`Q4Fdh?!p1NZjidq?wo$5xxi^N#Ui+qP?yfOFVgYqsOx-}QriclWJ% z1`D3SHP4}f=g@cc@4K%|6b-K1rd7iZKo^SE6$u*AQ%1LviIB;(1bqC?NEut#8(i0O z-_G6MS7;c@8;3r2Yy(uGQD1azd+^3%4W_2glIN9^P~O^HY-r6JTT5D<&iPB@dNt_O_a6mwl@#Rj>VVP`hNeEQBhX+6G{b>yKZmKZ#lJUKb@Zq z=8eHWlr)(8ho^m%wdDz=(K(?&(Rvv$kp}Ze4NWDT+T?^n@L@y1{rS%DxSsl{$_Meo z0sZ(+>I1!be23-(hjqL|^FfCW(>u-MM>HS!tm8wP4~BG*4r`&N5}E zk_whvEaMDIFo%^Mdxwx`cHElN5BU3?Bl46<=R|=>BhbZ?RtipC?y64M&T~+ASj%RzV$B-h3 zUUe$v2zuqLMz1zSu9V59X7?#IE6-Wv8g-$*SAtwM19H^@$W<95BB)cPZC_m@@CvAt zifKjE32>&iGMZJM6Q;egPEcaYCM7y5CFXmT9++)&h@fo+g|QkNXq4gSSM3Y~i5Ndkd|AgeDRxi*kvS?%L<1S}n20qBa7d0#aKa#h^ zXKeXpG!vYtu4S*L-V*^)a2`GSMAZ?Z2=T}%&hmIG^1C6*{awvn1j-KBIYBi94Nx-q zg9iQoWcF9Cz)xfK&rRZ$2$y9IN`5?7i<&lo&&zRu`%tUp z0ayk<2C)6^!EyU#4ZgC~re?;*Pj8~gO1e9L3a9KeMAJkw;-~R_79*n2oWT@o#{8=o zoyF*@5P92pbO7ZSAQH8aco_T?@g2D&@5KnYB;NxO__w4J@eGeFnMckhsc(Il3SiWa z5o&Y%8yKOF10TW&xikMYj8JLgS&SAjLOIH7t5>s5;3w32`283izzF=?DSikebbSyt zU}@$-5~KJSL;}j0RZ0|1OKKDt&B5OQ&J&-80{@OB8$jI#BXi$QNsY?8_ODv@|D^K) z_fPxQhGq&wGexWGe>Cbt=HE~dfw|MwT5L`iTi$~Bu|eyyUmYviDOX#*V|dj$4EEzj z_bua1<8Akyxx4m94*y5)_L70JZFy>@+O~Yp@Lj`N>%KzkzI#VjTMu15x$fEdgFScm zta&aOYe5@4s2-pU5{Kf8;o^?r?+l<^CpW zUyV&LkjCbs>ucqgmRt6l_BF@8f@9yk;YW@sD4uWa&ijt!yN~9ZM;L}=uH#+}v^3$o0Ff*>bU;=-k`uJal7DXf6RcIK<8wWo6^a>>oX}hdG zwmq;(>f^~7S$zZ-0joQ&b^ZS7o|kBL8^Si&81Ia$j{B$&v|SMY%x6Af(){e8^@LXQ zp;iZJJ_<#`xV4pl00F)~*#;X5oa0qtKghYt*U#Y=4|Zlmf*>SS6-zOrQDi`n;ibRJ3A1IJwUPs)byw0#YHlQ`W6XmZn3YN|7G<<(i&jQVWd5?owAn9K zK<0PImgcgK13BX=Ze+??{ z7njsWrccs)Pju0vUG#JpJ@XmOtY4m>zL~*-0Lw&Y07DwkDGNp~vx$1fFS3cOWRH+N z=}1cxfOB&6M4U@R=n()a!EzghAE{8KES58p%7wru^Zv#E{?IsD8#yvu`yj|U^J zM9cAqcfrYz9u*$mO|Hbdh)HHL1;2d(?h6!*#?XNf(r3V+7)PtBWc8~tU`A4LUy5Jy zgPpUtXMeyy;Oprf7zp(BnWuB!XWilaY%)I3({mw2grd;ssO+CQC-Hcku5c;~&Si8s zktI;HqwhfXX?i2y|b;^w|uZbAU}m_WlQZuFvPogNH#~N~sc?PIcpJxmh@P|csOD(}auVwm^2+%@A(45OwH9?YQ zH$>+&zTRj~vy0vvr?eoS>FKFsw4cst;pruAPk`ZTz%~JHmYU^wlE_A0ooE61dr8>fbbQOXe3FdfrlSK@R<;JXD*?;yx2)Z z8`xOkaK(kPuFx<{sPGOnbcHHz-CfukxM5yz-+DWIXXw7B&>6UYzR-DewSDA?xx&%8{Ob$(BX8v22a67cZDKnBc^St2o(N?ph~%Rj6^ z2^hA2hTD;ApCG=LXpd8-MOlI>X%@v-ucJyx1?Ec;^gWy4(0Nf3O;HR5v^@@o8w6}3 zverqfb>itFRFH&Hg9Q5{-2NfqEU@8ls(t?>Q>%;#Ee#Es< zK?K5QD`jjd7~0m|Ew|3yJa=p1=EAzQIo~=8kaDy6Dhiz@d(j#ta60%4TJ4H&-|c;C zmR$wQu60-Qt>ZV3-Gt-Whfc| diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-310.pyc deleted file mode 100644 index a6b198ba203b9013730ec7d32f36396b1a006933..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2326 zcmZuz&2Jnv6t_Jd@9gYHJ_u@qph8e8?V$IQEjNtj>&*$Us`w0Ez zn8haq<{N1G0T6)%&QXDfxFADP1j7KkF_P0Eg>R6D1sgIyM)Ot?4I_*s6X7AgjRX^| zCrGp;owO(MVQ`3t9YOA+LG%KU(IBw&R+bt|?_}E8=)0-Osxry%4zTT}RWZ$S$%{%z z;5#}^jFF1#y0A>ADw~?rfh<1VGsY_}(GUwXBu~*U7Fduc(J*KSA{JCq8A>*3?UHpQ zd`~bLIW=Lt?Ii2S$#x?ViT0Cl7(Yc64LhKC$9C@A*xtT(pKsr|`E5!TcBjzqoHhL$ zbPPH~$M^(c)Zhkj(xBr&;090V34t-k5#YkNTmT$%oM|;!&_klH2sDBfJo2+c16FL% zgiNC2PD75nBAj!At)mlifW7tslx<4(x&(h0hSehg=UXg2|0-n%Uj93{SM z`-!QFEahcg>`0|;sCbb)w#&fttORQ(sq+X3#xfDwQuldr3J-WB6lQ-a6^4^;4OVQZ zYv63eVP$4Gx1sY~ZkLoCfp@`kJN53aRH_XNlnX%RtYur{M1yCQa?Z1%sZ=_)%*aw# z%7$4fZJYBnPqgOzIePii=C~^4W|16>lWLMB&`agjsj4P2HJj6YGpL^o@#Zxp0L$$Au{lYTJgsxEr_!5wQWR4SrcC9=bYHE267Hz_GLRTE-1l#cBTO-Y z_8~O-h6VIhKZ3E-gSCYi+c-DqmZc&wiS6eVY{o0aN8<|8X31qrZ6yCWQ(2O0%ak;= zDpMDb=g$&!B4{_E=^wyn5MX#69itOS<%?)WMKBASK=lQMFPcEmiSRlE=>$Sl1oO8E z1v_Fh)<9aXqgktI2~;C>70ST;jee!~9da9_5zLC=+cv>5Z6F+HvF~Bgfnde6PSa}0 z8k%)Y*mPk{7jm)JP|-i4vmVT3P45Jc@T?D7J!{&q6IRbmFlkMqrUyw8HSG~;`d%|C zmX6pDWEKNz>}S@16$5%)hVS-oI$m}NvJS{PP*1aa@&WxEjVQe6SL?Vpcn17WP~+Bq*W zJ1>|N-TxrcROY!2 zcM_A172F00U<0jEb;02TA}eeJgzM};4w!9KwSj1ZD9cL!?Ye~H=c8SSJvgMx!s!@_ zxRMD(<{6jddG(0bWv!*~bz{)BA=Da~dlUI=RqwlP;{&NG&GYP`RLl*us}geC@it0o zQ@Kqob#$<#K7tv_C71fl5pT}e)iA%WR#zQD04jRhs4v~5YmR*G$QM9{D+_i#v*V@{ ze&xs~Z^@_7vn{YKR1|Ru82t;7Wd`{RZw!qryzFHJ`Ao=sjJ=Ln1ata0f{`v}r%^~x zy8(k$?Cq);*2Fj_Zf2J-;&9I~=}Xe0f6wLbP!HZ)tQKdw4e-{W>xyv+6%vYZ%-uy= z!+pwam-D?k$>$PZos_%B!0S+jz4_ULdHJO~Rh27ON7R;^(K{=gE@U`gu*=78iZhOK zMaETleGy26x`bSyuYtay0r`he@_IR3edT9?ehh99$Opf@3fywPm)D^uua)ISZL)l0 zk{RXxN$jdGsRoqN8)yFm^Y9>LY1y_vpvPWN%{$NgExJB8!PhP=wYzTCP(-@U%)tlv EUs3&L+yDRo diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/lammps_mace.cpython-313.pyc deleted file mode 100644 index 41757dccad7c2a2c8b7d0d2dd9d3d4ccb8d2ace4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4061 zcmbssOK{u9bwPXxg5XE8NmyVLrsgqNj=_NpMEB-&$M(+jr8OzD>G@zZ5nFPW-wNuoTlP%4&nLsN>S1r6}1RZwkP!-`cY zdYo0jdf6^CMf!`_M0S1$z}uki9Kw{0a7d;wok!!89Lk_&9y9$&j>z;JJIBq1XOitA z8t{0C1wdx@-j&(8h|Kxi!rmtccSFWFFNfhyW|_iKna5%t{fxq~!$=am=+w&>FTNow z7hit$?9&Kv6he*L{h81-DkRX_X8*&n&}SU-Bo6Pyi}PZ{t##w1XA0pGrMvF zNDfbgX8Hnbca`~oKal5txxMro*yReKz+?8I^YBegT8`vtqvih>;d@2VAX)_PyyV4I zyHwN*idiYnYuNHwtQ6IEy(Hj@ZbBHU1rn;Tv81UBmdB9g>3?8AsIm+8N?F6ixz0%) z9&1$q^LPa~bz4z9mV~t7C9t*#VGkiaU%7f!!#D;E7;+fG#5=t3l4?Q7V@%%Yv33a; zmOReZOsj-FRyP;4cRhMhgY03H%tbF!F|F$rO}nXiLRrOXQ3KJI7sJ~1ijFl)S;Xps z#CnXS8CP+d$mFiwQ1K!t;h*;+ic&CC%Tg4Je6v9q?|J#$QmLrrit5cJwY03O02Q>O zWn5a;3U&jT*UZ9F?gU-{G4_g2YfJ^zC{zqECAM;gS}c|=uwp?QE3e=bxqp9K5*Q22 zL*m~02Tb3-hqn0Gt+2x%c*JMiROYte@?wqeb@<+mzN7amwcM*t?$t;9Yc2BnFF(>g zx%Tn3NBr>?ZKE&uS*mvIb?4aYkNERjiS(Vy`eH3H>?DTWo`Ln@>e<^DT8CY+b4%>_ zVCwxTH@SaJSQTo?K_@v_OAb59;g8<=ZH`QD zCg0df_pB|iE^o7REdC9L5(jSS>(padw1~gU!#Tl+BC|o&Mth}f zNNBAjU<)|ySa-n?Ow;v-XCb$aNzx?~WFb@`i*9rd1 zh3Ryg^j3M;mH1r?<)cmRsLvgpe(MLhBY}2iboyM|>Hk5KJP*aIO%pjuENlw@r!KL4 zI4|U(l_0(&bzJi3vW=e-4?lMq6Qe?U<87=$o`r-M3&Q3qWKIodVc&}fG|PsRXnFes zd_lL$2IQuqX4+oVgaoZY;zHg^v>nG(6F_4Ss*v>(!@h);(@d&y;kM6vbkFJ^uz4w2jhR_ zK7VmDd%hZe+&kj*p4#BMxB5oiv5DwZYRR&jvsOJ zIqpqze36Qan*nu2)o} zxsy^tnlkJMR0u5+lnpAhR|ALUV|zwR^C1ti3)g*&I#&)%E5JLT+? zHsX>K$$k?h2X_#eeIxl-OCt1&D`iN8zP6b33JukKm1Eio(;TbVx-qt_+nCf`uhVbN zY_Nef*8H{E$O|FU^tg+@fqmJZqQ*Ubc=9hcOn(Ny1wx1wte?YFrzwj1D;jao$X7`G uGaCLo%5HNg6t6CAQe9t%6p9k=99bJ(9bL<<=5`RlcW4wkuSGt|*xr(&kdKATkIB-I&2F zj@j3ZT@HeYR9UKaS><);BR)!`JtMg)-c%`nByvtw#g!;T#1r0a8isH9 zX3MDPU$bK3-D+8Ft74hVOSMvMyJAb)Zl&8!#c5|M8N8+NxRdV>v zwDRro%6Pj_DM(AVHPJ3sijp2{O}0ywlBBb(srGbbTGF}Jq4rE=R?_*_TzkGUFX{2t z;r5Zrk@nHb(e|;*G07{mj<@fs+$HIWme)Q}IU#AYb$8`%j9+Zs)4sQIujEal{e6}D z+V@xPPwJekob*dUZuQh^*`LDi>4xo3`-fh&D-V3y@MrwlR}FtQ7{8vXobk;UjPl%D z+`YxJ8QaUP?h=2WuSF|mE4H6)hEbe;vD55!V)K!Tcj>vGIR8xbBH)GU;&at6T=>kV zpSkox^|{4ozgT_txpPmWP4;}PRqwTGQCG#eTB}v9uhcrrL5OY-H+pz~?3Kr=QMY=v z+gy&SAFGCu3PP4z-N!hzUyEw^O{rezwi!v2wJMBQ#HjMKfS* zP$6g5?D)-k5XM=gyPcrA(#0M~!`M+l)Ki^05Nq_YkV@e>k0<;QfXEmao932r-CVUc z%z+uDR_)7{DsZ#qEseKBzV(6;Ieuzje$9B(x^DgjvuyVde>#Y~XeIFcXnig4x(#{J zV`mzzZY}!QM`E3L?9qPV#m?nU_gW{ZUCzXI2Mdj}Sal6sgZ0E|=_yXj4yk!ePaOt` zv!AcEdcg&yx~gocS!7kR9H%PvZyU-();nK(Xa#%xP`lP&sdZPIHKgjngKMh08q}kQ z*4CqyZl`+wh39qtQqZZdJoK3IF@Cfzc^cG5C=p7=!rFQ~u^dEIS(zU46vhkl00k>+ z+9scZsqRKb9nYRm4sTE531fi3AW>H=->js@EPPXAHooaGXDov^cWlf*?Pmi!NC%E@ zHB$42pTk=wa977xv-5`jO-k~pGalspLXc}@{RzMLYPyp5Cy`qU#z9Qu{&Y}ixc(u3 z2IMs1&-!z(8kM3y?;pnRq;L6FP->+7BmU7>&B~PjfPc(Cj?}b&#`pXaNFDMY^zZiX zL23r6d;R;6n)UDZPon)Ca!&bWV1ys55rR z4_ntR2Yc%GLcU*XwtTN%Q@+>fMzsk1!fQ2AllW3Sy&$t0b>|0HL14+W^b5NN#d~OX zs^0T!oL_ydS01$qI)PeVul734NQgA}{8>*3;tb~Axe8S0m<_NBv}~wuJ6Wf0=Se#b z5`2EtWh>|`M=LvBp8Mf?)XjE!Jl9#q{&w5psAE59)OxK*&L^`X7`sa#yLH2LEhMrmoRoTqQs?xm=I4MILbOX z$?!9cgo}rKBYXquL1e6&1EXO!jY0m`%%3(*<0js#mJiNdO$~DD$iVf@%T_r5S*Q!- z&A#y(l!Zb5s-e#M_Qv=CIXy!?8rhra!T8spF1%@Np`StC2Zz}(J1%&;BPE$PQ|v*v z7+{Q^0semx}BwGBHfw^pMn5d6-QvX}1Kg+f0Jj}0j!)%!T zoIn4nJ20PyWJKOn=eX>Kss0F~Tha7&EBu{-g}yAb!>&tTG5R_z`)9F@E@6+@hmwEf zvi-WLve6;`DBrQJV_X-`!S#>htUl1Yg&nx->n6{qlh^yiR>@9!%TjU^ieAq|{z>$9 z;-;acvT58fZkRV9H^8+ydt-HKIDgssnNCVhZAO>8>~2g9CN^hpN}sQrS0425K|7qW zZnsg4<~HXC#Vu1(hX=($(LXc~jM&21;TN!4$b`)!tTEIJqNBQw?Ao#D_~u>g9eZF6 zY(Q_220Rhny?GDhNHP}tT)mewm74brGN^NZANaL8C1oe^e(DQGr(hdV`E_&s=S+iN z|D<5gsV^EG^BS-jcz4P-wycfG!DO;O-xy4y?di^4emNQKRI($daXQB8X?=3o$!}vd zr;&4NxDu9fbhgJXKOpaCSmQcW@QCj+JK$ew;P?+-w$RHTqUOSYXZuicwxtbcfbz?! z*DW=P(uW5QN*|Frk4pNnH#33-j{gz#VX2h=(c#VvPp?zJ2p^NuN&n;fOR>vm298j4 z4((52lpkS_g2$gg?LC;lc|UEuT#&PWaLYj2*_Z-eOjQ?Nx2`;b=K+kS=kiqa(CWh- z3-A%_L&`r3_^80gd;?PdM;8@!q5kpny$e?$5L@eB*jrocswnhC_9qNc)U{TXT(bYk zlRGT)BsGCvJLp6nWVs(e3$E9qz=sCV3Tlwnl-RrKhbQ|JWYbhT8$GDE!9xFmr7o0( zU6s&j$KRb1p7fTRSA)(OukXkt7y6&L@Cw8)rVFK+B5PN3)P`8r8A!K#VdO0ZR5Y$O z{lLFiPREYs>v683f}lEl=};w2UYrKhe(2?nHW}hsC@pKfD5xeXyv{}ica+LpZNg}1 zg_q*&j<|v-5n&b=FYFU!Wk)@Ptyd2dJVNj&!D9r}?E9bkG}a7t=W4Ch^b;LRi`O4U zOk%GJ)Q^Ve8Tg8ZMBh8#?X~>Q$q3>eyS#JAcy9%e!TlqP)V%6AAY^zQDEFas(rO0& z0b|5LpHp~3D*DSHwwK4io+8#*$6GKCc?3`VGTuW9SjVue{I;{*e}HS)(N7nuFshnS zwYu;)2Wo{+ATey+FL80S%XZfEmh8*XpXhZu&Cc>J?Nm15S`*0G&m@g*gFHF4PXLtd z*cEJP(tgu;Q-$FOtP)3^y6DGFfDxT z_nqwn%c1w=1=gMckmY1(#IA}-t0ymgW6o(+ZP#8=_oAZuEP*!yqk+GJtdJ)1jM2Z3 zyRWSbX(^f_=>cv6VLu@d+k3}6xW}VOQI(TPIoxAv9J{9kw?Pvg0=+@8*}v0-{t0?` zn-iKg;>^=pQ;+SHpmvauMx!1bj+*6GY#Li2Bj`oeAQ_#`hD7P7Z=uz_7SO7eW8z-k zX&$s)WgbV3{;xcP;XUa7*SfHvuGK=}5<~8IMtEv57pPAvpO6g`;?d~(5iUi0^) z9oUpK7^ET+=r}eP#HaD0Sp4>GgwaR07hxQ~+rx4{6o3Di$M@p|G`@fMj;_F55^<`wpr~DB`wimlY z#;zI9G3I4j-Q{H*Q=GY0Q)J)jF3EKdv)m7|c3x`Dr8uMh?1!jX`n*8c=Q<$)vKPn0%Z0=R~G;diO>Eu0? z(RqGqSDrhtH&d`t9oV3*Z@*!SRb~2VjPjXb{ZX%DtPRh}SW_(Mfel49Q}b}%Blp^Q z6~LbHz!t);3MpU2?IltzVKPI{`AFdyH9jkdrA3JFpdZ>Z_NViRdC4akZ-+abg#?ez zc$!CWo{2O|EU-JwSo_k>dQlH-QdTBvYO?;(%G5cWpfLKk$PuAHl2rDYG0;k=C##PW z(4-mandyF@x}orb*y*j|7A7cXRiD*r1VkqFWdL}+o4AkZHmb=4?lg>at*O^z!xI4U zOtmg2SJj6!Y@fo}hU9a0)&((_goJDov?h}GK%_1S$Ts`;?ML`K_E@*j+weme_M9^SB$6Uig)-Vh~|$gFDR ztf}p=${2-0m6qP^l)8dx$4VmEOeC8Aew43;)=3nG}TsgmKYU|LwDOeD^6 zOjxBdXj;~`ADt%jM5)P0IeA2Q3Z3%u$N?jDh50(bBU^vpv$idlSmpe$>2YAw@%`8& zlNoim2aQNle1uohBG$^}22sXThv3OH=Z#J9O5tu6y}&zfKm~gdDhf|gD!`g_CzEln4XEcF=(j7V%h;JF1xj)iYQ{NjoeqCzm;SWSz_cmjFOAdNf+ z&MpgXK3s3k8*qnbqRgg?HycjLF+V*I&2)nfv@Fw`Ig~iegA+W{K)VgPDf4+_Biph3 z%myfLa~wHvbFLP&dwMIikrQWQ_PPmY+b_OkZRFvUAB~^PxG)$WIGdAO#(>nD$D`fG z?ydC3L{bMXOB?PxIARNP#>Wi{PG*iJEBfYO!XLv7r_n>s$7SIl|G0s?98zO*!+gH! z{QRgcQc2!8E*uLzQ?W0klfICWa*i|sLHKQy7DuH>jY=m``aP7El74=7klQ@umzLp_ z-JH?+>{HNg!ktD0!FBBJV>sfxoMUn}*nw|GcO|P~-p(HEY3m=J zGLFdQ1=c^lXU@7kXQtcl%2^>ZH zsieJ}VA39E{h{q={bMq3!8$m$4<)Mx#vK@O)-Uc_qpr&p>bm<^IO`wB>1a-Y3j)y8 z30xG|2~C~w3t)nG)e!%(OOZ>$0>lxVSbYm4eJ0VIgeRZ%AmkP(Z-#fE`>XBPjkvOe zefTom-6K0I8+}SVsNoY~H&v zuEm303dbykEH1!x=nE{MKY5Hqts4!h^*0f_WcKQp2(Vp-`bB_?aBUWn1tc*uv5U|c zSg4)l%G9tV@Rx(C_H~MJt^L_|Py`pD35pUZ>Y!+Xrub1S6Wnn`#0;J2m7!Ui_|k8~ z%jP%R+DZN-riPw=6+f}t>nt^KPQIe#P)`$lh2R3gM+rVga1o$xHS-nlKuoPF!QLVlL|8Ni`MsV!Z*4E&EP=A&c>~>J=h`&?xnYQ=@ z;xzhPTNl?roPr-&jj?VT&WieFwT^VTq-FTeGC#Y6mT_S(hE?)nrlX;Xsw0(d30tZ} z{A&e5yakmaEefHHs@{)EM&Fe@tlq#dl^oD(sA0!uQ=8r6+8Q4Ykb*03Xm(?b)u@t{ z2}pXMx~0B>UOz>(;(s9Fj=>JK@P)!Y4V4{sIP6v1vSAb3+a=hrZVGA^R4uJ~Iq$j| zShCq^a}M?McxWlZrpA|6Z$71et*l#s)( z^}f4Z$n2CB(iZxcwy2#&Z3}Z7)wQknvfJ6r40=ymBgfsFvz>Q1WEP#xZVi{P}3~gzy$$~?8_;2EP~<8@|&jC zu3nSt=Z59me)?4x_0qh)+^}GIq;aWi={MxV5`jbdlTbt5dARQfyx$pHd}mxRMBh0f zRF2y5JKl4=QrzWp^XO~j3KLWGfUA+p;1m}{Cy}@sOZ0cE_lb-v7VJ3~vNR)BgiZUwxkWPZHn_ZVp8TA>LBqdGXX@H~P$uZ&bSl zg(Z^d30oVQ2W(k!lLdx)#J|N#95K%ATK#WhtdNU>Hw#|}^hfVf-n28$4+`*!xuA~z z;Sb-wx3=-wV=&OXPU8{8KZ^e8Ln69EO#xmch%a$8Ny#)zfJ5{fr4q!40gZ?eL^MDp zhFP%c0MYH6w(kJ39~juLp=2{dlY>OjF#U}0$}PMmiFD*{jzv`Sb?cn;kzXCZ&KAVY z!fO`b1a9HRW)WF4WgIqMOUZqHjy=Jgzzl?-oSP|{Bv5`Vb>)V2BXz?@{6}8K8;&-) z$p{7Wru|hIEEpkuBaJzY@2-QnAO--l|EYn@31ms9K%w^6u|*ejTr1pHjP?w|x@xY$ z6(8!@wH<2w|J8>^@v!ZlJ*cDb@z zwRJ??9dKpoY-(FcyzQP&=^~IBJot#*_$9SD7WJ{6afS}B&e}r97f*?IY9#K&b%yI? zXa|TkV2m3290mth8*mLGf*~x^Pd3WlGYxD3I~%&)LhljACh)f4OFIL;>>&UmP}iEF z4p^f%Z&CdQkX&Cj{|%;ajIi7Q4sZ`rgK*5f+~cOfyF8YofAjUvJwVJ z-iNfPfgaL3Z14LRQ}{;z(K}G+EH889fX_j*WHH1C?0P~b<-_pUcM&e}m=I}E2CJ0n+uk7LGz&9cn zI~rBuJP&So#JFdb>*sNd2|+=23s=>MlE4}NKC@=nW`qf&*6KG=xNp+ioHRE;%M|T@ zgbMTC9j?ho*Kd1eBIKa)Rm+E5{HjG)2G4u1ytr*Hs=v*#{{aB@kb3Cl09Pj-Ge;XZ zhGPmjh+}y0W0j}(2?zCeSp7Ky3WHem!QE7$b>}R~D$Zx*m{fxmUFe^P?!`r%(`t4C zkk{X1w%y`YzUc93(XUmD*{QJC)`BLpqr+AvM*FF`i{8xNVuRey351mr^mDi$!7Fz5 z-9s&S{{*)=vYe9ohfMt@!9OP8Hs}cPUEO_`$3sh{l-ZG7MUhw3CXeIJrG%Rk%YyUI zMuY?8SqbtCdgyl67OA%1;YYE!Yk)ebVE7@Ha{ooB{Z%BE@%7C&tc{e29|*2tFe5x; za00Gck`MI<`8EYfl#-B0Fn`~HpEw=a5&Xj9Cw4$v8G6d_>-rh^RGiH*cuo)|0ie^Xn)syRyA;R$F`&Rx)y&OE!{n z5O)pYJP*~+v&X+i@GSz5K$a(T#s9@{^Gjqg92d8YpdC<(4OZufgZ%$YkGnZ-zd7&E zE3*;$~})VP};t-OhPa2qWT@o zRSWfh%ajn$z5~t2DLC%+xgA5L*hYv(D;~$r<39+I_P*~sO!fQhjmNliML)d%aF+Gm z#v3gC1A=|W<}pu>E#H8P3Goqfc2Du_j-srHbD;S5qjMO3rsVXd=+=Jx6&`+qHF_<4O`FK4|~gi z>pvyz3^~|Xd4pRpjUHf_ZaRozXX}ig#hr2{awGf)XOICzd^_CEh-=?65e<$S(2~ZH z;~-t&bpaj!enF3d)C5vR=;XyM&TbNjV$yMnt2!IE%u}~AxI`EO8=PL$#+4Z#LX1SS zC$SQ1E7i57I>HzH8VM4g(vu7wYm`%$E?t7B&Czx7x4hab{A*aX6Ev4smY|3O3*` zF+++(ME+M6WyP(B*wEk=gt9P_GF)jkqEI}8<;mC;BGG7l7yhI?} zZ-y!DNz*sD4E98n&hIl{_>(ZCX}rev)n;&Qcn|y^S%Cw@`Qd!ze{y+}04{_ZaGQnk z_@3>GQ&D>@&J6!jmr4FlS8?hZ?>+V@fHTG!c{dL(ZvdfnU{3i%B5es=)o2Q22lJ1Fj5(gPr_mjHRS@OHq!rTul&*$=ss6 zKLL6u-ZJq|6=QW=@O^BvK&1fhxEzAQkR4=s{R66)24m!WihK6-~sVVc0Pqi{WF4pPVg@XewRR`ipP|kA3ZZ0I6p1)AKrHk3Gaf+ zwbWbwQ;B}{FFC}&BKX$?|AycWV}F@r$0$sMv=9^B=bEDa8&m(CU_@I(iu-?7Qr|-x z(D2xZ`2PekQTn0fJ~3XA)-*{L8R{i5YoKNdbxq8BwSygjQqF0$>)){hq0cv%5<&9s znfiHvN8^z!67bo1} zP#(>UkF3p?%rB`buC7Ei~@7AkTMa4ul2^s`8fiOd>(}|tWNo&35FSF=50Q@JJ z_XOAZ{tE=+sdSn82EpGY_@@N_f#5$8$g#f5l(!cMncRasDv9Aw0=Nb|){{eHbN8?P z{XG2GV0iC?5sG8ea`oQ#QyiuIp!zfT z!Jn%`9Wj|W$@nVV%sJG`ZFW`oLZ^*%2X}Psg z+z;rdsHHECP;WIChOAj)QW=n{sBWt-0aQv-OCM;s zAph!C1r}tbK>qEm>8!{$lRIss2z~2#&k_sfcG@gxXIog+U6Qv^;CK9PyMo|j%uKh3 zJ8t*C%onu85?VdP0Y#8*F@--)L2)Dyse3o`xchfH{Mdb5FX;)?H^Kz~*Ku+Gkrz$Q aMf_n=?nUwz^LSfAv4>L2P0yvKv;P-N|A(Id diff --git a/mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/calculators/__pycache__/mace.cpython-313.pyc deleted file mode 100644 index aebabb3272caef29c67a20f0255b8730c15b8cf1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 29046 zcmdUYd2n3EdFLCm4<01^~29Eu`HkB7hjga{0vHv>w9Y+1Xt zm4LDnk+NI_Wv>Y(P8r-dWhhl!ro5?$%4TEPm8|D+7|r+vikL*Hl=ByOWF2MC-`8*6 zo565+$o6Jyn_$0wue-m#zW(~_@B6yvzQtnX;0&jh&R&}0xWA(p*=f^(`@D+dzQOr9 zziNQvrKgHl;k{s>pud1G@X|N+fVyA9Ygkw_pzYW3x_&*c$Gi5FYRD!7%{=|OCcH4H zg}jkHO}q(D{eZdO!dv=__#y_)FktPs@irDN9I*F0cn1p`2b}#b-o?VE0e8QLFJ@u$ zKuLcoU&_LkfwKN`zP!JJujsGjE7`ZAfvWy$zM6%t1K$1`zJ`TW1GRiD%C`;F_1E+D z?3*3=8~BF)M!qqfripLzJ5HH~p3SsG4h;>-q0H0oMA~DScRAFp@w>{n7c_ip30KSc z-DRBLbIP(&EyeR~<0@bA&*^2*rxG>elhbGD`D`e5)>j~E4oyU3qBanjn2w05b)xs! zk)b^Yg8Sg_4Gta&KDYPa{sYJS!6So*o(~>6va1hyjC(?pBeRpC*fcMiLX(rhk+Y%5 zcsL5MWn;5=-|+H=U~D>getKd&7Tgq!#`tiQVhg4>P|>5ISO`xwKO4cr9)3B-hel$- ziOBg7KM{(=d}`4O*du3yeufXvg!nLm78%Uc z^k{ezAGEteu@P!wDJnWOF%pc-PMrz!QL5QQWOQOA92JcSPe;PRv(xAUmQmF4;n*x6 z83CdyPJSffn0~M%BPtt+aHd3V`szO(b$EVuy=Zly-4lW zk4;X8Vw<`}DX?MvTu~r$E;9X6B%Ru)7c~)dvS>t)hcGPYOHs=piCQ)~d^t+zE8vO7 zuXzpxi`h|%QLL=e!8GibO@JLL0 zITMbIobBAekD}1n1@=wyMid;3vJV|I7ewoLI2L5hk;>ugQBE`dM+e}&#(ipcUkWEZ zje@5!;b~6TeMx(#VDEgbKc%<76}WQx^67-WHf1fzd0jEOF1@<=!s4m8dws&VA+FhQ zAKX92ky6F$iBgy;)ts81%4H;_80J)DlXqYf2m({N_$FoO7B;9tBFX0!f8(# zP8ANhWWbr{kS7zCLzL&~Q!AxViTYD!zu}a*+vqPO>hXe#xA=|tZaP)uH{-wMl(pO7 zFJhF%=C_t{yxnhu@9?Yr>Ql~cwcn0774Pzw`5i3e_LuveEadT5Fq)*|ixC1vLb?*a z2f0h}Ud-Oh{3U$(xY}1K79HNTXK#k4{Cqtc6?zLeLf^f+{CmBbNMnc}kMRP`CMHL{BO!j&8<~!UVxX1Y$qA$+d=0jAu$nMXqv7))Ea~bW73pTn#R6m2uY?3uo&=I z41(vv7t)D$@qFllcWfezP6W*%avb$epP@d1mu@r$|W!4xkYUO8SETTGM&033e zETw*fIs>v}!{T$;%#hx}&0rBhjD!PpgL*Xc+HYGq~Z zLp}Yr&NX#9QjJ(Ui$y(AXUEz)ujxZ)z<}O%t?k1ze}G=2RhKz%6tws<_fV@^krHDn zxqwRU-$>DLjq0=uOO>^3SQ)GUumX&C%l<*^W|UD zR&I#cnjGgZ>v;?FH;DRn%2EDu;BlU`z@v%`9%0rR{FQ9>SA8DL2fkJvRYuUu9jR4t z9a#7|W>9r!7L-8YnJ=tqgEvr^&Bvm)+)#_$!psxnZ&St}xE`@(jel1rogAv=hC%Bj zH{30U1pIr*3lC?Cl%Up;@Qb2ImfWFQO~zM2eiDHDfK|U zjXQAGq2MfNZNxGBoLpAGA-4zPV^c6s`HK1~wCKdlBydx5RgbU5DLGx{2}laXc*H8% z=;vg57jXuhXAGF%PUKrBLpkzgv>o#{*P(?jrG=Oi5m&%96wXbR*)jW@Xx$mbNDx&H zxX^ZtWtuv>vP;aOlHAalJP=jVw>0I&4m1nbn)BCoEruaY9=OtIuBnOAdw$;|^0eeH zZ=*)4qn8^Vms^08d8nIIs*Dj4cR=TFWh=CO4XtboxIx?7*U)x%L>EB%o|mvX{M$9D zI|4eS?qqqoSa{tAy<7$;j?TZHm4Q*tU#m4V*O{h16LL#3&n!xV)iphuYiQ_^O(<;x zD-E>zbCjk~fzMM~7FBp5lTQv|9_zAtA;U%Yt?9k)tTcJYY)$WAMFee~lFMML%f_t* zkEi~?eyqs(`s-0Wbf6|Z3e$2)nWsDo;E#+*<-t1TncO1BAI6S^N;+I8J(bZ1?a5k6 zc{zRN2`*KkmP4`3cR7T)qU=fhoA5@g26OD~8ToA-9GJv~jmaTfo{(dVtKo*fDCg31 zA=DgvcLU?CmAP2W?JpQSJH~BOS94>kG-OvEcx%e=Z5|Xgv_qV$+B3WNMeNxpFLJj!F8;pV)1 zre`NdBTX?#DA3E9LB{)J7~+Gu@<9>|jbMQA!5e|>Gs=KWPJ~AxUsjEY=Bdz3aB>eS<49!L&6Or*OSrmi&(genPPM^+ny?`e)=4qbzG@^kGzywLF zrJ)F(4?*rO4M&iqTB1?vau#8fx|hOp^~XqxPNmII&4G~NNa)NYgj1NwQUSPJZa8QMnCAzo4C74>~#CKVDjXTzaU zQ9mlZ&VBm;UcHb9y)+F8+DoA*8wMHu+8A<2roEI&o(in{sKQNYAWa?KXQc_0&IdeT z8Dy}K0!3%SBNJm6SfyCph3Q!ohpC82%)1WnuJfUZN$LTXm&tlCzQ#b528FWB>8596 zm_8!#0;4F&$)n*(h^lA~N6~L1XDKlt&Uwh6A}xb7a#He+xuSGGq_fXect@v!GSp`e zrS{Gh%c8y*!A;{00E?=Qj6i*^We*bv=A>an3Yg#xjf|j@Jj1{1DjpjkPGKp0+jtU^ zh=w5a9uu)(kVW$(3=s8`)8pePSJb~0;t4f*tRVOdg$+c0Ksi<)Hp}=;6yuqg0wE*% z79I^^C}zTZjBry8(L1Qf%dv1|R5VVYL!+@!WF#!Aftx%n0HPXVFrHW*(Fj$HB&Uz^ zUizRR0Ul3`4}we~HN#ItVxm5cx;(fGj>ndX6HORM!g|7Ju^=`rY8h6FYRFW@g2;@h z4UxD}R0BOl%?r~L5mCz|syuilWM^F<3F~;`pm}1Qd6FcE_FYU2Imq-x#|TZtqBM0& z>yAVyCiHCZ4ev>!y6(IQd*77hF|2t5OAV}55G7&fiVhVijHQmL>aD$DNJF~ zkOX50`ieyjaCnUOQB$dxB@9rSFIv_V%TH30Au7q7?rAn!{66|*l4uC4j;K2aycv&* zPI;vR{R@IVMnmIa(Iz1e6O;x?@5qC><0Ng6j*VJWRLzM-27~pM)H!JZYg)C$v`tYG z-jPr@kK^kI>LAsvn5cI~SCv>C(dQim`I+KR-TZZz5KEcqsxbK#sHRY-J znknThp4WWM_}3;+%2~PSNjRJ0=7VwH!32C~Q>wgr{!q$Pys#-@s<~r3va04R-c=oE zD;KP_E4Jb*);F#3x=r!g&C6ejd-vbA9Y|HxFP;*r)?YS#>hY#LyW+b8@m2aMe}{P%}uLCoXzvb+{L+g*@opn(%LInds8-7($*^2 zT9dZ*f^GeB|INbqhW^{OfmG*)OPbrZwtr4piWi(_JpPV zUV)xKd;h#gCqjpFS72O|Ge+xmYmOu2@SKTE7!bT3ZEc>!)>1OQu9! zciiUv)ai*AHzl0SOS=-z_LZ{wc+>t5O*g+BZwMsHj>lcc?^JbMHm%sp7Y4rPO4>UF zdq=9eWnt=)=~JuYQg6aqy|^x6tzYqY<2Bn8o*i-9j+LsK#nwbs$9&({j@&hJbshJ$ z>hhw$t!=;5|3J;zJy(ox8W&F9wr%@aZTnq{n(zCqfmI_{i(zcxUnHGcCP6?VsDePhSP9ScVjrm7XQ zJ?`pScKx9G-RgMPuA9Yi$HCj?{yWyvc;(iFb!$%Ni>`rxQqz#@I7h|8mlF2YoCM2< zmbc%w?Y?6|XVx*`TW(a`D2hAw-!>n>=vLKU8-9EEzC~YNbkEFH`7Z5!V4}8K-n6W^ z%D;8^wyR~?ee<~BI+QAHNL4jG&}-}t^iGxUF6BY>2UcyIx1NTy;!ZW5x6Ku(ji>C& z@LR+4dsAln8wV~PxODs@bH!aX=cz&5`b|G5|HJZxecKar?FvyYouPnZwV3+HZN`JcNE3Dgl+pzCbPs8Hqch9|ZZn@^2 z2z8JtWhzfi`>p~xo2_!OFk$tjdUw3v`@`Nu)&3v$E;lU9Ui;G3 zFMa#f_j-k@{g*V~Fs1BOi#|f;p6&0q{t%OL-w#`t-AMMz)mOg#<@Z{Js(qI3=NA@w+-E(vV`D{?YGT4P{A~2vpHPZ47|o1u32<1ofmAIlD4gaZR?H3k8Hbs zR@t1M%)n$r25(ENlBGVO)b~+o>m_Zf5#C3(y8CLbtW82=1J~5Od?C@eleGMAp8nJD zkETADB22qeR*8o-Ra!Bx`K_YQ?j2QwRD5=C6K2b2_q?2A`)BuDoaLE+eqiQoRrfg+ zFsHtKv3I^N32ag0^?zU@mY!C9cCU@AKA`&SZ>#owcFzD9QIb6VS#{O(+ciI_UVl=p z`SY^+lLpOCx4KVS4Dp((lMYU3DLL7yxuw&=|I5;{lUsU*Vx5R7iSm$;aU*FF5`rA3$A99q?k{7v@eQtvi)f%`Q_(&m^cJ@9 zcuFq6V)2AGpP`DsU~54{`^3C;5J_a$C0!D7t0TH6=Y|6EQTa6y{gd)TjR>2B$3QRy zG(*%oIWC|L=mL7EEi@gPEOCy?ukA%jQieWF$&V~o9VyITu3r}@?8;VpXzDqies0yc zhYv$o-~|&Mo#vyoy5^=3ne{os{%UlH7Uvhpt;SQ?AhZM<#Xl(JUE4OM&t+Au;A6WP(G zl98LY8M%4VROy1zXa1?urD8;;APWLm^;`#ytpoX?d&qTuHHeIQn8uE|AxH#s-HX#SHcz_3eb+UR^uVGMyI0Y8hNL0E=)_!D z)JMRC}U0jotmNbe*ffy9x>1w;n4 zpo)6x6;un2WEBxjR3_wltWYg$Ta++Uq7V{wuu3(h4LVC8R^^ff23(2v56X0-ocRsS`-SfnBFbT$ z*CDF7G3n_NJY7joui)wZ)aJTky=;~CdMQ^~s=|Bi(A7iOwO;GF+LJ1)N|tp9WgV-9 zoV{(8136e#>)k{SN=j3%(v+uswU{e@W|cD+8|O_A$~e3Gtz!$V33p?{)|9lh2)34` z?(g?r?@ib?&uddw*Sz^*fL^_GtPZ29` zO`6K_mogVGm<4k~(%dbWyO)pMHuumclVGk-nl}jM4a?ZFY)x6q;)XJ5EHr>6JB2v^ z85GM;HY!s_MKi2G8Ol$DRkA7={mu3R%u1OY@``n5vH`LXHb)<*|kG zQs*zAvBi^(F@?bqqE(X1cOHOLczzAT2omh&IKP%+loAZ6V;+(&rOIbC2H3i62q%8C zo7E89US5JPV{|JQF&-)K>mx>_E6=GEUi0@AzCwAJ>++d$$rk?+(;DS?p0>$tcXInx z2mTYs&0-$T`~^yS!dPT2(0OKVEvH~nQNTiDFjSXY?~JNHK)sM>dBoTHhjZos>NjlF z!q7`eo{yen)e0)-ZK|P0CAkuSI<}}Cc-oHkM++vEg&(Cp8hfWOVpe}4!!n-&N(sn) z5it3UlqO(qBH!va(c5!{RNm0z^RbjlKJ!zQ#L_4^ER^F6wuvZ}g`ZF!OOwu>zvb3m zKnzI`ZBJ1jmgcb#?F5mG=k}`(A7C|guv$Y27%&adBAw&>j!wlGWvD$T60c}0gYdYv z%J}c4)J`gIM5@8XN0#b*VyO(y)0K+0{;{NHXU450+h*=@__kji`U9}68M@ac=gK@W z=hD~)_RiOx>cFR*L=L5#PNLn$!!F+W!%oMJXZN?XnW#}!agz!s~$CvvUu?tIexszS3UO+%F$zHzLqvV>5xU>Ow!Whg(#|H7fE!| zhs@;YjJXQVgh51p8b*1F1q%t=VH8L94lvb#tq<%>M!jb)u=9tIcZc~1EJ`lOu%Z{D zu`py!(#dB9p7faxveAKPFU(*@!d?=6;q;(%W*P<(kA|I=m!}2e2hn!^JLFv_kG5R= z61;t|3M)!iG<~>QG~f(1Oj}^dh&odktezl?R=AVlm$Hzhec ztzteEGeLtf3$utx7XmXM*&Zfs*@t7+>d}cQ$z+eDfq?WM@emEOkuwt*hf#hP0j6>0 ze}lX|HJ3;9R!%$$5$rMDD5Z;lBKpGVl|kfcb9^xc?2Ke1>nM#1ui%y*EHj|COM>KSw?!-%v9OR_Msk!ptFelXo1 zddYILi(jT4S$Df~-NnH>m9_JOcRXcD&pN@gE@ig8apdBWg=5!( zSA*XhS#J3K7s!x@B&Is9w05!nyRGlECQCXl9+=;iTVPU&RCLZC{+XjV zWpiJ#T(*4OnyPM3l~zhB>{NYovVNyfzcba`mTc}5n)~jWbk3srJr9bw^6F$+>#efZ zL|Oa%0ISUFM;1*>Q-ZfwFhBEKN1$M`x!(B7#jnI`wLCD%_U+jj}=yAtht zE}0hk1zXdS3ew!vy6)xE!n!>-$Axu=lU)I!E0E}V4k-?ksCS)UYfrUyEWLbvTe5YB z(7Gehx(jg!1Y5&mNU$}hoSsWBU*0A-+ftqDm$zL1aMrTjGe0(c1D7;`L?q=D{t?Op@eH^(zRc3?Pq!0gp&4T zNw-kaohaG#3(9+UK*g1Iu0mifSz18~OIAL>zT2jrPccqpPSm?NCzP+hf$r-`+O`X} z?Frk?q;0QY+nccMPmwU+Cpdg})hc@nS%=x*ICSyQLJ4$nFjZ*n{C?Z@w&jQqko2082r z+MLO!R=GRaOl%W8ZD@zPJp<)54BTKoV{=~Ga`DT8wRy$2XT?=|rT@+Tg;$og2-V#; zfRE24T{{KW&V*~XghmNh|0AF;g#=G0710IjoEFC=4QuNO!PbaETqRd_T;8#GG~sMq zSvRm!0+XG$3m03L&I+}gZyf#5EtDMjq_pCy^PBEZN~^BAzgf(bFYAP|b!d9&df)<0 z<1V4BYq>@!+ipC7P zX1~ySAW?nr!xshD(Z`mepu+PYEvse@B$7>>oe9&<40=BXQ&FYtsRD~iM#302Qne(+ zY8I+C+_2va2(E#()mhccxJal3I~^dJyYUtjSMFxmZyC&Q6un-w(6eL~%DV-_#@{aD z8?a!lQDhU}5BVQ=p970UdaDu0rLp(Hr(8FJ^?aDpGUr!pIKnFAg4#KME?oBV2ZVEh_C>2mU@ z$^E*m8o!=od^oNIo{kO?MKmZ;gIZ{nT42X89muCrmrEFO%RP{JDshU0^x-&^3Q1_X zRH`G!F4y@D>}VM(8PG@)`+9`x|HZqD1{!x6F$Ro(7XU=P{E7Hv#?1RH{33;G$RlhMorBY{ zgBb&R$zYr@zQo!}j4vd|oW#(+Pa)z!`9CF(SlhHlTY}5B4%8i@AzgNq{~2B%(qWtE z-kfw2u`Gv`%ZmF8V&=5)ps_UBE^Y%Obn@y+p|TU!oeRN)qdjhJXAxk2uJm8-2cwp@ zmK{r&>Q@{M@x}uQ$3d7?GsD+sE>Y42J% zi-V=C5kL+b61L`~txK?VC2Z?c-uhH^y}~uY!Ue`HM$TIDz@)X5-ZcZjs+B9WuX2EQ z*TL;k4X6^%fj?V$aN0ExH{!5 zU)XiIH*T&8fKm{DUj=+keu}S;D{b(hO|v#E}uddC;k;O^4ON=BOpv=JX9vb z&ph)fXXQG|&xj#l7^=$TkwZ`z8-~2{Tjm)k^czUC*UzsQK1R&4B}0IC5RqQlXb-0+tHBvQ+4lrNX3yEZHHSDfCqk$i32?b}-?j z+t$L-5q@F@7aB#SJDf7B#?y+fl)LewxUX&!7b~5?MR;v=vDGnt7M7K?Hc?SZx;2H; zeM<6`MzR2lWCoL=3%FtGA#+aYGYQ1fcj~PHMmAQg4Ws*XePu$P{H(XkXLWM?|E!)p23adCfpC+|S zG*fAEy>M=ejzIcKr3JQzz|q3Zi{!mP9xc=&lgMN@Qj%V=rJD$=s5{8x=E5ld9K9GI z2EuJ*IBut#oQQ;Ry3;tpqJt!q(M+O5Nvfu0(k`A%8Tm7mQyrb1;`6Zq)+{NLRq03x zdHBDj{C`hghC?uAAu0cJiu;rb@KEcO{F#1G1w{u0ZW0D%J{vPom{9Fp7>%UrW{Zsw zmQ1DCb$yu}(T7aui1o;d|2two{*kdC*twYPzX}7gQeL$XqZ0~C>#n|r( zh#N{^ox9*j!2GJzH{YMjrlbxY{%QG7$`hWzPs(p9luARPm`kTQ-ar@Hb?c&}a zm(QExp1`N>;w#U;`TVz{i{0PteW!Qn#do$R>w1K`9_;Yzwk7NOgu1?D-4UVgNTT}L zg!?G=DYh~)--BsgP3w27-l(Kn3ol;Y9yixW=4+qb+sqXo#~}(FRXT!)8CAi-3*?K^V*Dr8+Wj?} zKkabuZ`Ax3w`YH==Fg1V_jhT?XhU%{t?2#!f5jyJ25!^B{Y0vc0`MkuH1sn)AC6~`Gd5DC+IWC}G!}IJ!ZL*I?td#@GZ6PwJ zGo;ur%c&_~c&rf_ZGh`DWyvAfRny@M@KM> zamVrpvBDG0_;BL69*I147YNV~HUp(-R zCF$KPcsJiDNO*gb-ra(Ccf#8@Z@P3qFjS_zbyA9M!P}kiZc2atYg1{ezG=zu&f#SJ z7NLGiqP}N-V8I}mYe2nij^UkDIUK;BCpu@G5}z!w$L2BE4d9LqiO}VqB*>23P1NpxAv>s@yPOU%(_PkJK6~Ej zNw_=abzd{$h67T7UWr_eEFSyr$#+gBJYD#1x}#7>MoF3kchizC;qC-WTI6_Inqzzn z%}ILq3*P;IuKlUy$CiZm_#;SA(z{jgZvCT%_uJoVPk0YJnj$6Zw+Z#z{wVzZ)O%Bj z`a@{~btmf-1o}d8G(8y+s$xV4VupeNt{<0ldU0k1Q&i&t|6>C27 z0(rM~7BFYJalA^MKA)n9)MQv0TUwl8P)kTnsR{i~jQANE4lC>->LD>ImP~#*lwnjJ zK`@LspQ|5M2*s0xU;%wr3NR?S&ToRyAaKfc za;}IW`)q^1h>@Vj^e_6(%F=+b&vm$**CwASM=_!45|!r4`GuNvB^J2u8va z<7dx6)Dw)Ig(Tzb^yDaL$e9s#-wu;9@DqT<Xz@_vJod;bGTzLs)>r0HQgnuV^cGtUjAhUz5KHbx)E08d|V(;XKTfZ z91d^|X$}x&v?AGS+{cFXcMaUeepSlsfRVvUbNf>G_sycM7OL_qEtgwhX0g#I7}m|3 z7Y{=nVJ%-+cX2vdR3{YGC5js6)v2<^#TT!3&YLeCy=YC9Rnh}S9nQ9;eR1s4+xn7i z`(bbLp*GPr7(a0;d16$+pCdf4hjs*a9lUYq^+OBA-_w5Ic-@H8TS@aa!MqJt4dz`d zU<~`{v|ggbhwBZT#$V_-ljUOV{INHNFAjgT6DrO^+XEv9i;!0G`dq!_g_d$;VB(fGD^Evyh{fkEFhDn;d;50e21u@Ur~0){>)9{)~zZmws~*J zlniFE44WP&(Pw9_^BJ7lvBx(Sw2Zfr zrx1O4DR^umz7Tn^xD+a@(i1q}vR}KjfT%T3i-okN*yy@5cFh}>BPI%9t5FLjP-F@j zUK+~smniDLlgAbqChcH6^df!xpXB`xdA|!!ERs&>lEq|@Ru&z5U!TQq6Kf*z+i727 zyDe#s8bnxtnlygQI6 z-8o;Bvb&S^I>BDIxDU7a;3_9~*&VDmuVAUUQ(C=f7E0UiNb5~1{feuz7Au9a&O6eQ zb4OZo?v`3S#(5)Mv{sZfRNXREeb2o#`}?r;+_dWb9FmN($LuN|v5k22zTqKB#=(Mz zBpLfUnIvOhw*gkV5yGVNpyKY8us&63BoK@5kVr`Io%ca%Jm zkV)L0X>gEnfB%3laH~7`dYTyn2#W zqyU?pq_8;{d~r53nU1jrgJTnXG)9V~$TXt$!Qklh2rhl3rIshbH&Z(CMCzo8SP+Sb zy2A{d#1IZr)C=(NQ(a^g{W`t>9(jL2o{_xw$Xg-r0eQ6ft1v&jxOvF+^wrx`UXQ#yXVsA!p3V`u5P)u=tZ$R`o0`%qd-~g)C^~Dr;9wENJE`YW~T>g1dF)+G_B7 zt7>|-!gi6Kz1!LIPNC(27Vr0Z+*GA{V2M8+kqm&&;FC0@Wpo5Ww!y|;BrDA9>e(1VFpE?W%EDAhvQyX4 zRg*mNP%v-Dg|}=85X&goF)}&PA+y%(MoM-aB}uUOC}nW4R1#K3JJ=7jF-0Zrq@?WZ zhue52akf()QdFe#uqeQ0S95z2i(Vzs9L|C-xii}QAiR@Nh2{imQKN#1+ z$i1}Oux$<7%>00gZY4COpXrizKE%{Wh}cQr8tAo-N-;bmwPrg{M(p_GqVw=@?Uzb* zzlc*=@9Q{~@ng>YF<0^lS9NDy_iF{eY5X_FxT#&>I{uDx{1s>Zgxmfxx9(%k`S+9z z;jT}(9)aumm}~r)TmM&_1^zEIYE{8~Bd5X@b$V5I25nb4ye`dt|CQ^nJfPqM*Ab(t O{L;yP;^@h0`hNhSp-p-K diff --git a/mace-bench/3rdparty/mace/mace/calculators/foundations_models.py b/mace-bench/3rdparty/mace/mace/calculators/foundations_models.py index f4666ea..75e47c1 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/foundations_models.py +++ b/mace-bench/3rdparty/mace/mace/calculators/foundations_models.py @@ -1,339 +1,339 @@ -import os -import urllib.request -from pathlib import Path -from typing import Union - -import torch -from ase import units -from ase.calculators.mixing import SumCalculator - -from .mace import MACECalculator - -module_dir = os.path.dirname(__file__) -local_model_path = os.path.join( - module_dir, "foundations_models/mace-mpa-0-medium.model" -) - - -def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: - """ - Downloads or locates the MACE-MP checkpoint file. - - Args: - model (str, optional): Path to the model or size specification. - Defaults to None which uses the medium model. - - Returns: - str: Path to the downloaded (or cached, if previously loaded) checkpoint file. - """ - if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path): - return local_model_path - - urls = { - "small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", - "medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", - "large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", - "small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model", - "medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model", - "small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", - "medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", - "large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", - "medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", - "medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", - "medium-omat-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", - "mace-matpes-pbe-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", - "mace-matpes-r2scan-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", - } - - checkpoint_url = ( - urls.get(model, urls["medium-mpa-0"]) - if model - in ( - None, - "small", - "medium", - "large", - "small-0b", - "medium-0b", - "small-0b2", - "medium-0b2", - "large-0b2", - "medium-0b3", - "medium-mpa-0", - "medium-omat-0", - ) - else model - ) - - if checkpoint_url == urls["medium-mpa-0"]: - print( - "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument" - ) - ASL_checkpoint_urls = { - urls["medium-omat-0"], - urls["mace-matpes-pbe-0"], - urls["mace-matpes-r2scan-0"], - } - if checkpoint_url in ASL_checkpoint_urls: - print( - "Using model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use this model you accept the terms of the license." - ) - - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = "".join( - c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" - ) - cached_model_path = f"{cache_dir}/{checkpoint_url_name}" - - if not os.path.isfile(cached_model_path): - os.makedirs(cache_dir, exist_ok=True) - print(f"Downloading MACE model from {checkpoint_url!r}") - _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Model download failed, please check the URL {checkpoint_url}" - ) - print(f"Cached MACE model to {cached_model_path}") - - return cached_model_path - - -def mace_mp( - model: Union[str, Path] = None, - device: str = "", - default_dtype: str = "float32", - dispersion: bool = False, - damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"] - dispersion_xc: str = "pbe", - dispersion_cutoff: float = 40.0 * units.Bohr, - return_raw_model: bool = False, - **kwargs, -) -> MACECalculator: - """ - Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements). - The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models. - Note: - If you are using this function, please cite the relevant paper for the Materials Project, - any paper associated with the MACE model, and also the following: - - MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena, - Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096 - - MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b, - DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal - - Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang, - Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920 - - Args: - model (str, optional): Path to the model. Defaults to None which first checks for - a local model and then downloads the default model from figshare. Specify "small", - "medium" or "large" to download a smaller or larger model from figshare. - device (str, optional): Device to use for the model. Defaults to "cuda" if available. - default_dtype (str, optional): Default dtype for the model. Defaults to "float32". - dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False. - damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ). - dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections. - dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections. - return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. - **kwargs: Passed to MACECalculator and TorchDFTD3Calculator. - - Returns: - MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). - """ - try: - if model in ( - None, - "small", - "medium", - "large", - "medium-mpa-0", - "small-0b", - "medium-0b", - "small-0b2", - "medium-0b2", - "medium-0b3", - "large-0b2", - "medium-omat-0", - ) or str(model).startswith("https:"): - model_path = download_mace_mp_checkpoint(model) - print(f"Using Materials Project MACE for MACECalculator with {model_path}") - else: - if not Path(model).exists(): - raise FileNotFoundError(f"{model} not found locally") - model_path = model - except Exception as exc: - raise RuntimeError("Model download failed and no local model found") from exc - - device = device or ("cuda" if torch.cuda.is_available() else "cpu") - if default_dtype == "float64": - print( - "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." - ) - if default_dtype == "float32": - print( - "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." - ) - - if return_raw_model: - return torch.load(model_path, map_location=device) - - mace_calc = MACECalculator( - model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs - ) - - if not dispersion: - return mace_calc - - try: - from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator - except ImportError as exc: - raise RuntimeError( - "Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)" - ) from exc - - print("Using TorchDFTD3Calculator for D3 dispersion corrections") - dtype = torch.float32 if default_dtype == "float32" else torch.float64 - d3_calc = TorchDFTD3Calculator( - device=device, - damping=damping, - dtype=dtype, - xc=dispersion_xc, - cutoff=dispersion_cutoff, - **kwargs, - ) - - return SumCalculator([mace_calc, d3_calc]) - - -def mace_off( - model: Union[str, Path] = None, - device: str = "", - default_dtype: str = "float64", - return_raw_model: bool = False, - **kwargs, -) -> MACECalculator: - """ - Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models. - The model is released under the ASL license. - Note: - If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211 - - Args: - model (str, optional): Path to the model. Defaults to None which first checks for - a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off. - Specify "small", "medium" or "large" to download a smaller or larger model. - device (str, optional): Device to use for the model. Defaults to "cuda". - default_dtype (str, optional): Default dtype for the model. Defaults to "float64". - return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. - **kwargs: Passed to MACECalculator. - - Returns: - MACECalculator: trained on the MACE-OFF23 dataset - """ - try: - if model in (None, "small", "medium", "large") or str(model).startswith( - "https:" - ): - urls = dict( - small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", - medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", - large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", - ) - checkpoint_url = ( - urls.get(model, urls["medium"]) - if model in (None, "small", "medium", "large") - else model - ) - cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] - cached_model_path = f"{cache_dir}/{checkpoint_url_name}" - if not os.path.isfile(cached_model_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - print(f"Downloading MACE model from {checkpoint_url!r}") - print( - "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." - ) - print( - "ASL is based on the Gnu Public License, but does not permit commercial use" - ) - urllib.request.urlretrieve(checkpoint_url, cached_model_path) - print(f"Cached MACE model to {cached_model_path}") - model = cached_model_path - msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" - print(msg) - else: - if not Path(model).exists(): - raise FileNotFoundError(f"{model} not found locally") - except Exception as exc: - raise RuntimeError("Model download failed and no local model found") from exc - - device = device or ("cuda" if torch.cuda.is_available() else "cpu") - - if return_raw_model: - return torch.load(model, map_location=device) - - if default_dtype == "float64": - print( - "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." - ) - if default_dtype == "float32": - print( - "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." - ) - mace_calc = MACECalculator( - model_paths=model, device=device, default_dtype=default_dtype, **kwargs - ) - return mace_calc - - -def mace_anicc( - device: str = "cuda", - model_path: str = None, - return_raw_model: bool = False, -) -> MACECalculator: - """ - Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O). - The model is released under the MIT license. - Note: - If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following: - - "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322 - """ - if model_path is None: - model_path = os.path.join( - module_dir, "foundations_models/ani500k_large_CC.model" - ) - print( - "Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322" - ) - - if not os.path.exists(model_path): - model_dir = os.path.dirname(model_path) - os.makedirs(model_dir, exist_ok=True) - - # Download the model - print(f"Model not found at {model_path}. Downloading...") - model_url = "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model" - - try: - - def report_progress(block_num, block_size, total_size): - downloaded = block_num * block_size - percent = min(100, downloaded * 100 / total_size) - if total_size > 0: - print( - f"\rDownloading model: {percent:.1f}% ({downloaded / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)", - end="", - ) - - urllib.request.urlretrieve( - model_url, model_path, reporthook=report_progress - ) - print("\nDownload complete!") - - except Exception as e: - raise RuntimeError(f"Failed to download model: {e}") from e - - if return_raw_model: - return torch.load(model_path, map_location=device) - return MACECalculator( - model_paths=model_path, device=device, default_dtype="float64" - ) +import os +import urllib.request +from pathlib import Path +from typing import Union + +import torch +from ase import units +from ase.calculators.mixing import SumCalculator + +from .mace import MACECalculator + +module_dir = os.path.dirname(__file__) +local_model_path = os.path.join( + module_dir, "foundations_models/mace-mpa-0-medium.model" +) + + +def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str: + """ + Downloads or locates the MACE-MP checkpoint file. + + Args: + model (str, optional): Path to the model or size specification. + Defaults to None which uses the medium model. + + Returns: + str: Path to the downloaded (or cached, if previously loaded) checkpoint file. + """ + if model in (None, "medium-mpa-0") and os.path.isfile(local_model_path): + return local_model_path + + urls = { + "small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", + "medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", + "large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/MACE_MPtrj_2022.9.model", + "small-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_small.model", + "medium-0b": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mace_agnesi_medium.model", + "small-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", + "medium-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", + "large-0b2": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", + "medium-0b3": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", + "medium-mpa-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", + "medium-omat-0": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", + "mace-matpes-pbe-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", + "mace-matpes-r2scan-0": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", + } + + checkpoint_url = ( + urls.get(model, urls["medium-mpa-0"]) + if model + in ( + None, + "small", + "medium", + "large", + "small-0b", + "medium-0b", + "small-0b2", + "medium-0b2", + "large-0b2", + "medium-0b3", + "medium-mpa-0", + "medium-omat-0", + ) + else model + ) + + if checkpoint_url == urls["medium-mpa-0"]: + print( + "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument" + ) + ASL_checkpoint_urls = { + urls["medium-omat-0"], + urls["mace-matpes-pbe-0"], + urls["mace-matpes-r2scan-0"], + } + if checkpoint_url in ASL_checkpoint_urls: + print( + "Using model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use this model you accept the terms of the license." + ) + + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + print(f"Downloading MACE model from {checkpoint_url!r}") + _, http_msg = urllib.request.urlretrieve(checkpoint_url, cached_model_path) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Model download failed, please check the URL {checkpoint_url}" + ) + print(f"Cached MACE model to {cached_model_path}") + + return cached_model_path + + +def mace_mp( + model: Union[str, Path] = None, + device: str = "", + default_dtype: str = "float32", + dispersion: bool = False, + damping: str = "bj", # choices: ["zero", "bj", "zerom", "bjm"] + dispersion_xc: str = "pbe", + dispersion_cutoff: float = 40.0 * units.Bohr, + return_raw_model: bool = False, + **kwargs, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the Materials Project (89 elements). + The model is released under the MIT license. See https://github.com/ACEsuit/mace-mp for all models. + Note: + If you are using this function, please cite the relevant paper for the Materials Project, + any paper associated with the MACE model, and also the following: + - MACE-MP by Ilyes Batatia, Philipp Benner, Yuan Chiang, Alin M. Elena, + Dávid P. Kovács, Janosh Riebesell, et al., 2023, arXiv:2401.00096 + - MACE-Universal by Yuan Chiang, 2023, Hugging Face, Revision e5ebd9b, + DOI: 10.57967/hf/1202, URL: https://huggingface.co/cyrusyc/mace-universal + - Matbench Discovery by Janosh Riebesell, Rhys EA Goodall, Philipp Benner, Yuan Chiang, + Alpha A Lee, Anubhav Jain, Kristin A Persson, 2023, arXiv:2308.14920 + + Args: + model (str, optional): Path to the model. Defaults to None which first checks for + a local model and then downloads the default model from figshare. Specify "small", + "medium" or "large" to download a smaller or larger model from figshare. + device (str, optional): Device to use for the model. Defaults to "cuda" if available. + default_dtype (str, optional): Default dtype for the model. Defaults to "float32". + dispersion (bool, optional): Whether to use D3 dispersion corrections. Defaults to False. + damping (str): The damping function associated with the D3 correction. Defaults to "bj" for D3(BJ). + dispersion_xc (str, optional): Exchange-correlation functional for D3 dispersion corrections. + dispersion_cutoff (float, optional): Cutoff radius in Bohr for D3 dispersion corrections. + return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. + **kwargs: Passed to MACECalculator and TorchDFTD3Calculator. + + Returns: + MACECalculator: trained on the MPtrj dataset (unless model otherwise specified). + """ + try: + if model in ( + None, + "small", + "medium", + "large", + "medium-mpa-0", + "small-0b", + "medium-0b", + "small-0b2", + "medium-0b2", + "medium-0b3", + "large-0b2", + "medium-omat-0", + ) or str(model).startswith("https:"): + model_path = download_mace_mp_checkpoint(model) + print(f"Using Materials Project MACE for MACECalculator with {model_path}") + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + model_path = model + except Exception as exc: + raise RuntimeError("Model download failed and no local model found") from exc + + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + if default_dtype == "float64": + print( + "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." + ) + if default_dtype == "float32": + print( + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." + ) + + if return_raw_model: + return torch.load(model_path, map_location=device) + + mace_calc = MACECalculator( + model_paths=model_path, device=device, default_dtype=default_dtype, **kwargs + ) + + if not dispersion: + return mace_calc + + try: + from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator + except ImportError as exc: + raise RuntimeError( + "Please install torch-dftd to use dispersion corrections (see https://github.com/pfnet-research/torch-dftd)" + ) from exc + + print("Using TorchDFTD3Calculator for D3 dispersion corrections") + dtype = torch.float32 if default_dtype == "float32" else torch.float64 + d3_calc = TorchDFTD3Calculator( + device=device, + damping=damping, + dtype=dtype, + xc=dispersion_xc, + cutoff=dispersion_cutoff, + **kwargs, + ) + + return SumCalculator([mace_calc, d3_calc]) + + +def mace_off( + model: Union[str, Path] = None, + device: str = "", + default_dtype: str = "float64", + return_raw_model: bool = False, + **kwargs, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the MACE-OFF23 models. + The model is released under the ASL license. + Note: + If you are using this function, please cite the relevant paper by Kovacs et.al., arXiv:2312.15211 + + Args: + model (str, optional): Path to the model. Defaults to None which first checks for + a local model and then downloads the default medium model from https://github.com/ACEsuit/mace-off. + Specify "small", "medium" or "large" to download a smaller or larger model. + device (str, optional): Device to use for the model. Defaults to "cuda". + default_dtype (str, optional): Default dtype for the model. Defaults to "float64". + return_raw_model (bool, optional): Whether to return the raw model or an ASE calculator. Defaults to False. + **kwargs: Passed to MACECalculator. + + Returns: + MACECalculator: trained on the MACE-OFF23 dataset + """ + try: + if model in (None, "small", "medium", "large") or str(model).startswith( + "https:" + ): + urls = dict( + small="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_small.model?raw=true", + medium="https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model?raw=true", + large="https://github.com/ACEsuit/mace-off/blob/main/mace_off23/MACE-OFF23_large.model?raw=true", + ) + checkpoint_url = ( + urls.get(model, urls["medium"]) + if model in (None, "small", "medium", "large") + else model + ) + cache_dir = os.path.expanduser("~/.cache/mace") + checkpoint_url_name = os.path.basename(checkpoint_url).split("?")[0] + cached_model_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_model_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + print(f"Downloading MACE model from {checkpoint_url!r}") + print( + "The model is distributed under the Academic Software License (ASL) license, see https://github.com/gabor1/ASL \n To use the model you accept the terms of the license." + ) + print( + "ASL is based on the Gnu Public License, but does not permit commercial use" + ) + urllib.request.urlretrieve(checkpoint_url, cached_model_path) + print(f"Cached MACE model to {cached_model_path}") + model = cached_model_path + msg = f"Using MACE-OFF23 MODEL for MACECalculator with {model}" + print(msg) + else: + if not Path(model).exists(): + raise FileNotFoundError(f"{model} not found locally") + except Exception as exc: + raise RuntimeError("Model download failed and no local model found") from exc + + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + if return_raw_model: + return torch.load(model, map_location=device) + + if default_dtype == "float64": + print( + "Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization." + ) + if default_dtype == "float32": + print( + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization." + ) + mace_calc = MACECalculator( + model_paths=model, device=device, default_dtype=default_dtype, **kwargs + ) + return mace_calc + + +def mace_anicc( + device: str = "cuda", + model_path: str = None, + return_raw_model: bool = False, +) -> MACECalculator: + """ + Constructs a MACECalculator with a pretrained model based on the ANI (H, C, N, O). + The model is released under the MIT license. + Note: + If you are using this function, please cite the relevant paper associated with the MACE model, ANI dataset, and also the following: + - "Evaluation of the MACE Force Field Architecture by Dávid Péter Kovács, Ilyes Batatia, Eszter Sára Arany, and Gábor Csányi, The Journal of Chemical Physics, 2023, URL: https://doi.org/10.1063/5.0155322 + """ + if model_path is None: + model_path = os.path.join( + module_dir, "foundations_models/ani500k_large_CC.model" + ) + print( + "Using ANI couple cluster model for MACECalculator, see https://doi.org/10.1063/5.0155322" + ) + + if not os.path.exists(model_path): + model_dir = os.path.dirname(model_path) + os.makedirs(model_dir, exist_ok=True) + + # Download the model + print(f"Model not found at {model_path}. Downloading...") + model_url = "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model" + + try: + + def report_progress(block_num, block_size, total_size): + downloaded = block_num * block_size + percent = min(100, downloaded * 100 / total_size) + if total_size > 0: + print( + f"\rDownloading model: {percent:.1f}% ({downloaded / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)", + end="", + ) + + urllib.request.urlretrieve( + model_url, model_path, reporthook=report_progress + ) + print("\nDownload complete!") + + except Exception as e: + raise RuntimeError(f"Failed to download model: {e}") from e + + if return_raw_model: + return torch.load(model_path, map_location=device) + return MACECalculator( + model_paths=model_path, device=device, default_dtype="float64" + ) diff --git a/mace-bench/3rdparty/mace/mace/calculators/lammps_mace.py b/mace-bench/3rdparty/mace/mace/calculators/lammps_mace.py index 4a1edc0..4211c37 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/lammps_mace.py +++ b/mace-bench/3rdparty/mace/mace/calculators/lammps_mace.py @@ -1,105 +1,105 @@ -from typing import Dict, List, Optional - -import torch -from e3nn.util.jit import compile_mode - -from mace.tools.scatter import scatter_sum - - -@compile_mode("script") -class LAMMPS_MACE(torch.nn.Module): - def __init__(self, model, **kwargs): - super().__init__() - self.model = model - self.register_buffer("atomic_numbers", model.atomic_numbers) - self.register_buffer("r_max", model.r_max) - self.register_buffer("num_interactions", model.num_interactions) - if not hasattr(model, "heads"): - model.heads = [None] - self.register_buffer( - "head", - torch.tensor( - self.model.heads.index(kwargs.get("head", self.model.heads[-1])), - dtype=torch.long, - ).unsqueeze(0), - ) - - for param in self.model.parameters(): - param.requires_grad = False - - def forward( - self, - data: Dict[str, torch.Tensor], - local_or_ghost: torch.Tensor, - compute_virials: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - num_graphs = data["ptr"].numel() - 1 - compute_displacement = False - if compute_virials: - compute_displacement = True - data["head"] = self.head - out = self.model( - data, - training=False, - compute_force=False, - compute_virials=False, - compute_stress=False, - compute_displacement=compute_displacement, - ) - node_energy = out["node_energy"] - if node_energy is None: - return { - "total_energy_local": None, - "node_energy": None, - "forces": None, - "virials": None, - } - positions = data["positions"] - displacement = out["displacement"] - forces: Optional[torch.Tensor] = torch.zeros_like(positions) - virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) - # accumulate energies of local atoms - node_energy_local = node_energy * local_or_ghost - total_energy_local = scatter_sum( - src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs - ) - # compute partial forces and (possibly) partial virials - grad_outputs: List[Optional[torch.Tensor]] = [ - torch.ones_like(total_energy_local) - ] - if compute_virials and displacement is not None: - forces, virials = torch.autograd.grad( - outputs=[total_energy_local], - inputs=[positions, displacement], - grad_outputs=grad_outputs, - retain_graph=False, - create_graph=False, - allow_unused=True, - ) - if forces is not None: - forces = -1 * forces - else: - forces = torch.zeros_like(positions) - if virials is not None: - virials = -1 * virials - else: - virials = torch.zeros_like(displacement) - else: - forces = torch.autograd.grad( - outputs=[total_energy_local], - inputs=[positions], - grad_outputs=grad_outputs, - retain_graph=False, - create_graph=False, - allow_unused=True, - )[0] - if forces is not None: - forces = -1 * forces - else: - forces = torch.zeros_like(positions) - return { - "total_energy_local": total_energy_local, - "node_energy": node_energy, - "forces": forces, - "virials": virials, - } +from typing import Dict, List, Optional + +import torch +from e3nn.util.jit import compile_mode + +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class LAMMPS_MACE(torch.nn.Module): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + self.register_buffer("atomic_numbers", model.atomic_numbers) + self.register_buffer("r_max", model.r_max) + self.register_buffer("num_interactions", model.num_interactions) + if not hasattr(model, "heads"): + model.heads = [None] + self.register_buffer( + "head", + torch.tensor( + self.model.heads.index(kwargs.get("head", self.model.heads[-1])), + dtype=torch.long, + ).unsqueeze(0), + ) + + for param in self.model.parameters(): + param.requires_grad = False + + def forward( + self, + data: Dict[str, torch.Tensor], + local_or_ghost: torch.Tensor, + compute_virials: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + num_graphs = data["ptr"].numel() - 1 + compute_displacement = False + if compute_virials: + compute_displacement = True + data["head"] = self.head + out = self.model( + data, + training=False, + compute_force=False, + compute_virials=False, + compute_stress=False, + compute_displacement=compute_displacement, + ) + node_energy = out["node_energy"] + if node_energy is None: + return { + "total_energy_local": None, + "node_energy": None, + "forces": None, + "virials": None, + } + positions = data["positions"] + displacement = out["displacement"] + forces: Optional[torch.Tensor] = torch.zeros_like(positions) + virials: Optional[torch.Tensor] = torch.zeros_like(data["cell"]) + # accumulate energies of local atoms + node_energy_local = node_energy * local_or_ghost + total_energy_local = scatter_sum( + src=node_energy_local, index=data["batch"], dim=-1, dim_size=num_graphs + ) + # compute partial forces and (possibly) partial virials + grad_outputs: List[Optional[torch.Tensor]] = [ + torch.ones_like(total_energy_local) + ] + if compute_virials and displacement is not None: + forces, virials = torch.autograd.grad( + outputs=[total_energy_local], + inputs=[positions, displacement], + grad_outputs=grad_outputs, + retain_graph=False, + create_graph=False, + allow_unused=True, + ) + if forces is not None: + forces = -1 * forces + else: + forces = torch.zeros_like(positions) + if virials is not None: + virials = -1 * virials + else: + virials = torch.zeros_like(displacement) + else: + forces = torch.autograd.grad( + outputs=[total_energy_local], + inputs=[positions], + grad_outputs=grad_outputs, + retain_graph=False, + create_graph=False, + allow_unused=True, + )[0] + if forces is not None: + forces = -1 * forces + else: + forces = torch.zeros_like(positions) + return { + "total_energy_local": total_energy_local, + "node_energy": node_energy, + "forces": forces, + "virials": virials, + } diff --git a/mace-bench/3rdparty/mace/mace/calculators/lammps_mliap_mace.py b/mace-bench/3rdparty/mace/mace/calculators/lammps_mliap_mace.py index 036931c..f40e9b8 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/lammps_mliap_mace.py +++ b/mace-bench/3rdparty/mace/mace/calculators/lammps_mliap_mace.py @@ -1,214 +1,214 @@ -import logging -import os -import sys -import time -from contextlib import contextmanager -from typing import Dict, Tuple - -import torch -from ase.data import chemical_symbols -from e3nn.util.jit import compile_mode - -try: - from lammps.mliap.mliap_unified_abc import MLIAPUnified -except ImportError: - - class MLIAPUnified: - def __init__(self): - pass - - -class MACELammpsConfig: - """Configuration settings for MACE-LAMMPS integration.""" - - def __init__(self): - self.debug_time = self._get_env_bool("MACE_TIME", False) - self.debug_profile = self._get_env_bool("MACE_PROFILE", False) - self.profile_start_step = int(os.environ.get("MACE_PROFILE_START", "5")) - self.profile_end_step = int(os.environ.get("MACE_PROFILE_END", "10")) - self.allow_cpu = self._get_env_bool("MACE_ALLOW_CPU", False) - self.force_cpu = self._get_env_bool("MACE_FORCE_CPU", False) - - @staticmethod - def _get_env_bool(var_name: str, default: bool) -> bool: - return os.environ.get(var_name, str(default)).lower() in ( - "true", - "1", - "t", - "yes", - ) - - -@contextmanager -def timer(name: str, enabled: bool = True): - """Context manager for timing code blocks.""" - if not enabled: - yield - return - - start = time.perf_counter() - try: - yield - finally: - elapsed = time.perf_counter() - start - logging.info(f"Timer - {name}: {elapsed*1000:.3f} ms") - - -@compile_mode("script") -class MACEEdgeForcesWrapper(torch.nn.Module): - """Wrapper that adds per-pair force computation to a MACE model.""" - - def __init__(self, model: torch.nn.Module, **kwargs): - super().__init__() - self.model = model - self.register_buffer("atomic_numbers", model.atomic_numbers) - self.register_buffer("r_max", model.r_max) - self.register_buffer("num_interactions", model.num_interactions) - - if not hasattr(model, "heads"): - model.heads = ["Default"] - - head_name = kwargs.get("head", model.heads[-1]) - head_idx = model.heads.index(head_name) - self.register_buffer("head", torch.tensor([head_idx], dtype=torch.long)) - - for p in self.model.parameters(): - p.requires_grad = False - - def forward( - self, data: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute energies and per-pair forces.""" - data["head"] = self.head - - out = self.model( - data, - training=False, - compute_force=False, - compute_virials=False, - compute_stress=False, - compute_displacement=False, - compute_edge_forces=True, - lammps_mliap=True, - ) - - node_energy = out["node_energy"] - pair_forces = out["edge_forces"] - total_energy = out["energy"][0] - - if pair_forces is None: - pair_forces = torch.zeros_like(data["vectors"]) - - return total_energy, node_energy, pair_forces - - -class LAMMPS_MLIAP_MACE(MLIAPUnified): - """MACE integration for LAMMPS using the MLIAP interface.""" - - def __init__(self, model, **kwargs): - super().__init__() - self.config = MACELammpsConfig() - self.model = MACEEdgeForcesWrapper(model, **kwargs) - self.element_types = [chemical_symbols[s] for s in model.atomic_numbers] - self.num_species = len(self.element_types) - self.rcutfac = 0.5 * float(model.r_max) - self.ndescriptors = 1 - self.nparams = 1 - self.dtype = model.r_max.dtype - self.device = "cpu" - self.initialized = False - self.step = 0 - - def _initialize_device(self, data): - using_kokkos = "kokkos" in data.__class__.__module__.lower() - - if using_kokkos and not self.config.force_cpu: - device = torch.as_tensor(data.elems).device - if device.type == "cpu" and not self.config.allow_cpu: - raise ValueError( - "GPU requested but tensor is on CPU. Set MACE_ALLOW_CPU=true to allow CPU computation." - ) - else: - device = torch.device("cpu") - - self.device = device - self.model = self.model.to(device) - logging.info(f"MACE model initialized on device: {device}") - self.initialized = True - - def compute_forces(self, data): - natoms = data.nlocal - ntotal = data.ntotal - nghosts = ntotal - natoms - npairs = data.npairs - species = torch.as_tensor(data.elems, dtype=torch.int64) - - if not self.initialized: - self._initialize_device(data) - - self.step += 1 - self._manage_profiling() - - if natoms == 0 or npairs <= 1: - return - - with timer("total_step", enabled=self.config.debug_time): - with timer("prepare_batch", enabled=self.config.debug_time): - batch = self._prepare_batch(data, natoms, nghosts, species) - - with timer("model_forward", enabled=self.config.debug_time): - _, atom_energies, pair_forces = self.model(batch) - - if self.device.type != "cpu": - torch.cuda.synchronize() - - with timer("update_lammps", enabled=self.config.debug_time): - self._update_lammps_data(data, atom_energies, pair_forces, natoms) - - def _prepare_batch(self, data, natoms, nghosts, species): - """Prepare the input batch for the MACE model.""" - return { - "vectors": torch.as_tensor(data.rij).to(self.dtype).to(self.device), - "node_attrs": torch.nn.functional.one_hot( - species.to(self.device), num_classes=self.num_species - ).to(self.dtype), - "edge_index": torch.stack( - [ - torch.as_tensor(data.pair_j, dtype=torch.int64).to(self.device), - torch.as_tensor(data.pair_i, dtype=torch.int64).to(self.device), - ], - dim=0, - ), - "batch": torch.zeros(natoms, dtype=torch.int64, device=self.device), - "lammps_class": data, - "natoms": (natoms, nghosts), - } - - def _update_lammps_data(self, data, atom_energies, pair_forces, natoms): - """Update LAMMPS data structures with computed energies and forces.""" - if self.dtype == torch.float32: - pair_forces = pair_forces.double() - eatoms = torch.as_tensor(data.eatoms) - eatoms.copy_(atom_energies[:natoms]) - data.energy = torch.sum(atom_energies[:natoms]) - data.update_pair_forces_gpu(pair_forces) - - def _manage_profiling(self): - if not self.config.debug_profile: - return - - if self.step == self.config.profile_start_step: - logging.info(f"Starting CUDA profiler at step {self.step}") - torch.cuda.profiler.start() - - if self.step == self.config.profile_end_step: - logging.info(f"Stopping CUDA profiler at step {self.step}") - torch.cuda.profiler.stop() - logging.info("Profiling complete. Exiting.") - sys.exit() - - def compute_descriptors(self, data): - pass - - def compute_gradients(self, data): - pass +import logging +import os +import sys +import time +from contextlib import contextmanager +from typing import Dict, Tuple + +import torch +from ase.data import chemical_symbols +from e3nn.util.jit import compile_mode + +try: + from lammps.mliap.mliap_unified_abc import MLIAPUnified +except ImportError: + + class MLIAPUnified: + def __init__(self): + pass + + +class MACELammpsConfig: + """Configuration settings for MACE-LAMMPS integration.""" + + def __init__(self): + self.debug_time = self._get_env_bool("MACE_TIME", False) + self.debug_profile = self._get_env_bool("MACE_PROFILE", False) + self.profile_start_step = int(os.environ.get("MACE_PROFILE_START", "5")) + self.profile_end_step = int(os.environ.get("MACE_PROFILE_END", "10")) + self.allow_cpu = self._get_env_bool("MACE_ALLOW_CPU", False) + self.force_cpu = self._get_env_bool("MACE_FORCE_CPU", False) + + @staticmethod + def _get_env_bool(var_name: str, default: bool) -> bool: + return os.environ.get(var_name, str(default)).lower() in ( + "true", + "1", + "t", + "yes", + ) + + +@contextmanager +def timer(name: str, enabled: bool = True): + """Context manager for timing code blocks.""" + if not enabled: + yield + return + + start = time.perf_counter() + try: + yield + finally: + elapsed = time.perf_counter() - start + logging.info(f"Timer - {name}: {elapsed*1000:.3f} ms") + + +@compile_mode("script") +class MACEEdgeForcesWrapper(torch.nn.Module): + """Wrapper that adds per-pair force computation to a MACE model.""" + + def __init__(self, model: torch.nn.Module, **kwargs): + super().__init__() + self.model = model + self.register_buffer("atomic_numbers", model.atomic_numbers) + self.register_buffer("r_max", model.r_max) + self.register_buffer("num_interactions", model.num_interactions) + + if not hasattr(model, "heads"): + model.heads = ["Default"] + + head_name = kwargs.get("head", model.heads[-1]) + head_idx = model.heads.index(head_name) + self.register_buffer("head", torch.tensor([head_idx], dtype=torch.long)) + + for p in self.model.parameters(): + p.requires_grad = False + + def forward( + self, data: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute energies and per-pair forces.""" + data["head"] = self.head + + out = self.model( + data, + training=False, + compute_force=False, + compute_virials=False, + compute_stress=False, + compute_displacement=False, + compute_edge_forces=True, + lammps_mliap=True, + ) + + node_energy = out["node_energy"] + pair_forces = out["edge_forces"] + total_energy = out["energy"][0] + + if pair_forces is None: + pair_forces = torch.zeros_like(data["vectors"]) + + return total_energy, node_energy, pair_forces + + +class LAMMPS_MLIAP_MACE(MLIAPUnified): + """MACE integration for LAMMPS using the MLIAP interface.""" + + def __init__(self, model, **kwargs): + super().__init__() + self.config = MACELammpsConfig() + self.model = MACEEdgeForcesWrapper(model, **kwargs) + self.element_types = [chemical_symbols[s] for s in model.atomic_numbers] + self.num_species = len(self.element_types) + self.rcutfac = 0.5 * float(model.r_max) + self.ndescriptors = 1 + self.nparams = 1 + self.dtype = model.r_max.dtype + self.device = "cpu" + self.initialized = False + self.step = 0 + + def _initialize_device(self, data): + using_kokkos = "kokkos" in data.__class__.__module__.lower() + + if using_kokkos and not self.config.force_cpu: + device = torch.as_tensor(data.elems).device + if device.type == "cpu" and not self.config.allow_cpu: + raise ValueError( + "GPU requested but tensor is on CPU. Set MACE_ALLOW_CPU=true to allow CPU computation." + ) + else: + device = torch.device("cpu") + + self.device = device + self.model = self.model.to(device) + logging.info(f"MACE model initialized on device: {device}") + self.initialized = True + + def compute_forces(self, data): + natoms = data.nlocal + ntotal = data.ntotal + nghosts = ntotal - natoms + npairs = data.npairs + species = torch.as_tensor(data.elems, dtype=torch.int64) + + if not self.initialized: + self._initialize_device(data) + + self.step += 1 + self._manage_profiling() + + if natoms == 0 or npairs <= 1: + return + + with timer("total_step", enabled=self.config.debug_time): + with timer("prepare_batch", enabled=self.config.debug_time): + batch = self._prepare_batch(data, natoms, nghosts, species) + + with timer("model_forward", enabled=self.config.debug_time): + _, atom_energies, pair_forces = self.model(batch) + + if self.device.type != "cpu": + torch.cuda.synchronize() + + with timer("update_lammps", enabled=self.config.debug_time): + self._update_lammps_data(data, atom_energies, pair_forces, natoms) + + def _prepare_batch(self, data, natoms, nghosts, species): + """Prepare the input batch for the MACE model.""" + return { + "vectors": torch.as_tensor(data.rij).to(self.dtype).to(self.device), + "node_attrs": torch.nn.functional.one_hot( + species.to(self.device), num_classes=self.num_species + ).to(self.dtype), + "edge_index": torch.stack( + [ + torch.as_tensor(data.pair_j, dtype=torch.int64).to(self.device), + torch.as_tensor(data.pair_i, dtype=torch.int64).to(self.device), + ], + dim=0, + ), + "batch": torch.zeros(natoms, dtype=torch.int64, device=self.device), + "lammps_class": data, + "natoms": (natoms, nghosts), + } + + def _update_lammps_data(self, data, atom_energies, pair_forces, natoms): + """Update LAMMPS data structures with computed energies and forces.""" + if self.dtype == torch.float32: + pair_forces = pair_forces.double() + eatoms = torch.as_tensor(data.eatoms) + eatoms.copy_(atom_energies[:natoms]) + data.energy = torch.sum(atom_energies[:natoms]) + data.update_pair_forces_gpu(pair_forces) + + def _manage_profiling(self): + if not self.config.debug_profile: + return + + if self.step == self.config.profile_start_step: + logging.info(f"Starting CUDA profiler at step {self.step}") + torch.cuda.profiler.start() + + if self.step == self.config.profile_end_step: + logging.info(f"Stopping CUDA profiler at step {self.step}") + torch.cuda.profiler.stop() + logging.info("Profiling complete. Exiting.") + sys.exit() + + def compute_descriptors(self, data): + pass + + def compute_gradients(self, data): + pass diff --git a/mace-bench/3rdparty/mace/mace/calculators/mace.py b/mace-bench/3rdparty/mace/mace/calculators/mace.py index 31dbeb1..794065b 100644 --- a/mace-bench/3rdparty/mace/mace/calculators/mace.py +++ b/mace-bench/3rdparty/mace/mace/calculators/mace.py @@ -1,705 +1,705 @@ -########################################################################################### -# The ASE Calculator for MACE -# Authors: Ilyes Batatia, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging - -# pylint: disable=wrong-import-position -import os -from glob import glob -from pathlib import Path -from typing import List, Union - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" - -import numpy as np -import torch -from ase.calculators.calculator import Calculator, all_changes -from ase.stress import full_3x3_to_voigt_6_stress -from e3nn import o3 - -from mace import data -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq -from mace.modules.utils import extract_invariant -from mace.tools import torch_geometric, torch_tools, utils -from mace.tools.compile import prepare -from mace.tools.scripts_utils import extract_model -import random -from mace.tools.torch_geometric.batch import Batch - -from mace.tools import ( - atomic_numbers_to_indices, - to_one_hot, -) - -import time - - -def get_model_dtype(model: torch.nn.Module) -> torch.dtype: - """Get the dtype of the model""" - mode_dtype = next(model.parameters()).dtype - if mode_dtype == torch.float64: - return "float64" - if mode_dtype == torch.float32: - return "float32" - raise ValueError(f"Unknown dtype {mode_dtype}") - - -class MACECalculator(Calculator): - """MACE ASE Calculator - args: - model_paths: str, path to model or models if a committee is produced - to make a committee use a wild card notation like mace_*.model - device: str, device to run on (cuda or cpu) - energy_units_to_eV: float, conversion factor from model energy units to eV - length_units_to_A: float, conversion factor from model length units to Angstroms - default_dtype: str, default dtype of model - charges_key: str, Array field of atoms object where atomic charges are stored - model_type: str, type of model to load - Options: [MACE, DipoleMACE, EnergyDipoleMACE] - - Dipoles are returned in units of Debye - """ - - def __init__( - self, - model_paths: Union[list, str, None] = None, - models: Union[List[torch.nn.Module], torch.nn.Module, None] = None, - device: str = "cpu", - energy_units_to_eV: float = 1.0, - length_units_to_A: float = 1.0, - default_dtype="", - charges_key="Qs", - model_type="MACE", - compile_mode=None, - fullgraph=True, - enable_cueq=False, - **kwargs, - ): - Calculator.__init__(self, **kwargs) - self.device = device - self.dtype=None - if enable_cueq: - assert model_type == "MACE", "CuEq only supports MACE models" - compile_mode = None - if "model_path" in kwargs: - deprecation_message = ( - "'model_path' argument is deprecated, please use 'model_paths'" - ) - if model_paths is None: - logging.warning(f"{deprecation_message} in the future.") - model_paths = kwargs["model_path"] - else: - raise ValueError( - f"both 'model_path' and 'model_paths' given, {deprecation_message} only." - ) - - if (model_paths is None) == (models is None): - raise ValueError( - "Exactly one of 'model_paths' or 'models' must be provided" - ) - - self.results = {} - - self.model_type = model_type - self.compute_atomic_stresses = False - - if model_type == "MACE": - self.implemented_properties = [ - "energy", - "free_energy", - "node_energy", - "forces", - "stress", - ] - if kwargs.get("compute_atomic_stresses", False): - self.implemented_properties.extend(["stresses", "virials"]) - self.compute_atomic_stresses = True - elif model_type == "DipoleMACE": - self.implemented_properties = ["dipole"] - elif model_type == "EnergyDipoleMACE": - self.implemented_properties = [ - "energy", - "free_energy", - "node_energy", - "forces", - "stress", - "dipole", - ] - else: - raise ValueError( - f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" - ) - - if model_paths is not None: - if isinstance(model_paths, str): - # Find all models that satisfy the wildcard (e.g. mace_model_*.pt) - model_paths_glob = glob(model_paths) - - if len(model_paths_glob) == 0: - raise ValueError(f"Couldn't find MACE model files: {model_paths}") - - model_paths = model_paths_glob - elif isinstance(model_paths, Path): - model_paths = [model_paths] - - if len(model_paths) == 0: - raise ValueError("No mace file names supplied") - self.num_models = len(model_paths) - - # Load models from files - self.models = [ - torch.load(f=model_path, map_location=device) - for model_path in model_paths - ] - - elif models is not None: - if not isinstance(models, list): - models = [models] - - if len(models) == 0: - raise ValueError("No models supplied") - - self.models = models - self.num_models = len(models) - - if self.num_models > 1: - print(f"Running committee mace with {self.num_models} models") - - if model_type in ["MACE", "EnergyDipoleMACE"]: - self.implemented_properties.extend( - ["energies", "energy_var", "forces_comm", "stress_var"] - ) - elif model_type == "DipoleMACE": - self.implemented_properties.extend(["dipole_var"]) - - if compile_mode is not None: - print(f"Torch compile is enabled with mode: {compile_mode}") - self.models = [ - torch.compile( - prepare(extract_model)(model=model, map_location=device), - mode=compile_mode, - fullgraph=fullgraph, - ) - for model in self.models - ] - self.use_compile = True - else: - self.use_compile = False - - # Ensure all models are on the same device - for model in self.models: - model.to(device) - - r_maxs = [model.r_max.cpu() for model in self.models] - r_maxs = np.array(r_maxs) - if not np.all(r_maxs == r_maxs[0]): - raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}") - self.r_max = float(r_maxs[0]) - - self.device = torch_tools.init_device(device) - self.energy_units_to_eV = energy_units_to_eV - self.length_units_to_A = length_units_to_A - self.z_table = utils.AtomicNumberTable( - [int(z) for z in self.models[0].atomic_numbers] - ) - self.charges_key = charges_key - - try: - self.available_heads: List[str] = self.models[0].heads # type: ignore - except AttributeError: - self.available_heads = ["Default"] - kwarg_head = kwargs.get("head", None) - if kwarg_head is not None: - self.head = kwarg_head - else: - self.head = [head for head in self.available_heads if head.lower() == "default"] - if len(self.head) == 0: - raise ValueError( - "Head keyword was not provided, and no head in the model is 'default'. " - "Please provide a head keyword to specify the head you want to use. " - f"Available heads are: {self.available_heads}" - ) - self.head = self.head[0] - - print("Using head", self.head, "out of", self.available_heads) - - model_dtype = get_model_dtype(self.models[0]) - if default_dtype == "": - print( - f"No dtype selected, switching to {model_dtype} to match model dtype." - ) - default_dtype = model_dtype - if model_dtype != default_dtype: - print( - f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}." - ) - if default_dtype == "float64": - self.models = [model.double() for model in self.models] - elif default_dtype == "float32": - self.models = [model.float() for model in self.models] - torch_tools.set_default_dtype(default_dtype) - if enable_cueq: - print("Converting models to CuEq for acceleration") - self.models = [ - run_e3nn_to_cueq(model, device=device).to(device) - for model in self.models - ] - for model in self.models: - for param in model.parameters(): - param.requires_grad = False - - self.dtype = torch.float64 if default_dtype == "float64" else torch.float32 - - self.model_time = 0.0 - self.calc_time = 0.0 - - def _create_result_tensors( - self, model_type: str, num_models: int, num_atoms: int - ) -> dict: - """ - Create tensors to store the results of the committee - :param model_type: str, type of model to load - Options: [MACE, DipoleMACE, EnergyDipoleMACE] - :param num_models: int, number of models in the committee - :return: tuple of torch tensors - """ - dict_of_tensors = {} - if model_type in ["MACE", "EnergyDipoleMACE"]: - energies = torch.zeros(num_models, device=self.device) - node_energy = torch.zeros(num_models, num_atoms, device=self.device) - forces = torch.zeros(num_models, num_atoms, 3, device=self.device) - stress = torch.zeros(num_models, 3, 3, device=self.device) - dict_of_tensors.update( - { - "energies": energies, - "node_energy": node_energy, - "forces": forces, - "stress": stress, - } - ) - if model_type in ["EnergyDipoleMACE", "DipoleMACE"]: - dipole = torch.zeros(num_models, 3, device=self.device) - dict_of_tensors.update({"dipole": dipole}) - return dict_of_tensors - - def _atoms_to_batch(self, atoms): - keyspec = data.KeySpecification( - info_keys={}, arrays_keys={"charges": self.charges_key} - ) - config = data.config_from_atoms( - atoms, key_specification=keyspec, head_name=self.head - ) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, - z_table=self.z_table, - cutoff=self.r_max, - heads=self.available_heads, - ) - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)).to(self.device) - return batch - - def _clone_batch(self, batch): - batch_clone = batch.clone() - if self.use_compile: - batch_clone["node_attrs"].requires_grad_(True) - batch_clone["positions"].requires_grad_(True) - return batch_clone - - # pylint: disable=dangerous-default-value - def calculate(self, atoms=None, properties=None, system_changes=all_changes): - """ - Calculate properties. - :param atoms: ase.Atoms object - :param properties: [str], properties to be computed, used by ASE internally - :param system_changes: [str], system changes since last calculation, used by ASE internally - :return: - """ - # call to base-class to set atoms attribute - calc_start_t = time.perf_counter() - Calculator.calculate(self, atoms) - - batch_base = self._atoms_to_batch(atoms) - - if self.model_type in ["MACE", "EnergyDipoleMACE"]: - batch = self._clone_batch(batch_base) - node_heads = batch["head"][batch["batch"]] - num_atoms_arange = torch.arange(batch["positions"].shape[0]) - node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ - num_atoms_arange, node_heads - ] - compute_stress = not self.use_compile - else: - compute_stress = False - - ret_tensors = self._create_result_tensors( - self.model_type, self.num_models, len(atoms) - ) - for i, model in enumerate(self.models): - batch = self._clone_batch(batch_base) - # print(f'@@@File: {__file__}, batch.to_dict(): {batch.to_dict()}') - # set_seed(0) - model_start_t = time.perf_counter() - out = model( - batch.to_dict(), - compute_stress=compute_stress, - training=self.use_compile, - compute_edge_forces=self.compute_atomic_stresses, - compute_atomic_stresses=self.compute_atomic_stresses, - ) - model_end_t = time.perf_counter() - self.model_time += (model_end_t - model_start_t) - # print(f'&&& batch.positions: {batch["positions"]}') - # print(f'&&& batch.stress: {batch["stress"]}') - # print(f'compute_stress: {compute_stress}') - # for k,v in batch.to_dict().items(): - # print(f'&&& batch.to_dict(): {k} {v}') - # print("=======") - # print(f'&&& out["forces"]: {out["forces"]}') - # print(f'&&& training: {self.use_compile}') - # print(f'@@@File: {__file__}, out: {out}') - if self.model_type in ["MACE", "EnergyDipoleMACE"]: - ret_tensors["energies"][i] = out["energy"].detach() - ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach() - ret_tensors["forces"][i] = out["forces"].detach() - if out["stress"] is not None: - ret_tensors["stress"][i] = out["stress"].detach() - if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: - ret_tensors["dipole"][i] = out["dipole"].detach() - if self.model_type in ["MACE"]: - if out["atomic_stresses"] is not None: - ret_tensors.setdefault("atomic_stresses", []).append( - out["atomic_stresses"].detach() - ) - if out["atomic_virials"] is not None: - ret_tensors.setdefault("atomic_virials", []).append( - out["atomic_virials"].detach() - ) - - self.results = {} - if self.model_type in ["MACE", "EnergyDipoleMACE"]: - self.results["energy"] = ( - torch.mean(ret_tensors["energies"], dim=0).cpu().item() - * self.energy_units_to_eV - ) - self.results["free_energy"] = self.results["energy"] - self.results["node_energy"] = ( - torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy() - ) - self.results["forces"] = ( - torch.mean(ret_tensors["forces"], dim=0).cpu().numpy() - * self.energy_units_to_eV - / self.length_units_to_A - ) - if self.num_models > 1: - self.results["energies"] = ( - ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV - ) - self.results["energy_var"] = ( - torch.var(ret_tensors["energies"], dim=0, unbiased=False) - .cpu() - .item() - * self.energy_units_to_eV - ) - self.results["forces_comm"] = ( - ret_tensors["forces"].cpu().numpy() - * self.energy_units_to_eV - / self.length_units_to_A - ) - if out["stress"] is not None: - self.results["stress"] = full_3x3_to_voigt_6_stress( - torch.mean(ret_tensors["stress"], dim=0).cpu().numpy() - * self.energy_units_to_eV - / self.length_units_to_A**3 - ) - if self.num_models > 1: - self.results["stress_var"] = full_3x3_to_voigt_6_stress( - torch.var(ret_tensors["stress"], dim=0, unbiased=False) - .cpu() - .numpy() - * self.energy_units_to_eV - / self.length_units_to_A**3 - ) - if "atomic_stresses" in ret_tensors: - self.results["stresses"] = ( - torch.mean(torch.stack(ret_tensors["atomic_stresses"]), dim=0) - .cpu() - .numpy() - * self.energy_units_to_eV - / self.length_units_to_A**3 - ) - if "atomic_virials" in ret_tensors: - self.results["virials"] = ( - torch.mean(torch.stack(ret_tensors["atomic_virials"]), dim=0) - .cpu() - .numpy() - * self.energy_units_to_eV - ) - if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: - self.results["dipole"] = ( - torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy() - ) - if self.num_models > 1: - self.results["dipole_var"] = ( - torch.var(ret_tensors["dipole"], dim=0, unbiased=False) - .cpu() - .numpy() - ) - - calc_end_t = time.perf_counter() - self.calc_time += (calc_end_t - calc_start_t) - - def get_hessian(self, atoms=None): - if atoms is None and self.atoms is None: - raise ValueError("atoms not set") - if atoms is None: - atoms = self.atoms - if self.model_type != "MACE": - raise NotImplementedError("Only implemented for MACE models") - batch = self._atoms_to_batch(atoms) - hessians = [ - model( - self._clone_batch(batch).to_dict(), - compute_hessian=True, - compute_stress=False, - training=self.use_compile, - )["hessian"] - for model in self.models - ] - hessians = [hessian.detach().cpu().numpy() for hessian in hessians] - if self.num_models == 1: - return hessians[0] - return hessians - - def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): - """Extracts the descriptors from MACE model. - :param atoms: ase.Atoms object - :param invariants_only: bool, if True only the invariant descriptors are returned - :param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used - :return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise - """ - if atoms is None and self.atoms is None: - raise ValueError("atoms not set") - if atoms is None: - atoms = self.atoms - if self.model_type != "MACE": - raise NotImplementedError("Only implemented for MACE models") - num_interactions = int(self.models[0].num_interactions) - if num_layers == -1: - num_layers = num_interactions - batch = self._atoms_to_batch(atoms) - descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] - - irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out)) - l_max = irreps_out.lmax - num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 - per_layer_features = [irreps_out.dim for _ in range(num_interactions)] - per_layer_features[-1] = ( - num_invariant_features # Equivariant features not created for the last layer - ) - - if invariants_only: - descriptors = [ - extract_invariant( - descriptor, - num_layers=num_layers, - num_features=num_invariant_features, - l_max=l_max, - ) - for descriptor in descriptors - ] - to_keep = np.sum(per_layer_features[:num_layers]) - descriptors = [ - descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors - ] - - if self.num_models == 1: - return descriptors[0] - return descriptors - - - def predict(self, atoms_list, compute_stress=False): - predictions = {'energy': [], 'forces': []} - - configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads - ) - for config in configs - ], - batch_size=len(atoms_list), - shuffle=False, - drop_last=False, - ) - - # get the first batch of data_loader - batch_base = next(iter(data_loader)).to(self.device) - - # calculate node_e0 - # batch = self._clone_batch(batch_base) - # node_heads = batch["head"][batch["batch"]] - # num_atoms_arange = torch.arange(batch["positions"].shape[0]) - # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ - # num_atoms_arange, node_heads - # ] - - # set_seed(0) - out = self.models[0]( - batch_base.to_dict(), - compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? - training=self.use_compile, - ) - # print(f'&&& batch.positions: {batch["positions"]}') - # print(f'&&& batch.stress: {batch["stress"]}') - # print(f'&&& batch.to_dict(): {k} {v}') - # print("=======") - # print(f'&&& out["forces"]: {out["forces"]}') - # print(f'&&& training: {self.use_compile}') - predictions["energy"] = out["energy"].unsqueeze(-1).detach() - predictions["forces"] = out["forces"].detach() - if compute_stress: - predictions["stress"] = out["stress"].detach() - - # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') - - return predictions - - def fast_predict(self, gbatch, compute_stress=False): - gbatch.pos = gbatch.pos.to(self.dtype) - gbatch.cell = gbatch.cell.to(self.dtype) - - predictions = {'energy': [], 'forces': []} - batch_base = self.convert_batch(gbatch) - out = self.models[0]( - batch_base.to_dict(), - compute_stress=compute_stress, - training=self.use_compile, - ) - predictions["energy"] = out["energy"].unsqueeze(-1).detach().to(torch.float64) - predictions["forces"] = out["forces"].detach().to(torch.float64) - if compute_stress: - predictions["stress"] = out["stress"].detach().to(torch.float64) - - return predictions - - - def convert_batch(self, gbatch): - from batchopt import radius_graph_pbc_cuda - # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_mem_effi( - # from batchopt.pbc_graph_legacy import radius_graph_pbc - # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc( - edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_cuda( - gbatch, - radius=4.5, - max_num_neighbors_threshold=float('inf'), - pbc=[True, True, True], - dtype=self.dtype - ) - - tmp = edge_indices[0].clone() - edge_indices[0] = edge_indices[1] - edge_indices[1] = tmp - - # Create a one-hot matrix with number of columns equal to max atomic number + 1 - indices = atomic_numbers_to_indices(gbatch["atomic_numbers"].to("cpu"), z_table=self.z_table) - one_hot = to_one_hot( - torch.tensor(indices, dtype=torch.long).unsqueeze(-1), - num_classes=len(self.z_table), - ).to(self.device) - - cbatch = Batch( - positions = gbatch["pos"].clone(), - cell = gbatch["cell"].view(-1, 3), - batch = gbatch["batch"], - ptr = gbatch["ptr"], - edge_index = edge_indices, - unit_shifts = cell_offsets, - node_attrs = one_hot, - ) - - return cbatch - - - def predict_debug(self, atoms_list, gbatch, compute_stress=False): - predictions = {'energy': [], 'forces': []} - - configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads - ) - for config in configs - ], - batch_size=len(atoms_list), - shuffle=False, - drop_last=False, - ) - - # get the first batch of data_loader - # batch_base = next(iter(data_loader)).to(self.device) - batch_base_tmp = next(iter(data_loader)).to(self.device) - batch2 = self.convert_batch(gbatch) - batch_base = Batch( - # positions = batch_base_tmp["positions"], - positions = batch2["positions"], - # node_attrs = batch_base_tmp["node_attrs"], - node_attrs = batch2["node_attrs"], - # cell = batch_base_tmp["cell"], - cell = batch2["cell"], - edge_index = batch2["edge_index"], - unit_shifts = batch2["unit_shifts"], - # batch = batch_base_tmp["batch"], - batch = batch2["batch"], - # ptr = batch_base_tmp["ptr"], - ptr = batch2["ptr"], - ) - - torch.set_printoptions(threshold=float('inf')) - - # print(f'batch2["edge_index"]: {batch2["edge_index"]}') - # print(f'batch2["unit_shifts"]: {batch2["unit_shifts"]}') - # print(f'batch_base_tmp["edge_index"]: {batch_base_tmp["edge_index"]}') - # print(f'batch_base_tmp["unit_shifts"]: {batch_base_tmp["unit_shifts"]}') - - # calculate node_e0 - # batch = self._clone_batch(batch_base) - # node_heads = batch["head"][batch["batch"]] - # num_atoms_arange = torch.arange(batch["positions"].shape[0]) - # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ - # num_atoms_arange, node_heads - # ] - - # set_seed(0) - out = self.models[0]( - batch_base.to_dict(), - compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? - training=self.use_compile, - ) - # print(f'&&& batch.positions: {batch["positions"]}') - # print(f'&&& batch.cell: {batch["cell"]}') - # print(f'&&& batch.stress: {batch["stress"]}') - # for k,v in batch.to_dict().items(): - # print(f'&&& batch.to_dict(): {k} {v}') - # print("=======") - # print(f'&&& out["forces"]: {out["forces"]}') - # print(f'&&& training: {self.use_compile}') - predictions["energy"] = out["energy"].unsqueeze(-1).detach() - predictions["forces"] = out["forces"].detach() - if compute_stress: - predictions["stress"] = out["stress"].detach() - - # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') - +########################################################################################### +# The ASE Calculator for MACE +# Authors: Ilyes Batatia, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging + +# pylint: disable=wrong-import-position +import os +from glob import glob +from pathlib import Path +from typing import List, Union + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" + +import numpy as np +import torch +from ase.calculators.calculator import Calculator, all_changes +from ase.stress import full_3x3_to_voigt_6_stress +from e3nn import o3 + +from mace import data +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.modules.utils import extract_invariant +from mace.tools import torch_geometric, torch_tools, utils +from mace.tools.compile import prepare +from mace.tools.scripts_utils import extract_model +import random +from mace.tools.torch_geometric.batch import Batch + +from mace.tools import ( + atomic_numbers_to_indices, + to_one_hot, +) + +import time + + +def get_model_dtype(model: torch.nn.Module) -> torch.dtype: + """Get the dtype of the model""" + mode_dtype = next(model.parameters()).dtype + if mode_dtype == torch.float64: + return "float64" + if mode_dtype == torch.float32: + return "float32" + raise ValueError(f"Unknown dtype {mode_dtype}") + + +class MACECalculator(Calculator): + """MACE ASE Calculator + args: + model_paths: str, path to model or models if a committee is produced + to make a committee use a wild card notation like mace_*.model + device: str, device to run on (cuda or cpu) + energy_units_to_eV: float, conversion factor from model energy units to eV + length_units_to_A: float, conversion factor from model length units to Angstroms + default_dtype: str, default dtype of model + charges_key: str, Array field of atoms object where atomic charges are stored + model_type: str, type of model to load + Options: [MACE, DipoleMACE, EnergyDipoleMACE] + + Dipoles are returned in units of Debye + """ + + def __init__( + self, + model_paths: Union[list, str, None] = None, + models: Union[List[torch.nn.Module], torch.nn.Module, None] = None, + device: str = "cpu", + energy_units_to_eV: float = 1.0, + length_units_to_A: float = 1.0, + default_dtype="", + charges_key="Qs", + model_type="MACE", + compile_mode=None, + fullgraph=True, + enable_cueq=False, + **kwargs, + ): + Calculator.__init__(self, **kwargs) + self.device = device + self.dtype=None + if enable_cueq: + assert model_type == "MACE", "CuEq only supports MACE models" + compile_mode = None + if "model_path" in kwargs: + deprecation_message = ( + "'model_path' argument is deprecated, please use 'model_paths'" + ) + if model_paths is None: + logging.warning(f"{deprecation_message} in the future.") + model_paths = kwargs["model_path"] + else: + raise ValueError( + f"both 'model_path' and 'model_paths' given, {deprecation_message} only." + ) + + if (model_paths is None) == (models is None): + raise ValueError( + "Exactly one of 'model_paths' or 'models' must be provided" + ) + + self.results = {} + + self.model_type = model_type + self.compute_atomic_stresses = False + + if model_type == "MACE": + self.implemented_properties = [ + "energy", + "free_energy", + "node_energy", + "forces", + "stress", + ] + if kwargs.get("compute_atomic_stresses", False): + self.implemented_properties.extend(["stresses", "virials"]) + self.compute_atomic_stresses = True + elif model_type == "DipoleMACE": + self.implemented_properties = ["dipole"] + elif model_type == "EnergyDipoleMACE": + self.implemented_properties = [ + "energy", + "free_energy", + "node_energy", + "forces", + "stress", + "dipole", + ] + else: + raise ValueError( + f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" + ) + + if model_paths is not None: + if isinstance(model_paths, str): + # Find all models that satisfy the wildcard (e.g. mace_model_*.pt) + model_paths_glob = glob(model_paths) + + if len(model_paths_glob) == 0: + raise ValueError(f"Couldn't find MACE model files: {model_paths}") + + model_paths = model_paths_glob + elif isinstance(model_paths, Path): + model_paths = [model_paths] + + if len(model_paths) == 0: + raise ValueError("No mace file names supplied") + self.num_models = len(model_paths) + + # Load models from files + self.models = [ + torch.load(f=model_path, map_location=device) + for model_path in model_paths + ] + + elif models is not None: + if not isinstance(models, list): + models = [models] + + if len(models) == 0: + raise ValueError("No models supplied") + + self.models = models + self.num_models = len(models) + + if self.num_models > 1: + print(f"Running committee mace with {self.num_models} models") + + if model_type in ["MACE", "EnergyDipoleMACE"]: + self.implemented_properties.extend( + ["energies", "energy_var", "forces_comm", "stress_var"] + ) + elif model_type == "DipoleMACE": + self.implemented_properties.extend(["dipole_var"]) + + if compile_mode is not None: + print(f"Torch compile is enabled with mode: {compile_mode}") + self.models = [ + torch.compile( + prepare(extract_model)(model=model, map_location=device), + mode=compile_mode, + fullgraph=fullgraph, + ) + for model in self.models + ] + self.use_compile = True + else: + self.use_compile = False + + # Ensure all models are on the same device + for model in self.models: + model.to(device) + + r_maxs = [model.r_max.cpu() for model in self.models] + r_maxs = np.array(r_maxs) + if not np.all(r_maxs == r_maxs[0]): + raise ValueError(f"committee r_max are not all the same {' '.join(r_maxs)}") + self.r_max = float(r_maxs[0]) + + self.device = torch_tools.init_device(device) + self.energy_units_to_eV = energy_units_to_eV + self.length_units_to_A = length_units_to_A + self.z_table = utils.AtomicNumberTable( + [int(z) for z in self.models[0].atomic_numbers] + ) + self.charges_key = charges_key + + try: + self.available_heads: List[str] = self.models[0].heads # type: ignore + except AttributeError: + self.available_heads = ["Default"] + kwarg_head = kwargs.get("head", None) + if kwarg_head is not None: + self.head = kwarg_head + else: + self.head = [head for head in self.available_heads if head.lower() == "default"] + if len(self.head) == 0: + raise ValueError( + "Head keyword was not provided, and no head in the model is 'default'. " + "Please provide a head keyword to specify the head you want to use. " + f"Available heads are: {self.available_heads}" + ) + self.head = self.head[0] + + print("Using head", self.head, "out of", self.available_heads) + + model_dtype = get_model_dtype(self.models[0]) + if default_dtype == "": + print( + f"No dtype selected, switching to {model_dtype} to match model dtype." + ) + default_dtype = model_dtype + if model_dtype != default_dtype: + print( + f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}." + ) + if default_dtype == "float64": + self.models = [model.double() for model in self.models] + elif default_dtype == "float32": + self.models = [model.float() for model in self.models] + torch_tools.set_default_dtype(default_dtype) + if enable_cueq: + print("Converting models to CuEq for acceleration") + self.models = [ + run_e3nn_to_cueq(model, device=device).to(device) + for model in self.models + ] + for model in self.models: + for param in model.parameters(): + param.requires_grad = False + + self.dtype = torch.float64 if default_dtype == "float64" else torch.float32 + + self.model_time = 0.0 + self.calc_time = 0.0 + + def _create_result_tensors( + self, model_type: str, num_models: int, num_atoms: int + ) -> dict: + """ + Create tensors to store the results of the committee + :param model_type: str, type of model to load + Options: [MACE, DipoleMACE, EnergyDipoleMACE] + :param num_models: int, number of models in the committee + :return: tuple of torch tensors + """ + dict_of_tensors = {} + if model_type in ["MACE", "EnergyDipoleMACE"]: + energies = torch.zeros(num_models, device=self.device) + node_energy = torch.zeros(num_models, num_atoms, device=self.device) + forces = torch.zeros(num_models, num_atoms, 3, device=self.device) + stress = torch.zeros(num_models, 3, 3, device=self.device) + dict_of_tensors.update( + { + "energies": energies, + "node_energy": node_energy, + "forces": forces, + "stress": stress, + } + ) + if model_type in ["EnergyDipoleMACE", "DipoleMACE"]: + dipole = torch.zeros(num_models, 3, device=self.device) + dict_of_tensors.update({"dipole": dipole}) + return dict_of_tensors + + def _atoms_to_batch(self, atoms): + keyspec = data.KeySpecification( + info_keys={}, arrays_keys={"charges": self.charges_key} + ) + config = data.config_from_atoms( + atoms, key_specification=keyspec, head_name=self.head + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.available_heads, + ) + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)).to(self.device) + return batch + + def _clone_batch(self, batch): + batch_clone = batch.clone() + if self.use_compile: + batch_clone["node_attrs"].requires_grad_(True) + batch_clone["positions"].requires_grad_(True) + return batch_clone + + # pylint: disable=dangerous-default-value + def calculate(self, atoms=None, properties=None, system_changes=all_changes): + """ + Calculate properties. + :param atoms: ase.Atoms object + :param properties: [str], properties to be computed, used by ASE internally + :param system_changes: [str], system changes since last calculation, used by ASE internally + :return: + """ + # call to base-class to set atoms attribute + calc_start_t = time.perf_counter() + Calculator.calculate(self, atoms) + + batch_base = self._atoms_to_batch(atoms) + + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + batch = self._clone_batch(batch_base) + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + num_atoms_arange, node_heads + ] + compute_stress = not self.use_compile + else: + compute_stress = False + + ret_tensors = self._create_result_tensors( + self.model_type, self.num_models, len(atoms) + ) + for i, model in enumerate(self.models): + batch = self._clone_batch(batch_base) + # print(f'@@@File: {__file__}, batch.to_dict(): {batch.to_dict()}') + # set_seed(0) + model_start_t = time.perf_counter() + out = model( + batch.to_dict(), + compute_stress=compute_stress, + training=self.use_compile, + compute_edge_forces=self.compute_atomic_stresses, + compute_atomic_stresses=self.compute_atomic_stresses, + ) + model_end_t = time.perf_counter() + self.model_time += (model_end_t - model_start_t) + # print(f'&&& batch.positions: {batch["positions"]}') + # print(f'&&& batch.stress: {batch["stress"]}') + # print(f'compute_stress: {compute_stress}') + # for k,v in batch.to_dict().items(): + # print(f'&&& batch.to_dict(): {k} {v}') + # print("=======") + # print(f'&&& out["forces"]: {out["forces"]}') + # print(f'&&& training: {self.use_compile}') + # print(f'@@@File: {__file__}, out: {out}') + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + ret_tensors["energies"][i] = out["energy"].detach() + ret_tensors["node_energy"][i] = (out["node_energy"] - node_e0).detach() + ret_tensors["forces"][i] = out["forces"].detach() + if out["stress"] is not None: + ret_tensors["stress"][i] = out["stress"].detach() + if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: + ret_tensors["dipole"][i] = out["dipole"].detach() + if self.model_type in ["MACE"]: + if out["atomic_stresses"] is not None: + ret_tensors.setdefault("atomic_stresses", []).append( + out["atomic_stresses"].detach() + ) + if out["atomic_virials"] is not None: + ret_tensors.setdefault("atomic_virials", []).append( + out["atomic_virials"].detach() + ) + + self.results = {} + if self.model_type in ["MACE", "EnergyDipoleMACE"]: + self.results["energy"] = ( + torch.mean(ret_tensors["energies"], dim=0).cpu().item() + * self.energy_units_to_eV + ) + self.results["free_energy"] = self.results["energy"] + self.results["node_energy"] = ( + torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy() + ) + self.results["forces"] = ( + torch.mean(ret_tensors["forces"], dim=0).cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A + ) + if self.num_models > 1: + self.results["energies"] = ( + ret_tensors["energies"].cpu().numpy() * self.energy_units_to_eV + ) + self.results["energy_var"] = ( + torch.var(ret_tensors["energies"], dim=0, unbiased=False) + .cpu() + .item() + * self.energy_units_to_eV + ) + self.results["forces_comm"] = ( + ret_tensors["forces"].cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A + ) + if out["stress"] is not None: + self.results["stress"] = full_3x3_to_voigt_6_stress( + torch.mean(ret_tensors["stress"], dim=0).cpu().numpy() + * self.energy_units_to_eV + / self.length_units_to_A**3 + ) + if self.num_models > 1: + self.results["stress_var"] = full_3x3_to_voigt_6_stress( + torch.var(ret_tensors["stress"], dim=0, unbiased=False) + .cpu() + .numpy() + * self.energy_units_to_eV + / self.length_units_to_A**3 + ) + if "atomic_stresses" in ret_tensors: + self.results["stresses"] = ( + torch.mean(torch.stack(ret_tensors["atomic_stresses"]), dim=0) + .cpu() + .numpy() + * self.energy_units_to_eV + / self.length_units_to_A**3 + ) + if "atomic_virials" in ret_tensors: + self.results["virials"] = ( + torch.mean(torch.stack(ret_tensors["atomic_virials"]), dim=0) + .cpu() + .numpy() + * self.energy_units_to_eV + ) + if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: + self.results["dipole"] = ( + torch.mean(ret_tensors["dipole"], dim=0).cpu().numpy() + ) + if self.num_models > 1: + self.results["dipole_var"] = ( + torch.var(ret_tensors["dipole"], dim=0, unbiased=False) + .cpu() + .numpy() + ) + + calc_end_t = time.perf_counter() + self.calc_time += (calc_end_t - calc_start_t) + + def get_hessian(self, atoms=None): + if atoms is None and self.atoms is None: + raise ValueError("atoms not set") + if atoms is None: + atoms = self.atoms + if self.model_type != "MACE": + raise NotImplementedError("Only implemented for MACE models") + batch = self._atoms_to_batch(atoms) + hessians = [ + model( + self._clone_batch(batch).to_dict(), + compute_hessian=True, + compute_stress=False, + training=self.use_compile, + )["hessian"] + for model in self.models + ] + hessians = [hessian.detach().cpu().numpy() for hessian in hessians] + if self.num_models == 1: + return hessians[0] + return hessians + + def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): + """Extracts the descriptors from MACE model. + :param atoms: ase.Atoms object + :param invariants_only: bool, if True only the invariant descriptors are returned + :param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used + :return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise + """ + if atoms is None and self.atoms is None: + raise ValueError("atoms not set") + if atoms is None: + atoms = self.atoms + if self.model_type != "MACE": + raise NotImplementedError("Only implemented for MACE models") + num_interactions = int(self.models[0].num_interactions) + if num_layers == -1: + num_layers = num_interactions + batch = self._atoms_to_batch(atoms) + descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] + + irreps_out = o3.Irreps(str(self.models[0].products[0].linear.irreps_out)) + l_max = irreps_out.lmax + num_invariant_features = irreps_out.dim // (l_max + 1) ** 2 + per_layer_features = [irreps_out.dim for _ in range(num_interactions)] + per_layer_features[-1] = ( + num_invariant_features # Equivariant features not created for the last layer + ) + + if invariants_only: + descriptors = [ + extract_invariant( + descriptor, + num_layers=num_layers, + num_features=num_invariant_features, + l_max=l_max, + ) + for descriptor in descriptors + ] + to_keep = np.sum(per_layer_features[:num_layers]) + descriptors = [ + descriptor[:, :to_keep].detach().cpu().numpy() for descriptor in descriptors + ] + + if self.num_models == 1: + return descriptors[0] + return descriptors + + + def predict(self, atoms_list, compute_stress=False): + predictions = {'energy': [], 'forces': []} + + configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads + ) + for config in configs + ], + batch_size=len(atoms_list), + shuffle=False, + drop_last=False, + ) + + # get the first batch of data_loader + batch_base = next(iter(data_loader)).to(self.device) + + # calculate node_e0 + # batch = self._clone_batch(batch_base) + # node_heads = batch["head"][batch["batch"]] + # num_atoms_arange = torch.arange(batch["positions"].shape[0]) + # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + # num_atoms_arange, node_heads + # ] + + # set_seed(0) + out = self.models[0]( + batch_base.to_dict(), + compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? + training=self.use_compile, + ) + # print(f'&&& batch.positions: {batch["positions"]}') + # print(f'&&& batch.stress: {batch["stress"]}') + # print(f'&&& batch.to_dict(): {k} {v}') + # print("=======") + # print(f'&&& out["forces"]: {out["forces"]}') + # print(f'&&& training: {self.use_compile}') + predictions["energy"] = out["energy"].unsqueeze(-1).detach() + predictions["forces"] = out["forces"].detach() + if compute_stress: + predictions["stress"] = out["stress"].detach() + + # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') + + return predictions + + def fast_predict(self, gbatch, compute_stress=False): + gbatch.pos = gbatch.pos.to(self.dtype) + gbatch.cell = gbatch.cell.to(self.dtype) + + predictions = {'energy': [], 'forces': []} + batch_base = self.convert_batch(gbatch) + out = self.models[0]( + batch_base.to_dict(), + compute_stress=compute_stress, + training=self.use_compile, + ) + predictions["energy"] = out["energy"].unsqueeze(-1).detach().to(torch.float64) + predictions["forces"] = out["forces"].detach().to(torch.float64) + if compute_stress: + predictions["stress"] = out["stress"].detach().to(torch.float64) + + return predictions + + + def convert_batch(self, gbatch): + from batchopt import radius_graph_pbc_cuda + # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_mem_effi( + # from batchopt.pbc_graph_legacy import radius_graph_pbc + # edge_indices, cell_offsets, num_neighbors = radius_graph_pbc( + edge_indices, cell_offsets, num_neighbors = radius_graph_pbc_cuda( + gbatch, + radius=4.5, + max_num_neighbors_threshold=float('inf'), + pbc=[True, True, True], + dtype=self.dtype + ) + + tmp = edge_indices[0].clone() + edge_indices[0] = edge_indices[1] + edge_indices[1] = tmp + + # Create a one-hot matrix with number of columns equal to max atomic number + 1 + indices = atomic_numbers_to_indices(gbatch["atomic_numbers"].to("cpu"), z_table=self.z_table) + one_hot = to_one_hot( + torch.tensor(indices, dtype=torch.long).unsqueeze(-1), + num_classes=len(self.z_table), + ).to(self.device) + + cbatch = Batch( + positions = gbatch["pos"].clone(), + cell = gbatch["cell"].view(-1, 3), + batch = gbatch["batch"], + ptr = gbatch["ptr"], + edge_index = edge_indices, + unit_shifts = cell_offsets, + node_attrs = one_hot, + ) + + return cbatch + + + def predict_debug(self, atoms_list, gbatch, compute_stress=False): + predictions = {'energy': [], 'forces': []} + + configs = [data.config_from_atoms(atoms, charges_key=self.charges_key) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max, heads=self.heads + ) + for config in configs + ], + batch_size=len(atoms_list), + shuffle=False, + drop_last=False, + ) + + # get the first batch of data_loader + # batch_base = next(iter(data_loader)).to(self.device) + batch_base_tmp = next(iter(data_loader)).to(self.device) + batch2 = self.convert_batch(gbatch) + batch_base = Batch( + # positions = batch_base_tmp["positions"], + positions = batch2["positions"], + # node_attrs = batch_base_tmp["node_attrs"], + node_attrs = batch2["node_attrs"], + # cell = batch_base_tmp["cell"], + cell = batch2["cell"], + edge_index = batch2["edge_index"], + unit_shifts = batch2["unit_shifts"], + # batch = batch_base_tmp["batch"], + batch = batch2["batch"], + # ptr = batch_base_tmp["ptr"], + ptr = batch2["ptr"], + ) + + torch.set_printoptions(threshold=float('inf')) + + # print(f'batch2["edge_index"]: {batch2["edge_index"]}') + # print(f'batch2["unit_shifts"]: {batch2["unit_shifts"]}') + # print(f'batch_base_tmp["edge_index"]: {batch_base_tmp["edge_index"]}') + # print(f'batch_base_tmp["unit_shifts"]: {batch_base_tmp["unit_shifts"]}') + + # calculate node_e0 + # batch = self._clone_batch(batch_base) + # node_heads = batch["head"][batch["batch"]] + # num_atoms_arange = torch.arange(batch["positions"].shape[0]) + # node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])[ + # num_atoms_arange, node_heads + # ] + + # set_seed(0) + out = self.models[0]( + batch_base.to_dict(), + compute_stress=compute_stress, # TODO: DO WE NEED TO COMPUTE STRESS? + training=self.use_compile, + ) + # print(f'&&& batch.positions: {batch["positions"]}') + # print(f'&&& batch.cell: {batch["cell"]}') + # print(f'&&& batch.stress: {batch["stress"]}') + # for k,v in batch.to_dict().items(): + # print(f'&&& batch.to_dict(): {k} {v}') + # print("=======") + # print(f'&&& out["forces"]: {out["forces"]}') + # print(f'&&& training: {self.use_compile}') + predictions["energy"] = out["energy"].unsqueeze(-1).detach() + predictions["forces"] = out["forces"].detach() + if compute_stress: + predictions["stress"] = out["stress"].detach() + + # print(f'&&& predictions["forces"] in predict: {predictions["forces"]}') + return predictions \ No newline at end of file diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index a48194f211a8fc749c31e217d020af909043411b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 178 zcmd1j<>g`kg0)*$W`O9&AOaaM0yz#qT+9L_QW%06G#UL?G8BP?5yY=h{fzwFRQ=q< zs*J?^tjt6pD>+rSpeR2pHMvB;pt2+*KQG=nH~=D^l$w{Ep>JH2Qjl0wQVAA=5Xm{2 i`tk9Zd6^~g@p=W7w>WHa^HWN5Qtd$26*B<|76t$|*DUz} diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 3d6064764790e90fb44232574b7b8ba5c19458ad..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 171 zcmey&%ge<81oJIdWPs?$AOZ#$p^VQgK*m&tbOudEzm*I{OhDdekkl+rSpeR2pHMs=BNlML2&d@h5N-0PzDyamEL5Sp>O#S%y%)HE!_;|g7 h%3B;Zx%nxjIjMFKEC5|YD<%K{ diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/convert_e3nn_cueq.cpython-310.pyc deleted file mode 100644 index d710de37d4ef25fe426f3525f6ce5faf22ec3066..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5946 zcma)A&5snv74NF<>FMeD*j-qbU2H?gU-8Cfz$TYi1la;Ba%>aJAS)ULcY3F4W_y_F z9;$m5w$&>+St7%MNX{*akRZy1LrzitnZD*^DYqz!gE8dys%M5>z)5D+we?-EUj5$h zy{fp`bPfFe{MUP%m(~p9Z&cX(tE2D^zNEzrLl}Y$jZXQWb(k*m4#$5!G)Gp)8dW-# zk=?Ob*{9m6qRkvSqgtn?>#U(Ws(0$hS7bxjvLc&nM$Zye;as;mEm1?Qi*MFy%Z6<9 z_?m%Uj;Mrrp~XMg>^3MriJxk zn50GJ;$%FO>#SHlJ9+QQB-{#A7(`upHja8>AN6bUTB?F>>UVXGKMK0iAH_lr3pbjK z{9&*yRZ`eWrV|x)q4)WZ@)dk0zT_ej!2}nku=ZH`8aE74T{U*uIedR)!ugVk z+L!DLqsN4c9!`pOQQzg$cAP;usvL3HxL~X|vp0SsQ*VgfcyZ6WEVmOc-3(GMl2UkS z>}^O7JBX5=R7#5M`3S=MaHJ03PczT_q1=wB#}eE7qzFk zvS*UL7^Lwi?D{g2svpXPx;#;pttuIac#{3tEZ9JiA?W1(XD2q}kvuU9vdthKgaLA0d3>zm zf$XLy#@p#;9QkK2d|cLV$f&z{;*=8OK&9Kdu1vba@B|RBB~|LnQ&Hr1C-O>ryj{4b zN%P%!G(K|>i~fW5vT9_n^kwSL)b7)vv`-J?ZZJ&FATt~B6+t9hN>DHTWHaogZJmFj ziW;6$SQSHQ#pFdKT47C67!bu)4JeIk)tR>|Q`K3jhK_0pN#;y7=ijeR4_+(l9#;)C z7c9KT6wP1iu|>HBBrEt5+Te@ER|M*(WU^7IW%>*mEz|2a!7VpP9vou7L^fEN5SuS6Dc{r>`!su))uc0ps`dBZ5aA$&MW5ndLC6;Uq_AX zup`Ba~$B<#VulqINS4_Gvt9z`dz_|NzeRg(9DHWNKnoB}kNjxOB z!MV%%qh-z?x+eeDwfFg^8`3$NrdWDv1B4ZgJ`830@4+&r z;x@m}tNh<94*%M&+1z?){Y#O;i^?dt=6_h!yD^w(7|^+O2}iRZ?U+7Im_J7*XM%ys zi3a&D`!$Y8m?*!j2@)g<7_poSdq$x5e5o1Ho* z0iF@a&$|qq;#xgS6G;D(N4WrLX113vOkMo zw6!ZuGL5}bXRZBWgE$YKkJI-@@P#7@s}k=ig0_vo(oX6pkRKxR$|j;d@gZ!8F+`?YR&Xcg%!Hh>IQ!VP*bl`qG2^h(T`D;5P^XWEY7TNc!fV$X8e)!C_9Dub*$ijhX~JbV;W|N zl}r(i$_*U-t__p|p9iQk1}bT+x|_rZQ8!3sgs%jhQ_B5)5`I7Y6oIO+2!jQ`Ix4sJ z1L~(~>e&B~|4sBvNHGnk!7bKcu4O&oRr~7$uJunvhgdYGz$~M@#uS;lv=5r3R62eQ z%?(&tZrm~M8h|C!{{czBlo99vhn1V~a(QLTP|xK6;F^)!Hw*x(Xz3lLX&0s9OQJ9Q^w7co`Vnr=pR(pGPN?G08!d*r?22XttE1zFqDBL6kOK3+? zMitdz-0x$JqIxw@5i+F}TbL~=%plq>YLh6rGLbTqg$enJNT8QHgqsk|!8SQ!Y_gr#U}ylBTmbE8-%iF9k)5 z02`9jx?EDW*8UG^VohoDQL_;AeKaLMMdH-pR!wrPEry8BbeZ+2Q8@}X%RhJc5yW{d z(?aVqoNkL*oO)W1=(Ew!qM+>&=Z%N0ilRC1*UM{7NqlB^Jh|8* z68K^xWxHH8Q$8?rcxJ)iWvz(k#*PINvr3W~*gF-FLN%|D*@JwD7H#-vJO~VGAV_%X zH3J0Y3Vk(3TD@hV);VpY^*aqw(xM~bmu^B+-0JU5%s(dQfg3Kc-!1zSmz)pib|T&An!8fV|O_{ zg~0F^&cnquKrcRBu^!&t9UZ*G~#g@j>1g(!Eo5A)17kq&XAori;HebrW47lC1I6BN1YzOQ+U8L z^-79G!s}iXYq}^{J3DkC*rJK(g=0Dso&`$FKUBD35=Ou^9z9?ksZwuX7llP|FD$TF zTDZ6j2MCGioueYsZJ5pymi2|~1(PBEQMx8qWDcwkWpA+)QgLn?{|kmlK6hX z;qVD+1&@=$8VBiS;Uwc>n9?)L89gKkw&Z%FsE5&bl1`sm8c^Efk}c#`2yq+bqf6OG zn`6=3Csbi0u)sUc9s||(s>4`|g8e7y&Qz1=bBK|U`MLE)7)^`0%;7G~Y<|yS?gQ6Z zf$34Z#avG2wTNKCVv1<8i*NoTlAc@50EN`J!w@C_mjfNE3iB2)rb{LItGCQ)8BQ4w zLG^jLAz%Swp)B7qF$;s0u>nGkfsGH*?Z6R>x2!2;X{CA}y?ge( z5($8`wJC?=tabeOUL9vI%|?{9-V?CX6p=MGVs*YfPH_btVzHt+J%h|Se!M&>^#KXd zJg1M1#F!*>VI84X524lfo+a)Ez=`yGcrTIJtDoq%wJCH7tawzSw=N)saEQCst^Dz}w(SVoo(W>h->=Rg~>L>b8 zpE(`Y)N`rA(j7H*>qJwxMdbb$ted=sglJlmE{NCucCjR2op&l6I4b|#2WktQ3&+RB zAdGw;fjyDER^r9=RW+spnV0${lEQ}hK@Eyl!PQVl$~=k+`y)NO zw(3$(GwCCchVSpOu*Z6FPM0#Bp%^Ur42c8nkfU;VgE!cU$yc>=A_s-9#odRz*+8y= jhYQ%@x45`G;1%~9zG|b^ezNvQ^d8|pa85a^at^UB+;gZ z%na!VchRaT7Ey{~qohc~1r}Ik(I46$Ewm^K7$^!j+x$q2{@|Ey$ep-tuyI@XA8o6P zb$|8TJ0wLbinqN4@7%ff+;d;&p6{G9&s;7Cg7nkT)%oA=Lg;J!(n!vF;n|;22z`V^ zBvR*)YCb8IBJUQ}0#D01dV*1z39D+Iu&Fk|q_L}ZXrs@u6P(JCI_8{X!l^o;Y@Ojn z+l+O_WraZwGNOGOy3VL>k%f8={*89q3_rskv|uVi$%+m_=@FeX%#3TsJ|)y=+43;y zMgDhEw2XGpA-V`>ujmG@gDvdp#i?E2tP?%|584+;Ya^otdJ+DK0nrQZlo}L$+mPBm z(;@m_zk-m&A~e(9j}Fy`7SJ3uM}>lq@%x1kZ7?TOiJZZlPiZ;BI-Sp^oWhP0c}lmwY>?N{L)NL2BX)QbLX|WRh~)a47kOcv`w6tD0d`A> z5^(Cx z^)|U$F1K)ty)dgcY2e4yXLgp%%tdDAj?lv`y+lvTS~h0DlNp#j+FHhTw~j5E?Rc$J zz3e4rUvm#AKZ2bBu~z69eTkm5oBb!x)%h@rMWGdH2sU95_Q~9<*+$9ojZ&R{7KQwU zS5C`0Aq`R?WM+jM@*Pdc%}Y5!k>#Y2%LvzG0VF}uW@S~ClZ74ygdi?2T$7W@lrk4i zrxaOI!?)zr+C8V?**XC2RxiBejs=^JAv6vRvde?Zt-dKh! zrQ~ER{;7Ric&%_R`inn&aA3U5oxXoyd^Ngsc3D~$i|(>x zw|-z;=T6t0KwIMx*Hd%hk{g#ixPKdpQsdM;>2vRA&N6rYm(IHfp!f9^?+5AS@2$MH z8eNU7&XpY_pE`B!e9eoyv?E`~a^Zou|G%vc&h~3>|5EfTcXu&bjFftZOZ@&CZSy%E zox4=JJoV47m0i>S^x8dX*;SN^Vo4ZTdwVT$x1;Qi{^M)9Yr19yQdpL!2X`+jC4Q)8 z$5m{jDpTUaH4axfkf*aqEo;Sai4Uzip=v4m%!)Y269&ppaOqdx{%3ejPyBvZB=TEr z&!0Mw;|%rusSP^uY@ZHvjd6}okNC$J-!DCb zQ2qzT3b|ojkZ#A%8_q-qZdY2u0++y2*+>}4h%)_aD13xslt>X}4=S-m0a_Z$N2Y*+ zCKUUCqRA;{5t&U2mu{*_8$y{#3=LaADS}L5RjV)xedUDNN3;%GhKXuXm>4tII&M=* zZ;W48J~XHnwh%cf!#!F66(QV!g$h|IrE0=8NdwiAQ3SGfK~j=J%km#%OhDh()Y_St{E* z9=dmw`p3%dlO^_K)$TLv_71N(?q2?d`@!C0<>2xAdylPLF7EvE>-yefOO8_Tcs1C! z=KbWtTC_AUQtCgf2M;efs-CVA+x5_~eT84Wx;lBcx72&4>^NIu&OU!?N9|FH%y88e z^)ZH1%j8vn>LkW`XHHV_WbsP^ucDEf5dCYY`KaYYuMEVPLv^W9tT8Kq&RYoOVT-eK zSMf%ZG=+85rqqwF0w|d&W*3RPG^7GRnSGPf=sJwThAFHFP!HT;vlR{invU53n3)(0 zxi!YQQN)HP3{VGqjwNz(dmZX%g@fK;AdFfR$CXd&Jpnv-1O4RhnN)KukwZcYlE@Wy zP1i9@(C#cOfc&HqAan|`ap2Agbp=i&3foFx@Z2s~!u#QWKn1i$g`MGMPiz!5dTJO0 zHeLWt>a0N1)5+9Aq4(SRnRFKb2(Sbh)?7wS%wtWUB&3|Iz~2;5r`iMUDy$!E7S(OI zj>s#61=65ysBM^+<<^3R?*@@tGc#qk4`xcBJihf}RBX5&fWP)L$QBW3aG$VZU7IL- zk1bAo?ce^MeTjPL@>N{D_g%e!58L{RxfSWhw|~;N&e>e%?B(-yDq36O`MaJXcHj!!WQfDd@}a&E_xr4~Tr z)ER0yW%4*)3Pj87iKiVX(EI!u1H?q~|F%E6k2m^E(|pYO25^^PaDirE`93VU^<_N8lR^k|{Cy==`5`WTCma0t#EG!F;5$)Q;en63xxhT;(I zF~=LwO4bH@7;~g7ldvv=i=5(!agE)Oeh96=;mjZ5y=jc0RzwrC@0J@A4Y zVAE_<0y$u@97LK}kI#6MG>yN&Wov2Th%`AGeYct|NSpIzGsfDvv6~=^e_TJ!j>eQl zzR?PM6NWu?!X4{uv~E0OP9if{i)~79=f0Ib-IUqM{pQF(jkV*Z6l=A3*-C~PCv#B3%HhDu5iOg$r*nx$29V675? z|Fc&hJ4y{80b6Vs<=~cPywetZV{W34tgvSwS6H6r&|!L(3Q-q+jsSH4Zw%_TL1XhA zGZTbGQ*Coc=Dec$Dk$Uj=|*>?d`lQ7i5FhWQc2Mcx=O(l&|II2z(lRt1oAG zrJHkc9KcYp!+I?PF$_;SBPHV-V~1QSR?RBD<(mPwiNuDhc@AVYZH-s(k%@a{xUm?e z@T^UDYY?-;s!fF(fix_{E$}QHH&43Br{PknCxH+G4@mnB$l7ngzVsn~U&R;FeW9{% z-{QHduearG-G+j@R<4%=2Noy4^z~ z-4EP{^&KN+_u-23aM^kI?wQZ&zrWZRqS802_l@4&QSLiZ9}WDcw*GY%b#@oOQ|<_r zn7|{}xpc0=_vn02neQ#LeSbNya-*_qMBg=1-t}6!f3$SuyxxCaXV2GYWc7o2(7(SN zoc>vFFv#fY_>Ykme-D{#A8m)wm z>7ip~&#{MX9Ur?sbd`kf+)dn#m3qe>w4JODgjY3vAhPT%j+NVlM?m8F&{J(6T1~9R zYKV&TR(nIM1A1?`(tA+vJy_a)Xzl8~?$1R%dbtvPOOL))I)0@_Q^D_2zoBWcTbwA+ zzaBtt|MGz!U0I^4qo*pPQ~Kysm3KX%Ew4NN%h9bnx@+K11}p3io!zkl+29xK$P*g5 z+rf#pdMk{eGr~&uzc543pYBBMF-pVE&S#9{^daUmpAGT^0~L*AhXU3X&sG7PTaEb` z+A!H0@v$a!fqM8|hQ}3)NP&^fT(rc%A>1%EV$A;qW~cC^Hzmc|GA6dFH>G}nv<7qt zs0G+J3@cmACXyKMR=@>jmrbNGsDK%P6(C6>0}(_LGAqYbFrLhG|X8>NjXr$|5;AleDA&*O+=1R{wo`-5GAO$pXcVwSDO zNXyAK7ShtIA6@yNA%ht9LY@oYAn^dSaSUG=Kn&M>Dw&j(cnU%xS&ew)(}pdj#uxHw z!+tSBB3dEFu!D1SEiESjy`(!39W<Z-k-e54q7z{vo&Tv4?P67`UhtI9a+#P@`FAN(`%Ck~F4NqL9;!6$P zz^PNQ8K;hr_mnc5F}Qdh|^6#F1 z=X}X~;sG;O1$^~~?)yW!Ke9RwS;ar1`$yJpmi>dt z`|bng(tq0hRZn+`?S5qUwk*u!egD>bw@mn4ah5%=tX|eV!%IwcTW_(w`1VTAio4V? z49jJl?_PZ8VyS)K>iN~9rNHYCm^Wah?@qikvGfC2XVvL>2HS@cIBTS)?mq&X3c9~h zkETF}Y)%{Bh7|1d)-_Kn6@*pbTg|8iw9(e>rY5zI(%(wKik#>WoiCx_MVIInJ#`Aa zf=#L6q=FV|y0>8BW5<0{g4!)m25Ps!7O2IskxfZ#8)b+-@?KE?0NM(jm*5uR?;hmB z2$RvVQ|2uPb;Zof>1@G0G}L;33ImNEX05PEr4T=%;3C?Hl*8Fv$f90{!3=x-eihiE zq53VP6WM&hb&_03JdM021N9A{CSCC8t@jiHexL*7_$+t;a$(mL`5aeQOIQ)X_Y?eC z2ebgzS-@{L1>gO^@Hanh5S|c$!b?{roL$2ima!0LuwbBrw$fz79aZP@ICOdmzp4fg zR(~8?NgYvX8faa0P9xWROLj0O;iVc3Ce*NiHbSZlwv^R8)cg(nwZDRF5j}JTE3U!& zuEDZv*W!4Uaa5QNo#`mLO0SHUaPf4ruF7^+U9VQ#`f7G$?R02VPE7-n-WlSymvbxTcUb6Px~OQkgvIA}mFbqX`X zkpwfoXq%Q54T9bn&JByA7`92modAe=Z1@21Dipw{c$`2QS(&ND*(2t3j;avuK%j4s z7KR+GUyAw?dA>yMFA?|eXi!IkUnBlM(CELSqcvNQvaJl%5Ij}~SGAvq)^V}6-$6}K z;F{EEQrIrk3*iH1;VUb*v~!tW7MI!ghIQ+DKJ1Mfr7p!rl=s@1Ux|0<5AEd%nU&qILKG^f|@CU;+8s5KZ zdAGuQoeM6ROmNBMflDfde0^1C8%%AT+fTJUou$Ys5yLMP%5>MS81{$b;K(rB$bSAW DRppr} diff --git a/mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/cli/__pycache__/visualise_train.cpython-310.pyc deleted file mode 100644 index 48639ccc5b834099115dce8a9a6079d5895a8e97..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11810 zcmbVSX^b4lb?&=+dd^+Wo+PD_WZ53uT1uoOO12eO@;lR3<^)-3BIW+3R9TcRx0tY zT2WEf+FD1i=pCbCbj*sW#%)%|t~d#vs<@IbUCGGbY$YpybMjYj=R1W;K~;RGHMKp> z3}!af9TjDD`w+8s)JoAWZqN8?YnEyo3bWr-D|5_YsfNZ}mTqWujb-+iK$Cq_1vGpnETSALvS{2iW`B5-5)O+15R+d)b4&?l1cH)ip?b2wca3yX*ve z_^MfX2YUom_hVcb1xEEKI|&)@#K<3Gk7MMI`-LsN@-C)cP|6>;LGOIItPA6Gvla^D z`DPFb=Og{F+3R}kvMQ{ze#o1(VXEHihD$+n=-+U_BHWG__S?O%-CQZVL;Jpm>L*XG zJPd?Adi><^$AFxy(-+9`n)l%G8jw255BfkJfB4apD-6hEwMQO(cm;@m;^fKNaUd(T z6?UAEhfc2ewR%|-7ViXp)gNZgpS|#G>DkiD9sk8wu8->L!-A|oBkN1YN>9r+zw7hH z#<2LT{4Mo;Uh=|Tr}T2nYu0w2= zfZ?SjM-Xy92!?Z#`hqNs2lR;(#FHkTY|LXNz-CH{(~64rZx&)CmSp)1+P!et(#jqT@B?Gc?3Q0FYmh)9MWNMlmDj zcJ_tQI8OU38$OA0x!LH%+q5U1u;30qi8{rGI#QT<;AEAnrNL}oOUYmv_(&v2sG+wSO z% zE3|s-1ZRD$$cC(oL9N|9w$=;=Sav8w`3{!)8^Wnp@&3YU)k9xF5x~k?>cOw84(BJ{ z^U{b-q&q!^v*q`DwN>a8DhPvn-AhoremLOWngOAs&lC`H8{}*GO*jUj5-BZ}UQ3G} zyqK1;Wo{a>-n^~ejC7`6()nHVNoXiT6yKQ%~Y6e<;YQB)-`R@joenA*^&Asg*j0QF!hE>j=?oeRnS&|bC9Jk%z2LG zuNj-^D9zLbr8R}x3^@eX^vw+FvTmIRZ5A-A`KSOF0*=Nt9b?LFFGMc6N0%R= z6)oSh%<)o`8iCP=juAZpJ+wn!ML(~Yn`)%KsDM)=PW~_&cA^GkT?|fw zM5LZ*z_hOS`IEm+^2im?K4U0`>h2ETHMz^^;4eMV?$w}8pIjPqigc~lg$jTk@mX2B zK|AV(1%SS-wzmN#BJ6szLC(f-{w(duQoH9dpOxwyHe7Dj(oieCduWwNLzQn7h+7p7 zIVsXdp&8$xWoUfy3-q~u3c|CiezUO}RyizG4+?5E=;JV+AulIb^|-$w%$2rR!`^3J zu3=P9|0YI zoWT1De1O;vS|##i6RRYhT4|=Rv+ycl%q|bSE+)|Q17Xna5ju&t<@0@?aP=wGE6Vr~qiXp=zoTe{X5_v})`+CQgawsLpOq zceNbard4Zq)+nm;YC+R<7c{QnXvC$Bzs}8~V`v4nq!qM`YW&@>c&AfV*3K-Kb7T%^ zY}Kl8t5s}}0jzdaWU5unxt+A=IC)&Bm7_Bv^kx@s|4Ogd=A`NQV^nE|@QYW4*6qUo zmty!8Y9x!c50FOSWq?oO324VTik8N_f2e8N+m^2VOf$5%G*i2!S=vp_)_$rv+E28U z_G8V}eq^QHAKp7J`GwKv6f_9A%AU&i6AGm^6`hE-#)s@X(eO_`43(A$nt3)qWov+C!4s4K47H#-*ho3~)zl2Op0=yKwdtyK>1|?@?`aZ)oArcA5yTS?4{+muybG+RkB`8y98^sO#!<8Cam|iN*21I z0n>JrO%oOq_Id%X8_}UW;u+ycEKhGfvbSxF!igM6f!7Qv@DavR(8E}Y11SYap}JU#vnQqH zz;l^`51l3tn$3`1hwK!S$q9_WWpiv^+Hu-BkV<36ZcT;h$RN+=m2>cB+|6`oL@ql# zr=T6a2kNr$99)*Vr139DHavsbWX$;}-^#SI+c|cG9R)>tUMVVdtSjWFw$UoF{NB7G z_%XCP?5?5$-zAN)-hGWcmHc=f`IZ^w;C00F%fm-1p>z+e+2zk;7C+Z4rN?paGT*|s z5QdMfj<8D>OsV+gcEaiKy{>@LVR2dS^+Wz!V3D2*HgYcr_%EVd*7!%L`lA5D>`2L# zo=1CGl}%(Q8^ipI6nc;a@5)C0JQ#3{dZ+)YE_H@{+U8=&+d!3+M;IE5$1dj)x>f_tuYBQZ%S4sR*Gqxtu;NgwQYf8x@zPUyWTHyF+22w?3(n z>Y)BBJi(^`a;B@o%&FR21uLhGaK|xV0n^t7SVM!B*8zx5GQE<3_FFxE%jodYb-xq3(lwF4Do7t{~5{@NG=4KYLtGlN1 zrOQR8;NaWL!omMAv!XQkJ{*}i&S_GY%!aCpF;ZGt9B3$sYvyLIdkI`Fxatf64Y=wF z7pcUIn2WGUu1CO?2iNN?FS%Y%xbmQlM@ZP@2p7Rs z0M}PpL2`XH;VO)|2%F%Fax$Fy3n-{m?;?)%BgeLv~p zwv}VkQ64HgUb!G%xzDNGjMt8^g!8u5JHYv3yn3L&H(o!&63*LJ@BruI@d|?eRJ?|S zC7idd;sMSh@hXD;<#-(lOE_;^$^D%1N`jivTCzh}%fE)%F}KYfX&``EF}9^`7D78| z!H!K@fT3jyeK}DfLf$-vaY)PNq)|yhKJA7SjFG!TCMgY>Fo6`9LQaM)Mh-bw(oC7# zS$24T8G-x}$UDmJBKpgwG$wa1LkUt?C=G&8*QJfPslJ!fZG@QSeiH1AlagY(&wU0b z-)G}GB(sr9GVCbG7by^ylXP2GeGw*!@^T|H zWu*)lEo--+7Z3+3uNpt#H1{*2;I-S*3Rh*$g9mY#dkl?nhVnSD17?f%HYH4@+4?3B zm?cG;C3Dk+NiwVIWdr64J~d3$N8wY$cZFhgLTy`TT4W=?q`#rlK@E*9WgA&K#81d` zf%<;Rvn4fzL2)3DrouEK>xcZE!PFf@V8GB*f?0$y8^$U_syxa-nPe~D9~u07=(MD| zkUlIF1`J{rSvLC#QXrWfL|C74F6R!6-;v{Ik%h=BwR~ncbAHgJlr^HOk zVIZ=U!)~Ib+Th3=*l09eJ$jF18U^dG5KV}RC0;psCMMl z9gZy*f#M{rDx7q7QhmF@pPgC^kxm4636H9QACdQu7~$uYWUVQNMHtWuW=g{w}* zpz!r7q6KC#8@~uWSqi`QXOR}DIM{?P{|4BHAB$gdsn+&_po9%of+}nysZV=zY+y>T z1EsNzDa9IEQYSBo(nh5uvper?Fm@*Ugks)M2v`7~0IQd$5usN0(hlak4=T zT4!BkX)o-Nh@T>Wllj*sGn{-`8ICAtH|Z*{du{+R559-J!!_7$8S4oNl}d8)Wt52T6hfyqoT5 zJ%A;1PQhnL?qpFJi!CDok4!Y2Jejdfd5qCLBJ{M>v7{)VFDO@hT@|TnRUXr7HC7%c zQ7!qtODbatB^N2!0^-JrFO2i}tIz7QDt#Amp*ErF#}{g~k&b~(h0ZNK&)^B>WzwuA zQ2?ia6G?jT0Iu!SvO3OmR6Rna* zY1t^rSTiX*aciU>MnY7cgw%*)Tt_|O-!@YICCG9UjCx8ZJ*Cl8M#_?xO(SnUXT#h>T@8iHxs72Cl0y_peDAhY}fwAfqT{ zOeZpCqzpuovaE08%p;+w?&wJWDy_MgHZsQZpC8eW>uBt$xL-sx4tidoo`K^M660S? z`d);*!&2Tt()ST53+clBeeYag{5OIBAs{`dw+Qy(h-5wB)BBR#Z2YUWKH$NWZGL; z1G#vVmkyqUm?L^do!2Sv4d+_J1xtcY{^qo4wWuCRS~Ku*=2hqBaCY)^$Kr20+jLfwuJwolIVH$Jigrr{L{hmP*&7g_A%ve2s5172;UG?yv4(p?43Q-&bX4N524-W@3CmXTmxdaA zW*NFXOfRB{K5_=`tmq}E890f+p-_iMCztmC?*Z?JcHN!^#qVb}Q1>q6=tcn#oEfu*x@gF%a*`E@XH1&k09<$&^SF!|gaV;sX#S+%`BuG?BUhP^+M>x*5?AFJ3|}(A`s-F^&+-BUGr*25-Gi9IYkv$qN@A0W$nlAL^oMySi`x= z1CxY|pL~PB5J1?~D(ltY<~&AKHoYhLvS^GVn|MJr9~p+k6QUT8lpzYu&29B*$a+V6yGUh3(XUgy9zd|g33xEJ`gYG%yMRI(lOd3&4?fo?r|52g< zM_BPM|0MDL6A@`$Tzd8q9eDx&IMsfNYNzDI^V0r*h-G9|Y!(Af$^g2lc23@H$Qgoz>>FFFq5~1yw%TI z&)Ni=n;K=jO-i%*x3@JPS7Mxo)9$4A&~h4T(%v?uEdOqzIGq9`I7S`b9r-dP1~t2s zc>b-_a#Ot&r@uf8l^nwvdR3f}Gxe(aRUBKG0vPiJ8bZqjx?crvYd*!FUx z8eG&RjZn+EM~&Y71ZY#hHIB001Dmo)f@=ai&G6R(xLk;ftLNx`mTM*W8n(pOHd@2A z!yne~D5~orb=~z8*UxnV%#KmB_vqtodyhY!$90XWM_r@M{VGCBH(>99l#%O2oR?^! zh1&_(&X%ZIU%KZoHoFKtZDbsG!#MVh+V-er3_|-P?b-7H6GI-gq&XfMk4u_!p;%mE z&d2?sy}S28WbXw2gUIgjz|P&{5Sf7VP97oy{ZH*33qoYy_@1Zw#~{M@?cF=R8zN)l zW5L}R*|m3!AD{53B%LrF`>mB4yxKFr8rupZ;ep|_(n6r`mlO*5W>pnut_y{jd zUYm2AAaC~!FSrBo=(Kx8uB)Ah3gi5okc`8}ma2KUSD4lUwbTYh%+MwP1r0 z5ODVZ@=T;gEH3b|*jy!n=13w(&^P+94sI!eMXsaEf|n_fG64h-dS;mpK!@`**anY4 z#|uaR7%jp~%vrIhpzlBs%na|(HsHi!r4b+=yNRu2>+^Faf?6I^!HFa~AOQR}(yi&) zJQWqBZ-YLCj^q5;H6BqJ1wQO2;1hEWFCPxXLzj8?pwCSzjI@#o35I5(VScXiIC&RL zNDhq->7JgIy9BjIrgKFGY(egDd0_)cJm4>!5y6E+Trk{19cCCL!^LPk5Jr+D({sFV z1SXdxstOYFs9g>Tp+Gn$Rc1?|mCXeVK?T{8F%Kdc<*pH=204x(HhBr=3&(+;+=D}o z&!azq;B9J@LgpfVrbaeDY;AZ~y^}2z>aOHw^Qst?TtQUnq zC=!ZHJ_p}@oEIK^6Y37Z2f)W>!||9u7!oAy6dwr2Bx5`<7Uuo&YcsrLi(@tarn z!=N97Js3O%fyX3iVtjbwcdr@4VQ}Xk=|b;Rbeiv-4$MsjqL)Geh>i2zGeYzdKOUD8 zV=%5$y?sJ(2FCFk$&-WeaH#ikC^id|1)DCJxt^J8665#7%*XwHsj6rZ^e80(pd$7( z1lOsIi>htkV5-s;uD6V7SIb-0w0q|UQr9@(b%L`II*|(F!a@0c>uX^SOJn%Oj%M)EW^d^H3NYc7J7UX?<#r-hYLn0?@-r zYD5uHDK9dWq6fq#m5UlFQ?j8Sb|s#FD*|6Rh072hB`2aAK|VkpsH^N!N*26v2xeh? zXRbndiM^$qIz^ub(KBKg$?Y=U+5!aGC{U_K!FE6uG5U<&hQ}c|osYJFKdCDz$x# zGD_K+=8<-#Z1^vSZ}{}e`fyVto?OYs@QY9r8-NF>r@WnVyf7RCLDk2WQ>US45sS~_ z#eJmcDP0rxY&-NDgln#1KvDvZ0U90Hf zs<`R_nowS1ErDv|(oir%HDsP@d{&6p_UR%vm?zx%iXNYhs{?_%VL%O|lwTpZJr+Gp zFw@G2yJ1xg!nTS%o1{GMApt+}ak`)iX5jw%s0D&%fO{Lp8b%bp?r7O%>TNBI4Sdrw zhP~H&4=C6WDOIgeYJPTE|JJFGnxFkYG2DRufyG8krtxS5cpETCKIlg`geFn$o~s(ftoJ4^?HWd)~Ym#;>92kc! zHyDUb@xjSH@_c$|&c?+9lf3)ll_&_7z&&EOwUTx`8jcE*Aq?G!#jk~VIeH}&jKdBd z2+vFfByBh_1{=8nSob7wmMc*u8jbLhW(?{<<7ybk2mklCzx}O5`{y)0Gm*(q7mzM# z<47>q#tFB-7Yf^eQl<@hLv$aYQ&xGTW4rXLl;3XzD3yM-4x@(Nh zj>(xZ%pJo;5rf|F5vhU*Gi0LtL4Jmh1d&FmhGgex25Bccog)AR0!bgZdKm{476K2; zFDU6_fy+F!0>d4S#&}PeFo1o027^Nw9EPAs+(M-;D{c|}j>uoIvS2DeN_%-W5P=Sa z_?V=@)g!49yw87}vH4?Z*1C9Z;au9@x%%{K_nNTYvhG~JoUoltT;MWl z%3KX?nJpQHs&w6q+=#4PSf$q+*DtI)Kh!Km63)T-Q|XG@_!oR*$Ul ziHeRy#ew;g>8gem`j#{LhBd*6BYJL{%o`>|W54{L5{?^iW1X<-bPb}gJq zu{bp=&IGir|-i^!9*@I zE8}Zd=ARYWL-!!(%o}Hx#utY^J3|8j!a-OOAUhN3n0&NEfF>U@5tfNijo*g+LE=$*6)X99?q6n{r-uN5CgMCgpWj_SPi?cfZMSE z;v;o!aL66;&B6yfZq==2$wma4n<96hlG2 z0b@q0rXgd(7)v>-GG>ffC}(xXiZL5i=E#&`%uZF;X38<T7zl*vcRCCyAAJ|!H1=q5IcVlj@` zqoc4tzYP|{2wf-)b2O*o)B|c26;b(Az>ppOQD$=h`B=%qDi}P1@Cuz^pYkDOhwXMibW&9WI;HH;SMc?W?=uG zL4G4y&vNI7j`M(uKY>h5(#{C5&OKV;29PD3gg~MvBn{yU!uf@oGVX*W%a&gs`t|1qSlFrZcCozu5v2wUwu-c~do_6H z((Oy{hgZLma2)tTttzu;)KK`FFHBH%!_b?p{hJu>6Mv*TYE{3z=SY?2r$!A#e_FNk zD69V&qlWm;Of<$>61N(U*2Aa;uJSPfd60k>1DT7$q*3Ep5gnowg=>tg@_z?g1XvZy zWbv3&?LRH6Kwg;7;s-3*>I*qLy%*~jJ*e>X=GKe1u*EE-vf?A!VIQYM0AH?6$W!Rh zKZy?b2AP6Z{xCXpK8iCS9TB}xU+Nq6fkBGUspE|J=7_;+1osK@a3+!qO4fh|R=(bA zQcCh~&ioj>g{XRZZ|+MhfVSqpIKXGkmngB^C^9~#B+c3^&Dt$!COrlfQKp$GNwckh zW*J_lnfVelqvTZdfSQl0oGGu$LEp-YfFQ{^H}$@v1VkmOY@7>KESNnq{t-Rw)P|8t z1vW00#8%XKTeLK)ILn)(yuW)TxpEUQwUnlEQHm}wrb|0DoOrV-- z0lRhES+tf~6~q4nj1MYOzp{DT@yaa4HKRC(sUf@i(gO z?*9sDc;aWUZEplN9Z)MX_W@y$eIpngRo!XIOG6Nr*8 z`9UcSOvgMnVFmI^aVW`HU`sa=;{h?I3$Gy(b>Ql5Q2Br;Ke9N94hf|DI1rNhCD81R z*nPiDrG6T1^Pu0M*?{o;fy6nv&gjk!eCpd z{HUxmY1r`#1G`W@Ke5Em`;&&cUl<$>U<+j4uNtSM=Ll+iLDSg5>yG_BFSe(JSTX~w>^S7d6IzqYCq z>w4E9*!vfk?*{(tTH>M)lrb#}%9vIMQnptA`rvDWOT8aB*XY~T*9S#SAL78g)(uV7 zeIvWrvC#2m=k;T0jsEr1ubti&wK<2Hd41@$p+tGh2PfBd-9B@DNYp%qNFZ}*&xWS@ zKD2Z8wX4FaA zdTz5{KXcDfw>rHp*yGMSi{j;t_XebegS@mA&N`HrUqt zjCGM+VAHH~>59lUzO4gsYx<@58J7VwoC_(!5=bEA8EQ|CFTk)->VpW3xZA3 z9ZLSVdpsHjlUw1@!LUZr7r>0IKJCi##piuG+2p+wcqnKfShwg=Bo1Q!h8>;2zt7P= zBWjWXFxroRBq=`}Iq8qEoH}AYZz{+qya@8O5%ci=r$NF*L+xNE@C`Hy>p30ay={=z z`Ah__8@QIXB$qSE(j!L>27#{;Su!j2n1Ry)Cm+y8Ea&e4USkd~7|plEE6a1lWRCoj za8WaczCyhJ5b#=Zcr9N7uZ3Vz@XD#j@xB0ftvS5bFM-!e@G9dVr(}t>klr1D*GBNR z!uZ;mg^`fK^vQPf~UnS;qpH657#b7-c_;*hX@C(l)rs?z9bV0+YjC2GFqQ6+b?k54^$T z5Ra$O>x;LL*;so2Ue*2XuZ|>QYFT@e9UI6 z4ZzhXrTI4;Xgv;BT1Ro5`9$`q3In0&YENw`0Tkw z&ecY;&j@@F?Dc#wmj^Cd;2tkrt;CP9xBXDJFrFOcs%3$ki8cOeKBRrA4TR zl<*v8(NRfr5e{91_c6-{!Q=#a{^k%!%|yB;?Tlct9|9miKNW}s!(cXeRa*m+gDqHe)BuLeP{x@O)C zue$D3?QXGlcdE93;oSW3wAH!X^5(1eI(M#>-S#AG?dj^;8@gW>CYN++YkAVz@N-)g z)UT|WKXu<~Uz&RJ)s=m#my^wV(yrQfI&XER9j@hlqN63zIv_gs7o}<%)2>}~`X;)|3<^BVwW=kx;)6{H<<#(N$ zEfH-=PU)$63{Q4yk}x$Ly~`6xTkFF~n3|r?!7;4dxioh3r5i8(V0z_<=xke6-7(!Z zy>DI%h#mWq_JMglocV17k!FmG`UQR3>{yyiu=TmO(?vAA3RH+}L?QZHuyKuOSk;xR>RJ1CYVR3w@0sM@XA>7*z>`C#d~&Fx45mf%f_d4vvS+0|;q2UC zc3^XurCsKiUs#H6Fl|`^mNYAx)noJe4QA(OehpRKLo^`C76^xlfxu5(2ZWe>T1?K> z`fxKshPg>GLf;~y7)3xp0sjfQl$l#$vtT#`MJl2NB}bcgv>OB*$N)iA0@E?TWYBax zOvLX9lNWePd0nCRl1~r*BkDeN-pDzU6N;3UONmClTn$js7-)9_+C|Tz5@J*VEJ8d% zhypJn8V;wxlL!R!oB#!*lzqUUPb*^377#E#+d$~@5S$*AGS+8kqJ}?EYDLV0Q}USl zG&#Wu82SjBZpc&0$LmSr6NpAn6re!;{Me^Q!4%+1h)~KrxTzz|;AdIdzT8@2O05;2 zt_3`GhfdB_cxNN%XbOj2S&_yKX2UtgjqVFzHxm*-laRy!c5BH1rqQpKZ^J{|?jdI(L~SH5nHZhyS}UA{ z$}tNBn*|bC7=f>Y5-|8N3L`rp0B}n6o9=;K$oz&mQ3WS=H5j|x1KXy&fqo4l=z#X54fIKx{*%rI zTPWhX27%i z``eq3IyK90$PMgT2G&5xO3I1Nv9!w$7;pjgRW5IW@SC zGlG`WxF|u#X(0s56@(0@7Ysh?AT1c-b7utDIH=I4K+W8X@;Buk4+ml~H+;Wtki)~^ zw2WIpPP=h~65`G&IBr=zAkf)w0$q;WStsT>a^ldlX>xSXAmMHwL--60hCj%VpNq7Q z!t+jcPDEz|o&R$WH`N(YN585^%T-&W?kqa8)wnj*B`t%Mn;^C!gscu*JD1yxayhk6G{-mY9e<3fCr;~00-z$v0!>&w7AUhDDJpM(cwDT5iQPE_`+t3D(|UdR%V zsO(+0e|Q3E^5(L@!}LSax#x0OB`S8UjeSTub3O-^ub69dDNlVqZ|;LVcMkmNfpv9q z$AJ%h*f^KPvGn^Ded~i0ch3Fkx%Kj7*TD~YM0jy4gdZ`SENUQcRvzTNgjYA|vzn+< zQt(D+Lk()F1$jyoxub?c$B~E{IJF|WD8=wD#ozOqA?&)UQU_~X&ed@2pHtX2Q-CrX z9eWo_K+y3*>%|YJ?FM}_rpjM61L{!kE_zB*SF$yQTM#~h@_x}%5|c_kLfBL{r%&b) zl&o8l3ZJ?gH0@3J?0Gy&tR%iHEstR+J^r~%?qfU|e>h4je#XFpO~${({iWRa=j;M7 z{#C^owpbmsw?u23`V2?L#m^RfCI_+wbxZeFuRvsbTU(5sqH?`OBY)A8>ut_<0ljq> z0YS1PjdZDdTM9j7wKaW7Oe*;px}a|9-sWytjs29&ZOh~M$Npt@ZrdI&%k>tG1LT3; zDkc*p3*KAwwnS?M6qmfWRT7g*K4vedt1fk)3+}(sVEM`Jdf?hjdC{M>=qcF(?JYs8 zPxIIpxU&cuBunDUQv4W((z9exdN%emGE2K3ZzVTN{g0RBW-0frK+n;;zvwARDeW!M znnDj*llzL$K(ZtzmF(79PB~L1_W&b#;h3FFi0X6 zH^#~^utOkeE{FIl0-lixRTv-qi|je!dSaAik{Quu!u{1yT*dV@?GBa$y% zvBPUk<;6MvzG7E*NjFOdPjVJZpK`tq#D`c40x*)nm8Hga25t?kTv+igA6(sYXW;h0 z8n-sK`ixkECp{qMq?jrYPE$;S$TTGI2J|kd#oE! zwyY-?eO9vU$B#_`^1+14R+7pIoc3dszg)QTE$Q5eq1N+It!F6 z4!tq74Blvs-;+CCK-Y=)2H%CjkjJoKY@c5}Bes$;Y}%b_+AB8gU01CiS${d%bR>aJ zZZNQDK1{J)BHOjr8Y3-V&1aV|{KW{1ErXqNk8n&?}w{p*c^( z*yJ2)@;XUDUQP??SAj8ABj1PisD_4wcH9;Xk}(tu!42a8-Z_N}KH#mOjnB@6d0_@f zlT~s9K$3gUQvvw74&6m!&L%eE?6xW~XiFfKiD!u%RYNA+4g*s~(S;$&?1vwl0GAXv zy}+OQs4Lz~Hx>8ep85_D16~je1+Jx2KXUwNcEy}B^@ygPd*HRO@6B(o&`E1!y0Y<} zz4E5#`<~^~NqcL`-Ywd@SI3g}UGvB8l~>&y`2N81i^=l#RC%vh-n*tsmiH|j|E0Nd znfX!cdtL8#rOdsexi@WhrR+_jy=i54(%uGFILm8OIo6*#npCI8{3|4E0KArnz2$ zbh#xJCB7DbwdH%G#5Pee>C=H5rn*N36lk&(l+Yr-6a^(THK3q`M!1QB1s~%xWGT>Q zDJY>G<|ug5%;w}11=Uyi3E+hQSXk$vGG$FcVpr z2{{ROUkCu+MM0x=HhEdba5;~GL&cl!cKtt{pN*%R@TxcmPr z7`*G$eY;~`j}nt@@!;aYbges8+XGicR^zGO!(#7Yyw`tGJbE!%>suIvbK|DAR8zm$ z)StFFQnpUf*138-)pbhjIt6$6n=NoXVg58+p{{RE*EOfxdeUvZ4-MKz%lxUO({LgT zN8FD!bLZ9Duin+94nHRzelA&i0UNmwm$d7eA8OV4itHa$@tMVEHg)U}h~ak~sUu$T zh&NfwJ>D^oX!ET0r*@nWcbs@a$C|KXXFxCbk{#PTdL5#zW3?sKc}(m)mTzP!^n~$a z!P9NtLgYnq-`!&pzKh??>ln;quz=I9_xppZf>f?|M2#4(Vps9%4V`nhKOSDJe5Mf$VDbeGmge>O_nwB28{ z?AKOiCFE$|B9aUQk`8{8433ipWWw^je`MoQS^nn?;rFh{k5=g9x9}2y zgvQ|q0N_sb3>-fPV!|ZmevecozYfq-_#0&`$&xkjn1DPSKAZ99$i*%)3X%alMb3e1 z2{?xsr)R?8)pNgS~pHb5^{Rw6I zgtGjCY8R>Y|3Woxtq{^2@RvHtImJO;E{NPgN z`j5=@%daHOUDu6i%9sT?b@Tj<^BYvlXBv}HbzPUKr;Ol&dMv>lPaAe5x=$yVGa0Sg zu7hi4#@&g&3kf&_DJXF0J~L_zMiOMqrbgXou7adn2fsgNY)UjAN-&2DtCbOaI}^JG z6UE{{NkkQB2Cvkq%h1xR3^H7r94ByYt^wDyzTlNqnX-sG~JzsTV|aLof$Pr)YiXq z?$)_%VrO4Aa`Z$ta-K_EeBs|-lT-J#hR?NFD-%?`prUu>Fku25N;6uFz)os&Wb`D; zP$o;pK%z#ftYV3~VaS+BlBIOU&&?$I9q<^ee`wl4!;iDMGHQ6=H`x~VE$mx7v~Vb+ zg*2eFJD0|n` argparse.Namespace: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("--config", help="path to XYZ configurations", required=True) - parser.add_argument( - "--config_index", help="index of configuration", type=int, default=-1 - ) - parser.add_argument( - "--error_threshold", help="error threshold", type=float, default=0.1 - ) - parser.add_argument("--temperature_K", help="temperature", type=float, default=300) - parser.add_argument("--friction", help="friction", type=float, default=0.01) - parser.add_argument("--timestep", help="timestep", type=float, default=1) - parser.add_argument("--nsteps", help="number of steps", type=int, default=1000) - parser.add_argument( - "--nprint", help="number of steps between prints", type=int, default=10 - ) - parser.add_argument( - "--nsave", help="number of steps between saves", type=int, default=10 - ) - parser.add_argument( - "--ncheckerror", help="number of steps between saves", type=int, default=10 - ) - - parser.add_argument( - "--model", - help="path to model. Use wildcards to add multiple models as committee eg " - "(`mace_*.model` to load mace_1.model, mace_2.model) ", - required=True, - ) - parser.add_argument("--output", help="output path", required=True) - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda"], - default="cuda", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument( - "--compute_stress", - help="compute stress", - action="store_true", - default=False, - ) - parser.add_argument( - "--info_prefix", - help="prefix for energy, forces and stress keys", - type=str, - default="MACE_", - ) - return parser.parse_args() - - -def printenergy(dyn, start_time=None): # store a reference to atoms in the definition. - """Function to print the potential, kinetic and total energy.""" - a = dyn.atoms - epot = a.get_potential_energy() / len(a) - ekin = a.get_kinetic_energy() / len(a) - if start_time is None: - elapsed_time = 0 - else: - elapsed_time = time.time() - start_time - forces_var = np.var(a.calc.results["forces_comm"], axis=0) - print( - "%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209 - "Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A" - % ( - elapsed_time, - epot, - ekin, - ekin / (1.5 * units.kB), - epot + ekin, - dyn.get_time() / units.fs, - a.calc.results["energy_var"], - np.max(np.linalg.norm(forces_var, axis=1)), - ), - flush=True, - ) - - -def save_config(dyn, fname): - atomsi = dyn.atoms - ens = atomsi.get_potential_energy() - frcs = atomsi.get_forces() - - atomsi.info.update( - { - "mlff_energy": ens, - "time": np.round(dyn.get_time() / units.fs, 5), - "mlff_energy_var": atomsi.calc.results["energy_var"], - } - ) - atomsi.arrays.update( - { - "mlff_forces": frcs, - "mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0), - } - ) - - ase.io.write(fname, atomsi, append=True) - - -def stop_error(dyn, threshold, reg=0.2): - atomsi = dyn.atoms - force_var = np.var(atomsi.calc.results["forces_comm"], axis=0) - force = atomsi.get_forces() - ferr = np.sqrt(np.sum(force_var, axis=1)) - ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg) - - if np.max(ferr_rel) > threshold: - print( - "Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209 - np.max(ferr_rel), dyn.get_time() / units.fs - ), - flush=True, - ) - dyn.max_steps = 0 - - -def main() -> None: - args = parse_args() - run(args) - - -def run(args: argparse.Namespace) -> None: - mace_fname = args.model - atoms_fname = args.config - atoms_index = args.config_index - - mace_calc = MACECalculator( - model_paths=mace_fname, - device=args.device, - default_dtype=args.default_dtype, - ) - - NSTEPS = args.nsteps - - if os.path.exists(args.output): - print("Trajectory exists. Continuing from last step.") - atoms = ase.io.read(args.output, index=-1) - len_save = len(ase.io.read(args.output, ":")) - print("Last step: ", atoms.info["time"], "Number of configs: ", len_save) - NSTEPS -= len_save * args.nsave - else: - atoms = ase.io.read(atoms_fname, index=atoms_index) - MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K) - - atoms.calc = mace_calc - - # We want to run MD with constant energy using the Langevin algorithm - # with a time step of 5 fs, the temperature T and the friction - # coefficient to 0.02 atomic units. - dyn = Langevin( - atoms=atoms, - timestep=args.timestep * units.fs, - temperature_K=args.temperature_K, - friction=args.friction, - ) - - dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time()) - dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output) - dyn.attach( - stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold - ) - # Now run the dynamics - dyn.run(NSTEPS) - - -if __name__ == "__main__": - main() +"""Demonstrates active learning molecular dynamics with constant temperature.""" + +import argparse +import os +import time + +import ase.io +import numpy as np +from ase import units +from ase.md.langevin import Langevin +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution + +from mace.calculators.mace import MACECalculator + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--config", help="path to XYZ configurations", required=True) + parser.add_argument( + "--config_index", help="index of configuration", type=int, default=-1 + ) + parser.add_argument( + "--error_threshold", help="error threshold", type=float, default=0.1 + ) + parser.add_argument("--temperature_K", help="temperature", type=float, default=300) + parser.add_argument("--friction", help="friction", type=float, default=0.01) + parser.add_argument("--timestep", help="timestep", type=float, default=1) + parser.add_argument("--nsteps", help="number of steps", type=int, default=1000) + parser.add_argument( + "--nprint", help="number of steps between prints", type=int, default=10 + ) + parser.add_argument( + "--nsave", help="number of steps between saves", type=int, default=10 + ) + parser.add_argument( + "--ncheckerror", help="number of steps between saves", type=int, default=10 + ) + + parser.add_argument( + "--model", + help="path to model. Use wildcards to add multiple models as committee eg " + "(`mace_*.model` to load mace_1.model, mace_2.model) ", + required=True, + ) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cuda", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--compute_stress", + help="compute stress", + action="store_true", + default=False, + ) + parser.add_argument( + "--info_prefix", + help="prefix for energy, forces and stress keys", + type=str, + default="MACE_", + ) + return parser.parse_args() + + +def printenergy(dyn, start_time=None): # store a reference to atoms in the definition. + """Function to print the potential, kinetic and total energy.""" + a = dyn.atoms + epot = a.get_potential_energy() / len(a) + ekin = a.get_kinetic_energy() / len(a) + if start_time is None: + elapsed_time = 0 + else: + elapsed_time = time.time() - start_time + forces_var = np.var(a.calc.results["forces_comm"], axis=0) + print( + "%.1fs: Energy per atom: Epot = %.3feV Ekin = %.3feV (T=%3.0fK) " # pylint: disable=C0209 + "Etot = %.3feV t=%.1ffs Eerr = %.3feV Ferr = %.3feV/A" + % ( + elapsed_time, + epot, + ekin, + ekin / (1.5 * units.kB), + epot + ekin, + dyn.get_time() / units.fs, + a.calc.results["energy_var"], + np.max(np.linalg.norm(forces_var, axis=1)), + ), + flush=True, + ) + + +def save_config(dyn, fname): + atomsi = dyn.atoms + ens = atomsi.get_potential_energy() + frcs = atomsi.get_forces() + + atomsi.info.update( + { + "mlff_energy": ens, + "time": np.round(dyn.get_time() / units.fs, 5), + "mlff_energy_var": atomsi.calc.results["energy_var"], + } + ) + atomsi.arrays.update( + { + "mlff_forces": frcs, + "mlff_forces_var": np.var(atomsi.calc.results["forces_comm"], axis=0), + } + ) + + ase.io.write(fname, atomsi, append=True) + + +def stop_error(dyn, threshold, reg=0.2): + atomsi = dyn.atoms + force_var = np.var(atomsi.calc.results["forces_comm"], axis=0) + force = atomsi.get_forces() + ferr = np.sqrt(np.sum(force_var, axis=1)) + ferr_rel = ferr / (np.linalg.norm(force, axis=1) + reg) + + if np.max(ferr_rel) > threshold: + print( + "Error too large {:.3}. Stopping t={:.2} fs.".format( # pylint: disable=C0209 + np.max(ferr_rel), dyn.get_time() / units.fs + ), + flush=True, + ) + dyn.max_steps = 0 + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + mace_fname = args.model + atoms_fname = args.config + atoms_index = args.config_index + + mace_calc = MACECalculator( + model_paths=mace_fname, + device=args.device, + default_dtype=args.default_dtype, + ) + + NSTEPS = args.nsteps + + if os.path.exists(args.output): + print("Trajectory exists. Continuing from last step.") + atoms = ase.io.read(args.output, index=-1) + len_save = len(ase.io.read(args.output, ":")) + print("Last step: ", atoms.info["time"], "Number of configs: ", len_save) + NSTEPS -= len_save * args.nsave + else: + atoms = ase.io.read(atoms_fname, index=atoms_index) + MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature_K) + + atoms.calc = mace_calc + + # We want to run MD with constant energy using the Langevin algorithm + # with a time step of 5 fs, the temperature T and the friction + # coefficient to 0.02 atomic units. + dyn = Langevin( + atoms=atoms, + timestep=args.timestep * units.fs, + temperature_K=args.temperature_K, + friction=args.friction, + ) + + dyn.attach(printenergy, interval=args.nsave, dyn=dyn, start_time=time.time()) + dyn.attach(save_config, interval=args.nsave, dyn=dyn, fname=args.output) + dyn.attach( + stop_error, interval=args.ncheckerror, dyn=dyn, threshold=args.error_threshold + ) + # Now run the dynamics + dyn.run(NSTEPS) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/convert_cueq_e3nn.py b/mace-bench/3rdparty/mace/mace/cli/convert_cueq_e3nn.py index 5aa2056..c2399ca 100644 --- a/mace-bench/3rdparty/mace/mace/cli/convert_cueq_e3nn.py +++ b/mace-bench/3rdparty/mace/mace/cli/convert_cueq_e3nn.py @@ -1,208 +1,208 @@ -import argparse -import logging -import os -from typing import Dict, List, Tuple - -import torch - -from mace.tools.scripts_utils import extract_config_mace_model - - -def get_transfer_keys(num_layers: int) -> List[str]: - """Get list of keys that need to be transferred""" - return [ - "node_embedding.linear.weight", - "radial_embedding.bessel_fn.bessel_weights", - "atomic_energies_fn.atomic_energies", - "readouts.0.linear.weight", - *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], - "scale_shift.scale", - "scale_shift.shift", - *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], - ] + [ - s - for j in range(num_layers) - for s in [ - f"interactions.{j}.linear_up.weight", - *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], - f"interactions.{j}.linear.weight", - f"interactions.{j}.skip_tp.weight", - f"products.{j}.linear.weight", - ] - ] - - -def get_kmax_pairs( - max_L: int, correlation: int, num_layers: int -) -> List[Tuple[int, int]]: - """Determine kmax pairs based on max_L and correlation""" - if correlation == 2: - raise NotImplementedError("Correlation 2 not supported yet") - if correlation == 3: - kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] - kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] - return kmax_pairs - raise NotImplementedError(f"Correlation {correlation} not supported") - - -def transfer_symmetric_contractions( - source_dict: Dict[str, torch.Tensor], - target_dict: Dict[str, torch.Tensor], - max_L: int, - correlation: int, - num_layers: int, -): - """Transfer symmetric contraction weights from CuEq to E3nn format""" - kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) - - for i, kmax in kmax_pairs: - # Get the combined weight tensor from source - wm = source_dict[f"products.{i}.symmetric_contractions.weight"] - - # Get split sizes based on target dimensions - splits = [] - for k in range(kmax + 1): - for suffix in ["_max", ".0", ".1"]: - key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}" - target_shape = target_dict[key].shape - splits.append(target_shape[1]) - - # Split the weights using the calculated sizes - weights_split = torch.split(wm, splits, dim=1) - - # Assign back to target dictionary - idx = 0 - for k in range(kmax + 1): - target_dict[ - f"products.{i}.symmetric_contractions.contractions.{k}.weights_max" - ] = weights_split[idx] - target_dict[ - f"products.{i}.symmetric_contractions.contractions.{k}.weights.0" - ] = weights_split[idx + 1] - target_dict[ - f"products.{i}.symmetric_contractions.contractions.{k}.weights.1" - ] = weights_split[idx + 2] - idx += 3 - - -def transfer_weights( - source_model: torch.nn.Module, - target_model: torch.nn.Module, - max_L: int, - correlation: int, - num_layers: int, -): - """Transfer weights from CuEq to E3nn format""" - # Get state dicts - source_dict = source_model.state_dict() - target_dict = target_model.state_dict() - - # Transfer main weights - transfer_keys = get_transfer_keys(num_layers) - for key in transfer_keys: - if key in source_dict: # Check if key exists - target_dict[key] = source_dict[key] - else: - logging.warning(f"Key {key} not found in source model") - - # Transfer symmetric contractions - transfer_symmetric_contractions( - source_dict, target_dict, max_L, correlation, num_layers - ) - - # Unsqueeze linear and skip_tp layers - for key in source_dict.keys(): - if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: - target_dict[key] = target_dict[key].squeeze(0) - - # Transfer remaining matching keys - transferred_keys = set(transfer_keys) - remaining_keys = ( - set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys - ) - remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} - - if remaining_keys: - for key in remaining_keys: - if source_dict[key].shape == target_dict[key].shape: - logging.debug(f"Transferring additional key: {key}") - target_dict[key] = source_dict[key] - else: - logging.warning( - f"Shape mismatch for key {key}: " - f"source {source_dict[key].shape} vs target {target_dict[key].shape}" - ) - - # Transfer avg_num_neighbors - for i in range(2): - target_model.interactions[i].avg_num_neighbors = source_model.interactions[ - i - ].avg_num_neighbors - - # Load state dict into target model - target_model.load_state_dict(target_dict) - - -def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True): - - # Load CuEq model - if isinstance(input_model, str): - source_model = torch.load(input_model, map_location=device) - else: - source_model = input_model - default_dtype = next(source_model.parameters()).dtype - torch.set_default_dtype(default_dtype) - # Extract configuration - config = extract_config_mace_model(source_model) - - # Get max_L and correlation from config - max_L = config["hidden_irreps"].lmax - correlation = config["correlation"] - - # Remove CuEq config - config.pop("cueq_config", None) - - # Create new model without CuEq config - logging.info("Creating new model without CuEq settings") - target_model = source_model.__class__(**config) - - # Transfer weights with proper remapping - num_layers = config["num_interactions"] - transfer_weights(source_model, target_model, max_L, correlation, num_layers) - - if return_model: - return target_model - - # Save model - if isinstance(input_model, str): - base = os.path.splitext(input_model)[0] - output_model = f"{base}.{output_model}" - logging.warning(f"Saving E3nn model to {output_model}") - torch.save(target_model, output_model) - return None - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input_model", help="Path to input CuEq model") - parser.add_argument( - "--output_model", help="Path to output E3nn model", default="e3nn_model.pt" - ) - parser.add_argument("--device", default="cpu", help="Device to use") - parser.add_argument( - "--return_model", - action="store_false", - help="Return model instead of saving to file", - ) - args = parser.parse_args() - - run( - input_model=args.input_model, - output_model=args.output_model, - device=args.device, - return_model=args.return_model, - ) - - -if __name__ == "__main__": - main() +import argparse +import logging +import os +from typing import Dict, List, Tuple + +import torch + +from mace.tools.scripts_utils import extract_config_mace_model + + +def get_transfer_keys(num_layers: int) -> List[str]: + """Get list of keys that need to be transferred""" + return [ + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], + ] + [ + s + for j in range(num_layers) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", + ] + ] + + +def get_kmax_pairs( + max_L: int, correlation: int, num_layers: int +) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + if correlation == 3: + kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] + kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] + return kmax_pairs + raise NotImplementedError(f"Correlation {correlation} not supported") + + +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, + num_layers: int, +): + """Transfer symmetric contraction weights from CuEq to E3nn format""" + kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) + + for i, kmax in kmax_pairs: + # Get the combined weight tensor from source + wm = source_dict[f"products.{i}.symmetric_contractions.weight"] + + # Get split sizes based on target dimensions + splits = [] + for k in range(kmax + 1): + for suffix in ["_max", ".0", ".1"]: + key = f"products.{i}.symmetric_contractions.contractions.{k}.weights{suffix}" + target_shape = target_dict[key].shape + splits.append(target_shape[1]) + + # Split the weights using the calculated sizes + weights_split = torch.split(wm, splits, dim=1) + + # Assign back to target dictionary + idx = 0 + for k in range(kmax + 1): + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights_max" + ] = weights_split[idx] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.0" + ] = weights_split[idx + 1] + target_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights.1" + ] = weights_split[idx + 2] + idx += 3 + + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, + num_layers: int, +): + """Transfer weights from CuEq to E3nn format""" + # Get state dicts + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys(num_layers) + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + transfer_symmetric_contractions( + source_dict, target_dict, max_L, correlation, num_layers + ) + + # Unsqueeze linear and skip_tp layers + for key in source_dict.keys(): + if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: + target_dict[key] = target_dict[key].squeeze(0) + + # Transfer remaining matching keys + transferred_keys = set(transfer_keys) + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + + if remaining_keys: + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + + +def run(input_model, output_model="_e3nn.model", device="cpu", return_model=True): + + # Load CuEq model + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + default_dtype = next(source_model.parameters()).dtype + torch.set_default_dtype(default_dtype) + # Extract configuration + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + + # Remove CuEq config + config.pop("cueq_config", None) + + # Create new model without CuEq config + logging.info("Creating new model without CuEq settings") + target_model = source_model.__class__(**config) + + # Transfer weights with proper remapping + num_layers = config["num_interactions"] + transfer_weights(source_model, target_model, max_L, correlation, num_layers) + + if return_model: + return target_model + + # Save model + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.warning(f"Saving E3nn model to {output_model}") + torch.save(target_model, output_model) + return None + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_model", help="Path to input CuEq model") + parser.add_argument( + "--output_model", help="Path to output E3nn model", default="e3nn_model.pt" + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model, + ) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/convert_device.py b/mace-bench/3rdparty/mace/mace/cli/convert_device.py index 3366bfd..69735b7 100644 --- a/mace-bench/3rdparty/mace/mace/cli/convert_device.py +++ b/mace-bench/3rdparty/mace/mace/cli/convert_device.py @@ -1,31 +1,31 @@ -from argparse import ArgumentParser - -import torch - - -def main(): - parser = ArgumentParser() - parser.add_argument( - "--target_device", - "-t", - help="device to convert to, usually 'cpu' or 'cuda'", - default="cpu", - ) - parser.add_argument( - "--output_file", - "-o", - help="name for output model, defaults to model_file.target_device", - ) - parser.add_argument("model_file", help="input model file path") - args = parser.parse_args() - - if args.output_file is None: - args.output_file = args.model_file + "." + args.target_device - - model = torch.load(args.model_file, weights_only=False) - model.to(args.target_device) - torch.save(model, args.output_file) - - -if __name__ == "__main__": - main() +from argparse import ArgumentParser + +import torch + + +def main(): + parser = ArgumentParser() + parser.add_argument( + "--target_device", + "-t", + help="device to convert to, usually 'cpu' or 'cuda'", + default="cpu", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model_file.target_device", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + if args.output_file is None: + args.output_file = args.model_file + "." + args.target_device + + model = torch.load(args.model_file, weights_only=False) + model.to(args.target_device) + torch.save(model, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/convert_e3nn_cueq.py b/mace-bench/3rdparty/mace/mace/cli/convert_e3nn_cueq.py index 5e82233..299291f 100644 --- a/mace-bench/3rdparty/mace/mace/cli/convert_e3nn_cueq.py +++ b/mace-bench/3rdparty/mace/mace/cli/convert_e3nn_cueq.py @@ -1,204 +1,204 @@ -import argparse -import logging -import os -from typing import Dict, List, Tuple - -import torch - -from mace.modules.wrapper_ops import CuEquivarianceConfig -from mace.tools.scripts_utils import extract_config_mace_model - - -def get_transfer_keys(num_layers: int) -> List[str]: - """Get list of keys that need to be transferred""" - return [ - "node_embedding.linear.weight", - "radial_embedding.bessel_fn.bessel_weights", - "atomic_energies_fn.atomic_energies", - "readouts.0.linear.weight", - *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], - "scale_shift.scale", - "scale_shift.shift", - *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], - ] + [ - s - for j in range(num_layers) - for s in [ - f"interactions.{j}.linear_up.weight", - *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], - f"interactions.{j}.linear.weight", - f"interactions.{j}.skip_tp.weight", - f"products.{j}.linear.weight", - ] - ] - - -def get_kmax_pairs( - max_L: int, correlation: int, num_layers: int -) -> List[Tuple[int, int]]: - """Determine kmax pairs based on max_L and correlation""" - if correlation == 2: - raise NotImplementedError("Correlation 2 not supported yet") - if correlation == 3: - kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] - kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] - return kmax_pairs - raise NotImplementedError(f"Correlation {correlation} not supported") - - -def transfer_symmetric_contractions( - source_dict: Dict[str, torch.Tensor], - target_dict: Dict[str, torch.Tensor], - max_L: int, - correlation: int, - num_layers: int, -): - """Transfer symmetric contraction weights""" - kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) - - for i, kmax in kmax_pairs: - wm = torch.concatenate( - [ - source_dict[ - f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}" - ] - for k in range(kmax + 1) - for j in ["_max", ".0", ".1"] - ], - dim=1, - ) - target_dict[f"products.{i}.symmetric_contractions.weight"] = wm - - -def transfer_weights( - source_model: torch.nn.Module, - target_model: torch.nn.Module, - max_L: int, - correlation: int, - num_layers: int, -): - """Transfer weights with proper remapping""" - # Get source state dict - source_dict = source_model.state_dict() - target_dict = target_model.state_dict() - - # Transfer main weights - transfer_keys = get_transfer_keys(num_layers) - for key in transfer_keys: - if key in source_dict: # Check if key exists - target_dict[key] = source_dict[key] - else: - logging.warning(f"Key {key} not found in source model") - - # Transfer symmetric contractions - transfer_symmetric_contractions( - source_dict, target_dict, max_L, correlation, num_layers - ) - - # Unsqueeze linear and skip_tp layers - for key in source_dict.keys(): - if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: - target_dict[key] = target_dict[key].unsqueeze(0) - - transferred_keys = set(transfer_keys) - remaining_keys = ( - set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys - ) - remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} - if remaining_keys: - for key in remaining_keys: - if source_dict[key].shape == target_dict[key].shape: - logging.debug(f"Transferring additional key: {key}") - target_dict[key] = source_dict[key] - else: - logging.warning( - f"Shape mismatch for key {key}: " - f"source {source_dict[key].shape} vs target {target_dict[key].shape}" - ) - # Transfer avg_num_neighbors - for i in range(2): - target_model.interactions[i].avg_num_neighbors = source_model.interactions[ - i - ].avg_num_neighbors - - # Load state dict into target model - target_model.load_state_dict(target_dict) - - -def run( - input_model, - output_model="_cueq.model", - device="cpu", - return_model=True, -): - # Setup logging - - # Load original model - # logging.warning(f"Loading model") - # check if input_model is a path or a model - if isinstance(input_model, str): - source_model = torch.load(input_model, map_location=device) - else: - source_model = input_model - default_dtype = next(source_model.parameters()).dtype - torch.set_default_dtype(default_dtype) - # Extract configuration - config = extract_config_mace_model(source_model) - - # Get max_L and correlation from config - max_L = config["hidden_irreps"].lmax - correlation = config["correlation"] - - # Add cuequivariance config - config["cueq_config"] = CuEquivarianceConfig( - enabled=True, - layout="ir_mul", - group="O3_e3nn", - optimize_all=True, - ) - - # Create new model with cuequivariance config - logging.info("Creating new model with cuequivariance settings") - target_model = source_model.__class__(**config).to(device) - - # Transfer weights with proper remapping - num_layers = config["num_interactions"] - transfer_weights(source_model, target_model, max_L, correlation, num_layers) - - if return_model: - return target_model - - if isinstance(input_model, str): - base = os.path.splitext(input_model)[0] - output_model = f"{base}.{output_model}" - logging.warning(f"Saving CuEq model to {output_model}") - torch.save(target_model, output_model) - return None - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input_model", help="Path to input MACE model") - parser.add_argument( - "--output_model", - help="Path to output cuequivariance model", - default="cueq_model.pt", - ) - parser.add_argument("--device", default="cpu", help="Device to use") - parser.add_argument( - "--return_model", - action="store_false", - help="Return model instead of saving to file", - ) - args = parser.parse_args() - - run( - input_model=args.input_model, - output_model=args.output_model, - device=args.device, - return_model=args.return_model, - ) - - -if __name__ == "__main__": - main() +import argparse +import logging +import os +from typing import Dict, List, Tuple + +import torch + +from mace.modules.wrapper_ops import CuEquivarianceConfig +from mace.tools.scripts_utils import extract_config_mace_model + + +def get_transfer_keys(num_layers: int) -> List[str]: + """Get list of keys that need to be transferred""" + return [ + "node_embedding.linear.weight", + "radial_embedding.bessel_fn.bessel_weights", + "atomic_energies_fn.atomic_energies", + "readouts.0.linear.weight", + *[f"readouts.{j}.linear.weight" for j in range(num_layers - 1)], + "scale_shift.scale", + "scale_shift.shift", + *[f"readouts.{num_layers-1}.linear_{i}.weight" for i in range(1, 3)], + ] + [ + s + for j in range(num_layers) + for s in [ + f"interactions.{j}.linear_up.weight", + *[f"interactions.{j}.conv_tp_weights.layer{i}.weight" for i in range(4)], + f"interactions.{j}.linear.weight", + f"interactions.{j}.skip_tp.weight", + f"products.{j}.linear.weight", + ] + ] + + +def get_kmax_pairs( + max_L: int, correlation: int, num_layers: int +) -> List[Tuple[int, int]]: + """Determine kmax pairs based on max_L and correlation""" + if correlation == 2: + raise NotImplementedError("Correlation 2 not supported yet") + if correlation == 3: + kmax_pairs = [[i, max_L] for i in range(num_layers - 1)] + kmax_pairs = kmax_pairs + [[num_layers - 1, 0]] + return kmax_pairs + raise NotImplementedError(f"Correlation {correlation} not supported") + + +def transfer_symmetric_contractions( + source_dict: Dict[str, torch.Tensor], + target_dict: Dict[str, torch.Tensor], + max_L: int, + correlation: int, + num_layers: int, +): + """Transfer symmetric contraction weights""" + kmax_pairs = get_kmax_pairs(max_L, correlation, num_layers) + + for i, kmax in kmax_pairs: + wm = torch.concatenate( + [ + source_dict[ + f"products.{i}.symmetric_contractions.contractions.{k}.weights{j}" + ] + for k in range(kmax + 1) + for j in ["_max", ".0", ".1"] + ], + dim=1, + ) + target_dict[f"products.{i}.symmetric_contractions.weight"] = wm + + +def transfer_weights( + source_model: torch.nn.Module, + target_model: torch.nn.Module, + max_L: int, + correlation: int, + num_layers: int, +): + """Transfer weights with proper remapping""" + # Get source state dict + source_dict = source_model.state_dict() + target_dict = target_model.state_dict() + + # Transfer main weights + transfer_keys = get_transfer_keys(num_layers) + for key in transfer_keys: + if key in source_dict: # Check if key exists + target_dict[key] = source_dict[key] + else: + logging.warning(f"Key {key} not found in source model") + + # Transfer symmetric contractions + transfer_symmetric_contractions( + source_dict, target_dict, max_L, correlation, num_layers + ) + + # Unsqueeze linear and skip_tp layers + for key in source_dict.keys(): + if any(x in key for x in ["linear", "skip_tp"]) and "weight" in key: + target_dict[key] = target_dict[key].unsqueeze(0) + + transferred_keys = set(transfer_keys) + remaining_keys = ( + set(source_dict.keys()) & set(target_dict.keys()) - transferred_keys + ) + remaining_keys = {k for k in remaining_keys if "symmetric_contraction" not in k} + if remaining_keys: + for key in remaining_keys: + if source_dict[key].shape == target_dict[key].shape: + logging.debug(f"Transferring additional key: {key}") + target_dict[key] = source_dict[key] + else: + logging.warning( + f"Shape mismatch for key {key}: " + f"source {source_dict[key].shape} vs target {target_dict[key].shape}" + ) + # Transfer avg_num_neighbors + for i in range(2): + target_model.interactions[i].avg_num_neighbors = source_model.interactions[ + i + ].avg_num_neighbors + + # Load state dict into target model + target_model.load_state_dict(target_dict) + + +def run( + input_model, + output_model="_cueq.model", + device="cpu", + return_model=True, +): + # Setup logging + + # Load original model + # logging.warning(f"Loading model") + # check if input_model is a path or a model + if isinstance(input_model, str): + source_model = torch.load(input_model, map_location=device) + else: + source_model = input_model + default_dtype = next(source_model.parameters()).dtype + torch.set_default_dtype(default_dtype) + # Extract configuration + config = extract_config_mace_model(source_model) + + # Get max_L and correlation from config + max_L = config["hidden_irreps"].lmax + correlation = config["correlation"] + + # Add cuequivariance config + config["cueq_config"] = CuEquivarianceConfig( + enabled=True, + layout="ir_mul", + group="O3_e3nn", + optimize_all=True, + ) + + # Create new model with cuequivariance config + logging.info("Creating new model with cuequivariance settings") + target_model = source_model.__class__(**config).to(device) + + # Transfer weights with proper remapping + num_layers = config["num_interactions"] + transfer_weights(source_model, target_model, max_L, correlation, num_layers) + + if return_model: + return target_model + + if isinstance(input_model, str): + base = os.path.splitext(input_model)[0] + output_model = f"{base}.{output_model}" + logging.warning(f"Saving CuEq model to {output_model}") + torch.save(target_model, output_model) + return None + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input_model", help="Path to input MACE model") + parser.add_argument( + "--output_model", + help="Path to output cuequivariance model", + default="cueq_model.pt", + ) + parser.add_argument("--device", default="cpu", help="Device to use") + parser.add_argument( + "--return_model", + action="store_false", + help="Return model instead of saving to file", + ) + args = parser.parse_args() + + run( + input_model=args.input_model, + output_model=args.output_model, + device=args.device, + return_model=args.return_model, + ) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/create_lammps_model.py b/mace-bench/3rdparty/mace/mace/cli/create_lammps_model.py index f4ac867..7af81f2 100644 --- a/mace-bench/3rdparty/mace/mace/cli/create_lammps_model.py +++ b/mace-bench/3rdparty/mace/mace/cli/create_lammps_model.py @@ -1,114 +1,114 @@ -# pylint: disable=wrong-import-position -import argparse -import copy -import os - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" - -import torch -from e3nn.util import jit - -from mace.calculators import LAMMPS_MACE -from mace.calculators.lammps_mliap_mace import LAMMPS_MLIAP_MACE -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq - - -def parse_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "model_path", - type=str, - help="Path to the model to be converted to LAMMPS", - ) - parser.add_argument( - "--head", - type=str, - nargs="?", - help="Head of the model to be converted to LAMMPS", - default=None, - ) - parser.add_argument( - "--dtype", - type=str, - nargs="?", - help="Data type of the model to be converted to LAMMPS", - default="float64", - ) - parser.add_argument( - "--format", - type=str, - help="Old libtorch format, or new mliap format", - default="libtorch", - ) - return parser.parse_args() - - -def select_head(model): - if hasattr(model, "heads"): - heads = model.heads - else: - heads = [None] - - if len(heads) == 1: - print(f"Only one head found in the model: {heads[0]}. Skipping selection.") - return heads[0] - - print("Available heads in the model:") - for i, head in enumerate(heads): - print(f"{i + 1}: {head}") - - # Ask the user to select a head - selected = input( - f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): " - ) - - if selected.isdigit() and 1 <= int(selected) <= len(heads): - return heads[int(selected) - 1] - if selected == "": - print("No head selected. Proceeding without specifying a head.") - return None - print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") - return heads[-1] - - -def main(): - args = parse_args() - model_path = args.model_path # takes model name as command-line input - model = torch.load( - model_path, - map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), - ) - if args.dtype == "float64": - model = model.double().to("cpu") - elif args.dtype == "float32": - print("Converting model to float32, this may cause loss of precision.") - model = model.float().to("cpu") - - if args.format == "mliap": - # Enabling cuequivariance by default. TODO: switch? - model = run_e3nn_to_cueq(copy.deepcopy(model)) - model.lammps_mliap = True - - if args.head is None: - head = select_head(model) - else: - head = args.head - print( - f"Selected head: {head} from command line in the list available heads: {model.heads}" - ) - - lammps_class = LAMMPS_MLIAP_MACE if args.format == "mliap" else LAMMPS_MACE - lammps_model = ( - lammps_class(model, head=head) if head is not None else lammps_class(model) - ) - if args.format == "mliap": - torch.save(lammps_model, model_path + "-mliap_lammps.pt") - else: - lammps_model_compiled = jit.compile(lammps_model) - lammps_model_compiled.save(model_path + "-lammps.pt") - - -if __name__ == "__main__": - main() +# pylint: disable=wrong-import-position +import argparse +import copy +import os + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" + +import torch +from e3nn.util import jit + +from mace.calculators import LAMMPS_MACE +from mace.calculators.lammps_mliap_mace import LAMMPS_MLIAP_MACE +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq + + +def parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "model_path", + type=str, + help="Path to the model to be converted to LAMMPS", + ) + parser.add_argument( + "--head", + type=str, + nargs="?", + help="Head of the model to be converted to LAMMPS", + default=None, + ) + parser.add_argument( + "--dtype", + type=str, + nargs="?", + help="Data type of the model to be converted to LAMMPS", + default="float64", + ) + parser.add_argument( + "--format", + type=str, + help="Old libtorch format, or new mliap format", + default="libtorch", + ) + return parser.parse_args() + + +def select_head(model): + if hasattr(model, "heads"): + heads = model.heads + else: + heads = [None] + + if len(heads) == 1: + print(f"Only one head found in the model: {heads[0]}. Skipping selection.") + return heads[0] + + print("Available heads in the model:") + for i, head in enumerate(heads): + print(f"{i + 1}: {head}") + + # Ask the user to select a head + selected = input( + f"Select a head by number (Defaulting to head: {len(heads)}, press Enter to accept): " + ) + + if selected.isdigit() and 1 <= int(selected) <= len(heads): + return heads[int(selected) - 1] + if selected == "": + print("No head selected. Proceeding without specifying a head.") + return None + print(f"No valid selection made. Defaulting to the last head: {heads[-1]}") + return heads[-1] + + +def main(): + args = parse_args() + model_path = args.model_path # takes model name as command-line input + model = torch.load( + model_path, + map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + if args.dtype == "float64": + model = model.double().to("cpu") + elif args.dtype == "float32": + print("Converting model to float32, this may cause loss of precision.") + model = model.float().to("cpu") + + if args.format == "mliap": + # Enabling cuequivariance by default. TODO: switch? + model = run_e3nn_to_cueq(copy.deepcopy(model)) + model.lammps_mliap = True + + if args.head is None: + head = select_head(model) + else: + head = args.head + print( + f"Selected head: {head} from command line in the list available heads: {model.heads}" + ) + + lammps_class = LAMMPS_MLIAP_MACE if args.format == "mliap" else LAMMPS_MACE + lammps_model = ( + lammps_class(model, head=head) if head is not None else lammps_class(model) + ) + if args.format == "mliap": + torch.save(lammps_model, model_path + "-mliap_lammps.pt") + else: + lammps_model_compiled = jit.compile(lammps_model) + lammps_model_compiled.save(model_path + "-lammps.pt") + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/eval_configs.py b/mace-bench/3rdparty/mace/mace/cli/eval_configs.py index d4ec3c7..d00c54c 100644 --- a/mace-bench/3rdparty/mace/mace/cli/eval_configs.py +++ b/mace-bench/3rdparty/mace/mace/cli/eval_configs.py @@ -1,165 +1,165 @@ -########################################################################################### -# Script for evaluating configurations contained in an xyz file with a trained model -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import argparse - -import ase.data -import ase.io -import numpy as np -import torch - -from mace import data -from mace.tools import torch_geometric, torch_tools, utils - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument("--configs", help="path to XYZ configurations", required=True) - parser.add_argument("--model", help="path to model", required=True) - parser.add_argument("--output", help="output path", required=True) - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda"], - default="cpu", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument("--batch_size", help="batch size", type=int, default=64) - parser.add_argument( - "--compute_stress", - help="compute stress", - action="store_true", - default=False, - ) - parser.add_argument( - "--return_contributions", - help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE", - action="store_true", - default=False, - ) - parser.add_argument( - "--info_prefix", - help="prefix for energy, forces and stress keys", - type=str, - default="MACE_", - ) - parser.add_argument( - "--head", - help="Model head used for evaluation", - type=str, - required=False, - default=None, - ) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - run(args) - - -def run(args: argparse.Namespace) -> None: - torch_tools.set_default_dtype(args.default_dtype) - device = torch_tools.init_device(args.device) - - # Load model - model = torch.load(f=args.model, map_location=args.device) - model = model.to( - args.device - ) # shouldn't be necessary but seems to help with CUDA problems - - for param in model.parameters(): - param.requires_grad = False - - # Load data and prepare input - atoms_list = ase.io.read(args.configs, index=":") - if args.head is not None: - for atoms in atoms_list: - atoms.info["head"] = args.head - configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - - z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) - - try: - heads = model.heads - except AttributeError: - heads = None - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=float(model.r_max), heads=heads - ) - for config in configs - ], - batch_size=args.batch_size, - shuffle=False, - drop_last=False, - ) - - # Collect data - energies_list = [] - contributions_list = [] - stresses_list = [] - forces_collection = [] - - for batch in data_loader: - batch = batch.to(device) - output = model(batch.to_dict(), compute_stress=args.compute_stress) - energies_list.append(torch_tools.to_numpy(output["energy"])) - if args.compute_stress: - stresses_list.append(torch_tools.to_numpy(output["stress"])) - - if args.return_contributions: - contributions_list.append(torch_tools.to_numpy(output["contributions"])) - - forces = np.split( - torch_tools.to_numpy(output["forces"]), - indices_or_sections=batch.ptr[1:], - axis=0, - ) - forces_collection.append(forces[:-1]) # drop last as its empty - - energies = np.concatenate(energies_list, axis=0) - forces_list = [ - forces for forces_list in forces_collection for forces in forces_list - ] - assert len(atoms_list) == len(energies) == len(forces_list) - if args.compute_stress: - stresses = np.concatenate(stresses_list, axis=0) - assert len(atoms_list) == stresses.shape[0] - - if args.return_contributions: - contributions = np.concatenate(contributions_list, axis=0) - assert len(atoms_list) == contributions.shape[0] - - # Store data in atoms objects - for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)): - atoms.calc = None # crucial - atoms.info[args.info_prefix + "energy"] = energy - atoms.arrays[args.info_prefix + "forces"] = forces - - if args.compute_stress: - atoms.info[args.info_prefix + "stress"] = stresses[i] - - if args.return_contributions: - atoms.info[args.info_prefix + "BO_contributions"] = contributions[i] - - # Write atoms to output path - ase.io.write(args.output, images=atoms_list, format="extxyz") - - -if __name__ == "__main__": - main() +########################################################################################### +# Script for evaluating configurations contained in an xyz file with a trained model +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse + +import ase.data +import ase.io +import numpy as np +import torch + +from mace import data +from mace.tools import torch_geometric, torch_tools, utils + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--configs", help="path to XYZ configurations", required=True) + parser.add_argument("--model", help="path to model", required=True) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument("--batch_size", help="batch size", type=int, default=64) + parser.add_argument( + "--compute_stress", + help="compute stress", + action="store_true", + default=False, + ) + parser.add_argument( + "--return_contributions", + help="model outputs energy contributions for each body order, only supported for MACE, not ScaleShiftMACE", + action="store_true", + default=False, + ) + parser.add_argument( + "--info_prefix", + help="prefix for energy, forces and stress keys", + type=str, + default="MACE_", + ) + parser.add_argument( + "--head", + help="Model head used for evaluation", + type=str, + required=False, + default=None, + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + torch_tools.set_default_dtype(args.default_dtype) + device = torch_tools.init_device(args.device) + + # Load model + model = torch.load(f=args.model, map_location=args.device) + model = model.to( + args.device + ) # shouldn't be necessary but seems to help with CUDA problems + + for param in model.parameters(): + param.requires_grad = False + + # Load data and prepare input + atoms_list = ase.io.read(args.configs, index=":") + if args.head is not None: + for atoms in atoms_list: + atoms.info["head"] = args.head + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + + z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + + try: + heads = model.heads + except AttributeError: + heads = None + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=float(model.r_max), heads=heads + ) + for config in configs + ], + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + ) + + # Collect data + energies_list = [] + contributions_list = [] + stresses_list = [] + forces_collection = [] + + for batch in data_loader: + batch = batch.to(device) + output = model(batch.to_dict(), compute_stress=args.compute_stress) + energies_list.append(torch_tools.to_numpy(output["energy"])) + if args.compute_stress: + stresses_list.append(torch_tools.to_numpy(output["stress"])) + + if args.return_contributions: + contributions_list.append(torch_tools.to_numpy(output["contributions"])) + + forces = np.split( + torch_tools.to_numpy(output["forces"]), + indices_or_sections=batch.ptr[1:], + axis=0, + ) + forces_collection.append(forces[:-1]) # drop last as its empty + + energies = np.concatenate(energies_list, axis=0) + forces_list = [ + forces for forces_list in forces_collection for forces in forces_list + ] + assert len(atoms_list) == len(energies) == len(forces_list) + if args.compute_stress: + stresses = np.concatenate(stresses_list, axis=0) + assert len(atoms_list) == stresses.shape[0] + + if args.return_contributions: + contributions = np.concatenate(contributions_list, axis=0) + assert len(atoms_list) == contributions.shape[0] + + # Store data in atoms objects + for i, (atoms, energy, forces) in enumerate(zip(atoms_list, energies, forces_list)): + atoms.calc = None # crucial + atoms.info[args.info_prefix + "energy"] = energy + atoms.arrays[args.info_prefix + "forces"] = forces + + if args.compute_stress: + atoms.info[args.info_prefix + "stress"] = stresses[i] + + if args.return_contributions: + atoms.info[args.info_prefix + "BO_contributions"] = contributions[i] + + # Write atoms to output path + ase.io.write(args.output, images=atoms_list, format="extxyz") + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/fine_tuning_select.py b/mace-bench/3rdparty/mace/mace/cli/fine_tuning_select.py index 7dcfaba..59a0fb0 100644 --- a/mace-bench/3rdparty/mace/mace/cli/fine_tuning_select.py +++ b/mace-bench/3rdparty/mace/mace/cli/fine_tuning_select.py @@ -1,494 +1,494 @@ -########################################################################################### -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### -from __future__ import annotations - -import argparse -import logging -from dataclasses import dataclass -from enum import Enum -from typing import List, Tuple, Union - -import ase.data -import ase.io -import numpy as np -import torch - -from mace.calculators import MACECalculator, mace_mp - -try: - import fpsample # type: ignore -except ImportError: - pass - - -class FilteringType(Enum): - NONE = "none" - COMBINATIONS = "combinations" - EXCLUSIVE = "exclusive" - INCLUSIVE = "inclusive" - - -class SubselectType(Enum): - FPS = "fps" - RANDOM = "random" - - -@dataclass -class SelectionSettings: - configs_pt: str - output: str - configs_ft: str | None = None - atomic_numbers: List[int] | None = None - num_samples: int | None = None - subselect: SubselectType = SubselectType.FPS - model: str = "small" - descriptors: str | None = None - device: str = "cpu" - default_dtype: str = "float64" - head_pt: str | None = None - head_ft: str | None = None - filtering_type: FilteringType = FilteringType.COMBINATIONS - weight_ft: float = 1.0 - weight_pt: float = 1.0 - seed: int = 42 - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--configs_pt", - help="path to XYZ configurations for the pretraining", - required=True, - ) - parser.add_argument( - "--configs_ft", - help="path or list of paths to XYZ configurations for the finetuning", - required=False, - default=None, - ) - parser.add_argument( - "--num_samples", - help="number of samples to select for the pretraining", - type=int, - required=False, - default=None, - ) - parser.add_argument( - "--subselect", - help="method to subselect the configurations of the pretraining set", - type=SubselectType, - choices=list(SubselectType), - default=SubselectType.FPS, - ) - parser.add_argument( - "--model", help="path to model", default="small", required=False - ) - parser.add_argument("--output", help="output path", required=True) - parser.add_argument( - "--descriptors", help="path to descriptors", required=False, default=None - ) - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda"], - default="cpu", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument( - "--head_pt", - help="level of head for the pretraining set", - type=str, - default=None, - ) - parser.add_argument( - "--head_ft", - help="level of head for the finetuning set", - type=str, - default=None, - ) - parser.add_argument( - "--filtering_type", - help="filtering type", - type=FilteringType, - choices=list(FilteringType), - default=FilteringType.NONE, - ) - parser.add_argument( - "--weight_ft", - help="weight for the finetuning set", - type=float, - default=1.0, - ) - parser.add_argument( - "--weight_pt", - help="weight for the pretraining set", - type=float, - default=1.0, - ) - parser.add_argument("--seed", help="random seed", type=int, default=42) - return parser.parse_args() - - -def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: - logging.info("Calculating descriptors") - for mol in atoms: - descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) - # average descriptors over atoms for each element - descriptors_dict = { - element: np.mean(descriptors[mol.symbols == element], axis=0) - for element in np.unique(mol.symbols) - } - mol.info["mace_descriptors"] = descriptors_dict - - -def filter_atoms( - atoms: ase.Atoms, - element_subset: List[str], - filtering_type: FilteringType = FilteringType.COMBINATIONS, -) -> bool: - """ - Filters atoms based on the provided filtering type and element subset. - - Parameters: - atoms (ase.Atoms): The atoms object to filter. - element_subset (list): The list of elements to consider during filtering. - filtering_type (FilteringType): The type of filtering to apply. - Can be one of the following `FilteringType` enum members: - - `FilteringType.NONE`: No filtering is applied. - - `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. - - `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise. - - `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. - - Returns: - bool: True if the atoms pass the filter, False otherwise. - """ - if filtering_type == FilteringType.NONE: - return True - if filtering_type == FilteringType.COMBINATIONS: - atom_symbols = np.unique(atoms.symbols) - return all( - x in element_subset for x in atom_symbols - ) # atoms must *only* contain elements in the subset - if filtering_type == FilteringType.EXCLUSIVE: - atom_symbols = set(list(atoms.symbols)) - return atom_symbols == set(element_subset) - if filtering_type == FilteringType.INCLUSIVE: - atom_symbols = np.unique(atoms.symbols) - return all( - x in atom_symbols for x in element_subset - ) # atoms must *at least* contain elements in the subset - raise ValueError( - f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}." - ) - - -class FPS: - def __init__(self, atoms_list: List[ase.Atoms], n_samples: int): - self.n_samples = n_samples - self.atoms_list = atoms_list - self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) # type: ignore - self.species_dict = {x: i for i, x in enumerate(self.species)} - # start from a random configuration - self.list_index = [np.random.randint(0, len(atoms_list))] - self.assemble_descriptors() - - def run( - self, - ) -> List[int]: - """ - Run the farthest point sampling algorithm. - """ - descriptor_dataset_reshaped = ( - self.descriptors_dataset.reshape( # pylint: disable=E1121 - (len(self.atoms_list), -1) - ) - ) - logging.info(f"{descriptor_dataset_reshaped.shape}") - logging.info(f"n_samples: {self.n_samples}") - self.list_index = fpsample.fps_npdu_kdtree_sampling( - descriptor_dataset_reshaped, - self.n_samples, - ) - return self.list_index - - def assemble_descriptors(self) -> None: - """ - Assemble the descriptors for all the configurations. - """ - self.descriptors_dataset: np.ndarray = 10e10 * np.ones( - ( - len(self.atoms_list), - len(self.species), - len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), - ), - dtype=np.float32, - ).astype(np.float32) - - for i, atoms in enumerate(self.atoms_list): - descriptors = atoms.info["mace_descriptors"] - for z in descriptors: - self.descriptors_dataset[i, self.species_dict[z]] = np.array( - descriptors[z] - ).astype(np.float32) - - -def _load_calc( - model: str, device: str, default_dtype: str, subselect: SubselectType -) -> Union[MACECalculator, None]: - if subselect == SubselectType.RANDOM: - return None - if model in ["small", "medium", "large"]: - calc = mace_mp(model, device=device, default_dtype=default_dtype) - else: - calc = MACECalculator( - model_paths=model, - device=device, - default_dtype=default_dtype, - ) - return calc - - -def _get_finetuning_elements( - atoms: List[ase.Atoms], atomic_numbers: List[int] | None -) -> List[str]: - if atoms: - logging.debug( - "Using elements from the finetuning configurations for filtering." - ) - species = np.unique([x.symbol for atoms in atoms for x in atoms]).tolist() # type: ignore - elif atomic_numbers is not None and atomic_numbers: - logging.debug("Using the supplied atomic numbers for filtering.") - species = [ase.data.chemical_symbols[z] for z in atomic_numbers] - else: - species = [] - return species - - -def _read_finetuning_configs( - configs_ft: Union[str, list[str], None], -) -> List[ase.Atoms]: - if isinstance(configs_ft, str): - path = configs_ft - return ase.io.read(path, index=":") # type: ignore - if isinstance(configs_ft, list): - assert all(isinstance(x, str) for x in configs_ft) - atoms_list_ft = [] - for path in configs_ft: - atoms_list_ft += ase.io.read(path, index=":") - return atoms_list_ft - if configs_ft is None: - return [] - raise ValueError(f"Invalid type for configs_ft: {type(configs_ft)}") - - -def _filter_pretraining_data( - atoms: list[ase.Atoms], - filtering_type: FilteringType, - all_species_ft: List[str], -) -> Tuple[List[ase.Atoms], List[ase.Atoms], list[bool]]: - logging.info( - "Filtering configurations based on the finetuning set, " - f"filtering type: {filtering_type}, elements: {all_species_ft}" - ) - passes_filter = [filter_atoms(x, all_species_ft, filtering_type) for x in atoms] - assert len(passes_filter) == len(atoms), "Filtering failed" - filtered_atoms = [x for x, passes in zip(atoms, passes_filter) if passes] - remaining_atoms = [x for x, passes in zip(atoms, passes_filter) if not passes] - return filtered_atoms, remaining_atoms, passes_filter - - -def _get_random_configs( - num_samples: int, - atoms: List[ase.Atoms], -) -> list[ase.Atoms]: - if num_samples > len(atoms): - raise ValueError( - f"Requested more samples ({num_samples}) than available in the remaining set ({len(atoms)})" - ) - indices = np.random.choice(list(range(len(atoms))), num_samples, replace=False) - return [atoms[i] for i in indices] - - -def _load_descriptors( - atoms: List[ase.Atoms], - passes_filter: List[bool], - descriptors_path: str | None, - calc: MACECalculator | None, - full_data_length: int, -) -> None: - if descriptors_path is not None: - logging.info(f"Loading descriptors from {descriptors_path}") - descriptors = np.load(descriptors_path, allow_pickle=True) - assert sum(passes_filter) == len(atoms) - if len(descriptors) != full_data_length: - raise ValueError( - f"Length of the descriptors ({len(descriptors)}) does not match the length of the data ({full_data_length})" - "Please provide descriptors for all configurations" - ) - required_descriptors = [ - descriptors[i] for i, passes in enumerate(passes_filter) if passes - ] - for i, atoms_ in enumerate(atoms): - atoms_.info["mace_descriptors"] = required_descriptors[i] - else: - logging.info("Calculating descriptors") - if calc is None: - raise ValueError("MACECalculator must be provided to calculate descriptors") - calculate_descriptors(atoms, calc) - - -def _maybe_save_descriptors( - atoms: List[ase.Atoms], - output_path: str, -) -> None: - """ - Save the descriptors if they are present in the atoms objects. - Also, delete the descriptors from the atoms objects. - """ - if all("mace_descriptors" in x.info for x in atoms): - descriptor_save_path = output_path.replace(".xyz", "_descriptors.npy") - logging.info(f"Saving descriptors at {descriptor_save_path}") - descriptors_list = [x.info["mace_descriptors"] for x in atoms] - np.save(descriptor_save_path, descriptors_list, allow_pickle=True) - for x in atoms: - del x.info["mace_descriptors"] - - -def _maybe_fps(atoms: List[ase.Atoms], num_samples: int) -> List[ase.Atoms]: - try: - fps_pt = FPS(atoms, num_samples) - idx_pt = fps_pt.run() - logging.info(f"Selected {len(idx_pt)} configurations") - return [atoms[i] for i in idx_pt] - except Exception as e: # pylint: disable=W0703 - logging.error(f"FPS failed, selecting random configurations instead: {e}") - return _get_random_configs(num_samples, atoms) - - -def _subsample_data( - filtered_atoms: List[ase.Atoms], - remaining_atoms: List[ase.Atoms], - passes_filter: List[bool], - num_samples: int | None, - subselect: SubselectType, - descriptors_path: str | None, - calc: MACECalculator | None, -) -> List[ase.Atoms]: - if num_samples is None or num_samples == len(filtered_atoms): - logging.info( - f"No subsampling, keeping all {len(filtered_atoms)} filtered configurations" - ) - return filtered_atoms - if num_samples > len(filtered_atoms): - num_sample_randomly = num_samples - len(filtered_atoms) - logging.info( - f"Number of configurations after filtering {len(filtered_atoms)} " - f"is less than the number of samples {num_samples}, " - f"selecting {num_sample_randomly} random configurations for the rest." - ) - return filtered_atoms + _get_random_configs( - num_sample_randomly, remaining_atoms - ) - if num_samples == 0: - raise ValueError("Number of samples must be greater than 0") - if subselect == SubselectType.FPS: - _load_descriptors( - filtered_atoms, - passes_filter, - descriptors_path, - calc, - full_data_length=len(filtered_atoms) + len(remaining_atoms), - ) - logging.info("Selecting configurations using Farthest Point Sampling") - return _maybe_fps(filtered_atoms, num_samples) - if subselect == SubselectType.RANDOM: - return _get_random_configs(num_samples, filtered_atoms) - raise ValueError(f"Invalid subselect type: {subselect}") - - -def _write_metadata( - atoms: list[ase.Atoms], pretrained: bool, config_weight: float, head: str | None -) -> None: - for a in atoms: - a.info["pretrained"] = pretrained - a.info["config_weight"] = config_weight - if head is not None: - a.info["head"] = head - - -def select_samples( - settings: SelectionSettings, -) -> None: - np.random.seed(settings.seed) - torch.manual_seed(settings.seed) - calc = _load_calc( - settings.model, settings.device, settings.default_dtype, settings.subselect - ) - atoms_list_ft = _read_finetuning_configs(settings.configs_ft) - all_species_ft = _get_finetuning_elements(atoms_list_ft, settings.atomic_numbers) - - if settings.filtering_type is not FilteringType.NONE and not all_species_ft: - raise ValueError( - "Filtering types other than NONE require elements for filtering. They can be specified via the `--atomic_numbers` flag." - ) - - atoms_list_pt: list[ase.Atoms] = ase.io.read(settings.configs_pt, index=":") # type: ignore - filtered_pt_atoms, remaining_atoms, passes_filter = _filter_pretraining_data( - atoms_list_pt, settings.filtering_type, all_species_ft - ) - - subsampled_atoms = _subsample_data( - filtered_pt_atoms, - remaining_atoms, - passes_filter, - settings.num_samples, - settings.subselect, - settings.descriptors, - calc, - ) - _maybe_save_descriptors(subsampled_atoms, settings.output) - - _write_metadata( - subsampled_atoms, - pretrained=True, - config_weight=settings.weight_pt, - head=settings.head_pt, - ) - _write_metadata( - atoms_list_ft, - pretrained=False, - config_weight=settings.weight_ft, - head=settings.head_ft, - ) - - logging.info("Saving the selected configurations") - ase.io.write(settings.output, subsampled_atoms, format="extxyz") - - logging.info("Saving a combined XYZ file") - atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft - - ase.io.write( - settings.output.replace(".xyz", "_combined.xyz"), - atoms_fps_pt_ft, - format="extxyz", - ) - - -def main(): - args = parse_args() - settings = SelectionSettings(**vars(args)) - select_samples(settings) - - -if __name__ == "__main__": - main() +########################################################################################### +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### +from __future__ import annotations + +import argparse +import logging +from dataclasses import dataclass +from enum import Enum +from typing import List, Tuple, Union + +import ase.data +import ase.io +import numpy as np +import torch + +from mace.calculators import MACECalculator, mace_mp + +try: + import fpsample # type: ignore +except ImportError: + pass + + +class FilteringType(Enum): + NONE = "none" + COMBINATIONS = "combinations" + EXCLUSIVE = "exclusive" + INCLUSIVE = "inclusive" + + +class SubselectType(Enum): + FPS = "fps" + RANDOM = "random" + + +@dataclass +class SelectionSettings: + configs_pt: str + output: str + configs_ft: str | None = None + atomic_numbers: List[int] | None = None + num_samples: int | None = None + subselect: SubselectType = SubselectType.FPS + model: str = "small" + descriptors: str | None = None + device: str = "cpu" + default_dtype: str = "float64" + head_pt: str | None = None + head_ft: str | None = None + filtering_type: FilteringType = FilteringType.COMBINATIONS + weight_ft: float = 1.0 + weight_pt: float = 1.0 + seed: int = 42 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--configs_pt", + help="path to XYZ configurations for the pretraining", + required=True, + ) + parser.add_argument( + "--configs_ft", + help="path or list of paths to XYZ configurations for the finetuning", + required=False, + default=None, + ) + parser.add_argument( + "--num_samples", + help="number of samples to select for the pretraining", + type=int, + required=False, + default=None, + ) + parser.add_argument( + "--subselect", + help="method to subselect the configurations of the pretraining set", + type=SubselectType, + choices=list(SubselectType), + default=SubselectType.FPS, + ) + parser.add_argument( + "--model", help="path to model", default="small", required=False + ) + parser.add_argument("--output", help="output path", required=True) + parser.add_argument( + "--descriptors", help="path to descriptors", required=False, default=None + ) + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--head_pt", + help="level of head for the pretraining set", + type=str, + default=None, + ) + parser.add_argument( + "--head_ft", + help="level of head for the finetuning set", + type=str, + default=None, + ) + parser.add_argument( + "--filtering_type", + help="filtering type", + type=FilteringType, + choices=list(FilteringType), + default=FilteringType.NONE, + ) + parser.add_argument( + "--weight_ft", + help="weight for the finetuning set", + type=float, + default=1.0, + ) + parser.add_argument( + "--weight_pt", + help="weight for the pretraining set", + type=float, + default=1.0, + ) + parser.add_argument("--seed", help="random seed", type=int, default=42) + return parser.parse_args() + + +def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: + logging.info("Calculating descriptors") + for mol in atoms: + descriptors = calc.get_descriptors(mol.copy(), invariants_only=True) + # average descriptors over atoms for each element + descriptors_dict = { + element: np.mean(descriptors[mol.symbols == element], axis=0) + for element in np.unique(mol.symbols) + } + mol.info["mace_descriptors"] = descriptors_dict + + +def filter_atoms( + atoms: ase.Atoms, + element_subset: List[str], + filtering_type: FilteringType = FilteringType.COMBINATIONS, +) -> bool: + """ + Filters atoms based on the provided filtering type and element subset. + + Parameters: + atoms (ase.Atoms): The atoms object to filter. + element_subset (list): The list of elements to consider during filtering. + filtering_type (FilteringType): The type of filtering to apply. + Can be one of the following `FilteringType` enum members: + - `FilteringType.NONE`: No filtering is applied. + - `FilteringType.COMBINATIONS`: Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. + - `FilteringType.EXCLUSIVE`: Return true if `atoms` contains *only* elements in the subset, false otherwise. + - `FilteringType.INCLUSIVE`: Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. + + Returns: + bool: True if the atoms pass the filter, False otherwise. + """ + if filtering_type == FilteringType.NONE: + return True + if filtering_type == FilteringType.COMBINATIONS: + atom_symbols = np.unique(atoms.symbols) + return all( + x in element_subset for x in atom_symbols + ) # atoms must *only* contain elements in the subset + if filtering_type == FilteringType.EXCLUSIVE: + atom_symbols = set(list(atoms.symbols)) + return atom_symbols == set(element_subset) + if filtering_type == FilteringType.INCLUSIVE: + atom_symbols = np.unique(atoms.symbols) + return all( + x in atom_symbols for x in element_subset + ) # atoms must *at least* contain elements in the subset + raise ValueError( + f"Filtering type {filtering_type} not recognised. Must be one of {list(FilteringType)}." + ) + + +class FPS: + def __init__(self, atoms_list: List[ase.Atoms], n_samples: int): + self.n_samples = n_samples + self.atoms_list = atoms_list + self.species = np.unique([x.symbol for atoms in atoms_list for x in atoms]) # type: ignore + self.species_dict = {x: i for i, x in enumerate(self.species)} + # start from a random configuration + self.list_index = [np.random.randint(0, len(atoms_list))] + self.assemble_descriptors() + + def run( + self, + ) -> List[int]: + """ + Run the farthest point sampling algorithm. + """ + descriptor_dataset_reshaped = ( + self.descriptors_dataset.reshape( # pylint: disable=E1121 + (len(self.atoms_list), -1) + ) + ) + logging.info(f"{descriptor_dataset_reshaped.shape}") + logging.info(f"n_samples: {self.n_samples}") + self.list_index = fpsample.fps_npdu_kdtree_sampling( + descriptor_dataset_reshaped, + self.n_samples, + ) + return self.list_index + + def assemble_descriptors(self) -> None: + """ + Assemble the descriptors for all the configurations. + """ + self.descriptors_dataset: np.ndarray = 10e10 * np.ones( + ( + len(self.atoms_list), + len(self.species), + len(list(self.atoms_list[0].info["mace_descriptors"].values())[0]), + ), + dtype=np.float32, + ).astype(np.float32) + + for i, atoms in enumerate(self.atoms_list): + descriptors = atoms.info["mace_descriptors"] + for z in descriptors: + self.descriptors_dataset[i, self.species_dict[z]] = np.array( + descriptors[z] + ).astype(np.float32) + + +def _load_calc( + model: str, device: str, default_dtype: str, subselect: SubselectType +) -> Union[MACECalculator, None]: + if subselect == SubselectType.RANDOM: + return None + if model in ["small", "medium", "large"]: + calc = mace_mp(model, device=device, default_dtype=default_dtype) + else: + calc = MACECalculator( + model_paths=model, + device=device, + default_dtype=default_dtype, + ) + return calc + + +def _get_finetuning_elements( + atoms: List[ase.Atoms], atomic_numbers: List[int] | None +) -> List[str]: + if atoms: + logging.debug( + "Using elements from the finetuning configurations for filtering." + ) + species = np.unique([x.symbol for atoms in atoms for x in atoms]).tolist() # type: ignore + elif atomic_numbers is not None and atomic_numbers: + logging.debug("Using the supplied atomic numbers for filtering.") + species = [ase.data.chemical_symbols[z] for z in atomic_numbers] + else: + species = [] + return species + + +def _read_finetuning_configs( + configs_ft: Union[str, list[str], None], +) -> List[ase.Atoms]: + if isinstance(configs_ft, str): + path = configs_ft + return ase.io.read(path, index=":") # type: ignore + if isinstance(configs_ft, list): + assert all(isinstance(x, str) for x in configs_ft) + atoms_list_ft = [] + for path in configs_ft: + atoms_list_ft += ase.io.read(path, index=":") + return atoms_list_ft + if configs_ft is None: + return [] + raise ValueError(f"Invalid type for configs_ft: {type(configs_ft)}") + + +def _filter_pretraining_data( + atoms: list[ase.Atoms], + filtering_type: FilteringType, + all_species_ft: List[str], +) -> Tuple[List[ase.Atoms], List[ase.Atoms], list[bool]]: + logging.info( + "Filtering configurations based on the finetuning set, " + f"filtering type: {filtering_type}, elements: {all_species_ft}" + ) + passes_filter = [filter_atoms(x, all_species_ft, filtering_type) for x in atoms] + assert len(passes_filter) == len(atoms), "Filtering failed" + filtered_atoms = [x for x, passes in zip(atoms, passes_filter) if passes] + remaining_atoms = [x for x, passes in zip(atoms, passes_filter) if not passes] + return filtered_atoms, remaining_atoms, passes_filter + + +def _get_random_configs( + num_samples: int, + atoms: List[ase.Atoms], +) -> list[ase.Atoms]: + if num_samples > len(atoms): + raise ValueError( + f"Requested more samples ({num_samples}) than available in the remaining set ({len(atoms)})" + ) + indices = np.random.choice(list(range(len(atoms))), num_samples, replace=False) + return [atoms[i] for i in indices] + + +def _load_descriptors( + atoms: List[ase.Atoms], + passes_filter: List[bool], + descriptors_path: str | None, + calc: MACECalculator | None, + full_data_length: int, +) -> None: + if descriptors_path is not None: + logging.info(f"Loading descriptors from {descriptors_path}") + descriptors = np.load(descriptors_path, allow_pickle=True) + assert sum(passes_filter) == len(atoms) + if len(descriptors) != full_data_length: + raise ValueError( + f"Length of the descriptors ({len(descriptors)}) does not match the length of the data ({full_data_length})" + "Please provide descriptors for all configurations" + ) + required_descriptors = [ + descriptors[i] for i, passes in enumerate(passes_filter) if passes + ] + for i, atoms_ in enumerate(atoms): + atoms_.info["mace_descriptors"] = required_descriptors[i] + else: + logging.info("Calculating descriptors") + if calc is None: + raise ValueError("MACECalculator must be provided to calculate descriptors") + calculate_descriptors(atoms, calc) + + +def _maybe_save_descriptors( + atoms: List[ase.Atoms], + output_path: str, +) -> None: + """ + Save the descriptors if they are present in the atoms objects. + Also, delete the descriptors from the atoms objects. + """ + if all("mace_descriptors" in x.info for x in atoms): + descriptor_save_path = output_path.replace(".xyz", "_descriptors.npy") + logging.info(f"Saving descriptors at {descriptor_save_path}") + descriptors_list = [x.info["mace_descriptors"] for x in atoms] + np.save(descriptor_save_path, descriptors_list, allow_pickle=True) + for x in atoms: + del x.info["mace_descriptors"] + + +def _maybe_fps(atoms: List[ase.Atoms], num_samples: int) -> List[ase.Atoms]: + try: + fps_pt = FPS(atoms, num_samples) + idx_pt = fps_pt.run() + logging.info(f"Selected {len(idx_pt)} configurations") + return [atoms[i] for i in idx_pt] + except Exception as e: # pylint: disable=W0703 + logging.error(f"FPS failed, selecting random configurations instead: {e}") + return _get_random_configs(num_samples, atoms) + + +def _subsample_data( + filtered_atoms: List[ase.Atoms], + remaining_atoms: List[ase.Atoms], + passes_filter: List[bool], + num_samples: int | None, + subselect: SubselectType, + descriptors_path: str | None, + calc: MACECalculator | None, +) -> List[ase.Atoms]: + if num_samples is None or num_samples == len(filtered_atoms): + logging.info( + f"No subsampling, keeping all {len(filtered_atoms)} filtered configurations" + ) + return filtered_atoms + if num_samples > len(filtered_atoms): + num_sample_randomly = num_samples - len(filtered_atoms) + logging.info( + f"Number of configurations after filtering {len(filtered_atoms)} " + f"is less than the number of samples {num_samples}, " + f"selecting {num_sample_randomly} random configurations for the rest." + ) + return filtered_atoms + _get_random_configs( + num_sample_randomly, remaining_atoms + ) + if num_samples == 0: + raise ValueError("Number of samples must be greater than 0") + if subselect == SubselectType.FPS: + _load_descriptors( + filtered_atoms, + passes_filter, + descriptors_path, + calc, + full_data_length=len(filtered_atoms) + len(remaining_atoms), + ) + logging.info("Selecting configurations using Farthest Point Sampling") + return _maybe_fps(filtered_atoms, num_samples) + if subselect == SubselectType.RANDOM: + return _get_random_configs(num_samples, filtered_atoms) + raise ValueError(f"Invalid subselect type: {subselect}") + + +def _write_metadata( + atoms: list[ase.Atoms], pretrained: bool, config_weight: float, head: str | None +) -> None: + for a in atoms: + a.info["pretrained"] = pretrained + a.info["config_weight"] = config_weight + if head is not None: + a.info["head"] = head + + +def select_samples( + settings: SelectionSettings, +) -> None: + np.random.seed(settings.seed) + torch.manual_seed(settings.seed) + calc = _load_calc( + settings.model, settings.device, settings.default_dtype, settings.subselect + ) + atoms_list_ft = _read_finetuning_configs(settings.configs_ft) + all_species_ft = _get_finetuning_elements(atoms_list_ft, settings.atomic_numbers) + + if settings.filtering_type is not FilteringType.NONE and not all_species_ft: + raise ValueError( + "Filtering types other than NONE require elements for filtering. They can be specified via the `--atomic_numbers` flag." + ) + + atoms_list_pt: list[ase.Atoms] = ase.io.read(settings.configs_pt, index=":") # type: ignore + filtered_pt_atoms, remaining_atoms, passes_filter = _filter_pretraining_data( + atoms_list_pt, settings.filtering_type, all_species_ft + ) + + subsampled_atoms = _subsample_data( + filtered_pt_atoms, + remaining_atoms, + passes_filter, + settings.num_samples, + settings.subselect, + settings.descriptors, + calc, + ) + _maybe_save_descriptors(subsampled_atoms, settings.output) + + _write_metadata( + subsampled_atoms, + pretrained=True, + config_weight=settings.weight_pt, + head=settings.head_pt, + ) + _write_metadata( + atoms_list_ft, + pretrained=False, + config_weight=settings.weight_ft, + head=settings.head_ft, + ) + + logging.info("Saving the selected configurations") + ase.io.write(settings.output, subsampled_atoms, format="extxyz") + + logging.info("Saving a combined XYZ file") + atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft + + ase.io.write( + settings.output.replace(".xyz", "_combined.xyz"), + atoms_fps_pt_ft, + format="extxyz", + ) + + +def main(): + args = parse_args() + settings = SelectionSettings(**vars(args)) + select_samples(settings) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/plot_train.py b/mace-bench/3rdparty/mace/mace/cli/plot_train.py index 4e27372..238bd09 100644 --- a/mace-bench/3rdparty/mace/mace/cli/plot_train.py +++ b/mace-bench/3rdparty/mace/mace/cli/plot_train.py @@ -1,342 +1,342 @@ -import argparse -import dataclasses -import glob -import json -import os -import re -from typing import List - -import matplotlib.pyplot as plt -import pandas as pd - -plt.rcParams.update({"font.size": 8}) -plt.style.use("seaborn-v0_8-paper") - - -colors = [ - "#1f77b4", # muted blue - "#d62728", # brick red - "#ff7f0e", # safety orange - "#2ca02c", # cooked asparagus green - "#9467bd", # muted purple - "#8c564b", # chestnut brown - "#e377c2", # raspberry yogurt pink - "#7f7f7f", # middle gray - "#bcbd22", # curry yellow-green - "#17becf", # blue-teal -] - - -@dataclasses.dataclass -class RunInfo: - name: str - seed: int - - -name_re = re.compile(r"(?P.+)_run-(?P\d+)_train.txt") - - -def parse_path(path: str) -> RunInfo: - match = name_re.match(os.path.basename(path)) - if not match: - raise RuntimeError(f"Cannot parse {path}") - - return RunInfo(name=match.group("name"), seed=int(match.group("seed"))) - - -def parse_training_results(path: str) -> List[dict]: - run_info = parse_path(path) - results = [] - with open(path, mode="r", encoding="utf-8") as f: - for line in f: - d = json.loads(line) - d["name"] = run_info.name - d["seed"] = run_info.seed - results.append(d) - - return results - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Plot mace training statistics", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--path", help="Path to results file (.txt) or directory.", required=True - ) - parser.add_argument( - "--min_epoch", help="Minimum epoch.", default=0, type=int, required=False - ) - parser.add_argument( - "--start_stage_two", - "--start_swa", - help="Epoch that stage two (swa) loss began. Plots dashed line on plot to indicate. If None then assumed tag not used in training.", - default=None, - type=int, - required=False, - dest="start_swa", - ) - parser.add_argument( - "--linear", - help="Whether to plot linear instead of log scales.", - default=False, - required=False, - action="store_true", - ) - parser.add_argument( - "--error_bars", - help="Whether to plot standard deviations.", - default=False, - required=False, - action="store_true", - ) - parser.add_argument( - "--keys", - help="Comma-separated list of keys to plot.", - default="rmse_e,rmse_f", - type=str, - required=False, - ) - - parser.add_argument( - "--output_format", - help="What file type to save plot as", - default="png", - type=str, - required=False, - ) - - parser.add_argument( - "--heads", - help="Comma-separated name of the heads used for multihead training", - default=None, - type=str, - required=False, - ) - - return parser.parse_args() - - -def plot( - data: pd.DataFrame, - min_epoch: int, - output_path: str, - output_format: str, - linear: bool, - start_swa: int, - error_bars: bool, - keys: str, - heads: str, -) -> None: - """ - Plots train,validation loss and errors as a function of epoch. - min_epoch: minimum epoch to plot. - output_path: path to save the plot. - output_format: format to save the plot. - start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins. - error_bars: whether to plot standard deviation of loss. - linear: whether to plot in linear scale or logscale (default). - keys: Values to plot. - heads: Heads used for multihead training. - """ - - labels = { - "mae_e": "MAE E [meV]", - "mae_e_per_atom": "MAE E/atom [meV]", - "rmse_e": "RMSE E [meV]", - "rmse_e_per_atom": "RMSE E/atom [meV]", - "q95_e": "Q95 E [meV]", - "mae_f": "MAE F [meV / A]", - "rel_mae_f": "Relative MAE F [meV / A]", - "rmse_f": "RMSE F [meV / A]", - "rel_rmse_f": "Relative RMSE F [meV / A]", - "q95_f": "Q95 F [meV / A]", - "mae_stress": "MAE Stress", - "rmse_stress": "RMSE Stress [meV / A^3]", - "rmse_virials_per_atom": " RMSE virials/atom [meV]", - "mae_virials": "MAE Virials [meV]", - "rmse_mu_per_atom": "RMSE MU/atom [mDebye]", - } - - data = data[data["epoch"] > min_epoch] - if heads is None: - data = ( - data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index() - ) - - valid_data = data[data["mode"] == "eval"] - valid_data_dict = {"default": valid_data} - train_data = data[data["mode"] == "opt"] - else: - heads = heads.split(",") - # Separate eval and opt data - valid_data = ( - data[data["mode"] == "eval"] - .groupby(["name", "mode", "epoch", "head"]) - .agg(["mean", "std"]) - .reset_index() - ) - train_data = ( - data[data["mode"] == "opt"] - .groupby(["name", "mode", "epoch"]) - .agg(["mean", "std"]) - .reset_index() - ) - valid_data_dict = { - head: valid_data[valid_data["head"] == head] for head in heads - } - - for head, valid_data in valid_data_dict.items(): - fig, axes = plt.subplots( - nrows=1, ncols=2, figsize=(10, 3), constrained_layout=True - ) - - # ---- Plot loss ---- - ax = axes[0] - ax.plot( - train_data["epoch"], - train_data["loss"]["mean"], - color=colors[1], - linewidth=1, - ) - ax.set_ylabel("Training Loss", color=colors[1]) - ax.set_yscale("log") - - ax2 = ax.twinx() - ax2.plot( - valid_data["epoch"], - valid_data["loss"]["mean"], - color=colors[0], - linewidth=1, - ) - ax2.set_ylabel("Validation Loss", color=colors[0]) - - if not linear: - ax.set_yscale("log") - ax2.set_yscale("log") - - if error_bars: - ax.fill_between( - train_data["epoch"], - train_data["loss"]["mean"] - train_data["loss"]["std"], - train_data["loss"]["mean"] + train_data["loss"]["std"], - alpha=0.3, - color=colors[1], - ) - ax.fill_between( - valid_data["epoch"], - valid_data["loss"]["mean"] - valid_data["loss"]["std"], - valid_data["loss"]["mean"] + valid_data["loss"]["std"], - alpha=0.3, - color=colors[0], - ) - - if start_swa is not None: - ax.axvline( - start_swa, - color="black", - linestyle="dashed", - linewidth=1, - alpha=0.6, - label="Stage Two Starts", - ) - - ax.set_xlabel("Epoch") - ax.set_ylabel("Loss") - ax.legend(loc="upper right", fontsize=4) - ax.grid(True, linestyle="--", alpha=0.5) - - # ---- Plot selected keys ---- - ax = axes[1] - twin_axes = [] - for i, key in enumerate(keys.split(",")): - color = colors[(i + 3)] - label = labels.get(key, key) - - if i == 0: - main_ax = ax - else: - main_ax = ax.twinx() - main_ax.spines.right.set_position(("outward", 40 * (i - 1))) - twin_axes.append(main_ax) - - main_ax.plot( - valid_data["epoch"], - valid_data[key]["mean"] * 1e3, - color=color, - label=label, - linewidth=1, - ) - - if error_bars: - main_ax.fill_between( - valid_data["epoch"], - (valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3, - (valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3, - alpha=0.3, - color=color, - ) - - main_ax.set_ylabel(label, color=color) - main_ax.tick_params(axis="y", colors=color) - - if start_swa is not None: - ax.axvline( - start_swa, - color="black", - linestyle="dashed", - linewidth=1, - alpha=0.6, - label="Stage Two Starts", - ) - - ax.set_xlabel("Epoch") - ax.set_xlim(left=min_epoch) - ax.grid(True, linestyle="--", alpha=0.5) - - fig.savefig( - f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight" - ) - plt.close(fig) - - -def get_paths(path: str) -> List[str]: - if os.path.isfile(path): - return [path] - paths = glob.glob(os.path.join(path, "*_train.txt")) - - if len(paths) == 0: - raise RuntimeError(f"Cannot find results in '{path}'") - - return paths - - -def main() -> None: - args = parse_args() - run(args) - - -def run(args: argparse.Namespace) -> None: - data = pd.DataFrame( - results - for path in get_paths(args.path) - for results in parse_training_results(path) - ) - - for name, group in data.groupby("name"): - plot( - group, - min_epoch=args.min_epoch, - output_path=name, - output_format=args.output_format, - linear=args.linear, - start_swa=args.start_swa, - error_bars=args.error_bars, - keys=args.keys, - heads=args.heads, - ) - - -if __name__ == "__main__": - main() +import argparse +import dataclasses +import glob +import json +import os +import re +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd + +plt.rcParams.update({"font.size": 8}) +plt.style.use("seaborn-v0_8-paper") + + +colors = [ + "#1f77b4", # muted blue + "#d62728", # brick red + "#ff7f0e", # safety orange + "#2ca02c", # cooked asparagus green + "#9467bd", # muted purple + "#8c564b", # chestnut brown + "#e377c2", # raspberry yogurt pink + "#7f7f7f", # middle gray + "#bcbd22", # curry yellow-green + "#17becf", # blue-teal +] + + +@dataclasses.dataclass +class RunInfo: + name: str + seed: int + + +name_re = re.compile(r"(?P.+)_run-(?P\d+)_train.txt") + + +def parse_path(path: str) -> RunInfo: + match = name_re.match(os.path.basename(path)) + if not match: + raise RuntimeError(f"Cannot parse {path}") + + return RunInfo(name=match.group("name"), seed=int(match.group("seed"))) + + +def parse_training_results(path: str) -> List[dict]: + run_info = parse_path(path) + results = [] + with open(path, mode="r", encoding="utf-8") as f: + for line in f: + d = json.loads(line) + d["name"] = run_info.name + d["seed"] = run_info.seed + results.append(d) + + return results + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Plot mace training statistics", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--path", help="Path to results file (.txt) or directory.", required=True + ) + parser.add_argument( + "--min_epoch", help="Minimum epoch.", default=0, type=int, required=False + ) + parser.add_argument( + "--start_stage_two", + "--start_swa", + help="Epoch that stage two (swa) loss began. Plots dashed line on plot to indicate. If None then assumed tag not used in training.", + default=None, + type=int, + required=False, + dest="start_swa", + ) + parser.add_argument( + "--linear", + help="Whether to plot linear instead of log scales.", + default=False, + required=False, + action="store_true", + ) + parser.add_argument( + "--error_bars", + help="Whether to plot standard deviations.", + default=False, + required=False, + action="store_true", + ) + parser.add_argument( + "--keys", + help="Comma-separated list of keys to plot.", + default="rmse_e,rmse_f", + type=str, + required=False, + ) + + parser.add_argument( + "--output_format", + help="What file type to save plot as", + default="png", + type=str, + required=False, + ) + + parser.add_argument( + "--heads", + help="Comma-separated name of the heads used for multihead training", + default=None, + type=str, + required=False, + ) + + return parser.parse_args() + + +def plot( + data: pd.DataFrame, + min_epoch: int, + output_path: str, + output_format: str, + linear: bool, + start_swa: int, + error_bars: bool, + keys: str, + heads: str, +) -> None: + """ + Plots train,validation loss and errors as a function of epoch. + min_epoch: minimum epoch to plot. + output_path: path to save the plot. + output_format: format to save the plot. + start_swa: whether to plot a dashed line to show epoch when stage two loss (swa) begins. + error_bars: whether to plot standard deviation of loss. + linear: whether to plot in linear scale or logscale (default). + keys: Values to plot. + heads: Heads used for multihead training. + """ + + labels = { + "mae_e": "MAE E [meV]", + "mae_e_per_atom": "MAE E/atom [meV]", + "rmse_e": "RMSE E [meV]", + "rmse_e_per_atom": "RMSE E/atom [meV]", + "q95_e": "Q95 E [meV]", + "mae_f": "MAE F [meV / A]", + "rel_mae_f": "Relative MAE F [meV / A]", + "rmse_f": "RMSE F [meV / A]", + "rel_rmse_f": "Relative RMSE F [meV / A]", + "q95_f": "Q95 F [meV / A]", + "mae_stress": "MAE Stress", + "rmse_stress": "RMSE Stress [meV / A^3]", + "rmse_virials_per_atom": " RMSE virials/atom [meV]", + "mae_virials": "MAE Virials [meV]", + "rmse_mu_per_atom": "RMSE MU/atom [mDebye]", + } + + data = data[data["epoch"] > min_epoch] + if heads is None: + data = ( + data.groupby(["name", "mode", "epoch"]).agg(["mean", "std"]).reset_index() + ) + + valid_data = data[data["mode"] == "eval"] + valid_data_dict = {"default": valid_data} + train_data = data[data["mode"] == "opt"] + else: + heads = heads.split(",") + # Separate eval and opt data + valid_data = ( + data[data["mode"] == "eval"] + .groupby(["name", "mode", "epoch", "head"]) + .agg(["mean", "std"]) + .reset_index() + ) + train_data = ( + data[data["mode"] == "opt"] + .groupby(["name", "mode", "epoch"]) + .agg(["mean", "std"]) + .reset_index() + ) + valid_data_dict = { + head: valid_data[valid_data["head"] == head] for head in heads + } + + for head, valid_data in valid_data_dict.items(): + fig, axes = plt.subplots( + nrows=1, ncols=2, figsize=(10, 3), constrained_layout=True + ) + + # ---- Plot loss ---- + ax = axes[0] + ax.plot( + train_data["epoch"], + train_data["loss"]["mean"], + color=colors[1], + linewidth=1, + ) + ax.set_ylabel("Training Loss", color=colors[1]) + ax.set_yscale("log") + + ax2 = ax.twinx() + ax2.plot( + valid_data["epoch"], + valid_data["loss"]["mean"], + color=colors[0], + linewidth=1, + ) + ax2.set_ylabel("Validation Loss", color=colors[0]) + + if not linear: + ax.set_yscale("log") + ax2.set_yscale("log") + + if error_bars: + ax.fill_between( + train_data["epoch"], + train_data["loss"]["mean"] - train_data["loss"]["std"], + train_data["loss"]["mean"] + train_data["loss"]["std"], + alpha=0.3, + color=colors[1], + ) + ax.fill_between( + valid_data["epoch"], + valid_data["loss"]["mean"] - valid_data["loss"]["std"], + valid_data["loss"]["mean"] + valid_data["loss"]["std"], + alpha=0.3, + color=colors[0], + ) + + if start_swa is not None: + ax.axvline( + start_swa, + color="black", + linestyle="dashed", + linewidth=1, + alpha=0.6, + label="Stage Two Starts", + ) + + ax.set_xlabel("Epoch") + ax.set_ylabel("Loss") + ax.legend(loc="upper right", fontsize=4) + ax.grid(True, linestyle="--", alpha=0.5) + + # ---- Plot selected keys ---- + ax = axes[1] + twin_axes = [] + for i, key in enumerate(keys.split(",")): + color = colors[(i + 3)] + label = labels.get(key, key) + + if i == 0: + main_ax = ax + else: + main_ax = ax.twinx() + main_ax.spines.right.set_position(("outward", 40 * (i - 1))) + twin_axes.append(main_ax) + + main_ax.plot( + valid_data["epoch"], + valid_data[key]["mean"] * 1e3, + color=color, + label=label, + linewidth=1, + ) + + if error_bars: + main_ax.fill_between( + valid_data["epoch"], + (valid_data[key]["mean"] - valid_data[key]["std"]) * 1e3, + (valid_data[key]["mean"] + valid_data[key]["std"]) * 1e3, + alpha=0.3, + color=color, + ) + + main_ax.set_ylabel(label, color=color) + main_ax.tick_params(axis="y", colors=color) + + if start_swa is not None: + ax.axvline( + start_swa, + color="black", + linestyle="dashed", + linewidth=1, + alpha=0.6, + label="Stage Two Starts", + ) + + ax.set_xlabel("Epoch") + ax.set_xlim(left=min_epoch) + ax.grid(True, linestyle="--", alpha=0.5) + + fig.savefig( + f"{output_path}_{head}.{output_format}", dpi=300, bbox_inches="tight" + ) + plt.close(fig) + + +def get_paths(path: str) -> List[str]: + if os.path.isfile(path): + return [path] + paths = glob.glob(os.path.join(path, "*_train.txt")) + + if len(paths) == 0: + raise RuntimeError(f"Cannot find results in '{path}'") + + return paths + + +def main() -> None: + args = parse_args() + run(args) + + +def run(args: argparse.Namespace) -> None: + data = pd.DataFrame( + results + for path in get_paths(args.path) + for results in parse_training_results(path) + ) + + for name, group in data.groupby("name"): + plot( + group, + min_epoch=args.min_epoch, + output_path=name, + output_format=args.output_format, + linear=args.linear, + start_swa=args.start_swa, + error_bars=args.error_bars, + keys=args.keys, + heads=args.heads, + ) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/preprocess_data.py b/mace-bench/3rdparty/mace/mace/cli/preprocess_data.py index 11ba17b..64f1740 100644 --- a/mace-bench/3rdparty/mace/mace/cli/preprocess_data.py +++ b/mace-bench/3rdparty/mace/mace/cli/preprocess_data.py @@ -1,300 +1,300 @@ -# This file loads an xyz dataset and prepares -# new hdf5 file that is ready for training with on-the-fly dataloading - -import argparse -import ast -import json -import logging -import multiprocessing as mp -import os -import random -from functools import partial -from glob import glob -from typing import List, Tuple - -import h5py -import numpy as np -import tqdm - -from mace import data, tools -from mace.data import KeySpecification, update_keyspec_from_kwargs -from mace.data.utils import save_configurations_as_HDF5 -from mace.modules import compute_statistics -from mace.tools import torch_geometric -from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz -from mace.tools.utils import AtomicNumberTable - - -def compute_stats_target( - file: str, - z_table: AtomicNumberTable, - r_max: float, - atomic_energies: Tuple, - batch_size: int, -): - train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max) - train_loader = torch_geometric.dataloader.DataLoader( - dataset=train_dataset, - batch_size=batch_size, - shuffle=False, - drop_last=False, - ) - - avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies) - output = [avg_num_neighbors, mean, std] - return output - - -def pool_compute_stats(inputs: List): - path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs - - with mp.Pool(processes=num_process) as pool: - re = [ - pool.apply_async( - compute_stats_target, - args=( - file, - z_table, - r_max, - atomic_energies, - batch_size, - ), - ) - for file in glob(path_to_files + "/*") - ] - - pool.close() - pool.join() - - results = [r.get() for r in tqdm.tqdm(re)] - - if not results: - raise ValueError( - "No results were computed. Check if the input files exist and are readable." - ) - - # Separate avg_num_neighbors, mean, and std - avg_num_neighbors = np.mean([r[0] for r in results]) - means = np.array([r[1] for r in results]) - stds = np.array([r[2] for r in results]) - - # Compute averages - mean = np.mean(means, axis=0).item() - std = np.mean(stds, axis=0).item() - - return avg_num_neighbors, mean, std - - -def split_array(a: np.ndarray, max_size: int): - drop_last = False - if len(a) % 2 == 1: - a = np.append(a, a[-1]) - drop_last = True - factors = get_prime_factors(len(a)) - max_factor = 1 - for i in range(1, len(factors) + 1): - for j in range(0, len(factors) - i + 1): - if np.prod(factors[j : j + i]) <= max_size: - test = np.prod(factors[j : j + i]) - max_factor = max(test, max_factor) - return np.array_split(a, max_factor), drop_last - - -def get_prime_factors(n: int): - factors = [] - for i in range(2, n + 1): - while n % i == 0: - factors.append(i) - n = n / i - return factors - - -# Define Task for Multiprocessiing -def multi_train_hdf5(process, args, split_train, drop_last): - with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_train[process], process, f) - - -def multi_valid_hdf5(process, args, split_valid, drop_last): - with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_valid[process], process, f) - - -def multi_test_hdf5(process, name, args, split_test, drop_last): - with h5py.File( - args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" - ) as f: - f.attrs["drop_last"] = drop_last - save_configurations_as_HDF5(split_test[process], process, f) - - -def main() -> None: - """ - This script loads an xyz dataset and prepares - new hdf5 file that is ready for training with on-the-fly dataloading - """ - args = tools.build_preprocess_arg_parser().parse_args() - run(args) - - -def run(args: argparse.Namespace): - """ - This script loads an xyz dataset and prepares - new hdf5 file that is ready for training with on-the-fly dataloading - """ - - # currently support only command line property_key syntax - args.key_specification = KeySpecification() - update_keyspec_from_kwargs(args.key_specification, vars(args)) - - # Setup - tools.set_seeds(args.seed) - random.seed(args.seed) - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)-8s %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler()], - ) - - try: - config_type_weights = ast.literal_eval(args.config_type_weights) - assert isinstance(config_type_weights, dict) - except Exception as e: # pylint: disable=W0703 - logging.warning( - f"Config type weights not specified correctly ({e}), using Default" - ) - config_type_weights = {"Default": 1.0} - - folders = ["train", "val", "test"] - for sub_dir in folders: - if not os.path.exists(args.h5_prefix + sub_dir): - os.makedirs(args.h5_prefix + sub_dir) - - # Data preparation - collections, atomic_energies_dict = get_dataset_from_xyz( - work_dir=args.work_dir, - train_path=args.train_file, - valid_path=args.valid_file, - valid_fraction=args.valid_fraction, - config_type_weights=config_type_weights, - test_path=args.test_file, - seed=args.seed, - key_specification=args.key_specification, - head_name=None, - ) - - # Atomic number table - # yapf: disable - if args.atomic_numbers is None: - z_table = tools.get_atomic_number_table_from_zs( - z - for configs in (collections.train, collections.valid) - for config in configs - for z in config.atomic_numbers - ) - else: - logging.info("Using atomic numbers from command line argument") - zs_list = ast.literal_eval(args.atomic_numbers) - assert isinstance(zs_list, list) - z_table = tools.get_atomic_number_table_from_zs(zs_list) - - logging.info("Preparing training set") - if args.shuffle: - random.shuffle(collections.train) - - # split collections.train into batches and save them to hdf5 - split_train = np.array_split(collections.train,args.num_process) - drop_last = False - if len(collections.train) % 2 == 1: - drop_last = True - - multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last) - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_train_hdf5_, args=[i]) - p.start() - processes.append(p) - - for i in processes: - i.join() - - if args.compute_statistics: - logging.info("Computing statistics") - if len(atomic_energies_dict) == 0: - atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) - - # Remove atomic energies if element not in z_table - removed_atomic_energies = {} - for z in list(atomic_energies_dict): - if z not in z_table.zs: - removed_atomic_energies[z] = atomic_energies_dict.pop(z) - if len(removed_atomic_energies) > 0: - logging.warning("Atomic energies for elements not present in the atomic number table have been removed.") - logging.warning(f"Removed atomic energies (eV): {str(removed_atomic_energies)}") - logging.warning("To include these elements in the model, specify all atomic numbers explicitly using the --atomic_numbers argument.") - - atomic_energies: np.ndarray = np.array( - [atomic_energies_dict[z] for z in z_table.zs] - ) - logging.info(f"Atomic Energies: {atomic_energies.tolist()}") - _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] - avg_num_neighbors, mean, std=pool_compute_stats(_inputs) - logging.info(f"Average number of neighbors: {avg_num_neighbors}") - logging.info(f"Mean: {mean}") - logging.info(f"Standard deviation: {std}") - - # save the statistics as a json - statistics = { - "atomic_energies": str(atomic_energies_dict), - "avg_num_neighbors": avg_num_neighbors, - "mean": mean, - "std": std, - "atomic_numbers": str([int(z) for z in z_table.zs]), - "r_max": args.r_max, - } - - with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 - json.dump(statistics, f) - - logging.info("Preparing validation set") - if args.shuffle: - random.shuffle(collections.valid) - split_valid = np.array_split(collections.valid, args.num_process) - drop_last = False - if len(collections.valid) % 2 == 1: - drop_last = True - - multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last) - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_valid_hdf5_, args=[i]) - p.start() - processes.append(p) - - for i in processes: - i.join() - - if args.test_file is not None: - logging.info("Preparing test sets") - for name, subset in collections.tests: - drop_last = False - if len(subset) % 2 == 1: - drop_last = True - split_test = np.array_split(subset, args.num_process) - multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last) - - processes = [] - for i in range(args.num_process): - p = mp.Process(target=multi_test_hdf5_, args=[i, name]) - p.start() - processes.append(p) - - for i in processes: - i.join() - - -if __name__ == "__main__": - main() +# This file loads an xyz dataset and prepares +# new hdf5 file that is ready for training with on-the-fly dataloading + +import argparse +import ast +import json +import logging +import multiprocessing as mp +import os +import random +from functools import partial +from glob import glob +from typing import List, Tuple + +import h5py +import numpy as np +import tqdm + +from mace import data, tools +from mace.data import KeySpecification, update_keyspec_from_kwargs +from mace.data.utils import save_configurations_as_HDF5 +from mace.modules import compute_statistics +from mace.tools import torch_geometric +from mace.tools.scripts_utils import get_atomic_energies, get_dataset_from_xyz +from mace.tools.utils import AtomicNumberTable + + +def compute_stats_target( + file: str, + z_table: AtomicNumberTable, + r_max: float, + atomic_energies: Tuple, + batch_size: int, +): + train_dataset = data.HDF5Dataset(file, z_table=z_table, r_max=r_max) + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + ) + + avg_num_neighbors, mean, std = compute_statistics(train_loader, atomic_energies) + output = [avg_num_neighbors, mean, std] + return output + + +def pool_compute_stats(inputs: List): + path_to_files, z_table, r_max, atomic_energies, batch_size, num_process = inputs + + with mp.Pool(processes=num_process) as pool: + re = [ + pool.apply_async( + compute_stats_target, + args=( + file, + z_table, + r_max, + atomic_energies, + batch_size, + ), + ) + for file in glob(path_to_files + "/*") + ] + + pool.close() + pool.join() + + results = [r.get() for r in tqdm.tqdm(re)] + + if not results: + raise ValueError( + "No results were computed. Check if the input files exist and are readable." + ) + + # Separate avg_num_neighbors, mean, and std + avg_num_neighbors = np.mean([r[0] for r in results]) + means = np.array([r[1] for r in results]) + stds = np.array([r[2] for r in results]) + + # Compute averages + mean = np.mean(means, axis=0).item() + std = np.mean(stds, axis=0).item() + + return avg_num_neighbors, mean, std + + +def split_array(a: np.ndarray, max_size: int): + drop_last = False + if len(a) % 2 == 1: + a = np.append(a, a[-1]) + drop_last = True + factors = get_prime_factors(len(a)) + max_factor = 1 + for i in range(1, len(factors) + 1): + for j in range(0, len(factors) - i + 1): + if np.prod(factors[j : j + i]) <= max_size: + test = np.prod(factors[j : j + i]) + max_factor = max(test, max_factor) + return np.array_split(a, max_factor), drop_last + + +def get_prime_factors(n: int): + factors = [] + for i in range(2, n + 1): + while n % i == 0: + factors.append(i) + n = n / i + return factors + + +# Define Task for Multiprocessiing +def multi_train_hdf5(process, args, split_train, drop_last): + with h5py.File(args.h5_prefix + "train/train_" + str(process) + ".h5", "w") as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_train[process], process, f) + + +def multi_valid_hdf5(process, args, split_valid, drop_last): + with h5py.File(args.h5_prefix + "val/val_" + str(process) + ".h5", "w") as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_valid[process], process, f) + + +def multi_test_hdf5(process, name, args, split_test, drop_last): + with h5py.File( + args.h5_prefix + "test/" + name + "_" + str(process) + ".h5", "w" + ) as f: + f.attrs["drop_last"] = drop_last + save_configurations_as_HDF5(split_test[process], process, f) + + +def main() -> None: + """ + This script loads an xyz dataset and prepares + new hdf5 file that is ready for training with on-the-fly dataloading + """ + args = tools.build_preprocess_arg_parser().parse_args() + run(args) + + +def run(args: argparse.Namespace): + """ + This script loads an xyz dataset and prepares + new hdf5 file that is ready for training with on-the-fly dataloading + """ + + # currently support only command line property_key syntax + args.key_specification = KeySpecification() + update_keyspec_from_kwargs(args.key_specification, vars(args)) + + # Setup + tools.set_seeds(args.seed) + random.seed(args.seed) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler()], + ) + + try: + config_type_weights = ast.literal_eval(args.config_type_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + + folders = ["train", "val", "test"] + for sub_dir in folders: + if not os.path.exists(args.h5_prefix + sub_dir): + os.makedirs(args.h5_prefix + sub_dir) + + # Data preparation + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=args.train_file, + valid_path=args.valid_file, + valid_fraction=args.valid_fraction, + config_type_weights=config_type_weights, + test_path=args.test_file, + seed=args.seed, + key_specification=args.key_specification, + head_name=None, + ) + + # Atomic number table + # yapf: disable + if args.atomic_numbers is None: + z_table = tools.get_atomic_number_table_from_zs( + z + for configs in (collections.train, collections.valid) + for config in configs + for z in config.atomic_numbers + ) + else: + logging.info("Using atomic numbers from command line argument") + zs_list = ast.literal_eval(args.atomic_numbers) + assert isinstance(zs_list, list) + z_table = tools.get_atomic_number_table_from_zs(zs_list) + + logging.info("Preparing training set") + if args.shuffle: + random.shuffle(collections.train) + + # split collections.train into batches and save them to hdf5 + split_train = np.array_split(collections.train,args.num_process) + drop_last = False + if len(collections.train) % 2 == 1: + drop_last = True + + multi_train_hdf5_ = partial(multi_train_hdf5, args=args, split_train=split_train, drop_last=drop_last) + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_train_hdf5_, args=[i]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + if args.compute_statistics: + logging.info("Computing statistics") + if len(atomic_energies_dict) == 0: + atomic_energies_dict = get_atomic_energies(args.E0s, collections.train, z_table) + + # Remove atomic energies if element not in z_table + removed_atomic_energies = {} + for z in list(atomic_energies_dict): + if z not in z_table.zs: + removed_atomic_energies[z] = atomic_energies_dict.pop(z) + if len(removed_atomic_energies) > 0: + logging.warning("Atomic energies for elements not present in the atomic number table have been removed.") + logging.warning(f"Removed atomic energies (eV): {str(removed_atomic_energies)}") + logging.warning("To include these elements in the model, specify all atomic numbers explicitly using the --atomic_numbers argument.") + + atomic_energies: np.ndarray = np.array( + [atomic_energies_dict[z] for z in z_table.zs] + ) + logging.info(f"Atomic Energies: {atomic_energies.tolist()}") + _inputs = [args.h5_prefix+'train', z_table, args.r_max, atomic_energies, args.batch_size, args.num_process] + avg_num_neighbors, mean, std=pool_compute_stats(_inputs) + logging.info(f"Average number of neighbors: {avg_num_neighbors}") + logging.info(f"Mean: {mean}") + logging.info(f"Standard deviation: {std}") + + # save the statistics as a json + statistics = { + "atomic_energies": str(atomic_energies_dict), + "avg_num_neighbors": avg_num_neighbors, + "mean": mean, + "std": std, + "atomic_numbers": str([int(z) for z in z_table.zs]), + "r_max": args.r_max, + } + + with open(args.h5_prefix + "statistics.json", "w") as f: # pylint: disable=W1514 + json.dump(statistics, f) + + logging.info("Preparing validation set") + if args.shuffle: + random.shuffle(collections.valid) + split_valid = np.array_split(collections.valid, args.num_process) + drop_last = False + if len(collections.valid) % 2 == 1: + drop_last = True + + multi_valid_hdf5_ = partial(multi_valid_hdf5, args=args, split_valid=split_valid, drop_last=drop_last) + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_valid_hdf5_, args=[i]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + if args.test_file is not None: + logging.info("Preparing test sets") + for name, subset in collections.tests: + drop_last = False + if len(subset) % 2 == 1: + drop_last = True + split_test = np.array_split(subset, args.num_process) + multi_test_hdf5_ = partial(multi_test_hdf5, args=args, split_test=split_test, drop_last=drop_last) + + processes = [] + for i in range(args.num_process): + p = mp.Process(target=multi_test_hdf5_, args=[i, name]) + p.start() + processes.append(p) + + for i in processes: + i.join() + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/run_train.py b/mace-bench/3rdparty/mace/mace/cli/run_train.py index 00a23fb..977ec45 100644 --- a/mace-bench/3rdparty/mace/mace/cli/run_train.py +++ b/mace-bench/3rdparty/mace/mace/cli/run_train.py @@ -1,1007 +1,1007 @@ -########################################################################################### -# Training script for MACE -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import ast -import glob -import json -import logging -import os -from copy import deepcopy -from pathlib import Path -from typing import List, Optional - -import torch.distributed -import torch.nn.functional -from e3nn.util import jit -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import LBFGS -from torch.utils.data import ConcatDataset -from torch_ema import ExponentialMovingAverage - -import mace -from mace import data, tools -from mace.calculators.foundations_models import mace_mp, mace_off -from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq -from mace.cli.visualise_train import TrainingPlotter -from mace.data import KeySpecification, update_keyspec_from_kwargs -from mace.tools import torch_geometric -from mace.tools.model_script_utils import configure_model -from mace.tools.multihead_tools import ( - HeadConfig, - assemble_mp_data, - dict_head_to_dataclass, - prepare_default_head, - prepare_pt_head, -) -from mace.tools.run_train_utils import ( - combine_datasets, - load_dataset_for_path, - normalize_file_paths, -) -from mace.tools.scripts_utils import ( - LRScheduler, - SubsetCollection, - check_path_ase_read, - convert_to_json_format, - dict_to_array, - extract_config_mace_model, - get_atomic_energies, - get_avg_num_neighbors, - get_config_type_weights, - get_dataset_from_xyz, - get_files_with_suffix, - get_loss_fn, - get_optimizer, - get_params_options, - get_swa, - print_git_commit, - remove_pt_head, - setup_wandb, -) -from mace.tools.slurm_distributed import DistributedEnvironment -from mace.tools.tables_utils import create_error_table -from mace.tools.utils import AtomicNumberTable - - -def main() -> None: - """ - This script runs the training/fine tuning for mace - """ - args = tools.build_default_arg_parser().parse_args() - run(args) - - -def run(args) -> None: - """ - This script runs the training/fine tuning for mace - """ - tag = tools.get_tag(name=args.name, seed=args.seed) - args, input_log_messages = tools.check_args(args) - - # default keyspec to update using heads dictionary - args.key_specification = KeySpecification() - update_keyspec_from_kwargs(args.key_specification, vars(args)) - - if args.device == "xpu": - try: - import intel_extension_for_pytorch as ipex - except ImportError as e: - raise ImportError( - "Error: Intel extension for PyTorch not found, but XPU device was specified" - ) from e - if args.distributed: - try: - distr_env = DistributedEnvironment() - except Exception as e: # pylint: disable=W0703 - logging.error(f"Failed to initialize distributed environment: {e}") - return - world_size = distr_env.world_size - local_rank = distr_env.local_rank - rank = distr_env.rank - if rank == 0: - print(distr_env) - torch.distributed.init_process_group(backend="nccl") - else: - rank = int(0) - - # Setup - tools.set_seeds(args.seed) - tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) - logging.info("===========VERIFYING SETTINGS===========") - for message, loglevel in input_log_messages: - logging.log(level=loglevel, msg=message) - - if args.distributed: - torch.cuda.set_device(local_rank) - logging.info(f"Process group initialized: {torch.distributed.is_initialized()}") - logging.info(f"Processes: {world_size}") - - try: - logging.info(f"MACE version: {mace.__version__}") - except AttributeError: - logging.info("Cannot find MACE version, please install MACE via pip") - logging.debug(f"Configuration: {args}") - - tools.set_default_dtype(args.default_dtype) - device = tools.init_device(args.device) - commit = print_git_commit() - model_foundation: Optional[torch.nn.Module] = None - foundation_model_avg_num_neighbors = 0 - if args.foundation_model is not None: - if args.foundation_model in ["small", "medium", "large"]: - logging.info( - f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." - ) - calc = mace_mp( - model=args.foundation_model, - device=args.device, - default_dtype=args.default_dtype, - ) - model_foundation = calc.models[0] - elif args.foundation_model in ["small_off", "medium_off", "large_off"]: - model_type = args.foundation_model.split("_")[0] - logging.info( - f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license." - ) - calc = mace_off( - model=model_type, - device=args.device, - default_dtype=args.default_dtype, - ) - model_foundation = calc.models[0] - else: - model_foundation = torch.load( - args.foundation_model, map_location=args.device - ) - logging.info( - f"Using foundation model {args.foundation_model} as initial checkpoint." - ) - args.r_max = model_foundation.r_max.item() - foundation_model_avg_num_neighbors = model_foundation.interactions[ - 0 - ].avg_num_neighbors - if ( - args.foundation_model not in ["small", "medium", "large"] - and args.pt_train_file is None - ): - logging.warning( - "Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file." - ) - args.multiheads_finetuning = False - if args.multiheads_finetuning: - assert ( - args.E0s != "average" - ), "average atomic energies cannot be used for multiheads finetuning" - # check that the foundation model has a single head, if not, use the first head - if not args.force_mh_ft_lr: - logging.info( - "Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True." - ) - args.lr = 0.0001 - args.ema = True - args.ema_decay = 0.99999 - logging.info( - "Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True" - ) - if hasattr(model_foundation, "heads"): - if len(model_foundation.heads) > 1: - logging.warning( - "Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head." - ) - model_foundation = remove_pt_head( - model_foundation, args.foundation_head - ) - else: - args.multiheads_finetuning = False - - if args.heads is not None: - args.heads = ast.literal_eval(args.heads) - for _, head_dict in args.heads.items(): - # priority is global args < head property_key values < head info_keys+arrays_keys - head_keyspec = deepcopy(args.key_specification) - update_keyspec_from_kwargs(head_keyspec, head_dict) - head_keyspec.update( - info_keys=head_dict.get("info_keys", {}), - arrays_keys=head_dict.get("arrays_keys", {}), - ) - head_dict["key_specification"] = head_keyspec - else: - args.heads = prepare_default_head(args) - if args.multiheads_finetuning: - pt_keyspec = ( - args.heads["pt_head"]["key_specification"] - if "pt_head" in args.heads - else deepcopy(args.key_specification) - ) - args.heads["pt_head"] = prepare_pt_head( - args, pt_keyspec, foundation_model_avg_num_neighbors - ) - - logging.info("===========LOADING INPUT DATA===========") - heads = list(args.heads.keys()) - logging.info(f"Using heads: {heads}") - logging.info("Using the key specifications to parse data:") - for name, head_dict in args.heads.items(): - head_keyspec = head_dict["key_specification"] - logging.info(f"{name}: {head_keyspec}") - - head_configs: List[HeadConfig] = [] - for head, head_args in args.heads.items(): - logging.info(f"============= Processing head {head} ===========") - head_config = dict_head_to_dataclass(head_args, head, args) - - # Handle train_file and valid_file - normalize to lists - if hasattr(head_config, "train_file") and head_config.train_file is not None: - head_config.train_file = normalize_file_paths(head_config.train_file) - if hasattr(head_config, "valid_file") and head_config.valid_file is not None: - head_config.valid_file = normalize_file_paths(head_config.valid_file) - if hasattr(head_config, "test_file") and head_config.test_file is not None: - head_config.test_file = normalize_file_paths(head_config.test_file) - - if ( - head_config.statistics_file is not None - and head_config.head_name != "pt_head" - ): - with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 - statistics = json.load(f) - logging.info("Using statistics json file") - head_config.atomic_numbers = statistics["atomic_numbers"] - head_config.mean = statistics["mean"] - head_config.std = statistics["std"] - head_config.avg_num_neighbors = statistics["avg_num_neighbors"] - head_config.compute_avg_num_neighbors = False - if isinstance(statistics["atomic_energies"], str) and statistics[ - "atomic_energies" - ].endswith(".json"): - with open(statistics["atomic_energies"], "r", encoding="utf-8") as f: - atomic_energies = json.load(f) - head_config.E0s = atomic_energies - head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) - else: - head_config.E0s = statistics["atomic_energies"] - head_config.atomic_energies_dict = ast.literal_eval( - statistics["atomic_energies"] - ) - if head_config.train_file == ["mp"]: - assert ( - head_config.head_name == "pt_head" - ), "Only pt_head should use mp as train_file" - logging.info( - "Using the full Materials Project data for replay. You can construct a different subset using `fine_tuning_select.py` script." - ) - collections = assemble_mp_data(args, head_config, tag) - head_config.collections = collections - elif any(check_path_ase_read(f) for f in head_config.train_file): - train_files_ase_list = [ - f for f in head_config.train_file if check_path_ase_read(f) - ] - valid_files_ase_list = None - test_files_ase_list = None - if head_config.valid_file: - valid_files_ase_list = [ - f for f in head_config.valid_file if check_path_ase_read(f) - ] - if head_config.test_file: - test_files_ase_list = [ - f for f in head_config.test_file if check_path_ase_read(f) - ] - config_type_weights = get_config_type_weights( - head_config.config_type_weights - ) - collections, atomic_energies_dict = get_dataset_from_xyz( - work_dir=args.work_dir, - train_path=train_files_ase_list, - valid_path=valid_files_ase_list, - valid_fraction=head_config.valid_fraction, - config_type_weights=config_type_weights, - test_path=test_files_ase_list, - seed=args.seed, - key_specification=head_config.key_specification, - head_name=head_config.head_name, - keep_isolated_atoms=head_config.keep_isolated_atoms, - ) - head_config.collections = SubsetCollection( - train=collections.train, - valid=collections.valid, - tests=collections.tests, - ) - head_config.atomic_energies_dict = atomic_energies_dict - logging.info( - f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " - f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," - ) - head_configs.append(head_config) - - if all( - check_path_ase_read(head_config.train_file[0]) for head_config in head_configs - ): - size_collections_train = sum( - len(head_config.collections.train) for head_config in head_configs - ) - size_collections_valid = sum( - len(head_config.collections.valid) for head_config in head_configs - ) - if size_collections_train < args.batch_size: - logging.error( - f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})" - ) - if size_collections_valid < args.valid_batch_size: - logging.warning( - f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})" - ) - - if args.multiheads_finetuning: - logging.info( - "==================Using multiheads finetuning mode==================" - ) - args.loss = "universal" - - all_ase_readable = all( - all(check_path_ase_read(f) for f in head_config.train_file) - for head_config in head_configs - ) - head_config_pt = filter(lambda x: x.head_name == "pt_head", head_configs) - head_config_pt = next(head_config_pt, None) - assert head_config_pt is not None, "Pretraining head not found" - if all_ase_readable: - ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train) - if ratio_pt_ft < 0.1: - logging.warning( - f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, " - f"increasing the number of configurations in the fine-tuning heads by {int(0.1 / ratio_pt_ft)}" - ) - for head_config in head_configs: - if head_config.head_name == "pt_head": - continue - head_config.collections.train += ( - head_config.collections.train * int(0.1 / ratio_pt_ft) - ) - logging.info( - f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}" - ) - else: - logging.debug( - "Using LMDB/HDF5 datasets for pretraining or fine-tuning - skipping ratio check" - ) - - # Atomic number table - # yapf: disable - for head_config in head_configs: - if head_config.atomic_numbers is None: - assert all(check_path_ase_read(f) for f in head_config.train_file), "Must specify atomic_numbers when using .h5 or .aselmdb train_file input" - z_table_head = tools.get_atomic_number_table_from_zs( - z - for configs in (head_config.collections.train, head_config.collections.valid) - for config in configs - for z in config.atomic_numbers - ) - head_config.atomic_numbers = z_table_head.zs - head_config.z_table = z_table_head - else: - if head_config.statistics_file is None: - logging.info("Using atomic numbers from command line argument") - else: - logging.info("Using atomic numbers from statistics file") - zs_list = ast.literal_eval(head_config.atomic_numbers) - assert isinstance(zs_list, list) - z_table_head = tools.AtomicNumberTable(zs_list) - head_config.atomic_numbers = zs_list - head_config.z_table = z_table_head - # yapf: enable - all_atomic_numbers = set() - for head_config in head_configs: - all_atomic_numbers.update(head_config.atomic_numbers) - z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) - if args.foundation_model_elements and model_foundation: - z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist())) - logging.info(f"Atomic Numbers used: {z_table.zs}") - - # Atomic energies - atomic_energies_dict = {} - for head_config in head_configs: - if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: - assert head_config.E0s is not None, "Atomic energies must be provided" - if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation": - atomic_energies_dict[head_config.head_name] = get_atomic_energies( - head_config.E0s, head_config.collections.train, head_config.z_table - ) - elif head_config.E0s.lower() == "foundation": - assert args.foundation_model is not None - z_table_foundation = AtomicNumberTable( - [int(z) for z in model_foundation.atomic_numbers] - ) - foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies - if foundation_atomic_energies.ndim > 1: - foundation_atomic_energies = foundation_atomic_energies.squeeze() - if foundation_atomic_energies.ndim == 2: - foundation_atomic_energies = foundation_atomic_energies[0] - logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") - atomic_energies_dict[head_config.head_name] = { - z: foundation_atomic_energies[ - z_table_foundation.z_to_index(z) - ].item() - for z in z_table.zs - } - else: - atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) - else: - atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict - - # Atomic energies for multiheads finetuning - if args.multiheads_finetuning: - assert ( - model_foundation is not None - ), "Model foundation must be provided for multiheads finetuning" - z_table_foundation = AtomicNumberTable( - [int(z) for z in model_foundation.atomic_numbers] - ) - foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies - if foundation_atomic_energies.ndim > 1: - foundation_atomic_energies = foundation_atomic_energies.squeeze() - if foundation_atomic_energies.ndim == 2: - foundation_atomic_energies = foundation_atomic_energies[0] - logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") - atomic_energies_dict["pt_head"] = { - z: foundation_atomic_energies[ - z_table_foundation.z_to_index(z) - ].item() - for z in z_table.zs - } - heads = sorted(heads, key=lambda x: -1000 if x == "pt_head" else 0) - # Padding atomic energies if keeping all elements of the foundation model - if args.foundation_model_elements and model_foundation: - atomic_energies_dict_padded = {} - for head_name, head_energies in atomic_energies_dict.items(): - energy_head_padded = {} - for z in z_table.zs: - energy_head_padded[z] = head_energies.get(z, 0.0) - atomic_energies_dict_padded[head_name] = energy_head_padded - atomic_energies_dict = atomic_energies_dict_padded - - if args.model == "AtomicDipolesMACE": - atomic_energies = None - dipole_only = True - args.compute_dipole = True - args.compute_energy = False - args.compute_forces = False - args.compute_virials = False - args.compute_stress = False - else: - dipole_only = False - if args.model == "EnergyDipolesMACE": - args.compute_dipole = True - args.compute_energy = True - args.compute_forces = True - args.compute_virials = False - args.compute_stress = False - else: - args.compute_energy = True - args.compute_dipole = False - # atomic_energies: np.ndarray = np.array( - # [atomic_energies_dict[z] for z in z_table.zs] - # ) - atomic_energies = dict_to_array(atomic_energies_dict, heads) - for head_config in head_configs: - try: - logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") - except KeyError as e: - raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e - - # Load datasets for each head, supporting multiple files per head - valid_sets = {head: [] for head in heads} - train_sets = {head: [] for head in heads} - - for head_config in head_configs: - train_datasets = [] - - logging.info(f"Processing datasets for head '{head_config.head_name}'") - ase_files = [f for f in head_config.train_file if check_path_ase_read(f)] - non_ase_files = [f for f in head_config.train_file if not check_path_ase_read(f)] - - if ase_files: - dataset = load_dataset_for_path( - file_path=ase_files, - r_max=args.r_max, - z_table=z_table, - head_config=head_config, - heads=heads, - collection=head_config.collections.train, - ) - train_datasets.append(dataset) - logging.debug(f"Successfully loaded dataset from ASE files: {ase_files}") - - for file in non_ase_files: - dataset = load_dataset_for_path( - file_path=file, - r_max=args.r_max, - z_table=z_table, - head_config=head_config, - heads=heads, - ) - train_datasets.append(dataset) - logging.debug(f"Successfully loaded dataset from non-ASE file: {file}") - - if not train_datasets: - raise ValueError(f"No valid training datasets found for head {head_config.head_name}") - - train_sets[head_config.head_name] = combine_datasets(train_datasets, head_config.head_name) - - if head_config.valid_file: - valid_datasets = [] - - valid_ase_files = [f for f in head_config.valid_file if check_path_ase_read(f)] - valid_non_ase_files = [f for f in head_config.valid_file if not check_path_ase_read(f)] - - if valid_ase_files: - valid_dataset = load_dataset_for_path( - file_path=valid_ase_files, - r_max=args.r_max, - z_table=z_table, - head_config=head_config, - heads=heads, - collection=head_config.collections.valid, - ) - valid_datasets.append(valid_dataset) - logging.debug(f"Successfully loaded validation dataset from ASE files: {valid_ase_files}") - for valid_file in valid_non_ase_files: - valid_dataset = load_dataset_for_path( - file_path=valid_file, - r_max=args.r_max, - z_table=z_table, - head_config=head_config, - heads=heads, - ) - valid_datasets.append(valid_dataset) - logging.debug(f"Successfully loaded validation dataset from {valid_file}") - - # Combine validation datasets - if valid_datasets: - valid_sets[head_config.head_name] = combine_datasets(valid_datasets, f"{head_config.head_name}_valid") - logging.info(f"Combined validation datasets for {head_config.head_name}") - - # If no valid file is provided but collection exist, use the validation set from the collection - if head_config.valid_file is None and head_config.collections.valid: - valid_sets[head_config.head_name] = [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, heads=heads - ) - for config in head_config.collections.valid - ] - if not valid_sets[head_config.head_name]: - raise ValueError(f"No valid datasets found for head {head_config.head_name}, please provide a valid_file or a valid_fraction") - - # Create data loader for this head - if isinstance(train_sets[head_config.head_name], list): - dataset_size = len(train_sets[head_config.head_name]) - else: - dataset_size = len(train_sets[head_config.head_name]) - logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}") - - train_loader_head = torch_geometric.dataloader.DataLoader( - dataset=train_sets[head_config.head_name], - batch_size=args.batch_size, - shuffle=True, - drop_last=(not args.lbfgs), - pin_memory=args.pin_memory, - num_workers=args.num_workers, - generator=torch.Generator().manual_seed(args.seed), - ) - head_config.train_loader = train_loader_head - - # concatenate all the trainsets - train_set = ConcatDataset([train_sets[head] for head in heads]) - train_sampler, valid_sampler = None, None - if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=(not args.lbfgs), - seed=args.seed, - ) - valid_samplers = {} - for head, valid_set in valid_sets.items(): - valid_sampler = torch.utils.data.distributed.DistributedSampler( - valid_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, - ) - valid_samplers[head] = valid_sampler - train_loader = torch_geometric.dataloader.DataLoader( - dataset=train_set, - batch_size=args.batch_size, - sampler=train_sampler, - shuffle=(train_sampler is None), - drop_last=(train_sampler is None and not args.lbfgs), - pin_memory=args.pin_memory, - num_workers=args.num_workers, - generator=torch.Generator().manual_seed(args.seed), - ) - valid_loaders = {heads[i]: None for i in range(len(heads))} - if not isinstance(valid_sets, dict): - valid_sets = {"Default": valid_sets} - for head, valid_set in valid_sets.items(): - valid_loaders[head] = torch_geometric.dataloader.DataLoader( - dataset=valid_set, - batch_size=args.valid_batch_size, - sampler=valid_samplers[head] if args.distributed else None, - shuffle=False, - drop_last=False, - pin_memory=args.pin_memory, - num_workers=args.num_workers, - generator=torch.Generator().manual_seed(args.seed), - ) - - loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) - args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) - - # Model - model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs) - model.to(device) - - logging.debug(model) - logging.info(f"Total number of parameters: {tools.count_parameters(model)}") - logging.info("") - logging.info("===========OPTIMIZER INFORMATION===========") - logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") - logging.info(f"Batch size: {args.batch_size}") - if args.ema: - logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") - logging.info( - f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}" - ) - logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") - logging.info(loss_fn) - - # Cueq - if args.enable_cueq: - logging.info("Converting model to CUEQ for accelerated training") - assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE"] - model = run_e3nn_to_cueq(deepcopy(model), device=device) - # Optimizer - param_options = get_params_options(args, model) - optimizer: torch.optim.Optimizer - optimizer = get_optimizer(args, param_options) - if args.device == "xpu": - logging.info("Optimzing model and optimzier for XPU") - model, optimizer = ipex.optimize(model, optimizer=optimizer) - logger = tools.MetricsLogger( - directory=args.results_dir, tag=tag + "_train" - ) # pylint: disable=E1123 - - lr_scheduler = LRScheduler(optimizer, args) - - swa: Optional[tools.SWAContainer] = None - swas = [False] - if args.swa: - swa, swas = get_swa(args, model, optimizer, swas, dipole_only) - - checkpoint_handler = tools.CheckpointHandler( - directory=args.checkpoints_dir, - tag=tag, - keep=args.keep_checkpoints, - swa_start=args.start_swa, - ) - - start_epoch = 0 - restart_lbfgs = False - opt_start_epoch = None - if args.restart_latest: - try: - opt_start_epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=True, - device=device, - ) - except Exception: # pylint: disable=W0703 - try: - opt_start_epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=False, - device=device, - ) - except Exception: # pylint: disable=W0703 - restart_lbfgs = True - if opt_start_epoch is not None: - start_epoch = opt_start_epoch - - ema: Optional[ExponentialMovingAverage] = None - if args.ema: - ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) - else: - for group in optimizer.param_groups: - group["lr"] = args.lr - - if args.lbfgs: - logging.info("Switching optimizer to LBFGS") - optimizer = LBFGS(model.parameters(), - history_size=200, - max_iter=20, - line_search_fn="strong_wolfe") - if restart_lbfgs: - opt_start_epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=False, - device=device, - ) - if opt_start_epoch is not None: - start_epoch = opt_start_epoch - - if args.wandb: - setup_wandb(args) - if args.distributed: - distributed_model = DDP(model, device_ids=[local_rank]) - else: - distributed_model = None - - - train_valid_data_loader = {} - for head_config in head_configs: - data_loader_name = "train_" + head_config.head_name - train_valid_data_loader[data_loader_name] = head_config.train_loader - for head, valid_loader in valid_loaders.items(): - data_load_name = "valid_" + head - train_valid_data_loader[data_load_name] = valid_loader - - if args.plot and args.plot_frequency > 0: - try: - plotter = TrainingPlotter( - results_dir=logger.path, - heads=heads, - table_type=args.error_table, - train_valid_data=train_valid_data_loader, - test_data={}, - output_args=output_args, - device=device, - plot_frequency=args.plot_frequency, - distributed=args.distributed, - swa_start=swa.start if swa else None - ) - except Exception as e: # pylint: disable=W0718 - logging.debug(f"Creating Plotter failed: {e}") - else: - plotter = None - - if args.dry_run: - logging.info("DRY RUN mode enabled. Stopping now.") - return - - - tools.train( - model=model, - loss_fn=loss_fn, - train_loader=train_loader, - valid_loaders=valid_loaders, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - checkpoint_handler=checkpoint_handler, - eval_interval=args.eval_interval, - start_epoch=start_epoch, - max_num_epochs=args.max_num_epochs, - logger=logger, - patience=args.patience, - save_all_checkpoints=args.save_all_checkpoints, - output_args=output_args, - device=device, - swa=swa, - ema=ema, - max_grad_norm=args.clip_grad, - log_errors=args.error_table, - log_wandb=args.wandb, - distributed=args.distributed, - distributed_model=distributed_model, - plotter=plotter, - train_sampler=train_sampler, - rank=rank, - ) - - logging.info("") - logging.info("===========RESULTS===========") - - train_valid_data_loader = {} - for head_config in head_configs: - data_loader_name = "train_" + head_config.head_name - train_valid_data_loader[data_loader_name] = head_config.train_loader - for head, valid_loader in valid_loaders.items(): - data_load_name = "valid_" + head - train_valid_data_loader[data_load_name] = valid_loader - test_sets = {} - stop_first_test = False - test_data_loader = {} - if all( - head_config.test_file == head_configs[0].test_file - for head_config in head_configs - ) and head_configs[0].test_file is not None: - stop_first_test = True - if all( - head_config.test_dir == head_configs[0].test_dir - for head_config in head_configs - ) and head_configs[0].test_dir is not None: - stop_first_test = True - for head_config in head_configs: - if all(check_path_ase_read(f) for f in head_config.train_file): - for name, subset in head_config.collections.tests: - test_sets[name] = [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=args.r_max, heads=heads - ) - for config in subset - ] - if head_config.test_dir is not None: - if not args.multi_processed_test: - test_files = get_files_with_suffix(head_config.test_dir, "_test.h5") - for test_file in test_files: - name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = data.HDF5Dataset( - test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name - ) - else: - test_folders = glob(head_config.test_dir + "/*") - for folder in test_folders: - name = os.path.splitext(os.path.basename(test_file))[0] - test_sets[name] = data.dataset_from_sharded_hdf5( - folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name - ) - for test_name, test_set in test_sets.items(): - test_sampler = None - if args.distributed: - test_sampler = torch.utils.data.distributed.DistributedSampler( - test_set, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=True, - seed=args.seed, - ) - try: - drop_last = test_set.drop_last - except AttributeError as e: # pylint: disable=W0612 - drop_last = False - test_loader = torch_geometric.dataloader.DataLoader( - test_set, - batch_size=args.valid_batch_size, - shuffle=(test_sampler is None), - drop_last=drop_last, - num_workers=args.num_workers, - pin_memory=args.pin_memory, - ) - test_data_loader[test_name] = test_loader - if stop_first_test: - break - - for swa_eval in swas: - epoch = checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=swa_eval, - device=device, - ) - model.to(device) - if args.distributed: - distributed_model = DDP(model, device_ids=[local_rank]) - model_to_evaluate = model if not args.distributed else distributed_model - if swa_eval: - logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") - else: - logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") - - if rank == 0: - # Save entire model - if swa_eval: - model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model") - else: - model_path = Path(args.checkpoints_dir) / (tag + ".model") - logging.info(f"Saving model to {model_path}") - model_to_save = deepcopy(model) - if args.enable_cueq: - print("RUNING CUEQ TO E3NN") - print("swa_eval", swa_eval) - model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device) - if args.save_cpu: - model_to_save = model_to_save.to("cpu") - torch.save(model_to_save, model_path) - extra_files = { - "commit.txt": commit.encode("utf-8") if commit is not None else b"", - "config.yaml": json.dumps( - convert_to_json_format(extract_config_mace_model(model)) - ), - } - if swa_eval: - torch.save( - model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model") - ) - try: - path_complied = Path(args.model_dir) / ( - args.name + "_stagetwo_compiled.model" - ) - logging.info(f"Compiling model, saving metadata {path_complied}") - model_compiled = jit.compile(deepcopy(model_to_save)) - torch.jit.save( - model_compiled, - path_complied, - _extra_files=extra_files, - ) - except Exception as e: # pylint: disable=W0718 - pass - else: - torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model")) - try: - path_complied = Path(args.model_dir) / ( - args.name + "_compiled.model" - ) - logging.info(f"Compiling model, saving metadata to {path_complied}") - model_compiled = jit.compile(deepcopy(model_to_save)) - torch.jit.save( - model_compiled, - path_complied, - _extra_files=extra_files, - ) - except Exception as e: # pylint: disable=W0718 - pass - - logging.info("Computing metrics for training, validation, and test sets") - for param in model.parameters(): - param.requires_grad = False - skip_heads = args.skip_evaluate_heads.split(",") if args.skip_evaluate_heads else [] - if skip_heads: - logging.info(f"Skipping evaluation for heads: {skip_heads}") - table_train_valid = create_error_table( - table_type=args.error_table, - all_data_loaders=train_valid_data_loader, - model=model_to_evaluate, - loss_fn=loss_fn, - output_args=output_args, - log_wandb=args.wandb, - device=device, - distributed=args.distributed, - skip_heads=skip_heads, - ) - logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) - - if test_data_loader: - table_test = create_error_table( - table_type=args.error_table, - all_data_loaders=test_data_loader, - model=model_to_evaluate, - loss_fn=loss_fn, - output_args=output_args, - log_wandb=args.wandb, - device=device, - distributed=args.distributed, - ) - logging.info("Error-table on TEST:\n" + str(table_test)) - if args.plot: - try: - plotter = TrainingPlotter( - results_dir=logger.path, - heads=heads, - table_type=args.error_table, - train_valid_data=train_valid_data_loader, - test_data=test_data_loader, - output_args=output_args, - device=device, - plot_frequency=args.plot_frequency, - distributed=args.distributed, - swa_start=swa.start if swa else None - ) - plotter.plot(epoch, model_to_evaluate, rank) - except Exception as e: # pylint: disable=W0718 - logging.debug(f"Plotting failed: {e}") - - if args.distributed: - torch.distributed.barrier() - - logging.info("Done") - if args.distributed: - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - main() +########################################################################################### +# Training script for MACE +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import ast +import glob +import json +import logging +import os +from copy import deepcopy +from pathlib import Path +from typing import List, Optional + +import torch.distributed +import torch.nn.functional +from e3nn.util import jit +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import LBFGS +from torch.utils.data import ConcatDataset +from torch_ema import ExponentialMovingAverage + +import mace +from mace import data, tools +from mace.calculators.foundations_models import mace_mp, mace_off +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.cli.visualise_train import TrainingPlotter +from mace.data import KeySpecification, update_keyspec_from_kwargs +from mace.tools import torch_geometric +from mace.tools.model_script_utils import configure_model +from mace.tools.multihead_tools import ( + HeadConfig, + assemble_mp_data, + dict_head_to_dataclass, + prepare_default_head, + prepare_pt_head, +) +from mace.tools.run_train_utils import ( + combine_datasets, + load_dataset_for_path, + normalize_file_paths, +) +from mace.tools.scripts_utils import ( + LRScheduler, + SubsetCollection, + check_path_ase_read, + convert_to_json_format, + dict_to_array, + extract_config_mace_model, + get_atomic_energies, + get_avg_num_neighbors, + get_config_type_weights, + get_dataset_from_xyz, + get_files_with_suffix, + get_loss_fn, + get_optimizer, + get_params_options, + get_swa, + print_git_commit, + remove_pt_head, + setup_wandb, +) +from mace.tools.slurm_distributed import DistributedEnvironment +from mace.tools.tables_utils import create_error_table +from mace.tools.utils import AtomicNumberTable + + +def main() -> None: + """ + This script runs the training/fine tuning for mace + """ + args = tools.build_default_arg_parser().parse_args() + run(args) + + +def run(args) -> None: + """ + This script runs the training/fine tuning for mace + """ + tag = tools.get_tag(name=args.name, seed=args.seed) + args, input_log_messages = tools.check_args(args) + + # default keyspec to update using heads dictionary + args.key_specification = KeySpecification() + update_keyspec_from_kwargs(args.key_specification, vars(args)) + + if args.device == "xpu": + try: + import intel_extension_for_pytorch as ipex + except ImportError as e: + raise ImportError( + "Error: Intel extension for PyTorch not found, but XPU device was specified" + ) from e + if args.distributed: + try: + distr_env = DistributedEnvironment() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to initialize distributed environment: {e}") + return + world_size = distr_env.world_size + local_rank = distr_env.local_rank + rank = distr_env.rank + if rank == 0: + print(distr_env) + torch.distributed.init_process_group(backend="nccl") + else: + rank = int(0) + + # Setup + tools.set_seeds(args.seed) + tools.setup_logger(level=args.log_level, tag=tag, directory=args.log_dir, rank=rank) + logging.info("===========VERIFYING SETTINGS===========") + for message, loglevel in input_log_messages: + logging.log(level=loglevel, msg=message) + + if args.distributed: + torch.cuda.set_device(local_rank) + logging.info(f"Process group initialized: {torch.distributed.is_initialized()}") + logging.info(f"Processes: {world_size}") + + try: + logging.info(f"MACE version: {mace.__version__}") + except AttributeError: + logging.info("Cannot find MACE version, please install MACE via pip") + logging.debug(f"Configuration: {args}") + + tools.set_default_dtype(args.default_dtype) + device = tools.init_device(args.device) + commit = print_git_commit() + model_foundation: Optional[torch.nn.Module] = None + foundation_model_avg_num_neighbors = 0 + if args.foundation_model is not None: + if args.foundation_model in ["small", "medium", "large"]: + logging.info( + f"Using foundation model mace-mp-0 {args.foundation_model} as initial checkpoint." + ) + calc = mace_mp( + model=args.foundation_model, + device=args.device, + default_dtype=args.default_dtype, + ) + model_foundation = calc.models[0] + elif args.foundation_model in ["small_off", "medium_off", "large_off"]: + model_type = args.foundation_model.split("_")[0] + logging.info( + f"Using foundation model mace-off-2023 {model_type} as initial checkpoint. ASL license." + ) + calc = mace_off( + model=model_type, + device=args.device, + default_dtype=args.default_dtype, + ) + model_foundation = calc.models[0] + else: + model_foundation = torch.load( + args.foundation_model, map_location=args.device + ) + logging.info( + f"Using foundation model {args.foundation_model} as initial checkpoint." + ) + args.r_max = model_foundation.r_max.item() + foundation_model_avg_num_neighbors = model_foundation.interactions[ + 0 + ].avg_num_neighbors + if ( + args.foundation_model not in ["small", "medium", "large"] + and args.pt_train_file is None + ): + logging.warning( + "Using multiheads finetuning with a foundation model that is not a Materials Project model, need to provied a path to a pretraining file with --pt_train_file." + ) + args.multiheads_finetuning = False + if args.multiheads_finetuning: + assert ( + args.E0s != "average" + ), "average atomic energies cannot be used for multiheads finetuning" + # check that the foundation model has a single head, if not, use the first head + if not args.force_mh_ft_lr: + logging.info( + "Multihead finetuning mode, setting learning rate to 0.0001 and EMA to True. To use a different learning rate, set --force_mh_ft_lr=True." + ) + args.lr = 0.0001 + args.ema = True + args.ema_decay = 0.99999 + logging.info( + "Using multiheads finetuning mode, setting learning rate to 0.0001 and EMA to True" + ) + if hasattr(model_foundation, "heads"): + if len(model_foundation.heads) > 1: + logging.warning( + "Mutlihead finetuning with models with more than one head is not supported, using the first head as foundation head." + ) + model_foundation = remove_pt_head( + model_foundation, args.foundation_head + ) + else: + args.multiheads_finetuning = False + + if args.heads is not None: + args.heads = ast.literal_eval(args.heads) + for _, head_dict in args.heads.items(): + # priority is global args < head property_key values < head info_keys+arrays_keys + head_keyspec = deepcopy(args.key_specification) + update_keyspec_from_kwargs(head_keyspec, head_dict) + head_keyspec.update( + info_keys=head_dict.get("info_keys", {}), + arrays_keys=head_dict.get("arrays_keys", {}), + ) + head_dict["key_specification"] = head_keyspec + else: + args.heads = prepare_default_head(args) + if args.multiheads_finetuning: + pt_keyspec = ( + args.heads["pt_head"]["key_specification"] + if "pt_head" in args.heads + else deepcopy(args.key_specification) + ) + args.heads["pt_head"] = prepare_pt_head( + args, pt_keyspec, foundation_model_avg_num_neighbors + ) + + logging.info("===========LOADING INPUT DATA===========") + heads = list(args.heads.keys()) + logging.info(f"Using heads: {heads}") + logging.info("Using the key specifications to parse data:") + for name, head_dict in args.heads.items(): + head_keyspec = head_dict["key_specification"] + logging.info(f"{name}: {head_keyspec}") + + head_configs: List[HeadConfig] = [] + for head, head_args in args.heads.items(): + logging.info(f"============= Processing head {head} ===========") + head_config = dict_head_to_dataclass(head_args, head, args) + + # Handle train_file and valid_file - normalize to lists + if hasattr(head_config, "train_file") and head_config.train_file is not None: + head_config.train_file = normalize_file_paths(head_config.train_file) + if hasattr(head_config, "valid_file") and head_config.valid_file is not None: + head_config.valid_file = normalize_file_paths(head_config.valid_file) + if hasattr(head_config, "test_file") and head_config.test_file is not None: + head_config.test_file = normalize_file_paths(head_config.test_file) + + if ( + head_config.statistics_file is not None + and head_config.head_name != "pt_head" + ): + with open(head_config.statistics_file, "r") as f: # pylint: disable=W1514 + statistics = json.load(f) + logging.info("Using statistics json file") + head_config.atomic_numbers = statistics["atomic_numbers"] + head_config.mean = statistics["mean"] + head_config.std = statistics["std"] + head_config.avg_num_neighbors = statistics["avg_num_neighbors"] + head_config.compute_avg_num_neighbors = False + if isinstance(statistics["atomic_energies"], str) and statistics[ + "atomic_energies" + ].endswith(".json"): + with open(statistics["atomic_energies"], "r", encoding="utf-8") as f: + atomic_energies = json.load(f) + head_config.E0s = atomic_energies + head_config.atomic_energies_dict = ast.literal_eval(atomic_energies) + else: + head_config.E0s = statistics["atomic_energies"] + head_config.atomic_energies_dict = ast.literal_eval( + statistics["atomic_energies"] + ) + if head_config.train_file == ["mp"]: + assert ( + head_config.head_name == "pt_head" + ), "Only pt_head should use mp as train_file" + logging.info( + "Using the full Materials Project data for replay. You can construct a different subset using `fine_tuning_select.py` script." + ) + collections = assemble_mp_data(args, head_config, tag) + head_config.collections = collections + elif any(check_path_ase_read(f) for f in head_config.train_file): + train_files_ase_list = [ + f for f in head_config.train_file if check_path_ase_read(f) + ] + valid_files_ase_list = None + test_files_ase_list = None + if head_config.valid_file: + valid_files_ase_list = [ + f for f in head_config.valid_file if check_path_ase_read(f) + ] + if head_config.test_file: + test_files_ase_list = [ + f for f in head_config.test_file if check_path_ase_read(f) + ] + config_type_weights = get_config_type_weights( + head_config.config_type_weights + ) + collections, atomic_energies_dict = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=train_files_ase_list, + valid_path=valid_files_ase_list, + valid_fraction=head_config.valid_fraction, + config_type_weights=config_type_weights, + test_path=test_files_ase_list, + seed=args.seed, + key_specification=head_config.key_specification, + head_name=head_config.head_name, + keep_isolated_atoms=head_config.keep_isolated_atoms, + ) + head_config.collections = SubsetCollection( + train=collections.train, + valid=collections.valid, + tests=collections.tests, + ) + head_config.atomic_energies_dict = atomic_energies_dict + logging.info( + f"Total number of configurations: train={len(collections.train)}, valid={len(collections.valid)}, " + f"tests=[{', '.join([name + ': ' + str(len(test_configs)) for name, test_configs in collections.tests])}]," + ) + head_configs.append(head_config) + + if all( + check_path_ase_read(head_config.train_file[0]) for head_config in head_configs + ): + size_collections_train = sum( + len(head_config.collections.train) for head_config in head_configs + ) + size_collections_valid = sum( + len(head_config.collections.valid) for head_config in head_configs + ) + if size_collections_train < args.batch_size: + logging.error( + f"Batch size ({args.batch_size}) is larger than the number of training data ({size_collections_train})" + ) + if size_collections_valid < args.valid_batch_size: + logging.warning( + f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({size_collections_valid})" + ) + + if args.multiheads_finetuning: + logging.info( + "==================Using multiheads finetuning mode==================" + ) + args.loss = "universal" + + all_ase_readable = all( + all(check_path_ase_read(f) for f in head_config.train_file) + for head_config in head_configs + ) + head_config_pt = filter(lambda x: x.head_name == "pt_head", head_configs) + head_config_pt = next(head_config_pt, None) + assert head_config_pt is not None, "Pretraining head not found" + if all_ase_readable: + ratio_pt_ft = size_collections_train / len(head_config_pt.collections.train) + if ratio_pt_ft < 0.1: + logging.warning( + f"Ratio of the number of configurations in the training set and the in the pt_train_file is {ratio_pt_ft}, " + f"increasing the number of configurations in the fine-tuning heads by {int(0.1 / ratio_pt_ft)}" + ) + for head_config in head_configs: + if head_config.head_name == "pt_head": + continue + head_config.collections.train += ( + head_config.collections.train * int(0.1 / ratio_pt_ft) + ) + logging.info( + f"Total number of configurations in pretraining: train={len(head_config_pt.collections.train)}, valid={len(head_config_pt.collections.valid)}" + ) + else: + logging.debug( + "Using LMDB/HDF5 datasets for pretraining or fine-tuning - skipping ratio check" + ) + + # Atomic number table + # yapf: disable + for head_config in head_configs: + if head_config.atomic_numbers is None: + assert all(check_path_ase_read(f) for f in head_config.train_file), "Must specify atomic_numbers when using .h5 or .aselmdb train_file input" + z_table_head = tools.get_atomic_number_table_from_zs( + z + for configs in (head_config.collections.train, head_config.collections.valid) + for config in configs + for z in config.atomic_numbers + ) + head_config.atomic_numbers = z_table_head.zs + head_config.z_table = z_table_head + else: + if head_config.statistics_file is None: + logging.info("Using atomic numbers from command line argument") + else: + logging.info("Using atomic numbers from statistics file") + zs_list = ast.literal_eval(head_config.atomic_numbers) + assert isinstance(zs_list, list) + z_table_head = tools.AtomicNumberTable(zs_list) + head_config.atomic_numbers = zs_list + head_config.z_table = z_table_head + # yapf: enable + all_atomic_numbers = set() + for head_config in head_configs: + all_atomic_numbers.update(head_config.atomic_numbers) + z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) + if args.foundation_model_elements and model_foundation: + z_table = AtomicNumberTable(sorted(model_foundation.atomic_numbers.tolist())) + logging.info(f"Atomic Numbers used: {z_table.zs}") + + # Atomic energies + atomic_energies_dict = {} + for head_config in head_configs: + if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: + assert head_config.E0s is not None, "Atomic energies must be provided" + if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation": + atomic_energies_dict[head_config.head_name] = get_atomic_energies( + head_config.E0s, head_config.collections.train, head_config.z_table + ) + elif head_config.E0s.lower() == "foundation": + assert args.foundation_model is not None + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") + atomic_energies_dict[head_config.head_name] = { + z: foundation_atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } + else: + atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) + else: + atomic_energies_dict[head_config.head_name] = head_config.atomic_energies_dict + + # Atomic energies for multiheads finetuning + if args.multiheads_finetuning: + assert ( + model_foundation is not None + ), "Model foundation must be provided for multiheads finetuning" + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head as foundation E0s.") + atomic_energies_dict["pt_head"] = { + z: foundation_atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table.zs + } + heads = sorted(heads, key=lambda x: -1000 if x == "pt_head" else 0) + # Padding atomic energies if keeping all elements of the foundation model + if args.foundation_model_elements and model_foundation: + atomic_energies_dict_padded = {} + for head_name, head_energies in atomic_energies_dict.items(): + energy_head_padded = {} + for z in z_table.zs: + energy_head_padded[z] = head_energies.get(z, 0.0) + atomic_energies_dict_padded[head_name] = energy_head_padded + atomic_energies_dict = atomic_energies_dict_padded + + if args.model == "AtomicDipolesMACE": + atomic_energies = None + dipole_only = True + args.compute_dipole = True + args.compute_energy = False + args.compute_forces = False + args.compute_virials = False + args.compute_stress = False + else: + dipole_only = False + if args.model == "EnergyDipolesMACE": + args.compute_dipole = True + args.compute_energy = True + args.compute_forces = True + args.compute_virials = False + args.compute_stress = False + else: + args.compute_energy = True + args.compute_dipole = False + # atomic_energies: np.ndarray = np.array( + # [atomic_energies_dict[z] for z in z_table.zs] + # ) + atomic_energies = dict_to_array(atomic_energies_dict, heads) + for head_config in head_configs: + try: + logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") + except KeyError as e: + raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e + + # Load datasets for each head, supporting multiple files per head + valid_sets = {head: [] for head in heads} + train_sets = {head: [] for head in heads} + + for head_config in head_configs: + train_datasets = [] + + logging.info(f"Processing datasets for head '{head_config.head_name}'") + ase_files = [f for f in head_config.train_file if check_path_ase_read(f)] + non_ase_files = [f for f in head_config.train_file if not check_path_ase_read(f)] + + if ase_files: + dataset = load_dataset_for_path( + file_path=ase_files, + r_max=args.r_max, + z_table=z_table, + head_config=head_config, + heads=heads, + collection=head_config.collections.train, + ) + train_datasets.append(dataset) + logging.debug(f"Successfully loaded dataset from ASE files: {ase_files}") + + for file in non_ase_files: + dataset = load_dataset_for_path( + file_path=file, + r_max=args.r_max, + z_table=z_table, + head_config=head_config, + heads=heads, + ) + train_datasets.append(dataset) + logging.debug(f"Successfully loaded dataset from non-ASE file: {file}") + + if not train_datasets: + raise ValueError(f"No valid training datasets found for head {head_config.head_name}") + + train_sets[head_config.head_name] = combine_datasets(train_datasets, head_config.head_name) + + if head_config.valid_file: + valid_datasets = [] + + valid_ase_files = [f for f in head_config.valid_file if check_path_ase_read(f)] + valid_non_ase_files = [f for f in head_config.valid_file if not check_path_ase_read(f)] + + if valid_ase_files: + valid_dataset = load_dataset_for_path( + file_path=valid_ase_files, + r_max=args.r_max, + z_table=z_table, + head_config=head_config, + heads=heads, + collection=head_config.collections.valid, + ) + valid_datasets.append(valid_dataset) + logging.debug(f"Successfully loaded validation dataset from ASE files: {valid_ase_files}") + for valid_file in valid_non_ase_files: + valid_dataset = load_dataset_for_path( + file_path=valid_file, + r_max=args.r_max, + z_table=z_table, + head_config=head_config, + heads=heads, + ) + valid_datasets.append(valid_dataset) + logging.debug(f"Successfully loaded validation dataset from {valid_file}") + + # Combine validation datasets + if valid_datasets: + valid_sets[head_config.head_name] = combine_datasets(valid_datasets, f"{head_config.head_name}_valid") + logging.info(f"Combined validation datasets for {head_config.head_name}") + + # If no valid file is provided but collection exist, use the validation set from the collection + if head_config.valid_file is None and head_config.collections.valid: + valid_sets[head_config.head_name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in head_config.collections.valid + ] + if not valid_sets[head_config.head_name]: + raise ValueError(f"No valid datasets found for head {head_config.head_name}, please provide a valid_file or a valid_fraction") + + # Create data loader for this head + if isinstance(train_sets[head_config.head_name], list): + dataset_size = len(train_sets[head_config.head_name]) + else: + dataset_size = len(train_sets[head_config.head_name]) + logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}") + + train_loader_head = torch_geometric.dataloader.DataLoader( + dataset=train_sets[head_config.head_name], + batch_size=args.batch_size, + shuffle=True, + drop_last=(not args.lbfgs), + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + head_config.train_loader = train_loader_head + + # concatenate all the trainsets + train_set = ConcatDataset([train_sets[head] for head in heads]) + train_sampler, valid_sampler = None, None + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=(not args.lbfgs), + seed=args.seed, + ) + valid_samplers = {} + for head, valid_set in valid_sets.items(): + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + valid_samplers[head] = valid_sampler + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_set, + batch_size=args.batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + drop_last=(train_sampler is None and not args.lbfgs), + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + valid_loaders = {heads[i]: None for i in range(len(heads))} + if not isinstance(valid_sets, dict): + valid_sets = {"Default": valid_sets} + for head, valid_set in valid_sets.items(): + valid_loaders[head] = torch_geometric.dataloader.DataLoader( + dataset=valid_set, + batch_size=args.valid_batch_size, + sampler=valid_samplers[head] if args.distributed else None, + shuffle=False, + drop_last=False, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + generator=torch.Generator().manual_seed(args.seed), + ) + + loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) + args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) + + # Model + model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs) + model.to(device) + + logging.debug(model) + logging.info(f"Total number of parameters: {tools.count_parameters(model)}") + logging.info("") + logging.info("===========OPTIMIZER INFORMATION===========") + logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") + logging.info(f"Batch size: {args.batch_size}") + if args.ema: + logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}") + logging.info( + f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}" + ) + logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}") + logging.info(loss_fn) + + # Cueq + if args.enable_cueq: + logging.info("Converting model to CUEQ for accelerated training") + assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE"] + model = run_e3nn_to_cueq(deepcopy(model), device=device) + # Optimizer + param_options = get_params_options(args, model) + optimizer: torch.optim.Optimizer + optimizer = get_optimizer(args, param_options) + if args.device == "xpu": + logging.info("Optimzing model and optimzier for XPU") + model, optimizer = ipex.optimize(model, optimizer=optimizer) + logger = tools.MetricsLogger( + directory=args.results_dir, tag=tag + "_train" + ) # pylint: disable=E1123 + + lr_scheduler = LRScheduler(optimizer, args) + + swa: Optional[tools.SWAContainer] = None + swas = [False] + if args.swa: + swa, swas = get_swa(args, model, optimizer, swas, dipole_only) + + checkpoint_handler = tools.CheckpointHandler( + directory=args.checkpoints_dir, + tag=tag, + keep=args.keep_checkpoints, + swa_start=args.start_swa, + ) + + start_epoch = 0 + restart_lbfgs = False + opt_start_epoch = None + if args.restart_latest: + try: + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=True, + device=device, + ) + except Exception: # pylint: disable=W0703 + try: + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + device=device, + ) + except Exception: # pylint: disable=W0703 + restart_lbfgs = True + if opt_start_epoch is not None: + start_epoch = opt_start_epoch + + ema: Optional[ExponentialMovingAverage] = None + if args.ema: + ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) + else: + for group in optimizer.param_groups: + group["lr"] = args.lr + + if args.lbfgs: + logging.info("Switching optimizer to LBFGS") + optimizer = LBFGS(model.parameters(), + history_size=200, + max_iter=20, + line_search_fn="strong_wolfe") + if restart_lbfgs: + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + device=device, + ) + if opt_start_epoch is not None: + start_epoch = opt_start_epoch + + if args.wandb: + setup_wandb(args) + if args.distributed: + distributed_model = DDP(model, device_ids=[local_rank]) + else: + distributed_model = None + + + train_valid_data_loader = {} + for head_config in head_configs: + data_loader_name = "train_" + head_config.head_name + train_valid_data_loader[data_loader_name] = head_config.train_loader + for head, valid_loader in valid_loaders.items(): + data_load_name = "valid_" + head + train_valid_data_loader[data_load_name] = valid_loader + + if args.plot and args.plot_frequency > 0: + try: + plotter = TrainingPlotter( + results_dir=logger.path, + heads=heads, + table_type=args.error_table, + train_valid_data=train_valid_data_loader, + test_data={}, + output_args=output_args, + device=device, + plot_frequency=args.plot_frequency, + distributed=args.distributed, + swa_start=swa.start if swa else None + ) + except Exception as e: # pylint: disable=W0718 + logging.debug(f"Creating Plotter failed: {e}") + else: + plotter = None + + if args.dry_run: + logging.info("DRY RUN mode enabled. Stopping now.") + return + + + tools.train( + model=model, + loss_fn=loss_fn, + train_loader=train_loader, + valid_loaders=valid_loaders, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + checkpoint_handler=checkpoint_handler, + eval_interval=args.eval_interval, + start_epoch=start_epoch, + max_num_epochs=args.max_num_epochs, + logger=logger, + patience=args.patience, + save_all_checkpoints=args.save_all_checkpoints, + output_args=output_args, + device=device, + swa=swa, + ema=ema, + max_grad_norm=args.clip_grad, + log_errors=args.error_table, + log_wandb=args.wandb, + distributed=args.distributed, + distributed_model=distributed_model, + plotter=plotter, + train_sampler=train_sampler, + rank=rank, + ) + + logging.info("") + logging.info("===========RESULTS===========") + + train_valid_data_loader = {} + for head_config in head_configs: + data_loader_name = "train_" + head_config.head_name + train_valid_data_loader[data_loader_name] = head_config.train_loader + for head, valid_loader in valid_loaders.items(): + data_load_name = "valid_" + head + train_valid_data_loader[data_load_name] = valid_loader + test_sets = {} + stop_first_test = False + test_data_loader = {} + if all( + head_config.test_file == head_configs[0].test_file + for head_config in head_configs + ) and head_configs[0].test_file is not None: + stop_first_test = True + if all( + head_config.test_dir == head_configs[0].test_dir + for head_config in head_configs + ) and head_configs[0].test_dir is not None: + stop_first_test = True + for head_config in head_configs: + if all(check_path_ase_read(f) for f in head_config.train_file): + for name, subset in head_config.collections.tests: + test_sets[name] = [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=args.r_max, heads=heads + ) + for config in subset + ] + if head_config.test_dir is not None: + if not args.multi_processed_test: + test_files = get_files_with_suffix(head_config.test_dir, "_test.h5") + for test_file in test_files: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.HDF5Dataset( + test_file, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + else: + test_folders = glob(head_config.test_dir + "/*") + for folder in test_folders: + name = os.path.splitext(os.path.basename(test_file))[0] + test_sets[name] = data.dataset_from_sharded_hdf5( + folder, r_max=args.r_max, z_table=z_table, heads=heads, head=head_config.head_name + ) + for test_name, test_set in test_sets.items(): + test_sampler = None + if args.distributed: + test_sampler = torch.utils.data.distributed.DistributedSampler( + test_set, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=True, + seed=args.seed, + ) + try: + drop_last = test_set.drop_last + except AttributeError as e: # pylint: disable=W0612 + drop_last = False + test_loader = torch_geometric.dataloader.DataLoader( + test_set, + batch_size=args.valid_batch_size, + shuffle=(test_sampler is None), + drop_last=drop_last, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + ) + test_data_loader[test_name] = test_loader + if stop_first_test: + break + + for swa_eval in swas: + epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=swa_eval, + device=device, + ) + model.to(device) + if args.distributed: + distributed_model = DDP(model, device_ids=[local_rank]) + model_to_evaluate = model if not args.distributed else distributed_model + if swa_eval: + logging.info(f"Loaded Stage two model from epoch {epoch} for evaluation") + else: + logging.info(f"Loaded Stage one model from epoch {epoch} for evaluation") + + if rank == 0: + # Save entire model + if swa_eval: + model_path = Path(args.checkpoints_dir) / (tag + "_stagetwo.model") + else: + model_path = Path(args.checkpoints_dir) / (tag + ".model") + logging.info(f"Saving model to {model_path}") + model_to_save = deepcopy(model) + if args.enable_cueq: + print("RUNING CUEQ TO E3NN") + print("swa_eval", swa_eval) + model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device) + if args.save_cpu: + model_to_save = model_to_save.to("cpu") + torch.save(model_to_save, model_path) + extra_files = { + "commit.txt": commit.encode("utf-8") if commit is not None else b"", + "config.yaml": json.dumps( + convert_to_json_format(extract_config_mace_model(model)) + ), + } + if swa_eval: + torch.save( + model_to_save, Path(args.model_dir) / (args.name + "_stagetwo.model") + ) + try: + path_complied = Path(args.model_dir) / ( + args.name + "_stagetwo_compiled.model" + ) + logging.info(f"Compiling model, saving metadata {path_complied}") + model_compiled = jit.compile(deepcopy(model_to_save)) + torch.jit.save( + model_compiled, + path_complied, + _extra_files=extra_files, + ) + except Exception as e: # pylint: disable=W0718 + pass + else: + torch.save(model_to_save, Path(args.model_dir) / (args.name + ".model")) + try: + path_complied = Path(args.model_dir) / ( + args.name + "_compiled.model" + ) + logging.info(f"Compiling model, saving metadata to {path_complied}") + model_compiled = jit.compile(deepcopy(model_to_save)) + torch.jit.save( + model_compiled, + path_complied, + _extra_files=extra_files, + ) + except Exception as e: # pylint: disable=W0718 + pass + + logging.info("Computing metrics for training, validation, and test sets") + for param in model.parameters(): + param.requires_grad = False + skip_heads = args.skip_evaluate_heads.split(",") if args.skip_evaluate_heads else [] + if skip_heads: + logging.info(f"Skipping evaluation for heads: {skip_heads}") + table_train_valid = create_error_table( + table_type=args.error_table, + all_data_loaders=train_valid_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + skip_heads=skip_heads, + ) + logging.info("Error-table on TRAIN and VALID:\n" + str(table_train_valid)) + + if test_data_loader: + table_test = create_error_table( + table_type=args.error_table, + all_data_loaders=test_data_loader, + model=model_to_evaluate, + loss_fn=loss_fn, + output_args=output_args, + log_wandb=args.wandb, + device=device, + distributed=args.distributed, + ) + logging.info("Error-table on TEST:\n" + str(table_test)) + if args.plot: + try: + plotter = TrainingPlotter( + results_dir=logger.path, + heads=heads, + table_type=args.error_table, + train_valid_data=train_valid_data_loader, + test_data=test_data_loader, + output_args=output_args, + device=device, + plot_frequency=args.plot_frequency, + distributed=args.distributed, + swa_start=swa.start if swa else None + ) + plotter.plot(epoch, model_to_evaluate, rank) + except Exception as e: # pylint: disable=W0718 + logging.debug(f"Plotting failed: {e}") + + if args.distributed: + torch.distributed.barrier() + + logging.info("Done") + if args.distributed: + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/select_head.py b/mace-bench/3rdparty/mace/mace/cli/select_head.py index adc568d..305ab52 100644 --- a/mace-bench/3rdparty/mace/mace/cli/select_head.py +++ b/mace-bench/3rdparty/mace/mace/cli/select_head.py @@ -1,60 +1,60 @@ -from argparse import ArgumentParser - -import torch - -from mace.tools.scripts_utils import remove_pt_head - - -def main(): - parser = ArgumentParser() - grp = parser.add_mutually_exclusive_group() - grp.add_argument( - "--head_name", - "-n", - help="name of the head to extract", - default=None, - ) - grp.add_argument( - "--list_heads", - "-l", - action="store_true", - help="list names of the heads", - ) - parser.add_argument( - "--target_device", - "-d", - help="target device, defaults to model's current device", - ) - parser.add_argument( - "--output_file", - "-o", - help="name for output model, defaults to model.head_name, followed by .target_device if specified", - ) - parser.add_argument("model_file", help="input model file path") - args = parser.parse_args() - - model = torch.load(args.model_file, map_location=args.target_device) - torch.set_default_dtype(next(model.parameters()).dtype) - - if args.list_heads: - print("Available heads:") - print("\n".join([" " + h for h in model.heads])) - else: - - if args.output_file is None: - args.output_file = ( - args.model_file - + "." - + args.head_name - + ("." + args.target_device if (args.target_device is not None) else "") - ) - - model_single = remove_pt_head(model, args.head_name) - if args.target_device is not None: - target_device = str(next(model.parameters()).device) - model_single.to(target_device) - torch.save(model_single, args.output_file) - - -if __name__ == "__main__": - main() +from argparse import ArgumentParser + +import torch + +from mace.tools.scripts_utils import remove_pt_head + + +def main(): + parser = ArgumentParser() + grp = parser.add_mutually_exclusive_group() + grp.add_argument( + "--head_name", + "-n", + help="name of the head to extract", + default=None, + ) + grp.add_argument( + "--list_heads", + "-l", + action="store_true", + help="list names of the heads", + ) + parser.add_argument( + "--target_device", + "-d", + help="target device, defaults to model's current device", + ) + parser.add_argument( + "--output_file", + "-o", + help="name for output model, defaults to model.head_name, followed by .target_device if specified", + ) + parser.add_argument("model_file", help="input model file path") + args = parser.parse_args() + + model = torch.load(args.model_file, map_location=args.target_device) + torch.set_default_dtype(next(model.parameters()).dtype) + + if args.list_heads: + print("Available heads:") + print("\n".join([" " + h for h in model.heads])) + else: + + if args.output_file is None: + args.output_file = ( + args.model_file + + "." + + args.head_name + + ("." + args.target_device if (args.target_device is not None) else "") + ) + + model_single = remove_pt_head(model, args.head_name) + if args.target_device is not None: + target_device = str(next(model.parameters()).device) + model_single.to(target_device) + torch.save(model_single, args.output_file) + + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/mace/cli/visualise_train.py b/mace-bench/3rdparty/mace/mace/cli/visualise_train.py index 679d863..c380c70 100644 --- a/mace-bench/3rdparty/mace/mace/cli/visualise_train.py +++ b/mace-bench/3rdparty/mace/mace/cli/visualise_train.py @@ -1,640 +1,640 @@ -import json -import logging -from typing import Dict, List, Optional - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -import torch.distributed -from torchmetrics import Metric - -plt.rcParams.update({"font.size": 8}) -mpl_logger = logging.getLogger("matplotlib") -mpl_logger.setLevel(logging.WARNING) # Only show WARNING and above - -colors = [ - "#1f77b4", # muted blue - "#d62728", # brick red - "#7f7f7f", # middle gray - "#2ca02c", # cooked asparagus green - "#ff7f0e", # safety orange - "#9467bd", # muted purple - "#8c564b", # chestnut brown - "#e377c2", # raspberry yogurt pink - "#bcbd22", # curry yellow-green - "#17becf", # blue-teal -] - -error_type = { - "TotalRMSE": ( - [("rmse_e", "RMSE E [meV]"), ("rmse_f", "RMSE F [meV / A]")], - [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], - ), - "PerAtomRMSE": ( - [("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]")], - [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], - ), - "PerAtomRMSEstressvirials": ( - [ - ("rmse_e_per_atom", "RMSE E/atom [meV]"), - ("rmse_f", "RMSE F [meV / A]"), - ("rmse_stress", "RMSE Stress [meV / A^3]"), - ], - [ - ("energy", "Energy per atom [eV]"), - ("force", "Force [eV / A]"), - ("stress", "Stress [eV / A^3]"), - ], - ), - "PerAtomMAEstressvirials": ( - [ - ("mae_e_per_atom", "MAE E/atom [meV]"), - ("mae_f", "MAE F [meV / A]"), - ("mae_stress", "MAE Stress [meV / A^3]"), - ], - [ - ("energy", "Energy per atom [eV]"), - ("force", "Force [eV / A]"), - ("stress", "Stress [eV / A^3]"), - ], - ), - "TotalMAE": ( - [("mae_e", "MAE E [meV]"), ("mae_f", "MAE F [meV / A]")], - [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], - ), - "PerAtomMAE": ( - [("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]")], - [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], - ), - "DipoleRMSE": ( - [ - ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), - ("rel_rmse_f", "Relative MU RMSE [%]"), - ], - [("dipole", "Dipole per atom [Debye]")], - ), - "DipoleMAE": ( - [("mae_mu", "MAE MU [mDebye]"), ("rel_mae_f", "Relative MU MAE [%]")], - [("dipole", "Dipole per atom [Debye]")], - ), - "EnergyDipoleRMSE": ( - [ - ("rmse_e_per_atom", "RMSE E/atom [meV]"), - ("rmse_f", "RMSE F [meV / A]"), - ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), - ], - [ - ("energy", "Energy per atom [eV]"), - ("force", "Force [eV / A]"), - ("dipole", "Dipole per atom [Debye]"), - ], - ), -} - - -class TrainingPlotter: - def __init__( - self, - results_dir: str, - heads: List[str], - table_type: str, - train_valid_data: Dict, - test_data: Dict, - output_args: str, - device: str, - plot_frequency: int, - distributed: bool = False, - swa_start: Optional[int] = None, - ): - self.results_dir = results_dir - self.heads = heads - self.table_type = table_type - self.train_valid_data = train_valid_data - self.test_data = test_data - self.output_args = output_args - self.device = device - self.plot_frequency = plot_frequency - self.distributed = distributed - self.swa_start = swa_start - - def plot(self, model_epoch: str, model: torch.nn.Module, rank: int) -> None: - - # All ranks process data through model_inference - train_valid_dict = model_inference( - self.train_valid_data, - model, - self.output_args, - self.device, - self.distributed, - ) - test_dict = model_inference( - self.test_data, model, self.output_args, self.device, self.distributed - ) - - # Only rank 0 creates and saves plots - if rank != 0: - return - - data = pd.DataFrame( - results for results in parse_training_results(self.results_dir) - ) - labels, quantities = error_type[self.table_type] - - for head in self.heads: - fig = plt.figure(layout="constrained", figsize=(10, 6)) - fig.suptitle( - f"Model loaded from epoch {model_epoch} ({head} head)", fontsize=16 - ) - - subfigs = fig.subfigures(2, 1, height_ratios=[1, 1], hspace=0.05) - axsTop = subfigs[0].subplots(1, 2, sharey=False) - axsBottom = subfigs[1].subplots(1, len(quantities), sharey=False) - - plot_epoch_dependence(axsTop, data, head, model_epoch, labels) - - # Use the pre-computed results for plotting - plot_inference_from_results( - axsBottom, train_valid_dict, test_dict, head, quantities - ) - - if self.swa_start is not None: - # Add vertical lines to both axes - for ax in axsTop: - ax.axvline( - self.swa_start, - color="black", - linestyle="dashed", - linewidth=1, - alpha=0.6, - label="Stage Two Starts", - ) - stage = "stage_two" if self.swa_start < model_epoch else "stage_one" - else: - stage = "stage_one" - axsTop[0].legend(loc="best") - # Save the figure using the appropriate stage in the filename - filename = f"{self.results_dir[:-4]}_{head}_{stage}.png" - - fig.savefig(filename, dpi=300, bbox_inches="tight") - plt.close(fig) - - -def parse_training_results(path: str) -> List[dict]: - results = [] - with open(path, mode="r", encoding="utf-8") as f: - for line in f: - try: - d = json.loads(line.strip()) # Ensure it's valid JSON - results.append(d) - except json.JSONDecodeError: - print( - f"Skipping invalid line: {line.strip()}" - ) # Handle non-JSON lines gracefully - return results - - -def plot_epoch_dependence( - axes: np.ndarray, data: pd.DataFrame, head: str, model_epoch: str, labels: List[str] -) -> None: - - valid_data = ( - data[data["mode"] == "eval"] - .groupby(["mode", "epoch", "head"]) - .agg(["mean", "std"]) - .reset_index() - ) - valid_data = valid_data[valid_data["head"] == head] - train_data = ( - data[data["mode"] == "opt"] - .groupby(["mode", "epoch"]) - .agg(["mean", "std"]) - .reset_index() - ) - - # ---- Plot loss ---- - ax = axes[0] - ax.plot( - train_data["epoch"], train_data["loss"]["mean"], color=colors[1], linewidth=1 - ) - ax.set_ylabel("Training Loss", color=colors[1]) - ax.set_yscale("log") - - ax2 = ax.twinx() - ax2.plot( - valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], linewidth=1 - ) - ax2.set_ylabel("Validation Loss", color=colors[0]) - ax2.set_yscale("log") - - ax.axvline( - model_epoch, - color="black", - linestyle="solid", - linewidth=1, - alpha=0.8, - label="Loaded Model", - ) - ax.set_xlabel("Epoch") - ax.grid(True, linestyle="--", alpha=0.5) - - # ---- Plot selected keys ---- - ax = axes[1] - twin_axes = [] - for i, label in enumerate(labels): - color = colors[(i + 3)] - key, axis_label = label - if i == 0: - main_ax = ax - else: - main_ax = ax.twinx() - main_ax.spines.right.set_position(("outward", 60 * (i - 1))) - twin_axes.append(main_ax) - - main_ax.plot( - valid_data["epoch"], - valid_data[key]["mean"] * 1e3, - color=color, - label=label, - linewidth=1, - ) - main_ax.set_yscale("log") - main_ax.set_ylabel(axis_label, color=color) - main_ax.tick_params(axis="y", colors=color) - ax.axvline( - model_epoch, - color="black", - linestyle="solid", - linewidth=1, - alpha=0.8, - label="Loaded Model", - ) - ax.set_xlabel("Epoch") - ax.grid(True, linestyle="--", alpha=0.5) - - -# INFERENCE========= - - -def plot_inference_from_results( - axes: np.ndarray, - train_valid_dict: dict, - test_dict: dict, - head: str, - quantities: List[str], -) -> None: - - for ax, quantity in zip(axes, quantities): - key, label = quantity - - # Store legend handles to avoid duplicates - legend_labels = {} - - # Plot train/valid data (each entry keeps its own name) - for name, result in train_valid_dict.items(): - if "train" in name: - fixed_color_train_valid = colors[1] - marker = "x" - else: - fixed_color_train_valid = colors[0] - marker = "+" - if head not in name: - continue - - # Initialize scatter to None - scatter = None - - if key == "energy" and "energy" in result: - scatter = ax.scatter( - result["energy"]["reference_per_atom"], - result["energy"]["predicted_per_atom"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - elif key == "force" and "forces" in result: - scatter = ax.scatter( - result["forces"]["reference"], - result["forces"]["predicted"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - elif key == "stress" and "stress" in result: - scatter = ax.scatter( - result["stress"]["reference"], - result["stress"]["predicted"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - elif key == "virials" and "virials" in result: - scatter = ax.scatter( - result["virials"]["reference_per_atom"], - result["virials"]["predicted_per_atom"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - elif key == "dipole" and "dipole" in result: - scatter = ax.scatter( - result["dipole"]["reference_per_atom"], - result["dipole"]["predicted_per_atom"], - marker=marker, - color=fixed_color_train_valid, - label=name, - ) - - # Add each train/valid dataset's name to the legend if scatter was assigned - if scatter is not None: - legend_labels[name] = scatter - - fixed_color_test = colors[2] # Color for test dataset - - # Plot test data (single legend entry) - for name, result in test_dict.items(): - # Initialize scatter to None to avoid possibly used before assignment - scatter = None - - if key == "energy" and "energy" in result: - scatter = ax.scatter( - result["energy"]["reference_per_atom"], - result["energy"]["predicted_per_atom"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - elif key == "force" and "forces" in result: - scatter = ax.scatter( - result["forces"]["reference"], - result["forces"]["predicted"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - elif key == "stress" and "stress" in result: - scatter = ax.scatter( - result["stress"]["reference"], - result["stress"]["predicted"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - elif key == "virials" and "virials" in result: - scatter = ax.scatter( - result["virials"]["reference_per_atom"], - result["virials"]["predicted_per_atom"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - elif key == "dipole" and "dipole" in result: - scatter = ax.scatter( - result["dipole"]["reference_per_atom"], - result["dipole"]["predicted_per_atom"], - marker="o", - color=fixed_color_test, - label="Test", - ) - - # Only add to legend_labels if scatter was assigned - if scatter is not None: - legend_labels["Test"] = scatter - - # Add diagonal line for guide - min_val = min(ax.get_xlim()[0], ax.get_ylim()[0]) - max_val = max(ax.get_xlim()[1], ax.get_ylim()[1]) - ax.plot( - [min_val, max_val], - [min_val, max_val], - linestyle="--", - color="black", - alpha=0.7, - ) - - # Set legend with unique entries (Test + individual train/valid names) - if legend_labels: - ax.legend( - handles=legend_labels.values(), labels=legend_labels.keys(), loc="best" - ) - ax.set_xlabel(f"Reference {label}") - ax.set_ylabel(f"MACE {label}") - ax.grid(True, linestyle="--", alpha=0.5) - - -def model_inference( - all_data_loaders: dict, - model: torch.nn.Module, - output_args: Dict[str, bool], - device: str, - distributed: bool = False, -): - - for param in model.parameters(): - param.requires_grad = False - - results_dict = {} - - for name in all_data_loaders: - data_loader = all_data_loaders[name] - logging.debug(f"Running inference on {name} dataset") - scatter_metric = InferenceMetric().to(device) - - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=False, - compute_force=output_args.get("forces", False), - compute_virials=output_args.get("virials", False), - compute_stress=output_args.get("stress", False), - ) - - results = scatter_metric(batch, output) - - if distributed: - torch.distributed.barrier() - - results = scatter_metric.compute() - results_dict[name] = results - scatter_metric.reset() - - del data_loader - - for param in model.parameters(): - param.requires_grad = True - - return results_dict - - -def to_numpy(tensor: torch.Tensor) -> np.ndarray: - return tensor.cpu().detach().numpy() - - -class InferenceMetric(Metric): - """Metric class for collecting reference and predicted values for scatterplot visualization.""" - - def __init__(self): - super().__init__() - # Raw values - self.add_state("ref_energies", default=[], dist_reduce_fx="cat") - self.add_state("pred_energies", default=[], dist_reduce_fx="cat") - self.add_state("ref_forces", default=[], dist_reduce_fx="cat") - self.add_state("pred_forces", default=[], dist_reduce_fx="cat") - self.add_state("ref_stress", default=[], dist_reduce_fx="cat") - self.add_state("pred_stress", default=[], dist_reduce_fx="cat") - self.add_state("ref_virials", default=[], dist_reduce_fx="cat") - self.add_state("pred_virials", default=[], dist_reduce_fx="cat") - self.add_state("ref_dipole", default=[], dist_reduce_fx="cat") - self.add_state("pred_dipole", default=[], dist_reduce_fx="cat") - - # Per-atom normalized values - self.add_state("ref_energies_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("pred_energies_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("ref_virials_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("pred_virials_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("ref_dipole_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("pred_dipole_per_atom", default=[], dist_reduce_fx="cat") - - # Store atom counts for each configuration - self.add_state("atom_counts", default=[], dist_reduce_fx="cat") - - # Counters - self.add_state("n_energy", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_forces", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_stress", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_virials", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("n_dipole", default=torch.tensor(0.0), dist_reduce_fx="sum") - - def update(self, batch, output): # pylint: disable=arguments-differ - """Update metric states with new batch data.""" - # Calculate number of atoms per configuration - atoms_per_config = batch.ptr[1:] - batch.ptr[:-1] - self.atom_counts.append(atoms_per_config) - - # Energy - if output.get("energy") is not None and batch.energy is not None: - self.n_energy += 1.0 - self.ref_energies.append(batch.energy) - self.pred_energies.append(output["energy"]) - # Per-atom normalization - self.ref_energies_per_atom.append(batch.energy / atoms_per_config) - self.pred_energies_per_atom.append(output["energy"] / atoms_per_config) - - # Forces - if output.get("forces") is not None and batch.forces is not None: - self.n_forces += 1.0 - self.ref_forces.append(batch.forces) - self.pred_forces.append(output["forces"]) - - # Stress - if output.get("stress") is not None and batch.stress is not None: - self.n_stress += 1.0 - self.ref_stress.append(batch.stress) - self.pred_stress.append(output["stress"]) - - # Virials - if output.get("virials") is not None and batch.virials is not None: - self.n_virials += 1.0 - self.ref_virials.append(batch.virials) - self.pred_virials.append(output["virials"]) - # Per-atom normalization - atoms_per_config_3d = atoms_per_config.view(-1, 1, 1) - self.ref_virials_per_atom.append(batch.virials / atoms_per_config_3d) - self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d) - - # Dipole - if output.get("dipole") is not None and batch.dipole is not None: - self.n_dipole += 1.0 - self.ref_dipole.append(batch.dipole) - self.pred_dipole.append(output["dipole"]) - atoms_per_config_3d = atoms_per_config.view(-1, 1) - self.ref_dipole_per_atom.append(batch.dipole / atoms_per_config_3d) - self.pred_dipole_per_atom.append(output["dipole"] / atoms_per_config_3d) - - def _process_data(self, ref_list, pred_list): - # Handle different possible states of ref_list and pred_list in distributed mode - - # Check if this is a list type object - if isinstance(ref_list, (list, tuple)): - if len(ref_list) == 0: - return None, None - ref = torch.cat(ref_list).reshape(-1) - pred = torch.cat(pred_list).reshape(-1) - # Handle case where ref_list is already a tensor (happens after reset in distributed mode) - elif isinstance(ref_list, torch.Tensor): - ref = ref_list.reshape(-1) - pred = pred_list.reshape(-1) - # Handle other possible types - else: - return None, None - return to_numpy(ref), to_numpy(pred) - - def compute(self): - """Compute final results for scatterplot.""" - results = {} - - # Process energies - if self.n_energy: - ref_e, pred_e = self._process_data(self.ref_energies, self.pred_energies) - ref_e_pa, pred_e_pa = self._process_data( - self.ref_energies_per_atom, self.pred_energies_per_atom - ) - results["energy"] = { - "reference": ref_e, - "predicted": pred_e, - "reference_per_atom": ref_e_pa, - "predicted_per_atom": pred_e_pa, - } - - # Process forces - if self.n_forces: - ref_f, pred_f = self._process_data(self.ref_forces, self.pred_forces) - results["forces"] = { - "reference": ref_f, - "predicted": pred_f, - } - - # Process stress - if self.n_stress: - ref_s, pred_s = self._process_data(self.ref_stress, self.pred_stress) - results["stress"] = { - "reference": ref_s, - "predicted": pred_s, - } - - # Process virials - if self.n_virials: - ref_v, pred_v = self._process_data(self.ref_virials, self.pred_virials) - ref_v_pa, pred_v_pa = self._process_data( - self.ref_virials_per_atom, self.pred_virials_per_atom - ) - results["virials"] = { - "reference": ref_v, - "predicted": pred_v, - "reference_per_atom": ref_v_pa, - "predicted_per_atom": pred_v_pa, - } - - # Process dipoles - if self.n_dipole: - ref_d, pred_d = self._process_data(self.ref_dipole, self.pred_dipole) - ref_d_pa, pred_d_pa = self._process_data( - self.ref_dipole_per_atom, self.pred_dipole_per_atom - ) - results["dipole"] = { - "reference": ref_d, - "predicted": pred_d, - "reference_per_atom": ref_d_pa, - "predicted_per_atom": pred_d_pa, - } - return results +import json +import logging +from typing import Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.distributed +from torchmetrics import Metric + +plt.rcParams.update({"font.size": 8}) +mpl_logger = logging.getLogger("matplotlib") +mpl_logger.setLevel(logging.WARNING) # Only show WARNING and above + +colors = [ + "#1f77b4", # muted blue + "#d62728", # brick red + "#7f7f7f", # middle gray + "#2ca02c", # cooked asparagus green + "#ff7f0e", # safety orange + "#9467bd", # muted purple + "#8c564b", # chestnut brown + "#e377c2", # raspberry yogurt pink + "#bcbd22", # curry yellow-green + "#17becf", # blue-teal +] + +error_type = { + "TotalRMSE": ( + [("rmse_e", "RMSE E [meV]"), ("rmse_f", "RMSE F [meV / A]")], + [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], + ), + "PerAtomRMSE": ( + [("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]")], + [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], + ), + "PerAtomRMSEstressvirials": ( + [ + ("rmse_e_per_atom", "RMSE E/atom [meV]"), + ("rmse_f", "RMSE F [meV / A]"), + ("rmse_stress", "RMSE Stress [meV / A^3]"), + ], + [ + ("energy", "Energy per atom [eV]"), + ("force", "Force [eV / A]"), + ("stress", "Stress [eV / A^3]"), + ], + ), + "PerAtomMAEstressvirials": ( + [ + ("mae_e_per_atom", "MAE E/atom [meV]"), + ("mae_f", "MAE F [meV / A]"), + ("mae_stress", "MAE Stress [meV / A^3]"), + ], + [ + ("energy", "Energy per atom [eV]"), + ("force", "Force [eV / A]"), + ("stress", "Stress [eV / A^3]"), + ], + ), + "TotalMAE": ( + [("mae_e", "MAE E [meV]"), ("mae_f", "MAE F [meV / A]")], + [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], + ), + "PerAtomMAE": ( + [("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]")], + [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], + ), + "DipoleRMSE": ( + [ + ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), + ("rel_rmse_f", "Relative MU RMSE [%]"), + ], + [("dipole", "Dipole per atom [Debye]")], + ), + "DipoleMAE": ( + [("mae_mu", "MAE MU [mDebye]"), ("rel_mae_f", "Relative MU MAE [%]")], + [("dipole", "Dipole per atom [Debye]")], + ), + "EnergyDipoleRMSE": ( + [ + ("rmse_e_per_atom", "RMSE E/atom [meV]"), + ("rmse_f", "RMSE F [meV / A]"), + ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), + ], + [ + ("energy", "Energy per atom [eV]"), + ("force", "Force [eV / A]"), + ("dipole", "Dipole per atom [Debye]"), + ], + ), +} + + +class TrainingPlotter: + def __init__( + self, + results_dir: str, + heads: List[str], + table_type: str, + train_valid_data: Dict, + test_data: Dict, + output_args: str, + device: str, + plot_frequency: int, + distributed: bool = False, + swa_start: Optional[int] = None, + ): + self.results_dir = results_dir + self.heads = heads + self.table_type = table_type + self.train_valid_data = train_valid_data + self.test_data = test_data + self.output_args = output_args + self.device = device + self.plot_frequency = plot_frequency + self.distributed = distributed + self.swa_start = swa_start + + def plot(self, model_epoch: str, model: torch.nn.Module, rank: int) -> None: + + # All ranks process data through model_inference + train_valid_dict = model_inference( + self.train_valid_data, + model, + self.output_args, + self.device, + self.distributed, + ) + test_dict = model_inference( + self.test_data, model, self.output_args, self.device, self.distributed + ) + + # Only rank 0 creates and saves plots + if rank != 0: + return + + data = pd.DataFrame( + results for results in parse_training_results(self.results_dir) + ) + labels, quantities = error_type[self.table_type] + + for head in self.heads: + fig = plt.figure(layout="constrained", figsize=(10, 6)) + fig.suptitle( + f"Model loaded from epoch {model_epoch} ({head} head)", fontsize=16 + ) + + subfigs = fig.subfigures(2, 1, height_ratios=[1, 1], hspace=0.05) + axsTop = subfigs[0].subplots(1, 2, sharey=False) + axsBottom = subfigs[1].subplots(1, len(quantities), sharey=False) + + plot_epoch_dependence(axsTop, data, head, model_epoch, labels) + + # Use the pre-computed results for plotting + plot_inference_from_results( + axsBottom, train_valid_dict, test_dict, head, quantities + ) + + if self.swa_start is not None: + # Add vertical lines to both axes + for ax in axsTop: + ax.axvline( + self.swa_start, + color="black", + linestyle="dashed", + linewidth=1, + alpha=0.6, + label="Stage Two Starts", + ) + stage = "stage_two" if self.swa_start < model_epoch else "stage_one" + else: + stage = "stage_one" + axsTop[0].legend(loc="best") + # Save the figure using the appropriate stage in the filename + filename = f"{self.results_dir[:-4]}_{head}_{stage}.png" + + fig.savefig(filename, dpi=300, bbox_inches="tight") + plt.close(fig) + + +def parse_training_results(path: str) -> List[dict]: + results = [] + with open(path, mode="r", encoding="utf-8") as f: + for line in f: + try: + d = json.loads(line.strip()) # Ensure it's valid JSON + results.append(d) + except json.JSONDecodeError: + print( + f"Skipping invalid line: {line.strip()}" + ) # Handle non-JSON lines gracefully + return results + + +def plot_epoch_dependence( + axes: np.ndarray, data: pd.DataFrame, head: str, model_epoch: str, labels: List[str] +) -> None: + + valid_data = ( + data[data["mode"] == "eval"] + .groupby(["mode", "epoch", "head"]) + .agg(["mean", "std"]) + .reset_index() + ) + valid_data = valid_data[valid_data["head"] == head] + train_data = ( + data[data["mode"] == "opt"] + .groupby(["mode", "epoch"]) + .agg(["mean", "std"]) + .reset_index() + ) + + # ---- Plot loss ---- + ax = axes[0] + ax.plot( + train_data["epoch"], train_data["loss"]["mean"], color=colors[1], linewidth=1 + ) + ax.set_ylabel("Training Loss", color=colors[1]) + ax.set_yscale("log") + + ax2 = ax.twinx() + ax2.plot( + valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], linewidth=1 + ) + ax2.set_ylabel("Validation Loss", color=colors[0]) + ax2.set_yscale("log") + + ax.axvline( + model_epoch, + color="black", + linestyle="solid", + linewidth=1, + alpha=0.8, + label="Loaded Model", + ) + ax.set_xlabel("Epoch") + ax.grid(True, linestyle="--", alpha=0.5) + + # ---- Plot selected keys ---- + ax = axes[1] + twin_axes = [] + for i, label in enumerate(labels): + color = colors[(i + 3)] + key, axis_label = label + if i == 0: + main_ax = ax + else: + main_ax = ax.twinx() + main_ax.spines.right.set_position(("outward", 60 * (i - 1))) + twin_axes.append(main_ax) + + main_ax.plot( + valid_data["epoch"], + valid_data[key]["mean"] * 1e3, + color=color, + label=label, + linewidth=1, + ) + main_ax.set_yscale("log") + main_ax.set_ylabel(axis_label, color=color) + main_ax.tick_params(axis="y", colors=color) + ax.axvline( + model_epoch, + color="black", + linestyle="solid", + linewidth=1, + alpha=0.8, + label="Loaded Model", + ) + ax.set_xlabel("Epoch") + ax.grid(True, linestyle="--", alpha=0.5) + + +# INFERENCE========= + + +def plot_inference_from_results( + axes: np.ndarray, + train_valid_dict: dict, + test_dict: dict, + head: str, + quantities: List[str], +) -> None: + + for ax, quantity in zip(axes, quantities): + key, label = quantity + + # Store legend handles to avoid duplicates + legend_labels = {} + + # Plot train/valid data (each entry keeps its own name) + for name, result in train_valid_dict.items(): + if "train" in name: + fixed_color_train_valid = colors[1] + marker = "x" + else: + fixed_color_train_valid = colors[0] + marker = "+" + if head not in name: + continue + + # Initialize scatter to None + scatter = None + + if key == "energy" and "energy" in result: + scatter = ax.scatter( + result["energy"]["reference_per_atom"], + result["energy"]["predicted_per_atom"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + elif key == "force" and "forces" in result: + scatter = ax.scatter( + result["forces"]["reference"], + result["forces"]["predicted"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + elif key == "stress" and "stress" in result: + scatter = ax.scatter( + result["stress"]["reference"], + result["stress"]["predicted"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + elif key == "virials" and "virials" in result: + scatter = ax.scatter( + result["virials"]["reference_per_atom"], + result["virials"]["predicted_per_atom"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + elif key == "dipole" and "dipole" in result: + scatter = ax.scatter( + result["dipole"]["reference_per_atom"], + result["dipole"]["predicted_per_atom"], + marker=marker, + color=fixed_color_train_valid, + label=name, + ) + + # Add each train/valid dataset's name to the legend if scatter was assigned + if scatter is not None: + legend_labels[name] = scatter + + fixed_color_test = colors[2] # Color for test dataset + + # Plot test data (single legend entry) + for name, result in test_dict.items(): + # Initialize scatter to None to avoid possibly used before assignment + scatter = None + + if key == "energy" and "energy" in result: + scatter = ax.scatter( + result["energy"]["reference_per_atom"], + result["energy"]["predicted_per_atom"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + elif key == "force" and "forces" in result: + scatter = ax.scatter( + result["forces"]["reference"], + result["forces"]["predicted"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + elif key == "stress" and "stress" in result: + scatter = ax.scatter( + result["stress"]["reference"], + result["stress"]["predicted"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + elif key == "virials" and "virials" in result: + scatter = ax.scatter( + result["virials"]["reference_per_atom"], + result["virials"]["predicted_per_atom"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + elif key == "dipole" and "dipole" in result: + scatter = ax.scatter( + result["dipole"]["reference_per_atom"], + result["dipole"]["predicted_per_atom"], + marker="o", + color=fixed_color_test, + label="Test", + ) + + # Only add to legend_labels if scatter was assigned + if scatter is not None: + legend_labels["Test"] = scatter + + # Add diagonal line for guide + min_val = min(ax.get_xlim()[0], ax.get_ylim()[0]) + max_val = max(ax.get_xlim()[1], ax.get_ylim()[1]) + ax.plot( + [min_val, max_val], + [min_val, max_val], + linestyle="--", + color="black", + alpha=0.7, + ) + + # Set legend with unique entries (Test + individual train/valid names) + if legend_labels: + ax.legend( + handles=legend_labels.values(), labels=legend_labels.keys(), loc="best" + ) + ax.set_xlabel(f"Reference {label}") + ax.set_ylabel(f"MACE {label}") + ax.grid(True, linestyle="--", alpha=0.5) + + +def model_inference( + all_data_loaders: dict, + model: torch.nn.Module, + output_args: Dict[str, bool], + device: str, + distributed: bool = False, +): + + for param in model.parameters(): + param.requires_grad = False + + results_dict = {} + + for name in all_data_loaders: + data_loader = all_data_loaders[name] + logging.debug(f"Running inference on {name} dataset") + scatter_metric = InferenceMetric().to(device) + + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=False, + compute_force=output_args.get("forces", False), + compute_virials=output_args.get("virials", False), + compute_stress=output_args.get("stress", False), + ) + + results = scatter_metric(batch, output) + + if distributed: + torch.distributed.barrier() + + results = scatter_metric.compute() + results_dict[name] = results + scatter_metric.reset() + + del data_loader + + for param in model.parameters(): + param.requires_grad = True + + return results_dict + + +def to_numpy(tensor: torch.Tensor) -> np.ndarray: + return tensor.cpu().detach().numpy() + + +class InferenceMetric(Metric): + """Metric class for collecting reference and predicted values for scatterplot visualization.""" + + def __init__(self): + super().__init__() + # Raw values + self.add_state("ref_energies", default=[], dist_reduce_fx="cat") + self.add_state("pred_energies", default=[], dist_reduce_fx="cat") + self.add_state("ref_forces", default=[], dist_reduce_fx="cat") + self.add_state("pred_forces", default=[], dist_reduce_fx="cat") + self.add_state("ref_stress", default=[], dist_reduce_fx="cat") + self.add_state("pred_stress", default=[], dist_reduce_fx="cat") + self.add_state("ref_virials", default=[], dist_reduce_fx="cat") + self.add_state("pred_virials", default=[], dist_reduce_fx="cat") + self.add_state("ref_dipole", default=[], dist_reduce_fx="cat") + self.add_state("pred_dipole", default=[], dist_reduce_fx="cat") + + # Per-atom normalized values + self.add_state("ref_energies_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("pred_energies_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("ref_virials_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("pred_virials_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("ref_dipole_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("pred_dipole_per_atom", default=[], dist_reduce_fx="cat") + + # Store atom counts for each configuration + self.add_state("atom_counts", default=[], dist_reduce_fx="cat") + + # Counters + self.add_state("n_energy", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_forces", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_stress", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_virials", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_dipole", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, batch, output): # pylint: disable=arguments-differ + """Update metric states with new batch data.""" + # Calculate number of atoms per configuration + atoms_per_config = batch.ptr[1:] - batch.ptr[:-1] + self.atom_counts.append(atoms_per_config) + + # Energy + if output.get("energy") is not None and batch.energy is not None: + self.n_energy += 1.0 + self.ref_energies.append(batch.energy) + self.pred_energies.append(output["energy"]) + # Per-atom normalization + self.ref_energies_per_atom.append(batch.energy / atoms_per_config) + self.pred_energies_per_atom.append(output["energy"] / atoms_per_config) + + # Forces + if output.get("forces") is not None and batch.forces is not None: + self.n_forces += 1.0 + self.ref_forces.append(batch.forces) + self.pred_forces.append(output["forces"]) + + # Stress + if output.get("stress") is not None and batch.stress is not None: + self.n_stress += 1.0 + self.ref_stress.append(batch.stress) + self.pred_stress.append(output["stress"]) + + # Virials + if output.get("virials") is not None and batch.virials is not None: + self.n_virials += 1.0 + self.ref_virials.append(batch.virials) + self.pred_virials.append(output["virials"]) + # Per-atom normalization + atoms_per_config_3d = atoms_per_config.view(-1, 1, 1) + self.ref_virials_per_atom.append(batch.virials / atoms_per_config_3d) + self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d) + + # Dipole + if output.get("dipole") is not None and batch.dipole is not None: + self.n_dipole += 1.0 + self.ref_dipole.append(batch.dipole) + self.pred_dipole.append(output["dipole"]) + atoms_per_config_3d = atoms_per_config.view(-1, 1) + self.ref_dipole_per_atom.append(batch.dipole / atoms_per_config_3d) + self.pred_dipole_per_atom.append(output["dipole"] / atoms_per_config_3d) + + def _process_data(self, ref_list, pred_list): + # Handle different possible states of ref_list and pred_list in distributed mode + + # Check if this is a list type object + if isinstance(ref_list, (list, tuple)): + if len(ref_list) == 0: + return None, None + ref = torch.cat(ref_list).reshape(-1) + pred = torch.cat(pred_list).reshape(-1) + # Handle case where ref_list is already a tensor (happens after reset in distributed mode) + elif isinstance(ref_list, torch.Tensor): + ref = ref_list.reshape(-1) + pred = pred_list.reshape(-1) + # Handle other possible types + else: + return None, None + return to_numpy(ref), to_numpy(pred) + + def compute(self): + """Compute final results for scatterplot.""" + results = {} + + # Process energies + if self.n_energy: + ref_e, pred_e = self._process_data(self.ref_energies, self.pred_energies) + ref_e_pa, pred_e_pa = self._process_data( + self.ref_energies_per_atom, self.pred_energies_per_atom + ) + results["energy"] = { + "reference": ref_e, + "predicted": pred_e, + "reference_per_atom": ref_e_pa, + "predicted_per_atom": pred_e_pa, + } + + # Process forces + if self.n_forces: + ref_f, pred_f = self._process_data(self.ref_forces, self.pred_forces) + results["forces"] = { + "reference": ref_f, + "predicted": pred_f, + } + + # Process stress + if self.n_stress: + ref_s, pred_s = self._process_data(self.ref_stress, self.pred_stress) + results["stress"] = { + "reference": ref_s, + "predicted": pred_s, + } + + # Process virials + if self.n_virials: + ref_v, pred_v = self._process_data(self.ref_virials, self.pred_virials) + ref_v_pa, pred_v_pa = self._process_data( + self.ref_virials_per_atom, self.pred_virials_per_atom + ) + results["virials"] = { + "reference": ref_v, + "predicted": pred_v, + "reference_per_atom": ref_v_pa, + "predicted_per_atom": pred_v_pa, + } + + # Process dipoles + if self.n_dipole: + ref_d, pred_d = self._process_data(self.ref_dipole, self.pred_dipole) + ref_d_pa, pred_d_pa = self._process_data( + self.ref_dipole_per_atom, self.pred_dipole_per_atom + ) + results["dipole"] = { + "reference": ref_d, + "predicted": pred_d, + "reference_per_atom": ref_d_pa, + "predicted_per_atom": pred_d_pa, + } + return results diff --git a/mace-bench/3rdparty/mace/mace/data/__init__.py b/mace-bench/3rdparty/mace/mace/data/__init__.py index ad58cca..8629cf5 100644 --- a/mace-bench/3rdparty/mace/mace/data/__init__.py +++ b/mace-bench/3rdparty/mace/mace/data/__init__.py @@ -1,40 +1,40 @@ -from .atomic_data import AtomicData -from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 -from .lmdb_dataset import LMDBDataset -from .neighborhood import get_neighborhood -from .utils import ( - Configuration, - Configurations, - KeySpecification, - compute_average_E0s, - config_from_atoms, - config_from_atoms_list, - load_from_xyz, - random_train_valid_split, - save_AtomicData_to_HDF5, - save_configurations_as_HDF5, - save_dataset_as_HDF5, - test_config_types, - update_keyspec_from_kwargs, -) - -__all__ = [ - "get_neighborhood", - "Configuration", - "Configurations", - "random_train_valid_split", - "load_from_xyz", - "test_config_types", - "config_from_atoms", - "config_from_atoms_list", - "AtomicData", - "compute_average_E0s", - "save_dataset_as_HDF5", - "HDF5Dataset", - "dataset_from_sharded_hdf5", - "save_AtomicData_to_HDF5", - "save_configurations_as_HDF5", - "KeySpecification", - "update_keyspec_from_kwargs", - "LMDBDataset", -] +from .atomic_data import AtomicData +from .hdf5_dataset import HDF5Dataset, dataset_from_sharded_hdf5 +from .lmdb_dataset import LMDBDataset +from .neighborhood import get_neighborhood +from .utils import ( + Configuration, + Configurations, + KeySpecification, + compute_average_E0s, + config_from_atoms, + config_from_atoms_list, + load_from_xyz, + random_train_valid_split, + save_AtomicData_to_HDF5, + save_configurations_as_HDF5, + save_dataset_as_HDF5, + test_config_types, + update_keyspec_from_kwargs, +) + +__all__ = [ + "get_neighborhood", + "Configuration", + "Configurations", + "random_train_valid_split", + "load_from_xyz", + "test_config_types", + "config_from_atoms", + "config_from_atoms_list", + "AtomicData", + "compute_average_E0s", + "save_dataset_as_HDF5", + "HDF5Dataset", + "dataset_from_sharded_hdf5", + "save_AtomicData_to_HDF5", + "save_configurations_as_HDF5", + "KeySpecification", + "update_keyspec_from_kwargs", + "LMDBDataset", +] diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 13866301413816fc28ea594c20e7a3a84a5072f6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 934 zcmZXS%We}f6hLS4ZZeZjlD>d1$O3|dx~mY-f{=I!sS*osEO+c=+}h)j?G&csYuI(c zlArOG6`uerd_65GVxk<~d+qU@Yg?u1F!Im0@7Kkzb`<@9s?9HEg+P$E;5$}t+tZL}>jl*tL2 z$SIo28Jfu*v?F)XuG~X=akPqP?`w?qY5z2u4}Qj8VII5TbE_oJPlzS+mK(i3d36Y} zVb!NfF!J7Ms4Q*Nsh z%Z$Z0Mz*e)arf7ipT+P6tBr54F}nCf^vcYqy5~1u_W=XI7?1*{fDs@8z-7^JpY#Oa z?^C2ZfNekq*aFM|U61$k12lpi;yDJ(umVcpz(VfNbh8RdgrsMG>+rry?jB(9bltV( z!rajfpA1$T(_kO^*r)~P&=L5#KE^}{jISfT559ok9~Xs^Y$3^2LDV@XKFZm{N~?30 z+eKB|LY4UV^g~lWV`W|}o@iPTZR=2L0KG0S<|VfnKdNdsIFgF41$zp2+`wigQ5Glu IcN8btAFq`NtpET3 diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index fd358bc00d63ef15eabe4c0ce1cf835fee7c4b0f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 956 zcmZXS%Zd~+6oymP*Xio&`&_&r1LC4ZZ9!%shQoU@Qmcvxig!)%t%B_s5h0i81P?5yIOSLWZ31rhe6-S&SU9h0qNaj*i|^n2KQ!YWm`#Zm?m{HfS024VnhNf&M%Q zJcFJ=*`Q+3G3XkYp9TYis>bW#j`GY6@sJpowt{M!2eYw$({&ZqOlHx+-+8lNH7`D( z^{Faz$rEwGu07JaT06DwZ69lE*0b%%_)|4PXp#`}1)b>-sbCamGd7yhlbB|QoN5)a z+c{*1ER;oXz|t@t-GeBnAQyJ7faz^S2v4~rCVev}-? jIYrmMqDQCb#woga?zSBF9If!nsGqxi&;3w?t9rv9i|r5@ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-310.pyc deleted file mode 100644 index 79291dc9ce890a0d9d182b87be237bc8800505da..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5568 zcmb_gU5p!76`p&2Jsyw$f7ZL3q-oL?TJw`Y3#f>wfKmb~XsMJGsD+q%eaCBWZO``3 zIO(p&s>B9`mms88JoO=4QKU%ak%zvcN<8*;)F(hfQ41|U1=^$y-?=lM_0m+aNbM2K%wt>$te|3HE)lVA6?`W|2$zkvn{QRE*NJFxbF=idc2r=wg%`TODhh@8rjJ#~v3tg>kV{giKbNqtdwCDQi48s*J0hDl@)rNK57~ z8#3<}Hfx*p-At#ExAO&AlSNm!ja5UI+@>tM=4J)_rM@XEvU)k&Y01oaqg}rN-;6ek zEa|#iJ#TBf%_HmETfxwqbVreO-hFE7PI~TII~&!$8hGPjZ*4kWch!aN`pAt|yPD)o zG|hJc&lyhSu;==b4cwc!&cF+zO5my9z}axUu^XsiFREPhh8qEikGmi|eS^hbR5#qf znYhD^!Mdjgo+sg1>8v;D4>zW&OHTF3-^EWm8Th}44cF)}X>|BwhMUFL-)6mi|uc7Ul^h6BQQfwu~>M7Ppv1W=brC2M)mQ$>dV#O3Ir5K_ud2c1fs&lM= z-xZ_JZPU(svj7}Lb zdrxG?k@#>o=teeT&)HDjt$`n9*SkS)5ZSKWa48Vor@{85>vV%a`JJMsUvzs2$N6aF zPBsFl=C?e5NCDV^AS1X6rDdV{XoAjwM<%|TR(HJte{Hk1xyJM&>VF*hHVVyp9EIWayTi-Fsfy`_ay zu?*bAFN2@|G&hWpZz2J$8ndO4#hBzza@7bJjdvN1{G~^XAF?pBX@pBM(h4cn9ob02^|wUFKKE;Q(q1%3#<0B=2SK36k0KWO>JoEjj#c#mF&R{&1q`R+hLR5+^5xjo9IK7f^g|sChd>`7Qzf*F%*EM zFbh}?O~6XX0jpsSuohZ?^)L_E2yMV-SO8oKi-4`L1h^cQ0awBb;A&WXmW8#j9yY?J zJo3DFkMC#pt=h zI2VlH7M#B?GWNnu=qbkp8rGG{O_W`-xfvwH&Nn$6=nVxQQ?0U zRsNQ!@i%#$Gx}Pux0ZJr`}zLVd<)hhv!@$_D7WJTx*IWj)4=QZqb%h^q+aC)Q#I*j z;r{$XyKCY1P5k`(!3hN&BbAeEgUv)>ARQtF1{}}KF4IpzY3-``zL=S_+{~KggDkja z$Ytg#pV>h!ur~9d4QNYom0v5!>>dk>A%~s`iNa1NOp0<_`tZUUhB<8DpQYDJCHgNP`U`^=FIoE3F z4y`>7pNr)4E}M}Llxmpzv9^QJzWl_}mN1vOg;^Q$jcZL=*h6OptzRTs32k5U5Zy|m zTTFC+y`T%)hw4@n-BP0a+XY?FK2*1s=whwO=bth~Yo6byeyDamk*lQICl<6pUC=HN z2kTya0r9g>pEa=Fjm;(OW9>f|U>M4AwT)=CdQAA@4$|ZfYG9P@i-Z z*QLbD!rFq`(o|%c`5Zx8*0g<}sGzUtBEG82_YqyPkLn_ROc(AusBp{UVGZ!kunu@4 zYyh4Nn}9{#HWb1Z@KU%8SPoYJE8!|&Jv{O($~sE=-D^?qVYlC%jslcVUH3kRbQ)#0 z)_b}Zx0!kX1l--SGT27p{CGSnP#M?N3x%B8fTZUC0jU0{FX5p0s4-(LA2X)y|`QSBJ(ak@~>Pv5A2JXc7 zR3t{;WFyK?Cq8b}?vATIp@sAV8crNZ05hEaOFX^dd&$;?TXMJDg4x z^&QI^Y$rDWKgvK}(N#FI6W$P)k1Sr}`6>^3`)HKujeK=KaJ@T<_Dp@ARLmzY;j)tS zRtMXvC^x^)=aS2O?5pm`Xx4t*bN!d?)}j3kEwOH96m=U>p|3oIiLUM2AoverW39Z2 zW&n)^v+yn7Dw~`&naRxexj4p(>=;@OD_cj~QDx$`#raL!vPJWDmGk$iws^0~%-dDL zZ&r(kznCRyX%`=Q4H3)C&vr5?UQbKENN*Gk;)H;be|C4C&KZ>S!W2d#yIxQ@0o2|7g)Xz@3W7lya+i}L8 zoQ?=DI?hwm?kMJ{hjfHoh=aJ|INiy_3-n(x)Pw4do*Kodpbs|w-x+vP6+uzdcSYI$ z$m<4({vocoYeXp}9V}g_lP)&i{nMmGA)rbCblUyr0c^u&ALWbw^N`-tLhosz_q5P^ zTIfA3^qv-aPYb=Lh2GOb?`fg;w9tE6SU?L4Xf^p8)>8R({ta%5n;cd-wgu*IndoSF zuKAAk^uoC)M?JIe29do^|3EnY5GU^oaK0x1<{!c-U!29yr<+S)Y_d%rqUj8In33$N zKzikC*M<5Xxax6$DCZBR{XYIu$V(i02fcrQ)JHvRWYb%*yb@Ov`(fbU(zxo7^1q#` zicTnyabks&I8;7ulZlfQf6ZIU0in$=oTK{<{TZkz+aZN*Zw)6K@n+K`qN>hqr>4Ph zUQ=DzawAc*b~!mytehA&K0keYCRWoBsmc C#A9Co diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/atomic_data.cpython-313.pyc deleted file mode 100644 index 2a3c8c64cdc1257c9f81740f92ce4f77c0eea25e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14075 zcmdrzTWlLwc9-Ok96o3GrX)%t^{_=-lI8eq%MZnt?Ufum)_AfBq+NodD2ZN~REAWt z4xCN;wN|%)A3$Jvw}53dXq3W0Tp&PGARlQ_p#9`%K1d8)G{7!4$XCa)vI+L5J@@jV zAv2=Qt&IR3K!@j^bMCq4p1E`G<6J#)IIIYsdxzHMM|L9gdt9hI<}&i+&wzY@IK=73 z(U?xqg)y!l*N+(l!x$w{q-+>B3MQaakn*9b-2Z1OAEUK8hNMVG0|I>A3@h%#SeMpP%$Gs)y)BC~iks+Z{ti@8)L z9bb@XF8StCGMz}Cj~ZpyNG@|Fl{mk2WjZNbicc>j<(4=h#nOb9jpZ`2RC*?rNM>a# z;F)wXHlNAKOfDlN=3{fo%#~zLNF`+EtxRezhsCeNf$+<(btNmcSEkH zA=lfG>ubpMH{=Eya)X%5QQD)DGj2hrx1d|LphGOt`X(T9Q zCY!>sJS&^B^QqZfR<a+Vad1g#F3L7r5dqHv81S4Ea|Bfon>{EvZJK6QgoJ;SIX3UGCq^= zm0gB^Oz?+i>?ePu5IUy&0P*?>Os{)59rk?M$LqG{a{9xsBNT5zXtKNpKuBADAD%Da zDb<+KBNRYM9dFRm&HFgRUZS}wM`_7%6d4-^Y(rJj_^gsDi#7Ydg~jISwwm6Wwc4xJ z)z#QW*VUz)sB5XKYiXjcwXUwUiMqDBy0#|j+Ux4to2cuU4!~G)Xw~41#H(gJhImIM z8o+3B*0txAoa6^9 zj@S*1Y%jdm@83tt5b|{vygj0~XZ51!-G1HjCDT%1Iz^^)RVOlCt8t0h{@~<8Yhmbw zICMhlANk^<)PFippUyLw9$Jg94ob{Wo*v3GLq*1SYEbNm#6#kOy48Mee-a^(=B?sC8lRhCo$1H9nCXQ(6juj z4?U}|eKd7{N{YPnFfK(7=0K70EMNQE?CRc+4&OigyZ0VOq|QBgdQYC&Q)EIS z)APWv_DZ4uxY&PO>O1kph}1Wlr$_V5=>Lt|E;79jJZnEM42*~aBU1F`FXB@4OrAcI zXU>!~uQ>{RN5sA(pS?wzy!6Gp5`BrZ+qf?DW+iv;oxAU>y{763bbr)n9WKpQsi)+KAdL`7a2dii|ro`+#h&&Na`BS)5CdYn5Zc=ur@75 zc0ar*MfT_E{ds1;y24K$4oh9-mKDm+T8s9;N{eCu>`QxaWrtz}Ofexx5U6I1)5zhw zIjdMOZbgpPl_QD`<96im+&rZ?Fz!T-7O2Z$oGr<_FzznNcrfloj?l`_6d%U@B|L!f zUjhGTkxy($^;5ni4 z;KK8b)A=M0DRY3n3Y;Y_j>Y{Ff^Hh;SiZ&gX)&|3F! z)FJ%@j<$6VXRO#aiDw1!@cIgkGl3Ph4y4kp-8waaCx^NR`i5tVx~HOBrE?2=({|o6 z4fB#pQoD7FN+Mj>92C1yWzmJbaA8;_t*MWrcfr!&qHX|2fJ*kMnkVqgS@&=>NAD|- zue3$wW6k$o@lDzRzQa>+-821CjYXQ)Yu9Y19~ZW(q&3^%)T1q2XCxGBJ)U-gO?Vkm z_e@0V2vFo~6(3BGJX2qr)~o7U_=oM+TK?QU!8G9s`wQ|!Rqv1L3APDOI9`w^s(OD^ zPq>=!g!2V?qN?{t^@O_#PcSdY6IH!Ud18l3S~Ft8jKW@om4{01snI9xQOy^g35{e^ z#Rr@VZ?EIK@K#yX!gI3(&xC2ug&ogq583X=U1aO7wd!zHEKhryc<1g)t>-*9RQ0NS z_tdxZ+~)=_c%uJ@Gm_>kS&iY(=!vHFHsy&vmGs~bdjP9b2S=-9p5|KvImUffo10xR zG_AJ@%k|GND%4--(>$S1`1el6qRpR65W^->P%q3-eyJg(mt?-HZJQ-U~M2z3RL<{pvHV z+n@C-tm)>4gJ9#(2Fz*2v;k^LuB-LT#{eV0806-@n1jz6$ z0Ns3R43ZB$D1s)dm!Hfiz6YVO`V2&~AF5R$Rf+YSc>qq|R36yk{if}m=nQ@S+gWtf z(1~Vs(eCrI`BZW?zO;~oGy};G9D$idHY`pjNX{av6Al9bE-o2oa#tb2fFlW%Y&hLP z&iNtXfzQ{{?CAAYigevUEHri%|3=E#q3A107u~E^lf?)Q) z;f+Af#*OT*<)cMk$IbEUqmPD8zzKH2zfJUSD|%WUvB4Gpz1F*}65I30)v~hf-p;!_ zCD*pc7UOQKNDo}Mu8eQkk*9y{bfJGl>>v4$1EQO|ZhqtrJ+ixQocrr@D}e{bPpltX zMf;wD{itX^`q{K(A9)mPTRHPP`(vZQW4&%xXoKZt-6K!%&hYKwRpUpt`?j^ypNxMz zE`<;NoAKYR|6&zAajuZc7Ek_FVEcEeKY=>%Zd$I92+e11y@Xk;D8t$ki3J} zEs$(r{WsoSeph7rHe2!EYrot6AXyk3T^}5k2G6Y9&z4(SNj&fsq9@j)Cm??#x=t!) zWMLrF!1@c!Hj&v@V0MVijj%zD2VNCj7x8HFZ#a>ix&Cf}>4HS3WWz~GTVF90eK7jT zxqmn(yTW-kq8cqQA_tt0k-puz{K|$2IbAo~3$7l~)gw82*Ibfg2UIPD_K2Z94@acX z{_C&c(O)~PUYM(JVzt2ZK>vc(6&Kq_yGHf)gBd&K(g*l*j5-qwP*SM>H4 zyiw5`UF(y)dp_S+I5M$*WI{S}30{%Q{lq{yoZr!i2`;BrbbqsezkJ-ai>tcT-48|! z{U_J^PfGoxqU*H6;Ldf_H~fag>=KzH8x0zcO5VMnj~0fx^Vr1Y4{Zx%w9OJZKVWpQ_HI|atq0fzlgBR8hUXTu6tRE<<@G82_;x^dFEMh(1 zARlXe6bRk9cKh1u*r&;Vy8Mrq*8?Mp0my&wx7}H|y|B8!(7i|OR=>G*Kd%@;0zNEn z3z<|*ss`3QXB9K%(x|25p8dD>)k&3Q>%j@d0^G+|-QUP8W7w11nD*>&<_o(qkm^zyLjJOa+l~C`OPWCfF=H% z#cSSw?HHc*pg(#J&9R2B7#%=?tlf0yqb_{GB^|$#jKySYEOsR`v$TM5do1?mQhcGz zVPdh_l#tCWq|(WBCKeM;<8_DWQfc`H+;}>j$&nlQa8E^`fLpphVk1jhxYj?PnGqbA zjC0YladsgS&%wGc1=lTqiYY$0-hve*$PLld^inPfnq)`$@(aH630H7*0?vhIarXNJ zS^Ee$#Y_NRn`QqQ*16yR8Wr)^e#L_kvj2B-& zk*KyJUe7F_lc-=3FEeh1B&xk=Wp4B@Lk+)*yD7Jc!FpwCc}k*MRB~E0yac)h_H`Du zFc{0{Zh-}16}L)LsM9)Q!w5}RjH=*^^ z(0U1N8*;a<^xo^g+b_DdE87XJuSPCPXgkz$-|Zx{SM|u`f(bg(``9+4H(I}B+Z7{5 z;0?ID6*EC;aJOP1s1?Dwnp*yqLZspPiq*ftV9pZ*DEgBgqr$<;SeDKB5>_^ulda?) zbu62L-O5p5C-A3LdGP)UZ@Hel4B!JaS)n7XD-30Cohv;4BV7?<3eu*xL%WUeVTDunmc}p&}b7dfG8`w<#86->o2{ z-2}jBA|8tBFP)ACWD5T`l7PDp+sLbnsq|ba?BH(7EZJWUF6B}S*+Fb;DPURgmAJuN zCIg#29ESw#L9!zy-6ESo*{=QvBeJ=bjYLr?e(lBjDcolgc4R{;ofCe6_27304n3U8QBQ#$L- z&Vs8~boCZoJ4Dxx4TNdR>u^EJRdjkDTWmU05$;M|79A1A0EI7Ek79&hxHe_7{d)4R WCJRQNX!I3~VbK`=lMOh;wEqP{OEyyg diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/hdf5_dataset.cpython-310.pyc deleted file mode 100644 index fbb873cbac8e95d14cab4e4202055fcc5c435029..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3229 zcmZ`*&5t8F6}MdAhwFswAI25DA8lvr8SWhbfb zN-ASJJDVsub&Ea0h3g)=5tlh|C7sB8JX!?&p1QE0(Ba#w|eOA&DOPT9?rI&iZ^Gkmeq`@dmL&p!w z#wbc7*AL6)C{AP7ZM@X|}$B9=_FWFivT>4%~%I#+(WBX&gh zilklPJtw{0e}OKfN3B07t26w5RA_VBW48Uc8t0jLl9^0PV;iduh{V4z)u_mk0V9;_ z7lWzFOi_)2v0G1{PDf`_eVm<@GRHGEfBtHqpFl^FlnRnE>GkR7q)!DE?5dgiG7z5d zugE1$LlKA&bZ&@-h~Nvi6fu08pA(;?G2EkNTOU4o`f#OW9&DV}fPMt61x^1Ih#`h9 z$R)dGH(>k)^f{OtsBeM=IS;QJH*lz+2V@zV2sw-JfYe-&YSz5<_`JD~YvTn2+FE!( zlZ6koJ#U!~u0G!p>;QHIUiWuoocP4-&Yn>?8{CKT-QSZjeMuH@I^PonuH&C73Yyxc zP}PK&nKn;*cWj`i6RA`aXClsvabY;O&3;izKFQ3`1|0j=FTsb_ulpBrrma`Xv5n6% zlMlHr=F*117~l%UN7-k#G3N$7WxJX4Ef?V29M09&KdY+J2AMHRD-3WOJ(IIvDOD*O zUcAiIK=&Hf*Rt%ZHcYFxklaOb56Rbnq|N$h(B&1iP0rocoWCOfNsousNFI-}`7o=_ ziwt_Xd@xbfxy;S+WM+odm_L60+nRqS$NBL1eI+KDGBd}k3F>=16#a)za|qa@$xJ~K z5>!dw2h#CE8qz(wPw&DP{|tHwjVXH_v-$o;_#CZmo1+ufKkQ@RK8B_bfq-uzb*DtN z4ZQ~MpL>GdU=+9+@FF|4RP}t-0b%P8A5Lc0e+mwAc6vqC*MYCT`9E@qy)F=BX+(>d zby-N~-A&!lt6O?85cHrqz0mFk&!?Ak32C)}I6eXDv~h|+YWt$d4Mq@tIqxy2iDg{8 zp^Z&AA4sE3W+dm|g}rOFXyo~H{`RI?n}S>Fq1muG_3%**?$B=Qy|iVTNOKNBs;A#oRt8?~o;{dw0na^6eaKV%UOR(t8AQ z++X;D!Pr}kg97p%T4MnKbU>bvPddxUc=c!z9RM)9Yz`OE33>5t;0Ie$Fxrr6IH{i) zALIZ0bDQn17&HD-7K5SDwzGnfu`2@FHYb%Xus&!T)-txW+Rn@-(t4A#-1@mJ z%U-0==jsT_yGV|a`~b;&NPdXqM?fH>WF~BYR6~7|{cKvAp69T~hWXS~{l5A!NN&X= zrV6I@VyE!#H(YxQq5Y%Z}c&^up3l-wfFJ1d}PR zMw3Us!JRQK@00nv+c68tR2Vtp{OF{t@~qU4j&9fW+UgFhau8rWZIn%><4KlZ@K3XH zDiszo^*$2BRO`WJ4q|Pquc{HR?O7u)Vz3SHDBWw-*&P&PAwRRN+M8~WCMzG*Q{F&L z(oj#&;9Hv>t!_O)Z7`>c1>4%F4{w^FEHfNq!i%m=W2$2A5#fL+>g ztpnfi-*F5S#_TmpcBuc=UJ!oiH|F>Lj|HLR1A>jS5g?O|IUiMG3U!bpX1$nZWj&)1 zaaR=f=}9O9k_D?WDtr-zvmQgAH0AH%dp<0McSB6=U1PZeguR`OvXK$AFx z#^^#57<$PRVAM5(`bo!tSW5zz#5WkbU}_&WpI&;UP zD91t41vs}mJ3BkOJ3I3&cPcBL1j_31_1XWp3HclwX0a73JMTf|HW7))jgVnNNf_&v z5pLK*Eqa|F;VBP1VMG|VQtPmd+H}5kq+-}k?Rwod;uv;Pr(UlZaSgkvJ4_l#3lZ%> zB045*eY`OrtrVS;7ST277TuGcUYl4MB$oy16+J;heKUN-`y~oSI3~=bGE?|HpH$S* zh=o<2$)w{kbugyJ6j@~z#R`ZxUsf}7$v85gg>uu$nJkT|$xIp;RYvz{c5X_hZ^Wij za@;V+phe5b&P8bd1omQ*P);P&B65_U6neQR>E%QV*lEAaQR}2l1qOo7XF_5^mI2bsDUWz72&1DdAZKCl3b$f1F0+eDS# z1oCWMQ5(wr5RG7=!_@VCT7Q_nv1(J_?zbJxfthrlNA7qyRoad5u2g9^#$QTRnxjHw zqRFg3dzE7$vw!z1T5#p&e!-kRvv2%js$|&=X(kgbW={o)fn{(*_;UIEWHdNUPV)Q7 zG#3#@DIPgVHshAp{m6YT)nc09bgCnRk7rV+Tu1P9h2j+|E7f4FL4T{6O3KOO> znH00eRFx{!qfb2|FAPzdq0DyWofw@_z_kTMPEAvE6B|HrUqM>Ly~mG zlBC$7D{PcS_FCs`W=`&$i(Q+IWiBUUP>aj$^E7i=j;lsvN>0aTJ9}tiK1S6Aoo5iV zwsSTy-KX0PF`;9Af#ON9+OjLxVQv|ox$+5+Zx+cD8*zB9Us>+e?2-HSwnCugx~EVV z$O+e-TeS_#wX1=Zz-nkEln+GmwMW*wwc7TaV=LISoLEh+B-b5U=-7i$zZUAxhmPli z12E3Jfb!VYIblqJDcip!)ZCiN9 zfvu+Iwbpx)yODdHcRM%Qw3bu(rqj8>>!aH?Qs49o*EdhB#MiJxILCoPQ%i2}m!sP} z)c>uZll^(%(3!*JlfzYKUgSSH!2yjIZ!mNkkp-i~&LWW8c!%NpbcsJl$Znh&E5(sr zxfUF0W{X`mM2Ws)d;)(i@t-{l#6im_<7h;nXgCvQ`{oyza2DKF_nr?p8V0_C=7;G4 zptKeU^Oc_?IySnMaumoSdF1lm7+4z6TuomDnpOk94y^UwJ9YQeMo&J_4?8$q@YdY8 zwsdXN9@qhSvS({>TlB3NfC_}%(zb#5rJ#$WjACB&G^B+nCvt~?9p&^z>Fa^yBqWk? z6|X0LB`IRj7sM=5LP3RLV^z#8Ni(vls4-QBJKRvdyz(A66rzNIERwCts++BM+HbYz zE5kVpxTFNB&Z%J6YAiKe@B5D+gUEl5(-N)@|N7OM^%CfK(#S#x={{5k-L##AyiaJs_kzQbYDVLf&z!(@c zo+=QR9Ei=vFj72Kfa`yX1HQ9z+GAJ_xTT_PSG>)4j~C>U4VNs{-I>K04WrGN#WD3v~!Sqo=dx)M*P*RskwFBj`@s|M0Bh$ zn<2woMP2H`yv+FdsW=njaw-+E(-W{j`XZ7Skeoz`XsRREpK z)3Izyjqo}sV77Qx%}h_zS3ojSN$YWx9^!dLrOcg8&&T3dq_<gK$uQKw8o2+=npEXE1+<2ZRyrYHWiDfnl6u2kT%6J3XXre$DcLFb8f(Jk zFdlJ=&5&f=6~?{ITxGaO!I)OWEyhPrRLrFZZZrWxX8qof36>ijT`*Sp6U1i7TmHt? zpRN2%^LON2h5CJZy}l#oE_kYMB$txQ-rpTqzmh*PpfwC=o)bB~P*r>L$G@#vo6I*K z)9Q|CRXsV|qq_Q>v(VVIX>Tah1eTB9uQ`-EvlTkB{`|&Kt?l&3yw*1O@tD?jE+2X| zH~juwY8O8r^j7%e$LVyaOOo z_5MO*^JdF2t+8k06|J$q;NQ1Aq4^{0Et^%j`UUOaz(>clgQxNhr?-6jmM^Uht-QH@Ol#`M`#QI(nwF>6F0G{3$FxvazN-6i zm8063vu)QBclC`EOD8l}s1OXTwym_ScCK`;KbH@7=gwXqDfojA{B4@Q?Sa4lGk<@< zUA@iQ>Ya~j8}6LHb-qwje`nyv?w+vNGH*9J7=&=G(EdUdZxuxvnL^j;&ymcVhO5-~zLHizloDM@pg zL>6GYguvjfY%Epm@k#*csiLNmX*r#dB#L>-fVOt?a_iTJqGQn>ByRv=b%3t{TqNTX zBKPD}R+SadURgxnc%uQFaI^)4^_>5bfwu5*p}#ON$SJ7ro1c?M1A|}lK5Nr(R+a?cBuoyjzxDhD!NP19HW@9uVCnS8%&?%6`<5bQA(H|=S0_&D*Z^7jZEe-+l zTRiu7UtLZp*nKy=OWxnO)&@5lUf(#RHN5`mb3p!PAA>tw<^Q54yw;GfIkI@Z;Huvg z>c0&a1(*RBxpNiB-m}(m0cF|#G)cqSp!AR9z; zQ3hP;S?JJ-1Qn$h2*Brqog3CaW*;0psU19-8@h3R>HKp2eRuQ6eTzd|HU68~Uk$AJ z^AJ9JH-%&CQ6TyrMXHKN59u(Oo}su+6cHQqsTmrd?Z~Rhl+uAmsCxv(2NJ8p$E(il zC`%mrq^SX$4K3VJ;X(wlgwGhO+ zz{DuP3SKl(dE^HcvzYw|1vYb-3s2`UvQgl|fj!mu{EJr`>xi8W9~cK_1EzQg zq{K6b;T#`3rL%*+tslD*56R3b-8mM^x6ng;-%HT3CvEAJ=!DGNIl@SK(g(MsbP`;8 zKjWvESx<0@Xsskb>!k&>QQAP8%r5N*h+ivhW^W)^X))A;?3{94ckc!MP+5XU5EHkG zo2p43=AvKl>>28|l_N%JA{8lwvPYpDC^tSzgcP6{k7CaCy_8RBk|&Z|ByWUPSc8@L&u{E0T6>y-Ph*cnU?2~iY#%+8NyZT42;_2#{ z8rV9GOoYMwbZDNyya7|_u9J|IAwJPXkN2Rgtu)U0c~EpQ-++g?Rz)^JR34>ZLFpwB zgx;x+rz@|!HTp>FPKOJOc4oXjg)>Z>;1$iFe1>R9O1y!f9LXJ&U$K$2ChuT$Y+*D# zCvApH{3H5_oM0$HyR`K44!R4)bZmWRRb{2QLwty&GjmHa^X8}|L;rxxp=%C^_&}Gm zgfeB;3E_vD8`hbvdFE(-*nndFrs9|b1xtoa*&4QoOV<#C9Tyh;g@tP@I$*JEEWDQ% zhYJhug|+O_v}vOB+agNK{}-j)siLK01Uqi>-%B*xES(AhAoH!g>C!Ie z1;3i)@MrPS)#m`j&`(jAM#|%{00#UaKe?K-SY*I#ecm%iZh0YI9ESxl*_Q<@cQ45&y| zlhXU+Fs(EOfcP75Bz|4rq_x04oq+}ZRTt9Oz%^PJjMgC$8mxWP#D2SDnbqzRUxGE| z081gx56}ccF@)W;uuJSN>CV*Jti0`LV#YAU9@9HR z8fc|PeQ4P&Rp22aNGp+teIU|GNuQcNwQtSl3CAnhx+UsL**@frl&l&xFFkk0HpC%P zRJBL;oqO&%_uO;uoO91NJ3gNq0j&>iOmDgndQ6!z@ukkr40LWF5s6F;Ve83Y8}POm z8|5(9O2^n3AGKq<+2>-8s1rM*F6=USKIV>ku*dA%W8SC_`&v;mYDc0Yh(zbdTt92A zg%3GNOoSQY7?llKR}AR7&l%Z_nxq3^8>xvY^Wvv@ zi~zlA5}E$esK&?ndNdwJNV6x({vlzn&C+Cn+ODFwy=>=bNg{)d_rY2dQWy_vMZHLK zB%I}G2{Q~R7e_y7L8!kZXhjKsMW?ph6M4}V=X;sBW0X0=i0mK}M`;Kji!QdNE7mUd zrs=Lty7*3-9*apL7w3kcCxbN3i5!hJu;H_$2u(pqz08O06NHHnR>w(rm3HfO6PJNy zO;2TU<~9l+AiO@CRt;jy>V%t>jcH<6uBy7BgJ5n}!I=6*oRXw!sv$|3Zl7?IYSJLQ z85@MhQbxW?91D^mUrH;)e);Dzp3-p*Xog*c)0K3JxKe6bk*IwfgmDb#h43M2%QtBS ztSQuSz#J<@re%q^Bq^DebzPEldSb05x<)SaOlLDnPexvtma{Xe485dsau#Q2l%!!X zmlQ2I-E$gG&dS)BH^;1oF58pNOkR>o4s_4X<0iU#U{6$bmp4w2S^q6G&#$3kbKAOq z&A-9sn@=qI9tYYMN4M%5RtHuFR!3GwZgmvuyYls2o8GR?`mU!eYH0tuIk@QCa{87o z-|Wjd!(Ta%Z#A@i5y*u&8P%o?`FNy({_N zv73#zJh_7>7P;RaxZg+H9dHFD_of5A59ojA1T;5Lq9ThUp^^+BDp5JuArYAgxF6WG zqTz!e3^VvJJajt06}Ll69Fhc5Nh0+xAB{OkU!_i8_BFKSuUqb0jjTj|_p4232zS6> za_`dg?U^nb5WF`~2;x9_+sbnQ2~=cpc6am$RBDYq$I^;TpD8AV&ay{PiuvE#pZI09{jhIh z!9Rv^7PqPz{3eyFtxu?rgiFflBynIx2k@r@2+$NfHBan-6Osz>x&b1J2x~7Z^OF7& zkaUzVBjKhMd6Mw7)x)e=t5-q_7jNqF-l32%aE!YG&)6yl+*nzZl*UDCoWR!Q|p60D#aN(Rbm)5~SGBB1xG z?t~h04Q<!=!_Lv7zv*WGdSoq9XgQf{IhpsLEDj9+ap4aOg@NRs`EmM9$V$sO(~9PZD-ujb4?Q{Qye zZ))3%-rB8DcOf+VFf^PGjXdSqI`^V|+l#z)*M~kCx*5)U-YNz|>&MrQulKC=Y;@iY z=YzwG?=Ho*IhU(pE7(y8_C5^u7Hb-d&4<>#Yu;j0%j&|)La{NpdU@sYYIY@iYwE6R z;}`kHp{Fjs<-lTWJAk}3*GI08P{ZA8-5V|WVEgU_FxXHd<-9mKktx4gB> z7jBK+8rZBm{+0IxW!sjG= z{k8X;J?chZ2HjBy`(?+$sEvJKt@%%UFTbq~v5%bSO z41eTb_Ab6<8$n}YNmDiK#}r}_UYpARVjraa8Ut>r>Fzm0P3zqh&SQ!oNdx^HF!}C1 zsc2=je^+KsGp_^lVjxYX$jD}?G3#cK;L_8rE2d?SG$pIBE~Iej6Tu5*tN5E2Ot#?AR%`y!0H{^PZ>K>tvf1ne8M{hquJ&#b|Gy5RZvfTA8f`@76 EKX({LQ~&?~ diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/neighborhood.cpython-310.pyc deleted file mode 100644 index 3019675b7040ba04f1d83a6a61e4ee81b7824bde..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1590 zcmZ7$O>f*pbiVAhz1~eq)1pYFLgj)%O_1OiMTk~8AQ6y)I9Q;S;~DRI*K4~oW3(ID z5|DDti9gUtTzcThng20YPWcNhUvG9xzQ%fAGjHCT_vLCZ2oS8_et$Xn%tz=?D_mb2 zAbbj-7JxY7xIl9p<2i{5wlXPb%z)CunY*!zkx$@bJmijG(;jz)$9n<`B7FO0bQ^6W z?!gY60_VdCaMvMY+b9co!28dgIN;<6MS~0Qjv`{bFKeAwWl|XD>!vQmVT8@F6!~~^ zTs3l3Q{_VOE_wf zbUE%wOD(VNwe$r#M;gB%oSb9J&xm}W@sx6^8E0ptquOc7DdrSTKGZIExceMzJw>PZ z6syOdqn~g~AuD7?)F&&6aL7Ujx3~w?1?mIs-LUF`PYZsF&ha_s-rK12TStvsp9fp0 z#Q>0Hr~>+#WMmtGOt0iU)Viv>q3hk!y_7+>r+&Gi>)+D7l=qN^8nuw07T&ogER64vb!6o8-mUdq6N(r?YH5Ih7fDt-p?S< z-3CY@oX~U9ZMGZjrsXpRT`@e2=)u83L@p@QP1;6YA?OK*>t!ucsO04$z}cEB3-1E@ z5w+;Meup)@USmzBG$W_v8-&p_s6N=ohY>SDuJTgpq)dfj$5mArQr5WT2@n$>fL&&CsjFma`e@D|5%jiWcOpq>qP41 zYHy8rqLbY>T`eb7#dqqZ@%P4}6pLE!Z`lYKAo9aQbn@O<=+O=3y`oB!LhbKd-jFsC z7b|4baAETh(|mdqIVO14B&CLV5lY$_ke)?WL%M(#4RSt;Y? zT-d8i<77pZXhRE8f*MSjqb5l$44o$n<0Z$cDjF@MeKZ{DB%K+`D;?dra&+Z9zNcj) zMoJXfC@>Ay*zBX^?+a^st1?2S@en^J+5r9WL?BOBa z!i>F-8D0Y+1pohquUX)c``E=EUFp)0g+zV;h9Y`+co6lC3r(DtV`*z9ZI2tLZ07a4 z2=U!{qE(vL%bm*~pjv0RSMo&4WGUaV7Y7zuv)3ZA-6~h%czg4TH-*@@-Ju=<+C&}+ bF>&of*=vT$zwqpTHYxerW8|;{`xd(`73{LKWia6oU<7ERC&khEi(nt0K_edRgR|JhC zcmH?p1W!Fo1m!>vTH==&99(4M(hdze&Jyu*!PZSf%@g6ZVj-`k;tYuxn!cEuH;Xu( z*Dad_3Z|t~DU0|rMcZ6hAY5TSL->rA&l4Zk>>@TY)Qh)7AUJsT-UN6DDQt#?5-`jLXXY&%~d-DJWP>jwl>OdC;Asc44A*nFUXl5HT zI(@`^K?Vm=T|;+-7of>_}mi1kNsMhPG0`=g7dKW~YC%e>QJs)Vy_Wq#@J^VORqIcF-EC#!_Dm ze17FY&(_Ny^B>QDlKUuE9sSGWzVp?-^Ir}8V$st4eR!0?H257_z8#mM=TA6&NsbF7 z_(oARY`9)E3yZ+V@R>rGmJ<5`z&Y{iif+Zjgj0=G+y*=%lr(HwMAWkyRb3^*Lf%wu z!sRss#NbjEtejfV2)C@R5K*1C%zV+-u_`V3^*;v~1RwQ0+Ce8$qJ<*7bU1 zpv-OirMKR!*y}5|R{l|XF#X=lKWDZUYJ;aA4^BFRleNLA+Oer`TH0=2fBX8z*n?M` zmgmd-4^pITz4NEa8}I!2p6tX@Uk=n_smITxoM%#1K2?=cb*bfM;znYlt0whU`QED3 zyB%t&FgKEGuhjp8_^tJ|3*U5httW3M?~UJ|x;wSGT01dT>l}Z0#pyh=_UaQ6h1%-T z{>>{+bZD!~iH<%z`DpA@_0hyfBTn>unJ-_t&6Hoe5&Twa{XdHR@Z5O4qZ>l6huiAy zC+aQHpCob6=kR@N$;#Cyoha0|eZ1E>{z5rY7rXBD-yga=w0Yto4($3+d$-fpU-s2IPBh7*9$HIoOVP@;jq6qERFyxq2PcZ+Tq>1}OT-7amToL!_dcT&Mg*g{TsRnw zj%C%hGJ0Wkr12-l+V6Y6kySBPS8*TcVER>byXAd^;U5XwZz%EXvY9RBwR89^@aVu= zWdOSz!!X~W{_j!yztPBU2#t)_`v&Vhy+27)%vc?orgWpU%K^O|YTFe6L+-)$%HRgS zk=^`Twfm&g9(MxI@A~PMh@_6n_UbD}~WON_%#yW_EVx z*sbp2atA$*<7Hq62w^$!LlQy2QV5a=iE}%5fCO+}0^}u6Mj>x`2=o#J$wMHS(&qcC zdiD}!>&Y|dn!1m_|Kt1rnqIN!8~A+x2k$Lje9SQZg9^Jp6DT~3Cpu{v1~Zu1HmdY% zwat!IwK{gyHmPj4old5j>9|$5ldWbuUe)X5s=1C|^;Nss&Q}X+d@=nkRZFToRxRV# zX^(d%suP{b>ZED#^3qiGfJzVY>FNy2R1fh(%x#&~!`yn#U|HtfGMKkvZJ5;~EXR*9 zpS%1hKiaSk87$9_G+b6-#amAGbEVJ%9`LRPr^*Af@92@5Wo8TwdBtMFA zKH9Q0N5f}R>;T4}%$d0i`vNbo zTCq&SW=Cns2Ls+H4C(vuV59swU`#e^tu+u!t z&d_<7iP~9q7PSwvC)qi69=Xr6M`({|bJzv;C~8i#$JpbyjOrO?&Ks4d-Y4#vt5`C} zYH_XJu0@f|G+MmPDz3CIb=ReHrB#om^I|KCrT5ZG-0F2}ZRySP)jsdmxy-!UUukp9 zFn%fSbz1egerJJ;S8EGxF7sD-qtqxj! zHJMShUNgAO9cJD(s~J!redF9?*2SHc+@2QuqLFiQ4$Co1Q8pa|irah3ph||hpmLrU zG?#)yi_8IKd|Ae>=T~^W)o9ggbnJ>L$Cy4)*r?Uxo>-T;R=3d$m+5@@nh>@1NEPPl zHjXg*P@bO4c%sut5}d4W3O_NHj7>8!@0g~MSgLLcbr$L@RcFi_@0oKITa+Hdw6nmRdFB>io=#=q>Sje0F6$UhH+l z%kwYm`USki#j_U#Td9e7UDfG?_~C3nZndM?m31+RnUCX%NdI`5g6WwB(=i9r`-wEG z_gFX4Z}fpKG8{bW{8Pvz#)dI8W6&WsH?2FyhBa>p(p@;!VLgE&Bd8xXM6VMruhm2| z5_xoySx}pSC&6-dy&Xxf)s5mBSWz8cCkLrxQ4LMsSoAQeBHD#(mCV7hPwab7=ZZBq zS1Cy^47;@s4@2pPVW-FXZORwIaJ64+>lQ&TRywpH{VFn4wxd|cu`sN4yS-R3PZWmg zEr~4J6**;IiCc%`g)ni1s_FghO6U}#gyb-uXaW9Z4U@N~tUt7N3l zyIeHa=}k)?1mlqso@)L^OSEcjUGQ0Jr3bQ8CvTD0nC{-_i8@z3K*mKdhc3K-cWY~_ zV(kb;#ZFk$8>=rgy;no;hV&-x&Rt zP$#Ho9nYQ*amTZGqFYFItF$GDXP5~U=Hi!BSr$erN2(uXC@@v_)3S5XU`k2xqH1?l zxx@>6OqH`Jd(|?OTZZK@cDy=)*)q&WeZD%W@&y_L(idg%a5yz9RZowN-3xx4H z1{bKsp*L5V5=2DtIZBREwbfmb*)CHrLY$zQ+{>fI1qqS}C?TnZRciNYvA94L#KzLO zjMc@Xl(k>R7J!EgIeF#kwM(zP_-c6hrMYV_JQu$D`pZ|v<22?Al-x_eX3_8^JP~Oo z-|#KZET`XsRX|Qq51DDjyc5aNjas|Kpp{?%KqFGZoJnu3CzeCj5_Q_^=tH_PhliLT zx&n>Jpz9qJmkp6;*0K?mzHEFGx}Vb6__ncO4xPk_(JOJLP?~`x##U@6&W4FLH+GgX zn=aXgVK&KbW)qjJ1oWs2=`u(vH%w+@_UB`7Gsj@v5_8I6?%NJS-%+2pjdvUx*L^nj z6PtO7^|rxs@7U7@^MMr#x9#s78hShZA>;lqmw21`B!`~Gca&{fed9Wiq=WAUFlE;} zZ$ztWn3cY-Gyl>g?Hbjlq)-1ieNOi(0Eo9U>kFP#%GLjY
X)6?gac}O{sHLvVtO*EZ#n4zo`kPcCM&4^J(On$FvGY9JM{GY?5;DVFn!YsOfGU&6Ldv)7m<^%y;uhCX~(y&pejyw(<6OxhMhLv)1m%1D&SD90`jxfd^l|sgwRLW&M5$Qg_Mze1S zH#UY=6UIA%0KHj+0e3dd*h-)P5(`3Z-Bu#M>L&zl?ZFw?fE&>40d0x|q!#g60`L(S zGVTR?OhiNobODPAvTuT93qbA2JydOY{pt(PJ^$)FjOz;C5R>ES*%Z1`WT@ap(=2=E ztLAI%cBnhU15rJ9cwDF#bMO;$DUu4z77;@7varq2-i~Q)f1SJTmg9*Qt+3u5 z5TsPZoP~bgO-w?jgo2ueGw8!B6tQThC%d zezS0gdLDek8s>)u-8(5T6YJ!tSEA}}m|~gvi5~YFt@Md;MeJ(-xFW`VYkS;kTmo)0 zHplK5nA^i_c9I24xh#J>2R14UPJ&Mq8||>qe&(6Grv_27L@-q;H;)4cp_4DGsfV7Pq56rbeskV1XQZUbbv-O z8rEdfI+;zSQxL;ca%w|@z@a!#2>}rC2qg+@JWe?hQH5D5Wx*-0bX`Uv9%wgp2=IGM zKD5j>c^m>wNslrFuEZ~5jI>b{uh3}n{4`)ys-ApFjZTVc=m?@&+(f5e!xKG^L<=0i?^eMsnI%9a2b$1v9CHRbal)E_e)Is7 zcXF0xneMi5YB6>HPtdHg5c3G9R?sop!!)Gy2n{MZZ6E_hXX1hMK@uJw+KIgcY&C7* z7b#(=!~!RugydO?QhW)?s*~7FU3ATEmIH+!sRT2qWugP4a1@~#C7B9+a zeiP=f9*6XLTlG-+Uxe)^mpNYvTT!o#@B&lKQJtth`Vbb`)|78SxKeQi)nQsJXg_nw zX7(Win$KKvn1h_B>OfgwM{0%tIP{?h?PW_`kAY*M+3+TS2-R378vv>R&An`+x$<=vfCprhRwvC`;Y6SoOF1lv+foVQGI9uiWD z%Za;Tg09cOUXXRMh*Ak0!+1P_)nQ}HR`0*4ZgYZ_@0e_S*&3FUGR7Z{C-EMp?ig%> z{EFc?bU*bwaL2e|puG?uRC8siH7O?JuuwlzXPr#OV7sQ?QL@AijM|dY=5&%*^7SLE zUc~AX%e2$e80E2p)RJW0GNbcJi6Rf!wYPi6Z<(tX;u-iQHh~&8ecKsMVD1U!70k?i zXg&pRpCnGYIL9VURaP|*dYmurYY81pT}i-Jq@R|8nD2!32}VfK*=v*#AhpAP!U zk=Th%1gCWVRM3ioZZBx}2(tpcH+ZnN$VnG;I{mH+62g60;PDzq?N={dz8V11(kipT zbsh*#9RZ5c6>0%%H5y#-ZX7_Xob4w4|<|nhjp>^@A=v&Hyd_ zhzDoR>}>RlySq{c4Kvq+c##M7#a@J6VWEIj{H*kM<{ccqR)dmIlcp*n(9=E@d`5EZ z#AAX}I{&jI*J#BNxzh5z$hEUkO|D%-e>%Bz=X;S$w|r)D?L@nRQ#${%B-d!g5xLUx zy~wq*QBAI0Lw`ECbZ4;+kyM7UT6iHfF_&O~1ewgi*yY{9sNmG#x}rpXl5$H z8qnpg-mG(Vl_E1-h@Ox>b=FKI2-WEdO+e|lppw!UH^D%e;|SJq#P7Ix8D$Yu^7GU( zLGz^MeH74hAoP*=IqIW82{}WRNin2~?Q`j4(#4DTWlCz4JV41qlsrhuDN4wTk|vi^ z`#=PuUjR1KR2zj9Qt9<$f)pA}0MM|T)u|EPRb`6xWO2`aWeT^W{qU!V_#Xw{a2;@S zeef^XM)W--WzPrbp|AG*oGsr5=mZQRKzh*hzF~dr+2j*`wiY#F%eP$kl6cIr<(r2s zYs>%<3e6IDoigaYiD1_CIxiVx3nu%JIQ%v0(uz%rNXpsdS%Bfj&^|lsrhd68aV~H3Btt% zYRxI$uaXW?>PP%KMpyFUw`jWGMpDfo=BBx?>QWXH zo0p{dCGjGyV=hT^P_=O^LBh0O<`9}vln+y?hv;|$dRv8_ysQTq^KlV4f$uQ{zbxw@ z#~&F_v%Gu$;HE`4HmT$<`|y(g|=-im{piN{uO$$bsF zUMR{>Q&={_bK@$N{E1BnUd{btPWI7sVUMkhP6nI ze}>3Lp=we~M4*#f;tx^Wb5l&W2lOaa@NdY#ONEb>nA+2Ur;9KNh4bJ+;GP&cTooY) z518SS2su!KWmGPM9Mk~nfS4Jsh>*)72XRBrLk?F($mNiOKY$$F22cdaSjr=Z3nSzT z$Q4zth#b5S)FG13##C+$xw6WYk%I?-)^X$}RBi&fNtK&KZVDGrMEe7Cf}+tf2m1u( zX7E3K&YC>b(b#j;=0=OH2HfL3aCIE&f(oNZpGLrZ#3Ir+KE)IOr*}_9ZA8;ZjVz`B zCT%T^t`^hy&Ns1`5`D2j%VyJN>FQfY*dl$ms629X%6hJ~WK{1z-)9m77A>k3k zg2c)`eiMIy-dBme9zv$9Y~jcze%ddtjA@iC>BvKR7Z~kf#sWHQhh6?11@i z0pqXc6kF!!swpwTg<4!+49_Vx-5=yrul7+8V^_G3B7K#DplFQNA$fH%L02n%zJCH0 zxcdpS{5{Qblajxsm=s#p}UBsDB4GnCi>mUBmVb4;-AoMPxKCB4GWLbX=P** z=mGru1(0PPx=m?g>CE-;X8_s~eE5IF&wd{N2h{aw8#Gnf6jUZPPD~0(p`W>H*AO1Q zpK_I14aW1Ngw@|HcqC|~{=0^YxE&z6x>(lt%r3QMDIvon=!!*p6x5~(G~#!u*y@S* zs4WBKg}=op3+~e&Y-ZzLuN?_e@ak_W;z`OqMae}4MnS6^i(fi-KcmM^`Bav zrgs$m5Q*oO%>am&{@H+lqU93Vc0WF~odn*# zi&5xFe^aqO!k<>|S{eVx&LP6d&Rt-GPqbO*Xr2igKSjOJ??bAgIk3)0uA0O8F}C&L XMArYnx~^K$gVsEF%y%cI3zPp3EG=M2 diff --git a/mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/data/__pycache__/utils.cpython-313.pyc deleted file mode 100644 index ebfab693d457d33ebbae7721e564bdef5ae0f23f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 17743 zcmch8TWlLwmS7dD_pYmJ>^UB#JD_mMkSvnTgWohD}S9L`No-Dk`?s zNLm?S5h*`ZrdH`|Y0R_K_Vigw$x9iWjvI(DXK zx*O;{=T;R-DRkoO$6P_T?(^Jp&%O8DbIv_{V7FT+2=5K;E+0QaQNO~B7E0B`gTK;I z)VmZ<@$?iW;8!SSAf@#t$m?tfQh1AniR>4O4 z+tsf_aFD!HaKV?IDw!-5N+-*NGLG`QX4^Fs@?W9i5#?@wrBKD|g=&8_Z|J3k8o#cJ zYNB{!4aJ*|0OlBAY9t6A-t71A7Qew?>#rTu>7kUj`aOdN-d00hWQ97lw;g&rzSO(k zZ}it=AEAMF`pvw{&+#SxM!wWv3s5XjYgAwkTKF;ouc@A@r}$339O`Vd;|dmqF~Oh$ zaB}-i1S?K~K*}o#+y;PFh0re2LW|$il$#At^ZG%CuPzRm|7xi90JTPi+RA(U4KM>E z%%N7BV;+w|maiqVIfSzzGdj%I0d6<~;z)2O7}xvT_y+h#Y8d>Deswnf29+)jo>Rwc z=NnR8@ON)GcKHI&e` zM0>kGLlVz;b+UOO7z@s?1VvHSFNQ-a3tofFj6|-=?5Xg4OlHr9#h7e*b~P4`MuID{ zi4VQH7K+S=Wc|6d)s+xHC>e=GFNNo4)-GKL3Fm?rRzkA%RA@1{wi25PT@_`cn)$V| zs@bRJ3zyfVW*#IU@h;`31e&J=9Zw5PE9GaKD8HW94bg%DsBZF`6e_{aV%ng6K|SuR z!g?0#$x1+5KdxM^maGH&`N*9L2X7LbOM0(GcEOzZ)zEx+F+3kc zqI^~Z19@rLxuDEBuo#?=MTM)fIUHGx23|oDSc8HPyeg8+%sc}fw18>+fA9uG?@~Vc zf`_7f)T~xt@SQE}r2L_|b5j+0OKEJrVedM*XmmxIxZ;UJ{uLtU#v^kQf}rerR_URmxL z5EfR0LhLFjQzG1UJ!`S>irBq+Rj9;KD-V*kTS21%D#c?EZBSVYWvh^^T*l_QHnTO8 zwlzw&#)S8d?Qq8GmaI)sp0?FWw%W9jUUc&8xJ+G=;l>P<{py zKzU$^+Hyze^SPEj`n<{_DQdO|W%Io>X4parZU6~%t5NVL%MM#6uqZ??1zx!v6qZE6 z22Erm?2ZLEW^js_`4v$%g(G4t2&YTLF*#1C##VOHPu*UkT>m99?W^yAE+Ud_gZjrZ z_d7>6O?yu7Zu?(Pr3WWI8k|TCo=G_;lcq@lHX8K<%B!21@w#MFAP@;&3IzhPB@nn2 zU07Sev@H;LbuGA});I%!#jqg8R>F}`BpL_^NDi6B2`HPxfO$rY2{JSdMk3J|IUyo6 z#hoPJSyz_YN;uA?P;5E6AT(mPia-FS9t+P0g0YwozOWVxVZUbVf|wFg2rf)vgojT2 zE<|rp|H!r8T9CM78)r6S@zI^K`01@TQtYvRXllF7N=+v=rsC!C`JI9I@^*cSJ&`q0 zJp-BE!TUzDzWiR9XXjL+EHR&`x-lb_b!PQYm^D!D`kjTuz%G+GdLtrLbbWzsvu?^z za$l#@Ptgw;NIoD*1vkWqET#y{EYxGpAcSITLS#O-d7u#KL$+K367NzyATK;mM|8Z7 zXL@uIrk9Ga@UQpj%Ba(D4kL!SAa45wfH4hH$ebJZ?A$oY11yF7ZMuS@tP~u5<|})= zP^1Sno7>_5NDXsiynaXr@PHk74KrT7Yzaj|!qQdT!m@@H6$Cp#KhWmQTRq?<;x!hWVmd=}JH>Y++cV>4? zDO)2paKj$a+uoXab7tqIjTwo3I%6%}m?3*%9>&fETx<*vZbS6t%Wam3mZU{4@*8*^ zE-hSG2KACN`OSWddhr=}aX7*ciWqRzv8}QolJ8+XgU!z*)gff>}W+G)~p#&~J z18*VYlnP}4%fMSnolPhwIXlTYgo-7D*D2daqmjk%(wd;$c+bEHvLi^Y{Xm4=`=V@K zjf!Es5JlMnSHNmWh=oI<>{8QL1D8YLrRA77BeV0Nl@*y;y)Z905RyTuk*)J<An% zbhcv+i<^n?0IyqUffAt=qeB=Wg0e2MDjOpUWS0oY`DF8R8X5u~H^GY$P9MaZm1r;~ zAo7G0$?PcLA&g=j^Bf>^3R7jL#>Pi{XU_#jpPd;$GZ8rV(sN_N81@;*=rBgj80DGL zam-?bpv5pmZ&CNk-Pb3#CsS;F##xnOt1_={!TXiY6?p}H2#`xw~iml94b^tFcOR*gpS2eV6zROlC zSm4;QJ{${O5@pA0!Ikg=Ts0ttz+&Lt9a@l0m!rZffrYRzk33roW=P-`eeh?%VL*SO zeYB`Mp&JIhCTN5WMN$huyhv|qpISI3xcBd7<7qljp*ZXLkA#ZSk4yC|xPy@NyaMZIep9j9I1Rukjd^!ikn+F)= zZoCB~a@!E2qar4yFK}*qF`vx%(brFj;hsS9F!-zA8O^Ii(E_yh9lTO48=IOj5u;M2n*vbhhxi7wz3kvj4i?ukYYmO zFt_gPQ4Ww=5?E)uV^?A_eR*BK7F+B(3X;AFWZLKgbdzbz`q9T5N(IM>!Pi0yFx)Wb zwaLs%C?YfA2&h>rsAlP5CQBh<0_Fko1DTa<)zk_B#xsh`Ym1AZU>R4UOG}6Wi}$3= z0=tdK`pW|F0IyX*iB@H>cvpJ8adS(X#nu{8-Z3kTPk;-BNAu zt+J%&Nb*P|8GS8FQCH|u9eh^lQ}~I}WB9R;>%L%gX3IkhDRdRi*MeN4l7Z-3gkoNUu!2zb!vlSPWhJ0A2D1ujDWWu@oG0ZtxxWOAkY9KW z;laOn1|m324$9<_j7@v>LrIh89uZzkJe&HAsVr@(kxVsvwz`a~BGWXGX?ZeJ(|q5o zw>$ocF`6x~GfQi;rIf8Q$+{ImVGTg;TP)N<$zs{q;*0fl$}gBIaAF4n+#PK0?%-gV z*{`|T5sw4G$EvqYQ|`9OZ2*ucayLC-P;#IvWHnV@yVpRr02xOK#npI#13K5i@Fh4# zXot$|fl%!3M!inEpChJk&<|t`t)EX<46@IZZ-KH%#0+|)1^%t)(PBV?kDf&(`9RR< zv9De2fku$k+M&z>N3|8kuziCW7<@L6E{p?A9?!Y|+y=0wAtD+pJzHedS+rXm2-*q& z7Dp*+4OVgW7mtsL&(6~tE$99UF_=phTfbv)a}hmz5Uw(litBI=*CLes2DowurUcod zt6B`#B7MIK*YnNVSim@`Sy;D&L8GV7;xqfKKE`MA8GJ_GJY;UsBrbN5@>8((Q#N4&^-UoN zFk}X@0$xUnK?)DNA}tMJis(^x(p6I;y$)dws}3-DgEGD{j~u27HOOS1#!fgRnZ+AY zHY*n_8ZS`iRRt^cqu4M+?1Ir2{EI(@4SA6Je_By_eeU~nnJ16`&FepUJ!KmE(d)Y} z{BhZ}s;#QDnUl<1!hOfw{=?TL)7S$>=QeE2JTy>NS6qMHzHNWUnX+&jqkC5SwWDtz zjgMWQ*`7(2w}5~K(pYKb^^Wb1>)qSkiMm_nRO#@>RK{LWLTcxL(|3GmVGduen+Ce*Mk*TWz;nZo6+^PWDcGSe9y>PMWGxrh2%# zO6F*gBWt_To+l*F6KT(J$#eX!=j4_Xx_h|yhHnhNck;%`UGJ@PDbMhxMKV=^qW#|F zjmeLlhcd10Nn35o)^y)U)joyd6o5G@d5)$%Cne9xyPl^tFvo8kPkTBgPv`ELTd|bq zcn;9n8)rXuw(kQpLW7uStIkxm0=d$qtx{=gs#+dc(hb-bb0oX%Uz-wcB=@HpwVMf&&97|nwb8vb=eR)P}{VSG%n zcLAP8CJlAD3l6esz?uj<_YAU(;2~(Gn?+0xguE&ara!t^i@DJrdazIlog}b~2LM4f z4=bXdtH8l{@`zXrs*lmG8Bl9s=PMTjNOa74E03B5hqNN<+XXo13tTk9sT>7uL|vjx z2R?(DQC_ydM1#VU5XMr2@D_YP=|PHLxEg@>EF+Q3Vj>u!UKW0oRp?g}0#e-2&;o*d8movS3VtXFFZQ0Q;#BvN{CM z;LZS(4LzGP`Q%0iBMYfDJ)*Ck)S4erYaq22xYC&6IoS1QLHj_RgoN25f+jwN=l3C) zL>_^08$j?GKn!UF%&qm5&sxmfl)GN7vo3Yef)7tDu+wl!lf7=@*~iAwQh=lNQFJSU zV-cfL0gk8@9teshIESN|*FQFn_5vIokDw#5RUe=u+ENdM0vwCc(SbA7a6FHee-eC# z^Z5Egf+qIm%uTSPeTE)7VgNze?qjv`2|5S@J#bpbd01H`;}T z*wG6FFRMlc-Z^}lqE67-NI}4j!+ld^M#V%m-j#z?f{-FEpDPFJGVE(40to=qtEKZ_ zjfQ+Fo(TJC;WV$hA9Zy3|cby${{vV3iB#e%%SIVW@aN5 z1mz=QE5Hx-<-q{>nNhhaGzo*jw&4RrSpboXB9JIt*S`*QvXlxtJcZCqI7B2jK7iuPJC z>Mn-3%gZ5@PA*+qixAHcuqs~&#V&`Sbb4fTj01Bnj?&G&7~+HwHUyKBIzo_J2rn*% zgis{Lfr@Y#8lWu~Ez`LXLHBB#9UE6}@vCj!+)L3lE&`7`9Mr6dA+D<{kLZcQrr1ES z=UwGu%OP%lIVu8HFc62jNw;jt!(Fc#4}t(efU4pHr+^*i9+O=8XJ4*ON&jn-OB+!m zmzpm|t~{a&xeB^|J-L*|#mJ@9JT|%VkLp~TlK$5umo}nCE;V0_TzNzlausy_dU7d^ zg}32I5>Z{Yol(Wl5fDxVR9)7cqlK@`T-*AKgc6s6!Yi;LgCf^f*r$yH9}#XL8p4wp zL%kpd&pmL=)cxAceJdnH)eOA%tgOI!#DlWXO{@^>J<98mCI)jTB`eKAM&{7Qn}0z< zp&HD_#A192=RnNH)xzt@Vsv46ff4Yf8o0h{S@Je5f z66Uch^C9A0C+kCk5EWz#HdZ!~fMyA$Kp~)P2?I+~e1wP{LpF!t0X_r|^C96S9Q}_l z`WAL8ML4SHuRYSUXn+tvd7}iPCZVwz>?_+?kein>&F^E`I~c8C)Qb^{0|H)Y0^C#- zv5d%cNG?C1V8BB&s9Kd+DhBYSExHy%bwklKK*d_PBA07)Cxxm$EZg_tE35V+QAEP- zx0#TfVoA5dyFR{ngLBoDtLwroX!}$67yCfq+n^r0sM3n-?rryV&$eghrQP|sCyDP( znUbb-$ziGF@a~*c@`Rf2mP)#JuSg|Fi_|;0w6jBUcI-YOIr}$GXN)EBxsR};cJsylkqE(jr-J979V@ZvJ>2Prd@nu?ET4`lkd;moVj%()iUyT zFH0?amZ9v8{{&7rEqh#ZVl3TqTxvOd=6kDb)o-gC93UEPwadw1=w>&X2^%G3L>1s*e#PI$;{NxYh}w(o)7e)jFN$?C~F zwyFQ)Y|2#kZchG+yvQ`PU$bmk;zv`aTJrV+#y4kc##EZA?a9=1W=cKqHsdsGOu#bR zxI|^r+MZKX#ZZ{|EPu_7omFeu!h@W<~CtdBo z%j%)&?;g65!X1U-ZLv8`^hTlbUtx~U%B z2fcM?dv!nUwoDz?|MY14RGa>1ZF)%mtjBV;OaHT>_OtE!pSSBF{qtVSbcO!s!_H}! z{ueGir18aUWd)o;(Je@MIF^m;ffzao7x#?&5!?h|Z~(bX3BJ7AqAMuPxxGf1uY1>- zYt9?4;2jEY8WJAy3Y0T)!5foNYn%5n0pb7OfPiY4BW8x8#tRx%pt%=N@i~tV=_G(M zT*UOCDH;H^TU9cz6kv#!;sc?HQb#VLy7)|}Bx*3{#(m&xPt-5q-KdP9_>bRus$h(q z`jH=Kwj!U_Qs5Z~uUSKkYMLSMP2XU4KnL=2vpn~wggUzfFS~+|&w_p-UxO36)E)?+ zFu^+lC|Nnhqu7fVC`@{dCLH(&Ml9a|Z?lFsZ0O6+tZ^ z?SsNqY&+H~T7*9ZD$IZvie3PbXF-viX7fV3HZrtTm5Y^75RItd$P3<^pjj-T6DJ%< zO}#;tKm1hago+6mQ`WD8JS7T#NDEJ6^p{wwTNm-^XDt#7M4}OF zqCCVHz)Lu|vLx$Q#F+T1Y&{!}jI1n;p@buR8+-pbMBuCPLrnh&BH0YyFe;BQU=je+ zBQpJra1m?ii!!|~GcYaTKSI8^iwgKD5V{9cng_7?kD+-Z4PAyH0(rnem6RvTTT`w> z8)x?Hu6X75{2SwYCd)O`+orfTS<;d+wQh`P1=T5&2Xgk(_|RQ@-6oxJa+0$-?d+DE zC@DCPZqi%qUU}Wl(t8(gT)bQE-8!{7vgfW!yAOTjK9p=bDzzO;w@pfIlOG19wr5lB z=Qc;;C7TQ1n|Q!b6?jEE>a!-RzGKhfx)$CFUyE)AON#d zEieb&iE$e{*JjZ@%&yThzxfty;~Eoy=={iQ@hXKwe`TlS9O5*k)w8xcD1&Mp8?boqZdFf;xPfWPPK2&@VyVR zOReSdrCJL567wuTZqmx~U#Klr`?lmuwG`BrskMi+viuin%dvKe@>a|US0N>vK?Q{y zVga6?ETM&^P=Gia5+4At$t{N$;Z4p8rn6W;$&jai*@8VlLyQTc;+}0+ZCt9$xLwhc zRqq1XuG%QnT8H){1=T0vIv_&sPhkg9GOE?GLHYdxsGG{K6o_%qTQ1w?L93620!u=4 zZB_Ug_Wn6W#Bt#lnEIa>J%iEz!srx6NsJ_nKEf!4(f*NzJ6QHHMt3n9hX|f9kyj{} zopJdH{L}{g7?>qtRm9_k>_c%A_z+uRG=|KgncQg^OG(;NFInmnmAfUo%+Bj6OBeV{ zR(h_#vHb=(lUpUWJX2Nsp7VwiayE&rI8aw!b$x1k3hGR6PJ{n>y4ovMdm(F=Sa+t> zecii_Im4S%2cYchhcBL8k_SfKbV}gz{&o3UE?wIr)%I+R!mFsnR%iA>dvvCv`nrGH z52KjhoY~)JUriqEKblhBMjaqA<$uN??&FOpmJ;6h=U>SiHAxcm4)v-*PNbEVQ4@4d zk_U;|$08?%M+2r9r_DbyMT`a4oW2;x)JDK(O*~oq+85!W1??41Yb=IAUNQ#%H2vfH z+-r|Qa18^WB-il2FTf}FP&v1Sm`($?Bu^noVG>}ZBD}c(W5VAHAnJ>8>w@-#;*nnW03XQ1b8e{Wq2!b91nm?_gd&6fEjG{6;NcX z-@+(^(LZ4HUm%hV%3%YZymo;wP{qH(q-Euj{I7%pq;ORj#*TV|hwM<8wy4&Lc-bQt z)PiFB^AK5T|GCbIRG&f@KXS>A{ag~fEK-(U@P(?V-+3cl(Y<8`xxS_`(U>UPv2TKH z$(=51mdctF(_mPF-%qyo-g2k=#-+aTRO`g1DOu7enVR;v&fW3!;p5Wb<0Y`%5Mo8D^~t2I4!x8Z$}{+> z=TPFc_rG)VJ1Nfq0CR&w(_T$$V)6ZpH!r4Y`k<&{%ehzAzU%%^bw8*})jbI%Rp9=c zscuYHcT3gX$)1r^^)s6eG9r$Zu6a_bdGc<}5E-$lYd4f?8r+=0-?+hm!=FOQV!Hcj zsr%_v(+EJTyJKqjypD4B5N5g0GYQD2Fh`V{@VBxTfJx$y1bG_$ui+{ zW6Il2lt%b91c=<^x1%9ZKnAS*2Gs!WKwxJVki`fnZwf|?kS_|L!c&3;BH4tFj0iw6 z+Z!+&jtc(`>+~R=!0(VJFm1t~F?YwJ;9n@90f+onS2&HSGZ;-_^j(b3Vl<7>s~915 zlpmbcVT!mDI!5>1Ig0Ha?epYU6S5vP2~lzS7O^xAk<4Cz!HFT8Xs<*!Aa}liTbj_4 z5Svgc~CN6~QvQVa^!MZ>_(%zQJY;)(z&KyCF%HWg7Z#jeN;m8_lX+y?ZmSynU+M#^=k0{@fQ{;P3Z`mA-mv0VBdiO&;G|Zl&sV9$RD(df> z%jrt;Tf6HoZNIeb%rcO_SJ{@;!xv^y)0AjQh`VEnb*ZUesvgK1vD`$t%CguKTm=o~ zUsxe^zn=obcgOL}6T<{-7h*+}>MR3k0usJOL1MWA5>o)O>ay*WyCKujooVrAxYqmjb2RPEI7_n(Cdkm3fUyB> zBT1Ub%x=_Y%_M7~>YK7wlC)7Z^;rS|eo0`k0^OLnFK&^no}HJI)=o)(_zNd#7}0rX z4xAga41Rk@$QQ9}-dx-4kZ=*9B72;ndlN12D-*i+Rx|nBZh45ghqiLMHrwr@pP_F# UvJ`$lj6EdZ2Zw07pX}BD2l{rzMgRZ+ diff --git a/mace-bench/3rdparty/mace/mace/data/atomic_data.py b/mace-bench/3rdparty/mace/mace/data/atomic_data.py index 26ebe1b..14dd39d 100644 --- a/mace-bench/3rdparty/mace/mace/data/atomic_data.py +++ b/mace-bench/3rdparty/mace/mace/data/atomic_data.py @@ -1,300 +1,300 @@ -########################################################################################### -# Atomic Data Class for handling molecules as graphs -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from copy import deepcopy -from typing import Optional, Sequence - -import torch.utils.data - -from mace.tools import ( - AtomicNumberTable, - atomic_numbers_to_indices, - to_one_hot, - torch_geometric, - voigt_to_matrix, -) - -from .neighborhood import get_neighborhood -from .utils import Configuration - - -class AtomicData(torch_geometric.data.Data): - num_graphs: torch.Tensor - batch: torch.Tensor - edge_index: torch.Tensor - node_attrs: torch.Tensor - edge_vectors: torch.Tensor - edge_lengths: torch.Tensor - positions: torch.Tensor - shifts: torch.Tensor - unit_shifts: torch.Tensor - cell: torch.Tensor - forces: torch.Tensor - energy: torch.Tensor - stress: torch.Tensor - virials: torch.Tensor - dipole: torch.Tensor - charges: torch.Tensor - weight: torch.Tensor - energy_weight: torch.Tensor - forces_weight: torch.Tensor - stress_weight: torch.Tensor - virials_weight: torch.Tensor - dipole_weight: torch.Tensor - charges_weight: torch.Tensor - - def __init__( - self, - edge_index: torch.Tensor, # [2, n_edges] - node_attrs: torch.Tensor, # [n_nodes, n_node_feats] - positions: torch.Tensor, # [n_nodes, 3] - shifts: torch.Tensor, # [n_edges, 3], - unit_shifts: torch.Tensor, # [n_edges, 3] - cell: Optional[torch.Tensor], # [3,3] - weight: Optional[torch.Tensor], # [,] - head: Optional[torch.Tensor], # [,] - energy_weight: Optional[torch.Tensor], # [,] - forces_weight: Optional[torch.Tensor], # [,] - stress_weight: Optional[torch.Tensor], # [,] - virials_weight: Optional[torch.Tensor], # [,] - dipole_weight: Optional[torch.Tensor], # [,] - charges_weight: Optional[torch.Tensor], # [,] - forces: Optional[torch.Tensor], # [n_nodes, 3] - energy: Optional[torch.Tensor], # [, ] - stress: Optional[torch.Tensor], # [1,3,3] - virials: Optional[torch.Tensor], # [1,3,3] - dipole: Optional[torch.Tensor], # [, 3] - charges: Optional[torch.Tensor], # [n_nodes, ] - ): - # Check shapes - num_nodes = node_attrs.shape[0] - - assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 - assert positions.shape == (num_nodes, 3) - assert shifts.shape[1] == 3 - assert unit_shifts.shape[1] == 3 - assert len(node_attrs.shape) == 2 - assert weight is None or len(weight.shape) == 0 - assert head is None or len(head.shape) == 0 - assert energy_weight is None or len(energy_weight.shape) == 0 - assert forces_weight is None or len(forces_weight.shape) == 0 - assert stress_weight is None or len(stress_weight.shape) == 0 - assert virials_weight is None or len(virials_weight.shape) == 0 - assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight - assert charges_weight is None or len(charges_weight.shape) == 0 - assert cell is None or cell.shape == (3, 3) - assert forces is None or forces.shape == (num_nodes, 3) - assert energy is None or len(energy.shape) == 0 - assert stress is None or stress.shape == (1, 3, 3) - assert virials is None or virials.shape == (1, 3, 3) - assert dipole is None or dipole.shape[-1] == 3 - assert charges is None or charges.shape == (num_nodes,) - # Aggregate data - data = { - "num_nodes": num_nodes, - "edge_index": edge_index, - "positions": positions, - "shifts": shifts, - "unit_shifts": unit_shifts, - "cell": cell, - "node_attrs": node_attrs, - "weight": weight, - "head": head, - "energy_weight": energy_weight, - "forces_weight": forces_weight, - "stress_weight": stress_weight, - "virials_weight": virials_weight, - "dipole_weight": dipole_weight, - "charges_weight": charges_weight, - "forces": forces, - "energy": energy, - "stress": stress, - "virials": virials, - "dipole": dipole, - "charges": charges, - } - super().__init__(**data) - - @classmethod - def from_config( - cls, - config: Configuration, - z_table: AtomicNumberTable, - cutoff: float, - heads: Optional[list] = None, - **kwargs, # pylint: disable=unused-argument - ) -> "AtomicData": - if heads is None: - heads = ["Default"] - edge_index, shifts, unit_shifts, cell = get_neighborhood( - positions=config.positions, - cutoff=cutoff, - pbc=deepcopy(config.pbc), - cell=deepcopy(config.cell), - ) - indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) - one_hot = to_one_hot( - torch.tensor(indices, dtype=torch.long).unsqueeze(-1), - num_classes=len(z_table), - ) - try: - head = torch.tensor(heads.index(config.head), dtype=torch.long) - except ValueError: - head = torch.tensor(len(heads) - 1, dtype=torch.long) - - cell = ( - torch.tensor(cell, dtype=torch.get_default_dtype()) - if cell is not None - else torch.tensor( - 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() - ).view(3, 3) - ) - - num_atoms = len(config.atomic_numbers) - - weight = ( - torch.tensor(config.weight, dtype=torch.get_default_dtype()) - if config.weight is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - energy_weight = ( - torch.tensor( - config.property_weights.get("energy"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("energy") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - forces_weight = ( - torch.tensor( - config.property_weights.get("forces"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("forces") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - stress_weight = ( - torch.tensor( - config.property_weights.get("stress"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("stress") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - virials_weight = ( - torch.tensor( - config.property_weights.get("virials"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("virials") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - dipole_weight = ( - torch.tensor( - config.property_weights.get("dipole"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("dipole") is not None - else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) - ) - if len(dipole_weight.shape) == 0: - dipole_weight = dipole_weight * torch.tensor( - [[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype() - ) - elif len(dipole_weight.shape) == 1: - dipole_weight = dipole_weight.unsqueeze(0) - - charges_weight = ( - torch.tensor( - config.property_weights.get("charges"), dtype=torch.get_default_dtype() - ) - if config.property_weights.get("charges") is not None - else torch.tensor(1.0, dtype=torch.get_default_dtype()) - ) - - forces = ( - torch.tensor( - config.properties.get("forces"), dtype=torch.get_default_dtype() - ) - if config.properties.get("forces") is not None - else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype()) - ) - energy = ( - torch.tensor( - config.properties.get("energy"), dtype=torch.get_default_dtype() - ) - if config.properties.get("energy") is not None - else torch.tensor(0.0, dtype=torch.get_default_dtype()) - ) - stress = ( - voigt_to_matrix( - torch.tensor( - config.properties.get("stress"), dtype=torch.get_default_dtype() - ) - ).unsqueeze(0) - if config.properties.get("stress") is not None - else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) - ) - virials = ( - voigt_to_matrix( - torch.tensor( - config.properties.get("virials"), dtype=torch.get_default_dtype() - ) - ).unsqueeze(0) - if config.properties.get("virials") is not None - else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) - ) - dipole = ( - torch.tensor( - config.properties.get("dipole"), dtype=torch.get_default_dtype() - ).unsqueeze(0) - if config.properties.get("dipole") is not None - else torch.zeros(1, 3, dtype=torch.get_default_dtype()) - ) - charges = ( - torch.tensor( - config.properties.get("charges"), dtype=torch.get_default_dtype() - ) - if config.properties.get("charges") is not None - else torch.zeros(num_atoms, dtype=torch.get_default_dtype()) - ) - - return cls( - edge_index=torch.tensor(edge_index, dtype=torch.long), - positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), - shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), - unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), - cell=cell, - node_attrs=one_hot, - weight=weight, - head=head, - energy_weight=energy_weight, - forces_weight=forces_weight, - stress_weight=stress_weight, - virials_weight=virials_weight, - dipole_weight=dipole_weight, - charges_weight=charges_weight, - forces=forces, - energy=energy, - stress=stress, - virials=virials, - dipole=dipole, - charges=charges, - ) - - -def get_data_loader( - dataset: Sequence[AtomicData], - batch_size: int, - shuffle=True, - drop_last=False, -) -> torch.utils.data.DataLoader: - return torch_geometric.dataloader.DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - ) +########################################################################################### +# Atomic Data Class for handling molecules as graphs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from copy import deepcopy +from typing import Optional, Sequence + +import torch.utils.data + +from mace.tools import ( + AtomicNumberTable, + atomic_numbers_to_indices, + to_one_hot, + torch_geometric, + voigt_to_matrix, +) + +from .neighborhood import get_neighborhood +from .utils import Configuration + + +class AtomicData(torch_geometric.data.Data): + num_graphs: torch.Tensor + batch: torch.Tensor + edge_index: torch.Tensor + node_attrs: torch.Tensor + edge_vectors: torch.Tensor + edge_lengths: torch.Tensor + positions: torch.Tensor + shifts: torch.Tensor + unit_shifts: torch.Tensor + cell: torch.Tensor + forces: torch.Tensor + energy: torch.Tensor + stress: torch.Tensor + virials: torch.Tensor + dipole: torch.Tensor + charges: torch.Tensor + weight: torch.Tensor + energy_weight: torch.Tensor + forces_weight: torch.Tensor + stress_weight: torch.Tensor + virials_weight: torch.Tensor + dipole_weight: torch.Tensor + charges_weight: torch.Tensor + + def __init__( + self, + edge_index: torch.Tensor, # [2, n_edges] + node_attrs: torch.Tensor, # [n_nodes, n_node_feats] + positions: torch.Tensor, # [n_nodes, 3] + shifts: torch.Tensor, # [n_edges, 3], + unit_shifts: torch.Tensor, # [n_edges, 3] + cell: Optional[torch.Tensor], # [3,3] + weight: Optional[torch.Tensor], # [,] + head: Optional[torch.Tensor], # [,] + energy_weight: Optional[torch.Tensor], # [,] + forces_weight: Optional[torch.Tensor], # [,] + stress_weight: Optional[torch.Tensor], # [,] + virials_weight: Optional[torch.Tensor], # [,] + dipole_weight: Optional[torch.Tensor], # [,] + charges_weight: Optional[torch.Tensor], # [,] + forces: Optional[torch.Tensor], # [n_nodes, 3] + energy: Optional[torch.Tensor], # [, ] + stress: Optional[torch.Tensor], # [1,3,3] + virials: Optional[torch.Tensor], # [1,3,3] + dipole: Optional[torch.Tensor], # [, 3] + charges: Optional[torch.Tensor], # [n_nodes, ] + ): + # Check shapes + num_nodes = node_attrs.shape[0] + + assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 + assert positions.shape == (num_nodes, 3) + assert shifts.shape[1] == 3 + assert unit_shifts.shape[1] == 3 + assert len(node_attrs.shape) == 2 + assert weight is None or len(weight.shape) == 0 + assert head is None or len(head.shape) == 0 + assert energy_weight is None or len(energy_weight.shape) == 0 + assert forces_weight is None or len(forces_weight.shape) == 0 + assert stress_weight is None or len(stress_weight.shape) == 0 + assert virials_weight is None or len(virials_weight.shape) == 0 + assert dipole_weight is None or dipole_weight.shape == (1, 3), dipole_weight + assert charges_weight is None or len(charges_weight.shape) == 0 + assert cell is None or cell.shape == (3, 3) + assert forces is None or forces.shape == (num_nodes, 3) + assert energy is None or len(energy.shape) == 0 + assert stress is None or stress.shape == (1, 3, 3) + assert virials is None or virials.shape == (1, 3, 3) + assert dipole is None or dipole.shape[-1] == 3 + assert charges is None or charges.shape == (num_nodes,) + # Aggregate data + data = { + "num_nodes": num_nodes, + "edge_index": edge_index, + "positions": positions, + "shifts": shifts, + "unit_shifts": unit_shifts, + "cell": cell, + "node_attrs": node_attrs, + "weight": weight, + "head": head, + "energy_weight": energy_weight, + "forces_weight": forces_weight, + "stress_weight": stress_weight, + "virials_weight": virials_weight, + "dipole_weight": dipole_weight, + "charges_weight": charges_weight, + "forces": forces, + "energy": energy, + "stress": stress, + "virials": virials, + "dipole": dipole, + "charges": charges, + } + super().__init__(**data) + + @classmethod + def from_config( + cls, + config: Configuration, + z_table: AtomicNumberTable, + cutoff: float, + heads: Optional[list] = None, + **kwargs, # pylint: disable=unused-argument + ) -> "AtomicData": + if heads is None: + heads = ["Default"] + edge_index, shifts, unit_shifts, cell = get_neighborhood( + positions=config.positions, + cutoff=cutoff, + pbc=deepcopy(config.pbc), + cell=deepcopy(config.cell), + ) + indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) + one_hot = to_one_hot( + torch.tensor(indices, dtype=torch.long).unsqueeze(-1), + num_classes=len(z_table), + ) + try: + head = torch.tensor(heads.index(config.head), dtype=torch.long) + except ValueError: + head = torch.tensor(len(heads) - 1, dtype=torch.long) + + cell = ( + torch.tensor(cell, dtype=torch.get_default_dtype()) + if cell is not None + else torch.tensor( + 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() + ).view(3, 3) + ) + + num_atoms = len(config.atomic_numbers) + + weight = ( + torch.tensor(config.weight, dtype=torch.get_default_dtype()) + if config.weight is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + energy_weight = ( + torch.tensor( + config.property_weights.get("energy"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("energy") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + forces_weight = ( + torch.tensor( + config.property_weights.get("forces"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("forces") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + stress_weight = ( + torch.tensor( + config.property_weights.get("stress"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("stress") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + virials_weight = ( + torch.tensor( + config.property_weights.get("virials"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("virials") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + dipole_weight = ( + torch.tensor( + config.property_weights.get("dipole"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("dipole") is not None + else torch.tensor([[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype()) + ) + if len(dipole_weight.shape) == 0: + dipole_weight = dipole_weight * torch.tensor( + [[1.0, 1.0, 1.0]], dtype=torch.get_default_dtype() + ) + elif len(dipole_weight.shape) == 1: + dipole_weight = dipole_weight.unsqueeze(0) + + charges_weight = ( + torch.tensor( + config.property_weights.get("charges"), dtype=torch.get_default_dtype() + ) + if config.property_weights.get("charges") is not None + else torch.tensor(1.0, dtype=torch.get_default_dtype()) + ) + + forces = ( + torch.tensor( + config.properties.get("forces"), dtype=torch.get_default_dtype() + ) + if config.properties.get("forces") is not None + else torch.zeros(num_atoms, 3, dtype=torch.get_default_dtype()) + ) + energy = ( + torch.tensor( + config.properties.get("energy"), dtype=torch.get_default_dtype() + ) + if config.properties.get("energy") is not None + else torch.tensor(0.0, dtype=torch.get_default_dtype()) + ) + stress = ( + voigt_to_matrix( + torch.tensor( + config.properties.get("stress"), dtype=torch.get_default_dtype() + ) + ).unsqueeze(0) + if config.properties.get("stress") is not None + else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) + ) + virials = ( + voigt_to_matrix( + torch.tensor( + config.properties.get("virials"), dtype=torch.get_default_dtype() + ) + ).unsqueeze(0) + if config.properties.get("virials") is not None + else torch.zeros(1, 3, 3, dtype=torch.get_default_dtype()) + ) + dipole = ( + torch.tensor( + config.properties.get("dipole"), dtype=torch.get_default_dtype() + ).unsqueeze(0) + if config.properties.get("dipole") is not None + else torch.zeros(1, 3, dtype=torch.get_default_dtype()) + ) + charges = ( + torch.tensor( + config.properties.get("charges"), dtype=torch.get_default_dtype() + ) + if config.properties.get("charges") is not None + else torch.zeros(num_atoms, dtype=torch.get_default_dtype()) + ) + + return cls( + edge_index=torch.tensor(edge_index, dtype=torch.long), + positions=torch.tensor(config.positions, dtype=torch.get_default_dtype()), + shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), + unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), + cell=cell, + node_attrs=one_hot, + weight=weight, + head=head, + energy_weight=energy_weight, + forces_weight=forces_weight, + stress_weight=stress_weight, + virials_weight=virials_weight, + dipole_weight=dipole_weight, + charges_weight=charges_weight, + forces=forces, + energy=energy, + stress=stress, + virials=virials, + dipole=dipole, + charges=charges, + ) + + +def get_data_loader( + dataset: Sequence[AtomicData], + batch_size: int, + shuffle=True, + drop_last=False, +) -> torch.utils.data.DataLoader: + return torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + ) diff --git a/mace-bench/3rdparty/mace/mace/data/hdf5_dataset.py b/mace-bench/3rdparty/mace/mace/data/hdf5_dataset.py index b374885..ab6aa7c 100644 --- a/mace-bench/3rdparty/mace/mace/data/hdf5_dataset.py +++ b/mace-bench/3rdparty/mace/mace/data/hdf5_dataset.py @@ -1,97 +1,97 @@ -from glob import glob -from typing import List - -import h5py -from torch.utils.data import ConcatDataset, Dataset - -from mace.data.atomic_data import AtomicData -from mace.data.utils import Configuration -from mace.tools.utils import AtomicNumberTable - - -class HDF5Dataset(Dataset): - def __init__( - self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs - ): - super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments - self.file_path = file_path - self._file = None - batch_key = list(self.file.keys())[0] - self.batch_size = len(self.file[batch_key].keys()) - self.length = len(self.file.keys()) * self.batch_size - self.r_max = r_max - self.z_table = z_table - self.atomic_dataclass = atomic_dataclass - try: - self.drop_last = bool(self.file.attrs["drop_last"]) - except KeyError: - self.drop_last = False - self.kwargs = kwargs - - @property - def file(self): - if self._file is None: - # If a file has not already been opened, open one here - self._file = h5py.File(self.file_path, "r") - return self._file - - def __getstate__(self): - _d = dict(self.__dict__) - - # An opened h5py.File cannot be pickled, so we must exclude it from the state - _d["_file"] = None - return _d - - def __len__(self): - return self.length - - def __getitem__(self, index): - # compute the index of the batch - batch_index = index // self.batch_size - config_index = index % self.batch_size - grp = self.file["config_batch_" + str(batch_index)] - subgrp = grp["config_" + str(config_index)] - - properties = {} - property_weights = {} - for key in subgrp["properties"]: - properties[key] = unpack_value(subgrp["properties"][key][()]) - for key in subgrp["property_weights"]: - property_weights[key] = unpack_value(subgrp["property_weights"][key][()]) - - config = Configuration( - atomic_numbers=subgrp["atomic_numbers"][()], - positions=subgrp["positions"][()], - properties=properties, - weight=unpack_value(subgrp["weight"][()]), - property_weights=property_weights, - config_type=unpack_value(subgrp["config_type"][()]), - pbc=unpack_value(subgrp["pbc"][()]), - cell=unpack_value(subgrp["cell"][()]), - ) - if config.head is None: - config.head = self.kwargs.get("head") - atomic_data = self.atomic_dataclass.from_config( - config, - z_table=self.z_table, - cutoff=self.r_max, - heads=self.kwargs.get("heads", ["Default"]), - **{k: v for k, v in self.kwargs.items() if k != "heads"}, - ) - return atomic_data - - -def dataset_from_sharded_hdf5( - files: List, z_table: AtomicNumberTable, r_max: float, **kwargs -): - files = glob(files + "/*") - datasets = [] - for file in files: - datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs)) - full_dataset = ConcatDataset(datasets) - return full_dataset - - -def unpack_value(value): - value = value.decode("utf-8") if isinstance(value, bytes) else value - return None if str(value) == "None" else value +from glob import glob +from typing import List + +import h5py +from torch.utils.data import ConcatDataset, Dataset + +from mace.data.atomic_data import AtomicData +from mace.data.utils import Configuration +from mace.tools.utils import AtomicNumberTable + + +class HDF5Dataset(Dataset): + def __init__( + self, file_path, r_max, z_table, atomic_dataclass=AtomicData, **kwargs + ): + super(HDF5Dataset, self).__init__() # pylint: disable=super-with-arguments + self.file_path = file_path + self._file = None + batch_key = list(self.file.keys())[0] + self.batch_size = len(self.file[batch_key].keys()) + self.length = len(self.file.keys()) * self.batch_size + self.r_max = r_max + self.z_table = z_table + self.atomic_dataclass = atomic_dataclass + try: + self.drop_last = bool(self.file.attrs["drop_last"]) + except KeyError: + self.drop_last = False + self.kwargs = kwargs + + @property + def file(self): + if self._file is None: + # If a file has not already been opened, open one here + self._file = h5py.File(self.file_path, "r") + return self._file + + def __getstate__(self): + _d = dict(self.__dict__) + + # An opened h5py.File cannot be pickled, so we must exclude it from the state + _d["_file"] = None + return _d + + def __len__(self): + return self.length + + def __getitem__(self, index): + # compute the index of the batch + batch_index = index // self.batch_size + config_index = index % self.batch_size + grp = self.file["config_batch_" + str(batch_index)] + subgrp = grp["config_" + str(config_index)] + + properties = {} + property_weights = {} + for key in subgrp["properties"]: + properties[key] = unpack_value(subgrp["properties"][key][()]) + for key in subgrp["property_weights"]: + property_weights[key] = unpack_value(subgrp["property_weights"][key][()]) + + config = Configuration( + atomic_numbers=subgrp["atomic_numbers"][()], + positions=subgrp["positions"][()], + properties=properties, + weight=unpack_value(subgrp["weight"][()]), + property_weights=property_weights, + config_type=unpack_value(subgrp["config_type"][()]), + pbc=unpack_value(subgrp["pbc"][()]), + cell=unpack_value(subgrp["cell"][()]), + ) + if config.head is None: + config.head = self.kwargs.get("head") + atomic_data = self.atomic_dataclass.from_config( + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.kwargs.get("heads", ["Default"]), + **{k: v for k, v in self.kwargs.items() if k != "heads"}, + ) + return atomic_data + + +def dataset_from_sharded_hdf5( + files: List, z_table: AtomicNumberTable, r_max: float, **kwargs +): + files = glob(files + "/*") + datasets = [] + for file in files: + datasets.append(HDF5Dataset(file, z_table=z_table, r_max=r_max, **kwargs)) + full_dataset = ConcatDataset(datasets) + return full_dataset + + +def unpack_value(value): + value = value.decode("utf-8") if isinstance(value, bytes) else value + return None if str(value) == "None" else value diff --git a/mace-bench/3rdparty/mace/mace/data/lmdb_dataset.py b/mace-bench/3rdparty/mace/mace/data/lmdb_dataset.py index 342b179..e1ebbda 100644 --- a/mace-bench/3rdparty/mace/mace/data/lmdb_dataset.py +++ b/mace-bench/3rdparty/mace/mace/data/lmdb_dataset.py @@ -1,69 +1,69 @@ -import os - -import numpy as np -from torch.utils.data import Dataset - -from mace.data.atomic_data import AtomicData -from mace.data.utils import KeySpecification, config_from_atoms -from mace.tools.default_keys import DefaultKeys -from mace.tools.fairchem_dataset import AseDBDataset - - -class LMDBDataset(Dataset): - def __init__(self, file_path, r_max, z_table, **kwargs): - dataset_paths = file_path.split(":") # using : split multiple paths - # make sure each of the path exist - for path in dataset_paths: - assert os.path.exists(path) - config_kwargs = {} - super(LMDBDataset, self).__init__() # pylint: disable=super-with-arguments - self.AseDB = AseDBDataset(config=dict(src=dataset_paths, **config_kwargs)) - self.r_max = r_max - self.z_table = z_table - - self.kwargs = kwargs - self.transform = kwargs["transform"] if "transform" in kwargs else None - - def __len__(self): - return len(self.AseDB) - - def __getitem__(self, index): - try: - atoms = self.AseDB.get_atoms(self.AseDB.ids[index]) - except Exception as e: # pylint: disable=broad-except - print(f"Error in index {index}") - print(e) - return None - assert np.sum(atoms.get_cell() == atoms.cell) == 9 - - if hasattr(atoms, "calc") and hasattr(atoms.calc, "results"): - if "energy" in atoms.calc.results: - atoms.info[DefaultKeys.ENERGY.value] = atoms.calc.results["energy"] - if "forces" in atoms.calc.results: - atoms.arrays[DefaultKeys.FORCES.value] = atoms.calc.results["forces"] - if "stress" in atoms.calc.results: - atoms.info[DefaultKeys.STRESS.value] = atoms.calc.results["stress"] - - config = config_from_atoms( - atoms, - key_specification=KeySpecification.from_defaults(), - ) - - # Set head if not already set - if config.head == "Default": - config.head = self.kwargs.get("head", "Default") - - try: - atomic_data = AtomicData.from_config( - config, - z_table=self.z_table, - cutoff=self.r_max, - heads=self.kwargs.get("heads", ["Default"]), - ) - except Exception as e: # pylint: disable=broad-except - print(f"Error in index {index}") - print(e) - - if self.transform: - atomic_data = self.transform(atomic_data) - return atomic_data +import os + +import numpy as np +from torch.utils.data import Dataset + +from mace.data.atomic_data import AtomicData +from mace.data.utils import KeySpecification, config_from_atoms +from mace.tools.default_keys import DefaultKeys +from mace.tools.fairchem_dataset import AseDBDataset + + +class LMDBDataset(Dataset): + def __init__(self, file_path, r_max, z_table, **kwargs): + dataset_paths = file_path.split(":") # using : split multiple paths + # make sure each of the path exist + for path in dataset_paths: + assert os.path.exists(path) + config_kwargs = {} + super(LMDBDataset, self).__init__() # pylint: disable=super-with-arguments + self.AseDB = AseDBDataset(config=dict(src=dataset_paths, **config_kwargs)) + self.r_max = r_max + self.z_table = z_table + + self.kwargs = kwargs + self.transform = kwargs["transform"] if "transform" in kwargs else None + + def __len__(self): + return len(self.AseDB) + + def __getitem__(self, index): + try: + atoms = self.AseDB.get_atoms(self.AseDB.ids[index]) + except Exception as e: # pylint: disable=broad-except + print(f"Error in index {index}") + print(e) + return None + assert np.sum(atoms.get_cell() == atoms.cell) == 9 + + if hasattr(atoms, "calc") and hasattr(atoms.calc, "results"): + if "energy" in atoms.calc.results: + atoms.info[DefaultKeys.ENERGY.value] = atoms.calc.results["energy"] + if "forces" in atoms.calc.results: + atoms.arrays[DefaultKeys.FORCES.value] = atoms.calc.results["forces"] + if "stress" in atoms.calc.results: + atoms.info[DefaultKeys.STRESS.value] = atoms.calc.results["stress"] + + config = config_from_atoms( + atoms, + key_specification=KeySpecification.from_defaults(), + ) + + # Set head if not already set + if config.head == "Default": + config.head = self.kwargs.get("head", "Default") + + try: + atomic_data = AtomicData.from_config( + config, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.kwargs.get("heads", ["Default"]), + ) + except Exception as e: # pylint: disable=broad-except + print(f"Error in index {index}") + print(e) + + if self.transform: + atomic_data = self.transform(atomic_data) + return atomic_data diff --git a/mace-bench/3rdparty/mace/mace/data/neighborhood.py b/mace-bench/3rdparty/mace/mace/data/neighborhood.py index cd46352..0372896 100644 --- a/mace-bench/3rdparty/mace/mace/data/neighborhood.py +++ b/mace-bench/3rdparty/mace/mace/data/neighborhood.py @@ -1,66 +1,66 @@ -from typing import Optional, Tuple - -import numpy as np -from matscipy.neighbours import neighbour_list - - -def get_neighborhood( - positions: np.ndarray, # [num_positions, 3] - cutoff: float, - pbc: Optional[Tuple[bool, bool, bool]] = None, - cell: Optional[np.ndarray] = None, # [3, 3] - true_self_interaction=False, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - if pbc is None: - pbc = (False, False, False) - - if cell is None or cell.any() == np.zeros((3, 3)).any(): - cell = np.identity(3, dtype=float) - - assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) - assert cell.shape == (3, 3) - - pbc_x = pbc[0] - pbc_y = pbc[1] - pbc_z = pbc[2] - identity = np.identity(3, dtype=float) - max_positions = np.max(np.absolute(positions)) + 1 - # Extend cell in non-periodic directions - # For models with more than 5 layers, the multiplicative constant needs to be increased. - # temp_cell = np.copy(cell) - if not pbc_x: - cell[0, :] = max_positions * 5 * cutoff * identity[0, :] - if not pbc_y: - cell[1, :] = max_positions * 5 * cutoff * identity[1, :] - if not pbc_z: - cell[2, :] = max_positions * 5 * cutoff * identity[2, :] - - sender, receiver, unit_shifts = neighbour_list( - quantities="ijS", - pbc=pbc, - cell=cell, - positions=positions, - cutoff=cutoff, - # self_interaction=True, # we want edges from atom to itself in different periodic images - # use_scaled_positions=False, # positions are not scaled positions - ) - - if not true_self_interaction: - # Eliminate self-edges that don't cross periodic boundaries - true_self_edge = sender == receiver - true_self_edge &= np.all(unit_shifts == 0, axis=1) - keep_edge = ~true_self_edge - - # Note: after eliminating self-edges, it can be that no edges remain in this system - sender = sender[keep_edge] - receiver = receiver[keep_edge] - unit_shifts = unit_shifts[keep_edge] - - # Build output - edge_index = np.stack((sender, receiver)) # [2, n_edges] - - # From the docs: With the shift vector S, the distances D between atoms can be computed from - # D = positions[j]-positions[i]+S.dot(cell) - shifts = np.dot(unit_shifts, cell) # [n_edges, 3] - - return edge_index, shifts, unit_shifts, cell +from typing import Optional, Tuple + +import numpy as np +from matscipy.neighbours import neighbour_list + + +def get_neighborhood( + positions: np.ndarray, # [num_positions, 3] + cutoff: float, + pbc: Optional[Tuple[bool, bool, bool]] = None, + cell: Optional[np.ndarray] = None, # [3, 3] + true_self_interaction=False, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + if pbc is None: + pbc = (False, False, False) + + if cell is None or cell.any() == np.zeros((3, 3)).any(): + cell = np.identity(3, dtype=float) + + assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) + assert cell.shape == (3, 3) + + pbc_x = pbc[0] + pbc_y = pbc[1] + pbc_z = pbc[2] + identity = np.identity(3, dtype=float) + max_positions = np.max(np.absolute(positions)) + 1 + # Extend cell in non-periodic directions + # For models with more than 5 layers, the multiplicative constant needs to be increased. + # temp_cell = np.copy(cell) + if not pbc_x: + cell[0, :] = max_positions * 5 * cutoff * identity[0, :] + if not pbc_y: + cell[1, :] = max_positions * 5 * cutoff * identity[1, :] + if not pbc_z: + cell[2, :] = max_positions * 5 * cutoff * identity[2, :] + + sender, receiver, unit_shifts = neighbour_list( + quantities="ijS", + pbc=pbc, + cell=cell, + positions=positions, + cutoff=cutoff, + # self_interaction=True, # we want edges from atom to itself in different periodic images + # use_scaled_positions=False, # positions are not scaled positions + ) + + if not true_self_interaction: + # Eliminate self-edges that don't cross periodic boundaries + true_self_edge = sender == receiver + true_self_edge &= np.all(unit_shifts == 0, axis=1) + keep_edge = ~true_self_edge + + # Note: after eliminating self-edges, it can be that no edges remain in this system + sender = sender[keep_edge] + receiver = receiver[keep_edge] + unit_shifts = unit_shifts[keep_edge] + + # Build output + edge_index = np.stack((sender, receiver)) # [2, n_edges] + + # From the docs: With the shift vector S, the distances D between atoms can be computed from + # D = positions[j]-positions[i]+S.dot(cell) + shifts = np.dot(unit_shifts, cell) # [n_edges, 3] + + return edge_index, shifts, unit_shifts, cell diff --git a/mace-bench/3rdparty/mace/mace/data/utils.py b/mace-bench/3rdparty/mace/mace/data/utils.py index 6afa3bb..947ee60 100644 --- a/mace-bench/3rdparty/mace/mace/data/utils.py +++ b/mace-bench/3rdparty/mace/mace/data/utils.py @@ -1,368 +1,368 @@ -########################################################################################### -# Data parsing utilities -# Authors: Ilyes Batatia, Gregor Simm and David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import ase.data -import ase.io -import h5py -import numpy as np - -from mace.tools import AtomicNumberTable, DefaultKeys - -Positions = np.ndarray # [..., 3] -Cell = np.ndarray # [3,3] -Pbc = tuple # (3,) - -DEFAULT_CONFIG_TYPE = "Default" -DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} - - -@dataclass -class KeySpecification: - info_keys: Dict[str, str] = field(default_factory=dict) - arrays_keys: Dict[str, str] = field(default_factory=dict) - - def update( - self, - info_keys: Optional[Dict[str, str]] = None, - arrays_keys: Optional[Dict[str, str]] = None, - ): - if info_keys is not None: - self.info_keys.update(info_keys) - if arrays_keys is not None: - self.arrays_keys.update(arrays_keys) - return self - - @classmethod - def from_defaults(cls): - instance = cls() - return update_keyspec_from_kwargs(instance, DefaultKeys.keydict()) - - -def update_keyspec_from_kwargs( - keyspec: KeySpecification, keydict: Dict[str, str] -) -> KeySpecification: - # convert command line style property_key arguments into a keyspec - infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] - arrays = ["forces_key", "charges_key"] - info_keys = {} - arrays_keys = {} - for key in infos: - if key in keydict: - info_keys[key[:-4]] = keydict[key] - for key in arrays: - if key in keydict: - arrays_keys[key[:-4]] = keydict[key] - keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys) - return keyspec - - -@dataclass -class Configuration: - atomic_numbers: np.ndarray - positions: Positions # Angstrom - properties: Dict[str, Any] - property_weights: Dict[str, float] - cell: Optional[Cell] = None - pbc: Optional[Pbc] = None - - weight: float = 1.0 # weight of config in loss - config_type: str = DEFAULT_CONFIG_TYPE # config_type of config - head: str = "Default" # head used to compute the config - - -Configurations = List[Configuration] - - -def random_train_valid_split( - items: Sequence, valid_fraction: float, seed: int, work_dir: str -) -> Tuple[List, List]: - assert 0.0 < valid_fraction < 1.0 - - size = len(items) - train_size = size - int(valid_fraction * size) - - indices = list(range(size)) - rng = np.random.default_rng(seed) - rng.shuffle(indices) - if len(indices[train_size:]) < 10: - logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" - ) - else: - # Save indices to file - with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: - for index in indices[train_size:]: - f.write(f"{index}\n") - - logging.info( - f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" - ) - - return ( - [items[i] for i in indices[:train_size]], - [items[i] for i in indices[train_size:]], - ) - - -def config_from_atoms_list( - atoms_list: List[ase.Atoms], - key_specification: KeySpecification, - config_type_weights: Optional[Dict[str, float]] = None, - head_name: str = "Default", -) -> Configurations: - """Convert list of ase.Atoms into Configurations""" - if config_type_weights is None: - config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - - all_configs = [] - for atoms in atoms_list: - all_configs.append( - config_from_atoms( - atoms, - key_specification=key_specification, - config_type_weights=config_type_weights, - head_name=head_name, - ) - ) - return all_configs - - -def config_from_atoms( - atoms: ase.Atoms, - key_specification: KeySpecification = KeySpecification(), - config_type_weights: Optional[Dict[str, float]] = None, - head_name: str = "Default", -) -> Configuration: - """Convert ase.Atoms to Configuration""" - if config_type_weights is None: - config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS - - atomic_numbers = np.array( - [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] - ) - pbc = tuple(atoms.get_pbc()) - cell = np.array(atoms.get_cell()) - config_type = atoms.info.get("config_type", "Default") - weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( - config_type, 1.0 - ) - - properties = {} - property_weights = {} - for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): - property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0) - - for name, atoms_key in key_specification.info_keys.items(): - properties[name] = atoms.info.get(atoms_key, None) - if not atoms_key in atoms.info: - property_weights[name] = 0.0 - - for name, atoms_key in key_specification.arrays_keys.items(): - properties[name] = atoms.arrays.get(atoms_key, None) - if not atoms_key in atoms.arrays: - property_weights[name] = 0.0 - - return Configuration( - atomic_numbers=atomic_numbers, - positions=atoms.get_positions(), - properties=properties, - weight=weight, - property_weights=property_weights, - head=head_name, - config_type=config_type, - pbc=pbc, - cell=cell, - ) - - -def test_config_types( - test_configs: Configurations, -) -> List[Tuple[str, List[Configuration]]]: - """Split test set based on config_type-s""" - test_by_ct = [] - all_cts = [] - for conf in test_configs: - config_type_name = conf.config_type + "_" + conf.head - if config_type_name not in all_cts: - all_cts.append(config_type_name) - test_by_ct.append((config_type_name, [conf])) - else: - ind = all_cts.index(config_type_name) - test_by_ct[ind][1].append(conf) - return test_by_ct - - -def load_from_xyz( - file_path: str, - key_specification: KeySpecification, - head_name: str = "Default", - config_type_weights: Optional[Dict] = None, - extract_atomic_energies: bool = False, - keep_isolated_atoms: bool = False, -) -> Tuple[Dict[int, float], Configurations]: - atoms_list = ase.io.read(file_path, index=":") - energy_key = key_specification.info_keys["energy"] - forces_key = key_specification.arrays_keys["forces"] - stress_key = key_specification.info_keys["stress"] - head_key = key_specification.info_keys["head"] - if energy_key == "energy": - logging.warning( - "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." - ) - key_specification.info_keys["energy"] = "REF_energy" - for atoms in atoms_list: - try: - atoms.info["REF_energy"] = atoms.get_potential_energy() - except Exception as e: # pylint: disable=W0703 - logging.error(f"Failed to extract energy: {e}") - atoms.info["REF_energy"] = None - if forces_key == "forces": - logging.warning( - "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." - ) - key_specification.arrays_keys["forces"] = "REF_forces" - for atoms in atoms_list: - try: - atoms.arrays["REF_forces"] = atoms.get_forces() - except Exception as e: # pylint: disable=W0703 - logging.error(f"Failed to extract forces: {e}") - atoms.arrays["REF_forces"] = None - if stress_key == "stress": - logging.warning( - "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." - ) - key_specification.info_keys["stress"] = "REF_stress" - for atoms in atoms_list: - try: - atoms.info["REF_stress"] = atoms.get_stress() - except Exception as e: # pylint: disable=W0703 - atoms.info["REF_stress"] = None - if not isinstance(atoms_list, list): - atoms_list = [atoms_list] - - atomic_energies_dict = {} - if extract_atomic_energies: - atoms_without_iso_atoms = [] - - for idx, atoms in enumerate(atoms_list): - atoms.info[head_key] = head_name - isolated_atom_config = ( - len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" - ) - if isolated_atom_config: - atomic_number = int(atoms.get_atomic_numbers()[0]) - if energy_key in atoms.info.keys(): - atomic_energies_dict[atomic_number] = float(atoms.info[energy_key]) - else: - logging.warning( - f"Configuration '{idx}' is marked as 'IsolatedAtom' " - "but does not contain an energy. Zero energy will be used." - ) - atomic_energies_dict[atomic_number] = 0.0 - else: - atoms_without_iso_atoms.append(atoms) - - if len(atomic_energies_dict) > 0: - logging.info("Using isolated atom energies from training file") - if not keep_isolated_atoms: - atoms_list = atoms_without_iso_atoms - - for atoms in atoms_list: - atoms.info[head_key] = head_name - - configs = config_from_atoms_list( - atoms_list, - config_type_weights=config_type_weights, - key_specification=key_specification, - head_name=head_name, - ) - return atomic_energies_dict, configs - - -def compute_average_E0s( - collections_train: Configurations, z_table: AtomicNumberTable -) -> Dict[int, float]: - """ - Function to compute the average interaction energy of each chemical element - returns dictionary of E0s - """ - len_train = len(collections_train) - len_zs = len(z_table) - A = np.zeros((len_train, len_zs)) - B = np.zeros(len_train) - for i in range(len_train): - B[i] = collections_train[i].properties["energy"] - for j, z in enumerate(z_table.zs): - A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) - try: - E0s = np.linalg.lstsq(A, B, rcond=None)[0] - atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): - atomic_energies_dict[z] = E0s[i] - except np.linalg.LinAlgError: - logging.error( - "Failed to compute E0s using least squares regression, using the same for all atoms" - ) - atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): - atomic_energies_dict[z] = 0.0 - return atomic_energies_dict - - -def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: - with h5py.File(out_name, "w") as f: - for i, data in enumerate(dataset): - save_AtomicData_to_HDF5(data, i, f) - - -def save_AtomicData_to_HDF5(data, i, h5_file) -> None: - grp = h5_file.create_group(f"config_{i}") - grp["num_nodes"] = data.num_nodes - grp["edge_index"] = data.edge_index - grp["positions"] = data.positions - grp["shifts"] = data.shifts - grp["unit_shifts"] = data.unit_shifts - grp["cell"] = data.cell - grp["node_attrs"] = data.node_attrs - grp["weight"] = data.weight - grp["energy_weight"] = data.energy_weight - grp["forces_weight"] = data.forces_weight - grp["stress_weight"] = data.stress_weight - grp["virials_weight"] = data.virials_weight - grp["forces"] = data.forces - grp["energy"] = data.energy - grp["stress"] = data.stress - grp["virials"] = data.virials - grp["dipole"] = data.dipole - grp["charges"] = data.charges - grp["head"] = data.head - - -def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: - grp = h5_file.create_group("config_batch_0") - for j, config in enumerate(configurations): - subgroup_name = f"config_{j}" - subgroup = grp.create_group(subgroup_name) - subgroup["atomic_numbers"] = write_value(config.atomic_numbers) - subgroup["positions"] = write_value(config.positions) - properties_subgrp = subgroup.create_group("properties") - for key, value in config.properties.items(): - properties_subgrp[key] = write_value(value) - subgroup["cell"] = write_value(config.cell) - subgroup["pbc"] = write_value(config.pbc) - subgroup["weight"] = write_value(config.weight) - weights_subgrp = subgroup.create_group("property_weights") - for key, value in config.property_weights.items(): - weights_subgrp[key] = write_value(value) - subgroup["config_type"] = write_value(config.config_type) - - -def write_value(value): - return value if value is not None else "None" +########################################################################################### +# Data parsing utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import ase.data +import ase.io +import h5py +import numpy as np + +from mace.tools import AtomicNumberTable, DefaultKeys + +Positions = np.ndarray # [..., 3] +Cell = np.ndarray # [3,3] +Pbc = tuple # (3,) + +DEFAULT_CONFIG_TYPE = "Default" +DEFAULT_CONFIG_TYPE_WEIGHTS = {DEFAULT_CONFIG_TYPE: 1.0} + + +@dataclass +class KeySpecification: + info_keys: Dict[str, str] = field(default_factory=dict) + arrays_keys: Dict[str, str] = field(default_factory=dict) + + def update( + self, + info_keys: Optional[Dict[str, str]] = None, + arrays_keys: Optional[Dict[str, str]] = None, + ): + if info_keys is not None: + self.info_keys.update(info_keys) + if arrays_keys is not None: + self.arrays_keys.update(arrays_keys) + return self + + @classmethod + def from_defaults(cls): + instance = cls() + return update_keyspec_from_kwargs(instance, DefaultKeys.keydict()) + + +def update_keyspec_from_kwargs( + keyspec: KeySpecification, keydict: Dict[str, str] +) -> KeySpecification: + # convert command line style property_key arguments into a keyspec + infos = ["energy_key", "stress_key", "virials_key", "dipole_key", "head_key"] + arrays = ["forces_key", "charges_key"] + info_keys = {} + arrays_keys = {} + for key in infos: + if key in keydict: + info_keys[key[:-4]] = keydict[key] + for key in arrays: + if key in keydict: + arrays_keys[key[:-4]] = keydict[key] + keyspec.update(info_keys=info_keys, arrays_keys=arrays_keys) + return keyspec + + +@dataclass +class Configuration: + atomic_numbers: np.ndarray + positions: Positions # Angstrom + properties: Dict[str, Any] + property_weights: Dict[str, float] + cell: Optional[Cell] = None + pbc: Optional[Pbc] = None + + weight: float = 1.0 # weight of config in loss + config_type: str = DEFAULT_CONFIG_TYPE # config_type of config + head: str = "Default" # head used to compute the config + + +Configurations = List[Configuration] + + +def random_train_valid_split( + items: Sequence, valid_fraction: float, seed: int, work_dir: str +) -> Tuple[List, List]: + assert 0.0 < valid_fraction < 1.0 + + size = len(items) + train_size = size - int(valid_fraction * size) + + indices = list(range(size)) + rng = np.random.default_rng(seed) + rng.shuffle(indices) + if len(indices[train_size:]) < 10: + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with following indices: {indices[train_size:]}" + ) + else: + # Save indices to file + with open(work_dir + f"/valid_indices_{seed}.txt", "w", encoding="utf-8") as f: + for index in indices[train_size:]: + f.write(f"{index}\n") + + logging.info( + f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {work_dir}/valid_indices_{seed}.txt" + ) + + return ( + [items[i] for i in indices[:train_size]], + [items[i] for i in indices[train_size:]], + ) + + +def config_from_atoms_list( + atoms_list: List[ase.Atoms], + key_specification: KeySpecification, + config_type_weights: Optional[Dict[str, float]] = None, + head_name: str = "Default", +) -> Configurations: + """Convert list of ase.Atoms into Configurations""" + if config_type_weights is None: + config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + + all_configs = [] + for atoms in atoms_list: + all_configs.append( + config_from_atoms( + atoms, + key_specification=key_specification, + config_type_weights=config_type_weights, + head_name=head_name, + ) + ) + return all_configs + + +def config_from_atoms( + atoms: ase.Atoms, + key_specification: KeySpecification = KeySpecification(), + config_type_weights: Optional[Dict[str, float]] = None, + head_name: str = "Default", +) -> Configuration: + """Convert ase.Atoms to Configuration""" + if config_type_weights is None: + config_type_weights = DEFAULT_CONFIG_TYPE_WEIGHTS + + atomic_numbers = np.array( + [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] + ) + pbc = tuple(atoms.get_pbc()) + cell = np.array(atoms.get_cell()) + config_type = atoms.info.get("config_type", "Default") + weight = atoms.info.get("config_weight", 1.0) * config_type_weights.get( + config_type, 1.0 + ) + + properties = {} + property_weights = {} + for name in list(key_specification.arrays_keys) + list(key_specification.info_keys): + property_weights[name] = atoms.info.get(f"config_{name}_weight", 1.0) + + for name, atoms_key in key_specification.info_keys.items(): + properties[name] = atoms.info.get(atoms_key, None) + if not atoms_key in atoms.info: + property_weights[name] = 0.0 + + for name, atoms_key in key_specification.arrays_keys.items(): + properties[name] = atoms.arrays.get(atoms_key, None) + if not atoms_key in atoms.arrays: + property_weights[name] = 0.0 + + return Configuration( + atomic_numbers=atomic_numbers, + positions=atoms.get_positions(), + properties=properties, + weight=weight, + property_weights=property_weights, + head=head_name, + config_type=config_type, + pbc=pbc, + cell=cell, + ) + + +def test_config_types( + test_configs: Configurations, +) -> List[Tuple[str, List[Configuration]]]: + """Split test set based on config_type-s""" + test_by_ct = [] + all_cts = [] + for conf in test_configs: + config_type_name = conf.config_type + "_" + conf.head + if config_type_name not in all_cts: + all_cts.append(config_type_name) + test_by_ct.append((config_type_name, [conf])) + else: + ind = all_cts.index(config_type_name) + test_by_ct[ind][1].append(conf) + return test_by_ct + + +def load_from_xyz( + file_path: str, + key_specification: KeySpecification, + head_name: str = "Default", + config_type_weights: Optional[Dict] = None, + extract_atomic_energies: bool = False, + keep_isolated_atoms: bool = False, +) -> Tuple[Dict[int, float], Configurations]: + atoms_list = ase.io.read(file_path, index=":") + energy_key = key_specification.info_keys["energy"] + forces_key = key_specification.arrays_keys["forces"] + stress_key = key_specification.info_keys["stress"] + head_key = key_specification.info_keys["head"] + if energy_key == "energy": + logging.warning( + "Since ASE version 3.23.0b1, using energy_key 'energy' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'energy' to 'REF_energy'. You need to use --energy_key='REF_energy' to specify the chosen key name." + ) + key_specification.info_keys["energy"] = "REF_energy" + for atoms in atoms_list: + try: + atoms.info["REF_energy"] = atoms.get_potential_energy() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to extract energy: {e}") + atoms.info["REF_energy"] = None + if forces_key == "forces": + logging.warning( + "Since ASE version 3.23.0b1, using forces_key 'forces' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'forces' to 'REF_forces'. You need to use --forces_key='REF_forces' to specify the chosen key name." + ) + key_specification.arrays_keys["forces"] = "REF_forces" + for atoms in atoms_list: + try: + atoms.arrays["REF_forces"] = atoms.get_forces() + except Exception as e: # pylint: disable=W0703 + logging.error(f"Failed to extract forces: {e}") + atoms.arrays["REF_forces"] = None + if stress_key == "stress": + logging.warning( + "Since ASE version 3.23.0b1, using stress_key 'stress' is no longer safe when communicating between MACE and ASE. We recommend using a different key, rewriting 'stress' to 'REF_stress'. You need to use --stress_key='REF_stress' to specify the chosen key name." + ) + key_specification.info_keys["stress"] = "REF_stress" + for atoms in atoms_list: + try: + atoms.info["REF_stress"] = atoms.get_stress() + except Exception as e: # pylint: disable=W0703 + atoms.info["REF_stress"] = None + if not isinstance(atoms_list, list): + atoms_list = [atoms_list] + + atomic_energies_dict = {} + if extract_atomic_energies: + atoms_without_iso_atoms = [] + + for idx, atoms in enumerate(atoms_list): + atoms.info[head_key] = head_name + isolated_atom_config = ( + len(atoms) == 1 and atoms.info.get("config_type") == "IsolatedAtom" + ) + if isolated_atom_config: + atomic_number = int(atoms.get_atomic_numbers()[0]) + if energy_key in atoms.info.keys(): + atomic_energies_dict[atomic_number] = float(atoms.info[energy_key]) + else: + logging.warning( + f"Configuration '{idx}' is marked as 'IsolatedAtom' " + "but does not contain an energy. Zero energy will be used." + ) + atomic_energies_dict[atomic_number] = 0.0 + else: + atoms_without_iso_atoms.append(atoms) + + if len(atomic_energies_dict) > 0: + logging.info("Using isolated atom energies from training file") + if not keep_isolated_atoms: + atoms_list = atoms_without_iso_atoms + + for atoms in atoms_list: + atoms.info[head_key] = head_name + + configs = config_from_atoms_list( + atoms_list, + config_type_weights=config_type_weights, + key_specification=key_specification, + head_name=head_name, + ) + return atomic_energies_dict, configs + + +def compute_average_E0s( + collections_train: Configurations, z_table: AtomicNumberTable +) -> Dict[int, float]: + """ + Function to compute the average interaction energy of each chemical element + returns dictionary of E0s + """ + len_train = len(collections_train) + len_zs = len(z_table) + A = np.zeros((len_train, len_zs)) + B = np.zeros(len_train) + for i in range(len_train): + B[i] = collections_train[i].properties["energy"] + for j, z in enumerate(z_table.zs): + A[i, j] = np.count_nonzero(collections_train[i].atomic_numbers == z) + try: + E0s = np.linalg.lstsq(A, B, rcond=None)[0] + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = E0s[i] + except np.linalg.LinAlgError: + logging.error( + "Failed to compute E0s using least squares regression, using the same for all atoms" + ) + atomic_energies_dict = {} + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = 0.0 + return atomic_energies_dict + + +def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: + with h5py.File(out_name, "w") as f: + for i, data in enumerate(dataset): + save_AtomicData_to_HDF5(data, i, f) + + +def save_AtomicData_to_HDF5(data, i, h5_file) -> None: + grp = h5_file.create_group(f"config_{i}") + grp["num_nodes"] = data.num_nodes + grp["edge_index"] = data.edge_index + grp["positions"] = data.positions + grp["shifts"] = data.shifts + grp["unit_shifts"] = data.unit_shifts + grp["cell"] = data.cell + grp["node_attrs"] = data.node_attrs + grp["weight"] = data.weight + grp["energy_weight"] = data.energy_weight + grp["forces_weight"] = data.forces_weight + grp["stress_weight"] = data.stress_weight + grp["virials_weight"] = data.virials_weight + grp["forces"] = data.forces + grp["energy"] = data.energy + grp["stress"] = data.stress + grp["virials"] = data.virials + grp["dipole"] = data.dipole + grp["charges"] = data.charges + grp["head"] = data.head + + +def save_configurations_as_HDF5(configurations: Configurations, _, h5_file) -> None: + grp = h5_file.create_group("config_batch_0") + for j, config in enumerate(configurations): + subgroup_name = f"config_{j}" + subgroup = grp.create_group(subgroup_name) + subgroup["atomic_numbers"] = write_value(config.atomic_numbers) + subgroup["positions"] = write_value(config.positions) + properties_subgrp = subgroup.create_group("properties") + for key, value in config.properties.items(): + properties_subgrp[key] = write_value(value) + subgroup["cell"] = write_value(config.cell) + subgroup["pbc"] = write_value(config.pbc) + subgroup["weight"] = write_value(config.weight) + weights_subgrp = subgroup.create_group("property_weights") + for key, value in config.property_weights.items(): + weights_subgrp[key] = write_value(value) + subgroup["config_type"] = write_value(config.config_type) + + +def write_value(value): + return value if value is not None else "None" diff --git a/mace-bench/3rdparty/mace/mace/modules/__init__.py b/mace-bench/3rdparty/mace/mace/modules/__init__.py index d816220..40a29d3 100644 --- a/mace-bench/3rdparty/mace/mace/modules/__init__.py +++ b/mace-bench/3rdparty/mace/mace/modules/__init__.py @@ -1,100 +1,100 @@ -from typing import Callable, Dict, Optional, Type - -import torch - -from .blocks import ( - AtomicEnergiesBlock, - EquivariantProductBasisBlock, - InteractionBlock, - LinearDipoleReadoutBlock, - LinearNodeEmbeddingBlock, - LinearReadoutBlock, - NonLinearDipoleReadoutBlock, - NonLinearReadoutBlock, - RadialEmbeddingBlock, - RealAgnosticAttResidualInteractionBlock, - RealAgnosticDensityInteractionBlock, - RealAgnosticDensityResidualInteractionBlock, - RealAgnosticInteractionBlock, - RealAgnosticResidualInteractionBlock, - ScaleShiftBlock, -) -from .loss import ( - DipoleSingleLoss, - UniversalLoss, - WeightedEnergyForcesDipoleLoss, - WeightedEnergyForcesL1L2Loss, - WeightedEnergyForcesLoss, - WeightedEnergyForcesStressLoss, - WeightedEnergyForcesVirialsLoss, - WeightedForcesLoss, - WeightedHuberEnergyForcesStressLoss, -) -from .models import MACE, AtomicDipolesMACE, EnergyDipolesMACE, ScaleShiftMACE -from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis -from .symmetric_contraction import SymmetricContraction -from .utils import ( - compute_avg_num_neighbors, - compute_fixed_charge_dipole, - compute_mean_rms_energy_forces, - compute_mean_std_atomic_inter_energy, - compute_rms_dipoles, - compute_statistics, -) - -interaction_classes: Dict[str, Type[InteractionBlock]] = { - "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, - "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, - "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, - "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, - "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, -} - -scaling_classes: Dict[str, Callable] = { - "std_scaling": compute_mean_std_atomic_inter_energy, - "rms_forces_scaling": compute_mean_rms_energy_forces, - "rms_dipoles_scaling": compute_rms_dipoles, -} - -gate_dict: Dict[str, Optional[Callable]] = { - "abs": torch.abs, - "tanh": torch.tanh, - "silu": torch.nn.functional.silu, - "None": None, -} - -__all__ = [ - "AtomicEnergiesBlock", - "RadialEmbeddingBlock", - "ZBLBasis", - "LinearNodeEmbeddingBlock", - "LinearReadoutBlock", - "EquivariantProductBasisBlock", - "ScaleShiftBlock", - "LinearDipoleReadoutBlock", - "NonLinearDipoleReadoutBlock", - "InteractionBlock", - "NonLinearReadoutBlock", - "PolynomialCutoff", - "BesselBasis", - "GaussianBasis", - "MACE", - "ScaleShiftMACE", - "AtomicDipolesMACE", - "EnergyDipolesMACE", - "WeightedEnergyForcesLoss", - "WeightedForcesLoss", - "WeightedEnergyForcesVirialsLoss", - "WeightedEnergyForcesStressLoss", - "DipoleSingleLoss", - "WeightedEnergyForcesDipoleLoss", - "WeightedHuberEnergyForcesStressLoss", - "UniversalLoss", - "WeightedEnergyForcesL1L2Loss", - "SymmetricContraction", - "interaction_classes", - "compute_mean_std_atomic_inter_energy", - "compute_avg_num_neighbors", - "compute_statistics", - "compute_fixed_charge_dipole", -] +from typing import Callable, Dict, Optional, Type + +import torch + +from .blocks import ( + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + RealAgnosticAttResidualInteractionBlock, + RealAgnosticDensityInteractionBlock, + RealAgnosticDensityResidualInteractionBlock, + RealAgnosticInteractionBlock, + RealAgnosticResidualInteractionBlock, + ScaleShiftBlock, +) +from .loss import ( + DipoleSingleLoss, + UniversalLoss, + WeightedEnergyForcesDipoleLoss, + WeightedEnergyForcesL1L2Loss, + WeightedEnergyForcesLoss, + WeightedEnergyForcesStressLoss, + WeightedEnergyForcesVirialsLoss, + WeightedForcesLoss, + WeightedHuberEnergyForcesStressLoss, +) +from .models import MACE, AtomicDipolesMACE, EnergyDipolesMACE, ScaleShiftMACE +from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis +from .symmetric_contraction import SymmetricContraction +from .utils import ( + compute_avg_num_neighbors, + compute_fixed_charge_dipole, + compute_mean_rms_energy_forces, + compute_mean_std_atomic_inter_energy, + compute_rms_dipoles, + compute_statistics, +) + +interaction_classes: Dict[str, Type[InteractionBlock]] = { + "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, + "RealAgnosticAttResidualInteractionBlock": RealAgnosticAttResidualInteractionBlock, + "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, + "RealAgnosticDensityInteractionBlock": RealAgnosticDensityInteractionBlock, + "RealAgnosticDensityResidualInteractionBlock": RealAgnosticDensityResidualInteractionBlock, +} + +scaling_classes: Dict[str, Callable] = { + "std_scaling": compute_mean_std_atomic_inter_energy, + "rms_forces_scaling": compute_mean_rms_energy_forces, + "rms_dipoles_scaling": compute_rms_dipoles, +} + +gate_dict: Dict[str, Optional[Callable]] = { + "abs": torch.abs, + "tanh": torch.tanh, + "silu": torch.nn.functional.silu, + "None": None, +} + +__all__ = [ + "AtomicEnergiesBlock", + "RadialEmbeddingBlock", + "ZBLBasis", + "LinearNodeEmbeddingBlock", + "LinearReadoutBlock", + "EquivariantProductBasisBlock", + "ScaleShiftBlock", + "LinearDipoleReadoutBlock", + "NonLinearDipoleReadoutBlock", + "InteractionBlock", + "NonLinearReadoutBlock", + "PolynomialCutoff", + "BesselBasis", + "GaussianBasis", + "MACE", + "ScaleShiftMACE", + "AtomicDipolesMACE", + "EnergyDipolesMACE", + "WeightedEnergyForcesLoss", + "WeightedForcesLoss", + "WeightedEnergyForcesVirialsLoss", + "WeightedEnergyForcesStressLoss", + "DipoleSingleLoss", + "WeightedEnergyForcesDipoleLoss", + "WeightedHuberEnergyForcesStressLoss", + "UniversalLoss", + "WeightedEnergyForcesL1L2Loss", + "SymmetricContraction", + "interaction_classes", + "compute_mean_std_atomic_inter_energy", + "compute_avg_num_neighbors", + "compute_statistics", + "compute_fixed_charge_dipole", +] diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index fb761d632b28e4aa34bd81e2821cb4a84baca5d4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2297 zcmZ{lNpssa6o5&J+APc3>_lGTC3fN{jxF0s)6LD)TbCKv^+ZiO#f1ZrkVF?CRRGFV z{Vh55mP3#IGr0DYztCl-c%Ws=wIdC`A>YRL9v+CeSe&r*@Ap66_9~N>^(Q=x|1$dF zbA5@9ZCS6ZzzS??F&o*;K@Lly6myYlgWREMlnzpXOEWBsvMh&kERXVR0!^?2DzHg3 z$%?4RrqC2Cp%N>jGOM5ptD-8ap&FY;(`*LK=vb*BP4BR~=q{_HI-5naY!1z_c{I-! z&;q-M?y*I*$d=HOvCYtBwt`mJeRQ9#qE)tr){Hz$AFzk$AzMf5>=Amz8mM99Ir^Aw zpbc}Mr%%`>+GJa3i#1V`ZKG|rgLc?1+O;h*P0FN3Dx?}ryh)>{#39d!OY*Z;$C(b{R>gywYeSH4O`KN&sOucc@p0P_2`!Y^x&+1x1 z`8OMV7C0@dQF#X|&~R1m5L1fVlqmNo^rgyu+m~S+VXEBkhJAASPOpj8D5=`Mj9KU( zM}&7nA`WQm|D+aklOH@awJ@4VV`}&)YEK)Dyg+=dLaW3n{DZZu+=V%Rc&Mi0g}iX4=>&=39?U>fxV?o6Yg zc`79&S0&HGQ53@n#*y$mz5_=ixScF z^garu=WX?eDz~S*rKK<602JFAb5_nda-LM`g}OayFWb(9V<-Q)znwhX-TBwex^_y;J zQ+jJolH#6^R-}}v*nWJd5TjJkucYd!V zB9inL7ws?5njGzWBet>G7jTp0wzcXGD6-}xJs_UH8Kg4WD2H8 z7P3MfELjO2W9TUdEc^iiSrZx(O}P}#OE;Vg zWIN|&MQ^buC&OKlg|Fl}lWPof$xFrf6Ayf{?ae0MVSud6f-9x@CSa zRPTxMj8A$)E$gcBzvz?P##q`mE!)wxCCAyIwqCBQ#<|odIt}kp)7G8FIiu+dj6N7~ zsMGjds7$B!gE3-7O*Lq-qHno8WEurnZP>a90FmBNSYTLqd*RuiEuUXYv-Jy!Qv#pr*s$wGN zn-2?I8+l2-k z-lwIRC!S{hSu^ulb9%iw_w8Xg@UeGaO6@+~o7i*qGxG-{3;P2L2feqCqyTqFO6kSs zz+!VSci1IQd0)hK)4ST<;%@aIkvWoNZqXw}62IJed8f$-@{-{2;N^qYk&%PQ$m>Yv xAd+cD;)fys4eyK4&-1&!-LLm52fdT~@yUbm)RE-(c>n1c4|{*=g8RJY{{W<8@i+hg diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-310.pyc deleted file mode 100644 index 7da39a02fd504793d0c0ed9557483b8cf3022a7b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20740 zcmd^n32+=&dfxOgJ@?=s2#^FxffOZiBns5I+JjO=QKBM8!YWH@H#NKRWH{Xe%xDhi zb%PR+fwM8SUd2@`$8loEUate%ISN(YWaGqTCvmy#@|DYRF5ByN&PqA4ayp^+5G9fM zzW?>~%;2EN)w*&us(-w@U%&VN-~Ye=y`FBNkWJwF(ob$K*=vc!zh`9lau9hEw|~)0 zB;15+v=VKjY#0oiEwgQvt+rjZ+sSgW?UbE%s+?-4%W28CTA6Yd<*Zh&oX3B=RcIH> z#r9ZvO!AYh@%BV{Lc&gKvc0Fgr#)4kYVR%YZBLh{+cV`E$xF5NwfC3zOE}#+(4H;N zN;uOx*gjM~gmBi)wGOxMDBsbZE6=r$l#fVSzIC*HXZgS9+L3G%}4P6QFrc#&GLuh(ns8*D1Fp>NJitaIQ>rd7}AeL=^u{M z?{e=(`rT3bN8~uo2`!L$d*ShWHM$4V`$- z`AgMhuhLM;TlT{{g5^qgC5WPx*BZf6rP0Aqp_k~2kAWIHz1Z>m#(7ok`1P)8hqAyeGG(-j=$J#t#-QYMzwWjCFs`cVc}f29t;|D z{F-Vk2Q|*q@Rg1t{__ZUiL&7)%4WuJbFS%GP207E+FsIg_9yDj=M#0qO}fq-g>uSG zdFf`xO_SDbH{)jCu*=!cC+tKy2cF1>`}F89VEUeEU+`SF(OG<|)vdj<5N7naU<|@+ z6vXg^x!Q{NDu$dhTeFAy$#4?4h5PDZkRWj>v0+?E7>U5_89lRSxyEIym#A27;#^|W zT-dbK9LlL9Bu7Cu?J()DaKh6SJhu^4DyoR|&}j*ilbAX;?dsZ6wcBh|5vqCjFRN}7{C9GB6>Q(BoH_TRPQT!FYD*^{R_=0D1*?)4 z5in@W^G{yjc=+?ntLixVau4pSoRX|#8jg{_v36v)F`VzeM40VhfGc24<%f>)f)&-N zK^4Ro!yLCWxUW)|^b#A1o^i#%I2j9@hAJVgWU71dPdz|#e@x~P3R7Um*Q&~W2oayH zW$m@OUHcUkRu7}#W}buHskXgJCCpYR$#`po^Oef0E7g|HQS4OcJSU=}sLxcMgd-9r zgRZJAg^m_d^(fN{(S+2zRtvK@=2#y+b%K<>ir9|TD!OvirLHqGm{WxnUw#dpn| zj!nKNC3hE#3#i@~3(wlHU5uM+*3DQjC?v@d@o?M}#Uo?71uYzI2Kf+@{ZNA7K9?(w!Js=4pCZYL&w;=2o$qOMwczIXZPKMu0ASGMR>KqfEBsoeV z!z807W5sd)7}HOJtj%uk$bi?3*TWjNyzZm9y2Dc)s2nVP~5P~cUbDd`D zQzU~alOAwZ&*IkHm0UZpLm!6uN_}Xwk-@w_fsBu^C5}NlkxrV%mYFZw?{V(a-~`*A zd{?nX4fn`ZjAH&p1h&t=Yt=2+7HCWgbuDlR{$uVz`f|{wEL_;3=>)Q89Z-VrgIlTN6#Q=S1k544h_4QmjSD z!DW_BqZG^7aj6V3xMVFTN+6JB4*x(UaV44hJ_B`RcAK>7#2~6f^;E^-+cdxv8l`s{t!3#&9Lza*9PL z=<0EjL6-)?%TZdu?GrvAylE5jVZS{9Zu4C`X})7R<~1i(eB0b>->?nyMjE$mPF`26 zzfF{>PI+yyZ~)Kw7{rr+aK}r<7LJDSGz2-*%tDkwVc}4S?Q%{8G!F(Z^evpH8_V66 zHw@ipBhk#cl2z$;TB|olZl5?bEQP8Vm>aNau3-8rcE^#Jsbc}y0nA~;>aoP?aYW+m zh0uQR^#{Axau2@#fH!yF93xr?Qv*UQy zTq0{_`*>`T7!ei|_4^mY7H;9zEy^RZOhnWMTr-c-!cLz>KoZ4COGRbVqJi{6BrT;K=Je8UCKin#-Ai9Grcm0=Tu$PNnKz9lgm)u%C4pEz zDD=>eX1R`;~?9fDhz}SvCsb$M=%qOAXcD@<|nbRd)0KYM1lE0duKsf zE?Mf+h^WtigvPMtFJqNCwdi%c*O%23Kg6avOXg<}B-S3g&B5xovaQ$gyi@(kygRzr zZzlBo4%Aw|t5uX#FN2f}ZIXdvB*WEkKRyKa&#?#8^jb?FVMy#M`&qOUS}m^=PI|FtC8%_utr{I|wWa&Mka8h< ze3VP{7t2G;j8bMt^D71s<+6C4gocN#KrpjnqY1+coO2=aUATQ(7umF9XysxX;|4Un zMLK0c<0EbwdD?346jRz_Gm*}m{CjK1Ze_N$<+h_9ijB4d=*3o}4RHuG*UqxM!7MMK z)L>bA`UORlZV%cqqi#_1w69xK;>!k!2L0-}|PQkvVUDZr8rEWrHnWD0= ze10Pjl{K`4f_+7xhX+X%8hl8xos|vuoDSgIsy);3)MCT)_3C)6Dv1?{yrgb(dc=x~ zd^Mz~FomuWphrE>jpP;h$uFB%OnR*KvgH+&-!#Ic8>}vSo6bO%!ZGD70!iTnys%QQ zdrEAB&Hyxc_(F?cu>0!! znl=AOUF3Vy)Vk<2Dr^;a>b_rQIwfZ9V>=BL{8rwws#JE;(L?I`n(;Dx&33cf=zy1~ z{KDj->aHyB^vGXAwx0%Z5~g|JhI3=>f!lrbTJCsXo*w`7n)z}Gi_9Kc)JCw+knl1M zA2FjN4lhq)=D|>K!N&r*rHgb0>UaJHRQ74~+E{UCP5U}<-Pw)rb(QT#d%UO&$10U- zr_&9@LjW(W;t*gXVL427+^SMl?OfH1&+lNzcU^%tk^47r`;#EqMAkfPni;M!Z8Lj) zJgJz!4Fz*ZZ72T@8dMzfpT;a0^fj%DK9x%<(gm(@U!{Wm5xoO3V);{mDIXSJ^t(ZZ02N+Qn7k9(6CZhntj1&4zp^;KF37e?dNJB22{YI`xv=VE10{6yPD^NWizKQP8q1~p0O-|zG!$m^_*Lp1 zHH)~SL7?`L{ArRyB=j$%aT!q*3{%|ZaT{&OjTkoUk@_x@&y&1KBDL>hXkZKy!hQ|6 z&xd5wY3c&Gi};_#osN7oJF!5l9>TomE_xN1bc?}~FW&qxD@l!x>%ETsIju9n5Ene< z*J$hXFQRD-(;q@aY#*-Ak^x{@Tve0W{^5R0&%R=dYs#WMwC?nrUUG^G)Lf7kgzx#! zpb_-~iM;XmAk_EM{aL0`7pU(e`Ewwl6Y-^^`Em<&1m1v&UA+NG_>tjtrdI$Pab~?Z>~)Sk z&Z%LKGu&GAJEz6(JnSi$3=a7Z-JSo5EzCTz=9co-Gu#;;%=s! z+6Itdy$y?kj%w>N)5wkhk;pApN~$@ur8fhAn4upbISO(fEAVl+D?Nqn=*nweV{s|K zb}M#o6;xRimy)IQ5I%6-rZ`vOwBO%b=Eh}zbJ-V9wz$=F<;Xw%L#+1&66uCC6Zd5& z3-RkH>$9WqPs7bN_)EWIihoq6q~FP=*6zC5ybUlyZy3t-g(mieDTXy-d;j7D5E;(k z_G$9D#yT(_H~>s=O}MO&!TSii2(D@gHBEJpDNV~iApdQ3(6!dNF96Sqp#3q#lkg6m z?IqR8z}ZN_t^&SafqNUVbkD-R8VvBynQoE`g^kRW1TZDBk?k29IX4Nvr*j1^mQIw0 z@`$1Ryedad^D;`qzk`zI3~;C)10tms&g!`i9T>diQ5tw}nIKqS2|#g#!<5(IAq6+g zcX2+U-B`o6YO7VAjAQXsm$R+vDv_RmK}#geh}vB4`W~!~ZfDI?UBMLfqtbF6vVc{B zk_j*I!4X6J7{cl=k#HKyY03(B!OLU)P(}B-oD&Gw8@)SW;}kqgJ{z&&l+B9Hz<$OJ z-nNwoTJo=G91z^Z_VV?7vZ!^+v77QuA4RS63kx?}juJsX&~o$#(ZE-6%kp1i>9G61 zkEV{e|IN3}bnzSJp}}5WgsN<18>;R_A}{~k1N~h%0#rC5Cc|)glr)U_8pfWJ^3{ur6{35G z_<5m=lL*Z+Xawy& z>_b-I?*$40$rpi9KWeNaj?_#q1DXW_!0#I^>&=Cd_12Skiu#E@88P_2#8g4&zsAsC zClQ?H3#u=ZaA6?t5;_$h=PZ`7n<+Nl?bO`~;~%Buirtsz9^~MURv%+|V{QM|*^BXf zWEqQ=@WvXtuM*RP#CIC{ECvoXEp&tDaI{G@u;)IbfKfwk+xz)`v}4rY1o=2S#rd(# zwVCY?mjO^juQ1=Lw%g0TJ|q(sqF9HfrXS|&4F%r<;L=kg<|Y?uv68ufh`@Nl$MV3> z1yCFt=9Pp3paaS+cv+L!Ccu6u^g5gwu0dZCLf8mvI#@zj>nY?ysq|7joZ;G+p&o+t z2A2%*?Qj_nAxJI=^1~o(v*V8IU2N1m?UFY_i)MRP!@1RL1aPeSuv>U;3MZGsBr)B` zAtD~^k_Ob@VJjAng{eIZC2^pvwF)~yxI&Qqr;x6Gn&fYjR9P(5sRUT9wnA&U>TW^u zu&`9^xHu!DNAVmdf&QH7lvumq1fsa}-ei)YwZq#cW~91W=JT8m*4@=sje^#{80LKj zx6dUFM&E!shq`q!VxI#O~Hy1477GKv~(4FI^=G-fuw=C&=#GK6e zT%$8rtNPwtw?4OwgK#(lCE>-Tt{+52+!ci}or7u%8uvf@;!ATcJUs_j^NKbZ{rTA6 zFGI=KP~uP#9*r3_9K#7Z=~_h^_v?M8-G?Je+s0-~;o;?xU|qCS5|ER1$+i_RnwKF{ ze~UeOn1n98{$h{XS7bu{J%;{1$@h_rc3YM*t$tVS%-X(fUF|Dfb75f9W0Tuk(b$a zve(Qhy`L`@3;ZqE=Qnn4b?R-!K0uDekh_7ahWN-umH#AezXj4)`CLL_)6gzknPkSy zIPN4oGB!-qw9t&utk9g$yeNHO3Y2jH%9~0b2X@4-&3f3`0}@R}hj(yl5-77LJczRZ zEqEEg?~U5lxz)2dg4zV?3~rIee~Ew?+Bh5HL%l(j$ypI#pPUoHDu>3mUgvP_LsU#O zVHh`4mxz?uhr~u`KNZsK!r_8A4DQQRnoy2SEdAE4AP<8X2YC=L&fq!NDji!iW)i`8 zGvAzOPF|q`-h*`3niBU;qB$k&tU(X<2Gb}pg?lFYZ@A87s_T8h{-G4Mb%5)!&D4Un z#?@cMo1VqF%>3|CSak)dYJ=n`2zEsIuQWTb)AiCy43$iD{s;ch8P*WpEoyrl(J+M- z*~Ln*Eb3gl#PkAHHo(G((X+z~I3Y|OALjM4j~0#B()u7<2hK!imTkY$S_#LZw&OM8 z5q6+NQXu&fTT1z_G?v)`VtAN}P6co3gBIof(bsmvRlmjpnMP+>yN`zCF4oBF5V@-=FAUfw-QBQAdr=y>=ctGhqfgFaoX1Hd8j}mNM202-(Pkw8DHWV3Apb zoD_SI+g=W~fxSM45>9Wdmxpbj!-j+*3s{-{0!mYW9eC52$(?CeMVZtTn+{k0mISzv*4Y!4TeeKvyxhv9l z>L2hWE%mSQZ{aKq^v4hwI&4P6bVz@R>RMo$s6A0*V5L^V+v z?Wk%-hV>KpfWXDM>i;iP4wW8N&OY3Im9syf|B1?oDO5SLR5|KbF*T)wJE;&YgF7e> zEse4c{O71G8bkdHhGZT1>kQHJqke6Ym4Fqr<6w2M<&Kf#g4u{4$C7^nRD2w@CgI$(HFRN{z!V z_a8wZ9YJdTeL0i+IK_U7RFU(oJv-72^`0M~2; zu0?_?Ywr%OE#&?|f$J_-{D?jDG?u+MaelkV`VDmZ{}E$8S}eG!ZA*k*Pdc8_@6CdlHcXNW}_(*7+FLt9|P{ z3F~|u@xDS}d;%?GdJ_WtIy_JanchHAWP1}(2;(j`fS?jsGKX5U-5Jg^H?rNGA=_R3 zTfDpaDGNj^Ql|gStMnaYSNC{ZD4^WULSD^vMsdK&U&S{rrekMU#Z7erz|D zvPxb(h1z|UveinTg4I5-(+Adi9v_0l_WFqRu^SH#YZmoiIV}H;WT59HB(}}(A>+g7 z+11$cKD)c+J**tW@@DN_m4o`v$c5Wwr!MbiH@%-1=Ko6I-q|GH9k>q!vs=vNgJLfK z502-d{{zfrk+k?SRg^vDpVl={=Z>7^T04$Tx%1 zw_~0MNuV{oAOpmvhXkK%44LfXIE*!d(1?#mCOi?F2Zge|NkMEK9w0VQI!5wrZ(z>n zdXqr!i3q(}D=#V9pub97@_Az4b7pom@_(27 z`2hTnfrrAngmsNg$Va2&+R*1f|3$PA|Fmzhh5%4kyTUytib2>K+e5fRz^SCEOY^LJfDuUJ}e5!O5gIrF2 z=n;JOH)3tmsCmmG;jXY4-y^u5MV5RPQl1_<)8%$wi@(Y(vIKqYH+PhTxyTdGx_b9~ zi`V4W*~ZNsCIg_S!vsILLFc({cUO1G4BpPKqc{z7>@3}~TO2G$+QV~CcX%szePU&76bo};0==5JCJnjPeNx%)}N*&j=ucd0}1?pfzOY>2bj!{ee4N{X@?7-z5=~ zI&ZkYU9UW_%Sq&tI7GI3}!zcj-A6-Yu>q~Mm_rZiX^{W!+~_4Ni}pUz8I!ciU0kl zK!%R~?D};MS=+VbyB(YY-S*v%y`XNOB}GS(613|eguR4&nIKt`D%;qDWyx{m{z>#h zpZFI_)!C32puwhUlzqFKDCXGq~laDCBz_Y*;}_VUw!| zcOeC|n9_Gh6!~ zvD3)@Gv^_;F~0*5=oml>zKk;Ps|1s-$&m+BNV@){h6gcG#zL9o7vvy#(~fK0CpDa? zhCNinj%#rL4lIRld~j+96!s-(=8F!PDXY)v-25vY^%^>=UMG?1T#G}4*{vY$d+~Pq zL!1*^$D;p1%1$5Qs8!S*86_A<@(IK_KIj%7{{8X8Q^!Rd9$!nJm=lS3{F+h1DNu4| znABU%J%;T@;I%aepWu@o1BoWSJcikezOs?lMiMAyLrZUA2u|krkRc`!C|c7u@U8xP z#gzVkAh{EJM<3i|x9+|D06Rl*(YQ?qgWhN3Gj=~Lavizr>_HlVbYj{(n8SxOOf5Yg zTGb0R#c{riDKeCfm5-_)rZ#-j{P?<^Gqv(v;fE2Z-r2!_fDW>?7Gv za)4x(uOni1ZNHoYhinW-VJie)QIpVFiE1aL@y0h)K_uu|b?D|zFfnNPuB zGN#Nq!^ut?&Yo$*oN|nPhI0pEQ@Bsyp2d9+?#HK%$)_D_(r{kJ{q4A0xVyL)4;cRo DQrjhw diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/blocks.cpython-313.pyc deleted file mode 100644 index 953da4c8f0fe8e515db82c7cb2040303eb9ca22a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 39273 zcmeHw32?27Vw*K+oVY;PsXUcqRT8u*1tSQ~p|WmC1FAcwXqTf5)i z@0`IKk|O1DDyvsOPj|oWe*OCR`~Uv`?|!QukCVf7W&7eO^I4AjD+2k2JF0@-8UU^@J^&N9dYq){F;w= z2E4p?z{mU8ean%Ofl|Jd`K?F%17&>KKsjGNP{CIWRPvPrReTk@XFF0oP{Y?Szx_z< zKpkJl{Ej2_0}Xrw{LYZ;NaH{g-!u^50|U)`GmCQ{X&GqcTL;?swt;rOeV~Ky80h3X z2fFyK05@9N!i7Bj2JsE6IfT5mT*x=-9;%n#TJtkBqzOpgev|0wm%@CvPMIYrv-Cd7 zTsK;yQ>q`Omfd@)N{!d+lwFRpEAG8)euGY#l_<07KFaLTDYF`7*4#&#y*g#qqRhJc zD08Dune`~M;XcaTq*G=i%51uiGB+2=JF#|@51m>Yhuo*v#)bktMt-YKZJSZsmeI{? z=F2vnw5>?nrjzyoowV&p+o6+oyH46pr0vp4yF({!5NW%0((cqryAEmB7irDAp&7d- zO~DNxQey@Uf;W0HmFA;k>8V)y)bx1JDwy^r&j|MY(L^G8G7%HZN8_oqU_U;Sj!!3} z3Bht=b_UUwa1!CcpiwX;lY()23(9hjO;63l6S2tD^mr_25z6<^KK$Hl{P`#!k0!@r z`=^r=@kycX;A|ps27zR3EFBv^5lg0~`A7Nb@!7GoU_Ba7#-hAX7dkVA8u9TlBxLo& z)1Kg?Gcz$>@Z{yegUeF!shLE4;!Gqy5s@l^Xjf`1noh^~NNRQpLC4X($BsQ3iX40N zLn6UJli-+&rk;(Qibcl-FCR;tiq6C$ah{LOq=csQOk{dCEd?Vl#M7rD@g$lGwM11? zXei&_$z&`QKfyx6}H-tFe7&7pt zMp~1Mqvj#DwraC+SPzTDn@6pqHY~JUpEb<&8A6s??zH{1`LuzzhpeNHq4NB4C={~k z$M=~-cH}nmu8@Owhq>Jb-qXefT|%|kr-K+24^N$pjgQBZllv0WW6yrrj^YG|*y3ni z!6|vs@Pcb>HufAklg7-LSyK%w{Im?NJ_Kiv;Vaw-e#PsAIm|tcw~-f{05?MK#;#$1 z*c38sH-))`w1{%tBzM&~c-6$WA_d<@4%O(YS+JyLX`tC7s8u{2iSRze3)Tc1oEX_= zbn`?(Ho*~zj3uI}R3wt3rv`AnA{^~KH9ZyUor<146`ejEkHR+=+c3jVpT_7H11Dq2 zu~WTU`0<%2pFYFlBnNXM5li)+q{d41%$(uZQ8k*2wP%l9-A*J+HN*MwE8NF!&fz_0 z%h;m>ZIK7&!s`h7B@j)5cvj3|=+x>k%6?@*D8OZ-$eiG6Mm=7e9gxQ|5&zKQ07{ zzK)cFWi3OVSGW~BS5kid7tj6T{En=zW69pZZ$do3l^p8yt6u7}WOOPPi3rXJ5hAo8 z{O(BPx!Gt!y5oyPCgOZ5P5qgiM)$gqSFk?9ran&`DDNSMHMb=_&5xZDtRm6jw^7mx z^l>_xj*qcMi=UiL$I$J9M;ZZRu|y&gK@0RCaf;gNA@VO+`eLgf5pmGR9$-H;v*|I3fZ~J#-+y1nTjp9DCG)APNmt>d854M z($R}YS1j-&kJER){#^Yn8$us(aKr|p=Z@iDxfDwBw}>jFB>J=h+^SLT;gG4%7&7ap z3t5W9TMKB7O{F!UU7nC#pRR-)$frw*XqJVX`uF;*As102P%>A@t)WL=!7mQDA&`9# z8j%uhM^S=tV{DL3YKbP%0>CK3&yN1^G*$T^oS;RVD2x^zrdS#y;s{d6i8zYbu&i7A z7RvnJiJZ5a|2viW&xVI!cu}y>DwyK8Am*UYttDUuwzoBAhk zXs$WAK+Bagub){g|Mi!bx9-rJcIB095*D9!^;~enO`FBic*9k4e%HBOD;D@M_e!eg zEmz!^-5Fo|3f;z}m8PQup;Zo!SaYh?82**ZEBRK-@dEcU9Yf|m(`_c6K2uG8gh@ihN_h}Q9L0J^6B28vO-?{dziL9Y1(>1 zCA^oh5k9aHexYJ;I=ObVd^m{G_(SAWz>(&!^VrcxWlqjK83hkWGfHqIXQvpumlCHn zQF{4O*oq$_rhu)m6|xn>frPR?4nL|g6eA9~xZ#o^v+P%Y!&Zp}p*CRGh+Fmz$%8nUZtrv-R@~AC#W0)Nv9#@F3rX=U|m?#gAYH5J&}`d^ZB( z>gPEXA0Lk;rN-pLbb~QN_5@>wBAa;PHwEWj@Wsz#Nt{j!p5$~gBHoUtfhDArn}c=@ zi3?FlC&;0$5esf7A6ql&tw1U0dPPI@8g<8jLiZj+uG9n^;(PoWzGu0i6Wq^oP0Op^ z>(z~5fpY8jWY+I_j~JmEFhZ|-KDKg=fh$kH{`9pIx&A|${zLDj-Co4J4NC8unP;Sso%2;)croJ*tT-s>Pwl9T`b~;!^;>>3lgq!aKuSLwI9R3a?ymQYZW7L z%?j7Z-2uE>!ED+_?bNDkfT{yP)j33?tKb4q)s0txsvAJn!&Vrt7N{PG&rBy`dF@wz+aQ-4$|i=D6>@^b!aOLLH@&!N`lqf?!cFie_slF zvW;8k4!n9?gu-tfzk2+=x_ow+KSFh5oUpRUBZi1H{oe~GEP-GuSPOzhFew3Go%fVT z-z$`@1kNiQQ7nSL8a^0nH*04?+gFVT2ZO%b6DYxaa(X&ZfO7AiIMJ-;$)P!sBHAdj z)eS9-o%U!b68xY@m9;B2@@Ww)g*kE>4VK3FiYxV(>sL5<7u(-je|7yW@_s~)NKEKy z^sihp&$N0CQHAU>BPQ;9k%WPPwt<)chcZIrWbCerQk`Ivbs1V9RNZ(G8YeNcyeH(+ zu(BSZPFd;JcKV0Egm+@JhF4i6Wwwkd^AV|tku%|c^;h8*P#f^-`I0a#%NjQ4Q68mC zJxWkOd8lRB1ccKI2Wn%MW7S{8D)rUG%H=4nHi=OwkDCd7ziWaE+6MVa6vLk)hnA5s z5>3b#&01PuRtfC1VjL1Uz*xts%-Y-^hFHJX_&oO&OMp9ux2CqAG;A|X7(nh1A=BVh zQeYT+HVUSx*@R$-|h4Y5&H(w<$AF-H@rJv~as*YT7Tu~i>( zuW(C&!@s}l?aFIq3#|+3->&=pT|Y@t8-Agz!T@o&SCZR9-;#(zhU@~2jK~4Ebdz9s zN^sCE(W|{NF0v7&8a9G);b;(vt!{%)FeF5@X+DI)1XCiG6#Owo#gvXD!Is67BBy4T zl|q6?s)2NuF(e9mAN1&gQibC3&(njADB>*D<@4mcM9vH3Q2(W9R5GY4Iax z%-+YLtn08Po`l{ zuHjIo;ZU~W5fzB*g7=yx1)-z-OD|u1dAYjvrp-*~{>z@9thREVCk;PY1#%q5kCUs} znIa*_w}X`j`<>r0dB`odk^4a9!JW?U`aR^f+Ti}DcJCpBy9ifE>l#~w$Kknuhz}~H ztduBFy0mbc$yb9W{tLLL0jbX-xR$kSz1`@OxXa(8oYWTJF5M+LcWcJox^UplV{aV0 z)|_qY%ewpL^ntQftOJ{;zeKg6DXOW`6oPj|%k;K+qGn9oKszWLb zhZs^XCo-g-HVs1xzXl>xh#>NcgM3b6R9)nAt9Y1luW;lXH5e@cnNx*V=2YSRh=U_? zs#J|J{41AP@~wUco+KVP4bTC)#QL5OKhqL>!e zP@hmw1FqG|xK>vPyf3V%>>7~l5X$zZp|u`+I2q$74CxV5C-O5uT z+L40Ilh@!W;8juKd|p+57K99U;(x;06gsPUi^3U^@U)+nWn5TUnGow$_^%oS%Xs?C zOzf&vU5W%BADe^~1S|5uHuomoj5|Za`r?l_)L| z^y>Md%zB-6RF>m^gXF0P;9x=5f}Bg5dsJumP!!DWimr(D2hO zA(M{-kB&LUcywb@L0&WrLuPVuAIyrx!yA)|G${#-2)!=SP=F`*Q)0jl`5-O+86eha} zn%VpjER+27>|r z=hws63=3lm+rC=+b~qE*nXTNFDcSW->VuMf-)q?a5oKF7SsI;7_FAkZWmT7YF81X7 z!Hhq+RNB3`ZfWb`rT$0W^Dp_1F4>R%ToF>x**eIbrNJIb&tr@|V9FOvRR5T$G1<+r`z1v8k zzeu;qA>y8T0?v>BDR-ly_R{FZ(X6?7#cZ=U7cGk?-#UHu^gHc;u>M=?mwFCnx(?lf z6xg}qH(4GrD3W6eNRnbSD+0HiHp?-?3X&c(h_s5HG=_iWBG#X-)vv=TTsoPM$4f#U zm^N9k#91J`u|jxb?JE@9*dVsiT|^|$!^G#eM0$A zbR70w1x8Q*4X-0u_{dcBMHt>dA$gKnjR_UxhvZqZ3X6=1W??R#6%wI!MevVf86sP{ z2yEh#%<3#pBBXGm^`KP+tE_d+Dr*iINOlqhAW~IJ!+uGzNHmXK$Ae+CvXYX96uWd% zmLsK^r3BT~$e7oZ34Cs0>42+8JaUN|q)*%+yN*{7mRFu$D#z=kRV5O`Wg#0=k1AXM zC1YcQs(a=t6;YPsIfwK$t;I09k$;IgNB zejw}VxZcoxts>KKOpZH}^>kiuSa+>G(=Z4j`rxvwVWB$f3SRH%&vtB^JMgFO)@4uA z!q%(@QCoglevVXDYYVpg1t@u@K_EK)ca$kvL3E_I2sj3|RyS91yuTc70|_7C^eX{ zdAX#1u^X7a-aB^`JZ4^&Wnaa)mzV6KC?4-aFCe}Yd!KJ1h{KBxX9<`8UnW{{R&6UdXJ^)MrGZl#Kpvl1%Aq1%k1f(vuHvIQpiCe zNDUgvzUP(;p^u=y5H-fEUU^HGEOcD`J5*F}>4Lhm`jY=|Bsc zsH+G$LdHHA%K-P0#z8XA3Ayz35-`~*qRC*TdWz&Krsc5HE*7Tc@M*Lhr9!=G{GFEU zyP55?2pNBb7Z?;undRw3ESgS5L@Ax%Nk@4?R8r72HVw1PglJh}olV6cSQOl}R>l&s zsaP^C+VY4nM>{WuA)%DSw5NbMt*DDAC;89h`>*7*!#Q!)B$U8hAjZSSKJr2=K6xq)8*-SW z^I&+TFiQy|nr_Oqj-3BO4wg_(s7%QQ^JFuPi6)w~49-;hOqLf9(za?)_T zRKJYjGBpiS;Ef9}r8V#@7*au_g>=N+6kuw!P3%mV9Lhvggg$@YLMDz@?m9*M2@rr|ihMDP9 z3~+Edc{avRvjv&Yu#zW0WFgWLiOPCg=T#R=U zcM=p~UD|F-KqMB1%wSVaIACF5s8lm&U#@MM^NH474LMh5#?=W^h@7t}Lx1a*%c^o^ zt(mgcT-iGOEtgm4%G)yKZMpJI_#>0Phco>T=lX~7ccZEG!eLE@o=jQKwXSU0j^&Q- z3&;P$RrVLY#_J6&S0b-R7N2^j^qsxghCLUY*K3=uYveyU_CKT&$f316^{?TGOjvb2nvU6{)mb`k)HPh$ei?fh;0HHFx}}R} zR%{f41xU?`gM3b|v|`0YJ~vm@G@rUSwc??WYKpevqi_l5^xi5ZA4krPdXz8kcYq(| zdnzu3Ng|PPx2)JGN@dw7{5}Xz8e9)FD#nkbl$<7E(7)pEFd(n87lzxJ{vQ{Z!cOi%n!=E%j5t z4jdbJ+4hZc)VWGORneMP7fD}JB)+yte6hE5hw7+yN*h*$>iJ40`Kbae4+tfX5GyD0 zV2s_3Zvior*C80F>{A7f9;3RR1C#RuvcS>?1 zSiXxig30~0OFpqej2$AhS-=oOXqK;n87@&7Vu(&CAG?M(s!Z_kLka3bG5*nnF5^2q zqgtqk?Zlq-f)RVnkhBlKl2AoS7xIinbQW{DshP6II2bg2_z<4U|1a6e6T|#~V&5a@ zf0Fap5 zMK%b;AXjTO{>_^|us1HRbsM_Lf?agQyRM3it3jw~$4XM$_Ns5WuHyrH^|IZ4-gV9; z8K>tdGVn;T|Gn^N^*%9!xOom6X$Pm%k6%0e(~95PkFAgBU9P2@LZ-cxa96$_Sfh9Eg63cQs>Ir zGG%Qr9q~EmOdmsXRCnpg*PdK(E;Vhu7R^*VFz38p(ReBHT4dp=Y{i~A=d!P6$zCH4 z^4I7!tx4#iqKPquW+IaO^uJmOhYewj6Bv0xz+fhTK%Hbfn>xYICef;~K>8H^nR+>p zo(_<)dEnVtY=*^C;Q=;f0`X*EESic1rY8b3*uajB#LPc=YC4scl8n#t=)?dR^mKf~ zp-01kV+R7TjG7fCDyg19QL`jztmy@K6SXYW7TqKEj1T+XS;Zo(=#&sx(>DbvCa4S& zhog7wTx&4X(l__~g~`;gdV3{r@KU$*{9LSVM=F}Y}-f&-=AmWJ{F z8ws!pg-JQ6Xr7Hm8}jjXa%k;K(S#O*Ewd*rZ%h87>lga@x0nVf1+G~%h&lay-A=f6w&Nrc2xg+-|5A>TT3 zsDpT-#G)GYA@Z5Yp@}6j`UZMq3pww|g(PzL+h04a<^jSPZQ1yB4~7Y=s3Rb>@$-?|QER z1yC*Q5JSuG1b$RN7A8|jMJ5z-%J*Cpa?2qPg}hRAR(%MuH~c#)TbwY|O(5h1SU&u} zxW?d7E{dgV^@nis&xZ*4$UGQwS~Dq6nFQK0YSCSgtVQB&MdIy6;vGfeMbKiG;Dig% zLV_%#E@tzjUn}f62hrsE{2>A~w6ohC^70;>W&&9=w(BIIpab2Dj}=bj_%G5F*n+%+ zc?=QFX_`f=#NL7-LMrjnQxc71B8DKzA!;CAs?HR0ARQh-KbFagb>(xC{Fvr2yHx$D zeo2Vc;U*BP!#~J!ImX-gM-(mUCn_0>RS5cBqti0)Jp33 z`tQS3n(TaAATm=ulCGsbSKXDV?t(rBw&CViy->%r?Kh~&@-or}+*O;o&6%+W=A2-a ztK0QbR+5gVF0c%P+-kXepB?t&=2zWV*DGpr<=vU`?p%3)ro4a7alNTI*Rp{$7RKmTp5>DR4+kORlP+{qj^)7N2$8QB32z6 z4y!vB#^0QNWBT2`-2PAoe^n=7Z7$k3h>Cq3?o&iXAJG6IXf)=4BVWYqu@&jxqW=IB zEFzvvI^biAU(y*i4#}xCzfw9^*mlwaiil$d@WUO)E|Mg6g6VEBS8N26n4P*-=fS52Xo9s=&i z(MQ1AL`oThh9gi|N3T)@^)^VIDaDTo7N#iSe~c)>4)`5OL6ZX!)ZbG=(thxacdsK~ z1vykVk(FOhK4R~9GVc57quHj!CGeB#dY(OS%1f3c`o>1 z27mqskq|f5UtB+rY+L4RH$Zxho;x~Um323R#H?6MCD>UYTWVjmi@nVHyPdjHT}s$! zpk5aH+eQ6t0+3{bjGV!P&=Bg_+!TUCJCz$Hhha)H`I6+Q;4n?WP2}{Fo*%{R-iK|;Yjv{^#MJrq)L4~10RLm?ISPzd-?g=jvC zfi-}r9xA1fpY~|b(*O^hzPZ`+FP(d7{=|~6Ipb)#RYCEzS3`UOsBFqsz@@xdI)pi#XL%V)3&^K9lb*o2`# z*FPfIqz~KmVbezO?X@mg0;0eL7q%V~Y$B6ic8&DQKBVuG{pxSgQV2FF1}2lR^umM; z*6YC_k!R2WK;pXc*LWCTNz<874_`$-!a003`Plw}8uHbW!xkZR{ZXKdcrFfOIta@! z&Vz464n3r@PRj=rqc9R;`7IEBrX&4O%D*u5doC(ZejT)daFpag7dY(cl`G)#@FACE=Pq3a;^~`xu6Q zs^FRu+m2edzgm$VScUA!Nh1 zqI+eG!pOkg0u_c)O7DPC@&_2y&LA<}kAl;Wtv#6AHPCgWopc_u$pLf*&gAP=Uzp;Rh+BcyfplO)a9MweX{J-Ey#(!H;$g z{D5BrKPaTa4+^R9gF-6&ppaJv8$Jq^Xu(J+#fb2Od>Z%>zA}1wbjcUYIJ$3D&^->0 z2tlao`Vi!Yc$a@%5QI(yCTgs!oyikh9)MaK$3m__8xLUa-41Ph^qGdvo;^Ol)(I8RrU|;`dN`0zicdCso(vzkeJ@ z5Kh|vEfG)#=meV_V#to*>KbGxSs-u2Jt0rVRE(ZV$nFdg)Ps2Ly&}8&RXJU+s=svh zwX+LtZ}z^?d#yX$visfAOv}D()&4p6ZPiir*P7pnRc3QC;y5(BRrmPJyy%C&iINKtNG8 z?IvZ@Q6s0Ky`To|!LLDk6jIS1g;cahArY0#dcY@&GBRZB>Z zs;!UoK2tX7PSau%;rMT(^FKp@m_ghuTH-UE6SV|gpXr>$c3|d?7#W4v8lBV669p=! zej~I_vkMGPm+V)65v%_$r1`3s(?RpK8w57k@DD)q#a1|C)o!o;+N=ZhiqSj?^1X?c zNc}$Evk2tF+Fu*cv9J5ga_?1_Y`OOd^kDY^{9q<{PlkVY1N??D3Vyc0@8_XH8zHqb zyYyt13>xY{y`s>NO!0J-U29ZwC7J`lG$2ro4VQ$SOjSqVkiLbZ2ig1@a9AWw0XWnv zr4VB3q$yD65tS*p!hQulmTKT5)k+T`j5(d;$;4wpheMnRs2QpP&+5JPn z5AbF@Q|0Za=P`ZW0rD~EN9cMj^wVO$EuVzIN*M5mL}4GtL+%v=e!AMdPZbf)KTynS zQAT`ZC7ExxScnO4ryx`@;vLk3gt3`=usG5?Oi9QgW?#|kJxECJQ7zKz)*wCjHAs&_ zD$=8niu5R?B0UPJNRLAKNRMJfq(?pt(#zB9(LD~1sMn*a>!ZE@&-8ko+DW)Koi4`V z#4Dh?yB(@4H`A&&f_4?JyWD_TGjjnLJg zwVrou9YyeqS}1(InqP0TRtVe4^tRaRrKM9B$pRi*$pUhzdK-uGim@;936+{d2FIsgP(Ef!ET_0hn+YnZWeT>B zG5Mz8lD3#pUhy!I0#nMnu%e_c6d~g|SLqWfn%7=9kZaqTY1{hlf!u);nFA*v_NUjCk31~C z!RKzEYz#({t6l zZnCWm=4tKRM9^*_nW(jMi(0!d#WG5Df_x0D{U1tHM?qC-$5b!GiHdu}?hmXHPNNU{ zj>jV6SyZz#f=?49nkbYRyv!X4s{j}?Pd6scHUu}1$wtI0$wstGd8f6n`fz}`5BoH5W-^_N zfishN_7MR+>oEGTfNnfKp{^H?C<)7T9PmF(dAI~ZBJkuRN)??J9 zP~?NkJLB8u{Ze4f$9l|z{4+FEe-#P%2>C9+hr==1E?9Pp#S&L5Z%%ndv@e6xf!)!eg4nj(Vo_($CWb)D&;7+`q4IYg_&(`m|-5xL1Q}n zJu69`J7Q+$=9fh6EPuPlEFF&^f3s)PxY#+=1M-KEjvzgM5)paFO0kWnLRDdBNYwp{ z)EB#v5$DLc!Rgh%@2;OeapkGYPi5U*i{%-2&$ZHwd-FHKINbI9{)gWk|D)62KAr7< zY-#B6O#c%Z_Y*MWl0JkqCN%?l6XHWy!>%9}{17+9qVpZS9rMa?G_cR9qYXkvUA{-R zMEb^0-dB4{*L<}{`C!5Ex+1mG;h~gT;Tu3~m%a`>zt6YPDe)7(jZPUucX}HVxabfD zJ)sT0Nn{QiHQy()D&Hr|BMj!fm&58T^slviIGxHxd>_{HPd^7e(XrLyi78|SEoQ1nwT(z)hiN1fMt zmp1Ktx9hzJmj*_bMzBxf=_UKqKl!MPtJ?OHkDPE*1k>N~?C-IEr^&N_i(Nd8&aJzj z{kur2``rj?$lIK5_B48uCx@0<@q6Jo-OTK;IOk(49DWubzC}MPl_o?%J;D4t|lq9rtwro2l*(oC^0 zXrdoreKU~DZL1O;|yf_<0RaC`W($bmmD^#@oOS4;+j;YVK1(Daww zo*cL5uQ>moakYQRRsA{Fam#Eq7=Of(g9Katoa;v9Uvnoi+=*L8o56@rU>S@z%Q>TW z?o`%L^|7(U&^bSQ<;BY{UU~WQ%c~rMIMvNqy+Su`cBtt}4fHki>dVzvnl3l3atKQ4 z;KQ+W#)|J5JW`SMhQR#8R}No3eC7D%`u7y93P;Ee~~eDDnkq+$~zBrt#=*m=dk diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-310.pyc deleted file mode 100644 index f7d4c4f4b908409ea738ec04bb4e3fa22abae401..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3331 zcmb7G&5s;M6|ax}n%SAz^?I@65ZOMMgdt>&V^9bf%aUzKP}oGVL1<*tYERYd_PD2e zwz|j8E9k6VJjtwocQ(!RMxR?zTm|h6i^(+2bUZ1PhP-xjc6)o4K1xiQ z4O3(Ew7hb}yl=ZzVm5UZHyt)Gx|*4MC~4cpw~RmOmfFg{gD5_in{4i(`W36$9>2xz zU%4md0ctpp(8gWc+mKtR?fFt2)k|v(`ZMQ4A@73M`pZDtCzJI{($$9TjY~KC@8*wF-_av(QTR06) z71X}+RREobPsQA?d$m6zeMTF1t(-lY4LR;{^aWr-)YGWtw?Q^>srxJU{5hYCxtx1g zfY;)Wtd{3k?O`eZ&teF&1|4tKGZw?(6)3wn`o0%+R9AlHb|yZoh#$hZo$V6^VTa1 zQ`vy1R{do|>M$)L1a8}MCQ2S;da%6EB|wf5xiLO zBVkCFv^n|rkaT+VHD+h*CK}2IAaGkzlPkUh*95xZ&cHY1(~yIC2!w|Zg}G3ReFjGi z8@E(&BhGTZ-KqWZU2@&85JqeM75fXq*_V9VtHi#n;lhA`xAD`vyV1qGs4k}Y#lC0G zV_)0(e3n<~MBg^1FwQvU0^0kLc@MwF3`aVfq{@&hnzw1DKPqOh=eMY#PlVj=7@|&P zX|5Be_ywY`U_kjMh$lMS;~hb&^H9Xsp1;bWK51wS05LS_v>Jjx@limaulzZPW&^I+ zK(sibzR+^)DXD!OoMQ^=8wD-`QM~C6Lgb8S+)^IVk*6ZnLNrH=^4MSrs#&r~<-XoMds1!dM<=sErE_@o3DA>YlLZ~3JY z6r7Mw_AR7-xM3)W8)wP@Mg$=L2S5{lu1>Y zW_mkIb|_pTf!J_isK)923Z~hza+p;8miaNwzd`dw`h}GgDE$XXHQZ{P-8uMAsM{Z< zI#aLvc{nP}3K5IactW$J))3O!FS?#dzNHP&oLh$KQsg|2VCv=3?T?6TyMf90drs z*ceR3#tz(+kO38;0N+}?i3&+7`P7?BjPYuTzGe)y!Y$~@Rnr5RT!4?{HN1KDeZ>JU zaUk~tS`7~{S)+Z8lD*%6=s6(?2{};uQ99J-4HCmfFaEAJMQLuK$^04w>DeSXQf1gY zhP3&JM*Hrj=9e_=S2WC&6u0J;PoIrA!mW|0NX}Nsc zWDt?##qkS=n8srCB*MA#${pnxhF^v|Jb{|kuan-&OCW1O1k9oc;SmguYDctRymc{H JJ;N`(`yYx3EXDu; diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/irreps_tools.cpython-313.pyc deleted file mode 100644 index d9952a6f58a04d837ee593b5827f6072247dcabd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5857 zcma)AO>i4Y74DhQ=*N;Q*^({)C6S#twz9DuJI?NU<75+@WH(N1vzkoM5`!2^6MM4q zNa-1E)?q1hqT~Q&d5bM86jgCfQHsMFibM7U>;Z}sRyo9CFjXlQ4msgqL!hVvUe8GK z4>F|7?tcB=>z@93?|ZMc;dE9ZDE9_(Q*k#!pOa3h=o08`{Dea2W5gkjx`c))ObyeR z9%eA(Cp`L+37ep0F0sRAY}R?EOO|0Pw(2x{$u?YttAI9-n}hYG$+ue0vIpO~2j4c% zRF;fW{WJ>`_glHDCUnDu9h|+gCb#$871!8r<{bKr&T(rS;+!YoL`Z2s(5r4joNL?^ zY%zG->)`(FJ%c{M6x*T*yE!+EZu@Udu4WH@?RYI`t1N@Pd-RohCh36p8JRS|Xr{}O zqH5N$get{lA*Qicl8Kl&8lW^fejI8?BtD&xVq$nY9u=YQ9ZH_RnUvlauq4P4aVRcd zmnJofgt3@VdP_u~=2QhdDXL*Zpqa6#Cb1kT?oi*618UOW_{^{(N@jYHY=VFGTJaPS z;!){DJ?QNSLP;7S*-Yxh2ivr$ZCaO$S~gwlhoNtG!PUd)K^W->eRdK)8HMp9V8z$U z>)}nM8DuAK3KC=4RYD*3(eIRHp zQCn>r+P0!r@;#m*6>b_cYwf3s#s#h`hAccl~d*D1^KeAh{ zk!xj-?B(5j4e#P>$tHOZ@9k&5NRQMRtCCsJs0B^8?7ikJeTio|25i|hK+A0CPyO^E znAwjMv+ONtUq=CU^oqtL-%m!$mj}b)pyBM{lF4m<4A;p~`LJH4Kbn`oe5EnO!XYc(S(528J$c{s$-+cT)Lg~2lSRgeV>-4kPopG_nNEU}N*LCyQ8zRy zOk6H3Mkhth3Ab$PA%(BhY#GUkCFJ@tZRyyia~8_YY7rq5?iWZM_Li^+8cRk<>3=KZWV|Bc0H z?)-|i@3+SW*IC5YY$6!2S%th!*`p6f^PZEb%TMif%l3oIri1_Zn+t>rx&6>kevXn^o&t`)$e9ei^+u=&Ti{4M9o+6? zFbgT6$?X8efvft_7)_vuXFClQBoN?xbtmvuLx|?7YIK|_ftE;-QP)hxK`2O8ysH2} zmEjP)*))3L%5AG(j;nr2_6JP33G_8bFe$6jw0IulIM$s4Hv>Ol#l6rN=3lbaX(@`^ z3HKlw$zG2q!H?PqqXR142}`zV*keshl0`wc`C^#CurowvR^m{l(9;@EYWR2Frdey6 zD!A*??B~s$vzO9C1()}`-z>PBR^6?6ckANu#jeNhV}*v64EtMm-E$ji3A}J1&`++_ z?q6(Q^yREe`<7hG2hOb2o?UjFeO`y^nl~Geqk3-O&OqkND|Y{~$&X0}1704%QbLUY zT`ra)web-UUxgbVN0{Oera1~T`w>^gF%Ys#NmY3m0>B+Nk6Q>{ax6#pn|ahvab^hO z)*Gf96t+TwWD9z?4JeWvTbbW);;ecKVCQVu!J|_YcEa7+wQ4LXQ$j*4p5wD2m}peL zBh4O3iZ{a%JRn>! zRvcA!EXV>b8gX_KJ)}khRT`@#0iQK%7{FOl!{IF_FVwgf%Lv*}~ySOi+|?SRqGbxK2v<9Q-Dk)VK4_yGtuL4?^W8RDVyQ zO*67p&sy@Xwfcqy=_6^;yi(thHhsrlXg>Vd+EDP;L&sKXe9a4A|K8VU-vTFd z*Djdn&FQlRhkNeg?8Q|_OWx6vedn>`z)lYL#L@Z}3u-u6^dMD~Ayn=_OtJ&r%>D@Y z1rH8NrU5kpwSpknUG%i?{I=8%1XX4laNN`_5)BMt!FF$8dGhi9T#1yga|7P%f zfaYiah7zFg7eMgrUxc;;eh}Q&nA=hUpCtp%K&4rx1VvC)Yz(i82{R<4ghW@f6usVv z&#VfLHl;BUK@He2L4Eu-8Ba?eXbi~@*$F|7Oc`;xjfe?4Kya4CSX3c*R|?5~BPLs} z$MG!zN45H0=u+epe^t&xtVp4PvpP)|s@-V@(!GVj`9X-%>{@+e_Rzh+-N53N2k+j0 z_ep)v>;-@lZ`1dCvlBVyq2>PI^8TLX=AOLwSo(ayQJtmcJ7!1HUs!Y1XVmOhR$Z^; zU9S{8zE#h`Jo)P>_?lOJ9eH2Js;@8a>sy*y@trRC+aEaZJ9FZr))oKYUG{0Y&lvrd z)_Ws&M{?~;!b;2Q^XH#7?7Qc@>s%Clbs$gvw$`%h>CSt)mpWHGZ>%-@?+tu^V3EsBTW-HbJtcW&m+OxCm5mUp!;TibQ8gglEFh<@ot9MEmx6z~GoB&f*3 zv=|O+Rbldehu0C%j&S&9Qiv5Lu5kFegcX(KQgVDp{2`!;$=gSd7nod~X1=0>GY0rU zS{=j|RZyi!nAn;$kyJ$pH=45;c_Lyg77l~GUI1nNeRswPKLcb=QW1kwu0r+qZ;{u+ z7pW`MbH>iL6-cC+m)3Q<`_<*wF6Fx~>%FHo=L;4z))A?in%Kix@m}n1Y#jlaOMdd< z!w+8&@-L*)k4XOzCT_2prbGb(BfQwcgi>~7GUX6+0jzU-9uq8v1el>*04OR372|Qc zhfMzkil@6^zL39C@Zp#Nlgw;oP?ho-0#TL1S^3Kqf*QBp^a!$r>Q`3{I51naxr0eJLVhl7s~;hC#8n zMI2Y~2(YjS6}+IaAj50cVmw<#{0`A(^&7+jQJWxYrfETWPh+I$2MT=4fFT2dY(%18 zFn$jtSgIeqSH*)WM5_%ZQSG}4VDl|5w&fO6*-gHaeR9~>W<_2d6 zvsAV>E5J)&)!v!6cjjt7@jdjd*n3mw3U!~rrU%i@Hok>XYBqr$$)?JLO-#84D zcwCnC3{Z0tvOB3tvF;m^iive<4dlm=Zo`q4?pv7r)&S3+gkpGUGnwZYsmL=(qh->_8@(elyQD zZ?d4geui>Ut=Wlt(p_mC0r`mjsqIsnPX27-m(tIrO~QSC2+`H)sTIn%Nq16*GyFnm zKC}>?4{spo!c(7aSSNz@s}xIlGj$6s^DXNL$VL9a`2BI6d=$}}pNYSW{XF(pBGh;N EALeDH7ytkO diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/loss.cpython-310.pyc deleted file mode 100644 index 1bd846dfc236402135178704736eabf7f7c1c20d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 13402 zcmcIq-IE;GRqwCq>6x8bt>m@5`W{QRtg+>lB|Ek;7IMjooq)YgB0C{lkg3t!-rX6^ z&TMY?O3`KnGL|bLAifDC#vmnpEk$_afhvj@c;Fx4fqpX-1Xn>-s+I@JQ7nGH)6?^% zp53V;^iy}|CLzzuu z)m$`9=B=i+Xvx!Fbrv&HXE)u|>|$2(PSab>E#{EVDEFqZm{-rKtZy}@l;@j%VV|LL zEG@YzucmHh7K=Fx6a_!0ihg#4##7U37igxprI}H?L9=UHnmy_P(9CR0Gpim1&F*b! z_Ns?Kvu9hHIkgWo4{S@bUmXC=?6x!q)gjP4xGl|Lbp$kfg{CB2epu3@{zD_Y%&TMS z5zx-PYN$um@tcM^4$6Iv{qNX|2T=EzI)SPo0Pt`-KsVhP{_aOw}Q65P_Kn3WzJW_+DeUtgBM_Iyo@`TMG_jf%{xXr zho^N1;J)mjb8F#46K!QXa#TGCqv?83skiE3z1pmA_$r!S_QT56wr;9QfQn5+=RhvW zM{k{4X|MXHR;wE;)pny^MXu&Qxu)9lff->Uqhy2Ds$b?+0b~$Xqg@JI&S?;Xp$51wC@<#t&VZq!PCYwgQpX^jcf;{ zhKD<+tXuZ=OlUUp9i~%_g0gQ}rqOXa)~wOV955~#c(Zs1&$6@eZ}}3g&-rS-<_D!} ztK>KRRlgNJd9@z+C62!o$`a1!h4Rx&B}}JO4@z(9bssON5?4gmFRh2ZV*S9noGgVa zerajFRU;xwwf5>7sDc;buJ~d3z}eD?X1i8xR)Y1_(&MH6I+jbPO3TgmB{a3xS1aA2 z1@Y@|pl{7lo-EbFQp{cq79Z69@_MtXORIjhH7}!H*t8;7`B&;SKgy`^+L{kETJ5#+ zv@Dk9O6V*T4@s24*ZgK=V?0qNZ0p)e7`8*szq+~xjP=2*xf4WF@mB;*51}I1i)2r} zV0!X%%)IH|+hen=5Ap?D`%96thSzE~%E{#*)5BW{1Lge}_UoT%P+xkJg5xK#a*Q?r>?`%5y5bE?{CY)wG;hyNlqu)nnRN~iE#n-y6 zAvEmsX}l6}PeE<`irKn^m2AtZav!MBb5AWg9N02MKk-Zk`$i3>U@ zIH*UE0CC5ee~ifqCbDF9^->^%)@9a>EYGpwfq}N&gpdv`)T?+e9+^EdvM13pG%_w& zvS`7lP~bw{hF(*avhLW_S`BNU>;P*{+M>yVWw2nGE*7z1kD@VqBTNVjvv`YcySJx0dS5L6<&yFewjrBSJjFw1dM;OTZn!__G*Mum{PEJ!4vy zg_WF5R+5k(;un=pk6#*v9r0vkeL#zHW_)l8SG&h5;IC_w)UXWh{4xM`ft zUa4zrt3f#%EUs{oW|Ni>8$O(?&CEv z6Uf`+UGDbeMoflRFN8rVa@#l9*b7rUzDj&<7sQmF?;J2t?<9cdF3 z8@JtIes@uv9YitJjzBqbm$0|^LD`E1b+`7{{AvixCiHdFuU_%vB_Ow>+*3eZn4hMr8C*j-v}y&5aLhxyuw?9Pb-CJmHvydA|;{04|+J37qm=sZe8li;%P7A(Dj3?+2{ zLj&6pBr!&ibsPHb{;dcbk&|phV2i?d}qH`Au{c zP#%uj`(ZHX3&;-5j~jSTR=xyt`~OQ;u=wsWxbhv^9?m`k)@%IqRED-|T!Jl(Hf-%K z#*yo|kVx3Yuxf=CU~z9**S(IXvX_nPIeBJ~dY6rtjO%&Hv7D3g6#AHgzRT@{_UmL+ zexCtbINQm+$XT)lHPQlGNLuo}mZ@$Fp3*YaYsq!;=Z((Pd1IH+rGWz-_uFrI5#|SX zxK0)ZNFgkA3SF&Rj13T3_f6;f;6AP&oLgR!>)+2UfBRd!{yE9~%{dVTy$fWUrap(q zhlU!|2bE_1vL87jLnRW3 zW+|pAcm0!SkFrZuN*(+h2vcMQel*3GE3r^TR_!$IGq|5l)%V;W6g)Y4jFr{E@4ok4 zbRC>Q($_KrIcF~{-WiDL9@xAy`0bb9yjdW$-{P?Huz`~?OwknX6;{?S`MRR~W>~FJ zGbI<@HVTLccmcfZSkzO1s!6DJHx0N~Vo4eeu`;2tEI>A+GJv5gPnKMP9n?N@ECZx~ zr3av80U3{wxlR{1(=h?;JVm0Br32S-X}2;}UdK@YtjgbVI{A)E*!8HgLWNYQ0L?B{ zc8JOoDx^Zi=uX_1H$;VR;EVhczgX;1Q8)ncjxDah^n(5*m?Q8L%<5NJq=~0L&E%(% z5H~RGi9>i?Iq_S#v~aq}Zn512p!#P(QqIQ!9mb)4oxRcCkE~WJ$|vq)wHf8(rK8&v z!*WW0p2S~ZLbEIOLKc`q2GF&KX%M(^K&I5>&Kb*D}5 z-n1Q~Z5k-^%vm#k&$CCvkfUQ?+=yWpK8U;qicw7#?il_6#%ArS3m6i4|> z)%7Nx^wKvG@Cp3pQtZu0mu8f!RBFvCJm*Thj`Wdr9m&7o-lE~1H{m`(D>xs+lfu*c zc;n%5ESgV7HlpE%R0GulRRvYU7V!{LWZ);iV`DEzl%|`ft!Rn1ht1f@%o-1oGv)#; z75Z?RmeIgmT)oPQVY@k+lCVJReRc~m8b@K9MfKIHR{w=gwm8HBVgR$|MrrI^yD!uz zpvlF2qV~3qqCmT*VZH~o$Jr?BJx9@NgV%wa^S+B2Wfnf#bz;arKKqe9|Mb#FGbc;& zSi8pQ*^kUJ!USz}OANHCt3H;EIEevb#vB-tM^M5P)|JYCgK*13GHbce_}0B`*D-!) z(~40~=p!%S*YvD5i~1Mwi1L-nss!_Z#zLinct*23E`msNoiznRTcma_wcAYrY2?Pi zOnrroZ?m!3g<{P|ET{`dnM;c#0?Zn6AK>2FkHjM-Wh7-QUL`RaWM(qQp4{C9 z%@#|E+(F(B+&|KZB#jn!B54xm8k97k(4=%-G-N?BE)gh>3l?*VEa`PLjfjmbpAFQD ztzUocJ@+G!A_O?Dl`9I)Ha$DmOvyAE=h zjvMbkv8=cQI}oE#eI+b&M))R8enL`Gh`Jk59N`)OmM8%##6x5dH(zc-2t`{Sbekt} zuZ`n7FlL;rQS-lyTejRZS>wdxZHZHtT;_cb8MWlbp^uYL{4_JPW4v>8Kf4o(?R}49(lO9pYeHcUKLbv$okwJ9ZOdC zQ*^Ah(FEWB>g->=`oYG{b6ty+x<0v>@2P}KC=93sI9e5+dLd#LI>0s3+wPEEYQrvN zi?SW;3f-@DTekb|D%p_KZ?}nMB>qWdtDNair*q(^#qCAl}Q{ zE(7@kmgs=CI5`2^CQ=!clEi=u=JD;0IS`m)z8BoQjzFb;6UTy9LkxTgO*seP=_WY^ zW;fu?-g0n=2!EE|pByJ9Z_qzQ@CANJpE3?6lhPgNn?9r{KO=n+esbkZ*jEWX7V1B8%#z9|2FC_ zWAOS9klC0J>s<}$jB3a`{J;dZ&27*d}K-(~KPnWU(SfTQiIiV>%EKvnD=8pK3O zfx#^5Q*NKZH}I8hm4ey8e+&>)#W@|$=%8U6Hbj}HM`ptv|0d2#vJ&ZrNb76Rvj|im zLW?)3P9ShpEjVTBmqZmziF)x63-v@5cwrGr!K6~bJK0#hh`3I5Y?Xy7%Jurjh!g@^ zXoa?qa|3rkZ2Va7{6KJe= ziD>Sj2XR6N`^ZcVi5zIEY5WrXWr;T9_UPXA+`jHAxSrLgv!vyIDla>r)gBy{j~|!* z`ddsT)?fDtV63V1pOEDfD5gIp9T`(XWz-PVZjRj9|B2xGePq(z#sSXHPlV_4 z{Vvb$lcnwel4x8Uen|IuWF>B3K%=958+9LWw6c{?!1S(gywAQS5RN|s-QRP$OUMie z$DS17C}>iT6^($iw(ZE=fBN08RTK!8*;4X3UhnxvZJA(XR+J^et>DSJTLfl4g$;EsR z;tB=rGa;aLlSzuEEOFXvXxw;oyH8GpUTUudz52`iplsldevm;pcH_hQ z-$UyWcK$0QV<~3GBlrjNwKeRAa;g=dXYH21HqQ_Ie><8QD9FE}qdfmEKOf={ERa*N z#r^#XRO7$Vm1Q4)u*YA(Yw_6`T~GbXOuo)UVj@SGdxD9qiUgpaW${HO=b4Cdcni5G zv(#)?Lpes*bY&vj_~h@WAQwjP zpGfF!AYP5zb4!kO=zg2SO4&86=@hyYdnJ`FB^qI+I09BNuMjiHvhZ{9D+z fe&jmVT_=u@CYX}c|~XGBpXM6Y!$eKBA1h` z?C0EZFpV~{(&N@b1Q*|bv8lx zjt~+;renghCT0py*nG^)%=~G2*2--Bo#mMQ**uoV!`5StXPwN6u&rERuFz8H^C3Czl zZ7*HKtZ{|Plv>KlglZwQyj%#Cqr~RStW`s;o77O|x!EFzSp}!A0CnYKP*-y5Dp0R@ z4C*RQy%N-c$Dm%psaJuz`Z1_ia_ZHfu6YdV0H>}6^_s_^Ud5@`g1YW8sH-`35Y+1) zgL-wBIavQ84SvuhIZpH?yL;o&o)$!0Cu8wMFFVlPkwnb4Kbq_~*O8eqf?!63rp?1w z;C)9pt);b6N{*`*HY;y-o9dKMv`i2JdOB+k5uzEokZH%)1ffO7>0IymSkw9F<#W;ASGuDJ zb;KI_Snn&bj-(tp6N`78Yud;<`=Ts)iN904-p>A>SfZ(?H<4)UyTq&P?Cgug;?Xlb zu}@1e-XYmew_t}w;ck&%1MS(R?e5hHB1nnZbyg+k{qUPBNlud*T+PtK3+ZO-r|)rXJv~m3lA7RSV(ut@$}G z5k;wmkG>6=wp-#h-S;kgG*7kArq5x()mt$gb6v>1-5e0Y`y0UQ+iVT^q8e&LMQ*;# zmt%rjE{m?#SZ9AnED?yt1F@di`B*&JaG^U93(#l>l6++~x?D%k1~Bk}?nK}u>yIIG zXMmps(;jCIm`-NddBB+tbHXZzzFv`7OTz32OoE3w1n3c&G#cL(ZvdOM;$kwpLb zz`8(Y5d;HGfv%q3Gf0Z}pN}X_O1NAnP$r|6)d#wh0eJ$VfvAYZy83&fEO0&+jW@c+ z&62${_Ih_mOtN(*FZIQM3Fh9uppVae=0Jxqnh%VGGs%WZVm*=t&5&%#Ue<9=vdbUM zXlPh5WpYJ(dLoSXljM;5>qMVq3AH^ZS*g;Xo%tyE z@A6#VOs4$Rcm35Ve~sv`xmUbARa|$sxNf?5-4$25u>3~jP~&uA^_9GIQN@j2L%Y7c z=ZfQAS=CL?_dID&(ajShCnmQ2#h&l)d0;UwcU|$$+6AZgy8oJg*!ynHdky0asdYQW zbvy4mcBTELD1Y=TqJPtrV^i8)bba^L-QUa z!ob@GDVRvwO!BhlOw}2QG1kifm{cTNhm0HXOHg1yKo8{?Uw`x3o5RnH zI!2xwJ1|i;-Xd1+nD*}+IC9TjGI(*?T{~s1<-L;C6&gJLPZR`r6k74jbsu`Qg9Av9$;V?UJQ0$t3%QSa;XCq~w)*Ga^U%C@}(ad5CN|LnJwAT<7ZdA@%sp;2W;^gpoQH>{;}8m#-N+lt8#WD_2i?ORH@ilI6iuT7G#o)EzU}(5(=%vwLST)|%&}KEets zvuB73&oU+34U#aC^})jew+ki9Z@f74;*H2qWVCgxU~J!XapOQsx^iWzvO%nDNL4nA zmCX~)6R%CKny%b6WnF&XQIdA8q?H~}R{CIKc-zodMw4SlZoMUzZk~2+q3N-8No#zU zEd3h7`y$4eq|wKn9YU)H^XL=BVeZK+#V8GPQ)VfPXqio$eHsgC1)Ioqb1)Mwk!H<7 z`)Lw*xC@_|YYyhU-p!ybC4zY}x@<>&$)~(jCm{zPf}9LbLH_KS$aWHCi8}pqN+8>V z%xZ{_aaN-;n=0d3s1(5VBYzCP#5?fjnTAz?R8^B$)s(8*Bvx&jXq_yW+&5hX1*<#_ z##A~~I$SVZI9M^9xOrjZ!f0a5Jp86u@Kh>)gP6ZzqIx1aSumZyllT8t5SSd7zm8SEcN@ifMq*PNXT6{2bXD> zI5aU5-g!nM_?Yn)vky zXv<4QL+WfYL!>s6pJg(36s@?3U*dguke?kwNkyu-{%&!7S{AtmpSf{(=aph0dU09fLyNh@{u8_ZtKEa<-#9;G6KqA0zpJWKl^gF?ZUlYGv0QX4 zAKrL#>&Vv84WrG&JI9LND;+PLD4M7qs}w6XqlBvL63QV7I!iwV+&_o!TS&ucB^fBk z8RSK~kog)BY=ft1Kv&;P`PYj6wPTjCBjesl z>ty%`KC$l52m7W%ZBzcWQ~nng(1)liIP6R!Fqb&BIDN>FdFg?}iYmA>0x9EVnXUam z^U+L5dd-IfCf&m9?C$IBi3M%4ID80I=BPS`lJLChj9erOp{(!Xmmracvzfx;dAP+BtbitUNO9Kf1smX}sarF;ol9Zt6%Q z+2u1pV9O%7*5||(t)5 zb+we_ddPvjgkd%MnZ z_fLDfzBfYdZ`HsL_HsVK$l^b4V$b36VJCSXkXMhPVw{B>qbQ+q&?_U>L4-461mjS4 znBJ9>r?ZS+jpPjYM#E1O-TWXQRZDu$ur@BbgI1$xXx&ndM04 zjt%%V;Vk- zT!SIFax^qncI%~7?RK$t`<TxMuj`NTcYg z8+hiP)0c8Czw2B+yn5QXVqpKh!qOY-zO`={!*3z`UO7w4@G>iTn zk1`_31)H)uBiv+K)*|Qn&%{`yGuD%gc4)jEooIK>(ZVmq-LDmzh zC6Cuqs@1|lH$=7EeuA|)cr9gGE%Kols%7j6*1~%zyj-ir%R^MlIIo3(n7i;hI#~C$ zur|EqFfAM4PEnL(J@#etUW7;ZiaiOM$JpX>@_izLKJ0SQLXhQ90&+&<5_HKxo{iV+ zB}#9HCz<1M$)#@5VTZ~kFDlxn;z~l$&CXKhF7nQiN8$&2nY<`@XUOXyk2ngjdK3FK z1QN7Gpu=hzsmt_8E!U)|7Yy0|6VaXHSUuwr{N=xU@aFOF9Ve`OF;%-otlcvCN@`Dw zxTggeS!I79aG-=LSX9Pg@WruKv2gw4Kw#SKy}EU<_3ExGD}kyS%04yiT!T39F?=qdV^DFCT0n_Lhd$UlPP4u$sGqjD`v$BAy3HK4rpt) zn0~*LwyPBPTj!|kX}OTFvo>V3vP?)a~OiI zS_kdEs|$M0`Ya39XV&FdLVaBVZNaU__d&>`9cy-&!gk4)=kG*ni z+k<;_zP06CGtgbd;wZo|L0fbX2B2756*^*N1lL$gIucbdmAKyHL)LC?}ucIRa>wtUwcjPavA`m_C?BvI`fs zS+va>>uJivUDkW3T!u0O-<~m6cyTBLb<*}q4Th-&RvOx)91A7K|3SeC+C||4a_RQTqC3w_mmIm`zE{#T zVVT@GU9$IzJMAw@`B#blRinp5f73)ws`-G}eBjO-VsoqLZ=G_q%1EMh9*QGdFYnkM z1m-gA7+G+A5XOwVHDvZfOE&ISo@U)>N2l6ae&(;s+G*bXGfU&fvp>tP56Io67(aJQ zrXUUm44{FuHsaCqu}DNW2qr!V3HmB!%Ne;~W@T7G^UZNN%Phzq?t?cV-1iqUCv6fF zLmfA>7x3U7!UTsc>}!I616U%gZ$lk(!#c1;KMcrAPTpOw5%|gqnq_RJFNI0h0G8R;k(6_;fiD}W+@J>( z*tGyv601Nc!|wVypmxg9^eD8_dgahcYw1Zs>qaYtFl)^eM*`v8Rr zY~(3HGgb$YGE^83y^M=aeEGw5pc)X;g(Y+}Emas43xiVy>&BWUYf?Loi#v{g5SrTl z{M7TUQ~uDDBP6eH4uB>Z0O=$gA09SLp;35x$a#lel1Ixe@f|sdG!5-`9jXhYP#G-Os=u;D+=cv!B6UeE*avZEK`W7usWXC4ndXy()E639v z6y+$Z;0;gDBn=>4!y7%~uAP4ZXz1&sC3p>`EE#b*D!4 zg1y~Xnv^*J(hLSWBfulwV5r;d83AABAa<#AGzoIHY34Repr`Rj6HyS%C+y!VX`a|P*)d(R|BCy*CqLy`A$nGfR*0VZ zw68Gb3y8kJ=uXksh~p>Tru*5iY-Vp0U{q1OI#8YFs4Fe_J){OdNv>r)gt9_lJcK;W zw$TA)!QR5YiT21G?spKn3@st+RN8LN_{MEI#+S1m4O!sAYf)y0L0XF^QSoD`2~(R6 z+&P&#cuG8YYASqs>OlKc`%9!aL~!OOTXV2vG>1j8(D$fU$Rn=tkPM2Pg?fpF!kY;w zEYxLlaTZEXVxja@St!C93q|1{B(~ddSaH+vd(cZOi_v4837I1`gJ)BEc}`4FOL|e zjOw%KN;3d9)MwV^_~i9jb)^~Cr*F{`>gy8N_t9tzX0}(b(+H=7)~5Cu;S2(rlOJ2}kWa z@PXQJ5WL@3T%cDvTqcb{2M+0%#LB6wKrF#l(I5I^` z5)2%%5ad^*m4o9FR$RaGs^1X~sYctZOIe?+ReaYB?cvXPTcSS#+o83@uDc%g$y&m9 z*3c6EoVO);x;{y0Ey+9aXHTpxy3xrQE62QJrD3I!y>$t5C3{rntMk+9|M}|7Lm&9m zS($mtA}7-FGw;+;M%+Ytq53xSl#h}$;01Jhr|#egzoElkK>~Y~yaVKsxrD8NhmFI` zdFnX5?Io{|JdTa4DU=PGIErc{$Tm|>O`2tYj%dOGPmy8+Ix+UC#lWf}%N`(9{U><* z99(hPoaaZcIRDP=yN5#wbt84D@&>WIAyvLjEZ>&SFB&{DZ*)1X z0j0|W>5}qP$r`a_&5Tp1*ghlJDtuRt&3c52HKVUimv6mtZ0_QQlcPsQ-kL1AlS~~x zB_2M7twAcm%uBF8xe~4(`^P8qNr5@~WuXsw96D^=cnf}^6$&;adV zJ=`QlHm_d^@))RWw?xI^(-!2}`)D5fD|p7zcuCfvYH~L%K0~=X{c|LDw?Ov$#B#UE zXq&^(vWkNqA(OW6ZG*f|8@V|yZeLJyGqoD%YoA1WP^HmUe@%^!YE))T0P zFw2A?(;1Z6W(nm%gK<8guX{ek5AE=}TQlK#zsRNgqIa|}a-I17C)B3J^tjWwca$Fg zV)yNQZFk|j3|3v6gU>(1OJ!|NN+T&Bw3@D@&Mv%lsp-%HTnlIc*z)8eKCKv?M&Oe* z0=_%3oszip&#fs3DSDW^L*&h=FEFIyMB$uT!Y-s|E(d3w;ov96bJYmw_?j94bwZ~} z6lb#pMQ8X;;t0pkKMU&&>A>n#V1pRgkP2)O1Gq*rxnVLo9e8@mQKcmzd@oR&3T(X_ z*qZkEQ=Tf(Q>EP9d<}n8;El1+d#7*)Iy7-+tX*8W9rw^yxqjj*x*8nZ|BVLRWOJ_g z7rWr`8)_4^bfZel-#iKZWM4YJmjB&>{4JAEQT8uRMQJ7~xsVKfx-yWetQRZm)5}U9 zI5ow_K$TI#oK{H(rU#`nvvm^>mw$tuWSl7OtKAnXHL-v6Dv1$2%7 zL3y}GOHQwApi3#%Rp3Ae)PM}l=(EG z3+~B#_)%Qt&p*tuDQ~z1tXuz$id6E~H26`}I`6K0>86R$9m{m-fh(TH?aHqbeXB-~ zi@x>vBM{yVKTo^z!AE!H3+e+(XRC*I<@0F9+wwGo#`d>GtVo&h{g?*j6YR?$ywjOF z^kwnTm#1EQY3kt1Q!htoUp|_QySQJ(QbKQ-4yy`@)iJ$A5k9iXrat={?%>^ z^20myB$1n*s>qG7CUPVEaenUkS1SJtg%gJ7U**ZgQRCmJH-9B=W7zb6rFbz{);?Qt zo3E5hXr&4z+|b0$OBsA5B!kH(>&Ahhz?GQ*8eaZCnh1Z20v3;#e}dd|;U)eo3;_&x z+1JC2Pn|0GYM;Z*r(PZ-c9!FoboTV`L5#oZl(}BUAv(J+7nF15D`zsm%^~I*o;}s| zx`O5&uhfWaM#J{f`vm(g(fzLa~KPBPIO!hU3kgDXgT==S)Fc!}tJNdldiGy<#CC z;5g+4^aU4G*;@!1UHYw^ua`CC4N!0YIu*8PNtKk2r6Vt}l?1;Wr=9CJ+FsFJ%$3jW zTiF)SGm;y1=7HZtG_wt@7i4SZ?E(b`pF$7l?%VE|Tzw}rU2>G}+ct6mw`9-H*L_>| z4eFmk@(iWhg(ilKLd`TWzC(rZiNR6myA)kO(tSXA$m8qymR10N|20H|HTX;Jm-?VP z@++?L<<mvNG!%Hsue+!LC+=xx^E3#5~<^^8kZ^TEsV)zdYxC`4M zU!?U>^}k2n?~=EgymjPp9?U1xZi?6T!c}tYlCK2sVgHUOL?8Td4qzM15hl|| zKEYJ^4}#}MLg0T3t9~TZ|6h;EwD*BynW!qBh_OK0*Znud#SmW`C%Tsg9GR-pK7O~B+DZMz-56}}z06`2(f z1^Hf+YpnLYy79XAn#P-E1w>~Lm|ILH*Wke$M~05vI5Bi$R-pLo8FMuXh}}MS>)h?$ zTfMUaqO;GM3Q+c7|BZ`77jL{hga%N2w$89Fv~6_r?QOTV-QIg^??=>BjyeAyUnj-n diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/models.cpython-310.pyc deleted file mode 100644 index d411204f33eaa55f26939aa388421530caee96e3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 15897 zcmdU0X_O<^S*XWUExqsEvuMVi@kW3{$cbZnk{Em9*u$ncQ8cRRZpo6m=T)`i z8LDt_JA?xa7!G?J(nAs_14&pDAZ$s1utNwO4ix*&;9x>Fw{;t5@&c`+c9*vsqKY=Yt=}n@=9FJvZ^SSVyPXatJc)6R@1t% zTCA(r^lrQs?;16uo2VsJmeD%NT1x(!H50$FPP&_^Wn`Y-$#!$KoD9c1`EH?BkYS@U z(k<4C2q&y$r_?Ri%H2w>(jBdhcE@UC-SOIZccM1YovclEr)pE(>DqL6rZywXr#iFU zx!PQJzBb=os4a9CYm42b+ERD9hC$mo%iL05typO*^Ppm7Y_ol!eQ+zbrPdA&^0HPA zc{%%#Og}tG&szng7l!Io2YDk_5qZU-yd#!&mr^ZV;`~(;LF~rf*&umSqtj`e?%08T zbF1kF`fV-84U(_ha9iuWMkmluo!zj5_*eE2Jy}%)ZT))4Omn@v(dyXs?z&|oo~*s< zwpTTrmVi zJ@*c=Zf!PQ*+!UKcx?}jG@2Yom@~~j(w-yu+}%-T{$N)m|ME1ZOdx) zR>PcPSZ7!1sgvuyXKku7$lO(@d}qUIH9C7*&)?N-bnLs>5O#f4G&a@(!)c1vhTG&b zAO0MRBF-^9B3r4cmQvGFYD!DRQaY|!O0|wyaXV%iwr(YGC2^(fxNS@*mKpt~k(R-g zwG-_m)+c4&PPI*}$(S`_BOj%UHp+zI6%Fx{owalLE!%mkVi&AY#AQjOjUimX@3>4u z2;m8YN9=-KWNpMJ5zkL3D_6Z;Sy8PiYx==#t%R}9;2Qc`N3B^|uiUQ4e&^8NJgx;? zi@26>E#o?X>maU4T!&xTtp<_XTEevs;gnDhy-2p6y}%H#+s^rn|nf5+qQrZg)CCzO}Q;>&=c6 zRQ4ob9SOXU1-3Jk8!?lClMx!O_XnnKeFtvzP=QesxyY7lc z&skX)-M|#`EvcjTARUEWxqqzTb!D@g_L+Kfy|>a@4KkfZx4YrgyPZ~Jqd7LTIF}I4 z;ELgL9@7-XRkoCE^@5@*p0cf3svB$TUKTQLY3H>IT9h7>>B!wSE?}1*%F$&`Vpol_ zrEDtVS}VS)`em=|DfNV#^m1-$OB1i~a;vJBC@X%>8%2)krnfU*(lS=GRrP}EkI8sy zJB!~OPlOBF1$B@+?&iHQZ``X`sq-<^F1RC}>SiZL3LEP__y_l6bucA!F%0iC1>2_v&%+pa~(%aLgKjlq&Q)Oj)9Cw_UP_`#7DE>5R zOnQ?RP7NzBM=<6U<=tLf#cAq>OTAfnqwp8J1;i`k z%01dcswiS4DfYOf`iniyOZkgl9QTsKjPPC-hwlZw&21m>W@M|2GB@WfSjjQ$CwQ*e z{)ka-3FYQlZdY%OXuMT#F&wLBV%9O+UjEAQM5`MV4X^s>Xs;xC3CX2+SymQ$xHA<*M znc@bD>mY*Ieb?|DS2s=tDV|{1*LvtANcHeuH#)6zc987{<*slscIlO@eM?_dQ^80N z2Qbbz7iU+r5rs)BcEjCCl4soRQZi{*sjSO`(~a5@N7&8P6xUPmc1{GzJ8h?BZ8kbp zL%f>Fw^AIZs6qtBreoJ}{;~8WY$HnbNDy~6F`+@Sju~pX^?HyO_9}KV;FZ&xD=Wx~ z3`NBkBUPlBWxO`rjN{Gq zO&rW|85Kol<>XA+eXJrznOKN+%t3OHz@x+I-W6%)kL=!o>nlBREz?WR#+nTrfa3*d zqr2W~HJw1m*n#7mR(GS*S~(jO@9GzNO}~uDv37CK@qlInt%ucfI}lmNmT)k$>*R?6 zvV>EeX9LNN5W6`+QiPz)35>?ZhTXG7i3MW7FgP0f>wcUHf>qac-@#-flGzYT8$$ej9M`_{LdxZM|2&>8?A%^wV~)xpw?| zVFBy8XJuL>u`shJ?AqJ=7g>mRQ<9ZD@w_5GG?fTYT|3C1#MQ< zqHppUE#41nYVxUaLRT{=lhKN5S=A9rqHa<(HGBrO`XS?qLRybyST4#nqB6#lg=9u! zY*f`Q>z4~g4q=Spvaa`Q5C0lZWYS^W9JG*0=1_yFtXmAjPbM!LD9Kinm-RnBmnBR) z)_>Jc1-jL68$l9)1uuVZHOLOYI$n7@$oHcWHV<2hno377L15&;20=s4dD&$QI@aX;;)+p3>| z)ch=D3`iy2*8MzW+%G^H{s?5kFG42$5@ZTEC*3yvidSi;{ZYt_KL$jV^;8Q;7pHvI zD--kj<8BW4MYUo;+d9tm_<0@ZMA^>wO9F`~+XZg~m}Cl}5ugUy1|;|aPP^g-9XO=q z0hItxJ%BRheap;v8Sxfx#!|dd3tZuC?0p8k&qTcgk;&d?d!}329`(l2TlXk5I@~;n`;v7 zT-1Y@Bbq#%E&IM8hBGoqLew216;r-<(`9^Yl>VoW^53L2mWy8t%(7W@dPtAOET_+mORB>y@aq&C)_ z7I_sY+^%gT5z+oYZ-N;Lk4^X#JM;1lj6T3-jUYX=xyS*FA4CQK8;G48mUuf$zYC(8 z5$|ALIz%1hfoHk*0j;!7Z-!$rK;W=Vi3=>4=x=~VwDI-%VL!dkDf^q4c&DruZGdt* zs>@=AiK&)@IV95yNvs@Oy&>U1NR_3c!jVmghfyWGkwm2VkarNRgOEk;I+O2t3vi+% zzMJ*Fi{dl|*`ugYyours6hsU`Vy)rud5bqP>Ge!9&)H%fV{F}H2PqrV2dxwMZtPqO zniQ5=rC6hAv3wk858A63)LTIes;aAs8DOlT9wMY-ej#Hrok+(7{wf1;Wz?LOQ^$c~kh2epVk?Eu z>p7w}hKYHSQkqX{shvKCL&oFhG7|Yjw_Qz~5c9a}lVXAL#Qf zv4DPheZ3=A7}qIC$0RZhUqa)QWLx5sESrsBG55iGeduAng_Zb3#T-Nxm%|q%sc5C) z``0gFCjU!|Y0qf7_B1f;Q-+~EX(Y7&)RNj4Fl$ffruH9tTKl}7(H_^b+P@n)?Q^-j z_HRZ(`>Zyi{i{~g9y3eYzZhlhGiF75G&idKvoV&tq};E~4)?^c{`t(F-Zb>44fUp> zHx0cNGR4Q`$eS4SuFqr?bG`|IzBblUp+nW|*orRooH0v(FjLc0@sxpU=xddspM_qQ z+E=OPOxh`EYE#hCrmQrs46ZD$94=^Qp_@%XH_I9(>kRaqMXLloXIknxc`8?e+>U~q z3u-_X<=YNw4ej^dG%CZJeB}KPA2%=JTW1-9+PpT~uADK1q`3hLRu)07+;Kn6B7eYKBRU`C%8|Q*Tt1{k zcd8fF;}Cn7$X^!SB>5 z`m@05bI*d+Uk}tc53Ii6FZxSBlVkdVm85pC~ab(RaUdNRW5 z2T=c$vc87jz4aGS`v6j*u?1Eq_71W7EJ91W@w!Cf=OxmfvrM3M%|94n?>TSQ(wP!U z@B?Lg9vq;2LGup*Q@7$UkkjR2Y=~fbn!MUVF__VY>5|05!|st{3&%X-_GdwP<$K3_fvcy#Sc(yLqup) zp6#_=SS~vB<@qiT^qoWg{j4pIc6r=OFhtN@Jt$G;O94y&Gh!4!&YlNoQTzm>KS}XZ z6gI_ffGB>NX&<5ZFvTB11o37EDpm0_Oc>xq@w13luX~XwQGAq*y#$mfevS?NJO!t6 z2O^4JK(v+#;a>#&@-h~JFd{(-`^3js@8cAop!j8q@1^)vh#EA;& zOT3#xKI)yP{M=AZ0)#hm=l)~V-vi^4S3)?Z?1FLe0gTg!VBA|x%>b->A@GfWjIeMV z5U#8ywee?xZ^;(}-$wVGL=lh^zlNKxX2q{_?tX*fHz|ILf=5~itt7}IwGqEZ@w*fT z#V08~MX`gA<-sV=#NS889h}+Q5CEj&2Oui<%8W6j&p+$peJ0 z5OxN?CH%@#2$vDg!aSWLZ~_oomQbr=jl!gyPmK+Ktz~JH&e(W#r>Hv&s3bd=jgSW3rHqXc~(d~a};N?R=~ zw6qIlQ4ZH`(6f6^Lp$0Fix^ z#6)}CpK>RFOH$DGC1H~-LU%hI)t!#&PKI%u{Yo9$)b`Z&v^#^-J#Cd>!zMnSh21j| z*+Fam478&Ot4uqnc+`dEbxXAp_vp^w-KtoXEfq5jD`;dHot2i+S?Kgi7(;j$vpaWz z-bR6Npi4zB6S%`!juSo0aZ-bd+-c_S^a-usVl>KmlqZ(qNHOcsP=0#9@t^thuw7Gp z3D(<=O;dasqu6*tYbLa3FK776A%N!>8D63wbqH7UD-b#*Udi~46gNXay?27)S5v%( z;F9_d*6XGi%FUI#X84B)wFwCSwhn@O*uX9%1jj49RvA$7^xNaF@l+-e8mgckx@=rF3>eui>wk7`@`|m0hi*Ox+kAkNvPJPA zL^B1^|4br~#B~!c=RQmld`VR7zXtbTPnY40498`-D8q&fkIS&+THEl6l<&MsrZ5zX zBL+Vl%NSSu6u1p|duTcGO*e*csQafEh-yKvKsE&?mxi56+19-@_GNgqY{%i5A?rc1 ztkO1MEy{rp=G_EhNnB{XZEnS2x2b{$7W@&|u?kTw6MIxHtV?ei^kBliHG-X}6qW!3 zi}@z(P4KMT#@>#8%U&KHJY~5@HQa1F=jG*|ulW_k^6>xQ3849-ZUMWbYSBL}k6p9? zmNrKJ9Sa9JWO#?jMP3nk|_%-+{-6VNTzg3>THA926WC@JeAP4 zU`TZZEe8+x)5PB(gk9+yneY(BcTjvM#d(VFq1Yi+Z)X%wIAmZ!BJ?zLU`P3TCdzhY znGYfgU5XL;Aof-H(I!?c203BF5fj#Kx(r(NAWj{4C)|9vp!Vp{_ai<*!a1{P+t8%9 zSmE^)lKgyxp$x@1M0N7{aGjm?&7p^LLEA>C^B&Afkce(Ue2keNr}zW~2?*Td$UUMG zL1mY*qn|APNtO&X&e8NFPdDJ307uinfUX(dmP1%sZa<$yoiz=w7k92qIs^D0>TnBc z4jWELLCj~&ctkH?K}S9|qk;8XHJ^fZkeh>QJd=yVLXOP~R6=f7VJYKj%?MjUsi%_$ zH!I`6bp71KmD@&N$s3EbyU_MVs@=-IMB%G2?)`MQq1x79Zj54uLZ`sp4(V!#SOw?e z*$Ks;Aom}+QGY7}AQTnev;OGw)y-^h1(_9m-LZ@jb*~h>)alWR&(&4jy99 zLO9Apgrjl=;UL%^-9tFWB;ly+A{+%tIJP3fQGf$zUb<#hKs_)Xv{IICjcKnSsmD0z z!YCX<$0PbM4(c%h>M;>fk4Hc~0Oi}0{v_Ny;W8?z$ApCRQ&HWis4l4p+LF{`21hO& zK1nF7!hm|r4^fZ1eUO16>H!sGL_PjyKs_c<8@!s19B|B>k@RB*^rMJzgm*Fn`Vrj; zdK`y~DV#b$Kc;Y(Geh))RA?HxQxb{{`iXn$TV>_Rq0_(-hp%UJ?tK z{a)9 zBS^ONa*i$cas@3Y+swr{;elo%g#UJgrNa};&$y;!46I+kXhwj$$s^MJT7J$Yqrk`^ z*;|O;i;{ih;SEP0w><7=1fiVd5$G){tF-fAA4M+i-~nMG_2HmEQ* zC-cy&%$t{a^HFR;=2fEDqRd;2VoMm&QaGY*94`E4Lp1BlIAm07RL)&E+ZZ+S#_&X7 zJ@XEL2`#&@pP>}<`GmLzBEK?g9JaW)f6zbVANH$0eB5BSn}wI%A=sVLsJ~=QJe=~6 zct?OrCQl9#DDl@^?^z0Ro=Y)^swTb^KVUZYy>KZB(wp??4w+Tx8griINrK=YE(y`+ zneZ+Onf1Ybz9e6PPXD5iVto^H{+8nJC?xm#eugBu`Uixn$w);+I$#y-4e2TB9 zxF4dL3QNA53GbnJFU2Di@1u~l|B)dQlIoPW#PB~;kV(OtI&_}?Sw`a0T}j@#c$`J1 zDdYhlc~*_l=Skbdqbw{b+kmYFsm)%W%#pK6b5-@|iy&?vU~PGCdkjtckU8;-Ec{Cp zpP~2|#coPhOG-)yRu&nSq`45DN#X4N3v0;P{S}026S8n*C*-e0CrP1n@enKj#mTp3Bg2htAMR`8A%&#I<-x-ALu&K=h>cG~=Lw zX}S+Z?nI!4VW}OeXrLZ!P|raZ3(!E(A<2xDSu&P`>rf8$b|_>9ZICwZp^&H-j)=MZ z7QO(A6@~l1#uIw_+{BfKQ=epNCr^T3hR#FPv6sXzf1A_#I}|#_3I$HA7sE6^hiV@| zUFWR`gx)~wCTMUnE%ukHJf`bFAHR1yA={iGK|^7P@LV$HKpjg#QsK z${ZKIR?aav&+vb^QvHJeDDpqIj{MJukZR!qyC0z-DU*NIkjh~BAB|A!Jcf)Q&b}R~ zzmeXJ;pQ^MI&$loVW`QHM);Qz>Ea}QpUHnh!TA_^iRA0VHz`QIJiLWha7%ZNKr9SYre#|i1M5D7_?$ss);6U#|j zcl%fGqt(m}-bj*{pZ#o<06!&!JXYJG*Ch zyT9+w3;;sV%T}|Sb2<{f`R<)>?!9v#Gxz&_-@Ux$a9BAUSGF&nJ@Bg>_xBXgj(U~3 z^{$TNzRIP!wC)Hus1tRAbz9$7#VC78waCebTB5yn7rjke6UGu8cc|Z!Dg{}FexSnTf~;ZRHLgLq{t^3cp3Kp4CfYM~yTcNYgpeaffVY4Otgt-5(?WRE;z} zNYnc<((F1@m-4+&Pna^v`aQyV*}N|^Ihi>%nUxLuC&p%F!_kS{tZY6pGdnRYWF}?9 z$@4Q=*~kkB9#84yy6JvU``Gl<%*14NbZUA$i?I1qdynqT}*yX&CN}{ zPYsf?%Dz3b(^C^;2ZXG6W+I#0J2^e}f*d{Y;@re%GU7x=n0;ED9-kYVWoal^-G>Dv zk{P3NC^6xq6GAp4?w^>Mp3DwqGvm{9v+J%LM_~?3oyv}nPY7p}D_-TEb<&N;r-i#F z6)0U?caMK4Gd_`-tV-RL9?MK-(`P46D+Ni_%Z*esbFL-I5T@TC)?;M z9vEro^>ME9U%D+_H)2TZM~rC$en$N2N9spRMkJ9ol{hov zE%;eS%tPMFtQ0kAJ>Ru#HfV^>R6n6#5v+(oM*(1G2u!3 z5To9~;{6DF7#Cny7|+5%gmK4+Pf3f#g%IZ^ZeRn?4d~M02zSnW&TvjA`YY>0;nVHu z29`o#Xv3YQS4$W{xf=0{;upg&j$adg3H+Mz3*nb!sW?Y!ro&3Zo^~LZZb5mdG@*2> z80NWMINs$*aFWIl6dlq$sg#`OFua`;JZLix>r^MYg7kxz;jl;Jc;|5Gg`(9*r+P7iI_8~yI^NWS zHV~YA5NVWh9WK3ChQeBE(gnI1|f9cJ1j$u@pZhO_wk;zb-NKVrP`?|NA2h3hFY~$ z%IBf>GOuw?OrHAdw6Y2fyt`E60gfA{mkm33ecHZVhcX~mInoYA@0fw(`Y;`QMkjcOAJgP{oA$J*wNl^V>yZxif(YNhM|lfx=e_6=p6wPPRx!GE zXsMLXL%Yhn#__Ro4@WMgewOwM@w9JvPf4y`E6d}E%HYkDb_j7kHndlhD4&I<;R9NT zZ$b&3O5O2hK8}+5x7Q1ak5JOYx{{W9{KKcDwHKbZ)Z_7%Sd<~45kn8xvz ziuSHKriM?HWc^zGtsgynprkCQ%1MG7ex`I6-VHr?V`=yLxTkc7&;x~LzL__r19WZ2 zb7&72#$CCsEJ}TAcL~W8b-YDL@gD%E-^7PU^wF-v;dteTG4P1%r|B+_f1J;}Edj>OCsK=+~PM^jlBUHUH=xMWA zAvcX6X4zW@WDgZ)T&es~rB%>i94a|CAt)1HSr1Wh8)8M8^T_pwmC3hpYh>{&nK9z89Hnm|>+6i?6%I(y& zFfo>s4JbQcgWSZ_%;d!B^RhRs2_4d8M7+Z|*;}O!AX&LiK<_yQ0AMo9&LWaa&moUW+p3)i+&Q&LuH08lo8HkMVgAFT2Z>}#)8Jp%p2uJkTdAq%k7E4wg(1G0fCN7hZqo^=WtyJi`M zhL_3ZMn`irBop^PUXhRXot>V__DyA8KAV|7H<1B1mhC|%=KxVCfm2yw>}+4ZI1bQZ z_B@MI32V??A3;Qux!#%cB7OE`2Pv$*+IXOWmP=uF^y62!+nYIy^MXk-hi{l8MYDay z9F@$`qCNWc?3KyaCl`1BdEonv-)ofO$8PKEEY`Jp&f?Hyd|%;T8~M`6Lfg0IR{9RU z(|2&W?~oKYe5YXi8^*WnZ`unzyZ)U2{_yvPrNlFuklk}(S1}O3cqD(I7>MKzziM5z z1unkujs6u|hh*z0Hnz|A6uNduksbLX#b{^VaKqg2USsUa_Sd&Be17TKLi^4_+pgut z-34=aH5$Ki`1Qk!j+M^GrOwB{yZ28H{o$dvpSiC4qvOk+N3K6x=r~zu=a-|;7R(K+ z@vX&Rdok2m47c1g8Jet=aN}(&*Sw|Z4HZMp#pGkSRBYOG(^lWiuEuUVxmXj+!M~E) zEv0s^qz+1{gN4qg-p&=;k1R)zBHY_?x?zLp$>JV!UHJU9f59AEb#)ix{giUJ zWns@{XR)EJ=xd@3940r4c-zf2#IKkyn^(f!Qn-8Z>oE2JJ@ZrXY6gvBxBPY{lLx*;^O*8}{y19|Q+T zJle6STikQiC`EgVu~aeCNVhsnerJCFnj5#idhEin6?;OmCl>rS?Ct9dS9JSdb6j$) zxLYN6>#9Fh45wBbqVxRMA`9oFhK(!X9a4D5a(HJk99<1Zu2^2Td?URW`1MaM99_~E zk^{@(Evw9!ZcZQl^zI{)VR-`#i3_&bA3JFn#m zoAxbt?0=hTG7$VwZ}52k+cP@O+j9$V3{EzTXJ#|98Q2fLtHK%CL9+-3yW(q_m0em; zMOkI17NNkcvR@0Y!%j=LodtXbSX%aJV#9bsQ2mwA&3;7j- zh_9h)2_dKkHNY2Mmu3j5Hr3`GykEgzh464p%M{AAF0Ch}C2fGRYTRxBsKE`DlwnPd zT4jBDLIiTnB^cyN@ zxKuZ>`^q+yWaEH4;(c{O6Iv^VbD|Gw@Gwoh*fCs!ZcyJe5v@89TLY3oEX~WRj%tifl^L&*S@aq+W8hde>=fVfNr#Q zC;3B=4WX7T+s!PWh|p0ZRZ{h9rP_*ABYX#Pvu@X;Pc{#2uX-c|`G{7R365`Lclfnl zR530(E84fS)V?$>Y#0|QzNu!~sfx4{H6=RB634D8%4J3C*MO@~!ltwt5MrIsRUx_a z&L@LpmnzwPKgo*gwew6UDXHmI_U0M@333MH!OGafhlffKOu39Duqk%0tE5Q?8-_84 zm4ol*-MowU0LUdcm%aoh0#vr!<}eI}UZGFeC~Oior#Yb?zX9ITsot5vTY}$w3RCzB zaAtF+F^LzAZA`=$M~x@xtD&YvonYTuH!(1&CRy1q2D6#? zvy{Zo5qX6OtCwFO?iWF>*2}hv=>jXT_y#Tj-GJ4L6-)dgNxun_vWve)*KCT4PDZlx zG@F5HKXGbKDN8-9TbOu=-ylJgHl?Xor!t!HOpDhHZrZd&Twu4UQyNwssa7#gk(P-Z z>X6J}iiMArZ8XDlXT^@ljL12bgm?wFC=X&%t^^Ope-H9HpoKdxj`g!U5qFqXEQU|)D}^os4ozy)ts!Gm)cbP zC6Z&kUOzVblK5K`{5Fv#A`3)H)gyj~g4DR;Zxi_)BBe@@O-f5|lTq?VQjm4<_(_Av5ddo8Kr$7b4NLvYhp#$OZ#m`crwhB-|HmUH5Y*Nva3ZMec z2$K_IS@CtcOD$%Obx#qWM^3iSPR|1MQ3@k|lcE%4Nx>lq_*+Dzh_^o{{sxhMLxj4m z_?tw&1(Kr|=KZmU8WQ5mxKlP^FrZ(J%XS7NsG+~0lt09$8!$qYa{)DQ7o*+9##9kn zbVD07XO}gvUvqMS@N35|9fQ8x7+-1ZlN$S=trrpb=}Vu^TZ^H{N~l8$bwIK7hg3mt z;I-~c-FZ{7#JIe#zI5TGyrFzwEU^;nmts&kBgJsnN_d+T-j;V1L#dU}Rw=YKZ-=_- z3$1ulk~g*3y7+miYv;0e7od&#a|L@R^kttlZ@OL2IU9=Z(ERMeQ09(D7Iy zwvDQB<7Pk@?`?j3)$Y1zSaCN)FJ2gUYu6jQmRf(m=Q};i_9s8k*Y#T^dvAV^ zoBzbZ8KjnIT{tQ=Z9-N*a)sxgS@3*~uTkJ1#yX3Q=pSv`|E2}I)Ywyu#TS|{KTC>o zy!)od5_euWo*$gA`@oOqMq%r^R!qy=vn9e`78b~aJ&>uc~wddM1*CSHkWd2~$?tk^fg%k4!Z`fOk zb|lg!+1nP*{J`FOvmeiUdn+Lpw%2V7hnA93cuU@~8j8Ml{!8ZP_FgIR)OBR|bl$d#dmCPFD6|Y*OG*v9^3Gx~_L^`>SPr%qQrlNj zPfDpLms5MC;NHCLy>0slD+x9cXyR`ydIH5jjGz={T!AL~Z-(sKt&(|5zW$U|${f-S`>0(R5HY!0{`4T&`C~MF_zc zGE-R~_|{SKdIMGs-ab0|A`AlR71yYO&Q8KlO_;`IBS$hYI`B#AS38bDr}M&gnp&WG_^)nbIVZq6C2Y27z@U^VB>uMf?*;3hGl*?UX5!_EIz@xLKk0> zoQbt&7Ea=}wHC&^H$m!4lsv0xYH4$m7X zN6D33>t=BqNMwAWe^I}%?XrOP^I8v!@8v>G^Rt(P4_PIyaV+v>0~cv3#*;-v{jwdeNy*u^W~69TzAv<9Vyu}9N7pQjwW7w? zY>c(jQ#lyxF`Wh0OEpE2~G(m``AI&S?bNVSil>QiVKFfd;`gW_u^8f$pk zdEs*if5VZQeub*1UHym^zKGVe89xhtR{U)E!MiZ+z|V=Fi=}cGKSUZ^N?P!=bE}?q z_OwTYI}1GRoX~=Nva{^$Ni&a$r!j^~<}5bFg7FXTpw)5Y9ZG&qrGs)Ki15~4kShN& zRLJuf`S2R5+sN_y{kp9>SmEgoc0A_=SXc+YhEjU4-|CO2>M-+Vc^h+--Bm7S5TC{! z%DW;Aj=AOHoyD%hG|Ohc|B*!ArK|PoT%ElfC_9>ey`sDzo*{wI`ll}$aE4F|xi(1r z7&9+BGvK^q4ol|b{ML8O$<@ZLqPGKwI25|2KtDXo}KvQI1t~vlURRE{tw*}m@F{} zv^v!!QjPF!d@K1z0s^fKQ&vKt@tQCt0kjdm1~6IyU;fJnhA;7<0esnB(XuOdw6f9+d zE)b%bYG)AM$3dd*4abSUOIekmIPu?7@b`%PJ`tu~O%TVls27O)Ln41bN_TM7~etI+53j{Cg1cT|)}WoDc?Mlgf#IMH1g6!kW6=%C&GZ(pYI0 z_Hc(XWn&6quxiStLS@g>jO}}vK;9cFbN_TgW!6Uwm6;zkRQBWdnp)RGVo>B7QiV_o zpkwdGyz8eK^NOz6laf8Tu;Yfkw`dQq*qbDK(+`0vS{DYF-F=0PPrlt%*mzWO9|L&j zaNg`8L~8&VFriWlPh6{$8XnI(SK)g@t|V=5b-&rYIQh0tO76X`lfr|5RKwAgaHkaR zT#Q|_E{FF4{_@7=rvO)WTz&rSy6c~kywBxL?}i%YpStvUR4CxBsY_E|o?Z!Vkb)bQ zgS|_!l}-DlP5YNO9bDOTOxkp8dD96gctVA{BKL;7Ruabxj~)M~2fdQ18lV>tFW_Ex zf?hiIq7@5xU8G#1L~oaOh%6JiLF8|UkQS$aNN0&l6Zs*LA`uf2dTWR>k-sIf3Q`Fa z{fHuokmj2s5dw+~79bSBUyL}TC&!W>nuv?u6V`aDHmpI>Dy+d+1=e7!0&6fYy6dP!|U_W4WZE zV!5Po-9kxi@uad`(s+-hku@O?02;J3(pp*?X~ZI9+lbwWf~7qpHf6q!1(CK92aEH8 zQ>9thkFXO9DqRZTL2(1E42B7mmOQ%G12O4jnwCsDMm)7v5UC(cXvCKeUr}#)3jG zyhEcOp@h+OB_y?r9S@(9)=ucV=JK*J z_+ag(2x>^!Dd^20wi2N?fbtc765L6V>?7hnA_qa>BzB1S!$by&93gU)2#wJ4daZ-v zF@&!gYpv1}kJBaWQlf0hF*+?wo)=Fb?tgANIi0}X8ROUmC81o*C9viuk(o_QW-~d^ zU7Hf?mer)1D~v2dRD?8XA_0(8LIGQNdbVa+l_!54{$L~=xCiOdoC43W z^?uYvMf0EdM)s}AHz$|2{JHV_j_*07_Tx9qCo0w@-D4?{yGe4j7tHMiS38_-uo&o( ztPv{o?}sT(TOd?lwLs{m#i*?gvOmD;p#MjUfNs`vp7y(~{DH>}TlV9Bgv)-EB|mTg zW9xmKA9=mcjo6l@slqe76nplD`4jJXg4&v%AKba7=RJRD-m&a&U-T^dyVmq|o~Ly0 z`VQVQ5>xbdz4wup{Gf8!dLRG&*83>SeDK6wfe!TCRh2H*5B) zvo&p1u zQs!!{@hO=|oiH=HhFw|$NVA4Mf>`>lR2_)*S4APHt*^kk3Vex~Ingb6uy)9enY!9W zycM$}r&dDR%!K))Td|VJ11T$XDpeAOF}+~ry%ej}<4{TWD(9ZIYge(IjBBW&d`076 zK2v)J=0Eu8(*pS6(#*J)wuL-hOs}XD0u_=qbTCK;RLS6Qb1BClp2Epm)rC2z8>YEX zj z-zSbuawdp72O^ttFV1DN@CAQ?qSz#hS(yHSf(|0o*r~{a7?jFyAu7637jF>v9U`S# zkxl9&RGWHUhnQ+2QYGZTI(L0_v$ve6M@?xy(`xmYPi3$a9qxg0J-nr9FE{Facc1lC zSaTZgobRXxowI1CoD1X?4l66&u?6EhuC{`?t>9`ay5b9M@3=Y&=8l4^qv(n)1m4jm zCPi1v!rVKq4F&Utf@{OOW_x}_GDlX-Dai~UrUs7N}*KI7n(O;+KqM2QTUC9 z^QJY$U~?BY=6g6X<3?W~Z(8ly2w!hy{p;)n+tdr^;8gy|d}yHq zvyOPGAh*?qAd$MmK%Fl8n;S~TbIP;#vS>i z@U?#R_=V&1JAVK_?q=k9yPa!pC)ZnCnfKx=+HA!_WXu8;BS)cS+qF|t!yZ^ULdl{# zUi1bpe)dvN(d{V)nlOb6Tl04KQTh9d{_uRmr4ywJROLd=5b|5|W_*>ytqW(RP%rZ? zhqr8e!%d^!Yt7fgU)&wcJJ>fFdsVQ(G(x?bjVLzMXjrk4m(q*>0sQgf@PSpGT|;*_ zB0NX8nTdFaSc!a+NIenyvN6BayLzlvSRDSEgk>UsOQhP^aFZg55X_tVGY~R1baCyy z_^I+NN@U%=#0S{9;6hlnpzmEm|7UCm|7UCm|7UCm|7U? zCbv-!W4+`y%JNwA-^EunK!Ue8qRK3hT6GM+(m`4%9k*uC%(aXv>85|;T__FYC#srM zoY=?GRWhlNam7(p6pbkC)Q%b1fiL@7QMK5D&Of6H2yTX*>uGj<8rKwbC4?-H*!Avb^N0iwQ)rCwg zZf0r;SC}6*Xwv2L5UBuF2OnN%$OvnfVQOiBsij{F*EF>>!26NjJynMgsn~@PJ54fE zO9KPy8$Ut`8`qVvWVCzulr&HQbs5u&3>ABa#+T2v_Ij)|jilXGrj|Y0y>~ISM6qE( zwW(!a>7M(WTFUJ}o5K4Ir9BOW7&EoRn5o5!x>ou`tlTH?gd4S-%I8{7SoV;H;{KQ> ziyjuMFtyNz9NV;O<#WyS8nxaFQ_K42$Jls?nm?fB!^dE1p}N1*)ba>g@nKCZ;w_X5 zQ~Va3Dr_uIB3Nx+5pN?>v7Ctioy0#N0+56ge+-gpxwoxD{2wG)RtP>t!H2hVxahWV zB3UA*i7@@;;Vm9bkKE!>n#6p3V@DGn?CxqmU0G-AxMPA=vUOY|<^MD$Q9p(|XXj7c z)`1niKLP6oRB+9-Vg3KtwDF+(oBb46Ii8hbpSWRu?on7dk_(#({N3JM|EMS`qp-MJaNY+j>ph_kyBx|RcR3%#~>YkP+vXNA4 zOaF2Urt8&l)>@$`>e4~=9iWXPmJ%dG7?D_F$hR@8TkEiQ2TXdb7{bU2}&7>Ij zeuX0c50QT)a+%22i7>OC4#}j<%8ozGbhtp5-X!v~M1GFQD?~6Hxs*kb{3b=bMP!l4 zRU*GZWPwN>6^%?`sYcO2ypagm@UU&OvP<)KDPUwnQ{k$e$50 zQJMaNxW6RweIl=e$Y!=$jJ9u9){U{%fY{YZsa-SO#%lMkNVq{+oW=lK`cpRLZNij= zWfl^r>K3g;Y(&;Ki>w;`5)ABHhJkVj+g@g~_?|eE-Ajf zVBd~uX5}7FY=@`hwU-gic6h42Z0$h|s24-`Hlf~gCn($+2wfatTQALjlD1v)`hT)E z)o|1N%gdpzmC$A>w0UXsGB(RP^lr5E&1V*M3nwoh&mY4cN!TA58(E#X8oRdldIU8V zbmkAjcI60^H%e-3d2`Rgi*JYv2QEKRWt8nBqwK(acTH+dy>;ZxBZZB7uh&VfM+)X< zbswcdV*53bZHTm*XnSb8uDaz{a@qSaO}y`Q_OsoXRC7AEWby;5KQwmn~ zHgz{9&%z#z!OCrvls%NNiY%BUYq8JDti|vdC|Qe5${Y6#y_`mf)G`^@F>A51IZ~?O z<6Db=KpCwNq0h2nJf;s}d8eZXw;Q|Yg-3)eSh?Q@qwHqX-A6{_wwE-cF)^yqn6Zk{ zn6Zk{n6Z^cV-`_qG-eTs(U`G{(U`GiqcIDXjm9jf7>yZQHX5^_Vl-xKh0&OW$!NTg ze5>n?E~%kst(nCo*O`f1SX}EJW@7fxYaF}!vc7sDn6ZkPn6ZkP zn6ZkPn6ZkPn6ZkPn6ZkPn6Y;?6VugO98twgOpTc8#vjHwKZ6~Y&sj1*BWw>wf4VZq{+5MEMm5C|f_Yr~DCj3bDgyi46YMW-h-9*m zjg)teIVve@;MrUzN|J1*LoQ_u^pxwBKZU{eRATEbnI8n})Z%+2^?4#MgUIHcO6hlt zov1DR=W>4!vZmMRbRRl7-G1E)w@>2s{XOUXTh9J>obT_r_J80y|BmbW&|uTmeaMwa z*hE2)n_;fbnLoR%3*D|8Gw7`OEw667u)lZLX)5xp*mjjXWW{uS3{7kmN`R?{Q)=16^6_VZP;6lStF>xgBQ@P7e< CuGFLe diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/radial.cpython-310.pyc deleted file mode 100644 index efa36c8671fee63eb7ed4b48dad94a62dd663259..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10527 zcmbVSYm6LMR<3tfKW99ijK>eWmB;Q*d+nKVoQK!3Gj@y}vYVZF;gg{;Upfu zbE><#y2rCws8-#oI``JCd(S=ReCM3oezBO7aQx1@ckAv6N&07cNFO>L&f^RuS(04I zm77u}zRDGupOuQrSFNJqsxWXd&b){6@L{c~5PNHryRX3T` z?Q^G4H5^8S0oo8__`6{>!#(`b5j%Xd9I=Yy&fhMus4jaFm?Rnh2#Rcf3* zh?#T*mr-~UF|eYzG8kBNHBVbpzAmlF%%oyPcQc-`4yI7Krn$rdtzx=7-AeZBk|tGh z7*Z+9zu*Od*L=YV8o~XUoQ31k^$jO%_-$+M*)kLEnfT)6NEdd?xmDF@5dDM3M3pMe z;2S7H>5jA|Z)3efWnJybhow+k*Il`%^psHBGPbeEZ)?-irq)x%r%;#KGJD##+>`8V z+vrJKIakHAB1?R)p+&wY_utgtR+1WnYC>hJ;AZG6+hc8939~boKfsW1{QJ^fEz*O~ z5n-fyZI|x7P7vvCc)R08rB&vEebZ~y>R}*gMt5aV!X?Ufgtz8YLtpg!lFpZlcz>hg ziO96=M!ON(cBHgBkCTLud8Z_F>60r}DWyl&X;xkl5?DK?!g%VE?k&91LO=is( zHyz=g#6v*eos>&*cYk`s{T5Z&v*e(A5*Z#PF$ANM*5z-*-qBcbRYDbauwuC9yLBCR z+B+tMRA0WInPbg5cfWAlO4jyV7f;FZp1UuX%Z7Lyy_3|nAv?Ami-X$q;;BYO;a%To z2rsZ}SSm;`PkL;Ucw2ZKVcSomuYksVR54^@N7)%)o48q75g9bPd^ax^=<7L~tQO<} zQeUub*otO6nkZx2u3rUNUu*hK$U+jYhOxrJEIx-0#3|~f*ca2P*KA@y!!bTaO=vX) z*%n&7fH20CG-X+`kIXC~7 zlroORsC2RJt=_KIy&G&Fg@G?HH<Pxc?V2YJ zf)Zsn+QIb=&+B@^q82o&fdbHdOl3YNqR$<^fXZMB1r!=;JWI;FqdhQ;oZLMy%85jC zS*65h@zUOkJCE9)ycMeh=DNvhu>~&lHF5>)8vP-))9u$?d{T&0-(Ut|-1pyyPx-TDxk$RjSNuPLm;@Le( zk7FY|1=_1P*&H2FprpH3qOYfNESXP@_A=C`=TJcmOzTsAQWr$8;v6{=eR>|`5oaU+ z<2oDcWMnb>LR5Uw*$9G$(`Gg5KE__ffMzk2S!)~ZDtR42>~g$`?_Q48OiPW7#zmYb zbql_RVyI^1aKPe1gy0rvv7^Cm#u3!$>qbv=RabkP)r?&Y)74{F!)3v9GqciR_>lWi zk~_jmm|d>nBYX(5Lo2YV1P{?Z3`shiEM^As~A|T3di2*Hi$ooZBgAhak-ZE;HvpNo{G+ zWb8?<1Ns)3qjhOk+LS|ii_sQWt;b4B2pp^GOK6Ev^F(6w^=@^<7eI~Fj(_v6GDP2T z`6dm#hLhzx2b;mN`giJC$(;b=WJcU^ z6Ye;f7z^2$%XawhKJ zX5ACs61R%v4(sSW(#c(TNTS4wdLB1JMIGWHR8<9Al*F%BT$OMV9zT$gRhT_~Rz(=m zSI|!khNK+jOiH~HFcd$5sz`0LLsql_2Cz2q`z!-25-bQLHMeoHStJS6fCJ`E_u!DD zF^~UG%0udndBzG9ar0{Rh?#aqv?d6?X_=sO&>u;GSEnYi^}5s zWwaKAQz=R2v6G0ui8D}8K-*_`jd_4xv1V00hf#i-pYvCGFN8NWyaq5_!SRp&bAy5{j%ZO#eub4H+@#f?Y57qtHpUg5<% zDnC@<5@^3deU~0;`x1*6P1v^6Zu=oe?}z|XW|1T%Dj)(58&!mzb>9`lC3eQeo8{N&b0nNDux5zGg$(w!=9m0QP19!6#)-4Cbfmb=Ot^?VNtz~SPXpt_D zSV0XGQnftI|F2LCtss<5y&;JboE6nowq?$Ez(on+s??xDzb=RRRt7e3nDx+HgKA+N z^>sKeiPhBH4ADcR#gG#jeM=E5s4Hw0sRe~PZVvCkBH5)BJ)fdakhO*F(%b_2!aJ~8 zaXv9g=f)U9D^c}tIEYPuog&x0(0!Dx%Z`OZu-Xh}~`mF`x zA=vXB)CW(Y$jK(@6eUY^(yjvCJ(x9^hNKF-x(E1WW^q@5e>bJg$lZMdaUVseDAz`2 z(MD2H1X0d&YaS9+u6HXcA@9*@RBfagRw4g!0>U;+VRnngR01u(t!_@Z#a8v zi8uhqY+{*%bVf6z;#9ZMy0?R)_UP}5`Y-Jr&%qS%2e0_W$&~+Bg5(u)p>F#XPBFW1YU!;QNm27lu`&m5ci%2e_uPD<`8W755sz;%I+g($ZoTsTF&JiW`yaRLnkVJBSAAjY{T*O#&IGS;Y!^t>|B|A;%B% zZvF_>0eSak#)Mx++F`2D6G}BQq$@k7KBLUwX$Am4!vH^ROv^|;$}=)D7n6yG416I7 zb#l->8R?5SgBcW9__5sGI~tZ&CA6>hFdUuu#O9T~d@!fnyJlg#B34l#gIvKGcld{mDlKwWfs~ zITFOKgFHZN3U{(fR?TWz-JPaw7(!_ZY1RzJZrWp@|v@a6gGv?=)#VyQeh{J-WADfq(xQo$OY!LAV3OKI& z>3*7u(LgN*R?1-rn~E_=<$n2FWUL06k%SO`_fdNBA;|hKpvPK}IfTuJLovFJAnB}p zMAkm;&W;J*D42>-8@L-Y3-}PZ`D{eKft&v?W+EX)8yYbtsXfLdAWwpl$RA@H0cfcA z7%?gE!sbHqj9MhIVr)GiZJFD&3udHHNw^rDSfhe;8|h@WGV4TO>|p*g*BL+(PDgnd z{04T85GUE?!kz#jTGUmW3ZDv!9>jGNkt*!PN-hcD79VEqKSslVg1J#^KMT)D!A4S{ zZ{=ah4CQ~i2gX-zF9V;f&2e-+YT#Yv+wi%G>wJd*m9q4p>AX2(klA<6-B;(ITI*`Z zt?TEg?DTO!PvO9*q>3l;rSTb3E_Vii6A<7Wqg?Jo&YHUg{OFu4A6uM%`q)>A5~80ur`f4H$GEap zF8|!|IPj}Cs`Y;7&D(MJ4bD%;0n8id5%Pu{+S#+zJF;hQ-0D-Ff8G*{eeeMLa7>=Z z_BnrNBs?T=NI+sVoBbm|LJ>z{qK&N#gw@>2Zo^fkO(TAT(`O;Vz;&4MG)%bU#pvC? zOc-+k@yL)(#45~Uj(6_LXRMd~o7g^j3|q;q&SMCFkxq*zL=EZS$qd_H4Lz>|@D84{ z8f(_F-}a6p5aeC&fO+%G!kkz<0p^c(gJp~NudTUN-*4i}uh)^Dz_-EBmUGRsuo)e1 zl%tDtm%^_;fd3SmfDfcGEDnDQ89jt&s7DCa z55@>iG92PCtRTe6_GnK=cdldni7;Hm|4}wV5ApgUZO1^Z;9A9@B3vo*#aON8bIBXK z#b+3CT66p_i1~Qe7i+qVU=-bjqOppv1lAmQexmM&7XESIy9l=_In)1KKQbMYlVJ$E ztMf|;CNM}uw^TNed7uufe!CreiNR1>Sa3%I;9b34;A+~2r}#KE;aBifOL{-c{P9cwRk(gCcJN6t)L7L6%BG(s8NX3_1Dmnk;f^d|HoN^|H zqjTaH(0F%BnT?j#2))*xE`~>=rJ(Z&w*N*)IQZ)a@(7z~a2bx=0)1vI@;ps6cGlTD z{5(CrNyU#5e)iDr2#sJ7H@vTAC^Ms;NA5%A@F=Apj!p`CyF4X+mOe%{Intqs_`lI3 zMBue8dZ*g$`HisAoL_H*g60`b(chZpL*Pd+-~T6_AZX>>Tx5KS|E|LS8{y4v-sR(5 zHHW|K0P&YoYfnxn{|zNF&&Gduc^;FJf;kkXVJfj46cj_9k&R8Vw%Ji0%;x z)!HU^D~aS}19H}pa9lxbT(RVG1$k?CEoZk1Z)&SdDvzu7I2qjzsZv!Q4u2$9y7C&= z-ny0heW#~;dPXf{*4&!|I(_=|>F=CAkMH|_-*?WexLi&O!p*(Y{k`=R^-p-C25Tm< zGHRlzHz+^lr;k&Da-{{D+?xb5xmpAZTusNBV^+b+VHxvr+cCRf=csemD#~x!O)J#O zDY#?SQ-13?OLvi4RrcCll1;0r-DV}VJCh+47|^rX$r{jf08QsNpy@K8SpYQI zZ$Q&+K+^>@-QR$w$AG2>XnMZ^O|JpXLZDgn4QLidXkYPtoCRGzT4uwEfx&2u4-O=H zdAKhS!=Yr77lPu@fNT}RLUb@0&P*9{v1D#mh9U7HOzCqJPGiC3r(qhKU4v4aV9W@Q3SQr3O}G zJN*WA7IR5B%><}#GZdw+4~{ykz1gjMV+;)S>1#OsCUX0Rw*4Y1x&h9 zY;Fm>tJ3Q9oAz4jp$&0Mz+96l7}}Ooi+lC^-T3u+A)xV!02RZpCgEgJfa)$)@3L3F z8OlSCVBn}V%2;$?QKppD$5qfHx)NXgoy=a-L+i8Zs2SF$8=rE0OO!P9y-gV-m22wg zJwodul+WJvnL=}^W5yygVlpHoWiub|#WX)C%1m$a${;Vh&yy(`yv#=<{Yg>488Jf( zI6SgrP~iJQ;bcO{(4wC9x#0bwL0*vU!C*8VO$LLqDLyEh2BR_q1;hQaJr<3NgP|}l z7e@GGuop-S#o$V+GsWYwgu|Co=*%2Qa2%g9FBRi5pC>jr)kMRl_M=%(Ug+wtJ6mjZv7_Q1E8v7FieB(f9 zq(79n5Dh^p%-0VJi3@xRB7s@u@E8CDRtN)~|K2_GZ>|i|ZWya=mFPSOFx-Uv7w)9@r_IGM-)ugMMrK;xH z7gJTwj6U;s?B+$*n`SwQ<>t3E&9;Ba9$d7uX?vw)uUue@zaN-5_xic1ZSM`G8#_N~ z>`XNtkxGw#ZZvk2?qUD+%~<75&Ds6_b8LKR#DA3nkf!wVsp znj*Fl=7jJB=G*!b!sU?AD;MiWGn3f_CB!<2U@S{2CXT*-bh0rkJi@*=mRwYWo#i$7u_$Z3OPLHP? zvD2l{%{48-5^P^G!;B`Q<)g{0@tN*pH+Igr1gueZ22rzsr4Q|5 zgFz6`G3EPYPcYaQ6~rX4I6j_$$F^XwHxY&e(-%vGl7vAhbJBX6FfpMKC<=R#Y$V!2-gdigrtWi0KCoM`5JcqD3BfSX(t^4u|ErZG z_o06N{FQJ&e@W4k_CpH=R0qiH3gd!OgQgnILr$J*1XDj(HG-u8d?P>j>9d-mWVDz- zB!e>KHvvYXbR`WHC!9PO{N}ypZkLkQ$5}9xRU~@2n6{jIyfJM#g9vB3WIBEySBp@L z5pq+XLqNSknP;0ZMHrwJBx*du3zd)~JBH%o#UY*_;f1YO0^16llJG=KE4gHr;h3YC zR3;jcZtHTpUlVW#p{RHUqFe#Dkd-^0j5wJroIE;B-+1mvrRk<4pEMmwH64@K<7qY^ zv4I6|;fmSha(~V+AdFaQ{np=eFO;tTUS?7v%NA-O$`LR)Y=n>g1tfrv(%pJ43wam_ z$HXqm3u;%#*Wjo|v~(PnR=KI4)+p(9V}prcH1s`1>f69!O zl4~!Iy`1*!l03U+52QSM=k0q}i#=RVWVrev@%Y3ZtY2-V0^v+RA-Bs4ooeY4o`gJM z7eq2$2NND=IAQ2y1t@*o3VAvKq%h{yQ(h52D0xIA`qkp+qvXf0{QQ;q&NK68UYKWl zmQc@9#EVI(6z2$tny26*5N>uD3F92jaQJ8O2}bxdkHb5d4ani)vP7aXGqX?S@Q_qF zJR~1D&B);)QaC&k48xxmG!8F7mLUvwJboA~I0*nQfFW|qu4h9-q8JUu71rJZ?HF0_ zR5LlE*mHeD@h~D{(W!8Jwb}R`$ScX*DuV9Iwu%N!VWZv%I;{mu4; zqSDFTKPye|eg2cZ&!_fwFPlx3&Ut(J7gj_{)xLL*+&Yr3Zjq{67QEc71yIzsJ)j3| zb0cQjvSg=Rg(_;=HNX4tr)=j!X~kq2(f*NiKur`Z=a_B-E6q8Z-@F}OT=S^~pZZz9<(xfZf&@Ex#sujAQ==hO zzs)!wOpsuD1QVq5T)~>}uF+ED{zhN0W?gnaSjsP0RK-_d{Z24@y8H!#n{a&(Oi7pQ zeO^;r4h z4t{Dsq&;5K*!4?I2_t5p)tEwK;-gSGQBqMK%L~BTH3pu21d6C8__Mj)NHH(ythd;! zc|k=Ms+X&p642%Z2G62!C>9=yg_3-bk6+?r2{5=4@ru1Ho4}Ywq%*Q*0eRGFK3=}6 z$1|+_Di*JW2vmPhv0_e{x_G;GrgpaXgTD9sK8U^_{RQ)}8(Q)9X>mg4?%qX1zJA-^Rpwz3c0*4Xnl9!b$V^vtut$UX(oC+lSvd`h%m>v5)8v52vaR zO74TpHnZyxjpph?K=a=6$)Ov2CC|2b`!;3P=oc08K*DSLA@Tna7mk=~2HQsLb&9BH z6Pls=9C%207a65ic(e*08gVZLAYRAf4G_U%Dk__37;i`y`J^J>d|~bM*6H5ceKUR2 zJ+nP?y>~8rcwz3`-NO0of|>zHcptmO=UrP7BFh}rAGAD%;_(F94*?m0*)*#op=Knr zifIwtD*cIG0cQ}ww!$GyJ%iCF4;5ZEhqKv`3S;87-^iB*+_ z#FOh#pJcC~{-{Gu8EL3hR3rE=PSLtj?(ExYN~Z$ldvlHIwCboCDpcS-&bG>{){?#Y zzfLVS9Sy@;&K9W@jrPcNL;vOn9mPlK1yfCiP$&Uomi@BweEJmCjh2DDpnJpcJN$qa z{7z7HGeQAO7h3Cd?>?^D3RczTQ5zgVU!2j1v$}Sy@eOM0RaIln8S8o|t%B9GKpA&; z5shKw0MQb-sL1t#i;RaYE_`o;BEytS@R$#PAF5YDIk^7AiAy1Hz9l*ISVtRVI_$GV zZvB1p|7<<{aYtk@X#3lWQ;Qvuo)DfBc>-#^PTzIQPQT$O1%5IfLkkE+i zA|cgrT0k8v^7HApyPhuo#}1hZ1;MN!7leY^CIJv+I2RQsxh?$}u&pfg(<!RrT4Ik@1RI#0Rk+K|0Xx#p*= zCqs3sKNPV*Ma1iV&s#iUAGfEy+a>SzoNWM%8Od55&M(|(*Q(fbv;0PRx?+b^vE#kN zw@+gyEf_{MO0bCa~W>7KWAf_ym$WcP7s$@6EGRm%1_}1@og9;K;v$zsB-ea_{LG|Lh;7Q(0JRZ zePL7Gdx6{M-aR*0dv}Xee?r>yTzb^+39|8 zqNLXEk=%RGRM5Kk)kUDjR&0CmSJ44eK`UAcSg;h7Y?$A|e?Oq1ncXm!x*FN(F{9?maRe!kM$aBAMOXWp(j zF<^y~MMqR9lQ0bLcKVb}PJ)r&MKd=DXqQzyFgVvAW=i-DmcaiABJwxnQi|{xIC_ z4H(}z-?2u%bB+80<9x6Quo;^GZ1e}4fLmu1@XDS8kvK0#Ps7$95m?Z`{YCiBk?$zE zxKmHb=IS|-=i!ey)j8N{-c9ziW=kkGB`z9w9yoMZvk{057h-kXu|v;uPwsqj7gy7t zOb&|s8XJ2P(T0Q&X>8ir(6qB)?ghQ`CUY5^YC>4Ju9{5JNQcr2|4f`+Ll7_FZ=N%-v?G=J{0VDS-5y_r76)?{zNOz;uuW z@Ohf^SK#zNhHm|@!0E8v$#8?8hRSe?=;qX@16{NOAjUL09i{HE0BzE$Lw3UVpuBJmqj88}e=t$v&?%_k zudA3MtBrs#Z3W5&0Ggt35C+_eLD-B#>}x(OoC~sqKg7HVj3zONLyQGMQQ6(bG_;IQ&Rh>`R;(! zen#4QMyh-P5LRi~_zR=XMb@?GEt#;6TkjREoAABvn>=>gI%7>0H3IM{+__j>^8KNS zE3aRfT7Ub@%o)W=J`+r>Ym79Be7j|)Wp?Xy-E8k%^PQFtTRz%4SNBox-RA$^@|P_z zWOwUU?W;8_6PgG!Y!rBNXW$~di_tVhoyPlY-^S+{EyACOU`5>|hLCjqLU9HDFIu-; zP|9@EfKrwiOaMe-3zuvpX@}$k2T4*SZdk{ZYj_6IIy{4SSv-SxSv-SxSv&)G51!23 zy^GB>hBHR{&&VNG_v6{U`>1fIIkO844!y1ga31j(?IX9q=?AAN?s-GNBoz2v0!#Lv zxzG+KyZnj1qz03sl`tJg{@8gs6z{Kx9mn865eE_p@P}%Ui2$N~PUK~$YleMX8#ma* z?ci#L0ZR=w`D%AI?5;hFIZ2>V6N(M?hib{g0iW;NbI?HZ6S_4=oSH+7)dmjK?&53-HKhtd)gTIPL`~@jj%NfM^{`1RtI|W*bhbpjGcw(F zm#a^+SDhmN2tX%d@i8OsI1hm5N||0T`ZeU_m~S69zl|oK>so{?#Z<-hkryy8qWa0jum4*Zb6_pSgW( z=2*IJzf`wBRo9-X=s>TU1{1!ndjCuG*eM!!o(5G5t<5>oj}n`#Wb6gd24L45PFpx% z0FEw_0{56s(KNxsynsj7zqV7%+S5y@+Io0$wZl3w&>8|Ze7eRn;v7NcH0}sEfUPp2 zCKw6lA^0f796d1DGiom=vcV7?jW~MVM)WQ#a|ezC^BuoEdLW&Pu|}>;g9< zysjRP3Mb-m z#U>_ZmRXJ~a}|{X?<@AUZR&FhBHcg;O$E3PPL#_Ze~Q#;pJWl@x5g?}^~NB@hv3zQJnhxul+~ ziNg^nzvW>YmU_m)xd?rYG*vmI)#2<31+4fq;Kb*ckLyJ}ne%TQkpKz5>>!YhY$Js! z@KY$Kuw0xBvITVuU%}eA496RMSTLp#HOce_Fiz~q}`h&_om5_*;i8TP8_9@N%Po%KEZmf?;77Lu^ZvIAiyW*xN|amBMMF% z@U*xKulwKHeRJQN`=nn~vP2=RL>f?Z-&pz|ldh)n`IyDFpA+onV*4)*Cu; z4&Z~l!!u-VeH~BEkMU)UP)QaMR1t(^fwSgpiL(C_E=v?fbrnqUBK$l8`feCV>yQuL z*;U_-1Mi%w@1_9nTHMSw4KzVo_bT3HK@;9(K@;9(K@;3P;ArMJjZUEBnH{3oc3+Y3 zUHp#Uf=Fh7VZ$$XpbDpWr(OK6c)VdK8I3huh$aP`XmSz$Do8^Tv?H-0^J78+&O&0) zl&#N^UjraV=E*Tla?nWeU8C2DKuXeU!ezXz!RQHyWcz;QhXmS%zk*xnQc64kG5ibx zO|Q5q+VfA8^B<_9f1ucZq@MT}s&?5*ncSoODSF-KrtP%z`nHLh@tTRo@x~PjZeeGv zseB0^ET5*c1gqm1)2*XJ*M`T2ue~}3W$`{!7L&_oOj)|E q&`y4$f4qMpF`ihV;5JhdQp=ZgM32(+aa!M0l0gbDn@E~;@&5&!@x{35UbTYRlD2cxwg8; z$I@etHFEM`DLf<4E<#8aLgEF9KY~Y|>J=XFD@f!bHovcWW_G-Ghy#tOq{CTK+110_~fUpE>S#4HlZC>YXyKXaT<1MG|n6F!R z&DS$uyOnGEb>EC}TKRUNUNE@ZDz;1YlEJ-Jxm~GOfamH{>z43;&Fa&#B=Xzrc}o;T z@wz37vLM~&OtZS}Y_t08xVI$A=q(TXZ1l|ue%`88Zey=j&DPFR)JU}d?OqafJ7G(^ zuXlhf*O=zrQ-GzjT_K;BomZkOQHSb{Zo3z?WY9)C>gCsic9^K>N}y!ex=oU*l{CAo z*%i$$Y1YwfOEaZeqS;l=`kHk$>uDBiFK(!)mo#W2lfT(vCB6^L|Ig>m+M6_ zB}>h+fKZ#|HAl>d>UF1HxnVh0eG0;w({tyqw%ZbHHqLfCi3%GeEuAM>bUU(>=$vfH zHjrhVkCc+VIEXsBFhslkL>C)fw6{Vtan^JMovvz!t!OK2toFiWV_(~J)JdemxdfMG zw7!wVx)N`MN{Vs4)akc_;o^{~l1X268uQb=m1GKEvwkw27Me5Lv?I_zjL}U$^#Sw3yXoyRW^MjJSUADlvN{CL z)nR~|ZE)?7_z89-ny5!RbOh=#per>`9igV|I0GD^I)=J>n7|{{lN%mMt*YI)-;+xF z0Zu(ift3Pa7ZcD|ISZx}D(c`B$^{RoQ85 zoIIsOFI34@(>4T#kN9Lfx<0`KHw(S1s)n^m^zjlvq3CmmQ7Q3~DGo27=CA_(B~Cpa z_n5;=yuduh(Nn?KF_`m#ch|YQb^JkBbzyumbQ9DJK!fhX=;r{TV(ti?{|)N~TjPQW zex3BmTC%p;z)o0`r?#+#bCZmyyWC(S3ixsqI(RZ*@jSpf6!JoDz*6>>^%mRaKj)j6 zroN0W18Q9Jg=fe#j8FI$SK}RK1HXdNx^TJ$T_AI__#L87OU+wb&peRVwXV7xDzR|7 z)op~WcoE2G^6k6$!D`zgYD4Ltp`AQU3)qb?sX6M4)G*$FcA)lgwvDe*kJ(FZNdGmo z#3ZzN9KzPI|F)HyyLbJTmzQhPbX7WGTLytH1msGf?Z8Vxu-Ok=S&z9y>PcFcR?#+W zp(4iBDFSp<+F9*(TSkd8HoO(tiJ{YHsk1y(=!R^ya3{xNSI^RX;zb<B2Sn85XRu! ziL>JhF6`@8%IIpCRyY8d>s;dRZ%+VBd(lSb64w{!1i7NJcOkkYp zcTFQf{g9E#!~5r^#iO7*m`dR+gA$GBiFEKEjkmbI82kz%ScF4LDJO3o(W7zQqEP-V?JjoCu6#vZ2ufA~R zxpOP59A(wh@>jlI_9y|KGCu6>&kb+A0lCLO#9tDV)L1h_2?J0=3_a zsBPt;CynS+*&;m$WNJbecJw2h4#=npMTz&T$*zMa@q6e}-v>aPHjc=2lIsX_r3zO2Yim-e1tNNqy5=VW?0A?tV+noWp7zmhbotd#AvOgm zO2AB@MAj@Zp=FE|E83FUX|=;E+HNDw=WADxmS1hJc3V*+pw|JZ%iUfQWLG;t_W@L+ zHs%f6;$W@c!ea~eFFHCGZ}z1G(eiU=R?fZ3Wht%sk2SG=;GXu8cgWk)Mi91l4ZF6f=Dxd$2SJp*u0)=IBDWJ=DGBz_El z@RkbRXPV|^27x2I&aQ*Z)k$2m;+3U|H|(-zT2WoX z#?(au<^kmjFrAlBDyhOmYI`f{Wy#0}jftq+xfIFEcvaa5d$N|_pOxD581q9K=Vc01 zpB`Q0(X&{6X52DTr@D{3nPrE2nV?Z765n4?;zt4Ktszq(D3CYJs&~0fYK24m$ahPu z%C}}Gk7$?-tE-rMpLHc$Zw@E3^CaF~$|Vr2QMPlKo~HP+TE;MVU`NnXY?-oOF*IT^ z;x}`xiTwUphh+M9l=k6(6}3bZ&Jp+~fw9R|&r|gU0<>Kdm8cg9 zqf2Ho>Fo4gqUxULo~5dhA7z^4;o~O(MoF~u{wRUw!vxwx0_`J#=6>b9`E)+ZP-_RY z2UAXti`-GR1&XB9CIMskee*E305j^yQ=LxsxHd!RLO+RG3k?J=>$0QhD(LcFdoxND z1!THRk5>x`o(kfH#=3DAs!!x8>NI{rd)bSdN#{-KJi)2qnK;${>CE^oQbZe*I?v^N W?&EIbxr+Ya)n}f4h diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/symmetric_contraction.cpython-313.pyc deleted file mode 100644 index 16ef822000ec5469fc71ac8ec2675ecc160a82cb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11475 zcmdTqX>1!;dT%&9#Zw2Vi;}2Yqz>AaHyVowl#Ra4`GrP8euWLWXfnXQjUwcG#7n%?2pOh0 zYFNQ3hG~v=V4q?{$tlIBic^VCwfLk*G{ag>E54%~(G3@IMIx*k(GMFqg9xifjKe0* z1h9rPyNOqOnc^(dh6du*^}?Q5DhRZSs)$!VT{NxoG)TSqwWle+CBePQ%ri5@l@?gl z0IM3mz?$@tHOuA{PSZ!&>7)haxlk}BXfG|sLJ>9)7F0eKfN>`!C?efZ8_q^%nF|a% z7FrIm*bYV(7DHjizX1KvHv0SvffyHB_H#@i{1H;cX%MJ41Zr8JUKc1RcMS+S6>S&M)EyjY<>7=?e)7bx40A2;+O%smt5>DZzIJ%a||GdRsdRplb zSs;HpAztOL=~qpwr!~{sY2CD@N1=jAQteew>%1EH(*l|F*zeV4p(21ZVxQit=qd4$ z9?A>IzoxyWyhd>bAcE0jKfFJ-yr#naJxZ?`W-B?f*TPwRWs6WICMHw)%6chBp2=s>o_@+=Q4*1y!w-5-2Zc3dc zi-Fkukv>a^jWHaM+5ZL;nwyVB1ygiBz%jGgwt-z*@H1g%fni}(%`vehjt%PaE@=Q5 zIcoo11tDjs?~$qO;1#`(d=2o`o9S^7pE3w&4QiF;qqulb*QzjYih0|ealOANY znh+n&>W&eX_K`xYS4xIJ^RPvef9fl0g;`^OH9Y|wXtHZSk54VP0`mRUB4$*qdh+XG zZDu<%xfsr3eNa89nJSY)NwLITJ#-b86 z=7>`RQiWP9!UB^Enhdu)X|5Tj3Q80stXT%STEs3CB@XF!#i{0i5lBWX{UC}QqK9)} z)qo?KjA(&etD;Mb3@2#)AibfO-_IdG3+juaM2qWis6IfjqaijL3$Q^(P{o!O!;G{Z zA_+PoTyS1cu%H0OBC|_jW;7Iyp&rhv3J}!5;v&P&3i><+BGAzoCn!Q3=woG+317|a zD4n?$mgkU=x!t5IC{ej^F5F$`_Xop)Xw>hIqToAn<(4qzo{uar?uEe3`9S1aC;(89 z>0IO@*O*{TYRoWfaNga`%`OJG*bT8SQ;9c;x}&n@^XJ{CYw-rxiM!VnQjA?PI+5q2 zeek{g^cFd=5uNGV8eUtmt*v^jujTbE>4wgPq3*H1hSxWz>s<*$?PGI8qUCJTd^Vvy zn=qeEo9h$J14;8hLOYN!4?L-?-l)7=8ShJ0cCIMDYuc%@Z#eHdZsGhJG{axPtEUl~c))NlB|?_brWotL!mcBYz7@y(~w^_^+ELt1&ys3|V_nARChPia!++y@L24hhSl00lFa0`?NL zqC`Ev48ZqL_OSOzDJ8GaQH7gC6`=j#a~$bzAd@As?wnAJ!P=+;StUzPWi%@2^{m#; zEeCpC>7~4iK1CC(=72wzz6u7)_(%T@1q=;(Hnac&*#lF!<4_57Fc5R9xIXO2UI9Oh zS-B9DAfM4Fst%EFxzjjLb2Y-f5#VM8dqM8Xj9!LUqo<&{Mbf=}sos|#^}d`mUw}IP z+NyqS@B>Rd-}~~4GGV^3V=jC1#&>S4mBml+=9YxE7{6aA@}rRj0uQjk@4vnj2xmsje*e`F7mbBOpu!>< z#lm&rLU>t$hVMLw#N(dB3eO=ZXCjfXC{_~s)K?fb3UYoP2dyCAF|Zqg{y+@ukeQ_z z10pXNGtwVq!ePJvKE>g8Itmg;21HzPD1u^z#!B=jP<{HxWYHEI zD?O<)JubE0>AT&xs{(jeO|0d2F5bSltHG9**s9mMqvoSvTwkc8F#arFGIvgt^ zwe`Dtgbbv{v1>%gM5-I(J?mY&W^7qVWqrJC-LzYbtrAjRw-&v*934oi;4?6C5e1hOpthh7seS!dFi#Y@B{}O=0R_T|sNWuElfc2@J zRwLQK1_k@nll{%8Htn!FL|G|0<&3eZta(w0It^Kam(@<8Ac)1Q@X=lce#T^V2r3d37pt4f8f&0bvj%5o z8a#XopU$WC(tS#=5)mUBW@vRBDPEONXCmDSUy-~6m^t0jt@2J>z(et5&9~f=dFEo) z$ofjF6w7NrW0hw4z1-7H@+eX&_vdg6m{og~y&0}3B^Q)KKLLZT;Y{7U<1MVgu>|MJu%As48?T`-rkT;tPEx#hhJ zpTnz=sg1k^CxXwed_UlQbC&mk>sj;EvFvJQ+!y#BI4;BYX0Ogo;VQqa{c-Pw@;v@H zA)hw)^vDFtbrD}Vc;ccKlyA;jrch^z#q+!j<({A!EP4Cjhy6%_H;it|XM#UH(i+d^ zGf4Q6ck#S%WD+#JStsLJ>rYMt-xP3ZduVz6XRR=W-zTvYUdd*3@Fc&@Gc4HhxaAWtoZvwQ<&# z`^x*~=~A$h@wQgbx(-PP<=37I`7H_dmC2`cpzEB~UAZ?4&a&oa<(>bxr+0}517Jr= z!IKx~PJTf-_P3Ia;ZtJJ49lz)zNTtOvnZ8-k8swegFey5&DZa9WKIqNXs*hT`O84g z;&CVBd|>~N&}rqaK;=~Cx!@SP6I8RY8xW@(Kvc1cYakDh^S_ku5uwy0| z>X>01oo5^~!C-zM13d@4*FFYH4hP0PxXh6P>7jyt&lRPN$w_Lmu@{XtC{5(PVq75i z>5TKvK)`Sb!j#t{kd;y8;vXFPUmQBDLK(!|qKrSfG$SQiAWs96a6Jr$+>(MJ+SB2{ z!pv;ojL?!7I?RUHx=x29!9X~A1_0^3KZ5yDbRRywMc!Gu_st)CGhI=&p;^}?96ia3 z-mU8oFCWl~Hk;^5I~9To^F(36#4HCEAg~FM&_#&J+%#m6$#kBX2?~mt;7uK%H5kI? zp)eHMOTP)MW^JrV%BPUiY)ypwU09!+h6?xs{X{VkCLyr$C<3L(Zit-|BRg3na$m=J z0jT~0{<6q9e{b{$qgmwiB`Z!n9DJ`Hkkc$72c!Pkf@Be|IRU^|V4kcYH%ZCv(Vr&#fXS@Ct ziR?`rI#5y!R~3vbL6}`kEXf;)fp95ag|@n2hIbYrm^J5@Qe0A~0i0zb6wfg+|15Jg zuoQ+T3aX&u0_+^aVR}R`bIcrs!x_#$vvl<;!*OU6aNRg$%Zmg?GYKx(La-j{myaYE zCjuPA-C!wcpXaQEp^hNzwm%@;2g`3Q>f+l)>iD6)Q<4c2A&JOv{ zPY#YvjGh-vb6j9?-j7S6>N1&f#myuK8AJ-~&;#Dqr542A{j8W2T1SSC7&XA|{~U^~>G z6&JyJs6UJH(%k)W@0R2`oH96ggCpLz>DfG&s6W2-{0D~9duCGFaZpBD zJMPWDJ)dgn<6HWYE&aHC<4&{d-sIbpspjYT=I4{mr?9i(Kus!Y?kwM4PL+4^<(-=| z$#O_3eAn=#*tYtO_+Xq%6}wgjA47hjcD;6!N>;j3m2SS$ovb{*)xcNwt_-K^;0*W9 zynTkRJH9o@*Y!fyB5kX>!`^1&my@>El&yoebtG+Fn;dUDv2x+D*|Ke}d*X6$QSY98 zsQanu$ENqGWX~{vZ20}l{IT(E*QI3H#3R?G_~o}I*8*=%@vcjJ*+ja$BVp^@qbY0G zlaAwC_3s8BI)B>r{6Ql26=TA&*cfgLGM;({r*WX%>Q*V8P@3_p{JbNmI z)l2Qtq`3K@hyWUf-X2Od_47^r2MUGNwofaYOA?m)0~2ZONwr*j)N(Oh)&8WSWkxji$?+6E;XhQP$RisNL=uOWMXCb&Tb(I>y_^ z5v!L(th#_7ZB?r;?;v5e0zYd0tHy_+pRzw@-wz~DP5du(fi1n%s0cBf&>dvkQ+WXp6+lzeLi#uhn0QsCSg8Y!9 z`LW+D&1KT+)ks=hE9}<@vf0v|Y94yjJe00-XZG8vac!PUb-ciLypVYDQo=rwu4`Ef zZPu?Z08wf>fV`DepQuT7!y#t-JXh^zc4~p08X%=n44S5$x3s4$CwR+=MEA?vmJ4Z1 zS<2$#Ev}U1IBz+gIB{;mURRKk_K(4iwm8z& z4e{>yOKYcB&V6?j?i1(Mzp--epA7b06^wrN^ml0D8lgUWdXiW#Q=k3P=K1XDDPkQ% zaQw5UT2k&6d!rbp{fTw5M*BDYrIQWnpF`cM{ri&A$qw~Dm`f**YyZ(wI(br?fS!JB z&Sp7e8lu6CmdgHJFz&?+B^voh3OHYxFme6>^yZuHnIYNmhS36HI$xgQ2-(c+uQR65 zj^tMREJRQUyN3Kfh%o|vGqflrm?HSjA`UBG4>51RZz|>k zi;PouH1A8}PchIr5!8|`Ay{NPFZ<)B;%tv>SqNG%*dz-N>6tYqr0nK-T&^_Ve8?n+ zw_zaqcThoc*i@1-)$^u$$kD$!{_XL#o>X-^U)`QGbck7dI}hJYgAxAk>s02V75czH z2B~3c$6#JnqM)tKrfU26+CKaS`1pQb((ua1v|^Ad;ti))26^-8RdvGEzXu=mHW2bV zrrV~qSgOv=*SWXq`MU0{0AKf9((-)TVoO_!_ch826BrUjq{4CgTPs8VT+$dHN|qe^ z&e^?ETyxh(ETwNQe|ve2-q5Y<;`Qr>_)OC5OlX}_?(+}eQimK96cbU%1^|FOURJ~k zv!EQTml@bg`6$d$p??Fcpcv;cs|bFfau$DgXvBo}Y{qtvg0C)S^YIW-^o|@3i4e#r zf;FEIS;PF_@!bZfZjl{x>FPP@%bF}}gIjH0+?v_Cl5m~k&HV{&zvP=}MQ-9mz7l!w zHJB{%<28iZvAPA7lfD88c?*Zp9c~D#i&za~h3Y}_Uq%ob#R@Scdgt62LMT#Va$REc zaco|~3O&6196IJL#DptiCVCGlDTm(B0;%(~t834ElTx*3)92WL{CSy8%%*F^CM3`+ zcXc9IM5^m{^&)7HQ{hIjX(A<6Yt0)S>m7V?>#n#cj=xefaaDt{D7^A=(w36Wu@V_i*{wu|mI3 zP)k3Ml75ze2fI(CL!QP_?PLxpK&K1r1yDEg_^wbPqcv`rq=-LKmx(6LVwMa3=- j?R}HVz3Nd=r&pWrwB2sIn^z_WmJ|(Rl$YU%&#q4Fb@%8&jND1H$C$e&aoSVmI5?{v>- zXC)j`smyMlxsTJgZ=ZY4cg{U``V$krfzMaIdSUfDPa4K|SlRx$0hLE_h4YSKC_|YY zqsDKmV|J~Y)wOGO*Qq((TrJmiYp$8{}*XS52_A z?5VPvd@)zsS11&w)Cb47riFKAbjPgR<+2{+NcS4U?`3tV=|3#Y-=Gc%XFBV2#;$suI{2b<&a72lgXacS z5uO_hW?>5N#?KVq@R_kUOy6yK_!jEMQHc_IJ+7N@@x2e>x)~SeyyM+}RDC%InarzO z)UA-o>(y=YCaDZ=lyY?>iF6)Ad1T`+TO!lj)g6$>O_=?is+LJM9p$i$vVd6~2@c83 zhLYr+C3Tm&8&AGj-J|Y((Wu>`-XQ$<;i>oIdSh?^tN(y{6Gp#PJ&3;B+ST^$>dk@I zzC%Vk6daN{xaz2S2wZn!-FB&mgFV+g%TM?LFyv8`k)Z^edvTd9R z%)(J`RmahO*S7Ytc?O)TgSU3o6Y6c?yn7ob&J8;g~wKt&t9`#<--z)X|s8ga>&#I>aE4ZIyQf4E)c^a#; zBzkinR^)x^{aBGVD)VWh`rOOJjKyj$aUO3qqr`c#6-J4_*ysjoX>HI667SuEsMYT^ zI!R6%s#an)6Emoqi5K#;MfnI3^;bWbC^W;+E2CT0qg+V`T@$n%ZrW>tZtGAN4I*8VEuW57Y z*(Zl@57#kV;c<}2*fe8f%Um%x?8sA@3c-r)Qz>M&yf>HsDC!l{j{%|(^I%t zPm|1$>;XwU^0k6qbivTCN5jSUA6e~pgCpI>+0{nB-D;rJ3=R);za2EABZKv5wco2B zd-^?T`-z~}Ts?BHR)dC))}<{YeYMsJ!Xs-@s}n8^){}Cx-^H5;^<_*t2zTRI;Q^46 z?V1k$tnAOZXnD>LJl{7xQy&7u@uZ|$;h=+QbTQAwX$GCn)#1WO2VvNB!8!aCQ5Y@d zKccd2Dc$W$8Cxq>?8aO^Sl4GM>$)zD1^J2 zkL@jU)7vsO3kszTKlWl@xN<7DW$!mO3K+}B<;8_9Q@JONm@N)#7^M*BMQz?EZBUzy zygbd&A3{q}<D;pxzwrUbdl9-r|y;!)xjNATLvibs6NlxGMEc>#?dWb$uIq zNU9}GQR~-{93-J{*Ahpr|Xosr~93?ZXmjrlAAb--Q(I(&u*)t zFXS7PJZgzM(h~oo@0HB5x!asE4?H9cZajn;i>fvf@85ZXUY9*+*dMbD! z@c}leS%pdN1W{y6aTpKfaR~-|4Fy;^?RpwYV^T@55Ws}Ch0BKZI~!JP3j$EBghQFj zz=;uujU~-*QfW-YGP#(woN}?W)*fTd*zjU1r>pP{<$;+NVl5xF3$nZn8@Nm)E!ruXjjZG-SmibBJB`m#yv)r&gVSK_|wullhdcX*46OIB1;XomB zux01$9w=1qq@izAZqHQtja+24U8s4)$sMGHcG_9qY7rl|UOiF~Bpd&sj)K8$hX1~3@sIk++ z`jdzp7-^)!xOQMZy@&pgoPN$X%hrrkTYJOf#27HxU&U?(FK?E8r4ioXy4JbHjwP;6xiWN(klUpVtK zhHYCzJ_Ikrd*XsV78m)1@S%4gQaxvW)QqgS7!zK~J@~wha_sBpC8{;`^D?@F(cc>v zIP%71JSp>g4jye1UTjkCnZyXCcp@&xlgl|-^QwE@Sm9Q@sRRtb4T>Oz1x656dvS5` z?Cr-TM67hz!l-g0sAOVPm8f5V!B?Wyppwa0fDOaY?Mt*MMo^vFHm*1y2~?gLK3CUR zdX_}y@*GRglZa~F!_vJZ9|TE?+m>&KP9;SGP!`313CtnA$JFdFm`&YE`J&@l*3iXR za6hG2eBXf?mrTbhW%>u~yX57djURTbvZ?EQ{)hPd0+FPFP0o0fJApoWB~hMlpn$Ld z0Rt4lLO9@JLESj@1`?zoSjlH?*kHon0NByyAtdqkAT(jLiU6XpX%EAbv`u-4YjX%; zTzVCZoP(3d!O-E+HuH#mJ>|m8i&%V&D|~5;T*7E2jLP_zQExu>0V+OXV7BArg&NIRaAJz@BuaB&8*S*`S!uePxDa3$Suj%M`i zXnF3e+>8I`M+1zpkEVDHf91AC?BFV`d*d za~{Tsh&`XiRoL!v1DOMYda zFu;-zqb@F8;&95tpj+076Vxq|Hpye;@`j?ThV|@b8`G5F&C7e{9Gs^EnR}*Z=>hJ&Cf;<=D6(S+M#w^xX`V!+pl28` zE0nNP!X7GZKZ|36W@aTqC74B#nOzgly#UmhLgiY*9r4?4VQN6QrY_-u_ zVGGU9#nKw`rwCT44= z5T}C|uy1Te*pHuXAT!X54pNNt(T=%~+cJof+k9UFT`~i~^MC+vzbqRTifzd1 zAXqM9#3I9N8JQ%`sfhN;>$I1msknsx5`!-S z#|3W@EA~;mX^I~WmSjb_nEEF{F4+16%S{qON_ARK;tT1Kz6yYd zuL!0?JD3QTtR|Al1wIZ5PT)Hyl;mZa!CeWX{=sTJxGO{;CPi-U!rJsT1L>D|qr=p) zUgIQoAl#v0s7?+z5+(D%{IZ!yGI@cjN^j{qj)jc_HVO1E&Xp<4xqJZD&vM!J5x&ja z^C<>?iVyo4kTI=f+(x4pAmmzL=gzGez)&<0xdR0KOfs($6-fK-9 zW}d?Z-wacKf;qW1QmZideA|NY6e1g@m8YOVu~UM$nchVi;)Xwlxxof8!J zvGsxOAT@p2HI3W9{%Q0kPIyX3$o-yLYv@4#0&k#$#Qy9fEE5UDIFhNWYO^kuo)psn zJ~jKF1!Lk13`o{CRYFnRasAW3#35xqd|G1bZn#qk502>|7*Y^g&rT+urS?6Fn^P-y zpmUz((jlexz0KB3Hge9$%9P=_VDk>50COJh%DIA#oTWd9LB=%IL&M1IJE*`+|3A&% zXcwUCF0#U0<6p!B5h4C>%v~f|^WE8xE6NOxFW9Eonh;j$AC%AD$Pzo8(}?E5*(t%GkiiKk4QRupuBskDYP#fshcP=2QPLSC;@`%(cSwHk zP{-*B$2dPrnKReI%oo*{If!g4^kp=UDeNvZP-(?a#kliFLn~GyFDp?4uzG;k0+Ut( zuRh=vXJZ1d-Zpqe8SqpPcrBv;YBa_2Rd@wF5;Os?MZha^3Ie2zro0SZ392|#1H9&0 z;&H7##{0Za?FiDcouzgG8zFMV{h(4%kE4M=Me49 zK@azA8SzX!E5|y>sLs=Kte3nze*yJguu}b5dR{3nNHKcMKX@*`q$X@+ax=berBlj zNEAni7F*5m3)J=}!Fv_VUmMP+GMgBIx_uSY9hY`8ct^$zxQ9 zjhxj^ljM$Ux3B{tHX;3v`X)|+fI&F&<9>jDSa+N%1IG$H58_r38#6nL-!kN(_eGw( zHRyNJ-PpFJk_D2rkyYZdq<%LiYg5C4I&b3-14c1*)|5(jhhRsEY$ABZ_|KFq_Z5!q z(f#2Ff=Lzqdz@Fkx!TYx*orx*4L=mp$ot)e3d7y>5Dn+5R%GrsB5OahAr1h;!D)gV zupt*fn5GD_6R4HXsP%9E)TJ*ybk!2DJ6*ttnhrv9B_x)`S;1FWo+cT^wTabT+o?MH zhpMyOdLdBQ>2vzotA3kr$}#u3JUDc(=-)@BiEHd*qxLASkgpaPHB%Wi>wQLG2TmKQ zpa93YNAa4qeBcFzKu`~DHrnLXlIgc|(}TS+avJYMDq;nRgHEH{9fb8}rxAuo|6n4i z^(PZqPmgg^xLEZQuU_xrH#hZq;@9h4`Sk_LrFtE)S10Y!T!CySw-oe3oMuhd>x~}L zj*>|U>vjFdWZ^0$Hnu+7#3U+l6MO#vjSW zokt0hfMigHtdKz-wdY19cT~!cO5UhcK&gN`eMqzjWho`4C%7j()`whQz%H@{eHVGk zDReS5%G=h{ZLJ(dr@HbC-?LA03rMyvX80KM zGf>`xHr%o2yq!fbCB#JYe4l6<@QB|OA*#ZgV-r_Gnr~9=J?}&B_bnbzV|)T$;-(0Z zLWiV*meDNR{i-Ra+;SE!`C$DZx|Z}wGW;Ql;3h2~O#s}_5<>&z{G>}7q+bFCNfZgp zQooR%7AEOpgZ?A7`uZ1Gn)9Jf0CSM2y}km^hdT6Vh?B_)_~*wihY0rgho zbRtbr93?lQH#N97v-B{@9Uw_DOi^})xr-w@{7lpmX?kmuUg0TD%+jo_*5Uzp~wLn87_i_ zindAU?o)(w$w2WA*7Yg25&@Gb`NiTw`lCQWu*tC58muog@dHg{_OZJfdGN&NFAoV1jJ!3I0knOk-q$<-g$^ z50TtQ@^zBON&X%rakF1vF0qC2U3xsb!oJ)x^3ahSc1Vz=KP;{Cj-p73|2VMv$uwD= z=jnxRtm{v4luwgQ;P zt?ye!>w8uy)pAnI>5BiK)`Q#cw{X9M`-gL}Zo{)*-O0gsDZBnfZ{|S3n!-pjSKsz8 zSvOD$u!YOW+g`HDV|{mQ@0*{pCWgJjdpB!7p%aUomWpfCJ8_|n}NQi<>g7g5CE!o+n zle8w4he&lDLGE@%H)mriO|#l`+f~|aqa7uwNUjNJcI zJIB4riJWLS#z}NG9y1&@O2(rm$#m2#nU7i|%TcRj4XSx;63_D5B|AF{>})#bIO>#~ zM_rPOfh_*MOOp&g82o5)X+k;Kl!<9ck9oZY~KDX#CK2hVEBjVF<0eyZ%+oFGsQh|{sv3iYr&XKjAQg;sZYbiZh z4LzCEj9mN5^RC4P_iM58S*+Esam|rgi%zU!9ktzGtXc~GAL&E=iaxXv#g_lXgP)sr z*P>DPBk7>wk#~u{8}Vx@(s?t^E%*^7Jc7FZq3`us*lMX0TZ^luU8FhH<11^?eZw00 z+a9qV)*A7(*U2Nstj}0n@$GC~tYaMP(3-Tp)^n^|<2gDoYn>w_P5l$}3rm4GwCt@_BHm9M@uE?+8Em zoOV?{v#048=_vNnDsE=&b%TRdi(5u&|Bg><6}Pf_dj7%sy^Xck`-S_xoz=7X3-{ZJ zF*S>QthOy*xVBH3y%Cy0Goya1*w5PB`i0x1HCUz$v$%uRyG`7=2HmcOH`bDcBq!IB zFM_y>wYI&EYv;t>4V<`V#N2OE=~n7vse6&SuOhWy8=9G(h$;N3=~R3w z5uH#hEFol645Nx67BVP&YATYLnVh~5GAYjFXf&0INs-A|G@&@Ot7K}d;$mhJr7TZI zQ={iTHlSTx`46O~CgY<|C1TQeJeGWNVruk7#Tui_3yOVuDjBENlW0qd#uM?xxMG!J zsTnCTS{x`gTmY*aUElw=2;SsQvqTC8jiqQQ($0mserg`aW<$bMk=v}MR)&MN$}2S< zM!B#7KU3HkHi`z(xZ5f+D3&mX&7(D3We@m=%cBPdtp^>Ph%!B(>kcktIXQI0s@SGx zQqwc3q+*RH*x5NQMaLpqsuKf%v4|jX`kdk%m15CUOuLFsOiaBTnMuqfV`CCM{KtQb z1M_Q1O-ZBY6#G;nmW)iqU&PEtXHrvCvtnkUa4{xLsrNz_$&0d*j{-jh^#~N6^2K9` zRFalqkPhY8=5te%vCWgwi|3+KFT|sWjmCPWrKuNUqbc>`Oe`^aZu53&Y&t5XF0iy> zFf}$a5le2KNyR6Uz0(&I_vq9lo;Vgci_XT9QYV_hXmZJy5X^E9tlaip-`o6Kn{Thr zZab3MdUW>q9Z%r$SLeT)^R&yJ_Pn=tfzMaf&%K(jZp&9Uv;X>ni>q2!;A~aaLKSBh z?s@H1*4e{_0LNF!wpvt>^R&L_XgY&au$ygE&e=)|8vr!_3uWD!@kYNur&-iQfYApElauRwPN)+%el1Jpht_9&ta3+?jRA_ zszZa@gG6?#4zW}+Aq-zyqjwh3qA}4P^gm&*7>B^B3US$9xmK(6?uYGKe?{vajBL2F zM*Wd3mC$;E5w?9kzL2nm`Bf!oC9FCiIBZ()61->%+lZ6k%+on+?=yj`@F#~PDl0W3 z_?Y;lMAP!;_yr?N#wSB2b}en9I~0VRQVYe&8XyHJXrrJVfnq)%kG-r|C*lwZce95s+SaIMs=WEYWQVQCj`Qr*dVuz9eJs)yM|P2<(XVj@=) zl50X5+sJq}tl}HB8^Jeft*@V&9Y7^+f6mqXo~t=Mk#%jF9r(~zo$itwdh#CsaIht+UB73*YS*dX2mBw#`uYc*SFD;$VHuv5P$j$vX&wqa=_xP#z9zT_N z=DF;4d&So6l8FRO^p2pBIlZdAx$#e0u5I90?Y&;f;C&ps0C;~*gDpit-N^g zA``iPi7|eY3mZp!SBO0*e0(&<;KtB63gNi>xXNw*7tqH=;Y_UiOl-7g|C!N{U1crX zkwUDIq3-U zAI2|fVeAE{f_J|D8X=KaFTI*~)}+1jC-Y8M-q(0_^Wx^bzbWTmC;QjkvsJ+lhP=D>^3?oP-ctumB70iz*~|g!?D2xYS-d%O zt88vvGQV%`%r^(~d>~ibBiHsUpSjtdt?ifj{(B~4i#20zC|Egf!#AF~e*CTD$h8o^ zWs%$V-!~bXt@k*i*E)Ly8Jty@chB!$7)*Q92eOWq*+Y4U=j)wUyRUSoQ(50W*|Bf- z&>ffW^2ntT)RE`?dEU9OWBz%WZ@h0ayREYW=##TLJ&|#2%$PSaRwDIbOhzkLWQj-o z76$6ehA-Q|C%D1lT0-X10h%tRPTGcSdZ1lob`MazEo|sBiYDkVJiteZhv@MQ@S76E z(OcK(3$uimy1^n^PZH)9lo!o=0D_}pJ+ z39Mo3C|?YO{k2TRXTm?zCHLz{W)(`-9x+ z67&kAEsaMctgYP1XU}qvnOeECh7eXdK?b(4M=Pg|m?Yt|`zI0XH+~w*9VAY&>JSY> zOcV?m!`NH^ssygwjm|;`VvPDeG`1lEwXSnKPV$4vHZP3}q;*-#2p~AcZ4yeR{m!PsaOY&TFI>w$>F~ueF}Zsu1uC9afx!Gl^7uGWL`tO-V}i zLr*m^H8m{}iw*H?RkATroFvg>6HJjwN;Gwf=}6?HM5Z)p9R(*4BzXkEuP}!AggSl? zH$c6Ro=3(6mB=9Y@3S0OLQT`v(~GB_HUd${)w5l2oD0B zqv~?#QfT3LR%pu!9kS4|^xQkozx{kx*gku(V6Q#>{`4vQ&CebFMU0y1bZG_a;UGPu zg9vj2p$gUXP*I-dYt;e1Q&N@6i%VoUgG>~92u4%9mgD z#K|?j7vxU7x>jvv?ucB!%yAClOy!ZGg|q6Yqg-zEh0_5TsDBP6(apK6{Fyn z@Js$C0=67H-r1+~)s5w_I<1 zt97|PTfe=MwFJ7rRRTi6%{iR+JT^w}Ra~Iy>f?)#XZ;;>JcwC#H0GWD%O~egexpCv zxaqydP0L5JjXOSY?!2oer+0ne?6}v+)pR~+;(%0My!2waE9-7Y{~W&Ald3LYI}e-9 zl$2Cl8b@KQRq<$-ns!!7(8kEGu0OaxZ8wHZEpHb!g0zXnzXvZ_EOEzuEUN|JV`CAdD=pX8BUkv-tc!yBG!G% zC~^sRSSXbaqO|!8B<+gcnO8im`fPTnLZ*Z#>^Xy)MN8Pj)C>=+(KG!0;Sa-}J_}nC zA*-(19m2CS-5auWrG+nKI(hQs#f<~Z@D`k$Nv48lV!`6Hjs;UwK?uiS>Rc>XoO33n zQ4J~E%uvVVL8z*tUV1_X!KIkiu5^}S;}o#IoTFHr0=DvK>?D%B(isGbqk^20#0{8K zst8p1=LX$o4J#l0QN)%YCw*v)6rd|MF|{=BMwp@5=k@^EIvc`p&$!I^P(| z*RLXg>g9I)Kw zhRkt!^UkV*hieGVS<^MJD6RtzN9lCUhqju$r&{*3W%yH>b*HijJ#BYs2%B%+{sNsJ>R1a}dI=l^yN*hNkqZa>K^CV}EnEreV&J zcX-nS%cnDrT^aK(6?wmiM^)@3-lH(5ue3*kD%F+e1Py%}QiDQj7&exaK5GR9M~eWH zDdJGbZ7c4vjTyjR;5gjlhhHK7D_T?w-0(GBfwoFz)^dw{LI_**06WxQyJ#=<&QhWL z>Hv8?@fEgcVo6WE)>iId2^`}giic>`N+uk`-=&4UMh`2kqt4-Pun`a{da4WeEAH!y zRY7ZA6qpUNvc`bc$MP9svQ(lpUV>;1+lga@g+3CTj*EXK98?YNL86=ZXt2oOgVDrT zaH_cRGjTz?tJ>sZ$>3ROYBES`GpMq9mcJ-GdW8YHJ3Vz~YIlF|nZe+?{%yN@g+aD! zliVZF>Sa4g!Kt(A6{Yk9)s3K14^Ee;xjVQW|If36E3%fl^HAR4hLhOm2*%F_6H}=Y z6AK9!g^Tg-i!b!-zZidkX*D;@iYY0LDi+n|Wdc4|cmygWCoAd4ST(BUR62l|YWtA7 z5rt)wsUwFd13@L}AO#F~^ipgy1zRZCO2IY?+7X0Y5-nJX*(?}moTN)KTrgusk?pB- zq4;QOil)toYSvH!m1$CJG!{P}lN3iqYVkrHB55b8)^4lzRO&AcX6aBlX(Mq`ss^@_ zA}vI;VU_^nxF6z|O1nz+U2%uKXbF~Rvh+-2G}EN zuxr!QwKBa4>2)we1OS8f-8RU+1GC5S{+3+Z4!Lc|P3hLQY}zyEfd>^q1%z8VA$ z&nlAzvro*Az$PohCOejOt;@N(WLH<#70S7`%dYJ=pUk>;&JN^#b&EYYUx)1LSQ=dR zE+5GHda2T8*|!OORt=?hc-Lk7CHumrEZ?#ekonG~SHJsIuJ6!$eTQzJ&h|YEg~M;1 zJw~Pz_ksg(D-8y`A!FX4GT_5B!1lzHG#Q}L%{&$ z;x*!A2tct-#1iAFb4kTrWEK&sB{rsOf@>mLWB@?R46YO@)*$^ORBr+uCJ9RbY`aTV zwY;Z3-8TO-vxi|#lih7gN92xv+1-E3D!cd3K7GgSzdSKNk#o1o?$#yS^5CsN#@(8s zTOZo0^7X;x(YH^?_1os$c~5i3vp&PGzw7X29E}-sBb)0Lh7dFL0>At8+@CI7Vd^mK z!8mEmyfifg{-`c?=-aZyk}V5XXwz`9rY@8S)7G(BoxY0)#1w`pDvbrnDjTKUb1+&4 z+^{K+ZJY66JQ<{|neO1hbJ4^&r2yR}rh>)$%xDoZE5^h$wssO1Fl$relufa{9F=Gf z9{ctt?Cd8o{9I6F9G=MNx!CB7n9oFnDwm+;a@90HgyvvIvDx1^|Fh@kMi-u)KbN!B z(#|XIx$L>*$vaOyFd6LvmR5bkoM~a}{NZ_5-r=4-swyRh;#R~&^g;}(2{U4x#g_LB zti_6DA~G3$WwbPOXv>Zto07jqgem2u@KRS-9heOKZ37w)8@Cx%mkDiRf)!XCig(FC z2K61h6sz`;HHIy~K?Yi9qOpfpHv_PmvEYZr)#>WBdiUX8oPdR!DaEBH)F%#XtdRYT zq&np{f2RBC}jxaLyjgn{96#ef{X%OIN%LCl*E) zPGqnd=4pT5ybf%8Ve7&RAGq4)40+yhS(q1KOH-=6dHnn9(pwjIFL`BO`#qDvyY8;9 z1tuPE>)rNs%P+ls;pVgQ`Wg2m)*Ett&S;60Po<^IIPW0T*+ zx!YlDa(XZCzO;KmOrMb*Z5eZ$N>YxoXet_+n2KVHK=BmC9rHPX3(5)$4<=6~rngk^ zTp}Z4;JXb8BWyk<{LEq4Ll1wSk%kfA%c94sHR4KC!s?jbgp&PuklJ2#5V$^Fg1Upw zRkv};+SVzlwT*gp0H#X(*#EK&>eyVc;m?`rne^@f}J;8#)NR1HmKQgx>1V_ zzlFQD-KIhC?bQd|4~kTxofD3VIy4V^?sL}RJl&?X6n2(sb`F1^g!c{pBv$z~WKd=J z6vaw{T)KkG8zyO*QeUEg#Cpgp(R^Z`cWNvK%$JgqDvD(|bT4HicndojV;GA0Tnw&4 zrsyd*3T1&TwWSt0n^0<^=f@f11VbVJ zG3K2FA2d5P2S+8KPwJ3jW8I8xVdl8$bCKAVq~d^O4pB0nCl&8Q6-i4}&`cg_(k~#I zbRj4KUW%wJtwS?jQBJ_!F%1D1d&Hd4qyX4xgtgvCZ^V=J5Y|IFK;Ct0O=IjuQwM|#!i}5AHl5t^jDfUkM?f6a8 zP2W;NuIYm{fN#n1odB$P+OC@)Ngt72oy(1~vrm(%c3+OKlli)I_4S5p4Il6w_xzl% z`9Tc_d!zj=d!~I?wr+RETzvw7cn`ff*M|KXhf8QK3#=2T)|>X_K6J`S0H0<5II&KF!dU+fOR*%N{hi%VUgIP%T~$lK!-V!gFyvD z&}XZzE}7kZXD-wC}PT0XcuxYR5A zwm{jdZn!$KI099my60x29N0H|BJZe!4$5e91vwHkWAO3U5CYMo+v>4e;zR0N@Ut?> z|Cgkw!Fz<}WYqz=NQl6ux6h-I59ASzEjxy<>XRsHD9J1>F90SXf)c^%ksfPXr=-?) zm@i6y0WIhdUs3w&-z!Uh#~RY#rYrpgM$nGWB>ioJI?`P0Jurhs7c=k`Ijnj)hTqT| zXHv?`B~SdfBC;rta9w%qwHP6bB3f`(%!PhV4Oh56HK(OmC;zEAt;L+yGD10BLMFrS z7O_eRd1(D61FiO~+BxB_sH6PU!EzZb2FJ-#8x=il*Y20k7$28bf0+78$f>h@UyD85 z8?98rbGl2*750=rZub_;IVfrYx639sZ9BGO<`I3}^E740c zQ7I<m}TZ&DiT z54EP4q{-wXwmn)_Fkj?Vfz<`gd-{8*gZNi+3BhW_Qd_Px(<$n)MJqzSyw zx8!*MoO+qBPj?W)k{dRCz<1xPCydj?f$<-?awOBzd-H%C*fD#8jD}yj^reMY;gsOo z2p6nXpohFK>#x+O_bhLfy*;x>0rLn1Ej@o5yO9V<6G#Z)8a9yK zKtk}#&;#)^Ku6Pk75*LOQf;LOWQFmsqO6W3qMg$=4mHW7onh%AoXw7KX#-IYjc7G61K!XLgW^(Oa{~Fm z)hnrj{VS*l6wIvcs%{SFZ8+Vl3ON%{B*zppbIbLT{d_ETHaatbGbJ%iT`|W~u}MIG z5@L#rw%HP~`1rXqaQskJZ3FyoHM>1L@zoxv`t^$359J^|w0Yz~l3FG7(ni_$HZeNSb796St^RTX{E&q*o=<@A1PP6FqY)BdSBZN;3n(Mivqv^tlb3Z#`v_$yl2@Nh|plyYD7 zN1jwTjOXB{hf7zK?SA;~BScD|JXYa=*4M$!57(iIl&~-CXD@CDC#7#7o1WolNcwkl zhXjoD&k%&_zK}M>B-FpB%5?N9={t1ew2X|Qm|61lKu^){U-#;Y6L~Jw(_J}w|ts3gL7yKhUoZc zQhFCfewD0(x+zw8hLb}&V6mxa@h*bkuOJtEx#QDQ6Y6Gm1*>5UhVd50Wo{#@YPqqfpck)nQ!#qi zy|lelOk&(&m3OyN@@^*eP0RS7QS1A}2~vnaaf3Hg&Fpa%(&FVTz7g+ArSTTlwc{V~ zA?!zEAy~=tI({jhb}qNc?rxa7v4zr4Z#R4+cKwC7UigjbZ#VsV)6Ma0Xn(c^zTHn` zg#$TZKo$nF!r{s-l?wL=+vvHdG>&cS(#)YdCFSg^xOkHruF|73vE-^xo(ZC{tmF_P zD^mhp(GRmj^|E95$w$|SH;~C6f#$9FKF|1@V;6q(wg368$I9Xd9tm4fn?&lGL?p_# zEZ~IyhZOfwz*s!(MJh(PBoi)Uj9TT-jT(PG8-FD>1_u-xh~oT6GV=b2rsn`6m>wr5 z_;c0k3`{QGS;CT07+*(dKfYVuC^)v3j)r8C)@5Z*S~gPjMb zwZIuIR#j0zO{K$BjNKnZAa&kHu8J&nGlkm|Zy_N-v zt9U0t_Q>Q7eAmON@R3LYU-O7W6d@9sWFK5W+!=|$syR`-%h2sTb}bCm5Q$Isv+4wNfLxd@%H z8I!UUqktYHN!aiWF7GQE_GI6Nta(GebJOgJh5ocZJ({jteEb9RCbc}38pW@4lqSH5 zy7VSM9cKCjm1v@Y9nsywz+E=%m6#Rh7n$Jv4O-O%2*B??qL?zv!fYBHu zh#%7BDoXhgrBOhgPd<+bD8=Ldxl{Qn3h?^=#NtHO+`gK6z+hjSe=TcnELxMZ=K6{l z?Y7T%XU#SFs_LuGMQ7IBl&=eYU=FD?M(rtPat=}g8gE(g#;~^ZCMV+2=@tH+WHcg* zy#Mpy!=L5|sOPh($v{ofV6U7S642ek95KKsI552_N%V*_``Qe}Pt~19s2*$JfMRY< z5|j5@sg?i2YE!-Undd3$Rh>hyWhmf$N|AYQ%dXy=B%BJ$K1UE746^hvz&Qt7d- zk1fbyg&k#6nF!{H9ywU;I(CWDJsmV7I^qB7sOY?&-|{oBPw{dF4bO>s@pC`Lviu|Du(wLM2+Q zq`gPapF_`XwMM^)GK6sXT{2+Vuu)qR?2WXt*EFP8g~m!|)0_ub44-*W*P~di zzIm3YN_bI!RoI&(>a8N)jB^d0E372NIVjbG98{lJyW3Q>F4g0b9q?4GE6?-kby_;U z`NN0pVIk}YyTWet3Gal7^}7Y&u7;BXs&Y&iO0lY#Km`jlAriq>wht0ANKa6F69q8} z_9GaQzDCjaDPTf@hDUWBfr(#zc>%7Y>=sL4_nG3qm+mvOZo|# z`&$x(DXeu^9o!u6dgH?D7jm{{+18x1waT_uDCToca;@<#4E@6Xf{C-c=AN5(E)1o+ z7N38QU-yyASNwd<*wW5xlUY}H!NyrUpEx;u;Gv-<{mQl8WiK{Qmxpd{xbfv%2jrgp z3$}c94c@?R3|`;&*1l!yt>f~}Be|W=$oOx3HpAEDYa8dBWV1hZ<=7op)k4iLoKCl8 zoK1LRW8uQ$=A5rf_I2g`wQ2j}<2ipw_J{CpTeUD}hY7!`ntd+ARZHKX$T(XbS%!T$ zL*V?iYPDP7)pIkTy4KY_!OVm{%y|<&n<(z zVbAT|xr3sN|HdIa_nqpxbRZqO)|ji`DA#Yy)^{(zB-d}rR&T{#Tk+iflk)zP&?)MKIp-&^Xf%E^@xhiuciQ&S3x-X>rDnNlbH2Jh-`Kk3lpA~U zjh*?LhV;(GNxYZX?v|Y`a|h;LTIk02Mz(Idxi#bLnLD%qN5AHDYU%K`uik0gN!jpN^l`a1L?h2H?#Gxn;kCxPIU9!E6U3OD&ANN$OdmOF;X|GG<0}Rs`tXsj5l)Kq zkM|mp*w}Kf+Cln{qp1Ic3_~F!zRO@cIWROdG%ye{4vfF~!TP^&^m_M)%mX+=yQF!x zGO*i3%aRb01VF?|777^Tvrz_GaMKs0Hz<`Lp;E;@PuHt{363E!PjWOqeW7<0U;o1! z4cG$XDKn{wia_6(>rDY!CM6O(im#l|Tm1AOb4XSFPzjqU!HO)bt;l!`AFiXX!t|wrGU&1Ou1Fr7lW*L z8AD<`Vik&3_mxt0Aa7v{SJRMh4Bm4!8G1?u5x=V<^nwL(*i7gRvPEzry5i~65WwdJJ6(ZP{RIcb zoLn_N;zuzzXLZc=U*0#rFI@*GzwWoY<@(Jx>uz=aPWNwj%Uhq$IG&a*hkxRsy!+P~ zHjyt2)PaHt@w@6rxe6A<(Xhk6(0_H`;yzgj7HpKnbI1UlpkSvff%66m4vIO+SrIK3 zTy#~@gX-%|*P51Q-g)iq*D{XnvZe2VhjL&)+-Q5F{`LBt!7CfQ1!JQjP#OTlO9Oy7 z`T@U2jM;q~uI`&SqmSh;>|0ewz~C>{fq1D7#4GB+Rk03A-^F!pg5x82JD3ZYO1`Nr zAMCtWyVJ0{R2SlO=t zT%fMtp;#4Wzu={qk0#Ggu>hr2Q>=z-Xem%X3Uzq?yxBPyCf8*dF4UhhUa*@jO?d6Y zYWzt()t54~8Sp(f_W8Af3GutCqF1mWj;F)RUudXcqbr`P@)ztB!;9AH6g%k3sf~e) zt|}(HOmZ2=X4$gkfrm0YV}OG;-X_22@EDp(1c!L3X~fYqd@5-1*ysvP2MTtI37p_5 zP)RgfqHRW58AqpVS^pCk?)`O#)!-@kT!wb|aD(Uy9HKY3v(xQScFGKfA5dx`*kq`t z0cc!oRKGu$^LESL?t%&FcdKi!{`BHc)7q?jUm0%7~-oNqon5yE<+n@4QJLG&TF1Par!J}IP z(!t=$BW6KDp}}sj7VQfK6XK=CggDwLl9i3Fz+W)u6cdVz{(%Em59*AD?zHfL!|8z_ z7&b2Kp-0(RXyKaM!9eqE>+yx5dyc0KXi0F-B^RE$dVKMC#uJo3q|snhd%Z{^f%0FpZW`$INPS*8H&2tTDKJD@SNe09G>~UbIvHO*Ifn2 zU;p;HGrFWG{~%-jaA0hsWq*f=6i^c34SCP8abNf}Vsw$(Jd`iu$%X7n9iEMB6iIP+eg`GS4m8#G$a=g_)DvA13VQP2n zw@Orr+~-Q%Xu8`q)E<$ke-Ep`_Y!qwvuFTH>5|*A3BU} zwCv9iOo59PPx}In_4G*d3}!Mb(mz!q!?T$K-IRJo>K1g{bEWP;*SxA6Uy*uE>aNu5 zQm;zAA@!QnSEOEt-iTJ7E8c3f3P+t27j}mCzZ%A02Rsgvkl||%;=|E;aUSx3zDs%! z?wK3c48oB0SaE$<*enV5dn^)Gf7tV5F05W~ln!&@7H2=pxzHYM3GAwlO#ob>APPS?JYeuG&7MN+HISeelq8zv~Y?)pq(!r;iu!;~f`sDX4AMzw?vN-P`3LCF<!2c%RBx(q+{sD--AB11N);M)?0${7k&y!B{;nOu2 z21R220d-!45Y_|fjRG2yRR+3S7U0As?Q`u-M-?8gvT z-S>`WsM>o&)BYoyL5M*Qe%3p~p|jM#88wP!1rOvyUh8{%?bj~+>YGWVLEE3;`s)c07DrTlXa zo`?D9*JW-E5X=)mcWYx|z$Pu=#?CBsFI;W!PC*?fAI<3>%ViX`I6*YPdniDhSk#c9 zQXZre93-Z{h5}6A8QT)mI0Mrp3DY((eP@bkdk)j21Je@Ur0c-94SbtpD|cU3Ujg4F zBlU#?tk=5r66-Vq46uNAz1=craF@syW|SWdm~i_+-tnalN_>nt&OWHsIEj84>{=M) zU7k4a3G&{8^CTk*bjds-qobNHc!K4uWA1XL^5u#Z5dBN$NJeonXd4A|WZW zIHP5}k*M9n4cX5iwCZxsOj)B^s`bXGmdugtIUZeH_=b{12o$@!Z3Ka{StJRh(Uc{h z&ZZK5>+fi`L_kR;yO@jdrX-UAGFc!KNOr8*+FQ$Myy0};DN<2L+D6NM195^>lo@Kv zDFr%Npi_|_YLOdNpO^WeF0SqlQE{TYN|4l9wBVRg`!uo67+%oqi~)P?lk2@0B17_J zd5~T~lWgij0h?|EFK2mT?%=`XX71t z;;rG%&~CjO-8f@GO62o(n(+pSS(cjR1Nx{Vw2Wpt1F`XT9>SCl2*Q@#cq`p-3cl0A zW*d(=317fcmjz9k6|4xkfLX`}_O!G?R*=x8Jylg)UAE_$Aof&FST?o*?6NEESsEAB zFh#!a$pTT-l40M+Mx)P?Jd^S1R9QnM_B?3UCuI%J04>R~tXZFec1c;26Hv{rD+?Jl zufD7y_)7`kRsh^(WsdDtCL5eM@yfjpN;Pi{HCVYVL`**9jPYM$qFz1a@nK9g8&&t5 zsuBJZh?X&P0sn-0X?UxCGDpZ-SVG^VL5t8AR?L0uzzQn;lm<}LZdnBnW~I{f1>YL- zU*W!tCid-RQc);@0`&VAD(+dibYs~fPoPqgpry6jt^ZL%2^}Tw|5HN!0yF&!O__2; zNtwk0iekPjplEFCvMkm_H45@T{zbqtVIgP6$sw1?M#m^lE))iB=w(3huG6Az6x*F_ zBP+L|RB@N9&1qe{x*V8e$CGV@hjP0o8dRLmw%T&-ybFt+zD(s@E-=-`qWYjeNO^vr z^OTFqbhbrBDQjd(T}2$E4_g*cg!5@J<-00IWE4{k#W7Y@hc-!Qn&w~|NR~t8Fs?K% LH$G^5xOVHm?R_j^ diff --git a/mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/modules/__pycache__/wrapper_ops.cpython-313.pyc deleted file mode 100644 index 246f6bf3a7d1e6a8d3bb5a57898959c04d5d5947..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7737 zcmcJUU2q#$6@ag#)!*vJvSrJ$)mW14G*z+e)J}h#rlqMHm)Ne7S~e}UXO^|Rww20~ zdUr*}F2k5UK&DeBNtuw82SNu1k{5;t24)x@=nMlhz%Y&dpkZqhN~q)w)@Pr4@D)Xmb4Nq)jZJyFu5*Y;8$dz+uO zq?|M`Z<+2lI<9^mAiX5z?j|YzajUL{*eaRr??K`+nJ2I;gj!P;YCA+y zR?4Mp)IRT+cc!?*P$Olhu6Z})9F-i;a!!_WQ4h zR%XexSyW9tQb~#hLgp&_nGjPe5d!V-O%2a1$r_7I=G;c%96Py~Yu?J(+0feT3u=0e z_SMgn<=Cr)hz=N)-_V}!+NGx^?Q431cj(W(-*3?co6riQF{+Sd_!qY&sS`BgaoB_K zsrkodPtFJ@zIfu))QRV(PEsEzDGp0>E=!82z}nhid(Z&tu_c@kPL{GuGOkBddwFOj zzbXx_it8(4{%l5sR9Z?DX#T8}R`kq*luNG+jnKt{NR>6FtBCw!F)PVK=k#+fKuZh$O4EBqIy~`9KUe+ijO{)Kl= z-|CM2u|H*e@zWX&%f}kDNqJ_wxJ6W=h)2YGh$9xG9J`if{+teB^Wa=2&?SNLE0+_ zZxqFBrKCj=mNHaU@HFJ|Q0x?h#e7;2^yT0d@_9Hkf&fP&57Q4PRE9F_Mzg^|V@soV z-3mI|h9Zcf9YqIF$@3#yT!-4#voUqCA9WAKHcnrJ z7BMyU@Q2)itDt;Xjr4!Wg|2d2T))~K*?8&V(Je0WZ_Z_p=|jNY>i}!)^z9s!)B_gr z+5?shY?&v;&D)!vD*G;aM6wh9yXFDNU~I@y9&OxN?yQCKDV}&W(Cx~dBT!*| zbXro1G$%);jIsg-NUXIi(K#4m}`F8lG8;0k%)s z-fFam_v5)}cxG+r0FB!F$sC>)_B0>=tYy|ZYnug&KFamOc~8#7Z8{hSp_S$ZUqe^| z{eh@?d&ae#&34U z)ks1;7`y9nwEN4%;q(71Kmu(S*4|!Iy&bAIq=w__!Ngs+J;0ZV-OcOg9=F;J+ssVz zhL=VlO`|AcD0)%!p}+~D=zG$B6u4t`Ts?^?^miz}rv_$hdQq^s0{cq7LH^8*mAxd| zuf`6kPaM1J3)}hI!Hz4VmqyEWNS7U?qw~tCOQ*_C%(+O{!Ll1uJPAk29!z;jsOQ6; z;hP;JJJ`BRKhL-Ek7(TWLzFQ7JHZoXOYY6x5;8@jDt1C^K zh|A<2t4R;D4Xk-HIWSq!7h17~?VL6l0>eO0#0LMj9YC+kqzU2v0$g$~e1vjO!u_7r zxF6x1dlK&Y10bC9GZLBDs&lNHT$BnR**D)Ra904mIg3C2QR8_uDHLE z!OS~43YD}2m3{mQ+&Jp?G2^n%8$8CkyJc7MG`*s=f*UL>N=srf3!m7(PN(Bg5uE5J zu>P|sK97Qd`#DTC&>OeAPVZ<`cfSy2T1VYD41~`>k^EZ_fbh_L5&lvs@e&}s6g;vS zJkku<`_v_#ps(sU@=65<%YU ze_e+v;P*KZ{LWhNgDo8&PsV-wAwPpYQbD%?%v;?R|F{|y1yW5u2Lkc`V%bj~830cd z{8F?4t-W9Y+Cz-_c1YJ)0L-}zd*H?#PdX!4l}ov@2XkJu4n9oz(K@tXDp0i!_~tSJ zp<4$WNu71@g94xJZ-Cg%I;6Z0(l&rqaDr7ZoL+B=H=SMx41q=PY2ooA#7-a{kb}#b zv{<(n@2tT{*OT|!hdtWp)?$we$;P8ne@fDL!89Ho63i(4d^flE7XW+5!p~6dX%v3I zYTVB#IQBFOUw;5b!SNZ5Li6_fHww+$SB-+BsZro&G0tcmb?)DBh88w-w3+M7HBg)m z!@^SZQt$Pb+5Ht+O;4dfcew$}48Cz^Q|{JpCK)$V1-UiE?b~B}8vOk#w3L4W0{Gj0 zU;d7l661`&Bb&jI{j#=My$1fGZ{6ZT)Vct*?F_Yv&6Y$JwYr1taUf?MwVb7+ICIBP zH2BGEUQyujW|w{h<*Pt7(@$=R)KC)AHI&3$m6Dj_k&7Nod6A1gO!=!^#CMem2%U@A zaUB&Ka4`1|8-`1=!J&A#vllzZaum#oJ9@VED^|LOz4 z-~aj6U*i)uyZzgz+hIl<9(K#O%$c6Y4BbOS8$R>K~{yRC^QI5l@A zdX90TN)qElZ4dx5evAOObD2sKMV|(u;WV?H!GIA)Sn2biqhlZ-99az+oIx#C2N$ao z(>Mm&iUPkqWoX4fLIYYgM}`Uz-q%rm5-W^@h&y!7HtC>gOoF0+ru8*D29*&0C-s0X>_LD~4x2IH5;_NBPdG#{&|Kv-KIOfR#a5HvrA zI;+1<>0z<}80VxJK2$0Ya%=p*4~| torch.Tensor: # [n_nodes, irreps] - return self.linear(node_attrs) - - -@compile_mode("script") -class LinearReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - irrep_out: o3.Irreps = o3.Irreps("0e"), - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - self.linear = Linear( - irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config - ) - - def forward( - self, - x: torch.Tensor, - heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - return self.linear(x) # [n_nodes, 1] - - -@simplify_if_compile -@compile_mode("script") -class NonLinearReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - MLP_irreps: o3.Irreps, - gate: Optional[Callable], - irrep_out: o3.Irreps = o3.Irreps("0e"), - num_heads: int = 1, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - self.hidden_irreps = MLP_irreps - self.num_heads = num_heads - self.linear_1 = Linear( - irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config - ) - self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) - self.linear_2 = Linear( - irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config - ) - - def forward( - self, x: torch.Tensor, heads: Optional[torch.Tensor] = None - ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - x = self.non_linearity(self.linear_1(x)) - if hasattr(self, "num_heads"): - if self.num_heads > 1 and heads is not None: - x = mask_head(x, heads, self.num_heads) - return self.linear_2(x) # [n_nodes, len(heads)] - - -@compile_mode("script") -class LinearDipoleReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - dipole_only: bool = False, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - if dipole_only: - self.irreps_out = o3.Irreps("1x1o") - else: - self.irreps_out = o3.Irreps("1x0e + 1x1o") - self.linear = Linear( - irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - return self.linear(x) # [n_nodes, 1] - - -@compile_mode("script") -class NonLinearDipoleReadoutBlock(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - MLP_irreps: o3.Irreps, - gate: Callable, - dipole_only: bool = False, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - super().__init__() - self.hidden_irreps = MLP_irreps - if dipole_only: - self.irreps_out = o3.Irreps("1x1o") - else: - self.irreps_out = o3.Irreps("1x0e + 1x1o") - irreps_scalars = o3.Irreps( - [(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out] - ) - irreps_gated = o3.Irreps( - [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] - ) - irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) - self.equivariant_nonlin = nn.Gate( - irreps_scalars=irreps_scalars, - act_scalars=[gate for _, ir in irreps_scalars], - irreps_gates=irreps_gates, - act_gates=[gate] * len(irreps_gates), - irreps_gated=irreps_gated, - ) - self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() - self.linear_1 = Linear( - irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config - ) - self.linear_2 = Linear( - irreps_in=self.hidden_irreps, - irreps_out=self.irreps_out, - cueq_config=cueq_config, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] - x = self.equivariant_nonlin(self.linear_1(x)) - return self.linear_2(x) # [n_nodes, 1] - - -@compile_mode("script") -class AtomicEnergiesBlock(torch.nn.Module): - atomic_energies: torch.Tensor - - def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): - super().__init__() - # assert len(atomic_energies.shape) == 1 - - self.register_buffer( - "atomic_energies", - torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), - ) # [n_elements, n_heads] - - def forward( - self, x: torch.Tensor # one-hot of elements [..., n_elements] - ) -> torch.Tensor: # [..., ] - return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) - - def __repr__(self): - formatted_energies = ", ".join( - [ - "[" + ", ".join([f"{x:.4f}" for x in group]) + "]" - for group in torch.atleast_2d(self.atomic_energies) - ] - ) - return f"{self.__class__.__name__}(energies=[{formatted_energies}])" - - -@compile_mode("script") -class RadialEmbeddingBlock(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - radial_type: str = "bessel", - distance_transform: str = "None", - ): - super().__init__() - if radial_type == "bessel": - self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) - elif radial_type == "gaussian": - self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel) - elif radial_type == "chebyshev": - self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel) - if distance_transform == "Agnesi": - self.distance_transform = AgnesiTransform() - elif distance_transform == "Soft": - self.distance_transform = SoftTransform() - self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) - self.out_dim = num_bessel - - def forward( - self, - edge_lengths: torch.Tensor, # [n_edges, 1] - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ): - cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] - if hasattr(self, "distance_transform"): - edge_lengths = self.distance_transform( - edge_lengths, node_attrs, edge_index, atomic_numbers - ) - radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis] - return radial * cutoff # [n_edges, n_basis] - - -@compile_mode("script") -class EquivariantProductBasisBlock(torch.nn.Module): - def __init__( - self, - node_feats_irreps: o3.Irreps, - target_irreps: o3.Irreps, - correlation: int, - use_sc: bool = True, - num_elements: Optional[int] = None, - cueq_config: Optional[CuEquivarianceConfig] = None, - ) -> None: - super().__init__() - - self.use_sc = use_sc - self.symmetric_contractions = SymmetricContractionWrapper( - irreps_in=node_feats_irreps, - irreps_out=target_irreps, - correlation=correlation, - num_elements=num_elements, - cueq_config=cueq_config, - ) - # Update linear - self.linear = Linear( - target_irreps, - target_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=cueq_config, - ) - self.cueq_config = cueq_config - - def forward( - self, - node_feats: torch.Tensor, - sc: Optional[torch.Tensor], - node_attrs: torch.Tensor, - ) -> torch.Tensor: - use_cueq = False - use_cueq_mul_ir = False - if hasattr(self, "cueq_config"): - if self.cueq_config is not None: - if self.cueq_config.enabled and ( - self.cueq_config.optimize_all or self.cueq_config.optimize_symmetric - ): - use_cueq = True - if self.cueq_config.layout_str == "mul_ir": - use_cueq_mul_ir = True - if use_cueq: - if use_cueq_mul_ir: - node_feats = torch.transpose(node_feats, 1, 2) - index_attrs = torch.nonzero(node_attrs)[:, 1].int() - node_feats = self.symmetric_contractions( - node_feats.flatten(1), - index_attrs, - ) - else: - node_feats = self.symmetric_contractions(node_feats, node_attrs) - if self.use_sc and sc is not None: - return self.linear(node_feats) + sc - return self.linear(node_feats) - - -@compile_mode("script") -class InteractionBlock(torch.nn.Module): - def __init__( - self, - node_attrs_irreps: o3.Irreps, - node_feats_irreps: o3.Irreps, - edge_attrs_irreps: o3.Irreps, - edge_feats_irreps: o3.Irreps, - target_irreps: o3.Irreps, - hidden_irreps: o3.Irreps, - avg_num_neighbors: float, - radial_MLP: Optional[List[int]] = None, - cueq_config: Optional[CuEquivarianceConfig] = None, - ) -> None: - super().__init__() - self.node_attrs_irreps = node_attrs_irreps - self.node_feats_irreps = node_feats_irreps - self.edge_attrs_irreps = edge_attrs_irreps - self.edge_feats_irreps = edge_feats_irreps - self.target_irreps = target_irreps - self.hidden_irreps = hidden_irreps - self.avg_num_neighbors = avg_num_neighbors - if radial_MLP is None: - radial_MLP = [64, 64, 64] - self.radial_MLP = radial_MLP - self.cueq_config = cueq_config - self._setup() - - @abstractmethod - def _setup(self) -> None: - raise NotImplementedError - - def handle_lammps( - self, - node_feats: torch.Tensor, - lammps_class: Optional[Any], - lammps_natoms: Tuple[int, int], - first_layer: bool, - ) -> torch.Tensor: # noqa: D401 – internal helper - if lammps_class is None or first_layer or torch.jit.is_scripting(): - return node_feats - _, n_total = lammps_natoms - pad = torch.zeros( - (n_total, node_feats.shape[1]), - dtype=node_feats.dtype, - device=node_feats.device, - ) - node_feats = torch.cat((node_feats, pad), dim=0) - node_feats = LAMMPS_MP.apply(node_feats, lammps_class) - return node_feats - - def truncate_ghosts( - self, tensor: torch.Tensor, n_real: Optional[int] = None - ) -> torch.Tensor: - """Truncate the tensor to only keep the real atoms in case of presence of ghost atoms during multi-GPU MD simulations.""" - return tensor[:n_real] if n_real is not None else tensor - - @abstractmethod - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - ) -> torch.Tensor: - raise NotImplementedError - - -nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh} - - -@compile_mode("script") -class RealAgnosticInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - # Selector TensorProduct - self.skip_tp = FullyConnectedTensorProduct( - self.irreps_out, - self.node_attrs_irreps, - self.irreps_out, - cueq_config=self.cueq_config, - ) - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_natoms: Tuple[int, int] = (0, 0), - lammps_class: Optional[Any] = None, - first_layer: bool = False, - ) -> Tuple[torch.Tensor, None]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - n_real = lammps_natoms[0] if lammps_class is not None else None - node_feats = self.linear_up(node_feats) - node_feats = self.handle_lammps( - node_feats, - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - first_layer=first_layer, - ) - tp_weights = self.conv_tp_weights(edge_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.truncate_ghosts(message, n_real) - node_attrs = self.truncate_ghosts(node_attrs, n_real) - message = self.linear(message) / self.avg_num_neighbors - message = self.skip_tp(message, node_attrs) - return ( - self.reshape(message), - None, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticResidualInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, # gate - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - # Selector TensorProduct - self.skip_tp = FullyConnectedTensorProduct( - self.node_feats_irreps, - self.node_attrs_irreps, - self.hidden_irreps, - cueq_config=self.cueq_config, - ) - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_class: Optional[Any] = None, - lammps_natoms: Tuple[int, int] = (0, 0), - first_layer: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - n_real = lammps_natoms[0] if lammps_class is not None else None - sc = self.skip_tp(node_feats, node_attrs) - node_feats = self.linear_up(node_feats) - node_feats = self.handle_lammps( - node_feats, - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - first_layer=first_layer, - ) - tp_weights = self.conv_tp_weights(edge_feats) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.truncate_ghosts(message, n_real) - node_attrs = self.truncate_ghosts(node_attrs, n_real) - sc = self.truncate_ghosts(sc, n_real) - message = self.linear(message) / self.avg_num_neighbors - return ( - self.reshape(message), - sc, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticDensityInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - # Selector TensorProduct - self.skip_tp = FullyConnectedTensorProduct( - self.irreps_out, - self.node_attrs_irreps, - self.irreps_out, - cueq_config=self.cueq_config, - ) - - # Density normalization - self.density_fn = nn.FullyConnectedNet( - [input_dim] - + [ - 1, - ], - torch.nn.functional.silu, - ) - # Reshape - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_class: Optional[Any] = None, - lammps_natoms: Tuple[int, int] = (0, 0), - first_layer: bool = False, - ) -> Tuple[torch.Tensor, None]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - n_real = lammps_natoms[0] if lammps_class is not None else None - node_feats = self.linear_up(node_feats) - node_feats = self.handle_lammps( - node_feats, - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - first_layer=first_layer, - ) - tp_weights = self.conv_tp_weights(edge_feats) - edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - density = scatter_sum( - src=edge_density, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, 1] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.truncate_ghosts(message, n_real) - node_attrs = self.truncate_ghosts(node_attrs, n_real) - density = self.truncate_ghosts(density, n_real) - message = self.linear(message) / (density + 1) - message = self.skip_tp(message, node_attrs) - return ( - self.reshape(message), - None, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - input_dim = self.edge_feats_irreps.num_irreps - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], - torch.nn.functional.silu, # gate - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - # Selector TensorProduct - self.skip_tp = FullyConnectedTensorProduct( - self.node_feats_irreps, - self.node_attrs_irreps, - self.hidden_irreps, - cueq_config=self.cueq_config, - ) - - # Density normalization - self.density_fn = nn.FullyConnectedNet( - [input_dim] - + [ - 1, - ], - torch.nn.functional.silu, - ) - - # Reshape - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_class: Optional[Any] = None, - lammps_natoms: Tuple[int, int] = (0, 0), - first_layer: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - n_real = lammps_natoms[0] if lammps_class is not None else None - sc = self.skip_tp(node_feats, node_attrs) - node_feats = self.linear_up(node_feats) - node_feats = self.handle_lammps( - node_feats, - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - first_layer=first_layer, - ) - tp_weights = self.conv_tp_weights(edge_feats) - edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) - mji = self.conv_tp( - node_feats[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - density = scatter_sum( - src=edge_density, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, 1] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.truncate_ghosts(message, n_real) - node_attrs = self.truncate_ghosts(node_attrs, n_real) - density = self.truncate_ghosts(density, n_real) - sc = self.truncate_ghosts(sc, n_real) - message = self.linear(message) / (density + 1) - return ( - self.reshape(message), - sc, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class RealAgnosticAttResidualInteractionBlock(InteractionBlock): - def _setup(self) -> None: - if not hasattr(self, "cueq_config"): - self.cueq_config = None - self.node_feats_down_irreps = o3.Irreps("64x0e") - # First linear - self.linear_up = Linear( - self.node_feats_irreps, - self.node_feats_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - # TensorProduct - irreps_mid, instructions = tp_out_irreps_with_instructions( - self.node_feats_irreps, - self.edge_attrs_irreps, - self.target_irreps, - ) - self.conv_tp = TensorProduct( - self.node_feats_irreps, - self.edge_attrs_irreps, - irreps_mid, - instructions=instructions, - shared_weights=False, - internal_weights=False, - cueq_config=self.cueq_config, - ) - - # Convolution weights - self.linear_down = Linear( - self.node_feats_irreps, - self.node_feats_down_irreps, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - input_dim = ( - self.edge_feats_irreps.num_irreps - + 2 * self.node_feats_down_irreps.num_irreps - ) - self.conv_tp_weights = nn.FullyConnectedNet( - [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], - torch.nn.functional.silu, - ) - - # Linear - self.irreps_out = self.target_irreps - self.linear = Linear( - irreps_mid, - self.irreps_out, - internal_weights=True, - shared_weights=True, - cueq_config=self.cueq_config, - ) - - self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) - - # Skip connection. - self.skip_linear = Linear( - self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config - ) - - # pylint: disable=unused-argument - def forward( - self, - node_attrs: torch.Tensor, - node_feats: torch.Tensor, - edge_attrs: torch.Tensor, - edge_feats: torch.Tensor, - edge_index: torch.Tensor, - lammps_class: Optional[Any] = None, - lammps_natoms: Tuple[int, int] = (0, 0), - first_layer: bool = False, - ) -> Tuple[torch.Tensor, None]: - sender = edge_index[0] - receiver = edge_index[1] - num_nodes = node_feats.shape[0] - sc = self.skip_linear(node_feats) - node_feats_up = self.linear_up(node_feats) - node_feats_down = self.linear_down(node_feats) - augmented_edge_feats = torch.cat( - [ - edge_feats, - node_feats_down[sender], - node_feats_down[receiver], - ], - dim=-1, - ) - tp_weights = self.conv_tp_weights(augmented_edge_feats) - mji = self.conv_tp( - node_feats_up[sender], edge_attrs, tp_weights - ) # [n_edges, irreps] - message = scatter_sum( - src=mji, index=receiver, dim=0, dim_size=num_nodes - ) # [n_nodes, irreps] - message = self.linear(message) / self.avg_num_neighbors - return ( - self.reshape(message), - sc, - ) # [n_nodes, channels, (lmax + 1)**2] - - -@compile_mode("script") -class ScaleShiftBlock(torch.nn.Module): - def __init__(self, scale: float, shift: float): - super().__init__() - self.register_buffer( - "scale", - torch.tensor(scale, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "shift", - torch.tensor(shift, dtype=torch.get_default_dtype()), - ) - - def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: - return ( - torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head] - ) - - def __repr__(self): - formatted_scale = ( - ", ".join([f"{x:.4f}" for x in self.scale]) - if self.scale.numel() > 1 - else f"{self.scale.item():.4f}" - ) - formatted_shift = ( - ", ".join([f"{x:.4f}" for x in self.shift]) - if self.shift.numel() > 1 - else f"{self.shift.item():.4f}" - ) - return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})" +########################################################################################### +# Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from abc import abstractmethod +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch.nn.functional +from e3nn import nn, o3 +from e3nn.util.jit import compile_mode + +from mace.modules.wrapper_ops import ( + CuEquivarianceConfig, + FullyConnectedTensorProduct, + Linear, + SymmetricContractionWrapper, + TensorProduct, +) +from mace.tools.compile import simplify_if_compile +from mace.tools.scatter import scatter_sum +from mace.tools.utils import LAMMPS_MP + +from .irreps_tools import mask_head, reshape_irreps, tp_out_irreps_with_instructions +from .radial import ( + AgnesiTransform, + BesselBasis, + ChebychevBasis, + GaussianBasis, + PolynomialCutoff, + SoftTransform, +) + + +@compile_mode("script") +class LinearNodeEmbeddingBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irreps_out, cueq_config=cueq_config + ) + + def forward( + self, + node_attrs: torch.Tensor, + ) -> torch.Tensor: # [n_nodes, irreps] + return self.linear(node_attrs) + + +@compile_mode("script") +class LinearReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps = o3.Irreps("0e"), + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + self.linear = Linear( + irreps_in=irreps_in, irreps_out=irrep_out, cueq_config=cueq_config + ) + + def forward( + self, + x: torch.Tensor, + heads: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@simplify_if_compile +@compile_mode("script") +class NonLinearReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Optional[Callable], + irrep_out: o3.Irreps = o3.Irreps("0e"), + num_heads: int = 1, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + self.hidden_irreps = MLP_irreps + self.num_heads = num_heads + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.hidden_irreps, cueq_config=cueq_config + ) + self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, irreps_out=irrep_out, cueq_config=cueq_config + ) + + def forward( + self, x: torch.Tensor, heads: Optional[torch.Tensor] = None + ) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.non_linearity(self.linear_1(x)) + if hasattr(self, "num_heads"): + if self.num_heads > 1 and heads is not None: + x = mask_head(x, heads, self.num_heads) + return self.linear_2(x) # [n_nodes, len(heads)] + + +@compile_mode("script") +class LinearDipoleReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + if dipole_only: + self.irreps_out = o3.Irreps("1x1o") + else: + self.irreps_out = o3.Irreps("1x0e + 1x1o") + self.linear = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_out, cueq_config=cueq_config + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + return self.linear(x) # [n_nodes, 1] + + +@compile_mode("script") +class NonLinearDipoleReadoutBlock(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + MLP_irreps: o3.Irreps, + gate: Callable, + dipole_only: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + super().__init__() + self.hidden_irreps = MLP_irreps + if dipole_only: + self.irreps_out = o3.Irreps("1x1o") + else: + self.irreps_out = o3.Irreps("1x0e + 1x1o") + irreps_scalars = o3.Irreps( + [(mul, ir) for mul, ir in MLP_irreps if ir.l == 0 and ir in self.irreps_out] + ) + irreps_gated = o3.Irreps( + [(mul, ir) for mul, ir in MLP_irreps if ir.l > 0 and ir in self.irreps_out] + ) + irreps_gates = o3.Irreps([mul, "0e"] for mul, _ in irreps_gated) + self.equivariant_nonlin = nn.Gate( + irreps_scalars=irreps_scalars, + act_scalars=[gate for _, ir in irreps_scalars], + irreps_gates=irreps_gates, + act_gates=[gate] * len(irreps_gates), + irreps_gated=irreps_gated, + ) + self.irreps_nonlin = self.equivariant_nonlin.irreps_in.simplify() + self.linear_1 = Linear( + irreps_in=irreps_in, irreps_out=self.irreps_nonlin, cueq_config=cueq_config + ) + self.linear_2 = Linear( + irreps_in=self.hidden_irreps, + irreps_out=self.irreps_out, + cueq_config=cueq_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] + x = self.equivariant_nonlin(self.linear_1(x)) + return self.linear_2(x) # [n_nodes, 1] + + +@compile_mode("script") +class AtomicEnergiesBlock(torch.nn.Module): + atomic_energies: torch.Tensor + + def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): + super().__init__() + # assert len(atomic_energies.shape) == 1 + + self.register_buffer( + "atomic_energies", + torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), + ) # [n_elements, n_heads] + + def forward( + self, x: torch.Tensor # one-hot of elements [..., n_elements] + ) -> torch.Tensor: # [..., ] + return torch.matmul(x, torch.atleast_2d(self.atomic_energies).T) + + def __repr__(self): + formatted_energies = ", ".join( + [ + "[" + ", ".join([f"{x:.4f}" for x in group]) + "]" + for group in torch.atleast_2d(self.atomic_energies) + ] + ) + return f"{self.__class__.__name__}(energies=[{formatted_energies}])" + + +@compile_mode("script") +class RadialEmbeddingBlock(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + radial_type: str = "bessel", + distance_transform: str = "None", + ): + super().__init__() + if radial_type == "bessel": + self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "gaussian": + self.bessel_fn = GaussianBasis(r_max=r_max, num_basis=num_bessel) + elif radial_type == "chebyshev": + self.bessel_fn = ChebychevBasis(r_max=r_max, num_basis=num_bessel) + if distance_transform == "Agnesi": + self.distance_transform = AgnesiTransform() + elif distance_transform == "Soft": + self.distance_transform = SoftTransform() + self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) + self.out_dim = num_bessel + + def forward( + self, + edge_lengths: torch.Tensor, # [n_edges, 1] + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ): + cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] + if hasattr(self, "distance_transform"): + edge_lengths = self.distance_transform( + edge_lengths, node_attrs, edge_index, atomic_numbers + ) + radial = self.bessel_fn(edge_lengths) # [n_edges, n_basis] + return radial * cutoff # [n_edges, n_basis] + + +@compile_mode("script") +class EquivariantProductBasisBlock(torch.nn.Module): + def __init__( + self, + node_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + correlation: int, + use_sc: bool = True, + num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ) -> None: + super().__init__() + + self.use_sc = use_sc + self.symmetric_contractions = SymmetricContractionWrapper( + irreps_in=node_feats_irreps, + irreps_out=target_irreps, + correlation=correlation, + num_elements=num_elements, + cueq_config=cueq_config, + ) + # Update linear + self.linear = Linear( + target_irreps, + target_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=cueq_config, + ) + self.cueq_config = cueq_config + + def forward( + self, + node_feats: torch.Tensor, + sc: Optional[torch.Tensor], + node_attrs: torch.Tensor, + ) -> torch.Tensor: + use_cueq = False + use_cueq_mul_ir = False + if hasattr(self, "cueq_config"): + if self.cueq_config is not None: + if self.cueq_config.enabled and ( + self.cueq_config.optimize_all or self.cueq_config.optimize_symmetric + ): + use_cueq = True + if self.cueq_config.layout_str == "mul_ir": + use_cueq_mul_ir = True + if use_cueq: + if use_cueq_mul_ir: + node_feats = torch.transpose(node_feats, 1, 2) + index_attrs = torch.nonzero(node_attrs)[:, 1].int() + node_feats = self.symmetric_contractions( + node_feats.flatten(1), + index_attrs, + ) + else: + node_feats = self.symmetric_contractions(node_feats, node_attrs) + if self.use_sc and sc is not None: + return self.linear(node_feats) + sc + return self.linear(node_feats) + + +@compile_mode("script") +class InteractionBlock(torch.nn.Module): + def __init__( + self, + node_attrs_irreps: o3.Irreps, + node_feats_irreps: o3.Irreps, + edge_attrs_irreps: o3.Irreps, + edge_feats_irreps: o3.Irreps, + target_irreps: o3.Irreps, + hidden_irreps: o3.Irreps, + avg_num_neighbors: float, + radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ) -> None: + super().__init__() + self.node_attrs_irreps = node_attrs_irreps + self.node_feats_irreps = node_feats_irreps + self.edge_attrs_irreps = edge_attrs_irreps + self.edge_feats_irreps = edge_feats_irreps + self.target_irreps = target_irreps + self.hidden_irreps = hidden_irreps + self.avg_num_neighbors = avg_num_neighbors + if radial_MLP is None: + radial_MLP = [64, 64, 64] + self.radial_MLP = radial_MLP + self.cueq_config = cueq_config + self._setup() + + @abstractmethod + def _setup(self) -> None: + raise NotImplementedError + + def handle_lammps( + self, + node_feats: torch.Tensor, + lammps_class: Optional[Any], + lammps_natoms: Tuple[int, int], + first_layer: bool, + ) -> torch.Tensor: # noqa: D401 – internal helper + if lammps_class is None or first_layer or torch.jit.is_scripting(): + return node_feats + _, n_total = lammps_natoms + pad = torch.zeros( + (n_total, node_feats.shape[1]), + dtype=node_feats.dtype, + device=node_feats.device, + ) + node_feats = torch.cat((node_feats, pad), dim=0) + node_feats = LAMMPS_MP.apply(node_feats, lammps_class) + return node_feats + + def truncate_ghosts( + self, tensor: torch.Tensor, n_real: Optional[int] = None + ) -> torch.Tensor: + """Truncate the tensor to only keep the real atoms in case of presence of ghost atoms during multi-GPU MD simulations.""" + return tensor[:n_real] if n_real is not None else tensor + + @abstractmethod + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +nonlinearities = {1: torch.nn.functional.silu, -1: torch.tanh} + + +@compile_mode("script") +class RealAgnosticInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + # Selector TensorProduct + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, + ) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_natoms: Tuple[int, int] = (0, 0), + lammps_class: Optional[Any] = None, + first_layer: bool = False, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + n_real = lammps_natoms[0] if lammps_class is not None else None + node_feats = self.linear_up(node_feats) + node_feats = self.handle_lammps( + node_feats, + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + first_layer=first_layer, + ) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.truncate_ghosts(message, n_real) + node_attrs = self.truncate_ghosts(node_attrs, n_real) + message = self.linear(message) / self.avg_num_neighbors + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + # Selector TensorProduct + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, + ) + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_class: Optional[Any] = None, + lammps_natoms: Tuple[int, int] = (0, 0), + first_layer: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + n_real = lammps_natoms[0] if lammps_class is not None else None + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + node_feats = self.handle_lammps( + node_feats, + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + first_layer=first_layer, + ) + tp_weights = self.conv_tp_weights(edge_feats) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.truncate_ghosts(message, n_real) + node_attrs = self.truncate_ghosts(node_attrs, n_real) + sc = self.truncate_ghosts(sc, n_real) + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + # Selector TensorProduct + self.skip_tp = FullyConnectedTensorProduct( + self.irreps_out, + self.node_attrs_irreps, + self.irreps_out, + cueq_config=self.cueq_config, + ) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + # Reshape + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_class: Optional[Any] = None, + lammps_natoms: Tuple[int, int] = (0, 0), + first_layer: bool = False, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + n_real = lammps_natoms[0] if lammps_class is not None else None + node_feats = self.linear_up(node_feats) + node_feats = self.handle_lammps( + node_feats, + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + first_layer=first_layer, + ) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.truncate_ghosts(message, n_real) + node_attrs = self.truncate_ghosts(node_attrs, n_real) + density = self.truncate_ghosts(density, n_real) + message = self.linear(message) / (density + 1) + message = self.skip_tp(message, node_attrs) + return ( + self.reshape(message), + None, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticDensityResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + input_dim = self.edge_feats_irreps.num_irreps + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + self.radial_MLP + [self.conv_tp.weight_numel], + torch.nn.functional.silu, # gate + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + # Selector TensorProduct + self.skip_tp = FullyConnectedTensorProduct( + self.node_feats_irreps, + self.node_attrs_irreps, + self.hidden_irreps, + cueq_config=self.cueq_config, + ) + + # Density normalization + self.density_fn = nn.FullyConnectedNet( + [input_dim] + + [ + 1, + ], + torch.nn.functional.silu, + ) + + # Reshape + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_class: Optional[Any] = None, + lammps_natoms: Tuple[int, int] = (0, 0), + first_layer: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + n_real = lammps_natoms[0] if lammps_class is not None else None + sc = self.skip_tp(node_feats, node_attrs) + node_feats = self.linear_up(node_feats) + node_feats = self.handle_lammps( + node_feats, + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + first_layer=first_layer, + ) + tp_weights = self.conv_tp_weights(edge_feats) + edge_density = torch.tanh(self.density_fn(edge_feats) ** 2) + mji = self.conv_tp( + node_feats[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + density = scatter_sum( + src=edge_density, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, 1] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.truncate_ghosts(message, n_real) + node_attrs = self.truncate_ghosts(node_attrs, n_real) + density = self.truncate_ghosts(density, n_real) + sc = self.truncate_ghosts(sc, n_real) + message = self.linear(message) / (density + 1) + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class RealAgnosticAttResidualInteractionBlock(InteractionBlock): + def _setup(self) -> None: + if not hasattr(self, "cueq_config"): + self.cueq_config = None + self.node_feats_down_irreps = o3.Irreps("64x0e") + # First linear + self.linear_up = Linear( + self.node_feats_irreps, + self.node_feats_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + # TensorProduct + irreps_mid, instructions = tp_out_irreps_with_instructions( + self.node_feats_irreps, + self.edge_attrs_irreps, + self.target_irreps, + ) + self.conv_tp = TensorProduct( + self.node_feats_irreps, + self.edge_attrs_irreps, + irreps_mid, + instructions=instructions, + shared_weights=False, + internal_weights=False, + cueq_config=self.cueq_config, + ) + + # Convolution weights + self.linear_down = Linear( + self.node_feats_irreps, + self.node_feats_down_irreps, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + input_dim = ( + self.edge_feats_irreps.num_irreps + + 2 * self.node_feats_down_irreps.num_irreps + ) + self.conv_tp_weights = nn.FullyConnectedNet( + [input_dim] + 3 * [256] + [self.conv_tp.weight_numel], + torch.nn.functional.silu, + ) + + # Linear + self.irreps_out = self.target_irreps + self.linear = Linear( + irreps_mid, + self.irreps_out, + internal_weights=True, + shared_weights=True, + cueq_config=self.cueq_config, + ) + + self.reshape = reshape_irreps(self.irreps_out, cueq_config=self.cueq_config) + + # Skip connection. + self.skip_linear = Linear( + self.node_feats_irreps, self.hidden_irreps, cueq_config=self.cueq_config + ) + + # pylint: disable=unused-argument + def forward( + self, + node_attrs: torch.Tensor, + node_feats: torch.Tensor, + edge_attrs: torch.Tensor, + edge_feats: torch.Tensor, + edge_index: torch.Tensor, + lammps_class: Optional[Any] = None, + lammps_natoms: Tuple[int, int] = (0, 0), + first_layer: bool = False, + ) -> Tuple[torch.Tensor, None]: + sender = edge_index[0] + receiver = edge_index[1] + num_nodes = node_feats.shape[0] + sc = self.skip_linear(node_feats) + node_feats_up = self.linear_up(node_feats) + node_feats_down = self.linear_down(node_feats) + augmented_edge_feats = torch.cat( + [ + edge_feats, + node_feats_down[sender], + node_feats_down[receiver], + ], + dim=-1, + ) + tp_weights = self.conv_tp_weights(augmented_edge_feats) + mji = self.conv_tp( + node_feats_up[sender], edge_attrs, tp_weights + ) # [n_edges, irreps] + message = scatter_sum( + src=mji, index=receiver, dim=0, dim_size=num_nodes + ) # [n_nodes, irreps] + message = self.linear(message) / self.avg_num_neighbors + return ( + self.reshape(message), + sc, + ) # [n_nodes, channels, (lmax + 1)**2] + + +@compile_mode("script") +class ScaleShiftBlock(torch.nn.Module): + def __init__(self, scale: float, shift: float): + super().__init__() + self.register_buffer( + "scale", + torch.tensor(scale, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "shift", + torch.tensor(shift, dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor, head: torch.Tensor) -> torch.Tensor: + return ( + torch.atleast_1d(self.scale)[head] * x + torch.atleast_1d(self.shift)[head] + ) + + def __repr__(self): + formatted_scale = ( + ", ".join([f"{x:.4f}" for x in self.scale]) + if self.scale.numel() > 1 + else f"{self.scale.item():.4f}" + ) + formatted_shift = ( + ", ".join([f"{x:.4f}" for x in self.shift]) + if self.shift.numel() > 1 + else f"{self.shift.item():.4f}" + ) + return f"{self.__class__.__name__}(scale={formatted_scale}, shift={formatted_shift})" diff --git a/mace-bench/3rdparty/mace/mace/modules/irreps_tools.py b/mace-bench/3rdparty/mace/mace/modules/irreps_tools.py index c338801..6677b1b 100644 --- a/mace-bench/3rdparty/mace/mace/modules/irreps_tools.py +++ b/mace-bench/3rdparty/mace/mace/modules/irreps_tools.py @@ -1,116 +1,116 @@ -########################################################################################### -# Elementary tools for handling irreducible representations -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import List, Optional, Tuple - -import torch -from e3nn import o3 -from e3nn.util.jit import compile_mode - -from mace.modules.wrapper_ops import CuEquivarianceConfig - - -# Based on mir-group/nequip -def tp_out_irreps_with_instructions( - irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps -) -> Tuple[o3.Irreps, List]: - trainable = True - - # Collect possible irreps and their instructions - irreps_out_list: List[Tuple[int, o3.Irreps]] = [] - instructions = [] - for i, (mul, ir_in) in enumerate(irreps1): - for j, (_, ir_edge) in enumerate(irreps2): - for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 - if ir_out in target_irreps: - k = len(irreps_out_list) # instruction index - irreps_out_list.append((mul, ir_out)) - instructions.append((i, j, k, "uvu", trainable)) - - # We sort the output irreps of the tensor product so that we can simplify them - # when they are provided to the second o3.Linear - irreps_out = o3.Irreps(irreps_out_list) - irreps_out, permut, _ = irreps_out.sort() - - # Permute the output indexes of the instructions to match the sorted irreps: - instructions = [ - (i_in1, i_in2, permut[i_out], mode, train) - for i_in1, i_in2, i_out, mode, train in instructions - ] - - instructions = sorted(instructions, key=lambda x: x[2]) - - return irreps_out, instructions - - -def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: - # Assuming simplified irreps - irreps_mid = [] - for _, ir_in in irreps: - found = False - - for mul, ir_out in target_irreps: - if ir_in == ir_out: - irreps_mid.append((mul, ir_out)) - found = True - break - - if not found: - raise RuntimeError(f"{ir_in} not in {target_irreps}") - - return o3.Irreps(irreps_mid) - - -@compile_mode("script") -class reshape_irreps(torch.nn.Module): - def __init__( - self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None - ) -> None: - super().__init__() - self.irreps = o3.Irreps(irreps) - self.cueq_config = cueq_config - self.dims = [] - self.muls = [] - for mul, ir in self.irreps: - d = ir.dim - self.dims.append(d) - self.muls.append(mul) - - def forward(self, tensor: torch.Tensor) -> torch.Tensor: - ix = 0 - out = [] - batch, _ = tensor.shape - for mul, d in zip(self.muls, self.dims): - field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] - ix += mul * d - if hasattr(self, "cueq_config"): - if self.cueq_config is not None: - if self.cueq_config.layout_str == "mul_ir": - field = field.reshape(batch, mul, d) - else: - field = field.reshape(batch, d, mul) - else: - field = field.reshape(batch, mul, d) - else: - field = field.reshape(batch, mul, d) - out.append(field) - - if hasattr(self, "cueq_config"): - if self.cueq_config is not None: # pylint: disable=no-else-return - if self.cueq_config.layout_str == "mul_ir": - return torch.cat(out, dim=-1) - return torch.cat(out, dim=-2) - else: - return torch.cat(out, dim=-1) - return torch.cat(out, dim=-1) - - -def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: - mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) - idx = torch.arange(mask.shape[0], device=x.device) - mask[idx, :, head] = 1 - mask = mask.permute(0, 2, 1).reshape(x.shape) - return x * mask +########################################################################################### +# Elementary tools for handling irreducible representations +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import List, Optional, Tuple + +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + +from mace.modules.wrapper_ops import CuEquivarianceConfig + + +# Based on mir-group/nequip +def tp_out_irreps_with_instructions( + irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps +) -> Tuple[o3.Irreps, List]: + trainable = True + + # Collect possible irreps and their instructions + irreps_out_list: List[Tuple[int, o3.Irreps]] = [] + instructions = [] + for i, (mul, ir_in) in enumerate(irreps1): + for j, (_, ir_edge) in enumerate(irreps2): + for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 + if ir_out in target_irreps: + k = len(irreps_out_list) # instruction index + irreps_out_list.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", trainable)) + + # We sort the output irreps of the tensor product so that we can simplify them + # when they are provided to the second o3.Linear + irreps_out = o3.Irreps(irreps_out_list) + irreps_out, permut, _ = irreps_out.sort() + + # Permute the output indexes of the instructions to match the sorted irreps: + instructions = [ + (i_in1, i_in2, permut[i_out], mode, train) + for i_in1, i_in2, i_out, mode, train in instructions + ] + + instructions = sorted(instructions, key=lambda x: x[2]) + + return irreps_out, instructions + + +def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: + # Assuming simplified irreps + irreps_mid = [] + for _, ir_in in irreps: + found = False + + for mul, ir_out in target_irreps: + if ir_in == ir_out: + irreps_mid.append((mul, ir_out)) + found = True + break + + if not found: + raise RuntimeError(f"{ir_in} not in {target_irreps}") + + return o3.Irreps(irreps_mid) + + +@compile_mode("script") +class reshape_irreps(torch.nn.Module): + def __init__( + self, irreps: o3.Irreps, cueq_config: Optional[CuEquivarianceConfig] = None + ) -> None: + super().__init__() + self.irreps = o3.Irreps(irreps) + self.cueq_config = cueq_config + self.dims = [] + self.muls = [] + for mul, ir in self.irreps: + d = ir.dim + self.dims.append(d) + self.muls.append(mul) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + ix = 0 + out = [] + batch, _ = tensor.shape + for mul, d in zip(self.muls, self.dims): + field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] + ix += mul * d + if hasattr(self, "cueq_config"): + if self.cueq_config is not None: + if self.cueq_config.layout_str == "mul_ir": + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, d, mul) + else: + field = field.reshape(batch, mul, d) + else: + field = field.reshape(batch, mul, d) + out.append(field) + + if hasattr(self, "cueq_config"): + if self.cueq_config is not None: # pylint: disable=no-else-return + if self.cueq_config.layout_str == "mul_ir": + return torch.cat(out, dim=-1) + return torch.cat(out, dim=-2) + else: + return torch.cat(out, dim=-1) + return torch.cat(out, dim=-1) + + +def mask_head(x: torch.Tensor, head: torch.Tensor, num_heads: int) -> torch.Tensor: + mask = torch.zeros(x.shape[0], x.shape[1] // num_heads, num_heads, device=x.device) + idx = torch.arange(mask.shape[0], device=x.device) + mask[idx, :, head] = 1 + mask = mask.permute(0, 2, 1).reshape(x.shape) + return x * mask diff --git a/mace-bench/3rdparty/mace/mace/modules/loss.py b/mace-bench/3rdparty/mace/mace/modules/loss.py index 19ad76a..ff567e3 100644 --- a/mace-bench/3rdparty/mace/mace/modules/loss.py +++ b/mace-bench/3rdparty/mace/mace/modules/loss.py @@ -1,566 +1,566 @@ -########################################################################################### -# Implementation of different loss functions -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import Optional - -import torch -import torch.distributed as dist - -from mace.tools import TensorDict -from mace.tools.torch_geometric import Batch - - -# ------------------------------------------------------------------------------ -# Helper function for loss reduction that handles DDP correction -# ------------------------------------------------------------------------------ -def is_ddp_enabled(): - return dist.is_initialized() and dist.get_world_size() > 1 - - -def reduce_loss(raw_loss: torch.Tensor, ddp: Optional[bool] = None) -> torch.Tensor: - """ - Reduces an element-wise loss tensor. - - If ddp is True and distributed is initialized, the function computes: - - loss = (local_sum * world_size) / global_num_elements - - Otherwise, it returns the regular mean. - """ - ddp = is_ddp_enabled() if ddp is None else ddp - if ddp and dist.is_initialized(): - world_size = dist.get_world_size() - n_local = raw_loss.numel() - loss_sum = raw_loss.sum() - total_samples = torch.tensor( - n_local, device=raw_loss.device, dtype=raw_loss.dtype - ) - dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) - return loss_sum * world_size / total_samples - return raw_loss.mean() - - -# ------------------------------------------------------------------------------ -# Energy Loss Functions -# ------------------------------------------------------------------------------ - - -def mean_squared_error_energy( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - raw_loss = torch.square(ref["energy"] - pred["energy"]) - return reduce_loss(raw_loss, ddp) - - -def weighted_mean_squared_error_energy( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - # Calculate per-graph number of atoms. - num_atoms = ref.ptr[1:] - ref.ptr[:-1] # shape: [n_graphs] - raw_loss = ( - ref.weight - * ref.energy_weight - * torch.square((ref["energy"] - pred["energy"]) / num_atoms) - ) - return reduce_loss(raw_loss, ddp) - - -def weighted_mean_absolute_error_energy( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - num_atoms = ref.ptr[1:] - ref.ptr[:-1] - raw_loss = ( - ref.weight - * ref.energy_weight - * torch.abs((ref["energy"] - pred["energy"]) / num_atoms) - ) - return reduce_loss(raw_loss, ddp) - - -# ------------------------------------------------------------------------------ -# Stress and Virials Loss Functions -# ------------------------------------------------------------------------------ - - -def weighted_mean_squared_stress( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - configs_weight = ref.weight.view(-1, 1, 1) - configs_stress_weight = ref.stress_weight.view(-1, 1, 1) - raw_loss = ( - configs_weight - * configs_stress_weight - * torch.square(ref["stress"] - pred["stress"]) - ) - return reduce_loss(raw_loss, ddp) - - -def weighted_mean_squared_virials( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - configs_weight = ref.weight.view(-1, 1, 1) - configs_virials_weight = ref.virials_weight.view(-1, 1, 1) - num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) - raw_loss = ( - configs_weight - * configs_virials_weight - * torch.square((ref["virials"] - pred["virials"]) / num_atoms) - ) - return reduce_loss(raw_loss, ddp) - - -# ------------------------------------------------------------------------------ -# Forces Loss Functions -# ------------------------------------------------------------------------------ - - -def mean_squared_error_forces( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - # Repeat per-graph weights to per-atom level. - configs_weight = torch.repeat_interleave( - ref.weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - configs_forces_weight = torch.repeat_interleave( - ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - raw_loss = ( - configs_weight - * configs_forces_weight - * torch.square(ref["forces"] - pred["forces"]) - ) - return reduce_loss(raw_loss, ddp) - - -def mean_normed_error_forces( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - raw_loss = torch.linalg.vector_norm(ref["forces"] - pred["forces"], ord=2, dim=-1) - return reduce_loss(raw_loss, ddp) - - -# ------------------------------------------------------------------------------ -# Dipole Loss Function -# ------------------------------------------------------------------------------ - - -def weighted_mean_squared_error_dipole( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) - raw_loss = torch.square((ref["dipole"] - pred["dipole"]) / num_atoms) - return reduce_loss(raw_loss, ddp) - - -# ------------------------------------------------------------------------------ -# Conditional Losses for Forces -# ------------------------------------------------------------------------------ - - -def conditional_mse_forces( - ref: Batch, pred: TensorDict, ddp: Optional[bool] = None -) -> torch.Tensor: - configs_weight = torch.repeat_interleave( - ref.weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - configs_forces_weight = torch.repeat_interleave( - ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - # Define multiplication factors for different regimes. - factors = torch.tensor( - [1.0, 0.7, 0.4, 0.1], device=ref["forces"].device, dtype=ref["forces"].dtype - ) - err = ref["forces"] - pred["forces"] - se = torch.zeros_like(err) - norm_forces = torch.norm(ref["forces"], dim=-1) - c1 = norm_forces < 100 - c2 = (norm_forces >= 100) & (norm_forces < 200) - c3 = (norm_forces >= 200) & (norm_forces < 300) - se[c1] = torch.square(err[c1]) * factors[0] - se[c2] = torch.square(err[c2]) * factors[1] - se[c3] = torch.square(err[c3]) * factors[2] - se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] - raw_loss = configs_weight * configs_forces_weight * se - return reduce_loss(raw_loss, ddp) - - -def conditional_huber_forces( - ref_forces: torch.Tensor, - pred_forces: torch.Tensor, - huber_delta: float, - ddp: Optional[bool] = None, -) -> torch.Tensor: - factors = huber_delta * torch.tensor( - [1.0, 0.7, 0.4, 0.1], device=ref_forces.device, dtype=ref_forces.dtype - ) - norm_forces = torch.norm(ref_forces, dim=-1) - c1 = norm_forces < 100 - c2 = (norm_forces >= 100) & (norm_forces < 200) - c3 = (norm_forces >= 200) & (norm_forces < 300) - c4 = ~(c1 | c2 | c3) - se = torch.zeros_like(pred_forces) - se[c1] = torch.nn.functional.huber_loss( - ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0] - ) - se[c2] = torch.nn.functional.huber_loss( - ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1] - ) - se[c3] = torch.nn.functional.huber_loss( - ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2] - ) - se[c4] = torch.nn.functional.huber_loss( - ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3] - ) - return reduce_loss(se, ddp) - - -# ------------------------------------------------------------------------------ -# Loss Modules Combining Multiple Quantities -# ------------------------------------------------------------------------------ - - -class WeightedEnergyForcesLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) - loss_forces = mean_squared_error_forces(ref, pred, ddp) - return self.energy_weight * loss_energy + self.forces_weight * loss_forces - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f})" - ) - - -class WeightedForcesLoss(torch.nn.Module): - def __init__(self, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_forces = mean_squared_error_forces(ref, pred, ddp) - return self.forces_weight * loss_forces - - def __repr__(self): - return f"{self.__class__.__name__}(forces_weight={self.forces_weight:.3f})" - - -class WeightedEnergyForcesStressLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) - loss_forces = mean_squared_error_forces(ref, pred, ddp) - loss_stress = weighted_mean_squared_stress(ref, pred, ddp) - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.stress_weight * loss_stress - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 - ) -> None: - super().__init__() - # We store the huber_delta rather than a loss with fixed reduction. - self.huber_delta = huber_delta - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - num_atoms = ref.ptr[1:] - ref.ptr[:-1] - if ddp: - loss_energy = torch.nn.functional.huber_loss( - ref["energy"] / num_atoms, - pred["energy"] / num_atoms, - reduction="none", - delta=self.huber_delta, - ) - loss_energy = reduce_loss(loss_energy, ddp) - loss_forces = torch.nn.functional.huber_loss( - ref["forces"], pred["forces"], reduction="none", delta=self.huber_delta - ) - loss_forces = reduce_loss(loss_forces, ddp) - loss_stress = torch.nn.functional.huber_loss( - ref["stress"], pred["stress"], reduction="none", delta=self.huber_delta - ) - loss_stress = reduce_loss(loss_stress, ddp) - else: - loss_energy = torch.nn.functional.huber_loss( - ref["energy"] / num_atoms, - pred["energy"] / num_atoms, - reduction="mean", - delta=self.huber_delta, - ) - loss_forces = torch.nn.functional.huber_loss( - ref["forces"], pred["forces"], reduction="mean", delta=self.huber_delta - ) - loss_stress = torch.nn.functional.huber_loss( - ref["stress"], pred["stress"], reduction="mean", delta=self.huber_delta - ) - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.stress_weight * loss_stress - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class UniversalLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 - ) -> None: - super().__init__() - self.huber_delta = huber_delta - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "stress_weight", - torch.tensor(stress_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - num_atoms = ref.ptr[1:] - ref.ptr[:-1] - configs_stress_weight = ref.stress_weight.view(-1, 1, 1) - configs_energy_weight = ref.energy_weight - configs_forces_weight = torch.repeat_interleave( - ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] - ).unsqueeze(-1) - if ddp: - loss_energy = torch.nn.functional.huber_loss( - configs_energy_weight * ref["energy"] / num_atoms, - configs_energy_weight * pred["energy"] / num_atoms, - reduction="none", - delta=self.huber_delta, - ) - loss_energy = reduce_loss(loss_energy, ddp) - loss_forces = conditional_huber_forces( - configs_forces_weight * ref["forces"], - configs_forces_weight * pred["forces"], - huber_delta=self.huber_delta, - ddp=ddp, - ) - loss_stress = torch.nn.functional.huber_loss( - configs_stress_weight * ref["stress"], - configs_stress_weight * pred["stress"], - reduction="none", - delta=self.huber_delta, - ) - loss_stress = reduce_loss(loss_stress, ddp) - else: - loss_energy = torch.nn.functional.huber_loss( - configs_energy_weight * ref["energy"] / num_atoms, - configs_energy_weight * pred["energy"] / num_atoms, - reduction="mean", - delta=self.huber_delta, - ) - loss_forces = conditional_huber_forces( - configs_forces_weight * ref["forces"], - configs_forces_weight * pred["forces"], - huber_delta=self.huber_delta, - ddp=ddp, - ) - loss_stress = torch.nn.functional.huber_loss( - configs_stress_weight * ref["stress"], - configs_stress_weight * pred["stress"], - reduction="mean", - delta=self.huber_delta, - ) - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.stress_weight * loss_stress - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" - ) - - -class WeightedEnergyForcesVirialsLoss(torch.nn.Module): - def __init__( - self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 - ) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "virials_weight", - torch.tensor(virials_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) - loss_forces = mean_squared_error_forces(ref, pred, ddp) - loss_virials = weighted_mean_squared_virials(ref, pred, ddp) - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.virials_weight * loss_virials - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" - ) - - -class DipoleSingleLoss(torch.nn.Module): - def __init__(self, dipole_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "dipole_weight", - torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss = ( - weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0 - ) # scale adjustment - return self.dipole_weight * loss - - def __repr__(self): - return f"{self.__class__.__name__}(dipole_weight={self.dipole_weight:.3f})" - - -class WeightedEnergyForcesDipoleLoss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "dipole_weight", - torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) - loss_forces = mean_squared_error_forces(ref, pred, ddp) - loss_dipole = weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0 - return ( - self.energy_weight * loss_energy - + self.forces_weight * loss_forces - + self.dipole_weight * loss_dipole - ) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" - ) - - -class WeightedEnergyForcesL1L2Loss(torch.nn.Module): - def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: - super().__init__() - self.register_buffer( - "energy_weight", - torch.tensor(energy_weight, dtype=torch.get_default_dtype()), - ) - self.register_buffer( - "forces_weight", - torch.tensor(forces_weight, dtype=torch.get_default_dtype()), - ) - - def forward( - self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None - ) -> torch.Tensor: - loss_energy = weighted_mean_absolute_error_energy(ref, pred, ddp) - loss_forces = mean_normed_error_forces(ref, pred, ddp) - return self.energy_weight * loss_energy + self.forces_weight * loss_forces - - def __repr__(self): - return ( - f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " - f"forces_weight={self.forces_weight:.3f})" - ) +########################################################################################### +# Implementation of different loss functions +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Optional + +import torch +import torch.distributed as dist + +from mace.tools import TensorDict +from mace.tools.torch_geometric import Batch + + +# ------------------------------------------------------------------------------ +# Helper function for loss reduction that handles DDP correction +# ------------------------------------------------------------------------------ +def is_ddp_enabled(): + return dist.is_initialized() and dist.get_world_size() > 1 + + +def reduce_loss(raw_loss: torch.Tensor, ddp: Optional[bool] = None) -> torch.Tensor: + """ + Reduces an element-wise loss tensor. + + If ddp is True and distributed is initialized, the function computes: + + loss = (local_sum * world_size) / global_num_elements + + Otherwise, it returns the regular mean. + """ + ddp = is_ddp_enabled() if ddp is None else ddp + if ddp and dist.is_initialized(): + world_size = dist.get_world_size() + n_local = raw_loss.numel() + loss_sum = raw_loss.sum() + total_samples = torch.tensor( + n_local, device=raw_loss.device, dtype=raw_loss.dtype + ) + dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) + return loss_sum * world_size / total_samples + return raw_loss.mean() + + +# ------------------------------------------------------------------------------ +# Energy Loss Functions +# ------------------------------------------------------------------------------ + + +def mean_squared_error_energy( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + raw_loss = torch.square(ref["energy"] - pred["energy"]) + return reduce_loss(raw_loss, ddp) + + +def weighted_mean_squared_error_energy( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + # Calculate per-graph number of atoms. + num_atoms = ref.ptr[1:] - ref.ptr[:-1] # shape: [n_graphs] + raw_loss = ( + ref.weight + * ref.energy_weight + * torch.square((ref["energy"] - pred["energy"]) / num_atoms) + ) + return reduce_loss(raw_loss, ddp) + + +def weighted_mean_absolute_error_energy( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + raw_loss = ( + ref.weight + * ref.energy_weight + * torch.abs((ref["energy"] - pred["energy"]) / num_atoms) + ) + return reduce_loss(raw_loss, ddp) + + +# ------------------------------------------------------------------------------ +# Stress and Virials Loss Functions +# ------------------------------------------------------------------------------ + + +def weighted_mean_squared_stress( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + configs_weight = ref.weight.view(-1, 1, 1) + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) + raw_loss = ( + configs_weight + * configs_stress_weight + * torch.square(ref["stress"] - pred["stress"]) + ) + return reduce_loss(raw_loss, ddp) + + +def weighted_mean_squared_virials( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + configs_weight = ref.weight.view(-1, 1, 1) + configs_virials_weight = ref.virials_weight.view(-1, 1, 1) + num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) + raw_loss = ( + configs_weight + * configs_virials_weight + * torch.square((ref["virials"] - pred["virials"]) / num_atoms) + ) + return reduce_loss(raw_loss, ddp) + + +# ------------------------------------------------------------------------------ +# Forces Loss Functions +# ------------------------------------------------------------------------------ + + +def mean_squared_error_forces( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + # Repeat per-graph weights to per-atom level. + configs_weight = torch.repeat_interleave( + ref.weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + raw_loss = ( + configs_weight + * configs_forces_weight + * torch.square(ref["forces"] - pred["forces"]) + ) + return reduce_loss(raw_loss, ddp) + + +def mean_normed_error_forces( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + raw_loss = torch.linalg.vector_norm(ref["forces"] - pred["forces"], ord=2, dim=-1) + return reduce_loss(raw_loss, ddp) + + +# ------------------------------------------------------------------------------ +# Dipole Loss Function +# ------------------------------------------------------------------------------ + + +def weighted_mean_squared_error_dipole( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).unsqueeze(-1) + raw_loss = torch.square((ref["dipole"] - pred["dipole"]) / num_atoms) + return reduce_loss(raw_loss, ddp) + + +# ------------------------------------------------------------------------------ +# Conditional Losses for Forces +# ------------------------------------------------------------------------------ + + +def conditional_mse_forces( + ref: Batch, pred: TensorDict, ddp: Optional[bool] = None +) -> torch.Tensor: + configs_weight = torch.repeat_interleave( + ref.weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + # Define multiplication factors for different regimes. + factors = torch.tensor( + [1.0, 0.7, 0.4, 0.1], device=ref["forces"].device, dtype=ref["forces"].dtype + ) + err = ref["forces"] - pred["forces"] + se = torch.zeros_like(err) + norm_forces = torch.norm(ref["forces"], dim=-1) + c1 = norm_forces < 100 + c2 = (norm_forces >= 100) & (norm_forces < 200) + c3 = (norm_forces >= 200) & (norm_forces < 300) + se[c1] = torch.square(err[c1]) * factors[0] + se[c2] = torch.square(err[c2]) * factors[1] + se[c3] = torch.square(err[c3]) * factors[2] + se[~(c1 | c2 | c3)] = torch.square(err[~(c1 | c2 | c3)]) * factors[3] + raw_loss = configs_weight * configs_forces_weight * se + return reduce_loss(raw_loss, ddp) + + +def conditional_huber_forces( + ref_forces: torch.Tensor, + pred_forces: torch.Tensor, + huber_delta: float, + ddp: Optional[bool] = None, +) -> torch.Tensor: + factors = huber_delta * torch.tensor( + [1.0, 0.7, 0.4, 0.1], device=ref_forces.device, dtype=ref_forces.dtype + ) + norm_forces = torch.norm(ref_forces, dim=-1) + c1 = norm_forces < 100 + c2 = (norm_forces >= 100) & (norm_forces < 200) + c3 = (norm_forces >= 200) & (norm_forces < 300) + c4 = ~(c1 | c2 | c3) + se = torch.zeros_like(pred_forces) + se[c1] = torch.nn.functional.huber_loss( + ref_forces[c1], pred_forces[c1], reduction="none", delta=factors[0] + ) + se[c2] = torch.nn.functional.huber_loss( + ref_forces[c2], pred_forces[c2], reduction="none", delta=factors[1] + ) + se[c3] = torch.nn.functional.huber_loss( + ref_forces[c3], pred_forces[c3], reduction="none", delta=factors[2] + ) + se[c4] = torch.nn.functional.huber_loss( + ref_forces[c4], pred_forces[c4], reduction="none", delta=factors[3] + ) + return reduce_loss(se, ddp) + + +# ------------------------------------------------------------------------------ +# Loss Modules Combining Multiple Quantities +# ------------------------------------------------------------------------------ + + +class WeightedEnergyForcesLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) + loss_forces = mean_squared_error_forces(ref, pred, ddp) + return self.energy_weight * loss_energy + self.forces_weight * loss_forces + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f})" + ) + + +class WeightedForcesLoss(torch.nn.Module): + def __init__(self, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_forces = mean_squared_error_forces(ref, pred, ddp) + return self.forces_weight * loss_forces + + def __repr__(self): + return f"{self.__class__.__name__}(forces_weight={self.forces_weight:.3f})" + + +class WeightedEnergyForcesStressLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) + loss_forces = mean_squared_error_forces(ref, pred, ddp) + loss_stress = weighted_mean_squared_stress(ref, pred, ddp) + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.stress_weight * loss_stress + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class WeightedHuberEnergyForcesStressLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + ) -> None: + super().__init__() + # We store the huber_delta rather than a loss with fixed reduction. + self.huber_delta = huber_delta + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + if ddp: + loss_energy = torch.nn.functional.huber_loss( + ref["energy"] / num_atoms, + pred["energy"] / num_atoms, + reduction="none", + delta=self.huber_delta, + ) + loss_energy = reduce_loss(loss_energy, ddp) + loss_forces = torch.nn.functional.huber_loss( + ref["forces"], pred["forces"], reduction="none", delta=self.huber_delta + ) + loss_forces = reduce_loss(loss_forces, ddp) + loss_stress = torch.nn.functional.huber_loss( + ref["stress"], pred["stress"], reduction="none", delta=self.huber_delta + ) + loss_stress = reduce_loss(loss_stress, ddp) + else: + loss_energy = torch.nn.functional.huber_loss( + ref["energy"] / num_atoms, + pred["energy"] / num_atoms, + reduction="mean", + delta=self.huber_delta, + ) + loss_forces = torch.nn.functional.huber_loss( + ref["forces"], pred["forces"], reduction="mean", delta=self.huber_delta + ) + loss_stress = torch.nn.functional.huber_loss( + ref["stress"], pred["stress"], reduction="mean", delta=self.huber_delta + ) + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.stress_weight * loss_stress + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class UniversalLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, stress_weight=1.0, huber_delta=0.01 + ) -> None: + super().__init__() + self.huber_delta = huber_delta + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "stress_weight", + torch.tensor(stress_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + num_atoms = ref.ptr[1:] - ref.ptr[:-1] + configs_stress_weight = ref.stress_weight.view(-1, 1, 1) + configs_energy_weight = ref.energy_weight + configs_forces_weight = torch.repeat_interleave( + ref.forces_weight, ref.ptr[1:] - ref.ptr[:-1] + ).unsqueeze(-1) + if ddp: + loss_energy = torch.nn.functional.huber_loss( + configs_energy_weight * ref["energy"] / num_atoms, + configs_energy_weight * pred["energy"] / num_atoms, + reduction="none", + delta=self.huber_delta, + ) + loss_energy = reduce_loss(loss_energy, ddp) + loss_forces = conditional_huber_forces( + configs_forces_weight * ref["forces"], + configs_forces_weight * pred["forces"], + huber_delta=self.huber_delta, + ddp=ddp, + ) + loss_stress = torch.nn.functional.huber_loss( + configs_stress_weight * ref["stress"], + configs_stress_weight * pred["stress"], + reduction="none", + delta=self.huber_delta, + ) + loss_stress = reduce_loss(loss_stress, ddp) + else: + loss_energy = torch.nn.functional.huber_loss( + configs_energy_weight * ref["energy"] / num_atoms, + configs_energy_weight * pred["energy"] / num_atoms, + reduction="mean", + delta=self.huber_delta, + ) + loss_forces = conditional_huber_forces( + configs_forces_weight * ref["forces"], + configs_forces_weight * pred["forces"], + huber_delta=self.huber_delta, + ddp=ddp, + ) + loss_stress = torch.nn.functional.huber_loss( + configs_stress_weight * ref["stress"], + configs_stress_weight * pred["stress"], + reduction="mean", + delta=self.huber_delta, + ) + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.stress_weight * loss_stress + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, stress_weight={self.stress_weight:.3f})" + ) + + +class WeightedEnergyForcesVirialsLoss(torch.nn.Module): + def __init__( + self, energy_weight=1.0, forces_weight=1.0, virials_weight=1.0 + ) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "virials_weight", + torch.tensor(virials_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) + loss_forces = mean_squared_error_forces(ref, pred, ddp) + loss_virials = weighted_mean_squared_virials(ref, pred, ddp) + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.virials_weight * loss_virials + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, virials_weight={self.virials_weight:.3f})" + ) + + +class DipoleSingleLoss(torch.nn.Module): + def __init__(self, dipole_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "dipole_weight", + torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss = ( + weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0 + ) # scale adjustment + return self.dipole_weight * loss + + def __repr__(self): + return f"{self.__class__.__name__}(dipole_weight={self.dipole_weight:.3f})" + + +class WeightedEnergyForcesDipoleLoss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0, dipole_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "dipole_weight", + torch.tensor(dipole_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_squared_error_energy(ref, pred, ddp) + loss_forces = mean_squared_error_forces(ref, pred, ddp) + loss_dipole = weighted_mean_squared_error_dipole(ref, pred, ddp) * 100.0 + return ( + self.energy_weight * loss_energy + + self.forces_weight * loss_forces + + self.dipole_weight * loss_dipole + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" + ) + + +class WeightedEnergyForcesL1L2Loss(torch.nn.Module): + def __init__(self, energy_weight=1.0, forces_weight=1.0) -> None: + super().__init__() + self.register_buffer( + "energy_weight", + torch.tensor(energy_weight, dtype=torch.get_default_dtype()), + ) + self.register_buffer( + "forces_weight", + torch.tensor(forces_weight, dtype=torch.get_default_dtype()), + ) + + def forward( + self, ref: Batch, pred: TensorDict, ddp: Optional[bool] = None + ) -> torch.Tensor: + loss_energy = weighted_mean_absolute_error_energy(ref, pred, ddp) + loss_forces = mean_normed_error_forces(ref, pred, ddp) + return self.energy_weight * loss_energy + self.forces_weight * loss_forces + + def __repr__(self): + return ( + f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " + f"forces_weight={self.forces_weight:.3f})" + ) diff --git a/mace-bench/3rdparty/mace/mace/modules/models.py b/mace-bench/3rdparty/mace/mace/modules/models.py index c6ba5bc..b551f8b 100644 --- a/mace-bench/3rdparty/mace/mace/modules/models.py +++ b/mace-bench/3rdparty/mace/mace/modules/models.py @@ -1,947 +1,947 @@ -########################################################################################### -# Implementation of MACE models and other models based E(3)-Equivariant MPNNs -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import Any, Callable, Dict, List, Optional, Type, Union - -import numpy as np -import torch -from e3nn import o3 -from e3nn.util.jit import compile_mode - -from mace.modules.radial import ZBLBasis -from mace.tools.scatter import scatter_sum - -from .blocks import ( - AtomicEnergiesBlock, - EquivariantProductBasisBlock, - InteractionBlock, - LinearDipoleReadoutBlock, - LinearNodeEmbeddingBlock, - LinearReadoutBlock, - NonLinearDipoleReadoutBlock, - NonLinearReadoutBlock, - RadialEmbeddingBlock, - ScaleShiftBlock, -) -from .utils import ( - compute_fixed_charge_dipole, - get_atomic_virials_stresses, - get_edge_vectors_and_lengths, - get_outputs, - get_symmetric_displacement, - prepare_graph, -) - -# pylint: disable=C0302 - - -@compile_mode("script") -class MACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - atomic_energies: np.ndarray, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: Union[int, List[int]], - gate: Optional[Callable], - pair_repulsion: bool = False, - distance_transform: str = "None", - radial_MLP: Optional[List[int]] = None, - radial_type: Optional[str] = "bessel", - heads: Optional[List[str]] = None, - cueq_config: Optional[Dict[str, Any]] = None, - lammps_mliap: Optional[bool] = False, - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - if heads is None: - heads = ["Default"] - self.heads = heads - if isinstance(correlation, int): - correlation = [correlation] * num_interactions - self.lammps_mliap = lammps_mliap - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, - irreps_out=node_feats_irreps, - cueq_config=cueq_config, - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - radial_type=radial_type, - distance_transform=distance_transform, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - if pair_repulsion: - self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) - self.pair_repulsion = True - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readout - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - cueq_config=cueq_config, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer for proper E0 - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation[0], - num_elements=num_elements, - use_sc=use_sc_first, - cueq_config=cueq_config, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append( - LinearReadoutBlock( - hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config - ) - ) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - hidden_irreps_out = str( - hidden_irreps[0] - ) # Select only scalars for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - cueq_config=cueq_config, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation[i + 1], - num_elements=num_elements, - use_sc=True, - cueq_config=cueq_config, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearReadoutBlock( - hidden_irreps_out, - (len(heads) * MLP_irreps).simplify(), - gate, - o3.Irreps(f"{len(heads)}x0e"), - len(heads), - cueq_config, - ) - ) - else: - self.readouts.append( - LinearReadoutBlock( - hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config - ) - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_hessian: bool = False, - compute_edge_forces: bool = False, - compute_atomic_stresses: bool = False, - lammps_mliap: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - ctx = prepare_graph( - data, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_displacement=compute_displacement, - lammps_mliap=lammps_mliap, - ) - is_lammps = ctx.is_lammps - num_atoms_arange = ctx.num_atoms_arange - num_graphs = ctx.num_graphs - displacement = ctx.displacement - positions = ctx.positions - vectors = ctx.vectors - lengths = ctx.lengths - cell = ctx.cell - node_heads = ctx.node_heads - interaction_kwargs = ctx.interaction_kwargs - lammps_natoms = interaction_kwargs.lammps_natoms - lammps_class = interaction_kwargs.lammps_class - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, node_heads - ] - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs - ) # [n_graphs, n_heads] - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if hasattr(self, "pair_repulsion"): - pair_node_energy = self.pair_repulsion_fn( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if is_lammps: - pair_node_energy = pair_node_energy[: lammps_natoms[0]] - pair_energy = scatter_sum( - src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - else: - pair_node_energy = torch.zeros_like(node_e0) - pair_energy = torch.zeros_like(e0) - - # Interactions - energies = [e0, pair_energy] - node_energies_list = [node_e0, pair_node_energy] - node_feats_concat: List[torch.Tensor] = [] - - for i, (interaction, product, readout) in enumerate( - zip(self.interactions, self.products, self.readouts) - ): - node_attrs_slice = data["node_attrs"] - if is_lammps and i > 0: - node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] - node_feats, sc = interaction( - node_attrs=node_attrs_slice, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - first_layer=(i == 0), - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - ) - if is_lammps and i == 0: - node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice - ) - node_feats_concat.append(node_feats) - node_es = readout(node_feats, node_heads)[num_atoms_arange, node_heads] - energy = scatter_sum(node_es, data["batch"], dim=0, dim_size=num_graphs) - energies.append(energy) - node_energies_list.append(node_es) - - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) - node_energy = torch.sum(torch.stack(node_energies_list, dim=-1), dim=-1) - node_feats_out = torch.cat(node_feats_concat, dim=-1) - node_energy = node_e0.double() + pair_node_energy.double() - - forces, virials, stress, hessian, edge_forces = get_outputs( - energy=total_energy, - positions=positions, - displacement=displacement, - vectors=vectors, - cell=cell, - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_hessian=compute_hessian, - compute_edge_forces=compute_edge_forces, - ) - - atomic_virials: Optional[torch.Tensor] = None - atomic_stresses: Optional[torch.Tensor] = None - if compute_atomic_stresses and edge_forces is not None: - atomic_virials, atomic_stresses = get_atomic_virials_stresses( - edge_forces=edge_forces, - edge_index=data["edge_index"], - vectors=vectors, - num_atoms=positions.shape[0], - batch=data["batch"], - cell=cell, - ) - return { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "edge_forces": edge_forces, - "virials": virials, - "stress": stress, - "atomic_virials": atomic_virials, - "atomic_stresses": atomic_stresses, - "displacement": displacement, - "hessian": hessian, - "node_feats": node_feats_out, - } - - -@compile_mode("script") -class ScaleShiftMACE(MACE): - def __init__( - self, - atomic_inter_scale: float, - atomic_inter_shift: float, - **kwargs, - ): - super().__init__(**kwargs) - self.scale_shift = ScaleShiftBlock( - scale=atomic_inter_scale, shift=atomic_inter_shift - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_hessian: bool = False, - compute_edge_forces: bool = False, - compute_atomic_stresses: bool = False, - lammps_mliap: bool = False, - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - ctx = prepare_graph( - data, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_displacement=compute_displacement, - lammps_mliap=lammps_mliap, - ) - - is_lammps = ctx.is_lammps - num_atoms_arange = ctx.num_atoms_arange - num_graphs = ctx.num_graphs - displacement = ctx.displacement - positions = ctx.positions - vectors = ctx.vectors - lengths = ctx.lengths - cell = ctx.cell - node_heads = ctx.node_heads - interaction_kwargs = ctx.interaction_kwargs - lammps_natoms = interaction_kwargs.lammps_natoms - lammps_class = interaction_kwargs.lammps_class - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, node_heads - ] - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs - ) # [n_graphs, num_heads] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - if hasattr(self, "pair_repulsion"): - pair_node_energy = self.pair_repulsion_fn( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - if is_lammps: - pair_node_energy = pair_node_energy[: lammps_natoms[0]] - else: - pair_node_energy = torch.zeros_like(node_e0) - - # Interactions - node_es_list = [pair_node_energy] - node_feats_list: List[torch.Tensor] = [] - - for i, (interaction, product, readout) in enumerate( - zip(self.interactions, self.products, self.readouts) - ): - node_attrs_slice = data["node_attrs"] - if is_lammps and i > 0: - node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] - node_feats, sc = interaction( - node_attrs=node_attrs_slice, - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - first_layer=(i == 0), - lammps_class=lammps_class, - lammps_natoms=lammps_natoms, - ) - if is_lammps and i == 0: - node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] - node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice - ) - node_feats_list.append(node_feats) - node_es_list.append( - readout(node_feats, node_heads)[num_atoms_arange, node_heads] - ) - - node_feats_out = torch.cat(node_feats_list, dim=-1) - node_inter_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0) - node_inter_es = self.scale_shift(node_inter_es, node_heads) - inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs) - - total_energy = e0 + inter_e - node_energy = node_e0.clone().double() + node_inter_es.clone().double() - - forces, virials, stress, hessian, edge_forces = get_outputs( - energy=inter_e, - positions=positions, - displacement=displacement, - vectors=vectors, - cell=cell, - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - compute_hessian=compute_hessian, - compute_edge_forces=compute_edge_forces or compute_atomic_stresses, - ) - - atomic_virials: Optional[torch.Tensor] = None - atomic_stresses: Optional[torch.Tensor] = None - if compute_atomic_stresses and edge_forces is not None: - atomic_virials, atomic_stresses = get_atomic_virials_stresses( - edge_forces=edge_forces, - edge_index=data["edge_index"], - vectors=vectors, - num_atoms=positions.shape[0], - batch=data["batch"], - cell=cell, - ) - return { - "energy": total_energy, - "node_energy": node_energy, - "interaction_energy": inter_e, - "forces": forces, - "edge_forces": edge_forces, - "virials": virials, - "stress": stress, - "atomic_virials": atomic_virials, - "atomic_stresses": atomic_stresses, - "hessian": hessian, - "displacement": displacement, - "node_feats": node_feats_out, - } - - -@compile_mode("script") -class AtomicDipolesMACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[ - None - ], # Just here to make it compatible with energy models, MUST be None - radial_type: Optional[str] = "bessel", - radial_MLP: Optional[List[int]] = None, - cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - assert atomic_energies is None - - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - radial_type=radial_type, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - - # Interactions and readouts - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[1] - ) # Select only l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=True - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, # pylint: disable=W0613 - compute_force: bool = False, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_edge_forces: bool = False, # pylint: disable=W0613 - compute_atomic_stresses: bool = False, # pylint: disable=W0613 - ) -> Dict[str, Optional[torch.Tensor]]: - assert compute_force is False - assert compute_virials is False - assert compute_stress is False - assert compute_displacement is False - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - # Interactions - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] - dipoles.append(node_dipoles) - - # Compute the dipoles - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"], - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline - - output = { - "dipole": total_dipole, - "atomic_dipoles": atomic_dipoles, - } - return output - - -@compile_mode("script") -class EnergyDipolesMACE(torch.nn.Module): - def __init__( - self, - r_max: float, - num_bessel: int, - num_polynomial_cutoff: int, - max_ell: int, - interaction_cls: Type[InteractionBlock], - interaction_cls_first: Type[InteractionBlock], - num_interactions: int, - num_elements: int, - hidden_irreps: o3.Irreps, - MLP_irreps: o3.Irreps, - avg_num_neighbors: float, - atomic_numbers: List[int], - correlation: int, - gate: Optional[Callable], - atomic_energies: Optional[np.ndarray], - radial_MLP: Optional[List[int]] = None, - cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument - ): - super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) - self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) - self.register_buffer( - "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) - ) - # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) - self.node_embedding = LinearNodeEmbeddingBlock( - irreps_in=node_attr_irreps, irreps_out=node_feats_irreps - ) - self.radial_embedding = RadialEmbeddingBlock( - r_max=r_max, - num_bessel=num_bessel, - num_polynomial_cutoff=num_polynomial_cutoff, - ) - edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") - - sh_irreps = o3.Irreps.spherical_harmonics(max_ell) - num_features = hidden_irreps.count(o3.Irrep(0, 1)) - interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() - self.spherical_harmonics = o3.SphericalHarmonics( - sh_irreps, normalize=True, normalization="component" - ) - if radial_MLP is None: - radial_MLP = [64, 64, 64] - # Interactions and readouts - self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) - - inter = interaction_cls_first( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=node_feats_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions = torch.nn.ModuleList([inter]) - - # Use the appropriate self connection at the first layer - use_sc_first = False - if "Residual" in str(interaction_cls_first): - use_sc_first = True - - node_feats_irreps_out = inter.target_irreps - prod = EquivariantProductBasisBlock( - node_feats_irreps=node_feats_irreps_out, - target_irreps=hidden_irreps, - correlation=correlation, - num_elements=num_elements, - use_sc=use_sc_first, - ) - self.products = torch.nn.ModuleList([prod]) - - self.readouts = torch.nn.ModuleList() - self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) - - for i in range(num_interactions - 1): - if i == num_interactions - 2: - assert ( - len(hidden_irreps) > 1 - ), "To predict dipoles use at least l=1 hidden_irreps" - hidden_irreps_out = str( - hidden_irreps[:2] - ) # Select scalars and l=1 vectors for last layer - else: - hidden_irreps_out = hidden_irreps - inter = interaction_cls( - node_attrs_irreps=node_attr_irreps, - node_feats_irreps=hidden_irreps, - edge_attrs_irreps=sh_irreps, - edge_feats_irreps=edge_feats_irreps, - target_irreps=interaction_irreps, - hidden_irreps=hidden_irreps_out, - avg_num_neighbors=avg_num_neighbors, - radial_MLP=radial_MLP, - ) - self.interactions.append(inter) - prod = EquivariantProductBasisBlock( - node_feats_irreps=interaction_irreps, - target_irreps=hidden_irreps_out, - correlation=correlation, - num_elements=num_elements, - use_sc=True, - ) - self.products.append(prod) - if i == num_interactions - 2: - self.readouts.append( - NonLinearDipoleReadoutBlock( - hidden_irreps_out, MLP_irreps, gate, dipole_only=False - ) - ) - else: - self.readouts.append( - LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) - ) - - def forward( - self, - data: Dict[str, torch.Tensor], - training: bool = False, - compute_force: bool = True, - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - compute_edge_forces: bool = False, # pylint: disable=W0613 - compute_atomic_stresses: bool = False, # pylint: disable=W0613 - ) -> Dict[str, Optional[torch.Tensor]]: - # Setup - data["node_attrs"].requires_grad_(True) - data["positions"].requires_grad_(True) - num_graphs = data["ptr"].numel() - 1 - num_atoms_arange = torch.arange(data["positions"].shape[0]) - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=data["positions"].dtype, - device=data["positions"].device, - ) - if compute_virials or compute_stress or compute_displacement: - ( - data["positions"], - data["shifts"], - displacement, - ) = get_symmetric_displacement( - positions=data["positions"], - unit_shifts=data["unit_shifts"], - cell=data["cell"], - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) - - # Atomic energies - node_e0 = self.atomic_energies_fn(data["node_attrs"])[ - num_atoms_arange, data["head"][data["batch"]] - ] - e0 = scatter_sum( - src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - - # Embeddings - node_feats = self.node_embedding(data["node_attrs"]) - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - edge_attrs = self.spherical_harmonics(vectors) - edge_feats = self.radial_embedding( - lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers - ) - - # Interactions - energies = [e0] - node_energies_list = [node_e0] - dipoles = [] - for interaction, product, readout in zip( - self.interactions, self.products, self.readouts - ): - node_feats, sc = interaction( - node_attrs=data["node_attrs"], - node_feats=node_feats, - edge_attrs=edge_attrs, - edge_feats=edge_feats, - edge_index=data["edge_index"], - ) - node_feats = product( - node_feats=node_feats, - sc=sc, - node_attrs=data["node_attrs"], - ) - node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] - # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] - node_energies = node_out[:, 0] - energy = scatter_sum( - src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs - ) # [n_graphs,] - energies.append(energy) - node_dipoles = node_out[:, 1:] - dipoles.append(node_dipoles) - - # Compute the energies and dipoles - contributions = torch.stack(energies, dim=-1) - total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] - node_energy_contributions = torch.stack(node_energies_list, dim=-1) - node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] - contributions_dipoles = torch.stack( - dipoles, dim=-1 - ) # [n_nodes,3,n_contributions] - atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] - total_dipole = scatter_sum( - src=atomic_dipoles, - index=data["batch"].unsqueeze(-1), - dim=0, - dim_size=num_graphs, - ) # [n_graphs,3] - baseline = compute_fixed_charge_dipole( - charges=data["charges"], - positions=data["positions"], - batch=data["batch"], - num_graphs=num_graphs, - ) # [n_graphs,3] - total_dipole = total_dipole + baseline - - forces, virials, stress, _, _ = get_outputs( - energy=total_energy, - positions=data["positions"], - displacement=displacement, - cell=data["cell"], - training=training, - compute_force=compute_force, - compute_virials=compute_virials, - compute_stress=compute_stress, - ) - - output = { - "energy": total_energy, - "node_energy": node_energy, - "contributions": contributions, - "forces": forces, - "virials": virials, - "stress": stress, - "displacement": displacement, - "dipole": total_dipole, - "atomic_dipoles": atomic_dipoles, - } - return output +########################################################################################### +# Implementation of MACE models and other models based E(3)-Equivariant MPNNs +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import numpy as np +import torch +from e3nn import o3 +from e3nn.util.jit import compile_mode + +from mace.modules.radial import ZBLBasis +from mace.tools.scatter import scatter_sum + +from .blocks import ( + AtomicEnergiesBlock, + EquivariantProductBasisBlock, + InteractionBlock, + LinearDipoleReadoutBlock, + LinearNodeEmbeddingBlock, + LinearReadoutBlock, + NonLinearDipoleReadoutBlock, + NonLinearReadoutBlock, + RadialEmbeddingBlock, + ScaleShiftBlock, +) +from .utils import ( + compute_fixed_charge_dipole, + get_atomic_virials_stresses, + get_edge_vectors_and_lengths, + get_outputs, + get_symmetric_displacement, + prepare_graph, +) + +# pylint: disable=C0302 + + +@compile_mode("script") +class MACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + atomic_energies: np.ndarray, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: Union[int, List[int]], + gate: Optional[Callable], + pair_repulsion: bool = False, + distance_transform: str = "None", + radial_MLP: Optional[List[int]] = None, + radial_type: Optional[str] = "bessel", + heads: Optional[List[str]] = None, + cueq_config: Optional[Dict[str, Any]] = None, + lammps_mliap: Optional[bool] = False, + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + if heads is None: + heads = ["Default"] + self.heads = heads + if isinstance(correlation, int): + correlation = [correlation] * num_interactions + self.lammps_mliap = lammps_mliap + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, + irreps_out=node_feats_irreps, + cueq_config=cueq_config, + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + distance_transform=distance_transform, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + if pair_repulsion: + self.pair_repulsion_fn = ZBLBasis(p=num_polynomial_cutoff) + self.pair_repulsion = True + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readout + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + cueq_config=cueq_config, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer for proper E0 + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation[0], + num_elements=num_elements, + use_sc=use_sc_first, + cueq_config=cueq_config, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append( + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) + ) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + hidden_irreps_out = str( + hidden_irreps[0] + ) # Select only scalars for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + cueq_config=cueq_config, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation[i + 1], + num_elements=num_elements, + use_sc=True, + cueq_config=cueq_config, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearReadoutBlock( + hidden_irreps_out, + (len(heads) * MLP_irreps).simplify(), + gate, + o3.Irreps(f"{len(heads)}x0e"), + len(heads), + cueq_config, + ) + ) + else: + self.readouts.append( + LinearReadoutBlock( + hidden_irreps, o3.Irreps(f"{len(heads)}x0e"), cueq_config + ) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + compute_edge_forces: bool = False, + compute_atomic_stresses: bool = False, + lammps_mliap: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + ctx = prepare_graph( + data, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_displacement=compute_displacement, + lammps_mliap=lammps_mliap, + ) + is_lammps = ctx.is_lammps + num_atoms_arange = ctx.num_atoms_arange + num_graphs = ctx.num_graphs + displacement = ctx.displacement + positions = ctx.positions + vectors = ctx.vectors + lengths = ctx.lengths + cell = ctx.cell + node_heads = ctx.node_heads + interaction_kwargs = ctx.interaction_kwargs + lammps_natoms = interaction_kwargs.lammps_natoms + lammps_class = interaction_kwargs.lammps_class + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, node_heads + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, n_heads] + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if is_lammps: + pair_node_energy = pair_node_energy[: lammps_natoms[0]] + pair_energy = scatter_sum( + src=pair_node_energy, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + else: + pair_node_energy = torch.zeros_like(node_e0) + pair_energy = torch.zeros_like(e0) + + # Interactions + energies = [e0, pair_energy] + node_energies_list = [node_e0, pair_node_energy] + node_feats_concat: List[torch.Tensor] = [] + + for i, (interaction, product, readout) in enumerate( + zip(self.interactions, self.products, self.readouts) + ): + node_attrs_slice = data["node_attrs"] + if is_lammps and i > 0: + node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] + node_feats, sc = interaction( + node_attrs=node_attrs_slice, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + first_layer=(i == 0), + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + ) + if is_lammps and i == 0: + node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice + ) + node_feats_concat.append(node_feats) + node_es = readout(node_feats, node_heads)[num_atoms_arange, node_heads] + energy = scatter_sum(node_es, data["batch"], dim=0, dim_size=num_graphs) + energies.append(energy) + node_energies_list.append(node_es) + + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) + node_energy = torch.sum(torch.stack(node_energies_list, dim=-1), dim=-1) + node_feats_out = torch.cat(node_feats_concat, dim=-1) + node_energy = node_e0.double() + pair_node_energy.double() + + forces, virials, stress, hessian, edge_forces = get_outputs( + energy=total_energy, + positions=positions, + displacement=displacement, + vectors=vectors, + cell=cell, + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + compute_edge_forces=compute_edge_forces, + ) + + atomic_virials: Optional[torch.Tensor] = None + atomic_stresses: Optional[torch.Tensor] = None + if compute_atomic_stresses and edge_forces is not None: + atomic_virials, atomic_stresses = get_atomic_virials_stresses( + edge_forces=edge_forces, + edge_index=data["edge_index"], + vectors=vectors, + num_atoms=positions.shape[0], + batch=data["batch"], + cell=cell, + ) + return { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "edge_forces": edge_forces, + "virials": virials, + "stress": stress, + "atomic_virials": atomic_virials, + "atomic_stresses": atomic_stresses, + "displacement": displacement, + "hessian": hessian, + "node_feats": node_feats_out, + } + + +@compile_mode("script") +class ScaleShiftMACE(MACE): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + compute_edge_forces: bool = False, + compute_atomic_stresses: bool = False, + lammps_mliap: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + ctx = prepare_graph( + data, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_displacement=compute_displacement, + lammps_mliap=lammps_mliap, + ) + + is_lammps = ctx.is_lammps + num_atoms_arange = ctx.num_atoms_arange + num_graphs = ctx.num_graphs + displacement = ctx.displacement + positions = ctx.positions + vectors = ctx.vectors + lengths = ctx.lengths + cell = ctx.cell + node_heads = ctx.node_heads + interaction_kwargs = ctx.interaction_kwargs + lammps_natoms = interaction_kwargs.lammps_natoms + lammps_class = interaction_kwargs.lammps_class + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, node_heads + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=0, dim_size=num_graphs + ) # [n_graphs, num_heads] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + if hasattr(self, "pair_repulsion"): + pair_node_energy = self.pair_repulsion_fn( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + if is_lammps: + pair_node_energy = pair_node_energy[: lammps_natoms[0]] + else: + pair_node_energy = torch.zeros_like(node_e0) + + # Interactions + node_es_list = [pair_node_energy] + node_feats_list: List[torch.Tensor] = [] + + for i, (interaction, product, readout) in enumerate( + zip(self.interactions, self.products, self.readouts) + ): + node_attrs_slice = data["node_attrs"] + if is_lammps and i > 0: + node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] + node_feats, sc = interaction( + node_attrs=node_attrs_slice, + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + first_layer=(i == 0), + lammps_class=lammps_class, + lammps_natoms=lammps_natoms, + ) + if is_lammps and i == 0: + node_attrs_slice = node_attrs_slice[: lammps_natoms[0]] + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=node_attrs_slice + ) + node_feats_list.append(node_feats) + node_es_list.append( + readout(node_feats, node_heads)[num_atoms_arange, node_heads] + ) + + node_feats_out = torch.cat(node_feats_list, dim=-1) + node_inter_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0) + node_inter_es = self.scale_shift(node_inter_es, node_heads) + inter_e = scatter_sum(node_inter_es, data["batch"], dim=-1, dim_size=num_graphs) + + total_energy = e0 + inter_e + node_energy = node_e0.clone().double() + node_inter_es.clone().double() + + forces, virials, stress, hessian, edge_forces = get_outputs( + energy=inter_e, + positions=positions, + displacement=displacement, + vectors=vectors, + cell=cell, + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + compute_hessian=compute_hessian, + compute_edge_forces=compute_edge_forces or compute_atomic_stresses, + ) + + atomic_virials: Optional[torch.Tensor] = None + atomic_stresses: Optional[torch.Tensor] = None + if compute_atomic_stresses and edge_forces is not None: + atomic_virials, atomic_stresses = get_atomic_virials_stresses( + edge_forces=edge_forces, + edge_index=data["edge_index"], + vectors=vectors, + num_atoms=positions.shape[0], + batch=data["batch"], + cell=cell, + ) + return { + "energy": total_energy, + "node_energy": node_energy, + "interaction_energy": inter_e, + "forces": forces, + "edge_forces": edge_forces, + "virials": virials, + "stress": stress, + "atomic_virials": atomic_virials, + "atomic_stresses": atomic_stresses, + "hessian": hessian, + "displacement": displacement, + "node_feats": node_feats_out, + } + + +@compile_mode("script") +class AtomicDipolesMACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + atomic_energies: Optional[ + None + ], # Just here to make it compatible with energy models, MUST be None + radial_type: Optional[str] = "bessel", + radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + assert atomic_energies is None + + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + radial_type=radial_type, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + + # Interactions and readouts + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + hidden_irreps_out = str( + hidden_irreps[1] + ) # Select only l=1 vectors for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearDipoleReadoutBlock( + hidden_irreps_out, MLP_irreps, gate, dipole_only=True + ) + ) + else: + self.readouts.append( + LinearDipoleReadoutBlock(hidden_irreps, dipole_only=True) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, # pylint: disable=W0613 + compute_force: bool = False, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_edge_forces: bool = False, # pylint: disable=W0613 + compute_atomic_stresses: bool = False, # pylint: disable=W0613 + ) -> Dict[str, Optional[torch.Tensor]]: + assert compute_force is False + assert compute_virials is False + assert compute_stress is False + assert compute_displacement is False + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + dipoles = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_dipoles = readout(node_feats).squeeze(-1) # [n_nodes,3] + dipoles.append(node_dipoles) + + # Compute the dipoles + contributions_dipoles = torch.stack( + dipoles, dim=-1 + ) # [n_nodes,3,n_contributions] + atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"], + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + baseline = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = total_dipole + baseline + + output = { + "dipole": total_dipole, + "atomic_dipoles": atomic_dipoles, + } + return output + + +@compile_mode("script") +class EnergyDipolesMACE(torch.nn.Module): + def __init__( + self, + r_max: float, + num_bessel: int, + num_polynomial_cutoff: int, + max_ell: int, + interaction_cls: Type[InteractionBlock], + interaction_cls_first: Type[InteractionBlock], + num_interactions: int, + num_elements: int, + hidden_irreps: o3.Irreps, + MLP_irreps: o3.Irreps, + avg_num_neighbors: float, + atomic_numbers: List[int], + correlation: int, + gate: Optional[Callable], + atomic_energies: Optional[np.ndarray], + radial_MLP: Optional[List[int]] = None, + cueq_config: Optional[Dict[str, Any]] = None, # pylint: disable=unused-argument + ): + super().__init__() + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) + ) + self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.float64)) + self.register_buffer( + "num_interactions", torch.tensor(num_interactions, dtype=torch.int64) + ) + # Embedding + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + self.node_embedding = LinearNodeEmbeddingBlock( + irreps_in=node_attr_irreps, irreps_out=node_feats_irreps + ) + self.radial_embedding = RadialEmbeddingBlock( + r_max=r_max, + num_bessel=num_bessel, + num_polynomial_cutoff=num_polynomial_cutoff, + ) + edge_feats_irreps = o3.Irreps(f"{self.radial_embedding.out_dim}x0e") + + sh_irreps = o3.Irreps.spherical_harmonics(max_ell) + num_features = hidden_irreps.count(o3.Irrep(0, 1)) + interaction_irreps = (sh_irreps * num_features).sort()[0].simplify() + self.spherical_harmonics = o3.SphericalHarmonics( + sh_irreps, normalize=True, normalization="component" + ) + if radial_MLP is None: + radial_MLP = [64, 64, 64] + # Interactions and readouts + self.atomic_energies_fn = AtomicEnergiesBlock(atomic_energies) + + inter = interaction_cls_first( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=node_feats_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions = torch.nn.ModuleList([inter]) + + # Use the appropriate self connection at the first layer + use_sc_first = False + if "Residual" in str(interaction_cls_first): + use_sc_first = True + + node_feats_irreps_out = inter.target_irreps + prod = EquivariantProductBasisBlock( + node_feats_irreps=node_feats_irreps_out, + target_irreps=hidden_irreps, + correlation=correlation, + num_elements=num_elements, + use_sc=use_sc_first, + ) + self.products = torch.nn.ModuleList([prod]) + + self.readouts = torch.nn.ModuleList() + self.readouts.append(LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False)) + + for i in range(num_interactions - 1): + if i == num_interactions - 2: + assert ( + len(hidden_irreps) > 1 + ), "To predict dipoles use at least l=1 hidden_irreps" + hidden_irreps_out = str( + hidden_irreps[:2] + ) # Select scalars and l=1 vectors for last layer + else: + hidden_irreps_out = hidden_irreps + inter = interaction_cls( + node_attrs_irreps=node_attr_irreps, + node_feats_irreps=hidden_irreps, + edge_attrs_irreps=sh_irreps, + edge_feats_irreps=edge_feats_irreps, + target_irreps=interaction_irreps, + hidden_irreps=hidden_irreps_out, + avg_num_neighbors=avg_num_neighbors, + radial_MLP=radial_MLP, + ) + self.interactions.append(inter) + prod = EquivariantProductBasisBlock( + node_feats_irreps=interaction_irreps, + target_irreps=hidden_irreps_out, + correlation=correlation, + num_elements=num_elements, + use_sc=True, + ) + self.products.append(prod) + if i == num_interactions - 2: + self.readouts.append( + NonLinearDipoleReadoutBlock( + hidden_irreps_out, MLP_irreps, gate, dipole_only=False + ) + ) + else: + self.readouts.append( + LinearDipoleReadoutBlock(hidden_irreps, dipole_only=False) + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_edge_forces: bool = False, # pylint: disable=W0613 + compute_atomic_stresses: bool = False, # pylint: disable=W0613 + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["node_attrs"].requires_grad_(True) + data["positions"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=data["positions"].dtype, + device=data["positions"].device, + ) + if compute_virials or compute_stress or compute_displacement: + ( + data["positions"], + data["shifts"], + displacement, + ) = get_symmetric_displacement( + positions=data["positions"], + unit_shifts=data["unit_shifts"], + cell=data["cell"], + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + + # Atomic energies + node_e0 = self.atomic_energies_fn(data["node_attrs"])[ + num_atoms_arange, data["head"][data["batch"]] + ] + e0 = scatter_sum( + src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + energies = [e0] + node_energies_list = [node_e0] + dipoles = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, + sc=sc, + node_attrs=data["node_attrs"], + ) + node_out = readout(node_feats).squeeze(-1) # [n_nodes, ] + # node_energies = readout(node_feats).squeeze(-1) # [n_nodes, ] + node_energies = node_out[:, 0] + energy = scatter_sum( + src=node_energies, index=data["batch"], dim=-1, dim_size=num_graphs + ) # [n_graphs,] + energies.append(energy) + node_dipoles = node_out[:, 1:] + dipoles.append(node_dipoles) + + # Compute the energies and dipoles + contributions = torch.stack(energies, dim=-1) + total_energy = torch.sum(contributions, dim=-1) # [n_graphs, ] + node_energy_contributions = torch.stack(node_energies_list, dim=-1) + node_energy = torch.sum(node_energy_contributions, dim=-1) # [n_nodes, ] + contributions_dipoles = torch.stack( + dipoles, dim=-1 + ) # [n_nodes,3,n_contributions] + atomic_dipoles = torch.sum(contributions_dipoles, dim=-1) # [n_nodes,3] + total_dipole = scatter_sum( + src=atomic_dipoles, + index=data["batch"].unsqueeze(-1), + dim=0, + dim_size=num_graphs, + ) # [n_graphs,3] + baseline = compute_fixed_charge_dipole( + charges=data["charges"], + positions=data["positions"], + batch=data["batch"], + num_graphs=num_graphs, + ) # [n_graphs,3] + total_dipole = total_dipole + baseline + + forces, virials, stress, _, _ = get_outputs( + energy=total_energy, + positions=data["positions"], + displacement=displacement, + cell=data["cell"], + training=training, + compute_force=compute_force, + compute_virials=compute_virials, + compute_stress=compute_stress, + ) + + output = { + "energy": total_energy, + "node_energy": node_energy, + "contributions": contributions, + "forces": forces, + "virials": virials, + "stress": stress, + "displacement": displacement, + "dipole": total_dipole, + "atomic_dipoles": atomic_dipoles, + } + return output diff --git a/mace-bench/3rdparty/mace/mace/modules/radial.py b/mace-bench/3rdparty/mace/mace/modules/radial.py index b78dd4e..ff69b43 100644 --- a/mace-bench/3rdparty/mace/mace/modules/radial.py +++ b/mace-bench/3rdparty/mace/mace/modules/radial.py @@ -1,358 +1,358 @@ -########################################################################################### -# Radial basis and cutoff -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging - -import ase -import numpy as np -import torch -from e3nn.util.jit import compile_mode - -from mace.tools.scatter import scatter_sum - - -@compile_mode("script") -class BesselBasis(torch.nn.Module): - """ - Equation (7) - """ - - def __init__(self, r_max: float, num_basis=8, trainable=False): - super().__init__() - - bessel_weights = ( - np.pi - / r_max - * torch.linspace( - start=1.0, - end=num_basis, - steps=num_basis, - dtype=torch.get_default_dtype(), - ) - ) - if trainable: - self.bessel_weights = torch.nn.Parameter(bessel_weights) - else: - self.register_buffer("bessel_weights", bessel_weights) - - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - self.register_buffer( - "prefactor", - torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] - numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] - return self.prefactor * (numerator / x) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " - f"trainable={self.bessel_weights.requires_grad})" - ) - - -@compile_mode("script") -class ChebychevBasis(torch.nn.Module): - """ - Equation (7) - """ - - def __init__(self, r_max: float, num_basis=8): - super().__init__() - self.register_buffer( - "n", - torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze( - 0 - ), - ) - self.num_basis = num_basis - self.r_max = r_max - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] - x = x.repeat(1, self.num_basis) - n = self.n.repeat(len(x), 1) - return torch.special.chebyshev_polynomial_t(x, n) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis}," - ) - - -@compile_mode("script") -class GaussianBasis(torch.nn.Module): - """ - Gaussian basis functions - """ - - def __init__(self, r_max: float, num_basis=128, trainable=False): - super().__init__() - gaussian_weights = torch.linspace( - start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype() - ) - if trainable: - self.gaussian_weights = torch.nn.Parameter( - gaussian_weights, requires_grad=True - ) - else: - self.register_buffer("gaussian_weights", gaussian_weights) - self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2 - - def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] - x = x - self.gaussian_weights - return torch.exp(self.coeff * torch.pow(x, 2)) - - -@compile_mode("script") -class PolynomialCutoff(torch.nn.Module): - """Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max. - Equation (8) -- TODO: from where? - """ - - p: torch.Tensor - r_max: torch.Tensor - - def __init__(self, r_max: float, p=6): - super().__init__() - self.register_buffer("p", torch.tensor(p, dtype=torch.int)) - self.register_buffer( - "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.calculate_envelope(x, self.r_max, self.p.to(torch.int)) - - @staticmethod - def calculate_envelope( - x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor - ) -> torch.Tensor: - r_over_r_max = x / r_max - envelope = ( - 1.0 - - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p) - + p * (p + 2.0) * torch.pow(r_over_r_max, p + 1) - - (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2) - ) - return envelope * (x < r_max) - - def __repr__(self): - return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" - - -@compile_mode("script") -class ZBLBasis(torch.nn.Module): - """Implementation of the Ziegler-Biersack-Littmark (ZBL) potential - with a polynomial cutoff envelope. - """ - - p: torch.Tensor - - def __init__(self, p=6, trainable=False, **kwargs): - super().__init__() - if "r_max" in kwargs: - logging.warning( - "r_max is deprecated. r_max is determined from the covalent radii." - ) - - # Pre-calculate the p coefficients for the ZBL potential - self.register_buffer( - "c", - torch.tensor( - [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() - ), - ) - self.register_buffer("p", torch.tensor(p, dtype=torch.int)) - self.register_buffer( - "covalent_radii", - torch.tensor( - ase.data.covalent_radii, - dtype=torch.get_default_dtype(), - ), - ) - if trainable: - self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) - self.a_prefactor = torch.nn.Parameter( - torch.tensor(0.4543, requires_grad=True) - ) - else: - self.register_buffer("a_exp", torch.tensor(0.300)) - self.register_buffer("a_prefactor", torch.tensor(0.4543)) - - def forward( - self, - x: torch.Tensor, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - sender = edge_index[0] - receiver = edge_index[1] - node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( - -1 - ) - Z_u = node_atomic_numbers[sender] - Z_v = node_atomic_numbers[receiver] - a = ( - self.a_prefactor - * 0.529 - / (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp)) - ) - r_over_a = x / a - phi = ( - self.c[0] * torch.exp(-3.2 * r_over_a) - + self.c[1] * torch.exp(-0.9423 * r_over_a) - + self.c[2] * torch.exp(-0.4028 * r_over_a) - + self.c[3] * torch.exp(-0.2016 * r_over_a) - ) - v_edges = (14.3996 * Z_u * Z_v) / x * phi - r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] - envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p) - v_edges = 0.5 * v_edges * envelope - V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) - return V_ZBL.squeeze(-1) - - def __repr__(self): - return f"{self.__class__.__name__}(c={self.c})" - - -@compile_mode("script") -class AgnesiTransform(torch.nn.Module): - """Agnesi transform - see section on Radial transformations in - ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783). - """ - - def __init__( - self, - q: float = 0.9183, - p: float = 4.5791, - a: float = 1.0805, - trainable=False, - ): - super().__init__() - self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype())) - self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) - self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype())) - self.register_buffer( - "covalent_radii", - torch.tensor( - ase.data.covalent_radii, - dtype=torch.get_default_dtype(), - ), - ) - if trainable: - self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True)) - self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True)) - self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True)) - - def forward( - self, - x: torch.Tensor, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - sender = edge_index[0] - receiver = edge_index[1] - node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( - -1 - ) - Z_u = node_atomic_numbers[sender] - Z_v = node_atomic_numbers[receiver] - r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) - r_over_r_0 = x / r_0 - return ( - 1 - + ( - self.a - * torch.pow(r_over_r_0, self.q) - / (1 + torch.pow(r_over_r_0, self.q - self.p)) - ) - ).reciprocal_() - - def __repr__(self): - return ( - f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})" - ) - - -@compile_mode("script") -class SoftTransform(torch.nn.Module): - """ - Tanh-based smooth transformation: - T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))], - which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0. - """ - - def __init__(self, alpha: float = 4.0, trainable=False): - """ - Args: - p1 (float): Lower "clamp" point. - alpha (float): Steepness; if None, defaults to ~6/(r0-p1). - trainable (bool): Whether to make parameters trainable. - """ - super().__init__() - # Initialize parameters - self.register_buffer( - "alpha", torch.tensor(alpha, dtype=torch.get_default_dtype()) - ) - if trainable: - self.alpha = torch.nn.Parameter(self.alpha.clone()) - self.register_buffer( - "covalent_radii", - torch.tensor( - ase.data.covalent_radii, - dtype=torch.get_default_dtype(), - ), - ) - - def compute_r_0( - self, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - """ - Compute r_0 based on atomic information. - - Args: - node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers). - edge_index (torch.Tensor): Edge index indicating connections. - atomic_numbers (torch.Tensor): Atomic numbers. - - Returns: - torch.Tensor: r_0 values for each edge. - """ - sender = edge_index[0] - receiver = edge_index[1] - node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( - -1 - ) - Z_u = node_atomic_numbers[sender] - Z_v = node_atomic_numbers[receiver] - r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] - return r_0 - - def forward( - self, - x: torch.Tensor, - node_attrs: torch.Tensor, - edge_index: torch.Tensor, - atomic_numbers: torch.Tensor, - ) -> torch.Tensor: - - r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers) - p_0 = (3 / 4) * r_0 - p_1 = (4 / 3) * r_0 - m = 0.5 * (p_0 + p_1) - alpha = self.alpha / (p_1 - p_0) - s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m))) - return p_0 + (x - p_0) * s_x - - def __repr__(self): - return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})" +########################################################################################### +# Radial basis and cutoff +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging + +import ase +import numpy as np +import torch +from e3nn.util.jit import compile_mode + +from mace.tools.scatter import scatter_sum + + +@compile_mode("script") +class BesselBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8, trainable=False): + super().__init__() + + bessel_weights = ( + np.pi + / r_max + * torch.linspace( + start=1.0, + end=num_basis, + steps=num_basis, + dtype=torch.get_default_dtype(), + ) + ) + if trainable: + self.bessel_weights = torch.nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) + + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + self.register_buffer( + "prefactor", + torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] + return self.prefactor * (numerator / x) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " + f"trainable={self.bessel_weights.requires_grad})" + ) + + +@compile_mode("script") +class ChebychevBasis(torch.nn.Module): + """ + Equation (7) + """ + + def __init__(self, r_max: float, num_basis=8): + super().__init__() + self.register_buffer( + "n", + torch.arange(1, num_basis + 1, dtype=torch.get_default_dtype()).unsqueeze( + 0 + ), + ) + self.num_basis = num_basis + self.r_max = r_max + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x.repeat(1, self.num_basis) + n = self.n.repeat(len(x), 1) + return torch.special.chebyshev_polynomial_t(x, n) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={self.num_basis}," + ) + + +@compile_mode("script") +class GaussianBasis(torch.nn.Module): + """ + Gaussian basis functions + """ + + def __init__(self, r_max: float, num_basis=128, trainable=False): + super().__init__() + gaussian_weights = torch.linspace( + start=0.0, end=r_max, steps=num_basis, dtype=torch.get_default_dtype() + ) + if trainable: + self.gaussian_weights = torch.nn.Parameter( + gaussian_weights, requires_grad=True + ) + else: + self.register_buffer("gaussian_weights", gaussian_weights) + self.coeff = -0.5 / (r_max / (num_basis - 1)) ** 2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] + x = x - self.gaussian_weights + return torch.exp(self.coeff * torch.pow(x, 2)) + + +@compile_mode("script") +class PolynomialCutoff(torch.nn.Module): + """Polynomial cutoff function that goes from 1 to 0 as x goes from 0 to r_max. + Equation (8) -- TODO: from where? + """ + + p: torch.Tensor + r_max: torch.Tensor + + def __init__(self, r_max: float, p=6): + super().__init__() + self.register_buffer("p", torch.tensor(p, dtype=torch.int)) + self.register_buffer( + "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.calculate_envelope(x, self.r_max, self.p.to(torch.int)) + + @staticmethod + def calculate_envelope( + x: torch.Tensor, r_max: torch.Tensor, p: torch.Tensor + ) -> torch.Tensor: + r_over_r_max = x / r_max + envelope = ( + 1.0 + - ((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(r_over_r_max, p) + + p * (p + 2.0) * torch.pow(r_over_r_max, p + 1) + - (p * (p + 1.0) / 2) * torch.pow(r_over_r_max, p + 2) + ) + return envelope * (x < r_max) + + def __repr__(self): + return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" + + +@compile_mode("script") +class ZBLBasis(torch.nn.Module): + """Implementation of the Ziegler-Biersack-Littmark (ZBL) potential + with a polynomial cutoff envelope. + """ + + p: torch.Tensor + + def __init__(self, p=6, trainable=False, **kwargs): + super().__init__() + if "r_max" in kwargs: + logging.warning( + "r_max is deprecated. r_max is determined from the covalent radii." + ) + + # Pre-calculate the p coefficients for the ZBL potential + self.register_buffer( + "c", + torch.tensor( + [0.1818, 0.5099, 0.2802, 0.02817], dtype=torch.get_default_dtype() + ), + ) + self.register_buffer("p", torch.tensor(p, dtype=torch.int)) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a_exp = torch.nn.Parameter(torch.tensor(0.300, requires_grad=True)) + self.a_prefactor = torch.nn.Parameter( + torch.tensor(0.4543, requires_grad=True) + ) + else: + self.register_buffer("a_exp", torch.tensor(0.300)) + self.register_buffer("a_prefactor", torch.tensor(0.4543)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + a = ( + self.a_prefactor + * 0.529 + / (torch.pow(Z_u, self.a_exp) + torch.pow(Z_v, self.a_exp)) + ) + r_over_a = x / a + phi = ( + self.c[0] * torch.exp(-3.2 * r_over_a) + + self.c[1] * torch.exp(-0.9423 * r_over_a) + + self.c[2] * torch.exp(-0.4028 * r_over_a) + + self.c[3] * torch.exp(-0.2016 * r_over_a) + ) + v_edges = (14.3996 * Z_u * Z_v) / x * phi + r_max = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] + envelope = PolynomialCutoff.calculate_envelope(x, r_max, self.p) + v_edges = 0.5 * v_edges * envelope + V_ZBL = scatter_sum(v_edges, receiver, dim=0, dim_size=node_attrs.size(0)) + return V_ZBL.squeeze(-1) + + def __repr__(self): + return f"{self.__class__.__name__}(c={self.c})" + + +@compile_mode("script") +class AgnesiTransform(torch.nn.Module): + """Agnesi transform - see section on Radial transformations in + ACEpotentials.jl, JCP 2023 (https://doi.org/10.1063/5.0158783). + """ + + def __init__( + self, + q: float = 0.9183, + p: float = 4.5791, + a: float = 1.0805, + trainable=False, + ): + super().__init__() + self.register_buffer("q", torch.tensor(q, dtype=torch.get_default_dtype())) + self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) + self.register_buffer("a", torch.tensor(a, dtype=torch.get_default_dtype())) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + if trainable: + self.a = torch.nn.Parameter(torch.tensor(1.0805, requires_grad=True)) + self.q = torch.nn.Parameter(torch.tensor(0.9183, requires_grad=True)) + self.p = torch.nn.Parameter(torch.tensor(4.5791, requires_grad=True)) + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0: torch.Tensor = 0.5 * (self.covalent_radii[Z_u] + self.covalent_radii[Z_v]) + r_over_r_0 = x / r_0 + return ( + 1 + + ( + self.a + * torch.pow(r_over_r_0, self.q) + / (1 + torch.pow(r_over_r_0, self.q - self.p)) + ) + ).reciprocal_() + + def __repr__(self): + return ( + f"{self.__class__.__name__}(a={self.a:.4f}, q={self.q:.4f}, p={self.p:.4f})" + ) + + +@compile_mode("script") +class SoftTransform(torch.nn.Module): + """ + Tanh-based smooth transformation: + T(x) = p1 + (x - p1)*0.5*[1 + tanh(alpha*(x - m))], + which smoothly transitions from ~p1 for x << p1 to ~x for x >> r0. + """ + + def __init__(self, alpha: float = 4.0, trainable=False): + """ + Args: + p1 (float): Lower "clamp" point. + alpha (float): Steepness; if None, defaults to ~6/(r0-p1). + trainable (bool): Whether to make parameters trainable. + """ + super().__init__() + # Initialize parameters + self.register_buffer( + "alpha", torch.tensor(alpha, dtype=torch.get_default_dtype()) + ) + if trainable: + self.alpha = torch.nn.Parameter(self.alpha.clone()) + self.register_buffer( + "covalent_radii", + torch.tensor( + ase.data.covalent_radii, + dtype=torch.get_default_dtype(), + ), + ) + + def compute_r_0( + self, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + """ + Compute r_0 based on atomic information. + + Args: + node_attrs (torch.Tensor): Node attributes (one-hot encoding of atomic numbers). + edge_index (torch.Tensor): Edge index indicating connections. + atomic_numbers (torch.Tensor): Atomic numbers. + + Returns: + torch.Tensor: r_0 values for each edge. + """ + sender = edge_index[0] + receiver = edge_index[1] + node_atomic_numbers = atomic_numbers[torch.argmax(node_attrs, dim=1)].unsqueeze( + -1 + ) + Z_u = node_atomic_numbers[sender] + Z_v = node_atomic_numbers[receiver] + r_0: torch.Tensor = self.covalent_radii[Z_u] + self.covalent_radii[Z_v] + return r_0 + + def forward( + self, + x: torch.Tensor, + node_attrs: torch.Tensor, + edge_index: torch.Tensor, + atomic_numbers: torch.Tensor, + ) -> torch.Tensor: + + r_0 = self.compute_r_0(node_attrs, edge_index, atomic_numbers) + p_0 = (3 / 4) * r_0 + p_1 = (4 / 3) * r_0 + m = 0.5 * (p_0 + p_1) + alpha = self.alpha / (p_1 - p_0) + s_x = 0.5 * (1.0 + torch.tanh(alpha * (x - m))) + return p_0 + (x - p_0) * s_x + + def __repr__(self): + return f"{self.__class__.__name__}(alpha={self.alpha.item():.4f})" diff --git a/mace-bench/3rdparty/mace/mace/modules/symmetric_contraction.py b/mace-bench/3rdparty/mace/mace/modules/symmetric_contraction.py index 577713c..9db75da 100644 --- a/mace-bench/3rdparty/mace/mace/modules/symmetric_contraction.py +++ b/mace-bench/3rdparty/mace/mace/modules/symmetric_contraction.py @@ -1,233 +1,233 @@ -########################################################################################### -# Implementation of the symmetric contraction algorithm presented in the MACE paper -# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) -# Authors: Ilyes Batatia -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -from typing import Dict, Optional, Union - -import opt_einsum_fx -import torch -import torch.fx -from e3nn import o3 -from e3nn.util.codegen import CodeGenMixin -from e3nn.util.jit import compile_mode - -from mace.tools.cg import U_matrix_real - -BATCH_EXAMPLE = 10 -ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] - - -@compile_mode("script") -class SymmetricContraction(CodeGenMixin, torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - irreps_out: o3.Irreps, - correlation: Union[int, Dict[str, int]], - irrep_normalization: str = "component", - path_normalization: str = "element", - internal_weights: Optional[bool] = None, - shared_weights: Optional[bool] = None, - num_elements: Optional[int] = None, - ) -> None: - super().__init__() - - if irrep_normalization is None: - irrep_normalization = "component" - - if path_normalization is None: - path_normalization = "element" - - assert irrep_normalization in ["component", "norm", "none"] - assert path_normalization in ["element", "path", "none"] - - self.irreps_in = o3.Irreps(irreps_in) - self.irreps_out = o3.Irreps(irreps_out) - - del irreps_in, irreps_out - - if not isinstance(correlation, tuple): - corr = correlation - correlation = {} - for irrep_out in self.irreps_out: - correlation[irrep_out] = corr - - assert shared_weights or not internal_weights - - if internal_weights is None: - internal_weights = True - - self.internal_weights = internal_weights - self.shared_weights = shared_weights - - del internal_weights, shared_weights - - self.contractions = torch.nn.ModuleList() - for irrep_out in self.irreps_out: - self.contractions.append( - Contraction( - irreps_in=self.irreps_in, - irrep_out=o3.Irreps(str(irrep_out.ir)), - correlation=correlation[irrep_out], - internal_weights=self.internal_weights, - num_elements=num_elements, - weights=self.shared_weights, - ) - ) - - def forward(self, x: torch.Tensor, y: torch.Tensor): - outs = [contraction(x, y) for contraction in self.contractions] - return torch.cat(outs, dim=-1) - - -@compile_mode("script") -class Contraction(torch.nn.Module): - def __init__( - self, - irreps_in: o3.Irreps, - irrep_out: o3.Irreps, - correlation: int, - internal_weights: bool = True, - num_elements: Optional[int] = None, - weights: Optional[torch.Tensor] = None, - ) -> None: - super().__init__() - - self.num_features = irreps_in.count((0, 1)) - self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) - self.correlation = correlation - dtype = torch.get_default_dtype() - for nu in range(1, correlation + 1): - U_matrix = U_matrix_real( - irreps_in=self.coupling_irreps, - irreps_out=irrep_out, - correlation=nu, - dtype=dtype, - )[-1] - self.register_buffer(f"U_matrix_{nu}", U_matrix) - - # Tensor contraction equations - self.contractions_weighting = torch.nn.ModuleList() - self.contractions_features = torch.nn.ModuleList() - - # Create weight for product basis - self.weights = torch.nn.ParameterList([]) - - for i in range(correlation, 0, -1): - # Shapes definying - num_params = self.U_tensors(i).size()[-1] - num_equivariance = 2 * irrep_out.lmax + 1 - num_ell = self.U_tensors(i).size()[-2] - - if i == correlation: - parse_subscript_main = ( - [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] - + ["ik,ekc,bci,be -> bc"] - + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] - ) - graph_module_main = torch.fx.symbolic_trace( - lambda x, y, w, z: torch.einsum( - "".join(parse_subscript_main), x, y, w, z - ) - ) - - # Optimizing the contractions - self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( - model=graph_module_main, - example_inputs=( - torch.randn( - [num_equivariance] + [num_ell] * i + [num_params] - ).squeeze(0), - torch.randn((num_elements, num_params, self.num_features)), - torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), - torch.randn((BATCH_EXAMPLE, num_elements)), - ), - ) - # Parameters for the product basis - w = torch.nn.Parameter( - torch.randn((num_elements, num_params, self.num_features)) - / num_params - ) - self.weights_max = w - else: - # Generate optimized contractions equations - parse_subscript_weighting = ( - [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] - + ["k,ekc,be->bc"] - + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] - ) - parse_subscript_features = ( - ["bc"] - + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] - + ["i,bci->bc"] - + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] - ) - - # Symbolic tracing of contractions - graph_module_weighting = torch.fx.symbolic_trace( - lambda x, y, z: torch.einsum( - "".join(parse_subscript_weighting), x, y, z - ) - ) - graph_module_features = torch.fx.symbolic_trace( - lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) - ) - - # Optimizing the contractions - graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( - model=graph_module_weighting, - example_inputs=( - torch.randn( - [num_equivariance] + [num_ell] * i + [num_params] - ).squeeze(0), - torch.randn((num_elements, num_params, self.num_features)), - torch.randn((BATCH_EXAMPLE, num_elements)), - ), - ) - graph_opt_features = opt_einsum_fx.optimize_einsums_full( - model=graph_module_features, - example_inputs=( - torch.randn( - [BATCH_EXAMPLE, self.num_features, num_equivariance] - + [num_ell] * i - ).squeeze(2), - torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), - ), - ) - self.contractions_weighting.append(graph_opt_weighting) - self.contractions_features.append(graph_opt_features) - # Parameters for the product basis - w = torch.nn.Parameter( - torch.randn((num_elements, num_params, self.num_features)) - / num_params - ) - self.weights.append(w) - if not internal_weights: - self.weights = weights[:-1] - self.weights_max = weights[-1] - - def forward(self, x: torch.Tensor, y: torch.Tensor): - out = self.graph_opt_main( - self.U_tensors(self.correlation), - self.weights_max, - x, - y, - ) - for i, (weight, contract_weights, contract_features) in enumerate( - zip(self.weights, self.contractions_weighting, self.contractions_features) - ): - c_tensor = contract_weights( - self.U_tensors(self.correlation - i - 1), - weight, - y, - ) - c_tensor = c_tensor + out - out = contract_features(c_tensor, x) - - return out.view(out.shape[0], -1) - - def U_tensors(self, nu: int): - return dict(self.named_buffers())[f"U_matrix_{nu}"] +########################################################################################### +# Implementation of the symmetric contraction algorithm presented in the MACE paper +# (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +from typing import Dict, Optional, Union + +import opt_einsum_fx +import torch +import torch.fx +from e3nn import o3 +from e3nn.util.codegen import CodeGenMixin +from e3nn.util.jit import compile_mode + +from mace.tools.cg import U_matrix_real + +BATCH_EXAMPLE = 10 +ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"] + + +@compile_mode("script") +class SymmetricContraction(CodeGenMixin, torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: Union[int, Dict[str, int]], + irrep_normalization: str = "component", + path_normalization: str = "element", + internal_weights: Optional[bool] = None, + shared_weights: Optional[bool] = None, + num_elements: Optional[int] = None, + ) -> None: + super().__init__() + + if irrep_normalization is None: + irrep_normalization = "component" + + if path_normalization is None: + path_normalization = "element" + + assert irrep_normalization in ["component", "norm", "none"] + assert path_normalization in ["element", "path", "none"] + + self.irreps_in = o3.Irreps(irreps_in) + self.irreps_out = o3.Irreps(irreps_out) + + del irreps_in, irreps_out + + if not isinstance(correlation, tuple): + corr = correlation + correlation = {} + for irrep_out in self.irreps_out: + correlation[irrep_out] = corr + + assert shared_weights or not internal_weights + + if internal_weights is None: + internal_weights = True + + self.internal_weights = internal_weights + self.shared_weights = shared_weights + + del internal_weights, shared_weights + + self.contractions = torch.nn.ModuleList() + for irrep_out in self.irreps_out: + self.contractions.append( + Contraction( + irreps_in=self.irreps_in, + irrep_out=o3.Irreps(str(irrep_out.ir)), + correlation=correlation[irrep_out], + internal_weights=self.internal_weights, + num_elements=num_elements, + weights=self.shared_weights, + ) + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + outs = [contraction(x, y) for contraction in self.contractions] + return torch.cat(outs, dim=-1) + + +@compile_mode("script") +class Contraction(torch.nn.Module): + def __init__( + self, + irreps_in: o3.Irreps, + irrep_out: o3.Irreps, + correlation: int, + internal_weights: bool = True, + num_elements: Optional[int] = None, + weights: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + + self.num_features = irreps_in.count((0, 1)) + self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) + self.correlation = correlation + dtype = torch.get_default_dtype() + for nu in range(1, correlation + 1): + U_matrix = U_matrix_real( + irreps_in=self.coupling_irreps, + irreps_out=irrep_out, + correlation=nu, + dtype=dtype, + )[-1] + self.register_buffer(f"U_matrix_{nu}", U_matrix) + + # Tensor contraction equations + self.contractions_weighting = torch.nn.ModuleList() + self.contractions_features = torch.nn.ModuleList() + + # Create weight for product basis + self.weights = torch.nn.ParameterList([]) + + for i in range(correlation, 0, -1): + # Shapes definying + num_params = self.U_tensors(i).size()[-1] + num_equivariance = 2 * irrep_out.lmax + 1 + num_ell = self.U_tensors(i).size()[-2] + + if i == correlation: + parse_subscript_main = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + + ["ik,ekc,bci,be -> bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + ) + graph_module_main = torch.fx.symbolic_trace( + lambda x, y, w, z: torch.einsum( + "".join(parse_subscript_main), x, y, w, z + ) + ) + + # Optimizing the contractions + self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( + model=graph_module_main, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights_max = w + else: + # Generate optimized contractions equations + parse_subscript_weighting = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + + ["k,ekc,be->bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + ) + parse_subscript_features = ( + ["bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + + ["i,bci->bc"] + + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + ) + + # Symbolic tracing of contractions + graph_module_weighting = torch.fx.symbolic_trace( + lambda x, y, z: torch.einsum( + "".join(parse_subscript_weighting), x, y, z + ) + ) + graph_module_features = torch.fx.symbolic_trace( + lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) + ) + + # Optimizing the contractions + graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( + model=graph_module_weighting, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn((num_elements, num_params, self.num_features)), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) + graph_opt_features = opt_einsum_fx.optimize_einsums_full( + model=graph_module_features, + example_inputs=( + torch.randn( + [BATCH_EXAMPLE, self.num_features, num_equivariance] + + [num_ell] * i + ).squeeze(2), + torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), + ), + ) + self.contractions_weighting.append(graph_opt_weighting) + self.contractions_features.append(graph_opt_features) + # Parameters for the product basis + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, self.num_features)) + / num_params + ) + self.weights.append(w) + if not internal_weights: + self.weights = weights[:-1] + self.weights_max = weights[-1] + + def forward(self, x: torch.Tensor, y: torch.Tensor): + out = self.graph_opt_main( + self.U_tensors(self.correlation), + self.weights_max, + x, + y, + ) + for i, (weight, contract_weights, contract_features) in enumerate( + zip(self.weights, self.contractions_weighting, self.contractions_features) + ): + c_tensor = contract_weights( + self.U_tensors(self.correlation - i - 1), + weight, + y, + ) + c_tensor = c_tensor + out + out = contract_features(c_tensor, x) + + return out.view(out.shape[0], -1) + + def U_tensors(self, nu: int): + return dict(self.named_buffers())[f"U_matrix_{nu}"] diff --git a/mace-bench/3rdparty/mace/mace/modules/utils.py b/mace-bench/3rdparty/mace/mace/modules/utils.py index 59da118..6a5a8e0 100644 --- a/mace-bench/3rdparty/mace/mace/modules/utils.py +++ b/mace-bench/3rdparty/mace/mace/modules/utils.py @@ -1,582 +1,582 @@ -########################################################################################### -# Utilities -# Authors: Ilyes Batatia, Gregor Simm and David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging -from typing import Dict, List, NamedTuple, Optional, Tuple - -import numpy as np -import torch -import torch.utils.data -from scipy.constants import c, e - -from mace.tools import to_numpy -from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum -from mace.tools.torch_geometric.batch import Batch - -from .blocks import AtomicEnergiesBlock - - -def compute_forces( - energy: torch.Tensor, positions: torch.Tensor, training: bool = True -) -> torch.Tensor: - grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] - gradient = torch.autograd.grad( - outputs=[energy], # [n_graphs, ] - inputs=[positions], # [n_nodes, 3] - grad_outputs=grad_outputs, - retain_graph=training, # Make sure the graph is not destroyed during training - create_graph=training, # Create graph for second derivative - allow_unused=True, # For complete dissociation turn to true - )[ - 0 - ] # [n_nodes, 3] - if gradient is None: - return torch.zeros_like(positions) - return -1 * gradient - - -def compute_forces_virials( - energy: torch.Tensor, - positions: torch.Tensor, - displacement: torch.Tensor, - cell: torch.Tensor, - training: bool = True, - compute_stress: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] - forces, virials = torch.autograd.grad( - outputs=[energy], # [n_graphs, ] - inputs=[positions, displacement], # [n_nodes, 3] - grad_outputs=grad_outputs, - retain_graph=training, # Make sure the graph is not destroyed during training - create_graph=training, # Create graph for second derivative - allow_unused=True, - ) - stress = torch.zeros_like(displacement) - if compute_stress and virials is not None: - cell = cell.view(-1, 3, 3) - volume = torch.linalg.det(cell).abs().unsqueeze(-1) - stress = virials / volume.view(-1, 1, 1) - stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) - if forces is None: - forces = torch.zeros_like(positions) - if virials is None: - virials = torch.zeros((1, 3, 3)) - - return -1 * forces, -1 * virials, stress - - -def get_symmetric_displacement( - positions: torch.Tensor, - unit_shifts: torch.Tensor, - cell: Optional[torch.Tensor], - edge_index: torch.Tensor, - num_graphs: int, - batch: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if cell is None: - cell = torch.zeros( - num_graphs * 3, - 3, - dtype=positions.dtype, - device=positions.device, - ) - sender = edge_index[0] - displacement = torch.zeros( - (num_graphs, 3, 3), - dtype=positions.dtype, - device=positions.device, - ) - displacement.requires_grad_(True) - symmetric_displacement = 0.5 * ( - displacement + displacement.transpose(-1, -2) - ) # From https://github.com/mir-group/nequip - positions = positions + torch.einsum( - "be,bec->bc", positions, symmetric_displacement[batch] - ) - cell = cell.view(-1, 3, 3) - cell = cell + torch.matmul(cell, symmetric_displacement) - shifts = torch.einsum( - "be,bec->bc", - unit_shifts, - cell[batch[sender]], - ) - return positions, shifts, displacement - - -@torch.jit.unused -def compute_hessians_vmap( - forces: torch.Tensor, - positions: torch.Tensor, -) -> torch.Tensor: - forces_flatten = forces.view(-1) - num_elements = forces_flatten.shape[0] - - def get_vjp(v): - return torch.autograd.grad( - -1 * forces_flatten, - positions, - v, - retain_graph=True, - create_graph=False, - allow_unused=False, - ) - - I_N = torch.eye(num_elements).to(forces.device) - try: - chunk_size = 1 if num_elements < 64 else 16 - gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( - I_N - )[0] - except RuntimeError: - gradient = compute_hessians_loop(forces, positions) - if gradient is None: - return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) - return gradient - - -@torch.jit.unused -def compute_hessians_loop( - forces: torch.Tensor, - positions: torch.Tensor, -) -> torch.Tensor: - hessian = [] - for grad_elem in forces.view(-1): - hess_row = torch.autograd.grad( - outputs=[-1 * grad_elem], - inputs=[positions], - grad_outputs=torch.ones_like(grad_elem), - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] - hess_row = hess_row.detach() # this makes it very slow? but needs less memory - if hess_row is None: - hessian.append(torch.zeros_like(positions)) - else: - hessian.append(hess_row) - hessian = torch.stack(hessian) - return hessian - - -def get_outputs( - energy: torch.Tensor, - positions: torch.Tensor, - cell: torch.Tensor, - displacement: Optional[torch.Tensor], - vectors: Optional[torch.Tensor] = None, - training: bool = False, - compute_force: bool = True, - compute_virials: bool = True, - compute_stress: bool = True, - compute_hessian: bool = False, - compute_edge_forces: bool = False, -) -> Tuple[ - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], -]: - if (compute_virials or compute_stress) and displacement is not None: - forces, virials, stress = compute_forces_virials( - energy=energy, - positions=positions, - displacement=displacement, - cell=cell, - compute_stress=compute_stress, - training=(training or compute_hessian or compute_edge_forces), - ) - elif compute_force: - forces, virials, stress = ( - compute_forces( - energy=energy, - positions=positions, - training=(training or compute_hessian or compute_edge_forces), - ), - None, - None, - ) - else: - forces, virials, stress = (None, None, None) - if compute_hessian: - assert forces is not None, "Forces must be computed to get the hessian" - hessian = compute_hessians_vmap(forces, positions) - else: - hessian = None - if compute_edge_forces and vectors is not None: - edge_forces = compute_forces( - energy=energy, - positions=vectors, - training=(training or compute_hessian), - ) - if edge_forces is not None: - edge_forces = -1 * edge_forces # Match LAMMPS sign convention - else: - edge_forces = None - return forces, virials, stress, hessian, edge_forces - - -def get_atomic_virials_stresses( - edge_forces: torch.Tensor, # [n_edges, 3] - edge_index: torch.Tensor, # [2, n_edges] - vectors: torch.Tensor, # [n_edges, 3] - num_atoms: int, - batch: torch.Tensor, - cell: torch.Tensor, # [n_graphs, 3, 3] -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Compute atomic virials and optionally atomic stresses from edge forces and vectors. - From pobo95 PR #528. - Returns: - Tuple of: - - Atomic virials [num_atoms, 3, 3] - - Atomic stresses [num_atoms, 3, 3] (None if not computed) - """ - edge_virial = torch.einsum("zi,zj->zij", edge_forces, vectors) - atom_virial_sender = scatter_sum( - src=edge_virial, index=edge_index[0], dim=0, dim_size=num_atoms - ) - atom_virial_receiver = scatter_sum( - src=edge_virial, index=edge_index[1], dim=0, dim_size=num_atoms - ) - atom_virial = (atom_virial_sender + atom_virial_receiver) / 2 - atom_virial = (atom_virial + atom_virial.transpose(-1, -2)) / 2 - atom_stress = None - cell = cell.view(-1, 3, 3) - volume = torch.linalg.det(cell).abs().unsqueeze(-1) - atom_volume = volume[batch].view(-1, 1, 1) - atom_stress = atom_virial / atom_volume - atom_stress = torch.where( - torch.abs(atom_stress) < 1e10, atom_stress, torch.zeros_like(atom_stress) - ) - return -1 * atom_virial, atom_stress - - -def get_edge_vectors_and_lengths( - positions: torch.Tensor, # [n_nodes, 3] - edge_index: torch.Tensor, # [2, n_edges] - shifts: torch.Tensor, # [n_edges, 3] - normalize: bool = False, - eps: float = 1e-9, -) -> Tuple[torch.Tensor, torch.Tensor]: - sender = edge_index[0] - receiver = edge_index[1] - vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] - lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] - if normalize: - vectors_normed = vectors / (lengths + eps) - return vectors_normed, lengths - - return vectors, lengths - - -def _check_non_zero(std): - if np.any(std == 0): - logging.warning( - "Standard deviation of the scaling is zero, Changing to no scaling" - ) - std[std == 0] = 1 - return std - - -def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): - out = [] - out.append(x[:, :num_features]) - for i in range(1, num_layers): - out.append( - x[ - :, - i - * (l_max + 1) ** 2 - * num_features : (i * (l_max + 1) ** 2 + 1) - * num_features, - ] - ) - return torch.cat(out, dim=-1) - - -def compute_mean_std_atomic_inter_energy( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - avg_atom_inter_es_list = [] - head_list = [] - - for batch in data_loader: - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), batch.head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - avg_atom_inter_es_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - head_list.append(batch.head) - - avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] - head = torch.cat(head_list, dim=0) # [total_n_graphs] - # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() - # std = to_numpy(torch.std(avg_atom_inter_es)).item() - mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) - std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) - std = _check_non_zero(std) - - return mean, std - - -def _compute_mean_std_atomic_inter_energy( - batch: Batch, - atomic_energies_fn: AtomicEnergiesBlock, -) -> Tuple[torch.Tensor, torch.Tensor]: - head = batch.head - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energies = (batch.energy - graph_e0s) / graph_sizes - return atom_energies - - -def compute_mean_rms_energy_forces( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - atom_energy_list = [] - forces_list = [] - head_list = [] - head_batch = [] - - for batch in data_loader: - head = batch.head - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energy_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - head_list.append(head) - head_batch.append(head[batch.batch]) - - atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] - forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - head = torch.cat(head_list, dim=0) # [total_n_graphs] - head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] - - # mean = to_numpy(torch.mean(atom_energies)).item() - # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() - mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) - rms = to_numpy( - torch.sqrt( - scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) - ) - ) - rms = _check_non_zero(rms) - - return mean, rms - - -def _compute_mean_rms_energy_forces( - batch: Batch, - atomic_energies_fn: AtomicEnergiesBlock, -) -> Tuple[torch.Tensor, torch.Tensor]: - head = batch.head - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } - forces = batch.forces # {[n_graphs*n_atoms,3], } - - return atom_energies, forces - - -def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: - num_neighbors = [] - for batch in data_loader: - _, receivers = batch.edge_index - _, counts = torch.unique(receivers, return_counts=True) - num_neighbors.append(counts) - - avg_num_neighbors = torch.mean( - torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) - ) - return to_numpy(avg_num_neighbors).item() - - -def compute_statistics( - data_loader: torch.utils.data.DataLoader, - atomic_energies: np.ndarray, -) -> Tuple[float, float, float, float]: - atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) - - atom_energy_list = [] - forces_list = [] - num_neighbors = [] - head_list = [] - head_batch = [] - - for batch in data_loader: - head = batch.head - node_e0 = atomic_energies_fn(batch.node_attrs) - graph_e0s = scatter_sum( - src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs - )[torch.arange(batch.num_graphs), head] - graph_sizes = batch.ptr[1:] - batch.ptr[:-1] - atom_energy_list.append( - (batch.energy - graph_e0s) / graph_sizes - ) # {[n_graphs], } - forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } - head_list.append(head) # {[n_graphs], } - head_batch.append(head[batch.batch]) - _, receivers = batch.edge_index - _, counts = torch.unique(receivers, return_counts=True) - num_neighbors.append(counts) - - atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] - forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } - head = torch.cat(head_list, dim=0) # [total_n_graphs] - head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] - - # mean = to_numpy(torch.mean(atom_energies)).item() - mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) - rms = to_numpy( - torch.sqrt( - scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) - ) - ) - - avg_num_neighbors = torch.mean( - torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) - ) - - return to_numpy(avg_num_neighbors).item(), mean, rms - - -def compute_rms_dipoles( - data_loader: torch.utils.data.DataLoader, -) -> Tuple[float, float]: - dipoles_list = [] - for batch in data_loader: - dipoles_list.append(batch.dipole) # {[n_graphs,3], } - - dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } - rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() - rms = _check_non_zero(rms) - return rms - - -def compute_fixed_charge_dipole( - charges: torch.Tensor, - positions: torch.Tensor, - batch: torch.Tensor, - num_graphs: int, -) -> torch.Tensor: - mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] - return scatter_sum( - src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs - ) # [N_graphs,3] - - -class InteractionKwargs(NamedTuple): - lammps_class: Optional[torch.Tensor] - lammps_natoms: Tuple[int, int] = (0, 0) - - -class GraphContext(NamedTuple): - is_lammps: bool - num_graphs: int - num_atoms_arange: torch.Tensor - displacement: Optional[torch.Tensor] - positions: torch.Tensor - vectors: torch.Tensor - lengths: torch.Tensor - cell: torch.Tensor - node_heads: torch.Tensor - interaction_kwargs: InteractionKwargs - - -def prepare_graph( - data: Dict[str, torch.Tensor], - compute_virials: bool = False, - compute_stress: bool = False, - compute_displacement: bool = False, - lammps_mliap: bool = False, -) -> GraphContext: - if torch.jit.is_scripting(): - lammps_mliap = False - - node_heads = ( - data["head"][data["batch"]] - if "head" in data - else torch.zeros_like(data["batch"]) - ) - - if lammps_mliap: - n_real, n_total = data["natoms"][0], data["natoms"][1] - num_graphs = 2 - num_atoms_arange = torch.arange(n_real, device=data["node_attrs"].device) - displacement = None - positions = torch.zeros( - (int(n_real), 3), - dtype=data["vectors"].dtype, - device=data["vectors"].device, - ) - cell = torch.zeros( - (num_graphs, 3, 3), - dtype=data["vectors"].dtype, - device=data["vectors"].device, - ) - vectors = data["vectors"].requires_grad_(True) - lengths = torch.linalg.vector_norm(vectors, dim=1, keepdim=True) - ikw = InteractionKwargs(data["lammps_class"], (n_real, n_total)) - else: - data["positions"].requires_grad_(True) - positions = data["positions"] - cell = data["cell"] - num_atoms_arange = torch.arange(positions.shape[0], device=positions.device) - num_graphs = int(data["ptr"].numel() - 1) - displacement = torch.zeros( - (num_graphs, 3, 3), dtype=positions.dtype, device=positions.device - ) - if compute_virials or compute_stress or compute_displacement: - p, s, displacement = get_symmetric_displacement( - positions=positions, - unit_shifts=data["unit_shifts"], - cell=cell, - edge_index=data["edge_index"], - num_graphs=num_graphs, - batch=data["batch"], - ) - data["positions"], data["shifts"] = p, s - vectors, lengths = get_edge_vectors_and_lengths( - positions=data["positions"], - edge_index=data["edge_index"], - shifts=data["shifts"], - ) - ikw = InteractionKwargs(None, (0, 0)) - - return GraphContext( - is_lammps=lammps_mliap, - num_graphs=num_graphs, - num_atoms_arange=num_atoms_arange, - displacement=displacement, - positions=positions, - vectors=vectors, - lengths=lengths, - cell=cell, - node_heads=node_heads, - interaction_kwargs=ikw, - ) +########################################################################################### +# Utilities +# Authors: Ilyes Batatia, Gregor Simm and David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from typing import Dict, List, NamedTuple, Optional, Tuple + +import numpy as np +import torch +import torch.utils.data +from scipy.constants import c, e + +from mace.tools import to_numpy +from mace.tools.scatter import scatter_mean, scatter_std, scatter_sum +from mace.tools.torch_geometric.batch import Batch + +from .blocks import AtomicEnergiesBlock + + +def compute_forces( + energy: torch.Tensor, positions: torch.Tensor, training: bool = True +) -> torch.Tensor: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + gradient = torch.autograd.grad( + outputs=[energy], # [n_graphs, ] + inputs=[positions], # [n_nodes, 3] + grad_outputs=grad_outputs, + retain_graph=training, # Make sure the graph is not destroyed during training + create_graph=training, # Create graph for second derivative + allow_unused=True, # For complete dissociation turn to true + )[ + 0 + ] # [n_nodes, 3] + if gradient is None: + return torch.zeros_like(positions) + return -1 * gradient + + +def compute_forces_virials( + energy: torch.Tensor, + positions: torch.Tensor, + displacement: torch.Tensor, + cell: torch.Tensor, + training: bool = True, + compute_stress: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] + forces, virials = torch.autograd.grad( + outputs=[energy], # [n_graphs, ] + inputs=[positions, displacement], # [n_nodes, 3] + grad_outputs=grad_outputs, + retain_graph=training, # Make sure the graph is not destroyed during training + create_graph=training, # Create graph for second derivative + allow_unused=True, + ) + stress = torch.zeros_like(displacement) + if compute_stress and virials is not None: + cell = cell.view(-1, 3, 3) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) + stress = virials / volume.view(-1, 1, 1) + stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) + if forces is None: + forces = torch.zeros_like(positions) + if virials is None: + virials = torch.zeros((1, 3, 3)) + + return -1 * forces, -1 * virials, stress + + +def get_symmetric_displacement( + positions: torch.Tensor, + unit_shifts: torch.Tensor, + cell: Optional[torch.Tensor], + edge_index: torch.Tensor, + num_graphs: int, + batch: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if cell is None: + cell = torch.zeros( + num_graphs * 3, + 3, + dtype=positions.dtype, + device=positions.device, + ) + sender = edge_index[0] + displacement = torch.zeros( + (num_graphs, 3, 3), + dtype=positions.dtype, + device=positions.device, + ) + displacement.requires_grad_(True) + symmetric_displacement = 0.5 * ( + displacement + displacement.transpose(-1, -2) + ) # From https://github.com/mir-group/nequip + positions = positions + torch.einsum( + "be,bec->bc", positions, symmetric_displacement[batch] + ) + cell = cell.view(-1, 3, 3) + cell = cell + torch.matmul(cell, symmetric_displacement) + shifts = torch.einsum( + "be,bec->bc", + unit_shifts, + cell[batch[sender]], + ) + return positions, shifts, displacement + + +@torch.jit.unused +def compute_hessians_vmap( + forces: torch.Tensor, + positions: torch.Tensor, +) -> torch.Tensor: + forces_flatten = forces.view(-1) + num_elements = forces_flatten.shape[0] + + def get_vjp(v): + return torch.autograd.grad( + -1 * forces_flatten, + positions, + v, + retain_graph=True, + create_graph=False, + allow_unused=False, + ) + + I_N = torch.eye(num_elements).to(forces.device) + try: + chunk_size = 1 if num_elements < 64 else 16 + gradient = torch.vmap(get_vjp, in_dims=0, out_dims=0, chunk_size=chunk_size)( + I_N + )[0] + except RuntimeError: + gradient = compute_hessians_loop(forces, positions) + if gradient is None: + return torch.zeros((positions.shape[0], forces.shape[0], 3, 3)) + return gradient + + +@torch.jit.unused +def compute_hessians_loop( + forces: torch.Tensor, + positions: torch.Tensor, +) -> torch.Tensor: + hessian = [] + for grad_elem in forces.view(-1): + hess_row = torch.autograd.grad( + outputs=[-1 * grad_elem], + inputs=[positions], + grad_outputs=torch.ones_like(grad_elem), + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + hess_row = hess_row.detach() # this makes it very slow? but needs less memory + if hess_row is None: + hessian.append(torch.zeros_like(positions)) + else: + hessian.append(hess_row) + hessian = torch.stack(hessian) + return hessian + + +def get_outputs( + energy: torch.Tensor, + positions: torch.Tensor, + cell: torch.Tensor, + displacement: Optional[torch.Tensor], + vectors: Optional[torch.Tensor] = None, + training: bool = False, + compute_force: bool = True, + compute_virials: bool = True, + compute_stress: bool = True, + compute_hessian: bool = False, + compute_edge_forces: bool = False, +) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + if (compute_virials or compute_stress) and displacement is not None: + forces, virials, stress = compute_forces_virials( + energy=energy, + positions=positions, + displacement=displacement, + cell=cell, + compute_stress=compute_stress, + training=(training or compute_hessian or compute_edge_forces), + ) + elif compute_force: + forces, virials, stress = ( + compute_forces( + energy=energy, + positions=positions, + training=(training or compute_hessian or compute_edge_forces), + ), + None, + None, + ) + else: + forces, virials, stress = (None, None, None) + if compute_hessian: + assert forces is not None, "Forces must be computed to get the hessian" + hessian = compute_hessians_vmap(forces, positions) + else: + hessian = None + if compute_edge_forces and vectors is not None: + edge_forces = compute_forces( + energy=energy, + positions=vectors, + training=(training or compute_hessian), + ) + if edge_forces is not None: + edge_forces = -1 * edge_forces # Match LAMMPS sign convention + else: + edge_forces = None + return forces, virials, stress, hessian, edge_forces + + +def get_atomic_virials_stresses( + edge_forces: torch.Tensor, # [n_edges, 3] + edge_index: torch.Tensor, # [2, n_edges] + vectors: torch.Tensor, # [n_edges, 3] + num_atoms: int, + batch: torch.Tensor, + cell: torch.Tensor, # [n_graphs, 3, 3] +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute atomic virials and optionally atomic stresses from edge forces and vectors. + From pobo95 PR #528. + Returns: + Tuple of: + - Atomic virials [num_atoms, 3, 3] + - Atomic stresses [num_atoms, 3, 3] (None if not computed) + """ + edge_virial = torch.einsum("zi,zj->zij", edge_forces, vectors) + atom_virial_sender = scatter_sum( + src=edge_virial, index=edge_index[0], dim=0, dim_size=num_atoms + ) + atom_virial_receiver = scatter_sum( + src=edge_virial, index=edge_index[1], dim=0, dim_size=num_atoms + ) + atom_virial = (atom_virial_sender + atom_virial_receiver) / 2 + atom_virial = (atom_virial + atom_virial.transpose(-1, -2)) / 2 + atom_stress = None + cell = cell.view(-1, 3, 3) + volume = torch.linalg.det(cell).abs().unsqueeze(-1) + atom_volume = volume[batch].view(-1, 1, 1) + atom_stress = atom_virial / atom_volume + atom_stress = torch.where( + torch.abs(atom_stress) < 1e10, atom_stress, torch.zeros_like(atom_stress) + ) + return -1 * atom_virial, atom_stress + + +def get_edge_vectors_and_lengths( + positions: torch.Tensor, # [n_nodes, 3] + edge_index: torch.Tensor, # [2, n_edges] + shifts: torch.Tensor, # [n_edges, 3] + normalize: bool = False, + eps: float = 1e-9, +) -> Tuple[torch.Tensor, torch.Tensor]: + sender = edge_index[0] + receiver = edge_index[1] + vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] + lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] + if normalize: + vectors_normed = vectors / (lengths + eps) + return vectors_normed, lengths + + return vectors, lengths + + +def _check_non_zero(std): + if np.any(std == 0): + logging.warning( + "Standard deviation of the scaling is zero, Changing to no scaling" + ) + std[std == 0] = 1 + return std + + +def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max: int): + out = [] + out.append(x[:, :num_features]) + for i in range(1, num_layers): + out.append( + x[ + :, + i + * (l_max + 1) ** 2 + * num_features : (i * (l_max + 1) ** 2 + 1) + * num_features, + ] + ) + return torch.cat(out, dim=-1) + + +def compute_mean_std_atomic_inter_energy( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + avg_atom_inter_es_list = [] + head_list = [] + + for batch in data_loader: + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), batch.head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + avg_atom_inter_es_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + head_list.append(batch.head) + + avg_atom_inter_es = torch.cat(avg_atom_inter_es_list) # [total_n_graphs] + head = torch.cat(head_list, dim=0) # [total_n_graphs] + # mean = to_numpy(torch.mean(avg_atom_inter_es)).item() + # std = to_numpy(torch.std(avg_atom_inter_es)).item() + mean = to_numpy(scatter_mean(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) + std = to_numpy(scatter_std(src=avg_atom_inter_es, index=head, dim=0).squeeze(-1)) + std = _check_non_zero(std) + + return mean, std + + +def _compute_mean_std_atomic_inter_energy( + batch: Batch, + atomic_energies_fn: AtomicEnergiesBlock, +) -> Tuple[torch.Tensor, torch.Tensor]: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energies = (batch.energy - graph_e0s) / graph_sizes + return atom_energies + + +def compute_mean_rms_energy_forces( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + atom_energy_list = [] + forces_list = [] + head_list = [] + head_batch = [] + + for batch in data_loader: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energy_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + head_list.append(head) + head_batch.append(head[batch.batch]) + + atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] + forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + head = torch.cat(head_list, dim=0) # [total_n_graphs] + head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + # rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item() + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + rms = to_numpy( + torch.sqrt( + scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) + ) + ) + rms = _check_non_zero(rms) + + return mean, rms + + +def _compute_mean_rms_energy_forces( + batch: Batch, + atomic_energies_fn: AtomicEnergiesBlock, +) -> Tuple[torch.Tensor, torch.Tensor]: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energies = (batch.energy - graph_e0s) / graph_sizes # {[n_graphs], } + forces = batch.forces # {[n_graphs*n_atoms,3], } + + return atom_energies, forces + + +def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: + num_neighbors = [] + for batch in data_loader: + _, receivers = batch.edge_index + _, counts = torch.unique(receivers, return_counts=True) + num_neighbors.append(counts) + + avg_num_neighbors = torch.mean( + torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) + ) + return to_numpy(avg_num_neighbors).item() + + +def compute_statistics( + data_loader: torch.utils.data.DataLoader, + atomic_energies: np.ndarray, +) -> Tuple[float, float, float, float]: + atomic_energies_fn = AtomicEnergiesBlock(atomic_energies=atomic_energies) + + atom_energy_list = [] + forces_list = [] + num_neighbors = [] + head_list = [] + head_batch = [] + + for batch in data_loader: + head = batch.head + node_e0 = atomic_energies_fn(batch.node_attrs) + graph_e0s = scatter_sum( + src=node_e0, index=batch.batch, dim=0, dim_size=batch.num_graphs + )[torch.arange(batch.num_graphs), head] + graph_sizes = batch.ptr[1:] - batch.ptr[:-1] + atom_energy_list.append( + (batch.energy - graph_e0s) / graph_sizes + ) # {[n_graphs], } + forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], } + head_list.append(head) # {[n_graphs], } + head_batch.append(head[batch.batch]) + _, receivers = batch.edge_index + _, counts = torch.unique(receivers, return_counts=True) + num_neighbors.append(counts) + + atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] + forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } + head = torch.cat(head_list, dim=0) # [total_n_graphs] + head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs] + + # mean = to_numpy(torch.mean(atom_energies)).item() + mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1)) + rms = to_numpy( + torch.sqrt( + scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1) + ) + ) + + avg_num_neighbors = torch.mean( + torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) + ) + + return to_numpy(avg_num_neighbors).item(), mean, rms + + +def compute_rms_dipoles( + data_loader: torch.utils.data.DataLoader, +) -> Tuple[float, float]: + dipoles_list = [] + for batch in data_loader: + dipoles_list.append(batch.dipole) # {[n_graphs,3], } + + dipoles = torch.cat(dipoles_list, dim=0) # {[total_n_graphs,3], } + rms = to_numpy(torch.sqrt(torch.mean(torch.square(dipoles)))).item() + rms = _check_non_zero(rms) + return rms + + +def compute_fixed_charge_dipole( + charges: torch.Tensor, + positions: torch.Tensor, + batch: torch.Tensor, + num_graphs: int, +) -> torch.Tensor: + mu = positions * charges.unsqueeze(-1) / (1e-11 / c / e) # [N_atoms,3] + return scatter_sum( + src=mu, index=batch.unsqueeze(-1), dim=0, dim_size=num_graphs + ) # [N_graphs,3] + + +class InteractionKwargs(NamedTuple): + lammps_class: Optional[torch.Tensor] + lammps_natoms: Tuple[int, int] = (0, 0) + + +class GraphContext(NamedTuple): + is_lammps: bool + num_graphs: int + num_atoms_arange: torch.Tensor + displacement: Optional[torch.Tensor] + positions: torch.Tensor + vectors: torch.Tensor + lengths: torch.Tensor + cell: torch.Tensor + node_heads: torch.Tensor + interaction_kwargs: InteractionKwargs + + +def prepare_graph( + data: Dict[str, torch.Tensor], + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + lammps_mliap: bool = False, +) -> GraphContext: + if torch.jit.is_scripting(): + lammps_mliap = False + + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) + + if lammps_mliap: + n_real, n_total = data["natoms"][0], data["natoms"][1] + num_graphs = 2 + num_atoms_arange = torch.arange(n_real, device=data["node_attrs"].device) + displacement = None + positions = torch.zeros( + (int(n_real), 3), + dtype=data["vectors"].dtype, + device=data["vectors"].device, + ) + cell = torch.zeros( + (num_graphs, 3, 3), + dtype=data["vectors"].dtype, + device=data["vectors"].device, + ) + vectors = data["vectors"].requires_grad_(True) + lengths = torch.linalg.vector_norm(vectors, dim=1, keepdim=True) + ikw = InteractionKwargs(data["lammps_class"], (n_real, n_total)) + else: + data["positions"].requires_grad_(True) + positions = data["positions"] + cell = data["cell"] + num_atoms_arange = torch.arange(positions.shape[0], device=positions.device) + num_graphs = int(data["ptr"].numel() - 1) + displacement = torch.zeros( + (num_graphs, 3, 3), dtype=positions.dtype, device=positions.device + ) + if compute_virials or compute_stress or compute_displacement: + p, s, displacement = get_symmetric_displacement( + positions=positions, + unit_shifts=data["unit_shifts"], + cell=cell, + edge_index=data["edge_index"], + num_graphs=num_graphs, + batch=data["batch"], + ) + data["positions"], data["shifts"] = p, s + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + ikw = InteractionKwargs(None, (0, 0)) + + return GraphContext( + is_lammps=lammps_mliap, + num_graphs=num_graphs, + num_atoms_arange=num_atoms_arange, + displacement=displacement, + positions=positions, + vectors=vectors, + lengths=lengths, + cell=cell, + node_heads=node_heads, + interaction_kwargs=ikw, + ) diff --git a/mace-bench/3rdparty/mace/mace/modules/wrapper_ops.py b/mace-bench/3rdparty/mace/mace/modules/wrapper_ops.py index ee03ef7..ca05219 100644 --- a/mace-bench/3rdparty/mace/mace/modules/wrapper_ops.py +++ b/mace-bench/3rdparty/mace/mace/modules/wrapper_ops.py @@ -1,192 +1,192 @@ -""" -Wrapper class for o3.Linear that optionally uses cuet.Linear -""" - -import dataclasses -from typing import List, Optional - -import torch -from e3nn import o3 - -from mace.modules.symmetric_contraction import SymmetricContraction -from mace.tools.cg import O3_e3nn - -try: - import cuequivariance as cue - import cuequivariance_torch as cuet - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - - -@dataclasses.dataclass -class CuEquivarianceConfig: - """Configuration for cuequivariance acceleration""" - - enabled: bool = False - layout: str = "mul_ir" # One of: mul_ir, ir_mul - layout_str: str = "mul_ir" - group: str = "O3" - optimize_all: bool = False # Set to True to enable all optimizations - optimize_linear: bool = False - optimize_channelwise: bool = False - optimize_symmetric: bool = False - optimize_fctp: bool = False - - def __post_init__(self): - if self.enabled and CUET_AVAILABLE: - self.layout_str = self.layout - self.layout = getattr(cue, self.layout) - self.group = ( - O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) - ) - if not CUET_AVAILABLE: - self.enabled = False - - -class Linear: - """Returns either a cuet.Linear or o3.Linear based on config""" - - def __new__( - cls, - irreps_in: o3.Irreps, - irreps_out: o3.Irreps, - shared_weights: bool = True, - internal_weights: bool = True, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_linear) - ): - return cuet.Linear( - cue.Irreps(cueq_config.group, irreps_in), - cue.Irreps(cueq_config.group, irreps_out), - layout=cueq_config.layout, - shared_weights=shared_weights, - use_fallback=True, - ) - - return o3.Linear( - irreps_in, - irreps_out, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) - - -class TensorProduct: - """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" - - def __new__( - cls, - irreps_in1: o3.Irreps, - irreps_in2: o3.Irreps, - irreps_out: o3.Irreps, - instructions: Optional[List] = None, - shared_weights: bool = False, - internal_weights: bool = False, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_channelwise) - ): - return cuet.ChannelWiseTensorProduct( - cue.Irreps(cueq_config.group, irreps_in1), - cue.Irreps(cueq_config.group, irreps_in2), - cue.Irreps(cueq_config.group, irreps_out), - layout=cueq_config.layout, - shared_weights=shared_weights, - internal_weights=internal_weights, - dtype=torch.get_default_dtype(), - math_dtype=torch.get_default_dtype(), - ) - - return o3.TensorProduct( - irreps_in1, - irreps_in2, - irreps_out, - instructions=instructions, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) - - -class FullyConnectedTensorProduct: - """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" - - def __new__( - cls, - irreps_in1: o3.Irreps, - irreps_in2: o3.Irreps, - irreps_out: o3.Irreps, - shared_weights: bool = True, - internal_weights: bool = True, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_fctp) - ): - return cuet.FullyConnectedTensorProduct( - cue.Irreps(cueq_config.group, irreps_in1), - cue.Irreps(cueq_config.group, irreps_in2), - cue.Irreps(cueq_config.group, irreps_out), - layout=cueq_config.layout, - shared_weights=shared_weights, - internal_weights=internal_weights, - use_fallback=True, - ) - - return o3.FullyConnectedTensorProduct( - irreps_in1, - irreps_in2, - irreps_out, - shared_weights=shared_weights, - internal_weights=internal_weights, - ) - - -class SymmetricContractionWrapper: - """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" - - def __new__( - cls, - irreps_in: o3.Irreps, - irreps_out: o3.Irreps, - correlation: int, - num_elements: Optional[int] = None, - cueq_config: Optional[CuEquivarianceConfig] = None, - ): - if ( - CUET_AVAILABLE - and cueq_config is not None - and cueq_config.enabled - and (cueq_config.optimize_all or cueq_config.optimize_symmetric) - ): - return cuet.SymmetricContraction( - cue.Irreps(cueq_config.group, irreps_in), - cue.Irreps(cueq_config.group, irreps_out), - layout_in=cue.ir_mul, - layout_out=cueq_config.layout, - contraction_degree=correlation, - num_elements=num_elements, - original_mace=True, - dtype=torch.get_default_dtype(), - math_dtype=torch.get_default_dtype(), - ) - - return SymmetricContraction( - irreps_in=irreps_in, - irreps_out=irreps_out, - correlation=correlation, - num_elements=num_elements, - ) +""" +Wrapper class for o3.Linear that optionally uses cuet.Linear +""" + +import dataclasses +from typing import List, Optional + +import torch +from e3nn import o3 + +from mace.modules.symmetric_contraction import SymmetricContraction +from mace.tools.cg import O3_e3nn + +try: + import cuequivariance as cue + import cuequivariance_torch as cuet + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + + +@dataclasses.dataclass +class CuEquivarianceConfig: + """Configuration for cuequivariance acceleration""" + + enabled: bool = False + layout: str = "mul_ir" # One of: mul_ir, ir_mul + layout_str: str = "mul_ir" + group: str = "O3" + optimize_all: bool = False # Set to True to enable all optimizations + optimize_linear: bool = False + optimize_channelwise: bool = False + optimize_symmetric: bool = False + optimize_fctp: bool = False + + def __post_init__(self): + if self.enabled and CUET_AVAILABLE: + self.layout_str = self.layout + self.layout = getattr(cue, self.layout) + self.group = ( + O3_e3nn if self.group == "O3_e3nn" else getattr(cue, self.group) + ) + if not CUET_AVAILABLE: + self.enabled = False + + +class Linear: + """Returns either a cuet.Linear or o3.Linear based on config""" + + def __new__( + cls, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_linear) + ): + return cuet.Linear( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + use_fallback=True, + ) + + return o3.Linear( + irreps_in, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class TensorProduct: + """Wrapper around o3.TensorProduct/cuet.ChannelwiseTensorProduct""" + + def __new__( + cls, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + instructions: Optional[List] = None, + shared_weights: bool = False, + internal_weights: bool = False, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_channelwise) + ): + return cuet.ChannelWiseTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + dtype=torch.get_default_dtype(), + math_dtype=torch.get_default_dtype(), + ) + + return o3.TensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + instructions=instructions, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class FullyConnectedTensorProduct: + """Wrapper around o3.FullyConnectedTensorProduct/cuet.FullyConnectedTensorProduct""" + + def __new__( + cls, + irreps_in1: o3.Irreps, + irreps_in2: o3.Irreps, + irreps_out: o3.Irreps, + shared_weights: bool = True, + internal_weights: bool = True, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_fctp) + ): + return cuet.FullyConnectedTensorProduct( + cue.Irreps(cueq_config.group, irreps_in1), + cue.Irreps(cueq_config.group, irreps_in2), + cue.Irreps(cueq_config.group, irreps_out), + layout=cueq_config.layout, + shared_weights=shared_weights, + internal_weights=internal_weights, + use_fallback=True, + ) + + return o3.FullyConnectedTensorProduct( + irreps_in1, + irreps_in2, + irreps_out, + shared_weights=shared_weights, + internal_weights=internal_weights, + ) + + +class SymmetricContractionWrapper: + """Wrapper around SymmetricContraction/cuet.SymmetricContraction""" + + def __new__( + cls, + irreps_in: o3.Irreps, + irreps_out: o3.Irreps, + correlation: int, + num_elements: Optional[int] = None, + cueq_config: Optional[CuEquivarianceConfig] = None, + ): + if ( + CUET_AVAILABLE + and cueq_config is not None + and cueq_config.enabled + and (cueq_config.optimize_all or cueq_config.optimize_symmetric) + ): + return cuet.SymmetricContraction( + cue.Irreps(cueq_config.group, irreps_in), + cue.Irreps(cueq_config.group, irreps_out), + layout_in=cue.ir_mul, + layout_out=cueq_config.layout, + contraction_degree=correlation, + num_elements=num_elements, + original_mace=True, + dtype=torch.get_default_dtype(), + math_dtype=torch.get_default_dtype(), + ) + + return SymmetricContraction( + irreps_in=irreps_in, + irreps_out=irreps_out, + correlation=correlation, + num_elements=num_elements, + ) diff --git a/mace-bench/3rdparty/mace/mace/tools/__init__.py b/mace-bench/3rdparty/mace/mace/tools/__init__.py index 5dda7f3..0fa6b07 100644 --- a/mace-bench/3rdparty/mace/mace/tools/__init__.py +++ b/mace-bench/3rdparty/mace/mace/tools/__init__.py @@ -1,73 +1,73 @@ -from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser -from .arg_parser_tools import check_args -from .cg import U_matrix_real -from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState -from .default_keys import DefaultKeys -from .finetuning_utils import load_foundations, load_foundations_elements -from .torch_tools import ( - TensorDict, - cartesian_to_spherical, - count_parameters, - init_device, - init_wandb, - set_default_dtype, - set_seeds, - spherical_to_cartesian, - to_numpy, - to_one_hot, - voigt_to_matrix, -) -from .train import SWAContainer, evaluate, train -from .utils import ( - AtomicNumberTable, - MetricsLogger, - atomic_numbers_to_indices, - compute_c, - compute_mae, - compute_q95, - compute_rel_mae, - compute_rel_rmse, - compute_rmse, - get_atomic_number_table_from_zs, - get_tag, - setup_logger, -) - -__all__ = [ - "TensorDict", - "AtomicNumberTable", - "atomic_numbers_to_indices", - "to_numpy", - "to_one_hot", - "build_default_arg_parser", - "check_args", - "DefaultKeys", - "set_seeds", - "init_device", - "setup_logger", - "get_tag", - "count_parameters", - "MetricsLogger", - "get_atomic_number_table_from_zs", - "train", - "evaluate", - "SWAContainer", - "CheckpointHandler", - "CheckpointIO", - "CheckpointState", - "set_default_dtype", - "compute_mae", - "compute_rel_mae", - "compute_rmse", - "compute_rel_rmse", - "compute_q95", - "compute_c", - "U_matrix_real", - "spherical_to_cartesian", - "cartesian_to_spherical", - "voigt_to_matrix", - "init_wandb", - "load_foundations", - "load_foundations_elements", - "build_preprocess_arg_parser", -] +from .arg_parser import build_default_arg_parser, build_preprocess_arg_parser +from .arg_parser_tools import check_args +from .cg import U_matrix_real +from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState +from .default_keys import DefaultKeys +from .finetuning_utils import load_foundations, load_foundations_elements +from .torch_tools import ( + TensorDict, + cartesian_to_spherical, + count_parameters, + init_device, + init_wandb, + set_default_dtype, + set_seeds, + spherical_to_cartesian, + to_numpy, + to_one_hot, + voigt_to_matrix, +) +from .train import SWAContainer, evaluate, train +from .utils import ( + AtomicNumberTable, + MetricsLogger, + atomic_numbers_to_indices, + compute_c, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, + get_atomic_number_table_from_zs, + get_tag, + setup_logger, +) + +__all__ = [ + "TensorDict", + "AtomicNumberTable", + "atomic_numbers_to_indices", + "to_numpy", + "to_one_hot", + "build_default_arg_parser", + "check_args", + "DefaultKeys", + "set_seeds", + "init_device", + "setup_logger", + "get_tag", + "count_parameters", + "MetricsLogger", + "get_atomic_number_table_from_zs", + "train", + "evaluate", + "SWAContainer", + "CheckpointHandler", + "CheckpointIO", + "CheckpointState", + "set_default_dtype", + "compute_mae", + "compute_rel_mae", + "compute_rmse", + "compute_rel_rmse", + "compute_q95", + "compute_c", + "U_matrix_real", + "spherical_to_cartesian", + "cartesian_to_spherical", + "voigt_to_matrix", + "init_wandb", + "load_foundations", + "load_foundations_elements", + "build_preprocess_arg_parser", +] diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 6c34fb5d37ef0a4f1790cd0a65bdb7075a3659c9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1507 zcmZ9MOOG5i5XZY`AM@(jnc4U3`@HXR?P}HOv?Blf{>Qt?pNmoS7a1QuHE(?F zU-P#|7G<%FFoqawP>ac~$vK>d`7oZ73%Cf2VLUJE*nkE$p@}VM1>b^fV**JSFUk%s z!4h_%8|=F5VITUq49mCzE4T`)xCU#u4(qr98@LIZIDi3e!4_`AcKD_tcW@VWaS!%z z2X=z5DfjUJ4)72T@d%Fa7>@A-PVf{?@gCg6`*0s0zyln@5Ff%rJcF}18b?{{w-_E} z?aOGG{1f|9hc&l3Ds!2!OkD9&TE_L5O}REgyX|Ubs>M{RR2cKY5Xgo zktzh6DC@e{Dj!=~S&8KJNiJVKy-p!&$^YlfDBcWgN zkrb};TzH97^Sv66{YL{HR?vQaiJ|a$k@-4|Yo-dPr4=l7?b{)8F*APt;xpI1ozX)2 z#005ReMJt9M!4k8L)QM-*Z;pwnWeVK3TB)1ppjlbfd zWawS_8(_;KFUG92xsP_2hihA<)5-fyf`%dGph!?D7j@Kf!=hGWIQ3boW{IT+nM3GL;Khc@*M0*o_6@s4U%wJ^RmF$V{qsEu0vQmn# a=!a#fxL%aRNxbUcUi@ji+W95XbrCMcb1ld69QpUgTI=>IOj%0g3`aa&V|D2$Z5nf_|cW#aUDowFjCk-8m65+)tTY5Fzf8;PM#N_z>84iB`CSrOsC9! z@SQ!|sqi_N<5j3SKi{eGI@I|*%<~0U;ES-xmtcuEpuv}6nXkYKUxih^25WpB*4>&y zXM=CTCf|ZBz5yF9R_tu^CN%jD?C@RK<$JKl_hFy6pv4d1fZu^T{16U#8`}IX+~xP+ zUXvZMZt0B&_xt6x|EGtowsUsnBnzXS=&4hgMMg;Ii=hOqz^;!*hMl#JR_w*HichiRCYW0~|KOz2NS&tKc>#OoSq6ecM@8&%O! z=Ne0^MXBtG(=KcjBNxGjMAY@gh_=S4itoe zjBG8yY=%0Nu`&v@Er&^Hu>G?zP`2cp@3Gz!Ti1%>BUyUpe5h=ZLRzVwwvF4Iw88Ce zwt!xe#lv%3LMu&_7^KEl&(g4OXy!;H)W`3B_vC4s7#SugO+lT>C_@$Oi~$U_eOrHG z(l`uWW$}rExAG)Xw(?S;#DRX1_WQVKRk|6pA5PJfK1_PphqlEajfa_0BCzGL8A~;B zK0N-)R>vctA}maiI-^w-fLJT*-$uyW?qmP|+f*1TOq@a*i;vpoX@Ze`>!VWHP(-fK z_C8ebvcWoGldwc+5T*z-1bUsYK&TN82rWXDutC@-tPxfTD}+3uKqwOC3EPAn!X9Cl z&?IaT@bNQPB+L<(33Wo5;1i}1ui6i6Y4RLKmC}H%O`@oI8fn*n%>{jz`{rf1tN&Sf z)*oM}_jtu!%eHokH^5{`nDj+vLdUKRl_zC3g8Vk0>Hjs#v@e1+l*d^d%!Kk~3 z$>}ep-yr^)@;vX7y}V$D7i{yA9e-lS7p!&3o?qvDFL%wx_}f`)%zOBWeR-X$(AXGP QWc{<%D(8O8!$ZvaFVx?>$p8QV diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser.cpython-310.pyc deleted file mode 100644 index f9591a2a4a813f6ad62488e98974052c57245440..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 15563 zcmcIr3w#_$dB2xVZ&|YCSNz&I4we$1U$LFUacoPrw6}a?P2lR*zme#v=^Hijt3nVs-9_AbZa>d2msw@je=gYyoZe?H%L2%S92

+u8)Vjcsd!Y-c;XbTP(uzBkTxG3`3Wf%9-X+YQV+ z*q&m7?NQ)&vKO(vY={llxxI_s4a|KCZyC1V>p~wJJvzV+qK}6Z$YFL5Z}$k=J*v<@ z#_nbJfuH*UKVEO+0d~UcW)CX$PqI^J^)xU?fH_*{`wTmamc|s!adr-nhZM+p_Tr{K zJj^E8Bxp@VTD`!g(Z(Zy&qUgot;6Rc@JHEWXz8NT(o5K-CPv@0 zMM!|X@6 zG#Yd5Ci_wLW9)sHx9?-`ulN2r_Tw1upTKzkqyqgZ_S3x8M(v+rA83O7Ec>~p7Ji<6 zunF=D>=(VX*K@h|a-YK6FR@?d&~*{$udrX`(EbSY*VwOf=s*Pe8|*hZbbSQ+A@*Aw zx*-DnF#Bx|9gIMKhy5=4^Lq+czt29xJ_=d<0c6pzFn`GY2(tc|g7`&RB zT~AB9_~+~|n#SZWS-}1Zu)k)1qu_ml{cRK8-?2}!zX$A7>>m`oe`NpEg!gIo&up3F z{fmP4uk15Tc%Nne#y-bkpI7j{!2TWm{ST!d|H=LfSpTg+zR3OuIR6V7tVHJjm)I>} zep$hMp1t68zJ*rozEG2pi{nYG0#5>-inawv8=f?tcCQ<02cFIbzsPmr>E>lUkyd&E z?}MFAHP#wH*5SeGh*Vi0kz0>v1E(?=p|X*~sSleX@Xdg1;rLr4@Y?|2hG#pi5cu7J zXJoTv}yvf7tavqU^v3TU4Y-srK_x|fMxLP zUj=>u@PpAg3&q(2*A0Q9#c`?Ft}g@rN>W*tTxxA&jZm_OkJ7wMT zLcmH#CL@sNRy6UVQ7z50HN9ALjH2!tn$0~{_hvFkc}rCz+h@8trHLR_GOX&&xYLg2 z{d6W%(aVOv&bV4N9J6dxJdFUu38x3;L5E|!qGmb^a``f8fBIrr!FKp zs?$yQ?M%4GD_xn)qU~JCG1FPu>l&Vh7bEZ4&XR^+X}V=`)?E$M0I`_ixIuis-Gz_02i$QX&e-^7r>Wy&tds;?1+y} zG1K)NbH3(b1n=}5-K=Q%Uq(mG%(>}DHMokp2Ia%J7(#qOmy6Lda-LH&#^o3XX+2Nd zA;AoWfriR577WYpq~YPOLF)X}_yy9es%3lrmT9D#l$>T4G~T*dDf;~aBUd0DGb;Hd z|KcnZg=%>y*Xr=&+TMeEx!TwjYe~~R%_|w20mUIrqMmq0(QvdWng~=~14~e}Db1+b z`I5Uoh#w>&7>;8*IZvOr4FBFah>zIeK#c>ns$;WS-e8d?z%g12xgPEex~C0i#IwsY zle1&l&Nn@m%satM#vn-h$8H*NlH43sFoDxv7U1!azZ5e{TcqefG zSjBLPON~NjH|a#XoDF^#*LscgB1lb+oEaNVXVaq>=B5nKxdWKlzS+EP8M7s`;8Ah5 zzmXSc*#YX8nHNpC9BDvpm4k-mgeQ%3ZbPyH7MxsJzq+!8^h~eh4J}{u>_R~sf|iX` zio>4<8vX`P(>ORfgIIE0fp=0$;~;?AyzZK=R;X3DqG5|Y4>=kZ+T4r!K}S)qxvr^K zf=@wiaz8-)w>&2n|S1! z>6UHVE6HvzQ{aN@w{AK)$EemU7t_YStL7S*Wu}AqAYmG;ds$mFQ8zm}p;c`UCK%H3 z2gRf%VPX#I6_*wR{}55A3%L&GEUY2H)RbK@AioTF%BDw(6~i^Nsad<=(OA&jam1|l zx02;2*(gA5-m>#of;}^aZlSL2dS-qG6j)8S&Nu3g64r2=v!5E-t0Iqh-dgys*Dccn@?=B~)66Be;9mlYAvXcHTB?D81C5{Y>q327Qr7syy zsAr0}U4K(U&s6v_#g+V?OlERoI%lGds_PGxOvVh@Kb|Gd$Wj4;Y&$#%GP^pnYhJeP z(6Orr4XQ4u1Cad9izb8G$!b!2sG#Si$ByjN#Pf1mb3{8~vpZyaw2n&-Dp-y+%zV#q z{XQ}8@+G}eF)Y{L-sq5l$;TMk_IfE>HPFTs2hgU>C`cbTnRrH}&av zy_5bvc&c@Q;5@>Afxdcz#jSHR!|Tjd!DQ*EU{1Kg+`skV*77>=3;iS0I;Ib#-f$AR z9VZ0gPI6Zo-h>M)kB067Nq-X*p4_e+oDFCe*v?U*e|*%&nh$RV9l&5xR~AbK6hvqV zC(6CFMUdkZ>IBH{kj9b24L91~Jyl=jB#4vawaI72h>|I&nw4DHD8nW9w@s5b$lqH~ z!@;c&&I-mw=Y9g|IhIp*Wzz)KeHs%1m&>e6Gq#LA`Fkd*X9o9Qv?{lQ*h0|sR>sh8 z!_@Sce|#2Gmgx@TzOv97pmG)E8jkf0{yZ5unm-W|aQVp$fg2cQF|r3PyD%4gpM-6Y z7ix4JtC@`zurYSfUD#1b`9*3#is^ow}wec zJ1N({eVhX{3D8*0q3MK8ijnbZG_}ZgD#1#-D~G8)YKdo$(*`eO)~qmdfiJmIuV~W* z9{tTynRpenUcq?;QS}dhWo;6kA|^3cB&v;l0c59ig@FFvDd`cApb0|iPA&+kr_u|u zG>c^Kyj<1G@v4U}@;sL=Dpzq-LSqd9ayoMGEZc2(dqr3d{C)lukt2n6^>WpM^{smT zP(!CBG#vjX%4RY9l!?iejt{xom>fx2R~l-p5&qqeMXf!0Bot<~+F0-ur4!rw z`t&*RuwhhrtRMu1Vr|qBd;L{rmcKERxq|w+aOUR7IJy35PBa8^i&eKJWFj$2`$)$N zHH0W2Ye)soH*6NHv%{w>Y^N|~xhuxf$}o-si7m$f61K6lGsQj)#(`^SC`YW8v?_9~ zba}$(w2&7Pj~`rv?mHztH$^abdBajVQcrWIMS@*wYEtn z>yurT1Nq5tor{>-S0KSVC9YgNWKlC)G@kAivzjw#g@h^QcE#7JR7+71`0bJerbV_rG+0=x$Qt=Qp;t(`phK87gAI2D38Ub=6w@W!Ek}{U z5kb3@evlU84AMeAL0hSY#WCoFMMgNn)vfHHR2x}c;AQ)*L)KwYA_{0qNfgn#zft^> znu!99n3MMM(rI!#-9c2U=PUT{-T>k;@RI7`Zp-lI(nImqY>YQrMFyZlovCBBdt^ znXkDgLX@Z9d|GrV+6Jt0IR7N`n;P;9=N{J`Y}!d8P1v`o1JP7=ZgomuX9r??PG{Gz z!U3c&4MD5)r4iu#OFfF>Mx-AtQlzht`I?y*LXVOL%hCw7$~-hcD$h~m2+z@?(T2oB zI$xjON64{(IU%Evj?#k&|#9gsWW9G7kKdJb3v?xkszTEKefS|-fc`NH4-!1Q) z>^^>I|G_87+vk=ghY+}@`#bp$58lEHGaxX%!zbf%Z#rW9x*w^3#>+v6hg#%dOfp~3g$^9H!kMv*o#L0zl2 zSp4Wm=f)l@sziafqjN39i_$HhQ1!<$=QZ(@W6il|`-!>r^X2cNb z{G-rll;ZC+9i$&uw}!}{TDumlJYUuJ6>r`A!pA4)r%%D|*IN(+iGj)MOa4}=XgsG4 zO&rF~2Zs~h(s1$mXFqV`iyyvmlEO@S*)1aY;&Z=RX0QpQ5D4E6)Fns*7|P!92j*np zvq3Wy|Q zeW^$$V(>W(6xKS5$IyFzDF(hvtvCn_gpXDP(XHzAFaM%N!C2;bJI zPy~iteF1Sl%gU+nL4vkmW1JA;<4Q?W9Uj(@%r|reCuj&&h+?Q>coIrO##_Pxq_53# z>XLmZ+&m^4Vnp!p5H%qTXj0yS+s%#C=5rJ$rc>fGfFXiZ)L@!x1{pHuEajY9#m~&t zDx6=}LzG&pRY8?+E#*Yup)BmO$j5Z|({cv0h%>odzGl3_-#x}NI7x|wg7thJdqAQ|RKaJ@{9%v9J$)hGW7o;7^i8fl#UqNfs7haK{1y7HF&3S}%5N*ZEC1QEH z8qGuSZB_DEDX|t}y<0SKA^?eT9bu#%f;>WVm$5(hcdf#kEaS1kAV%>8QHRcbyd{6< zDpf=guR(!3flr4u{_fSVqeuCqf}#eU)|RkpAVhd@4wH(%*H?bNRiaZ|tbVOjhr5Fw zG74efKHElSVj}2P&NzaEFAqoBJW zGQ2tH6aEl(jNsNag1!-_SSv#+)BL(bFnDQXa^eiXMuBjNA)y_0{2d{_bfnyeVE)I$ z%M~~rU@S<&{S3O#m#enpjnP4I&;xJ|GSmjLin^kDZ@6KQb{Jq-W=5#%A$*_WbmPiLoH{(Adb?p#984BQxiK zw}DU194DB=0W}ZZq-Ym=GvO@q26(rfn=$9j$gW&EP_oO$0jvTg-M(z{`C5IDjAlfZ>7>5@$?I+q)69*w+VO} z^wM8SwXc>#$*uM2R6?O3@u)dJW3; zsQqhdORX7wS9xzxS8<|zMsq6O+V358J{25Q!c?xv`^zag4pDe>E+UWr7v*LOYOAuW zlYYg&21l}NRar%8ti@9Z=b9yXRo{6FN*_Y}Rp3&Zsy=+9-qDv^DV6>}RF60}E2)le z7L@3((&L6+Dq2Fk4q~?V@mbx&Q=0oJP~dg}e2Ovw<9$mDh?<5qNaFNP^@&0U>Z|#* z%!mAim|;DtL+@&v0@930Ef8jd=e21H%w`I!;I}I{obOd|c>h8^1-{Cm`i>F4m0It- z74|ldr~ViDu4ELsZ{Bw${6-{wTM~Z7LEo0p{rO%6*2mU)o%q*vqx;zy>)#f88h2^K z+r5mwH0fc3ks2G>rWPujaeFq+wkkEnww00>F4fkf#xHH>m4_}zz);8{q zUtXs`@s$aO_D7(&k;|b25$G_xi$m8(pm(!<9J(O_r7u&cwZRCKzD{YD(jkT~R3L>T zb;%rM$8c}2!#yEGI zBHYpUD-rJSozp|&5~`DimN|6uY0l}SSBk?py_7tF#2=gsgZFekE|0~X(^P(f$~Q*J z=~@6C`HY9h((;_zpCg|W+g1mAPC9mD%8pXv&`zHpFz%y6^F_LELx<<62j5d(+#=ep zMjL+w`}Z%C2j&b=@*pKADLF;S2ogj|ur=bx`XO7+1}Yq+WFuiM%+SN+v0pqT~W4)08|y$qXg4l+01` zC?$_ka*2{GC8Ly_K|&Fi@ahJ3dGbKp-;J{lx*S3=dI{sal7!^hm9v9r?4+bbgbp|2 zWh7fkRKu>1;slCZ<$aZ!<1Y78?RzP?j}p4&;vA>sE=tZ)GKPdUeQv2%DB#d<`>c53 zgH_+{g*>ez>|-<7v&oyooNKpgnc}Kvy#hsb;6*E!MO=q-bCDW-2_@r{oI^6anS001 zRRSzg;!{F*lAUXmJVD7TDS48TS0M=!HbO4g*5XQZt>QUPQ2|$>r|In(N?t?BYbkjh zC9kLCjg-8ak~dItof7iJMe-KB1>1z{3{8^COnV3ETK6U4Jx}4n_gD9x8$LBwpt>i_ z*OYn=E1lx8tc@quFsd7_Nm2iyLFF6g$3(elw7_|3_;G9Yt(9lV{VIKS-cCb6gO;bu z^x+@9lQpA>1_29H8atiAh~3? zxHg9?WkK8v;!8pN&>X;tij7?_g~5V&A&6I;+tJdvzd3r-grY8aGR&yaWx8esmI`6y zASHIqmc57z*LAKC02~fD@bE=^5CoS1NCFf=@S7qVxjC=?2M@b{TxuP6Ta1&`;O^kDu~ z8#j-B1UH}ZFc0%y_Q?9LSN7t$>T=anKG`RDn2-4{`=6?ot8ribyl0}pL}L9Os$w-= zo)`RbtqVjghX}C1^L2poJs%_@4*6Lf&#z~}E|1(0_YT%SB*f!`UM+AnmrrIhV(Ot6 zImoC#C(Ve3RQ|Fwr)V{@lrPAcC_$Tlb-3~2dZe~fzk1Lw=DF<2c&>XUD1XHta{Vi^ z@Hkat=1m05JL{KK9jnTCuX`uztvvIN`KC(TVSZN4YL58@4`7>22}J-!ZrBdixo+4lwq@P0-E6A` zJM3XSPkY%mCLDwGI&gVA>$PgIKGr{6#rj=q>|i_DE;hghUB+WKd(4J)k1LOd*xsqu z(m4BFVTaj1n@sF?g+0IyS~P$chro-&u01=#j^bYGcF@?mYuWpanu$2d3M1Dd(jp45_@XRal6dM*cH$?z6J-bvTL?pJ&hU@ zYxIh(sxi4njqB_gTi>2_?c00UbL)nEFPpN#KJSY6ee8Q|{fxN6{sSi3V5ePSqbz2F zb=N@>o3YXB4jX5Qb;Bmv3pN^VU<^}hj7XYgY&Ei|kz1q2i%ho76vdS%dBCn$4;0u< z8|;89-dT3b2J5b0=9p^h;Q)KdMGJf{dzpF4p#P!EtbRZHfQ=L12mSd$7wGS2AF^To z0T(OzgY3gLNY5I0Kf*q0sBBDH}ihYw-FXvrlgT_fNB*DZ}mN8-0eowgJqa zWj|}<-_N;@`OmXou*t_Sy2Abvd)-E{`;5KKewqCW`y4!x&#_;1k)zMEUxN<%0(8(n zafSPJ_D{<=;Kr>lvM;S0_MfqTzV1H#3--pkVgDuj4U2}U4%5?^=rZrD--&MP-n?@i zhdb@SUAnQ4dcGY-n4MD-*c6z-)CQ8U&T!P1I$D>zxfZ@ zyiK<(xZ-?^eU1IQGU$KL{==~v$92qoR#aqmv9H^pAFw}i#r?A`>*Vw4f>nxzq#W6@9fWP(C)bZoPC@9MH%#xEAF@1 zJ2v@vjFi3TxaBa87oO$rvyT(}b0j7Vg)iq#t;94_!0+?O62Fm(6xJF-hqsC(v)p&A^x}bar z*WL})7)FhKw%NJX?AZ_a1Go;ZDRGAYb=abXxvmredj!|fb@2Ofz#Ur$+;PA?u@1Ns zfO~Qsa3=wGYHGJ>ec3nH7`@YgA6W<9Gk`n0UOy7B=Wva#1MhjjUD&|LUIhFl3m#8T ztf zd+p_w;^DQ&yczE{&xQK&!i?uh-!{*TcVGem`*E4D-LBF)qS?$$G9H;prld$tlod&i zL{p-oXq#rTa$3yi&5Y&1(2x#1xh&{60@V`oSwW1)WhpM^B_Ydq1(e%9&7YskNm^r4 ziP)OZ{0S+QyX}=}`$Mf68p?=iN!=vf%1LrEEoJfoWeiryVU%kDS$eUMl%?1*AQef9 zsr9m$iDlCQJv@Agy5QrWGBHl8jp zNYNX)Y%-HqICM`rv{h`N4h;>-l7c-H9lWg^+$d5v4GqPzvzb&@jOlPING%nr7_pdi zGZ~fCkRqkP06~8o^lQFouAup&g_x-M(m6%*-O3eYTEPwa?F@}3vZ$`e6i?<2>kN*> zXmHgg&RRW7;8|Mj4Cu-qJYd`$Ii#Xfv80lhlhcJfMsP=77LyqP|7mnIIdtLLbs;9^ zMFGNxaglWPX+Z^}B}MXbK{{`Yu~s8SY0F6`2E#x@m6C2sDYcG=cOXKISl4Rvl z*}U3u4XI#oPRPy(y!B)zu7-7p$PCGtl!?x%&oT%K)ykt>DCHj)cJ156#U`6c%?V;& z$R{L0g5Z!OQBU$xT#|)xng~=~07($EaY4#uqX}iN=G{k3kYqV4NAluyN>Yzbf_+2| z&lGs3kdw2qLR5;CH35pzno0FwomPKMl1K8{^u!f5s?|+q^J0qMcG&X4ZAnpXCgr4< zQnYrX ziIgcUunZF9NLswL+(B|CW}=c1E#$K^Gr|C*Y$OvO{1(9Q*Ls@9K^9|RON0w>52Z8? zdZsWfDoI6{DP*{yVHV&nXG%=55=2;gZRTQk$XL z3h9W^7{{v$$eD(uIJWnYRJd7H??Xhd`j}O91&v%Xm2@_nPZ-^PNPD+jzQv>*k)>QA zrC{2qy9dnme*StJO}@{KOkurZAn#MsH~S6Ox!h z-K>&NMkjzFRuEH{th#3iYOq=!p+%!91EEFc}Y&J`<3|`1$lbUpLT!=~F-S{81&DEpObU58YELGM{v|UK;PSdZ zQPx|(l1vpau9v_vI6-12q4~sVh0lOx46K5o0};zxP2FQ;koh1Ef@3)_tyxYh5p)vP zH)24*sj64lteHU)CfDBAWw4-Q4N<=fR5gKYfdEtyf<;!v1AUaTxqqv|t>ty#7pjM@iI_g%y5%HtJx&0+JIP&XcoPb= zJQ~^!lo_AP(na@ zha`?1Zn)8E?|8||8Ce`1uT4HHMwC=RE}4m>r8Hb}b;~vK2KoIeP;+oggR_ir5xJj0 zat?Eoz-KR0p^9?hr(2Px^Pip@a>Y2p-7a_;>AeIm`y=`NtR%2@BW9o5+DH{{z z826QRsR1Y#Q7+-IXYl7q$S}xxuQiAXjIz`Vd7QL zdIo2{2&;O)U7B=H5tEn;64l1K0KBtxg)-G$!AoTo%iD`XbKr036rm%5R>X8J1?`*5s{@uy8@cBA zufUuP?}8o$0xmC+bRsg7kECSv$a&J_y2hpXM^uo?u!h3yn<&SPb+~|fFlS0(T3@^o zAx~bxq|9ZW-^hq-~X`7f_==hM^8#6}IsB0){*a-EpXUerbdZY_XrP?rfHlfp(^`+@^ z;<6;=c&xz8GTCa?(bxJFW|rDEG;{;?Bj(JFkaAMg(;TUp$u(9fKmL@_GZS>KUc0Dxy8z6L4rxFRWB8##f*Y&FO`!y9eV)vx?W+Vh~xH z>~=nopA6SIj;VbE9Nb~R8pgvYT1gg-r*b2ii%4W4p^CYlQPd~TX47dA!6^z~kb+Jb zwuqXB<;@CI%O)W3+YJ<$7Fw;Tuwbi58>tPZS4dAVRIN7toY4K=jttHvN~x zOfL|KI744PcbeQzWmYtuRijcTjXwP5NPafEqUd}@e1bRw>qUON%vB5Qa9k25XS2cp zG|bIpwxA&V0_QmeE~C9ct!Y)OrYpz2IC2>@e~jE4bjj%a>MrR8yh~D52FnHx-bb=5pXNn#+sM zSR6O;{N3cf4DB+slM!X-rdw}BSw77WcZHe8%#B(8q|LV@qi6tgbHl>}Iwl^cnc2!^A{RI-~hUkc&YXC7yd;3&?yR=)80Tj zsUq~yPcIMfTMkdUQ?G4+aGyc*i!m_`H|4`Lk*n!pR%=iYd5aZNP=*p;lTS;aLO6qT z1rgM>Qj77geue)2>=aqSX)zy7L=;e>2Kl`}_Ye7&1mbqOGhqX2v$+%&yy}O_mesB? z!!RR;Nc$f%R-+Vur|Dq$aV2et=TpwL$nvbI`{I9f`|Y=`OkX>t)mSa)1Bro&igRkG zA!z(b7#KT%l@B&2@^gdnkN^3XUi-`6eeEQLnZ&dbNAN}Eezoy{MIePh_;R4cK^j1l z_p;hFX#`G7Sz6Ns#w99^O0MDfu86mi=|b8TJJ(B+9t6uNB4WyldUQ1? zc&tbWq{9U=c4J6hjDaJ9xemviw>1`E&)CB^$_8*B#qIM^lySD~CDJTi+Lb(`so8aVWgrxTfzMrSF+%NHTm zPNSl>QVw6%*mDsWQpB5x`=wG5dwwhBu7bumA;iaplBPO5tO4WMk`?TrAylCkLln)O zP#QAc5;h=JfpP2x`4G5iOf0dS1VNQ}foM~~3l3s|44+bYmzl=0Z0 z=Arn4UWfL5@^flWg(`XxuR(!39iBF8)ZR+a8}+Ztk_la_u6& zLg1k?e_IP-_$cVEBN=Sh8g+jND@IVO;7G%W951B7m23RCgx31p$d$3P{1^qoC6bXF zD8X+v;SEPhxd`=t-aK4^&4E}9>OU`Q^_SAQtehXEjbg0<*%9F2EonZ?c`X2-1Jel9 zYSny~U8apbt!8w5bmGEuS~WC6LDFi@U!6ER%Cy>NE=^n-8Dm-vo17SB_&s;&+SRd9 z&3|!phM(8-}kpK z_}gC%%

q%ZKLa;bPGrG;?2j`Od+6KAt_cR8s;k23m`uBM(C1Vz>qPsyl$Mx@M`t z2GBC!b?{E#m-9N$lqps>*f(^(7&L_$;7tC!}(JB#5?2f+4X=umOf z7Q$4C6|aU~zl@$9;Y^(PhRNGvAY2UXdJt+WhV6Z{Q_xim&8*_F9RMw#2f#a3!KUES zW=~Vk(l*|eFuJnS)|IAWsIwT}{7zLZ;B*=r0Mca#neRAor|q84#xVPEp(TAhI<)XX z2R?KcL&Kchriy$h1^{r}6#yMR=EBKhpr;r*jgDyH-40EF&H>)Uk$0+U&{=(GS9W%* z4XU~XRS%X}Gol!P1fRaE%uil-;>mo^sk_m6W*0I$&${RPp19k0??s2|n~R~NoYZ!Q zz5rk+2k3DCfQ0Vh08RqbQ4Bq{3cx-|AsG8y0U)FLxP#}z4&&ZplF_e)&pN<>;|Cvv znu}p4+1SAWR!jyvsgMnH766rJSus!ouyZ496LFK_O`9F~34pzrjl>YZRMrY;j89@N(bKGSJMqQ;EQ+NzLg?$chFNQlDBp24OAJh%`Z{BB@2bjJF~z|7ulk|8XFh5O;13*ns(WPs*FCh(dxgSTSm> zG=&4w^-fiN$-l!bj@!+cDMST;((>680asCp5e5BsoVelv;d0}E$GqXcMA8S4z%b!dfL$Wj#zZ* z#|Uk=jgWo*V7y^TNl7{Yfa$vSgQ?x`V0|#maIqW}&Z~{DKY!Qv<^*>&ua(^R#X$E$ z-|@RGZ%$Y2<6@w3p>6-2Y6o9^G0?mi=vWMmqv6GH>td*_;&1?FVcYS$Z41{b`$LO? zmc`I8_wFilU@_3P&^`Rc=)&1&9LPjDeAG89`>2b7o`oIox89T;YWJXd-i69Kaxt(8 z6ue#KX?$Om?t?Cdn%}N82>`Yj*s@3#nrzi#sKbFFO^cyH7k8Fg3=O#gw3#$Jv8-!h zJM}s5&}R@w3j_>XMWKUWESjd$0ib^|bZ&7|x0Bc|209mnk+1mXuR9GITIh4_qp%i0 z*J9|XD+csZ7p%%wmd1-bX~DS@5X}>-ny}o3#X$QaIa@2dPfIT@1~Ar7uEJo=e@;V0 z7Z?XlcP@tZ^08`m0Kgn1Uwkp_1hAPC=>mYUIQ1aZ^r)(#WF?xQ-}^lc%^$e(op&Sc zo`&{E9$$6vK~v}brh&Jb21;Rtw|!O3!H?A}`B1v7(DA(wj#OoWo^J>DWRCm3jdyI0 z`=ZD#{gpm-^*Ma5>i_*uU0JU&;gwVR>{Yeni_G-dE9)~_9(?%94Z4wSvR<2;YBBj^ zy&^V2oAfLGdOe4Y92OqJhrABseI@#1O)ei>pwDKNK1IM<*PyG7ZQejH+VNti)#pxE zI?6wX)$IzqmG#(Y-sTD`;H6KChF(`#d?w3+ohp5nq0&qC6BRzJ;EEfc*Q$tH;6qfF zY#29QgRO|Ww+tUX7Ogq^X`ebyEB92>Q$u#L!`?}8-J2h%E!$p6)sLw)0o=F6dj7xgzoWqK1y zrmrYY%0!SHqJ+MbC{G|!TPMvEb#!RmIIAu{OXZBp+sex6tUjH?Ja3*AHx7@hle7_p zy<{01!E_cjX4~d9U}*cDpZeWHC)sD|B^Wxl+=@WE@i>9E_aVK@Ah%I6 zNy&9eo}uJDNN{uk`<48xy-AkbPK9JWdBv^dStQWBa|Ga%5$5C6`L-?_ZWG&`}F)ZQ0)7P-OZLo1CR2N6Gsr`5q*L?R?8h z{%Oki8A?7w$!nB+mXe>P{NXh3Z`E?}57Mc7-+-qC(eJzNQA<6cwt=|AdBi`ClS|P08^{uT| zTjT?Qh(6u*E;c#s_5h~egmH@fVhJAw%!6>tkDvL8u}_Zu*!XKtFSPFbT6kAQkDZ64 z*Apz}+e~f&O++u;H`@?9ptVEMVQEkdG`Vl@tQoxkd;b59-EAlwhj8nffpctC?RQxa z76V|Z9;LP~=d3j+_Mg_Vui_LkUx5d^KrW$fM3oSPhV(dvu{RXhbdJ93tIUFbgR9K3 zKF?SB_FQT4eYM$(l&hntu^@sD@pwp6L6P9A%1CIbIaMmXURpYF_ks5B`5(S5(MR8w zxQ;I{AD{JjL@W`epojFoi5m0H`t1d5BfdL*u&T5e^GH?OJaJFSaqx}vHyqV|dOHuV zTO~7GX&Wz;4e;{?^uiv09~dXyNX=?Netd@?c;K2EuYemT9_V}_&&RhSNChj@D}ifP z`3tz^$s43W$sUmgt*vf-u&MQx>mU2x`M{R%Jg9Gbgl+lOVtx3P^B+6)ptPTdFM3m{z+t4WwJO}LMP2M?~LZn$U8vId9}U# zWiC3{6eTwcHG)YcnStMdlbV`eKbVrr&f=Q}-0J=&s>{EH#6EW1U{Z3I#_lQHe8BUd zw&|4}pKiTh+p|#HQ{2@0k%nSz!-rq`zL&oL<@t^yU#mS*+|;rvr}aUw`F^nftziHB zj+0*to|^ZcdIW5q!I0ki8@ShMa2lF?G#W6i!8~Guw>}iz!gD|2ODL4dAnV=%xBpzk zh|Jp_`KwgKO?u!24hQ1w@kx0Bc_bAId1gyKuh;vi(c=w&-Bb5Ro-Kds*|Ajfa=o|t zwSoJ+#};~zz3ZW;rPCdWDrCOr{*Dt1J5Idop{J#1ssrBESCu6XeqOuvh3e0~|6RI$ Mw7J%M+RNMiTmOALh5!Hn diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/arg_parser_tools.cpython-310.pyc deleted file mode 100644 index dedac14ba795f34b778dbb67a1fbe4a526b4e244..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2874 zcmb_e+iP4!7@s-k?AdEBO}A~*iwp&^!8U0PEmCT*L@OyJim8Z8<#>8#cK2+~&6%@J zcc~zz=!*}<*IGzklsxz%2qNN(e}M1j)u;ICo6yGZo4sVy8Z79-%=vEfn{U2b5|53Q z2>ibLVY~VL6d}JOGx{@tc^#HI0~;rt2Bg;e(i#Qc2#nCInUrvoTlY!L;&heF=61lE ztjyBPh!;GHO#g>?HmoIB>b^-xgM32okrp7giJVMn%h)u5x45xo+#?PFI;V9rCuv#68dBvR#qC?*lT(CMf>nM% zqvQP(RDRn((WCGLP-$gz4CQI1HMThpRC$Y5$>zisX-#e(+9F%n)C=S@$!Ar$1QHrH^>IvpuDnXsBf+lKDI$WqtS0eU;e;5wJ%7O z`1qQoegeA*w9~rrwVe7DB$Ft4K7YZBLkD`XGxXw1z{hCP2UMPdGgAkhc|eslIF%md zhyT%Td=K10WSkwKue8G>vxfz}Dqmbw0{R$wh8Mx_}z9o9_b$SBt zXr1K8juSpzB^%fm(`bQG^akv)lCVRjp>nV`Pd+4%tiAft`V)IRKdEnrimMyYucx{{ zmKgk&n!;OSew4IR=E_Dp6j7?qvUbAVRIs!uLgv?5Bs`&%D?6+nOLoT%e9qj6GcS&m zuTnVEgMgDLmLhG-h`B6nCxJk#)b%4jYUnEgl$}XpJmtGV(3xdn%>BAA08=)UMsQgG zxGb)-Aa1Y_FK-Anr_Zls76?^V_R#kv_QB&UFNJCcsj3xuydDLy%hjyeow)GIt9Q=} z_A*1_LcH@S451yIPTk8_g8=cHKIbBGd?`f&?nmv=@tSTFiJ<36=-zd1fMbHGM0ncG z?xpt>Jmw(VG-jR@7`}bjmsv*Vo|%j=56LR5KVcM;Ce;0hIhc;LMZPjVhCFSPZJL>3 zJJ>cdWA1#G(|$^v;ClPZnQz9Sm=E1<(~Voc3zR3$B{FUaFP%?1X)}(TrPX(O^4lWv zn)3^iC$3C8TGrdrI1bc2gx>*;64KFvxula7uPBjvahNP#LN{@mPm%5^9eW4!HFqV5 zJvUH`b3;p+2dnPq7~x?o!#3oz>E6*r#grQg#(k(Z<;M}!o9f9A2>+VstjCfo2C?m> zp@X%F8MrMZrg3-=)*q@MZ+_Sg*$t`F{o+1k?aXb#l34kv4?aN+iH4Bf)4g;L^idq0 zi$ud6wK*__vN*_!EF(2GWOceGkv#<4RU|Oo{~77c{|#w*fF$ux_08_Z)f5_z-CU2E z4}IW40HC0uK^0Re)Q$stLpUku*FnZb-H!y{vEc9ymUQ>h(`JK^2GGlpAUDDcIhYlZ z#9k9xuMwxLYqCYweQoddbsvf-0*%Il@$#$#i%%h4n>|&_!-R0SFX6utm05|KHZ!XQ zlOiiX9j$3i5vGm^kgQFNOcE_9_os zJ{XIz!!=p4AL4a4GZk#s^2)WhVOW4_e1B|bjm*MUtyM5AgK4@qMI_c{R@&=n>|gaW zY*wyhP0`I?vd4t-!Rj@&;fhDt1 zS}{%=2DNq#vr3I!!?2%BnbWiiSOd$#G%kgw3AbR6x6HxG2ST`=T1 uc))Bgxp4|o!L#+|X3 zV4+R+r5m(SBf47weSuV!LfW*6`jkqQ+CDW~sag+CWos5`k-B}z8(s3UZ#{R+*oF*g zs-#l1qw(A`_ndpqx!*bG-s{k^HV(NY@ff8(w{u=lehx`FyY6(6OcI91+g3K~|l0u0ti5Qi9Oe0%MCbS&*PK zI^VulpN$1RW7O0O>^e0%-@aD6rUJV;gu!k}-kwIiR=ef`ySW&GUC76s7t^Z7sC;KX zK`nwLA_#v`9C8ohXw~>Y=W+mOyVl=p9=@(`yVo$jt99VLwVQq5RX|4X)hzgG zG%DX7z?r8rs~3;*x$_z{3iRk;(6LpCO_0aFI!!j;c5IP8l<(6hAIxm$rO9i0w;^lr z9o#7o5hxzR%ZPgU4rx@rJz~^!A>y3A28(`NoP?ijH@&>}vBR3x7o)4jL?;I|Kg)>o zM#?Z}0jnXNSt$(boggf6Pfc^3pGXlaB2dk9sq6_l?dC<8YgP1&sd< z)|WRptH)qR&t5OiiT`WG+3^-$xaz-}FTQuQIpy1J`-ow9Oux#;@9b8s4AiRrnB6MV z@BANU^c|_w;os`x#4PWI?f9p=adulH{no;GzC^ZsN3aEC+wwWBQRA>@`(1!Y?)0+D zE?rWv6Cc$^=i6=jJmEW%=V+96IG}%{Moh;SrXJS2W12jUldnsA%eVg0COG^1_uw;6 z!3ncklhB{1zAn6r(hT^ce=;f4*99k&a(#VHqTO78lQw0|HE6ftq;Iw2$0MAJ_5Ci@c^aa@oj%R~U@n?XiChXDCono5N zFr)xXf#&fX0@!ahDlj1k+oTXf2@~o@i4%G2Qyjp|14q>d8Tq^e-8(d$Zrra}AiDxV zPJ|B}l46e1v5;cE0$HYzfb>FzFF`2F`#xdN2e8=3UC0M7x6I@f4Gp&z>tkq*6Qk)rWU7`W)^2wybp%%4Lv%Ms_VaLenRe3#j0cr+;wZU?xoJf&ZXmv$5+e` z9QPa#r&F~(SZuE@%)4irgwJG;0ELNlXo5oUDa5Lw-)w}nB2CuHWbJy%;95!5ylc+2QDkyBpP7oRwhb!J z-qZKEJl#Jg_m3s}E+#MjC|PppIr(Q!!clVH_`v#W>qAev%`3NgpBy@sDmlG+=+r`d zzWTFj`Oqo3xgl~Oi&}n%_2-I)Zn_+k+#;z)|ypo!^)8dUH7`= z!#x>SS=xnU7g}yjxjad7DA_QSf~U&^oZE@n#KI3hxt=7wN#sqD-kaBPJYE%?XVmMO zO|5K&`+Hwgu2%vlLk&JFCfxOFo`$>qzwBQLrkhX7%_pCDdQ#N8tDc^Pj@zTRMrBWr zOuf5lGP-T+Hv7MejBW>xbu|pF+&e$|v(bfEx~^5OYfV+Or7Jq+iq4znpV~53$L*R~ zfW2OD{s$C?jj%bhhHID2Y1As?8yns&+!d8j$+lOvjSI5` zFC^+BQJ*QR0<~m@EZ3@`7iTv$M-^0q6U$%w_*%y9y4^k7E!&~&(Q+vUtW)-MK5CUc z#~=Mz_MA+)dv1Dfo!ydK@TT{*$opDS_LdB_H%-;aRNZ>z;Ny`mE`5F}IS9oRNL~)6 zFH7=eDH(_*%j0Y1we!Ps!(h5}4-}%q35DpW-Wpl%*JJN>gT=meyDMpTC(Z6}zv(v; zrM-r4zp2bNq}DjzGf4Q`<>v=0#*F0O+bhN_15NUhyeQ!vB~yOa_DDOTzgjVf`0TwplD04EsfV G=>HE;)-1ENbYJi zsP5_Me*N`6e(yaDtJRW$=U@K)(N5td!}uTic>XEk<874qtIRNjAy{a%^EYcV{mt7P zZyuTht8ESJwmop#4$El@ZCBG2+eN(1urw&Q%gmVMy_I%VSnb*mji_uHa%uk?%%GpJ z+qEAV5xZwd-mlM=g(C`we716E2=}hhUX{kIA&R1O$lGUhyDTbbuNA9BR%GJ5s7YIv z`!Ad|#FAX^TB0tNAK2})`8X(E5i8P>>qOJ#pjtil)sQdhd9FAk)-dlSSyK;1Q>=@# z4_N!$JBD~sy!5~jFUe~E~ste6H@l`^v;V3J1hJ(>Cl2MWs z0;S|AjJe34YFg zY)mWDs@}VpR#KL7>|J7G zezKP&rG7bG?pM+pN)d(bVog_4CS2V1pVAeKnxeRByk+R#5_*4)-df&EH=@yo9#7b; zD#~enRuk4COIHp}%vtF#(HcD_DxwNbSBR&-Vl?v??EA*YSobNX-kGj~tMxg3?HGME zrx#198L*Gpw4r&rO6xR?7YuN0Rakh>zcrUv(uSy~wHAY?N)3 z|6N1XgXlgeOj%;gSjrCBFZdp&Ju+Hc5nbk7*IN6CXZ+%qGAGwzwTvNM{r~%yc7_9a zY2Z(G{9!-v(dx)IM{3xYo#fJJKiL^Z-sbjudH)?5b#^XoC^3SG?CZX~N`}KQzSQYm z9PMYN%OR|j4Cl%lbWA$)nlU-=-3xk=R9@7-csU$){4lpMN- zs)3P=<4#ZSqlhQ!1yp}YjdfIyiGKhITO|7a^&CDWOpOpK%7lMBH=9OSK{LDG&dNGO?$4^Nd&J?2P8$SRR@4@0B4z~4(Wu^!Q zY88|xj++%P_hcmRkJObdT899If7mc4FQ3Y>1?}e^L^WxTSE+i9s#CU9U%6xwks3$W!4ug(V72Tx)JF_%km!2)2 z8$fZcbsPPRg_$FRC>{^8LO1ji0De|bGT!k=SaBlNFji$+(SaXGq-dQ{HGFDaQRk_p zHH1JuGnGs-CNmaf=3pEu_RAvpIi5 z4~JtcyP2_Z#;%>Y1sA`VOYOKN?X0bX-1;09LqUL|HtP-~vfBLI`YB~8h zXX~8ixon*^Km)2H%s8@`dsH#!-<&_OOy^JBaz1?x^RN!*n8Eo9!^jg};D2zB_#f(P z{J-j7)Thdd`IWsO%1ZMl-K5+Z0)D`1a+C#}dhw9KQsAT=9=+W*o8}98d!; z3T}nLExahWFdrBR7Z$laxKk^2;7$$IOw6g1I_j?y3&lFX z*3?C))U=GGuz!;o(;@0EN+8RH~{hDM=)dRxrd-CbwvR| z56ZLzXIW0%S+R%nPAjPkM|c&kvP#Y}t)}d*p+3T!+}cNp+b_YbF3rk^#;hXTLk5@N ziXxnQB`rHr!d1fYrIj9sqbNVW2If}ND(BPr3eytSd7icm zCsNvaOrUbis>v7^o+KtX(4M45&;|Dlz}8%^3LY_Gc6shsTGnHlms|Zgv>qSf$?ZRf zWA(Q*ehMDnL%&9VFBp?oZ+QbhQNevrNk2Tz8$2e<2Ht;tr(;Ce2_l3grM@Qe0@F>Bq7O^H2ORY<9+2 zD!A*bz>hleZdRKpN_4(35%O(~eko3xgb2cenS@qraFZYw7;D~_XiQ(d4VEUWL* z?oXL>0-@UFaV+(AUZ$CsI8r#6_<+0`JXzDegO+yS$^ z;YPa<`Z3X+5NTFDCc-}E#B`=ovCdrJgbQpToFUAp5y~)!tsHR^SaJk3dE$P^|5zzn zti)=(&dy=9t`Q7~#vES340RjJDRPJ%6V>7&|3g5cfn$aSOwc6oNO)pW9DT^A8cg&z z)89~N80K4uSS`fWCZK=;Fw7YMtG7XUkjzMIlwwP7Iti?)PAjj3TWl@DpYRbdLnT3Q7*&*t$ zWDXqqSVbMex5cx-NUdxXOXleu)P&7A`I$qV!eK%4vthaedL1&6lr^}P$hjdhr05X$ ze7FErhEVovAj;UNW$KgZPql+S^&(ZL;OW1iU*qXHWAgf3V=j6mde1uzQL*xmX%GJH zMpl^9VP8SZ%Tx9slm0D=-ecXIES=axvu&e8er54cn{KVym7tLtGhz}f2yZ@rn}^Ia zxjAK1J~cHrc%LeJS?9%6nV>+G}CZ6{n7>Y?WC3BrkYKvp#gP#crNPZof@hwz8 z)krEJLEz^)v=@h=Ff3)1lMCx1ryNK3nkRerV5u7Yo&^gz=us+wKXZx!GB2GrW_D*d zjuNf-aQS-}&u@0(J}%9mzkhyD=PRM&AJ+I9B3bLnq`tVmAVZ|bvnoH%WK#q&5sQmB9%a^BTPp|4ZB}8VT=+G*nio}ygOcuFH9g?+c$FHLCPJJJgSsBq5QaT-IWd#}C4b%|+ zwTB$e4u|)oQtuO4iF86FW5kL&5!Cr*9_jr9ZMZ>1W&mB)!3m+e3aHZF#2zUcI)&|- zdpXy>DP_z*?Nk|>FkXzFD{6d$4_N&<8;M`VJ5z#|C?JWvNv zGG*6h-6jLsX$-1SNT!YjrD;SqZbYS>O7>(XN}Gq1X=6h0M97&*n8ulorZX*Q%59YX zXurL~0Td0@Gi`T>eSY`tx8Lr5ujQ)2phb|bwayKz$`SevZpcZKt*lm%2wg>9x)r>lo95|mC9iZ!Fe+Xx!f1F6l++Pzw~p62(FSQ$&l|i7-q=ri zobrVE(o>tS%QdgktJ(#W;N;mTm22}|*@joW7izqvA9)(qj!X3##?eG)GXU0!3qy;WmN z_){Uw{bk||YOf_bLwUbZUZNj%&s*-b!U(}v*b%I=qzrj&{RYpTPg1*!lKS`!Lxc!h zieYo>-OyhJ-`jT)5|k4S5!I;Qq(Xim6?T{Aja8`UmxE|ORfPr#SH%Mi*yExE_2DSX z`=eoAP#g_Lq5{ou!7%3{1Tx%oszWF`=5A~EG2V7x8`FN$*VZWz+g%EQ*dZvQ{5UI6 z6Kq5vCIrfN>ck~=cx;Q)&kscerSEXh3Fg!#OO)jzVSbA<67};@T!hC2RgmY|u}DNP zaAAJb9|~UbM`6x@LLy$wO_1Xyul^Jka1}Agf=(+UWFv}@7S!cHh*R|Cz&+VRmg|5@ z4Pq2BJ5G?tfY~}=1P9+SvA=(S0g;W(LU+USdx# zLjJ2T6B>tY;S`OCQ}(H4NRN1p<(5=pnmzi*`*4e)ykslMsk{`YW)!ggL4?L(&0|;l zFrUQ4Xyln;+ynDz`-$JED8Njq|ESSlH19dN3Xh@!E76gdTZEG0(I5DW=CP_)9@ zJBh9z!Rp~`jQtx|Byz+c_TH9x()-)3K!(^i4Ko$PegoGTDEGoX_u)|*ER-kb!K#1@v5xz!*H$EA@ z@UCWY0S_d6YtGyTc^&x@#DJ)0N#s|7bu|)RO&MbJ z79KJM}v76wm z`TN7D8B1ZR@kztC2t#-Yu~zt|W#RI-4Nij9t7ml0PWNB_ey)Wb+ zJv-n(AXF6OT6PqlnF!XqlPGaAasF$kR-uluQRmnRR|&6#mY_T=+9U!kmI4)GIe`YN zC@?Hg>;%hW9U;)XpBrKYIuZ&7Sl%dhiH41!5Dgn`o(w^g-N&X8fd~moHpoTBM+Mbj z$R7opNKo-?WY|9j$S$#bIKsod0wTavin0Sh6IScuec-VY7J0{4Vh}RO3lz^r1%efb zpg@g|hXleWkdcs}&mK-+h#ic|Egv5o8jfOL<@HSml40#3suTG6BJEh%HmUtUa{{sRB{#H$lCwx5}t^A{&O7mY1x zl{Kky{(SrXl>NZB4RbYbn6KHtVo%h(S#iq{Gc4K*2HXS zV^NG+(igve_|?Pl!Re8ykvTf4-xed%dP9u(ptSOGXRK|hq->_?%V**QuIxzE+^D}^ zKRc2tamC4Wk>!ok33r0O@%;7Y7tFUyZ`!a;lWj}&+tX@m+E$DIjFyiz^d@CeT@zEy zbY%3%V2lw)B*iiqR^GK!QTh1dj3PB@XGhdJswAm`Y zJ-R>`BnaGY?8&I$6FK3!=*d_izq;1$#By|90y`>ab#cTR8?`WO&f91N3mMxYoQff{ zp2sGzaj^&$lb4jO2Jm{+Ij;vigm8>ed#SA;1icE;>V%oR^l`$gY$m~P0+}rOO$zWl zlmUS1!{Sm*oRU#CWb1yA8F}k0*9)xc(`7+WFVN^!wNiE9nD=3MC#K#=Z~HE+yC88I zM&ngrqNSYEfhVBv!*+p~JXrOuOEsu5#|6&7Xu$jVV%|UK{q-W)IMS z6{(mQgZw4%ACJO$Vf4=xGloV|^rH0t7iu^|@4s@zobK2c0ILzoxU`?^X3*y>HT58> zMtxYdi3tJ-Pr<%u3+yfSp+2#<5wZ#_0Z!}I)}g#5?^VldWV9@VpYZq~kqbs>dmi8e z@Bpj+2l$|i3ZED4od1I4BKhOE_XK2@o~>1+$JXCr~(t({?hz|T1Y zVRmpZ7zl!?7=cIx>*T^wXK-{Z#EwE2JFvwGReywaf;x`+Lm_8i98d)>`1zop3$S~g zW5KZ}jEM%2%P5))Jht@tosbE-wXqQ+1W8W$+^5}#kGc;YZ5NE#a2=1bKE{V961Apa z=?O;<=Sgbkd5|32F)pPPEfJ`IA5?Y(*2rFB1c>n277H1Jv1Y;b=q;oF|#LCTtC^pWGWVI z1V_qLyJ)OSypS~Roa|gyndAPHs&wg@Lw7yjnMm0WzcX>CI#HeQU*9;ll`vg*sm7#X{m))`W|@&uW3ZeIem z3c7PFM}a6qe_%4mTpQGZ32_ST0VVz>Uhe!@xHY&X>D8}Q`ROxP*XRfignMb zmL27Z?|fiZ1?@&hhGj zB};X>$Z_Y~3PoCtA0yIW1VwJD$&kpUfinZQ?kAibUZ9|JWGI4G8r-J1`z#2Yiu<(NPwp zOTZ^dNFVsvFcS5Fj0!@u83zl%5u*J(@1NkY$`>6rQL6H1py^Vt6Pg0Sh+<9@cp?Hk zv09d5qQ5B1I?3(n#aMCnCj27rK{kofdTjYyOQ*Z9biZ-&Eo$!YH;pN4OH8v=VxK;9 z<;>SrZ&~Kfe6uQ5(i+o#P*QcdJ?4g?8S|e$7wZ5==Jl@2UGdZJ>noN@%kG!D@0GgY z4lj`3Y5MM-xA(wJwj2N1Xt~U0Xk;h>W83b$U%KgD>83gIrgC1nSjsH4-rj#}|3c&4 zBgvh|p}puib+43(b)`YRUbuW=ruelN(&e=Y_w3%dIy>E$nCbp2{fUND$x}bmZ&|T` z*8`HVs5qwPLDQoE{BlXY9IJ0Y*HzSy@Ea1m9}zr+Ht`D){1p8PFTy5Uv(iiBR}P-; zSLQ~QB1FN1-mLbjFuZJ5!JAVarxviOMQl32mY;*f;lwrH=)GFrFr;(od8}<7dk>m23jBkRz{Ym1bcG;bmJgkg5HKTq=1CqbS0oT)E+rqtHS#!UW*%udW%|Lo-j>bB;V=1$#To9H<^j7-3#Pq}`lBWU(1;8^ng5zVh_U$%NvD`nr0q?q2Nv}QB-y<7c;d?>UZ~GQ<ctdaQU<0CTY9^m>d5FT$dw??7Vmo{e z!eW(HE0+Snn?!j7qq znM2n)W;+%wwTUPWEzL!5UYNhISid`I*}bUWEwO5?biwv!E&Nu$C{~=GD)iw^EN4-I z6$z0Ph{fa-0lept@IFe2FLuxi?~d<5g!QzIuIV`qN0!oSNM-HI93LaF;k7f#b^8?nVoXa4>9aP$nc}Vhb0@PeOG*m zlRqoj1YuKKWr9eYoU2(Vo_AdTOw!hLuc%4t^vvyCsGfIS|7_CM{C-jMM>MiEJc?k^ z_llcx(*zG(fqx({3TmH^^N+Hyky_9tV3vh&UGMXqANPl{BPO435FE&8C557+%EC?z*CKaDF}gdK6nj)sk)QCGT={Rf+hz!FPLFB8DeI8{f*<%xL2Ya6CHqGF3E>@72D(M?Jr(2M7b1_-Sblz&u$3Kzm6Mm zoK4~E3eLWavnTNge=RES&m%#t7iqaP_;_H%8Rpl8N_^tuh%*=wLoEIa{$R*|HpFgo zwmlY6ary%RHiVBR#6_JKgAj5#%Q^>w5ez)A4Wva0%-+)jE0HQaE`mP=T`ngMsfI%E zk_KN!_}8Eh-j2Z7hbO}zHzXY)EW3DoGx_`=01)2j2>A6VLlFeAYD9$P161?@()N^7k5HC6XCRSEUt=4VoxXD5|#_RBV= z&9=1GxT0w%h^Nwq;&{!pYsz)mm!WXSvc>i>jVqa6G6$P3vaXmK2~*l+jeBQmuDNDi z@y{krbs6|a5Yuu|36AVvYn^S4bIGEO85%l4u$xOVDqK?|Yw5=tsQvo@ffVM43L{VFn4acvNKw?+etCqIqaE)q6|bGx-q5ma$%;2v&u*5o*B)p+XigQ$ z8O}`4RyXxA7$QIeZP$;sE`V#v1W-+Mh@9C8;% zhUnMz)m8Ogz4v>s9+Snz$ieSl|KsV-f4}57|3j78pNGl~Jn6rIC`XBoQ_tU`F1Rl1 zlHYFK#ani~u3z`NK|Sae>V)I>+fQ1(Y$J;^ZT)Vd zbjKDG?5EorXk$wq(@2{;TJ3eTsZ|$j*l6_Pu5L7J)M#Kh+EFew8Xxb)ol%dUC8oJ! zWv^$0@A444)M&)LUV>G?Xo@dRxl_+Lw&#$0mEoZ8^!3 zq?29)OC!$L`VUQrMON{oREV6CEQuAtttaGO#E2U-l*m}_+Eqe(N@{=0BaM3OL0^JV ztFNI&x}c;gY=tURMeYq%0cQ}|Gt)Ep)?TZlAag(E!w8Jd_Druy$}s*>ne_iAp7a%v zq0@E_#38-q=5jNWr(`aIyoe`#0VH%Jwjv8s;O}Vi$mQoe?u1^G;BU~!YanC2COm;&GnSBk z14n$kfgn1-+XIsL5OrWzrdxeK&UW7EZ6((0qh#GI-nN16C(WJIy6OG6RxQl)?m`pH z3f8qQpvaIoPi64)=y(%PO3EBLp(NES2}A8C7$P{ifhYZ6kTF9=sa{bo4Aa9ytOPJl znC4c*hUwP>^*I$$n|5K!1m=FdfLf@bYg>{5-Gkd$!uJIo=>@cUYFV9m?A1%$l8;+a zXVJ1)6sn?Db(!ZphxKT!P@Pv7Fz*uY<|1Z+-l|J@FSpO2MDwobGQb3`2=fQ(1@$7v zpT+pk@yfGfCnGWLCG|4KttyE_xok_*LwV=D+cjHMEd!`b%tLb(GO^yS)_q&VnUEQn z%#hwReUR~w$`a2_JggW)F?1TRLr^hvLE&gY-Ju5xMFI5{^#*K%Zf#I#@Y%mB>JyPyXKX)56%|AT*=2_<{Wr=&n@SX*aw`W)foz^L4m#6 ziKz)=&(hQchuaY8N725RE-@5`@*m~L(0*0~L|HOP&7Wbc_1cJXSY@ix%C$eYV(Y2M z1v&vx8?AQ`jy02R|N4it3N05n%OW@$eDSy}XdKIzl{Z~m?gMa{sMG9p1TVq;&3x{+QkXt-Zws_R z)T)xXx5c>A$(hJ!T&D%tc$OP;T;8GS$!(Q|CnEd>M`d^L@@Gnh_u8aAjeitk74Wbr z{S1Zt{73}*GII|YMcV!^U_3CH0B`pbAARnzKH$wvpEq<53PXue{?Hqh0_E=)hJO0h zVUT{A&}bO!BM{p!WFd^!KPalgp~xaGg@-WNVL=t2P+%O@2Bjap_Ko+~;kvG`ym@2=gxlN&EV2n?B&Z`xq0)VeTpG~s5PGQe4Bv@MYC z;h=NVC#y$D9%s!Rc5&9tZUT>N2&CD?Ljv6Gp2+?1LuVMa zqYZ#tR(u!>hyHlo1)AZ9qu%=tz6(?dG5+Yg{HyBXNh~8cU@Rm2VJs``3lafY!YmL7 z^IeQE-ypJ1WCB=h^!Lz|=14^rIaoh#kA&JC#O}bW=H$-9hx{jj7J)hMvVsEdvb2v# z+mwf*jRR{tIDk3e60GmQ7s4v*wGa!Qw(p1N<@;9PkndR)hAuaVo%xxtJAK2*oac56|+nB^K>fwsR z@d}Nsm&wZzO0njp_0{RgB0}3r*wxfuMqGygRjXN5opxXT%x;3;`jyWf(o1-hi>rVA z{7I`|Rb}g9J7X{SB1csx$5POmfwlA&NC;;L5ClLHEqPflqZR51oU6!uxjNQ0{w~*A5Sgs6|5$b)UbLL*VQ9%k~2G zQqF@e=2@w|NX}LHd`-u!4dkWhhFBZCatnu>L5nJg`De5jmfq2CR|aP{;(OFP)r_|* zcVJ7Qen7!ivWvjYO*Y%+ufSBb%=ZWLMU>4CK&-Fy=H51YRtxvnP3~#)73!snMec^b zMkUf-8+PMejUXsBOH}pJ_@1r?xxVtPpWwp$h{)d%Aul^`HwtY($CJ{3glFK6bsJ1DhNgEe4^&4=C{janI1jRl_kQj$v?2xOZ_L z1Y&h-NNPK!M6;_B_wq`urs)LOebfC ze){%FV9+9VW0TFF$-vAI8mHtZ?j!N{f;9=Ok6F4$Rmw2SebPR zWOtRT<~^)3SbZ04&g%?hKpJ98-5BP9u$(fE2$#r-+R_JmNG*0pFhz=!y97R9(gm|k zgpPkKKJFsn)|aNL$)Gi_5%lT5zfdu(f-mWn;KNYvD=jKu( z`L*u~CgW(&`~{U*ohe#kj5ee&92e&C0Uf+~m53t3%+R)tPh^J(^Tzl#B?NNj4G>x) z{XPgXN79>SjK)1zbm`bgAEPLVFZ&&MUc}^_9yuPz(P7Xr zkT&wQE`!>)kR14D@l(o$K+PY)7#gr(+)2p~L#s&GR<;t(!`Z@Sp7kAF@2wY`?bHUl zaJk#*yoS*k{fgG2MVEm%gBPX-I@;zWi*Y;8Vosp!n42J_pywvdRskWzp&TmiP;|FK zNR6Bh%<$07gXr<@6O83-r0G$KbPab+1{%&BPnKgZHO~MI|1}d=G(TnZeS(EBA5XBX-&@{1R>eNFaluA?4lyo-)4_ zX-s4)AUGZ7imin?y-cm6yB70DXq%%cWroSsq#lF?wwWMSqI z8S^;v&7zsQ#5v|@3c9(a0my2B**fJwi7m+b1q`a04rtZAUAt{|sZ^dvK889v=N$R2 zP>Iwg$Gase@uv8UNDQA2YhCpIC!X{+Na!F&^G^Y;BE@Ci?{JxY9v~?eBKZ{r(KC5= z+)AmkkkVv9`I(_PtE+ZF#aY}$^p)zA1CabqPYc-~d)TLFnXVFsAR^y7on0PpEb#xw z8vLIqL&uuK{i`S;NjsjH{|Ue#k)u4}1}{HjRD2#Br-tkKKjW^`6^Jc53@AMwii=P< U{6|ZAE+Uh_TLCRNf45TlBc?lk5C8xG diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/checkpoint.cpython-313.pyc deleted file mode 100644 index ca87854b4283d311643682e2bae93aa307f2ae9f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11960 zcmb_CTX0*)b$fB|1uot%K17P5KnfBC@g<9-EmE>9icQ1_mdmSNvz=uK+^PScIh0aY9AOY?K zp<2sPjy>zBjunC2ZR%*<7R61kdPR%gv^~f z+=5y|WB~TwST>PH5)f-iK;Na)=DPAA)XwG%Feeq}sN-|67y;5$fGlEhGAw48vlQm2 zi;Y=f%vKz84|6{IW8B;C6I7vZG?G$8b1H9dLKGBNr)g7fr1?6x08@Cz|JBMlWztr>!0x{WC-Y*uoVUzPtOjx zo*si0{O7EwczVS^)zf2uML;Y8ky?iJK7(rS9f?IQk0(bHsX!{6iaka|syLR6#^S0m z3B(y2os3b{5~o9QWF!`yh{xz%Afr{4hK3U1vDna%Y8o1X1uz4B%h1r*Cc^RjjD2Y6 z;wY6<@zF#qk%Vb6m86joRhLMp`crI;3NYbBA_=?0LC6ThwpERIG;DWj!S>%e4{zz+NHR5Wn6m|*WQnXy%Z&dAzH^- zMi|AidIqZXoD~uRx{w&)L;6l3Ak>meLQw(?`8hG959mV@NT{Jx3P|f%XQVgvfy}IowqP0B*ghwhLgz8{NaNoiB4LK`y}I;-D^|>(+VjJip-Sz)BuhcICN)t^I=M1Z#!$Yln$X^s}S!adao3 zEjg9^k=f(PaFoe-Y{R+T<=MGW)MRY2jI{80@@J?mEFodT(b;s$vR`7wr8X!MQU0wE$&lxzVbji2*VkvR*`Bm z(k?~Xwa~MeP?`__Ryv{yJ1QlgnWCzvPOMPDQ#VwqMN?v9vD8R1%Cv%NT6`r?@vOy)LZ{`$eCJ@F)>$@MS4Es+3s(W*QBK&Ndi8~d5!6@S;GrU8 zhD;AVW{g4x=!wiM>8S*YN>W{GBzP^H8abA@n0(v{h*e=69#uXS9#-|S@nmE~R(bhK zn32TqGnOSW4-!PHaWxNSj_!tDSvu~)39N7@8LcNMc$4*8$FfGtxB9ad`?pSHjn?;$ zX%s-fbrdke0G8DrkdXBhs3(DHP~V`u;a`NW)q}bQ7ogLmsp?A{H3T*Vj5rz+K*^g> zfy1Z-&&Cjwolac<9Gy!7H3cky#Sk)wEGSa|4bBdXTiI9v*2czcuyako9*K@eg}Iq7vK^RSHKND=KxB>hQelXPj|=;W{A0cYlI6C!# z%KZAuKozYT*7<5x%i8H5JN+09iC;BFM=99tB)v*;Kou^>V&ke2&K^|+P0>aejo>R4 z3c8%gm8<8Wx=GGC;)NM<&J)jTF+$Gmj^~|I_FT-Z+O(_b5DkbC5A>>{93O?F2U>V+>jwXgxH@Hdh7;~Cp zZ>PZZ1Xsu>s)8JgU)+F-A&~J^K;acs9{6ITiP01|PnD&#Yby++w<5Pg#q3kt=M$km~CpF60TdbHI3J8%k@oD!UJh@b{n+6l{^}&L^ACiC_kwXL>m2|8#%W=E~R_6k7w}vNk@l zIcIl#_cGwJ)&efO3vgLofXnXA*mfwk9e~l=l;ibz4`!VUX@J*d+%nG;;~+e0r5ryD z6nUWl^{lfma0+@4SfapP?Z6Es_$6(ov~GL`Cl#={$yrX;wG&zACjG!(39bZR8U7t8 zsQr>n%JVAlc7FY9x>F#FX5t|UU4je3PG|Yzp3Sf)|C~LoC7b2DLB8+>rSv%E1M2ax zAAAY!C+ia6D9bw5<1Q9@{6%t?^FuhOOUkqpFskCH3<{U29@P+sNEx0~;S$7&s(w^P zWd-byjIMGQRbx0FXUa`mETeh=j?cSejv-vn7+nKWx2$LN4bvF^$k(BImprsPXRb_N znfb=_H|EdWK7Z?c#@?dXTe9xTsbi12Ui#@9fAhwD;g`0LY!A9lukyM|liGE9k-L6q z=IHd%nUm8e=PPd4->P45er#!0x=!bKVszyQfUW3BWAl=|_pfT^EB^YzkB09K|0r=c zaqmo~?a-X?!y3ijo0fY2_=y2P|M-cXIQlplJ+_bBoxM8YXV#s)2H|JjdZ-z(BBgAB zVH{)j?4Qc)MHQIP4Hi}2Kv`Bst29t1S{%sM zgj10b=DDikFilR3({4l}X_iPsQt&TSk;aB&Z)29vxD@V?!4-ch>zM8g#VTH>yr3Z86teP|5FwgfdIkr!US!eZJ-Hp2WktJu-ls;>7 zPw~IESNz^xldayIt*D(k3|4fi_mRCJ>#WQ=D%r}zOOD2e?yA{|xp!{7b7${Tb<2{w zHS2J%7)6ZR#e&rM)q#PT6%i7A zPyj0;7Oe>A{baBxBJc*Mz)k_1qiWtzK#a?b_MFo!>M4lJf+Fk%c8To;f;NcX#!66M z5RCwbE5Z;J4I%@vQcg8NMi^C7nH^CzJyWF!Yf8%l;@(aY2KbkPda%;qdnJrP;{fJ+ zHk1Md!a&cUVW6lS*1&?s!WVc5`&Q|>P(2zYf?N^%CPM1TU{TqCnqXP)Ze79@G?mN4 zf|s%e&GK~t{pUgFEjHNcd{4Yc8pvQp;X4c@>;`F_ar_@%40P%$Zh zc4P4p>W8tr9P^s*a;h*20e@*)PTC(2k6nm{52!useFt=tg1?D=?T34YU$@)Oim=CEuJeQpl{jS zsO$kvd?6l-wt0F%4~J>YlNyP6iczQsh9D=NicwD(NQHCS+m)bg`;!SoN8k(SD1HXg zRnyM&G7iR}`U{*u;d__dPyOpXx!^0SQ8X$X;fQ}Ys zqV-Yr;qmcUBC7IZ;kQ-O*>HR!c9_y6W%dYc6AN${qzyekG%bwedEc?3{Gv>k0@U(X;cy^4tJ#vNX;W(2mTGoS+5WAkAMBCGr1tbLdNY!NeXyyv0CJ#+fn>8z_`vG1q-Kki@bxp(${-!J_i`P03Dbl^2i_j*5e99%JyiWmON1iR*R2CM6#uO(a2rR)o=@VZ^Dv|~rEia2&;?7I~EuC%o4 ziGl3!WgVSL&j0}Iu%#`HuyYB(A3r(F!FPj;4IlI0PufUD{VIXD@}Z-8&i=kVTk*|L z!RospS6Q+s=LCrVkx}>FtnKp%e`o63E_}pq!g{M7>VMp0?%PFvzRPoX6aR~94r-Pt zR*jf30%akmBBj;`<4xwUn+kRs1~!(0rjJ5k1~v+n;)YUy2p>)bwWiZ-EN6|@szGg3vW6id~e3_{6oTr@CEX1h^p~RM#yQeJLur*@9K1HikGAa~t ziqEBoE+{_p+pE@a7J4kAFS94!R^t`NAXsOu44@I(P#Vd(%7BP)=U!V&3PMUx@1Lyi z18tQ8-RObrU5N$$rtf+@XTY9Fv9~eeO1mQBpPle22_^p^v zuc=~i(eeqWKf~S-R-I7E7~E-Ejww0%ZFnxd1O%@g@%$yi{s1bL9duO8T%EZJ@%Edy z-~358>XcL0PNnO5?sb0n@&_;9Yfsl6e_%QB&|;f8cJ0_~bZ+d%*aM4q8FESP-HLnn zV%3NBAJi|o`@y--hwopX{`!31ZT~I*LjRJzYf5-1*=MZRtm#e73w;YOrmOZmkUAbp zrWx}!^K9Ro|Av3z%x|PU58t7`5S3hc2nK=tXeWsvry64TceTKx};=LkdMHZ zJj>z6`>$>e++`lEFvN{z?-kGO!P_y2^LRg8PvMT>n|ucvpN@VV8dbpiqkIO0&9hmJ z=qT<;bj_}q*-evXSJX}G>`2MX0<4koVj_P68GE~8Z(r<9+uPIjy;DNgY@gnrbyqLj z-7{CGug+IYT}|6t(^4y=NqI6n0(8$NLqRhzQb2${h&qrJ&_^655ITw(r5N{InFWHD zOHpi)<6^3!TxtU3xGWC1ssuTPxibEJCmxdMpLCTvsA#f1P4pss&1A(%IH-*>3a5w? zsanRuU=YAd#9tmjxEx`F0~7EQgR#8FgM!Lj7fl{f_UJXN^1(ZEg;3TQ1E}^jRR9Tj z@L|_616#R4{*VcN4D02uLB&pg^VUplmr~oc7+tE}H)UTo*DB`5hpzf`1DN&27wCIX;XDp2hSYeUJu$9k;t#Dh>i(k-(6}z{2dkM}6`L1#-s|{r>k|UqoS!2{ zxD(t8ZxOd-YrQ$vwt>Tj&j${%mTe8sxq5Pn`wGAs#I~#(av3)cKms269sBMb{_x}n zCzXz4Y+%{woZUWey3;#vS6n+4qwiA@z;YF2-z&J0DO*vEh)vlB-@^7EwcKq{8eU=p z%SPMOJwW9|A9f5Q~u z3%@qO%gTC)f&s?KV2p5aX~6Rin4niS(qywbnAKwsFtH0xz+_=~TnvvTWPL6MJ*;Jy zps>~!uu=Ok?{k+5T0cZ%kTZLHANFB4h5mIx7;ca3wTj01CK<%&ij^R@8*aDyxEJyI zuQncEf@_Z zI%@5jvBdCJxwcgY<>Z#d=UWCZB1dbvMeWmGJ#nQEz=2dY#(_UWYCr2F1MBAx?n9KR)~= zXjc`4AAxoA9!O(peu=Kc?`Vei0YWnJF|El++Jd9`Brqdn@vGR6X3*!9C*mpXXEbQf zo=*KT9b9@}Q{coB-{oI$Wj|CvWE*i+%~jv1&N#h_)4SkKJH2V=en_*Lozn-* zqN&yB4m8h5UPbaQbfsTAo5t?zS$pMF-p@m$TAuo7lfM-8y-0mV@v~oo;*~)@PjNk? zIGu)lX(_3hZCWF>p@e6dOMp?JX3vTqx{PuGa*jpw$q|WKhJTU5D$L1yp<<-1Z@m5b zt=DgdZiN=S^P%*X17DDo;O?-OMD~DwBQ$>ceSF}L_>a){vF9jj(U}QH)1X;&X4_Fb zDJskobIo`T-}Bk#`Xi<0uH+?XTY9m-FgJV_14d<*dtJBN&>&eSD~7t;Z!&Rc?CHp zv%pTfyBUO8RiC;#j=43|8Wh!}Mmgr*pf9PK+3#V8*zZlLgaC|;=o7(DUJ8iW@4~d- zc|n3b@B1){u6K%5X}|VE+yfLN3J(5d%zEZ{j^kEsgmeFn*nUTh@b|BzZHct~fmr^7 zY+JTEXSdBx&bBMo#&=CX2%R;E9q+W99vFH9Dd+n*Wo_nG1&F-2&L6thzj%VJ^ynP?P_nS!<7wZJt zufH9ReuvQ?I62xJm|TWO{skQ)j7B7;9;G-oA|p0EQ;*HaifdjiwmqAYeFxrJ(Om3! z^OW2p$$~|AeewY%_+yqgf1f0D4|UX{g}C80;ze&!>#?I|yyPv#Ew82LbJ4MQ*;}TB z*OtjSjfHd|t?$2D(_pFSYo z8Fmacz#G(Kv1NArvE`k8MY_UHfb=b9-68E$e}dPfP4i}uCW=2)v7h)uE@L6so z34T^i;+-@KgT4}ez&in&Q#A?$6CM=G525>(C~^QPS|g8$qRe3C511(HE(BfCDXHe* zSZ)92X5o=ry5y>nuUr;NbVb69yAdDy!Nk2a>7^nVxk1XfJLCx$+9h`^Idj9rzPP*d7!%4M2#S7n|E#%@*SQ8qwWR4GGC~PmKB{!(@1uMG|s|^cd|*oQo7Jbv-)T= zn0ySrC=uw!H+8d4ZQ7)w1=C(`4aP~3Hz9@TUVj#STW~cNNnpU5s$r-tc<^6-4BeE9 zc}1R5W$c>IAkc{oGpm}Z+7m$J_hdjBq~fu2Kpz=X143{5-$CZK;Ijw}7ZRJuK`P=R z0h!?7A|$|fZ(P~D=Ef--N8E)4uEH#IN>}T37uppulpA`4Gjzp0cSG;k@&i~uX>Yh! zw7mEhrZ^DXSNMU%g}fE3XDVKfgL`fV!h7CDB&UfVfwnC$f5{&XW zB&mGgE`>Q2FIv1l9LM<USVH-_q&55@`klWx5U^PHq zs#?WSn;?*=5({QZuw14Dih=q+!;&(g8so0u!%!+NT>mKeDbE0IR-0665lD>>gN%o-_`$&D_OnZ{!(hFz8hvMUvmY!^%zsW^Hsf@4By zcjS6HSg*`lN6rU}Usvj-MdxM}45f*b{ld~JjBhBZ*~X1oz#&4#b{z*ejh6d&uHU%z z`Sq>){oB_*y?&>6`+jcvi~+Y}ofHt!IMEO}Ug5kS4uE(yF?@n~N7k%2Eh^rH!AtO{ zfdC&3k1ZG|IH;h603uI~XXGmRa`g-2k@3h>=G2_h14Eou*3^U*ma0w3)R-i%ffq#eaDu=N3Bc&G(?!7sle0u<)oYh}49b>G0};(cT@2 v#*2b;E*ABA30YRY4_$*aXp0(fS~MU!#`BgQHH`)};M%k5*k?|fr6C0SpVJ^HXCx)PZtwrVS>3&)mXQbdUXs6{EpqAAvYpu7%1D6vk{b;WQ0`zX0`5Ht<@TOYFQWAoL;1p`0*>G6EGaqs2Kf;RtNyu2CQ7537>)69V6HvWK|3;W8t3D-xz0?2CNNs| z=Yq@e>|sxU-S1J+4+Gl099Ztj1>MoUWg!N^yGEt+pzR>l0%ec}ZG)-~>lbCsWNKi0E?qIIBwdxSRAl3VDg&jE$r?2- z5XEu^t3+2S=~rp7CR3}y=bX+m49!d{MzyADBvotJ-LBXLCQ5mqsYN_mn$3f|l5oyk!F(wfjBh zdL5&;#`bUu?(XoY|DMxQ#5=QyJpOrfRLBVsI6)6xx|kQoP*RjjAi^vdgJB&5R&}#v z(5i!%8YSQzz>AFMPo18@Rijwf2!;*VVxUR6UJ8d#(Xu%m#?Y8y@%TL7IV&r$uaTU_ zr}(^c2ScJliEJ^(ItF+-yZP~FJ%Yzq0SyzF1*Du?J60JD0dR|VkC|yb6Gj3e7lZ&)YQsHoxyWdsh85O2syJA3SKu|CZU!V!_ z*<%te7(^0;n`^eP;`xPr85j=2Z|5R(CxsU!01;ud=o>+K^fG|%j4&nWB2)tSJLW=d zZwmOi=sDYQ?J$}^9#D(m5BRgoS)taW38cE5PWbmaO_{Amu z^@GoeMiU?gJcPFnM-s(=S&ph>^_A!Xt&i>7w4?5QU4?4QnqfbLH346aaSs45}bj0_k@bR`} zzz!4}x?DB9s8#g>c(PWZtl-)b&%K>2Rq&W@t+Mz_`nUqL3urnBoyj741!GqeR{=T< zvjc4T9CRS`zxJWvzU^T9{ov4@;Lsl`SF0aYKMPJgVl((o5dX37#?bYld%^T0e|Y`q zFPEN(et*x74+T1#jB2){XxG}p4lqM@01OGiqkt<~jY+|NJ1;{A5NIuA+3a&I-->1o zkbu|~WDA(Nxa1*;DJiR_MJSf{#;R9}ukH!WH=S4OdsJsy3=YakjV$=A==Y^u(EDl@=u3 zn~hv4##>1SG6*YvyV7|kdoKE02c}y_1xq-`!cgxMh1|LMw|{zW_Ix2b^X9qSV)nf4 zlZ!>DjHBQhglCpq8L~Ir)lJt>bKCFLY46_9K7qeQ97z8rYn=CyhLabgs?Q z5mUaSb_2U(y}_h%J4wh0F+vlu$$Sx3rW7UUaJ`tN)5L-LohqTz3;tw&hUid?jRH@ zz&sC9s7sXWm7sNUoZ$ zAon8^Tak(Fp3!SFcYBVk&u&M%HfFCK{yds|6z$%~|K_d75!9Dv2{Ym1J}E3F57=S% z^{%O_lmW~R@#@Q-h?He^kg`z&-%=(eW#M7Roa9k(t3{hj&>^-MIag(%n4#35w!|Kk zlvNqZ01gIa6`a=-${ta6&|y`&?AS2ipu>(iU-xxmOU!t7Xc2n;Z-i4)CH8XR6~*>} z?YWyMi+XOnTZjj@lOTexJ9*_KJq7C!6v+G=bWNWi2oKSrhiLE%G}4rOLgJw`@-%>i zS>X$G;2|3QGTyh@wYhp#{Yc$BuoWL)i#>>TZj5dEHj-P>{cE9S0EzL(qF>WZ?1l<-cp}1 zvN_fxBsQ|YKmY#uCIa%>D>qJDKk+pqziGZHqF~n!5`@HqSm%4QOo7C=CVQp-F7U=M Q`gAz>#)R+F9|+L@2N{>xEdT%j diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-310.pyc deleted file mode 100644 index c8990a318593c50705118bb82efa41760871d75f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 874 zcmY*YPjAyO6t^8W4Q=)Zu@AsC(pV2joF>H7W`D%gO)Wyg(-dW~w`5D=&Q69pwWsY{ zAS5{Mz>zQGD<>{|fSq`53lv9rdVhX?@AsZ%)w|6afj1;#Hn!Z97cI#hywSg%@bhLj>}S3ah1tZgGs}(E)pAC zUVrPn0WKTqkukaRc-_twH+T4eVtESx(=W(4^ zZ^c-%)*#%CL@Ct8c&?vIm57C3S4wEj+7@$@sVvU5ad|e9xiB;pF`tK$G*%Y^LP}J% zDj!#QmQ*LOdGk4l)BoCVfVpj>`)6DknxYEz(COL~Pw*xFLcXF;Fv8DdO2Bp|XhOEp zHJuWs!!D$A%lJf8c{M46jrug>R((tzyxPCxX6JW{%5%|W7BFi<3Oe=~C$?M$e zeiU8TalVMSVLu$~eK79s(V!a+jXQid2!}({e19-F*f|78@8J08Ff?>O-07KScYkNF z2RL8bOO_M@I^^o9t!im5iFIq*>ALXi4GgoG$U2@O`a#1|Jg?!{fubS+HdMgX fRrZ``n5E3KLN>kew&!iVhS!9&E#xDgY&gFGLq*?) diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/default_keys.cpython-313.pyc deleted file mode 100644 index 03c5c91a09c8d8f7f7930a7dad03a8a7abfe2fd5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1106 zcmZ8h&1(}|6hHHsn2%Opsf|%27D0w;D;5z!gksXBR;%QyyoZ&`G`Vf2Z!%NQoha$5 zE6MV0T)dUIQvVyp#S0f63gW_z_{ekbxtUnhyEyYZzjN;SoS9*DqrK;=C0!>;fwl_d|}3Yv*>JEF5mrV*okrQSoOQc<@T+HprLbo{-xAMXS{p{6*u zi}8+VCaJP1qGo$x5nH=HCcP1tiX@IhxzI|7ugj71yFHa5FiJZ0!&>IYE-E_TAA_{pCTq@@($!xqr%YZ_0CT%jaGnn}6AQJ+s)a4^EwVapL)je*JCXPw9k@+(OF);BvMX~!)rQF&^b^OWR%l3=jeO@!nekCP3{TFhvn_csz6$V()i zz@T_+fB6S7aeRMeP@E<%e`LQp(CT~V$k240($E6;w-HKX>L}};qg&qU(qO1%)Vo%; ypidF?3!J9>xeS%Pl1}I{q*40o^jDzADAzZU(;J`w diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-310.pyc deleted file mode 100644 index afc0e6aa06ed82d55148984ef09aac468911ad00..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3891 zcma)9&5zs06`vVW6h&#ZTCKbvcAacuC$*aHuATO1*okAfA4rWf_K9dP;BrPQaw$?C zQtRx3M1!u4rpP6|^b!Nb|fu8m+Xd63L-wY{PS$o@3;E40)eg58i zGi0=|U?I5v^!1beA6`f3A96B(RblcT-0@FgV2I%W+47$RctC7|QBAAq@T;%|W^m+d zzRs%UzsX6@U^Qkvs@N4)XN^b5HW&yim_!dVU04ZrG}evN7CD6O4Gne|)L!%PQhGd2B+ zl+NlPYx(u8k?3h1<_j=y_zPJxY5I%Dt*MsdrL@K>W0KaH!HmbGN7m7CdyI~k$Dli1 zxPVx78C^hQ4Dw)?no4iSKj*J7bB_oy(`We|e^tsUUYOCr6{BwI>bOamYOR%lILn@PT& zE*%lyWJcO}Oy(j}-T_~!*H?VheDKv-8Vc{umY7+DjGT(3z^aHuqE+8VUt;mrDFik^ z?v2y}tgW$Dcsk1k66nZzP8E&1SvB?^%GKbb!FZ7<$uKz#UXGqBi7E}bIm#})1 zH6_|FVWsGR?g37S)WG!tgj$dZ2Sv;3ija>HMN?TAX zZ%=V4^C{7pR@-+t>Dx1VSF+_Zw0%EahVCh={(ouof%`gS@4;W-{^FFz`OFu7K<0t( zFx9>!;tN3VitJR`v(n$J12yx4vXok~|Bky%c53wq!*khxS*NtrQ5{vvsQ$RZ8ldq9 z{yR(}^TX_%vhrQp${ek*ys9i0Tqbn>*_KyJ%d4ejC8K)3u)JBuvNqEwDNE_OB9GO@beR?D6Z8#fh``+-vB$`cnk=W>YA>T|!p%?CMDGe^3 z?&a_=4*(tRdvM1c7p_H*~un z4fyuJIqW-;?>X?~^39=$eC{UO!-J$Bh4h1a9~JZ+9=iSQ9|<;eL~@{L#gIf%5N}Io z3t#Cd@q&13c#vDy120b8XfV97EWMO2chKQx07%TAR+I;Pz{4cox*kNX6T~;*VJazP z&@V*}#}`sMSKvWKUd{05=on%jm$9oLyV$l4X%bN{Wc*W61MVj{uL4P)j|CVFAc_ST zF2JpvTtO!jxbE%(0j|H_wClHc9P{9o6MONAggvhW&H)!)l3)IqJHh+AVH78x`!SC_ zHgbZWg$WmqD-Yva5Yt}%&OFy`9>!jBa2E543>h?EdWFns#X_duHk310g*@AH%M*eR zV;YT;?zOxk=)n0@sP~1YquwNFs9q)DEuqgH7Ux>P!<-z(b~A+c3CiI;gE6SO+z246 zj>wIN+}rIZc_r>ULvFW(!#pRLW2}NL%Dh3o91r_kc#uTecf=qHJr~T-09=4~l1RAy zoP=Rs`^XW_fCEOkzVGped9@cfNdj(E+#m|MUE5K3(;o0r>IP06Q<|HUh9E|%-7MD& z$H0Ve!d)&*8Gj4HgM7~t5a*5gnO!TQqNAa`n6^WE?n$S3Gx5+4j8Hb8-bili(9@M3;u*GsiBJ9%{) zhbbT<0A`#UkVEbywpn_6O-9c=Hvt?h6`{)VZ2euk9s}&0#(l3RQza)dV=(Kk3t6*5 zc8xG#up6oeOJjCz6qj=4V<8Kk(GTr<@ph?7#jc0^AuYSdUX^adL+*M!p6};k!ROVoYVHoo|+Y|CxqqYK`8AQoV+RAJ4tUYPerZgRBt&)CzqE8H~h1(ND zZxfR=KvtLWJ~gTYylYng39XG^u>x3tzmq1&kcmkErBbE=m=xA^Vmvpr7tgvTZUW-2 z+ybYUYzKLN0C)Tz3@JWFV+5ZF>%Sry$+VQD=orIi2pOY&BrYKv%~+xhsAQdb&F7Kd?tJ3g=|8#l^W5!WT$f^b zMTr%)q>gG;ORdwj^jGP(Tx#hp2)m8RXF3*llmmC}-&Y4gw<<0H!EVWOeM=qFTk339 zhk>{ZYq{|Y*tFmvlmIF5Im<81o7bU5fbE7nnBv!9Fj0dL$d&x*|Aso5KrrL0ANcr`Tzg` diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/finetuning_utils.cpython-313.pyc deleted file mode 100644 index 29e44bbc813cc85a3bebeabf784cd008ea461caa..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10867 zcmd5?T~HfYc5eNlmbxYM|HGdJW&i;a2(XO}{=qf|+rWT?JWf12iYyu!{mlO=ny7nnK-BQzn zku`vw>@-66+`n_q{m!}P-rHB3I-M4S=l;_T~}V0F{*&N1Y?S> z$=sTkJi7t6j5Bs649fg`>5eCed(>)7mS9C?$ydNs*RPkvC0% zW*~Xz;|Ga5AHQj535NMC-!2QtR8SZsJb(gC2DSO*ZyUK8!{WF6_HT3AhJiJM0j;Ce zFmkZV&c|P}vjpVpbCA$&jk*{(lzqFkzvA0vwBd?NEcI83l$844lW388OH6A;K2jPa z0`fzEC;tR>0YyO3B5#xPIQ!i~;kK_s*_^*&Td|N4>n{J+13nT^dccQxO2h_`*`+_k zG$Qvm*rQRbmr}o1qD9KzEK*V$Q1SC9Jc42SLO-8(LQ3XS5F z*h)%+W^p#>Lp-1e>RaU7X8A5`kfTi@YZTN*vP=KM_X&~R%*uB={vKIB%w2xfeaEwI zw=3%|EM90XlW;}+Zc^^fJ29WwjV!%74psT;;bsVEJOHGCUv{sW!EXnO%I_D7@5Ga}yeE%< z*Yqg*m-pk(IkBgZXP=WiRipw$p_Skr#=U~FC9mvJ_8-m0y{#QffZVo*{tLes2^zC2 zC%8kfDC0Zt%K{qOBCWka%eH>U`L%anKHn>%C8gOt%mt}Lr2Ks%C8a@=;I#u%iAagj zA-mdUV#MxZ?>o{ewzc1vkjq897a}Ze{j=gWD552$K`W2wpC3JcUfp;|`+uJCV_3;g zy~RrYW-Gtjm3&Wfa^|b7|6ens3o7ce=IN1#{|jt``SZM>$ehme?GEy;u?1!VZG%<( z?Imn^e^z-&f~0iX^#`SLkqT&XPW2iw#%#sWeV&ME+dk%)O_aV@v?U9^0Y!}vu`Y>Y zmQw#^i597VPH6p?VqQvV`x`WSSNt7&=AZ99QXxNc?irzc623ov=G=#$Ilsft-e}Dn zKg%BvGCe<(JVvDU91ltg+G)2E^4ZB=I3DDS2Kma~HbMmLLC5yDW^~CuY+14^Tjs$l z$CfhxN1~mi^y;r5cB;fxAbb&WdE5E_O6nt0`QmI|46|>g3++rHBIMC7)-1{jDtDpd zRk5{{{y*sG6Qu>QYIxtCdFMWwL%H>LPHOw&;SpDIjB@4NBe>$DuCPnEUT{U}WPB?# z8fRRiQJV53C+R4C&8ubj7r8J02cCgz8F4R8a>QvWk)UFy!-;6(C4c_nRI%`FigC-i zYCjc@bzP(5iDWe5rxMYT$#CpEours>gulW#9g9b9a1FVw&Qf$Dn!Np{#xHpz0_THy z`s`%P=y+v{1ZA3E@~U&%D8o<_iBNnp>2BdvOlUlOi-FS<#_7Yt#Y>15Ze9|c0+8TT zV^nwq=om$F@;eDmPs3GAh=RMBkr4m{cmq0NhAaAnie4Lo_KC6Z1jT9Zgp%=4lpdjO zaRw$l5)H?4I;lZ;Xe2t$nG+LZ6cdd=?O2!@kJHfz-06g3pc7o}#F@w#C#PvnGZ<#V z;}n3#DL;-foTfO6xj*Eybxt z88`!m0UQ%t5!jOoCpkQ;t&bo0NHmg!fop{k3Krk0Z{dgpAWnr6W6@FGVtj@-9OT@k z4DY`TzawGz`#+A1!F&m2g>jn6gxD`bX8M9=p*~J_hf{;yg)2KEb61z~9)hMm35hj! zUZunY6^T-bZFhM>MPw{Y(^M?ChARd?ibkNdICf!F8L#4q9CneiI9ZgFUFWod z;|or}Y51y)*=0Al%8Ysg|k zFx8Vb+2o|7gwCoxElup^YpVE(wtlf?5?gy1$*&CnGQ}3_Xy3->A)ineS+CD&BEF2c0N8T(!5bB%t+de;uBBBJ?m0BPGd2$kmO zrb1m`G}X;29Z2b5l?RaWz^bxl-GJdbR$Y$N<*d3EscTua52<~u`WR9lTOLoTFRmRs z`P;gGsAG@yp<{h%0G{+B(#w+Vh-_!c9z^!AWIrPNQ{?-gREKt2#o`qRugKuIiXVAZ zgz3#cX`P*9?H*+Jq|9Dc*EB6#^BsQXc;aAvXOQoVz_%We^$&cA^swX+L>^(u6No&) zl4lWlHbtHT<}CvA5`>p5_z+&s;=-ILw(LV?$Jw%eRMwv^tJzd3oF%Vuti-iwTOyW9mWbaBE$KHjN@w$HEmm@1 z(fXNvQ>V7tUMVo6V?o6_>XCyVo5RC8jw8qMbsHkvSh53=9W2>}$gUK52C#10 ziFGx?t5@*`z?w9&L^&eL7s7~eu|z#0>RF-{5v}W(qFuX*DM;-r9BVuJto=zl+uDs< zyVHhB*5E}3?}MAj;CpU#%-QeT7HwasKUb$M4VllTT3KoO2FJ{n+0ZY-E2diTTBG|l ziJ6>p_D`M9jkY=Kr}nhD9+}&h1ITo0!-yNZWSa_Yn|`{RB`XkFv2YWSl`QE-r29*< z8Nl6x64tVXn{!vPx@c6OixzbG@z6r?Z;UIZ`VC!?vGKM3&DxkXf;L26ZC1%l{=o5*pLb(}?xvuTG5T+nKRId3nY8@(S~ z41V?jIE2vz4yUusc4r)T=_azcSxXzTv^^e6ktbR5+(x$?(>K4m1mEUa-@VJzvUFwD zQuvp)+3tm|AD@3PG_6@P+GcOF#>(lgKNzf%Drw7mti_8gybUeB=hg;fZFto8uy5H2 z5z*TDwbeCuW$D1v%BMpsm4_iBS=&K7Tj^Zz(~M>iGIgeV12Y3KcljuiaXa4Oepd-I ze3cx0WrFq#zMo#6m2I1c0s<|_(efB3(4ptfCgf~-Mm`~zhd{FPlYZyd^!IdF@ z$`Ak?Z510<%vANN4r^$BRQIrst?xwjox<#?L1YbpOS)OI8IjFmygH7^;{Y?vo(7mb z_w+OR*_((DWr_grV0`QYwf zJlC0Kf_FbyCooO@EBy}rJ^A%p#cS8Kn64s6CQG&;vLzE6mirKSo+XD7ISeprY!C-Z z4&1L;tauDUoWW^*+SBsrqlX_YcOlOS*3*MLJ*;N{ zc?MX|F!BtiJeOAR`ke!{&Rm`!1+N}n9A2q9y4;c~Ke3B}Ho!nZSntDLw($gNJdq~7 zFi_8SLsTKWY87{b^4`|()GIBjoLN)6u%?)1>gEs69ll?^SiSJK53Z(4e9vm0)U1)k zvqXw4`&neEAyvVP<<}qMOzmCweD_@Ur#)$X>5?AlnjW8lb!~$v@@iq&d(#)(NL`Ki zPxHSMmUPV#w&o0~Im6bRM>Xf!nn6@ExMncVUQ8KWAiq}swcb2K-Mt12pw7gSa(fb37-7%=-eS%QFl*<)+`IQ~zbk_ovo!k_5Ghu^pJ z=eW(^!8x9n+OY`~}DwGjMf4c`7vLjQ?_mtB6o-;PIY>F19EJ07u4xOoE}{qRa4 zAPdSnKqCM9xqf(C0PhhBa1CaD|D@i*7_vcy_=d%W^$%y?j>CE0EgxWbd2S^f+u&3; z5`x`s=1`d6v&aiVk z%urDL&+wn%>tPp+W3V#Vk0Sfguk0N&#I!O^6tm`o$b2wmevjqX$@-MJ@k^pHUF=-4 zd|CWn+F)f3E@W_}3{^{Q$WWJElPAA29NJK0j!s}n8mIa;fjGJ)z7%ylT^eINm*#?}=w zne5*%@;@=}IxMXAwBE#jw8Td7U(032G-0^cKGS~he5mYJAHj~&pEsVB7U<3u+Tyk&>*eQwy3PB@scQiXA z$t1P*8rnzsD3C*d90CM!nm8Y+*{yBZfF+Q6Bp)Bi_kHhC z<@$O=@O<&*lgYmZg#5=go_|iE@h24h4Jtu)h@gVa=$!2^YGpp-J6w2@Pd(wgI*|Tu zAbKKvCh zS=m3Ce@>jo{HU8R-g<-C-O2-TL2O~>n&7v|@N)WJO2|0Xc^2y|HAbpL*Yminq)JU* z7CPQl{vJkS!4!5IQhj;8z;%b)SY*3n)2QoGjsj;9;O|7Jk^8y3ti|vMEC9ASCUKb{x z*@5CpWT&S461q6xZ04L5>wI3#q`DH{lW`%XfNrki3PxjiWKtHfmL@KD<9jNt;9?z_ zNoqG-{KjCNyLArDl|}1^Ik( zzuk_$0>R0Uw&AChMK+z?Y`Ioh>z6o@YPrbg8MivC$HUi_8z_2;s$qie6BRXdL6&H3 zkt5a+D+{*I)Z2y%+OUW0&*X2tBaRX1FX<7vd*0AaubcUT26r02#yD@Bc0<5R@1_DT z$ycs=hG*z0A04f)(qd1xeemjTE#(7<@_v;TB1tuL?EaPAbf-;|CqyCY$!5aw;=p9R_fc8@F$D{bP`2GPg6V z6%RZZH*N@J<`n#W+v36fb~c4ciksP`-uWoKICBE&&{1wzrCzv*0&9i5Uh+ zXxHAS>YYv&>>@b|44WbjTzrwK2bcST!{t2_G(9-nLg_vBjzVKwE66cs3BXokeT;=o z-_l_b3cCxm7X2oK6hSlUz+2k&)*8I$@P&;vXM;E1Lmv?Ps%*l=nw`HGtgLWrMfh6; zed|kNiGWxQx>$8stMl-}kU9kR9)JDi;=6zUBm3ZFx*>Lpif&Jbpq5C-nFy8i>;l zsADJd1z_uFa^$;+24x%Hrp_fndi6mCNR1Dg{Hl{3O?-o&h>kXvA}E2QNq`kVEm$HB z#vWq3*Pz}f*9mCRYq-l2_%6aFQMd5zH@zjLpstPNX#{Qj=sd>H&#S?0{oPqvD%>OM zxakXlYbu%Jx<*jPIO0o4JMaHLdOM6y=Q@EAT$z|#l{VxDM%5B<9ppj+V+hGGd@aUp z-{sE}c-4D$<(I0fCdHnFhD`t?QPq+6e%gH z>=grg1(D}oo_o*Xx##=t;S;}~L(uM=*joCt51}te$8JGOc=B%wq1#A67zvc%I7`hq zu!AQueU`?w;9Q_*S$k6JE-*WcyB3(2T!Q-e^At2coK&o15sEpIq0GCkj`FFQrWCxk3Nyo+>X zulOY-K|<8zB(Eh!U6*iN%Vl^qD`8QWRYl`3U^T<*Do!rNEe&&sThjHcHZ?i9DCPelUL!hX%XHH zlrriSiPx7T4Gc=6p2L#H%L+VzF`hWgV=1dP7SAcVoF)&HF^)_XFG-*aFPKi7!QwSJ zlTAx_j9->`MUqmWD{Fif7Udx#ORB0+A;cP|FNwrqWmE?{mvb^`S}T$` z*xp88e36%~>XM?7Et%j1N#eUi#)`TfnMqmGa*{Ur^3mh39G>95A{dTP275kBc2-QE zGiVJb4Mv4G!Umm^GfBFIb3_0}`cGy|o`^XFYATG7*n?1vnuW2>cFY=xImSU!95J+; zmf)D0gVjb5nr~VasF>rJqp?g6IP1-uDJhL-v3gNT>WxkU@_%XaC{ATX ztgo13jf<|TX>HO15kS!3;t8yZsidgsn4A^lul)utc~o9b?esB5u@{N_!$s50VuxDTCa(Vc7&lOJA!o3^gYvY?#HGJ@9^f8UXLy!F7($S6M zYsc45l%+fJ2l9uP{^tKdfvN@etY(Ut@*CB_&`lat11~=0{Hx=|@iJ582J_LHH;|us z0$xCzWI*#(gI068!9v!?H5p>%!cr=iH1B|7B}qbRHzJ#_Z=t zsCg@u5Sj}?0$S;)P%%2W-kzS_Q_)WQWc|9Q80Y|I3!$`#IWaCK!6#7 zV=iIN7@YUmuESifEw}FWT-&BV6Ua4Y&%=S=XtM!ZYTXZ_xnItPT76rBcLQ*DHoN=li%x_gl#|Pt`Hx9$ao1rmDSagi+m@29$Vn;UHFM5zre#ex zoaW9lN-iU%4H{_KpfloCgE7w#gSjSQRm1yX9sCSj44-|Nh^bUO;%*>j4|z-y0F(E_ zz+MOgu^>DE7me%-fe+k#6T)kDY^F2!!W{<-s+!s`!w`%1aZ zqgCH%{xl%q`a2JTJ*DWznYAai4OPKooEm z3~vXAD#4*mdMjED9{!v?{P5dDnCz~>=9GjiYEtScAduHJf$&>MeEd50fY_SkZYN}u`yGl4Qg>I)sj22^vWuSJ7>8O`am^F9Guv!Cw5_p& z|LaUb-D+7#db3Wp7O`#u4N32w_uHd$m*-Ld1gFBz+qH8CQMBn&b0em%r%kumB9&Q^ zjvY;CMVgIyW=V#2v>4mm3vQ0<4o3UjGNa|4`_niOisEha><@FpG3POIfItIKGcPnm zB=V30>};RaZM(HRnA2$M{UnJg)7t{ARGu~Fw}xrE^| zAH;!RE@M67+nFgwhyhPhh0-q$M-4ti`i$8(XdTveW+YKD=qn=D2=AIqVqH23MNi=$ zg33|&Ykhz%lTcXi^&^}}LtS%N8Z^`B1qFSIIhI;bTHMj4* zmFp`v-YF%k?)|mB`!}Z7rq|DI9#}hH@I3N{R;O-Gm86Y}YZuqkTf?73?ndrMKb!e< zrsDlc{?r3sU)j0&^NR0{{OLz+)b)VtEhRUHE8K~Ev_)y#zVV80d`qnQ4gS%Ga^n7__ZLG=-ZF3_PZluZ`%10lPL*C@Cq?M--D}d0xj8{<69WC~**|D=+I|J+;Ms!Ka3%{rC=9&7i;bg7N;m^L2UMbF|fX1>W@ z-5AN_v0d1(J7d9U^U$~4gPX}?uC{22@aCwE1Kdz>EJOJLX8|cY+89e~F_z*6Y=!P;DaZ?{LYZuLH=tk3i0X3Y_YzSBqE6^sRUEXDpnDN=I(I zTig5G$L^2ZzxP!3Mhc!5_~oj5=>G%%J?Fips_%6`bYFk|bm>4bQgQBWf!?fRw%r32 z_rSJ$u;L!9vEFUAzryyHd;iJ~)m_LJ%+J(m((Jt*p{?zmxGKhjWQ3fxP%i4&++V|) zfDaXuK+S;OMW(??oJ9wxzc(4}kd+hI}j?_GX)q&zbod)?Kudd5BWoNvRFNUT;BHs|mmo|PpoWzq*^Ir3GJm@E=IRdnQBzEVl5Ql+Y-oQmVbj;kDZm6KR@l}*csQnD?D z2axakU-ujfP?mRVy8GS#d+&e$`>*#p`Fu8pzhC^7H&@^Mkyz{x85sWCi@;;}d5_0q zF(>92tyo3Bjfx>*vtr6SUWv=ws#x+)R1$cbt>jv&k}_B(-b$}!Dw(xxCA&6O8C%O$ za%=fYeyva`NV?G)uN0;1L}fzWC3#z|$+fAjhdbupF&|qt=VQ*an;VTm+>CpdGwbei z=8%>}c>Wf77Tmj?J?=qgFI#l;?g8d??{@EEt-9{z8E2of|J6k0kaL@3zZ$EQodeG8 zc;DmP;Z8I^&^(NiSwNFvnPtnl)42;U-|JMIgQ)e8)GYbXpR#igYFf+~&L{h&KY%vT z57ZLnI_%twTt`qt)WT>ji_Q_GA9ap8_W_Rkoco;z@V@_+Ek5WxguD-+_6L_u=V9j& z!1j>yLFYp#@vw8ujWr)}z7HdC9KQ&HwDV!-1oC|_GvRzH%o9QKD4_gM^O*A(M*kq7 zM%m@8bJBSnt$g3Eyu;&?bWS->pw8oAo+k%mChek5jL%cfN08^kIuGD|+Ia@y6V9_3 z#qXDq)N`J6PCL&buueBTKS+Oc-EVZ-wN{Wg zx3S)GgTxDMgi*xoJWw`+#M$#FPoD|$Coj3Gw&FU^bsV=bL8TrWuY zoleU`sp9v$n`hVEdSkg!ud!ytoH4zlgZd~wO|5T^da_ejvKq)jCBrerj4F@6D z^Sz~VCP-JS?b@1Ktp?d@RrdhjxoY+0jao~ms2qR^3V?dKv9h70CtfgKt=8J@j<18& zs!AZ2N-=BP^VOR%wTBUWy#1*otDQCX$XcztTI)0$HGI|Gd)HN`3FwZjZ~Cj9cJu0>W8L?v8-AnZEv|2>2{gfGyfhx!Sk}xM3x>o7 zDb5TABFKdw(%B%XTz^Bg>m~|Bf1t7`9E>YDZ^l>tbDq|;dMwA^mqu6NefJ8ie^ zd&lfL)1p6iHHqIOejfLhl~^zKQg+Mm(XF1*L`UXh*NmSqUpCL`Sp#C5?R8A%##&Is zT&=s9jp}8$v9juWZ<>L*cr-}X1u5P%f(&!W=<0cA6G8TP3+sw7ow%Fr5{SLA?x9v^ zr5a5LYVZl~;_+6eUTb+L@DUa2nuqPKd5fNKok&k+@F6`(ZLt@anEBXn9}IN|a;Q7` zILL<}C4p0AJPzUaEkZRDd+Yj*8#mrIUMi=8Y}nfd7PEEH-SmQZqwT8$EOk2`K?)sM zX|z`YtI=NWl;dh2llJ3L$+B<2f28WRq*WP>satiM^?3z6vPQ{B8;i!GSu_fUx*t(* z{WkjfRFHbYU9N4kK%Xvm)WxdPP(fB!b9EhrEXc}|*53vFy{u|_)lUHWsy9r6f=NAP zRey8cjb<~*kRIqNEYEeFVDh5tu2&mgr&aS^r&{wnYhI99b!(2yab0*9fXF}go8eRV zdB20v^s#n)W7|~ep0Q%IF)>?N#F-K|(=#1o!%%ojnz=QGw4|in*-J_q-jbHw%B{q< z^7s|-8*jw6iug_VW-q&KU?raYl(|*fn&jGZ%--aeV{gd%@Dt5sZ?c!@Sx&r{?qwY7 zY7#l8wx)aeom6j{Yp6G+uJopM(z2p&L+Z@dY;UGlykw|VAXKq817e!!P4Laod~7Gv zo0XD?O90KB@p&vk-1*~IN6NsPNn(q|=>pWXT`FSTdQa@_w*Z(D?7o5D*pGO}HTl4LCj7iQZ zNa@5*zL)RCd#PURnuWTjx8?=nah8y}i~Pa}FZ6O>G}I&@%zY0K?&*yK!aX5`dz^_8 z2vI{H!aoA-D)h#$8CwfzZ`>)dJQWd7<`8*c!4n{!Vvljx8$?tNgGnYfiSo!*;EW6s#pXByXO* ztU7D{qV zCdE-}(p|%ZRZb7n+*<{?vtznSRNVqRay>6|(ik)zX zU9m~+`qkY;6pz`vIW@a;fxL-S*TQtmq=1*`*U=}K-(|!E5&Bc#FG>P?%svBD0cJ2D zt;i@H85$Cq?tYs_U=1iet)jjGLy{w$N*!ijCw575i27srWH_39O$?5wbB_ zAX&p+)3{fHUU-_oJg9xS0o0U!hKpusy^Ktk z*N0WtG&*|ibNTyttH1NT3tU}nYFru4&vgNLUceA@$ zAkGFzA26=F5K=n(1&`p=w6J79dGwgwjW0gD+>P7ri`~7068A}g#V|kffukyMbjaeV zDLl#(LAHVFJipeiyJ`VZLB?%wtU(a-U3HYmw+*qU z;k)2tlC|}91ck>6;w`tW?qdaWN|t05Q{QM;WhH`l0?iF3WVH`bAy4=~kul2q)k7$) z9_Hf_J}A2h9_h*cJ_e8T(Z`<3m~KO8b7f37r32YMVeu*xwN@*HMzHJYtQF4)g zi(+Q<3hARQ`?BGBm7C)0m^9xDj7DIw^-8Mdg1xN!mAr;8tT0HqdZICqTy&d(QLW@h zp_m#PBCmR3Q~1juHyj&G?S_kO1xXE^=9c^Hc`&u&`h6iGn#p@HjNZS&W9mRUZr(^6 z$s49=zL!o7zO9sH;6H8|#dr~4=Fork?QGgYY+C*cX34P4Y@9K}`Lge1Q>HP8y0Zz( zNE*n&{5P`ktYMi+NyEC4j$7}Vmi4ZB0l<$4W|TMk z-1Daq=#Gm>24M`snF&rhzHNH>XJfAz=fPK>jaf0@+B{-FA~a&%DN2Qo(KEpXnQz3& z%Q$hzYM|`-*eAfjZkeA!9pGZ*>#}h1d;SB{1Y-sh4A_LRfv=4QQ<-^$VNeV(aMm>mL{xN>|`{tud z5V+zsXkT(zX;{v!stbfr6egfpA%bL{yi#{XGou%J>T*qy-j?H)ENFDN^UGoj40ksk z%@U7vH%DZiub}~N4iD4Ha!QkhcPs;wc%#JW{;c_qdI^bj{Q7^CI{Q=ALZG|gKtt1t zZ9$f98jk72uVUaSO*f&J(IUHJadNa^(lgIPbjJjyHXp(S#xa58GJ%O+Jjx@}hP2Pb zog}7BzGT``OS0*lD;t%0&#kHYDz+_!S)oKI+#)u^?$oIb3_WP_N5Cs`5*t3~uKlD^ z*d=TAy6eeQFjcwh9j`$y^Ozl2+(FCfAimP@E7ln#1r}KViukOgl3aDKc&jy;0`(+1 zuPS(yEmg&H?C2gs>kYtA)c)H4o9lyBo452A)H2T#TzMC!j%Gx17YQu`R<%Iw`!Nzu|;z#pP{n zS$+Z8zYcQx~Usiw~byrYN2QKzY@;vUrGGDv6bm%AidqO6co-vz2u%*(}K7=+Gs)xEG;Uv)3(lng#E zYSy)R;RLn2AtEH%MyYoL-F zt!ZRMayYn$$#9U(fi?#Nf;}rmJtbyH$f09mD2h&eD8@p}yptHjV2F}SLCJySN;Z$I zPuruk8;qY^X}ezITt6bIBZB;X$pOp zh#-FDsH=6m2~}OIT?w*acB>)2O!FHG-_14zAxwUK!|yCF!vImcQgvIcU>vj-JE@4r z=(!h6?~14{Hxzc8!ZnR zs>^NIGzz1UdTL8yz0!-S#}yvzisOP;mpSi;m5-O;KrkkGs~%H>gf5tnkn)&`1=K5e zK*tNAL1p-vd`3n>%O7cliR^6no_dvO31Ozyv;0m&E^{_GHZ$Jl+L{Y>aGg8|us|jS z13eN-g*;49I0>X}T-fkk*|3AsP`0Y$t|-@4J+@E`N6W= zL6~H$2b119bOx|UwW`?8*E;P+9X6|Ki&&N$#s>8x(r0pjHK?357F5yWAm7HM==g@e z2^$i~g_ypK<^JeLGBa9Lud||feSITH00CgK`C&$kqjQ(4{(8j01?H+(&Wce?ts#TD z#K+&~<4b&SN-9O6Ox3j(w9_&fFlqO{)sHeCO%1zO_Wb*<*8hdv-jCstBa?5YGx!JN zUWm_#Xt6(G8VhFfMlN15OvGgIESTo?LL6d;WxktBnD6A0=Gzb%{$Dn2erGOYetRZs z{*5_Su+r}(Sq47~Qi}P_V#)lA+@$&E$tm;CtZDNb)(mSi|8!>F{NwaO8r&zsM*8cP zt^OtZ@Krv30goWJR$H&qKEm}fVz%N=M9Z5nL}aM^m=A*6gQ(ybSHb!_=2faM<4d3~ ztV{hXJj#ZC#e;k_Ffy*RU4pR;EXvR7Uo-nIPA(VaXO6o;{5@KgOi;$;4fU(|400rJ zRlifEN@@gFF>@e`!1{B5JlZN$&ga$-*8pr2hAR4TC%$dG0#4eoHe%|*Zovmj2UA&$ zTR-#BytZv@fuC+40YjDCd?4hf6DXS^OYn;6SmdI~J9lA_Yno72gSC+|(|rctN=?a!kK24F*{>U_L{Mi1Cx`-uW&7f9AlDW zyjheS6B3sT<4QwudBj;{h;}9&D#+QtDQBU9*q<}LU~DCOv7PB_v8@!;<>_8(XGRPb z{~YS_%oZ4;UglCv{fi#pfLXw zz=0x~i-p+M80sCHif!e3xgL#>c`%u*>kk=rrWRrnH!X4R^s+m1U@pU%yoT|2tj_~3 z+NXEsF#-`>vtf&Kqa}KYYY<&^8U0?4?{ju%=!853`bzKAAy`R6k4+`X=DI92ErGMST)98)eofJW+ zug{e<&6#o=S3l0j7y0-+A6-5Md-P8u*d5obLG}Jc&4IkdO{zN=@k@(G7j=d(S#2_T zp`UzTm|VRt%y6r;sO}Y%#ozyVRFB;r&ApKelc^1|1MCSbjGM2XKpNx*zq9@fxn%7f z(`Y;H6}{{JGnN&OC9~1?Ufyust_u$f*N0U}Z^^&N96!N_bo?jzB^=7n@QXZVkc9~2 zYh|i>ni1eFYwK>1Qf~OM4CWU~>0n$`)r0Kig64pJ4W%k+qD8ggToJ7b*Z`tvsLcW^ z1ja|W&WIJtV6ADe_<5O4ig8-|IUqrdy1RpUT^|g>TqtDP^%W`)PV1_RW!0tokB5-|KvMUsVSrt z;}k=SCi$%*epx6_wQ3f;SQ;gSb4p*Q0yc;5BKR)L)M0XCIfRR5_Ik;J`ZaEu>Nfz$ z2*<&d*yoN8BY@o%!f8YHxNSprSggt&@Jl=KYf!m@=9@bf`2=D zB7+;4K0QC`43nh-9~w>$_}kGzQ@@1t*Ev$;mDuyFA#pcRjy|lXkr6P(`tAv!9@xoD zABrsBY9^LVsITS#YZCh^cs=a*mY)KBq1(kB;90;UMe(A0KwpTM<&2=Cs8_|IB}lcM zno_k*^~*>OlIPsE*HJ;z@6dp(R};-G@CM3K8w+`H9TGqsHdJvY+qg;XFcWS?WC5EV zr8luxe}Vlc2UrA7U<6Fe=w~fM{SKl=#(~6wRP7Xg-p`_x9tY@EF>_Ebi6VhKZZl3E zmpqlkQ}41@;xT_=ycOm@)X#58e#SHZgZ=!77qt=`1~77P<3bp`my-;`$uKhkbg`ceeP2Q?v53GYN$c0$rhl2$_6_)Zb2lVQope#}%DGbQ<_CI2*1!G}6< zsSw!dcERv$7@kF5u8+AT^;t|(_wnPRUa=2Sb+m|OKv{6{0_2s8-r^az)`GI8T)V*!63f^Ce07A) zQOTfwiVs$>KhP5BNA@`F;gnmcD2Tdmz9-XY(rb^*)R!uw?=H{3L`oX!nV%A=$Q|>rF%R zn+eIb1+Ba$+li&J+0AM@!=iu1e;djs`hqO7%W&h(RYit{5)5c6v|IgeK%!~=m>&9g zv*D{>VJs(HIef4)fyGLK1UGF@^*00cUy%kAe_z0^Ob)f7+k;uPptlIr4Pn0KRK|5! zIDpOCT5I==6(prUUH^siR6UB|ZJz0yl=#oH#I@bilM6VOlpXlPT7`t*;7`ZWH2eL{bWR85p8mb%YB z6ZzYm)GjGc_FYn*JnEH&kERb!B=8a8B3mfY^CtUhkP5vex{sZL(=0rkha4z&x1nvS zEf=;^8*C~}fgWT{s0d+7=(M{hPIWd~j=+GnP?tO8is5NrnE`0@Qd((Ta@&V(F^kyU zwB4!KHs)!?nYPuOG3KpR$HbQ?iD zaZMbx(gQb#(Xd?7#@+h>ASv8sU_y@Ha2XZ^s)pEEMN4A+4RUATK+;$Xwbht*$D@=S zSnw8atFNQ5CZZI9RDZLAy}#8`=NQ)?I(h-1+DcJ?hFH@-xsd8seR^{#;*GX0lgw6b z@+J|#fnB&Ud5h7dI>&3G3Q8Y(lhbe0${_~w$qnoF!XKC-e?V`0of1qo!7WV_Ud-Ed zYp(bw)|~klkae(WrmrV0x_?-i!gof?nb03wH`L#txjK^q1 z_O=c6{l2k-Esum+h%n8|AjQpudYaoIdHiN(8-yGiA96u+sJBTgihjuhEujuy_A!_ z81oulZ*YJL@;m7}7)O;V_NiAqMWgyLt&)D2(fOQ07)0m}LiuQ=}}gBi!ho3{NT9 zw*UyLEv@8={0AdfQ9VN88K3tD*hWJ#C1ZL@gs}kI){N}}TfI$q-w)^o0!@)nlS4S) zTu++Td+OgI>MgQd-Me0BZ+H+G`dim6IwNj{_bqq#eU}cyrLd&FMOcpVK?Cgu$-=FTZkT2oF#!D6-aFCtHN zI1j#I#JLKCQNOM+8M$CH+#B*GTr_qEU-yC;} zBYm7e-mfBW2}Z9sAn&luKN2q7fFWtnr*UU8>JwTvohkhHN3uM1IGS0twAknTChmsVb5;i*d&y!;5(IvQGa91AWeQKco@NZgo7* zu8Zm&=7S3^4-z=qt6Ve3PY8B;18V~)Z$3m>JtDzWI2vKMb>dUtG5bALILjPUFQR1% z{cKYU{d4ETF?~{eV4spm|E+SJl_8c)GH`c9 znV#mr4%au7bB@*JadK6UDoY1@qk_X_PTzO>en~CraXSl!z2!=49D-Yeq_Z*;T>nTD zRECc%3*cBBoNQq4-^GNDkt&60he(U)!a(0z`tMo~X8#-9YK~@qODl)Le!XWcW`B1m zbSV4#Q3$E;TDNEaXe7KR`wbmVUbm*Q-^_{*!!m`D%Vob!kK8c*Ju6ipMLR@rGaWdU zt~I)#!j^Zr=GDoMMSm2Zxitv^dl5j>>Wk%=#uA2Z6}IX44s?iGjD`bZ3`=&u59=eB zsgqa$moycFi(0H|v}X}Vt7!aULOtYLI|&>XLkxswAynXmz_Y<|7j-(&8zLluHR!mJPdJZdVyBE8?a?J6$tFyLEPkF+aBF_ka1c>Stvfl3*6=j z5R_Mwga>z3#8|uDhz7N*y-0wPlv|k)`0ngs4ce{X2xp)VhniLk^mz%~9p^{>X4D1{{#x1Csd8~r+H(1pHflnO8rjm{SU1DTYNwS=ULmLQP zx2!{`-HgM3iy8v`eAh~vM)KR~F?9?b7*T6;x1jP~b_?q6;jLrh2MbmI+JI>&xO$muXNzGI@}>etchAStG0$T)iGfW@oVHY=rJufi~H*d^7gwmUjn{VZxxKaI!A z2VR@`z2E%yUp|V3a~#7Rm{?<}+L{OP2nV??z#Q+D^Xli>YLRAbho>R%2Kn(s85b1V zFvwQ>z~WecKLhwWihQDi`Y=(l->y2r;G=BOVk5&STq#L^MqFkq(-G_=(J~zBKNGsV zG+QOJ(P?|TQM^PH*P|9?`6Xie<0#5QOf;?l=jlHPYdLW7Ju@lzE^i&C9*gl|`ZugS z!dlUf@FVQaQ+SM^DTMTvQP&7F>5z!m51L$)YJ!2;o>hyuO&9t`+2I(abT1O3R`gg z_p!^=>`$*ZaE3OF8?6G{@<^&&U?RSfpB&Ek(x>Z$EBcG$gB$f({Ji~mh{ii|1_INl zoq-~)H{>rOo}dU;6jjNAaWOEox4*`my@*tA;qm9h68v!eLiYfXsL#y~DXRJkZe3?q z6JZjlAR+mK#7T_pdG!q2eHM>k>>1o3qI~hU5?_JJgk(7zj_^|$;lN^4IjN~T(Mwk> z6~x^u^-4kGbcE8ou|7Wn@E@=%cLS&_oNTRJ_ODGKeL;vWU2e4kwgeXvD_$b^eg+MW zj58HqB*B8uU&Y73#^t~s)q?XC39=8J11(A*njn$5Kcc8fGRPi^b$N*6B6P{ z+%^Pke-%WShWPC`Pyc}Nj^mz+1Xx3|@j29!hGG_^8^&`@${+FbTE0g*nbHYi1aT1L zYQl)&9tNX12}U(_9GX6m-SkrTPJLGPyf8za@HdY)fT3WTfO+9bU}x<~nrv_-5m74o<2UP{ z#82vvEFiquqY1ZSoK1{H?lM{%o6su#ag^ZQLSl}tg4Nl8laacDSQub&#Yk&8Fgj3b z`iL1qY-V7KV#A@PDra=>`P+eB3=dPl0?Pv)YKD(lK3HLp56^$paX12JKn6_4d15Zd zEy#}y!VMxM7j7EfI_eY11xgdgNeUc8NI_@Hs{wcx)`$1Iru4+#AQK(L+mtPGTz`9h!_VhGz4_t86Z!-Ab`S?1k8GH8ZN0-9sz^P_9%)=}q#UVtmR9AV_$5Rwe z;gsKoEm_hkd7UK$uaeV24P+&&L)<7U8U4+BWa4V09i(u>&l;{#D5v#${(I%hU2(Y3jGvu_8A&n8P~C)$@?pg>LGwlo(+Cn@chcruj{iVvw%H~XKVC7 zIj-b8#gtk`3qcxyf||gpl=L>1mm61ZLQ&2X(Nv6zuf<-&g&kr`Ns|n`4x=y)$y@+8 zTN$_$3U=R#54lLFK$Z3m+_hl0hRYS)#^bMJFT7?BUn^v(|BL+4>U(O5A6B!|fNO!a z{`~@CV0c&o@M@J0K{yVF+&MH!{O0bP#6ZA_4Qyw(OnNB(2vtg3Ro&6S5*=KLghm=1 zl@u8mjEuV2Z$Z9f)I|+^Q+1_iVezB`y!F>ma%3!N?BjsQSW^CoY*Md>S6|WRUd{R# z5+N`ojUWUZUHZ!*x7dn?b9;;Z+;9{a$!%?eBu9B)%784JJ0U+yDf;Dy=+8zUe?JXN z6NqrnWJ~nLN1Yiz`#M6nPR82CeAE8rkD<04xnYpx10e0Z^d}|c9W*5P^14=Xo6yAbO_4FGRD_|S=`&k!m^#;04tCAd6= zIs}(tF(dAdT!|$u13k%Is88eZra>|S2LU;4zYAFx0MobFX_&e{&!Jfeo%#CjG(ZEXwN#CPezjOyzhHoy@eQ0xGbo#5Kzy9{34RsG{ zxFxh_k>Cvvpmt0cTKwuykR@8{LJfD#Wf|!Q*vQLhygaV$PUjG*ap3py1!)(T(j_<6 zaYu)%7Mb=0GK0c_^}_XA2`D8&Tb`8!K4i!E5B#E|SUIQG*|E#`4#uEBhnF&r+iZFj z%p~+GLaA}YO)Ap~IVkIKuZzYe^)c3Tf{#8q4>G(KNBp3JtadvVpJ?fv45N@1sMCkg za*p$LJL;K%P|hEBY~HU>k`TE#5_a=F%TirHTu0WCKa`I>hM(*U5%&VWAd(gf_nKxw zQj)J4xY-oOW!&Wm0)s_47L1)f6EUPp-G&O5M$AvFmfGA?th$W{Y~>Jgs(eF!uEOUx z6nGZsWCq)Kl1@&Lt6@e#IxbL!(8D}5NiBwXK5k4M3&TZ;qRsiB!%jEdegnIwxcnWs6Kk+j-BVJtpuRNK-i7|N5Y*g&Gw#}3aK z-x;Q10?E1SRN=mpOIrBhz<|q3i2hDS1Dt9~pWXD-=b37$SJ-nI^=*8Ck@em6Xh435 zY0S5qw!enB6C8^YV5jWPtNzH`;_HTWqgxyv0oFLODUjV?hF+{JgliWUV3^~3xcpon}4E&c%{e1MIm*%+M`O1Z3J>_PrNE)3w#loV29E|?7AT;cpnL@=3i=7nqmzY>(_zRN!xZFT?Xd1yTz#mU#M4=p zcWceL(9PVl2zggN6Zb44Mx@>DK^$3v>`09S?@6P_?!dndQzrMGY}yCa=P;z@1TL6D z-3eWFVo|C~_bawo?>QVX=9&tU4N5ejkB98O)=H~$VSpbvD8pMzgz5<+4*}`N`!h$0 zq$20#S0B=>JC57H#=>5OVR7MTITym<1e2wc7jQQfyHb%H07?e~i?O?LYf|GD+hb|? zpVXZ5M68F=3#$HjWQJbX`XBldAQ`nLCXOhZIHGJOtal1&x#R|9hlE%C6uL2jb|b5d z%n?`FHQb@1eS|UHSuoFBT|hV@$95dkjtvo)#fle}6{H~vLXqjYL1qb9?XBZ5fbcC4 zFf|>*#SC(W1bR@2n?Z)iue0^Jklmf1h}eyOUawCwI?oE>`u6d>r6|doJz@;|c1;A(x`Savx~7#ks#Sty4w!wn#&a++r-o5;2yQ`4n@H>~&_6^ylbgxZ9D(ANq3-A7_xK?D z2+cNbsFtg!Tp z9PiwMN&%M%f&ZKm#RBj3noF7!(G6hM-*9E{@P4VkGR$|9IJ{rN;e9wkeaFP%eG`ZG zXU4LB9s9xbOnQ8Oe(%#}3MV2^QhqpYW`7DTXZJo($ffTyZ5W68t1W3|-!&h%aFG;p tqn#w$!8r=7!0bEblEfkf7gS$27a5}bpG2rQxv)Q-EKH{Fn91b!{(le~IcWd@ diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/scripts_utils.cpython-313.pyc deleted file mode 100644 index 3a2e0c207675385de131d252027055bfe1a80b11..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 45514 zcmd7533yz`btZay-_eZ*&_Fk~#*W5H65s}|Ah>`42#RR7M2nPKO`rin0vph6KqBE- zrv1%~K{*LQCw_!%&j^Y&5#(r|;UwQPobj7TcAO}atuLEVx3A%t@%Vdw{^t8Kph%Cm zO!EFy_uk&n-0ZxUcS+o?Tg$0ar?yk4PMx}Ivza+OukT*D@b;@5_uuG4c`9Vb-8Xd{ z_gT)*`Bj4)Fa4@`75i56YWAz)HSAZ*YuT@k*Wp(^s2?!!1}~-23>pVaylKG9n+Ghs zWx&c?2W-4;pnxx6;i^G9?_g;Q`9k*VWWU%D(l3Wdm;B zJ>cOz1Lb@<^D_)q3{>)!%xxU38mQ*0z1%tV$U&u8>EAN0mh+oDoZq}%E%|Pj1Nj=i z<(zxua5}Ey@>|c<`E7nxx87gi;hxv>^;Mkm&%fVqN2udm!MVC~mUDF=8`rI-x4-b5 zb+a!Eah_}N7oDs07bAw5`IOu@uIpT*zw}&_zl?A=XFFFzY0fpCYmjnD=hyA>y9vAJ zReZDG#tg1fZ`0GLEt^NkQxBb$TYLuaM zp}gJQYJU@b^V|K+fV#u)JI9Sglv+yv`diR4TJiTdG&1_z80=lBlivG6ejM7hqlWl2 zR*Gw6R&k|&{T=8fO{~oW_inSlQ!e+Oe-TWw+ppfgg_U&gePM6SZ10SUk5~@0?QIP9 zzWc#O9d@x8`~BVMH``g?QAP;8e+LUY;NOY*J;+9-T^N;iE29$M>+eCxA^)C?e$a#- zxtFEu^Y26G;R%)R(2t0?!#27sJ7zXbuJ?V!&+; z249>DP01nlU~qhrkIqh^_>t*gkhc-c0#Iapa$=5WREmm-6^cZrXQfZX(jyjcq^O$c zEH9P>gHZ|`4TffC`N{Khzyh(#5aBNVIPp~Hh3Sjo&WoY>3!&-fCqrE}@t z>EnDjGJ2u2iyxZ_@w1nipY&q2*BKqjJx{%|h59jM!5>CJc75_-M zd`XHIAg0ft<^0BWlqI4*Yh{JPYotJNrGG(1wNf%KH-QpMr5?|4`Aoy}YVTQ|8ic4H zo1B@R3eVfTqth2>rqPRLqdnevbsPSAyg&MeJ0G$PGUsTz7W0~ z6*ZHQSy3ICK_8mCC_3cu%fXk!lM@$aqoQGIdSYTSG9hXwBjeLP9bbnG`FhILj39J5 zoPLQDzBAo|+y7M;)l=D_pYJySRVZ87r*6GLS5( zSUPq6=|oA3P|}hp=@d#jKPcI{U`kqwmyB_XH)(Y(4K9D|5B*n1epE&=_^$pki`b2+F=nux8n-uz=|Dg8;IXZqQI#TrJ2oy zqx3$i?x&yf<7&jR9aq0PaGPbxzf8dnQ$4U0GjOv1ojGntvfUHpz=-HsS_nFl&is zHtb670?9qjhuH8e7NVAe(aiZqERrfHID2^}oT_INjoA_iZ8RJn6N_F5hi8J5(dnts zYh#5^Xu1#%jj?JT&7672%|@bOj1HgW&dPn3!3Rsb?wm5o;$GvEi=(fA%e0#DTxU^q3Ja zD}4~-*w;C34kd{=BSjAz&HU;>Q32PL(Sy&XXf=}R_c9UigWg^g&;_)9O~4p1`L(mNC{gh8a)II$H3vbH;*r9f$oR|{-FM2Ra79W3u7G0?>L~==I=O%& z;F9L~K%wL!A2x?zq|1E*r5P9?{yz0SY0;*q5jd)eGXCSlZfWd0TcV^7l=uw+r(aKF zAVzHC+0s-%pz$;pH;Q==NLhQ^pidIc3B&O>;>Gi~_$cql`ARgo23sWPe7>Bd#Xr z!P2}S1EX|FJvkBaS<9V?pPAV%bC;}~B@uT91|@!b8iNdfx8ITWN5Am|tf?CKL{<%0 z6I7(L*vnlxhY5uZze46)0LobnMbY^OD z)=Ml8f01Rw5~t_f9xtsIy+|;HMWT0l+)FcY)H^$!LNbj2OPijfACNX(th}j}S-r=5 zRG!}{XN4-P{Di?=G|bA90c~Cmo|`xuG=~e}*$ZKw=trCnXe4nyUdH*LDR>dAG%*Gy#S8U# zhq1(ERe>nv%`mt)H#Ixi&OWj54NpbQ<`K4vlba$FClD=Z-48-Y4U`&byEk{4?R|L? z&5e~wZu}YTKC3#S3^xf%b~RE#CP%$t$QUND0Pr@W^-?qNqH$nd?G3#Ynw$!qp9=fT z!}F!o;iz^}t&93nkM}fWBqnmQmc;KtU*>!85=+CE5Os8xaf(=kE7|b-DarwQ5i9s3 z@EqK{?KH^?SVs%bdPP;|yz?|x_EX+W|D#xK^isN1Cj24#Bv|wICrQ-6x+dy;);nJ; z0m$~A-#a{=F;=jcSgE`;ITjvklREiqRv~E!L0v~@QHdS%Zn;tU*=Uf7+kAOeP)@F< z(=;40vM$sY5@BU9mvSo(P9&1fWZ7`o81<~{OASyqy}Ff=tKV=(k}UJkyVM=i7H{aeDsN2dNmf?ljeLqLm(G4r9Ox8D(!&QC9W~em_ z#;OIwV&#T|GkFq5uVC~pSKcyq{!bL4C|TF?*1k9P{qBM5_C>>zD_P`;)$WQH?Mm+0 z6)SF8*(nrvC3oypoVOi~%P&JhmvHP596MHDy5-oPB(J!mb9JlW=>A(r=N%hY-*pdJ zRW3cVa&fJ9E%dd%we9b!{-o~R7vhITV!i%Y+vyLCfjjM7am&4JoX#C{x5U~G-qQ90 zA!E^csdu^jgVOe^#};}QX4k9h64hI8R&Pyus*?3PleJyR=Juqg?v6$8EfA~~i1RTJ z(N>kP)(h79sN|ab*rcTFuJxUzVldYM}MsP`2WOtVXA5qRaVm5$QI=n$vQ#Flb%M2SdFYmP>kq8n<%(7J+SavGU+Y?HSbOnLir#(to7J&{ zr(&HWA6Wf&O1QH2dmhePpQzq`vwC}M$I*E8v6!&}*u*N^W1R;4ToQSm#$A&M&CO-yg(8pCrNgVXN>KMs}`qhK0 z`YC)u%?Q6?=_0_*DTl)+FjKO?tiH*Z8$tR1sA*WV;b{ubN;>u5r~kw@*{ejjMe6`EK(uO#jlT z%UIDgyGpA72t+T^0)h4xydi!9(_duPJm22OSlH|Y#01HrGBA;&(>yHYY>;=&TVLEbYG|V=W zq{-Lta)>7$2%AM_@J%w8%_cxm6^>G+VO*C6@KPGUr9x)O?ReVP{09D_JYHA04WmKV zxn5P1sOq>`)v+=kuj+{z%h&Dhn5X@n*@V4Yuy@Cd^xm*?R*bq;KfPz+tR;6jEz)7Y zwZCdl7VN=T))m`79J8LhN6il7wAZ&%7pv;HWh_rxs*vm#8y2Mad9)bWezJWV_xtAF z2F=&2dcB(WygE2X;h%bFA<5bxhn7ig(2rfLK`w&bf_(cCezjk-8#NzLjmQ#;6eKuC zs72>I`Sv0pWq(+!(^9L;l+*>V;zavR$>h;~>(G8R4~3))XtE)p)goQ5Wq4Tt#A3*e=BG)offvZW}v^ZDYwIkZ18ipQ`JrB*yS3ue%Jgy+Gwhek)k zQPv(Qln>8LM<;0s+v63r#L`35ubBW7rX4|$sHHWCsF#3<`U~Mp(F-Bg2Ki$kB>y~K zJ}plzNYoNAevZ5^;uWO|Kr57+A@&EbUl0UnG)6*+zl3msB2MFl_NO-KT;DM|Z`+;M zE?vE}TEB29X5YD>O%|5ia9?*P3TuSInnYocjb5&gsR3+P8GbCAF%6p`e) zE<c^V-cJBlVm8{8H@3%^?bA@}8e@3+K0? zM%p&Vd{B)9RsMo3jA;>-^NeiC@xbTEw#|MDMu2GtCix(BZ%#ErOSPQmAs9OVotAYu z&j<+$avtDPh}g8?%X#D$l;caENaMUu8mTt&rzt|3$LtTtsW^-Jv>eTOGN@YeuvVPa zYrHvt@F};ZgPTJ}4sG^Ra#ntUt?Tj2DEE2vm4JczBA}H1(onkCm~{W~;PB>Bj2zwU z=P#o^C{v?6e3T_UYoL-{4SMf{#%B?NmxZjeFrMM!- zYVMeN_`j!CA~nNefqyhK74~149G^YW+jm6Ng^5|4uRB4~UlO`ADXTX-mUlCog02WH z8Z6nUB@mke_TvoL4@`9vq}(GZ=g^lHU%FJ7KEPqa#4s#c4uLxiPaO(HCn32z&pt%k zvCtfN#!y7^HcrT&P_()bK7V=iLii;q#02O0%TatY>(*pZY0^@du+$5d`h=xJuyiCWI|a+m*q))dWr#JKJfUj`;+BDA zksAw(guPj?Hz(}d1pBsxeUD(@6Ep5%%}LbJ5((N)W76`T??e7Sg4)zSzYT?bxS5Bj zxwIwBK1CfLycoJ9njuaP%KfW^+)%^19KkjVI?L$X?DY6JG(ba_g5jwt(T z%n=+z4F1wYkRU}!wd?#e9~BGaIPg6m#h(_%EjLi{kT+z!H zFNUGHHcB!Tv^-WRAavH3LD7LL=v5R-+3H(#W<(2)g(vuMn7=_eIAm&tLC!JSJrhk* z;{dJ10@bkNQFI2ok5CDo3Sw=3aXK!T6tyyIW^3U8BP%n>vqDe< zlUhL#B^GTFREp$3IeQtJw-|Yu?y+h-MSW#g4!QDQBtV+cnK@C11^{i=H_5|}!o3un zol)dbqWVJAXXY;>jQ_XvTBO&1qSyaUFRBvJ!3O!@#VKfv$;~qD=6{#s(aIu9^+nPI z>5)WR?)wfC{9hsZ-%+{*R_&M}4AxJ5{1YEf7^($Bb*!!{Zs zVHe_BY6MG-wBZrAG_P!pTiWhvIFsd?=~dHj*DZ!`M6O4cM+8@M!nH$i?T9;f-ZJjG zZ7ZO?j*@#&@hiUJ{)`(EvqV+5P}RNqWV~wcpJ?7a@=dc)H4v*De7`GJellM8#Fc>! zwN5vp`hnHHsJl^cy=GQi5{@H+^JRvxqh+9v^P6q^Q0DCK@ z;tg|kfj()pFP7XWzg{khR^o+q8ya}uHWvT(=u$(%vqkW1NqBk$Pfx;gT<{!^y9PFN zh=5X>yvrqtnw>(;&O}YGP}7^J=@V-D5;groO@F**V8cMsjGW22xIIzSCKRT`yQcu&xVBptupK`p{8pUpm!u{4hl5~H*6GGz?sUI zjwC9!2^HHC6+4BBo$-p@8+Ho9R!Zp?G!(c+sMr!O+qzLm-cHU`_66tb?$5d3>{_0C z>ym6teg6I4jS`CE;!Jjw&4D?gL9jO@>@9-5C2nt9c`SW~VQthEVin_z8ASa%549SQ3J!Fm8YHf83Gc8a?VR9FlrT+L{s zD}{oqHR0MLxb`Gm`vuqjwHM!Y#$1Otx)AeTH;1m!5cBPgd-i?1YU^r4Vq3qktv_CM zJXuxu`s8bqpMN1y)hSeU#;dmG@hNnDxDQ~i9Njpm;;haqCw_YG2ppcfoC@ta=_+Sk z<&@xRlJZWt;6C+f%Z7%+e;#$BgMV}H-jiM2-*mM=OYel(&@8I{<` zM#ncGrxmrdN6yC(GfG^6A{MnOky-F9gfxQ7jFP|#OpK*$-DtaJe$~8qN-%DFw*osk zw0Wb}=_TI7O1wvKb~;FUCs~Cw0st%I*Y(+jq$NQvV;jN@dX{3d2h5~aai@Qkbr~2~ z&@BY{;#ndjd;F=rmoOynqgW|HR%S|tLeiPJ3VaasqLj|5l3zr9vI4t9d2!0C*{4dU zJe!hN2GToISoq|_Ai(rcJ5|u8RiOZsBQ>v_C!a)9N;yJ`6FCp`O=-!4XM}i^oF^qi zD9;JNXDTd@+@~u;4OzU!@;xhxLZ3gm{tZ;>$m9uV1zz$ zp0xfclRYadu8dgFM09WA0@}l>9jXYHt0S!$fwEI6MZ^GIN@Kt<(tiIqMcgq}|3@Wl z%9+pkp|?xALe#F${0`1ss41zIJ9CwgJ%PG~+>^x^Lx1u+9vHg=F$$#^$5f$1XiMgN z(i%xh`KbMlB+?PfnbvlYFk(IP7ldEL25F{*w*h$ng};bm$sxsZ2xLjfe=DNrJpPhL z&085S{93wJDPI-+$Vwd|UH6YXN4?qc*jE6*rUkQ(ZZo526DVRvV-R!*2BjEE3 z;8QB&bMbz3c~q)A2K{8L(XJEX!ZPJ6iC3%DkXu>ER?OUez zE3qGn-Xc9nTauI3z|Nf+(>wS1U_guatH`1cOADMnih4s{rX0z1BS};y3BgGBGUTOg z%_&)pl1;we^Ov#3c@bKBq>cy;%NZPoQ}hQ3M3#cDl%zRZ$mWCMA`;#&Q(E!>ue4UD zJQ;hNmBXKBan9)(848_R%u@E0RYp6!$H$pMWgF8Vg)%U*WCt4|_=~hZahj=y0@*Ar z&}>9|yuIE?_+@VfQJCiEbOd?`1TU?hKqUvP&}1|$E0jJ-Dr;=*s>_@wt!`L_Dr1NsL<4T$(Cv52Bs+YLW>Y$lkB(?ep z)lRffoW3>8%mx{JNmQv3)g4>-rvOz{UC7p;Fzubd;`h=FzhA6jtqB^eQrfg^G8F}y za9p$kV15J!E>`Zp*>doW`sJE;d*64zH~f0$CLwPxN9-?XSQ!ZcWyd&VSkPV-?M$aN z9o<(T2{QhcQW>e?1SiKXF--@oHK=LJN;Z&aK$gLTM~XzlgnlUrv-uxT@&fu!g9ZUG zE{Hm)a)d65g`;Tw5$Jpeaom!%P|<|A{4C`s8lXu_j%Wb@$-)r%N0NyYWVcZMF5wIIe zU?|_vYD`B|>!syFx&IyC>XEMwer51h%fWcLKVIs;+P`7cm`&j5Fvi9|}ubo{f!>Tn_0guH}LF>M4v`MO#3F8*QxMj6W zF!p@UQ(=+hUtN31}w+@`h!l_l<%-tXmC# z^@Xpz5Z`u4@E(d;E0bkjp{#r1ICNG8SJ%R^q}{uGIBsu-uIJiTp>eO^*}G_fEX2{W zayo9`nrv!&>tkOEOK`bVloj-KCuQ&acZ$rlw zR>vLnsB>q>YH_@9yPP&z>>{zs>w~Wi#><*xEj^^!>HNO4Xz|SFUR>KRLA$z znlmCS%`Kl>9*%80601BKcOHBHRLnZeVxENLp%cC8j~a%mxc6$xhxT%RS!EvDt@+Dl z>rl7mFS~W*-fKQluK6pyYbnRCmU!x8BljOIk{{k{a4>oztK3JS0oa8L z3cHZOW~S?)=7vC3$Qr^y zxE|IJpeKqhMtw-$mW8zC!9u`%4Fdu}Sz;u^k~Y~N4$Fv)bDj)X+J6DOm<1M;QG)fW z;l*HO%JHOAp2~m%%ymsD&Dm6gN4|P0Lq-fei-EPO^dN>ICx&4Y3}a3V zrKnOe8}TZx^e@Y~Cuq{1&?~iU#F;rXzMp zEptLLV=`Kz35h0BOi0d(8=coXg}q8@3i1My5i9gzcRRJ5VU3M-n!+1m0uI;_r|^Z;AaKzyWYBG z)f3z67q*^`w+5~Z#vDz8wt2l0XRTI_V`DX5xf3x>g4VO%=v!%tZ9gWo^v4^IQ%qW3 zt=BdzKfhWfG#-rC_9AA1psil7Yj|t-@9bWAdd)7h_s8oHdC@Ltz3blEw@kldT4`B5 zC$t=jd;1g|n%Y*Lj_n#0+K1vzCkS$#psio`RKMQ*T64nFE_m7#o^66>8{+E)ZRvVL z(_3eL=gi9UvE2hg$6&l+h|<&u+S+7&2aah8^#{qrD`;!hYwDMuUbP7g`{OkSklY5Q zFDiJn98)!+LW&A#_6M7>+lee6a6r=$vU?5?eyj73HCJ&)AY;w8PKhfCpLK z?0~^7uyyj2SA&a2rkv-Dv;(Pu7?!wR_T)Rng4Kf-tB2H3_;ABf{w%Xu%8<#{8d@k- zrnpZfS4B#w;OAjW@`T-5nREdddAB+`kdS1|mK z5dkN~zeVxb5~e644=Fz!zABK`HGqx%WSoYvOW#aBbUAc!D%V_90SifuLOfOZU*IqL z3FL#}$Jf3#CV2W{XznGR8y~s;ksHrme>Q0}t+k<_yXvvM>(XBtMn^BLO1Ro?y4vn= z>Qeo>x%SQOw|agDM+cE z+|nLQT099$rC_OCnu=RmW7<}!ZBkoQ9%4gT+dPLivCPL*=MRCAQ^ND=7Sh^OsXf#$ zmv1Giv;&#Bh8oX9VWu{cz1Tv=RxIt9JiYiqMItYVmpr={- z;X=_zz{wE6QIPaO5O3x>45oLP5-vED z36}Ab##M1C=@sgZp*;nR5zR?8z0cMui36Im^@>};GL=TIanwGrPBZZ^Yb4VytirYu zji>C9x0SRRF)>!*=#DO9TS<$pDHs#%)uy+mqG~V9O+^e5{Ww>{wWV4F>G+N#teW%U z-&4RrIxhN|gm{)yv?agbigrpA$!O^qX4qCO_$Cxao%CKU~L$^lXu95)?E zg%ts8K*4V*g$M+c+WAA=Smw7aJx^ZWaz=cdLO z4#*3zuuUTYhdqg?Wb7l28ktg=n0zT5Y4b8|b?>~ph zAVQ-fjCb(!y1Ci$_FYg7Ho|`9G{w!@?xFdw3x9VpAZ2d%lwd1Sm0`XukFhv;+ox3C z_TgyT=WP|~>>L3seiARZzfJD{g;ypo$$tkPVj~@#mDN!d<;Of5K0ii_iPj(iZTM_( znrheg8xAcUfy$q|Zh8AV^@*}h*dY)b4*l$|7O_5AS*@C{CUU0t%_bu_?>jj*bvZ@J zgA{>u{!{xsSuKJYF3LW0j7F2$K72*eDb1}^5=?wAql#!{PA2?V%?9O~4I(FYj8n9Z zz&yp|MQIz5>SHWQTXLcnb{`|W9mSO9)&YQ#xcL-wj^oQyQ~Y@fNDM?jYgY0NkaXu zn4>FJ+QqCJ9JtwZAX!|`B;a5THx5Hr4K^sMH@0(*qHDi)_1Bh;CG7Qry?)(Mdc*n| zYpiwyzbgw|0Jc_m_6Bp8CDLvEt_2J;RBf(?ZYb zTRl%A|2CJ{^W@tDD}9i>_kG^D-0(TO(DNjigAyNDfi~AqbewC)#uhpA23sL`{fu7_I&+tqP$lq?~N7H8|ss+fcFfXy*z0hM#=4F@FynoPw(x9#f`h1#^zpk)h1mq zM(||p^i#2?g0bS~5UtdCyR-~PZ42EHnIqCa-gPR7-zwRz;Qotyg^d##4J$?BRC-;rYmLolyZyL&ubQ`}}ts!@N zZ|PAh_bsRS$X3m_DjY}JG~a5|k-J-u;BWPqkD3gBYBe8i)%>Z;akNqMr;R$eN2yCH z4-Gb$vbeji;GL-$=vQG9#4hvwCOT;OV32ySTAXE1O?L{a z87$0Q5cBNUDKNlB*C=)F*XK_gF$4^LLl4>r>AArzo$7}CFr#(XDp1nDXx%2nHrulp z<$wV!92h$3$)|ORbe4c2o4uw~K5Mqx2LWMxz%Dp4)IK1%oh}@aGsygrXMYsL!jU0l zn*DhX<^wXFke&i_jOtDwH~M1H{$SGg;gb4OWaQd$Xy6Qs<|NpBqxt zVX!-uLz%Z#>=*~%4k~RbG~Wr|QYsJHg3Riui_GHjz*=j<>VpZK)#AUkZvU?pa=APe zs`=htnR(1p8fgDWdO(UzHN{MxM0DA7dsb;vZQx9HX}!b?VvvMy?nCRx6QD-^P`QoN zZlVOg4TXIG9Xvz_TCqqhpSIro{p!|53l143>}`U*?XHfqwJZ#RE+1HZVs)#m?^D>K zx>3}k;=fBN$MBkO2t?+h5ErM|>^ziWcHFj(&sRQxny}&KfN#Pp=HIlhv-osaKBL5)-`zYYg9=M84TeCoZSxcOAVaO$Svl+6FdN|+sx zin>I_zMB>M)~4bWgE6C<8QU*USjyMkRj*fkt^#T?ZHbosLd*Vm6oZ4ZAkAtymCsnsC!*_LR zC^mq5bC@Aps;R%jq22y{b?aRneqqF?#vC(N;&6_!a_Pmm!5gdJ4Q8_b01=K^_UVds zAAi3t)|EExW9F#3eCe_rRXrPdh?B+1#2nH6ecP!w#CK*sXLZWJJdFm*S$c9Vr=- z;+gR)pP=?;<(h{9VY%jmuZ4LcMMjkYS8>W`vdN4JTC?=zGt}~Ur9xS0TPNr7>vuzn zEX&Js!sru>0EXI7HZGWPbDzh?1tW6q<2OD87Zby!H;XQq!*ZX030y#}JX@`h@Nv9X-_;|TKG0iDF4ud6?Zt5>0Ed zgYUITi4~lRWopR9mI9AL7cEs%{wXT+IC7?KZOW6eLt@!eGEI*#dLCtD9HQVrQvVx$UNN8_ahmJZ*xfZS8Pk@363U7+Y0E<^{vX z1&Gaxt_*L0xi=SFb3@tpV%${6(zV@mwykw1_M8y#-#LuP&&iR4smPOYQ#Fgc`=)bu zDwDLMyjr>?>5f6mvO6b3@TVYLh?^udcinXEN+(qCP5CGgo{%9NOF@{3n<^O&`))e- zrQ+vtlvO0knuW6FJ0=}NhD3%`_T5U+Ux=G(7^J3~&Zay7Qb^@O+93gIl>x#cXS$tL9>*uMnQmxaoY;jrzah3 zz@ZuEKIxOJUfM)VO%UmLqqSf`Qu+DmG-)2@(yybFc(|_wF_kT@05_G&@#M+;(DOs< z`8wDzp#ESLRLD}WRhbWGY=;^1Lx$ufC8_gbOBeTdG3`0TPRSJ0W%rxelnPl@T4B|n zx|;-NfAgo3*p5Y{N->Z5q>yaF&BM|GY94kBzIm7{zYPdvri%=(@L%e)cFJ@HFre7Xr8CpR?tSQ2numVq#ktR8)2}`k0(yMezX<)vI^E;8Ke9EK z2cyTJpZ7lWD`)hBzF|iHZQNY;$EKfmGYnv$XaZye`YW*X`qg{&FGJN|u~nXYDT8J$ z8AXzvk4H0qC4>|II&`@I7?!f5>gCSTF$4CT?NnT8C9kuz7tNjk*}cs(BI0y(mA`s7 z&8c}t=K!qpp&jVYjI!H(YKD)_xA|wG`QSbM@-#NedECcC@+BOcon=}Ov_U?35f`tL z;qLI*X}%A+|A>}fKOzQUeg{(Xv)(ieMJ5dmGdhBcRSH?)B0(iA8;@Ryws|K25;tn> z_p-y0{2(Uu`OZE>WSj9Yv^AOukDX|hRw+2niuDe0&>j2`Qt>A!X<<6==Fm8fJ$mOW zy;5F2X;DNeyo!b_M8Qs|2(Yz_Y^&TW`N^5Q%$7IgSJdY9X0P_(S;F)Np#H&(hkNNV z55(n95m{lVj-a-l}A}*GB1g&ikFjOWfN5r zwLecFpGg6dNGPY`&!_w);wb)D(QrI1^!TUgMO(By&Hg;tE0|)G7G_j(MlF7vBGEcc zgnl7gqR9y3`c1IFBr?s?j4;rO4I%s`qzR)tMR$M~#cB;KcIi^XO$#izpN!R;*+N&ADi;nfIgr5jWJZL9Dtn);aY4v$3Ibv2&5wbo3`UCobK^KtCJU zER(KeIHvvxeLbyyR*fi6t7quC2Ks^LN7OIU&x`8A8vGp59M>Sq5zRRbqRcgrytuvEv{-{#_{D9Fk6yRArIViA4r#tR)1Ng{r{+t zeT3y7QN)|XqAY(-C1 z*>Tg|vAJw%pNthfYq&6d#~>j7*9%}dE?!W(?kvUOY1jk3vFG}p zq^CStQO{&EI8kUXx)!+_Njj?T6qzmhD}Au^X0`$-mE>@1hDt5CQGpa``Ok5iGy8djOc37*b2_5QGb7Mgr{c9ho6IUqYFmx3*I!!bY@|^0>&8-dCaQM{ z)w|-=JuJ9B;RRE*H5F+Ki`1Q{-Xm1+iC6E7d-i9-s6!U7cf8h-@N@~DuDEA=61Q4q zL&y_F#-Y^%g#bv79oKeV-I;V0->_b{CNb_^zIu77Ay(kMv#+{Df8_)-$B0|j%Gkii zUj?<%>QkpH?DF{gBo>(ZOw&PUp2W$DLTQ}ENdX=DvnJmd zC1YzyWBZU2=NXwcml!7wQo=yl@K)w2htye;!2M4*=p(vusOb!gy3knY;>)5{SwfAI z-C?Q~Y_2(3STHka{-e)jem!3&@8iEC7JXoWKsO`YVjzFDr(8gXOz}d)SM;241gwlX;f6bAgW#x3#0)<887l^ zAK7SBN$o!HbgK04LB}W^BVYpxcEYfBS}KJG&3du>#^CkAL~+ASoF6UTv0}q(!L)&$ zlM*s6Q%YMWRadK)_Qy>=nDRB1e%HDan`v#k6RrD&*8R6y58iV7Z?+zMqkDO)WIyWj z^-DFMs~1`iE))oEf3mnSQM~JB@h+iwU~NBM01gYp1ROgkI9tYYufJJbFBETI{uo{h zrlf2PrI2%!L-E>Fy6&vpP;-UDYSJ4jxLxFe%C)VORIW{>|9BUC-f2~|9;JIvt3TAn zy;o`NGwa`L*W>+-Hglgr^Nk(WKDFkXY8~8J>UyO7!Ul;(5D41T_>`V}vK=)4jN@du zDTy0YM_~K}2PA1_%bpAyo^*UFkMWXp2k+mdFjX1s$V0T2m3h?pcyRRc!H2;VJBC@e7v*6vV03neUk9ll?d;@XHNZZC4iW8{ zZrYMMQ_paL0D)9HAAaB~pUMLA@iKB}z2IA$<-tCf9FH1D&I83FotHcF;s~gh`{b4H zGOG_0G*r#kOZV7Ho82(?7a8*&n#528DHw*3sp*LcY;m)(Ycvxy4QBM*rP_NI?d=@L zI%&aaJtt`{ z>!iaq#3jQ-JG(kI7=$ov4u-<`--0hJw`g&B$GE7PW+q@{do5RZsyh}#Yj##%#)}$a)&_}Rpmqud z&rO-q{u_@K4#$(cm*K#B;Sz1~q~w>yB?62(cZ&Ht2MKc^QRg%T|T{*%LU3&CZ`u!4`HPAoa}Bqo9v zaNh}FfVY5JE$*wvC5_a?M4J>vLC-3EN=S-k>5Ex|Gf8eBD{A<;h-iq;;q)JWS((LF z2njj4q<|}$NdJ~K3;rJoPdZA)=z`lCln`dAC(p!Y$!Hh+R{{GT(VAvnRKC{osi9B8 zoZ8M?+Fjq#I&K%0Tp3z-6fd6sE%TM*w_VjYrms(b;A%=+_ZfNp>DQi4Sejw~o(#X8 zd+prkpItcq>ESy%uC)1{QLAhGf>u(DxQE*#8^J%jJU5>H%+t#zp=95cVX&k)NOj9z zzo1E0G`u#vV2vBycQl-%;Rjk%LR)-OTMXtk0W$@TS{BzU*kN6*@LKR{Fo}~&PS=f& z>m46j^tj}-fUDd3spFqGp3s&E+LFZ!H?`HQm5N4?7V|w0xap&F3UbEBpLE6^{N<`!T3t~HIq&SKcSGXUvwtuZnct0d8=M*J+K`5v!3^Q z-XBl&1>(kt#;iEP#UXCHEKvb!E$bN5zg|BpJnvg*Z% zTv)yEcxMj0(iyD>i&~U9TM<1dw5OY9A8h;!`*f7ls?GbrJm^C%eI|7_V9M&QT94Kz}4>r??db6)IH{F8zM#8(f&u^;mY&ZB8YXp17} z31}Zh*aJJP^)h!tLnC_v8ii&!B#wzs@ug*2swY9*gq=^CblHk= z(kgNDiF2}Vp;YqExMMGMydH+#g5+52d&S5;A`E5|? z0qGEbr_8rLo&VO#d=yY}i&(Jig;+5638bvh!*2Ln0cM*jz${sv!1zRzI}wn6>Eu;< z%-H0j#iSL50K{{qVD9zBN%=5ZwsPZDJbR=uw=8oelfnV=|MLK29idmZB zDIa5M@~N^GHL3C*0*H+AmfXL*Ic0+kmq?ybHohE1;q#V9F=|B*T%x)f@KnlY}vmTHB6hTxMKNa?V*2>>AFgP~v1EZNm1H{KnO0f&6Oi9Vh|qoHR`;=ZL4U&1FIf8JlwZ@u%l7|~ZdJQ{ zTEMk(?$)pAgtGl{c7M>iv*foLt_(6$hfTPy*3P>I zuAH!JI;?^$!deD1H9z~Xl5_9-+4sx#|Lj9E=h+jbLHdshjG_vK6&E zoOL{i_#$|6ci%ysW57|5;3M~ewq?o-F%*PSepp~}Qz_s4Fytdf7?ab>vwVk!@}1Al zHyfUD=G``d+COlip@luTVlZ!s@XfOU;@9p5e~^h~ zdbafI@?#9$Gkw|&OISr0OxliInt^F|tja=DLnBZuVY(o?i4aydVxe^CaA@RY7D=f- zn|sE3ovi|cVsS>}?ofQ&(4M1TMGe4Qg}d_gs@JL(w7*_(8-jXRqQe0|lkJwVf`oex zrg9O)onq^fo#=#j7xXd`-lKx|D9k`7S`P`WhmfM?7@hU2DOfoCx7JFeEO1_X_Uf~- z>TPk`ww$B`f_DH(lcnW1Ucntv?#hIFi{PfatZD`UL@NWqDy^tGR*1 z`Dk{lXpTjiR2NT2@=G^oJ{HU~P7=crOo*xU!;%1+8VzI)I;OeAS_<7OAYBg_*G`l$cd9IgYG|PtHbgOJ5?I4pi8RIOyUvo zS%uPQiwcF+Tc#I^O$v}XP2H3WD@mOAv6LJ&ao}$#GdU(LOp0$~m^=!Y0t@3ouO0(njabOD_H& z0DG8f;23fO*I3}V_WY~QFO|m&8m=7sPGQ-CZoSxbA&U`9+&}bSSCXR8lLyXd90XC3WINNv-@w zw2uKLwboZsE5DEvJUT9^b-t2X`Nd|hI-nuRZ$Qi5I`-DnTRs>~d?92j(v$Dx21$(+ zc<6y7!!V>z&IAlf7*^o23?nW7pc6abpG=wj(*$~QK90?A-UB2;tW8Z=pm1_@Ufa@f zVf(xu@3C?AJ9TmF{Jg2916wJ~scY$g58FtYZ^8vM*hQdyR2;@+GQPCoG4wwO(x>^$ z-X=bQEPT3o4Z+q)pmiM#oG}GhOK?x)-dAeGqJEMj(z2xqw(QhSOiiCpcXnKR!b6dT z@lZ_kh{5%LARDPCG5IXJ#Q4H?$@C}mPoR%YrINpyA&IH6^YUjQ1+zk)AIppZq~uR6 zE1(hBdF4l1Sr$oy1!YM9Ui+1lFgU05?@MN#JW75J8H(V6n zwdibx_Z*zN{B`r^%*(qzC~I5Pt~*K=$No197JlybA2_RGHU00Gy>E_<_+!Im&vc;cIQm%xf|6mQnq-b^N(e8{B zGmJc1q8C9vNOMX+rEHmLMI*L8pav8Ti>6@=%h4I=9Wdbrc2pz|1G5Xe?*Mv4*k2J% zGeJ3M(4Q!~^S00`S+e}*5~aM4lU5&nOn8S)eGgH;!ZW0+ZT*u$kQDQO>z zW>IW^p%G6_o|pUJ?ByBSGzD0Wr=gLjI^tox6LEPhm6ZvtwC9nOA0^KSy)My9M+DHHU^d#N zo1<=#_g#7k^!g`ynF&ZSy(a0?N3K44{S&>$=*3ovrjyE1Q>IAIRLGdx2iqiLyB|yj z#h5zASFj<2r}mJ;{E5cBlI8gQyaS=gf{VU`CWG5&RI0lLoXYVX&iozD`5n&k9j@Rz zoawt<+4s1f1lRN5$kX~AuI#&9(f7C=Kjaz&uHm~}<99h9db!$`P`Lz^3p=8!tuR2T z+Im+5i@_W0(`r;zF4re&w+pq~H#mH}>t{djpT0+58?{_(XL9RyY^UMgp=3$bosto? zN}nt!+|ZEYHim`KrS4_TYr7XCf}>$WN506k(7jZ$6kR^NQuN00rB?)Z2X05&Ft9)) zR8KZc%xUJznl~)WY2`e%%k{5SZP?gn0p}=N(k+cHcduyP*u4}H%G$B9ykTd74$e}r zQOKN5&QrBf#GJ*jvbe#}gJ!PY60?*qomn1R=~~sT^azbRgvy;U%T7VJ>&K-O{ln+A zR#o+fRc4iI!*fKXaxMqnI`hVvw}Nj3H#quUD_$FYw`=Xf*Q)NZ?>kR`VEV;AqM&{` z%6?YozIy2^m%jS(uYCN+HPr@$5YvzBkUP99f9aSNqo0y4C1v^V`pEXz)!`KKqHY37tpKdG6`p zv(Z?t+PZi_y1M!vhmVcDRGU(zu7ZU4qKf~T`yM$zv^rE}KWX%;w%pb(l=9R8D+BY~jR}bH#pS!zsD(S+R{|C1hr>g(} diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/torch_tools.cpython-310.pyc deleted file mode 100644 index 8020bd538330553a86befe7e659f0c6835a65c58..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4686 zcmZ`+&2QYs73U1e{h&TA%d(|7Hf^VkH?id;PSaLZ45xCECaqPsiUV|;bW3q&wOsF# zo8hc%EmlAhr)W?hz4VwCkPbcQ&}09V9(XHy(m6K+Mf!U~YQ2)vU2w>G@6DT;H}C!4 zn`COrv+(=JKMz{>PFdD}s5AZ}?rl8j@0n!@OR$brr#J7gF0XS&-*(6DI(4V()?L$Q zol?DQWGnRw-cG05_39q8q$4L}r9Ii?i&lL~xb!Vba$1xPwInK{`h`=U5gz)p?Ngvn zHytq{CPAGOQ*ug7gL4}6sb)#ch*{9*#VOJt7$oP!X;99Hd84^3%k2eF&xi$3&x*5h z!pw9Uv~yw+v~!SK#Iq!p{>tmmiRVbdXgV*R2mgGosma9!aS{FJM}1#hLjS_3e_32X z|Dt#S8or0;`*>cIzPKtUnx3eMm%zOwIcA_;Uj`L6fZmtIwJ)5ld^Y|LVH`$bYrwYqQ^Rkui-cK02FH`APFg_qe2x z+py7!uM-uv{JfX*b2}BYj43=e-}>RV_#LpUln*^bP*^oPvqjX+T&z?+%G~Xs)0eeU z=BBNnCo|W~mAM0{l2lD#K;~>m@?lmQ`z)cOEzQV%zJMd` z@qO!#^_bBgPGPSf8%kbC-j~vI9?fSM@^N*SX={%S;c@%yLyK9TZ~>1xVCyg~k9(Pe z?FwTV!~tJ^C1X8foV;0po~*jM1;2aLQ@36q-82dRbjccA36p-T8^)C+jGZp8c9JmY zq_>twBj^C5SZ8+H?`Gap(Hc|dpy%XNvnN$vp&4mwOH5tHE8|+2gFY5@Bw!a$N^Ab3 z)y&OSf!VCwY|omEhLCrh#*8P*9%O7=5u}e<&RYd+=rHN_ZjpMkk*n6AmWNq`nDCfJ?(k#Q8;!yo zZlq;5t#L)BFk85yhI2{dz2^vf zp9yCXk!mE}ed~dhUV7j9oaxeDS-7B9+U#pflnlSRZ+&h3%`rA7&E7u36hS*d*5AO^ zMfm|&8<3rBPnmh&!@L!m*XXO}BlnPA3waOn0pzDiI@cvU;u-mgk^CglYE$c(6ZSYkttXF!OXSs5i5RXcYEpJJQUr% zx8ZLF;R6{9KT5}R4VLbwl(0zP-Fx@_APi-PvKlmI_M={ZP|XF_?%h{6FllYZ_?Y?u zzO%_FZH(5XE)vO2d!0yUlf^L_1j)*sWNQo4WKI+}ldN1U8M(i&lmu$!(!`E*o;fPT za#Qs~s8v6rshN76zE8OF1Pv-Q5I-ectgty&VLo@*Ip%!p*&cfH%)`H8T9127{gjmQ ztlaD*fxh`#@$ts%nJe^8@7Y{HSZn-`atHyrMszNoh)O$;!BZ4qLNK)M5@tBC1W2S#B3T!ZsR`JaRbTs1NG22B-p>QOeL@{FlEBRMI#+p4? zF-dgs8Pw3>;DBz(y1gM)HNAERLk>a+28Wr`YY28I@3vgESN`_W0Xy6ennYKbS zl0-FIB^NMS34-1FOfB^Al0NX-{ITH!UbSX$d=U;C{fEu?BCQ-tNA zT*^H?k65ZsGVepGeBV|oQR+2X!3p1(G3Fz`$gL(IlTx^PB?s{1uz5ZYIO6|x+|m=T zGRZjL$oW?NhLo3!k~b>>86&-uIdRaHSqb-~X0(-6MWj#=Cu%1ogB<-#bf9DK=`uc0 zr&=9r*B!cTWJ;=^hW5}A9G9FDU0#6>&bl#d<_4Wk@(_EjW1KP`{1JV9jD{+HMVBj- zj1pk!K18*#scwVy-2>F`Xy_(2C#>Qbpf=-o!A=Q9E6m}XJ^g!6R0#61xS=y9k03_6 zismx|zJNOez}kMy)r>}Y^caOwjD$Wal}wFo`i#=7xq0%z-+_j72O3V(AazfF0@P2W;mX%4Qq3i<}x{j28Mv1~d+pj-|2ucx{SA;c2Y znJZ!xgO7ZS$;TID^sm?4B9m-An@J)h?NkSGC{fF6WfBF7Qj=3|!$K!bKXq(7ggxmr z8vbdeOgYixB|gp77D%-h6qz{liZykjO+^Qgm5{t(97QLgOccIKUsQscb=J%BW*jd^ ziAfWxX~ukz-`KtUeFqh`sS@%l7mwqtbk|(F%xxksr>Q#5%g}Nx0;Pf-Qy7|Lt=5Qn zhnjb(d50QG1_o^N@NuYIH%^qBgElp04&#%iP_L}!SH4>WE-4ksmC_`eWffjwsD=5G dy>RNBGv_S0^Un0yIeXTbt9+qY6<+UU~=-QGWvK1?~V#QY4NS3T+c_MZb>ASrNxHbsH9f28Sv!q!Ct#6SLHuk@4qbMHB` z+!dwJar<=#;LMphXV1)>Ip@ro!!^HOB2ey~d@wa!MaZ|=s|aoK3$%PRI|ol-sKWxFafD^S%isvW@n z=*HI;1f!i6ywNJ#+IiZb)FNND#_fzpUDiOoCr^+DOnf{}<`U zYC@%95Qv;isE4MKx~wK5$!IJwDPN1}Q}U@2q4dZ_>Bv`Ta#}uB6kLzgJrkzkX}Q?% zmiG_I*X1KYS!WWq&?TI@W#~+(R8C&8cv@6Zbrw#eTh@~_G9{ZmsoOk_BIgoD_8ddc z{}&3Jugm=8If^OhFdWg?wBdx6sn-qXY&f1$gD%6VO@(Jv!^!3_obxJ8YSafG7{Y8! zy=I77B&_Qy9S;iB4Z{?rNWt|HgHKLrc#bj@<3h(&a$4<}4$n`8lUHM5fFkOl8JfJR zMs%|?0pgtMI7Xu&B7KgHnU$VQ#9G9=mV?ar~MmL3#3p< z{M(iW7YEa=>7yB6^Zg&)Z_fHU7W$rgDwY~-n@N>p7bDbY(qVUpjN1o`u&7 zF5-0G8n`*I{Caxw!CMOhS)t<_uYalQcGq&>hp%OX`ZeHA4A*ornu@EDqH^IF!JRX8 zzX-rxGHT;Sag-H>t4A|H&%C8r=2=55$WmOZ1qI}XHcg0fJqEwE38-3$%xtWN{A3T< z5jSd2Fk)MjJw!mKxpMYpO*TvJI%P7rj}R~Log`h3E8IhF2o!`*%oqY3S(KeG_zS-M zu))n}IEgINI%o{fsYx)C>oc^+XpJOO34NR?yb3C!wV#S7BjLE#({2;B!8ck9RL|ku zGS$7051o0y{pRS$UBB*1AGz1H(v`1i&D6BQ@OMm?J~-6Dp{m9by5Z1L(}uLcKFHB} zWQW2puBLM95-4m7e3*i3wjJ6mXy;V7V%UK-=LXl#zzowLn(S_WkLVuU`3M_Y?2XAG&|r`Q?!hdslw; z;d@K^qQ3mbM>ig|J@%dWs{65bC?gDgm-7y>S$@;f?J@*SRim4f7B|U6mH+^(1ZuOC zQ`2&m5+asJWHM@B1ut;iCA5=80Xb`sV_z~R_U<%hI; zix!C1`u+*od&|PYNMa%XgvAuxdl*q4f%c7zCL+gII{1jpj(~~94!5! zTQ(fiGunJ3Fd7_I~M&lfJuV^=9TGbfDKVcQ&v)A85}8+Vg?K*}&n9IGi~;oP#=WxVV5! z?U5(qQHxPuF3^$@TWm(6f624x$&1ZdvH6K8Zy|docyFEm;QVsK6QSv;w{q#!?Ne!I z&fA<3n%VsUQH%RZJQ>!H9W4SUIt^!3pPShnVhbR7c5U_m@Uq~cg^VISD$Di;Uo^90V>U`amy0pH)MMKyQg*09^YGkA;R$`tPT* zjfcJ#+S%qdU5L00wu4Imls*y^JJL&V5St^-mSQ=$1P^%Rv!z_-75*f|gh9vrQR!?l z0m`q-+RT(nV-Rr5iKNb~4PJ_J#Cp@)qtgF4V^>dKjXmTH2i0dDIw*!#Mjch7sfZdK zF9wPj_@^SeX80ghu(<#=+JjLI*T1T13{m{Tkcwl#RKrcdJJ5_hx7$m@$kXOU$AD-A zerrddS|IBpsSiG=%k4ap5w@>ZwdSi{%~rjdt2({Vzv^wudkBcn@T}r!&Io z|9^?wJF#Vp61O*Nk4jX54bqEPp)%07q1qxMn-_}6WF4_Ubx9TG4n4q7=EIZ5UH#}3NJHVvG#@g-TCQYtCPKS#r%VW@)K zbF@o}8+nexdX6rR`msX%b41X$paMz^dguETW|PI-l02Q#ba_I>L=sEtmr%1QKB!+YeGev!rCMi7$10}A~2rzM>0-I8zF(8N{ z(SJb&i5edfJ-4piytceI=iadzsQge|b&G7E?HB)1yhy&Os$M>FXWR07>8d+=`mOYP zD;G1BEvx^+Xu&P`D+&^++4k|iJNy3OcUMXU;`T_8Cn`=4v%&03vx~FK;dIrnZe+xU zXVSJ+=|X1st^Dv$v%^3A^IS%Jd(FXnT?N8BUCi2rpV2v4hq3lghW zVv~j^8l#ZiO47Lqo|jedY}h@3j`J=&j1)3J63&pboaYZgwkE`f93esBA=Bx?`zt66 zF)c8h;dnfG4G!okaH8*G6~}5CDop25%*{fUM8(v=HF)$+(4PRR{09_YIFOa_qmEQ=STm3itS1apQ1GTIEDy)6AtAXtWKXC^Bo4AYO z+T^HEcU$4Nw#on*FGPUMQtn;SV!vh}!DWjxRFVgje4%22qetrdRV@$GxpI}*b`3e}EsUIndQL#RyLik-0 ze#j|Jjmi^qW?EO0dg-LrE#YCXuteU^5~1y^a|3SJAS8e{cod1iF##KvmyRU>4*@w$ zTDlNvDCjIkK$G8(fMO7_YkD{lQ6bHyQx@+8dDHAsg*N5<{F~NK=Ky~jY2Jhi^LpzQ z#8;K~wmkN>iLLtPQESrFd+DqCpuh2;V@a6qSpjj?h=I!c2y_RkwO9 zk!X)4S%gEc!?@wPX#Te$nJxYU!d@9HsSGL4gy9-w315~vWDf_HBx9K%!<~qRDGkrD zvLObfaiVrbJO<56#=$G&u2;nPl#vzXTLqVkvgxU~wx1^!0%{GBv^ORE1$e19XY-}tMRTbENG-?(!lQ{9^Nw=GBoR|VIQcHZ-? z_zDETN5@$6=Q0JbP5>+%l{i=0Q6SLV2TssDGJ9WMWW7H})^V^l-`&r2p>d!&=gu`BbRTR! z^|I_cjtA@*XoDfp?PCo@w(;}uZe5A`$$SpjeEzv32*Yc?mmi`o|;N+ z<3FPHl;%C_P5Zg};re6VO#I&NPkXa|(tq5aUKBH$w*%=<)ZgOmHr@USM)81}8 zB|P`}d%ZpWoOeIyyeTG5x)mrX5IEhEW& ziLoT_2~S_#;l0Is64Kswmz<~kJ@vQOj}2t?d5-qFFHgyM!gu^5JSG;62|q>eY44fq zhI{-j{-gc)+Pta#kD#|($(Zuq>K%cMyXTf@XH8M`tszGF&sda+Le)mdHYqu(m z$UMKg((ohmg%;wdx$PwKW$nUh;Cm-4LFKuMtTY;a16k&oXHUQ5+?ZmOap8auJiWf9L;TxWQMDtC&1>P3k3A~eayPg7V9ML=z6KX{n z4D|`jBSr&FHJdcjQL0m2_Ptf=s@VxDGKi9mcBfNbY%R<=kzFpgDowv!j+}BC6;KXz zrd+ZR|4}z`cuwHy{0EsGJ=AGL zLj!M&XlUY%(G1197V5Wj9XX!1s>x~3>{%_-vqIxF?RC+!!i1N+DCB9Bx1r`z6Iw3` z+KFj9(APz1y{38T&_tLCH81;`;pM`lm&dyh+Wa1S-61KSc1{DwA~=2+T*kpgTh*D~ zl6zc+q^7=3NfTj;B~HGc*(C9oz-=m)Nc6oVQW=uye4P@vg$_%c-XL)Xef$?lA>%-O zBziXHOe09fg{0V^cgFOvkGtY=+#Qc&DIUlB2E7{{Psn;zm;1nJkCGP0?j`U3Sl+OQ z_ma1F(6>=}53s!FAa7sH>FZe&`@ia%IKVu+;JIb1?=csf$7ka@n`53o44w~;v~w`> z_u%=^rZL~h^C0v5G2!gC$b3N4v&T^G_{yiw8@@qI_ylnR5;= zcC)3X|4ni3>7!knqNPKn;rvow{~o$6PJd8p*0U8kS^t&c^M zTp_MLXhFBsFA`VXTmVg^64AD}C1;PH+RzYp95%45k_KF(93>71304U4#mE_!?Ixdz z2WRYH2~0PPN`^$LWDLjTu!&SK?wmV8%r-YSN{^(DjtkX{3m=VA=V8JcJfCXT{BEj0 zIY%RB*wx79lBle@;CJn~wJ3J7w$g6+%%#AC+gu%tKzBUmclL!r+R<*Z)I8~5xavoR zQ?SiTSBD(Bv%^mEA~`$cb#yKfB^vD|GTBQ@Ksb?6Yc00tGLeHNT=UAb$f69b=d!{$ zoyel1F!m1KWnap+?8rk{s?k`#3PCQm`fi?J6Rr62%Cf9Y8#&AohxWM!2sWTS~oB);?g2yC0Iw1dae9a!uP{tSJeZr`yD&_c!MzefB3 z@*N{LD9`c4keBmSXhYvcwYdbzCRxUp$;f6&#iCB&k`R zQA~|5F_2NZy&3>?mMd}z^XvJSYgJ!gJglCmT#Ub)ql%_v2)e+4`* zL}_mb(xZt^<+5J}NG}g%ccSEqLeA0HNLhk|z>irq0v!d)BSThN7ppT6-v1o{k;lWh zs%U6hpw;!>42Ro2y*KNL-VQI(+Zh&R(F3Y&h~BPnjP3vKAh8YvixZO+Ngv3@U z+#x^6Enu6X`iY+HKo57Krqk0Arq_Y@xP^+xO{8VlQ@1pWhBrL2?-O^ej*mh0pJ+MLYUqSLTb9&%oEnTZ=y?ujm zeRwdgd+#1sv`0xT4o4Vm<`E8yKwCw`F!El0MnhNtG8p@$5$+31N3=_S7fzzY36z+u z&kWjsqMHgQ*J)<hnl-dw_Er7rv%1!U6Q!HO4D9fV;eNv6G>$K7?**Fd z?f=kzV03<8(XVJU!!IJ;z~41+YNa04Ue!C#T|o`{doRrgwetPpuJxVaKAy3DkKPQo zhcg#Vy1CD^=8A_(B_nn^J>wNw}y7|dLf$E#Y8v@yg&JRg1I5Dx4Y$xlY z)=D^mI=`W22hO;srs&9H&%lYn%5HHZcljIZt|Gi)vI3_?N#g1aC>it zcHV&hruhedxV$#=4|gj5;Y4re79;ohQ5oSlT77$T74j(U#)x1~b>7Q;p5|3fsB30H z(xLqP>F)j)Ij$z8t?&CPLOSm6Otfs)|U`ayn%7A#XdFqtftrFWKKD(AWe$d@@28I)Mmplkt zQ4w1J`67BDzZn7C1j)F5_rM8gZ!*t}Yf2YjNj(yZrY-rU^H`G`k0F>dRK<6I zb|;njWoRN+*s67S zzF?P1wI+`8m;EM9ChSd#O&@U`^d1Q3-gSU*$So20k);Ynsf)h9Qf^c_L3iwY?Aa-i z=cM7oJ5z$6VTJq-i;>CpO`f75IE^4OYk}W{fhFs?3YKivWr?zkPT;SwN#`+-##UD_ ztr#14tT2r}K+u7RiAxqk?HvtR+r>Sl4_8xmXjdo4vegn z1>*K@3V6A^hp6W%I7h*G3SOXq*UUE&HCVI^8svK^jh7WKu=i1vm({ltRi@wr6jUgn zRpzD}?JIsKz$WN)B(JbSAHIEFBo~0O5}vNGhv83xdb6z$lOrN`I3{|Sl761*+Rr=cdUeSTf97F3%!V%J9PVH zaYrPKJG=hI+N*j&?U@2!wCawXFwi#xC4j9HK(#`A#WE?6_}O3CMZHNH`V`t^CHm2Y=hB;o!X*;PaR%R7U&3~vuo_G)|3 zR)`W%igKRPcx{tao3n>}BJ)Bes4lCmxdo>=GjeLMfY>Ooa*~{vP%D*-bdV;kcVTg< zqjrC4kUBRuwUjE;%WfJy>7N>9k9sul+43Pyc96D!4t8jB7V=JiVQyH& zLU;bl&d{+g$BS^|f~J`BMcbi#EX0?MH<>GotJ+}O|Iu1X;(=RM>$DIA+PZcNj=(i- zT@P^xctryxZfURRH?$l24RIrJ!?T(`PB0B~zi zJTOD&ny$Ep_B3u2KPh@?ti<%906;Most`K(k;bOxjN|Ei0KuAB2gddUoded4HQm$K zEH5#mAv8ENfo`pvW=~raA)4wXZwj0S^mUztuO+xK^MR-*>PFpMGuEwJlxp7;z+rS9 zlUg|c2Cibr{Lt+KY)oV=;{L+N=`22ON0vH;zhU5pA~uX}wtuIC7#`}!8-u;no*rO5(HYnivslu&TQJlK z!!uuT*63uM?4{SR$m`aV+AGebZEHe4i8{$O2W!~E8qFYvd!$gkoqC#9F4kmbJ%i2W zp!_kE$a=Oa-@^I*mDI*9cq4hkzLC1&ET+Ab=U^*lQR{oxY*ynJ*DRcOv$ynKZY{N* z_e@ggEVHnd@=TnpA+5DNbypfnqXo?4gDoeH*)R;(!z?y>?otC=9Vce=G2P26jyR1K z`#dkb(lUp#`6fq~pMp9VXLT?U<_7A(fjY=S9TdV`SYUOKTj=J`R$eL{Z_t{;@#2|M z4LVT6sRT~jMCt6ae4cqjXHbrPB-$&Co;PR_x_u%g$65qx0J=fmk8#1NS&pMem^B^^ zy$-#A(ZDSfjRyJ!mBszi0?rGyy4lag99nf~9nU0Pg#4suVQyI8@Qfeo3EW^|zA)xQ z!Sm2dteLoBGS-shUq_hYknXF}&{@V>0>)Q(8LS5rCR)^;HD^}CJq%*&X;!CtXyO#C zUpl;|6ze5O&eXH@oN6snv)pzeKg4a$hdRFt^)YxYXho5h%d2xB4;4l&(UH~j@*^0c z&*I^qT*pgZr{JR$d^-gY z7?D=sw>oVpU!(e;q!fo-ayW*rEO+FqlzM=I@1)@4oKH?NXv$SM>g?zF{capQ&b@FJ zH?1b?yvPa2o2+0lYm4ur>fepPja43N!L+j4LnLt@728S&GQ_YtwI$rPyG6xadEOMx zRwkPD3F(~AU`=)20>WVXT7XM*0uSIR^hqD04wWX-0TFFMij!btw=__m3VGb6LglBx z{Ji`DgmaUuK-F>U3KDqWBcr;5xev+u4qTr_y zz?F^DT4a*Z>98i2KT3H&MgdKq{Ba6ir(g=fJy6gVs(FakI^gf8H7-k|7-Z`DN5{}H;j@tnZZ`FR8(BOs^&vxm(enq~cx0W8fR z1jyLH_O=2>H@JLO$k=9-OF2ZrEC3b3&4bzr0xP5fIP8NJ2WZ>|E)u}b>0aj2bZ8MU zQ8C3kOimLbPXd=)>q%AxqwTV}UC)K-FvE&7wQzb)S2ljDJxnI}Jn%Aa&nz8-Yq;8q zF~`~CCr+KA6DDs`DynvBGKMgbiL!ohx7&o_1Me=LMxSa9W?l~Wq+PkZ#G5rrR90Ww z40oHTmILtX&??J+e_g~9yqnVG|J{~z^Uoxp-?mO z7a@QLyh@<8|09~(kigR+5FcpQfTUwpu9X2&>j}UIfCn-{9##gq7CF`2GL_ca-DQ(Y z8rn9jo7xn2wyc%jR4TJ6p53xmZd0lJmZb`tN{wxb7q_f6zNys2rugJA&Zozz1^KfW zTKOXgV8&LPfWApjeRYD54sqiW+m7<$OOZ{chHO#foPr+b4~V?T#_b9mivaiuKnz^H z#L>a6VU#}I85S3d@NB6vpz;dN53@5WOC?8z82%6Qvns!TESTC59g57HT^*5`Xs+Ux zOBE%uI3Nc#$QX83;V7d|iEy15;L{*7Z&9k^dE^g*Z)iV_6xjg`lF{)Si(CTsKUFSQ z;pn5!iCFEj#{E+iJWTs=9t3$`rueS_AIKD*0t^>Aln+qCQdj(bNyFQmM*(=`6_cN#GDGKJH~bMsGHnS z7Z-0_<1?sRj4A5#cwBcJbtkxP5p@AD2YsI0n2P%M==1s%U^@;kw-O`0$2A{G!$rdw zYHo{ZNQ3O@SoSnz&#;tj@d#{JF|-6abkLI7jj3OyB|GSdqFO;az5mQP3yeVFrsfO$ zCAFsh?d^k_bK=T1d(yM=3kM2N%23V;EyQy z3ZfJ4&> zN}JNBM?!L!?gi{I?(D!-y(*j0OGv?G|Ce#PijDoeI(?Jl8`@BQ5CZNpWVD4(;OXo@ z0LLOu7H}*&fdCAmQx42ns9&S=g}#8Rn-8HB{10^N&6?*o5LXaVsLuWsC9uL7Xq^Zk zqXln!(4=l!C_qTg*1Ls4L*oKf7}LQo18UH^zef&kPvK<3qJ;XW<4k@{CwCMn7yu}( zZpgO+)|sHNwQ)LAG&W~15!fL<_$GLFL<5~*8g9`b6{o{B3R0&bUS@&t`rqCj~k`$Z?pv%1JZ#g8XRAaSR+SaI~0XhIuGD z%e54p<61@eQ*o_v)Jn$m1ZW!^lFT8`928w(+UE2z=BCn%%xzMBKISk54gj`1j@b4+ zC#G@V*h_;m9DV5DI5^m0n&}D9IXL@L-1AA6qe@LNtx9bJogb8%M$CzM&7cLCLsHcY zbGI+3eR2%*FQ1~|CgLsZV3U=YI5Qc4pd_E=`hy0bo<+=>t8+8F zI~1T{1x3dXrI!AhDAG-=ljJl}vGsUtJq43b1ao3R9z-(C)`C0>;`Hg$3m{2(Q69ff z;oBgfQifoDN1g0HfdY{e-?uet7icWWeR7p9=_TnpNhnHk<4I4-Bn8|tlRuj0F7XgV zWBg_K{K!W|Zk*J@N8Is03&<20xM@|CublZ>#GM>TnWxJ|_CND?=jxO6f-1nt^X%%$ zf>X3(bMmNsnImB-E>gyKQ!-CgG8~Vu^JRgGOw(8D@FdsfFV{Nw7QTjWKiTOn$s;SP z6pT|{XZRtNoS;;?u#HrA;L68G>uPtgXK5Q1;7!E)XSiW#PXTH7XM;auSJ&5rWrSl! zfN%JDgrd~D`=6f4_fXaB$bU+ld0y_LG~)uKvx5?@GF-d@PLTKx1!V6d^I5*DdYKZ) zAHhdib)Xzt$DxU|kKhss{~+)-T0EWiA+#;>!;@K8-`~?NffK6#C+M6^xrT$A&kU}D zZj0947evNq;U&3a3Gu%oA->`m;>(sPzGPeCwn&Qq!6E#MA|?J?IO4xVTKp&Y{D+ej zUodjwO_3M>ZWYA8S!3c~t)loBF)ltYCd5CBN%2o&O8ld>O?=Lp7XM()IG@!%EEa@r zd|_0U1z7^J9LTaD3%9Q<%ZO!JkYz!Z$T@F{$NM!cw32F4EBiGK)GVf){}u=PrKm4k bmFoAUn0EeCd|l5S6?hbo)cGvlwjgrYGU254~slgNN#>t-SW|Nj>TeFir z$=>^_tE*|XgzWL`1Kzvuf8Tw-`|f?OxN9~WDKKwztekyygrfcmJ(6ip19vZKDeBu) zfC?y1Qmp7ISOxJbStYp2ld2vytM+1?>ZGPe%W8Xctgc7T>U#{Vp~uJ?drYi}#4AqH zteNDsuomK4iK{+o>#?(Tl16jV(Nn+{5LkP%u&0PE>T$A8;@6!l?s2g$0_#t@drH_6 z0vk@2_IOwifsH52ddk^y0-H`&^mth>z;wWTva+X&ts=1HWOYvsThmj^))K$+5M`8+)4Ark-ZDxn~QzrKg2$@lr#S|BOT;*>zD>RKQ+B1spq+V(5-^ zB->g}l~aM;fr1k183nryeYrn@jzHm%#Xl;SCqe!(36kAF(dS6(9C8JUhjc^jL#}qE z1_}wdAonwJX%gh0lpxs+xF3)XHp=Zr>OF{7NCmm|nKlPn1B!M} zU|YJ~`+s3AZX2rdzbN&b?D{XsumpWB%E+zXIA&z}JcuH#L&l+fWSnWmahA3nf$c~Q zw(A#0i_hEpU%ZXj>&-BB9@JK6pq;e!&@Z&D8++#S_6HrhVB|F-=Pvm z3wvZ-;oJMS=;`+<1XE;oa&lxk5(}S?`Lu$nD{?_lADI}53F_X7XiU%#%)}<9BcVw_ zGdMdl85T551pJV*_z1+a6T`Ey@aU0HEc9524NXplCm~AHd+6w~fKMrC0#A1J`UP{> z^IDl{`q*?ZGCMVM0jOaXtBLkbkB>uY)5!GH3^X@56$;Cq=Q?%>Rw;%J zPm)+$b^tpS4GVN81bo05$R^z})Lq<}z75+UvI4*cC|Ff3HKZQW1eEm_QF=<_Kt6)40y37ij33!_eEYDIU63GMfzx> zu@D;*bd%H3XmBj@77Q6$&n_#1{Nhm!s z9SpK&tN=nok;pVq8ir65LTLXAXFcl7zu`A>4AoMY9Y^E z_@jgqThBu8Zk-CvoefPtGZ6x4B)nyYoqi@f5)(bc;mF9@t=rkr8ED4^5++))>FLSn z)))?+mYEA|1(ZiDBt*0ltglnIO|FF6#ajy#>O$UR`4>IB*>U;Q!l}!_h2RHj58Eo` z83A@KV=Mq+kKX++utNfqr(hIf|DEa*d+*faV&5@JKaSNL!zgy7L%kHE9)XeqnuvBl z6VNhh3$ZQlC)_Ijb8$W#f*_)=(-O>H=*HiLN~%x4Id(KX;$xHZ;I@hNHab7 z%d+AhjfUmveyvFNt2f=Rl^N>g_HQ}%yV&&9I9Y`cUOp&j zNzB-sevcQ++2ftlHG6TN?VMBZZrMFH=VfVS*0q{<+ts`!kQUwbJ%0vQ21pH%yQBFIj+mj}az1a(eZlDT{Rf@qt zjv^UFOqgz*P>6-)qAQrPLbK_DIn#+d1;Ze)ER&=^#a;`~8KlnX6pV6H1wBa#nZ+iA z=k#f8A=#0Mnd!+eDaA&JJ2ji3z+6GPygnuq);Xv1P8|skUkD4fZs5_zFUTd#mCB7I zgQQO`tJ9|vRFl)=$hD1+17|F#Cn96hz_=M<2v3X#agYf#j_07r8%G5#CIx=a2yQUU zveRsoMRAK@m$D$MPS`%ditILy_~6W0HUyllj${aCfH5c>i%64RFpIu4IR(Awotljz z_vg(r>^7{@ELI7r0Ml{yG=y!1zv!RAqH~E#*(gudrQSs+2dtT=a)(XTds zz31hg`HQ#I9$rVk{8hfBdck_zQ?mL9^wvl?UgzyOpE2 zadhWqNV{s{wfi_%=X)y7b$H(VC1D*w1gDi3(fO)~XGD1j(LJO(O0Yh*0qd;^{Z^oK z+}*;_TRD0!!Fns_XkCR)-YNF^InGhPQq4KGfEP&J%N65@DBhJz>T}C;t2ObuUGZJV zZxqKLXKq;HPYlOLo{2v@9Uq~h2u&FohOuELFbtAG&-l1L3GT<xcW84DL* zNEmREWKW75?`FMvIzBWKA3YmCJrSRHK7Q^+sN$gV2-b8?(Tz^GI9ts9+4Wvb7^=zS zd~io?Rx#B2Yv$999x6%)Qz5l13L)#Jz1?0d{af$Gm>Cn!@q1 zNPIdP569xM7vl6qalT^T#fyCjgO3pL!7Cj9Ijeib6>kFD_zdVUz=?`W%7pSf`pzp~ zLdSeCJx2|cu~9rh<6)bv!IhsT$8lKoakG*30nrx~4AUTpnwpq{7Dd?vd4eH! z9QMW0@beQRVK#sXRnc=HK^2|~31+NioDGczBh&1ZV8C7T9F#dMm_}uxm{1rEJs%E& zP%J3d9TjvlqNGZ&XQf3s3W#3JA_bzN5KK_Bp~$l%Zm3)`HUNlp;i3Vh3%EIo#-BOq zOCcZ+aBmV{5>P}+1Imaepo)|+j)rvB5XcGwAcH!ojFdC>;SyLQG7=8IUM^pPKvAP+ z>;a7x-4k9_;JIy@xb^g16=dS93u({2V8TFKYK){BN-~ICn^}RZ7}kn1h8{WzV=GJ zGQ4gE8EQbmI3qQTjf^atA5QVPMk;5(a2mO7Vx+K0?H8gDa`QPxDiPXU ztAx?tN=526opZ z8lVlV382V_A4&BMVe;w>WBSQ*KMutl^JJYgb|HP55}|(t>`EB-o*%-B7*ny0V(X1n$bx8!c+v5f#MF8m-I|_-wgB&oM1{KZrn}(|6;_H zi5M}3k+S}`?M0V77RISc>hw;DwmK>VTwkO< ztAyQBoa_Q`;b7`f^C6d00X-0@@2Wqh=>Ihl8dHo>ovI3IOi@WSQ_=_-1x?DB(o0pq zO5=qUvHv$WOVj_|O;fxrrtrD@Sy*JK!J~8aOcbOnUgQUX)r@&ZCMRZQFy!1s?5sBi zLLXq`M#7zQs+R3z0yS6ODVsrm*NOfU{m0~x&bf9_MY5n&^%7PjERzUN zqTcC2eP$Wu)h~O+sQb=b7kImU1L)IFx z9t%xOhC!&xB4ae?3B;zwE=eCEM6!j(famwlwRy?eAxioSxPR}&6zCA24^Lrd0$1-P z2MjDV(|VA9G(eIhN7`K`(RpM?L~d?WEgbLUeJ2ZpcZsY z(2!FE+m9p|z-DVEV&N$`#^4!c7K$X|Fcx_aK^={SXGnh%nhW;X8R%A^4d~&a8^a*N ztHN{!7#-=;kvet(1kuQ2fW8&+05v%}LC~{{N!VtHBt0VNhC?hn0mm%|oMKRWpItJ{ ze~R6X1zGSo6C7j1&&|R)3oZ|!Uodo^9|@D21A=BWJUlzj_F-Z*o`_gv34IM&BV`Cn zdm1e=T!IMw5?Ud&hS3^9i;S63gffGc9L?Bq3?sve3^2TmB@U}egr?9U18f>09A$!O za{6328iPp?jk08f*`%XlaAbNGPlG_Yb0JhgvoUrKGkz5N3ThXNxKhfEKtS`K+ zE97+Ew8SA^ynpShHv;kfkH?GrxAcLJ)TXcZeyw-WaZ6peF2Zs9wp;49jOW-b_3?}c zV|cpwQg=#AHFqSN_H#}9KWIA0>FeWq_tL)ghF0ED%ompOt{UD|%@?@un6y44r>~v2 zE_pt&PzEPYmo5#h>|Be*yPxymwIj+t&1(UjNaQO%fjBJ+U4r& zzBhe|k|t2ogV4%a#oIhdTP0_!bi@m2a&I$n0{mhJd@ zx%Z8+Yh#Nhlp-5=j~6Vb`|7h-o}E9&)7Hx;7EUBlL|eb2xkb0g#TQ&PUoj_bRh+Hr zmaXP9Ekzgev^z;xb9D7`_bs}SH@ROQT=so8m@M1DmF-BFcJOr3<$;BPTXfmFzF2O) zvjSltLP%&GK?yLoVN1s6)sQw zW<|2HgRAUFSUY%|`>N%N<(ADW*vt9K>O~8`>qv4}54Wr5gI&E`;r4i8?efL-_I-$$ zx1O(QAX`*?e0^2 z)=m|cF3m2VOV)RC@LSROo;^`t9pGIRZ_wB1WhUO-xfZ+SIsy$Zs=-b#y?0na zIVwIou7F;LfqvC+#jw1MFU$sAeU?9H#u@pr;E{gIhk0?J;Da*>C0ksX+ zPrP~JTK7jeiQ|tw8~091smiK;QrrZoz17!V0xZkCNlzQ+X}jfV|ERM0H_K8w%3b%d zktzrBwIn^;InVa>x*el>gQx7&BUJFuimx>`6_%Pm*ypO*dzyS}LEV=V>a%+;85ZRK2B*Q>Uz*1p%9sCt~Q zZs6;-V&$z+vb!W@q*@LsmJIjIkp44E-T}z#&5QO0}S!hz-so#-YtHk>QC%NXg9Wi{Udk5<8I}Tnw0?mnb+g@ zsQ%nn?%$%~+(y4n!&STeN)5M93$Ztq8VI?mQz2|L`rW#l9;3fecXNx;zgH(RD7c=3 zGBzBWWg{bSFv=K&ahtnar9e>|W6B(*O%d~H2onvXk`>_6aZO@CL%{5w#uFqlaPU_! zN=C)#7!9Lk3~(H!1BwnK3XEp{A)sWe=?g*w=?g-1Ktt5wAo5}Kew|c~>@rrFDITZb za*h&2RB(`s*iM6ZP&A}gn_O!^+o9jAMsgSiHevHx`+!?RF%~Gvb^ShL)}AhNe`H@K z8#dXTk<}D8XR;BifF)zwF^uf2$W#0Dayf!#I20Q>D>4RoO7dw0;{@>EuuTJx&7$-U zn3~YDsANW-Yj|utD)JtNOsLN;m;VL1R8U!?z!o=lk&#jaL2^73lQgRAH9#Z^C-wU` za&gE06^LTd>W7gO{Sg45+bpL*DOtYj3KQ?_ztsN`C@X!hGf7=Fr>kCfH}DQG**L3t z*g0Us5O)r7&#)KWDKeE97cBFNl#8mUnKvyS<@BX|Ma|+dPVeDO*2Qh#=-_F`<(`F} z#TVYEt6@=~i|%SEXXSkNdRZOsg#EaJcb0(4nGnRB5?Ip|Uw=VED=T+ks&_Z1o?PL5ZSZA4xEx2TL+Ig2{~932i<@ zol%@oo>86AoY9`4H;A&O^jU5dg`)#aL_etXV<|a?$%u%jK)2)1%sM%lRKgSuvyg_l zs0HDiiqW*|0t&bNE~XVwlWjrYy0l z6q)79I57GxMi4(KTa*#gseje3tf73We&mER2FZ&@oV$Mq)_&c$sX<-7TZJIaS9U18 z)S$MC8dL=ogZh99DFm)M?P}14#X%FH2Gwvg5R{k!U5656xQcw_mdP-v8OCJ}cuvvy zwGt$|`LqW4B<+J*fA$ibE{`q(9yrF*y}v3zq^%43n@(|MI}qG7|Mf{G1~L4|svU;?d=PbU~a z88uDxK!ScaH1h1Z5IZU!ptbzWw32Jf@!8s!s z$&yEo1LFCPtece(B`$+ufTB3)yxvQqA!~}D8-X%r*|1QOS?XH$O-_%%1)%*c>9|E) zZoUS2e|m}H_4cH`6qbn;jAt5YFgD4L$1nFU^e5>?j&5Apk!;$_ zHSJ9_b>1+?G3bfg)`EExStZ_5vOfT91T`r4!agkvVoFL-#-;@w^eP^R*h?6q6%V&; zF?coz8IWKP+#CcWIlTq3Z1&${N>Uf->)?<^3SWc3C=R(_0^@~WGzY~c`gP2J*4L0` z7XkPuI8_x=4(C<(tM1oI;W#BWb@7P=?YY#=+nra-u9Ur2ajEB?+M+4EOM!LAOVI`2 z=(u!*SDU}y|8hT^qHgPqFTcd=O_z-e#&6P}XeqmUsqi(+x~=f4;Z?(Hrln^%Tf-fV z%C5b10=O~SHh&S(3W! zWj&2hzWFT8+<i`uy<&B?JU5`3attEV0xV9%H0WSVemc{>41n_LDMuej(2>;` z)HkGKLNF1saKf9{IR;ic$Dj^YHD!d(N@>eU2Ht!cC1$fcfKjf=r|-gb6jm)dpWaKN z5YRzxSJ8%cmE_ieH@5?Nv0Ve8r=7CHhNffSez3Hq0dr|*9POrnfoyE%tg?An^$lsp zvI}5UUhSA!Vk{$;LAyVr9g{H$XrvUd#bs9fd>TuCrJPHG20>d4x>BIZz~kjVSAM|G zEtM*}5>?afg!wyAwn<7oVYPupVa)Od9I^C01~j4`rtO2#+XSQ6`uV8FiRiQa=yU`FNjzI-u!LK;Ox==3VCkT38PVzc#1^Oow4a!euW%oi!G;mu1ckvt})#;^96%qZUqHi?4-Wv&jDYxI1lNG#q?+E14`aisf2aN$U83*$C zI|wVWTyuuUl`+V-HZqD$@_6O^66CLvAlZc*0x&8Gw`&!(O=R-GJW@s3kwNx1AnE^* zlCysR5c@r}ehaPNM(a9SZ=>}NT7L+Z&mm~WCg4s|mZ;w_%)*_k5ztSt{}z%6YSa(0 z?;@y;g(J~vmi>>I|36`f5wBN)vTNpSl>OHjip(ziyJ-Cz5|0<2fMcG8OILb7aR)%X z5n%d!c0ogUU%?o|8`Yxlop5l-hq3>TWEYuV!VO0;RF)UipcfVxN<6+3T>-EFX^Flz zIx&v#6dhvO;vFhcmY;PHPczKuh+H1Bg@QN3FpCm{``G3`g)Jls0xe{le`0bd>o0%hfAmT+xnsWBSz27rL(=zjA!(i9}(|d^hhdf1~VL z*>@}Ed;X%RgtwLO?lQjIx6;6sZ|A*DD#)yG0wev zzK3^LC+l}{^}FJGdb#>u&fPcPbDMT1=?XaS-=ZtwT)@#auk_r(^L+}KTK6Ne`hR0` zW=XIcIJ#k_37$(tb1;1_`7S0y>vwQaCc*mWpTWVzm31&dbxhya z`go)J_1ZTYuQeu1I=GUKM9JQ@ajvBMy^*-P=SFp0-w&c`jTXe_C8cq-^T&42yoT48 zEQOQhEgbynTW;!0U{E^BK+T~qSTC=BWAZzb@utrA#<=o6I5Mv5%}KqR)4P-U8ctus z(-q(APS)(;YIY=R_H#A+lQoC9nnNIIGSq?^2JXOXsq$@UwMf~nxW43XS~|YTTt5BE z=_UK}&YR}Owb-Sj_taIIl6w?bD8JViNV2N}&(-}`_Al2b3L58+pd_nzp?Ast&4H}Z zMbtIg)vh>Qf0N#M+u{7^L9-H6Em!wm+51}Oa#y0D7G~1MfWIiM`fkVKkvnFppcq~p zI7^d7wOmo{2SxQDER*#yqI&2Z<6B0U?R0&D-t)1FayCFSZSKo2F1)xjp0L)CBNYpD-tjW!4cXbQjdDe}McrbH$FN0LCEKNG$+~96nMtKnql&zzUfW zBnY^&z{<#gSnJd)lJuQOw}04oSZpWrP?3_D_EFo$=#U8psl5*QY$(~9i#e5%Q^Lmm zXuYG)*E^?vNEmEG|50F`$-b9wXa6BoKM2xtT(bl-s^BN!VO3fp)z@{nyBD?FWb4MH z;{L5m3%CRgqz%V5NK?p$OO`aUnvtamSn#aXjFoq4edzgdvV{w((Cm34MMD5UXT!A(q>_tVy_97R>Yd zE7f<3V4(xyd6DPpz?FeSVeQL(myWN4u*dhR@3p37Bj;$mba-C-@<|>S*O&W#dEmC) zxuFqNawA{~pbO_a<7$^Ew}ED7j5ML|{yqeqfoYum;y<7qQs+A@g@vMCaqB+tmRb`J{v;pocoF0RwAes7?iI z2T81vy$uCSY=izKLlyO88KVV=tXT74?^k3Cqy_=~r=f0Xwgc;Lo6O{e*`v#pk^w*n zC9gwy4pkr+$s7g)uvXx8y$6(}`O-?tUK@WtoOD}3u=7x?4V$JgerfG8ZCV$-X$td~ zhK*&@x~yMXUA9f@vTvHg@ugu?uxVX|o2Dp|Q-H44+0Q0mPO*Os791vLr{G&Hy3z1h zXm&CtSnyp+`UDpoJ1^+*z=tP7!PpJwDe?)FQ9%#tPq;(~>K4^VCSBOG+B&2CrNt-8Y+6rwZPf3I;JLf{qNB z$pIgE(L`tAQXGEO1Rl1=-~%JL-3o@#=qTR(f_q?i*dS*KxY(zTh9}3;DT2X~$q*B)HKo=u!c>fo(+oj#NMJl!pR<`IhV8h@934HT{j?=qt=}Y+H@^8|pL9r}Y7B9Zd zyfgIup|$G7mi;$o=PjI`k)r0_-to@9@9$evC7Sl#U@$5mMLjInO50i(6CX(@ju$ts z9(s=y-6!S$>f7OWCci(qwms2uApRIB`w1y7w(MWiU;+E3kh$CF#>Lgc ziTd3)ijdteme_rRnKbOh752SaEarH+7@jVFTk{V6eR?&JXy~|6n=jK#tIGFwV5Sr4 zO!4BaYmOV;8044GdGYPpcP@Va;@V-js}q0x35?UMs?o15(HfD~Hxd+qp*$<9lv!|+QY}>? zyP#>s?J~!tbITi}GM|)E-u8x`dy0BS5YRNLXetfNe96q%ogFF@J3g zyoE*10$FBB+DtZM(Dme+FIFnlTbNrCXqnUjN{tlA7YAR{Q4@?%9+`zzewt=ps137( zrq26o8xE&wnnx;`p=pNRK+`^BO z-%hD4+5KXr{#^anA=fTJvbJ_(-{GY}UH%@14P~MxBm=AR4n{9 zL=f!)Jd7q^Gx3#)!U@7m{!fgB&IXQh9G;;lf!AUG7l!>EhG|1HGvUZ6`}ZUP-1!rW z{0B^sLhBA%q=t_Xx{KCtq4fz`pQ7~{Sc3ZbiSRjAjiukiz@MV^f6)3FMjK`$QFw|F zo(seKk$=Kq6v7hS0=fMW#f=vQ6e72VLSDg|;nu`!CMS_s`vb7xj*PY7^2LRV%g*_W zaqD*RvWxW<>ry+XulSY~xd`(s=A|N0#Lww{-!k7e*)Q)|*t2*xVXELvO{v5j8T^@L4U@6McRo@ywU*AGslJ>$qy5*wlo;N+K+ST)1>p`ya;4Qif;*)e0Bu&yj zj`rc*63-RSvKFr4cw0Gp>nihur`~;P?YT8}_38JRKYHr-p1Ki@Kjwe$47Ve|*#<$d zYqp>e!to$nlC@vu;AO4#V~N z_u-l#wner11rTcw_3m-hMdJ`#8tGHGjf!5E1TDd`b{DP}@!@;-)6T z)#4)@IIXe22BUq0cm^i};cz1WfN%z8m1wecn37;s+Vt@K`NQw~fzUSd9m9ws>oRgC z6y!#Vet4Aj9jFfOzTsCB$T{nOgI`p-D8<=p0+gc2{HA9t+;W?az$2oVV9U(fbjTD= zTa?Vg3CmQp7=C>)X{zE(RqLQAJMrp?r7%&Ht%Gz?^Np9{`f`!s&omh_X2@|HxQlP& z^JGO9NvOQ+ecn$iW-qh(M`ZVBUv`hk?m!=Pt!X|h_Mxc~*NC;EWXnF|M}2Bj$jae6{EnKIe= z3?!~X!f!)9BZ-qKlbz3$ri@q0DZ3fUbU$RwRLSWjDAPu5RvR-1AZQ1j(~j&oa}KWA zah5c$HBwEoo52-Up9l6w9+o_VtL}lhq@v}@GRM7$KC6Y9qYbpsn$yDUI9pC!=1>9o z>^ZH^&gV$CJ`Zazp&0tzhqN9S#T=7qJ!r=suqL$RWC3p`m+^tQ^048VT&66VW$v*< z+Y3nB<@LfO<$$90D&E7v{fH=>`sM1~i;m~MvNNBQ`yokB%Mb%=aGc-=)7wF=z<9@5nMiS`2b&BdiCO!i_1r@ z_r2MdC~liS&UqZnnUx4 z<2GQ|E7v^@E3Sm64Wji6ee1Q`;%$c$wMQ@-_`9l&xuNEysT6+UQ75wSOQUa$UmH)@ z>*f#LkIlxc>lH04=MxqCpdJ%PdvdDvwywGo-u)0s!{gA6IZ34zNmB)1T7lgBiaOcQ z!8LRwO83n7uETAh-h{V9%2~N(WinB@7rbU*taD;XN#3N%3om4mnO!MN`u1|Zy$R30 z`98kZM_BoqWLqEC)|aU5k3;h-n^&IBDzB<#<*SLRPAJ|2SBch}cExx1Cz=M3T5x}& zAtyPh&yzHHc+k8yUulM(J%8()`u3Z<>GBQ zT%Xzhix!?gSp!&tc2K-&r^S#LFhmy_g(vu-3oMzxI2qWj7;a8$9!A1&(=O;!u>TC9 zQJCXY`i9VL09mw9CmVedfS>+LDy1ye3@Da0*Y$7euUp@=-lNd}ahq3TNgYwBG!;t> z9EbM3xo@>_HT;9gcPF`RhvU_Uxw0cE3Sw_K{=DRmOFlvHp0P>eO;KQpO~#sh%6>t2 zMARUGJfVLa!obuVJbLtKf4_Kp5z^q?rF`A+oixxV5%ozpD&Fh*Aw&qq^sD&EiD4Fx zC2;-l0)9$^RfC7crJE(>(IKQqi;zx3KG;u)Mu;rf$=BptvOcgT1x4U<@9BS3z#5?d z!6YUnABJE-C=8#P$O>t}PXLg2Jmg#TEJ*TUMG_N`;1+UK61i!?lN+2}KZ{pI1OrJD zgm1{PHuTfPpZ@3ui@T`c!jIU)Yn_Rf=O?1@LH!B%elU5w#5y3FWzj;(Hqa{bK|2=r zJn~tEbaUWI6nx%Z+z^1fpjc#u#dV)7`SKc%+W_duq}~uegnt@SZ$S$ODkunJ@UH*} zg#^QZ^cg-jh;inu|I*OwXPYogJq)#zdz*086h4m-6TeUJD_9^3V+741@}~234EQcu zFQIiCty;8@5g|K(px+1YEa4vv>}U7D_{Jd~{e2h<@Nih6xND&l1wW=NKc-9{QdNIR z?MYI5{)%$^n4*77xjv+Jd`Ol3CAITIs{TW2?}t>&hg8L1Q|*6^fqOorD(@+E3gtZt zEO^cfDX#)NEHyUka9PsiU{b$PXQDY&_xv%r`!ZAfeVC-BTMaz6IUuy zr6kBh0oriSE>%Xn<&?cJRY5>6WppeybB2mkCGl5*KMq(Kno`vypoS{+EOlQwkg6r# zI?7ggS??9X+asV=3%(A+j*m{G>mg#L{!!w|V;=(9=q+tKHs zw5ItTi#3bUrNfKoue8LWPEF+}1sHcq@w6fbRSejGD+Z*APQY9YNnkdLBw!4-i8QRi`+o$zj*|S983s2vYZ=w-%c@NG&8n%tR@K7SY}uhxbwao5 zhPi4k^r~K%uja!-wZO7o(7E@3CE->W$?G z5SJJD7$C<-W*7N5#-@*$UE;?8d2VF(1fRs%OmA#TOzJaTewsAhtb)( z%TMrGjGh#%d5Saq4PZAX3StgWbI#x=g(qh5J=Oc3#&>ScsC)bjKl|9Jo_^2Z=lJ|% zgU>%Qt7rK0B1drSeg-?ez+c4fXZcH-!Ub^#RRrR*x}t!PWxHyX86c^{-g1}iP0 zygM6dBaVWW@|MMBTSPUXoDU-mJ%yZ&GE-g}`%ydG*isG`tu#=s6lq&VHG=nl)OJwb zK}n`jrN$0JF}ubt+hse}uCZ+FvC5vU%xFW|p$H;n1uKcN8>tBQm?Zt$GbB{c_|Lt| zYjG$phe2m8h?|W7t(sWeka1Jg(#soL=~^85?=0WV##ey!+T}MS-w0&7rN^==jpJ5w zxt%sz$r8{XtHt3)I~9Hy2s-9V6V)`6&*7sAX&_fc>faAqH5%{#EHnm$y^79&Fp#y& zw#`9*Kxo(G?1XU&CNNDzTb0#6{^R_=#+D|pJ>}Ryj-%SMG{N!&+SRe%*-T zHmYNNq@`#P;}%TRAOpU$m43n9B2pgf+(! zK1l}AdW{6p|BazR+#r0CDL5D6#;&n#ckLZmT+`ZR3}a23gv6~U=C%Wf^Qlp1+~&?> zuLpP2T+`$36SnPk-OL#=1Fd}5p&3|H=(@UPa-+@UFS>B-_HH%O^gr6rSP8w^D z@07YG@q*{SFn7)E956c8&A|hBU2m|nH?nhv#s0Mgy*97c=CQWWEez>FyN_34$Khk3 z15)o{S)dk&+4f7hSn0gJa3M%)X(JTnWNG076uy>}UwwnmFI>=^CP|l*E9izINrF{$ zI+F_@EiQzM3z+o&m4(|^7M7vdIUb~<9;Sz^gA|8UEAdeM3hJ;<7GYK-Is8l!U=}|7 z_fmgAYQN=OZ3W>953apUt7#k0=$uUuIX3)Oyt*pn($!X63tGuFG_t8c#7$Bug7Ds6 zre4Yu4;xA9$Lk$?3ClVKpNp0DDz%Tqb(u~h$JRBPM#`Gzy3_M?WlQBUro?CfgGRKf zoEz`H{o#9bE_5-L{D3CQ*Rj(_v?UQxG|Ch5TWF{}6#SMZLU|zR7HPG7Q$udThX*O5 ztI7wQ-)yv~Ilio^dOwIb{DrcTt%NA$aXVE-Y%p9@z@q*%GftEj2I~Trf;hPmq-)A< z#*IjgVPDakRvwcVak9Kb6YDu$o z2=D;VLm5Y2qkXkp3EhDRX*eMsdATC9Et8qdeKzHmtOBF|gjr(Ky>>yjiWn{7@BLeT z56fyOBR}#1*HMxmp%MlqmPjg@u!N0N$|PGyGURYu*yuUjMbFhqqEXH9ys$OI?j?sF zFN%C`WSo~Ux4@4dYO5fu3{d>_K#-aLW-JwFop0H;hNX1Hpkjf?eqx zKeX$W`H}sIekutSQl_6{pHG}^gqR-~%-Dnr?op*l0BPpZtBNHD3Cdqx6_I$jA+HJA ziBL3uHf3}!9YO$%g+KBAeXgdz6uUc6WSApKQIg*{cn%n+S95~RqdJ;#dX91$5f=~t zj|Rafb=lm2$_n;G9MwXr(jmtOmA18l-qT0^-B~4w7T)Q@R1T-9@jxp!CG61bL+DUE z)vDBYkvm}i?ogRBW)LoYygnYFq>;<;>JU!))F!~0b7Y+(FdI&+l*h;%C?fg3D)@dF z^LC5cMc+rl*~%c&#fHjiM5(+>C*PxrJh1$+27uY}6Y36GHNZ>W0)T)z0C>g}D_Nd3 zWzAR`qG=mZNaDso!6`vWjYkI5(YRrJlG{eK2X{=0yPHmGrdDcqq44zY8=nJLE>hfM zIM+wOsmC%ET-hx|E#!)%TUDpy(=686hqUJ7Im(kp1cs~$momRDF&{3t4C5g&`%d~7n+Cy0ajYf${dK=%uvz`#K$KbcE(Guy0_Q zBnjh&(v!Dej}B5XH;PVAH5xUMw6Rp1XVLPQY@t^}8TpYvxQ?Rz!O-Dz^RavSJs&J*jiSu)@qCv}@SXgXbQw3DJSjtdDheuyl^ z#{ZF(T^U|8l8X;8PXCWkH0569p8N(5Q!MB>?XpDvE5~{xTLT@34$MBa`JJVs^b~r8;cvQ>Zdh7P>yUYIVyCb26F0WMHqwVNs&}AQ+ z+vc{_HJiwz=Aef@>q4$R0t+;g+_K%Zb{x27(o?nz*>%Yf%*xZR@xxJ0vKDLzm4lzA zj856OwBf}Io2!BdH_|P?)mRtR31s{a0?B=(BWpo~OCQ2iEq$nLyl^E|r!Imtk-w#7 z7A_$BSf<62>n;B?VO~V#nMFnlJ;7|&nH-S2)SEX#HW}f7Y>KELyZ1(>NhC!!@epMa z=}lV`2}7(89IK=0k-x)Ijbe2|;&LSuF`!10mBDj8!ABKWWx#zff> zj39!Li-%4fLZqY>n^e)WrhXvQ?{fO_OTRtoS0Je^NM8v^Ov?VCa`TS9lV!OQB^o+n z%LmjVXRmEWc|kjkiOqZpyQt67$JjbVL437zM9;JJslOV31vGN-MR z_OwG~)}FC#hpvz&v%dnzzIILXnQNKR diff --git a/mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/__pycache__/utils.cpython-313.pyc deleted file mode 100644 index ec8c28f4740a70032a5e0cf7bf3661908309aa0f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11054 zcmcgSYfxKfcK1s6N>>sRAOQyPep@zR8^17X6Wj44#sMeD>^LzSm=xC`YdMI^gxqhDT{RU#_HxgsNK!ko1G4-2?xj~(0Aw?A4N~~_Zn>%akFDAt> z#*Y*YHE0`em2SgG=|0YFbTU^sVt*`@@Pye0w~5YnJa)FahnYf^Jr?T8dsyh~^2g3r z_b8%J6_16oj+D|H8{9Suxy@ZXQt2)kse<0heFk?a?RAc@LkG5Sq}ALk{4szn`-pj@ z3gC4-47nX6!bl~pm1Sz>P^-eeeS*6J*Sx}#no6dLac^@n?#lf@Z%m}tT{U7jnI zu(p$)pw3w79)BFj#CQyuJ7Kf(+hrCa4>`}Jy*G*ZY0fNad1sfqSA zL0|KTb?8*iMA{g7?$N%?imhX;tK~id47v=8{#fXy!kzGYBZ_b)A`$UYKvINr;}L&2 zBnA}0EnSVJvIqs zx0#DtKQolZZl(h&P1)mI_J~?D9Gz$oMGS{&ZIhbB`Uo$0@qdxWWY`F{J3l zOR~cGBU11lODq5;<8lqknO&E|L1|Y|oV+ZCulPl%d8N*A622mNBWmLkpy2YZy~H;z zlE_Wkr)H6GI3Vwui1-6?7a+sp4F|_3B9bR4N~8i-2Zh4OO^{77>lW*+gYO*tpnLAX zdk12z?;lB68k3f`xTP&&*}lSU-vlT|F-Jr)CPh5g#eg>(OBr0I0O}8bOIE|7R)%zc zT>1>_+-xV@r4uOO)^)=DfMWW=vD?3H`wvT3Y2SSx43c8V?&&GDNF~%2OBz3+FbT@Y zTn!Bkq#Q?zHQW;^K=EG(z1!elMri>2+so$K-)q0yw9s~^ZPA%rz_u5c#_u9(B+$b*#2??C1Er%V124aKe zJ)V?6p>Xmw5+S8nHOSW{L?RL7R8j_+DitbDi!wM3$`WaY(Hey4fGnG#559NsZui2W zJBJpB6V*Etj*f(-Gim9LTe=gLeJk8P6=~V*oI%TNBI3&VC!`N8$fqR*TT5; zgZDkKHau7CHJ4uHeDk8fXfSgG6 zBWj-8@QT~OiXLWd%w9dvNH;jid&oqn^B&Ym7l4pmkmVfDh!0{!93$N@f(MjQSxqr1 zn|=Db(fzTqm|?|Izrxk42jK(Kbt$0eVTNMz`3bz@VFE@D5kptK9L&^Gh68^L=X*u} zbOLGGG{g;aLpVDxJz3l}ay4XIIzrwO&;%P4IM}T}qzmz|4_YdDX;{CHb#q`z3I*3J z$^(ZcSXZe3X&OR8Mk!0)#i0T^FkwVo3K}>+0G$j7LCwA z$V;%BQ9Ef!7#5z258-#jM>getw_BOAZe@DvtxTC)nTO4VZk4q{0lPz+4IE&8>W$5X zc-9TxW!5T%xx~B$F~;DZ%Q86G-MU?ry%B#-q-+u`~0%C6meeGRSKsidRcaml^V z3`(*rjzMFxwDqOV)?jBV06o{+I?&tdhF8Go6C=`SFp^`wp&mq*-bDH6V{bF(!HMhz zlaVSXg2QAFWGhf44%KI>OV!0ZusShCwz61>w?uL%^sh0X<`iP zg-?DMRz<}+1Z94@h!>z~C|mWV`+9$d_oOd`ORpGJbcUeW(w&4NC~#T`_?cf1V;o5dc-s zqmm}i(veeO%!?85J{6Pb^PTbsuxfQvv^*zdm_QK9&A7-f4DPa`@-WyLbnmLfpds(pt zMSlnh>!HgTM=yha0V~=lUA6`U7@%D;>TGaKUWI>5X=6H@LCDJhf_x?Wzyb!IcxYp6 z>TKd-L1}r{@M{2iCZzhX$s!GFj5NTGE&j4yQ_*AMpB8!_=9L zB8Imlxw1G{wpPAS$YhBjGqOEu@4J zb>iLjh0giTHDNpLsf`P@F~8Pw#sz1*`plZpzqQZye-@fI8kwd&E5;hij9$QaCTRk( z`6_hXW=0r{=s`4(g-D)mgd5@UiQfk=F^$<6+|UCpp0ev{iY zVovv{@idDYB7S2{ED{d-y@M0MOAxt9 zIt5Vy=4i1!x^M2_?7<&=e?_QPDVjMBKx8bsqtJkOExUDCx6X+wvcQ2I882CTeZV$V z{f9}%Bd6+Z$UzRShpJ?=1+r$QfhuAL^BlaP;nM5OZN7mS02|K&BV)`#{V1!NBczeQ#u~m97qR`g~yZiylBu8>@Oph?gO4ycJzHVwp{wNJ)e#J zx6DE3pV>-S;1(WUC{U(IcruC;5P50!1eBm_7p4Rq^5K+v$bex=TRs|&@910heLnWt zSYk(CqP1_O<(ZY@qbtHuN*IOrhkVkF!kpazi#+bh8fKAxpoMy)Td&W&9{XORc;||+ zQx)Y`VR#S}4`>54l0eGL9#1grn+RaN$m4;~Par)~?D34klNky4LsBU0@eouDirycJ zkV{Y_h(R)rS%eM%LlAUpm2%`8G2-`n;0Y)GOA!8)VLEcSdI~%w8PPJ-cyd2vZ!sG> zmM@R>EIvDb@Tz z_q!jRSStUkvx(Z>3Co@pZjUMxwwxVg#R$BIyA+lVv^}5hKm&p~0L)7oa)1E3>;R{R ztaRwcbe($g{z-{49w2E379#IJ24Uf}(0E2exyY)@y+&HP{9Z{WKZ2PP@Gmz*1{Ros zag@(>yw?$XEm^-OUcV<{@18pQFP4hv8wpF>3fD&Q;sw}DLiLGh6W!5#3j*u|&aEf> z80X@Z68L;_T_Um=I2rPWeG<9=+pO-Y@%sVkEaGHkfT5YR*WOc*> z>K>YMPQ2xC7YH~ERcmv$$_uQvy2g_>#q5{;A<%sAaZC*@8bMb|;69KyV4@<72Ew2f z(TSop#TfF5gorm4LnMqs4v%q|fOlayFNHw(i5X{cSL86bD*f(rgK0hiuT>fqCjyc5 zlcBKT><5GVB~Gi%=dP~wq}O3o{tL)3lvf@VA#~}ScP1?@aZAg(qau3k{R7hj+DP?$ zb<)xrx3sQ1Dr2SZ_e`JDMly3!js!n>IuoxdD>{-RnOY2E^Ay16OXdj3CH!-49=^ze9S>jDd9dUSVDlT(-;8lgv*`sl ze8C^%U1r5PAi)#ol~q6XKGi?StEwZ3EI;T)5!O;hq<;)Ox0!59UJK$6GXYJDFc%S? zrXlV8S_hPczV8%v`U~5+8VkdYnKMF!1tPFqSL8;-#Zylh7#I9xGz!QdY)W3h>_x~x z--=kz+&4$hQX@gB$iYv>(Lfa~?}iL)m8zz9ZK=xY*yw!GH21D`y{uyH^z7-m=VzZ^ zEc>|b!@A|N&uc%eO_V(|&Hd1%zC>(pKpC|a)suErPeM~og^YWAPYS&`>PG^2$TQD3 zz!lJg_!A8Xj_P%29_R`Mwk~Eg-4KS5@)GRNUM$f%-G)P1#zf)u+?fC*{lN(QXhb6% zx)6;_2czkEj2w*W8kl;f0sh9ABl>z~lywP%=Sd&zsIcNBKM@)2>_I;f%s%+A58tuF zK2S;!($RYZvcj5OCY8?TV1~*7TzDM(HKe=3zeeDDIr#~#``~-OOeIZYg}CQ8Q^$`y}$qg@fzX(rC+^Yu2^cylh;x_IzyKSqYda7xJGfTVlkSh6cz57ip0OQDT6=zN?fD-Xarv9xV?Fv`_Lz?w zjBrBArys$F^eQfML4EV5I8NJEP}g%`rK)efE_N2AWIUw^*pU5LffUeUcM}j zONtRJK@1xx97J?QG7q4NSqhFvZh8X#tCCUz@w?YV;`4y}d07lWL>TOCZ{&u;!H+9s zc#tJAB9p(xX?h4J%7_$<(`lY?db(1YL1b6DD+ZAADafXnuPdB0mT5j!Q9XBc_UdBO zYM8eiOtxpL>Nr6VIq^0IhcHY~$CLBl=+o#Q{MajZ? zTFTG|bta_`j6VHCbQZ4lg%w_*-^yHQF0Ayl-wFZ#;WusjAQqyAhfH4^^0Iuv`%v=!E#HO5t2eiZ==tf% z6`?BI?hp#;0SG8l7d9wGE{A>Om#`=K6=qYI<*9+6V;^QHTrw)-x0n=P`aoyk9h-W_ zVBf&nMhVm&=s?nG-zL`7bKq208&#?_GOp_1e&_|$@a8yZXH^eTjfDI_P|&D9PVj|N zc>2pLYILf+58e-0lHl6FkCCVyN$_X{jSl)NBvOa9X3SbK+kqK=NF`MNp}k7Ktt#A^ z!BgjGXj*mHr*Q<`BytzCC}v0yMGz+RoGRJt5)c@7-XqzTEwJ@2Bbdm#q(J--elG_s4V#=6N$t-oH43llSA~AvSh= zp?|&~Cl9g9$3O4?G+qC4KTaNEadOwz&Ry!xT^j<+?uri0y*&FeuDNSbqQx?gOYg$b zhA6xOP{iD{Sajn^!-kouYh4(bA4yf!rs|th&gusy7u%Aultdj-Z?tmenGHR(zc!Ud zn`de_c&Kj}n3Bqve&+QJBen#_;)w2@>HG%geS3kmu^kT>$TpfdwtGofcHFbV!@_p2 z^qzf)<%W%8+f!gy_0RTi=&|nVrRCSW@$Hx2eEA^{eGhGBc26{t+}0l7*8Y%z#wEw;=fj#})*0pJY_qmxNqfAc{Tl{4DD(awX|Cm& diff --git a/mace-bench/3rdparty/mace/mace/tools/arg_parser.py b/mace-bench/3rdparty/mace/mace/tools/arg_parser.py index 193b1c3..c9d537e 100644 --- a/mace-bench/3rdparty/mace/mace/tools/arg_parser.py +++ b/mace-bench/3rdparty/mace/mace/tools/arg_parser.py @@ -1,971 +1,971 @@ -########################################################################################### -# Parsing functionalities -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import argparse -import os -from typing import Optional - -from .default_keys import DefaultKeys - - -def build_default_arg_parser() -> argparse.ArgumentParser: - try: - import configargparse - - parser = configargparse.ArgumentParser( - config_file_parser_class=configargparse.YAMLConfigFileParser, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add( - "--config", - type=str, - is_config_file=True, - help="config file to aggregate options", - ) - except ImportError: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Name and seed - parser.add_argument("--name", help="experiment name", required=True) - parser.add_argument("--seed", help="random seed", type=int, default=123) - - # Directories - parser.add_argument( - "--work_dir", - help="set directory for all files and folders", - type=str, - default=".", - ) - parser.add_argument( - "--log_dir", help="directory for log files", type=str, default=None - ) - parser.add_argument( - "--model_dir", help="directory for final model", type=str, default=None - ) - parser.add_argument( - "--checkpoints_dir", - help="directory for checkpoint files", - type=str, - default=None, - ) - parser.add_argument( - "--results_dir", help="directory for results", type=str, default=None - ) - parser.add_argument( - "--downloads_dir", help="directory for downloads", type=str, default=None - ) - - # Device and logging - parser.add_argument( - "--device", - help="select device", - type=str, - choices=["cpu", "cuda", "mps", "xpu"], - default="cpu", - ) - parser.add_argument( - "--default_dtype", - help="set default dtype", - type=str, - choices=["float32", "float64"], - default="float64", - ) - parser.add_argument( - "--distributed", - help="train in multi-GPU data parallel mode", - action="store_true", - default=False, - ) - parser.add_argument("--log_level", help="log level", type=str, default="INFO") - - parser.add_argument( - "--plot", - help="Plot results of training", - type=str2bool, - default=True, - ) - - parser.add_argument( - "--plot_frequency", - help="Set plotting frequency: '0' for only at the end or an integer N to plot every N epochs.", - type=int, - default="0", - ) - - parser.add_argument( - "--error_table", - help="Type of error table produced at the end of the training", - type=str, - choices=[ - "PerAtomRMSE", - "TotalRMSE", - "PerAtomRMSEstressvirials", - "PerAtomMAEstressvirials", - "PerAtomMAE", - "TotalMAE", - "DipoleRMSE", - "DipoleMAE", - "EnergyDipoleRMSE", - ], - default="PerAtomRMSE", - ) - - # Model - parser.add_argument( - "--model", - help="model type", - default="MACE", - choices=[ - "BOTNet", - "MACE", - "ScaleShiftMACE", - "ScaleShiftBOTNet", - "AtomicDipolesMACE", - "EnergyDipolesMACE", - ], - ) - parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 - ) - parser.add_argument( - "--radial_type", - help="type of radial basis functions", - type=str, - default="bessel", - choices=["bessel", "gaussian", "chebyshev"], - ) - parser.add_argument( - "--num_radial_basis", - help="number of radial basis functions", - type=int, - default=8, - ) - parser.add_argument( - "--num_cutoff_basis", - help="number of basis functions for smooth cutoff", - type=int, - default=5, - ) - parser.add_argument( - "--pair_repulsion", - help="use pair repulsion term with ZBL potential", - action="store_true", - default=False, - ) - parser.add_argument( - "--distance_transform", - help="use distance transform for radial basis functions", - default="None", - choices=["None", "Agnesi", "Soft"], - ) - parser.add_argument( - "--interaction", - help="name of interaction block", - type=str, - default="RealAgnosticResidualInteractionBlock", - choices=[ - "RealAgnosticResidualInteractionBlock", - "RealAgnosticAttResidualInteractionBlock", - "RealAgnosticInteractionBlock", - "RealAgnosticDensityInteractionBlock", - "RealAgnosticDensityResidualInteractionBlock", - ], - ) - parser.add_argument( - "--interaction_first", - help="name of interaction block", - type=str, - default="RealAgnosticResidualInteractionBlock", - choices=[ - "RealAgnosticResidualInteractionBlock", - "RealAgnosticInteractionBlock", - "RealAgnosticDensityInteractionBlock", - "RealAgnosticDensityResidualInteractionBlock", - ], - ) - parser.add_argument( - "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 - ) - parser.add_argument( - "--correlation", help="correlation order at each layer", type=int, default=3 - ) - parser.add_argument( - "--num_interactions", help="number of interactions", type=int, default=2 - ) - parser.add_argument( - "--MLP_irreps", - help="hidden irreps of the MLP in last readout", - type=str, - default="16x0e", - ) - parser.add_argument( - "--radial_MLP", - help="width of the radial MLP", - type=str, - default="[64, 64, 64]", - ) - parser.add_argument( - "--hidden_irreps", - help="irreps for hidden node states", - type=str, - default=None, - ) - # add option to specify irreps by channel number and max L - parser.add_argument( - "--num_channels", - help="number of embedding channels", - type=int, - default=None, - ) - parser.add_argument( - "--max_L", - help="max L equivariance of the message", - type=int, - default=None, - ) - parser.add_argument( - "--gate", - help="non linearity for last readout", - type=str, - default="silu", - choices=["silu", "tanh", "abs", "None"], - ) - parser.add_argument( - "--scaling", - help="type of scaling to the output", - type=str, - default="rms_forces_scaling", - choices=["std_scaling", "rms_forces_scaling", "no_scaling"], - ) - parser.add_argument( - "--avg_num_neighbors", - help="normalization factor for the message", - type=float, - default=1, - ) - parser.add_argument( - "--compute_avg_num_neighbors", - help="normalization factor for the message", - type=str2bool, - default=True, - ) - parser.add_argument( - "--compute_stress", - help="Select True to compute stress", - type=str2bool, - default=False, - ) - parser.add_argument( - "--compute_forces", - help="Select True to compute forces", - type=str2bool, - default=True, - ) - - # Dataset - parser.add_argument( - "--train_file", - help="Training set file, format is .xyz or .h5", - type=str, - required=False, - ) - parser.add_argument( - "--valid_file", - help="Validation set .xyz or .h5 file", - default=None, - type=str, - required=False, - ) - parser.add_argument( - "--valid_fraction", - help="Fraction of training set used for validation", - type=float, - default=0.1, - required=False, - ) - parser.add_argument( - "--test_file", - help="Test set .xyz pt .h5 file", - type=str, - ) - parser.add_argument( - "--test_dir", - help="Path to directory with test files named as test_*.h5", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--multi_processed_test", - help="Boolean value for whether the test data was multiprocessed", - type=str2bool, - default=False, - required=False, - ) - parser.add_argument( - "--num_workers", - help="Number of workers for data loading", - type=int, - default=0, - ) - parser.add_argument( - "--pin_memory", - help="Pin memory for data loading", - default=True, - type=str2bool, - ) - parser.add_argument( - "--atomic_numbers", - help="List of atomic numbers", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--mean", - help="Mean energy per atom of training set", - type=float, - default=None, - required=False, - ) - parser.add_argument( - "--std", - help="Standard deviation of force components in the training set", - type=float, - default=None, - required=False, - ) - parser.add_argument( - "--statistics_file", - help="json file containing statistics of training set", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--E0s", - help="Dictionary of isolated atom energies", - type=str, - default=None, - required=False, - ) - - # Fine-tuning - parser.add_argument( - "--foundation_filter_elements", - help="Filter element during fine-tuning", - type=str2bool, - default=True, - required=False, - ) - parser.add_argument( - "--heads", - help="Dict of heads: containing individual files and E0s", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--multiheads_finetuning", - help="Boolean value for whether the model is multiheaded", - type=str2bool, - default=True, - ) - parser.add_argument( - "--foundation_head", - help="Name of the head to use for fine-tuning", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--weight_pt_head", - help="Weight of the pretrained head in the loss function", - type=float, - default=1.0, - ) - parser.add_argument( - "--num_samples_pt", - help="Number of samples in the pretrained head", - type=int, - default=10000, - ) - parser.add_argument( - "--force_mh_ft_lr", - help="Force the multiheaded fine-tuning to use arg_parser lr", - type=str2bool, - default=False, - ) - parser.add_argument( - "--subselect_pt", - help="Method to subselect the configurations of the pretraining set", - choices=["fps", "random"], - default="random", - ) - parser.add_argument( - "--filter_type_pt", - help="Filtering method for collecting the pretraining set", - choices=["none", "combinations", "inclusive", "exclusive"], - default="none", - ) - parser.add_argument( - "--pt_train_file", - help="Training set file for the pretrained head", - type=str, - default=None, - ) - parser.add_argument( - "--pt_valid_file", - help="Validation set file for the pretrained head", - type=str, - default=None, - ) - parser.add_argument( - "--foundation_model_elements", - help="Keep all elements of the foundation model during fine-tuning", - type=str2bool, - default=False, - ) - parser.add_argument( - "--keep_isolated_atoms", - help="Keep isolated atoms in the dataset, useful for transfer learning", - type=str2bool, - default=False, - ) - - # Keys - parser.add_argument( - "--energy_key", - help="Key of reference energies in training xyz", - type=str, - default=DefaultKeys.ENERGY.value, - ) - parser.add_argument( - "--forces_key", - help="Key of reference forces in training xyz", - type=str, - default=DefaultKeys.FORCES.value, - ) - parser.add_argument( - "--virials_key", - help="Key of reference virials in training xyz", - type=str, - default=DefaultKeys.VIRIALS.value, - ) - parser.add_argument( - "--stress_key", - help="Key of reference stress in training xyz", - type=str, - default=DefaultKeys.STRESS.value, - ) - parser.add_argument( - "--dipole_key", - help="Key of reference dipoles in training xyz", - type=str, - default=DefaultKeys.DIPOLE.value, - ) - parser.add_argument( - "--head_key", - help="Key of head in training xyz", - type=str, - default=DefaultKeys.HEAD.value, - ) - parser.add_argument( - "--charges_key", - help="Key of atomic charges in training xyz", - type=str, - default=DefaultKeys.CHARGES.value, - ) - parser.add_argument( - "--skip_evaluate_heads", - help="Comma-separated list of heads to skip during final evaluation", - type=str, - default="pt_head", - ) - - # Loss and optimization - parser.add_argument( - "--loss", - help="type of loss", - default="weighted", - choices=[ - "ef", - "weighted", - "forces_only", - "virials", - "stress", - "dipole", - "huber", - "universal", - "energy_forces_dipole", - "l1l2energyforces", - ], - ) - parser.add_argument( - "--forces_weight", help="weight of forces loss", type=float, default=100.0 - ) - parser.add_argument( - "--swa_forces_weight", - "--stage_two_forces_weight", - help="weight of forces loss after starting Stage Two (previously called swa)", - type=float, - default=100.0, - dest="swa_forces_weight", - ) - parser.add_argument( - "--energy_weight", help="weight of energy loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_energy_weight", - "--stage_two_energy_weight", - help="weight of energy loss after starting Stage Two (previously called swa)", - type=float, - default=1000.0, - dest="swa_energy_weight", - ) - parser.add_argument( - "--virials_weight", help="weight of virials loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_virials_weight", - "--stage_two_virials_weight", - help="weight of virials loss after starting Stage Two (previously called swa)", - type=float, - default=10.0, - dest="swa_virials_weight", - ) - parser.add_argument( - "--stress_weight", help="weight of stress loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_stress_weight", - "--stage_two_stress_weight", - help="weight of stress loss after starting Stage Two (previously called swa)", - type=float, - default=10.0, - dest="swa_stress_weight", - ) - parser.add_argument( - "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 - ) - parser.add_argument( - "--swa_dipole_weight", - "--stage_two_dipole_weight", - help="weight of dipoles after starting Stage Two (previously called swa)", - type=float, - default=1.0, - dest="swa_dipole_weight", - ) - parser.add_argument( - "--config_type_weights", - help="String of dictionary containing the weights for each config type", - type=str, - default='{"Default":1.0}', - ) - parser.add_argument( - "--huber_delta", - help="delta parameter for huber loss", - type=float, - default=0.01, - ) - parser.add_argument( - "--optimizer", - help="Optimizer for parameter optimization", - type=str, - default="adam", - choices=["adam", "adamw", "schedulefree"], - ) - parser.add_argument( - "--beta", - help="Beta parameter for the optimizer", - type=float, - default=0.9, - ) - parser.add_argument("--batch_size", help="batch size", type=int, default=10) - parser.add_argument( - "--valid_batch_size", help="Validation batch size", type=int, default=10 - ) - parser.add_argument( - "--lr", help="Learning rate of optimizer", type=float, default=0.01 - ) - parser.add_argument( - "--swa_lr", - "--stage_two_lr", - help="Learning rate of optimizer in Stage Two (previously called swa)", - type=float, - default=1e-3, - dest="swa_lr", - ) - parser.add_argument( - "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 - ) - parser.add_argument( - "--amsgrad", - help="use amsgrad variant of optimizer", - action="store_true", - default=True, - ) - parser.add_argument( - "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" - ) - parser.add_argument( - "--lr_factor", help="Learning rate factor", type=float, default=0.8 - ) - parser.add_argument( - "--scheduler_patience", help="Learning rate factor", type=int, default=50 - ) - parser.add_argument( - "--lr_scheduler_gamma", - help="Gamma of learning rate scheduler", - type=float, - default=0.9993, - ) - parser.add_argument( - "--swa", - "--stage_two", - help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", - action="store_true", - default=False, - dest="swa", - ) - parser.add_argument( - "--start_swa", - "--start_stage_two", - help="Number of epochs before changing to Stage Two loss weights", - type=int, - default=None, - dest="start_swa", - ) - parser.add_argument( - "--lbfgs", - help="Switch to L-BFGS optimizer", - action="store_true", - default=False, - ) - parser.add_argument( - "--ema", - help="use Exponential Moving Average", - action="store_true", - default=False, - ) - parser.add_argument( - "--ema_decay", - help="Exponential Moving Average decay", - type=float, - default=0.99, - ) - parser.add_argument( - "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 - ) - parser.add_argument( - "--patience", - help="Maximum number of consecutive epochs of increasing loss", - type=int, - default=2048, - ) - parser.add_argument( - "--foundation_model", - help="Path to the foundation model for transfer learning", - type=str, - default=None, - ) - parser.add_argument( - "--foundation_model_readout", - help="Use readout of foundation model for transfer learning", - action="store_false", - default=True, - ) - parser.add_argument( - "--eval_interval", help="evaluate model every epochs", type=int, default=1 - ) - parser.add_argument( - "--keep_checkpoints", - help="keep all checkpoints", - action="store_true", - default=False, - ) - parser.add_argument( - "--save_all_checkpoints", - help="save all checkpoints", - action="store_true", - default=False, - ) - parser.add_argument( - "--restart_latest", - help="restart optimizer from latest checkpoint", - action="store_true", - default=False, - ) - parser.add_argument( - "--save_cpu", - help="Save a model to be loaded on cpu", - action="store_true", - default=False, - ) - parser.add_argument( - "--clip_grad", - help="Gradient Clipping Value", - type=check_float_or_none, - default=10.0, - ) - parser.add_argument( - "--dry_run", - help="Run all steps upto training to test settings.", - action="store_true", - default=False, - ) - # option for cuequivariance acceleration - parser.add_argument( - "--enable_cueq", - help="Enable cuequivariance acceleration", - type=str2bool, - default=False, - ) - # options for using Weights and Biases for experiment tracking - # to install see https://wandb.ai - parser.add_argument( - "--wandb", - help="Use Weights and Biases for experiment tracking", - action="store_true", - default=False, - ) - parser.add_argument( - "--wandb_dir", - help="An absolute path to a directory where Weights and Biases metadata will be stored", - type=str, - default=None, - ) - parser.add_argument( - "--wandb_project", - help="Weights and Biases project name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_entity", - help="Weights and Biases entity name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_name", - help="Weights and Biases experiment name", - type=str, - default="", - ) - parser.add_argument( - "--wandb_log_hypers", - help="The hyperparameters to log in Weights and Biases", - type=list, - default=[ - "num_channels", - "max_L", - "correlation", - "lr", - "swa_lr", - "weight_decay", - "batch_size", - "max_num_epochs", - "start_swa", - "energy_weight", - "forces_weight", - ], - ) - return parser - - -def build_preprocess_arg_parser() -> argparse.ArgumentParser: - try: - import configargparse - - parser = configargparse.ArgumentParser( - config_file_parser_class=configargparse.YAMLConfigFileParser, - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add( - "--config", - type=str, - is_config_file=True, - help="config file to aggregate options", - ) - except ImportError: - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--train_file", - help="Training set h5 file", - type=str, - default=None, - required=True, - ) - parser.add_argument( - "--valid_file", - help="Training set xyz file", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--num_process", - help="The user defined number of processes to use, as well as the number of files created.", - type=int, - default=int(os.cpu_count() / 4), - ) - parser.add_argument( - "--valid_fraction", - help="Fraction of training set used for validation", - type=float, - default=0.1, - required=False, - ) - parser.add_argument( - "--test_file", - help="Test set xyz file", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--work_dir", - help="set directory for all files and folders", - type=str, - default=".", - ) - parser.add_argument( - "--h5_prefix", - help="Prefix for h5 files when saving", - type=str, - default="", - ) - parser.add_argument( - "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 - ) - parser.add_argument( - "--config_type_weights", - help="String of dictionary containing the weights for each config type", - type=str, - default='{"Default":1.0}', - ) - parser.add_argument( - "--energy_key", - help="Key of reference energies in training xyz", - type=str, - default=DefaultKeys.ENERGY.value, - ) - parser.add_argument( - "--forces_key", - help="Key of reference forces in training xyz", - type=str, - default=DefaultKeys.FORCES.value, - ) - parser.add_argument( - "--virials_key", - help="Key of reference virials in training xyz", - type=str, - default=DefaultKeys.VIRIALS.value, - ) - parser.add_argument( - "--stress_key", - help="Key of reference stress in training xyz", - type=str, - default=DefaultKeys.STRESS.value, - ) - parser.add_argument( - "--dipole_key", - help="Key of reference dipoles in training xyz", - type=str, - default=DefaultKeys.DIPOLE.value, - ) - parser.add_argument( - "--charges_key", - help="Key of atomic charges in training xyz", - type=str, - default=DefaultKeys.CHARGES.value, - ) - parser.add_argument( - "--atomic_numbers", - help="List of atomic numbers", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--compute_statistics", - help="Compute statistics for the dataset", - action="store_true", - default=False, - ) - parser.add_argument( - "--batch_size", - help="batch size to compute average number of neighbours", - type=int, - default=16, - ) - - parser.add_argument( - "--scaling", - help="type of scaling to the output", - type=str, - default="rms_forces_scaling", - choices=["std_scaling", "rms_forces_scaling", "no_scaling"], - ) - parser.add_argument( - "--E0s", - help="Dictionary of isolated atom energies", - type=str, - default=None, - required=False, - ) - parser.add_argument( - "--shuffle", - help="Shuffle the training dataset", - type=str2bool, - default=True, - ) - parser.add_argument( - "--seed", - help="Random seed for splitting training and validation sets", - type=int, - default=123, - ) - parser.add_argument( - "--head_key", - help="Key of head in training xyz", - type=str, - default=DefaultKeys.HEAD.value, - ) - parser.add_argument( - "--heads", - help="Dict of heads: containing individual files and E0s", - type=str, - default=None, - required=False, - ) - return parser - - -def check_float_or_none(value: str) -> Optional[float]: - try: - return float(value) - except ValueError: - if value != "None": - raise argparse.ArgumentTypeError( - f"{value} is an invalid value (float or None)" - ) from None - return None - - -def str2bool(value): - if isinstance(value, bool): - return value - if value.lower() in ("yes", "true", "t", "y", "1"): - return True - if value.lower() in ("no", "false", "f", "n", "0"): - return False - raise argparse.ArgumentTypeError("Boolean value expected.") +########################################################################################### +# Parsing functionalities +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import os +from typing import Optional + +from .default_keys import DefaultKeys + + +def build_default_arg_parser() -> argparse.ArgumentParser: + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to aggregate options", + ) + except ImportError: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Name and seed + parser.add_argument("--name", help="experiment name", required=True) + parser.add_argument("--seed", help="random seed", type=int, default=123) + + # Directories + parser.add_argument( + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) + parser.add_argument( + "--log_dir", help="directory for log files", type=str, default=None + ) + parser.add_argument( + "--model_dir", help="directory for final model", type=str, default=None + ) + parser.add_argument( + "--checkpoints_dir", + help="directory for checkpoint files", + type=str, + default=None, + ) + parser.add_argument( + "--results_dir", help="directory for results", type=str, default=None + ) + parser.add_argument( + "--downloads_dir", help="directory for downloads", type=str, default=None + ) + + # Device and logging + parser.add_argument( + "--device", + help="select device", + type=str, + choices=["cpu", "cuda", "mps", "xpu"], + default="cpu", + ) + parser.add_argument( + "--default_dtype", + help="set default dtype", + type=str, + choices=["float32", "float64"], + default="float64", + ) + parser.add_argument( + "--distributed", + help="train in multi-GPU data parallel mode", + action="store_true", + default=False, + ) + parser.add_argument("--log_level", help="log level", type=str, default="INFO") + + parser.add_argument( + "--plot", + help="Plot results of training", + type=str2bool, + default=True, + ) + + parser.add_argument( + "--plot_frequency", + help="Set plotting frequency: '0' for only at the end or an integer N to plot every N epochs.", + type=int, + default="0", + ) + + parser.add_argument( + "--error_table", + help="Type of error table produced at the end of the training", + type=str, + choices=[ + "PerAtomRMSE", + "TotalRMSE", + "PerAtomRMSEstressvirials", + "PerAtomMAEstressvirials", + "PerAtomMAE", + "TotalMAE", + "DipoleRMSE", + "DipoleMAE", + "EnergyDipoleRMSE", + ], + default="PerAtomRMSE", + ) + + # Model + parser.add_argument( + "--model", + help="model type", + default="MACE", + choices=[ + "BOTNet", + "MACE", + "ScaleShiftMACE", + "ScaleShiftBOTNet", + "AtomicDipolesMACE", + "EnergyDipolesMACE", + ], + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + ) + parser.add_argument( + "--radial_type", + help="type of radial basis functions", + type=str, + default="bessel", + choices=["bessel", "gaussian", "chebyshev"], + ) + parser.add_argument( + "--num_radial_basis", + help="number of radial basis functions", + type=int, + default=8, + ) + parser.add_argument( + "--num_cutoff_basis", + help="number of basis functions for smooth cutoff", + type=int, + default=5, + ) + parser.add_argument( + "--pair_repulsion", + help="use pair repulsion term with ZBL potential", + action="store_true", + default=False, + ) + parser.add_argument( + "--distance_transform", + help="use distance transform for radial basis functions", + default="None", + choices=["None", "Agnesi", "Soft"], + ) + parser.add_argument( + "--interaction", + help="name of interaction block", + type=str, + default="RealAgnosticResidualInteractionBlock", + choices=[ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticAttResidualInteractionBlock", + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ], + ) + parser.add_argument( + "--interaction_first", + help="name of interaction block", + type=str, + default="RealAgnosticResidualInteractionBlock", + choices=[ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ], + ) + parser.add_argument( + "--max_ell", help=r"highest \ell of spherical harmonics", type=int, default=3 + ) + parser.add_argument( + "--correlation", help="correlation order at each layer", type=int, default=3 + ) + parser.add_argument( + "--num_interactions", help="number of interactions", type=int, default=2 + ) + parser.add_argument( + "--MLP_irreps", + help="hidden irreps of the MLP in last readout", + type=str, + default="16x0e", + ) + parser.add_argument( + "--radial_MLP", + help="width of the radial MLP", + type=str, + default="[64, 64, 64]", + ) + parser.add_argument( + "--hidden_irreps", + help="irreps for hidden node states", + type=str, + default=None, + ) + # add option to specify irreps by channel number and max L + parser.add_argument( + "--num_channels", + help="number of embedding channels", + type=int, + default=None, + ) + parser.add_argument( + "--max_L", + help="max L equivariance of the message", + type=int, + default=None, + ) + parser.add_argument( + "--gate", + help="non linearity for last readout", + type=str, + default="silu", + choices=["silu", "tanh", "abs", "None"], + ) + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--avg_num_neighbors", + help="normalization factor for the message", + type=float, + default=1, + ) + parser.add_argument( + "--compute_avg_num_neighbors", + help="normalization factor for the message", + type=str2bool, + default=True, + ) + parser.add_argument( + "--compute_stress", + help="Select True to compute stress", + type=str2bool, + default=False, + ) + parser.add_argument( + "--compute_forces", + help="Select True to compute forces", + type=str2bool, + default=True, + ) + + # Dataset + parser.add_argument( + "--train_file", + help="Training set file, format is .xyz or .h5", + type=str, + required=False, + ) + parser.add_argument( + "--valid_file", + help="Validation set .xyz or .h5 file", + default=None, + type=str, + required=False, + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set .xyz pt .h5 file", + type=str, + ) + parser.add_argument( + "--test_dir", + help="Path to directory with test files named as test_*.h5", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--multi_processed_test", + help="Boolean value for whether the test data was multiprocessed", + type=str2bool, + default=False, + required=False, + ) + parser.add_argument( + "--num_workers", + help="Number of workers for data loading", + type=int, + default=0, + ) + parser.add_argument( + "--pin_memory", + help="Pin memory for data loading", + default=True, + type=str2bool, + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--mean", + help="Mean energy per atom of training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--std", + help="Standard deviation of force components in the training set", + type=float, + default=None, + required=False, + ) + parser.add_argument( + "--statistics_file", + help="json file containing statistics of training set", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + + # Fine-tuning + parser.add_argument( + "--foundation_filter_elements", + help="Filter element during fine-tuning", + type=str2bool, + default=True, + required=False, + ) + parser.add_argument( + "--heads", + help="Dict of heads: containing individual files and E0s", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--multiheads_finetuning", + help="Boolean value for whether the model is multiheaded", + type=str2bool, + default=True, + ) + parser.add_argument( + "--foundation_head", + help="Name of the head to use for fine-tuning", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--weight_pt_head", + help="Weight of the pretrained head in the loss function", + type=float, + default=1.0, + ) + parser.add_argument( + "--num_samples_pt", + help="Number of samples in the pretrained head", + type=int, + default=10000, + ) + parser.add_argument( + "--force_mh_ft_lr", + help="Force the multiheaded fine-tuning to use arg_parser lr", + type=str2bool, + default=False, + ) + parser.add_argument( + "--subselect_pt", + help="Method to subselect the configurations of the pretraining set", + choices=["fps", "random"], + default="random", + ) + parser.add_argument( + "--filter_type_pt", + help="Filtering method for collecting the pretraining set", + choices=["none", "combinations", "inclusive", "exclusive"], + default="none", + ) + parser.add_argument( + "--pt_train_file", + help="Training set file for the pretrained head", + type=str, + default=None, + ) + parser.add_argument( + "--pt_valid_file", + help="Validation set file for the pretrained head", + type=str, + default=None, + ) + parser.add_argument( + "--foundation_model_elements", + help="Keep all elements of the foundation model during fine-tuning", + type=str2bool, + default=False, + ) + parser.add_argument( + "--keep_isolated_atoms", + help="Keep isolated atoms in the dataset, useful for transfer learning", + type=str2bool, + default=False, + ) + + # Keys + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default=DefaultKeys.ENERGY.value, + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default=DefaultKeys.FORCES.value, + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default=DefaultKeys.VIRIALS.value, + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default=DefaultKeys.STRESS.value, + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default=DefaultKeys.DIPOLE.value, + ) + parser.add_argument( + "--head_key", + help="Key of head in training xyz", + type=str, + default=DefaultKeys.HEAD.value, + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default=DefaultKeys.CHARGES.value, + ) + parser.add_argument( + "--skip_evaluate_heads", + help="Comma-separated list of heads to skip during final evaluation", + type=str, + default="pt_head", + ) + + # Loss and optimization + parser.add_argument( + "--loss", + help="type of loss", + default="weighted", + choices=[ + "ef", + "weighted", + "forces_only", + "virials", + "stress", + "dipole", + "huber", + "universal", + "energy_forces_dipole", + "l1l2energyforces", + ], + ) + parser.add_argument( + "--forces_weight", help="weight of forces loss", type=float, default=100.0 + ) + parser.add_argument( + "--swa_forces_weight", + "--stage_two_forces_weight", + help="weight of forces loss after starting Stage Two (previously called swa)", + type=float, + default=100.0, + dest="swa_forces_weight", + ) + parser.add_argument( + "--energy_weight", help="weight of energy loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_energy_weight", + "--stage_two_energy_weight", + help="weight of energy loss after starting Stage Two (previously called swa)", + type=float, + default=1000.0, + dest="swa_energy_weight", + ) + parser.add_argument( + "--virials_weight", help="weight of virials loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_virials_weight", + "--stage_two_virials_weight", + help="weight of virials loss after starting Stage Two (previously called swa)", + type=float, + default=10.0, + dest="swa_virials_weight", + ) + parser.add_argument( + "--stress_weight", help="weight of stress loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_stress_weight", + "--stage_two_stress_weight", + help="weight of stress loss after starting Stage Two (previously called swa)", + type=float, + default=10.0, + dest="swa_stress_weight", + ) + parser.add_argument( + "--dipole_weight", help="weight of dipoles loss", type=float, default=1.0 + ) + parser.add_argument( + "--swa_dipole_weight", + "--stage_two_dipole_weight", + help="weight of dipoles after starting Stage Two (previously called swa)", + type=float, + default=1.0, + dest="swa_dipole_weight", + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--huber_delta", + help="delta parameter for huber loss", + type=float, + default=0.01, + ) + parser.add_argument( + "--optimizer", + help="Optimizer for parameter optimization", + type=str, + default="adam", + choices=["adam", "adamw", "schedulefree"], + ) + parser.add_argument( + "--beta", + help="Beta parameter for the optimizer", + type=float, + default=0.9, + ) + parser.add_argument("--batch_size", help="batch size", type=int, default=10) + parser.add_argument( + "--valid_batch_size", help="Validation batch size", type=int, default=10 + ) + parser.add_argument( + "--lr", help="Learning rate of optimizer", type=float, default=0.01 + ) + parser.add_argument( + "--swa_lr", + "--stage_two_lr", + help="Learning rate of optimizer in Stage Two (previously called swa)", + type=float, + default=1e-3, + dest="swa_lr", + ) + parser.add_argument( + "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 + ) + parser.add_argument( + "--amsgrad", + help="use amsgrad variant of optimizer", + action="store_true", + default=True, + ) + parser.add_argument( + "--scheduler", help="Type of scheduler", type=str, default="ReduceLROnPlateau" + ) + parser.add_argument( + "--lr_factor", help="Learning rate factor", type=float, default=0.8 + ) + parser.add_argument( + "--scheduler_patience", help="Learning rate factor", type=int, default=50 + ) + parser.add_argument( + "--lr_scheduler_gamma", + help="Gamma of learning rate scheduler", + type=float, + default=0.9993, + ) + parser.add_argument( + "--swa", + "--stage_two", + help="use Stage Two loss weight, which decreases the learning rate and increases the energy weight at the end of the training to help converge them", + action="store_true", + default=False, + dest="swa", + ) + parser.add_argument( + "--start_swa", + "--start_stage_two", + help="Number of epochs before changing to Stage Two loss weights", + type=int, + default=None, + dest="start_swa", + ) + parser.add_argument( + "--lbfgs", + help="Switch to L-BFGS optimizer", + action="store_true", + default=False, + ) + parser.add_argument( + "--ema", + help="use Exponential Moving Average", + action="store_true", + default=False, + ) + parser.add_argument( + "--ema_decay", + help="Exponential Moving Average decay", + type=float, + default=0.99, + ) + parser.add_argument( + "--max_num_epochs", help="Maximum number of epochs", type=int, default=2048 + ) + parser.add_argument( + "--patience", + help="Maximum number of consecutive epochs of increasing loss", + type=int, + default=2048, + ) + parser.add_argument( + "--foundation_model", + help="Path to the foundation model for transfer learning", + type=str, + default=None, + ) + parser.add_argument( + "--foundation_model_readout", + help="Use readout of foundation model for transfer learning", + action="store_false", + default=True, + ) + parser.add_argument( + "--eval_interval", help="evaluate model every epochs", type=int, default=1 + ) + parser.add_argument( + "--keep_checkpoints", + help="keep all checkpoints", + action="store_true", + default=False, + ) + parser.add_argument( + "--save_all_checkpoints", + help="save all checkpoints", + action="store_true", + default=False, + ) + parser.add_argument( + "--restart_latest", + help="restart optimizer from latest checkpoint", + action="store_true", + default=False, + ) + parser.add_argument( + "--save_cpu", + help="Save a model to be loaded on cpu", + action="store_true", + default=False, + ) + parser.add_argument( + "--clip_grad", + help="Gradient Clipping Value", + type=check_float_or_none, + default=10.0, + ) + parser.add_argument( + "--dry_run", + help="Run all steps upto training to test settings.", + action="store_true", + default=False, + ) + # option for cuequivariance acceleration + parser.add_argument( + "--enable_cueq", + help="Enable cuequivariance acceleration", + type=str2bool, + default=False, + ) + # options for using Weights and Biases for experiment tracking + # to install see https://wandb.ai + parser.add_argument( + "--wandb", + help="Use Weights and Biases for experiment tracking", + action="store_true", + default=False, + ) + parser.add_argument( + "--wandb_dir", + help="An absolute path to a directory where Weights and Biases metadata will be stored", + type=str, + default=None, + ) + parser.add_argument( + "--wandb_project", + help="Weights and Biases project name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_entity", + help="Weights and Biases entity name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_name", + help="Weights and Biases experiment name", + type=str, + default="", + ) + parser.add_argument( + "--wandb_log_hypers", + help="The hyperparameters to log in Weights and Biases", + type=list, + default=[ + "num_channels", + "max_L", + "correlation", + "lr", + "swa_lr", + "weight_decay", + "batch_size", + "max_num_epochs", + "start_swa", + "energy_weight", + "forces_weight", + ], + ) + return parser + + +def build_preprocess_arg_parser() -> argparse.ArgumentParser: + try: + import configargparse + + parser = configargparse.ArgumentParser( + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add( + "--config", + type=str, + is_config_file=True, + help="config file to aggregate options", + ) + except ImportError: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--train_file", + help="Training set h5 file", + type=str, + default=None, + required=True, + ) + parser.add_argument( + "--valid_file", + help="Training set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--num_process", + help="The user defined number of processes to use, as well as the number of files created.", + type=int, + default=int(os.cpu_count() / 4), + ) + parser.add_argument( + "--valid_fraction", + help="Fraction of training set used for validation", + type=float, + default=0.1, + required=False, + ) + parser.add_argument( + "--test_file", + help="Test set xyz file", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--work_dir", + help="set directory for all files and folders", + type=str, + default=".", + ) + parser.add_argument( + "--h5_prefix", + help="Prefix for h5 files when saving", + type=str, + default="", + ) + parser.add_argument( + "--r_max", help="distance cutoff (in Ang)", type=float, default=5.0 + ) + parser.add_argument( + "--config_type_weights", + help="String of dictionary containing the weights for each config type", + type=str, + default='{"Default":1.0}', + ) + parser.add_argument( + "--energy_key", + help="Key of reference energies in training xyz", + type=str, + default=DefaultKeys.ENERGY.value, + ) + parser.add_argument( + "--forces_key", + help="Key of reference forces in training xyz", + type=str, + default=DefaultKeys.FORCES.value, + ) + parser.add_argument( + "--virials_key", + help="Key of reference virials in training xyz", + type=str, + default=DefaultKeys.VIRIALS.value, + ) + parser.add_argument( + "--stress_key", + help="Key of reference stress in training xyz", + type=str, + default=DefaultKeys.STRESS.value, + ) + parser.add_argument( + "--dipole_key", + help="Key of reference dipoles in training xyz", + type=str, + default=DefaultKeys.DIPOLE.value, + ) + parser.add_argument( + "--charges_key", + help="Key of atomic charges in training xyz", + type=str, + default=DefaultKeys.CHARGES.value, + ) + parser.add_argument( + "--atomic_numbers", + help="List of atomic numbers", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--compute_statistics", + help="Compute statistics for the dataset", + action="store_true", + default=False, + ) + parser.add_argument( + "--batch_size", + help="batch size to compute average number of neighbours", + type=int, + default=16, + ) + + parser.add_argument( + "--scaling", + help="type of scaling to the output", + type=str, + default="rms_forces_scaling", + choices=["std_scaling", "rms_forces_scaling", "no_scaling"], + ) + parser.add_argument( + "--E0s", + help="Dictionary of isolated atom energies", + type=str, + default=None, + required=False, + ) + parser.add_argument( + "--shuffle", + help="Shuffle the training dataset", + type=str2bool, + default=True, + ) + parser.add_argument( + "--seed", + help="Random seed for splitting training and validation sets", + type=int, + default=123, + ) + parser.add_argument( + "--head_key", + help="Key of head in training xyz", + type=str, + default=DefaultKeys.HEAD.value, + ) + parser.add_argument( + "--heads", + help="Dict of heads: containing individual files and E0s", + type=str, + default=None, + required=False, + ) + return parser + + +def check_float_or_none(value: str) -> Optional[float]: + try: + return float(value) + except ValueError: + if value != "None": + raise argparse.ArgumentTypeError( + f"{value} is an invalid value (float or None)" + ) from None + return None + + +def str2bool(value): + if isinstance(value, bool): + return value + if value.lower() in ("yes", "true", "t", "y", "1"): + return True + if value.lower() in ("no", "false", "f", "n", "0"): + return False + raise argparse.ArgumentTypeError("Boolean value expected.") diff --git a/mace-bench/3rdparty/mace/mace/tools/arg_parser_tools.py b/mace-bench/3rdparty/mace/mace/tools/arg_parser_tools.py index 21a23ff..be714b2 100644 --- a/mace-bench/3rdparty/mace/mace/tools/arg_parser_tools.py +++ b/mace-bench/3rdparty/mace/mace/tools/arg_parser_tools.py @@ -1,122 +1,122 @@ -import logging -import os - -from e3nn import o3 - - -def check_args(args): - """ - Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing - the (potentially) modified args and a list of log messages. - """ - log_messages = [] - - # Directories - # Use work_dir for all other directories as well, unless they were specified by the user - if args.log_dir is None: - args.log_dir = os.path.join(args.work_dir, "logs") - if args.model_dir is None: - args.model_dir = args.work_dir - if args.checkpoints_dir is None: - args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") - if args.results_dir is None: - args.results_dir = os.path.join(args.work_dir, "results") - if args.downloads_dir is None: - args.downloads_dir = os.path.join(args.work_dir, "downloads") - - # Model - # Check if hidden_irreps, num_channels and max_L are consistent - if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: - args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 - elif ( - args.hidden_irreps is not None - and args.num_channels is not None - and args.max_L is not None - ): - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - log_messages.append( - ( - "All of hidden_irreps, num_channels and max_L are specified", - logging.WARNING, - ) - ) - log_messages.append( - ( - f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", - logging.WARNING, - ) - ) - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - elif args.num_channels is not None and args.max_L is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - assert args.max_L >= 0, "max_L must be non-negative integer" - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - elif args.hidden_irreps is not None: - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - - args.num_channels = list( - {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} - )[0] - args.max_L = o3.Irreps(args.hidden_irreps).lmax - elif args.max_L is not None and args.num_channels is None: - assert args.max_L >= 0, "max_L must be non-negative integer" - args.num_channels = 128 - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - elif args.max_L is None and args.num_channels is not None: - assert args.num_channels > 0, "num_channels must be positive integer" - args.max_L = 1 - args.hidden_irreps = o3.Irreps( - (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) - .sort() - .irreps.simplify() - ) - - # Loss and optimization - # Check Stage Two loss start - if args.start_swa is not None: - args.swa = True - log_messages.append( - ( - "Stage Two is activated as start_stage_two was defined", - logging.INFO, - ) - ) - - if args.swa: - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - if args.start_swa > args.max_num_epochs: - log_messages.append( - ( - f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", - logging.WARNING, - ) - ) - log_messages.append( - ( - "Stage Two will not start, as start_stage_two > max_num_epochs", - logging.WARNING, - ) - ) - args.swa = False - - return args, log_messages +import logging +import os + +from e3nn import o3 + + +def check_args(args): + """ + Check input arguments, update them if necessary for valid and consistent inputs, and return a tuple containing + the (potentially) modified args and a list of log messages. + """ + log_messages = [] + + # Directories + # Use work_dir for all other directories as well, unless they were specified by the user + if args.log_dir is None: + args.log_dir = os.path.join(args.work_dir, "logs") + if args.model_dir is None: + args.model_dir = args.work_dir + if args.checkpoints_dir is None: + args.checkpoints_dir = os.path.join(args.work_dir, "checkpoints") + if args.results_dir is None: + args.results_dir = os.path.join(args.work_dir, "results") + if args.downloads_dir is None: + args.downloads_dir = os.path.join(args.work_dir, "downloads") + + # Model + # Check if hidden_irreps, num_channels and max_L are consistent + if args.hidden_irreps is None and args.num_channels is None and args.max_L is None: + args.hidden_irreps, args.num_channels, args.max_L = "128x0e + 128x1o", 128, 1 + elif ( + args.hidden_irreps is not None + and args.num_channels is not None + and args.max_L is not None + ): + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + log_messages.append( + ( + "All of hidden_irreps, num_channels and max_L are specified", + logging.WARNING, + ) + ) + log_messages.append( + ( + f"Using num_channels and max_L to create hidden_irreps: {args.hidden_irreps}.", + logging.WARNING, + ) + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.num_channels is not None and args.max_L is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + assert args.max_L >= 0, "max_L must be non-negative integer" + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + elif args.hidden_irreps is not None: + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + args.num_channels = list( + {irrep.mul for irrep in o3.Irreps(args.hidden_irreps)} + )[0] + args.max_L = o3.Irreps(args.hidden_irreps).lmax + elif args.max_L is not None and args.num_channels is None: + assert args.max_L >= 0, "max_L must be non-negative integer" + args.num_channels = 128 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + elif args.max_L is None and args.num_channels is not None: + assert args.num_channels > 0, "num_channels must be positive integer" + args.max_L = 1 + args.hidden_irreps = o3.Irreps( + (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) + .sort() + .irreps.simplify() + ) + + # Loss and optimization + # Check Stage Two loss start + if args.start_swa is not None: + args.swa = True + log_messages.append( + ( + "Stage Two is activated as start_stage_two was defined", + logging.INFO, + ) + ) + + if args.swa: + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + if args.start_swa > args.max_num_epochs: + log_messages.append( + ( + f"start_stage_two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}", + logging.WARNING, + ) + ) + log_messages.append( + ( + "Stage Two will not start, as start_stage_two > max_num_epochs", + logging.WARNING, + ) + ) + args.swa = False + + return args, log_messages diff --git a/mace-bench/3rdparty/mace/mace/tools/cg.py b/mace-bench/3rdparty/mace/mace/tools/cg.py index 5570be0..471adac 100644 --- a/mace-bench/3rdparty/mace/mace/tools/cg.py +++ b/mace-bench/3rdparty/mace/mace/tools/cg.py @@ -1,211 +1,211 @@ -########################################################################################### -# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) -# Authors: Ilyes Batatia -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import collections -import itertools -import os -from typing import Iterator, List, Union - -import numpy as np -import torch -from e3nn import o3 - -try: - import cuequivariance as cue - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - -USE_CUEQ_CG = os.environ.get("MACE_USE_CUEQ_CG", "0").lower() in ( - "1", - "true", - "yes", - "y", -) - -_TP = collections.namedtuple("_TP", "op, args") -_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") - - -def _wigner_nj( - irrepss: List[o3.Irreps], - normalization: str = "component", - filter_ir_mid=None, - dtype=None, -): - irrepss = [o3.Irreps(irreps) for irreps in irrepss] - if filter_ir_mid is not None: - filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] - - if len(irrepss) == 1: - (irreps,) = irrepss - ret = [] - e = torch.eye(irreps.dim, dtype=dtype) - i = 0 - for mul, ir in irreps: - for _ in range(mul): - sl = slice(i, i + ir.dim) - ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] - i += ir.dim - return ret - - *irrepss_left, irreps_right = irrepss - ret = [] - for ir_left, path_left, C_left in _wigner_nj( - irrepss_left, - normalization=normalization, - filter_ir_mid=filter_ir_mid, - dtype=dtype, - ): - i = 0 - for mul, ir in irreps_right: - for ir_out in ir_left * ir: - if filter_ir_mid is not None and ir_out not in filter_ir_mid: - continue - - C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) - if normalization == "component": - C *= ir_out.dim**0.5 - if normalization == "norm": - C *= ir_left.dim**0.5 * ir.dim**0.5 - - C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) - C = C.reshape( - ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim - ) - for u in range(mul): - E = torch.zeros( - ir_out.dim, - *(irreps.dim for irreps in irrepss_left), - irreps_right.dim, - dtype=dtype, - ) - sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) - E[..., sl] = C - ret += [ - ( - ir_out, - _TP( - op=(ir_left, ir, ir_out), - args=( - path_left, - _INPUT(len(irrepss_left), sl.start, sl.stop), - ), - ), - E, - ) - ] - i += mul * ir.dim - return sorted(ret, key=lambda x: x[0]) - - -def U_matrix_real( - irreps_in: Union[str, o3.Irreps], - irreps_out: Union[str, o3.Irreps], - correlation: int, - normalization: str = "component", - filter_ir_mid=None, - dtype=None, - use_cueq_cg=None, -): - irreps_out = o3.Irreps(irreps_out) - irrepss = [o3.Irreps(irreps_in)] * correlation - - if correlation == 4: - filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)] - - if use_cueq_cg is None: - use_cueq_cg = USE_CUEQ_CG - if use_cueq_cg and CUET_AVAILABLE: - return compute_U_cueq(irreps_in, irreps_out=irreps_out, correlation=correlation) - - try: - wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) - except NotImplementedError as e: - if CUET_AVAILABLE: - return compute_U_cueq( - irreps_in, irreps_out=irreps_out, correlation=correlation - ) - raise NotImplementedError( - "The requested Clebsch-Gordan coefficients are not implemented, please install cuequivariance; pip install cuequivariance" - ) from e - - current_ir = wigners[0][0] - out = [] - stack = torch.tensor([]) - - for ir, _, base_o3 in wigners: - if ir in irreps_out and ir == current_ir: - stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) - last_ir = current_ir - elif ir in irreps_out and ir != current_ir: - if len(stack) != 0: - out += [last_ir, stack] - stack = base_o3.squeeze().unsqueeze(-1) - current_ir, last_ir = ir, ir - else: - current_ir = ir - out += [last_ir, stack] - return out - - -if CUET_AVAILABLE: - - def compute_U_cueq(irreps_in, irreps_out, correlation=2): - U = [] - irreps_in = cue.Irreps(O3_e3nn, str(irreps_in)) - irreps_out = cue.Irreps(O3_e3nn, str(irreps_out)) - for _, ir in irreps_out: - ir_str = str(ir) - U.append(ir_str) - U_matrix = cue.reduced_symmetric_tensor_product_basis( - irreps_in, correlation, keep_ir=ir, layout=cue.ir_mul - ).array - U_matrix = U_matrix.reshape(ir.dim, *([irreps_in.dim] * correlation), -1) - if ir.dim == 1: - U_matrix = U_matrix[0] - U.append(torch.tensor(U_matrix)) - return U - - class O3_e3nn(cue.O3): - def __mul__( # pylint: disable=no-self-argument - rep1: "O3_e3nn", rep2: "O3_e3nn" - ) -> Iterator["O3_e3nn"]: - return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] - - @classmethod - def clebsch_gordan( - cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" - ) -> np.ndarray: - rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) - - if rep1.p * rep2.p == rep3.p: - return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( - rep3.dim - ) - return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) - - def __lt__( # pylint: disable=no-self-argument - rep1: "O3_e3nn", rep2: "O3_e3nn" - ) -> bool: - rep2 = rep1._from(rep2) - return (rep1.l, rep1.p) < (rep2.l, rep2.p) - - @classmethod - def iterator(cls) -> Iterator["O3_e3nn"]: - for l in itertools.count(0): - yield O3_e3nn(l=l, p=1 * (-1) ** l) - yield O3_e3nn(l=l, p=-1 * (-1) ** l) - -else: - - class O3_e3nn: - pass - - print( - "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled." - ) +########################################################################################### +# Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) +# Authors: Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import collections +import itertools +import os +from typing import Iterator, List, Union + +import numpy as np +import torch +from e3nn import o3 + +try: + import cuequivariance as cue + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +USE_CUEQ_CG = os.environ.get("MACE_USE_CUEQ_CG", "0").lower() in ( + "1", + "true", + "yes", + "y", +) + +_TP = collections.namedtuple("_TP", "op, args") +_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") + + +def _wigner_nj( + irrepss: List[o3.Irreps], + normalization: str = "component", + filter_ir_mid=None, + dtype=None, +): + irrepss = [o3.Irreps(irreps) for irreps in irrepss] + if filter_ir_mid is not None: + filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] + + if len(irrepss) == 1: + (irreps,) = irrepss + ret = [] + e = torch.eye(irreps.dim, dtype=dtype) + i = 0 + for mul, ir in irreps: + for _ in range(mul): + sl = slice(i, i + ir.dim) + ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] + i += ir.dim + return ret + + *irrepss_left, irreps_right = irrepss + ret = [] + for ir_left, path_left, C_left in _wigner_nj( + irrepss_left, + normalization=normalization, + filter_ir_mid=filter_ir_mid, + dtype=dtype, + ): + i = 0 + for mul, ir in irreps_right: + for ir_out in ir_left * ir: + if filter_ir_mid is not None and ir_out not in filter_ir_mid: + continue + + C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) + if normalization == "component": + C *= ir_out.dim**0.5 + if normalization == "norm": + C *= ir_left.dim**0.5 * ir.dim**0.5 + + C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) + C = C.reshape( + ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim + ) + for u in range(mul): + E = torch.zeros( + ir_out.dim, + *(irreps.dim for irreps in irrepss_left), + irreps_right.dim, + dtype=dtype, + ) + sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) + E[..., sl] = C + ret += [ + ( + ir_out, + _TP( + op=(ir_left, ir, ir_out), + args=( + path_left, + _INPUT(len(irrepss_left), sl.start, sl.stop), + ), + ), + E, + ) + ] + i += mul * ir.dim + return sorted(ret, key=lambda x: x[0]) + + +def U_matrix_real( + irreps_in: Union[str, o3.Irreps], + irreps_out: Union[str, o3.Irreps], + correlation: int, + normalization: str = "component", + filter_ir_mid=None, + dtype=None, + use_cueq_cg=None, +): + irreps_out = o3.Irreps(irreps_out) + irrepss = [o3.Irreps(irreps_in)] * correlation + + if correlation == 4: + filter_ir_mid = [(i, 1 if i % 2 == 0 else -1) for i in range(12)] + + if use_cueq_cg is None: + use_cueq_cg = USE_CUEQ_CG + if use_cueq_cg and CUET_AVAILABLE: + return compute_U_cueq(irreps_in, irreps_out=irreps_out, correlation=correlation) + + try: + wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) + except NotImplementedError as e: + if CUET_AVAILABLE: + return compute_U_cueq( + irreps_in, irreps_out=irreps_out, correlation=correlation + ) + raise NotImplementedError( + "The requested Clebsch-Gordan coefficients are not implemented, please install cuequivariance; pip install cuequivariance" + ) from e + + current_ir = wigners[0][0] + out = [] + stack = torch.tensor([]) + + for ir, _, base_o3 in wigners: + if ir in irreps_out and ir == current_ir: + stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) + last_ir = current_ir + elif ir in irreps_out and ir != current_ir: + if len(stack) != 0: + out += [last_ir, stack] + stack = base_o3.squeeze().unsqueeze(-1) + current_ir, last_ir = ir, ir + else: + current_ir = ir + out += [last_ir, stack] + return out + + +if CUET_AVAILABLE: + + def compute_U_cueq(irreps_in, irreps_out, correlation=2): + U = [] + irreps_in = cue.Irreps(O3_e3nn, str(irreps_in)) + irreps_out = cue.Irreps(O3_e3nn, str(irreps_out)) + for _, ir in irreps_out: + ir_str = str(ir) + U.append(ir_str) + U_matrix = cue.reduced_symmetric_tensor_product_basis( + irreps_in, correlation, keep_ir=ir, layout=cue.ir_mul + ).array + U_matrix = U_matrix.reshape(ir.dim, *([irreps_in.dim] * correlation), -1) + if ir.dim == 1: + U_matrix = U_matrix[0] + U.append(torch.tensor(U_matrix)) + return U + + class O3_e3nn(cue.O3): + def __mul__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> Iterator["O3_e3nn"]: + return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] + + @classmethod + def clebsch_gordan( + cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" + ) -> np.ndarray: + rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) + + if rep1.p * rep2.p == rep3.p: + return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( + rep3.dim + ) + return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) + + def __lt__( # pylint: disable=no-self-argument + rep1: "O3_e3nn", rep2: "O3_e3nn" + ) -> bool: + rep2 = rep1._from(rep2) + return (rep1.l, rep1.p) < (rep2.l, rep2.p) + + @classmethod + def iterator(cls) -> Iterator["O3_e3nn"]: + for l in itertools.count(0): + yield O3_e3nn(l=l, p=1 * (-1) ** l) + yield O3_e3nn(l=l, p=-1 * (-1) ** l) + +else: + + class O3_e3nn: + pass + + print( + "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled." + ) diff --git a/mace-bench/3rdparty/mace/mace/tools/checkpoint.py b/mace-bench/3rdparty/mace/mace/tools/checkpoint.py index 2925140..81161cc 100644 --- a/mace-bench/3rdparty/mace/mace/tools/checkpoint.py +++ b/mace-bench/3rdparty/mace/mace/tools/checkpoint.py @@ -1,227 +1,227 @@ -########################################################################################### -# Checkpointing -# Authors: Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import dataclasses -import logging -import os -import re -from typing import Dict, List, Optional, Tuple - -import torch - -from .torch_tools import TensorDict - -Checkpoint = Dict[str, TensorDict] - - -@dataclasses.dataclass -class CheckpointState: - model: torch.nn.Module - optimizer: torch.optim.Optimizer - lr_scheduler: torch.optim.lr_scheduler.ExponentialLR - - -class CheckpointBuilder: - @staticmethod - def create_checkpoint(state: CheckpointState) -> Checkpoint: - return { - "model": state.model.state_dict(), - "optimizer": state.optimizer.state_dict(), - "lr_scheduler": state.lr_scheduler.state_dict(), - } - - @staticmethod - def load_checkpoint( - state: CheckpointState, checkpoint: Checkpoint, strict: bool - ) -> None: - state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore - state.optimizer.load_state_dict(checkpoint["optimizer"]) - state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) - - -@dataclasses.dataclass -class CheckpointPathInfo: - path: str - tag: str - epochs: int - swa: bool - - -class CheckpointIO: - def __init__( - self, directory: str, tag: str, keep: bool = False, swa_start: int = None - ) -> None: - self.directory = directory - self.tag = tag - self.keep = keep - self.old_path: Optional[str] = None - self.swa_start = swa_start - - self._epochs_string = "_epoch-" - self._filename_extension = "pt" - - def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: - if swa_start is not None and epochs >= swa_start: - return ( - self.tag - + self._epochs_string - + str(epochs) - + "_swa" - + "." - + self._filename_extension - ) - return ( - self.tag - + self._epochs_string - + str(epochs) - + "." - + self._filename_extension - ) - - def _list_file_paths(self) -> List[str]: - if not os.path.isdir(self.directory): - return [] - all_paths = [ - os.path.join(self.directory, f) for f in os.listdir(self.directory) - ] - return [path for path in all_paths if os.path.isfile(path)] - - def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: - filename = os.path.basename(path) - regex = re.compile( - rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" - ) - regex2 = re.compile( - rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" - ) - match = regex.match(filename) - match2 = regex2.match(filename) - swa = False - if not match: - if not match2: - return None - match = match2 - swa = True - - return CheckpointPathInfo( - path=path, - tag=match.group("tag"), - epochs=int(match.group("epochs")), - swa=swa, - ) - - def _get_latest_checkpoint_path(self, swa) -> Optional[str]: - all_file_paths = self._list_file_paths() - checkpoint_info_list = [ - self._parse_checkpoint_path(path) for path in all_file_paths - ] - selected_checkpoint_info_list = [ - info for info in checkpoint_info_list if info and info.tag == self.tag - ] - - if len(selected_checkpoint_info_list) == 0: - logging.warning( - f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" - ) - return None - - selected_checkpoint_info_list_swa = [] - selected_checkpoint_info_list_no_swa = [] - - for ckp in selected_checkpoint_info_list: - if ckp.swa: - selected_checkpoint_info_list_swa.append(ckp) - else: - selected_checkpoint_info_list_no_swa.append(ckp) - if swa: - try: - latest_checkpoint_info = max( - selected_checkpoint_info_list_swa, key=lambda info: info.epochs - ) - except ValueError: - logging.warning( - "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." - ) - else: - latest_checkpoint_info = max( - selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs - ) - return latest_checkpoint_info.path - - def save( - self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False - ) -> None: - if not self.keep and self.old_path and not keep_last: - logging.debug(f"Deleting old checkpoint file: {self.old_path}") - os.remove(self.old_path) - - filename = self._get_checkpoint_filename(epochs, self.swa_start) - path = os.path.join(self.directory, filename) - logging.debug(f"Saving checkpoint: {path}") - os.makedirs(self.directory, exist_ok=True) - torch.save(obj=checkpoint, f=path) - self.old_path = path - - def load_latest( - self, swa: Optional[bool] = False, device: Optional[torch.device] = None - ) -> Optional[Tuple[Checkpoint, int]]: - path = self._get_latest_checkpoint_path(swa=swa) - if path is None: - return None - - return self.load(path, device=device) - - def load( - self, path: str, device: Optional[torch.device] = None - ) -> Tuple[Checkpoint, int]: - checkpoint_info = self._parse_checkpoint_path(path) - - if checkpoint_info is None: - raise RuntimeError(f"Cannot find path '{path}'") - - logging.info(f"Loading checkpoint: {checkpoint_info.path}") - return ( - torch.load(f=checkpoint_info.path, map_location=device), - checkpoint_info.epochs, - ) - - -class CheckpointHandler: - def __init__(self, *args, **kwargs) -> None: - self.io = CheckpointIO(*args, **kwargs) - self.builder = CheckpointBuilder() - - def save( - self, state: CheckpointState, epochs: int, keep_last: bool = False - ) -> None: - checkpoint = self.builder.create_checkpoint(state) - self.io.save(checkpoint, epochs, keep_last) - - def load_latest( - self, - state: CheckpointState, - swa: Optional[bool] = False, - device: Optional[torch.device] = None, - strict=False, - ) -> Optional[int]: - result = self.io.load_latest(swa=swa, device=device) - if result is None: - return None - - checkpoint, epochs = result - self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) - return epochs - - def load( - self, - state: CheckpointState, - path: str, - strict=False, - device: Optional[torch.device] = None, - ) -> int: - checkpoint, epochs = self.io.load(path, device=device) - self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) - return epochs +########################################################################################### +# Checkpointing +# Authors: Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import dataclasses +import logging +import os +import re +from typing import Dict, List, Optional, Tuple + +import torch + +from .torch_tools import TensorDict + +Checkpoint = Dict[str, TensorDict] + + +@dataclasses.dataclass +class CheckpointState: + model: torch.nn.Module + optimizer: torch.optim.Optimizer + lr_scheduler: torch.optim.lr_scheduler.ExponentialLR + + +class CheckpointBuilder: + @staticmethod + def create_checkpoint(state: CheckpointState) -> Checkpoint: + return { + "model": state.model.state_dict(), + "optimizer": state.optimizer.state_dict(), + "lr_scheduler": state.lr_scheduler.state_dict(), + } + + @staticmethod + def load_checkpoint( + state: CheckpointState, checkpoint: Checkpoint, strict: bool + ) -> None: + state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore + state.optimizer.load_state_dict(checkpoint["optimizer"]) + state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + +@dataclasses.dataclass +class CheckpointPathInfo: + path: str + tag: str + epochs: int + swa: bool + + +class CheckpointIO: + def __init__( + self, directory: str, tag: str, keep: bool = False, swa_start: int = None + ) -> None: + self.directory = directory + self.tag = tag + self.keep = keep + self.old_path: Optional[str] = None + self.swa_start = swa_start + + self._epochs_string = "_epoch-" + self._filename_extension = "pt" + + def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str: + if swa_start is not None and epochs >= swa_start: + return ( + self.tag + + self._epochs_string + + str(epochs) + + "_swa" + + "." + + self._filename_extension + ) + return ( + self.tag + + self._epochs_string + + str(epochs) + + "." + + self._filename_extension + ) + + def _list_file_paths(self) -> List[str]: + if not os.path.isdir(self.directory): + return [] + all_paths = [ + os.path.join(self.directory, f) for f in os.listdir(self.directory) + ] + return [path for path in all_paths if os.path.isfile(path)] + + def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]: + filename = os.path.basename(path) + regex = re.compile( + rf"^(?P.+){self._epochs_string}(?P\d+)\.{self._filename_extension}$" + ) + regex2 = re.compile( + rf"^(?P.+){self._epochs_string}(?P\d+)_swa\.{self._filename_extension}$" + ) + match = regex.match(filename) + match2 = regex2.match(filename) + swa = False + if not match: + if not match2: + return None + match = match2 + swa = True + + return CheckpointPathInfo( + path=path, + tag=match.group("tag"), + epochs=int(match.group("epochs")), + swa=swa, + ) + + def _get_latest_checkpoint_path(self, swa) -> Optional[str]: + all_file_paths = self._list_file_paths() + checkpoint_info_list = [ + self._parse_checkpoint_path(path) for path in all_file_paths + ] + selected_checkpoint_info_list = [ + info for info in checkpoint_info_list if info and info.tag == self.tag + ] + + if len(selected_checkpoint_info_list) == 0: + logging.warning( + f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'" + ) + return None + + selected_checkpoint_info_list_swa = [] + selected_checkpoint_info_list_no_swa = [] + + for ckp in selected_checkpoint_info_list: + if ckp.swa: + selected_checkpoint_info_list_swa.append(ckp) + else: + selected_checkpoint_info_list_no_swa.append(ckp) + if swa: + try: + latest_checkpoint_info = max( + selected_checkpoint_info_list_swa, key=lambda info: info.epochs + ) + except ValueError: + logging.warning( + "No SWA checkpoint found, while SWA is enabled. Compare the swa_start parameter and the latest checkpoint." + ) + else: + latest_checkpoint_info = max( + selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs + ) + return latest_checkpoint_info.path + + def save( + self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False + ) -> None: + if not self.keep and self.old_path and not keep_last: + logging.debug(f"Deleting old checkpoint file: {self.old_path}") + os.remove(self.old_path) + + filename = self._get_checkpoint_filename(epochs, self.swa_start) + path = os.path.join(self.directory, filename) + logging.debug(f"Saving checkpoint: {path}") + os.makedirs(self.directory, exist_ok=True) + torch.save(obj=checkpoint, f=path) + self.old_path = path + + def load_latest( + self, swa: Optional[bool] = False, device: Optional[torch.device] = None + ) -> Optional[Tuple[Checkpoint, int]]: + path = self._get_latest_checkpoint_path(swa=swa) + if path is None: + return None + + return self.load(path, device=device) + + def load( + self, path: str, device: Optional[torch.device] = None + ) -> Tuple[Checkpoint, int]: + checkpoint_info = self._parse_checkpoint_path(path) + + if checkpoint_info is None: + raise RuntimeError(f"Cannot find path '{path}'") + + logging.info(f"Loading checkpoint: {checkpoint_info.path}") + return ( + torch.load(f=checkpoint_info.path, map_location=device), + checkpoint_info.epochs, + ) + + +class CheckpointHandler: + def __init__(self, *args, **kwargs) -> None: + self.io = CheckpointIO(*args, **kwargs) + self.builder = CheckpointBuilder() + + def save( + self, state: CheckpointState, epochs: int, keep_last: bool = False + ) -> None: + checkpoint = self.builder.create_checkpoint(state) + self.io.save(checkpoint, epochs, keep_last) + + def load_latest( + self, + state: CheckpointState, + swa: Optional[bool] = False, + device: Optional[torch.device] = None, + strict=False, + ) -> Optional[int]: + result = self.io.load_latest(swa=swa, device=device) + if result is None: + return None + + checkpoint, epochs = result + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + return epochs + + def load( + self, + state: CheckpointState, + path: str, + strict=False, + device: Optional[torch.device] = None, + ) -> int: + checkpoint, epochs = self.io.load(path, device=device) + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + return epochs diff --git a/mace-bench/3rdparty/mace/mace/tools/compile.py b/mace-bench/3rdparty/mace/mace/tools/compile.py index 59f7450..0328206 100644 --- a/mace-bench/3rdparty/mace/mace/tools/compile.py +++ b/mace-bench/3rdparty/mace/mace/tools/compile.py @@ -1,95 +1,95 @@ -from contextlib import contextmanager -from functools import wraps -from typing import Callable, Tuple - -try: - import torch._dynamo as dynamo -except ImportError: - dynamo = None -from e3nn import get_optimization_defaults, set_optimization_defaults -from torch import autograd, nn -from torch.fx import symbolic_trace - -ModuleFactory = Callable[..., nn.Module] -TypeTuple = Tuple[type, ...] - - -@contextmanager -def disable_e3nn_codegen(): - """Context manager that disables the legacy PyTorch code generation used in e3nn.""" - init_val = get_optimization_defaults()["jit_script_fx"] - set_optimization_defaults(jit_script_fx=False) - yield - set_optimization_defaults(jit_script_fx=init_val) - - -def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: - """Function transform that prepares a MACE module for torch.compile - - Args: - func (ModuleFactory): A function that creates an nn.Module - allow_autograd (bool, optional): Force inductor compiler to inline call to - `torch.autograd.grad`. Defaults to True. - - Returns: - ModuleFactory: Decorated function that creates a torch.compile compatible module - """ - if allow_autograd: - dynamo.allow_in_graph(autograd.grad) - else: - dynamo.disallow_in_graph(autograd.grad) - - @wraps(func) - def wrapper(*args, **kwargs): - with disable_e3nn_codegen(): - model = func(*args, **kwargs) - - model = simplify(model) - return model - - return wrapper - - -_SIMPLIFY_REGISTRY = set() - - -def simplify_if_compile(module: nn.Module) -> nn.Module: - """Decorator to register a module for symbolic simplification - - The decorated module will be simplifed using `torch.fx.symbolic_trace`. - This constrains the module to not have any dynamic control flow, see: - - https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing - - Args: - module (nn.Module): the module to register - - Returns: - nn.Module: registered module - """ - _SIMPLIFY_REGISTRY.add(module) - return module - - -def simplify(module: nn.Module) -> nn.Module: - """Recursively searches for registered modules to simplify with - `torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler. - - Modules are registered with the `simplify_if_compile` decorator and - - Args: - module (nn.Module): the module to simplify - - Returns: - nn.Module: the simplified module - """ - simplify_types = tuple(_SIMPLIFY_REGISTRY) - - for name, child in module.named_children(): - if isinstance(child, simplify_types): - traced = symbolic_trace(child) - setattr(module, name, traced) - else: - simplify(child) - - return module +from contextlib import contextmanager +from functools import wraps +from typing import Callable, Tuple + +try: + import torch._dynamo as dynamo +except ImportError: + dynamo = None +from e3nn import get_optimization_defaults, set_optimization_defaults +from torch import autograd, nn +from torch.fx import symbolic_trace + +ModuleFactory = Callable[..., nn.Module] +TypeTuple = Tuple[type, ...] + + +@contextmanager +def disable_e3nn_codegen(): + """Context manager that disables the legacy PyTorch code generation used in e3nn.""" + init_val = get_optimization_defaults()["jit_script_fx"] + set_optimization_defaults(jit_script_fx=False) + yield + set_optimization_defaults(jit_script_fx=init_val) + + +def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory: + """Function transform that prepares a MACE module for torch.compile + + Args: + func (ModuleFactory): A function that creates an nn.Module + allow_autograd (bool, optional): Force inductor compiler to inline call to + `torch.autograd.grad`. Defaults to True. + + Returns: + ModuleFactory: Decorated function that creates a torch.compile compatible module + """ + if allow_autograd: + dynamo.allow_in_graph(autograd.grad) + else: + dynamo.disallow_in_graph(autograd.grad) + + @wraps(func) + def wrapper(*args, **kwargs): + with disable_e3nn_codegen(): + model = func(*args, **kwargs) + + model = simplify(model) + return model + + return wrapper + + +_SIMPLIFY_REGISTRY = set() + + +def simplify_if_compile(module: nn.Module) -> nn.Module: + """Decorator to register a module for symbolic simplification + + The decorated module will be simplifed using `torch.fx.symbolic_trace`. + This constrains the module to not have any dynamic control flow, see: + + https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing + + Args: + module (nn.Module): the module to register + + Returns: + nn.Module: registered module + """ + _SIMPLIFY_REGISTRY.add(module) + return module + + +def simplify(module: nn.Module) -> nn.Module: + """Recursively searches for registered modules to simplify with + `torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler. + + Modules are registered with the `simplify_if_compile` decorator and + + Args: + module (nn.Module): the module to simplify + + Returns: + nn.Module: the simplified module + """ + simplify_types = tuple(_SIMPLIFY_REGISTRY) + + for name, child in module.named_children(): + if isinstance(child, simplify_types): + traced = symbolic_trace(child) + setattr(module, name, traced) + else: + simplify(child) + + return module diff --git a/mace-bench/3rdparty/mace/mace/tools/default_keys.py b/mace-bench/3rdparty/mace/mace/tools/default_keys.py index f062948..769867d 100644 --- a/mace-bench/3rdparty/mace/mace/tools/default_keys.py +++ b/mace-bench/3rdparty/mace/mace/tools/default_keys.py @@ -1,21 +1,21 @@ -from __future__ import annotations - -from enum import Enum - - -class DefaultKeys(Enum): - ENERGY = "REF_energy" - FORCES = "REF_forces" - STRESS = "REF_stress" - VIRIALS = "REF_virials" - DIPOLE = "dipole" - HEAD = "head" - CHARGES = "REF_charges" - - @staticmethod - def keydict() -> dict[str, str]: - key_dict = {} - for member in DefaultKeys: - key_name = f"{member.name.lower()}_key" - key_dict[key_name] = member.value - return key_dict +from __future__ import annotations + +from enum import Enum + + +class DefaultKeys(Enum): + ENERGY = "REF_energy" + FORCES = "REF_forces" + STRESS = "REF_stress" + VIRIALS = "REF_virials" + DIPOLE = "dipole" + HEAD = "head" + CHARGES = "REF_charges" + + @staticmethod + def keydict() -> dict[str, str]: + key_dict = {} + for member in DefaultKeys: + key_name = f"{member.name.lower()}_key" + key_dict[key_name] = member.value + return key_dict diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__init__.py b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__init__.py index fb7c72a..5163777 100644 --- a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__init__.py +++ b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__init__.py @@ -1,3 +1,3 @@ -from .lmdb_dataset_tools import AseDBDataset - -__all__ = ["AseDBDataset"] +from .lmdb_dataset_tools import AseDBDataset + +__all__ = ["AseDBDataset"] diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index e9ea4f9d9d0a340ef61b4b146cfb85cc56ecf628..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 274 zcmYjLI}XAy3{46W2$h9_J9K~ytO)U2SP&9h7b`S@v`U&P2}8LN10$E|%ET3zNLwNC zBtP5l*|PO;I6xkc=Uwc4e$~xC0^cls*s&*u7?w!k7-PhO3yuTUIijfh#@>u@6|OAb zuD8@t%blrL_faPvNySdUS{0nu(l%bz1E5j@+>z<}*@ixX2(36#^d_jzQtGROPYRYfU^V9)Mi?Zs^m1bAT7xnkH J{vbqy;0v@^O4R@W diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index 1a9234ca0cdf6d1a333108228eaf3d0fb41ca264..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 276 zcmey&%ge<81oJIdWOxGU#~=<2FhLog#ej^d48aV+jNS}hj75wJ48ctLj73c8%$h7O z8G(|TjJJ3ki&I^kToOwXi&IOAn1O~wAe^MsyyOgh^p|m|2vZk(vv$O+P+9GcU6wK3=b&@)n0pZhlH>PO4oIC(t&KlZyp_#0O?Z VM#j4gIuE!dFK|gUvKO%f1p$vrOrih) diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-310.pyc deleted file mode 100644 index 0a8a893bb9bc8b2bfb1fbb1b3a1c45ac184835e7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 26544 zcma)ld5|2}d0$`CbMA@7;v_KOi2;EHNx5VRAsK)GK|vyyFaS!{Q1E1Sx_5U5b1Yu> z0@%^4MF*5cxlE}NS0rCbc4fygty0Q=I8L0ziLWGf6332n*lEXc$}U&Rjw@9urhKd< zBER4Fy63_|@+{`f>v#1#zVCbA^&UDC6FCE)ul}QVRxfTE#vk%w_>;lIS;H`cA2toc zGs=c%dREhHS!K&)+-}-sTYjCgBfqI~N`Bq4i(jXiZe_}uR<@iqjcyJpsb;Q~FXvl@ za=|n{X0+!W!=KnVYa0ARzKeY1wI<7xrr}R|wwHd}EayLNw9VVh&orl6Q{^csn{7_F zX38@X&oyUTd&+xSbLF|#-tykKck|_W)UliU;@|!8?}7ODVELfrA1WV~--pT%$?uW! z5r40j_X^&`+gACgcfy^*|iNA6nX6L@;m zdkjyH;ps^{J?=e$rziYJ(WX;AiIn4B2`MH2Bx;}X?5jrURQ0>CK)GwH^`O}5c%jDU?F!YETRej zPN=H2(D#bVs?(AdmaBDDTlHIL*YlUqjx~OPQ^VF_I;t|iIE<@zmWT5B1XxK^3`^`6IR1|ryVSotjIq1!ue>jx)jKqwfu0k;~_a! zt75*QDf#g#YpS#6t8k-aDHrWT*``{rNFJI@x2ku7de=wPS!s5b-ofwXYPedm(RzC$ z%ABt@o7JVJAEmEttoa|Us>r!ksdcWQLbleqRaNzBJB;!-{EbShy0%ttuk0N3>b3B7 zbn0XgL9vI)2;}+t^_s<@#Ge#?`8`iGUqz&6UNEko-n4pF&$wl%r+a44I$&(tq0zHL z^RA`ldiILBY4>cD+q3EPoSt>VQV-pnN6ffo29JE$_?u?WU@e5svj%dko_)(iTS$G* zxQUhlv}Sb^Kzqe+S6hG^0B5nY)W8Z9F^nLWE8=rbtpsNz#wV7k=rU)9EA$HtjUNRt zeymXy>lv*1wfgb~rbXJ;;3}^r7PS4_MLq<@Pp&EdRt2q`VIm`&x*vRdoKmb++r=e- zzf34Wh4tCxlYlb?+^M$Kn*Patrv`2MZC|Zy^fTfH&lF#;by}@XyQ1??7CW_Po{Ar? z58f}x4WO*8E9JMt<_1vi<~m^3E8br9+r`^{ah+&c3@}cBc34#_es~&i<2R*et@UPD zN7s-O+z6x_fYV$Z<%6)=uKAG-NGYYF{LnJx+dxX z4zE}x;nzhfpjYhxAlpi)U_843F}-)7ywFDg9f>7qFY(|!$QU*j?lXQf(7geLtcLNM zz54R9A490#T!?qr3P618pv~yA?$*U2Y$?9vw-){I)T`%ikma70(-FByqfdWI; z)a9B4!Nx1K<8ENWfL!af za9x4W#GQ&KLwXEs9U^sOFjqWIm@AzrURp+b#U|byVoK3V^sBhaRpl6yA)YLH9Xy9@ zA5*oV`!FCSL!jGl`GT&Y;fi%+q4ZRtZ=h?ydv1vsFS;Nz&TOf>in*p#f8~CQJyZsovR&$8WFJYpa~M0f8eZ)oX+VRKSpA+FAfUV0%ekqOgM) zB{#C#Ymtqnq6}zWgg{11|(KG*oPSQapF}46{PBi zkQR_dc4i%O-kLKVOFfFm8onbRcH=A&zJQ2tluaM;cDMBpIE7$9Fs3Z4fa=ALD@ zE+aDN7I^l(p5vv!(K_C|m+`WQrM!LQcoB2`v^U{Rz6p9%&LF;DbH8}Xdb8f1H;r=6 zo5Ro!MEMs$C*v`7kMp4zKgC?8;33Er2~$E|79tk}X1Tr+IozX1xDimi0rA|z1Go>P zVyu89R!k4^26>Vbq1D6F0VA}*j~os`UK=2)2hiaamN(mzVal^S`%NQsb;-NnYVMkM zCD%jwN}6@HGG5Aaw=K08X18)-zEO~tQU?q#?PcD!yzBvE+uXFmY-6G^*|R*?%e@Wm zrIPlXs|LTNeD~_5c0H_Dn_%DuC^pKu5D)a2i^n-;U~)Vgaq_wnhO&=1*osxHMyN_Y zho0hg0|)%RdduBSIm@MRd+MiFA7+TMTN~Mm0RVoY5HakGm;Q57qt;HzathX!5U%q$^%xVFc z*8vYuuiXHTnyPjD5WMoEghK2d%yMm#h%Fg10Ra#YI3($_pXPlpMXLE*7%A;;R=?UZJL;E;1OS&*^#__~QrPi#+Mr zv^6?QasJL;qL~(d+G`7Kuc}mah2yk|B*8x9$pR@(vPkyy`?yfgiltU(($ z4~LnzK$*N0WT;n2g*=z7{s<`0mu$YJj&=Nuu?*D*$TZ4&piZP^kjU(omfJ(67b&)9!I23IyUuL z<1Jy3Wu&0x7^zU~ZSx1Ln*uT;tq|V1e$oo52zQ995?8$i@4>uy#+q5~Do? zTqYD->tGxyy|%G9p$~j35($sVa&}L=47v(#H<=}2vy+XA3W{Q@9<-`qZFM20=IRW3 z7`fn@phS`(tLHeHZ)bA49#kkeL@vZSh^*=nNo{skpa)us(zmOs%?L!$V0}rbfqGe2 z)P0Lm!D>}V#91V((+Gy1M==h!*We)l~mRX~|Z%Xx3ow#>{u3o*y(?%j3L?-TeyLo^+E7LkPhs>7ti zJ@6%;H9}}nLJQo=et~ePQQyq;EaLH1Kt6RP;3x1c*^3VZ^mm5=9hK1Vg8<#B)V(_+ zkT#(AGDokWpFe}I`Y3|#$>C9wgH9k|2+RktZOaVH4vwuDApef7x3o)W3wxt?N3V#_ z_64pB!)^XFl7&ysK@`tnX);)#?*0dL0XjClr&El7$5Jn_)$<6DE^%&%N&2fy&nwau z^#p_C2qNsQUZ+;6sB=tuiNS{%xC|B<5M)%H!C=FYm`qj)87ue(TmW+0IV(4xvn@Mk zT6V!KJf>KphVRIS1PUKB_yEEY?t5B*Am`=ZER-ET<;_sj{H6mA+w~^BDR9{7&w##_ zGXRL$$i2G01YxH8h43p~-5-808Y!iu?#EicK%FXkA62)Ds+LLLIn zi|m7t3Q%dBKw$sjgaTC-@DrK!Ar7lg6Mh3(GA>F3@wkSOcUOXt(%k$q>guVhGbXD? zw9)e?Gb!(3?yut;a9VQa5zyoAo?YHm%IMZgY(rY>FeXsbhh4SCv>OOY*)hCVO=bzs z^GyRJk0xSZs`<@k0#E`~isht}EyR8fUjZ%R{G??~6wGN`F{g&_$S0+rG$T2T>sc#n zXC3}|AND>XD@_cafaLu$@P1Y#?^nD%-W(-w?-fYq^N6Lhmbd7=D&>!QMU>B=CCj_) zeTVmm_b9ASS?@9LagoKZcqQ)y>g2tv-ZfaKP9e8|_-o#2?r@X)JofS!c8u4@93lg94XBxBK`Tos=f|VM%dco#(-PDH-Tx;dZK80aVG0@^C_3ukzJ|<8q90f)HlR=|E>&wc zz)eU+u~N~7wZ5y7rv{1857s&m3ZPYb`K8M*KeO<3V)lvkmDgzhg8BgrDp-K0q8^o? zdbcCE*9gwZCxZyTQ=~6HjtD7r!=#D?^caShp0ncg+%KC~$ZDs5-25DbiXODh6;SVn zt)2><2C28vNJ9+aAH@?`P2WaZ1>!;@)5zl4d7IQd*Pb6KPirr0xNoz?q_mZk7R60o zLa0`@CbntiO#@Cc-Ou-x9aIE!Nm?B+r?dc2r;7F9KUgNQY%m&9n~E~DR6v-6Mi#mz zZwXv1buU!>_M3yYa=nFWFv$3KYW`ZhUBW(6^-!;v%x7`ASzXaOU8rY@7q7we;Rawt z>k1Y*zs07bEbM8OYOv9+DHzB_^-#M3i9U5()kD7tqOL!gk^?(8%92f1-;G4|DF$yK zSS(Gdn|u&bM0H2f7R(rt1H)ijYeMcIE3*CetthqRgGY|+@D7y?UI3en552eVfv*d& z7g7PjC`ja4w#S;P%ixm?dJJ@bsaKBN8{oB90wv^abo>18C>oqa;7-C2qW`{brQEMs zuDj#rUCRN(9H-+qvoo7>oIy(FYttFmcDd#59Y*~A#=A55#Blok!eDKsMLWiIR-N><%ZEYOe0=$Rx=?MIXL%6xgKAU7%2m zdrJ$ji|C{)o!ikA>}Y!Ql{I(*1d7Z!^taF`?c^!5CX~&pdpkL>nDKKh-WO5YFbMfs ze1SBeROng4Q<6$;+olm(4JcZ+QrpHRDH3QHp3{vC#HIs1bJJ~4>y)ib4|cLm7gWrx zWW(H6-orEEg}Sks?zxQ#C}j?9W_lS9e3NHG-SGnAj{XfNd#;ySG49$>C7-{!&~tkk zn5xsg3{2GvXfy4(&(dml4gE@cX~b<+1|5Y;7;$%N%FE*I*uvx{6!51iX@N&Ea)uF<=(_TV>8#D#`B)7sh;yA#@3uSxy=!Qqo&d+_`Eiu zBY*E8Kb+s%7w&Hy=%o=J^ronU+J+wv&F+muFt7h2f42_z9Bo5nZEsq)-LIS5rS3zR zhu=iqhqjI+b+?WR2vWg4vluzAU48M2Xn!KJqJ9WV7^TFr8l_rcz2&PfAW40Z=~Q@N zzk`h?3`QWGP)B?U8N3ne=syWpHj%^mFQiZp_Qf_*2AbdkrYJ<*DyG!(8>Q9@SL zua7YuhH= zO#256#&(sTMA|XZoy&+!XA0J|<$`X}(PgAzG&&Vm{S zay3kC@8zhHfp#E?Lr(y0?+6Xk!%(a22FSWGLhng?$SDcErwk4lhfqX*ZlJ1<)!=8D~+<9QAgsx$bEP1MCpH%oL0+nF{44z=YyNlu~+D zrH6xA_eN|T*g6JFdI1mOtuza7B^hha z0x`9?8N?(IWA2H}UiMm7;|5dda*XNh@*ssxPR@nAO?Y$mD{rncNLjZIUhJMB!Eg$)#Cz0{X%u>n}Y z2YU`}543w=O93Kq&*YA3lz;HE0LHXFf^S8NN z)2MmS8e6C%qeLWLsD4Ear$}|pa$24jK|TX7%;tKl9f$x|Bfck}y@+u4sgdFN9gN|7 z$Vbu#+Nz7Wo@5ssBKMvZ{Mz7D!J(>$ywqk6&p^db6SdQA%S&(OR*X%^(VLW_iJpKA z@VEyzCwddm%BF$tlf6kH6VLVvAc1yoVp$WuQZK({^(Kd(8+I=z=0t<;3_$mMukh_? zzn4<4BktlawNJl2-^=4Y=}%ajQ+$`W8C;j1r-F9R%^RM(AKbclj(}$-AH0dTZ<4!p zP_wf+)0^(0OpgZ6PiumUyi{)n967wJ(W;4Z8TP+70WS6nFsxp^`8l?A{lX{g&B@K# z%{`C?Cc`x3fUky`t!!_ik?T!vTU&WZ3UlqpWDdYVZ_Z&3Ch?x%>dl5AT#d=bls5&V zXA+-2U@Tj=jNn&EJiVE%Y3q* z1HEa_#e7ZQ?8??Tx2XOj78)D@Ryn4QjCFGX9iu=AVW?ulj(O^uP|%C_ZQNv!g*DC~ z(i3ME5L|3b(jFj^VEbX?mPIlb!$dWGIgt-Q(YCD2c&U{gS+AZK<|nKG(I5Jr|9R zhx1Fy9S|ESXagI?eZ~Ve9WY||8E$eo+i>yPqI4{lKh%Fi`SG4Ug|y#CPu0_ibjJjb z`=&f`w<*nl&C|JJI@PTCe_{JVW4S$PW%_x>jv$CqaPx&8FbyA)B`g+PgVtMP6RfDk ziJV5K-Y)GG<|K0H`Ko?{tq?fXZ!!2!4E{3%d8JU|-{#}*B8VI?tpQYDdNbGCD*;3O zD09uWX0v~fDgT2hDN%=vPs|@6Ynw~xA!5DZ6msDDGT2+8LI(qnKk$G#@PgCWb7vs{ zIMy^o1M-Fx4`6{63l2mE)MBc2Z&Cz;9G)J4w?97FjmC+t;3qZp>oR1Sz&#r9m6};hp zaS8}^u06(3BvnUx$dY}g??ZSXC*hf3`;H%n z_AAm2jPYE5iiZA_(?Ovo2YWoduV<~C)!jdW-b2$m!t`@)viJfe%z$_&)<=$!^I~FN zDBcKP)kuRfxw{EBTod7T65(bryl*byP`F7(_X})fm(l$NN5{<{o&+`-{>*z1UfmNY zC+mY%)S?-79;~U_qeZkWXui0PkUe>}2kZAz!qIv9XbL=&p`yU4mD^N{!<~cT6EqE4 z9Ov@ME)&Z$n$SFmdy9uw&TF-UIM0s{@n_MZg;fw+>8wKphRStJBGvUr7LEYz;QJ9S z9Dzsx7q0dmw-A76&s;VE#1@Sb*zR2s%57||_RXmt1`n?Zu+sFAz%+asjz>@|l1Rm- zogzM0eRze_Fd)0!@e+|(a5pw8=yd3HicsQ*)jJEt;%j0$fGr2jp+ZOXO$V^p9QT$= zoK?v!!m%-aA>YMMyd(b@Mgx~aH>Tx};87@eEQ&{W8mj{6*>`ZU!ySwrRXZyg|wN5c$YOsc5B;uCD<5aNwKjT+!xJ#4MOSxu4bKkHV@<43ub#i0IJO0YqdYJnHRwPU zz$aBhUED!-dtgZ#C{phCKvp|E? z5d9E^ko58UP-B0eeqD|CKfRuT;Oy$2kTK1f_is_L%q_3;h5CF);lkWB-Z)HDHPwF7+!6{sV(Q zKoDh;Ghmtt_*15VK4N21N?_*yvM6Ce-C{7`4-{hpDYp^hScI6->jGSYssE;Ta%XZj zysmQ7U$ZXa@YdHe6S+(wQ<$;#BxU$p_%rJw{IVR7&^&&#a0klFPAIlq!*}FkdO^EjYTIVHIU0V&h7hY^b z(^7+j!$^9u&Xa*OO^d-=A7{~%1`bT04+ZgSd+@DaS%45yOAqxngSdsSY^Kj35_3WH zwStQ;LW`ClSoLX%hsA$wR`^5Mv+xXULgL<(hS*o}y%n5^NvR4XCb$^H#|;_7kqtOl zGX^~e%?x=ZNG%2**~}87^Pi&=>Mt1lC4(IXUt{of2D0r9)>xMJ95NO#h+q;CH>drz zsP@NRH$O60ef#lumd97ZnBr-5gfWUEhax_(anJvi)Fq-!SG{x3NAlDnj z@AHL(R=9x48)LLn?=eFl_#R^d#uOQ``@$3$ds!S{Rt~AU_h00rRFpa_K19?A@KG9; zORm9L0zK6q1;@4L@t3pbVdvtPqA%Q&$9^f^!rHZW6e;4?MsR{OB8RWw#2+962%@JZ zWFl&v)hz8<4QRa(!*36IM(7=Fbwtu>13!Xvl$3OEA5xYF8E^!=pT=^}NI6L(-r;Tr z=SOX@?0`Pxc5n}NQJWvNa#j}xu^XJW@)-1l%5=o`Wqv21O$O=Hu=tq`-+=y@rc zZgBL$#i0cl8+bRw$Kffo>3+9Me<2;a%{w5p_s~Zu6ZMIvTvI?`0H29d_q7}33u|E2 za5|DcSUkY0z_>wB;5p3Lg-A;Z!9p~rfe#WC&v7NRFvrvwuq7S?w>`Gz-F;%6px#~h z(SEPG)|ujBDW}Qd8LUj?)Z5D)xXaVh3r$OW%0~SR^Cr*Xz)c-ovagLDg)89pVM&%J z^<6BI!qp4wD-i3>;z@XzD9g9W8Nza24_H6s^x$|MK;-H;Jd5ts#sE~|UJj!};QTf; zS+J>T;TIdNe4(zvmg;_E-W8*5ycHMTD>{yydBgfkC*3_X3Zf%uRdO$qms`|qSNg%+ zMxMCCz6|wG_72%L{tSJ4lAUCda0SvWz3b=xGB;Z|tyt#4`}@Pl9p2xuj{|yTPv_ol zU`MBSZ6=x|b-R|Fj1IczUV$4}3;t(gdp8bUkCvp&h%;05V43p7(488KBO) zD|f-3xn^rU&R~Z&qnZBU>;M)99+pg05Ogpanxy`=`X$Z}TZaK_-)Pqcdnxmi594S{ z(+zH_J&M_I#6PExmiOm?BNLL%*-yG5w2uk8FGB3#-~dlOIMi+>eR1qv`l8EdkZ>Oi z5(h^ID({lf?nlNC^9|$%2n{)_U@NBHxBis2@G*l=ARJzQ-@$>ZeIfz#Sk(m9Up#r_ z?CXAS%9|E%e4c)#k1nhO@VBLpF3!G6Eqdu@%6aj7EnwabL{sMiUr>RQ54N}ZKha0o zJ;nQ#{-rR!#n*)NkI9OB8v7MM^7Wf*%j?bJdI-B-2+>%(tK!HZ2%5A-PgWsx@dq*QngoNYFFk4MTPbjkoHS3`k1tRV zSzAOif&hjk_?-y)JOtK9473LESv)ss`i=FaT2*iXzO1)UKwS8FV1Z(^xML=#*~8m( z@We2OPr>6n=fZf8j$=PQQ>+e|JmoyPoP*%8Ch-`BcQGyi!09x7V3X(LM}38d=Ww9S zfYN4pVYks9*GPCqZWY1nWm@}2SFEjeLk{XN^~jo0t)P8VaV}Wc{nKoBpzEM6N$#1^ zcb{NprN5X>*cfq}3se*`NqPY2KW&$+L6sEPBK@ktXz3&kW-ivS@N~VbN7(5zaD1Z- zaJL!Nb8x9Yd4ZhOPqZW)4IOiN7{98IjA?bp;ABoTrNO38xON^pkFL`8_4?}o2klFE zvUsk&as8zD7@p&PeSIJ(k8m|4NW@;7*h@h(@f8%u4$=G8;i|AU0gpfEf8~0ad00IW ze-?IB{H9@|S_H3#bMJ9sX5gJ!b zR2Q#l0cr;EDA%96pgeiQ_i^h;8-|7!t|SuXySuj7DZ=JbhZi$dm_SIt8q{CMil1m- z=&Td=g!|E;AX-sX!@|K40Jv}(8d2UZ1R$v5NHxK!~t9O2SRYQ^f_0@)eQ{I7_$S`LBxG6gQ2kesSaGq*A0rC~#ABR(sh2JK!7Zcux|Ii~bd zjlEm*+c+?L*Zct3lFACkJ^g*P=`TZ;Q4|gq^EIUXE*4X}be_I{k^11Yn^LnpH%5mvvQcV_ z>}7boAYxoA&nLco%5PThz^`6QvKLE-wNqy*;GJUXUG{I00iiR>_DQD(Spi4^O+g7< zOE)CbDguiV&h}Bp%4|k7CDh(%A>xl1gA@c@VIZiZF^R7e$#=AaII<(RNa5}U={Ize zXW7OogMoZ1FWk#Cnb(0T_84CCJH!-XcFro|KvT|!nhiL?^RIKxJotZVWajN7lmG;4 zU~#p;=ii4d{e_U%?_{-O3>Xr15UoFu?|qkyTL|3^UC-d-(0%!eT&o})(l~>&P#otS z0?p0z`3}u;B%ATVJ7%nm!D4TaJkYSj#+2ejHYB_cO+4 z%w>!EnB+(~ZO38yEfX%E&;n`m>7vk|`>cl?0tvJD#E3mAgTtfd7TyXj#~QXoL0yT$ ziX{6TOI>38EQ0|ak#Zm6!~mC!{YLkx-6jRf2)T)5H{uPigPUwpUW9HLYMH?i1n_^+ zmw3QopXw9sOiIj6SOn>jp!XG45*A?r`P#C;CTH2VX+$Q)u>?Nouw);Eqe=JhH|r^A ztk8I|5(&wFh|Rvhx2z#Z%sa>XsU&};iCtyeM_KC}1CnCJjZtwFig>FCT6*-im@zO`dgu*pZ`ymEry`$a%?;vXByeVw@#i)QeypY^7 z+WqQ4^u3@ZM_zS=Q|7ot3<4{#mXw}Z6c!;7@Cu;>b7^3$uJ|JEP)bOIU1?LuLKTcK zY{#Ko))}7|C9=~&pRIiO>Z^+am`jwp;VUTB^@?+DII$Sw@<_+<*jJ*CH*#B#c8z<$L~t=UYB>t)8Kt5 z2R;>jkaYH7(;~)K^5Sdx=%(GJ*Jd7lpo~K+%7m99Uy48U>0`#tRcLP~H@@G5;>a|@ zDfmyi(#QC!nvbKzZOjYmWmpew&!EK_&@woYiI!_d-bjlG{YoE)y=D)*o8ft&ewHQc z2Bas(X3-0$F(-5JW&IqVwQ?GJJxF~lt=&fAnfx-Wb3BWMy7cD62{N3QA}_nOf19H` z5Srn^5F9E;E;Mm+QRJJhb^1j2fG`Qzjr1-T?@-)5cFbc@AqHLHCA)Li0C{rhZJbPi z?j8{jhkE#g-uYqxXG4U#W4JIih;q;wz>#;nnO|1e(09s%+A5@pVAwD`X6kyJ_d=92E=z8#{$-4TZc zQz$d^gfw?(HhGs0zJd-4gD*5L=DIbt{~+uAG=tL!;C3JIM6L^uC!7jbS*E#Av-0+r z*I8;0uo35ymV$`q#7JQ(-ja5DdIVXbI#+XyeH%-@O!4wkeLia1eB7Rf>m+W# zdiYzf@Q?$-eZ2V;`oYCLPuMG?cesy~cm&u3F9NDY`BQk1dD@(j4sK&Vq(*ptlT$Ho zq=LixI0`(?wKR%zjy${yEJTdw%YotVgS^ROD0n7b%ee+#;P(PjatM>$9CM-9fsGV2 z;_zmgC42P#UJCO%=B2I>Tv90UY)@4nZDX6vc7OgQs``d@wK2_PtZiR|FE5m<-kpJ8(YcW; zLW}qyyLcSspplfL5~No#Y7dl}v2ZoGN0~LH180Uz!=via@lMFb%X%U{>r+5q0LXy# zc@ICTZvc(*%eXB&mZle#C~N2r0K@TK_=@3T;>#QuRD)C{Vw*4RiQJ`n0IXHd5k04b zamH+mWl!!D(b>hs2hqck?rN1oeR8DNytkhz@}c zFTnZ`zd2Z-qyV&Khtye?aF~mON@v*;S;?4M+yVJ%f)?B0K8v#=h6j=+7I2byri?hOYiJa-AfeSU;G48&Nxp#LFn$Y&t|??(j4 z1Io^OzcnEMD!}kd_X^wsPNbdh**IJ0P}D#S$|uHPK6+Zfn@*DhhrwhM1k{OOCRY>@ zVtJL{0to2`I0*Sd#I%djSPL&$%PF##8Tf zF5&fs_ss(9$3TYrEMbhGw@Y%m=Q3!ydLLg#AFrJ`f({PF^XTuwg?LXNk(xPseInYx z1K2FGJs>~a#J3>C@i-UNdmfDG!7&n1Fgs7u$Mr6Z?Yxi`*JwZ)%_Hj|G*5<@dng#8 zY^$RMk`jn@Af7zKDdZ(igTgIx5uQ>BiBV-gKe2pE-qp?EG*gLIyE?OAU+M;WpNDhBb*@4QaO!Xm0cKLjt?U9 zi(HOrG@F6zFx?$8*zV@6Ijg&GWKjmY_pZFhd2UY+b7ldXAlh)8#N2_I1vE}#hAp1w zya~Y>aEOC5bX!9g1+-+ZmxNdgLqrX^d+|mdoFIIetH#*9lwQ=fJY8w!MUfbQb`-JwY+!i`;aY)4G z`>|@s8CuLz+S*^d4Qn3np09V-0~nLm`)U!o>+=d8UdGeEbQ!1F!Tj=UxHs5vy8g#I z>&bCJaJ)1(G(E~r%UZd$1b>U5ga1bYOfM~0u~iDY#|<{9TkrQA$xzcQ!2ddt^TaFT znD#7kHKSJH%E+Q!VN6K99Pq!$$5Bo)+OO|a?cYA%KbG+ou3-T+oD8*j)3#&I*`f^V zOQ;JEs*64CME})&X)qD;0fn*XPd+NKVhbX?oM40YB9*GIAPqPuagOD_%jxtbRU-V0 z8i4dG;ywXmn_fvp98#|__$Y(H97s%f%BPXR>*5DD)TMH#GS88LoV=HroJ=W}Ku1Km zO!=Z@UAqX%D=K3d$srEsFoTDZSZP6e$GfSQ*Lia}?#i;5d5-;{)=Kl~KZh7hL2$A{ z7qm_T95L^!=k@(!2GalQMIKsNm^Zac3ymt8ollc<#@z<`Dqt6M1OTGAd}@LL{Un3{ z;XmIP(BUW-KY)OMlMk%<7;eQt)ORtZri=P?UT)E|6uJBeO8$8Ce5`z{4 zAq!tYEXpKAUQ@q6V=i3`M71K5x!XG6Jf23Z=TTHUu{uT}&vQ)t!^RAI6xn-NqP|gZ m3Uc;X=Otl*F0Mxk_ofd`JJVO@4;r5`K861br!!la+xY*jd4E6v diff --git a/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/fairchem_dataset/__pycache__/lmdb_dataset_tools.cpython-313.pyc deleted file mode 100644 index 0e2a72d88f0e8e0082f2d81eea9f9e095b525f2e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 42257 zcmch=33OZ6nI`zK5(Gd1B*1+G7fB?=MJpvyD@AITNSPpH#+1zlL6Cw(0?Y%bg|QuX zWs(MMSBrG)B9*vWl;er0R8Gv~$5gzE{i=C2`_=Fo_N(Q!_*EavI;!J!N3;2CJ6Az(XpZTR z8hFD|BX49S=#H6>nt3yeXCJd1&E<1hTz@R@sFk-Kwehy2`Fy_kwF16?m0l?R7Ky*b z;%^CG!k(A%W$d?{FK53Me8uVfYR+MB7#%q~lzgSb?l6^bXVttNwW_K>-GqORYKNK9 z&noz;A`Wj@N;pUE>FTX2p)dFvmYRoD>uH^wTFX*xNXfPFuRwa` z=|;KKrV&*YcRDAFbLDaz>uD`U2_tmW#GCBgX%qTE<;dEh;0={r*9XJgZdD~Wtf;SX zJ)___eeW4>z&_?5ni%!i-F{!t<@E*Z!80Cvd%)AVzta_T1w28!d(;&O*j>IMyEhoH z4|&Etz9Em#?ezqj^+$qs6ydoTj%!{XV<&Ub9}`&3i|@KG!HK0>SttB%BGh={uyTz0O^4Kj8_w8tq4X z?q<1;uHj+tsMi&w8VZdJ_$PQbHFC&96(04vJ-z@cGJ&S>timUb^h(M0VL#ICULQTs zA3xG@pv!TJRfc4p!CwRs`gk+7|h5kyI+0W3xIPV|#@WD&KYvNB(i!hjeAbH_=LK5 z!0qqF3)yb}c^B_>`GRh794ls}(Wj)+ z`_$_Yo+#!x1LkyJjg;fWQ9M&-b0@I&=!ui7MgO{+9;7C&*@pg;+@^?^4SNm1~A0}L7Ff0E>*D+la0p$a*d6T zdK%d~$+A2?4?l7#wV*b8zuP}H=JyQ<&l>H1_ofZv&C|j+*{a5DbWiZS#}^#EgthkB z2~5@@`-L+epZ$WzK0)iz9zeS=r-Lqj#1m}6gz=2B+Kf$%2EC{bo&?SX^q8D_FY*S0 zE}z?zP+{WMYZ8W(4ow(RIxvwfPu_%unmgbh9uWH>kT6nAn%n^zZWSh_fffmJ3{b@Yw^~NWh6bgZ z0hXtE{8FM2eS$G#55)OE;}zgbQJ3}j3j`6oz^#;U1y!N!y9V1&3^nn}hDhbsu(52( zTpTWG_-04Uydh%V5U<@5uGt#4*gvpum^Q|(l~HTsPi+hJmF)@&wjx>I%>z7#QY7qMgp{GHqrvb zN@@o)A!F4HMmrCsv zatrf@^0NUEx(0Ouqyv>wW_1c8eH=aP7e<4?XsprTl@A?Y>lG{-Bzg(w{ga2-ls zJ)){tCkmauz{L2tp9e|UDFE#QG+Td3CEeh$J|HE#V}v%V*=>&zlE}lMEaCcdT(6mmwfbt5Yq@E%vmi%StIC(3_9BS%KhZcOA(xIcc)E1vI z<%_hK)2^QKv;^Ebl@5&~dxtW$ly{;kV+W?6%Knx>~%!<0xpa5 zh%h9QCwRuW=t1`JZnr#J@FC(%H`C18jz|q>Cjx~7m2Ux+uSnq!6jE!>bFXRa+zIrd z0l7!GZK`3#4Mi7EbOLBhW%GuFdK_F#fM$~fy|Q<4w~b)N(4q;dN_!^7a1azyeuFfhWqhOjoZLf0fTthT%rrSUW$cq4*@ zCP;`op$#y&#_ysu6~Hj9~!J7A zOOJ&*|M>WFHfL-SkW1>+qc{UU15|{F zAhQMrfLnn#tgE||CB8tJO$c7#;%4+!WyD+=)>R6162Tsxava53ooHMM&}OO=0n=o? zvIj8U5Qzj`;PRWf=fzL{AfCskKah`UR8{Fe6~qoqc}h`G|QPxi1YoCUbpjDH0B~gGwu|o-VT-k+nq!|iA>QJv1rnnaK+Bf3iyYX&91#*@r zIoYZK@cICIK;x_$05S1Plo_%BwCF=~?&V7_UYfJTit8iA_20Ed%v-~{tq-@2)*)-# z3kYB1S^$tT&ygNHPro2`lo8tCCgO!M@leT@_vq3~Anqsg_XMQpP&!mQloD99bKMt_ z8^V?J3&8XYfT_}{?2!>*GLJ+$)QL$h0AA#zxtex4)J=e~K2^5{$U%%dRiCCzCidz4 zz54*9-s+Z~0B|DU-!CIwr;HvP8YUQ8`^jk5q0QW1CeD$yS(WV94OnX+Rx~u#t=uW~ zDNR7NfeR?nuTAI=wNsrkiZbd>b{<9%Vlcd=Y!Ti9DdBfI8M;2HJOKiS+7eLuG}1Q* zlqXumRFzNTQ~NY2U&u@nXatnslL5V_yi4?T+LcbA3_-kB$T_l|k15f_ZB|P2#-osP z639xsgkDP-Cz&XrQ6d=uzNJqss}?;W^v19dV?8j;)f>7JIWn9gl-bMg#@8l|Qev}j ze3EfOlg%CeiBSmF!0!!X{mZmSYEH(e2Z(K7mr$%rXnl+$nru2D!cvIjm}p>Z0*oce z+7|$Q zQ&ESEd|t#O;6+9-Xu0CjP);gJ-@DKmHEfI+cE^pyF{3?Vw68|$-I1a#x3)wqhd)v& zb4(C#6_;PtzpkHF$E}64yJmLH?wQ#WudTa&?Ao!#=1A?nnG<(w>lXGzYIjVZh+FM3 zYhA=z7qd1*tqpNoRm@fwvDJMu@2$db6fTy(+Z(CxdEe%^m%|m;Jg{&^bLhFKp?p3! zYN%c^S}%9M*d5Q?cgyyplG`P>jPJIGEnN>-Ewr4iAarC|$Ca#`|Ll9k^$#>EoBp4R z>!&r#DlRAQ)lIW|UfFYdceMJ6AM9SdxNzXjy|;Hqs-K8jd)_nl{It0KKIIST9za-^ zclo&&pNku9(|r+RMa)V0Pv%gj;E&6*Lf@MybeM)gRr|stUTZbb%PsDaQ zB0C-7w%)M8xuT(G$boEH&Qdg|ivEiD%6Zox9?=oQ&Cn$X~{lPG!GL3#3CKS3Y4Og0`!mW=967sjlVvHcD7Lp>}~ zZ7QB-ETgSKBtvFXuVIUUcOn^N#L&f*CBWz+^#Br+5@kX7VI+Bqv?vlqsq;Ve$s!63 z5Lp7U$zY}_8#PqL%j{RXu5>MIiIi;!ogj+2=JjCATpKah{&f4!XyLB!s}|MsLwAbS ze@_)D+!eO$!hDWc*MyC0Kq`KGU&9scX7#ygYB#CtRY`iD!3^kvSV}@ojMoiiLd$y~ zO7rojQ4c(uq=u?VSaw6W`E5RB?1#*)(e4{>_6@ms-gW8e`W(IyS$Uu|j)%0CA-uMgSO~d^nzh5&Q)@5PXUIse01` zwM}ze@t{$mu|Lo%G;1-bwfXmy3Qg;Z3bAF1v9uLsjiyLUL+qoY3a!Q_luxC)@h4y6 z%u=in6nqBrG+kUIDS4$|N0yV@0wp#?C_1QuHK(->ogj9RFcP0d!WSNb7e3p`?N;!5 zP^Z}mtz%*kg0)|M4%6a?jqqHR0l_K_#sl`;b zKJzz_L*`3H#MPt&Mwx&jfQX|4Tl@*~Ow#m)fJ1w*_> zbtL5Sf-qTVjwQ@O-^#4{+ei%TLhzdxxCcfq+ccFG(G|@VUah!NaYt7dFRTFbSY&@W zhsBmkE9NeIbyrAzM^_XtD4+Gs_(JMGHatw+15pW`PjmfI%wHyTh{Sj)dDbpA{b|jo zup(?tBrH-xq-pSD%9V%U1@5lddimL@XJh6ysf?98l*mKu0Gk}YK?GBS>dP*GO?hA* z0-HeU0Kf7yklul6lNM~+)t3hVtC9gMh-v?ZOl*i#MdJig6(S0eRE0KH`ZkE;Pya>|Mtvo=PojSIc+x}#fq{IX6Rw6MCKqbCA9?}8kje-`)khD>0af1a{3 zkj((~Y8ai!5m%twGdemj09nQtk)6%oeL|<1M?p#a0_8*co`mpFA5 zD$@9m|i3a(;Y=5Ob680!2qaubqRFmC}E|LcpbO_ zrDU_z(qv9Y8426jGrc35iQUT`dRC^jB@Jevt`NN*;?NH|5^uD0b91^29IXqgSjHhdGsv zFJu_62pF$6VLX09Fy$G9`rwyQd}H$ND#IVBtDso%f`Lde-(d=cLHk)~1MPz@_c^dU zdS(Gc_CXUF4QSkuX?gW zXRh-uFEd}UlghG9e~Jn_pBX+Q>$yoG*l2>2^YIFN_S{=sT5ECO3MZM zg{*qrP|weYSv||C9U{a|vW+V$o4q)5kwoST2d1A77jA+A;_w3;eiccsz11igerX=P;Mww`zQ{mHgAN@Vbm$eE~&8aL2f%>$4r$z}H zv8&%E#d->)Sn?NU)AD;x?Wl}&HgLWyr?#ggnNx~Er?4J61)WppU`&xi#pL9W$sCn=BiVYyfwr`yy zi>0dM{-e}w4joI?_P`J${c>c>J(BIrmfx#E+1X$#^)g!thS&d%o~rcsA;l1rz*wRe zMaEL^%yMMYSNp4_H+pKNSn?M-l`Myz+9%b5?pbZsZhR37GXo_|x*CT8UBUSb-5Zcc zs&=wqE>l~bY7^IkaY?@%M(OL)FPSKl;S8A}Kcam4eA@NN<@b}6f?Yiuq*(G7;stIh&BvB zrhp`3|9^?Wgt3V$;P!f$`a#7(=0ZfH8uT7dn46@}z}#6#CxJVEfGVI3f_HQW z1yz3*OcY2a4dJs;4G~kqa2wXN2^H*I6Y9xP?_felRW3 zcK}^1tu5>y&AFVj*}(x8*~Nbyhp0oi)rDX3aC^sHG~T`4F-z&+Cq>r>>m3`qY)D zV&xkmnt4w&uQAjC`t{`_ONGTmmChGL z3+qB1FCV^ZvV>G|Yr*WUS9gUv?wa$aH%^_Ov%d7)ylOu9HN!&SyCXNR!LzTN zowv@Px>g!3XbP!92d4~o4LMV7@!b5`vYE2be%K|=Rn455I(QeV!tFELNnJR6&TbLe#!3YTX{{gjg?PtqB`z!qyr{uCj9&iDaBI#?85*q03!UU31%_#v1SlmruNS z;`Oyx8^6*xANX!-Z1af-{+Dz`jor}yF^0itn(dqEg9;H!x66m84qxt`>Ym$w$5<)d zAD=oNw-nJB&1Xd|bxRi8Y|bk=uMb}x{mSS<>)UH%ZM_lvuW&{!r%)(Ucp6Qy+^R@! z6{(DjmYA_DVl0apYa+&)c@({lrP19Tw9G2bzVl(e zKqimlDEfEd3mh3ud>k@jO)4xfG_YrL;!Amp5ieI-lu@t&;OKy1O*>ZK#-GPq9C13`9~xsDp!va!y2i8^|Lf& z2a!|Y%!o^&B`Dbh%oj$D_2Lp-vJ}i3Uop=0#>$%_p1#%&Jm*lIN#Cc9d6VOILnbhIoJA_Emt|Wzk0abRfv8>M$v88~{gps_4 z7x{mnfQ=sw6%@$~#j!`U6a(ls!c>9=fVwXYKM4h1q7giWr!R0HSPOthQb1r}&U3Z< zO83H{XxWxK#;vK3I^H+h?`b$&A#?@VIWb)kdkx&y?l*=c?u3tzjsXU+68a&JOcm;mc`mWNARrq3IPWW)a*f>p@qUnrGPdOt%f4H@yjKw7%MNyKaa9@8=K4Lm2q57X4nC2PTSkmP`0 z-aX$QWwsn)>za68!ED`3-Q32R^;kE+R~M%)lJamab=^eGTSIC{9)Y@+@``2~Uuk?j zc=hvN`TRmKTDc{fx0NWmk~s$`B29SBzHn)KG;e>{xc?XC(zwxl`Ows%>4S5fcZ^kc zZH1xEsq7_l-sSUC=PyrAO~%aC5p(tY=BRmHShwym3++X;=6At+SXLPpj$rVEg+i1_ z+VwCfl^2SP88ZP!;7A!WX3}yzVxd@*28a$-axVlHjQV$8FdE=<5#T^DhJ__u5G*jk z%D=(V$(r2XK5>zDqGKzuyvYp8224JtaG*VWkks;*8tvF-0c96bY44)A@7U{xj)JYZ zF8=>UZTV#iWQGQdYn~F=l42`NURV4*dceql`xINDfH6T2DE8>(l+ZxT@43i7c0na_ zC9-f}@Dk|-JQouctLX=EdCAXTp>=i{-xe$=r1cZLdhyD|1 z%SZ?RX_nHPdGqRU?cQ+dzGz;1*x3HiQu+{9fT~Zim}D4_3eq$N%u=kph@~NJYe`Di zuq24^-k=aRvy4DKGx$1Sbe4meS_qOp7s_yGd?XPAlvYD#*sXMA`Sh?uH{cJF+8%UA z`qh^wO;1ohdqDZ*ICC;mIaxA}GC56d;0C8LBh{mF=4?r&T^f)AV@7WhO#A6X0=L zG^ndHcZ&+r%*{f*d`f)5S3R$ooOj{14ZqW-qQ{R6hshs3q0Td?Wg^a=^EAG0cd-HRH`&*`{LyD1|hST=&HbbAyX&c1Ma9YrxHfOHXck(!3CY<#im|HwjErxWsmY=UMf}y7Zaf%>_=hyc7!59@zRbd z#l;k(xP)%ksB3I+$h9Z2b~Tj8E;Wxkn6d>=%y%18IshODXutTx;%C3p%V-wJ6SCkY z!^1jgCW5O?9ma(6cM7JL9FK# zw5gCb6CUjCVoFmO(r%}3{XYm&H&*d&NJy;vg!V{#4}V4tB8?H-QnKc+;Le>@;i#YCJ6W^$(~RzXiS@YmrQ zO=Mx;+90|MdoL%(66&*l41K*@BX z3v)FuJx|idtaxGRT>ISq>5DV!yCs#=S$D1Z5DFF)g%07#?6Wh^&Ue0F*fOnxjEKo7 zhrT?98*9PL&bX~~wri$qzI{G0zx{n%<6T=3ZGQv>u7BUQ{+@v=+5quWPVVI+Q%B}T zZtMvii5N-ZG}RsI{Ds+u;vu|zCS15>TKz#u#f)}Zfp5H;eNV-eR^Bz{QI`DD=|gj_ znPZ^?|I~T-C*~tR==}Ea#g3bKHyzW~SKYJ2GsCmzX3ov}XZ-Vnk%Ib#%~4Cs?aq+; zJ@b*cF(;%To5k&~Z6~h6GTk~gIp>U;tAA=KAdaJHrfFWa*f`x3&f6R6_<3O|Yu0Of z61HM6h98-4o5Lr2Bm11;`cu)GQ{l2x;hIzN+O^jYT{{$6-|;X+RsN&-QDJ%b|Lvci<_@Qe1;a$2P?%aX+JKH@qd!H-??yOcEXhLJBkFbQmIr^vzz+v}SF2@UWeBbfm|BsufryJ?K2 zl=r{VWGBi2B2Mi4u%+!BrK(F8p(4OK#Yz!UAJq)^5O79DiAg4k4g$?g%$rON%9vUK zsSE%$jN7*7eXc2Bf$Z|HE`T!^j2@v0i^vNnBKweEV^ z8)e@dddvR}|IIVe=0nk{!`NeOB6IWSr=Axi08w*oSXV3X|Afk=i`W-7fx?Ka_VMtC zn==3FQzAtIKno-7dYrsl9#2W$-HVq6<~K$$HmGrmr3t;4kS2{OlMGR}Ns>|;Q@#qp zZ)o^B@dV6iaoNkJ)mkQFLCjblF_zDbM2%}7t7ZC#v`oVAphWqS*EccXk6PcEIK59& zOQf!1neJ==jSOB&xH4JzdK>H<#`T9KYb!v!&}D0fX{3lH!a zeu18m>0Uxjy23Q^D>>4@XVVSgD9u3Q)`7P%Zbr_SPn$R_xuN!WQQ1|)SD3?y?TZC> zine1nsU>&%i5cxH3Q%9gJMLO-p~H}nPH!Q-@k`Hzb!88)JHmkIWyK0fTmG%q-C)bV z#d-s_l<{|p2o!tEdSt;RS|?r-?+J;CQnua#G3BpLkr_PJxLlS&z&g#Tge0U&n5o#? zPL5=-283$_n~v<@qkKwOvuIp;bSy*4uo+Ip#h9Z|FAL`WWJissxnR1b_!ufqM92e1X63q?bZW_z#S9-!2~IJ^Y1t=r@+6QS67DmNu0K;K)T-f?|cJnpih)pp!&#?QKNsjVbckN6{VZrG!Rsk-;FGig>nb^uD0z zOIXCOkYzT5iu{Xsx?1b^HKe2)0ZZ20MYPI3hYVO{c^o#i!r98zT~~H3hAX<7b427PdG-6n{RJrErvtN1ko1JfUf1^9R>v**BM6|Li1m#2-3_I*==AT$l zzh%5(yi>U~YVCWka_d~xJ2kh4Vh2w}4xWnce=@S;$=P=JON?0iu!T?e@q5<307)Ef zZ!OzjtNua079sYwGbGOX>_6e5zzxq*Y#u?qrmOeh!Gq|-KgHc-F6?vc&)Uh!DSp@{ zsy8VsHH-h><=EHd7%PZ9ousrsMgRmu@k!Gvu791F%(4g`WAw}#VzLsttY&5zg-=p^ zXs{XJO-Sm2O~hYP$%Go1O#Zh9GK7+9N4WmCdYEfT2`Sa+>O$)S^d4jRSuq>2lmyW} z#USCK8Nk0xLGn#SFCHXc{1-J}q^E(uLi-NzU;J;Gvj}-hX1x!xz8;UUrgHG}xDiHgOJ_oJ>@xk-7jQ9qs0WAw* zck_Rw*kuGvu(eN+Rm@N#1yOqN_jm)l`sK_C{FFK1DV_6ICyQ%uk0<}Vv4X5_4eEI*bTHq}`vjYIGts8DEZv~4>70XEbX z<0B;%%Z7G^ChxAfd09gd6pgej7S(a)lI3g`)pI$8LQa&9GILle$HL=EuC5WA)oCA} z>28jN_V^JETe61s_?b{ebIC{a?0$(>qkmx0YN~%_F>C66Rh^@$TVBf*SHw$7@qWCZ zFkV_2FRy}y*>ZlTLQ{%D%=OFI_JpX8%PUyUW>GztU$kssQ6pz5S$M?$%%vetA@snpn4uxs43Ofh_c(y}#c#5BYpHqhqAG=&jO z;XP%RMo&+QmMMm=rQBLgu9zFKUs(*A@?V#jG@nu2=O|<}0UKa923WoT59v~@>_PbW z&F12^a@vZR%{Cw^;2j;?53#*Gi`S7?wL}gB5xxFHylebSz`Gu$Pj(*gL6PUij=EJ3 z?OyWoOZ$k}t|Gx1q%e`ey9y4Uf$n@m@Jcq)+~t;O0MwHz&1g~L(^oTj*fWl)D_q!1 zE;bLEn0?9ZMA)<@Q<+Zf?xcTL$gI=A3m$e!PRR(wdYxdlvYpoD4hq^RAj*T^MZsLFw99Y2!k*xWBwi+F7oW z8jWA-0Ng89?jw8+gg!jxb4Nk$+hN90Et|dj)Q{hqKx{xJBNIuOk@7Zi?r9hC!A^B@ z08(8g4Vx;?XUZ zpMc0Mdbli^!Tof#DSfzv47DkHIG5P}a-X6$_J?!Hwb{+Yay1Xx&YVllN+ORQxnRr>-?x(T2pde{c*?XomA^$L z=_osGCSU-WS@siX7id2veU)G|ybiE0iU^8trw0sf@2408-W`-r6~M*=5zN<8b-U;} ztNmZmEqjr@L)jk&^NB1N;#2P@@>hd;F?k(;!K_1V1M5SCU>*28@_eQii8eNqtqBzLWB9rdqI#eQ;V; zl?ex?yqoE};=62zyvm_&R^cR!TJ{diNRqyi`)p2@Q-XkU&jEH-q`1^i&qo+0NufYo zn^l!!ahEhw-J--4`~)b`8tKqiqAHbBAz2X=PA4saSM8-&`qN;p~cY1_3Upn&Oj zSRtP(qnkW-2m)f6l4Xa%5Zr;72WOkO&yje>4G9Vyz0;<2^0^Pc-6XUmSMuWKdqJ8R zXinG|)bT@{Oiqg>VpC74A`E455st#W7C{>syi67{`%$vBV`3q>R+Gv$dly``Y)^$q*s*^jm^K_4zUTIJuPXQp<^g*YoNVeEn25?iF$9Y0K2G&ta9VgxeoW4L3Y zgG4rcjd?D|&Mnqp;YKesj?NlMXoWa-=3jKPJ=!A4D^R>A(3nK&YB0D84EqTf5}OUO z7rSjrV+Cs?1#4pk4UvL|g&om?ts%p&0ICY-cD~hcqaj+lC2lR6ZJTMEI~TDwgpCc- z!^Rtp(b86WNV{(gaGeu2HWDDpzpB5Yhoe2|Jum$9g^hK>BjXig)KDwEl`Rwr4z92% zUQmty_p;Ol`Vj5RH|C>^H!mzq+$xT=9*eCz9$9xB99+(E#Zs|-K7VmXq;7Y#c+a%< zr-kKTzBp}Owg9OA=0OgZTe)O`zZ#Xzw8E*89b55pi)c4a*uL#%_TuMmRo`;lIuNcn z8nYY=TaKY#7TbScHdEH$G6#7-RM0$gvGQ5qRdopy z(IdUmp%?-0S5VMo`w4ib7=s&(H0~4!ZmiBu$^!Tx_6CMYR+@4QRM#q(^He#8G-_Gp zoQjOxCY}kG$YEj+W3GH$M*xN@7J((l?Q|e|S;9=BVwJBD$ zeQ_|P{)w*QVFeG06{<=gn-o0ETA-+1ICVtsCu%sO#MUF5R3{>7QtL!pB;ia1ewh*v zS4?;?PZc9>i>+CUO+>jo?S~R3n(h-O+8N5Z0VOcwCPBVYLbrqTfC3s))_C&!bnJmz zq1o{fgiWIBJ<73aCaLiBIxxz}_eKOFG_&(++Z&t0EKn-Iy@92Rsbo(YQjE=lwd`;<|}Ia9BjHfnXgi zo}9%dCvpLOfFKvQvq4+n866*<7`0CXVR#*c8bR6F@U z;sGr5u}m$j6PB-VmKHwEx>2@($2ba}q5@tREVA(-^^4$c2d5J-;aCfm8_;Ir%{u*?f znR`*MKO%QlfSm;+X^^M|gRa@BBHqjeVHS?R!2LRW7Q2j84h_ZniKJ(j!w5~aiZs>G@QGl6 z2{SSX9x~b(rR+dL3Di1v^{0wXK5bc9$yP}{fHHMbnUHWMg;W`RrDv?6&Y_a}N+zTM zEmvtpWubudtrX7`DVk4+Gm4Wrq*&6s2xFbWXccS}6C`r+vqf&Frdk z(TXamSO8^(!$8tY=HDBmPnSo5y@fel37u=x$N=-lM6}MqUP1Uu%}(f(2sdFm=keg2 zCLgTb#&D(~(-}>UclqtG`t*t}v9S2CTA^UAvY+;K`X^}hGQwV1P|Wy%cVihqONSNO z0)qrPo)DG_E}VQp=}^!3m`BwBld#Z!FRXSdhgNo*ohKokV7Qx`;^KD|07j4$r6DkxBKdPd%A`2^4SWOWWml*LcXalUE*R)1QeLE0I=S71B)?Ms&qX5YemepZZ`>`d zyt?zs&Z~Q`>|NM>vo>0~mx;NTO31M#OhDhNyis|x>=qxbKN>ALhBwT2E1U1=%E(Wa zcyvd&bp3)eX4o7tY+kCYo$rmU*&121HCoxq)V|oDAG0<^tW68UQ7g8&y{sd5t!($s zS9h~-eP}R+g0wI0rQLJw*kHeuR|a3ewlj+O8DH2RjGUPWkDOmtDRM6=;KJH!f-gz7 zkK34mDx+o30OeYMrBxnO6_4?tx*1D?)&RQzJY_&9rQvxB5Rn$4^xH~D1SHKT@V4JB z@ZvCsz-Y}ZXJ&64>ct67vNY+)%q$uMy zluVmU94QEbfIG=wGD-C_zDt{}AKeN}V!Ftr=~I`eN^VujB-MwUK@E?zSQ&Ilz#7tw zjDR!KX4Sqe@>}HKDG5-MW37_Phbt#p>hXv#R0}mp4IP@}oa-bTq1x1RAwJlFCs3$m zlEnAdOZoA&4DZ5~R0bJ&4`xI9CFw3wXHyfT*Vt-qmGa^nEmA!B>(olEhoqnDA>JYF z8qCMa!0HjMzJHe~PqL3uhTMnJJMdC0*CLrlCt6AG$>nCsf1ONJ-PTH&GxtzTiEhu8VMeD9Ggx)|c_r1~|&5 z!-4rGyq5Heoye9+Y5`p^WYCWMf=#27av!AILllhDEt3gcq}U*(F&V=yiZMA=4m~(Z z!4V1wh4BodQ3C}z{^|-QOhQF*S{19eByS2(kuOp}6HE}fv#(-Q1>b+G+FJr75H#q-Kd3DkJoNecV=r!yewV)x|CO z_f*P!6Ff3rExl5Dwem`3xPI@gyl8PJOAj4@pu}vue0J(A9ZWdqobSAT_}bwG$F&nt z(tlPnBYV4}B_~72 zR$<`Hi@?A;BZi&v`leX@zDWH(9Bol_yXf7T+vU;v?$EJ2#YnMNGM6z-43<*S(Xj4xUHC0~1e~zYNar#u07ux@%iAd^1kW8>csg4_;g8vXR zJ8dI_|Bz4J=KcdvF~K?fAOQ0Qi^3FE*>5LK6m1n~REhh1h}k3?K>h`^ zNfHpXr#F~Pj}mLr8=R3SRiinp8}I#tH2&v3a8yfTcoB+V`#BvkEgEm~E;?t{DvzWt)Qx1Li9@Gpn>%%%AV45td&(Kx?66s!f(tnW zge&X>UmTN5OgUAPULz9{K|JEIr&yKX3xZEx(c)8Lb5;!k&e%!v>x2L#cv&Wyf*AO1WAgphfj z0w@k~WQBO-3B2Bss)nzjCj>>rgOYeBB`}2Orx?SEUqp<#+G1!>Bx z^EwZs82CK0fn)p7S~RmQVqLeOj#@W`ba8_fC@gL*4Rr&2Pj8PJO0kVkIEtY2&Fbq7 z*BWqwVa(7RF*Gk(^Jllevi-I-T6W+E)|>6I^5#f+^Fr|b@@==Rk+K6(>%p+`pkT?4 zQ)lLDqSm#7n-4mHp!}fVCG2)V$^Kj&Fg{^#fx2J>$TeyuqZ+y?MnnX>^O5+gR?^M*lY!M`u*0!Ql3$}x{Wr{&7rM2B# zop#@-(dd6|(rfBgI0S-ZooeAm1>{RD#1t!M5y_x;$#e=ZX<#59$H4?SDROl2{NJ+* ztWkFH+3W;hJ>_wj=~x0d^EWv3I1gA`Z%G(2Sp`4-Bp9#6?3}zNNX0uvktjs2K0t4V z#{f(Tozo+dCu5-LWK#*ClKr}C#KR<{H29J%T+bl23lblg-v=+zbJ6!d!C!J0MrvE zRw9jIi^8Yu{>&$ook@D~sItKhW$bk^f9AeN9SR@!med02*M|chA;6XEMM#mkU{O^N zUrG=*?aDX`8AZgfcaeG4WFZbsiojb$+fr>XhWi8r3n^BncXi2kQEPebpr+NyFGw*m zB4*or<$SEIsWyZe+K2Xkg(N-w1kc$8KS|jV2_TRBHjI46uWsO&m>O^6IWeYbSH}KR z_ece&eI-8wx0eeiV-lRr<6b2Mt|J6)>r6X_zc#d?Z*5fprnt4q^!simHTBM0`x#GimDu z?gF<)+Zq`@r08S?+h*nst`pc95j|2t%wb)^AZ%t7Bu3ZJ0zm~)kwSrXQtIDRK-~sE zSphoRLyNttbQ~xlw}cj!5jbyyCz65xi~=dDp}_!iFHD$uU}#W~76tK$nGW$!QS24U zsl|43aF9HiHwq%AbrkzP6;(q4qky&Ui81m800P78T2@TpVUtWowlO|QCO(=oGOTNW+5;{?GT~6J;W_8b z*qgmELqo*S5HBhT9a$=>m^(l3x$@lmMNQLcqKmi9Y@6LZvwPkav#yO;*T!w-F?vAe6^S*8GQhwp=lQU1wJ~Q*od}l1bK9XNAzcAnXzOCWI zlC|Ob{kOK7ef?e=Kpj%ZQa z)R7O_lY#o;|zH=|?Tgt6cFnt4(E7iYhH z_Etss2}iWGH)?mjZ#e~59{Ck8dMc~F+IFRFer&NPTDo1>rI3^x)x;~Sm?DN`OU3w# zfw?4nwJx-O$y_ka&uxvFt6)|-KNMTr8d=+VvpTwV7kwTMr(>4th^2ad)AiPCt?yeJ zPy(cLa{A)nUwFUcjkG#$)yDSsME3W**XoGY^dejcTtE2cK|G1n^uBL#!d;5B6z)>i zt-pTmjdRi5O>dl=3%pu0?f#RxYv)4RaPFq1s+#M?Zxk=IM#<=AN2F>;v}z}{q~b%f z_03K&fHdH^2gY zp1Cw5WxRsiKS>U!@FjYfcJ;sO?O4hCK zP&D(HMJt+kw@S2F?k?Fz@k* z>MOx8fh=0#vnU_C??*;GC#k*={*XZuIR1KTLObXUfDqxoihKz(6L$J943L>NXp=-% z5c^0V4iJc*zry$kQY0tk)(W261^R^TSbmQ(v{P^Z0SF=(Sp@h~l)02dMKF03zZ_^L zSuW@gst|vTa;WJkh54Q*{sDH>cw;~~SuCL@`)a-f zB_$gUI!my$hP*gol+4NLa1HiJegV0eVI&P($|cxaGi5|Us*ojVMMv_`z@OuFXd?1B z`1UP-rFUTkrQ)TX@%*A#>BdOu#>KTaH-5VzQo1{m zzng9Sn9F_XbI{*}K8F{uHH4jZzEBXgZ4YG=VlIi{-~~%n)KU}DKbr;dvLXuMZ_SA2JjO_b3(y@On+8a6}>8x}tk+uDWyX}Mz++aeX)?ijb@ zL!coHE2zPKMW~Yqh)Xz;9Ad|B9$GvSZS0ITo?I%boX?IHZJ5@?YwPDaQKxuK-K>5_ zKer`nvE#@Hyb59f4%K$f?3~>@vvX#%;02 z-I2!K?>Fv^Sn8(rackxEwrg!LdWkn}dh4khPu(;|nhs1Km_0gkbZ#D=>25d?(01p^VV$uc5$rih>khD`w4CfNO3$uf(2`WLK9cYGN_)loP z%uzu2n(d1gw$;-DmTaAF(DQFl@GequC^;Ym7!ztBTQtIQ`5H1X0eX@xNCPLAY0;d( zTX3ob7khnEebd8{obr&8j{hi~DV;kN$*rA7&=}H^?0oR`#+YFpyidfz8eme*orfqq zY+Q>|KFAOD^zcj1hjrylF1y;?oN9(|DOP@g2-$EG<^~M_*JeGaW%@-saHN3&EvIy- z7$#avgXvH*sYNrezC2x=vP3TvwdPPeG&u1R?tQ5mLRl{G_DM5K!QR=Cy#e(h3v50}J8dr){3q`MZ^?i-ZK{Fx@vA65fwOo#>^Mk8v*3eC zLB~m642;7n`j`JNE_PyLFO&*2`aDro7;N@fm>EB!WHtwATOKj6a0iqH<73+2&Ck#T z$O6+!)d)~;VRi^}W0oEjt^%bJi#{|sOVi>D$PN>D9cL^eO$rW1JRZ-hh~-sB@~Y=s z!|V4%;qfplZpaH4)JF~VaicX{*bp@~&`FMmzHE$HY9f}J`Hsb^u%#w!*|OLZv1||T z>W*5vKeSfPXGN@s7Hh-SLpN~{^r5iv(1)Tc_2*xFer_aUZe1ut5Z1K{`vFqth&-3Z zQNu?822f3yEj~rc9Dy`anp4tZXS59@Ynjb>B?=5OBfLP8B;g>4Z{fS}`@&W#69upY zQG-r{k*Yz5w^2Hx1D1`pRlO~yex14+HAPpaaf{nx=9;LvCT__i{`61u_ch?{9_|ep zO#U6~jedl}m5JYhtVg!vvJ)3YZ>L?tfeo2<&x!IvwsNsTWDxngmXir|N`0LKv2Z3h zp?A2*P$9WnP15`&9T`dTA$E8%G~vRAE&B!7`oq)Ud9QyW0MqjcSy4?M1nDg02^4$b zj29NKOnrUZ+Kj69E1&`kx=6fK4br#h7C7>zMTE-YgSE0oEe} z(oQ+q?C$viU`G+!%PEQsV3-vi%|R8n|}dDcAtbZ;Be$ z0N96)fPk^sa763jnZt8W%p4P_qL{5YVr!1tHbl)E!@7-Zc7g^b(pA7=tVtw0Bg$BA z*(AkYMu1fCKce>mzloJB9 zvVy1!anKkEhBN9wnW~crN}i!fqD8SSK$tx|K}Y^J1@BP6^q}jhB$BNP0Q5i70}4pZ z7ogqa-HDlO^QENOlPX??uc}5Ol#$KV7fv zJ&3-mrHZ~xbzDQS&rv{NMsP*c%=EzUM?5e*5uD&X0|Pu-f%9wV1vVl7o^Im^V9brp z<^(5r0=y7XPWWJc;vWq#<83~R5)H5?!Zr_Kzom}yGFfXD2{Un$9B-sdT3{h80X~PG z1O0NmnF2y#2}8-yZrbiCS?^4O*%*gK+4bYr4+2AfC+Bv zDF(Kd<6oeFfHMDM1PPrqPXzqgN)2Y)vd&SA4eIKFhGT63TI}K9Ll*q!0(OKjlU69; z2}hy+sDM+{{G7}AfHQr-Wq-gKe#X@-t2K(+pQ$%|q~{a`|G*XfjH{1u^*`tIA8>ge zaJe5)ZUc)WV!Ws2bmq|cd0QlF?F-l_U!o{n-o+hJbSaiqMT*?G0mo|5@4EHEZ~K9d z=>9>gLXkap_5p{V2fA*B!Zuy=fTQ0BL&`!0_=f%m9Q{7n;?O8QqnO@2+d9)qbX&~2 z4yFQ^c@P4E2(T42Sd z-z+ int: - return self.num_samples - - def metadata_hasattr(self, attr) -> bool: - return attr in self._metadata - - @cached_property - def indices(self): - return np.arange(self.num_samples, dtype=int) - - @cached_property - def _metadata(self) -> dict[str, np.ndarray]: - # logic to read metadata file here - metadata_npzs = [] - if self.config.get("metadata_path", None) is not None: - metadata_npzs.append( - np.load(self.config["metadata_path"], allow_pickle=True) - ) - - else: - for path in self.paths: - if path.is_file(): - metadata_file = path.parent / "metadata.npz" - else: - metadata_file = path / "metadata.npz" - if metadata_file.is_file(): - metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) - - if len(metadata_npzs) == 0: - logging.warning( - f"Could not find dataset metadata.npz files in '{self.paths}'" - ) - return {} - - metadata = { - field: np.concatenate([metadata[field] for metadata in metadata_npzs]) - for field in metadata_npzs[0] - } - - assert np.issubdtype( - metadata["natoms"].dtype, np.integer - ), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}" - assert metadata["natoms"].shape[0] == len( - self - ), "Loaded metadata and dataset size mismatch." - - return metadata - - def get_metadata(self, attr, idx): - if attr in self._metadata: - metadata_attr = self._metadata[attr] - if isinstance(idx, list): - return [metadata_attr[_idx] for _idx in idx] - return metadata_attr[idx] - return None - - -class Subset(BaseDataset): - """A subset that also takes metadata if given.""" - - def __init__( - self, - dataset: BaseDataset, - indices: list[int], - metadata: dict[str, np.ndarray], - ) -> None: - super().__init__(dataset.config) - self.dataset = dataset - self.metadata = metadata - self.indices = indices - self.num_samples = len(indices) - self.config = dataset.config - - @cached_property - def _metadata(self) -> dict[str, np.ndarray]: - return self.dataset._metadata # pylint: disable=protected-access - - def get_metadata(self, attr, idx): - if isinstance(idx, list): - return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]]) - return self.dataset.get_metadata(attr, self.indices[idx]) - - -class LMDBDatabase(ase.db.core.Database): - """ - This module is modified from the ASE db json backend - and is thus licensed under the corresponding LGPL2.1 license. - - The ASE notice for the LGPL2.1 license is available here: - https://gitlab.com/ase/ase/-/blob/master/LICENSE - """ - - def __init__( # pylint: disable=keyword-arg-before-vararg - self, - filename: str | Path | None = None, - create_indices: bool = True, - use_lock_file: bool = False, - serial: bool = False, - readonly: bool = False, # Moved after *args to make it keyword-only - *args, - **kwargs, - ) -> None: - """ - For the most part, this is identical to the standard ase db initiation - arguments, except that we add a readonly flag. - """ - super().__init__( - Path(filename), - create_indices, - use_lock_file, - serial, - *args, - **kwargs, - ) - - # Add a readonly mode for when we're only training - # to make sure there's no parallel locks - self.readonly = readonly - - if self.readonly: - # Open a new env - self.env = lmdb.open( - str(self.filename), - subdir=False, - meminit=False, - map_async=True, - readonly=True, - lock=False, - ) - - # Open a transaction and keep it open for fast read/writes! - self.txn = self.env.begin(write=False) - - else: - # Open a new env with write access - self.env = lmdb.open( - str(self.filename), - map_size=1099511627776 * 2, - subdir=False, - meminit=False, - map_async=True, - ) - - self.txn = self.env.begin(write=True) - - # Load all ids based on keys in the DB. - self.ids = [] - self.deleted_ids = [] - self._load_ids() - - def __enter__(self) -> "LMDBDatabase": - return self - - def __exit__(self, exc_type, exc_value, tb) -> None: - self.close() - - def close(self) -> None: - # Close the lmdb environment and transaction - self.txn.commit() - self.env.close() - - def _write( - self, - atoms: ase.Atoms | ase.db.row.AtomsRow, - key_value_pairs: dict, - data: dict | None, - id: int | None = None, # pylint: disable=redefined-builtin - ) -> None: - # Call parent method with the original parameter name - super()._write(atoms, key_value_pairs, data) - - mtime = ase.db.core.now() - - if isinstance(atoms, ase.db.row.AtomsRow): - row = atoms - else: - row = ase.db.row.AtomsRow(atoms) - row.ctime = mtime - row.user = os.getenv("USER") - - dct = {} - for key in row.__dict__: - # Use getattr to avoid accessing protected member directly - if key[0] == "_" or key == "id" or key in getattr(row, "_keys", []): - continue - dct[key] = row[key] - - dct["mtime"] = mtime - - if key_value_pairs: - dct["key_value_pairs"] = key_value_pairs - - if data: - dct["data"] = data - - constraints = row.get("constraints") - if constraints: - dct["constraints"] = [constraint.todict() for constraint in constraints] - - # json doesn't like Cell objects, so make it an array - dct["cell"] = np.asarray(dct["cell"]) - - if id is None: - id = self._nextid - nextid = id + 1 - else: - data = self.txn.get(f"{id}".encode("ascii")) - assert data is not None - - # Add the new entry - self.txn.put( - f"{id}".encode("ascii"), - zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - # only append if idx is not in ids - if id not in self.ids: - self.ids.append(id) - self.txn.put( - "nextid".encode("ascii"), - zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - # check if id is in removed ids and remove accordingly - if id in self.deleted_ids: - self.deleted_ids.remove(id) - self._write_deleted_ids() - - return id - - def _update( - self, - idx: int, - key_value_pairs: dict | None = None, - data: dict | None = None, - ): - # hack this to play nicely with ASE code - row = self._get_row(idx, include_data=True) - if data is not None or key_value_pairs is not None: - self._write( - atoms=row, key_value_pairs=key_value_pairs, data=data, id=idx - ) # Fixed E1123 by using id=idx - - def _write_deleted_ids(self): - self.txn.put( - "deleted_ids".encode("ascii"), - zlib.compress( - orjson.dumps(self.deleted_ids, option=orjson.OPT_SERIALIZE_NUMPY) - ), - ) - - def delete(self, ids: list[int]) -> None: - for idx in ids: - self.txn.delete(f"{idx}".encode("ascii")) - self.ids.remove(idx) - - self.deleted_ids += ids - self._write_deleted_ids() - - def _get_row(self, idx: int, include_data: bool = True): - if idx is None: - assert len(self.ids) == 1 - idx = self.ids[0] - data = self.txn.get(f"{idx}".encode("ascii")) - - if data is not None: - dct = orjson.loads(zlib.decompress(data)) - else: - raise KeyError(f"Id {idx} missing from the database!") - - if not include_data: - dct.pop("data", None) - - dct["id"] = idx - return ase.db.row.AtomsRow(dct) - - def _get_row_by_index(self, index: int, include_data: bool = True): - """Auxiliary function to get the ith entry, rather than a specific id""" - data = self.txn.get(f"{self.ids[index]}".encode("ascii")) - - if data is not None: - dct = orjson.loads(zlib.decompress(data)) - else: - raise KeyError(f"Id {id} missing from the database!") - - if not include_data: - dct.pop("data", None) - - dct["id"] = id - return ase.db.row.AtomsRow(dct) - - def _select( - self, - keys, - cmps: list[tuple[str, str, str]], - explain: bool = False, - _verbosity: int = 0, # Unused parameter marked with underscore - limit: int | None = None, - offset: int = 0, - sort: str | None = None, - include_data: bool = True, - _columns: str = "all", # Unused parameter marked with underscore - ): - if explain: - yield {"explain": (0, 0, 0, "scan table")} - return - - if sort is not None: - if sort[0] == "-": - reverse = True - sort = sort[1:] - else: - reverse = False - - rows = [] - missing = [] - for row in self._select(keys, cmps): - key = row.get(sort) - if key is None: - missing.append((0, row)) - else: - rows.append((key, row)) - - rows.sort(reverse=reverse, key=lambda x: x[0]) - rows += missing - - if limit: - rows = rows[offset : offset + limit] - for _, row in rows: - yield row - return - - if not limit: - limit = -offset - 1 - - cmps = [(key, ase.db.core.ops[op], val) for key, op, val in cmps] - n = 0 - for idx in self.ids: - if n - offset == limit: - return - row = self._get_row(idx, include_data=include_data) - - for key in keys: - if key not in row: - break - else: - for key, op, val in cmps: - if isinstance(key, int): - value = np.equal(row.numbers, key).sum() - else: - value = row.get(key) - if key == "pbc": - assert op in [ase.db.core.ops["="], ase.db.core.ops["!="]] - value = "".join("FT"[x] for x in value) - if value is None or not op(value, val): - break - else: - if n >= offset: - yield row - n += 1 - - @property - def metadata(self): - """Override abstract metadata method from Database class.""" - return self.db_metadata - - @property - def db_metadata(self): - """Load the metadata from the DB if present""" - if self._metadata is None: - metadata = self.txn.get("metadata".encode("ascii")) - if metadata is None: - self._metadata = {} - else: - self._metadata = orjson.loads(zlib.decompress(metadata)) - - return self._metadata.copy() - - @db_metadata.setter - def db_metadata(self, dct): - self._metadata = dct - - # Put the updated metadata dictionary - self.txn.put( - "metadata".encode("ascii"), - zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - @property - def _nextid(self): - """Get the id of the next row to be written""" - # Get the nextid - nextid_data = self.txn.get("nextid".encode("ascii")) - if nextid_data: - return orjson.loads(zlib.decompress(nextid_data)) - return 1 # Removed unnecessary else (R1705) - - def count(self, selection=None, **kwargs) -> int: - """Count rows. - - See the select() method for the selection syntax. Use db.count() or - len(db) to count all rows. - """ - if selection is not None: - n = 0 - for _row in self.select(selection, **kwargs): - n += 1 - return n - return len(self.ids) - - def _load_ids(self) -> None: - """Load ids from the DB - - Since ASE db ids are mostly 1-N integers, but can be missing entries - if ids have been deleted. To save space and operating under the assumption - that there will probably not be many deletions in most OCP datasets, - we just store the deleted ids. - """ - # Load the deleted ids - deleted_ids_data = self.txn.get("deleted_ids".encode("ascii")) - if deleted_ids_data is not None: - self.deleted_ids = orjson.loads(zlib.decompress(deleted_ids_data)) - - # Reconstruct the full id list - self.ids = [i for i in range(1, self._nextid) if i not in set(self.deleted_ids)] - - -# Placeholder for AtomsToGraphs class -# This is a minimal implementation without the full functionality -class AtomsToGraphs: - """Enhanced AtomsToGraphs implementation with proper property handling.""" - - def __init__( - self, - r_edges=False, - r_pbc=True, - r_energy=False, - r_forces=False, - r_stress=False, - r_data_keys=None, - **kwargs, - ): - self.r_edges = r_edges - self.r_pbc = r_pbc - self.r_energy = r_energy - self.r_forces = r_forces - self.r_stress = r_stress - self.r_data_keys = r_data_keys or {} - self.kwargs = kwargs - - def convert(self, atoms, sid=None): - """ - Convert ASE atoms to graph data format with proper property handling. - """ - from mace.tools.torch_geometric.data import Data - - # Create a minimal data object with required properties - data = Data() - - # Set positions - data.pos = torch.tensor(atoms.get_positions(), dtype=torch.float) - - # Set atomic numbers - data.atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long) - - # Set cell if available - if atoms.cell is not None: - data.cell = torch.tensor(atoms.get_cell(), dtype=torch.float) - - # Set PBC if requested - if self.r_pbc: - data.pbc = torch.tensor(atoms.get_pbc(), dtype=torch.bool) - - # Set energy if requested - if self.r_energy: - energy = self._get_property(atoms, "energy") - if energy is not None: - data.energy = torch.tensor(energy, dtype=torch.float) - - # Set forces if requested - if self.r_forces: - forces = self._get_property(atoms, "forces") - if forces is not None: - data.forces = torch.tensor(forces, dtype=torch.float) - - # Set stress if requested - if self.r_stress: - stress = self._get_property(atoms, "stress") - if stress is not None: - data.stress = torch.tensor(stress, dtype=torch.float) - - # Set sid if provided - if sid is not None: - data.sid = sid - - return data - - def _get_property(self, atoms, prop_name): - """Get property from atoms, checking custom names first then standard methods.""" - # Check if we have a custom name for this property - custom_name = self.r_data_keys.get(prop_name) - - # Try custom name in info dict - if custom_name and custom_name in atoms.info: - return atoms.info[custom_name] - - # Try custom name in arrays dict - if custom_name and custom_name in atoms.arrays: - return atoms.arrays[custom_name] - - # Try standard name in info dict - if prop_name in atoms.info: - return atoms.info[prop_name] - - # Try standard name in arrays dict - if prop_name in atoms.arrays: - return atoms.arrays[prop_name] - - # Try standard ASE methods - method_map = { - "energy": "get_potential_energy", - "forces": "get_forces", - "stress": "get_stress", - } - - if prop_name in method_map and hasattr(atoms, method_map[prop_name]): - try: - method = getattr(atoms, method_map[prop_name]) - return method() - except ( - AttributeError, - RuntimeError, - ) as exc: # Fixed W0718 by specifying exceptions - logging.debug(f"Error getting property {prop_name}: {exc}") - # Removed unnecessary pass (W0107) - - return None - - -# Placeholder for DataTransforms class -class DataTransforms: - """Minimal implementation of DataTransforms to satisfy dependencies.""" - - def __init__(self, transforms_config=None): - self.transforms_config = transforms_config or {} - - def __call__(self, data): - """Apply transforms to data""" - # No transforms applied in this minimal implementation - return data - - -class AseAtomsDataset(BaseDataset, ABC): - """ - This is an abstract Dataset that includes helpful utilities for turning - ASE atoms objects into OCP-usable data objects. This should not be instantiated directly - as get_atoms_object and load_dataset_get_ids are not implemented in this base class. - - Derived classes must add at least two things: - self.get_atoms_object(id): a function that takes an identifier and returns a corresponding atoms object - - self.load_dataset_get_ids(config: dict): This function is responsible for any initialization/loads - of the dataset and importantly must return a list of all possible identifiers that can be passed into - self.get_atoms_object(id) - - Identifiers need not be any particular type. - """ - - def __init__( - self, - config: dict, - atoms_transform: Callable[[ase.Atoms, Any], ase.Atoms] = apply_one_tags, - ) -> None: - super().__init__(config) - - a2g_args = config.get("a2g_args", {}) or {} - - # set default to False if not set by user, assuming otf_graph will be used - if "r_edges" not in a2g_args: - a2g_args["r_edges"] = False - - # Make sure we always include PBC info in the resulting atoms objects - a2g_args["r_pbc"] = True - self.a2g = AtomsToGraphs(**a2g_args) - - self.key_mapping = self.config.get("key_mapping", None) - self.transforms = DataTransforms(self.config.get("transforms", {})) - - self.atoms_transform = atoms_transform - - if self.config.get("keep_in_memory", False): - self.__getitem__ = cache(self.__getitem__) - - self.ids = self._load_dataset_get_ids(config) - self.num_samples = len(self.ids) - - if len(self.ids) == 0: - raise ValueError( - rf"No valid ase data found! \n" - f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" - ) - - def __getitem__(self, idx): # pylint: disable=method-hidden - # Handle slicing - if isinstance(idx, slice): - return [self[i] for i in range(*idx.indices(len(self)))] - - # Get atoms object via derived class method - atoms = self.get_atoms(self.ids[idx]) - - # Transform atoms object - if self.atoms_transform is not None: - atoms = self.atoms_transform( - atoms, **self.config.get("atoms_transform_args", {}) - ) - - sid = atoms.info.get("sid", self.ids[idx]) - fid = atoms.info.get("fid", torch.tensor([0])) - - # Convert to data object - data_object = self.a2g.convert(atoms, sid) - data_object.fid = fid - data_object.natoms = len(atoms) - - # apply linear reference - if self.a2g.r_energy is True and self.lin_ref is not None: - data_object.energy -= sum(self.lin_ref[data_object.atomic_numbers.long()]) - - # Transform data object - data_object = self.transforms(data_object) - - if self.key_mapping is not None: - data_object = rename_data_object_keys(data_object, self.key_mapping) - - if self.config.get("include_relaxed_energy", False): - data_object.energy_relaxed = self.get_relaxed_energy(self.ids[idx]) - - return data_object - - @abstractmethod - def get_atoms(self, idx: str | int) -> ase.Atoms: - # This function should return an ASE atoms object. - raise NotImplementedError( - "Returns an ASE atoms object. Derived classes should implement this function." - ) - - @abstractmethod - def _load_dataset_get_ids(self, config): - # This function should return a list of ids that can be used to index into the database - raise NotImplementedError( - "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." - ) - - def get_relaxed_energy(self, identifier): - raise NotImplementedError( - "Reading relaxed energy from trajectory or file is not implemented with this dataset. " - "If relaxed energies are saved with the atoms info dictionary, they can be used by passing the keys in " - "the r_data_keys argument under a2g_args." - ) - - def get_metadata(self, attr, idx): - # try the parent method - metadata = super().get_metadata(attr, idx) - if metadata is not None: - return metadata - # try to resolve it here - if attr != "natoms": - return None - if isinstance(idx, (list, np.ndarray)): - return np.array([self.get_metadata(attr, i) for i in idx]) - return len(self.get_atoms(idx)) - - -class AseDBDataset(AseAtomsDataset): - """ - This Dataset connects to an ASE Database, allowing the storage of atoms objects - with a variety of backends including JSON, SQLite, and database server options. - """ - - def _load_dataset_get_ids(self, config: dict) -> list[int]: - if isinstance(config["src"], list): - filepaths = [] - for path in sorted(config["src"]): - if os.path.isdir(path): - filepaths.extend(sorted(glob(f"{path}/*"))) - elif os.path.isfile(path): - filepaths.append(path) - else: - raise RuntimeError(f"Error reading dataset in {path}!") - elif os.path.isfile(config["src"]): - filepaths = [config["src"]] - elif os.path.isdir(config["src"]): - filepaths = sorted(glob(f'{config["src"]}/*')) - else: - filepaths = sorted(glob(config["src"])) - - self.dbs = [] - - for path in filepaths: - try: - self.dbs.append(self.connect_db(path, config.get("connect_args", {}))) - except ValueError: - logging.debug( - f"Tried to connect to {path} but it's not an ASE database!" - ) - - self.select_args = config.get("select_args", {}) - if self.select_args is None: - self.select_args = {} - - # Get all unique IDs from the databases - self.db_ids = [] - for db in self.dbs: - if hasattr(db, "ids") and self.select_args == {}: - self.db_ids.append(db.ids) - else: - # this is the slow alternative - self.db_ids.append([row.id for row in db.select(**self.select_args)]) - - idlens = [len(ids) for ids in self.db_ids] - self._idlen_cumulative = np.cumsum(idlens).tolist() - - return list(range(sum(idlens))) - - def get_atoms(self, idx: int) -> ase.Atoms: - """Get atoms object corresponding to datapoint idx. - Args: - idx (int): index in dataset - - Returns: - atoms: ASE atoms corresponding to datapoint idx - """ - # Figure out which db this should be indexed from - db_idx = bisect.bisect(self._idlen_cumulative, idx) - - # Extract index of element within that db - el_idx = idx - if db_idx != 0: - el_idx = idx - self._idlen_cumulative[db_idx - 1] - assert el_idx >= 0 - - # Use a wrapper method to avoid protected access warning - atoms_row = self.get_row_from_db(db_idx, el_idx) - - # Convert to atoms object - atoms = atoms_row.toatoms() - - # Put data back into atoms info - if isinstance(atoms_row.data, dict): - atoms.info.update(atoms_row.data) - - # Add key-value pairs directly to atoms.info - if hasattr(atoms_row, "key_value_pairs") and atoms_row.key_value_pairs: - atoms.info.update(atoms_row.key_value_pairs) - - # Create a SinglePointCalculator to attach energy, forces and stress to atoms - calc_kwargs = {} - - # Check for energy, forces, stress in atoms_row and store in info & calc_kwargs - for prop in ["energy", "forces", "stress", "free_energy"]: - if hasattr(atoms_row, prop) and getattr(atoms_row, prop) is not None: - value = getattr(atoms_row, prop) - calc_kwargs[prop] = value - atoms.info[prop] = value - - # If we have custom data mappings, copy the standard properties to the custom names - a2g_args = self.config.get("a2g_args", {}) or {} - r_data_keys = a2g_args.get("r_data_keys", {}) - if r_data_keys: - # Map from standard names to custom names (in reverse of how they'll be used) - for custom_key, standard_key in r_data_keys.items(): - if standard_key in atoms.info: - atoms.info[custom_key] = atoms.info[standard_key] - elif standard_key in atoms.arrays: - atoms.arrays[custom_key] = atoms.arrays[standard_key] - - # Create calculator if we have any properties - if calc_kwargs: - from ase.calculators.singlepoint import SinglePointCalculator - - calc = SinglePointCalculator(atoms, **calc_kwargs) - atoms.calc = calc - - return atoms - - def get_row_from_db(self, db_idx, el_idx): - """Get a row from the database at the given indices.""" - db = self.dbs[db_idx] - row_id = self.db_ids[db_idx][el_idx] - if isinstance(db, LMDBDatabase): - return db._get_row(row_id) # pylint: disable=protected-access - return db.get(row_id) - - @staticmethod - def connect_db( - address: str | Path, connect_args: dict | None = None - ) -> ase.db.core.Database: - if connect_args is None: - connect_args = {} - db_type = connect_args.get("type", "extract_from_name") - if db_type in ("lmdb", "aselmdb") or ( - db_type == "extract_from_name" - and str(address).rsplit(".", maxsplit=1)[-1] in ("lmdb", "aselmdb") - ): - return LMDBDatabase(address, readonly=True, **connect_args) - - return ase.db.connect(address, **connect_args) - - def __del__(self): - for db in self.dbs: - if hasattr(db, "close"): - db.close() - - def sample_property_metadata( - self, - ) -> dict: # Removed unused argument num_samples (W0613) - """ - Sample property metadata from the database. - - This method was previously using the copy module which is now removed. - """ - logging.warning( - "You specified a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" - ) - if self.dbs[0].metadata == {}: - return {} - - # Fixed unnecessary comprehension (R1721) - return dict(self.dbs[0].metadata.items()) +""" +This module contains the AseDBDataset class and its dependencies. +It is extracted from the fairchem codebase and adapted to remove dependencies on fairchem. + +Original code copyright: +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import bisect +import logging +import os +import zlib +from abc import ABC, abstractmethod + +try: + from functools import cache, cached_property +except ImportError: + from functools import cached_property, lru_cache + + cache = lru_cache(maxsize=None) +from glob import glob +from pathlib import Path +from typing import Any, Callable, TypeVar + +import ase +import ase.db.core +import ase.db.row +import ase.io +import lmdb +import numpy as np +import orjson +import torch + +# Type variable for generic dataset return type +T_co = TypeVar("T_co", covariant=True) + + +def rename_data_object_keys(data_object, key_mapping: dict[str, str | list[str]]): + """Rename data object keys + + Args: + data_object: data object + key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key} + + new_key can be a list of new keys, for example, + prev_key: energy + new_key: [common_energy, oc20_energy] + + This is currently required when we use a single target/label for multiple tasks + """ + for _property in key_mapping: + # catch for test data not containing labels + if _property in data_object: + list_of_new_keys = key_mapping[_property] + if isinstance(list_of_new_keys, str): + list_of_new_keys = [list_of_new_keys] + for new_property in list_of_new_keys: + if new_property == _property: + continue + assert new_property not in data_object + data_object[new_property] = data_object[_property] + if _property not in list_of_new_keys: + del data_object[_property] + return data_object + + +def apply_one_tags( + atoms: ase.Atoms, skip_if_nonzero: bool = True, skip_always: bool = False +): + """ + This function will apply tags of 1 to an ASE atoms object. + It is used as an atoms_transform in the datasets contained in this file. + + Certain models will treat atoms differently depending on their tags. + For example, GemNet-OC by default will only compute triplet and quadruplet interactions + for atoms with non-zero tags. This model throws an error if there are no tagged atoms. + For this reason, the default behavior is to tag atoms in structures with no tags. + + args: + skip_if_nonzero (bool): If at least one atom has a nonzero tag, do not tag any atoms + + skip_always (bool): Do not apply any tags. This arg exists so that this function can be disabled + without needing to pass a callable (which is currently difficult to do with main.py) + """ + if skip_always: + return atoms + + if np.all(atoms.get_tags() == 0) or not skip_if_nonzero: + atoms.set_tags(np.ones(len(atoms))) + + return atoms + + +class UnsupportedDatasetError(ValueError): + pass + + +class BaseDataset(ABC): + """Base Dataset class for all ASE datasets.""" + + def __init__(self, config: dict): + """Initialize + + Args: + config (dict): dataset configuration + """ + self.config = config + self.paths = [] + + if "src" in self.config: + if isinstance(config["src"], str): + self.paths = [Path(self.config["src"])] + else: + self.paths = tuple(Path(path) for path in sorted(config["src"])) + + self.lin_ref = None + if self.config.get("lin_ref", False): + lin_ref = torch.tensor( + np.load(self.config["lin_ref"], allow_pickle=True)["coeff"] + ) + self.lin_ref = torch.nn.Parameter(lin_ref, requires_grad=False) + + def __len__(self) -> int: + return self.num_samples + + def metadata_hasattr(self, attr) -> bool: + return attr in self._metadata + + @cached_property + def indices(self): + return np.arange(self.num_samples, dtype=int) + + @cached_property + def _metadata(self) -> dict[str, np.ndarray]: + # logic to read metadata file here + metadata_npzs = [] + if self.config.get("metadata_path", None) is not None: + metadata_npzs.append( + np.load(self.config["metadata_path"], allow_pickle=True) + ) + + else: + for path in self.paths: + if path.is_file(): + metadata_file = path.parent / "metadata.npz" + else: + metadata_file = path / "metadata.npz" + if metadata_file.is_file(): + metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) + + if len(metadata_npzs) == 0: + logging.warning( + f"Could not find dataset metadata.npz files in '{self.paths}'" + ) + return {} + + metadata = { + field: np.concatenate([metadata[field] for metadata in metadata_npzs]) + for field in metadata_npzs[0] + } + + assert np.issubdtype( + metadata["natoms"].dtype, np.integer + ), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}" + assert metadata["natoms"].shape[0] == len( + self + ), "Loaded metadata and dataset size mismatch." + + return metadata + + def get_metadata(self, attr, idx): + if attr in self._metadata: + metadata_attr = self._metadata[attr] + if isinstance(idx, list): + return [metadata_attr[_idx] for _idx in idx] + return metadata_attr[idx] + return None + + +class Subset(BaseDataset): + """A subset that also takes metadata if given.""" + + def __init__( + self, + dataset: BaseDataset, + indices: list[int], + metadata: dict[str, np.ndarray], + ) -> None: + super().__init__(dataset.config) + self.dataset = dataset + self.metadata = metadata + self.indices = indices + self.num_samples = len(indices) + self.config = dataset.config + + @cached_property + def _metadata(self) -> dict[str, np.ndarray]: + return self.dataset._metadata # pylint: disable=protected-access + + def get_metadata(self, attr, idx): + if isinstance(idx, list): + return self.dataset.get_metadata(attr, [[self.indices[i] for i in idx]]) + return self.dataset.get_metadata(attr, self.indices[idx]) + + +class LMDBDatabase(ase.db.core.Database): + """ + This module is modified from the ASE db json backend + and is thus licensed under the corresponding LGPL2.1 license. + + The ASE notice for the LGPL2.1 license is available here: + https://gitlab.com/ase/ase/-/blob/master/LICENSE + """ + + def __init__( # pylint: disable=keyword-arg-before-vararg + self, + filename: str | Path | None = None, + create_indices: bool = True, + use_lock_file: bool = False, + serial: bool = False, + readonly: bool = False, # Moved after *args to make it keyword-only + *args, + **kwargs, + ) -> None: + """ + For the most part, this is identical to the standard ase db initiation + arguments, except that we add a readonly flag. + """ + super().__init__( + Path(filename), + create_indices, + use_lock_file, + serial, + *args, + **kwargs, + ) + + # Add a readonly mode for when we're only training + # to make sure there's no parallel locks + self.readonly = readonly + + if self.readonly: + # Open a new env + self.env = lmdb.open( + str(self.filename), + subdir=False, + meminit=False, + map_async=True, + readonly=True, + lock=False, + ) + + # Open a transaction and keep it open for fast read/writes! + self.txn = self.env.begin(write=False) + + else: + # Open a new env with write access + self.env = lmdb.open( + str(self.filename), + map_size=1099511627776 * 2, + subdir=False, + meminit=False, + map_async=True, + ) + + self.txn = self.env.begin(write=True) + + # Load all ids based on keys in the DB. + self.ids = [] + self.deleted_ids = [] + self._load_ids() + + def __enter__(self) -> "LMDBDatabase": + return self + + def __exit__(self, exc_type, exc_value, tb) -> None: + self.close() + + def close(self) -> None: + # Close the lmdb environment and transaction + self.txn.commit() + self.env.close() + + def _write( + self, + atoms: ase.Atoms | ase.db.row.AtomsRow, + key_value_pairs: dict, + data: dict | None, + id: int | None = None, # pylint: disable=redefined-builtin + ) -> None: + # Call parent method with the original parameter name + super()._write(atoms, key_value_pairs, data) + + mtime = ase.db.core.now() + + if isinstance(atoms, ase.db.row.AtomsRow): + row = atoms + else: + row = ase.db.row.AtomsRow(atoms) + row.ctime = mtime + row.user = os.getenv("USER") + + dct = {} + for key in row.__dict__: + # Use getattr to avoid accessing protected member directly + if key[0] == "_" or key == "id" or key in getattr(row, "_keys", []): + continue + dct[key] = row[key] + + dct["mtime"] = mtime + + if key_value_pairs: + dct["key_value_pairs"] = key_value_pairs + + if data: + dct["data"] = data + + constraints = row.get("constraints") + if constraints: + dct["constraints"] = [constraint.todict() for constraint in constraints] + + # json doesn't like Cell objects, so make it an array + dct["cell"] = np.asarray(dct["cell"]) + + if id is None: + id = self._nextid + nextid = id + 1 + else: + data = self.txn.get(f"{id}".encode("ascii")) + assert data is not None + + # Add the new entry + self.txn.put( + f"{id}".encode("ascii"), + zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + # only append if idx is not in ids + if id not in self.ids: + self.ids.append(id) + self.txn.put( + "nextid".encode("ascii"), + zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + # check if id is in removed ids and remove accordingly + if id in self.deleted_ids: + self.deleted_ids.remove(id) + self._write_deleted_ids() + + return id + + def _update( + self, + idx: int, + key_value_pairs: dict | None = None, + data: dict | None = None, + ): + # hack this to play nicely with ASE code + row = self._get_row(idx, include_data=True) + if data is not None or key_value_pairs is not None: + self._write( + atoms=row, key_value_pairs=key_value_pairs, data=data, id=idx + ) # Fixed E1123 by using id=idx + + def _write_deleted_ids(self): + self.txn.put( + "deleted_ids".encode("ascii"), + zlib.compress( + orjson.dumps(self.deleted_ids, option=orjson.OPT_SERIALIZE_NUMPY) + ), + ) + + def delete(self, ids: list[int]) -> None: + for idx in ids: + self.txn.delete(f"{idx}".encode("ascii")) + self.ids.remove(idx) + + self.deleted_ids += ids + self._write_deleted_ids() + + def _get_row(self, idx: int, include_data: bool = True): + if idx is None: + assert len(self.ids) == 1 + idx = self.ids[0] + data = self.txn.get(f"{idx}".encode("ascii")) + + if data is not None: + dct = orjson.loads(zlib.decompress(data)) + else: + raise KeyError(f"Id {idx} missing from the database!") + + if not include_data: + dct.pop("data", None) + + dct["id"] = idx + return ase.db.row.AtomsRow(dct) + + def _get_row_by_index(self, index: int, include_data: bool = True): + """Auxiliary function to get the ith entry, rather than a specific id""" + data = self.txn.get(f"{self.ids[index]}".encode("ascii")) + + if data is not None: + dct = orjson.loads(zlib.decompress(data)) + else: + raise KeyError(f"Id {id} missing from the database!") + + if not include_data: + dct.pop("data", None) + + dct["id"] = id + return ase.db.row.AtomsRow(dct) + + def _select( + self, + keys, + cmps: list[tuple[str, str, str]], + explain: bool = False, + _verbosity: int = 0, # Unused parameter marked with underscore + limit: int | None = None, + offset: int = 0, + sort: str | None = None, + include_data: bool = True, + _columns: str = "all", # Unused parameter marked with underscore + ): + if explain: + yield {"explain": (0, 0, 0, "scan table")} + return + + if sort is not None: + if sort[0] == "-": + reverse = True + sort = sort[1:] + else: + reverse = False + + rows = [] + missing = [] + for row in self._select(keys, cmps): + key = row.get(sort) + if key is None: + missing.append((0, row)) + else: + rows.append((key, row)) + + rows.sort(reverse=reverse, key=lambda x: x[0]) + rows += missing + + if limit: + rows = rows[offset : offset + limit] + for _, row in rows: + yield row + return + + if not limit: + limit = -offset - 1 + + cmps = [(key, ase.db.core.ops[op], val) for key, op, val in cmps] + n = 0 + for idx in self.ids: + if n - offset == limit: + return + row = self._get_row(idx, include_data=include_data) + + for key in keys: + if key not in row: + break + else: + for key, op, val in cmps: + if isinstance(key, int): + value = np.equal(row.numbers, key).sum() + else: + value = row.get(key) + if key == "pbc": + assert op in [ase.db.core.ops["="], ase.db.core.ops["!="]] + value = "".join("FT"[x] for x in value) + if value is None or not op(value, val): + break + else: + if n >= offset: + yield row + n += 1 + + @property + def metadata(self): + """Override abstract metadata method from Database class.""" + return self.db_metadata + + @property + def db_metadata(self): + """Load the metadata from the DB if present""" + if self._metadata is None: + metadata = self.txn.get("metadata".encode("ascii")) + if metadata is None: + self._metadata = {} + else: + self._metadata = orjson.loads(zlib.decompress(metadata)) + + return self._metadata.copy() + + @db_metadata.setter + def db_metadata(self, dct): + self._metadata = dct + + # Put the updated metadata dictionary + self.txn.put( + "metadata".encode("ascii"), + zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + @property + def _nextid(self): + """Get the id of the next row to be written""" + # Get the nextid + nextid_data = self.txn.get("nextid".encode("ascii")) + if nextid_data: + return orjson.loads(zlib.decompress(nextid_data)) + return 1 # Removed unnecessary else (R1705) + + def count(self, selection=None, **kwargs) -> int: + """Count rows. + + See the select() method for the selection syntax. Use db.count() or + len(db) to count all rows. + """ + if selection is not None: + n = 0 + for _row in self.select(selection, **kwargs): + n += 1 + return n + return len(self.ids) + + def _load_ids(self) -> None: + """Load ids from the DB + + Since ASE db ids are mostly 1-N integers, but can be missing entries + if ids have been deleted. To save space and operating under the assumption + that there will probably not be many deletions in most OCP datasets, + we just store the deleted ids. + """ + # Load the deleted ids + deleted_ids_data = self.txn.get("deleted_ids".encode("ascii")) + if deleted_ids_data is not None: + self.deleted_ids = orjson.loads(zlib.decompress(deleted_ids_data)) + + # Reconstruct the full id list + self.ids = [i for i in range(1, self._nextid) if i not in set(self.deleted_ids)] + + +# Placeholder for AtomsToGraphs class +# This is a minimal implementation without the full functionality +class AtomsToGraphs: + """Enhanced AtomsToGraphs implementation with proper property handling.""" + + def __init__( + self, + r_edges=False, + r_pbc=True, + r_energy=False, + r_forces=False, + r_stress=False, + r_data_keys=None, + **kwargs, + ): + self.r_edges = r_edges + self.r_pbc = r_pbc + self.r_energy = r_energy + self.r_forces = r_forces + self.r_stress = r_stress + self.r_data_keys = r_data_keys or {} + self.kwargs = kwargs + + def convert(self, atoms, sid=None): + """ + Convert ASE atoms to graph data format with proper property handling. + """ + from mace.tools.torch_geometric.data import Data + + # Create a minimal data object with required properties + data = Data() + + # Set positions + data.pos = torch.tensor(atoms.get_positions(), dtype=torch.float) + + # Set atomic numbers + data.atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long) + + # Set cell if available + if atoms.cell is not None: + data.cell = torch.tensor(atoms.get_cell(), dtype=torch.float) + + # Set PBC if requested + if self.r_pbc: + data.pbc = torch.tensor(atoms.get_pbc(), dtype=torch.bool) + + # Set energy if requested + if self.r_energy: + energy = self._get_property(atoms, "energy") + if energy is not None: + data.energy = torch.tensor(energy, dtype=torch.float) + + # Set forces if requested + if self.r_forces: + forces = self._get_property(atoms, "forces") + if forces is not None: + data.forces = torch.tensor(forces, dtype=torch.float) + + # Set stress if requested + if self.r_stress: + stress = self._get_property(atoms, "stress") + if stress is not None: + data.stress = torch.tensor(stress, dtype=torch.float) + + # Set sid if provided + if sid is not None: + data.sid = sid + + return data + + def _get_property(self, atoms, prop_name): + """Get property from atoms, checking custom names first then standard methods.""" + # Check if we have a custom name for this property + custom_name = self.r_data_keys.get(prop_name) + + # Try custom name in info dict + if custom_name and custom_name in atoms.info: + return atoms.info[custom_name] + + # Try custom name in arrays dict + if custom_name and custom_name in atoms.arrays: + return atoms.arrays[custom_name] + + # Try standard name in info dict + if prop_name in atoms.info: + return atoms.info[prop_name] + + # Try standard name in arrays dict + if prop_name in atoms.arrays: + return atoms.arrays[prop_name] + + # Try standard ASE methods + method_map = { + "energy": "get_potential_energy", + "forces": "get_forces", + "stress": "get_stress", + } + + if prop_name in method_map and hasattr(atoms, method_map[prop_name]): + try: + method = getattr(atoms, method_map[prop_name]) + return method() + except ( + AttributeError, + RuntimeError, + ) as exc: # Fixed W0718 by specifying exceptions + logging.debug(f"Error getting property {prop_name}: {exc}") + # Removed unnecessary pass (W0107) + + return None + + +# Placeholder for DataTransforms class +class DataTransforms: + """Minimal implementation of DataTransforms to satisfy dependencies.""" + + def __init__(self, transforms_config=None): + self.transforms_config = transforms_config or {} + + def __call__(self, data): + """Apply transforms to data""" + # No transforms applied in this minimal implementation + return data + + +class AseAtomsDataset(BaseDataset, ABC): + """ + This is an abstract Dataset that includes helpful utilities for turning + ASE atoms objects into OCP-usable data objects. This should not be instantiated directly + as get_atoms_object and load_dataset_get_ids are not implemented in this base class. + + Derived classes must add at least two things: + self.get_atoms_object(id): a function that takes an identifier and returns a corresponding atoms object + + self.load_dataset_get_ids(config: dict): This function is responsible for any initialization/loads + of the dataset and importantly must return a list of all possible identifiers that can be passed into + self.get_atoms_object(id) + + Identifiers need not be any particular type. + """ + + def __init__( + self, + config: dict, + atoms_transform: Callable[[ase.Atoms, Any], ase.Atoms] = apply_one_tags, + ) -> None: + super().__init__(config) + + a2g_args = config.get("a2g_args", {}) or {} + + # set default to False if not set by user, assuming otf_graph will be used + if "r_edges" not in a2g_args: + a2g_args["r_edges"] = False + + # Make sure we always include PBC info in the resulting atoms objects + a2g_args["r_pbc"] = True + self.a2g = AtomsToGraphs(**a2g_args) + + self.key_mapping = self.config.get("key_mapping", None) + self.transforms = DataTransforms(self.config.get("transforms", {})) + + self.atoms_transform = atoms_transform + + if self.config.get("keep_in_memory", False): + self.__getitem__ = cache(self.__getitem__) + + self.ids = self._load_dataset_get_ids(config) + self.num_samples = len(self.ids) + + if len(self.ids) == 0: + raise ValueError( + rf"No valid ase data found! \n" + f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" + ) + + def __getitem__(self, idx): # pylint: disable=method-hidden + # Handle slicing + if isinstance(idx, slice): + return [self[i] for i in range(*idx.indices(len(self)))] + + # Get atoms object via derived class method + atoms = self.get_atoms(self.ids[idx]) + + # Transform atoms object + if self.atoms_transform is not None: + atoms = self.atoms_transform( + atoms, **self.config.get("atoms_transform_args", {}) + ) + + sid = atoms.info.get("sid", self.ids[idx]) + fid = atoms.info.get("fid", torch.tensor([0])) + + # Convert to data object + data_object = self.a2g.convert(atoms, sid) + data_object.fid = fid + data_object.natoms = len(atoms) + + # apply linear reference + if self.a2g.r_energy is True and self.lin_ref is not None: + data_object.energy -= sum(self.lin_ref[data_object.atomic_numbers.long()]) + + # Transform data object + data_object = self.transforms(data_object) + + if self.key_mapping is not None: + data_object = rename_data_object_keys(data_object, self.key_mapping) + + if self.config.get("include_relaxed_energy", False): + data_object.energy_relaxed = self.get_relaxed_energy(self.ids[idx]) + + return data_object + + @abstractmethod + def get_atoms(self, idx: str | int) -> ase.Atoms: + # This function should return an ASE atoms object. + raise NotImplementedError( + "Returns an ASE atoms object. Derived classes should implement this function." + ) + + @abstractmethod + def _load_dataset_get_ids(self, config): + # This function should return a list of ids that can be used to index into the database + raise NotImplementedError( + "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." + ) + + def get_relaxed_energy(self, identifier): + raise NotImplementedError( + "Reading relaxed energy from trajectory or file is not implemented with this dataset. " + "If relaxed energies are saved with the atoms info dictionary, they can be used by passing the keys in " + "the r_data_keys argument under a2g_args." + ) + + def get_metadata(self, attr, idx): + # try the parent method + metadata = super().get_metadata(attr, idx) + if metadata is not None: + return metadata + # try to resolve it here + if attr != "natoms": + return None + if isinstance(idx, (list, np.ndarray)): + return np.array([self.get_metadata(attr, i) for i in idx]) + return len(self.get_atoms(idx)) + + +class AseDBDataset(AseAtomsDataset): + """ + This Dataset connects to an ASE Database, allowing the storage of atoms objects + with a variety of backends including JSON, SQLite, and database server options. + """ + + def _load_dataset_get_ids(self, config: dict) -> list[int]: + if isinstance(config["src"], list): + filepaths = [] + for path in sorted(config["src"]): + if os.path.isdir(path): + filepaths.extend(sorted(glob(f"{path}/*"))) + elif os.path.isfile(path): + filepaths.append(path) + else: + raise RuntimeError(f"Error reading dataset in {path}!") + elif os.path.isfile(config["src"]): + filepaths = [config["src"]] + elif os.path.isdir(config["src"]): + filepaths = sorted(glob(f'{config["src"]}/*')) + else: + filepaths = sorted(glob(config["src"])) + + self.dbs = [] + + for path in filepaths: + try: + self.dbs.append(self.connect_db(path, config.get("connect_args", {}))) + except ValueError: + logging.debug( + f"Tried to connect to {path} but it's not an ASE database!" + ) + + self.select_args = config.get("select_args", {}) + if self.select_args is None: + self.select_args = {} + + # Get all unique IDs from the databases + self.db_ids = [] + for db in self.dbs: + if hasattr(db, "ids") and self.select_args == {}: + self.db_ids.append(db.ids) + else: + # this is the slow alternative + self.db_ids.append([row.id for row in db.select(**self.select_args)]) + + idlens = [len(ids) for ids in self.db_ids] + self._idlen_cumulative = np.cumsum(idlens).tolist() + + return list(range(sum(idlens))) + + def get_atoms(self, idx: int) -> ase.Atoms: + """Get atoms object corresponding to datapoint idx. + Args: + idx (int): index in dataset + + Returns: + atoms: ASE atoms corresponding to datapoint idx + """ + # Figure out which db this should be indexed from + db_idx = bisect.bisect(self._idlen_cumulative, idx) + + # Extract index of element within that db + el_idx = idx + if db_idx != 0: + el_idx = idx - self._idlen_cumulative[db_idx - 1] + assert el_idx >= 0 + + # Use a wrapper method to avoid protected access warning + atoms_row = self.get_row_from_db(db_idx, el_idx) + + # Convert to atoms object + atoms = atoms_row.toatoms() + + # Put data back into atoms info + if isinstance(atoms_row.data, dict): + atoms.info.update(atoms_row.data) + + # Add key-value pairs directly to atoms.info + if hasattr(atoms_row, "key_value_pairs") and atoms_row.key_value_pairs: + atoms.info.update(atoms_row.key_value_pairs) + + # Create a SinglePointCalculator to attach energy, forces and stress to atoms + calc_kwargs = {} + + # Check for energy, forces, stress in atoms_row and store in info & calc_kwargs + for prop in ["energy", "forces", "stress", "free_energy"]: + if hasattr(atoms_row, prop) and getattr(atoms_row, prop) is not None: + value = getattr(atoms_row, prop) + calc_kwargs[prop] = value + atoms.info[prop] = value + + # If we have custom data mappings, copy the standard properties to the custom names + a2g_args = self.config.get("a2g_args", {}) or {} + r_data_keys = a2g_args.get("r_data_keys", {}) + if r_data_keys: + # Map from standard names to custom names (in reverse of how they'll be used) + for custom_key, standard_key in r_data_keys.items(): + if standard_key in atoms.info: + atoms.info[custom_key] = atoms.info[standard_key] + elif standard_key in atoms.arrays: + atoms.arrays[custom_key] = atoms.arrays[standard_key] + + # Create calculator if we have any properties + if calc_kwargs: + from ase.calculators.singlepoint import SinglePointCalculator + + calc = SinglePointCalculator(atoms, **calc_kwargs) + atoms.calc = calc + + return atoms + + def get_row_from_db(self, db_idx, el_idx): + """Get a row from the database at the given indices.""" + db = self.dbs[db_idx] + row_id = self.db_ids[db_idx][el_idx] + if isinstance(db, LMDBDatabase): + return db._get_row(row_id) # pylint: disable=protected-access + return db.get(row_id) + + @staticmethod + def connect_db( + address: str | Path, connect_args: dict | None = None + ) -> ase.db.core.Database: + if connect_args is None: + connect_args = {} + db_type = connect_args.get("type", "extract_from_name") + if db_type in ("lmdb", "aselmdb") or ( + db_type == "extract_from_name" + and str(address).rsplit(".", maxsplit=1)[-1] in ("lmdb", "aselmdb") + ): + return LMDBDatabase(address, readonly=True, **connect_args) + + return ase.db.connect(address, **connect_args) + + def __del__(self): + for db in self.dbs: + if hasattr(db, "close"): + db.close() + + def sample_property_metadata( + self, + ) -> dict: # Removed unused argument num_samples (W0613) + """ + Sample property metadata from the database. + + This method was previously using the copy module which is now removed. + """ + logging.warning( + "You specified a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" + ) + if self.dbs[0].metadata == {}: + return {} + + # Fixed unnecessary comprehension (R1721) + return dict(self.dbs[0].metadata.items()) diff --git a/mace-bench/3rdparty/mace/mace/tools/finetuning_utils.py b/mace-bench/3rdparty/mace/mace/tools/finetuning_utils.py index f76aa90..8df0b0d 100644 --- a/mace-bench/3rdparty/mace/mace/tools/finetuning_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/finetuning_utils.py @@ -1,204 +1,204 @@ -import torch - -from mace.tools.utils import AtomicNumberTable - - -def load_foundations_elements( - model: torch.nn.Module, - model_foundations: torch.nn.Module, - table: AtomicNumberTable, - load_readout=False, - use_shift=True, - use_scale=True, - max_L=2, -): - """ - Load the foundations of a model into a model for fine-tuning. - """ - assert model_foundations.r_max == model.r_max - z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) - model_heads = model.heads - new_z_table = table - num_species_foundations = len(z_table.zs) - num_channels_foundation = ( - model_foundations.node_embedding.linear.weight.shape[0] - // num_species_foundations - ) - indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] - num_radial = model.radial_embedding.out_dim - num_species = len(indices_weights) - max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access - model.node_embedding.linear.weight = torch.nn.Parameter( - model_foundations.node_embedding.linear.weight.view( - num_species_foundations, -1 - )[indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": - model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( - model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() - ) - for i in range(int(model.num_interactions)): - model.interactions[i].linear_up.weight = torch.nn.Parameter( - model_foundations.interactions[i].linear_up.weight.clone() - ) - model.interactions[i].avg_num_neighbors = model_foundations.interactions[ - i - ].avg_num_neighbors - for j in range(4): # Assuming 4 layers in conv_tp_weights, - layer_name = f"layer{j}" - if j == 0: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ) - .weight[:num_radial, :] - .clone() - ) - ) - else: - getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].conv_tp_weights, - layer_name, - ).weight.clone() - ) - ) - - model.interactions[i].linear.weight = torch.nn.Parameter( - model_foundations.interactions[i].linear.weight.clone() - ) - if model.interactions[i].__class__.__name__ in [ - "RealAgnosticResidualInteractionBlock", - "RealAgnosticDensityResidualInteractionBlock", - ]: - model.interactions[i].skip_tp.weight = torch.nn.Parameter( - model_foundations.interactions[i] - .skip_tp.weight.reshape( - num_channels_foundation, - num_species_foundations, - num_channels_foundation, - )[:, indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - else: - model.interactions[i].skip_tp.weight = torch.nn.Parameter( - model_foundations.interactions[i] - .skip_tp.weight.reshape( - num_channels_foundation, - (max_ell + 1), - num_species_foundations, - num_channels_foundation, - )[:, :, indices_weights, :] - .flatten() - .clone() - / (num_species_foundations / num_species) ** 0.5 - ) - if model.interactions[i].__class__.__name__ in [ - "RealAgnosticDensityInteractionBlock", - "RealAgnosticDensityResidualInteractionBlock", - ]: - # Assuming only 1 layer in density_fn - getattr(model.interactions[i].density_fn, "layer0").weight = ( - torch.nn.Parameter( - getattr( - model_foundations.interactions[i].density_fn, - "layer0", - ).weight.clone() - ) - ) - # Transferring products - for i in range(2): # Assuming 2 products modules - max_range = max_L + 1 if i == 0 else 1 - for j in range(max_range): # Assuming 3 contractions in symmetric_contractions - model.products[i].symmetric_contractions.contractions[j].weights_max = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights_max[indices_weights, :, :] - .clone() - ) - ) - - for k in range(2): # Assuming 2 weights in each contraction - model.products[i].symmetric_contractions.contractions[j].weights[k] = ( - torch.nn.Parameter( - model_foundations.products[i] - .symmetric_contractions.contractions[j] - .weights[k][indices_weights, :, :] - .clone() - ) - ) - - model.products[i].linear.weight = torch.nn.Parameter( - model_foundations.products[i].linear.weight.clone() - ) - - if load_readout: - # Transferring readouts - model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() - model_readouts_zero_linear_weight = ( - model_foundations.readouts[0] - .linear.weight.view(num_channels_foundation, -1) - .repeat(1, len(model_heads)) - .flatten() - .clone() - ) - model.readouts[0].linear.weight = torch.nn.Parameter( - model_readouts_zero_linear_weight - ) - - shape_input_1 = ( - model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps - ) - shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps - model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() - model_readouts_one_linear_1_weight = ( - model_foundations.readouts[1] - .linear_1.weight.view(num_channels_foundation, -1) - .repeat(1, len(model_heads)) - .flatten() - .clone() - ) - model.readouts[1].linear_1.weight = torch.nn.Parameter( - model_readouts_one_linear_1_weight - ) - model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() - model_readouts_one_linear_2_weight = model_foundations.readouts[ - 1 - ].linear_2.weight.view(shape_input_1, -1).repeat( - len(model_heads), len(model_heads) - ).flatten().clone() / ( - ((shape_input_1) / (shape_output_1)) ** 0.5 - ) - model.readouts[1].linear_2.weight = torch.nn.Parameter( - model_readouts_one_linear_2_weight - ) - if model_foundations.scale_shift is not None: - if use_scale: - model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( - len(model_heads) - ).clone() - if use_shift: - model.scale_shift.shift = model_foundations.scale_shift.shift.repeat( - len(model_heads) - ).clone() - return model - - -def load_foundations( - model, - model_foundations, -): - for name, param in model_foundations.named_parameters(): - if name in model.state_dict().keys(): - if "readouts" not in name: - model.state_dict()[name].copy_(param) - return model +import torch + +from mace.tools.utils import AtomicNumberTable + + +def load_foundations_elements( + model: torch.nn.Module, + model_foundations: torch.nn.Module, + table: AtomicNumberTable, + load_readout=False, + use_shift=True, + use_scale=True, + max_L=2, +): + """ + Load the foundations of a model into a model for fine-tuning. + """ + assert model_foundations.r_max == model.r_max + z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) + model_heads = model.heads + new_z_table = table + num_species_foundations = len(z_table.zs) + num_channels_foundation = ( + model_foundations.node_embedding.linear.weight.shape[0] + // num_species_foundations + ) + indices_weights = [z_table.z_to_index(z) for z in new_z_table.zs] + num_radial = model.radial_embedding.out_dim + num_species = len(indices_weights) + max_ell = model.spherical_harmonics._lmax # pylint: disable=protected-access + model.node_embedding.linear.weight = torch.nn.Parameter( + model_foundations.node_embedding.linear.weight.view( + num_species_foundations, -1 + )[indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + if model.radial_embedding.bessel_fn.__class__.__name__ == "BesselBasis": + model.radial_embedding.bessel_fn.bessel_weights = torch.nn.Parameter( + model_foundations.radial_embedding.bessel_fn.bessel_weights.clone() + ) + for i in range(int(model.num_interactions)): + model.interactions[i].linear_up.weight = torch.nn.Parameter( + model_foundations.interactions[i].linear_up.weight.clone() + ) + model.interactions[i].avg_num_neighbors = model_foundations.interactions[ + i + ].avg_num_neighbors + for j in range(4): # Assuming 4 layers in conv_tp_weights, + layer_name = f"layer{j}" + if j == 0: + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ) + .weight[:num_radial, :] + .clone() + ) + ) + else: + getattr(model.interactions[i].conv_tp_weights, layer_name).weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].conv_tp_weights, + layer_name, + ).weight.clone() + ) + ) + + model.interactions[i].linear.weight = torch.nn.Parameter( + model_foundations.interactions[i].linear.weight.clone() + ) + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticResidualInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + num_species_foundations, + num_channels_foundation, + )[:, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + else: + model.interactions[i].skip_tp.weight = torch.nn.Parameter( + model_foundations.interactions[i] + .skip_tp.weight.reshape( + num_channels_foundation, + (max_ell + 1), + num_species_foundations, + num_channels_foundation, + )[:, :, indices_weights, :] + .flatten() + .clone() + / (num_species_foundations / num_species) ** 0.5 + ) + if model.interactions[i].__class__.__name__ in [ + "RealAgnosticDensityInteractionBlock", + "RealAgnosticDensityResidualInteractionBlock", + ]: + # Assuming only 1 layer in density_fn + getattr(model.interactions[i].density_fn, "layer0").weight = ( + torch.nn.Parameter( + getattr( + model_foundations.interactions[i].density_fn, + "layer0", + ).weight.clone() + ) + ) + # Transferring products + for i in range(2): # Assuming 2 products modules + max_range = max_L + 1 if i == 0 else 1 + for j in range(max_range): # Assuming 3 contractions in symmetric_contractions + model.products[i].symmetric_contractions.contractions[j].weights_max = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights_max[indices_weights, :, :] + .clone() + ) + ) + + for k in range(2): # Assuming 2 weights in each contraction + model.products[i].symmetric_contractions.contractions[j].weights[k] = ( + torch.nn.Parameter( + model_foundations.products[i] + .symmetric_contractions.contractions[j] + .weights[k][indices_weights, :, :] + .clone() + ) + ) + + model.products[i].linear.weight = torch.nn.Parameter( + model_foundations.products[i].linear.weight.clone() + ) + + if load_readout: + # Transferring readouts + model_readouts_zero_linear_weight = model.readouts[0].linear.weight.clone() + model_readouts_zero_linear_weight = ( + model_foundations.readouts[0] + .linear.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_heads)) + .flatten() + .clone() + ) + model.readouts[0].linear.weight = torch.nn.Parameter( + model_readouts_zero_linear_weight + ) + + shape_input_1 = ( + model_foundations.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + ) + shape_output_1 = model.readouts[1].linear_1.__dict__["irreps_out"].num_irreps + model_readouts_one_linear_1_weight = model.readouts[1].linear_1.weight.clone() + model_readouts_one_linear_1_weight = ( + model_foundations.readouts[1] + .linear_1.weight.view(num_channels_foundation, -1) + .repeat(1, len(model_heads)) + .flatten() + .clone() + ) + model.readouts[1].linear_1.weight = torch.nn.Parameter( + model_readouts_one_linear_1_weight + ) + model_readouts_one_linear_2_weight = model.readouts[1].linear_2.weight.clone() + model_readouts_one_linear_2_weight = model_foundations.readouts[ + 1 + ].linear_2.weight.view(shape_input_1, -1).repeat( + len(model_heads), len(model_heads) + ).flatten().clone() / ( + ((shape_input_1) / (shape_output_1)) ** 0.5 + ) + model.readouts[1].linear_2.weight = torch.nn.Parameter( + model_readouts_one_linear_2_weight + ) + if model_foundations.scale_shift is not None: + if use_scale: + model.scale_shift.scale = model_foundations.scale_shift.scale.repeat( + len(model_heads) + ).clone() + if use_shift: + model.scale_shift.shift = model_foundations.scale_shift.shift.repeat( + len(model_heads) + ).clone() + return model + + +def load_foundations( + model, + model_foundations, +): + for name, param in model_foundations.named_parameters(): + if name in model.state_dict().keys(): + if "readouts" not in name: + model.state_dict()[name].copy_(param) + return model diff --git a/mace-bench/3rdparty/mace/mace/tools/model_script_utils.py b/mace-bench/3rdparty/mace/mace/tools/model_script_utils.py index e577524..c9de08b 100644 --- a/mace-bench/3rdparty/mace/mace/tools/model_script_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/model_script_utils.py @@ -1,265 +1,265 @@ -import ast -import logging - -import numpy as np -from e3nn import o3 - -from mace import modules -from mace.tools.finetuning_utils import load_foundations_elements -from mace.tools.scripts_utils import extract_config_mace_model -from mace.tools.utils import AtomicNumberTable - - -def configure_model( - args, - train_loader, - atomic_energies, - model_foundation=None, - heads=None, - z_table=None, - head_configs=None, -): - # Selecting outputs - compute_virials = args.loss == "virials" - compute_stress = args.loss in ("stress", "huber", "universal") - - if compute_virials: - args.compute_virials = True - args.error_table = "PerAtomRMSEstressvirials" - elif compute_stress: - args.compute_stress = True - args.error_table = "PerAtomRMSEstressvirials" - - output_args = { - "energy": args.compute_energy, - "forces": args.compute_forces, - "virials": compute_virials, - "stress": compute_stress, - "dipoles": args.compute_dipole, - } - logging.info( - f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" - ) - logging.info("===========MODEL DETAILS===========") - - if args.scaling == "no_scaling": - args.std = 1.0 - if head_configs is not None: - for head_config in head_configs: - head_config.std = 1.0 - logging.info("No scaling selected") - - if ( - head_configs is not None - and args.std is not None - and not isinstance(args.std, list) - ): - atomic_inter_scale = [] - for head_config in head_configs: - if hasattr(head_config, "std") and head_config.std is not None: - atomic_inter_scale.append(head_config.std) - elif args.std is not None: - atomic_inter_scale.append( - args.std if isinstance(args.std, float) else 1.0 - ) - args.std = atomic_inter_scale - - elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": - args.mean, args.std = modules.scaling_classes[args.scaling]( - train_loader, atomic_energies - ) - - # Build model - if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]: - logging.info("Loading FOUNDATION model") - model_config_foundation = extract_config_mace_model(model_foundation) - model_config_foundation["atomic_energies"] = atomic_energies - - if args.foundation_model_elements: - foundation_z_table = AtomicNumberTable( - [int(z) for z in model_foundation.atomic_numbers] - ) - model_config_foundation["atomic_numbers"] = foundation_z_table.zs - model_config_foundation["num_elements"] = len(foundation_z_table) - z_table = foundation_z_table - logging.info( - f"Using all elements from foundation model: {foundation_z_table.zs}" - ) - else: - model_config_foundation["atomic_numbers"] = z_table.zs - model_config_foundation["num_elements"] = len(z_table) - logging.info(f"Using filtered elements: {z_table.zs}") - - args.max_L = model_config_foundation["hidden_irreps"].lmax - - if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": - model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) - else: - model_config_foundation["atomic_inter_shift"] = ( - _determine_atomic_inter_shift(args.mean, heads) - ) - model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) - args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] - args.model = "FoundationMACE" - model_config_foundation["heads"] = heads - model_config = model_config_foundation - - logging.info("Model configuration extracted from foundation model") - logging.info("Using universal loss function for fine-tuning") - logging.info( - f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" - ) - logging.info( - f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" - ) - logging.info( - f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" - ) - logging.info( - f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" - ) - else: - logging.info("Building model") - logging.info( - f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" - ) - logging.info( - f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" - ) - logging.info( - f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" - ) - logging.info( - f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" - ) - logging.info( - f"Distance transform for radial basis functions: {args.distance_transform}" - ) - - assert ( - len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 - ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" - - logging.info(f"Hidden irreps: {args.hidden_irreps}") - - model_config = dict( - r_max=args.r_max, - num_bessel=args.num_radial_basis, - num_polynomial_cutoff=args.num_cutoff_basis, - max_ell=args.max_ell, - interaction_cls=modules.interaction_classes[args.interaction], - num_interactions=args.num_interactions, - num_elements=len(z_table), - hidden_irreps=o3.Irreps(args.hidden_irreps), - atomic_energies=atomic_energies, - avg_num_neighbors=args.avg_num_neighbors, - atomic_numbers=z_table.zs, - ) - model_config_foundation = None - - model = _build_model(args, model_config, model_config_foundation, heads) - - if model_foundation is not None: - model = load_foundations_elements( - model, - model_foundation, - z_table, - load_readout=args.foundation_filter_elements, - max_L=args.max_L, - ) - - return model, output_args - - -def _determine_atomic_inter_shift(mean, heads): - if isinstance(mean, np.ndarray): - if mean.size == 1: - return mean.item() - if mean.size == len(heads): - return mean.tolist() - logging.info("Mean not in correct format, using default value of 0.0") - return [0.0] * len(heads) - if isinstance(mean, list) and len(mean) == len(heads): - return mean - if isinstance(mean, float): - return [mean] * len(heads) - logging.info("Mean not in correct format, using default value of 0.0") - return [0.0] * len(heads) - - -def _build_model( - args, model_config, model_config_foundation, heads -): # pylint: disable=too-many-return-statements - if args.model == "MACE": - if args.interaction_first not in [ - "RealAgnosticInteractionBlock", - "RealAgnosticDensityInteractionBlock", - ]: - args.interaction_first = "RealAgnosticInteractionBlock" - return modules.ScaleShiftMACE( - **model_config, - pair_repulsion=args.pair_repulsion, - distance_transform=args.distance_transform, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=[0.0] * len(heads), - radial_MLP=ast.literal_eval(args.radial_MLP), - radial_type=args.radial_type, - heads=heads, - ) - if args.model == "ScaleShiftMACE": - return modules.ScaleShiftMACE( - **model_config, - pair_repulsion=args.pair_repulsion, - distance_transform=args.distance_transform, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[args.interaction_first], - MLP_irreps=o3.Irreps(args.MLP_irreps), - atomic_inter_scale=args.std, - atomic_inter_shift=args.mean, - radial_MLP=ast.literal_eval(args.radial_MLP), - radial_type=args.radial_type, - heads=heads, - ) - if args.model == "FoundationMACE": - return modules.ScaleShiftMACE(**model_config_foundation) - if args.model == "ScaleShiftBOTNet": - # say it is deprecated - raise RuntimeError("ScaleShiftBOTNet is deprecated, use MACE instead") - if args.model == "BOTNet": - raise RuntimeError("BOTNet is deprecated, use MACE instead") - if args.model == "AtomicDipolesMACE": - assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" - assert ( - args.error_table == "DipoleRMSE" - ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" - return modules.AtomicDipolesMACE( - **model_config, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - ) - if args.model == "EnergyDipolesMACE": - assert ( - args.loss == "energy_forces_dipole" - ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" - assert ( - args.error_table == "EnergyDipoleRMSE" - ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" - return modules.EnergyDipolesMACE( - **model_config, - correlation=args.correlation, - gate=modules.gate_dict[args.gate], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticInteractionBlock" - ], - MLP_irreps=o3.Irreps(args.MLP_irreps), - ) - raise RuntimeError(f"Unknown model: '{args.model}'") +import ast +import logging + +import numpy as np +from e3nn import o3 + +from mace import modules +from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.scripts_utils import extract_config_mace_model +from mace.tools.utils import AtomicNumberTable + + +def configure_model( + args, + train_loader, + atomic_energies, + model_foundation=None, + heads=None, + z_table=None, + head_configs=None, +): + # Selecting outputs + compute_virials = args.loss == "virials" + compute_stress = args.loss in ("stress", "huber", "universal") + + if compute_virials: + args.compute_virials = True + args.error_table = "PerAtomRMSEstressvirials" + elif compute_stress: + args.compute_stress = True + args.error_table = "PerAtomRMSEstressvirials" + + output_args = { + "energy": args.compute_energy, + "forces": args.compute_forces, + "virials": compute_virials, + "stress": compute_stress, + "dipoles": args.compute_dipole, + } + logging.info( + f"During training the following quantities will be reported: {', '.join([f'{report}' for report, value in output_args.items() if value])}" + ) + logging.info("===========MODEL DETAILS===========") + + if args.scaling == "no_scaling": + args.std = 1.0 + if head_configs is not None: + for head_config in head_configs: + head_config.std = 1.0 + logging.info("No scaling selected") + + if ( + head_configs is not None + and args.std is not None + and not isinstance(args.std, list) + ): + atomic_inter_scale = [] + for head_config in head_configs: + if hasattr(head_config, "std") and head_config.std is not None: + atomic_inter_scale.append(head_config.std) + elif args.std is not None: + atomic_inter_scale.append( + args.std if isinstance(args.std, float) else 1.0 + ) + args.std = atomic_inter_scale + + elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": + args.mean, args.std = modules.scaling_classes[args.scaling]( + train_loader, atomic_energies + ) + + # Build model + if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]: + logging.info("Loading FOUNDATION model") + model_config_foundation = extract_config_mace_model(model_foundation) + model_config_foundation["atomic_energies"] = atomic_energies + + if args.foundation_model_elements: + foundation_z_table = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + model_config_foundation["atomic_numbers"] = foundation_z_table.zs + model_config_foundation["num_elements"] = len(foundation_z_table) + z_table = foundation_z_table + logging.info( + f"Using all elements from foundation model: {foundation_z_table.zs}" + ) + else: + model_config_foundation["atomic_numbers"] = z_table.zs + model_config_foundation["num_elements"] = len(z_table) + logging.info(f"Using filtered elements: {z_table.zs}") + + args.max_L = model_config_foundation["hidden_irreps"].lmax + + if args.model == "MACE" and model_foundation.__class__.__name__ == "MACE": + model_config_foundation["atomic_inter_shift"] = [0.0] * len(heads) + else: + model_config_foundation["atomic_inter_shift"] = ( + _determine_atomic_inter_shift(args.mean, heads) + ) + model_config_foundation["atomic_inter_scale"] = [1.0] * len(heads) + args.avg_num_neighbors = model_config_foundation["avg_num_neighbors"] + args.model = "FoundationMACE" + model_config_foundation["heads"] = heads + model_config = model_config_foundation + + logging.info("Model configuration extracted from foundation model") + logging.info("Using universal loss function for fine-tuning") + logging.info( + f"Message passing with hidden irreps {model_config_foundation['hidden_irreps']})" + ) + logging.info( + f"{model_config_foundation['num_interactions']} layers, each with correlation order: {model_config_foundation['correlation']} (body order: {model_config_foundation['correlation']+1}) and spherical harmonics up to: l={model_config_foundation['max_ell']}" + ) + logging.info( + f"Radial cutoff: {model_config_foundation['r_max']} A (total receptive field for each atom: {model_config_foundation['r_max'] * model_config_foundation['num_interactions']} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {model_config_foundation['distance_transform']}" + ) + else: + logging.info("Building model") + logging.info( + f"Message passing with {args.num_channels} channels and max_L={args.max_L} ({args.hidden_irreps})" + ) + logging.info( + f"{args.num_interactions} layers, each with correlation order: {args.correlation} (body order: {args.correlation+1}) and spherical harmonics up to: l={args.max_ell}" + ) + logging.info( + f"{args.num_radial_basis} radial and {args.num_cutoff_basis} basis functions" + ) + logging.info( + f"Radial cutoff: {args.r_max} A (total receptive field for each atom: {args.r_max * args.num_interactions} A)" + ) + logging.info( + f"Distance transform for radial basis functions: {args.distance_transform}" + ) + + assert ( + len({irrep.mul for irrep in o3.Irreps(args.hidden_irreps)}) == 1 + ), "All channels must have the same dimension, use the num_channels and max_L keywords to specify the number of channels and the maximum L" + + logging.info(f"Hidden irreps: {args.hidden_irreps}") + + model_config = dict( + r_max=args.r_max, + num_bessel=args.num_radial_basis, + num_polynomial_cutoff=args.num_cutoff_basis, + max_ell=args.max_ell, + interaction_cls=modules.interaction_classes[args.interaction], + num_interactions=args.num_interactions, + num_elements=len(z_table), + hidden_irreps=o3.Irreps(args.hidden_irreps), + atomic_energies=atomic_energies, + avg_num_neighbors=args.avg_num_neighbors, + atomic_numbers=z_table.zs, + ) + model_config_foundation = None + + model = _build_model(args, model_config, model_config_foundation, heads) + + if model_foundation is not None: + model = load_foundations_elements( + model, + model_foundation, + z_table, + load_readout=args.foundation_filter_elements, + max_L=args.max_L, + ) + + return model, output_args + + +def _determine_atomic_inter_shift(mean, heads): + if isinstance(mean, np.ndarray): + if mean.size == 1: + return mean.item() + if mean.size == len(heads): + return mean.tolist() + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + if isinstance(mean, list) and len(mean) == len(heads): + return mean + if isinstance(mean, float): + return [mean] * len(heads) + logging.info("Mean not in correct format, using default value of 0.0") + return [0.0] * len(heads) + + +def _build_model( + args, model_config, model_config_foundation, heads +): # pylint: disable=too-many-return-statements + if args.model == "MACE": + if args.interaction_first not in [ + "RealAgnosticInteractionBlock", + "RealAgnosticDensityInteractionBlock", + ]: + args.interaction_first = "RealAgnosticInteractionBlock" + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=[0.0] * len(heads), + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "ScaleShiftMACE": + return modules.ScaleShiftMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) + if args.model == "FoundationMACE": + return modules.ScaleShiftMACE(**model_config_foundation) + if args.model == "ScaleShiftBOTNet": + # say it is deprecated + raise RuntimeError("ScaleShiftBOTNet is deprecated, use MACE instead") + if args.model == "BOTNet": + raise RuntimeError("BOTNet is deprecated, use MACE instead") + if args.model == "AtomicDipolesMACE": + assert args.loss == "dipole", "Use dipole loss with AtomicDipolesMACE model" + assert ( + args.error_table == "DipoleRMSE" + ), "Use error_table DipoleRMSE with AtomicDipolesMACE model" + return modules.AtomicDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + if args.model == "EnergyDipolesMACE": + assert ( + args.loss == "energy_forces_dipole" + ), "Use energy_forces_dipole loss with EnergyDipolesMACE model" + assert ( + args.error_table == "EnergyDipoleRMSE" + ), "Use error_table EnergyDipoleRMSE with AtomicDipolesMACE model" + return modules.EnergyDipolesMACE( + **model_config, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticInteractionBlock" + ], + MLP_irreps=o3.Irreps(args.MLP_irreps), + ) + raise RuntimeError(f"Unknown model: '{args.model}'") diff --git a/mace-bench/3rdparty/mace/mace/tools/multihead_tools.py b/mace-bench/3rdparty/mace/mace/tools/multihead_tools.py index 1a12416..f321af3 100644 --- a/mace-bench/3rdparty/mace/mace/tools/multihead_tools.py +++ b/mace-bench/3rdparty/mace/mace/tools/multihead_tools.py @@ -1,200 +1,200 @@ -import argparse -import ast -import dataclasses -import logging -import os -import urllib.request -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import torch - -from mace.cli.fine_tuning_select import ( - FilteringType, - SelectionSettings, - SubselectType, - select_samples, -) -from mace.data import KeySpecification -from mace.tools.scripts_utils import SubsetCollection, get_dataset_from_xyz - - -@dataclasses.dataclass -class HeadConfig: - head_name: str - key_specification: KeySpecification - train_file: Optional[Union[str, List[str]]] = None - valid_file: Optional[Union[str, List[str]]] = None - test_file: Optional[str] = None - test_dir: Optional[str] = None - E0s: Optional[Any] = None - statistics_file: Optional[str] = None - valid_fraction: Optional[float] = None - config_type_weights: Optional[Dict[str, float]] = None - keep_isolated_atoms: Optional[bool] = None - atomic_numbers: Optional[Union[List[int], List[str]]] = None - mean: Optional[float] = None - std: Optional[float] = None - avg_num_neighbors: Optional[float] = None - compute_avg_num_neighbors: Optional[bool] = None - collections: Optional[SubsetCollection] = None - train_loader: Optional[torch.utils.data.DataLoader] = None - z_table: Optional[Any] = None - atomic_energies_dict: Optional[Dict[str, float]] = None - - -def dict_head_to_dataclass( - head: Dict[str, Any], head_name: str, args: argparse.Namespace -) -> HeadConfig: - """Convert head dictionary to HeadConfig dataclass.""" - # parser+head args that have no defaults but are required - if (args.train_file is None) and (head.get("train_file", None) is None): - raise ValueError( - "train file is not set in the head config yaml or via command line args" - ) - - return HeadConfig( - head_name=head_name, - train_file=head.get("train_file", args.train_file), - valid_file=head.get("valid_file", args.valid_file), - test_file=head.get("test_file", None), - test_dir=head.get("test_dir", None), - E0s=head.get("E0s", args.E0s), - statistics_file=head.get("statistics_file", args.statistics_file), - valid_fraction=head.get("valid_fraction", args.valid_fraction), - config_type_weights=head.get("config_type_weights", args.config_type_weights), - compute_avg_num_neighbors=head.get( - "compute_avg_num_neighbors", args.compute_avg_num_neighbors - ), - atomic_numbers=head.get("atomic_numbers", args.atomic_numbers), - mean=head.get("mean", args.mean), - std=head.get("std", args.std), - avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), - key_specification=head["key_specification"], - keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), - ) - - -def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: - """Prepare a default head from args.""" - return { - "Default": { - "train_file": args.train_file, - "valid_file": args.valid_file, - "test_file": args.test_file, - "test_dir": args.test_dir, - "E0s": args.E0s, - "statistics_file": args.statistics_file, - "key_specification": args.key_specification, - "valid_fraction": args.valid_fraction, - "config_type_weights": args.config_type_weights, - "keep_isolated_atoms": args.keep_isolated_atoms, - } - } - - -def prepare_pt_head( - args: argparse.Namespace, - pt_keyspec: KeySpecification, - foundation_model_num_neighbours: float, -) -> Dict[str, Any]: - """Prepare a pretraining head from args.""" - if ( - args.foundation_model in ["small", "medium", "large"] - or args.pt_train_file == "mp" - ): - logging.info( - "Using foundation model for multiheads finetuning with Materials Project data" - ) - pt_keyspec.update( - info_keys={"energy": "energy", "stress": "stress"}, - arrays_keys={"forces": "forces"}, - ) - pt_head = { - "train_file": "mp", - "E0s": "foundation", - "statistics_file": None, - "key_specification": pt_keyspec, - "avg_num_neighbors": foundation_model_num_neighbours, - "compute_avg_num_neighbors": False, - } - else: - pt_head = { - "train_file": args.pt_train_file, - "valid_file": args.pt_valid_file, - "E0s": "foundation", - "statistics_file": args.statistics_file, - "valid_fraction": args.valid_fraction, - "key_specification": pt_keyspec, - "avg_num_neighbors": foundation_model_num_neighbours, - "keep_isolated_atoms": args.keep_isolated_atoms, - "compute_avg_num_neighbors": False, - } - - return pt_head - - -def assemble_mp_data( - args: argparse.Namespace, - head_config_pt: HeadConfig, - tag: str, -) -> SubsetCollection: - """Assemble Materials Project data for fine-tuning.""" - try: - checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" - cache_dir = ( - Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace" - ) - checkpoint_url_name = "".join( - c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" - ) - cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" - if not os.path.isfile(cached_dataset_path): - os.makedirs(cache_dir, exist_ok=True) - # download and save to disk - logging.info("Downloading MP structures for finetuning") - _, http_msg = urllib.request.urlretrieve( - checkpoint_url, cached_dataset_path - ) - if "Content-Type: text/html" in http_msg: - raise RuntimeError( - f"Dataset download failed, please check the URL {checkpoint_url}" - ) - logging.info(f"Materials Project dataset to {cached_dataset_path}") - output = f"mp_finetuning-{tag}.xyz" - atomic_numbers = ( - ast.literal_eval(args.atomic_numbers) - if args.atomic_numbers is not None - else None - ) - settings = SelectionSettings( - configs_pt=cached_dataset_path, - output=f"mp_finetuning-{tag}.xyz", - atomic_numbers=atomic_numbers, - num_samples=args.num_samples_pt, - seed=args.seed, - head_pt="pbe_mp", - weight_pt=args.weight_pt_head, - filtering_type=FilteringType(args.filter_type_pt), - subselect=SubselectType(args.subselect_pt), - default_dtype=args.default_dtype, - ) - select_samples(settings) - head_config_pt.train_file = [output] - collections_mp, _ = get_dataset_from_xyz( - work_dir=args.work_dir, - train_path=output, - valid_path=None, - valid_fraction=args.valid_fraction, - config_type_weights=None, - test_path=None, - seed=args.seed, - key_specification=head_config_pt.key_specification, - head_name="pt_head", - keep_isolated_atoms=args.keep_isolated_atoms, - ) - return collections_mp - except Exception as exc: - raise RuntimeError( - "Model or descriptors download failed and no local model found" - ) from exc +import argparse +import ast +import dataclasses +import logging +import os +import urllib.request +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch + +from mace.cli.fine_tuning_select import ( + FilteringType, + SelectionSettings, + SubselectType, + select_samples, +) +from mace.data import KeySpecification +from mace.tools.scripts_utils import SubsetCollection, get_dataset_from_xyz + + +@dataclasses.dataclass +class HeadConfig: + head_name: str + key_specification: KeySpecification + train_file: Optional[Union[str, List[str]]] = None + valid_file: Optional[Union[str, List[str]]] = None + test_file: Optional[str] = None + test_dir: Optional[str] = None + E0s: Optional[Any] = None + statistics_file: Optional[str] = None + valid_fraction: Optional[float] = None + config_type_weights: Optional[Dict[str, float]] = None + keep_isolated_atoms: Optional[bool] = None + atomic_numbers: Optional[Union[List[int], List[str]]] = None + mean: Optional[float] = None + std: Optional[float] = None + avg_num_neighbors: Optional[float] = None + compute_avg_num_neighbors: Optional[bool] = None + collections: Optional[SubsetCollection] = None + train_loader: Optional[torch.utils.data.DataLoader] = None + z_table: Optional[Any] = None + atomic_energies_dict: Optional[Dict[str, float]] = None + + +def dict_head_to_dataclass( + head: Dict[str, Any], head_name: str, args: argparse.Namespace +) -> HeadConfig: + """Convert head dictionary to HeadConfig dataclass.""" + # parser+head args that have no defaults but are required + if (args.train_file is None) and (head.get("train_file", None) is None): + raise ValueError( + "train file is not set in the head config yaml or via command line args" + ) + + return HeadConfig( + head_name=head_name, + train_file=head.get("train_file", args.train_file), + valid_file=head.get("valid_file", args.valid_file), + test_file=head.get("test_file", None), + test_dir=head.get("test_dir", None), + E0s=head.get("E0s", args.E0s), + statistics_file=head.get("statistics_file", args.statistics_file), + valid_fraction=head.get("valid_fraction", args.valid_fraction), + config_type_weights=head.get("config_type_weights", args.config_type_weights), + compute_avg_num_neighbors=head.get( + "compute_avg_num_neighbors", args.compute_avg_num_neighbors + ), + atomic_numbers=head.get("atomic_numbers", args.atomic_numbers), + mean=head.get("mean", args.mean), + std=head.get("std", args.std), + avg_num_neighbors=head.get("avg_num_neighbors", args.avg_num_neighbors), + key_specification=head["key_specification"], + keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), + ) + + +def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: + """Prepare a default head from args.""" + return { + "Default": { + "train_file": args.train_file, + "valid_file": args.valid_file, + "test_file": args.test_file, + "test_dir": args.test_dir, + "E0s": args.E0s, + "statistics_file": args.statistics_file, + "key_specification": args.key_specification, + "valid_fraction": args.valid_fraction, + "config_type_weights": args.config_type_weights, + "keep_isolated_atoms": args.keep_isolated_atoms, + } + } + + +def prepare_pt_head( + args: argparse.Namespace, + pt_keyspec: KeySpecification, + foundation_model_num_neighbours: float, +) -> Dict[str, Any]: + """Prepare a pretraining head from args.""" + if ( + args.foundation_model in ["small", "medium", "large"] + or args.pt_train_file == "mp" + ): + logging.info( + "Using foundation model for multiheads finetuning with Materials Project data" + ) + pt_keyspec.update( + info_keys={"energy": "energy", "stress": "stress"}, + arrays_keys={"forces": "forces"}, + ) + pt_head = { + "train_file": "mp", + "E0s": "foundation", + "statistics_file": None, + "key_specification": pt_keyspec, + "avg_num_neighbors": foundation_model_num_neighbours, + "compute_avg_num_neighbors": False, + } + else: + pt_head = { + "train_file": args.pt_train_file, + "valid_file": args.pt_valid_file, + "E0s": "foundation", + "statistics_file": args.statistics_file, + "valid_fraction": args.valid_fraction, + "key_specification": pt_keyspec, + "avg_num_neighbors": foundation_model_num_neighbours, + "keep_isolated_atoms": args.keep_isolated_atoms, + "compute_avg_num_neighbors": False, + } + + return pt_head + + +def assemble_mp_data( + args: argparse.Namespace, + head_config_pt: HeadConfig, + tag: str, +) -> SubsetCollection: + """Assemble Materials Project data for fine-tuning.""" + try: + checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz" + cache_dir = ( + Path(os.environ.get("XDG_CACHE_HOME", "~/")).expanduser() / ".cache/mace" + ) + checkpoint_url_name = "".join( + c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" + ) + cached_dataset_path = f"{cache_dir}/{checkpoint_url_name}" + if not os.path.isfile(cached_dataset_path): + os.makedirs(cache_dir, exist_ok=True) + # download and save to disk + logging.info("Downloading MP structures for finetuning") + _, http_msg = urllib.request.urlretrieve( + checkpoint_url, cached_dataset_path + ) + if "Content-Type: text/html" in http_msg: + raise RuntimeError( + f"Dataset download failed, please check the URL {checkpoint_url}" + ) + logging.info(f"Materials Project dataset to {cached_dataset_path}") + output = f"mp_finetuning-{tag}.xyz" + atomic_numbers = ( + ast.literal_eval(args.atomic_numbers) + if args.atomic_numbers is not None + else None + ) + settings = SelectionSettings( + configs_pt=cached_dataset_path, + output=f"mp_finetuning-{tag}.xyz", + atomic_numbers=atomic_numbers, + num_samples=args.num_samples_pt, + seed=args.seed, + head_pt="pbe_mp", + weight_pt=args.weight_pt_head, + filtering_type=FilteringType(args.filter_type_pt), + subselect=SubselectType(args.subselect_pt), + default_dtype=args.default_dtype, + ) + select_samples(settings) + head_config_pt.train_file = [output] + collections_mp, _ = get_dataset_from_xyz( + work_dir=args.work_dir, + train_path=output, + valid_path=None, + valid_fraction=args.valid_fraction, + config_type_weights=None, + test_path=None, + seed=args.seed, + key_specification=head_config_pt.key_specification, + head_name="pt_head", + keep_isolated_atoms=args.keep_isolated_atoms, + ) + return collections_mp + except Exception as exc: + raise RuntimeError( + "Model or descriptors download failed and no local model found" + ) from exc diff --git a/mace-bench/3rdparty/mace/mace/tools/run_train_utils.py b/mace-bench/3rdparty/mace/mace/tools/run_train_utils.py index cb1d568..ce37e0e 100644 --- a/mace-bench/3rdparty/mace/mace/tools/run_train_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/run_train_utils.py @@ -1,217 +1,217 @@ -import logging -import os -from pathlib import Path -from typing import Any, List, Optional, Union - -import torch -from torch.utils.data import ConcatDataset - -from mace import data -from mace.tools.scripts_utils import check_path_ase_read -from mace.tools.torch_geometric.dataset import Dataset -from mace.tools.utils import AtomicNumberTable - - -def normalize_file_paths(file_paths: Union[str, List[str]]) -> List[str]: - """ - Normalize file paths to a list format. - - Args: - file_paths: Either a string or a list of strings representing file paths - - Returns: - A list of file paths - """ - if isinstance(file_paths, str): - return [file_paths] - if isinstance(file_paths, list): - return file_paths - raise ValueError(f"Unexpected file paths format: {type(file_paths)}") - - -def load_dataset_for_path( - file_path: Union[str, Path, List[str]], - r_max: float, - z_table: AtomicNumberTable, - heads: List[str], - head_config: Any, - collection: Optional[Any] = None, -) -> Union[Dataset, List]: - """ - Load a dataset from a file path based on its format. - - Args: - file_path: Path to the dataset file - r_max: Cutoff radius - z_table: Atomic number table - heads: List of head names - head_name: Current head name - **kwargs: Additional arguments - - Returns: - Loaded dataset - """ - if isinstance(file_path, list): - if len(file_path) == 1: - file_path = file_path[0] - if isinstance(file_path, list): - is_ase_readable = all(check_path_ase_read(p) for p in file_path) - if not is_ase_readable: - raise ValueError( - "Not all paths in the list are ASE readable, not supported" - ) - if isinstance(file_path, str): - is_ase_readable = check_path_ase_read(file_path) - - if is_ase_readable: - assert ( - collection is not None - ), "Collection must be provided for ASE readable files" - return [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=r_max, heads=heads - ) - for config in collection - ] - - filepath = Path(file_path) - if filepath.is_dir(): - - if filepath.name.endswith("_lmdb") or any( - f.endswith(".lmdb") or f.endswith(".aselmdb") for f in os.listdir(filepath) - ): - logging.info(f"Loading LMDB dataset from {file_path}") - return data.LMDBDataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - h5_files = list(filepath.glob("*.h5")) + list(filepath.glob("*.hdf5")) - if h5_files: - logging.info(f"Loading HDF5 dataset from directory {file_path}") - try: - return data.dataset_from_sharded_hdf5( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - except Exception as e: - logging.error(f"Error loading sharded HDF5 dataset: {e}") - raise - - if "lmdb" in str(filepath).lower() or "aselmdb" in str(filepath).lower(): - logging.info(f"Loading LMDB dataset based on path name: {file_path}") - return data.LMDBDataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - logging.info(f"Attempting to load directory as HDF5 dataset: {file_path}") - try: - return data.dataset_from_sharded_hdf5( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - except Exception as e: - logging.error(f"Error loading as sharded HDF5: {e}") - raise - - suffix = filepath.suffix.lower() - if suffix in (".h5", ".hdf5"): - logging.info(f"Loading single HDF5 file: {file_path}") - return data.HDF5Dataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - if suffix in (".lmdb", ".aselmdb", ".db"): - logging.info(f"Loading single LMDB file: {file_path}") - return data.LMDBDataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - logging.info(f"Attempting to load as LMDB: {file_path}") - return data.LMDBDataset( - file_path, - r_max=r_max, - z_table=z_table, - heads=heads, - head=head_config.head_name, - ) - - -def combine_datasets(datasets, head_name): - """ - Combine multiple datasets which might be of different types. - - Args: - datasets: List of datasets (can be mixed types) - head_name: Name of the current head - - Returns: - Combined dataset - """ - if not datasets: - return [] - - if all(isinstance(ds, list) for ds in datasets): - logging.info(f"Combining {len(datasets)} list datasets for head '{head_name}'") - return [item for sublist in datasets for item in sublist] - - if all(not isinstance(ds, list) for ds in datasets): - logging.info( - f"Combining {len(datasets)} Dataset objects for head '{head_name}'" - ) - return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] - - logging.info(f"Converting mixed dataset types for head '{head_name}'") - - try: - all_items = [] - for ds in datasets: - if isinstance(ds, list): - all_items.extend(ds) - else: - all_items.extend([ds[i] for i in range(len(ds))]) - return all_items - except Exception as e: # pylint: disable=W0703 - logging.warning(f"Failed to convert mixed datasets to list: {e}") - - try: - dataset_objects = [] - for ds in datasets: - if isinstance(ds, list): - from torch.utils.data import TensorDataset - - # Convert list to a Dataset - dataset_objects.append( - TensorDataset(*[torch.tensor([i]) for i in range(len(ds))]) - ) - else: - dataset_objects.append(ds) - return ConcatDataset(dataset_objects) - except Exception as e: # pylint: disable=W0703 - logging.warning(f"Failed to convert mixed datasets to ConcatDataset: {e}") - - logging.warning( - "Could not combine datasets of different types. Using only the first dataset." - ) - return datasets[0] +import logging +import os +from pathlib import Path +from typing import Any, List, Optional, Union + +import torch +from torch.utils.data import ConcatDataset + +from mace import data +from mace.tools.scripts_utils import check_path_ase_read +from mace.tools.torch_geometric.dataset import Dataset +from mace.tools.utils import AtomicNumberTable + + +def normalize_file_paths(file_paths: Union[str, List[str]]) -> List[str]: + """ + Normalize file paths to a list format. + + Args: + file_paths: Either a string or a list of strings representing file paths + + Returns: + A list of file paths + """ + if isinstance(file_paths, str): + return [file_paths] + if isinstance(file_paths, list): + return file_paths + raise ValueError(f"Unexpected file paths format: {type(file_paths)}") + + +def load_dataset_for_path( + file_path: Union[str, Path, List[str]], + r_max: float, + z_table: AtomicNumberTable, + heads: List[str], + head_config: Any, + collection: Optional[Any] = None, +) -> Union[Dataset, List]: + """ + Load a dataset from a file path based on its format. + + Args: + file_path: Path to the dataset file + r_max: Cutoff radius + z_table: Atomic number table + heads: List of head names + head_name: Current head name + **kwargs: Additional arguments + + Returns: + Loaded dataset + """ + if isinstance(file_path, list): + if len(file_path) == 1: + file_path = file_path[0] + if isinstance(file_path, list): + is_ase_readable = all(check_path_ase_read(p) for p in file_path) + if not is_ase_readable: + raise ValueError( + "Not all paths in the list are ASE readable, not supported" + ) + if isinstance(file_path, str): + is_ase_readable = check_path_ase_read(file_path) + + if is_ase_readable: + assert ( + collection is not None + ), "Collection must be provided for ASE readable files" + return [ + data.AtomicData.from_config( + config, z_table=z_table, cutoff=r_max, heads=heads + ) + for config in collection + ] + + filepath = Path(file_path) + if filepath.is_dir(): + + if filepath.name.endswith("_lmdb") or any( + f.endswith(".lmdb") or f.endswith(".aselmdb") for f in os.listdir(filepath) + ): + logging.info(f"Loading LMDB dataset from {file_path}") + return data.LMDBDataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + h5_files = list(filepath.glob("*.h5")) + list(filepath.glob("*.hdf5")) + if h5_files: + logging.info(f"Loading HDF5 dataset from directory {file_path}") + try: + return data.dataset_from_sharded_hdf5( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + except Exception as e: + logging.error(f"Error loading sharded HDF5 dataset: {e}") + raise + + if "lmdb" in str(filepath).lower() or "aselmdb" in str(filepath).lower(): + logging.info(f"Loading LMDB dataset based on path name: {file_path}") + return data.LMDBDataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + logging.info(f"Attempting to load directory as HDF5 dataset: {file_path}") + try: + return data.dataset_from_sharded_hdf5( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + except Exception as e: + logging.error(f"Error loading as sharded HDF5: {e}") + raise + + suffix = filepath.suffix.lower() + if suffix in (".h5", ".hdf5"): + logging.info(f"Loading single HDF5 file: {file_path}") + return data.HDF5Dataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + if suffix in (".lmdb", ".aselmdb", ".db"): + logging.info(f"Loading single LMDB file: {file_path}") + return data.LMDBDataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + logging.info(f"Attempting to load as LMDB: {file_path}") + return data.LMDBDataset( + file_path, + r_max=r_max, + z_table=z_table, + heads=heads, + head=head_config.head_name, + ) + + +def combine_datasets(datasets, head_name): + """ + Combine multiple datasets which might be of different types. + + Args: + datasets: List of datasets (can be mixed types) + head_name: Name of the current head + + Returns: + Combined dataset + """ + if not datasets: + return [] + + if all(isinstance(ds, list) for ds in datasets): + logging.info(f"Combining {len(datasets)} list datasets for head '{head_name}'") + return [item for sublist in datasets for item in sublist] + + if all(not isinstance(ds, list) for ds in datasets): + logging.info( + f"Combining {len(datasets)} Dataset objects for head '{head_name}'" + ) + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] + + logging.info(f"Converting mixed dataset types for head '{head_name}'") + + try: + all_items = [] + for ds in datasets: + if isinstance(ds, list): + all_items.extend(ds) + else: + all_items.extend([ds[i] for i in range(len(ds))]) + return all_items + except Exception as e: # pylint: disable=W0703 + logging.warning(f"Failed to convert mixed datasets to list: {e}") + + try: + dataset_objects = [] + for ds in datasets: + if isinstance(ds, list): + from torch.utils.data import TensorDataset + + # Convert list to a Dataset + dataset_objects.append( + TensorDataset(*[torch.tensor([i]) for i in range(len(ds))]) + ) + else: + dataset_objects.append(ds) + return ConcatDataset(dataset_objects) + except Exception as e: # pylint: disable=W0703 + logging.warning(f"Failed to convert mixed datasets to ConcatDataset: {e}") + + logging.warning( + "Could not combine datasets of different types. Using only the first dataset." + ) + return datasets[0] diff --git a/mace-bench/3rdparty/mace/mace/tools/scatter.py b/mace-bench/3rdparty/mace/mace/tools/scatter.py index cf7a5ec..7e1139a 100644 --- a/mace-bench/3rdparty/mace/mace/tools/scatter.py +++ b/mace-bench/3rdparty/mace/mace/tools/scatter.py @@ -1,112 +1,112 @@ -"""basic scatter_sum operations from torch_scatter from -https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py -Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. -PyTorch plans to move these features into the main repo, but until then, -to make installation simpler, we need this pure python set of wrappers -that don't require installing PyTorch C++ extensions. -See https://github.com/pytorch/pytorch/issues/63780. -""" - -from typing import Optional - -import torch - - -def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): - if dim < 0: - dim = other.dim() + dim - if src.dim() == 1: - for _ in range(0, dim): - src = src.unsqueeze(0) - for _ in range(src.dim(), other.dim()): - src = src.unsqueeze(-1) - src = src.expand_as(other) - return src - - -def scatter_sum( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - reduce: str = "sum", -) -> torch.Tensor: - assert reduce == "sum" # for now, TODO - index = _broadcast(index, src, dim) - if out is None: - size = list(src.size()) - if dim_size is not None: - size[dim] = dim_size - elif index.numel() == 0: - size[dim] = 0 - else: - size[dim] = int(index.max()) + 1 - out = torch.zeros(size, dtype=src.dtype, device=src.device) - return out.scatter_add_(dim, index, src) - else: - return out.scatter_add_(dim, index, src) - - -def scatter_std( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - unbiased: bool = True, -) -> torch.Tensor: - if out is not None: - dim_size = out.size(dim) - - if dim < 0: - dim = src.dim() + dim - - count_dim = dim - if index.dim() <= dim: - count_dim = index.dim() - 1 - - ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) - count = scatter_sum(ones, index, count_dim, dim_size=dim_size) - - index = _broadcast(index, src, dim) - tmp = scatter_sum(src, index, dim, dim_size=dim_size) - count = _broadcast(count, tmp, dim).clamp(1) - mean = tmp.div(count) - - var = src - mean.gather(dim, index) - var = var * var - out = scatter_sum(var, index, dim, out, dim_size) - - if unbiased: - count = count.sub(1).clamp_(1) - out = out.div(count + 1e-6).sqrt() - - return out - - -def scatter_mean( - src: torch.Tensor, - index: torch.Tensor, - dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, -) -> torch.Tensor: - out = scatter_sum(src, index, dim, out, dim_size) - dim_size = out.size(dim) - - index_dim = dim - if index_dim < 0: - index_dim = index_dim + src.dim() - if index.dim() <= index_dim: - index_dim = index.dim() - 1 - - ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) - count = scatter_sum(ones, index, index_dim, None, dim_size) - count[count < 1] = 1 - count = _broadcast(count, out, dim) - if out.is_floating_point(): - out.true_divide_(count) - else: - out.div_(count, rounding_mode="floor") - return out +"""basic scatter_sum operations from torch_scatter from +https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py +Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. +PyTorch plans to move these features into the main repo, but until then, +to make installation simpler, we need this pure python set of wrappers +that don't require installing PyTorch C++ extensions. +See https://github.com/pytorch/pytorch/issues/63780. +""" + +from typing import Optional + +import torch + + +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src + + +def scatter_sum( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> torch.Tensor: + assert reduce == "sum" # for now, TODO + index = _broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +def scatter_std( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True, +) -> torch.Tensor: + if out is not None: + dim_size = out.size(dim) + + if dim < 0: + dim = src.dim() + dim + + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, count_dim, dim_size=dim_size) + + index = _broadcast(index, src, dim) + tmp = scatter_sum(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clamp(1) + mean = tmp.div(count) + + var = src - mean.gather(dim, index) + var = var * var + out = scatter_sum(var, index, dim, out, dim_size) + + if unbiased: + count = count.sub(1).clamp_(1) + out = out.div(count + 1e-6).sqrt() + + return out + + +def scatter_mean( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = _broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode="floor") + return out diff --git a/mace-bench/3rdparty/mace/mace/tools/scripts_utils.py b/mace-bench/3rdparty/mace/mace/tools/scripts_utils.py index 91fc674..bb7f79e 100644 --- a/mace-bench/3rdparty/mace/mace/tools/scripts_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/scripts_utils.py @@ -1,888 +1,888 @@ -########################################################################################### -# Training utils -# Authors: David Kovacs, Ilyes Batatia -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import argparse -import ast -import dataclasses -import json -import logging -import os -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.distributed -from e3nn import o3 -from torch.optim.swa_utils import SWALR, AveragedModel - -from mace import data, modules, tools -from mace.data import KeySpecification -from mace.tools.train import SWAContainer - - -@dataclasses.dataclass -class SubsetCollection: - train: data.Configurations - valid: data.Configurations - tests: List[Tuple[str, data.Configurations]] - - -def log_dataset_contents(dataset: data.Configurations, dataset_name: str) -> None: - log_string = f"{dataset_name} [" - for prop_name in dataset[0].properties.keys(): - if prop_name == "dipole": - log_string += f"{prop_name} components: {int(np.sum([np.sum(config.property_weights[prop_name]) for config in dataset]))}, " - else: - log_string += f"{prop_name}: {int(np.sum([config.property_weights[prop_name] for config in dataset]))}, " - log_string = log_string[:-2] + "]" - logging.info(log_string) - - -def get_dataset_from_xyz( - work_dir: str, - train_path: Union[str, List[str]], - valid_path: Optional[Union[str, List[str]]], - valid_fraction: float, - key_specification: KeySpecification, - config_type_weights: Optional[Dict] = None, - test_path: Optional[Union[str, List[str]]] = None, - seed: int = 1234, - keep_isolated_atoms: bool = False, - head_name: str = "Default", -) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: - """ - Load training, validation, and test datasets from xyz files. - - Args: - work_dir: Working directory for saving split information - train_path: Path or list of paths to training xyz files - valid_path: Path or list of paths to validation xyz files - valid_fraction: Fraction of training data to use for validation if valid_path is None - config_type_weights: Dictionary of weights for each configuration type - key_specification: KeySpecification object for loading data - test_path: Path or list of paths to test xyz files - seed: Random seed for train/validation split - keep_isolated_atoms: Whether to keep isolated atoms in the dataset - head_name: Name of the head for multi-head models - - Returns: - Tuple containing: - - SubsetCollection with train, valid, and test configurations - - Dictionary of atomic energies (or None if not available) - """ - # Convert input paths to lists if they're not already - train_paths = [train_path] if isinstance(train_path, str) else train_path - valid_paths = ( - [valid_path] - if isinstance(valid_path, str) and valid_path is not None - else valid_path - ) - test_paths = ( - [test_path] - if isinstance(test_path, str) and test_path is not None - else test_path - ) - - # Initialize collections and atomic energies tracking - all_train_configs = [] - all_valid_configs = [] - all_test_configs = [] - - # For tracking atomic energies across files - atomic_energies_values = {} # Element Z -> list of energy values - atomic_energies_counts = {} # Element Z -> count of files with this element - - # Process training files - for i, path in enumerate(train_paths): - logging.debug(f"Loading training file: {path}") - ae_dict, train_configs = data.load_from_xyz( - file_path=path, - config_type_weights=config_type_weights, - key_specification=key_specification, - extract_atomic_energies=True, # Extract from all files to average - keep_isolated_atoms=keep_isolated_atoms, - head_name=head_name, - ) - all_train_configs.extend(train_configs) - - # Track atomic energies from each file for averaging - if ae_dict: - for element, energy in ae_dict.items(): - if element not in atomic_energies_values: - atomic_energies_values[element] = [] - atomic_energies_counts[element] = 0 - - atomic_energies_values[element].append(energy) - atomic_energies_counts[element] += 1 - - log_dataset_contents(train_configs, f"Training set {i+1}/{len(train_paths)}") - - # Log total training set info - log_dataset_contents(all_train_configs, "Total Training set") - - # Process validation files if provided - if valid_paths: - for i, path in enumerate(valid_paths): - _, valid_configs = data.load_from_xyz( - file_path=path, - config_type_weights=config_type_weights, - key_specification=key_specification, - extract_atomic_energies=False, - head_name=head_name, - ) - all_valid_configs.extend(valid_configs) - log_dataset_contents( - valid_configs, f"Validation set {i+1}/{len(valid_paths)}" - ) - - # Log total validation set info - log_dataset_contents(all_valid_configs, "Total Validation set") - train_configs = all_train_configs - valid_configs = all_valid_configs - else: - # Split training data if no validation files are provided - logging.info("No validation set provided, splitting training data instead.") - train_configs, valid_configs = data.random_train_valid_split( - all_train_configs, valid_fraction, seed, work_dir - ) - log_dataset_contents(train_configs, "Random Split Training set") - log_dataset_contents(valid_configs, "Random Split Validation set") - - test_configs_by_type = [] - if test_paths: - for i, path in enumerate(test_paths): - _, test_configs = data.load_from_xyz( - file_path=path, - config_type_weights=config_type_weights, - key_specification=key_specification, - extract_atomic_energies=False, - head_name=head_name, - ) - all_test_configs.extend(test_configs) - - log_dataset_contents(test_configs, f"Test set {i+1}/{len(test_paths)}") - - # Create list of tuples (config_type, list(Atoms)) - test_configs_by_type = data.test_config_types(all_test_configs) - log_dataset_contents(all_test_configs, "Total Test set") - - atomic_energies_dict = {} - for element, values in atomic_energies_values.items(): - if atomic_energies_counts[element] > 1: - atomic_energies_dict[element] = sum(values) / len(values) - logging.debug( - f"Element {element} found in {atomic_energies_counts[element]} files. Using average E0: {atomic_energies_dict[element]:.6f} eV" - ) - else: - atomic_energies_dict[element] = values[0] - logging.debug( - f"Element {element} found in 1 file. Using E0: {atomic_energies_dict[element]:.6f} eV" - ) - - return ( - SubsetCollection( - train=train_configs, valid=valid_configs, tests=test_configs_by_type - ), - atomic_energies_dict if atomic_energies_dict else None, - ) - - -def get_config_type_weights(ct_weights): - """ - Get config type weights from command line argument - """ - try: - config_type_weights = ast.literal_eval(ct_weights) - assert isinstance(config_type_weights, dict) - except Exception as e: # pylint: disable=W0703 - logging.warning( - f"Config type weights not specified correctly ({e}), using Default" - ) - config_type_weights = {"Default": 1.0} - return config_type_weights - - -def print_git_commit(): - try: - import git - - repo = git.Repo(search_parent_directories=True) - commit = repo.head.commit.hexsha - logging.debug(f"Current Git commit: {commit}") - return commit - except Exception as e: # pylint: disable=W0703 - logging.debug(f"Error accessing Git repository: {e}") - return "None" - - -def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: - if model.__class__.__name__ != "ScaleShiftMACE": - return {"error": "Model is not a ScaleShiftMACE model"} - - def radial_to_name(radial_type): - if radial_type == "BesselBasis": - return "bessel" - if radial_type == "GaussianBasis": - return "gaussian" - if radial_type == "ChebychevBasis": - return "chebyshev" - return radial_type - - def radial_to_transform(radial): - if not hasattr(radial, "distance_transform"): - return None - if radial.distance_transform.__class__.__name__ == "AgnesiTransform": - return "Agnesi" - if radial.distance_transform.__class__.__name__ == "SoftTransform": - return "Soft" - return radial.distance_transform.__class__.__name__ - - scale = model.scale_shift.scale - shift = model.scale_shift.shift - heads = model.heads if hasattr(model, "heads") else ["default"] - model_mlp_irreps = ( - o3.Irreps(str(model.readouts[-1].hidden_irreps)) - if model.num_interactions.item() > 1 - else 1 - ) - mlp_irreps = o3.Irreps(f"{model_mlp_irreps.count((0, 1)) // len(heads)}x0e") - try: - correlation = ( - len(model.products[0].symmetric_contractions.contractions[0].weights) + 1 - ) - except AttributeError: - correlation = model.products[0].symmetric_contractions.contraction_degree - config = { - "r_max": model.r_max.item(), - "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), - "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), - "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access - "interaction_cls": model.interactions[-1].__class__, - "interaction_cls_first": model.interactions[0].__class__, - "num_interactions": model.num_interactions.item(), - "num_elements": len(model.atomic_numbers), - "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), - "MLP_irreps": (mlp_irreps if model.num_interactions.item() > 1 else 1), - "gate": ( - model.readouts[-1] # pylint: disable=protected-access - .non_linearity._modules["acts"][0] - .f - if model.num_interactions.item() > 1 - else None - ), - "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), - "avg_num_neighbors": model.interactions[0].avg_num_neighbors, - "atomic_numbers": model.atomic_numbers, - "correlation": correlation, - "radial_type": radial_to_name( - model.radial_embedding.bessel_fn.__class__.__name__ - ), - "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], - "pair_repulsion": hasattr(model, "pair_repulsion_fn"), - "distance_transform": radial_to_transform(model.radial_embedding), - "atomic_inter_scale": scale.cpu().numpy(), - "atomic_inter_shift": shift.cpu().numpy(), - "heads": heads, - } - return config - - -def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: - return extract_model( - torch.load(f=f, map_location=map_location), map_location=map_location - ) - - -def remove_pt_head( - model: torch.nn.Module, head_to_keep: Optional[str] = None -) -> torch.nn.Module: - """Converts a multihead MACE model to a single head model by removing the pretraining head. - - Args: - model (ScaleShiftMACE): The multihead MACE model to convert - head_to_keep (Optional[str]): The name of the head to keep. If None, keeps the first non-PT head. - - Returns: - ScaleShiftMACE: A new MACE model with only the specified head - - Raises: - ValueError: If the model is not a multihead model or if the specified head is not found - """ - if not hasattr(model, "heads") or len(model.heads) <= 1: - raise ValueError("Model must be a multihead model with more than one head") - - # Get index of head to keep - if head_to_keep is None: - # Find first non-PT head - try: - head_idx = next(i for i, h in enumerate(model.heads) if h != "pt_head") - except StopIteration as e: - raise ValueError("No non-PT head found in model") from e - else: - try: - head_idx = model.heads.index(head_to_keep) - except ValueError as e: - raise ValueError(f"Head {head_to_keep} not found in model") from e - - # Extract config and modify for single head - model_config = extract_config_mace_model(model) - model_config["heads"] = [model.heads[head_idx]] - model_config["atomic_energies"] = ( - model.atomic_energies_fn.atomic_energies[head_idx] - .unsqueeze(0) - .detach() - .cpu() - .numpy() - ) - model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item() - model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item() - mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) - # model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") - - new_model = model.__class__(**model_config) - state_dict = model.state_dict() - new_state_dict = {} - - for name, param in state_dict.items(): - if "atomic_energies" in name: - new_state_dict[name] = param[head_idx : head_idx + 1] - elif "scale" in name or "shift" in name: - new_state_dict[name] = param[head_idx : head_idx + 1] - elif "readouts" in name: - channels_per_head = param.shape[0] // len(model.heads) - start_idx = head_idx * channels_per_head - end_idx = start_idx + channels_per_head - if "linear_2.weight" in name: - end_idx = start_idx + channels_per_head // 2 - # if ( - # "readouts.0.linear.weight" in name - # or "readouts.1.linear_2.weight" in name - # ): - # new_state_dict[name] = param[start_idx:end_idx] / ( - # len(model.heads) ** 0.5 - # ) - if "readouts.0.linear.weight" in name: - new_state_dict[name] = param.reshape(-1, len(model.heads))[ - :, head_idx - ].flatten() - elif "readouts.1.linear_1.weight" in name: - new_state_dict[name] = param.reshape( - -1, len(model.heads), mlp_count_irreps - )[:, head_idx, :].flatten() - elif "readouts.1.linear_2.weight" in name: - new_state_dict[name] = param.reshape( - len(model.heads), -1, len(model.heads) - )[head_idx, :, head_idx].flatten() / (len(model.heads) ** 0.5) - else: - new_state_dict[name] = param[start_idx:end_idx] - - else: - new_state_dict[name] = param - - # Load state dict into new model - new_model.load_state_dict(new_state_dict) - - return new_model - - -def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: - model_copy = model.__class__(**extract_config_mace_model(model)) - model_copy.load_state_dict(model.state_dict()) - return model_copy.to(map_location) - - -def convert_to_json_format(dict_input): - for key, value in dict_input.items(): - if isinstance(value, (np.ndarray, torch.Tensor)): - dict_input[key] = value.tolist() - # # check if the value is a class and convert it to a string - elif hasattr(value, "__class__"): - dict_input[key] = str(value) - return dict_input - - -def convert_from_json_format(dict_input): - dict_output = dict_input.copy() - if ( - dict_input["interaction_cls"] - == "" - ): - dict_output["interaction_cls"] = ( - modules.blocks.RealAgnosticResidualInteractionBlock - ) - if ( - dict_input["interaction_cls"] - == "" - ): - dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock - if ( - dict_input["interaction_cls_first"] - == "" - ): - dict_output["interaction_cls_first"] = ( - modules.blocks.RealAgnosticResidualInteractionBlock - ) - if ( - dict_input["interaction_cls_first"] - == "" - ): - dict_output["interaction_cls_first"] = ( - modules.blocks.RealAgnosticInteractionBlock - ) - dict_output["r_max"] = float(dict_input["r_max"]) - dict_output["num_bessel"] = int(dict_input["num_bessel"]) - dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) - dict_output["max_ell"] = int(dict_input["max_ell"]) - dict_output["num_interactions"] = int(dict_input["num_interactions"]) - dict_output["num_elements"] = int(dict_input["num_elements"]) - dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) - dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) - dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) - dict_output["gate"] = torch.nn.functional.silu - dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) - dict_output["atomic_numbers"] = dict_input["atomic_numbers"] - dict_output["correlation"] = int(dict_input["correlation"]) - dict_output["radial_type"] = dict_input["radial_type"] - dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) - dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) - dict_output["distance_transform"] = dict_input["distance_transform"] - dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) - dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) - - return dict_output - - -def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: - extra_files_extract = {"commit.txt": None, "config.json": None} - model_jit_load = torch.jit.load( - f, _extra_files=extra_files_extract, map_location=map_location - ) - model_load_yaml = modules.ScaleShiftMACE( - **convert_from_json_format(json.loads(extra_files_extract["config.json"])) - ) - model_load_yaml.load_state_dict(model_jit_load.state_dict()) - return model_load_yaml.to(map_location) - - -def get_atomic_energies(E0s, train_collection, z_table) -> dict: - if E0s is not None: - logging.info( - "Isolated Atomic Energies (E0s) not in training file, using command line argument" - ) - if E0s.lower() == "average": - logging.info( - "Computing average Atomic Energies using least squares regression" - ) - # catch if colections.train not defined above - try: - assert train_collection is not None - atomic_energies_dict = data.compute_average_E0s( - train_collection, z_table - ) - except Exception as e: - raise RuntimeError( - f"Could not compute average E0s if no training xyz given, error {e} occured" - ) from e - else: - if E0s.endswith(".json"): - logging.info(f"Loading atomic energies from {E0s}") - with open(E0s, "r", encoding="utf-8") as f: - atomic_energies_dict = json.load(f) - atomic_energies_dict = { - int(key): value for key, value in atomic_energies_dict.items() - } - else: - try: - atomic_energies_eval = ast.literal_eval(E0s) - if not all( - isinstance(value, dict) - for value in atomic_energies_eval.values() - ): - atomic_energies_dict = atomic_energies_eval - else: - atomic_energies_dict = atomic_energies_eval - assert isinstance(atomic_energies_dict, dict) - except Exception as e: - raise RuntimeError( - f"E0s specified invalidly, error {e} occured" - ) from e - else: - raise RuntimeError( - "E0s not found in training file and not specified in command line" - ) - return atomic_energies_dict - - -def get_avg_num_neighbors(head_configs, args, train_loader, device): - if all(head_config.compute_avg_num_neighbors for head_config in head_configs): - logging.info("Computing average number of neighbors") - avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) - if args.distributed: - num_graphs = torch.tensor(len(train_loader.dataset)).to(device) - num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) - torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce( - num_neighbors, op=torch.distributed.ReduceOp.SUM - ) - avg_num_neighbors_out = (num_neighbors / num_graphs).item() - else: - avg_num_neighbors_out = avg_num_neighbors - else: - assert any( - head_config.avg_num_neighbors is not None for head_config in head_configs - ), "Average number of neighbors must be provided in the configuration" - avg_num_neighbors_out = max( - head_config.avg_num_neighbors - for head_config in head_configs - if head_config.avg_num_neighbors is not None - ) - if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100: - logging.warning( - f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}" - ) - else: - logging.info(f"Average number of neighbors: {avg_num_neighbors_out}") - return avg_num_neighbors_out - - -def get_loss_fn( - args: argparse.Namespace, - dipole_only: bool, - compute_dipole: bool, -) -> torch.nn.Module: - if args.loss == "weighted": - loss_fn = modules.WeightedEnergyForcesLoss( - energy_weight=args.energy_weight, forces_weight=args.forces_weight - ) - elif args.loss == "forces_only": - loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) - elif args.loss == "virials": - loss_fn = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - virials_weight=args.virials_weight, - ) - elif args.loss == "stress": - loss_fn = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - ) - elif args.loss == "huber": - loss_fn = modules.WeightedHuberEnergyForcesStressLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - huber_delta=args.huber_delta, - ) - elif args.loss == "universal": - loss_fn = modules.UniversalLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - stress_weight=args.stress_weight, - huber_delta=args.huber_delta, - ) - elif args.loss == "l1l2energyforces": - loss_fn = modules.WeightedEnergyForcesL1L2Loss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - ) - elif args.loss == "dipole": - assert ( - dipole_only is True - ), "dipole loss can only be used with AtomicDipolesMACE model" - loss_fn = modules.DipoleSingleLoss( - dipole_weight=args.dipole_weight, - ) - elif args.loss == "energy_forces_dipole": - assert dipole_only is False and compute_dipole is True - loss_fn = modules.WeightedEnergyForcesDipoleLoss( - energy_weight=args.energy_weight, - forces_weight=args.forces_weight, - dipole_weight=args.dipole_weight, - ) - else: - loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) - return loss_fn - - -def get_swa( - args: argparse.Namespace, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - swas: List[bool], - dipole_only: bool = False, -): - assert dipole_only is False, "Stage Two for dipole fitting not implemented" - swas.append(True) - if args.start_swa is None: - args.start_swa = max(1, args.max_num_epochs // 4 * 3) - else: - if args.start_swa >= args.max_num_epochs: - logging.warning( - f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" - ) - swas[-1] = False - if args.loss == "forces_only": - raise ValueError("Can not select Stage Two with forces only loss.") - if args.loss == "virials": - loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - virials_weight=args.swa_virials_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, virials_weight: {args.swa_virials_weight} and learning rate : {args.swa_lr}" - ) - elif args.loss == "stress": - loss_fn_energy = modules.WeightedEnergyForcesStressLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - stress_weight=args.swa_stress_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" - ) - elif args.loss == "energy_forces_dipole": - loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( - args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - dipole_weight=args.swa_dipole_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" - ) - elif args.loss == "universal": - loss_fn_energy = modules.UniversalLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - stress_weight=args.swa_stress_weight, - huber_delta=args.huber_delta, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" - ) - else: - loss_fn_energy = modules.WeightedEnergyForcesLoss( - energy_weight=args.swa_energy_weight, - forces_weight=args.swa_forces_weight, - ) - logging.info( - f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" - ) - swa = SWAContainer( - model=AveragedModel(model), - scheduler=SWALR( - optimizer=optimizer, - swa_lr=args.swa_lr, - anneal_epochs=1, - anneal_strategy="linear", - ), - start=args.start_swa, - loss_fn=loss_fn_energy, - ) - return swa, swas - - -def get_params_options( - args: argparse.Namespace, model: torch.nn.Module -) -> Dict[str, Any]: - decay_interactions = {} - no_decay_interactions = {} - for name, param in model.interactions.named_parameters(): - if "linear.weight" in name or "skip_tp_full.weight" in name: - decay_interactions[name] = param - else: - no_decay_interactions[name] = param - - param_options = dict( - params=[ - { - "name": "embedding", - "params": model.node_embedding.parameters(), - "weight_decay": 0.0, - }, - { - "name": "interactions_decay", - "params": list(decay_interactions.values()), - "weight_decay": args.weight_decay, - }, - { - "name": "interactions_no_decay", - "params": list(no_decay_interactions.values()), - "weight_decay": 0.0, - }, - { - "name": "products", - "params": model.products.parameters(), - "weight_decay": args.weight_decay, - }, - { - "name": "readouts", - "params": model.readouts.parameters(), - "weight_decay": 0.0, - }, - ], - lr=args.lr, - amsgrad=args.amsgrad, - betas=(args.beta, 0.999), - ) - return param_options - - -def get_optimizer( - args: argparse.Namespace, param_options: Dict[str, Any] -) -> torch.optim.Optimizer: - if args.optimizer == "adamw": - optimizer = torch.optim.AdamW(**param_options) - elif args.optimizer == "schedulefree": - try: - from schedulefree import adamw_schedulefree - except ImportError as exc: - raise ImportError( - "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" - ) from exc - _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} - optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) - else: - optimizer = torch.optim.Adam(**param_options) - return optimizer - - -def setup_wandb(args: argparse.Namespace): - logging.info("Using Weights and Biases for logging") - import wandb - - wandb_config = {} - args_dict = vars(args) - - for key, value in args_dict.items(): - if isinstance(value, np.ndarray): - args_dict[key] = value.tolist() - - class CustomEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, KeySpecification): - return o.__dict__ - return super().default(o) - - args_dict_json = json.dumps(args_dict, cls=CustomEncoder) - for key in args.wandb_log_hypers: - wandb_config[key] = args_dict[key] - tools.init_wandb( - project=args.wandb_project, - entity=args.wandb_entity, - name=args.wandb_name, - config=wandb_config, - directory=args.wandb_dir, - ) - wandb.run.summary["params"] = args_dict_json - - -def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: - return [ - os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) - ] - - -def dict_to_array(input_data, heads): - if all(isinstance(value, np.ndarray) for value in input_data.values()): - return np.array([input_data[head] for head in heads]) - if not all(isinstance(value, dict) for value in input_data.values()): - return np.array([[input_data[head]] for head in heads]) - unique_keys = set() - for inner_dict in input_data.values(): - unique_keys.update(inner_dict.keys()) - unique_keys = list(unique_keys) - sorted_keys = sorted([int(key) for key in unique_keys]) - result_array = np.zeros((len(input_data), len(sorted_keys))) - for _, (head_name, inner_dict) in enumerate(input_data.items()): - for key, value in inner_dict.items(): - key_index = sorted_keys.index(int(key)) - head_index = heads.index(head_name) - result_array[head_index][key_index] = value - return result_array - - -class LRScheduler: - def __init__(self, optimizer, args) -> None: - self.scheduler = args.scheduler - self._optimizer_type = ( - args.optimizer - ) # Schedulefree does not need an optimizer but checkpoint handler does. - if args.scheduler == "ExponentialLR": - self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( - optimizer=optimizer, gamma=args.lr_scheduler_gamma - ) - elif args.scheduler == "ReduceLROnPlateau": - self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer=optimizer, - factor=args.lr_factor, - patience=args.scheduler_patience, - ) - else: - raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") - - def step(self, metrics=None, epoch=None): # pylint: disable=E1123 - if self._optimizer_type == "schedulefree": - return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary - if self.scheduler == "ExponentialLR": - self.lr_scheduler.step(epoch=epoch) - elif self.scheduler == "ReduceLROnPlateau": - self.lr_scheduler.step( # pylint: disable=E1123 - metrics=metrics, epoch=epoch - ) - - def __getattr__(self, name): - if name == "step": - return self.step - return getattr(self.lr_scheduler, name) - - -def check_folder_subfolder(folder_path): - entries = os.listdir(folder_path) - for entry in entries: - full_path = os.path.join(folder_path, entry) - if os.path.isdir(full_path): - return True - return False - - -def check_path_ase_read(filename: Optional[str]) -> bool: - if filename is None: - return False - filepath = Path(filename) - if filepath.is_dir(): - num_h5_files = len(list(filepath.glob("*.h5"))) - num_hdf5_files = len(list(filepath.glob("*.hdf5"))) - num_ldb_files = len(list(filepath.glob("*.lmdb"))) - num_aselmbd_files = len(list(filepath.glob("*.aselmdb"))) - num_mdb_files = len(list(filepath.glob("*.mdb"))) - if ( - num_h5_files - + num_hdf5_files - + num_ldb_files - + num_aselmbd_files - + num_mdb_files - == 0 - ): - # print all the files in the directory extension in the directory for debugging - for file in os.listdir(filepath): - print(file) - raise RuntimeError(f"No supported files found in directory '{filename}'") - return False - if filepath.suffix in (".h5", ".hdf5", ".lmdb", ".aselmdb", ".mdb"): - return False - return True - - -def dict_to_namespace(dictionary): - # Convert the dictionary into an argparse.Namespace - namespace = argparse.Namespace() - for key, value in dictionary.items(): - setattr(namespace, key, value) - return namespace +########################################################################################### +# Training utils +# Authors: David Kovacs, Ilyes Batatia +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import argparse +import ast +import dataclasses +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed +from e3nn import o3 +from torch.optim.swa_utils import SWALR, AveragedModel + +from mace import data, modules, tools +from mace.data import KeySpecification +from mace.tools.train import SWAContainer + + +@dataclasses.dataclass +class SubsetCollection: + train: data.Configurations + valid: data.Configurations + tests: List[Tuple[str, data.Configurations]] + + +def log_dataset_contents(dataset: data.Configurations, dataset_name: str) -> None: + log_string = f"{dataset_name} [" + for prop_name in dataset[0].properties.keys(): + if prop_name == "dipole": + log_string += f"{prop_name} components: {int(np.sum([np.sum(config.property_weights[prop_name]) for config in dataset]))}, " + else: + log_string += f"{prop_name}: {int(np.sum([config.property_weights[prop_name] for config in dataset]))}, " + log_string = log_string[:-2] + "]" + logging.info(log_string) + + +def get_dataset_from_xyz( + work_dir: str, + train_path: Union[str, List[str]], + valid_path: Optional[Union[str, List[str]]], + valid_fraction: float, + key_specification: KeySpecification, + config_type_weights: Optional[Dict] = None, + test_path: Optional[Union[str, List[str]]] = None, + seed: int = 1234, + keep_isolated_atoms: bool = False, + head_name: str = "Default", +) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: + """ + Load training, validation, and test datasets from xyz files. + + Args: + work_dir: Working directory for saving split information + train_path: Path or list of paths to training xyz files + valid_path: Path or list of paths to validation xyz files + valid_fraction: Fraction of training data to use for validation if valid_path is None + config_type_weights: Dictionary of weights for each configuration type + key_specification: KeySpecification object for loading data + test_path: Path or list of paths to test xyz files + seed: Random seed for train/validation split + keep_isolated_atoms: Whether to keep isolated atoms in the dataset + head_name: Name of the head for multi-head models + + Returns: + Tuple containing: + - SubsetCollection with train, valid, and test configurations + - Dictionary of atomic energies (or None if not available) + """ + # Convert input paths to lists if they're not already + train_paths = [train_path] if isinstance(train_path, str) else train_path + valid_paths = ( + [valid_path] + if isinstance(valid_path, str) and valid_path is not None + else valid_path + ) + test_paths = ( + [test_path] + if isinstance(test_path, str) and test_path is not None + else test_path + ) + + # Initialize collections and atomic energies tracking + all_train_configs = [] + all_valid_configs = [] + all_test_configs = [] + + # For tracking atomic energies across files + atomic_energies_values = {} # Element Z -> list of energy values + atomic_energies_counts = {} # Element Z -> count of files with this element + + # Process training files + for i, path in enumerate(train_paths): + logging.debug(f"Loading training file: {path}") + ae_dict, train_configs = data.load_from_xyz( + file_path=path, + config_type_weights=config_type_weights, + key_specification=key_specification, + extract_atomic_energies=True, # Extract from all files to average + keep_isolated_atoms=keep_isolated_atoms, + head_name=head_name, + ) + all_train_configs.extend(train_configs) + + # Track atomic energies from each file for averaging + if ae_dict: + for element, energy in ae_dict.items(): + if element not in atomic_energies_values: + atomic_energies_values[element] = [] + atomic_energies_counts[element] = 0 + + atomic_energies_values[element].append(energy) + atomic_energies_counts[element] += 1 + + log_dataset_contents(train_configs, f"Training set {i+1}/{len(train_paths)}") + + # Log total training set info + log_dataset_contents(all_train_configs, "Total Training set") + + # Process validation files if provided + if valid_paths: + for i, path in enumerate(valid_paths): + _, valid_configs = data.load_from_xyz( + file_path=path, + config_type_weights=config_type_weights, + key_specification=key_specification, + extract_atomic_energies=False, + head_name=head_name, + ) + all_valid_configs.extend(valid_configs) + log_dataset_contents( + valid_configs, f"Validation set {i+1}/{len(valid_paths)}" + ) + + # Log total validation set info + log_dataset_contents(all_valid_configs, "Total Validation set") + train_configs = all_train_configs + valid_configs = all_valid_configs + else: + # Split training data if no validation files are provided + logging.info("No validation set provided, splitting training data instead.") + train_configs, valid_configs = data.random_train_valid_split( + all_train_configs, valid_fraction, seed, work_dir + ) + log_dataset_contents(train_configs, "Random Split Training set") + log_dataset_contents(valid_configs, "Random Split Validation set") + + test_configs_by_type = [] + if test_paths: + for i, path in enumerate(test_paths): + _, test_configs = data.load_from_xyz( + file_path=path, + config_type_weights=config_type_weights, + key_specification=key_specification, + extract_atomic_energies=False, + head_name=head_name, + ) + all_test_configs.extend(test_configs) + + log_dataset_contents(test_configs, f"Test set {i+1}/{len(test_paths)}") + + # Create list of tuples (config_type, list(Atoms)) + test_configs_by_type = data.test_config_types(all_test_configs) + log_dataset_contents(all_test_configs, "Total Test set") + + atomic_energies_dict = {} + for element, values in atomic_energies_values.items(): + if atomic_energies_counts[element] > 1: + atomic_energies_dict[element] = sum(values) / len(values) + logging.debug( + f"Element {element} found in {atomic_energies_counts[element]} files. Using average E0: {atomic_energies_dict[element]:.6f} eV" + ) + else: + atomic_energies_dict[element] = values[0] + logging.debug( + f"Element {element} found in 1 file. Using E0: {atomic_energies_dict[element]:.6f} eV" + ) + + return ( + SubsetCollection( + train=train_configs, valid=valid_configs, tests=test_configs_by_type + ), + atomic_energies_dict if atomic_energies_dict else None, + ) + + +def get_config_type_weights(ct_weights): + """ + Get config type weights from command line argument + """ + try: + config_type_weights = ast.literal_eval(ct_weights) + assert isinstance(config_type_weights, dict) + except Exception as e: # pylint: disable=W0703 + logging.warning( + f"Config type weights not specified correctly ({e}), using Default" + ) + config_type_weights = {"Default": 1.0} + return config_type_weights + + +def print_git_commit(): + try: + import git + + repo = git.Repo(search_parent_directories=True) + commit = repo.head.commit.hexsha + logging.debug(f"Current Git commit: {commit}") + return commit + except Exception as e: # pylint: disable=W0703 + logging.debug(f"Error accessing Git repository: {e}") + return "None" + + +def extract_config_mace_model(model: torch.nn.Module) -> Dict[str, Any]: + if model.__class__.__name__ != "ScaleShiftMACE": + return {"error": "Model is not a ScaleShiftMACE model"} + + def radial_to_name(radial_type): + if radial_type == "BesselBasis": + return "bessel" + if radial_type == "GaussianBasis": + return "gaussian" + if radial_type == "ChebychevBasis": + return "chebyshev" + return radial_type + + def radial_to_transform(radial): + if not hasattr(radial, "distance_transform"): + return None + if radial.distance_transform.__class__.__name__ == "AgnesiTransform": + return "Agnesi" + if radial.distance_transform.__class__.__name__ == "SoftTransform": + return "Soft" + return radial.distance_transform.__class__.__name__ + + scale = model.scale_shift.scale + shift = model.scale_shift.shift + heads = model.heads if hasattr(model, "heads") else ["default"] + model_mlp_irreps = ( + o3.Irreps(str(model.readouts[-1].hidden_irreps)) + if model.num_interactions.item() > 1 + else 1 + ) + mlp_irreps = o3.Irreps(f"{model_mlp_irreps.count((0, 1)) // len(heads)}x0e") + try: + correlation = ( + len(model.products[0].symmetric_contractions.contractions[0].weights) + 1 + ) + except AttributeError: + correlation = model.products[0].symmetric_contractions.contraction_degree + config = { + "r_max": model.r_max.item(), + "num_bessel": len(model.radial_embedding.bessel_fn.bessel_weights), + "num_polynomial_cutoff": model.radial_embedding.cutoff_fn.p.item(), + "max_ell": model.spherical_harmonics._lmax, # pylint: disable=protected-access + "interaction_cls": model.interactions[-1].__class__, + "interaction_cls_first": model.interactions[0].__class__, + "num_interactions": model.num_interactions.item(), + "num_elements": len(model.atomic_numbers), + "hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)), + "MLP_irreps": (mlp_irreps if model.num_interactions.item() > 1 else 1), + "gate": ( + model.readouts[-1] # pylint: disable=protected-access + .non_linearity._modules["acts"][0] + .f + if model.num_interactions.item() > 1 + else None + ), + "atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(), + "avg_num_neighbors": model.interactions[0].avg_num_neighbors, + "atomic_numbers": model.atomic_numbers, + "correlation": correlation, + "radial_type": radial_to_name( + model.radial_embedding.bessel_fn.__class__.__name__ + ), + "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], + "pair_repulsion": hasattr(model, "pair_repulsion_fn"), + "distance_transform": radial_to_transform(model.radial_embedding), + "atomic_inter_scale": scale.cpu().numpy(), + "atomic_inter_shift": shift.cpu().numpy(), + "heads": heads, + } + return config + + +def extract_load(f: str, map_location: str = "cpu") -> torch.nn.Module: + return extract_model( + torch.load(f=f, map_location=map_location), map_location=map_location + ) + + +def remove_pt_head( + model: torch.nn.Module, head_to_keep: Optional[str] = None +) -> torch.nn.Module: + """Converts a multihead MACE model to a single head model by removing the pretraining head. + + Args: + model (ScaleShiftMACE): The multihead MACE model to convert + head_to_keep (Optional[str]): The name of the head to keep. If None, keeps the first non-PT head. + + Returns: + ScaleShiftMACE: A new MACE model with only the specified head + + Raises: + ValueError: If the model is not a multihead model or if the specified head is not found + """ + if not hasattr(model, "heads") or len(model.heads) <= 1: + raise ValueError("Model must be a multihead model with more than one head") + + # Get index of head to keep + if head_to_keep is None: + # Find first non-PT head + try: + head_idx = next(i for i, h in enumerate(model.heads) if h != "pt_head") + except StopIteration as e: + raise ValueError("No non-PT head found in model") from e + else: + try: + head_idx = model.heads.index(head_to_keep) + except ValueError as e: + raise ValueError(f"Head {head_to_keep} not found in model") from e + + # Extract config and modify for single head + model_config = extract_config_mace_model(model) + model_config["heads"] = [model.heads[head_idx]] + model_config["atomic_energies"] = ( + model.atomic_energies_fn.atomic_energies[head_idx] + .unsqueeze(0) + .detach() + .cpu() + .numpy() + ) + model_config["atomic_inter_scale"] = model.scale_shift.scale[head_idx].item() + model_config["atomic_inter_shift"] = model.scale_shift.shift[head_idx].item() + mlp_count_irreps = model_config["MLP_irreps"].count((0, 1)) + # model_config["MLP_irreps"] = o3.Irreps(f"{mlp_count_irreps}x0e") + + new_model = model.__class__(**model_config) + state_dict = model.state_dict() + new_state_dict = {} + + for name, param in state_dict.items(): + if "atomic_energies" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "scale" in name or "shift" in name: + new_state_dict[name] = param[head_idx : head_idx + 1] + elif "readouts" in name: + channels_per_head = param.shape[0] // len(model.heads) + start_idx = head_idx * channels_per_head + end_idx = start_idx + channels_per_head + if "linear_2.weight" in name: + end_idx = start_idx + channels_per_head // 2 + # if ( + # "readouts.0.linear.weight" in name + # or "readouts.1.linear_2.weight" in name + # ): + # new_state_dict[name] = param[start_idx:end_idx] / ( + # len(model.heads) ** 0.5 + # ) + if "readouts.0.linear.weight" in name: + new_state_dict[name] = param.reshape(-1, len(model.heads))[ + :, head_idx + ].flatten() + elif "readouts.1.linear_1.weight" in name: + new_state_dict[name] = param.reshape( + -1, len(model.heads), mlp_count_irreps + )[:, head_idx, :].flatten() + elif "readouts.1.linear_2.weight" in name: + new_state_dict[name] = param.reshape( + len(model.heads), -1, len(model.heads) + )[head_idx, :, head_idx].flatten() / (len(model.heads) ** 0.5) + else: + new_state_dict[name] = param[start_idx:end_idx] + + else: + new_state_dict[name] = param + + # Load state dict into new model + new_model.load_state_dict(new_state_dict) + + return new_model + + +def extract_model(model: torch.nn.Module, map_location: str = "cpu") -> torch.nn.Module: + model_copy = model.__class__(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + return model_copy.to(map_location) + + +def convert_to_json_format(dict_input): + for key, value in dict_input.items(): + if isinstance(value, (np.ndarray, torch.Tensor)): + dict_input[key] = value.tolist() + # # check if the value is a class and convert it to a string + elif hasattr(value, "__class__"): + dict_input[key] = str(value) + return dict_input + + +def convert_from_json_format(dict_input): + dict_output = dict_input.copy() + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) + if ( + dict_input["interaction_cls"] + == "" + ): + dict_output["interaction_cls"] = modules.blocks.RealAgnosticInteractionBlock + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticResidualInteractionBlock + ) + if ( + dict_input["interaction_cls_first"] + == "" + ): + dict_output["interaction_cls_first"] = ( + modules.blocks.RealAgnosticInteractionBlock + ) + dict_output["r_max"] = float(dict_input["r_max"]) + dict_output["num_bessel"] = int(dict_input["num_bessel"]) + dict_output["num_polynomial_cutoff"] = float(dict_input["num_polynomial_cutoff"]) + dict_output["max_ell"] = int(dict_input["max_ell"]) + dict_output["num_interactions"] = int(dict_input["num_interactions"]) + dict_output["num_elements"] = int(dict_input["num_elements"]) + dict_output["hidden_irreps"] = o3.Irreps(dict_input["hidden_irreps"]) + dict_output["MLP_irreps"] = o3.Irreps(dict_input["MLP_irreps"]) + dict_output["avg_num_neighbors"] = float(dict_input["avg_num_neighbors"]) + dict_output["gate"] = torch.nn.functional.silu + dict_output["atomic_energies"] = np.array(dict_input["atomic_energies"]) + dict_output["atomic_numbers"] = dict_input["atomic_numbers"] + dict_output["correlation"] = int(dict_input["correlation"]) + dict_output["radial_type"] = dict_input["radial_type"] + dict_output["radial_MLP"] = ast.literal_eval(dict_input["radial_MLP"]) + dict_output["pair_repulsion"] = ast.literal_eval(dict_input["pair_repulsion"]) + dict_output["distance_transform"] = dict_input["distance_transform"] + dict_output["atomic_inter_scale"] = float(dict_input["atomic_inter_scale"]) + dict_output["atomic_inter_shift"] = float(dict_input["atomic_inter_shift"]) + + return dict_output + + +def load_from_json(f: str, map_location: str = "cpu") -> torch.nn.Module: + extra_files_extract = {"commit.txt": None, "config.json": None} + model_jit_load = torch.jit.load( + f, _extra_files=extra_files_extract, map_location=map_location + ) + model_load_yaml = modules.ScaleShiftMACE( + **convert_from_json_format(json.loads(extra_files_extract["config.json"])) + ) + model_load_yaml.load_state_dict(model_jit_load.state_dict()) + return model_load_yaml.to(map_location) + + +def get_atomic_energies(E0s, train_collection, z_table) -> dict: + if E0s is not None: + logging.info( + "Isolated Atomic Energies (E0s) not in training file, using command line argument" + ) + if E0s.lower() == "average": + logging.info( + "Computing average Atomic Energies using least squares regression" + ) + # catch if colections.train not defined above + try: + assert train_collection is not None + atomic_energies_dict = data.compute_average_E0s( + train_collection, z_table + ) + except Exception as e: + raise RuntimeError( + f"Could not compute average E0s if no training xyz given, error {e} occured" + ) from e + else: + if E0s.endswith(".json"): + logging.info(f"Loading atomic energies from {E0s}") + with open(E0s, "r", encoding="utf-8") as f: + atomic_energies_dict = json.load(f) + atomic_energies_dict = { + int(key): value for key, value in atomic_energies_dict.items() + } + else: + try: + atomic_energies_eval = ast.literal_eval(E0s) + if not all( + isinstance(value, dict) + for value in atomic_energies_eval.values() + ): + atomic_energies_dict = atomic_energies_eval + else: + atomic_energies_dict = atomic_energies_eval + assert isinstance(atomic_energies_dict, dict) + except Exception as e: + raise RuntimeError( + f"E0s specified invalidly, error {e} occured" + ) from e + else: + raise RuntimeError( + "E0s not found in training file and not specified in command line" + ) + return atomic_energies_dict + + +def get_avg_num_neighbors(head_configs, args, train_loader, device): + if all(head_config.compute_avg_num_neighbors for head_config in head_configs): + logging.info("Computing average number of neighbors") + avg_num_neighbors = modules.compute_avg_num_neighbors(train_loader) + if args.distributed: + num_graphs = torch.tensor(len(train_loader.dataset)).to(device) + num_neighbors = num_graphs * torch.tensor(avg_num_neighbors).to(device) + torch.distributed.all_reduce(num_graphs, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce( + num_neighbors, op=torch.distributed.ReduceOp.SUM + ) + avg_num_neighbors_out = (num_neighbors / num_graphs).item() + else: + avg_num_neighbors_out = avg_num_neighbors + else: + assert any( + head_config.avg_num_neighbors is not None for head_config in head_configs + ), "Average number of neighbors must be provided in the configuration" + avg_num_neighbors_out = max( + head_config.avg_num_neighbors + for head_config in head_configs + if head_config.avg_num_neighbors is not None + ) + if avg_num_neighbors_out < 2 or avg_num_neighbors_out > 100: + logging.warning( + f"Unusual average number of neighbors: {avg_num_neighbors_out:.1f}" + ) + else: + logging.info(f"Average number of neighbors: {avg_num_neighbors_out}") + return avg_num_neighbors_out + + +def get_loss_fn( + args: argparse.Namespace, + dipole_only: bool, + compute_dipole: bool, +) -> torch.nn.Module: + if args.loss == "weighted": + loss_fn = modules.WeightedEnergyForcesLoss( + energy_weight=args.energy_weight, forces_weight=args.forces_weight + ) + elif args.loss == "forces_only": + loss_fn = modules.WeightedForcesLoss(forces_weight=args.forces_weight) + elif args.loss == "virials": + loss_fn = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + virials_weight=args.virials_weight, + ) + elif args.loss == "stress": + loss_fn = modules.WeightedEnergyForcesStressLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + ) + elif args.loss == "huber": + loss_fn = modules.WeightedHuberEnergyForcesStressLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "universal": + loss_fn = modules.UniversalLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + stress_weight=args.stress_weight, + huber_delta=args.huber_delta, + ) + elif args.loss == "l1l2energyforces": + loss_fn = modules.WeightedEnergyForcesL1L2Loss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + ) + elif args.loss == "dipole": + assert ( + dipole_only is True + ), "dipole loss can only be used with AtomicDipolesMACE model" + loss_fn = modules.DipoleSingleLoss( + dipole_weight=args.dipole_weight, + ) + elif args.loss == "energy_forces_dipole": + assert dipole_only is False and compute_dipole is True + loss_fn = modules.WeightedEnergyForcesDipoleLoss( + energy_weight=args.energy_weight, + forces_weight=args.forces_weight, + dipole_weight=args.dipole_weight, + ) + else: + loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) + return loss_fn + + +def get_swa( + args: argparse.Namespace, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + swas: List[bool], + dipole_only: bool = False, +): + assert dipole_only is False, "Stage Two for dipole fitting not implemented" + swas.append(True) + if args.start_swa is None: + args.start_swa = max(1, args.max_num_epochs // 4 * 3) + else: + if args.start_swa >= args.max_num_epochs: + logging.warning( + f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}" + ) + swas[-1] = False + if args.loss == "forces_only": + raise ValueError("Can not select Stage Two with forces only loss.") + if args.loss == "virials": + loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + virials_weight=args.swa_virials_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, virials_weight: {args.swa_virials_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "stress": + loss_fn_energy = modules.WeightedEnergyForcesStressLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "energy_forces_dipole": + loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss( + args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + dipole_weight=args.swa_dipole_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, dipole weight : {args.swa_dipole_weight} and learning rate : {args.swa_lr}" + ) + elif args.loss == "universal": + loss_fn_energy = modules.UniversalLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + stress_weight=args.swa_stress_weight, + huber_delta=args.huber_delta, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight}, stress weight : {args.swa_stress_weight} and learning rate : {args.swa_lr}" + ) + else: + loss_fn_energy = modules.WeightedEnergyForcesLoss( + energy_weight=args.swa_energy_weight, + forces_weight=args.swa_forces_weight, + ) + logging.info( + f"Stage Two (after {args.start_swa} epochs) with loss function: {loss_fn_energy}, with energy weight : {args.swa_energy_weight}, forces weight : {args.swa_forces_weight} and learning rate : {args.swa_lr}" + ) + swa = SWAContainer( + model=AveragedModel(model), + scheduler=SWALR( + optimizer=optimizer, + swa_lr=args.swa_lr, + anneal_epochs=1, + anneal_strategy="linear", + ), + start=args.start_swa, + loss_fn=loss_fn_energy, + ) + return swa, swas + + +def get_params_options( + args: argparse.Namespace, model: torch.nn.Module +) -> Dict[str, Any]: + decay_interactions = {} + no_decay_interactions = {} + for name, param in model.interactions.named_parameters(): + if "linear.weight" in name or "skip_tp_full.weight" in name: + decay_interactions[name] = param + else: + no_decay_interactions[name] = param + + param_options = dict( + params=[ + { + "name": "embedding", + "params": model.node_embedding.parameters(), + "weight_decay": 0.0, + }, + { + "name": "interactions_decay", + "params": list(decay_interactions.values()), + "weight_decay": args.weight_decay, + }, + { + "name": "interactions_no_decay", + "params": list(no_decay_interactions.values()), + "weight_decay": 0.0, + }, + { + "name": "products", + "params": model.products.parameters(), + "weight_decay": args.weight_decay, + }, + { + "name": "readouts", + "params": model.readouts.parameters(), + "weight_decay": 0.0, + }, + ], + lr=args.lr, + amsgrad=args.amsgrad, + betas=(args.beta, 0.999), + ) + return param_options + + +def get_optimizer( + args: argparse.Namespace, param_options: Dict[str, Any] +) -> torch.optim.Optimizer: + if args.optimizer == "adamw": + optimizer = torch.optim.AdamW(**param_options) + elif args.optimizer == "schedulefree": + try: + from schedulefree import adamw_schedulefree + except ImportError as exc: + raise ImportError( + "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" + ) from exc + _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} + optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) + else: + optimizer = torch.optim.Adam(**param_options) + return optimizer + + +def setup_wandb(args: argparse.Namespace): + logging.info("Using Weights and Biases for logging") + import wandb + + wandb_config = {} + args_dict = vars(args) + + for key, value in args_dict.items(): + if isinstance(value, np.ndarray): + args_dict[key] = value.tolist() + + class CustomEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, KeySpecification): + return o.__dict__ + return super().default(o) + + args_dict_json = json.dumps(args_dict, cls=CustomEncoder) + for key in args.wandb_log_hypers: + wandb_config[key] = args_dict[key] + tools.init_wandb( + project=args.wandb_project, + entity=args.wandb_entity, + name=args.wandb_name, + config=wandb_config, + directory=args.wandb_dir, + ) + wandb.run.summary["params"] = args_dict_json + + +def get_files_with_suffix(dir_path: str, suffix: str) -> List[str]: + return [ + os.path.join(dir_path, f) for f in os.listdir(dir_path) if f.endswith(suffix) + ] + + +def dict_to_array(input_data, heads): + if all(isinstance(value, np.ndarray) for value in input_data.values()): + return np.array([input_data[head] for head in heads]) + if not all(isinstance(value, dict) for value in input_data.values()): + return np.array([[input_data[head]] for head in heads]) + unique_keys = set() + for inner_dict in input_data.values(): + unique_keys.update(inner_dict.keys()) + unique_keys = list(unique_keys) + sorted_keys = sorted([int(key) for key in unique_keys]) + result_array = np.zeros((len(input_data), len(sorted_keys))) + for _, (head_name, inner_dict) in enumerate(input_data.items()): + for key, value in inner_dict.items(): + key_index = sorted_keys.index(int(key)) + head_index = heads.index(head_name) + result_array[head_index][key_index] = value + return result_array + + +class LRScheduler: + def __init__(self, optimizer, args) -> None: + self.scheduler = args.scheduler + self._optimizer_type = ( + args.optimizer + ) # Schedulefree does not need an optimizer but checkpoint handler does. + if args.scheduler == "ExponentialLR": + self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, gamma=args.lr_scheduler_gamma + ) + elif args.scheduler == "ReduceLROnPlateau": + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer=optimizer, + factor=args.lr_factor, + patience=args.scheduler_patience, + ) + else: + raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'") + + def step(self, metrics=None, epoch=None): # pylint: disable=E1123 + if self._optimizer_type == "schedulefree": + return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary + if self.scheduler == "ExponentialLR": + self.lr_scheduler.step(epoch=epoch) + elif self.scheduler == "ReduceLROnPlateau": + self.lr_scheduler.step( # pylint: disable=E1123 + metrics=metrics, epoch=epoch + ) + + def __getattr__(self, name): + if name == "step": + return self.step + return getattr(self.lr_scheduler, name) + + +def check_folder_subfolder(folder_path): + entries = os.listdir(folder_path) + for entry in entries: + full_path = os.path.join(folder_path, entry) + if os.path.isdir(full_path): + return True + return False + + +def check_path_ase_read(filename: Optional[str]) -> bool: + if filename is None: + return False + filepath = Path(filename) + if filepath.is_dir(): + num_h5_files = len(list(filepath.glob("*.h5"))) + num_hdf5_files = len(list(filepath.glob("*.hdf5"))) + num_ldb_files = len(list(filepath.glob("*.lmdb"))) + num_aselmbd_files = len(list(filepath.glob("*.aselmdb"))) + num_mdb_files = len(list(filepath.glob("*.mdb"))) + if ( + num_h5_files + + num_hdf5_files + + num_ldb_files + + num_aselmbd_files + + num_mdb_files + == 0 + ): + # print all the files in the directory extension in the directory for debugging + for file in os.listdir(filepath): + print(file) + raise RuntimeError(f"No supported files found in directory '{filename}'") + return False + if filepath.suffix in (".h5", ".hdf5", ".lmdb", ".aselmdb", ".mdb"): + return False + return True + + +def dict_to_namespace(dictionary): + # Convert the dictionary into an argparse.Namespace + namespace = argparse.Namespace() + for key, value in dictionary.items(): + setattr(namespace, key, value) + return namespace diff --git a/mace-bench/3rdparty/mace/mace/tools/slurm_distributed.py b/mace-bench/3rdparty/mace/mace/tools/slurm_distributed.py index 578f5a0..9b7c77b 100644 --- a/mace-bench/3rdparty/mace/mace/tools/slurm_distributed.py +++ b/mace-bench/3rdparty/mace/mace/tools/slurm_distributed.py @@ -1,40 +1,40 @@ -########################################################################################### -# Slurm environment setup for distributed training. -# This code is refactored from rsarm's contribution at: -# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import os - -import hostlist - - -class DistributedEnvironment: - def __init__(self): - self._setup_distr_env() - self.master_addr = os.environ["MASTER_ADDR"] - self.master_port = os.environ["MASTER_PORT"] - self.world_size = int(os.environ["WORLD_SIZE"]) - self.local_rank = int(os.environ["LOCAL_RANK"]) - self.rank = int(os.environ["RANK"]) - - def _setup_distr_env(self): - hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] - os.environ["MASTER_ADDR"] = hostname - os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") - os.environ["WORLD_SIZE"] = os.environ.get( - "SLURM_NTASKS", - str( - int(os.environ["SLURM_NTASKS_PER_NODE"]) - * int(os.environ["SLURM_NNODES"]) - ), - ) - os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] - os.environ["RANK"] = os.environ["SLURM_PROCID"] - - def __repr__(self): - return ( - f"DistributedEnvironment(master_addr={self.master_addr}, master_port={self.master_port}, " - f"world_size={self.world_size}, local_rank={self.local_rank}, rank={self.rank})" - ) +########################################################################################### +# Slurm environment setup for distributed training. +# This code is refactored from rsarm's contribution at: +# https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import os + +import hostlist + + +class DistributedEnvironment: + def __init__(self): + self._setup_distr_env() + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = os.environ["MASTER_PORT"] + self.world_size = int(os.environ["WORLD_SIZE"]) + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.rank = int(os.environ["RANK"]) + + def _setup_distr_env(self): + hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] + os.environ["MASTER_ADDR"] = hostname + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33333") + os.environ["WORLD_SIZE"] = os.environ.get( + "SLURM_NTASKS", + str( + int(os.environ["SLURM_NTASKS_PER_NODE"]) + * int(os.environ["SLURM_NNODES"]) + ), + ) + os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] + os.environ["RANK"] = os.environ["SLURM_PROCID"] + + def __repr__(self): + return ( + f"DistributedEnvironment(master_addr={self.master_addr}, master_port={self.master_port}, " + f"world_size={self.world_size}, local_rank={self.local_rank}, rank={self.rank})" + ) diff --git a/mace-bench/3rdparty/mace/mace/tools/tables_utils.py b/mace-bench/3rdparty/mace/mace/tools/tables_utils.py index bbd581b..04ff640 100644 --- a/mace-bench/3rdparty/mace/mace/tools/tables_utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/tables_utils.py @@ -1,246 +1,246 @@ -import logging -from typing import Dict, List, Optional - -import torch -from prettytable import PrettyTable - -from mace.tools import evaluate - - -def custom_key(key): - """ - Helper function to sort the keys of the data loader dictionary - to ensure that the training set, and validation set - are evaluated first - """ - if key == "train": - return (0, key) - if key == "valid": - return (1, key) - return (2, key) - - -def create_error_table( - table_type: str, - all_data_loaders: dict, - model: torch.nn.Module, - loss_fn: torch.nn.Module, - output_args: Dict[str, bool], - log_wandb: bool, - device: str, - distributed: bool = False, - skip_heads: Optional[List[str]] = None, -) -> PrettyTable: - if log_wandb: - import wandb - skip_heads = skip_heads or [] - table = PrettyTable() - if table_type == "TotalRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - ] - elif table_type == "PerAtomRMSEstressvirials": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "relative F RMSE %", - "RMSE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "PerAtomMAEstressvirials": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - "MAE Stress (Virials) / meV / A (A^3)", - ] - elif table_type == "TotalMAE": - table.field_names = [ - "config_type", - "MAE E / meV", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "PerAtomMAE": - table.field_names = [ - "config_type", - "MAE E / meV / atom", - "MAE F / meV / A", - "relative F MAE %", - ] - elif table_type == "DipoleRMSE": - table.field_names = [ - "config_type", - "RMSE MU / mDebye / atom", - "relative MU RMSE %", - ] - elif table_type == "DipoleMAE": - table.field_names = [ - "config_type", - "MAE MU / mDebye / atom", - "relative MU MAE %", - ] - elif table_type == "EnergyDipoleRMSE": - table.field_names = [ - "config_type", - "RMSE E / meV / atom", - "RMSE F / meV / A", - "rel F RMSE %", - "RMSE MU / mDebye / atom", - "rel MU RMSE %", - ] - - for name in sorted(all_data_loaders, key=custom_key): - if any(skip_head in name for skip_head in skip_heads): - logging.info(f"Skipping evaluation of {name} (in skip_heads list)") - continue - data_loader = all_data_loaders[name] - logging.info(f"Evaluating {name} ...") - _, metrics = evaluate( - model, - loss_fn=loss_fn, - data_loader=data_loader, - output_args=output_args, - device=device, - ) - if distributed: - torch.distributed.barrier() - - del data_loader - torch.cuda.empty_cache() - if log_wandb: - wandb_log_dict = { - name - + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] - * 1e3, # meV / atom - name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A - name + "_final_rel_rmse_f": metrics["rel_rmse_f"], - } - wandb.log(wandb_log_dict) - if table_type == "TotalRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif table_type == "PerAtomRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomRMSEstressvirials" - and metrics["rmse_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.2f}", - f"{metrics['rmse_virials'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_stress"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_stress'] * 1000:8.1f}", - ] - ) - elif ( - table_type == "PerAtomMAEstressvirials" - and metrics["mae_virials"] is not None - ): - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - f"{metrics['mae_virials'] * 1000:8.1f}", - ] - ) - elif table_type == "TotalMAE": - table.add_row( - [ - name, - f"{metrics['mae_e'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "PerAtomMAE": - table.add_row( - [ - name, - f"{metrics['mae_e_per_atom'] * 1000:8.1f}", - f"{metrics['mae_f'] * 1000:8.1f}", - f"{metrics['rel_mae_f']:8.2f}", - ] - ) - elif table_type == "DipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - elif table_type == "DipoleMAE": - table.add_row( - [ - name, - f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", - f"{metrics['rel_mae_mu']:8.1f}", - ] - ) - elif table_type == "EnergyDipoleRMSE": - table.add_row( - [ - name, - f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", - f"{metrics['rmse_f'] * 1000:8.1f}", - f"{metrics['rel_rmse_f']:8.1f}", - f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", - f"{metrics['rel_rmse_mu']:8.1f}", - ] - ) - return table +import logging +from typing import Dict, List, Optional + +import torch +from prettytable import PrettyTable + +from mace.tools import evaluate + + +def custom_key(key): + """ + Helper function to sort the keys of the data loader dictionary + to ensure that the training set, and validation set + are evaluated first + """ + if key == "train": + return (0, key) + if key == "valid": + return (1, key) + return (2, key) + + +def create_error_table( + table_type: str, + all_data_loaders: dict, + model: torch.nn.Module, + loss_fn: torch.nn.Module, + output_args: Dict[str, bool], + log_wandb: bool, + device: str, + distributed: bool = False, + skip_heads: Optional[List[str]] = None, +) -> PrettyTable: + if log_wandb: + import wandb + skip_heads = skip_heads or [] + table = PrettyTable() + if table_type == "TotalRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + ] + elif table_type == "PerAtomRMSEstressvirials": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "relative F RMSE %", + "RMSE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "PerAtomMAEstressvirials": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + "MAE Stress (Virials) / meV / A (A^3)", + ] + elif table_type == "TotalMAE": + table.field_names = [ + "config_type", + "MAE E / meV", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "PerAtomMAE": + table.field_names = [ + "config_type", + "MAE E / meV / atom", + "MAE F / meV / A", + "relative F MAE %", + ] + elif table_type == "DipoleRMSE": + table.field_names = [ + "config_type", + "RMSE MU / mDebye / atom", + "relative MU RMSE %", + ] + elif table_type == "DipoleMAE": + table.field_names = [ + "config_type", + "MAE MU / mDebye / atom", + "relative MU MAE %", + ] + elif table_type == "EnergyDipoleRMSE": + table.field_names = [ + "config_type", + "RMSE E / meV / atom", + "RMSE F / meV / A", + "rel F RMSE %", + "RMSE MU / mDebye / atom", + "rel MU RMSE %", + ] + + for name in sorted(all_data_loaders, key=custom_key): + if any(skip_head in name for skip_head in skip_heads): + logging.info(f"Skipping evaluation of {name} (in skip_heads list)") + continue + data_loader = all_data_loaders[name] + logging.info(f"Evaluating {name} ...") + _, metrics = evaluate( + model, + loss_fn=loss_fn, + data_loader=data_loader, + output_args=output_args, + device=device, + ) + if distributed: + torch.distributed.barrier() + + del data_loader + torch.cuda.empty_cache() + if log_wandb: + wandb_log_dict = { + name + + "_final_rmse_e_per_atom": metrics["rmse_e_per_atom"] + * 1e3, # meV / atom + name + "_final_rmse_f": metrics["rmse_f"] * 1e3, # meV / A + name + "_final_rel_rmse_f": metrics["rel_rmse_f"], + } + wandb.log(wandb_log_dict) + if table_type == "TotalRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif table_type == "PerAtomRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomRMSEstressvirials" + and metrics["rmse_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.2f}", + f"{metrics['rmse_virials'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_stress"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_stress'] * 1000:8.1f}", + ] + ) + elif ( + table_type == "PerAtomMAEstressvirials" + and metrics["mae_virials"] is not None + ): + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + f"{metrics['mae_virials'] * 1000:8.1f}", + ] + ) + elif table_type == "TotalMAE": + table.add_row( + [ + name, + f"{metrics['mae_e'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "PerAtomMAE": + table.add_row( + [ + name, + f"{metrics['mae_e_per_atom'] * 1000:8.1f}", + f"{metrics['mae_f'] * 1000:8.1f}", + f"{metrics['rel_mae_f']:8.2f}", + ] + ) + elif table_type == "DipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + elif table_type == "DipoleMAE": + table.add_row( + [ + name, + f"{metrics['mae_mu_per_atom'] * 1000:8.2f}", + f"{metrics['rel_mae_mu']:8.1f}", + ] + ) + elif table_type == "EnergyDipoleRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_e_per_atom'] * 1000:8.1f}", + f"{metrics['rmse_f'] * 1000:8.1f}", + f"{metrics['rel_rmse_f']:8.1f}", + f"{metrics['rmse_mu_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_mu']:8.1f}", + ] + ) + return table diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__init__.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__init__.py index 329f8dd..486f0d0 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__init__.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__init__.py @@ -1,7 +1,7 @@ -from .batch import Batch -from .data import Data -from .dataloader import DataLoader -from .dataset import Dataset -from .seed import seed_everything - -__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] +from .batch import Batch +from .data import Data +from .dataloader import DataLoader +from .dataset import Dataset +from .seed import seed_everything + +__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"] diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 735bf63e361ff3a6ac40b3983fa5225113bd74ef..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 438 zcmYk2u};G<5Qgm}Zrad@=g0sVSP??Nz<>}^w=9+$TS+QAjvRwj9*3<1BQIk{CSHMw zJC^}l{_nrj>ARC{Sr!D&`^TaGg!qNY|C2y+4R5{&kwj8W3~eYypXkJ-O&WcwGs7Ad zeWn*CZ}RAwUYb?2qNF2o@kX0M<{MHizbW7<8rY5SZ4V~96J9_pFg{ozm4htDY*rrZ z>8zCG>P5M|?}xFgn8V#%io#Gblm*CRbrRSX!LcP|5-71T6litC$NrFG0VYI5$Z$wx zgq(AsHRlK7N(7_)baiiys*QN75fj`gh?+rGYVauUw( vW9=y3T0Pg^y0+(?0-|?ATXQ~)gXjEw+K1((vGQ4~EBvQ<4YDR{`gfCG5%Y04 diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/__init__.cpython-313.pyc deleted file mode 100644 index b98d9d71e3b8a1c87147e2c555a75e1c6fe9cbd1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 441 zcmXw#F-rq66vva^rD;z}aS;WF>UM+P;3SBH2+ExZ(p3oQUFco8HsxIu>gI=VcGa&G zxkV>8pSAiyg*17I5MD|nu|}@mXjv%0 zj=Qc=265OZRS5N{U}ef`F^w4n=z!2g*9CLIF>d26%O+0jM{HQ2#ShxZVeR#P-g0+= z8@$4XZi#nTE`+RWA>PRvmBijBrJk!~E}u%NXO%?C)zK34Oyy?NOw=MTlM^VG63oiq zZNg|>H;K`Zmtv|=W1z|tA*w}Xgg9QVZ0~HYi-%gB!!|B+pVRDN`~;NJmh84B0gXLR+>4YCCl2Svf&z$W54sx{L7CY9sclnP5U`jRzC(Re~e$UheBwAh1!sH z86%p9e5iMIMeCu_Eul>h&8~@mBQ%HQZW(eKo zE2nz5CiL4{+kT8yv^LAF+y2qmk9t06?@nU0q!Gx^}>XE{unjuHl#ZT3F4ja7y$yd2ZDsIX7p;c|_%^?; z8CrK8JKe}juXt(i(A&dqUUo*I7er1R3TJej9>yT|2YxSghOzL&AnH4wvrogF_`vB) zZ*-VAUgE@&@9YPC+rj6W2$EhLMVRG_b}^{qd@V>1otwSTOOl&=X)Jq(Zr_iGeky}r zN8tQAbclP-zE6v!hrZ&q?oR{kgC-3lpRd5tCGoNl52z;L(!UNfPv5IE?mi9f&X#im0bapD;LAd=-JZUp6((p>>neUp ziXzogmT9Mazy%MqO#dhLEhu=VPxUkHRqd_jn|xwSN|_Ff+7)6RlKend`RA+(`WJK(1jK*`&41JNG=Kd7?Ya0 zt|X80-{Qt0^z6paJ2~{?LEwSv`M)=kG3oVF`}2Q>GF+I?qqENfV51Ha@licQw!7PFYmO>VN~*JdY;SIBxg^Hc57{m(z!AV?Q( z;g@vKqx69Zc1auXQwH78PuaG1$|i>3nQ@oNi$bUWWb4b?+oU-fw4)fWPr1^W&54XF-qGdfP%jH(Sr-O0$b|kplNuSEPF8{r6)wVUsOS_rIoCa+NWH8FKY~{X>Cx4^xs44TG~MCcd|9ufiAzBHL|rW%#EsQMk8H2 z;OPIydAq1hnpu6?%x0^c@yYsv_fN9Mbo~tB1QzuZ>|;G!&ulu4#fYC~>%tgo@~48K zw=r!AcE%uSV{lqa?kO)kRZD{UEDihxNhJv_@2-no#SZ?RO>@Gt|Zm z@;*{>T|6%r?xZyM#Nxb50|zupI;repan?s6bR45jq2FrfKN3d7sRf$;3DxJIucrF` zZPw?Y&%v#g+-GnLCHGlqNXh+VRz>+-VJVYM@ZdP<&qJ5hy-&8XEwWrlwW)NE?iboc zx`ZA4O0bA!>q||1hf9vz{#3RF9fCxpgE-~?Q2I5~X^dG{8nJZ-%bSnLO4H}h$dj*o zUq%eSH4kiv%=1`X#P9oS#z#d?XEka% z0m8NPJ&Lt}6clO?!s8`}muJBmA^!OKY$V}|MKo4!b{3Np=OC6v?1kHRf@mqc3e*R2 z7{>QpHDzzP#@7Pi2mgLA93x(P(XoTciPJ}{j=*Ck?yTRXzP*kle76iYyN@~QId|R1 z?>~9+q({B;5A~2Mzm4B}6bBH#$z3w2ht_A`;Ya+o_J|c6DJ(tWIq!U0UIMAYR8c&Q zYHsg{pqKXI;pnB8P=AkjZ)hjiipW8iSVUO5z|Kw>_ke<4>MSSyR>9;mwA!;&kO=aV z3rbaKarr3}e?$Z@xb_&^d?E?E=FH8_;!N8-FNqW&H#dcU7m-~yFra-kH%OpdPy94D z=l}^j&`uD__0W&xCb7Lj#X51U_=xVlgi_>o0FaiX9$>blQg75f` zmB_}yI1C|v9QE_^p_h0mzzl6lR;XCu4cVgV1uE!@^4iuByK$B_EAIIaQy^W?Y z$twy4DgL!pmCS3aN+vtRNUfk$K}l~GfGIZ%u#{VKpww+Fp_k&!g502WB%uzhJ08lL zM3rYV=-7HW8{~RA9I4&pmh=yl{>VBlt!{pC{dwcOdXW+&|BB*rWrLZ##T=j-{=nun z)?hZi25*2e@NJ>TVl{nJQT#yzHA8Q*P2B+1LTTfl=9sL(FM~r9^krbA8ruMU8ND0K zdcyz5E*G@%gg>YmSO+M|W}B>rAAfL#Us)h33+-xEtXXh@EDx;G$;L<73Jycu&rRv4 zV;S|xs^%ZM<}!W?c)ba`&a?p_0fK%dWdhKPz!d?i35PxFcQm<)5qbtt1Sig5{s z0!%1p=C*dGgU_?eQ5f-&b7(vPy@`P^g1k={fH?^O9ae(mhLX`2>Q-@Lv| zzY-M=K6c!3ycF&lY4!)aKCX7pQ2~ykOA#-MNVNN?Zi|u8iRd^)j-+mvHlt0eR$iv`!^3&9!9N-@i^;s&&3o6&ruE;H-|D1}iP@!BJ zxhDCCD3DjeQz1gYC(B)G{ebhS_WjiL?|VZy>3=|fvV+1lIpf@T-zsmyZ&>7abbRsO zg1>0MC)L$I2!yH8<%DTCIpr99QYp$%X& zq{&EWSQM(>5_+G>*Z>Lkz=Ab_o>tDZsSWSIqV8hM=VxQ87z3Kd3;?$U|K{rWSJe3W zfD)S5)QtLI4KqNmaJ{iI%AbuQzM$2p8WI>-^vaA)H3Ohpjuf?i%99Z+wy8?ZJ{7w=MW6-T5>7PgL;_mm9n;g10uT;GTnAu|BD~{C-?{6Bfgt;) z2LtC?A8LN}gkNoc2-1)~lODcDpaLnW_EveqzM)NVe`oC53T-Lv0_1uVBOY=29RB2M zM6*IomqE|oA>&hk>lYZH0N18=@_eB@%GIgdtpd3IEIbcxc>enyg_0$HUpx78b? zGTq!zC=4ho@`k?aVv|ezQ@itML4_%6tGoDjG$(W(HY{ip5qy(hX50AR#{UL~z3H4; z+`8Glf(aEk`swrs63o7MCoKQ2ix-2{qXO z90Wh?g^O6>eNNfn6f|-V%0^@M#^6?z0=K`>Fc2%y`E*Ghidg3fmk`xvh z0_l9N554=j{TE(1o(1sbK`f7#PUzRP7s_I6-Gr(WUT;xMY!>#sjQDN7yW4Kbh`8gm z7BBtd5blItXphETSU49-qC|`&O1*UERrSPLJjKg5Ftad7dI2u%kl?f=qbV$stWPE= zatkl7cwJ79)iu&{vSQWq!*KRuTd0mQUzL*5J2XA7Cgkfa&0vi(Z#7v(gRk2jExx8ExqrwKm^jHJZL=|J&bORl4E$SUDfaD}AL%Y?>|!%DRS jTDy=(9q!_Ub8DwKr!mgrctOv;j_NK61jrbM;S5?GRu!3Dhw zNEW9}{nM0c)Cn2K73gFlC{b%D(M+guJCeIW%=D*sNbG%_d+xc9^St+w-EN~GT|bzf{b?6P{SpgWGUYRmrXlkV<){4g z3FZcFaE=mbKP4Lcv}p7jM8oD~OwdYB8;tN5&l zPpm9zH~4L&-EqzdwH6qm9Lt3fs{9qAD?q(Ki|$U!Q>mC<;^Lv%&no~o_f?qeVRk;s zMFch$4YTu?(bads{i=A#i|hUM5P?8C;U*%^_WpOsinVq*f&PDN&HJ>iHH ziU|Vr;KLpjV1Rus5}##e0C3rlQ>6QE9t0XUnG_%J)mi~NA?6{rn|Bk@Q~=wl^76y-Gn zu?GN)Jdh^E#RL+^UW{-!6i_ckW`sFjh_e@g?3id9hA9u9C$l`?$6lI^z-;HZd5PsY z$O*6zY&=Ft5_t)ZLOgOYKNK4Y1D?oqgck>Fq5N?pyNkP54*!q7fS*J3JJckp#-B-p zmJLvU0!#u2*07rj7{}=SGz(u}nbxj^w=ArcP;6;mY6&q<0s7%V_zE!h&qI3E{9 zToa`-7{t>PBnJfmY{8%ujf8kf&4&_m5)2kpDuTff7Y~LbbJ*f4#u^2NhdRK7BdG^E z81$4WjKoK$MOZpYtj4j-uV3QC8A&m|&Rs%ra!jyKWJ#Z-n`QMgB@31SscwdTuZNK9BsLs zo==_a51QE}|J&AwR;saQqk^h!x@yb0Yp+_?ceF1u_bpAi#vO~yRfqOzX0f&**WQKS z)*M28VrhB=BU7QmVS@s0KOX%@7C07g&{2kxG+AJX0PUyo{P~Le<)QZF;ftV+T3tV! z6D=<=enWtH9-uimzJ8;c4+AUIQ9K<0Hqw|xFcrot9Q6EPNJD`U^g694uvZI33Fd&= z&-hIYHOllg@qx+6ru=(*baJ zvPc5Qi zvRP|?ViZ`FfCB_4_JY$m@t?TDLgjdCQU6N4|Ktu0(~#i_Tt)r;R>2vt``R>U{W~!i zFylICy@jfhk)FODuoAEy9Xcj0r=8Zk_=k0b-LDE%P5eEcgl%F6yn(6^3w>$mTj|@W zVb{O@!j21-0MVtlXgQ&5eD5~L1HJ+r9<8VT4OD4d0xS5GQ6^5~3i^8WURo|tsr3dW z_syq5In_4xMBhF&jj?(-ChuqC9*KY^N3-al$QIg{T3U2AVipSGs@N%u43}s0$$7rXP*Q^eq7H9^& zwK)NYUm+TtNLryTP^VgKuwS)XYBgOxg60V!-SjvesM@x6YbrTa$2Zf1Zc(5PbOT_< zl(&3dEs=8D)YHU%d!uB$ttE1uwj24;zNMZGSh15T6~(`rjO+w!i-mu|JYAIYRC4V! zMU83>D0t19AFcY=(_(B6>hexEx+Ux^c;Vp3+P2|2t4_K!V0(2dkCrX=XxRw3&?3*F z4-D=-y1>)X%LBH-ZAk)@;`8Pw+1o+IHaG3>QiJbv|S%5xyvhwXgY&$<7;s)z}*Qw93tQdm9&;RGk zFTea8{=XE_idX2MVpa$87!*@D4nY~k9Of^A|0lWtz|*ZTxQvQX;^PW~*wMobM?&cE z8Kb-)qV*^uYl=09(aM7l$cwNtiY)>W1S!rzSV%8JYf?Yw05CNb>w24J+;7Vpd&h#iF^>N~P|g6U=l( zVQ?G~?PM4?F(;x%Qp)l@aM?l%eL*qC=jO?56pP4D6F*)=-$KHPsIK&rzaopcDv?)I z&!3E0tvbY|wP>A6AA@9ji+b#$oYmhQO&susx_e}IPuATpyZb+NdsTmW z@6zGLOUcf^`HR$GN=!+q1L^oD9fP@+wk6ZswudIFdiP@s#df3)e(xKLru&xWTx(Cd zYUR(et^4HGeQJ%nF8N&2bKl(o7#o@vk9_O+LnBq)`JldGNm?{5@>ePzRMjk+|D|hJ zrmFXc;iQzLm!$9X+z2n4*Q$D-Qk`?xELsp<<$4{PYRK00UOBpWc)hM^$(XBaTE4u* zY?!IKmgUz|?aQ$xROt3#w#h3u zc~=MTw!MEK*V?hNkYv7CuW3y&H>KN0QqgS9pj8@M7t5dhUa>v2@&f#2b{v4hFn!&enP5I`4Y- z?$zoKn(j2c-|}hq$jZ^=;bc6wvnO>VU6tOG?ns|rYadFE=DfRC5C7!iE!W+BKXcz3 zP98~htQ^lZx8Wz-+%Grxud_Sfv;EMP9!o#BdicY@M`!PyUHjT;nSJfIMnlVv$7NK{ zi}Xe%)xI-jN}Wy$Nn@_JZqw_2cgzPcx}Y ztMrF8A2r@<{Kd<1v;TKSLqp?3j;89HAJ5X1vu-J#3}zkuvZFue+>v$el$|?M;cV9d zx$D3uFloK+zO(m1ee?ComCD6ea*q1@j?VSE=4%VzSx6mB0X2m&>ilq6oHf_V^$of2 z?7=;}Q->vzm>0Z;K>1z3Ub8E`0KD_T0Cy4^)N+26ES58~*MvFlKeb^6|wZP)=4YA6OjA)$YjF?viWaWM^vkEyA&E zUm02ER?8NTt~(l%oj1){N4M}Z96-I}HC_x~jsC|BR_m8xC|BSk+q*aO4N z&ui>~5e9^6_)Wzu^6`WypzT#iWsoz%|D&_;`woP0eT61H6M{IM9&OXTe2k(!D>^0^ zz=|{sgI?ks;E_;1)XSSvfJCgo5~Kyhwy711zJjH#H)^&NM3U%Z;Pwc{2_vdufdL#C zE8dg>q%{N}6Ngo}@f(wKKQ z2cZjMaz4~@o6^MTiHjt*iF*Y}V8#zW8z;f`M~pl1a-fU>yhX8#g3sqS5?=#QVvG?n z!+tSjtqaj8LKgWcnYm?~tIWa))Az$NiFOM~PwWSrp@0QWl@^aXhU)_B(OT4>(fjM3 z=oV2I(Ebzo{u3_+NL&;mBVdnVGyv8Y?Y~V)-O(wS_p{mqhG?bMF<^k3G-ly}PrqG~ zh*iC9vz_JQU=72?V%~@bW zH)g}=Z?Ar>zIL{pDN zBTS`NrEH21ad0t5u$0hxiId>^EM6cof!B2G`IpFjMk$#+9?ca%q2h56OINAzCBUkk ziLc=g z#E2-7*uH4PiU;6FF@Y~2f-3>WyEqON5(=$?eX5$JIrkTAyzA_G_=H07qxq08x5dgXFUGb(n%t zwEFd1m*mFh01^zW=U1LjNvnf*p3gLmERVru()RnNLdlV(vGp!by7Nbo)jdBFl0zx_ z-RJIh62lSZdD3^E3G5Cd zH)$|lL1R5J9z@duv~P>sBwY*oocf4bj_3fE^hN)2Q_?iD4?X2gNoylierKChp4kSK z1udZuNT`G`H_Y%^RH7riDfz)xIr!{ee`s!_N-5q%$n5}_KBSV>PsMz z?Md1U83y4!-affgYxgsgc%4C#ejDo3OiQf^zHsAU8Z>I$40bgQH?s!G;DrbwL|S2m zV(gX?Z-dt7Ylf3BTeMpsvIrJOKE`O6rab0TFT>L$3nv9hFo^`v782t^6ytJneuft% zco0M0MX){H!`+@fqOto}0uF5?sx_d7pdd;YW`+3yAq;AhyNs`VNNA14SQHB}Ka72` zoXE2mxo9Meb_+hZV0&gjuXisPx;>s+5w$O)|wU5Q~Sgy=*!P zAufo>jbIgCDJW$^On~57Ofd-aP*BN%hap0I|B#467~*5Mq1+ zH^&EqVi2G_#v;vr5-KF@3z}PG;j2^-yDJE9J>YGI;!>4%aos;*m%k!fd%e77DfIn) z+445Iylvf8``twH)$75PV5Y7=9g^$zW?X}RJ^CwqUDEjd(^-3`Z0}6Ff7)=X;UD-9 zF5J10@w}LG)m-ac?tQyI>*@mKmT~oL8@^k1b+6W}xw_X}&t+XBvTG#cI-GSKlU>I$ zuHzaS-}SRAXS4Pm+1``(|MbkQGat5p)N`*V)BncX9Yw zZj)VoS=W&48p^o#XI(GIt`{<{5e?Nu*4`!CyKXPs?f6j2dPimNXs*HuMu)p)!(eb% z|Ejt^>3;i_Y;~7h-L>ZKPKg^PLsjo1BZGzyW3PNzK@IH5IQIP5pQiIpf7PeXj$3}& zv1iS)=l35w_DHCMf6?A}Y=kLM&8T3BYW{b~koQ%vXKH$;at79$1ZA$KNkmQ(yu^YQ zovAxZMMOdQP8QUxuiy}9IN-V7lTQS|`2)qjg$r0b7C^x0n^Q4_hd$-85k3EzJf>1OVo%lWpH~?j-2sc{5PI*y4m_ihws)REz2%hJRw+Ls5Z(y&y6wX3JaU^44{C{r7893(rdccm&x;tfeXUd;(d)6vEs@+xCJ{YTDcg3MlVs}x#;%Eo7V$x(L z$XL2W#yYo(X6VaWB4g=-jHN-w(hHwF;olLr2jt_n{XEI`|IVp{LlA})QzbKzmj+c{ z5>bh3Tac9TSW(D5FDZ-bHUL;cFADCGGkNZbUI9l!>l^tVi zmN8)IF$wuE*|?7h74JD#5{CY7$UH4_H3jSk%mNUbQzzDMaWCJf0>%e1w#`YyuANsuiIstv_+ag6__c-Qg=^nf{zj&*8-zg`LhF5N6@9;QHLq1P6EOge zIq~P~U|izwt5b7f5p8_+0t|&ExexOgPMz{pi*r!M$Dr_LIS~a*03U$9k>H}LU7$FE z!Rd%7#o;+KFT|i3uH(WnxQGWcgjC)6}uiNtvs1L;HW`dqmCCruIihC~p`ryAd|-F;yp}4GKP~ z_+$KROw)~~y(#Mk1)p?yb?DZG$C&-i0kdh>hSOv^LMOwip%tjbJOm%il^X`DsU&i2@{P(xwfHAgEaMA$lkf6iAUg^r?L;(6=C%*P;)8Hc%8sV)gs}Gc&tG zE@h<+nuk)_GiT?V^Pm6x=f9t`-MxEj2LAs3XE&FBamq0MonJ=(9DF&2PxL$rWhk?4 zT+~nVqAB04i)l@HDlhij?!t=Sirtpial4)Ny1U?`ebEm?U%3nG zZV*TAwGF&dLCcS3U9Y3uWxvz%-BoYh4gILsjtQ{W@vp7{5J!1&Yx&rfAo2lh*}EKc z!&%pFEH!4?vE|`?yzGa*>t78bfDfW>+v7mNuHl|*`|Q)|cA`K5Zvbuwtsn++ksrrF zXNlu?tZK0v;(ZVmmW35~Q4}n7I4X$UxGT*Kkff=nYuO}L`d7k}yTNzSFHtCSV9glc zGGgO~IWV`(aAsiMu(k|kU9zNZZy6hQXD)AlICkO^`joefO>@&k+o`xRFyqpu70pVm zx@m>apx49z?Mu2(Lt1KRdG(vdqA9yu1(ZEo#tVS6;c%C4Ob(na^MIi$2aM|mp49`K zds@ai82c3km_IVq#E)&fcVdS#Y}nV09~lFDRBd4Y$oR1X4%svJzUhWw52M8*`{S}J z_}qg}`E)-+I#mAoE=d-ENx6L=^P>X6UblH|@k~ zJeio667w<|YG56|6E~wcylK4OI=FnA3RfH}69^ zWmc_ymh^0)Q{j(0dkUZE6!zOUs8*m-%!~GfHTqATP8Eyem$YKpG%uE+ZmRDXP%aay zhG$jnQImL1sJ&_m&zibRP2;)8pHz3NeeZx_FYZ{e)M`oJu3ABs6VD2m-<1} zkEpuT521ckJt6hOYE~Ts^m|o99mn%N^`!b3p7*P#)YEuApq^15$MZq;y!r$6NhnlT zeM+6cn}^h=)w6g$tWK)u@cgLyj5>wqBZ^#qCUIT>uX!i2E2rJHwgCaL zp=`02ORndFE0)^6yA*nB%kGsRUUn~gVbJYGH5Y$Jx@&RJ?Rf3F>&4*YgMGu<&L^^n`+wte?`8Z!K`AzaH@Oh=%SsH`WSy?J{cn5+VTn=KG zl8$@k3tw=dr-5S*i#P~H3dlY+D~NHTJ4T4K8%v)T0cCgk6fhZ~mCQXvaD`goJjQr#YyHr=~6`&bZfQg2w-3hxo2y=Q=?;c3zuC463WQQl?NoDand?fqF_NMtYb5N09=LVI$*Q@B|@cXcr zGS0(Khe0UKB@>wgy$|i!gNLo# z`qq4-e|iXv4ER3)124wBn$(7~f02jsUG!^hDa9$`^uag@aqEM{38Ly?p(U4ib4!iV^V z{*U%`#NnZ@lT@29E&alSp6wk zSl*%FzSEWGN@ZnJs3=oe?_hl|z$F8h+H-~OHhNIp&f^p9M-kB5uV`JcL^ix`tQZ?+jGzsnoVjITqb&)sEM>og5R0VwE=AaZ2ovlDdhG?6Tm|VQ zQ4!f+qN4JKc^wMsP#Osh(Y|SoQ0dd?uA9V*=cZi4M<_F6^dHL>((zV(B>}?-ee&3( zZYK_cg(%t6>b8L$SqvOC98|Y8?LNclPUdsU>>^ToL9Ai-t3;uUqHIngcV(rS@v_UzmfAlOp83yGF6cpu#dT!rRClV`|Z_L%rqyoo>ji8k5xF~{c z+rFoXz%FLM)NF1LUE(ll^``~Zk;X#&p0TtbI-Mgrur_TI(jdq-&S%JuLy-DedqbD0 z+){**cMybAG85sXBprtPE^%4{ywO?X{dR`;^bqg7aR-5kyWq{BPZM}kYo6H`{I(xU z*5tzpO!^}3AiCiTfHVd{h9yGSucQ4v61Fx(FmK!;14NLYh)P{#i0Y+E3A7JVR;gJj6|N zPOtMhbPUG`-M5K1S9%}v0K!6)r*X2j0}0v-QIRx6g2W6L#(dn@twYQN*&z_Y!i~}I zX@bO=3rVxYYBt~f;MT2MBWM;mKBR8=4nBhBlMpM~L8{1RGo)N+(jqc0V{`GA_HS`4 zh5Dd9WjbbGv{{48EG(&DBr;Z8L81sdW%iF{J4IoJ^-uGL+_yFOu=Ht0Hmq)i_hUkt zeuV0RoT!~jvx)Rav-v&tdId$Qy3^M?WPm%A%K*K<72`6Mv)MnFA*O}%LI~MJJ`rxz zE}8~u#T;csp6Hgx8B|g<&qB#@GBVk0B65x45J%1C4}g>GNzLxxoob(9Yk{^zpFG>* zkVjCwErLwy!Zn>B+fZ*FG8U~b8m94PIaNvi%dsDdZ_&L>M}I7}89!$#cGg`F{I*I{ zp+T50E#I6}%$Tkvve>xB+)XPfBZGw?WcQ`h_X0A_=3jFSbUm!MrtE3c8S$y4y|8ey zI2XReLC>Ig8_SHU)M8wG2MY%Q_R{zcu%qDtJ_nGVln}L)5dXH55c!ItR4*STOV5v| zMotG1DgYs&1SOj?h8lU-mlwAI))CuMaM>rp!A3|gE_W3ob?gg?;~}8|7e3vp9MBj6 znF}v~7zOJ*&e(zINJqAt&A%c2C_dI(`%6`1^sKj>4{~Q#_t|0asGvk>4x$npLtkgp z3~S6+DOzdA`#eNHlOkcHEWBq*#=Wu}w3Z8|k2*T1V7^K%f3)SAAh6Qpuol?2BU$__ zI393Y^p(k;av0!84IQj9I&Cx<5f=hhzP>O*Y0#V~1>oBdX)(8lXqkr-`uHJQ?)H#u z%T6n}4`HB?WX^FUyi_&F{38f57lL#%B6%J>pdk!yqprl z`<`vhA6MsH4|(!bTMwm~U|#OH5l)ljkTP!qf9;r8AvR04O-6!0S!DC{M368|@6t)0v?1}8Gwu0$)FMHZP>O2pAe zacXSaqCe^Ow4p=RPVc$rB7Y?Z%JH(t)b5+uEyx6>XY{DL;GyA4!Una#<1F$#SW*(R zl-Ni_i|$BDs5RqvsYmAO<)i{C)8OdA)bHUNEaq8wD3FR7hR$PhXrg=m1sx+E8rJSE zBh2TS;iP<=N}5N?m;T|Nc++5}GyFWj!FA@M;Qb2@qUOZsmx7MJO#dIOyn^AOO2#)C?O#8L->185nA$nC%VyfF)zzHCy z^;X5?!eaJ10eDB!^YD-~7Op#zKAk5Cx!-f}A#*lJ>`td2Dz+8Lu9MI*HLrj~4Da?; z_T!v*k}ieL9dgg5OJQ5cc?=%K$S`z)A;88UQ$EcaHp75ffI(R?XiPwGeCl@p$ap%- zSxIIQ5brNH<|O%xWgG+W;abLqjgm_j5dVqJ;uY2G{lz6n^NuCJRLK(Fn;kB}2rKl1 z^fGXD!^XkDG<{-u1Ix+fAWQM_U6!H`RuN4iG!ui824b3h=0t$Q1$iQ@Dwo)+@&%6QmP%_Mn;v~&fMXt-dfdSVcayp$P z@_rMyZJJGQZLPiDY(fdsHw-6GAij4Nd+_Z=E$193nWdR*4-Xpshtn)Z3gSbvV5A5q z5%eBFMGgw~WCsO;B(4W-6%Gtm(J$np!Y+#;HYf0PKd^a^tA5~i*yQH}^@fqfp$-qoo4JSxbxIR-e4-Su2a;g^0A>(U) z&f*h2hJvY!@A9ybH=}q|WMG$Y*=mbP4IUZcjfmn;Ph*n;7#ym7?u;D8YoCx`=jksn z+oO+KhvAgI7ea%eJL{ZB!@XUuVm9ro0mI~QdjmOe&P~hs%+iSI;jYz7nt_Cdk|TQv zZ^xe0m4t;3lc?P4>C!$b>6-ADG!>KsH=MjgLcT{~rq0r5R7)`Se`~&K$We^l|M`oZ z2$xxS*2e5ru++m{8RTJBySO6J_TXvpq*Q0h^1wXvwo_u|i@rV4#*;)s1YCXq67)KA zSh&sz1HLO##S}KkBIj|;9>2#_8X0A4=tU*9;o@-m^*DNHp3Q|NsXp8GRu`1_+~WWy z5>qw$4+(vTV@Cy8_Xf0((EeyA$dCtVAI+fcclvj2gI3^uDnx$n!bh&2amz*(A2!Y! z*|=Yv^FwURDp+3bDd}qi3=fWDyOe(!Q4y(!u!LWifGaAM&E|1x?oxvJL}K8>#<3hxTO zHYg(G|9wXPkPtlX?%cL;9X*go5dFp5*g2=u zFltV}j>FyWz-g{qI5h~#ipgXLS92F|arP4Oy4Ve;)l;6{jz2~FHirLn#+y?{|6WZ^ zbl+u13M~H?fr;4>H7Kx!<=;lbCkU2djpIz3)-B7#T@QWGMQjSpkDg+-&)}JF)(xyn z--%eUZs6`VkN?H}MF%o;H&7T@Te z%r_Dr67dqED5T@I4BX1LZ{TLhmN{~{0tR;mBgMT*N-f;2@j+e8o>YklYub|%g5<^_ zq$hQn1fgvy8)fUhTjg8*2Q;+_WrW;AU&^8P*lOciJPsB=$$8JQc#y@IcqRNA8lKC@ zUbf1(W$b)_;~;J21=|AMN3DFObbfeeg0L3gx~Tenio#gEQv+U2?R@O2em(~ZY4xM6 z^Ki(>CZx7p)cja2^bOGZEFp;&(w2D$sX%5j9`%s0_b-9BI8(gZc7U}NGqzT298DjH z9cA7q@tc0XWyu9MW2G!-=ncFr=e--~ZRNf3KJSh64`9DS7`$qSDB6ymmv2>tpC#~n z_zFgsZs2kr`c$`!Ed))tN{4SUb`!OUyk^{}0g7|U374fme(HepQI4EM&FVPHaqxy0 zv(U?kuJFhZCY6`4uiNTx_%+{@Orh2G^H(;vz=jcFB_Zejd{j4u~3)in8) z%*6GwR?ac~C;66PiXUW*#q?hbnEnbg)po&D)8*HIx>xAJ8YrjATcpHZ)IeEa`)k11 zGo-6-1AVidoZ{&3;l3GKgu-ckN<-k=OE%vD&ZNc#@6y&B%3UcG?rH{w^}Dt~c^**q z=1|hs9LjVG#UFvPg!=S0C;^~MkkE%zCruH zf>gg-Cx4G)XSGl8Bx=dzt2m>@X}w&vP09?g@GFmF>@>`YJ;t6TkkD@@_G-{U;a%0Y za3!+LPalw`WPp)Opp*U;2ddsxUEStd&Z@~{%<k5}owW2YD9EK$G6O+Spjx1apyocv8rUd80KYa6wC$V5fG95RU- zejNob(ecJ1F0kA9DMk1#Hn=FVnB(Kp3FmEc*A74I>p;AUDsX*~#~k^igxmvsdB*NLiKEmfVWV>$R%My-T$qB>@x;@Z;z36dvIwi`Q8&w7R%|^tTK=zHP2XOCWg(jY(Pm23o+3 eTMx2udg~Yl8PPO~Dt?2Bf0k9UZq+7c?)h)@uBP?? diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/data.cpython-313.pyc deleted file mode 100644 index a0ef969ca973a60fa900aefae0179abf01228386..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 22430 zcmdsfdr(_fn&-WGfFvFQgt5V3u6Y?@kd48P!3LWcKV#!G0x#Vp76J*7BP4RKa1352 zp5B>hoMg%{N!GY0Gsc}%8M?dHcxHCmPIYakd$+b`pM@-UAnI&wdb@YG@*jcB)OxCR zdw<`#4?Pf$lgUig?y)(#k8{3r&i6Xs`Mz`RT2_{YL%P*8H~io*$Nh$G)MQjH)=uE! z9nQ%)`R6!M{^doU-Rndh`_+qjr%p6D)13Mu?gB3ww{T^g(@?}YjW49_)yt5?^da7! z{u$kLxVQ5h*EQ@F0%IfnUQq}R3W1==nI&_mWQnc%KyXyQw1UokMuu@1-&W z!GPrRpx*%O_YL?$=$zyYg?xb_s=sbjgFz9`eNw`(FaozE`Gx{i%NG(tLH61Kl(^*P z1{BOO{1K(-e(hiB7(MuIT^(o?(D$;}ly%}waw_Ax`Mp3krE2+yb60VA+Nn}lc}hQf zm20Q2xrgK2d8}M|Tmj`$Ef+8x=g*qIMi0t4uGg%>?affH(5o!=+)Vd-6-=lj_H06X@g;NATFE8d1qnT~spVo=+@;|Y?Z1%#HD4{g4(eHk z$<6Bf*snS^8+XolE=|(!;v}7oI{`>z7{lnfZhdc|TDs?)Q?E)>9UeZfJA>Y4Z~_-d z4;TY!PJ`3fxm7J4`|VCwF-EnVX_R;C*>yVQ7{`u;ahTM)>&5|IaGJXLY~&4k)XRub zF@~Xlpqf;ODK7b@ z9~xZwoCGr#xRcVC+dEu|NS(~D>G?`*lf<-MZ0m)|R0&d$5;yXJf2rPuYJn6s~! zzFs<$zhvGPHgEeRJNNC1snCt_>G8KNNAk)e*%jAN#x~XdZpD(hG;A(?zj@B{mq(Y% zo5SVJRJvg*^j6c<&PevQq=L&WZ9n_YgYW$Gg-1E3K6v3?*TT1^ir)3j@pJCTw%w0% zcRzSx(iqM;wVZ96^4_>GePMP}q+mxld&gWwINPyM6wW?ye;}OQra~>3JGh*aJ5@1N zF!l1}OV{+j&Bo0ezDfP3*;^)Gm?@vxG~=Gt&x|c*R4y7SAFo-t+~;^H4aoV|bz6@X zrVZdqON#V!Yv+Ksceocgny)Zd@$x*ijf?;1te}~SW~~>}oW>$e2a4&=v?5M4In$Aw zohIZKrx|&M(}Fzng)C=AT4i<5 zwOX&W3#ToiOqMf0?tU|1Y>KSvYcw)@uxj%yP} z{cg;nLBB^B9S;o$aVbyWBbeGDJVSz8fQ%0Ly~2>_9vv25@r8zki*C^u9Fr{7!O@T} z7;yXTf;$9J=pPGtrDltzO&r1x3EU;2x{K+7T0tw@EOY?4Fz9uM#xQM^r5-0SDRF887Km$K5{9Qr5F1gLbb;`=!V_UtW@j>2E2WCR>$iZ^14VhpdY7$ zfuR(AInIU!>mUH}_g(abV59_uBWKSFumI3Et!*~)gbwbhWgVhMo~k2+qm=kK3J4qH zL+e2tURay4=ARZH<0+e$Q2={~Oc){L7KgmBqu6-eV!tmWx<%MZL>hssW4v#w2ZNYUHsrTs z^wRCQfS%WUV4hE z{g|p@vpAS#YtNteLuV>w`uv(L z-ZSc_NsXm$Q>SLLs?TCJI;=z(3Cv3Tez)1lGm9NP?3wjz{X7HvwP&{5yaUzGz>dt| z1g=}x{|xLuT_;%LfFy&=+HxJ^^zP zH24_Sm-4#_U7us3Ee4DP>T7_pXGQr-QA?~VQ8T+HnTqn`Q9Tq~G)>k)Q3Ew15>=vE zE|)eYmn&*yTrLtb!XgHjM=j^v{xR=SQ4EUq4A{qB|DcF9A14wYh}D$YDXXDuCuL-X zMAKe+1^QYNX~DrvVu@gqU@A!zsF=xxY0wu!%UL9;4)yj3UP}K7*;Q_(fXlPJnYLV1 z@uqdTu&?b^x7Bg@vjsh1+w?aP(+ zk8;XBEf%JnZ<#+gb4ArFnOtT;IAiUg3;c*(=#xW==R`>#X69<+kO=-n$p(FGPg>O5Kv> zlI^#?bMrehr{{E3u{N%vp4;BEn6Y&Y1L6iW8O9CCEy?j(2eNmt&MDJu#()khSWWiv zT1l~0-U+OFnano!Jlm~jatS2Ma#_LI`H5XeL>M!0g}4(`@SqqR!6U`6qBRW}OX9f0 z&NHU48>9LGzZA8^USa72^2AY-lnkS!Oh+@pbg|`8HY22V{}KvH4rEuk&(pXZ+Z$t( zhEEI2mkRCSLiY6i?aSIQhArE2>>Vq5R1Vk%l=CKYUVU3{?klL8!tqmp`YI z^tj5WjL+VIE~x*V_)Sg$zE@+9P6CsZ;0oihI9AGlMReqp0Ved4PFe1wo#5_nJrM|tdKC}4 zfPNpR-~rcIIP0L!XmDM$zP0P41}t8sXvRR$?Ghxs*g2ksjp`7M*? zkF8to86rh{|J*QVo7KPDbk7hj+Pi4ow`kh;_+imr*6I(l51JB0i3ZI0qX-x8aJ?Gu z^zcqzb#*{2bQ3kl$&MuFng`=&Y8uy?^`V&SKu4dW!_G(3VRq8Xd0mWUB%**i>zt5IDF# zbrTM!I)+sJrU^(kS1*$GR>lFhA{$2>>qm{OfRw=e7!QgjjYoU(d+3~$i45a0S*Hv$ zdw*%#E)PayUA#5mkC9^wE)oZ$vMO(p}a4HL>2~UAq1`YfYgQf{NR2}n_ zF{NiWPE>Vzv9JwbHt_oe5=n_zUtT_IOe8rvMCxr+=W>1ae}4D7-`RCENt0%Gxx{w7 zvB$9wdyv39?1`*c9HdaLp@uYw3=?@8Csf?Ja`VbkX??h~{!wW|#CqmoX~T^E{*im* zH}v7shDpPs^~}emruQJ;N!_U*wDbGzqo}iJ5h>S(aVu@!_ z=;nA<##>6N3Ch^EPDLiE5m-=4g8w>qN4K8FJ=%rbZ-UI(4>jx!_)e z@{ng6tYweSD+M@H86+*e2*f6jk=56J4S(t6A8=^|BQZ+ zwCRf{Ltc?tO2Lb?xJ*e!^1rB)My)XJ^ZGqXkj5ukG-4}~MFAulFBrkmF|UiJ!RL&{ z?gscGGZ;OHH-p*jO$=EgYx^@`>Wefkzi6i7mi?yvBU=T`+`LUQN2bT89FH=ERXvwq z`It?N->s_V6WE*(NXFq>2w6FX4XwtL;ZDi332QK7MV6k-_-A2y0Hs*#1#~?Eu@lox z&}ltFw~n-ihG1}e0)l_C9>E&MBxS0qA$`%88p{?S-Uq@!@iT8;3M^Yr4kiJTJq<&? zi(Voeco=^1;h+Z|B+!@e?S^9w%QS`LmK1{9A$l<^RpqtgTDE&g*fg^T)EfZH#8FSa zy@f*39%MwlzqN36^>6Q!>G#N1vEX@VYi4VeKf1hJP&li<*)$b+lv4p+mVe+e3k`g? zazFp0to?jK>`t4Xo(JG~rS(+|M^%-nT;K)=4C&yBHMmP~sp<#D?W%u^{Eosa!@hyx z1k-}%5REJrS)-8YBx4zN5sAS7WGldepRXXQDh$xFgk*=sP`#4FVsSXgEu?vGy5h(M z!co&c2$Kk(t}rHnW{Ba+QCev@I0i4V8q_6Xiarm5FBZ^8RcfrAxQ>vc50MqPt|%^x z7os>eO1>y+$Sp_+@%jdR1I!>{aI_IREPt0-%{;gNlorwU&oD*)j)kAdgeK=sR3~~N z?p`-sNFos_VSbFlxB@mXp~f4F+LJjdFCc@s=VP_`w(f@ckIgfVx$;PkV<}@-IAhnc zZPT(fXT9~PF~%^ee`AaR0|rWgU?-&t&Z)y}qId5|Vi;cl3#M;rkXWo9vmLAoBL>}AR-DO=Aj z>Esa^h%*h?-Cv=wM8XR;!TyM<_xtboe&Ack{!kBJ#?tnq;q6B!L)Txu25aH;U)LjB zKdoxZQ8}&^KL=pBq?kXNevsK9g;7+&d=^ffQ%@1&y$19N39|&oBQM9{)uX8)DS@e| z2T7?XPi3mhiPg#t%orD0MnEN5VxsXcWP1WiiFwEY1#lT=D6gP60TY8 zcLg_Ov1hm+0%fW;g=)bfN?v%zEJhYeP$xZRj3;PAQes^RCZ<+l-R2f(GWZmIH2J_) zJ{Afhk|5v7x=`8gQY+NZHkF!MXmcoIf~N*~)T#JVpWsFaOqmen=*r1}AR)5McD6{u zaPSo`R4sI{vVx#}GO7}mr(9)Q4Y?te_OXhQtxe;6WVrKa#?EU7YjP39r)BEP><{O8m56#v` zSc7m5=j1A5}tlv zom_?7-F_A;A)#X?Ep(?RIM$E-1k{Xsgc&ZdKgI$+hyV*7!c2_|dT4Chfz^S=%~5?I zFtMJI<-ideNgO+~`_F)D$%zd9@ocVW+l+tixsM9>e>ulC{TeLpo&6^t71pVR!Ln~P zKg=v4hIV{vwXJ6m%}AU?iyL{j`avNuLJ||uZ7SwK$rSTDqIoGHsB989Px5yCM0IM~ zv*@xqIl^Hz4DVDe`T>;FPYmzkdECakR9K{rRgXhByZ>7%Nj%7W51BdFo31s5GXxT8 zsD~S^)2*`=OJ#e)WqTI(-EaI^(}Si++0lsY*ka}}c$?Ovl;nl}*2t-`5lU%6KE$U+ zNh+kOGNNprgfx-0y+Qqy=#L>7=-s;kdgZMyaPGkA!R!FV#UYCsQ{&=g3=8jC`qW#eXWJtA zm6IkwA9+r$!TVeO0&leVD{)nb$lz-c0jR zbKk4$`q)RVt?$ublgF7!Wz}6iif1eSkpX>CoOUddwPdXfTPqh$mFs)`{|n=g8dGI7 z1>=x~{$-3qAYxdJF;2pBGohAj1XXZbueGwdm~l?WbjMPDRXD#Yl3zUs=BXp*Il0jI zlh%8!OAW`u4aYuwInrE4OA|OtjoP$K7^39f4k(GiJR1e`I^o1jPCa{6O^Rs3CMDR3z#i5!@$F`+_0-Z@ zTaDD#llO$2Ga#&}hF7s$fKBR0dJaFiKM1}eE|I|r6nw?LUZ^czi%B5xNYlRjW>7JJ zOl-e_NPQLaD8s0+@ex_n>~gtBNB!f7Jz$WW5}uYVCB#s$Nnvmd{swY;izUb;8a%C; zWY4-dMP+0IQiGI3eDnrhT;)E=$hm&#^+Pj_OBn*@fRAmPXAEyOFJ_i}l7*{x$KLnM zj4Wl@rZ^5p}Ee%Ab;`KJdT z*-m_L@LuCW{Z!kHSEpZ{bw@Tk9_Bk99DHayv6y+{Q`_dr=N_+`$=UZC+X)s>`?^Y~Pv%H-okzAk{HpF?JSqkx{nCYP3J<{)Sw??-! z>ck2=4*G-eR!FV5kc*HragCfUSGn1?JEw1-e!qjMYk(#5N)5rc<2qrHG=m3GYZCZ! z0!AEz3B!cd+4YInC&ER07Ol+-T@S6zi>79IDxguTi6~CN`n7$?&hzK>8&kcT@76bB zBRIA=n?QYC9QAKM?-fDa5l9x>Wh`$VBu~LE(A?v_2uBE!1RfOHI;x6Uj1dy$>q{7c zq&E&&f{4X)Ni0YuC6b+FiIJVyyj8dHkh}OsemB>pYv=dzG4ks|LB&9frpaRy_W+Q* zl)rja6mTPcA6c}JQ79B#Pk0AOEA`=jG0M09pna|&T+zI+?_q^ZCGk)3Qr(5}K0zcV zkQ*YR34KDJ+|qlvi>50r*$FP0g^$4m7KI6ZS=78}efB<>;Mqmfv+FH0g7@{r{`Zdr zM@gZQY{samXh0>v1)X~2`EErQIR2q&`!z*cKQKBb-UJ@vBw40mkt}7SW9y07NXXWt z_vr1T?;mH#WwaCjjG$3>)X8}s*HM!SZZ}n4C#^ddtqpUYht`HgQ-e$>J!+{*Ba|Yo zPhee0t1eDj^>Nbr2Tw_~?)J8SEUIY|S%IM{5()eyr!Kw=3aE(!F}ep8zJw4PcFY(C#yoCy1xF>se}OE8bZ(VNfzsK_sh6x}~AzQCmyg^oZS{)!L+l?Vw)oXx%_? z6h3=7^5$VX;n*1)huute_kbTqN?~_5qivCvM9lb(nd(vF05%AF!Cn9pH4#L*0;4fM zCRyXmE|(cimHEHKW5r5i)mw@D0X&3L6B_~%vRbn42wQhx@r>=XXKtQZD&7??-u0-s z{tKf%EBn6{*H5=kwXNv6+`RXVcdS3KF7Qjjp0KcIs{Y0^)6cx!^i#vnvOdWA&>h)- z;$hB-kBjS9sot92NL*wzTOP0Ja4*q<>@V7~pEX+&3?Sl)uc!IH_jF2R)Q1B-o_{Pg z@->ye1-*31wrc7*KsMxJrPk|<=RjBSU0^e&9oA0(AK<-WwZ&i^r<${B%>?Su+L_7w zhmf?5O_gDscHla- zHN4zo%9Syuc=L^nQTMgR_>{d#ImXz~w5xsu$Tz?o8R$t#)tt(n$zGsfY;UyqP#L^wYVNoM`S zxah*(n4#^ zq}2btZ`2woH9hvKC#U`p z!%qqt^bfH&De9!^X-wYG*A&ngZdwy?leqzIQq%^_uiTH67$9eLQur1kvOj zQdpC*bC~!fk-S?;J}@dj!^60G6w*fU?dj@?v>eS`SK6m4`yh*&ux zLPe!o`2Oh#Fg`g&)j`pd9SPg;)X@?g6v7cH+EJ@*gBkRRz_j-1fw-{))F}o^{h)z_ zFC?%X(yUKk71jv`OFlOI_kv{4mG{m#hy%%<7Iz_wW_DtOFm{-;^Tp9L@|oka8@2R4 zw%4GNN*aMeuRp3E@dc2%FUd!b*{*~dD!vn$;(LgiQAqy01}@!~z$3wXzw%Da?V5`5&fj2&B~au;i7fEHN5j^!r&v;s=3^V zwPv|kxOMjCS=lWUDXyO~CqJ1se_Cx{%(3H*b<0fs8&_t{cRS`g?td#>+ZqvCl|CKe z)lQwid4A3vF4?tE7%4cG+$Cjmj)J-?*RH(r+U)j-wR|>oXZ-f~omX$ay5NrN*nht* zvZEz=wBNq@?OOmIDXEJTI8pzH>e)wG0bTJF0eHHRbC!^&U}s920A?Lya2&fhz~ zwD&}K?}^CXQ;~w_l1Gm{H{pzu*@`=k+m3MAK^pzVYZtFyzIGV|dlesn+ID+e#JWpq zq+V@=NY!=s-2Az_z4N{IbANO`QgARC^!{u6v4e89{Z7a2j&MaQfxdL@(sht-=35bK z>1^Yj*4wQSYkg|iJ@Y+xUz~sOeto2_B~s9u4E(^g1L2G^5FdmO?`R|7=<|}b6sN~# zAqQo%o;!itfrzy+HF$J;sjeklM`)}&94TmvYe3KC?1f{o`dNN4qwERJ=Klyq5r4h{ zhG7cFQYr_tYS`nHH#r$uV&~X(w9!I2zl|l=7~qYT$CktRaGT~@R?P`)`@jNA+h(CT zneo9igX(DZ946BVu4`i)e7KVtd}pubfbMa|L51l{Q0en(0X^JZDV!}>*y~amU)|c$ zMTJLu){@!>y!GwV@qs2i#)QR9T_?2+u6N&P3bf(M>Qn~!|px%bRM{kxXg#vkW@w8inj z^KkQ#h^>9n^gpbdKW=V)WIOah#r(JDoNu3;8Tcr_`a#8Hd)Rhp*|wE!Yi^h`+%?ae z7b<>ajo9`tX72xW`-w-klfP)cH@+ZIl?DD${+^$=Q>By3w$deARoGTFE6p9fdusmF zLMT#uFk)+2%xt0Wt8HoLAFtu~;$ivc6fe;t_J_qs@-06v+I*xa?H5G_M}#!HN!}bz z{41La%o3ofUH-lZo9<~(2WTcI#J!H$cZmNE&!QHW3!euY^Ak;Q0`=uFw_l-<)#YL) z6Mc3s5X5U7tMLSJtVV30iWw}{A)gME;~j_bSUgVIY05e&JBKW4!m%;zmJE%{;b#}U z=%)-TRZhG>nMm0fWz2BFQLqqBu8F@#Pkx`W?@{(TWq(8&{01D`-5}B~2Jt=0C^RbG zrHle5A`x5sx0L;7WYJAHppT>cz5!Rv=EPXAv{l^g0$r7<8FfSI0Mq;-_qlGX(X`Ua zWmvJ_bvdtK*^<4SnFBStoR#}UZj~{6<$0b<%l%xJZp>fP<7$PjR`Ymc{c4)t_{^$B zXRKZ|>5OG8L%=RA`*WQ^E}XqWSF4#i7+P{qnyD`9Nwe}y zXRK15;Y#cADPAjhieJq%8MmzFv!NP{2UatU#v}Y{w!zr=MP{1OvYL~QERVf$^Txct zIi@!oTjj_GVHw&28k+CHm*O#Z zJbL$NSp;3Ql>oC3Hc<4jD9{3Z(L)~l&+KcT^4vG44l=XHPdf>6#Ln%Qt)@;9=ET=2Y72Efd)pV`7M#(RTsEF!05!Ex+yab~qQG-!U z`kl+2<<9z<(_9zMT~c5E7&IYuX3Fn}yYATjlWk&Q;`FkCJ*+J+a5H%24qNwBp^cu6u90c z7j*xEnOX}rXD7iCYG`WmL6r1`3_g^n8jq@F(o(CiTOo@JqAtVhX}s0WQn{6egLas8 z;t+a~e66Rlj*RlH-f7;>y20MvUl;R7V7K~bXj<>5`rwr?_A zq{vTuGQz3JgZ(nJb!ZlVF&UB%R{jFk8?i?ObOI}$`y-0_oYB9rdtik+Ti}6SKrdMy z8_+pM9^OMbq$6^b{MGKUD6$0Z7;C(4Ty**cH4fmeSj7FNx|6_Fi?X!0dkx0_L*^S~ z@cNwdqcBN=ps|x=(9^q(9T7*x4w##4E<{mv1@;5xOBU?v5{%S0kc@>*x1unO!mg9S zx4)7wQx00OOoV>H7zn{PkRM9=6B?uZ5KJ)y1OAp!a*yF%RETb?SZ8b zZl^*V5=~3s0t9X(o$I(AWQ<} zHj7E7WCMut4rEtnN+EzvsRa=vH(qY%j1zZrd!T{fdap4n)PXz=KP5LGfAL5fs88nxA&uS$cuR+6~ z3vB~_uR^;*FSCU>R`6MfW2i2+J!I|!8Cu}X1E$VqDi+Gk08@^N%y0v@BOjj*)}~+? zFYebpjI!M@l|f*tL6BympJ2Zh1o!%3QmiPv%Vq`jtzEUV z5zH<;2zY$-0gzdugCxf=%^g#0Er^%W6`q7tCq3cIvh-!;D#Q^{ne|&$8=I

3q2f zeU~(sATHHRbq2QZF?jYiNBr<2x$u*A9JP5W)v?sYLUx4tOLK%8533xW?I;Pg-adRT zNi~pB1KC3oKjeAFL!LsOdd)_bhTV#*e8MscyZlIU8Ru=Oc-FFTpK|}h7qUO(nOacf z^FMu8=2B%S@vsnnTOI3dyAe-;+Jd5dbF6WFn;*30blRS`FbA9j&BH-WyIgQX9`DtCfI|Bd87*r2MkQ@eshV*ma=#^6=(VRelmD2^e|G2_-2x<|lZVb9rB?Q!aa1 z)NZ_>#5<5b?*b6$`2T)}SV3jE@ zL%QsgC1e(CXGpGr^{$`vB)6{9$mbmy3Y=W9X}&AlXJTEs~8_*bd5>3M2=NA3-pe zH!+UhhNdxr`W26Q4m8FfWkAw+=DY0wtW^DEu#T6tFxj-ot=H6>DC`GFeh38WjzDA) z2tZpX4D`muo1?J&RhZ+|cafk}vpm)(l2|m&h45Yc1gjwRBOncE`Wg_QKqC8Ek;>es zjImFd!=5qFuv%iDRJ|)t?e~zhSJ8NYbNC~q8{tu86K#=AJ;<)`7pl<3YK5*dYc}T1 zK}2COVP1ZI+QJ7z2>1Z~t}qwdnm~OWd$vTdzOWWR51En$p4DyckkdS6{GB4l?PA5N Taej?#P@mOk4I5o?Ub^-f>hPt; diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataloader.cpython-313.pyc deleted file mode 100644 index b26fde472f4457c90cbe798e07f29bc14519d5f7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4816 zcmahNTWs6b^^zhbS(f~W{7~W~zD}LQOcOb29)%sps2k@|JN8=B>H^p;O-q!elqTIv zDyg$!2r!^%v7dBU)&wX}cRx#j0&9zY6xcq7A?UDuNpT0eM2Ze*_mMv*ZrA2-=a8ag z%MEq`9Uh*0dG0yqp4ZiSEEXX!-tW7e+Z`q38+`GHSVq=A2jm8kiOh|X5sq;qJmV?8 z^P_qTJ1CViva~d;RRxRkZk~Vc+wY3Kv zvf&)@dDTwm01<{&TZLx`pO$8)y-huci*c-@csIy2VVq1DFLNv)^E(I=z{$L# z%%ZSPRC^<>7c!dinzm@6)|RGcSsXrREdVTJOeu9}#6zNih_p0l+n~y0hnt@OFUz3U6Nr3Q3n;1aX3T?h9oFo(BD#gcu>V}X1?mcC`FVIak?T%8q zV*+)m8CxW(2Z@^fMJeuhS&tf?+x!(|d+A78Uyn?Ji{ zKbYCT?FsumL5su_xx$I1xEdKMA-_e2B&_D+ zy&gRT^x;$hz2rrHtd5X+LSe=bd>IvfjFzbZLojMj6NAnjTcqd(OLY3PDP9(~tN@v& zx~ePu9YIENjFZJ3&xVZ zNXb~EC;A+DHjw2Q2H70~DUeSdb~uZ*!h1RWU1R9)PTjzqW@v9LutBGLLy}6Xx~?e6 z0o?@58cbGJ;gMMF0DgQxu90HrsZTm?4}Bmkz4XV>Qs|S8#}*3x>(&H!wG#`)1veSe ztXVKJG;ds1^?Zh+#{E=kaUyvuZ&Z(9k`lbU?`|lHwm(w zVa#Mw(6x>udNoAQbZy=VKm~CG6opY(ZejAAfzm;#XFbU9o~1Trc#nj(w=TnZ{s<|) zvE}+pZ@=-DaKFCwdgi_1mHMtyeb@bXVkN%+UVQ&*Q|s!ku9aQM(yn9$c^DE}8rFyq zYxpig8WJ~!_q6x*Tlz}u=~C?J`|&+X$Cu+$F)H1Uwd34Mth*HJ2A1y*{B~eT{4n*W zsXJ4pcwaHv=dDPu#P*b8dnzlIbPOboM) z!PQONP-4ha;$25!hXJ=RlznFim0KsLsS-$K3RGZXVZpjazHE$NKl7Vk6vI2%b|Bp0 z-Im1l@o_1PErX%XYtS-86lLDb6m-O+it>6v)yp%D3M30=*?Qj43{z1Uh9koq!JY-+ zv?>b3l$}p2=<@m5f~{Gw3T1cqbX}KWopu4rWuekJ0{4%X;BVM2j%c|NUoa>PuERHur%RMj%UCc9N-m$y(Rq!{ zX_hx(&SFPrQA{jUjr6B=)w23$YPz>1GD#wTXm4kzZBmuaLj!ln%?L?%XX>*`s|KCb zsFt^L8lz^`O={HV{HrIjo}nhI2=LR(&TF>DOq63L85v@8R(}LD4oU}t(Vnt|Qa_!{ zX%(`oo2NQRj;5d>nme6zzjw7zto&7t_CS(6NX_ys02rM&uh4nbSfrLZ5Bbuf3!s$y z0h$4sc|1^#p7KqhpZ0?>`e%;JNG^XXSIA~{U_J}2@h0vsWGM;`oV+h4SwWjQSP_q2 z$?G~8)8&FYFs;lhLsgqIXwFNVsOK#cxEL110EWpM)0PVq>&J@&vS2l@qnrU#k zu_&B>e$JTJ4BJz)viD_h6qA9gn2c&F_z_e9C|s_vSdF%dFm4*Y1RL)iPm4>xH!}=!aHr-{!IK!K5T&W0%+-;L&tcZ^ z#>E$_SGnpHWipf^uRvD91f1eifnx!gn?j-7aq*1b3?2sF!`umu!ZlFh$DK$y+bUT@ z5~~6_qE!HQ_Oc}AGAZmO@aF+%rx6SzsN4^opl4gB4yrq53dpCZ7`i45uW@qOIdP0d z&vKQU-4HCbCICR)U)x5aiJK>u>$_H)+OGey*wnoiB)cTMbR8efgJ z0elz=)GZ9| zQk%3^kEG~E(htbtzXX7%D{c&pVj~_cL?--MVI|eswA6n zs&#U<7s(y9WH$gX!GVY1A-iSW?L!&1-d W&0VGDu9fD4rRIa*!%e|6$o~Mh4oD~f diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/dataset.cpython-310.pyc deleted file mode 100644 index 466c5e1e552f982edc5717131a772edba00120af..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10260 zcmc&)%X1q?dY=~_1VE4iDN59X8p*au$Q0<;T3%Ywu57IwB@V5a+DZy*XE;PRNI-)b zxMv`V2$a+=nb=#EY}FoeO16qT>X2h9haB=3oPEs+_#BmEYI%06Wy1 zH9a$5cYpo$*WY{SO->dR{QmwQo^1d7WkvZndKvv?@bW`k{(quj3R4}W9^KWBR@c?& z*{B<`FH_Ipt~uFGuAY;9x|8or)F(QHdZAOS7dxeTsWVxh?3C-}cnG)N&YA0+s-IF7UgoFy44-b#^IX%ISL$azQdpi%>?v%5SK4#! z+4k98y?&Mz_yjA`7*=AFdq({nEAt{hKd%6~0O%B(2J`|uA;+Buq{3zZF`2rqRA+yU z{VHo!HO#N`FM8Z<@M<>HZ@N2S{zKbw>`jM<#>Xu`2=kwGgO=yoPMEpZ>jIqlv5R+1 z&EDg#?}=ZlK!sQ1mL1rHv18xotR?>ugwOX90j8VhUI4DpTqjQ|mf6n8q_a%X2(C zukd_RpI4e1)0wen?5g#N=k$YzvcJG`EWfAJi)?}wa4!MdB<2@c2|F&cNj?e46nIc( zQ-DlQ#DDA>_;Z5RWEI>ixX;A*+34QPf2lN8c9PBQP1a|CWfpi(vD3hEQc{)0`#d`% z>6l9)XW2OkIVB-P%XzjSA*WNw1!hXfd(* zKwh1AlfB2@;;+Ty^O$p4&WYx|&EDC|)ECi9`~Mz$7ju@rR1BqFg>-#C%#kekJg~cl zqQuYUCLm_Rv3=icdctf7ZU-%Q%iI!nciUvq48K+|@muHI+}P%Vo7cC4pzE)$taNt* zPc*i#ZgH=}1JP>KfPe+toOunu*78=kyW-dZ_k$JJ!)-+(lB;kl#$MYFI?jy^%aRme zu4y+KUZ-ojJ48v)!xAmqsTJhhn_|mfl}|*H@VvlW^0BhJwQPE^R#aEbPpJa}&l{M% zV?LB*6Yh9X)8F=b4l_4-ii`01_5;pp=2C)b)m-&9+p8OEp366?>Ciyft`B~7%%w)G znPaQpG@CuQK`Od392#ufff+y#eY547wwq$3g`;I8Gp!*riP+jLG#e(!z+NzonSuvB z;rga+rpP%P!}fp+AB)$_d-0Bw=?`0ugFTN<#4C2fO^!)B#Cq<>=S_6cZ;A1y-E)HH zFrzEDl@juu88Sw{?S40gj3G&B;>`RG3>m|<$w?HYbeMq$!Se4zN?y*FX3GhwJSVQNV0(2o@;%$em8<|!f$V{ye(MQ7Qv2$ zMNNS7@K+9L_)1~}yE|c-c0oI}XzTtGF-1dT82vf20WjTiqr)N3Wu}4eIp^QDyCxEu5HWt0u~JtD?0pf49s zc#l|khFCb25Eu8nn6@YC+Vc-fcer~UIDIL7YQLDE20YYT>`{X{k{>-MN}R}D^&=h( zx187%xId^be-s}H5N$+TtZHHgSi}M~ha^Ye;=vCHE_a!d0_a+3koylOQbQk3j`^an zK(CNDnBWFw70HjD>C$oeeKdixtL&=>*hioZ)Ii_WL?tlxGq|z?^#Fg5)qy%ttlU6D zn;+_!itsUvAA!FQ__Ty8>&Rmem+5O_9u$bP zXsRc|jNj{W5#}wc<+cLL3JWd2<@$k*Y#`JT=!ClGcf$gIg!qW{d>ZQ_5b4^%wlE3N zMQScl^9mYZlGdL4;e=($Ko3ZP9v* zER59XK*OtY*R-#?ct@JA|}+GZ`*DyOdfIesZvu z^&~Pop~2u?57`|#zX=>qOK9?X0kW*AWs>3fgq24K7-J{xiY_V4Sv*8y3$MI%(`jKh zly9wecUFnq)r~(pf^DN#RV5k69(0MuC#?Dc=%rX^_5SQAdC7n=%0>@bkfHxy9v){~ zFFHKxPAL*MfcK@xH#P;wC#0pnjj7^?)JQ9Hlb$}HX52Lh8ux(Ux6q6q%NXJ#niu^pgrb;*EbLeedNdSDjA-r1igfKWK+a7YRJJq(L$~EMyNJN1VKhsh3gKgL6j@s z_zR4amgS7ne=U&!vMiG9+I7c6qUztkLo%bPg}K19=v~aoRiwQSGqhPh;j?dn5o+MZG+JxbR=gWeiw8c<@Eam zy+h5=f=l1_F<^g1@IOG)e?8gokh&MJSgg^+)6{%Ijg;UaSEmjk`4VFy-WuTTspl*; zwlgv@q|nmLyov`{YbnnnNAU&X8Kn&itpd7lB#}qtJ2PQXp0&jvkr4bLCJVB5Qs|`=|e8gNc0^bE8E?a=DOeykVwPcqO>zNO2 z;6y$|RbWCx18o%KDDxk|B+G5QvF)}{eqN47`|#Z|eU9&pJDj|4Keq_8DEy!=gKh{C zqHNmOzoZ+BVE=!+YPLYqi}sZ4XKvvnE>=&Bm}oK_ zVOHvpA7*5N9~x9(2s4)8T_G;ffc!&SxIp5E20h8(T&$xl!^>I)K}!zrn)IRJ;H7eYj_#IEDZTe2DK{^uEmzT{eytY-lvK;T{CvUBO z@AHji(>5tiz&%|dpzlu9ad+n)eFBIleC)Yf(I+UBtOdFwzI=eoZyh1?oNs= z8s&Ctwh;CX798ali0o~m`W2bCRJf1}NLKn?-e@&jBPH-Sp{vQvVy#*dmmw#pT)07~ zx!n-Spga{ZA;GNG33D!!D~I`(?;}#e%LQ}_DYqACSP4k^BMX)SRlh_pWx?V@8aS?& zVbMYrKY;z|Sk^zG-=~q~lJ+dGY2TK#f{G_x6&)FcT6l&_@6ROeeq`>kqOTLew;+2E zP31uOxq{>v3INeAeKbJv){3VV{uX{h&} z4AKpjNrH6jwHTaW*`P8&5Z|6bgpi9;(qMLfb|94+T+ytPq`vUQx9M}J?ORfSrG$(b zz9cy!yX(`b%4j|N7?C>Jh>(b)tVs$xIMM*#Uq{FEN^0Tk}JOnL@rI{3lNju~~ElKa_ zqX41|ZsJQoi-HN1fe7&yzLqWawZ&?6F4~3Em<$;Q9;r!2A`(BQ&BD96TY}MJu8Gfk zSUk)I9+`;Hpe3MBp4;arrdt>!HfdxATKDcdf_Nbo2~fC08YL6+&>$+sO#&MYue&3w zI-^<^=2%n$k2GzYu*m3~sQ6c0K2?+Rh(!zP9Gu)7Do_>mG@ee&r&(=Q)pQMRF#E0X ztdPw@?}n-J=z-i0i;ehP!`b#S`=cCXo#pEckwpXWcLZ=Pho$G>|LWQ0m{iMT%Y2@;VXnM?$I zk!`&vdVJijrpA|$(T^NA>S!R3YBil+fKNo88-<g+8AWymGOi&W_af z;{Ib*oD8>50TjhHyEwA@hA1JFkZ-%H3)@Flsg2Nr@{WNj%UuSJGP5S=|0Y7^Yjx>s zwF+WoltW@YtXafWZ5SJ|6$OHah({rc;zP^QAr~_3J4dvqpBa+ws6l>1=xwqteVf}bLC5# zQO2x-UihYRN_*A#R>NBbZ*xeHF@&0luq@3$sIe>ICeGe2+ zsY^!n#@g}jkmP0^mrr6N-C-^&UQr-C;t(P5OBf;UQ6noUhe#)fSRvBM3Ct<|3nS%+ z`#{%4-4IbtO(KVHzT+n4kg-;}#?IWBZc1ew=^FV2vJ}YGDAFs%W6E9%C z@qyuhOtpb;OY|=v{mG^4wJSfUe&*M%EZwN~jnCfv?5(OEYJxt-^fslOg0jCnhO$2b z4F5Wu|sg#lK~ta0mElQvW1wzG9Qok`N|O#5kPzw8W(G8MgXWA2$tm0@TtK?7GW%9Zs}r&%kp^#R9VallD41YGNcGZwZX zP%`Ko+(`4GeiQ25ynb`Q!+UN-O)0ODMnqrP9|%ydPgLyz`PB(I9+rJYs<}HpsoHx& zu~_JIOja$=MU|v#KQW$+CgP!(YCSnIj%sUv9KSKN=%gH166_B+3;{na2SdpaK^zOc zB1fW33DX>U;-`uq%ZE*Z@G?$sP+(#}V5We`%mEX#3|a%`!JdknxLDr)dj5^12w?ara&$7%~_1CqaI8G#)?4g zTlx-W4Yy51%rHMjFUp)aAIsc z6rZFOBqsnq8jATHj&3%hbU7p>mPjO}W(BjwN7|%B_R{&fq&{lEeUif%H7T9tOC^}u z$s}boF%gSMr)3U@U@4*3P4xvt64P*O_bdMi>iE`{R8gc^8`v|tWtv_StsWJ40B$NDiOA~88R!QzS( zk_^Cd#DMgg%oI%OmriDPs1H3Gjm5CX{6YAM9mtZ5L6d}HS;2V{0w^^{()OA8I$Q_--v55$924X}n0}aDb zh%5biJP-kDkw%GOl3#kFG6~sbH?lW{0li+3MV~nJG7)7wG$u=lVZJUP4|f5m#+3Rp zvb3Pcv0*<8oekpZ4uKEEBG^kb?np2yV2JWcbX?n9;%q#Y2u1w2Mb+DzNFEu3m>rYj zNjdT*#u5zG7J{bG;v{vTeH2`&eJnIJ8cLjrhENL2+s0XfZkYDtGz9u+`%V@a53%GV zuhX(5u&%VPAk%i8YmHB;Zdx(zDoEQ@)GE5ZV6S{9xq8x9mPw?a7w+2yPu+rj%QB(r z5pX&DsJO%q>4M7n+1<+58bFM`!4_LZAk_D%h>dVz@I&;f|A!OE3_9z&=E%vT`rskL&SDpGC_^}V56%8GC zbLddb(a5nOGGdvle6u%>v2JNzObX zCv(fgj|0@r*uH8Z1wbt}M1fjTAf;3X{RFXsS%*ox(i4?A$NDw)v0u#RStqUi!b@n^ zRcA>UFh=%^14TNB&MY=`DeE$sy#n95{B~DfL$+8a^oyN@S0w!#7f+SO zY+9WloRVV345zloijDfGW5#>|pLi#=3XYD6Kg=`D^sn?#JJ2WC`Q zxRMOPU{lRd7FBaX8CMpMi^Cynn>-E&U1Qa#Efz*O5Z>ab^TXl{{zY4 z7K+^$Y$udYNy_Y6tPF_28<A{y5Sb)|^5Wp$^V z*vMAAjDZTVR7g2b`Q>vL&*4mquNTjCrais|yYCCyMzJ9S*&!@u)uDiW=jkhIr7oBX zs!u;?FHfdD?F;ty)q5d|;_Uts3Pgtik*eb2K$hKt(TD+)j&g&ggGu*`YaLrdo=^yw z(Z>)A0GmyH2CLR<1i=mO*iN%f#ZZr}Q@i`JrX49Z@rkk1P;{Xv55w@DV9=An%A|NA zB4_n=LqlgoO==lN1gSZw4^YLXsclWV$Q=zS7%hd8Fp*DBBxR{-)qp0SRkfiX^*}38fV?-51VI zTkbl`XF3*Y+cM5=Dd)CDZ*|7&OL=_@F5g}EhDEn$#&XSl#hoc_N|iP(c$$|4(btBo zWZSaY>~Sm!=3>X{1E*u<2To${L5v4COPp%%jzxhZa&~l$Pj(R!x`tj~we}&uPvrYs zMY1mHYaqWOfA@Ort5}2r;i+#tY+bN#UA^$XbxM^ZBM+lgYBrF?z6_LnCao3> z0j)BUA?zi}0+f*)WhW`?M^-?O)JZL%#{yul(V$b5tw&a1oRG%F>F2vBfQRSUgs>Jn z&@Xgc>=(Vl3ns8wJ5BNPrnPH6YhY z!$iLwWD$}%M%oKh?wnP<+QQQn6hGbgWFi=YtD2E=$mI$|} zZ~-kvQ_71;rts<1GM(Iws&AmDLT3V1rAR2NnYCnUT2nQx>C(1o>u-1NNqhGGtaCo} zhWVQPJ9a1^8?JAdd-XS^zIQrPp1lk9y?^<@vv-+p!FOA`?UujML*XuuH4kNl^;#g; zhCZ)(j!Q5E3^CHz^N#K~{R|;1JT6ae67A5c^+Q|H@GTnp6f*9g(2nkjRLv9V(j7Ru zz5COi1HY)e)2$y}rnWUz+dBW{-;}n$Uzze8Sg;?!;T>Sa<_pe^h=*UMZ=&$|FOwmK zars8_;IYSfO6A z{b0!x?FOh zvyWs>r|i)(S>2kO%nUAcQK%=A1qW-3ohmKxUirt9Q@1i88qs~8?e=tml6i81aNMW{-GL&NZ@V<0)XH5afRCDHuwQN=JNq$9)9 zWMV=(8v;yd5EP3hFw`-GSRDlv@*weOoE#pFM-c>X(|RkgRk(jX`;EvknXHW|TNAVs zrX;7hXATS(t~lD*e;FIo>i++BH5|o~9=0caFJI*jX9Jz&QU98;yXQ45llR>3pRyjr zNo?_zt&q-Y5tlWJYUR-c)j|;r)f!~-IOF~%)qXa_;sBzk7AkT76MF@O%tn!^)^V&f zsfj7@RQLiYZUuJ-cb|rN4HxnxZ*_%@k<&EKtCZcro2XTegFirOxjdH-Upzc>GE>=< zs%%>DJ(+gxn-=dn-Iu#AcFlw?KXu`$8~rz5{O*hM;`d(qiG1f^rsLUE$Fphwq4dT> zcU={ik6%2#WHNi}Vc0b^XX-ms^_>fKyY7rE+011J?whS8ZaA(=O7mx#-<)>sedr|b zo_Oa(+J7*;@nCL>vUVMst9mAxscubGw`QDeDQ6p(wI$OD$XPCPGiCLuvicA04J374 z7hIRzhUX(w(~_!bS!g4b1{SEj(OAF5xGbVWty4ctm?q}u{RGwU3X z&xI)5#RJN?oyJ9mADtLK>!ECg@}k;;LHPIKuQVK9*@}13Tk#^}zM+|pYx}P3y9(4G zKvk!0C!eVsipsA+ItAMi2aJ9XH+eb?(?%Clkuby{;AshNc0(NKVq5PS0-J@u#c*)z zV{$&ifY=42E&>N&0Yo|mp8D2sy|0=jgweSXYp6GCDp1d2Z9-p}K2z?CT0-!Oq#t!Z z_eSXl9w$X;4Pg?-dir4i_QM234?5lz31TLTL!+0WF66rAR5)@ihGH0UBrb*N;fEGi zqqvL&@j&6)$#F90lQO)hZ4!S?pv8ZUs9e_3f?MTh^a>0nrfLf_f{!B(2%}22s z1yarDcdWghkfg+-VdB)<`O;9E6q3lW#n+Lfw^vm0{Nzb` zo09GLTp~WAJ;TBq3x;NG-_IbH;y7@TnI+yp<`(5PK*icqL7z?9DgA(jxt)`bAVPy%}aS)mo^4C0MR5~6Ml9KfqvfO(bw zPD6f(4Ax(bP*Odk+z7l~cdO;CmW2)5GA@70<-c3HZg%I}MVc>ezUQZhZXf#bk@w1e z+^aR1ZaLm^WJ+68rL8|Vz3cp$^Sz$T_NP1drZ%*v zz1yb`F0S8vbN}`Isr66Hbl{h_dioe5348Xv`}#Yt|2mT1eN3+`5-Msk-iDO7VJ-lM z3+Lc!#zaP;8UL5*{UG=G84Lxr*`!}NXq?oSC$Jmis;^=MfO={R@w~X{kmAUh zPpd{9<4nq6VtPL?6(df!78@eQ723v`h$lBjFM;i?!CLekJ&eKVYyocT{9QkI{iB~_ z{cEr&zW$?p?88>N22%=%0n=U){xhNPi`cngIjt23eT37@!P%gnQttZ&pJr|?>KF$g2bQ<6fRKKv3-lQX96$Zd43vPq-pq`OR+AsXNX2Y)w|8QbQ zKi>(wvv4<&cQ$T{n@`jY)a!5pmVvE$IVW`c8uhx|7d{Sa&c`8(1G=^&^ckgGZ}M^I zH#%-P)eyt_wA61gyu)UAhl_fjrKiOqCCCLKmpoGwC*rY5U8XB-T&6;kY;Njm^0j3p zC-06nZl(B>{KYbVh6o{;l}aSddPBF6L(e)h$W~VCQ*EZ{XG?J!+0)=$R`m%*e>m-swJElpX8xoxKO#t#HD2q4H2WNEuy{K zQi$KfdeY-sd^JX7Ctv7(fR4(4K+YgmTwZmp?MmBK|Fn(5^5Gku8RwRibIW4Y#w)Rl zhdwE(o;BZiAyd+jDruPa{bBf^Z4VN-fv28>s~BeSi0Odza>@P_IVd}DpS&!Drx-5)O+>6R5I;7 zsrH`Vw))dm?LXQ%7n*(bjj7u^rw=VuwND>Hn8CYo$z<|Yd{R*}>%Dp;Q_+~JXk73% z&9Nn`sjPL`Z1K961Pg3XmjG47_Z{Et!K>}lw#D^zAKGgcJyp~0zkFIPRBz|+G5)j7 zedH;N>eLk7AZTZ-;N=q$2T|JcYsfy^Q=9jC$hft}?Il&8VofVjo+6LnazK1e2w}!k zJ32`9HX-NM_hqNtE2XWU%XmR&>6=FL{zr@k-|N!zUw zN-HnFe)07i;kkzIMbn<I{wL9*s#Po9}-^wTNi1oR>2c(fa;XE=IIL*GkgP>p|+!d7{M5VYtxu;+?7*8kl8Z#K5vKIj1LE73p zU-5&l+$sCPOKNimM%)X>CaJ^)*Xt;$#avFEBmwmg)IyZa zE>T8y0=rJx95S^s7{mj%WHcP47Yh6-hk~t@WZ$EO^y?o`|fzf=3(mloQ(Q;pB8==!O{WUc$uj%>*;TI*&4_XT{Gicm45ED89`ChyZ{ z$t7&AUupLU@}nzN^>eM`mk z(;-xDoIM3D|7RMwBv7_pv>q2{YiPpbBF5A1L~F~ecNsM;+1;RIG;N<%sA;>=v_0Eo zw>su_E(`c%n^0P+-~dh5#@X;PHqe-DLJ5nbxgIjgZ0(xum@{ABOW5kljzlSoSc~;( zajs%6{8sI))Dt_v+T5CRaOV;Y8gQCe=KkFsz)i2pAtu6OT%4c@Ynr=F-JLZlo zc(Z}~Be?-9FU=d%T6auT0^29*D024zrObon=M5vOm$J1KOKHmokGX#O9 z>?0B58L?pwh%g6Cy_)Yy69TFkf5D{1C*qg^LVz~K!pQRZzZQI;2!w|)wF?FQvcu(> zp7^J(*|G}$`u^c_;LC&#j~1KRW;F50Af(AV7C|^D_&UXa4d*eK%LC zZ%W-oc09j5eWn0=#Em)jUO<6dH`F?kAHv2duS^s&hOm41dr^Xl1W{<|RW0O#ZS>0_8P(3P1!q? zF>3UL6EX7RK{HbPq0?bbzA$|z$Kmm0LO||lNEs*prsy$EHEy}pa7^2cH%D6P5^XHqttv!~~9s=3-X|i}f z`vuy87D2g*Y{@K&;zxq>W1;dRq4*=g^RdwI-@^Wku>W^L`A0(a$3o5Tg{qH*_4i!3 zaI;l&<+s+owQl~@LV0J(wd;KG{UXII?wOV^zjEP~jC)hcy(!~vOt~BH!=)+iS&G># zV)e{UDD-0W?4JAdS#q1j9&si)+jC_SF9K1XJ9(dfmt1D?scA?ad}fqC(T5+&{{mM` B&eH$@ diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-310.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-310.pyc deleted file mode 100644 index e81ebced3f665c683e5d8030ba76fe8291990903..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 631 zcmYjPy>8qv43_+J=Q<5i^a(V^K^qt7W)ww>?7;z-=@y^PQqDGP$)I%w9CYj}bia-n zI`(C>cIqp1DkT@m5fJ^6CO=Z7^!a&8P=5Y;tap@<-_AId5RFfG#w%2U2-=a&{H2?~ z+ulPvy&@#NM zM^wqT1L#&(CdUh5qvg$dHT!o`RdxOz#K6H!-Zv|wbvNd|iE=NI(Lht>zUg7I7Y>y2 ziv;Au*)82t5>)B(B~9ro3TJ3!3ve5=v7Zz~h#_{45Xa#oe9I@wJt~`g=(RX>@}pb1 MXH7^FB=mLo57~00YybcN diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-313.pyc b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/__pycache__/seed.cpython-313.pyc deleted file mode 100644 index 3bbd4145a0f0b10e6d2c204c5d12eb3b00eaf83d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 844 zcmZ8e&ubG=5PrM+lGaUQQ3yzd^6&>3(dJ-1xQOVn4TO?Z5Go{_-RJIx>?FpCPP8P`B-sqg8w+R`>ik3zvws`QU+YzP{2Zd*8YZW zX0A2Q5OR;aCC@Lr89AW&Q8Nciu_>osvn$=j{II#2qeZ-?w5ZK?v#67(S-#38TUNsR ztA}4&Jc(K2awb>6K7_y=9&j)dFK`%`fgSoR82DWdG4lduL}?O&R`4TCLRe*;cz=nO5J{>TfGs=@YA&E~NOC@Aw=dZW4`iZd~g`w?mo zi9#kMO1>Ek%zj%|3Th2(Q>-tg+Gf9BRkE&FJ^>T&YHsNBnr|MuX4vygL>7M84xwE0I4ctP9`tfAlve(`H` zpUB^vq)-~wKi9`}LMvyqINBNRj2gqnczZ(coY5O6wDOZy#`c8XMK;POxe_hUbH`W+F7@EPeRVok)N^NjXW&=Qe? zwkGHS$+K60NcIbg!&O5)g8l%yxe3COF>T>g5j+p@EZ7(BALi7h@~kv=pq}J`3zTu` zgg=p@)CPPZj4z4q!2Z{wH&PSXoA8SfSL2d{R>;pAt;Vvjz2@AGR6Tfb@YB+MEbC&_ z`%;UBYkTf(OJS9&OwTG^j0PtX5UnnY-qe-Z{lP2Ow@zKR+xmcB_V7?EYoH zgVK^YylVk3yG7s#zQ8Pi-vXHoSy5w9x?p%RR7k zT7nkaY3U_}F`1_eGK=lnI2*6Gq>V-LJeYNCw~a)48NAquGqGL5Y{L89yrfYWf~WaRav`^2Qn`9*Y!Mdp)r;FXi#X$tsG42r&2?Jr5LP` zI}p99Ja&nezfGmFE&+pTWbGm^<5OK)=^_9SdFn#&y3R75E`&I`6fa)$39MXDa=pU*U+~pYy`P6JfmmRUUD7X`hp1$787N*=d89vs z;Omc3e1hUO3jBRuOZ)oOh5 zlLrvda*>=efEFYY15FLIcvUi@PdV^xNec0K2Wuk*{(BQ3?x6z6PRoicgaP_tl!sh$ z@);SUDl(MWIYVMG>xNDJqIcDhne0&(TQsj)Wt&j&DZJ;)VfJOF<}F&8MYCBmy#?C@ ztz^AjC*Fcp@;ifh%PrM1lcZcXiGMS&b%gJE4$b%;Dbq5I6 z$!TX~H(EF*EO|8-6B)_K)$&^Q4dL7G^U>__{gKc_CZblrk2wDcyiP5K&^QuxE5moK zz$qR?7=FA`A1H&UA{lz~k7;kobv(07v0t-r$#Z?n^|4pMri=60&#~`e6Vtj?vMY88 zS3IX|kqKQtN2+v6$Iu$Y!DAH8s9QCr@Em}EWg>8Oq7Qim1fT18RXjEi=Y1S|b>H?} z(*canDtOA9Uzoab!E~r~WgIWrj)Ps#$8}i6hsN5=tUfpt;8py~Jv3-lF08ln!OaU*8c528Js2_{*#9ea{V(O+}BWb=v zwQTw(lPZ++jMs)Y!Urr_4!?L2pvYTb0#+gV5JU?->qp&vtG8BewM(1n;q`Yn(kHgl zCpXe3@1NgHzw<-SMtZXK>24ymdiw6^wF{ex6RqhdntrEwySX;HtsULaj{cfR?IwF3 z^^UIBwvw60>AvN}@0z|F>snrH=eA-)Pt@e{r7gApNvdZ%HU3j-d^0uC`eZlJvwC*r z>~>;kBQdl-@+dL%DDfU($GeuxYtvitW9=^<$KtEningYH8~-}~O=2td<_iVs1J9x; z+1;9Xp@9x@;qWEpn|ct(=O>j1llru*{qude3{n5>9T4mSwkmlg%P+q}?ci;allf1~ zf%ijxJs+!b$VR34;^P}WF-txL&Nb~iHsKAhuUr34LjL6s@hmWP7=nG12(14pGs&!> zFb(z`&`|rzu3zAM)U8fIOvf=)G6EGRn2RAUPJp2N`Po0*i7>thNMsyb{7<=Pq369w z?^zwZJJ{}T)2;Z(`Za!}k}JuzOW|P}-Hwm^93Oc}<1nG!JE{DrpOZ90nlp65KRDAo zc!D{}f=v3rMEG@2_`OaHE`ryYlPXyKdqg-?hj_8~7rk=BvCfhV*x*V=bQHuhMUtc) i6#oSs*+Kda>e@l6KckW~`aE`A>i+8^NtccXD1QS?G7QfE diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py index 93ff7e2..be5ec9d 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/batch.py @@ -1,257 +1,257 @@ -from collections.abc import Sequence -from typing import List - -import numpy as np -import torch -from torch import Tensor - -from .data import Data -from .dataset import IndexType - - -class Batch(Data): - r"""A plain old python object modeling a batch of graphs as one big - (disconnected) graph. With :class:`torch_geometric.data.Data` being the - base class, all its methods can also be used here. - In addition, single graphs can be reconstructed via the assignment vector - :obj:`batch`, which maps each node to its respective graph identifier. - """ - - def __init__(self, batch=None, ptr=None, **kwargs): - super(Batch, self).__init__(**kwargs) - - for key, item in kwargs.items(): - if key == "num_nodes": - self.__num_nodes__ = item - else: - self[key] = item - - self.batch = batch - self.ptr = ptr - self.__data_class__ = Data - self.__slices__ = None - self.__cumsum__ = None - self.__cat_dims__ = None - self.__num_nodes_list__ = None - self.__num_graphs__ = None - - @classmethod - def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): - r"""Constructs a batch object from a python list holding - :class:`torch_geometric.data.Data` objects. - The assignment vector :obj:`batch` is created on the fly. - Additionally, creates assignment batch vectors for each key in - :obj:`follow_batch`. - Will exclude any keys given in :obj:`exclude_keys`.""" - - keys = list(set(data_list[0].keys) - set(exclude_keys)) - assert "batch" not in keys and "ptr" not in keys - - batch = cls() - for key in data_list[0].__dict__.keys(): - if key[:2] != "__" and key[-2:] != "__": - batch[key] = None - - batch.__num_graphs__ = len(data_list) - batch.__data_class__ = data_list[0].__class__ - for key in keys + ["batch"]: - batch[key] = [] - batch["ptr"] = [0] - - device = None - slices = {key: [0] for key in keys} - cumsum = {key: [0] for key in keys} - cat_dims = {} - num_nodes_list = [] - for i, data in enumerate(data_list): - for key in keys: - item = data[key] - - # Increase values by `cumsum` value. - cum = cumsum[key][-1] - if isinstance(item, Tensor) and item.dtype != torch.bool: - if not isinstance(cum, int) or cum != 0: - item = item + cum - elif isinstance(item, (int, float)): - item = item + cum - - # Gather the size of the `cat` dimension. - size = 1 - cat_dim = data.__cat_dim__(key, data[key]) - # 0-dimensional tensors have no dimension along which to - # concatenate, so we set `cat_dim` to `None`. - if isinstance(item, Tensor) and item.dim() == 0: - cat_dim = None - cat_dims[key] = cat_dim - - # Add a batch dimension to items whose `cat_dim` is `None`: - if isinstance(item, Tensor) and cat_dim is None: - cat_dim = 0 # Concatenate along this new batch dimension. - item = item.unsqueeze(0) - device = item.device - elif isinstance(item, Tensor): - size = item.size(cat_dim) - device = item.device - - batch[key].append(item) # Append item to the attribute list. - - slices[key].append(size + slices[key][-1]) - inc = data.__inc__(key, item) - if isinstance(inc, (tuple, list)): - inc = torch.tensor(inc) - cumsum[key].append(inc + cumsum[key][-1]) - - if key in follow_batch: - if isinstance(size, Tensor): - for j, size in enumerate(size.tolist()): - tmp = f"{key}_{j}_batch" - batch[tmp] = [] if i == 0 else batch[tmp] - batch[tmp].append( - torch.full((size,), i, dtype=torch.long, device=device) - ) - else: - tmp = f"{key}_batch" - batch[tmp] = [] if i == 0 else batch[tmp] - batch[tmp].append( - torch.full((size,), i, dtype=torch.long, device=device) - ) - - if hasattr(data, "__num_nodes__"): - num_nodes_list.append(data.__num_nodes__) - else: - num_nodes_list.append(None) - - num_nodes = data.num_nodes - if num_nodes is not None: - item = torch.full((num_nodes,), i, dtype=torch.long, device=device) - batch.batch.append(item) - batch.ptr.append(batch.ptr[-1] + num_nodes) - - batch.batch = None if len(batch.batch) == 0 else batch.batch - batch.ptr = None if len(batch.ptr) == 1 else batch.ptr - batch.__slices__ = slices - batch.__cumsum__ = cumsum - batch.__cat_dims__ = cat_dims - batch.__num_nodes_list__ = num_nodes_list - - ref_data = data_list[0] - for key in batch.keys: - items = batch[key] - item = items[0] - cat_dim = ref_data.__cat_dim__(key, item) - cat_dim = 0 if cat_dim is None else cat_dim - if isinstance(item, Tensor): - batch[key] = torch.cat(items, cat_dim) - elif isinstance(item, (int, float)): - batch[key] = torch.tensor(items) - - # if torch_geometric.is_debug_enabled(): - # batch.debug() - - return batch.contiguous() - - def get_example(self, idx: int) -> Data: - r"""Reconstructs the :class:`torch_geometric.data.Data` object at index - :obj:`idx` from the batch object. - The batch object must have been created via :meth:`from_data_list` in - order to be able to reconstruct the initial objects.""" - - if self.__slices__ is None: - raise RuntimeError( - ( - "Cannot reconstruct data list from batch because the batch " - "object was not created using `Batch.from_data_list()`." - ) - ) - - data = self.__data_class__() - idx = self.num_graphs + idx if idx < 0 else idx - - for key in self.__slices__.keys(): - item = self[key] - if self.__cat_dims__[key] is None: - # The item was concatenated along a new batch dimension, - # so just index in that dimension: - item = item[idx] - else: - # Narrow the item based on the values in `__slices__`. - if isinstance(item, Tensor): - dim = self.__cat_dims__[key] - start = self.__slices__[key][idx] - end = self.__slices__[key][idx + 1] - item = item.narrow(dim, start, end - start) - else: - start = self.__slices__[key][idx] - end = self.__slices__[key][idx + 1] - item = item[start:end] - item = item[0] if len(item) == 1 else item - - # Decrease its value by `cumsum` value: - cum = self.__cumsum__[key][idx] - if isinstance(item, Tensor): - if not isinstance(cum, int) or cum != 0: - item = item - cum - elif isinstance(item, (int, float)): - item = item - cum - - data[key] = item - - if self.__num_nodes_list__[idx] is not None: - data.num_nodes = self.__num_nodes_list__[idx] - - return data - - def index_select(self, idx: IndexType) -> List[Data]: - if isinstance(idx, slice): - idx = list(range(self.num_graphs)[idx]) - - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - idx = idx.flatten().tolist() - - elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() - - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - idx = idx.flatten().tolist() - - elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: - idx = idx.flatten().nonzero()[0].flatten().tolist() - - elif isinstance(idx, Sequence) and not isinstance(idx, str): - pass - - else: - raise IndexError( - f"Only integers, slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')" - ) - - return [self.get_example(i) for i in idx] - - def __getitem__(self, idx): - if isinstance(idx, str): - return super(Batch, self).__getitem__(idx) - elif isinstance(idx, (int, np.integer)): - return self.get_example(idx) - else: - return self.index_select(idx) - - def to_data_list(self) -> List[Data]: - r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects - from the batch object. - The batch object must have been created via :meth:`from_data_list` in - order to be able to reconstruct the initial objects.""" - return [self.get_example(i) for i in range(self.num_graphs)] - - @property - def num_graphs(self) -> int: - """Returns the number of graphs in the batch.""" - if self.__num_graphs__ is not None: - return self.__num_graphs__ - elif self.ptr is not None: - return self.ptr.numel() - 1 - elif self.batch is not None: - return int(self.batch.max()) + 1 - else: - raise ValueError +from collections.abc import Sequence +from typing import List + +import numpy as np +import torch +from torch import Tensor + +from .data import Data +from .dataset import IndexType + + +class Batch(Data): + r"""A plain old python object modeling a batch of graphs as one big + (disconnected) graph. With :class:`torch_geometric.data.Data` being the + base class, all its methods can also be used here. + In addition, single graphs can be reconstructed via the assignment vector + :obj:`batch`, which maps each node to its respective graph identifier. + """ + + def __init__(self, batch=None, ptr=None, **kwargs): + super(Batch, self).__init__(**kwargs) + + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + self.batch = batch + self.ptr = ptr + self.__data_class__ = Data + self.__slices__ = None + self.__cumsum__ = None + self.__cat_dims__ = None + self.__num_nodes_list__ = None + self.__num_graphs__ = None + + @classmethod + def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + The assignment vector :obj:`batch` is created on the fly. + Additionally, creates assignment batch vectors for each key in + :obj:`follow_batch`. + Will exclude any keys given in :obj:`exclude_keys`.""" + + keys = list(set(data_list[0].keys) - set(exclude_keys)) + assert "batch" not in keys and "ptr" not in keys + + batch = cls() + for key in data_list[0].__dict__.keys(): + if key[:2] != "__" and key[-2:] != "__": + batch[key] = None + + batch.__num_graphs__ = len(data_list) + batch.__data_class__ = data_list[0].__class__ + for key in keys + ["batch"]: + batch[key] = [] + batch["ptr"] = [0] + + device = None + slices = {key: [0] for key in keys} + cumsum = {key: [0] for key in keys} + cat_dims = {} + num_nodes_list = [] + for i, data in enumerate(data_list): + for key in keys: + item = data[key] + + # Increase values by `cumsum` value. + cum = cumsum[key][-1] + if isinstance(item, Tensor) and item.dtype != torch.bool: + if not isinstance(cum, int) or cum != 0: + item = item + cum + elif isinstance(item, (int, float)): + item = item + cum + + # Gather the size of the `cat` dimension. + size = 1 + cat_dim = data.__cat_dim__(key, data[key]) + # 0-dimensional tensors have no dimension along which to + # concatenate, so we set `cat_dim` to `None`. + if isinstance(item, Tensor) and item.dim() == 0: + cat_dim = None + cat_dims[key] = cat_dim + + # Add a batch dimension to items whose `cat_dim` is `None`: + if isinstance(item, Tensor) and cat_dim is None: + cat_dim = 0 # Concatenate along this new batch dimension. + item = item.unsqueeze(0) + device = item.device + elif isinstance(item, Tensor): + size = item.size(cat_dim) + device = item.device + + batch[key].append(item) # Append item to the attribute list. + + slices[key].append(size + slices[key][-1]) + inc = data.__inc__(key, item) + if isinstance(inc, (tuple, list)): + inc = torch.tensor(inc) + cumsum[key].append(inc + cumsum[key][-1]) + + if key in follow_batch: + if isinstance(size, Tensor): + for j, size in enumerate(size.tolist()): + tmp = f"{key}_{j}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + else: + tmp = f"{key}_batch" + batch[tmp] = [] if i == 0 else batch[tmp] + batch[tmp].append( + torch.full((size,), i, dtype=torch.long, device=device) + ) + + if hasattr(data, "__num_nodes__"): + num_nodes_list.append(data.__num_nodes__) + else: + num_nodes_list.append(None) + + num_nodes = data.num_nodes + if num_nodes is not None: + item = torch.full((num_nodes,), i, dtype=torch.long, device=device) + batch.batch.append(item) + batch.ptr.append(batch.ptr[-1] + num_nodes) + + batch.batch = None if len(batch.batch) == 0 else batch.batch + batch.ptr = None if len(batch.ptr) == 1 else batch.ptr + batch.__slices__ = slices + batch.__cumsum__ = cumsum + batch.__cat_dims__ = cat_dims + batch.__num_nodes_list__ = num_nodes_list + + ref_data = data_list[0] + for key in batch.keys: + items = batch[key] + item = items[0] + cat_dim = ref_data.__cat_dim__(key, item) + cat_dim = 0 if cat_dim is None else cat_dim + if isinstance(item, Tensor): + batch[key] = torch.cat(items, cat_dim) + elif isinstance(item, (int, float)): + batch[key] = torch.tensor(items) + + # if torch_geometric.is_debug_enabled(): + # batch.debug() + + return batch.contiguous() + + def get_example(self, idx: int) -> Data: + r"""Reconstructs the :class:`torch_geometric.data.Data` object at index + :obj:`idx` from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + + if self.__slices__ is None: + raise RuntimeError( + ( + "Cannot reconstruct data list from batch because the batch " + "object was not created using `Batch.from_data_list()`." + ) + ) + + data = self.__data_class__() + idx = self.num_graphs + idx if idx < 0 else idx + + for key in self.__slices__.keys(): + item = self[key] + if self.__cat_dims__[key] is None: + # The item was concatenated along a new batch dimension, + # so just index in that dimension: + item = item[idx] + else: + # Narrow the item based on the values in `__slices__`. + if isinstance(item, Tensor): + dim = self.__cat_dims__[key] + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item.narrow(dim, start, end - start) + else: + start = self.__slices__[key][idx] + end = self.__slices__[key][idx + 1] + item = item[start:end] + item = item[0] if len(item) == 1 else item + + # Decrease its value by `cumsum` value: + cum = self.__cumsum__[key][idx] + if isinstance(item, Tensor): + if not isinstance(cum, int) or cum != 0: + item = item - cum + elif isinstance(item, (int, float)): + item = item - cum + + data[key] = item + + if self.__num_nodes_list__[idx] is not None: + data.num_nodes = self.__num_nodes_list__[idx] + + return data + + def index_select(self, idx: IndexType) -> List[Data]: + if isinstance(idx, slice): + idx = list(range(self.num_graphs)[idx]) + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + idx = idx.flatten().tolist() + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + idx = idx.flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0].flatten().tolist() + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + pass + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + return [self.get_example(i) for i in idx] + + def __getitem__(self, idx): + if isinstance(idx, str): + return super(Batch, self).__getitem__(idx) + elif isinstance(idx, (int, np.integer)): + return self.get_example(idx) + else: + return self.index_select(idx) + + def to_data_list(self) -> List[Data]: + r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects + from the batch object. + The batch object must have been created via :meth:`from_data_list` in + order to be able to reconstruct the initial objects.""" + return [self.get_example(i) for i in range(self.num_graphs)] + + @property + def num_graphs(self) -> int: + """Returns the number of graphs in the batch.""" + if self.__num_graphs__ is not None: + return self.__num_graphs__ + elif self.ptr is not None: + return self.ptr.numel() - 1 + elif self.batch is not None: + return int(self.batch.max()) + 1 + else: + raise ValueError diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py index 0c41726..4e1ab30 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/data.py @@ -1,441 +1,441 @@ -import collections -import copy -import re - -import torch - -# from ..utils.num_nodes import maybe_num_nodes - -__num_nodes_warn_msg__ = ( - "The number of nodes in your data object can only be inferred by its {} " - "indices, and hence may result in unexpected batch-wise behavior, e.g., " - "in case there exists isolated nodes. Please consider explicitly setting " - "the number of nodes for this data object by assigning it to " - "data.num_nodes." -) - - -def size_repr(key, item, indent=0): - indent_str = " " * indent - if torch.is_tensor(item) and item.dim() == 0: - out = item.item() - elif torch.is_tensor(item): - out = str(list(item.size())) - elif isinstance(item, list) or isinstance(item, tuple): - out = str([len(item)]) - elif isinstance(item, dict): - lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] - out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" - elif isinstance(item, str): - out = f'"{item}"' - else: - out = str(item) - - return f"{indent_str}{key}={out}" - - -class Data(object): - r"""A plain old python object modeling a single graph with various - (optional) attributes: - - Args: - x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, - num_node_features]`. (default: :obj:`None`) - edge_index (LongTensor, optional): Graph connectivity in COO format - with shape :obj:`[2, num_edges]`. (default: :obj:`None`) - edge_attr (Tensor, optional): Edge feature matrix with shape - :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) - y (Tensor, optional): Graph or node targets with arbitrary shape. - (default: :obj:`None`) - pos (Tensor, optional): Node position matrix with shape - :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) - normal (Tensor, optional): Normal vector matrix with shape - :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) - face (LongTensor, optional): Face adjacency matrix with shape - :obj:`[3, num_faces]`. (default: :obj:`None`) - - The data object is not restricted to these attributes and can be extended - by any other additional data. - - Example:: - - data = Data(x=x, edge_index=edge_index) - data.train_idx = torch.tensor([...], dtype=torch.long) - data.test_mask = torch.tensor([...], dtype=torch.bool) - """ - - def __init__( - self, - x=None, - edge_index=None, - edge_attr=None, - y=None, - pos=None, - normal=None, - face=None, - **kwargs, - ): - self.x = x - self.edge_index = edge_index - self.edge_attr = edge_attr - self.y = y - self.pos = pos - self.normal = normal - self.face = face - for key, item in kwargs.items(): - if key == "num_nodes": - self.__num_nodes__ = item - else: - self[key] = item - - if edge_index is not None and edge_index.dtype != torch.long: - raise ValueError( - ( - f"Argument `edge_index` needs to be of type `torch.long` but " - f"found type `{edge_index.dtype}`." - ) - ) - - if face is not None and face.dtype != torch.long: - raise ValueError( - ( - f"Argument `face` needs to be of type `torch.long` but found " - f"type `{face.dtype}`." - ) - ) - - @classmethod - def from_dict(cls, dictionary): - r"""Creates a data object from a python dictionary.""" - data = cls() - - for key, item in dictionary.items(): - data[key] = item - - return data - - def to_dict(self): - return {key: item for key, item in self} - - def to_namedtuple(self): - keys = self.keys - DataTuple = collections.namedtuple("DataTuple", keys) - return DataTuple(*[self[key] for key in keys]) - - def __getitem__(self, key): - r"""Gets the data of the attribute :obj:`key`.""" - return getattr(self, key, None) - - def __setitem__(self, key, value): - """Sets the attribute :obj:`key` to :obj:`value`.""" - setattr(self, key, value) - - def __delitem__(self, key): - r"""Delete the data of the attribute :obj:`key`.""" - return delattr(self, key) - - @property - def keys(self): - r"""Returns all names of graph attributes.""" - keys = [key for key in self.__dict__.keys() if self[key] is not None] - keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] - return keys - - def __len__(self): - r"""Returns the number of all present attributes.""" - return len(self.keys) - - def __contains__(self, key): - r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the - data.""" - return key in self.keys - - def __iter__(self): - r"""Iterates over all present attributes in the data, yielding their - attribute names and content.""" - for key in sorted(self.keys): - yield key, self[key] - - def __call__(self, *keys): - r"""Iterates over all attributes :obj:`*keys` in the data, yielding - their attribute names and content. - If :obj:`*keys` is not given this method will iterative over all - present attributes.""" - for key in sorted(self.keys) if not keys else keys: - if key in self: - yield key, self[key] - - def __cat_dim__(self, key, value): - r"""Returns the dimension for which :obj:`value` of attribute - :obj:`key` will get concatenated when creating batches. - - .. note:: - - This method is for internal use only, and should only be overridden - if the batch concatenation process is corrupted for a specific data - attribute. - """ - if bool(re.search("(index|face)", key)): - return -1 - return 0 - - def __inc__(self, key, value): - r"""Returns the incremental count to cumulatively increase the value - of the next attribute of :obj:`key` when creating batches. - - .. note:: - - This method is for internal use only, and should only be overridden - if the batch concatenation process is corrupted for a specific data - attribute. - """ - # Only `*index*` and `*face*` attributes should be cumulatively summed - # up when creating batches. - return self.num_nodes if bool(re.search("(index|face)", key)) else 0 - - @property - def num_nodes(self): - r"""Returns or sets the number of nodes in the graph. - - .. note:: - The number of nodes in your data object is typically automatically - inferred, *e.g.*, when node features :obj:`x` are present. - In some cases however, a graph may only be given by its edge - indices :obj:`edge_index`. - PyTorch Geometric then *guesses* the number of nodes - according to :obj:`edge_index.max().item() + 1`, but in case there - exists isolated nodes, this number has not to be correct and can - therefore result in unexpected batch-wise behavior. - Thus, we recommend to set the number of nodes in your data object - explicitly via :obj:`data.num_nodes = ...`. - You will be given a warning that requests you to do so. - """ - if hasattr(self, "__num_nodes__"): - return self.__num_nodes__ - for key, item in self("x", "pos", "normal", "batch"): - return item.size(self.__cat_dim__(key, item)) - if hasattr(self, "adj"): - return self.adj.size(0) - if hasattr(self, "adj_t"): - return self.adj_t.size(1) - # if self.face is not None: - # logging.warning(__num_nodes_warn_msg__.format("face")) - # return maybe_num_nodes(self.face) - # if self.edge_index is not None: - # logging.warning(__num_nodes_warn_msg__.format("edge")) - # return maybe_num_nodes(self.edge_index) - return None - - @num_nodes.setter - def num_nodes(self, num_nodes): - self.__num_nodes__ = num_nodes - - @property - def num_edges(self): - """ - Returns the number of edges in the graph. - For undirected graphs, this will return the number of bi-directional - edges, which is double the amount of unique edges. - """ - for key, item in self("edge_index", "edge_attr"): - return item.size(self.__cat_dim__(key, item)) - for key, item in self("adj", "adj_t"): - return item.nnz() - return None - - @property - def num_faces(self): - r"""Returns the number of faces in the mesh.""" - if self.face is not None: - return self.face.size(self.__cat_dim__("face", self.face)) - return None - - @property - def num_node_features(self): - r"""Returns the number of features per node in the graph.""" - if self.x is None: - return 0 - return 1 if self.x.dim() == 1 else self.x.size(1) - - @property - def num_features(self): - r"""Alias for :py:attr:`~num_node_features`.""" - return self.num_node_features - - @property - def num_edge_features(self): - r"""Returns the number of features per edge in the graph.""" - if self.edge_attr is None: - return 0 - return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) - - def __apply__(self, item, func): - if torch.is_tensor(item): - return func(item) - elif isinstance(item, (tuple, list)): - return [self.__apply__(v, func) for v in item] - elif isinstance(item, dict): - return {k: self.__apply__(v, func) for k, v in item.items()} - else: - return item - - def apply(self, func, *keys): - r"""Applies the function :obj:`func` to all tensor attributes - :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to - all present attributes. - """ - for key, item in self(*keys): - self[key] = self.__apply__(item, func) - return self - - def contiguous(self, *keys): - r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. - If :obj:`*keys` is not given, all present attributes are ensured to - have a contiguous memory layout.""" - return self.apply(lambda x: x.contiguous(), *keys) - - def to(self, device, *keys, **kwargs): - r"""Performs tensor dtype and/or device conversion to all attributes - :obj:`*keys`. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.to(device, **kwargs), *keys) - - def cpu(self, *keys): - r"""Copies all attributes :obj:`*keys` to CPU memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.cpu(), *keys) - - def cuda(self, device=None, non_blocking=False, *keys): - r"""Copies all attributes :obj:`*keys` to CUDA memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply( - lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys - ) - - def clone(self): - r"""Performs a deep-copy of the data object.""" - return self.__class__.from_dict( - { - k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) - for k, v in self.__dict__.items() - } - ) - - def pin_memory(self, *keys): - r"""Copies all attributes :obj:`*keys` to pinned memory. - If :obj:`*keys` is not given, the conversion is applied to all present - attributes.""" - return self.apply(lambda x: x.pin_memory(), *keys) - - def debug(self): - if self.edge_index is not None: - if self.edge_index.dtype != torch.long: - raise RuntimeError( - ( - "Expected edge indices of dtype {}, but found dtype " " {}" - ).format(torch.long, self.edge_index.dtype) - ) - - if self.face is not None: - if self.face.dtype != torch.long: - raise RuntimeError( - ( - "Expected face indices of dtype {}, but found dtype " " {}" - ).format(torch.long, self.face.dtype) - ) - - if self.edge_index is not None: - if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: - raise RuntimeError( - ( - "Edge indices should have shape [2, num_edges] but found" - " shape {}" - ).format(self.edge_index.size()) - ) - - if self.edge_index is not None and self.num_nodes is not None: - if self.edge_index.numel() > 0: - min_index = self.edge_index.min() - max_index = self.edge_index.max() - else: - min_index = max_index = 0 - if min_index < 0 or max_index > self.num_nodes - 1: - raise RuntimeError( - ( - "Edge indices must lay in the interval [0, {}]" - " but found them in the interval [{}, {}]" - ).format(self.num_nodes - 1, min_index, max_index) - ) - - if self.face is not None: - if self.face.dim() != 2 or self.face.size(0) != 3: - raise RuntimeError( - ( - "Face indices should have shape [3, num_faces] but found" - " shape {}" - ).format(self.face.size()) - ) - - if self.face is not None and self.num_nodes is not None: - if self.face.numel() > 0: - min_index = self.face.min() - max_index = self.face.max() - else: - min_index = max_index = 0 - if min_index < 0 or max_index > self.num_nodes - 1: - raise RuntimeError( - ( - "Face indices must lay in the interval [0, {}]" - " but found them in the interval [{}, {}]" - ).format(self.num_nodes - 1, min_index, max_index) - ) - - if self.edge_index is not None and self.edge_attr is not None: - if self.edge_index.size(1) != self.edge_attr.size(0): - raise RuntimeError( - ( - "Edge indices and edge attributes hold a differing " - "number of edges, found {} and {}" - ).format(self.edge_index.size(), self.edge_attr.size()) - ) - - if self.x is not None and self.num_nodes is not None: - if self.x.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node features should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.x.size(0)) - ) - - if self.pos is not None and self.num_nodes is not None: - if self.pos.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node positions should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.pos.size(0)) - ) - - if self.normal is not None and self.num_nodes is not None: - if self.normal.size(0) != self.num_nodes: - raise RuntimeError( - ( - "Node normals should hold {} elements in the first " - "dimension but found {}" - ).format(self.num_nodes, self.normal.size(0)) - ) - - def __repr__(self): - cls = str(self.__class__.__name__) - has_dict = any([isinstance(item, dict) for _, item in self]) - - if not has_dict: - info = [size_repr(key, item) for key, item in self] - return "{}({})".format(cls, ", ".join(info)) - else: - info = [size_repr(key, item, indent=2) for key, item in self] - return "{}(\n{}\n)".format(cls, ",\n".join(info)) +import collections +import copy +import re + +import torch + +# from ..utils.num_nodes import maybe_num_nodes + +__num_nodes_warn_msg__ = ( + "The number of nodes in your data object can only be inferred by its {} " + "indices, and hence may result in unexpected batch-wise behavior, e.g., " + "in case there exists isolated nodes. Please consider explicitly setting " + "the number of nodes for this data object by assigning it to " + "data.num_nodes." +) + + +def size_repr(key, item, indent=0): + indent_str = " " * indent + if torch.is_tensor(item) and item.dim() == 0: + out = item.item() + elif torch.is_tensor(item): + out = str(list(item.size())) + elif isinstance(item, list) or isinstance(item, tuple): + out = str([len(item)]) + elif isinstance(item, dict): + lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] + out = "{\n" + ",\n".join(lines) + "\n" + indent_str + "}" + elif isinstance(item, str): + out = f'"{item}"' + else: + out = str(item) + + return f"{indent_str}{key}={out}" + + +class Data(object): + r"""A plain old python object modeling a single graph with various + (optional) attributes: + + Args: + x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes, + num_node_features]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Graph connectivity in COO format + with shape :obj:`[2, num_edges]`. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + y (Tensor, optional): Graph or node targets with arbitrary shape. + (default: :obj:`None`) + pos (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + normal (Tensor, optional): Normal vector matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + face (LongTensor, optional): Face adjacency matrix with shape + :obj:`[3, num_faces]`. (default: :obj:`None`) + + The data object is not restricted to these attributes and can be extended + by any other additional data. + + Example:: + + data = Data(x=x, edge_index=edge_index) + data.train_idx = torch.tensor([...], dtype=torch.long) + data.test_mask = torch.tensor([...], dtype=torch.bool) + """ + + def __init__( + self, + x=None, + edge_index=None, + edge_attr=None, + y=None, + pos=None, + normal=None, + face=None, + **kwargs, + ): + self.x = x + self.edge_index = edge_index + self.edge_attr = edge_attr + self.y = y + self.pos = pos + self.normal = normal + self.face = face + for key, item in kwargs.items(): + if key == "num_nodes": + self.__num_nodes__ = item + else: + self[key] = item + + if edge_index is not None and edge_index.dtype != torch.long: + raise ValueError( + ( + f"Argument `edge_index` needs to be of type `torch.long` but " + f"found type `{edge_index.dtype}`." + ) + ) + + if face is not None and face.dtype != torch.long: + raise ValueError( + ( + f"Argument `face` needs to be of type `torch.long` but found " + f"type `{face.dtype}`." + ) + ) + + @classmethod + def from_dict(cls, dictionary): + r"""Creates a data object from a python dictionary.""" + data = cls() + + for key, item in dictionary.items(): + data[key] = item + + return data + + def to_dict(self): + return {key: item for key, item in self} + + def to_namedtuple(self): + keys = self.keys + DataTuple = collections.namedtuple("DataTuple", keys) + return DataTuple(*[self[key] for key in keys]) + + def __getitem__(self, key): + r"""Gets the data of the attribute :obj:`key`.""" + return getattr(self, key, None) + + def __setitem__(self, key, value): + """Sets the attribute :obj:`key` to :obj:`value`.""" + setattr(self, key, value) + + def __delitem__(self, key): + r"""Delete the data of the attribute :obj:`key`.""" + return delattr(self, key) + + @property + def keys(self): + r"""Returns all names of graph attributes.""" + keys = [key for key in self.__dict__.keys() if self[key] is not None] + keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"] + return keys + + def __len__(self): + r"""Returns the number of all present attributes.""" + return len(self.keys) + + def __contains__(self, key): + r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the + data.""" + return key in self.keys + + def __iter__(self): + r"""Iterates over all present attributes in the data, yielding their + attribute names and content.""" + for key in sorted(self.keys): + yield key, self[key] + + def __call__(self, *keys): + r"""Iterates over all attributes :obj:`*keys` in the data, yielding + their attribute names and content. + If :obj:`*keys` is not given this method will iterative over all + present attributes.""" + for key in sorted(self.keys) if not keys else keys: + if key in self: + yield key, self[key] + + def __cat_dim__(self, key, value): + r"""Returns the dimension for which :obj:`value` of attribute + :obj:`key` will get concatenated when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + if bool(re.search("(index|face)", key)): + return -1 + return 0 + + def __inc__(self, key, value): + r"""Returns the incremental count to cumulatively increase the value + of the next attribute of :obj:`key` when creating batches. + + .. note:: + + This method is for internal use only, and should only be overridden + if the batch concatenation process is corrupted for a specific data + attribute. + """ + # Only `*index*` and `*face*` attributes should be cumulatively summed + # up when creating batches. + return self.num_nodes if bool(re.search("(index|face)", key)) else 0 + + @property + def num_nodes(self): + r"""Returns or sets the number of nodes in the graph. + + .. note:: + The number of nodes in your data object is typically automatically + inferred, *e.g.*, when node features :obj:`x` are present. + In some cases however, a graph may only be given by its edge + indices :obj:`edge_index`. + PyTorch Geometric then *guesses* the number of nodes + according to :obj:`edge_index.max().item() + 1`, but in case there + exists isolated nodes, this number has not to be correct and can + therefore result in unexpected batch-wise behavior. + Thus, we recommend to set the number of nodes in your data object + explicitly via :obj:`data.num_nodes = ...`. + You will be given a warning that requests you to do so. + """ + if hasattr(self, "__num_nodes__"): + return self.__num_nodes__ + for key, item in self("x", "pos", "normal", "batch"): + return item.size(self.__cat_dim__(key, item)) + if hasattr(self, "adj"): + return self.adj.size(0) + if hasattr(self, "adj_t"): + return self.adj_t.size(1) + # if self.face is not None: + # logging.warning(__num_nodes_warn_msg__.format("face")) + # return maybe_num_nodes(self.face) + # if self.edge_index is not None: + # logging.warning(__num_nodes_warn_msg__.format("edge")) + # return maybe_num_nodes(self.edge_index) + return None + + @num_nodes.setter + def num_nodes(self, num_nodes): + self.__num_nodes__ = num_nodes + + @property + def num_edges(self): + """ + Returns the number of edges in the graph. + For undirected graphs, this will return the number of bi-directional + edges, which is double the amount of unique edges. + """ + for key, item in self("edge_index", "edge_attr"): + return item.size(self.__cat_dim__(key, item)) + for key, item in self("adj", "adj_t"): + return item.nnz() + return None + + @property + def num_faces(self): + r"""Returns the number of faces in the mesh.""" + if self.face is not None: + return self.face.size(self.__cat_dim__("face", self.face)) + return None + + @property + def num_node_features(self): + r"""Returns the number of features per node in the graph.""" + if self.x is None: + return 0 + return 1 if self.x.dim() == 1 else self.x.size(1) + + @property + def num_features(self): + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self): + r"""Returns the number of features per edge in the graph.""" + if self.edge_attr is None: + return 0 + return 1 if self.edge_attr.dim() == 1 else self.edge_attr.size(1) + + def __apply__(self, item, func): + if torch.is_tensor(item): + return func(item) + elif isinstance(item, (tuple, list)): + return [self.__apply__(v, func) for v in item] + elif isinstance(item, dict): + return {k: self.__apply__(v, func) for k, v in item.items()} + else: + return item + + def apply(self, func, *keys): + r"""Applies the function :obj:`func` to all tensor attributes + :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to + all present attributes. + """ + for key, item in self(*keys): + self[key] = self.__apply__(item, func) + return self + + def contiguous(self, *keys): + r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`. + If :obj:`*keys` is not given, all present attributes are ensured to + have a contiguous memory layout.""" + return self.apply(lambda x: x.contiguous(), *keys) + + def to(self, device, *keys, **kwargs): + r"""Performs tensor dtype and/or device conversion to all attributes + :obj:`*keys`. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.to(device, **kwargs), *keys) + + def cpu(self, *keys): + r"""Copies all attributes :obj:`*keys` to CPU memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.cpu(), *keys) + + def cuda(self, device=None, non_blocking=False, *keys): + r"""Copies all attributes :obj:`*keys` to CUDA memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply( + lambda x: x.cuda(device=device, non_blocking=non_blocking), *keys + ) + + def clone(self): + r"""Performs a deep-copy of the data object.""" + return self.__class__.from_dict( + { + k: v.clone() if torch.is_tensor(v) else copy.deepcopy(v) + for k, v in self.__dict__.items() + } + ) + + def pin_memory(self, *keys): + r"""Copies all attributes :obj:`*keys` to pinned memory. + If :obj:`*keys` is not given, the conversion is applied to all present + attributes.""" + return self.apply(lambda x: x.pin_memory(), *keys) + + def debug(self): + if self.edge_index is not None: + if self.edge_index.dtype != torch.long: + raise RuntimeError( + ( + "Expected edge indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.edge_index.dtype) + ) + + if self.face is not None: + if self.face.dtype != torch.long: + raise RuntimeError( + ( + "Expected face indices of dtype {}, but found dtype " " {}" + ).format(torch.long, self.face.dtype) + ) + + if self.edge_index is not None: + if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2: + raise RuntimeError( + ( + "Edge indices should have shape [2, num_edges] but found" + " shape {}" + ).format(self.edge_index.size()) + ) + + if self.edge_index is not None and self.num_nodes is not None: + if self.edge_index.numel() > 0: + min_index = self.edge_index.min() + max_index = self.edge_index.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Edge indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.face is not None: + if self.face.dim() != 2 or self.face.size(0) != 3: + raise RuntimeError( + ( + "Face indices should have shape [3, num_faces] but found" + " shape {}" + ).format(self.face.size()) + ) + + if self.face is not None and self.num_nodes is not None: + if self.face.numel() > 0: + min_index = self.face.min() + max_index = self.face.max() + else: + min_index = max_index = 0 + if min_index < 0 or max_index > self.num_nodes - 1: + raise RuntimeError( + ( + "Face indices must lay in the interval [0, {}]" + " but found them in the interval [{}, {}]" + ).format(self.num_nodes - 1, min_index, max_index) + ) + + if self.edge_index is not None and self.edge_attr is not None: + if self.edge_index.size(1) != self.edge_attr.size(0): + raise RuntimeError( + ( + "Edge indices and edge attributes hold a differing " + "number of edges, found {} and {}" + ).format(self.edge_index.size(), self.edge_attr.size()) + ) + + if self.x is not None and self.num_nodes is not None: + if self.x.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node features should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.x.size(0)) + ) + + if self.pos is not None and self.num_nodes is not None: + if self.pos.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node positions should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.pos.size(0)) + ) + + if self.normal is not None and self.num_nodes is not None: + if self.normal.size(0) != self.num_nodes: + raise RuntimeError( + ( + "Node normals should hold {} elements in the first " + "dimension but found {}" + ).format(self.num_nodes, self.normal.size(0)) + ) + + def __repr__(self): + cls = str(self.__class__.__name__) + has_dict = any([isinstance(item, dict) for _, item in self]) + + if not has_dict: + info = [size_repr(key, item) for key, item in self] + return "{}({})".format(cls, ", ".join(info)) + else: + info = [size_repr(key, item, indent=2) for key, item in self] + return "{}(\n{}\n)".format(cls, ",\n".join(info)) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py index 9953c14..396b7e7 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataloader.py @@ -1,87 +1,87 @@ -from collections.abc import Mapping, Sequence -from typing import List, Optional, Union - -import torch.utils.data -from torch.utils.data.dataloader import default_collate - -from .batch import Batch -from .data import Data -from .dataset import Dataset - - -class Collater: - def __init__(self, follow_batch, exclude_keys): - self.follow_batch = follow_batch - self.exclude_keys = exclude_keys - - def __call__(self, batch): - elem = batch[0] - if isinstance(elem, Data): - return Batch.from_data_list( - batch, - follow_batch=self.follow_batch, - exclude_keys=self.exclude_keys, - ) - elif isinstance(elem, torch.Tensor): - return default_collate(batch) - elif isinstance(elem, float): - return torch.tensor(batch, dtype=torch.float) - elif isinstance(elem, int): - return torch.tensor(batch) - elif isinstance(elem, str): - return batch - elif isinstance(elem, Mapping): - return {key: self([data[key] for data in batch]) for key in elem} - elif isinstance(elem, tuple) and hasattr(elem, "_fields"): - return type(elem)(*(self(s) for s in zip(*batch))) - elif isinstance(elem, Sequence) and not isinstance(elem, str): - return [self(s) for s in zip(*batch)] - - raise TypeError(f"DataLoader found invalid type: {type(elem)}") - - def collate(self, batch): # Deprecated... - return self(batch) - - -class DataLoader(torch.utils.data.DataLoader): - r"""A data loader which merges data objects from a - :class:`torch_geometric.data.Dataset` to a mini-batch. - Data objects can be either of type :class:`~torch_geometric.data.Data` or - :class:`~torch_geometric.data.HeteroData`. - Args: - dataset (Dataset): The dataset from which to load the data. - batch_size (int, optional): How many samples per batch to load. - (default: :obj:`1`) - shuffle (bool, optional): If set to :obj:`True`, the data will be - reshuffled at every epoch. (default: :obj:`False`) - follow_batch (List[str], optional): Creates assignment batch - vectors for each key in the list. (default: :obj:`None`) - exclude_keys (List[str], optional): Will exclude each key in the - list. (default: :obj:`None`) - **kwargs (optional): Additional arguments of - :class:`torch.utils.data.DataLoader`. - """ - - def __init__( - self, - dataset: Dataset, - batch_size: int = 1, - shuffle: bool = False, - follow_batch: Optional[List[str]] = [None], - exclude_keys: Optional[List[str]] = [None], - **kwargs, - ): - if "collate_fn" in kwargs: - del kwargs["collate_fn"] - - # Save for PyTorch Lightning < 1.6: - self.follow_batch = follow_batch - self.exclude_keys = exclude_keys - - super().__init__( - dataset, - batch_size, - shuffle, - collate_fn=Collater(follow_batch, exclude_keys), - **kwargs, - ) +from collections.abc import Mapping, Sequence +from typing import List, Optional, Union + +import torch.utils.data +from torch.utils.data.dataloader import default_collate + +from .batch import Batch +from .data import Data +from .dataset import Dataset + + +class Collater: + def __init__(self, follow_batch, exclude_keys): + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + def __call__(self, batch): + elem = batch[0] + if isinstance(elem, Data): + return Batch.from_data_list( + batch, + follow_batch=self.follow_batch, + exclude_keys=self.exclude_keys, + ) + elif isinstance(elem, torch.Tensor): + return default_collate(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, str): + return batch + elif isinstance(elem, Mapping): + return {key: self([data[key] for data in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): + return type(elem)(*(self(s) for s in zip(*batch))) + elif isinstance(elem, Sequence) and not isinstance(elem, str): + return [self(s) for s in zip(*batch)] + + raise TypeError(f"DataLoader found invalid type: {type(elem)}") + + def collate(self, batch): # Deprecated... + return self(batch) + + +class DataLoader(torch.utils.data.DataLoader): + r"""A data loader which merges data objects from a + :class:`torch_geometric.data.Dataset` to a mini-batch. + Data objects can be either of type :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData`. + Args: + dataset (Dataset): The dataset from which to load the data. + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + shuffle (bool, optional): If set to :obj:`True`, the data will be + reshuffled at every epoch. (default: :obj:`False`) + follow_batch (List[str], optional): Creates assignment batch + vectors for each key in the list. (default: :obj:`None`) + exclude_keys (List[str], optional): Will exclude each key in the + list. (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`torch.utils.data.DataLoader`. + """ + + def __init__( + self, + dataset: Dataset, + batch_size: int = 1, + shuffle: bool = False, + follow_batch: Optional[List[str]] = [None], + exclude_keys: Optional[List[str]] = [None], + **kwargs, + ): + if "collate_fn" in kwargs: + del kwargs["collate_fn"] + + # Save for PyTorch Lightning < 1.6: + self.follow_batch = follow_batch + self.exclude_keys = exclude_keys + + super().__init__( + dataset, + batch_size, + shuffle, + collate_fn=Collater(follow_batch, exclude_keys), + **kwargs, + ) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py index 7b4db34..b4aeb2b 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/dataset.py @@ -1,280 +1,280 @@ -import copy -import os.path as osp -import re -import warnings -from collections.abc import Sequence -from typing import Any, Callable, List, Optional, Tuple, Union - -import numpy as np -import torch.utils.data -from torch import Tensor - -from .data import Data -from .utils import makedirs - -IndexType = Union[slice, Tensor, np.ndarray, Sequence] - - -class Dataset(torch.utils.data.Dataset): - r"""Dataset base class for creating graph datasets. - See `here `__ for the accompanying tutorial. - - Args: - root (string, optional): Root directory where the dataset should be - saved. (optional: :obj:`None`) - transform (callable, optional): A function/transform that takes in an - :obj:`torch_geometric.data.Data` object and returns a transformed - version. The data object will be transformed before every access. - (default: :obj:`None`) - pre_transform (callable, optional): A function/transform that takes in - an :obj:`torch_geometric.data.Data` object and returns a - transformed version. The data object will be transformed before - being saved to disk. (default: :obj:`None`) - pre_filter (callable, optional): A function that takes in an - :obj:`torch_geometric.data.Data` object and returns a boolean - value, indicating whether the data object should be included in the - final dataset. (default: :obj:`None`) - """ - - @property - def raw_file_names(self) -> Union[str, List[str], Tuple]: - r"""The name of the files to find in the :obj:`self.raw_dir` folder in - order to skip the download.""" - raise NotImplementedError - - @property - def processed_file_names(self) -> Union[str, List[str], Tuple]: - r"""The name of the files to find in the :obj:`self.processed_dir` - folder in order to skip the processing.""" - raise NotImplementedError - - def download(self): - r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" - raise NotImplementedError - - def process(self): - r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" - raise NotImplementedError - - def len(self) -> int: - raise NotImplementedError - - def get(self, idx: int) -> Data: - r"""Gets the data object at index :obj:`idx`.""" - raise NotImplementedError - - def __init__( - self, - root: Optional[str] = None, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, - ): - super().__init__() - - if isinstance(root, str): - root = osp.expanduser(osp.normpath(root)) - - self.root = root - self.transform = transform - self.pre_transform = pre_transform - self.pre_filter = pre_filter - self._indices: Optional[Sequence] = None - - if "download" in self.__class__.__dict__.keys(): - self._download() - - if "process" in self.__class__.__dict__.keys(): - self._process() - - def indices(self) -> Sequence: - return range(self.len()) if self._indices is None else self._indices - - @property - def raw_dir(self) -> str: - return osp.join(self.root, "raw") - - @property - def processed_dir(self) -> str: - return osp.join(self.root, "processed") - - @property - def num_node_features(self) -> int: - r"""Returns the number of features per node in the dataset.""" - data = self[0] - if hasattr(data, "num_node_features"): - return data.num_node_features - raise AttributeError( - f"'{data.__class__.__name__}' object has no " - f"attribute 'num_node_features'" - ) - - @property - def num_features(self) -> int: - r"""Alias for :py:attr:`~num_node_features`.""" - return self.num_node_features - - @property - def num_edge_features(self) -> int: - r"""Returns the number of features per edge in the dataset.""" - data = self[0] - if hasattr(data, "num_edge_features"): - return data.num_edge_features - raise AttributeError( - f"'{data.__class__.__name__}' object has no " - f"attribute 'num_edge_features'" - ) - - @property - def raw_paths(self) -> List[str]: - r"""The filepaths to find in order to skip the download.""" - files = to_list(self.raw_file_names) - return [osp.join(self.raw_dir, f) for f in files] - - @property - def processed_paths(self) -> List[str]: - r"""The filepaths to find in the :obj:`self.processed_dir` - folder in order to skip the processing.""" - files = to_list(self.processed_file_names) - return [osp.join(self.processed_dir, f) for f in files] - - def _download(self): - if files_exist(self.raw_paths): # pragma: no cover - return - - makedirs(self.raw_dir) - self.download() - - def _process(self): - f = osp.join(self.processed_dir, "pre_transform.pt") - if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): - warnings.warn( - f"The `pre_transform` argument differs from the one used in " - f"the pre-processed version of this dataset. If you want to " - f"make use of another pre-processing technique, make sure to " - f"sure to delete '{self.processed_dir}' first" - ) - - f = osp.join(self.processed_dir, "pre_filter.pt") - if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): - warnings.warn( - "The `pre_filter` argument differs from the one used in the " - "pre-processed version of this dataset. If you want to make " - "use of another pre-fitering technique, make sure to delete " - "'{self.processed_dir}' first" - ) - - if files_exist(self.processed_paths): # pragma: no cover - return - - print("Processing...") - - makedirs(self.processed_dir) - self.process() - - path = osp.join(self.processed_dir, "pre_transform.pt") - torch.save(_repr(self.pre_transform), path) - path = osp.join(self.processed_dir, "pre_filter.pt") - torch.save(_repr(self.pre_filter), path) - - print("Done!") - - def __len__(self) -> int: - r"""The number of examples in the dataset.""" - return len(self.indices()) - - def __getitem__( - self, - idx: Union[int, np.integer, IndexType], - ) -> Union["Dataset", Data]: - r"""In case :obj:`idx` is of type integer, will return the data object - at index :obj:`idx` (and transforms it in case :obj:`transform` is - present). - In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a - tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy - :obj:`np.array`, will return a subset of the dataset at the specified - indices.""" - if ( - isinstance(idx, (int, np.integer)) - or (isinstance(idx, Tensor) and idx.dim() == 0) - or (isinstance(idx, np.ndarray) and np.isscalar(idx)) - ): - data = self.get(self.indices()[idx]) - data = data if self.transform is None else self.transform(data) - return data - - else: - return self.index_select(idx) - - def index_select(self, idx: IndexType) -> "Dataset": - indices = self.indices() - - if isinstance(idx, slice): - indices = indices[idx] - - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - idx = idx.flatten().nonzero(as_tuple=False) - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: - idx = idx.flatten().nonzero()[0] - return self.index_select(idx.flatten().tolist()) - - elif isinstance(idx, Sequence) and not isinstance(idx, str): - indices = [indices[i] for i in idx] - - else: - raise IndexError( - f"Only integers, slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')" - ) - - dataset = copy.copy(self) - dataset._indices = indices - return dataset - - def shuffle( - self, - return_perm: bool = False, - ) -> Union["Dataset", Tuple["Dataset", Tensor]]: - r"""Randomly shuffles the examples in the dataset. - - Args: - return_perm (bool, optional): If set to :obj:`True`, will return - the random permutation used to shuffle the dataset in addition. - (default: :obj:`False`) - """ - perm = torch.randperm(len(self)) - dataset = self.index_select(perm) - return (dataset, perm) if return_perm is True else dataset - - def __repr__(self) -> str: - arg_repr = str(len(self)) if len(self) > 1 else "" - return f"{self.__class__.__name__}({arg_repr})" - - -def to_list(value: Any) -> Sequence: - if isinstance(value, Sequence) and not isinstance(value, str): - return value - else: - return [value] - - -def files_exist(files: List[str]) -> bool: - # NOTE: We return `False` in case `files` is empty, leading to a - # re-processing of files on every instantiation. - return len(files) != 0 and all([osp.exists(f) for f in files]) - - -def _repr(obj: Any) -> str: - if obj is None: - return "None" - return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) +import copy +import os.path as osp +import re +import warnings +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np +import torch.utils.data +from torch import Tensor + +from .data import Data +from .utils import makedirs + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class Dataset(torch.utils.data.Dataset): + r"""Dataset base class for creating graph datasets. + See `here `__ for the accompanying tutorial. + + Args: + root (string, optional): Root directory where the dataset should be + saved. (optional: :obj:`None`) + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + """ + + @property + def raw_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.raw_dir` folder in + order to skip the download.""" + raise NotImplementedError + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple]: + r"""The name of the files to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + raise NotImplementedError + + def download(self): + r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" + raise NotImplementedError + + def process(self): + r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" + raise NotImplementedError + + def len(self) -> int: + raise NotImplementedError + + def get(self, idx: int) -> Data: + r"""Gets the data object at index :obj:`idx`.""" + raise NotImplementedError + + def __init__( + self, + root: Optional[str] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + ): + super().__init__() + + if isinstance(root, str): + root = osp.expanduser(osp.normpath(root)) + + self.root = root + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + self._indices: Optional[Sequence] = None + + if "download" in self.__class__.__dict__.keys(): + self._download() + + if "process" in self.__class__.__dict__.keys(): + self._process() + + def indices(self) -> Sequence: + return range(self.len()) if self._indices is None else self._indices + + @property + def raw_dir(self) -> str: + return osp.join(self.root, "raw") + + @property + def processed_dir(self) -> str: + return osp.join(self.root, "processed") + + @property + def num_node_features(self) -> int: + r"""Returns the number of features per node in the dataset.""" + data = self[0] + if hasattr(data, "num_node_features"): + return data.num_node_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_node_features'" + ) + + @property + def num_features(self) -> int: + r"""Alias for :py:attr:`~num_node_features`.""" + return self.num_node_features + + @property + def num_edge_features(self) -> int: + r"""Returns the number of features per edge in the dataset.""" + data = self[0] + if hasattr(data, "num_edge_features"): + return data.num_edge_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_edge_features'" + ) + + @property + def raw_paths(self) -> List[str]: + r"""The filepaths to find in order to skip the download.""" + files = to_list(self.raw_file_names) + return [osp.join(self.raw_dir, f) for f in files] + + @property + def processed_paths(self) -> List[str]: + r"""The filepaths to find in the :obj:`self.processed_dir` + folder in order to skip the processing.""" + files = to_list(self.processed_file_names) + return [osp.join(self.processed_dir, f) for f in files] + + def _download(self): + if files_exist(self.raw_paths): # pragma: no cover + return + + makedirs(self.raw_dir) + self.download() + + def _process(self): + f = osp.join(self.processed_dir, "pre_transform.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): + warnings.warn( + f"The `pre_transform` argument differs from the one used in " + f"the pre-processed version of this dataset. If you want to " + f"make use of another pre-processing technique, make sure to " + f"sure to delete '{self.processed_dir}' first" + ) + + f = osp.join(self.processed_dir, "pre_filter.pt") + if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): + warnings.warn( + "The `pre_filter` argument differs from the one used in the " + "pre-processed version of this dataset. If you want to make " + "use of another pre-fitering technique, make sure to delete " + "'{self.processed_dir}' first" + ) + + if files_exist(self.processed_paths): # pragma: no cover + return + + print("Processing...") + + makedirs(self.processed_dir) + self.process() + + path = osp.join(self.processed_dir, "pre_transform.pt") + torch.save(_repr(self.pre_transform), path) + path = osp.join(self.processed_dir, "pre_filter.pt") + torch.save(_repr(self.pre_filter), path) + + print("Done!") + + def __len__(self) -> int: + r"""The number of examples in the dataset.""" + return len(self.indices()) + + def __getitem__( + self, + idx: Union[int, np.integer, IndexType], + ) -> Union["Dataset", Data]: + r"""In case :obj:`idx` is of type integer, will return the data object + at index :obj:`idx` (and transforms it in case :obj:`transform` is + present). + In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a + tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy + :obj:`np.array`, will return a subset of the dataset at the specified + indices.""" + if ( + isinstance(idx, (int, np.integer)) + or (isinstance(idx, Tensor) and idx.dim() == 0) + or (isinstance(idx, np.ndarray) and np.isscalar(idx)) + ): + data = self.get(self.indices()[idx]) + data = data if self.transform is None else self.transform(data) + return data + + else: + return self.index_select(idx) + + def index_select(self, idx: IndexType) -> "Dataset": + indices = self.indices() + + if isinstance(idx, slice): + indices = indices[idx] + + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Tensor) and idx.dtype == torch.bool: + idx = idx.flatten().nonzero(as_tuple=False) + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + idx = idx.flatten().nonzero()[0] + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + indices = [indices[i] for i in idx] + + else: + raise IndexError( + f"Only integers, slices (':'), list, tuples, torch.tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + dataset = copy.copy(self) + dataset._indices = indices + return dataset + + def shuffle( + self, + return_perm: bool = False, + ) -> Union["Dataset", Tuple["Dataset", Tensor]]: + r"""Randomly shuffles the examples in the dataset. + + Args: + return_perm (bool, optional): If set to :obj:`True`, will return + the random permutation used to shuffle the dataset in addition. + (default: :obj:`False`) + """ + perm = torch.randperm(len(self)) + dataset = self.index_select(perm) + return (dataset, perm) if return_perm is True else dataset + + def __repr__(self) -> str: + arg_repr = str(len(self)) if len(self) > 1 else "" + return f"{self.__class__.__name__}({arg_repr})" + + +def to_list(value: Any) -> Sequence: + if isinstance(value, Sequence) and not isinstance(value, str): + return value + else: + return [value] + + +def files_exist(files: List[str]) -> bool: + # NOTE: We return `False` in case `files` is empty, leading to a + # re-processing of files on every instantiation. + return len(files) != 0 and all([osp.exists(f) for f in files]) + + +def _repr(obj: Any) -> str: + if obj is None: + return "None" + return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py index 6819fda..be27fca 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/seed.py @@ -1,17 +1,17 @@ -import random - -import numpy as np -import torch - - -def seed_everything(seed: int): - r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, - :obj:`numpy` and Python. - - Args: - seed (int): The desired seed. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) +import random + +import numpy as np +import torch + + +def seed_everything(seed: int): + r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`, + :obj:`numpy` and Python. + + Args: + seed (int): The desired seed. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py index 3efc354..f53b8f8 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_geometric/utils.py @@ -1,54 +1,54 @@ -import os -import os.path as osp -import ssl -import urllib -import zipfile - - -def makedirs(dir): - os.makedirs(dir, exist_ok=True) - - -def download_url(url, folder, log=True): - r"""Downloads the content of an URL to a specific folder. - - Args: - url (string): The url. - folder (string): The folder. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ - - filename = url.rpartition("/")[2].split("?")[0] - path = osp.join(folder, filename) - - if osp.exists(path): # pragma: no cover - if log: - print("Using exist file", filename) - return path - - if log: - print("Downloading", url) - - makedirs(folder) - - context = ssl._create_unverified_context() - data = urllib.request.urlopen(url, context=context) - - with open(path, "wb") as f: - f.write(data.read()) - - return path - - -def extract_zip(path, folder, log=True): - r"""Extracts a zip archive to a specific folder. - - Args: - path (string): The path to the tar archive. - folder (string): The folder. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ - with zipfile.ZipFile(path, "r") as f: - f.extractall(folder) +import os +import os.path as osp +import ssl +import urllib +import zipfile + + +def makedirs(dir): + os.makedirs(dir, exist_ok=True) + + +def download_url(url, folder, log=True): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (string): The url. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + + filename = url.rpartition("/")[2].split("?")[0] + path = osp.join(folder, filename) + + if osp.exists(path): # pragma: no cover + if log: + print("Using exist file", filename) + return path + + if log: + print("Downloading", url) + + makedirs(folder) + + context = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=context) + + with open(path, "wb") as f: + f.write(data.read()) + + return path + + +def extract_zip(path, folder, log=True): + r"""Extracts a zip archive to a specific folder. + + Args: + path (string): The path to the tar archive. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + with zipfile.ZipFile(path, "r") as f: + f.extractall(folder) diff --git a/mace-bench/3rdparty/mace/mace/tools/torch_tools.py b/mace-bench/3rdparty/mace/mace/tools/torch_tools.py index 380a1da..2ab339e 100644 --- a/mace-bench/3rdparty/mace/mace/tools/torch_tools.py +++ b/mace-bench/3rdparty/mace/mace/tools/torch_tools.py @@ -1,153 +1,153 @@ -########################################################################################### -# Tools for torch -# Authors: Ilyes Batatia, Gregor Simm -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import logging -from contextlib import contextmanager -from typing import Dict, Union - -import numpy as np -import torch -from e3nn.io import CartesianTensor - -TensorDict = Dict[str, torch.Tensor] - - -def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: - """ - Generates one-hot encoding with classes from - :param indices: (N x 1) tensor - :param num_classes: number of classes - :param device: torch device - :return: (N x num_classes) tensor - """ - shape = indices.shape[:-1] + (num_classes,) - oh = torch.zeros(shape, device=indices.device).view(shape) - - # scatter_ is the in-place version of scatter - oh.scatter_(dim=-1, index=indices, value=1) - - return oh.view(*shape) - - -def count_parameters(module: torch.nn.Module) -> int: - return int(sum(np.prod(p.shape) for p in module.parameters())) - - -def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: - return {k: v.to(device) if v is not None else None for k, v in td.items()} - - -def set_seeds(seed: int) -> None: - np.random.seed(seed) - torch.manual_seed(seed) - - -def to_numpy(t: torch.Tensor) -> np.ndarray: - return t.cpu().detach().numpy() - - -def init_device(device_str: str) -> torch.device: - if "cuda" in device_str: - assert torch.cuda.is_available(), "No CUDA device available!" - if ":" in device_str: - # Check if the desired device is available - assert int(device_str.split(":")[-1]) < torch.cuda.device_count() - logging.info( - f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" - ) - torch.cuda.init() - return torch.device(device_str) - if device_str == "mps": - assert torch.backends.mps.is_available(), "No MPS backend is available!" - logging.info("Using MPS GPU acceleration") - return torch.device("mps") - if device_str == "xpu": - torch.xpu.is_available() - return torch.device("xpu") - - logging.info("Using CPU") - return torch.device("cpu") - - -dtype_dict = {"float32": torch.float32, "float64": torch.float64} - - -def set_default_dtype(dtype: str) -> None: - torch.set_default_dtype(dtype_dict[dtype]) - - -def spherical_to_cartesian(t: torch.Tensor): - """ - Convert spherical notation to cartesian notation - """ - stress_cart_tensor = CartesianTensor("ij=ji") - stress_rtp = stress_cart_tensor.reduced_tensor_products() - return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) - - -def cartesian_to_spherical(t: torch.Tensor): - """ - Convert cartesian notation to spherical notation - """ - stress_cart_tensor = CartesianTensor("ij=ji") - stress_rtp = stress_cart_tensor.reduced_tensor_products() - return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) - - -def voigt_to_matrix(t: torch.Tensor): - """ - Convert voigt notation to matrix notation - :param t: (6,) tensor or (3, 3) tensor or (9,) tensor - :return: (3, 3) tensor - """ - if t.shape == (3, 3): - return t - if t.shape == (6,): - return torch.tensor( - [ - [t[0], t[5], t[4]], - [t[5], t[1], t[3]], - [t[4], t[3], t[2]], - ], - dtype=t.dtype, - ) - if t.shape == (9,): - return t.view(3, 3) - - raise ValueError( - f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" - ) - - -def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): - import wandb - - wandb.init( - project=project, - entity=entity, - name=name, - config=config, - dir=directory, - resume="allow", - ) - - -@contextmanager -def default_dtype(dtype: Union[torch.dtype, str]): - """Context manager for configuring the default_dtype used by torch - - Args: - dtype (torch.dtype|str): the default dtype to use within this context manager - """ - init = torch.get_default_dtype() - if isinstance(dtype, str): - set_default_dtype(dtype) - else: - torch.set_default_dtype(dtype) - - yield - - torch.set_default_dtype(init) +########################################################################################### +# Tools for torch +# Authors: Ilyes Batatia, Gregor Simm +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import logging +from contextlib import contextmanager +from typing import Dict, Union + +import numpy as np +import torch +from e3nn.io import CartesianTensor + +TensorDict = Dict[str, torch.Tensor] + + +def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Generates one-hot encoding with classes from + :param indices: (N x 1) tensor + :param num_classes: number of classes + :param device: torch device + :return: (N x num_classes) tensor + """ + shape = indices.shape[:-1] + (num_classes,) + oh = torch.zeros(shape, device=indices.device).view(shape) + + # scatter_ is the in-place version of scatter + oh.scatter_(dim=-1, index=indices, value=1) + + return oh.view(*shape) + + +def count_parameters(module: torch.nn.Module) -> int: + return int(sum(np.prod(p.shape) for p in module.parameters())) + + +def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: + return {k: v.to(device) if v is not None else None for k, v in td.items()} + + +def set_seeds(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + + +def to_numpy(t: torch.Tensor) -> np.ndarray: + return t.cpu().detach().numpy() + + +def init_device(device_str: str) -> torch.device: + if "cuda" in device_str: + assert torch.cuda.is_available(), "No CUDA device available!" + if ":" in device_str: + # Check if the desired device is available + assert int(device_str.split(":")[-1]) < torch.cuda.device_count() + logging.info( + f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" + ) + torch.cuda.init() + return torch.device(device_str) + if device_str == "mps": + assert torch.backends.mps.is_available(), "No MPS backend is available!" + logging.info("Using MPS GPU acceleration") + return torch.device("mps") + if device_str == "xpu": + torch.xpu.is_available() + return torch.device("xpu") + + logging.info("Using CPU") + return torch.device("cpu") + + +dtype_dict = {"float32": torch.float32, "float64": torch.float64} + + +def set_default_dtype(dtype: str) -> None: + torch.set_default_dtype(dtype_dict[dtype]) + + +def spherical_to_cartesian(t: torch.Tensor): + """ + Convert spherical notation to cartesian notation + """ + stress_cart_tensor = CartesianTensor("ij=ji") + stress_rtp = stress_cart_tensor.reduced_tensor_products() + return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) + + +def cartesian_to_spherical(t: torch.Tensor): + """ + Convert cartesian notation to spherical notation + """ + stress_cart_tensor = CartesianTensor("ij=ji") + stress_rtp = stress_cart_tensor.reduced_tensor_products() + return stress_cart_tensor.to_cartesian(t, rtp=stress_rtp) + + +def voigt_to_matrix(t: torch.Tensor): + """ + Convert voigt notation to matrix notation + :param t: (6,) tensor or (3, 3) tensor or (9,) tensor + :return: (3, 3) tensor + """ + if t.shape == (3, 3): + return t + if t.shape == (6,): + return torch.tensor( + [ + [t[0], t[5], t[4]], + [t[5], t[1], t[3]], + [t[4], t[3], t[2]], + ], + dtype=t.dtype, + ) + if t.shape == (9,): + return t.view(3, 3) + + raise ValueError( + f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" + ) + + +def init_wandb(project: str, entity: str, name: str, config: dict, directory: str): + import wandb + + wandb.init( + project=project, + entity=entity, + name=name, + config=config, + dir=directory, + resume="allow", + ) + + +@contextmanager +def default_dtype(dtype: Union[torch.dtype, str]): + """Context manager for configuring the default_dtype used by torch + + Args: + dtype (torch.dtype|str): the default dtype to use within this context manager + """ + init = torch.get_default_dtype() + if isinstance(dtype, str): + set_default_dtype(dtype) + else: + torch.set_default_dtype(dtype) + + yield + + torch.set_default_dtype(init) diff --git a/mace-bench/3rdparty/mace/mace/tools/train.py b/mace-bench/3rdparty/mace/mace/tools/train.py index 0c3916b..c7c17e1 100644 --- a/mace-bench/3rdparty/mace/mace/tools/train.py +++ b/mace-bench/3rdparty/mace/mace/tools/train.py @@ -1,669 +1,669 @@ -########################################################################################### -# Training script -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import dataclasses -import logging -import time -from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.distributed -from torch.nn.parallel import DistributedDataParallel -from torch.optim import LBFGS -from torch.optim.swa_utils import SWALR, AveragedModel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from torch_ema import ExponentialMovingAverage -from torchmetrics import Metric - -from mace.cli.visualise_train import TrainingPlotter - -from . import torch_geometric -from .checkpoint import CheckpointHandler, CheckpointState -from .torch_tools import to_numpy -from .utils import ( - MetricsLogger, - compute_mae, - compute_q95, - compute_rel_mae, - compute_rel_rmse, - compute_rmse, -) - - -@dataclasses.dataclass -class SWAContainer: - model: AveragedModel - scheduler: SWALR - start: int - loss_fn: torch.nn.Module - - -def valid_err_log( - valid_loss, - eval_metrics, - logger, - log_errors, - epoch=None, - valid_loader_name="Default", -): - eval_metrics["mode"] = "eval" - eval_metrics["epoch"] = epoch - eval_metrics["head"] = valid_loader_name - logger.log(eval_metrics) - if epoch is None: - inintial_phrase = "Initial" - else: - inintial_phrase = f"Epoch {epoch}" - if log_errors == "PerAtomRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A" - ) - elif ( - log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_stress"] is not None - ): - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_stress = eval_metrics["rmse_stress"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3", - ) - elif ( - log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_virials_per_atom"] is not None - ): - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV", - ) - elif ( - log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_stress_per_atom"] is not None - ): - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - error_stress = eval_metrics["mae_stress"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3" - ) - elif ( - log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_virials_per_atom"] is not None - ): - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - error_virials = eval_metrics["mae_virials"] * 1e3 - logging.info( - f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV" - ) - elif log_errors == "TotalRMSE": - error_e = eval_metrics["rmse_e"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A", - ) - elif log_errors == "PerAtomMAE": - error_e = eval_metrics["mae_e_per_atom"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", - ) - elif log_errors == "TotalMAE": - error_e = eval_metrics["mae_e"] * 1e3 - error_f = eval_metrics["mae_f"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", - ) - elif log_errors == "DipoleRMSE": - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", - ) - elif log_errors == "EnergyDipoleRMSE": - error_e = eval_metrics["rmse_e_per_atom"] * 1e3 - error_f = eval_metrics["rmse_f"] * 1e3 - error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 - logging.info( - f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", - ) - - -def train( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - train_loader: DataLoader, - valid_loaders: Dict[str, DataLoader], - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, - start_epoch: int, - max_num_epochs: int, - patience: int, - checkpoint_handler: CheckpointHandler, - logger: MetricsLogger, - eval_interval: int, - output_args: Dict[str, bool], - device: torch.device, - log_errors: str, - swa: Optional[SWAContainer] = None, - ema: Optional[ExponentialMovingAverage] = None, - max_grad_norm: Optional[float] = 10.0, - log_wandb: bool = False, - distributed: bool = False, - save_all_checkpoints: bool = False, - plotter: TrainingPlotter = None, - distributed_model: Optional[DistributedDataParallel] = None, - train_sampler: Optional[DistributedSampler] = None, - rank: Optional[int] = 0, -): - lowest_loss = np.inf - valid_loss = np.inf - patience_counter = 0 - swa_start = True - keep_last = False - if log_wandb: - import wandb - - if max_grad_norm is not None: - logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") - - logging.info("") - logging.info("===========TRAINING===========") - logging.info("Started training, reporting errors on validation set") - logging.info("Loss metrics on validation set") - epoch = start_epoch - - # log validation loss before _any_ training - for valid_loader_name, valid_loader in valid_loaders.items(): - valid_loss_head, eval_metrics = evaluate( - model=model, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - valid_err_log( - valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name - ) - valid_loss = valid_loss_head # consider only the last head for the checkpoint - - while epoch < max_num_epochs: - # LR scheduler and SWA update - if swa is None or epoch < swa.start: - if epoch > start_epoch: - lr_scheduler.step( - metrics=valid_loss - ) # Can break if exponential LR, TODO fix that! - else: - if swa_start: - logging.info("Changing loss based on Stage Two Weights") - lowest_loss = np.inf - swa_start = False - keep_last = True - loss_fn = swa.loss_fn - swa.model.update_parameters(model) - if epoch > start_epoch: - swa.scheduler.step() - - # Train - if distributed: - train_sampler.set_epoch(epoch) - if "ScheduleFree" in type(optimizer).__name__: - optimizer.train() - train_one_epoch( - model=model, - loss_fn=loss_fn, - data_loader=train_loader, - optimizer=optimizer, - epoch=epoch, - output_args=output_args, - max_grad_norm=max_grad_norm, - ema=ema, - logger=logger, - device=device, - distributed=distributed, - distributed_model=distributed_model, - rank=rank, - ) - if distributed: - torch.distributed.barrier() - - # Validate - if epoch % eval_interval == 0: - model_to_evaluate = ( - model if distributed_model is None else distributed_model - ) - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - if "ScheduleFree" in type(optimizer).__name__: - optimizer.eval() - with param_context: - wandb_log_dict = {} - for valid_loader_name, valid_loader in valid_loaders.items(): - valid_loss_head, eval_metrics = evaluate( - model=model_to_evaluate, - loss_fn=loss_fn, - data_loader=valid_loader, - output_args=output_args, - device=device, - ) - if rank == 0: - valid_err_log( - valid_loss_head, - eval_metrics, - logger, - log_errors, - epoch, - valid_loader_name, - ) - if log_wandb: - wandb_log_dict[valid_loader_name] = { - "epoch": epoch, - "valid_loss": valid_loss_head, - "valid_rmse_e_per_atom": eval_metrics[ - "rmse_e_per_atom" - ], - "valid_rmse_f": eval_metrics["rmse_f"], - } - if plotter and epoch % plotter.plot_frequency == 0: - try: - plotter.plot(epoch, model_to_evaluate, rank) - except Exception as e: # pylint: disable=broad-except - logging.debug(f"Plotting failed: {e}") - valid_loss = ( - valid_loss_head # consider only the last head for the checkpoint - ) - if log_wandb: - wandb.log(wandb_log_dict) - if rank == 0: - if valid_loss >= lowest_loss: - patience_counter += 1 - if patience_counter >= patience: - if swa is not None and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" - ) - epoch = swa.start - else: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break - if save_all_checkpoints: - param_context = ( - ema.average_parameters() - if ema is not None - else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=True, - ) - else: - lowest_loss = valid_loss - patience_counter = 0 - param_context = ( - ema.average_parameters() if ema is not None else nullcontext() - ) - with param_context: - checkpoint_handler.save( - state=CheckpointState(model, optimizer, lr_scheduler), - epochs=epoch, - keep_last=keep_last, - ) - keep_last = False or save_all_checkpoints - if distributed: - torch.distributed.barrier() - epoch += 1 - - logging.info("Training complete") - - -def train_one_epoch( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - data_loader: DataLoader, - optimizer: torch.optim.Optimizer, - epoch: int, - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - ema: Optional[ExponentialMovingAverage], - logger: MetricsLogger, - device: torch.device, - distributed: bool, - distributed_model: Optional[DistributedDataParallel] = None, - rank: Optional[int] = 0, -) -> None: - model_to_train = model if distributed_model is None else distributed_model - - if isinstance(optimizer, LBFGS): - _, opt_metrics = take_step_lbfgs( - model=model_to_train, - loss_fn=loss_fn, - data_loader=data_loader, - optimizer=optimizer, - ema=ema, - output_args=output_args, - max_grad_norm=max_grad_norm, - device=device, - distributed=distributed, - rank=rank, - ) - opt_metrics["mode"] = "opt" - opt_metrics["epoch"] = epoch - if rank == 0: - logger.log(opt_metrics) - else: - for batch in data_loader: - _, opt_metrics = take_step( - model=model_to_train, - loss_fn=loss_fn, - batch=batch, - optimizer=optimizer, - ema=ema, - output_args=output_args, - max_grad_norm=max_grad_norm, - device=device, - ) - opt_metrics["mode"] = "opt" - opt_metrics["epoch"] = epoch - if rank == 0: - logger.log(opt_metrics) - - -def take_step( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - batch: torch_geometric.batch.Batch, - optimizer: torch.optim.Optimizer, - ema: Optional[ExponentialMovingAverage], - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - device: torch.device, -) -> Tuple[float, Dict[str, Any]]: - start_time = time.time() - batch = batch.to(device) - batch_dict = batch.to_dict() - - def closure(): - optimizer.zero_grad(set_to_none=True) - output = model( - batch_dict, - training=True, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - loss = loss_fn(pred=output, ref=batch) - loss.backward() - if max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) - - return loss - - loss = closure() - optimizer.step() - - if ema is not None: - ema.update() - - loss_dict = { - "loss": to_numpy(loss), - "time": time.time() - start_time, - } - - return loss, loss_dict - - -def take_step_lbfgs( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - data_loader: DataLoader, - optimizer: torch.optim.Optimizer, - ema: Optional[ExponentialMovingAverage], - output_args: Dict[str, bool], - max_grad_norm: Optional[float], - device: torch.device, - distributed: bool, - rank: int, -) -> Tuple[float, Dict[str, Any]]: - start_time = time.time() - logging.debug( - f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB" - ) - - total_sample_count = 0 - for batch in data_loader: - total_sample_count += batch.num_graphs - - if distributed: - global_sample_count = torch.tensor(total_sample_count, device=device) - torch.distributed.all_reduce( - global_sample_count, op=torch.distributed.ReduceOp.SUM - ) - total_sample_count = global_sample_count.item() - - signal = torch.zeros(1, device=device) if distributed else None - - def closure(): - if distributed: - if rank == 0: - signal.fill_(1) - torch.distributed.broadcast(signal, src=0) - - for param in model.parameters(): - torch.distributed.broadcast(param.data, src=0) - - optimizer.zero_grad(set_to_none=True) - total_loss = torch.tensor(0.0, device=device) - - # Process each batch and then collect the results we pass to the optimizer - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=True, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - batch_loss = loss_fn(pred=output, ref=batch) - batch_loss = batch_loss * (batch.num_graphs / total_sample_count) - - batch_loss.backward() - total_loss += batch_loss - - if max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) - - if distributed: - torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM) - return total_loss - - if distributed: - if rank == 0: - loss = optimizer.step(closure) - signal.fill_(0) - torch.distributed.broadcast(signal, src=0) - else: - while True: - # Other ranks wait for signals from rank 0 - torch.distributed.broadcast(signal, src=0) - if signal.item() == 0: - break - if signal.item() == 1: - loss = closure() - - for param in model.parameters(): - torch.distributed.broadcast(param.data, src=0) - else: - loss = optimizer.step(closure) - - if ema is not None: - ema.update() - - loss_dict = { - "loss": to_numpy(loss), - "time": time.time() - start_time, - } - - return loss, loss_dict - - -def evaluate( - model: torch.nn.Module, - loss_fn: torch.nn.Module, - data_loader: DataLoader, - output_args: Dict[str, bool], - device: torch.device, -) -> Tuple[float, Dict[str, Any]]: - for param in model.parameters(): - param.requires_grad = False - - metrics = MACELoss(loss_fn=loss_fn).to(device) - - start_time = time.time() - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=False, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - avg_loss, aux = metrics(batch, output) - - avg_loss, aux = metrics.compute() - aux["time"] = time.time() - start_time - metrics.reset() - - for param in model.parameters(): - param.requires_grad = True - - return avg_loss, aux - - -class MACELoss(Metric): - def __init__(self, loss_fn: torch.nn.Module): - super().__init__() - self.loss_fn = loss_fn - self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("delta_es", default=[], dist_reduce_fx="cat") - self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("fs", default=[], dist_reduce_fx="cat") - self.add_state("delta_fs", default=[], dist_reduce_fx="cat") - self.add_state( - "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) - self.add_state("delta_stress", default=[], dist_reduce_fx="cat") - self.add_state( - "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" - ) - self.add_state("delta_virials", default=[], dist_reduce_fx="cat") - self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") - self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("mus", default=[], dist_reduce_fx="cat") - self.add_state("delta_mus", default=[], dist_reduce_fx="cat") - self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") - - def update(self, batch, output): # pylint: disable=arguments-differ - loss = self.loss_fn(pred=output, ref=batch) - self.total_loss += loss - self.num_data += batch.num_graphs - - if output.get("energy") is not None and batch.energy is not None: - self.E_computed += 1.0 - self.delta_es.append(batch.energy - output["energy"]) - self.delta_es_per_atom.append( - (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) - ) - if output.get("forces") is not None and batch.forces is not None: - self.Fs_computed += 1.0 - self.fs.append(batch.forces) - self.delta_fs.append(batch.forces - output["forces"]) - if output.get("stress") is not None and batch.stress is not None: - self.stress_computed += 1.0 - self.delta_stress.append(batch.stress - output["stress"]) - if output.get("virials") is not None and batch.virials is not None: - self.virials_computed += 1.0 - self.delta_virials.append(batch.virials - output["virials"]) - self.delta_virials_per_atom.append( - (batch.virials - output["virials"]) - / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) - ) - if output.get("dipole") is not None and batch.dipole is not None: - self.Mus_computed += 1.0 - self.mus.append(batch.dipole) - self.delta_mus.append(batch.dipole - output["dipole"]) - self.delta_mus_per_atom.append( - (batch.dipole - output["dipole"]) - / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) - ) - - def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: - if isinstance(delta, list): - delta = torch.cat(delta) - return to_numpy(delta) - - def compute(self): - aux = {} - aux["loss"] = to_numpy(self.total_loss / self.num_data).item() - if self.E_computed: - delta_es = self.convert(self.delta_es) - delta_es_per_atom = self.convert(self.delta_es_per_atom) - aux["mae_e"] = compute_mae(delta_es) - aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) - aux["rmse_e"] = compute_rmse(delta_es) - aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) - aux["q95_e"] = compute_q95(delta_es) - if self.Fs_computed: - fs = self.convert(self.fs) - delta_fs = self.convert(self.delta_fs) - aux["mae_f"] = compute_mae(delta_fs) - aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) - aux["rmse_f"] = compute_rmse(delta_fs) - aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) - aux["q95_f"] = compute_q95(delta_fs) - if self.stress_computed: - delta_stress = self.convert(self.delta_stress) - aux["mae_stress"] = compute_mae(delta_stress) - aux["rmse_stress"] = compute_rmse(delta_stress) - aux["q95_stress"] = compute_q95(delta_stress) - if self.virials_computed: - delta_virials = self.convert(self.delta_virials) - delta_virials_per_atom = self.convert(self.delta_virials_per_atom) - aux["mae_virials"] = compute_mae(delta_virials) - aux["rmse_virials"] = compute_rmse(delta_virials) - aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) - aux["q95_virials"] = compute_q95(delta_virials) - if self.Mus_computed: - mus = self.convert(self.mus) - delta_mus = self.convert(self.delta_mus) - delta_mus_per_atom = self.convert(self.delta_mus_per_atom) - aux["mae_mu"] = compute_mae(delta_mus) - aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) - aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) - aux["rmse_mu"] = compute_rmse(delta_mus) - aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) - aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) - aux["q95_mu"] = compute_q95(delta_mus) - - return aux["loss"], aux +########################################################################################### +# Training script +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import dataclasses +import logging +import time +from contextlib import nullcontext +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed +from torch.nn.parallel import DistributedDataParallel +from torch.optim import LBFGS +from torch.optim.swa_utils import SWALR, AveragedModel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch_ema import ExponentialMovingAverage +from torchmetrics import Metric + +from mace.cli.visualise_train import TrainingPlotter + +from . import torch_geometric +from .checkpoint import CheckpointHandler, CheckpointState +from .torch_tools import to_numpy +from .utils import ( + MetricsLogger, + compute_mae, + compute_q95, + compute_rel_mae, + compute_rel_rmse, + compute_rmse, +) + + +@dataclasses.dataclass +class SWAContainer: + model: AveragedModel + scheduler: SWALR + start: int + loss_fn: torch.nn.Module + + +def valid_err_log( + valid_loss, + eval_metrics, + logger, + log_errors, + epoch=None, + valid_loader_name="Default", +): + eval_metrics["mode"] = "eval" + eval_metrics["epoch"] = epoch + eval_metrics["head"] = valid_loader_name + logger.log(eval_metrics) + if epoch is None: + inintial_phrase = "Initial" + else: + inintial_phrase = f"Epoch {epoch}" + if log_errors == "PerAtomRMSE": + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A" + ) + elif ( + log_errors == "PerAtomRMSEstressvirials" + and eval_metrics["rmse_stress"] is not None + ): + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_stress = eval_metrics["rmse_stress"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_stress={error_stress:8.2f} meV / A^3", + ) + elif ( + log_errors == "PerAtomRMSEstressvirials" + and eval_metrics["rmse_virials_per_atom"] is not None + ): + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_virials = eval_metrics["rmse_virials_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_virials_per_atom={error_virials:8.2f} meV", + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_stress_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_stress = eval_metrics["mae_stress"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_stress={error_stress:8.2f} meV / A^3" + ) + elif ( + log_errors == "PerAtomMAEstressvirials" + and eval_metrics["mae_virials_per_atom"] is not None + ): + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + error_virials = eval_metrics["mae_virials"] * 1e3 + logging.info( + f"{inintial_phrase}: loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A, MAE_virials={error_virials:8.2f} meV" + ) + elif log_errors == "TotalRMSE": + error_e = eval_metrics["rmse_e"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "PerAtomMAE": + error_e = eval_metrics["mae_e_per_atom"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E_per_atom={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "TotalMAE": + error_e = eval_metrics["mae_e"] * 1e3 + error_f = eval_metrics["mae_f"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, MAE_E={error_e:8.2f} meV, MAE_F={error_f:8.2f} meV / A", + ) + elif log_errors == "DipoleRMSE": + error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_MU_per_atom={error_mu:8.2f} mDebye", + ) + elif log_errors == "EnergyDipoleRMSE": + error_e = eval_metrics["rmse_e_per_atom"] * 1e3 + error_f = eval_metrics["rmse_f"] * 1e3 + error_mu = eval_metrics["rmse_mu_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.8f}, RMSE_E_per_atom={error_e:8.2f} meV, RMSE_F={error_f:8.2f} meV / A, RMSE_Mu_per_atom={error_mu:8.2f} mDebye", + ) + + +def train( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + train_loader: DataLoader, + valid_loaders: Dict[str, DataLoader], + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler.ExponentialLR, + start_epoch: int, + max_num_epochs: int, + patience: int, + checkpoint_handler: CheckpointHandler, + logger: MetricsLogger, + eval_interval: int, + output_args: Dict[str, bool], + device: torch.device, + log_errors: str, + swa: Optional[SWAContainer] = None, + ema: Optional[ExponentialMovingAverage] = None, + max_grad_norm: Optional[float] = 10.0, + log_wandb: bool = False, + distributed: bool = False, + save_all_checkpoints: bool = False, + plotter: TrainingPlotter = None, + distributed_model: Optional[DistributedDataParallel] = None, + train_sampler: Optional[DistributedSampler] = None, + rank: Optional[int] = 0, +): + lowest_loss = np.inf + valid_loss = np.inf + patience_counter = 0 + swa_start = True + keep_last = False + if log_wandb: + import wandb + + if max_grad_norm is not None: + logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}") + + logging.info("") + logging.info("===========TRAINING===========") + logging.info("Started training, reporting errors on validation set") + logging.info("Loss metrics on validation set") + epoch = start_epoch + + # log validation loss before _any_ training + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + valid_err_log( + valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name + ) + valid_loss = valid_loss_head # consider only the last head for the checkpoint + + while epoch < max_num_epochs: + # LR scheduler and SWA update + if swa is None or epoch < swa.start: + if epoch > start_epoch: + lr_scheduler.step( + metrics=valid_loss + ) # Can break if exponential LR, TODO fix that! + else: + if swa_start: + logging.info("Changing loss based on Stage Two Weights") + lowest_loss = np.inf + swa_start = False + keep_last = True + loss_fn = swa.loss_fn + swa.model.update_parameters(model) + if epoch > start_epoch: + swa.scheduler.step() + + # Train + if distributed: + train_sampler.set_epoch(epoch) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.train() + train_one_epoch( + model=model, + loss_fn=loss_fn, + data_loader=train_loader, + optimizer=optimizer, + epoch=epoch, + output_args=output_args, + max_grad_norm=max_grad_norm, + ema=ema, + logger=logger, + device=device, + distributed=distributed, + distributed_model=distributed_model, + rank=rank, + ) + if distributed: + torch.distributed.barrier() + + # Validate + if epoch % eval_interval == 0: + model_to_evaluate = ( + model if distributed_model is None else distributed_model + ) + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + if "ScheduleFree" in type(optimizer).__name__: + optimizer.eval() + with param_context: + wandb_log_dict = {} + for valid_loader_name, valid_loader in valid_loaders.items(): + valid_loss_head, eval_metrics = evaluate( + model=model_to_evaluate, + loss_fn=loss_fn, + data_loader=valid_loader, + output_args=output_args, + device=device, + ) + if rank == 0: + valid_err_log( + valid_loss_head, + eval_metrics, + logger, + log_errors, + epoch, + valid_loader_name, + ) + if log_wandb: + wandb_log_dict[valid_loader_name] = { + "epoch": epoch, + "valid_loss": valid_loss_head, + "valid_rmse_e_per_atom": eval_metrics[ + "rmse_e_per_atom" + ], + "valid_rmse_f": eval_metrics["rmse_f"], + } + if plotter and epoch % plotter.plot_frequency == 0: + try: + plotter.plot(epoch, model_to_evaluate, rank) + except Exception as e: # pylint: disable=broad-except + logging.debug(f"Plotting failed: {e}") + valid_loss = ( + valid_loss_head # consider only the last head for the checkpoint + ) + if log_wandb: + wandb.log(wandb_log_dict) + if rank == 0: + if valid_loss >= lowest_loss: + patience_counter += 1 + if patience_counter >= patience: + if swa is not None and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" + ) + epoch = swa.start + else: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + break + if save_all_checkpoints: + param_context = ( + ema.average_parameters() + if ema is not None + else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=True, + ) + else: + lowest_loss = valid_loss + patience_counter = 0 + param_context = ( + ema.average_parameters() if ema is not None else nullcontext() + ) + with param_context: + checkpoint_handler.save( + state=CheckpointState(model, optimizer, lr_scheduler), + epochs=epoch, + keep_last=keep_last, + ) + keep_last = False or save_all_checkpoints + if distributed: + torch.distributed.barrier() + epoch += 1 + + logging.info("Training complete") + + +def train_one_epoch( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + optimizer: torch.optim.Optimizer, + epoch: int, + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + ema: Optional[ExponentialMovingAverage], + logger: MetricsLogger, + device: torch.device, + distributed: bool, + distributed_model: Optional[DistributedDataParallel] = None, + rank: Optional[int] = 0, +) -> None: + model_to_train = model if distributed_model is None else distributed_model + + if isinstance(optimizer, LBFGS): + _, opt_metrics = take_step_lbfgs( + model=model_to_train, + loss_fn=loss_fn, + data_loader=data_loader, + optimizer=optimizer, + ema=ema, + output_args=output_args, + max_grad_norm=max_grad_norm, + device=device, + distributed=distributed, + rank=rank, + ) + opt_metrics["mode"] = "opt" + opt_metrics["epoch"] = epoch + if rank == 0: + logger.log(opt_metrics) + else: + for batch in data_loader: + _, opt_metrics = take_step( + model=model_to_train, + loss_fn=loss_fn, + batch=batch, + optimizer=optimizer, + ema=ema, + output_args=output_args, + max_grad_norm=max_grad_norm, + device=device, + ) + opt_metrics["mode"] = "opt" + opt_metrics["epoch"] = epoch + if rank == 0: + logger.log(opt_metrics) + + +def take_step( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + batch: torch_geometric.batch.Batch, + optimizer: torch.optim.Optimizer, + ema: Optional[ExponentialMovingAverage], + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + device: torch.device, +) -> Tuple[float, Dict[str, Any]]: + start_time = time.time() + batch = batch.to(device) + batch_dict = batch.to_dict() + + def closure(): + optimizer.zero_grad(set_to_none=True) + output = model( + batch_dict, + training=True, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + loss = loss_fn(pred=output, ref=batch) + loss.backward() + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) + + return loss + + loss = closure() + optimizer.step() + + if ema is not None: + ema.update() + + loss_dict = { + "loss": to_numpy(loss), + "time": time.time() - start_time, + } + + return loss, loss_dict + + +def take_step_lbfgs( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + optimizer: torch.optim.Optimizer, + ema: Optional[ExponentialMovingAverage], + output_args: Dict[str, bool], + max_grad_norm: Optional[float], + device: torch.device, + distributed: bool, + rank: int, +) -> Tuple[float, Dict[str, Any]]: + start_time = time.time() + logging.debug( + f"Max Allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB" + ) + + total_sample_count = 0 + for batch in data_loader: + total_sample_count += batch.num_graphs + + if distributed: + global_sample_count = torch.tensor(total_sample_count, device=device) + torch.distributed.all_reduce( + global_sample_count, op=torch.distributed.ReduceOp.SUM + ) + total_sample_count = global_sample_count.item() + + signal = torch.zeros(1, device=device) if distributed else None + + def closure(): + if distributed: + if rank == 0: + signal.fill_(1) + torch.distributed.broadcast(signal, src=0) + + for param in model.parameters(): + torch.distributed.broadcast(param.data, src=0) + + optimizer.zero_grad(set_to_none=True) + total_loss = torch.tensor(0.0, device=device) + + # Process each batch and then collect the results we pass to the optimizer + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=True, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + batch_loss = loss_fn(pred=output, ref=batch) + batch_loss = batch_loss * (batch.num_graphs / total_sample_count) + + batch_loss.backward() + total_loss += batch_loss + + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) + + if distributed: + torch.distributed.all_reduce(total_loss, op=torch.distributed.ReduceOp.SUM) + return total_loss + + if distributed: + if rank == 0: + loss = optimizer.step(closure) + signal.fill_(0) + torch.distributed.broadcast(signal, src=0) + else: + while True: + # Other ranks wait for signals from rank 0 + torch.distributed.broadcast(signal, src=0) + if signal.item() == 0: + break + if signal.item() == 1: + loss = closure() + + for param in model.parameters(): + torch.distributed.broadcast(param.data, src=0) + else: + loss = optimizer.step(closure) + + if ema is not None: + ema.update() + + loss_dict = { + "loss": to_numpy(loss), + "time": time.time() - start_time, + } + + return loss, loss_dict + + +def evaluate( + model: torch.nn.Module, + loss_fn: torch.nn.Module, + data_loader: DataLoader, + output_args: Dict[str, bool], + device: torch.device, +) -> Tuple[float, Dict[str, Any]]: + for param in model.parameters(): + param.requires_grad = False + + metrics = MACELoss(loss_fn=loss_fn).to(device) + + start_time = time.time() + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=False, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + avg_loss, aux = metrics(batch, output) + + avg_loss, aux = metrics.compute() + aux["time"] = time.time() - start_time + metrics.reset() + + for param in model.parameters(): + param.requires_grad = True + + return avg_loss, aux + + +class MACELoss(Metric): + def __init__(self, loss_fn: torch.nn.Module): + super().__init__() + self.loss_fn = loss_fn + self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("num_data", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("E_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("delta_es", default=[], dist_reduce_fx="cat") + self.add_state("delta_es_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("Fs_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("fs", default=[], dist_reduce_fx="cat") + self.add_state("delta_fs", default=[], dist_reduce_fx="cat") + self.add_state( + "stress_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("delta_stress", default=[], dist_reduce_fx="cat") + self.add_state( + "virials_computed", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("delta_virials", default=[], dist_reduce_fx="cat") + self.add_state("delta_virials_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("Mus_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("mus", default=[], dist_reduce_fx="cat") + self.add_state("delta_mus", default=[], dist_reduce_fx="cat") + self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") + + def update(self, batch, output): # pylint: disable=arguments-differ + loss = self.loss_fn(pred=output, ref=batch) + self.total_loss += loss + self.num_data += batch.num_graphs + + if output.get("energy") is not None and batch.energy is not None: + self.E_computed += 1.0 + self.delta_es.append(batch.energy - output["energy"]) + self.delta_es_per_atom.append( + (batch.energy - output["energy"]) / (batch.ptr[1:] - batch.ptr[:-1]) + ) + if output.get("forces") is not None and batch.forces is not None: + self.Fs_computed += 1.0 + self.fs.append(batch.forces) + self.delta_fs.append(batch.forces - output["forces"]) + if output.get("stress") is not None and batch.stress is not None: + self.stress_computed += 1.0 + self.delta_stress.append(batch.stress - output["stress"]) + if output.get("virials") is not None and batch.virials is not None: + self.virials_computed += 1.0 + self.delta_virials.append(batch.virials - output["virials"]) + self.delta_virials_per_atom.append( + (batch.virials - output["virials"]) + / (batch.ptr[1:] - batch.ptr[:-1]).view(-1, 1, 1) + ) + if output.get("dipole") is not None and batch.dipole is not None: + self.Mus_computed += 1.0 + self.mus.append(batch.dipole) + self.delta_mus.append(batch.dipole - output["dipole"]) + self.delta_mus_per_atom.append( + (batch.dipole - output["dipole"]) + / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) + ) + + def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: + if isinstance(delta, list): + delta = torch.cat(delta) + return to_numpy(delta) + + def compute(self): + aux = {} + aux["loss"] = to_numpy(self.total_loss / self.num_data).item() + if self.E_computed: + delta_es = self.convert(self.delta_es) + delta_es_per_atom = self.convert(self.delta_es_per_atom) + aux["mae_e"] = compute_mae(delta_es) + aux["mae_e_per_atom"] = compute_mae(delta_es_per_atom) + aux["rmse_e"] = compute_rmse(delta_es) + aux["rmse_e_per_atom"] = compute_rmse(delta_es_per_atom) + aux["q95_e"] = compute_q95(delta_es) + if self.Fs_computed: + fs = self.convert(self.fs) + delta_fs = self.convert(self.delta_fs) + aux["mae_f"] = compute_mae(delta_fs) + aux["rel_mae_f"] = compute_rel_mae(delta_fs, fs) + aux["rmse_f"] = compute_rmse(delta_fs) + aux["rel_rmse_f"] = compute_rel_rmse(delta_fs, fs) + aux["q95_f"] = compute_q95(delta_fs) + if self.stress_computed: + delta_stress = self.convert(self.delta_stress) + aux["mae_stress"] = compute_mae(delta_stress) + aux["rmse_stress"] = compute_rmse(delta_stress) + aux["q95_stress"] = compute_q95(delta_stress) + if self.virials_computed: + delta_virials = self.convert(self.delta_virials) + delta_virials_per_atom = self.convert(self.delta_virials_per_atom) + aux["mae_virials"] = compute_mae(delta_virials) + aux["rmse_virials"] = compute_rmse(delta_virials) + aux["rmse_virials_per_atom"] = compute_rmse(delta_virials_per_atom) + aux["q95_virials"] = compute_q95(delta_virials) + if self.Mus_computed: + mus = self.convert(self.mus) + delta_mus = self.convert(self.delta_mus) + delta_mus_per_atom = self.convert(self.delta_mus_per_atom) + aux["mae_mu"] = compute_mae(delta_mus) + aux["mae_mu_per_atom"] = compute_mae(delta_mus_per_atom) + aux["rel_mae_mu"] = compute_rel_mae(delta_mus, mus) + aux["rmse_mu"] = compute_rmse(delta_mus) + aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) + aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) + aux["q95_mu"] = compute_q95(delta_mus) + + return aux["loss"], aux diff --git a/mace-bench/3rdparty/mace/mace/tools/utils.py b/mace-bench/3rdparty/mace/mace/tools/utils.py index 12878e1..1b1b55b 100644 --- a/mace-bench/3rdparty/mace/mace/tools/utils.py +++ b/mace-bench/3rdparty/mace/mace/tools/utils.py @@ -1,166 +1,166 @@ -########################################################################################### -# Statistics utilities -# Authors: Ilyes Batatia, Gregor Simm, David Kovacs -# This program is distributed under the MIT License (see MIT.md) -########################################################################################### - -import json -import logging -import os -import sys -from typing import Any, Dict, Iterable, Optional, Sequence, Union - -import numpy as np -import torch - -from .torch_tools import to_numpy - - -def compute_mae(delta: np.ndarray) -> float: - return np.mean(np.abs(delta)).item() - - -def compute_rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float: - target_norm = np.mean(np.abs(target_val)) - return np.mean(np.abs(delta)).item() / (target_norm + 1e-9) * 100 - - -def compute_rmse(delta: np.ndarray) -> float: - return np.sqrt(np.mean(np.square(delta))).item() - - -def compute_rel_rmse(delta: np.ndarray, target_val: np.ndarray) -> float: - target_norm = np.sqrt(np.mean(np.square(target_val))).item() - return np.sqrt(np.mean(np.square(delta))).item() / (target_norm + 1e-9) * 100 - - -def compute_q95(delta: np.ndarray) -> float: - return np.percentile(np.abs(delta), q=95) - - -def compute_c(delta: np.ndarray, eta: float) -> float: - return np.mean(np.abs(delta) < eta).item() - - -def get_tag(name: str, seed: int) -> str: - return f"{name}_run-{seed}" - - -def setup_logger( - level: Union[int, str] = logging.INFO, - tag: Optional[str] = None, - directory: Optional[str] = None, - rank: Optional[int] = 0, -): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels - - # Create formatters - formatter = logging.Formatter( - "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - # Add filter for rank - logger.addFilter(lambda _: rank == 0) - - # Create console handler - ch = logging.StreamHandler(stream=sys.stdout) - ch.setLevel(level) - ch.setFormatter(formatter) - logger.addHandler(ch) - - if directory is not None and tag is not None: - os.makedirs(name=directory, exist_ok=True) - - # Create file handler for non-debug logs - main_log_path = os.path.join(directory, f"{tag}.log") - fh_main = logging.FileHandler(main_log_path) - fh_main.setLevel(level) - fh_main.setFormatter(formatter) - logger.addHandler(fh_main) - - # Create file handler for debug logs - debug_log_path = os.path.join(directory, f"{tag}_debug.log") - fh_debug = logging.FileHandler(debug_log_path) - fh_debug.setLevel(logging.DEBUG) - fh_debug.setFormatter(formatter) - fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) - logger.addHandler(fh_debug) - - -class AtomicNumberTable: - def __init__(self, zs: Sequence[int]): - self.zs = zs - - def __len__(self) -> int: - return len(self.zs) - - def __str__(self): - return f"AtomicNumberTable: {tuple(s for s in self.zs)}" - - def index_to_z(self, index: int) -> int: - return self.zs[index] - - def z_to_index(self, atomic_number: str) -> int: - return self.zs.index(atomic_number) - - -def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: - z_set = set() - for z in zs: - z_set.add(z) - return AtomicNumberTable(sorted(list(z_set))) - - -def atomic_numbers_to_indices( - atomic_numbers: np.ndarray, z_table: AtomicNumberTable -) -> np.ndarray: - to_index_fn = np.vectorize(z_table.z_to_index) - return to_index_fn(atomic_numbers) - - -class UniversalEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, np.integer): - return int(o) - if isinstance(o, np.floating): - return float(o) - if isinstance(o, np.ndarray): - return o.tolist() - if isinstance(o, torch.Tensor): - return to_numpy(o) - return json.JSONEncoder.default(self, o) - - -class MetricsLogger: - def __init__(self, directory: str, tag: str) -> None: - self.directory = directory - self.filename = tag + ".txt" - self.path = os.path.join(self.directory, self.filename) - - def log(self, d: Dict[str, Any]) -> None: - os.makedirs(name=self.directory, exist_ok=True) - with open(self.path, mode="a", encoding="utf-8") as f: - f.write(json.dumps(d, cls=UniversalEncoder)) - f.write("\n") - - -# pylint: disable=abstract-method, arguments-differ -class LAMMPS_MP(torch.autograd.Function): - @staticmethod - def forward(ctx, *args): - feats, data = args # unpack - ctx.vec_len = feats.shape[-1] - ctx.data = data - out = torch.empty_like(feats) - data.forward_exchange(feats, out, ctx.vec_len) - return out - - @staticmethod - def backward(ctx, *grad_outputs): - (grad,) = grad_outputs # unpack - gout = torch.empty_like(grad) - ctx.data.reverse_exchange(grad, gout, ctx.vec_len) - return gout, None +########################################################################################### +# Statistics utilities +# Authors: Ilyes Batatia, Gregor Simm, David Kovacs +# This program is distributed under the MIT License (see MIT.md) +########################################################################################### + +import json +import logging +import os +import sys +from typing import Any, Dict, Iterable, Optional, Sequence, Union + +import numpy as np +import torch + +from .torch_tools import to_numpy + + +def compute_mae(delta: np.ndarray) -> float: + return np.mean(np.abs(delta)).item() + + +def compute_rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float: + target_norm = np.mean(np.abs(target_val)) + return np.mean(np.abs(delta)).item() / (target_norm + 1e-9) * 100 + + +def compute_rmse(delta: np.ndarray) -> float: + return np.sqrt(np.mean(np.square(delta))).item() + + +def compute_rel_rmse(delta: np.ndarray, target_val: np.ndarray) -> float: + target_norm = np.sqrt(np.mean(np.square(target_val))).item() + return np.sqrt(np.mean(np.square(delta))).item() / (target_norm + 1e-9) * 100 + + +def compute_q95(delta: np.ndarray) -> float: + return np.percentile(np.abs(delta), q=95) + + +def compute_c(delta: np.ndarray, eta: float) -> float: + return np.mean(np.abs(delta) < eta).item() + + +def get_tag(name: str, seed: int) -> str: + return f"{name}_run-{seed}" + + +def setup_logger( + level: Union[int, str] = logging.INFO, + tag: Optional[str] = None, + directory: Optional[str] = None, + rank: Optional[int] = 0, +): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels + + # Create formatters + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Add filter for rank + logger.addFilter(lambda _: rank == 0) + + # Create console handler + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + if directory is not None and tag is not None: + os.makedirs(name=directory, exist_ok=True) + + # Create file handler for non-debug logs + main_log_path = os.path.join(directory, f"{tag}.log") + fh_main = logging.FileHandler(main_log_path) + fh_main.setLevel(level) + fh_main.setFormatter(formatter) + logger.addHandler(fh_main) + + # Create file handler for debug logs + debug_log_path = os.path.join(directory, f"{tag}_debug.log") + fh_debug = logging.FileHandler(debug_log_path) + fh_debug.setLevel(logging.DEBUG) + fh_debug.setFormatter(formatter) + fh_debug.addFilter(lambda record: record.levelno >= logging.DEBUG) + logger.addHandler(fh_debug) + + +class AtomicNumberTable: + def __init__(self, zs: Sequence[int]): + self.zs = zs + + def __len__(self) -> int: + return len(self.zs) + + def __str__(self): + return f"AtomicNumberTable: {tuple(s for s in self.zs)}" + + def index_to_z(self, index: int) -> int: + return self.zs[index] + + def z_to_index(self, atomic_number: str) -> int: + return self.zs.index(atomic_number) + + +def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: + z_set = set() + for z in zs: + z_set.add(z) + return AtomicNumberTable(sorted(list(z_set))) + + +def atomic_numbers_to_indices( + atomic_numbers: np.ndarray, z_table: AtomicNumberTable +) -> np.ndarray: + to_index_fn = np.vectorize(z_table.z_to_index) + return to_index_fn(atomic_numbers) + + +class UniversalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, np.integer): + return int(o) + if isinstance(o, np.floating): + return float(o) + if isinstance(o, np.ndarray): + return o.tolist() + if isinstance(o, torch.Tensor): + return to_numpy(o) + return json.JSONEncoder.default(self, o) + + +class MetricsLogger: + def __init__(self, directory: str, tag: str) -> None: + self.directory = directory + self.filename = tag + ".txt" + self.path = os.path.join(self.directory, self.filename) + + def log(self, d: Dict[str, Any]) -> None: + os.makedirs(name=self.directory, exist_ok=True) + with open(self.path, mode="a", encoding="utf-8") as f: + f.write(json.dumps(d, cls=UniversalEncoder)) + f.write("\n") + + +# pylint: disable=abstract-method, arguments-differ +class LAMMPS_MP(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + feats, data = args # unpack + ctx.vec_len = feats.shape[-1] + ctx.data = data + out = torch.empty_like(feats) + data.forward_exchange(feats, out, ctx.vec_len) + return out + + @staticmethod + def backward(ctx, *grad_outputs): + (grad,) = grad_outputs # unpack + gout = torch.empty_like(grad) + ctx.data.reverse_exchange(grad, gout, ctx.vec_len) + return gout, None diff --git a/mace-bench/3rdparty/mace/scripts/eval_configs.py b/mace-bench/3rdparty/mace/scripts/eval_configs.py index 350d804..d2f4e21 100644 --- a/mace-bench/3rdparty/mace/scripts/eval_configs.py +++ b/mace-bench/3rdparty/mace/scripts/eval_configs.py @@ -1,6 +1,6 @@ -## Wrapper for mace.cli.eval_configs.main ## - -from mace.cli.eval_configs import main - -if __name__ == "__main__": - main() +## Wrapper for mace.cli.eval_configs.main ## + +from mace.cli.eval_configs import main + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/scripts/preprocess_data.py b/mace-bench/3rdparty/mace/scripts/preprocess_data.py index be11345..3c2c288 100644 --- a/mace-bench/3rdparty/mace/scripts/preprocess_data.py +++ b/mace-bench/3rdparty/mace/scripts/preprocess_data.py @@ -1,6 +1,6 @@ -## Wrapper for mace.cli.run_train.main ## - -from mace.cli.preprocess_data import main - -if __name__ == "__main__": - main() +## Wrapper for mace.cli.run_train.main ## + +from mace.cli.preprocess_data import main + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/scripts/run_checks.sh b/mace-bench/3rdparty/mace/scripts/run_checks.sh index e2e073b..bd1214a 100644 --- a/mace-bench/3rdparty/mace/scripts/run_checks.sh +++ b/mace-bench/3rdparty/mace/scripts/run_checks.sh @@ -1,9 +1,9 @@ -# Format -python -m black . -python -m isort . - -# Check -python -m pylint --rcfile=pyproject.toml mace tests scripts - -# Tests -python -m pytest tests +# Format +python -m black . +python -m isort . + +# Check +python -m pylint --rcfile=pyproject.toml mace tests scripts + +# Tests +python -m pytest tests diff --git a/mace-bench/3rdparty/mace/scripts/run_train.py b/mace-bench/3rdparty/mace/scripts/run_train.py index 77d53b0..d14952d 100644 --- a/mace-bench/3rdparty/mace/scripts/run_train.py +++ b/mace-bench/3rdparty/mace/scripts/run_train.py @@ -1,6 +1,6 @@ -## Wrapper for mace.cli.run_train.main ## - -from mace.cli.run_train import main - -if __name__ == "__main__": - main() +## Wrapper for mace.cli.run_train.main ## + +from mace.cli.run_train import main + +if __name__ == "__main__": + main() diff --git a/mace-bench/3rdparty/mace/tests/__init__.py b/mace-bench/3rdparty/mace/tests/__init__.py index 9ff3a03..96ae777 100644 --- a/mace-bench/3rdparty/mace/tests/__init__.py +++ b/mace-bench/3rdparty/mace/tests/__init__.py @@ -1,3 +1,3 @@ -import os - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" +import os + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" diff --git a/mace-bench/3rdparty/mace/tests/modules/test_radial.py b/mace-bench/3rdparty/mace/tests/modules/test_radial.py index 402dc46..3aef254 100644 --- a/mace-bench/3rdparty/mace/tests/modules/test_radial.py +++ b/mace-bench/3rdparty/mace/tests/modules/test_radial.py @@ -1,95 +1,95 @@ -import pytest -import torch - -from mace.modules.radial import AgnesiTransform, ZBLBasis - - -@pytest.fixture -def zbl_basis(): - return ZBLBasis(p=6, trainable=False) - - -def test_zbl_basis_initialization(zbl_basis): - assert zbl_basis.p == torch.tensor(6.0) - assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) - - assert zbl_basis.a_exp == torch.tensor(0.300) - assert zbl_basis.a_prefactor == torch.tensor(0.4543) - assert not zbl_basis.a_exp.requires_grad - assert not zbl_basis.a_prefactor.requires_grad - - -def test_trainable_zbl_basis_initialization(zbl_basis): - zbl_basis = ZBLBasis(p=6, trainable=True) - assert zbl_basis.p == torch.tensor(6.0) - assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) - - assert zbl_basis.a_exp == torch.tensor(0.300) - assert zbl_basis.a_prefactor == torch.tensor(0.4543) - assert zbl_basis.a_exp.requires_grad - assert zbl_basis.a_prefactor.requires_grad - - -def test_forward(zbl_basis): - x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges] - node_attrs = torch.tensor( - [[1, 0], [0, 1]] - ) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers - edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges] - atomic_numbers = torch.tensor([1, 6]) # [n_nodes] - output = zbl_basis(x, node_attrs, edge_index, atomic_numbers) - - assert output.shape == torch.Size([node_attrs.shape[0]]) - assert torch.is_tensor(output) - assert torch.allclose( - output, - torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()), - rtol=1e-2, - ) - - -@pytest.fixture -def agnesi(): - return AgnesiTransform(trainable=False) - - -def test_agnesi_transform_initialization(agnesi: AgnesiTransform): - assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) - assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) - assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) - assert not agnesi.a.requires_grad - assert not agnesi.q.requires_grad - assert not agnesi.p.requires_grad - - -def test_trainable_agnesi_transform_initialization(): - agnesi = AgnesiTransform(trainable=True) - - assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) - assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) - assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) - assert agnesi.a.requires_grad - assert agnesi.q.requires_grad - assert agnesi.p.requires_grad - - -def test_agnesi_transform_forward(): - agnesi = AgnesiTransform() - x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1) - node_attrs = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.get_default_dtype()) - edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) - atomic_numbers = torch.tensor([1, 6, 8]) - output = agnesi(x, node_attrs, edge_index, atomic_numbers) - assert output.shape == x.shape - assert torch.is_tensor(output) - assert torch.allclose( - output, - torch.tensor( - [0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype() - ).unsqueeze(-1), - rtol=1e-2, - ) - - -if __name__ == "__main__": - pytest.main([__file__]) +import pytest +import torch + +from mace.modules.radial import AgnesiTransform, ZBLBasis + + +@pytest.fixture +def zbl_basis(): + return ZBLBasis(p=6, trainable=False) + + +def test_zbl_basis_initialization(zbl_basis): + assert zbl_basis.p == torch.tensor(6.0) + assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) + + assert zbl_basis.a_exp == torch.tensor(0.300) + assert zbl_basis.a_prefactor == torch.tensor(0.4543) + assert not zbl_basis.a_exp.requires_grad + assert not zbl_basis.a_prefactor.requires_grad + + +def test_trainable_zbl_basis_initialization(zbl_basis): + zbl_basis = ZBLBasis(p=6, trainable=True) + assert zbl_basis.p == torch.tensor(6.0) + assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817])) + + assert zbl_basis.a_exp == torch.tensor(0.300) + assert zbl_basis.a_prefactor == torch.tensor(0.4543) + assert zbl_basis.a_exp.requires_grad + assert zbl_basis.a_prefactor.requires_grad + + +def test_forward(zbl_basis): + x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges] + node_attrs = torch.tensor( + [[1, 0], [0, 1]] + ) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers + edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges] + atomic_numbers = torch.tensor([1, 6]) # [n_nodes] + output = zbl_basis(x, node_attrs, edge_index, atomic_numbers) + + assert output.shape == torch.Size([node_attrs.shape[0]]) + assert torch.is_tensor(output) + assert torch.allclose( + output, + torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()), + rtol=1e-2, + ) + + +@pytest.fixture +def agnesi(): + return AgnesiTransform(trainable=False) + + +def test_agnesi_transform_initialization(agnesi: AgnesiTransform): + assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) + assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) + assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) + assert not agnesi.a.requires_grad + assert not agnesi.q.requires_grad + assert not agnesi.p.requires_grad + + +def test_trainable_agnesi_transform_initialization(): + agnesi = AgnesiTransform(trainable=True) + + assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4) + assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4) + assert agnesi.a.item() == pytest.approx(1.0805, rel=1e-4) + assert agnesi.a.requires_grad + assert agnesi.q.requires_grad + assert agnesi.p.requires_grad + + +def test_agnesi_transform_forward(): + agnesi = AgnesiTransform() + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1) + node_attrs = torch.tensor([[0, 1], [1, 0], [0, 1]], dtype=torch.get_default_dtype()) + edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + atomic_numbers = torch.tensor([1, 6, 8]) + output = agnesi(x, node_attrs, edge_index, atomic_numbers) + assert output.shape == x.shape + assert torch.is_tensor(output) + assert torch.allclose( + output, + torch.tensor( + [0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype() + ).unsqueeze(-1), + rtol=1e-2, + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/mace-bench/3rdparty/mace/tests/test_benchmark.py b/mace-bench/3rdparty/mace/tests/test_benchmark.py index 6e5f11c..fae104b 100644 --- a/mace-bench/3rdparty/mace/tests/test_benchmark.py +++ b/mace-bench/3rdparty/mace/tests/test_benchmark.py @@ -1,121 +1,121 @@ -import json -import os -from pathlib import Path -from typing import List, Optional - -import pandas as pd -import pytest -import torch -from ase import build - -from mace import data as mace_data -from mace.calculators.foundations_models import mace_mp -from mace.tools import AtomicNumberTable, torch_geometric, torch_tools - - -def is_mace_full_bench(): - return os.environ.get("MACE_FULL_BENCH", "0") == "1" - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8) -@pytest.mark.parametrize("size", (3, 5, 7, 9)) -@pytest.mark.parametrize("dtype", ["float32", "float64"]) -@pytest.mark.parametrize("compile_mode", [None, "default"]) -def test_inference( - benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda" -): - if not is_mace_full_bench() and compile_mode is not None: - pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute") - - with torch_tools.default_dtype(dtype): - model = load_mace_mp_medium(dtype, compile_mode, device) - batch = create_batch(size, model, device) - log_bench_info(benchmark, dtype, compile_mode, batch) - - def func(): - torch.cuda.synchronize() - model(batch, training=compile_mode is not None, compute_force=True) - - torch.cuda.empty_cache() - benchmark(func) - - -def load_mace_mp_medium(dtype, compile_mode, device): - calc = mace_mp( - model="medium", - default_dtype=dtype, - device=device, - compile_mode=compile_mode, - fullgraph=False, - ) - model = calc.models[0].to(device) - return model - - -def create_batch(size: int, model: torch.nn.Module, device: str) -> dict: - cutoff = model.r_max.item() - z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) - atoms = build.bulk("C", "diamond", a=3.567, cubic=True) - atoms = atoms.repeat((size, size, size)) - config = mace_data.config_from_atoms(atoms) - dataset = [mace_data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=dataset, - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - batch.to(device) - return batch.to_dict() - - -def log_bench_info(benchmark, dtype, compile_mode, batch): - benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0]) - benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1]) - benchmark.extra_info["dtype"] = dtype - benchmark.extra_info["is_compiled"] = compile_mode is not None - benchmark.extra_info["device_name"] = torch.cuda.get_device_name() - - -def process_benchmark_file(bench_file: Path) -> pd.DataFrame: - with open(bench_file, "r", encoding="utf-8") as f: - bench_data = json.load(f) - - records = [] - for bench in bench_data["benchmarks"]: - record = {**bench["extra_info"], **bench["stats"]} - records.append(record) - - result_df = pd.DataFrame(records) - result_df["ns/day (1 fs/step)"] = 0.086400 / result_df["median"] - result_df["Steps per day"] = result_df["ops"] * 86400 - columns = [ - "num_atoms", - "num_edges", - "dtype", - "is_compiled", - "device_name", - "median", - "Steps per day", - "ns/day (1 fs/step)", - ] - return result_df[columns] - - -def read_bench_results(result_files: List[str]) -> pd.DataFrame: - return pd.concat([process_benchmark_file(Path(f)) for f in result_files]) - - -if __name__ == "__main__": - # Print to stdout a csv of the benchmark metrics - import subprocess - - result = subprocess.run( - ["pytest-benchmark", "list"], capture_output=True, text=True, check=True - ) - - bench_files = result.stdout.strip().split("\n") - bench_results = read_bench_results(bench_files) - print(bench_results.to_csv(index=False)) +import json +import os +from pathlib import Path +from typing import List, Optional + +import pandas as pd +import pytest +import torch +from ase import build + +from mace import data as mace_data +from mace.calculators.foundations_models import mace_mp +from mace.tools import AtomicNumberTable, torch_geometric, torch_tools + + +def is_mace_full_bench(): + return os.environ.get("MACE_FULL_BENCH", "0") == "1" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +@pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8) +@pytest.mark.parametrize("size", (3, 5, 7, 9)) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +@pytest.mark.parametrize("compile_mode", [None, "default"]) +def test_inference( + benchmark, size: int, dtype: str, compile_mode: Optional[str], device: str = "cuda" +): + if not is_mace_full_bench() and compile_mode is not None: + pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute") + + with torch_tools.default_dtype(dtype): + model = load_mace_mp_medium(dtype, compile_mode, device) + batch = create_batch(size, model, device) + log_bench_info(benchmark, dtype, compile_mode, batch) + + def func(): + torch.cuda.synchronize() + model(batch, training=compile_mode is not None, compute_force=True) + + torch.cuda.empty_cache() + benchmark(func) + + +def load_mace_mp_medium(dtype, compile_mode, device): + calc = mace_mp( + model="medium", + default_dtype=dtype, + device=device, + compile_mode=compile_mode, + fullgraph=False, + ) + model = calc.models[0].to(device) + return model + + +def create_batch(size: int, model: torch.nn.Module, device: str) -> dict: + cutoff = model.r_max.item() + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms = atoms.repeat((size, size, size)) + config = mace_data.config_from_atoms(atoms) + dataset = [mace_data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch.to(device) + return batch.to_dict() + + +def log_bench_info(benchmark, dtype, compile_mode, batch): + benchmark.extra_info["num_atoms"] = int(batch["positions"].shape[0]) + benchmark.extra_info["num_edges"] = int(batch["edge_index"].shape[1]) + benchmark.extra_info["dtype"] = dtype + benchmark.extra_info["is_compiled"] = compile_mode is not None + benchmark.extra_info["device_name"] = torch.cuda.get_device_name() + + +def process_benchmark_file(bench_file: Path) -> pd.DataFrame: + with open(bench_file, "r", encoding="utf-8") as f: + bench_data = json.load(f) + + records = [] + for bench in bench_data["benchmarks"]: + record = {**bench["extra_info"], **bench["stats"]} + records.append(record) + + result_df = pd.DataFrame(records) + result_df["ns/day (1 fs/step)"] = 0.086400 / result_df["median"] + result_df["Steps per day"] = result_df["ops"] * 86400 + columns = [ + "num_atoms", + "num_edges", + "dtype", + "is_compiled", + "device_name", + "median", + "Steps per day", + "ns/day (1 fs/step)", + ] + return result_df[columns] + + +def read_bench_results(result_files: List[str]) -> pd.DataFrame: + return pd.concat([process_benchmark_file(Path(f)) for f in result_files]) + + +if __name__ == "__main__": + # Print to stdout a csv of the benchmark metrics + import subprocess + + result = subprocess.run( + ["pytest-benchmark", "list"], capture_output=True, text=True, check=True + ) + + bench_files = result.stdout.strip().split("\n") + bench_results = read_bench_results(bench_files) + print(bench_results.to_csv(index=False)) diff --git a/mace-bench/3rdparty/mace/tests/test_calculator.py b/mace-bench/3rdparty/mace/tests/test_calculator.py index 8491538..e9a6546 100644 --- a/mace-bench/3rdparty/mace/tests/test_calculator.py +++ b/mace-bench/3rdparty/mace/tests/test_calculator.py @@ -1,689 +1,689 @@ -import os -import subprocess -import sys -from pathlib import Path - -import ase.io -import numpy as np -import pytest -import torch -from ase import build -from ase.atoms import Atoms -from ase.calculators.test import gradient_test -from ase.constraints import ExpCellFilter - -from mace.calculators import mace_mp, mace_off -from mace.calculators.mace import MACECalculator -from mace.modules.models import ScaleShiftMACE - -try: - import cuequivariance as cue # pylint: disable=unused-import - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - -pytest_mace_dir = Path(__file__).parent.parent -run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - - -@pytest.fixture(scope="module", name="fitting_configs") -def fitting_configs_fixture(): - water = Atoms( - numbers=[8, 1, 1], - positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], - cell=[4] * 3, - pbc=[True] * 3, - ) - fit_configs = [ - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), - ] - fit_configs[0].info["REF_energy"] = 1.0 - fit_configs[0].info["config_type"] = "IsolatedAtom" - fit_configs[1].info["REF_energy"] = -0.5 - fit_configs[1].info["config_type"] = "IsolatedAtom" - - np.random.seed(5) - for _ in range(20): - c = water.copy() - c.positions += np.random.normal(0.1, size=c.positions.shape) - c.info["REF_energy"] = np.random.normal(0.1) - c.info["REF_dipole"] = np.random.normal(0.1, size=3) - c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) - c.new_array("Qs", np.random.normal(0.1, size=c.positions.shape[0])) - c.info["REF_stress"] = np.random.normal(0.1, size=6) - fit_configs.append(c) - - return fit_configs - - -@pytest.fixture(scope="module", name="trained_model") -def trained_model_fixture(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "128x0e", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - -@pytest.fixture(scope="module", name="trained_equivariant_model") -def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "16x0e+16x1o", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - -@pytest.fixture(scope="module", name="trained_equivariant_model_cueq") -def trained_model_equivariant_fixture_cueq(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "16x0e+16x1o", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True - ) - - -@pytest.fixture(scope="module", name="trained_dipole_model") -def trained_dipole_fixture(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "AtomicDipolesMACE", - "num_channels": 8, - "max_L": 2, - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "dipole", - "energy_key": "", - "forces_key": "", - "stress_key": "", - "dipole_key": "REF_dipole", - "error_table": "DipoleRMSE", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", model_type="DipoleMACE" - ) - - -@pytest.fixture(scope="module", name="trained_energy_dipole_model") -def trained_energy_dipole_fixture(tmp_path_factory, fitting_configs): - _mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "EnergyDipolesMACE", - "num_channels": 32, - "max_L": 1, - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "energy_forces_dipole", - "energy_key": "REF_energy", - "forces_key": "", - "stress_key": "", - "dipole_key": "REF_dipole", - "error_table": "EnergyDipoleRMSE", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp("run_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - return MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", model_type="EnergyDipoleMACE" - ) - - -@pytest.fixture(scope="module", name="trained_committee") -def trained_committee_fixture(tmp_path_factory, fitting_configs): - _seeds = [5, 6, 7] - _model_paths = [] - for seed in _seeds: - _mace_params = { - "name": f"MACE{seed}", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "16x0e", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": seed, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, - } - - tmp_path = tmp_path_factory.mktemp(f"run{seed}_") - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - - assert p.returncode == 0 - - _model_paths.append(tmp_path / f"MACE{seed}.model") - - return MACECalculator(model_paths=_model_paths, device="cpu") - - -def test_calculator_node_energy(fitting_configs, trained_model): - for at in fitting_configs: - trained_model.calculate(at) - node_energies = trained_model.results["node_energy"] - batch = trained_model._atoms_to_batch(at) # pylint: disable=protected-access - node_heads = batch["head"][batch["batch"]] - num_atoms_arange = torch.arange(batch["positions"].shape[0]) - node_e0 = ( - trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach() - ) - node_e0 = node_e0[num_atoms_arange, node_heads].cpu().numpy() - energy_via_nodes = np.sum(node_energies + node_e0) - energy = trained_model.results["energy"] - np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6) - - -def test_calculator_forces(fitting_configs, trained_model): - at = fitting_configs[2].copy() - at.calc = trained_model - - # test just forces - grads = gradient_test(at) - - assert np.allclose(grads[0], grads[1]) - - -def test_calculator_stress(fitting_configs, trained_model): - at = fitting_configs[2].copy() - at.calc = trained_model - - # test forces and stress - at_wrapped = ExpCellFilter(at) - grads = gradient_test(at_wrapped) - - assert np.allclose(grads[0], grads[1]) - - -def test_calculator_committee(fitting_configs, trained_committee): - at = fitting_configs[2].copy() - at.calc = trained_committee - - # test just forces - grads = gradient_test(at) - - assert np.allclose(grads[0], grads[1]) - - E = at.get_potential_energy() - energies = at.calc.results["energies"] - energies_var = at.calc.results["energy_var"] - forces_var = np.var(at.calc.results["forces_comm"], axis=0) - assert np.allclose(E, np.mean(energies)) - assert np.allclose(energies_var, np.var(energies)) - assert forces_var.shape == at.calc.results["forces"].shape - - -def test_calculator_from_model(fitting_configs, trained_committee): - # test single model - test_calculator_forces( - fitting_configs, - trained_model=MACECalculator(models=trained_committee.models[0], device="cpu"), - ) - - # test committee model - test_calculator_committee( - fitting_configs, - trained_committee=MACECalculator(models=trained_committee.models, device="cpu"), - ) - - -def test_calculator_dipole(fitting_configs, trained_dipole_model): - at = fitting_configs[2].copy() - at.calc = trained_dipole_model - - dip = at.get_dipole_moment() - - assert len(dip) == 3 - - -def test_calculator_energy_dipole(fitting_configs, trained_energy_dipole_model): - at = fitting_configs[2].copy() - at.calc = trained_energy_dipole_model - - grads = gradient_test(at) - dip = at.get_dipole_moment() - - assert np.allclose(grads[0], grads[1]) - assert len(dip) == 3 - - -def test_calculator_descriptor(fitting_configs, trained_equivariant_model): - at = fitting_configs[2].copy() - at_rotated = fitting_configs[2].copy() - at_rotated.rotate(90, "x") - calc = trained_equivariant_model - - desc_invariant = calc.get_descriptors(at, invariants_only=True) - desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) - desc_invariant_single_layer = calc.get_descriptors( - at, invariants_only=True, num_layers=1 - ) - desc_invariant_single_layer_rotated = calc.get_descriptors( - at_rotated, invariants_only=True, num_layers=1 - ) - desc = calc.get_descriptors(at, invariants_only=False) - desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) - desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) - desc_rotated_single_layer = calc.get_descriptors( - at_rotated, invariants_only=False, num_layers=1 - ) - - assert desc_invariant.shape[0] == 3 - assert desc_invariant.shape[1] == 32 - assert desc_invariant_single_layer.shape[0] == 3 - assert desc_invariant_single_layer.shape[1] == 16 - assert desc.shape[0] == 3 - assert desc.shape[1] == 80 - assert desc_single_layer.shape[0] == 3 - assert desc_single_layer.shape[1] == 16 * 4 - assert desc_rotated_single_layer.shape[0] == 3 - assert desc_rotated_single_layer.shape[1] == 16 * 4 - - np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) - np.testing.assert_allclose( - desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 - ) - np.testing.assert_allclose( - desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 - ) - np.testing.assert_allclose( - desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 - ) - assert not np.allclose( - desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 - ) - assert not np.allclose(desc, desc_rotated, atol=1e-6) - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model_cueq): - at = fitting_configs[2].copy() - at_rotated = fitting_configs[2].copy() - at_rotated.rotate(90, "x") - calc = trained_equivariant_model_cueq - - desc_invariant = calc.get_descriptors(at, invariants_only=True) - desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) - desc_invariant_single_layer = calc.get_descriptors( - at, invariants_only=True, num_layers=1 - ) - desc_invariant_single_layer_rotated = calc.get_descriptors( - at_rotated, invariants_only=True, num_layers=1 - ) - desc = calc.get_descriptors(at, invariants_only=False) - desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) - desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) - desc_rotated_single_layer = calc.get_descriptors( - at_rotated, invariants_only=False, num_layers=1 - ) - - assert desc_invariant.shape[0] == 3 - assert desc_invariant.shape[1] == 32 - assert desc_invariant_single_layer.shape[0] == 3 - assert desc_invariant_single_layer.shape[1] == 16 - assert desc.shape[0] == 3 - assert desc.shape[1] == 80 - assert desc_single_layer.shape[0] == 3 - assert desc_single_layer.shape[1] == 16 * 4 - assert desc_rotated_single_layer.shape[0] == 3 - assert desc_rotated_single_layer.shape[1] == 16 * 4 - - np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) - np.testing.assert_allclose( - desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 - ) - np.testing.assert_allclose( - desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 - ) - np.testing.assert_allclose( - desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 - ) - assert not np.allclose( - desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 - ) - assert not np.allclose(desc, desc_rotated, atol=1e-6) - - -def test_mace_mp(capsys: pytest.CaptureFixture): - mp_mace = mace_mp() - assert isinstance(mp_mace, MACECalculator) - assert mp_mace.model_type == "MACE" - assert len(mp_mace.models) == 1 - assert isinstance(mp_mace.models[0], ScaleShiftMACE) - - _, stderr = capsys.readouterr() - assert stderr == "" - - -def test_mace_off(): - mace_off_model = mace_off(model="small", device="cpu") - assert isinstance(mace_off_model, MACECalculator) - assert mace_off_model.model_type == "MACE" - assert len(mace_off_model.models) == 1 - assert isinstance(mace_off_model.models[0], ScaleShiftMACE) - - atoms = build.molecule("H2O") - atoms.calc = mace_off_model - - E = atoms.get_potential_energy() - - assert np.allclose(E, -2081.116128586803, atol=1e-9) - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_mace_off_cueq(model="medium", device="cpu"): - mace_off_model = mace_off(model=model, device=device, enable_cueq=True) - assert isinstance(mace_off_model, MACECalculator) - assert mace_off_model.model_type == "MACE" - assert len(mace_off_model.models) == 1 - assert isinstance(mace_off_model.models[0], ScaleShiftMACE) - - atoms = build.molecule("H2O") - atoms.calc = mace_off_model - - E = atoms.get_potential_energy() - - assert np.allclose(E, -2081.116128586803, atol=1e-9) - - -def test_mace_mp_stresses(model="medium", device="cpu"): - atoms = build.bulk("Al", "fcc", a=4.05, cubic=True) - atoms = atoms.repeat((2, 2, 2)) - mace_mp_model = mace_mp(model=model, device=device, compute_atomic_stresses=True) - atoms.set_calculator(mace_mp_model) - stress = atoms.get_stress() - stresses = atoms.get_stresses() - assert stress.shape == (6,) - assert stresses.shape == (32, 6) - assert np.allclose(stress, stresses.sum(axis=0), atol=1e-6) +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import torch +from ase import build +from ase.atoms import Atoms +from ase.calculators.test import gradient_test +from ase.constraints import ExpCellFilter + +from mace.calculators import mace_mp, mace_off +from mace.calculators.mace import MACECalculator +from mace.modules.models import ScaleShiftMACE + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +pytest_mace_dir = Path(__file__).parent.parent +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +@pytest.fixture(scope="module", name="fitting_configs") +def fitting_configs_fixture(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + fit_configs[0].info["REF_energy"] = 1.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = -0.5 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + c.info["REF_dipole"] = np.random.normal(0.1, size=3) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.new_array("Qs", np.random.normal(0.1, size=c.positions.shape[0])) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + + +@pytest.fixture(scope="module", name="trained_model") +def trained_model_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + +@pytest.fixture(scope="module", name="trained_equivariant_model") +def trained_model_equivariant_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e+16x1o", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + +@pytest.fixture(scope="module", name="trained_equivariant_model_cueq") +def trained_model_equivariant_fixture_cueq(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e+16x1o", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True + ) + + +@pytest.fixture(scope="module", name="trained_dipole_model") +def trained_dipole_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "AtomicDipolesMACE", + "num_channels": 8, + "max_L": 2, + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "dipole", + "energy_key": "", + "forces_key": "", + "stress_key": "", + "dipole_key": "REF_dipole", + "error_table": "DipoleRMSE", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", model_type="DipoleMACE" + ) + + +@pytest.fixture(scope="module", name="trained_energy_dipole_model") +def trained_energy_dipole_fixture(tmp_path_factory, fitting_configs): + _mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "EnergyDipolesMACE", + "num_channels": 32, + "max_L": 1, + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "energy_forces_dipole", + "energy_key": "REF_energy", + "forces_key": "", + "stress_key": "", + "dipole_key": "REF_dipole", + "error_table": "EnergyDipoleRMSE", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp("run_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + return MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", model_type="EnergyDipoleMACE" + ) + + +@pytest.fixture(scope="module", name="trained_committee") +def trained_committee_fixture(tmp_path_factory, fitting_configs): + _seeds = [5, 6, 7] + _model_paths = [] + for seed in _seeds: + _mace_params = { + "name": f"MACE{seed}", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "16x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": seed, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, + } + + tmp_path = tmp_path_factory.mktemp(f"run{seed}_") + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + + assert p.returncode == 0 + + _model_paths.append(tmp_path / f"MACE{seed}.model") + + return MACECalculator(model_paths=_model_paths, device="cpu") + + +def test_calculator_node_energy(fitting_configs, trained_model): + for at in fitting_configs: + trained_model.calculate(at) + node_energies = trained_model.results["node_energy"] + batch = trained_model._atoms_to_batch(at) # pylint: disable=protected-access + node_heads = batch["head"][batch["batch"]] + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_e0 = ( + trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach() + ) + node_e0 = node_e0[num_atoms_arange, node_heads].cpu().numpy() + energy_via_nodes = np.sum(node_energies + node_e0) + energy = trained_model.results["energy"] + np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6) + + +def test_calculator_forces(fitting_configs, trained_model): + at = fitting_configs[2].copy() + at.calc = trained_model + + # test just forces + grads = gradient_test(at) + + assert np.allclose(grads[0], grads[1]) + + +def test_calculator_stress(fitting_configs, trained_model): + at = fitting_configs[2].copy() + at.calc = trained_model + + # test forces and stress + at_wrapped = ExpCellFilter(at) + grads = gradient_test(at_wrapped) + + assert np.allclose(grads[0], grads[1]) + + +def test_calculator_committee(fitting_configs, trained_committee): + at = fitting_configs[2].copy() + at.calc = trained_committee + + # test just forces + grads = gradient_test(at) + + assert np.allclose(grads[0], grads[1]) + + E = at.get_potential_energy() + energies = at.calc.results["energies"] + energies_var = at.calc.results["energy_var"] + forces_var = np.var(at.calc.results["forces_comm"], axis=0) + assert np.allclose(E, np.mean(energies)) + assert np.allclose(energies_var, np.var(energies)) + assert forces_var.shape == at.calc.results["forces"].shape + + +def test_calculator_from_model(fitting_configs, trained_committee): + # test single model + test_calculator_forces( + fitting_configs, + trained_model=MACECalculator(models=trained_committee.models[0], device="cpu"), + ) + + # test committee model + test_calculator_committee( + fitting_configs, + trained_committee=MACECalculator(models=trained_committee.models, device="cpu"), + ) + + +def test_calculator_dipole(fitting_configs, trained_dipole_model): + at = fitting_configs[2].copy() + at.calc = trained_dipole_model + + dip = at.get_dipole_moment() + + assert len(dip) == 3 + + +def test_calculator_energy_dipole(fitting_configs, trained_energy_dipole_model): + at = fitting_configs[2].copy() + at.calc = trained_energy_dipole_model + + grads = gradient_test(at) + dip = at.get_dipole_moment() + + assert np.allclose(grads[0], grads[1]) + assert len(dip) == 3 + + +def test_calculator_descriptor(fitting_configs, trained_equivariant_model): + at = fitting_configs[2].copy() + at_rotated = fitting_configs[2].copy() + at_rotated.rotate(90, "x") + calc = trained_equivariant_model + + desc_invariant = calc.get_descriptors(at, invariants_only=True) + desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1 + ) + desc_invariant_single_layer_rotated = calc.get_descriptors( + at_rotated, invariants_only=True, num_layers=1 + ) + desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) + desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1 + ) + + assert desc_invariant.shape[0] == 3 + assert desc_invariant.shape[1] == 32 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 + assert desc.shape[0] == 3 + assert desc.shape[1] == 80 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16 * 4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16 * 4 + + np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) + np.testing.assert_allclose( + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 + ) + assert not np.allclose(desc, desc_rotated, atol=1e-6) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model_cueq): + at = fitting_configs[2].copy() + at_rotated = fitting_configs[2].copy() + at_rotated.rotate(90, "x") + calc = trained_equivariant_model_cueq + + desc_invariant = calc.get_descriptors(at, invariants_only=True) + desc_invariant_rotated = calc.get_descriptors(at_rotated, invariants_only=True) + desc_invariant_single_layer = calc.get_descriptors( + at, invariants_only=True, num_layers=1 + ) + desc_invariant_single_layer_rotated = calc.get_descriptors( + at_rotated, invariants_only=True, num_layers=1 + ) + desc = calc.get_descriptors(at, invariants_only=False) + desc_single_layer = calc.get_descriptors(at, invariants_only=False, num_layers=1) + desc_rotated = calc.get_descriptors(at_rotated, invariants_only=False) + desc_rotated_single_layer = calc.get_descriptors( + at_rotated, invariants_only=False, num_layers=1 + ) + + assert desc_invariant.shape[0] == 3 + assert desc_invariant.shape[1] == 32 + assert desc_invariant_single_layer.shape[0] == 3 + assert desc_invariant_single_layer.shape[1] == 16 + assert desc.shape[0] == 3 + assert desc.shape[1] == 80 + assert desc_single_layer.shape[0] == 3 + assert desc_single_layer.shape[1] == 16 * 4 + assert desc_rotated_single_layer.shape[0] == 3 + assert desc_rotated_single_layer.shape[1] == 16 * 4 + + np.testing.assert_allclose(desc_invariant, desc_invariant_rotated, atol=1e-6) + np.testing.assert_allclose( + desc_invariant_single_layer, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_invariant_single_layer_rotated, desc_invariant[:, :16], atol=1e-6 + ) + np.testing.assert_allclose( + desc_single_layer[:, :16], desc_rotated_single_layer[:, :16], atol=1e-6 + ) + assert not np.allclose( + desc_single_layer[:, 16:], desc_rotated_single_layer[:, 16:], atol=1e-6 + ) + assert not np.allclose(desc, desc_rotated, atol=1e-6) + + +def test_mace_mp(capsys: pytest.CaptureFixture): + mp_mace = mace_mp() + assert isinstance(mp_mace, MACECalculator) + assert mp_mace.model_type == "MACE" + assert len(mp_mace.models) == 1 + assert isinstance(mp_mace.models[0], ScaleShiftMACE) + + _, stderr = capsys.readouterr() + assert stderr == "" + + +def test_mace_off(): + mace_off_model = mace_off(model="small", device="cpu") + assert isinstance(mace_off_model, MACECalculator) + assert mace_off_model.model_type == "MACE" + assert len(mace_off_model.models) == 1 + assert isinstance(mace_off_model.models[0], ScaleShiftMACE) + + atoms = build.molecule("H2O") + atoms.calc = mace_off_model + + E = atoms.get_potential_energy() + + assert np.allclose(E, -2081.116128586803, atol=1e-9) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_mace_off_cueq(model="medium", device="cpu"): + mace_off_model = mace_off(model=model, device=device, enable_cueq=True) + assert isinstance(mace_off_model, MACECalculator) + assert mace_off_model.model_type == "MACE" + assert len(mace_off_model.models) == 1 + assert isinstance(mace_off_model.models[0], ScaleShiftMACE) + + atoms = build.molecule("H2O") + atoms.calc = mace_off_model + + E = atoms.get_potential_energy() + + assert np.allclose(E, -2081.116128586803, atol=1e-9) + + +def test_mace_mp_stresses(model="medium", device="cpu"): + atoms = build.bulk("Al", "fcc", a=4.05, cubic=True) + atoms = atoms.repeat((2, 2, 2)) + mace_mp_model = mace_mp(model=model, device=device, compute_atomic_stresses=True) + atoms.set_calculator(mace_mp_model) + stress = atoms.get_stress() + stresses = atoms.get_stresses() + assert stress.shape == (6,) + assert stresses.shape == (32, 6) + assert np.allclose(stress, stresses.sum(axis=0), atol=1e-6) diff --git a/mace-bench/3rdparty/mace/tests/test_cg.py b/mace-bench/3rdparty/mace/tests/test_cg.py index c6465fc..36b119b 100644 --- a/mace-bench/3rdparty/mace/tests/test_cg.py +++ b/mace-bench/3rdparty/mace/tests/test_cg.py @@ -1,12 +1,12 @@ -from e3nn import o3 - -from mace.tools import cg - - -def test_U_matrix(): - irreps_in = o3.Irreps("1x0e + 1x1o + 1x2e") - irreps_out = o3.Irreps("1x0e + 1x1o") - u_matrix = cg.U_matrix_real( - irreps_in=irreps_in, irreps_out=irreps_out, correlation=3 - )[-1] - assert u_matrix.shape == (3, 9, 9, 9, 21) +from e3nn import o3 + +from mace.tools import cg + + +def test_U_matrix(): + irreps_in = o3.Irreps("1x0e + 1x1o + 1x2e") + irreps_out = o3.Irreps("1x0e + 1x1o") + u_matrix = cg.U_matrix_real( + irreps_in=irreps_in, irreps_out=irreps_out, correlation=3 + )[-1] + assert u_matrix.shape == (3, 9, 9, 9, 21) diff --git a/mace-bench/3rdparty/mace/tests/test_compile.py b/mace-bench/3rdparty/mace/tests/test_compile.py index 9869441..d7d585e 100644 --- a/mace-bench/3rdparty/mace/tests/test_compile.py +++ b/mace-bench/3rdparty/mace/tests/test_compile.py @@ -1,154 +1,154 @@ -import os -from functools import wraps -from typing import Callable - -import numpy as np -import pytest -import torch -import torch.nn.functional as F -from e3nn import o3 -from torch.testing import assert_close - -from mace import data, modules, tools -from mace.tools import compile as mace_compile -from mace.tools import torch_geometric - -table = tools.AtomicNumberTable([6]) -atomic_energies = np.array([1.0], dtype=float) -cutoff = 5.0 - - -def create_mace(device: str, seed: int = 1702): - torch_geometric.seed_everything(seed) - - model_config = { - "r_max": cutoff, - "num_bessel": 8, - "num_polynomial_cutoff": 6, - "max_ell": 3, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "num_interactions": 2, - "num_elements": 1, - "hidden_irreps": o3.Irreps("128x0e + 128x1o"), - "MLP_irreps": o3.Irreps("16x0e"), - "gate": F.silu, - "atomic_energies": atomic_energies, - "avg_num_neighbors": 8, - "atomic_numbers": table.zs, - "correlation": 3, - "radial_type": "bessel", - "atomic_inter_scale": 1.0, - "atomic_inter_shift": 0.0, - } - model = modules.ScaleShiftMACE(**model_config) - return model.to(device) - - -def create_batch(device: str): - from ase import build - - size = 2 - atoms = build.bulk("C", "diamond", a=3.567, cubic=True) - atoms_list = [atoms.repeat((size, size, size))] - print("Number of atoms", len(atoms_list[0])) - - configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) - for config in configs - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - batch = batch.to(device) - batch = batch.to_dict() - return batch - - -def time_func(func: Callable): - @wraps(func) - def wrapper(*args, **kwargs): - torch._inductor.cudagraph_mark_step_begin() # pylint: disable=W0212 - outputs = func(*args, **kwargs) - torch.cuda.synchronize() - return outputs - - return wrapper - - -@pytest.fixture(params=[torch.float32, torch.float64], ids=["fp32", "fp64"]) -def default_dtype(request): - with tools.torch_tools.default_dtype(request.param): - yield torch.get_default_dtype() - - -# skip if on windows -@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_mace(device, default_dtype): # pylint: disable=W0621 - print(f"using default dtype = {default_dtype}") - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip(reason="cuda is not available") - - model_defaults = create_mace(device) - tmp_model = mace_compile.prepare(create_mace)(device) - model_compiled = torch.compile(tmp_model, mode="default") - - batch = create_batch(device) - output1 = model_defaults(batch, training=True) - output2 = model_compiled(batch, training=True) - assert_close(output1["energy"], output2["energy"]) - assert_close(output1["forces"], output2["forces"]) - - -@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 - print(f"using default dtype = {default_dtype}") - batch = create_batch("cuda") - model = create_mace("cuda") - model = time_func(model) - benchmark(model, batch, training=True) - - -@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -@pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) -@pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) -def test_compile_benchmark(benchmark, compile_mode, enable_amp): - if enable_amp: - pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") - - with tools.torch_tools.default_dtype(torch.float32): - batch = create_batch("cuda") - torch.compiler.reset() - model = mace_compile.prepare(create_mace)("cuda") - model = torch.compile(model, mode=compile_mode) - model = time_func(model) - - with torch.autocast("cuda", enabled=enable_amp): - benchmark(model, batch, training=True) - - -@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") -def test_graph_breaks(): - import torch._dynamo as dynamo - - batch = create_batch("cuda") - model = mace_compile.prepare(create_mace)("cuda") - explanation = dynamo.explain(model)(batch, training=False) - - # these clutter the output but might be useful for investigating graph breaks - explanation.ops_per_graph = None - explanation.out_guards = None - print(explanation) - assert explanation.graph_break_count == 0 +import os +from functools import wraps +from typing import Callable + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 +from torch.testing import assert_close + +from mace import data, modules, tools +from mace.tools import compile as mace_compile +from mace.tools import torch_geometric + +table = tools.AtomicNumberTable([6]) +atomic_energies = np.array([1.0], dtype=float) +cutoff = 5.0 + + +def create_mace(device: str, seed: int = 1702): + torch_geometric.seed_everything(seed) + + model_config = { + "r_max": cutoff, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": o3.Irreps("128x0e + 128x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": atomic_energies, + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, + } + model = modules.ScaleShiftMACE(**model_config) + return model.to(device) + + +def create_batch(device: str): + from ase import build + + size = 2 + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms_list = [atoms.repeat((size, size, size))] + print("Number of atoms", len(atoms_list[0])) + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch = batch.to(device) + batch = batch.to_dict() + return batch + + +def time_func(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + torch._inductor.cudagraph_mark_step_begin() # pylint: disable=W0212 + outputs = func(*args, **kwargs) + torch.cuda.synchronize() + return outputs + + return wrapper + + +@pytest.fixture(params=[torch.float32, torch.float64], ids=["fp32", "fp64"]) +def default_dtype(request): + with tools.torch_tools.default_dtype(request.param): + yield torch.get_default_dtype() + + +# skip if on windows +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_mace(device, default_dtype): # pylint: disable=W0621 + print(f"using default dtype = {default_dtype}") + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip(reason="cuda is not available") + + model_defaults = create_mace(device) + tmp_model = mace_compile.prepare(create_mace)(device) + model_compiled = torch.compile(tmp_model, mode="default") + + batch = create_batch(device) + output1 = model_defaults(batch, training=True) + output2 = model_compiled(batch, training=True) + assert_close(output1["energy"], output2["energy"]) + assert_close(output1["forces"], output2["forces"]) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_eager_benchmark(benchmark, default_dtype): # pylint: disable=W0621 + print(f"using default dtype = {default_dtype}") + batch = create_batch("cuda") + model = create_mace("cuda") + model = time_func(model) + benchmark(model, batch, training=True) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +@pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"]) +@pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"]) +def test_compile_benchmark(benchmark, compile_mode, enable_amp): + if enable_amp: + pytest.skip(reason="autocast compiler assertion aten.slice_scatter.default") + + with tools.torch_tools.default_dtype(torch.float32): + batch = create_batch("cuda") + torch.compiler.reset() + model = mace_compile.prepare(create_mace)("cuda") + model = torch.compile(model, mode=compile_mode) + model = time_func(model) + + with torch.autocast("cuda", enabled=enable_amp): + benchmark(model, batch, training=True) + + +@pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_graph_breaks(): + import torch._dynamo as dynamo + + batch = create_batch("cuda") + model = mace_compile.prepare(create_mace)("cuda") + explanation = dynamo.explain(model)(batch, training=False) + + # these clutter the output but might be useful for investigating graph breaks + explanation.ops_per_graph = None + explanation.out_guards = None + print(explanation) + assert explanation.graph_break_count == 0 diff --git a/mace-bench/3rdparty/mace/tests/test_cueq.py b/mace-bench/3rdparty/mace/tests/test_cueq.py index d76b25f..480a817 100644 --- a/mace-bench/3rdparty/mace/tests/test_cueq.py +++ b/mace-bench/3rdparty/mace/tests/test_cueq.py @@ -1,181 +1,181 @@ -# pylint: disable=wrong-import-position -import os -from copy import deepcopy -from typing import Any, Dict - -os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" - -import pytest -import torch -import torch.nn.functional as F -from e3nn import o3 - -from mace import data, modules, tools -from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn -from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq -from mace.tools import torch_geometric - -try: - import cuequivariance as cue # pylint: disable=unused-import - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - -CUDA_AVAILABLE = torch.cuda.is_available() - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -class TestCueq: - @pytest.fixture - def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: - table = tools.AtomicNumberTable([6]) - return { - "r_max": 5.0, - "num_bessel": 8, - "num_polynomial_cutoff": 6, - "max_ell": 3, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": interaction_cls_first, - "num_interactions": 2, - "num_elements": 1, - "hidden_irreps": hidden_irreps, - "MLP_irreps": o3.Irreps("16x0e"), - "gate": F.silu, - "atomic_energies": torch.tensor([1.0]), - "avg_num_neighbors": 8, - "atomic_numbers": table.zs, - "correlation": 3, - "radial_type": "bessel", - "atomic_inter_scale": 1.0, - "atomic_inter_shift": 0.0, - } - - @pytest.fixture - def batch(self, device: str, default_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - from ase import build - - torch.set_default_dtype(default_dtype) - - table = tools.AtomicNumberTable([6]) - - atoms = build.bulk("C", "diamond", a=3.567, cubic=True) - import numpy as np - - displacement = np.random.uniform(-0.1, 0.1, size=atoms.positions.shape) - atoms.positions += displacement - atoms_list = [atoms.repeat((2, 2, 2))] - - configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=table, cutoff=5.0) - for config in configs - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - return batch.to(device).to_dict() - - @pytest.mark.parametrize( - "device", - ["cpu"] + (["cuda"] if CUDA_AVAILABLE else []), - ) - @pytest.mark.parametrize( - "interaction_cls_first", - [ - modules.interaction_classes["RealAgnosticResidualInteractionBlock"], - modules.interaction_classes["RealAgnosticInteractionBlock"], - modules.interaction_classes["RealAgnosticDensityInteractionBlock"], - ], - ) - @pytest.mark.parametrize( - "hidden_irreps", - [ - o3.Irreps("32x0e + 32x1o"), - o3.Irreps("32x0e + 32x1o + 32x2e"), - o3.Irreps("32x0e"), - ], - ) - @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) - def test_bidirectional_conversion( - self, - model_config: Dict[str, Any], - batch: Dict[str, torch.Tensor], - device: str, - default_dtype: torch.dtype, - ): - if device == "cuda" and not CUDA_AVAILABLE: - pytest.skip("CUDA not available") - torch.manual_seed(42) - - # Create original E3nn model - model_e3nn = modules.ScaleShiftMACE(**model_config).to(device) - - # Convert E3nn to CuEq - model_cueq = run_e3nn_to_cueq(model_e3nn).to(device) - - # Convert CuEq back to E3nn - model_e3nn_back = run_cueq_to_e3nn(model_cueq).to(device) - - # Test forward pass equivalence - out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True) - out_cueq = model_cueq(deepcopy(batch), training=True, compute_stress=True) - out_e3nn_back = model_e3nn_back( - deepcopy(batch), training=True, compute_stress=True - ) - - # Check outputs match for both conversions - torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) - torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) - torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) - torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) - torch.testing.assert_close(out_e3nn["stress"], out_cueq["stress"]) - torch.testing.assert_close(out_cueq["stress"], out_e3nn_back["stress"]) - - # Test backward pass equivalence - loss_e3nn = out_e3nn["energy"].sum() - loss_cueq = out_cueq["energy"].sum() - loss_e3nn_back = out_e3nn_back["energy"].sum() - - loss_e3nn.backward() - loss_cueq.backward() - loss_e3nn_back.backward() - - # Compare gradients for all conversions - tol = 1e-4 if default_dtype == torch.float32 else 1e-7 - - def print_gradient_diff(name1, p1, name2, p2, conv_type): - if p1.grad is not None and p1.grad.shape == p2.grad.shape: - if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: - error = torch.abs(p1.grad - p2.grad) - print( - f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" - ) - torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=tol) - - # E3nn to CuEq gradients - for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( - model_e3nn.named_parameters(), model_cueq.named_parameters() - ): - print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") - - # CuEq to E3nn gradients - for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( - model_cueq.named_parameters(), model_e3nn_back.named_parameters() - ): - print_gradient_diff( - name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" - ) - - # Full circle comparison (E3nn -> E3nn) - for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( - model_e3nn.named_parameters(), model_e3nn_back.named_parameters() - ): - print_gradient_diff( - name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" - ) +# pylint: disable=wrong-import-position +import os +from copy import deepcopy +from typing import Any, Dict + +os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" + +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 + +from mace import data, modules, tools +from mace.cli.convert_cueq_e3nn import run as run_cueq_to_e3nn +from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq +from mace.tools import torch_geometric + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +CUDA_AVAILABLE = torch.cuda.is_available() + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +class TestCueq: + @pytest.fixture + def model_config(self, interaction_cls_first, hidden_irreps) -> Dict[str, Any]: + table = tools.AtomicNumberTable([6]) + return { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": interaction_cls_first, + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": hidden_irreps, + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": torch.tensor([1.0]), + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + "atomic_inter_scale": 1.0, + "atomic_inter_shift": 0.0, + } + + @pytest.fixture + def batch(self, device: str, default_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + from ase import build + + torch.set_default_dtype(default_dtype) + + table = tools.AtomicNumberTable([6]) + + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + import numpy as np + + displacement = np.random.uniform(-0.1, 0.1, size=atoms.positions.shape) + atoms.positions += displacement + atoms_list = [atoms.repeat((2, 2, 2))] + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=5.0) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + return batch.to(device).to_dict() + + @pytest.mark.parametrize( + "device", + ["cpu"] + (["cuda"] if CUDA_AVAILABLE else []), + ) + @pytest.mark.parametrize( + "interaction_cls_first", + [ + modules.interaction_classes["RealAgnosticResidualInteractionBlock"], + modules.interaction_classes["RealAgnosticInteractionBlock"], + modules.interaction_classes["RealAgnosticDensityInteractionBlock"], + ], + ) + @pytest.mark.parametrize( + "hidden_irreps", + [ + o3.Irreps("32x0e + 32x1o"), + o3.Irreps("32x0e + 32x1o + 32x2e"), + o3.Irreps("32x0e"), + ], + ) + @pytest.mark.parametrize("default_dtype", [torch.float32, torch.float64]) + def test_bidirectional_conversion( + self, + model_config: Dict[str, Any], + batch: Dict[str, torch.Tensor], + device: str, + default_dtype: torch.dtype, + ): + if device == "cuda" and not CUDA_AVAILABLE: + pytest.skip("CUDA not available") + torch.manual_seed(42) + + # Create original E3nn model + model_e3nn = modules.ScaleShiftMACE(**model_config).to(device) + + # Convert E3nn to CuEq + model_cueq = run_e3nn_to_cueq(model_e3nn).to(device) + + # Convert CuEq back to E3nn + model_e3nn_back = run_cueq_to_e3nn(model_cueq).to(device) + + # Test forward pass equivalence + out_e3nn = model_e3nn(deepcopy(batch), training=True, compute_stress=True) + out_cueq = model_cueq(deepcopy(batch), training=True, compute_stress=True) + out_e3nn_back = model_e3nn_back( + deepcopy(batch), training=True, compute_stress=True + ) + + # Check outputs match for both conversions + torch.testing.assert_close(out_e3nn["energy"], out_cueq["energy"]) + torch.testing.assert_close(out_cueq["energy"], out_e3nn_back["energy"]) + torch.testing.assert_close(out_e3nn["forces"], out_cueq["forces"]) + torch.testing.assert_close(out_cueq["forces"], out_e3nn_back["forces"]) + torch.testing.assert_close(out_e3nn["stress"], out_cueq["stress"]) + torch.testing.assert_close(out_cueq["stress"], out_e3nn_back["stress"]) + + # Test backward pass equivalence + loss_e3nn = out_e3nn["energy"].sum() + loss_cueq = out_cueq["energy"].sum() + loss_e3nn_back = out_e3nn_back["energy"].sum() + + loss_e3nn.backward() + loss_cueq.backward() + loss_e3nn_back.backward() + + # Compare gradients for all conversions + tol = 1e-4 if default_dtype == torch.float32 else 1e-7 + + def print_gradient_diff(name1, p1, name2, p2, conv_type): + if p1.grad is not None and p1.grad.shape == p2.grad.shape: + if name1.split(".", 2)[:2] == name2.split(".", 2)[:2]: + error = torch.abs(p1.grad - p2.grad) + print( + f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}" + ) + torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=tol) + + # E3nn to CuEq gradients + for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip( + model_e3nn.named_parameters(), model_cueq.named_parameters() + ): + print_gradient_diff(name_e3nn, p_e3nn, name_cueq, p_cueq, "E3nn->CuEq") + + # CuEq to E3nn gradients + for (name_cueq, p_cueq), (name_e3nn_back, p_e3nn_back) in zip( + model_cueq.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_cueq, p_cueq, name_e3nn_back, p_e3nn_back, "CuEq->E3nn" + ) + + # Full circle comparison (E3nn -> E3nn) + for (name_e3nn, p_e3nn), (name_e3nn_back, p_e3nn_back) in zip( + model_e3nn.named_parameters(), model_e3nn_back.named_parameters() + ): + print_gradient_diff( + name_e3nn, p_e3nn, name_e3nn_back, p_e3nn_back, "Full circle" + ) diff --git a/mace-bench/3rdparty/mace/tests/test_data.py b/mace-bench/3rdparty/mace/tests/test_data.py index 41180e8..6710ecd 100644 --- a/mace-bench/3rdparty/mace/tests/test_data.py +++ b/mace-bench/3rdparty/mace/tests/test_data.py @@ -1,213 +1,213 @@ -from copy import deepcopy -from pathlib import Path - -import ase.build -import h5py -import numpy as np -import torch - -from mace.data import ( - AtomicData, - Configuration, - HDF5Dataset, - config_from_atoms, - get_neighborhood, - save_configurations_as_HDF5, -) -from mace.tools import AtomicNumberTable, torch_geometric - -mace_path = Path(__file__).parent.parent - - -class TestAtomicData: - config = Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - }, - ) - config_2 = deepcopy(config) - config_2.positions = config.positions + 0.01 - - table = AtomicNumberTable([1, 8]) - - def test_atomic_data(self): - data = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - - assert data.edge_index.shape == (2, 4) - assert data.forces.shape == (3, 3) - assert data.node_attrs.shape == (3, 2) - - def test_data_loader(self): - data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data1, data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - - for batch in data_loader: - assert batch.batch.shape == (6,) - assert batch.edge_index.shape == (2, 8) - assert batch.shifts.shape == (8, 3) - assert batch.positions.shape == (6, 3) - assert batch.node_attrs.shape == (6, 2) - assert batch.energy.shape == (2,) - assert batch.forces.shape == (6, 3) - - def test_to_atomic_data_dict(self): - data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data1, data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - for batch in data_loader: - batch_dict = batch.to_dict() - assert batch_dict["batch"].shape == (6,) - assert batch_dict["edge_index"].shape == (2, 8) - assert batch_dict["shifts"].shape == (8, 3) - assert batch_dict["positions"].shape == (6, 3) - assert batch_dict["node_attrs"].shape == (6, 2) - assert batch_dict["energy"].shape == (2,) - assert batch_dict["forces"].shape == (6, 3) - - def test_hdf5_dataloader(self): - datasets = [self.config, self.config_2] * 5 - # get path of the mace package - with h5py.File(str(mace_path) + "test.h5", "w") as f: - save_configurations_as_HDF5(datasets, 0, f) - train_dataset = HDF5Dataset( - str(mace_path) + "test.h5", z_table=self.table, r_max=3.0 - ) - train_loader = torch_geometric.dataloader.DataLoader( - dataset=train_dataset, - batch_size=2, - shuffle=False, - drop_last=False, - ) - batch_count = 0 - for batch in train_loader: - batch_count += 1 - assert batch.batch.shape == (6,) - assert batch.edge_index.shape == (2, 8) - assert batch.shifts.shape == (8, 3) - assert batch.positions.shape == (6, 3) - assert batch.node_attrs.shape == (6, 2) - assert batch.energy.shape == (2,) - assert batch.forces.shape == (6, 3) - print(batch_count, len(train_loader), len(train_dataset)) - assert batch_count == len(train_loader) == len(train_dataset) / 2 - train_loader_direct = torch_geometric.dataloader.DataLoader( - dataset=[ - AtomicData.from_config(config, z_table=self.table, cutoff=3.0) - for config in datasets - ], - batch_size=2, - shuffle=False, - drop_last=False, - ) - for batch_direct, batch in zip(train_loader_direct, train_loader): - assert torch.all(batch_direct.edge_index == batch.edge_index) - assert torch.all(batch_direct.shifts == batch.shifts) - assert torch.all(batch_direct.positions == batch.positions) - assert torch.all(batch_direct.node_attrs == batch.node_attrs) - assert torch.all(batch_direct.energy == batch.energy) - assert torch.all(batch_direct.forces == batch.forces) - - -class TestNeighborhood: - def test_basic(self): - positions = np.array( - [ - [-1.0, 0.0, 0.0], - [+0.0, 0.0, 0.0], - [+1.0, 0.0, 0.0], - ] - ) - - indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5) - assert indices.shape == (2, 4) - assert shifts.shape == (4, 3) - assert unit_shifts.shape == (4, 3) - - def test_signs(self): - positions = np.array( - [ - [+0.5, 0.5, 0.0], - [+1.0, 1.0, 0.0], - ] - ) - - cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) - edge_index, shifts, unit_shifts, _ = get_neighborhood( - positions, cutoff=3.5, pbc=(True, False, False), cell=cell - ) - num_edges = 10 - assert edge_index.shape == (2, num_edges) - assert shifts.shape == (num_edges, 3) - assert unit_shifts.shape == (num_edges, 3) - - -# Based on mir-group/nequip -def test_periodic_edge(): - atoms = ase.build.bulk("Cu", "fcc") - dist = np.linalg.norm(atoms.cell[0]).item() - config = config_from_atoms(atoms) - edge_index, shifts, _, _ = get_neighborhood( - config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell - ) - sender, receiver = edge_index - vectors = ( - config.positions[receiver] - config.positions[sender] + shifts - ) # [n_edges, 3] - assert vectors.shape == (12, 3) # 12 neighbors in close-packed bulk - assert np.allclose( - np.linalg.norm(vectors, axis=-1), - dist, - ) - - -def test_half_periodic(): - atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0) - assert all(atoms.pbc == (True, True, False)) - config = config_from_atoms(atoms) # first shell dist is 2.864A - edge_index, shifts, _, _ = get_neighborhood( - config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell - ) - sender, receiver = edge_index - vectors = ( - config.positions[receiver] - config.positions[sender] + shifts - ) # [n_edges, 3] - # Check number of neighbors: - _, neighbor_count = np.unique(edge_index[0], return_counts=True) - assert (neighbor_count == 6).all() # 6 neighbors - # Check not periodic in z - assert np.allclose( - vectors[:, 2], - np.zeros(vectors.shape[0]), - ) +from copy import deepcopy +from pathlib import Path + +import ase.build +import h5py +import numpy as np +import torch + +from mace.data import ( + AtomicData, + Configuration, + HDF5Dataset, + config_from_atoms, + get_neighborhood, + save_configurations_as_HDF5, +) +from mace.tools import AtomicNumberTable, torch_geometric + +mace_path = Path(__file__).parent.parent + + +class TestAtomicData: + config = Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, + ) + config_2 = deepcopy(config) + config_2.positions = config.positions + 0.01 + + table = AtomicNumberTable([1, 8]) + + def test_atomic_data(self): + data = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + assert data.edge_index.shape == (2, 4) + assert data.forces.shape == (3, 3) + assert data.node_attrs.shape == (3, 2) + + def test_data_loader(self): + data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data1, data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + + for batch in data_loader: + assert batch.batch.shape == (6,) + assert batch.edge_index.shape == (2, 8) + assert batch.shifts.shape == (8, 3) + assert batch.positions.shape == (6, 3) + assert batch.node_attrs.shape == (6, 2) + assert batch.energy.shape == (2,) + assert batch.forces.shape == (6, 3) + + def test_to_atomic_data_dict(self): + data1 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + data2 = AtomicData.from_config(self.config, z_table=self.table, cutoff=3.0) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data1, data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + for batch in data_loader: + batch_dict = batch.to_dict() + assert batch_dict["batch"].shape == (6,) + assert batch_dict["edge_index"].shape == (2, 8) + assert batch_dict["shifts"].shape == (8, 3) + assert batch_dict["positions"].shape == (6, 3) + assert batch_dict["node_attrs"].shape == (6, 2) + assert batch_dict["energy"].shape == (2,) + assert batch_dict["forces"].shape == (6, 3) + + def test_hdf5_dataloader(self): + datasets = [self.config, self.config_2] * 5 + # get path of the mace package + with h5py.File(str(mace_path) + "test.h5", "w") as f: + save_configurations_as_HDF5(datasets, 0, f) + train_dataset = HDF5Dataset( + str(mace_path) + "test.h5", z_table=self.table, r_max=3.0 + ) + train_loader = torch_geometric.dataloader.DataLoader( + dataset=train_dataset, + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch_count = 0 + for batch in train_loader: + batch_count += 1 + assert batch.batch.shape == (6,) + assert batch.edge_index.shape == (2, 8) + assert batch.shifts.shape == (8, 3) + assert batch.positions.shape == (6, 3) + assert batch.node_attrs.shape == (6, 2) + assert batch.energy.shape == (2,) + assert batch.forces.shape == (6, 3) + print(batch_count, len(train_loader), len(train_dataset)) + assert batch_count == len(train_loader) == len(train_dataset) / 2 + train_loader_direct = torch_geometric.dataloader.DataLoader( + dataset=[ + AtomicData.from_config(config, z_table=self.table, cutoff=3.0) + for config in datasets + ], + batch_size=2, + shuffle=False, + drop_last=False, + ) + for batch_direct, batch in zip(train_loader_direct, train_loader): + assert torch.all(batch_direct.edge_index == batch.edge_index) + assert torch.all(batch_direct.shifts == batch.shifts) + assert torch.all(batch_direct.positions == batch.positions) + assert torch.all(batch_direct.node_attrs == batch.node_attrs) + assert torch.all(batch_direct.energy == batch.energy) + assert torch.all(batch_direct.forces == batch.forces) + + +class TestNeighborhood: + def test_basic(self): + positions = np.array( + [ + [-1.0, 0.0, 0.0], + [+0.0, 0.0, 0.0], + [+1.0, 0.0, 0.0], + ] + ) + + indices, shifts, unit_shifts, _ = get_neighborhood(positions, cutoff=1.5) + assert indices.shape == (2, 4) + assert shifts.shape == (4, 3) + assert unit_shifts.shape == (4, 3) + + def test_signs(self): + positions = np.array( + [ + [+0.5, 0.5, 0.0], + [+1.0, 1.0, 0.0], + ] + ) + + cell = np.array([[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + edge_index, shifts, unit_shifts, _ = get_neighborhood( + positions, cutoff=3.5, pbc=(True, False, False), cell=cell + ) + num_edges = 10 + assert edge_index.shape == (2, num_edges) + assert shifts.shape == (num_edges, 3) + assert unit_shifts.shape == (num_edges, 3) + + +# Based on mir-group/nequip +def test_periodic_edge(): + atoms = ase.build.bulk("Cu", "fcc") + dist = np.linalg.norm(atoms.cell[0]).item() + config = config_from_atoms(atoms) + edge_index, shifts, _, _ = get_neighborhood( + config.positions, cutoff=1.05 * dist, pbc=(True, True, True), cell=config.cell + ) + sender, receiver = edge_index + vectors = ( + config.positions[receiver] - config.positions[sender] + shifts + ) # [n_edges, 3] + assert vectors.shape == (12, 3) # 12 neighbors in close-packed bulk + assert np.allclose( + np.linalg.norm(vectors, axis=-1), + dist, + ) + + +def test_half_periodic(): + atoms = ase.build.fcc111("Al", size=(3, 3, 1), vacuum=0.0) + assert all(atoms.pbc == (True, True, False)) + config = config_from_atoms(atoms) # first shell dist is 2.864A + edge_index, shifts, _, _ = get_neighborhood( + config.positions, cutoff=2.9, pbc=(True, True, False), cell=config.cell + ) + sender, receiver = edge_index + vectors = ( + config.positions[receiver] - config.positions[sender] + shifts + ) # [n_edges, 3] + # Check number of neighbors: + _, neighbor_count = np.unique(edge_index[0], return_counts=True) + assert (neighbor_count == 6).all() # 6 neighbors + # Check not periodic in z + assert np.allclose( + vectors[:, 2], + np.zeros(vectors.shape[0]), + ) diff --git a/mace-bench/3rdparty/mace/tests/test_finetuning_select.py b/mace-bench/3rdparty/mace/tests/test_finetuning_select.py index d8d9701..a58c8b3 100644 --- a/mace-bench/3rdparty/mace/tests/test_finetuning_select.py +++ b/mace-bench/3rdparty/mace/tests/test_finetuning_select.py @@ -1,164 +1,164 @@ -import ase.io as aio -import numpy as np -import pytest -from ase import Atoms -from ase.build import molecule - -from mace.cli.fine_tuning_select import ( - FilteringType, - SelectionSettings, - SubselectType, - _filter_pretraining_data, - _load_descriptors, - _maybe_save_descriptors, - filter_atoms, - select_samples, -) - - -@pytest.fixture(name="train_atoms_fixture") -def train_atoms(): - return [ - molecule("H2O"), - molecule("CH4"), - Atoms("Fe2O3"), - Atoms("C"), - Atoms("FeON"), - Atoms("Fe"), - ] - - -@pytest.fixture(name="train_atom_descriptors_fixture") -def train_atom_descriptors(train_atoms_fixture): - return [ - {x: np.zeros(5) + i for x in atoms.symbols} - for i, atoms in enumerate(train_atoms_fixture) - ] - - -@pytest.mark.parametrize( - "filtering_type, passes_filter, element_sublist", - [ - (FilteringType.NONE, [True] * 6, []), - (FilteringType.NONE, [True] * 6, ["C", "U", "Anything really"]), - ( - FilteringType.COMBINATIONS, - [False, False, True, False, False, True], - ["O", "Fe"], - ), - ( - FilteringType.INCLUSIVE, - [False, False, True, False, True, False], - ["O", "Fe"], - ), - ( - FilteringType.EXCLUSIVE, - [False, False, True, False, False, False], - ["O", "Fe"], - ), - ], -) -def test_filter_data( - train_atoms_fixture, filtering_type, passes_filter, element_sublist -): - filtered, _, passes = _filter_pretraining_data( - train_atoms_fixture, filtering_type, element_sublist - ) - assert passes == passes_filter - assert len(filtered) == sum(passes_filter) - - -@pytest.mark.parametrize( - "passes_filter", [[True] * 6, [False, True, False, True, False, True]] -) -def test_load_descriptors( - train_atoms_fixture, train_atom_descriptors_fixture, passes_filter, tmp_path -): - for i, atoms in enumerate(train_atoms_fixture): - atoms.info["mace_descriptors"] = train_atom_descriptors_fixture[i] - save_path = tmp_path / "test.xyz" - _maybe_save_descriptors(train_atoms_fixture, save_path.as_posix()) - assert all(not "mace_descriptors" in atoms.info for atoms in train_atoms_fixture) - filtered_atoms = [ - x for x, passes in zip(train_atoms_fixture, passes_filter) if passes - ] - descriptors_path = save_path.as_posix().replace(".xyz", "_descriptors.npy") - - _load_descriptors( - filtered_atoms, - passes_filter, - descriptors_path=descriptors_path, - calc=None, - full_data_length=len(train_atoms_fixture), - ) - expected_descriptors = [ - train_atom_descriptors_fixture[i] - for i, passes in enumerate(passes_filter) - if passes - ] - for i, atoms in enumerate(filtered_atoms): - assert "mace_descriptors" in atoms.info - for key, value in expected_descriptors[i].items(): - assert np.allclose(atoms.info["mace_descriptors"][key], value) - - -def test_select_samples_random(train_atoms_fixture, tmp_path): - input_file_path = tmp_path / "input.xyz" - aio.write(input_file_path, train_atoms_fixture, format="extxyz") - output_file_path = tmp_path / "output.xyz" - - settings = SelectionSettings( - configs_pt=input_file_path.as_posix(), - output=output_file_path.as_posix(), - num_samples=2, - subselect=SubselectType.RANDOM, - filtering_type=FilteringType.NONE, - ) - select_samples(settings) - - # Check if output file is created - assert output_file_path.exists() - combined_output_file_path = tmp_path / "output_combined.xyz" - assert combined_output_file_path.exists() - - output_atoms = aio.read(output_file_path, index=":") - assert isinstance(output_atoms, list) - assert len(output_atoms) == 2 - - combined_output_atoms = aio.read(combined_output_file_path, index=":") - assert isinstance(combined_output_atoms, list) - assert ( - len(combined_output_atoms) == 2 - ) # combined same as output since no FT data provided - - -def test_select_samples_ft_provided(train_atoms_fixture, tmp_path): - input_file_path = tmp_path / "input.xyz" - aio.write(input_file_path, train_atoms_fixture, format="extxyz") - output_file_path = tmp_path / "output.xyz" - ft_file_path = tmp_path / "ft_data.xyz" - ft_data = [Atoms("FeO")] - aio.write(ft_file_path.as_posix(), ft_data, format="extxyz") - - settings = SelectionSettings( - configs_pt=input_file_path.as_posix(), - output=output_file_path.as_posix(), - num_samples=2, - subselect=SubselectType.RANDOM, - configs_ft=ft_file_path.as_posix(), - ) - select_samples(settings) - - # Check if output file is created - assert output_file_path.exists() - combined_output_file_path = tmp_path / "output_combined.xyz" - assert combined_output_file_path.exists() - - output_atoms = aio.read(output_file_path, index=":") - assert isinstance(output_atoms, list) - assert len(output_atoms) == 2 - assert all(filter_atoms(x, ["Fe", "O"]) for x in output_atoms) - - combined_atoms = aio.read(combined_output_file_path, index=":") - assert isinstance(combined_atoms, list) - assert len(combined_atoms) == len(output_atoms) + len(ft_data) +import ase.io as aio +import numpy as np +import pytest +from ase import Atoms +from ase.build import molecule + +from mace.cli.fine_tuning_select import ( + FilteringType, + SelectionSettings, + SubselectType, + _filter_pretraining_data, + _load_descriptors, + _maybe_save_descriptors, + filter_atoms, + select_samples, +) + + +@pytest.fixture(name="train_atoms_fixture") +def train_atoms(): + return [ + molecule("H2O"), + molecule("CH4"), + Atoms("Fe2O3"), + Atoms("C"), + Atoms("FeON"), + Atoms("Fe"), + ] + + +@pytest.fixture(name="train_atom_descriptors_fixture") +def train_atom_descriptors(train_atoms_fixture): + return [ + {x: np.zeros(5) + i for x in atoms.symbols} + for i, atoms in enumerate(train_atoms_fixture) + ] + + +@pytest.mark.parametrize( + "filtering_type, passes_filter, element_sublist", + [ + (FilteringType.NONE, [True] * 6, []), + (FilteringType.NONE, [True] * 6, ["C", "U", "Anything really"]), + ( + FilteringType.COMBINATIONS, + [False, False, True, False, False, True], + ["O", "Fe"], + ), + ( + FilteringType.INCLUSIVE, + [False, False, True, False, True, False], + ["O", "Fe"], + ), + ( + FilteringType.EXCLUSIVE, + [False, False, True, False, False, False], + ["O", "Fe"], + ), + ], +) +def test_filter_data( + train_atoms_fixture, filtering_type, passes_filter, element_sublist +): + filtered, _, passes = _filter_pretraining_data( + train_atoms_fixture, filtering_type, element_sublist + ) + assert passes == passes_filter + assert len(filtered) == sum(passes_filter) + + +@pytest.mark.parametrize( + "passes_filter", [[True] * 6, [False, True, False, True, False, True]] +) +def test_load_descriptors( + train_atoms_fixture, train_atom_descriptors_fixture, passes_filter, tmp_path +): + for i, atoms in enumerate(train_atoms_fixture): + atoms.info["mace_descriptors"] = train_atom_descriptors_fixture[i] + save_path = tmp_path / "test.xyz" + _maybe_save_descriptors(train_atoms_fixture, save_path.as_posix()) + assert all(not "mace_descriptors" in atoms.info for atoms in train_atoms_fixture) + filtered_atoms = [ + x for x, passes in zip(train_atoms_fixture, passes_filter) if passes + ] + descriptors_path = save_path.as_posix().replace(".xyz", "_descriptors.npy") + + _load_descriptors( + filtered_atoms, + passes_filter, + descriptors_path=descriptors_path, + calc=None, + full_data_length=len(train_atoms_fixture), + ) + expected_descriptors = [ + train_atom_descriptors_fixture[i] + for i, passes in enumerate(passes_filter) + if passes + ] + for i, atoms in enumerate(filtered_atoms): + assert "mace_descriptors" in atoms.info + for key, value in expected_descriptors[i].items(): + assert np.allclose(atoms.info["mace_descriptors"][key], value) + + +def test_select_samples_random(train_atoms_fixture, tmp_path): + input_file_path = tmp_path / "input.xyz" + aio.write(input_file_path, train_atoms_fixture, format="extxyz") + output_file_path = tmp_path / "output.xyz" + + settings = SelectionSettings( + configs_pt=input_file_path.as_posix(), + output=output_file_path.as_posix(), + num_samples=2, + subselect=SubselectType.RANDOM, + filtering_type=FilteringType.NONE, + ) + select_samples(settings) + + # Check if output file is created + assert output_file_path.exists() + combined_output_file_path = tmp_path / "output_combined.xyz" + assert combined_output_file_path.exists() + + output_atoms = aio.read(output_file_path, index=":") + assert isinstance(output_atoms, list) + assert len(output_atoms) == 2 + + combined_output_atoms = aio.read(combined_output_file_path, index=":") + assert isinstance(combined_output_atoms, list) + assert ( + len(combined_output_atoms) == 2 + ) # combined same as output since no FT data provided + + +def test_select_samples_ft_provided(train_atoms_fixture, tmp_path): + input_file_path = tmp_path / "input.xyz" + aio.write(input_file_path, train_atoms_fixture, format="extxyz") + output_file_path = tmp_path / "output.xyz" + ft_file_path = tmp_path / "ft_data.xyz" + ft_data = [Atoms("FeO")] + aio.write(ft_file_path.as_posix(), ft_data, format="extxyz") + + settings = SelectionSettings( + configs_pt=input_file_path.as_posix(), + output=output_file_path.as_posix(), + num_samples=2, + subselect=SubselectType.RANDOM, + configs_ft=ft_file_path.as_posix(), + ) + select_samples(settings) + + # Check if output file is created + assert output_file_path.exists() + combined_output_file_path = tmp_path / "output_combined.xyz" + assert combined_output_file_path.exists() + + output_atoms = aio.read(output_file_path, index=":") + assert isinstance(output_atoms, list) + assert len(output_atoms) == 2 + assert all(filter_atoms(x, ["Fe", "O"]) for x in output_atoms) + + combined_atoms = aio.read(combined_output_file_path, index=":") + assert isinstance(combined_atoms, list) + assert len(combined_atoms) == len(output_atoms) + len(ft_data) diff --git a/mace-bench/3rdparty/mace/tests/test_foundations.py b/mace-bench/3rdparty/mace/tests/test_foundations.py index cb19a3e..c864183 100644 --- a/mace-bench/3rdparty/mace/tests/test_foundations.py +++ b/mace-bench/3rdparty/mace/tests/test_foundations.py @@ -1,512 +1,512 @@ -from pathlib import Path - -import numpy as np -import pytest -import torch -import torch.nn.functional -from ase.build import molecule -from e3nn import o3 -from e3nn.util import jit -from scipy.spatial.transform import Rotation as R - -from mace import data, modules, tools -from mace.calculators import mace_mp, mace_off -from mace.tools import torch_geometric -from mace.tools.finetuning_utils import load_foundations_elements -from mace.tools.scripts_utils import extract_config_mace_model, remove_pt_head -from mace.tools.utils import AtomicNumberTable - -MODEL_PATH = ( - Path(__file__).parent.parent - / "mace" - / "calculators" - / "foundations_models" - / "2023-12-03-mace-mp.model" -) - -torch.set_default_dtype(torch.float64) - -@pytest.skip("Problem with the float type", allow_module_level=True) -def test_foundations(): - # Create MACE model - config = data.Configuration( - atomic_numbers=molecule("H2COH").numbers, - positions=molecule("H2COH").positions, - properties={ - "forces": molecule("H2COH").positions, - "energy": -1.5, - "charges": molecule("H2COH").numbers, - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, - ) - - # Created the rotated environment - rot = R.from_euler("z", 60, degrees=True).as_matrix() - positions_rotated = np.array(rot @ config.positions.T).T - config_rotated = data.Configuration( - atomic_numbers=molecule("H2COH").numbers, - positions=positions_rotated, - properties={ - "forces": molecule("H2COH").positions, - "energy": -1.5, - "charges": molecule("H2COH").numbers, - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, - ) - table = tools.AtomicNumberTable([1, 6, 8]) - atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float) - model_config = dict( - r_max=6, - num_bessel=10, - num_polynomial_cutoff=5, - max_ell=3, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=3, - hidden_irreps=o3.Irreps("128x0e + 128x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, - avg_num_neighbors=3, - atomic_numbers=table.zs, - correlation=3, - radial_type="bessel", - atomic_inter_scale=0.1, - atomic_inter_shift=0.0, - ) - model = modules.ScaleShiftMACE(**model_config) - calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") - model_loaded = load_foundations_elements( - model, - calc_foundation.models[0], - table=table, - load_readout=True, - use_shift=False, - max_L=1, - ) - atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=6.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=6.0 - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - forces_loaded = model_loaded(batch.to_dict())["forces"] - forces = model(batch.to_dict())["forces"] - assert torch.allclose(forces, forces_loaded) - - -def test_multi_reference(): - config_multi = data.Configuration( - atomic_numbers=molecule("H2COH").numbers, - positions=molecule("H2COH").positions, - properties={ - "forces": molecule("H2COH").positions, - "energy": -1.5, - "charges": molecule("H2COH").numbers, - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, - head="MP2", - ) - table_multi = tools.AtomicNumberTable([1, 6, 8]) - atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) - table = tools.AtomicNumberTable([1, 6, 8]) - - - # Create MACE model - model_config = dict( - r_max=6, - num_bessel=10, - num_polynomial_cutoff=5, - max_ell=3, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=3, - hidden_irreps=o3.Irreps("128x0e + 128x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies_multi, - avg_num_neighbors=61, - atomic_numbers=table.zs, - correlation=3, - radial_type="bessel", - atomic_inter_scale=[1.0, 1.0], - atomic_inter_shift=[0.0, 0.0], - heads=["MP2", "DFT"], - ) - model = modules.ScaleShiftMACE(**model_config) - calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") - model_loaded = load_foundations_elements( - model, - calc_foundation.models[0], - table=table, - load_readout=True, - use_shift=False, - max_L=1, - ) - atomic_data = data.AtomicData.from_config( - config_multi, z_table=table_multi, cutoff=6.0, heads=["MP2", "DFT"] - ) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - forces_loaded = model_loaded(batch.to_dict())["forces"] - calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") - atoms = molecule("H2COH") - atoms.info["head"] = "MP2" - atoms.calc = calc_foundation - forces = atoms.get_forces() - assert np.allclose( - forces, forces_loaded.detach().numpy()[:5, :], atol=1e-5, rtol=1e-5 - ) - - -@pytest.mark.parametrize( - "calc", - [ - mace_mp(device="cpu", default_dtype="float64"), - mace_mp(model="small", device="cpu", default_dtype="float64"), - mace_mp(model="medium", device="cpu", default_dtype="float64"), - mace_mp(model="large", device="cpu", default_dtype="float64"), - mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"), - mace_off(model="small", device="cpu", default_dtype="float64"), - mace_off(model="medium", device="cpu", default_dtype="float64"), - mace_off(model="large", device="cpu", default_dtype="float64"), - mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"), - ], -) -def test_compile_foundation(calc): - model = calc.models[0] - atoms = molecule("CH4") - atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1 - batch = calc._atoms_to_batch(atoms) # pylint: disable=protected-access - output_1 = model(batch.to_dict()) - model_compiled = jit.compile(model) - output = model_compiled(batch.to_dict()) - for key in output_1.keys(): - if isinstance(output_1[key], torch.Tensor): - assert torch.allclose(output_1[key], output[key], atol=1e-5) - - -@pytest.mark.parametrize( - "model", - [ - mace_mp(model="small", device="cpu", default_dtype="float64").models[0], - mace_mp(model="medium", device="cpu", default_dtype="float64").models[0], - mace_mp(model="large", device="cpu", default_dtype="float64").models[0], - mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], - mace_off(model="small", device="cpu", default_dtype="float64").models[0], - mace_off(model="medium", device="cpu", default_dtype="float64").models[0], - mace_off(model="large", device="cpu", default_dtype="float64").models[0], - mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], - ], -) -def test_extract_config(model): - assert isinstance(model, modules.ScaleShiftMACE) - config = data.Configuration( - atomic_numbers=molecule("H2COH").numbers, - positions=molecule("H2COH").positions, - properties={ - "forces": molecule("H2COH").positions, - "energy": -1.5, - "charges": molecule("H2COH").numbers, - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, - ) - model_copy = modules.ScaleShiftMACE(**extract_config_mace_model(model)) - model_copy.load_state_dict(model.state_dict()) - z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) - atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=6.0) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - output = model(batch.to_dict()) - output_copy = model_copy(batch.to_dict()) - # assert all items of the output dicts are equal - for key in output.keys(): - if isinstance(output[key], torch.Tensor): - assert torch.allclose(output[key], output_copy[key], atol=1e-5) - - -def test_remove_pt_head(): - # Set up test data - torch.manual_seed(42) - atomic_energies_pt_head = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float) - z_table = AtomicNumberTable([1, 8]) # H and O - - # Create multihead model - model_config = { - "r_max": 5.0, - "num_bessel": 8, - "num_polynomial_cutoff": 5, - "max_ell": 2, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "num_interactions": 2, - "num_elements": len(z_table), - "hidden_irreps": o3.Irreps("32x0e + 32x1o"), - "MLP_irreps": o3.Irreps("16x0e"), - "gate": torch.nn.functional.silu, - "atomic_energies": atomic_energies_pt_head, - "avg_num_neighbors": 8, - "atomic_numbers": z_table.zs, - "correlation": 3, - "heads": ["pt_head", "DFT"], - "atomic_inter_scale": [1.0, 1.0], - "atomic_inter_shift": [0.0, 0.1], - } - - model = modules.ScaleShiftMACE(**model_config) - - # Create test molecule - mol = molecule("H2O") - config_pt_head = data.Configuration( - atomic_numbers=mol.numbers, - positions=mol.positions, - properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)}, - property_weights={"forces": 1.0, "energy": 1.0}, - head="DFT", - ) - atomic_data = data.AtomicData.from_config( - config_pt_head, z_table=z_table, cutoff=5.0, heads=["pt_head", "DFT"] - ) - dataloader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data], batch_size=1, shuffle=False - ) - batch = next(iter(dataloader)) - # Test original mode - output_orig = model(batch.to_dict()) - - # Convert to single head model - new_model = remove_pt_head(model, head_to_keep="DFT") - - # Basic structure tests - assert len(new_model.heads) == 1 - assert new_model.heads[0] == "DFT" - assert new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 - assert len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 - assert len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 - - # Test output consistency - atomic_data = data.AtomicData.from_config( - config_pt_head, z_table=z_table, cutoff=5.0, heads=["DFT"] - ) - dataloader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data], batch_size=1, shuffle=False - ) - batch = next(iter(dataloader)) - output_new = new_model(batch.to_dict()) - torch.testing.assert_close( - output_orig["energy"], output_new["energy"], rtol=1e-5, atol=1e-5 - ) - torch.testing.assert_close( - output_orig["forces"], output_new["forces"], rtol=1e-5, atol=1e-5 - ) - - -def test_remove_pt_head_multihead(): - # Set up test data - torch.manual_seed(42) - atomic_energies_pt_head = np.array( - [ - [1.0, 2.0], # H energies for each head - [3.0, 4.0], # O energies for each head - ] - * 2 - ) - z_table = AtomicNumberTable([1, 8]) # H and O - - # Create multihead model - model_config = { - "r_max": 5.0, - "num_bessel": 8, - "num_polynomial_cutoff": 5, - "max_ell": 2, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "num_interactions": 2, - "num_elements": len(z_table), - "hidden_irreps": o3.Irreps("32x0e + 32x1o"), - "MLP_irreps": o3.Irreps("16x0e"), - "gate": torch.nn.functional.silu, - "atomic_energies": atomic_energies_pt_head, - "avg_num_neighbors": 8, - "atomic_numbers": z_table.zs, - "correlation": 3, - "heads": ["pt_head", "DFT", "MP2", "CCSD"], - "atomic_inter_scale": [1.0, 1.0, 1.0, 1.0], - "atomic_inter_shift": [0.0, 0.1, 0.2, 0.3], - } - - model = modules.ScaleShiftMACE(**model_config) - - # Create test configurations for each head - mol = molecule("H2O") - configs = {} - atomic_datas = {} - dataloaders = {} - original_outputs = {} - - # First get outputs from original model for each head - for head in model.heads: - config_pt_head = data.Configuration( - atomic_numbers=mol.numbers, - positions=mol.positions, - properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)}, - property_weights={"forces": 1.0, "energy": 1.0}, - head=head, - ) - configs[head] = config_pt_head - - atomic_data = data.AtomicData.from_config( - config_pt_head, z_table=z_table, cutoff=5.0, heads=model.heads - ) - atomic_datas[head] = atomic_data - - dataloader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data], batch_size=1, shuffle=False - ) - dataloaders[head] = dataloader - - batch = next(iter(dataloader)) - output = model(batch.to_dict()) - original_outputs[head] = output - - # Now test each head separately - for i, head in enumerate(model.heads): - # Convert to single head model - new_model = remove_pt_head(model, head_to_keep=head) - - # Basic structure tests - assert len(new_model.heads) == 1, f"Failed for head {head}" - assert new_model.heads[0] == head, f"Failed for head {head}" - assert ( - new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 - ), f"Failed for head {head}" - assert ( - len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 - ), f"Failed for head {head}" - assert ( - len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 - ), f"Failed for head {head}" - - # Verify scale and shift values - assert torch.allclose( - new_model.scale_shift.scale, model.scale_shift.scale[i : i + 1] - ), f"Failed for head {head}" - assert torch.allclose( - new_model.scale_shift.shift, model.scale_shift.shift[i : i + 1] - ), f"Failed for head {head}" - - # Test output consistency - single_head_data = data.AtomicData.from_config( - configs[head], z_table=z_table, cutoff=5.0, heads=[head] - ) - single_head_loader = torch_geometric.dataloader.DataLoader( - dataset=[single_head_data], batch_size=1, shuffle=False - ) - batch = next(iter(single_head_loader)) - new_output = new_model(batch.to_dict()) - - # Compare outputs - print( - original_outputs[head]["energy"], - new_output["energy"], - ) - torch.testing.assert_close( - original_outputs[head]["energy"], - new_output["energy"], - rtol=1e-5, - atol=1e-5, - msg=f"Energy mismatch for head {head}", - ) - torch.testing.assert_close( - original_outputs[head]["forces"], - new_output["forces"], - rtol=1e-5, - atol=1e-5, - msg=f"Forces mismatch for head {head}", - ) - - # Test error cases - with pytest.raises(ValueError, match="Head non_existent not found in model"): - remove_pt_head(model, head_to_keep="non_existent") - - # Test default behavior (first non-PT head) - default_model = remove_pt_head(model) - assert default_model.heads[0] == "DFT" - - # Additional test: check if each model's computation graph is independent - models = {head: remove_pt_head(model, head_to_keep=head) for head in model.heads} - results = {} - - for head, head_model in models.items(): - single_head_data = data.AtomicData.from_config( - configs[head], z_table=z_table, cutoff=5.0, heads=[head] - ) - single_head_loader = torch_geometric.dataloader.DataLoader( - dataset=[single_head_data], batch_size=1, shuffle=False - ) - batch = next(iter(single_head_loader)) - results[head] = head_model(batch.to_dict()) - - # Verify each model produces different outputs - energies = torch.stack([results[head]["energy"] for head in model.heads]) - assert not torch.allclose( - energies[0], energies[1], rtol=1e-3 - ), "Different heads should produce different outputs" +from pathlib import Path + +import numpy as np +import pytest +import torch +import torch.nn.functional +from ase.build import molecule +from e3nn import o3 +from e3nn.util import jit +from scipy.spatial.transform import Rotation as R + +from mace import data, modules, tools +from mace.calculators import mace_mp, mace_off +from mace.tools import torch_geometric +from mace.tools.finetuning_utils import load_foundations_elements +from mace.tools.scripts_utils import extract_config_mace_model, remove_pt_head +from mace.tools.utils import AtomicNumberTable + +MODEL_PATH = ( + Path(__file__).parent.parent + / "mace" + / "calculators" + / "foundations_models" + / "2023-12-03-mace-mp.model" +) + +torch.set_default_dtype(torch.float64) + +@pytest.skip("Problem with the float type", allow_module_level=True) +def test_foundations(): + # Create MACE model + config = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + ) + + # Created the rotated environment + rot = R.from_euler("z", 60, degrees=True).as_matrix() + positions_rotated = np.array(rot @ config.positions.T).T + config_rotated = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=positions_rotated, + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + ) + table = tools.AtomicNumberTable([1, 6, 8]) + atomic_energies = np.array([0.0, 0.0, 0.0], dtype=float) + model_config = dict( + r_max=6, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=3, + hidden_irreps=o3.Irreps("128x0e + 128x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + atomic_inter_scale=0.1, + atomic_inter_shift=0.0, + ) + model = modules.ScaleShiftMACE(**model_config) + calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") + model_loaded = load_foundations_elements( + model, + calc_foundation.models[0], + table=table, + load_readout=True, + use_shift=False, + max_L=1, + ) + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=6.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=6.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + forces_loaded = model_loaded(batch.to_dict())["forces"] + forces = model(batch.to_dict())["forces"] + assert torch.allclose(forces, forces_loaded) + + +def test_multi_reference(): + config_multi = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + head="MP2", + ) + table_multi = tools.AtomicNumberTable([1, 6, 8]) + atomic_energies_multi = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=float) + table = tools.AtomicNumberTable([1, 6, 8]) + + + # Create MACE model + model_config = dict( + r_max=6, + num_bessel=10, + num_polynomial_cutoff=5, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=3, + hidden_irreps=o3.Irreps("128x0e + 128x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies_multi, + avg_num_neighbors=61, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.0], + heads=["MP2", "DFT"], + ) + model = modules.ScaleShiftMACE(**model_config) + calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") + model_loaded = load_foundations_elements( + model, + calc_foundation.models[0], + table=table, + load_readout=True, + use_shift=False, + max_L=1, + ) + atomic_data = data.AtomicData.from_config( + config_multi, z_table=table_multi, cutoff=6.0, heads=["MP2", "DFT"] + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + forces_loaded = model_loaded(batch.to_dict())["forces"] + calc_foundation = mace_mp(model="medium", device="cpu", default_dtype="float64") + atoms = molecule("H2COH") + atoms.info["head"] = "MP2" + atoms.calc = calc_foundation + forces = atoms.get_forces() + assert np.allclose( + forces, forces_loaded.detach().numpy()[:5, :], atol=1e-5, rtol=1e-5 + ) + + +@pytest.mark.parametrize( + "calc", + [ + mace_mp(device="cpu", default_dtype="float64"), + mace_mp(model="small", device="cpu", default_dtype="float64"), + mace_mp(model="medium", device="cpu", default_dtype="float64"), + mace_mp(model="large", device="cpu", default_dtype="float64"), + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"), + mace_off(model="small", device="cpu", default_dtype="float64"), + mace_off(model="medium", device="cpu", default_dtype="float64"), + mace_off(model="large", device="cpu", default_dtype="float64"), + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"), + ], +) +def test_compile_foundation(calc): + model = calc.models[0] + atoms = molecule("CH4") + atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1 + batch = calc._atoms_to_batch(atoms) # pylint: disable=protected-access + output_1 = model(batch.to_dict()) + model_compiled = jit.compile(model) + output = model_compiled(batch.to_dict()) + for key in output_1.keys(): + if isinstance(output_1[key], torch.Tensor): + assert torch.allclose(output_1[key], output[key], atol=1e-5) + + +@pytest.mark.parametrize( + "model", + [ + mace_mp(model="small", device="cpu", default_dtype="float64").models[0], + mace_mp(model="medium", device="cpu", default_dtype="float64").models[0], + mace_mp(model="large", device="cpu", default_dtype="float64").models[0], + mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], + mace_off(model="small", device="cpu", default_dtype="float64").models[0], + mace_off(model="medium", device="cpu", default_dtype="float64").models[0], + mace_off(model="large", device="cpu", default_dtype="float64").models[0], + mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64").models[0], + ], +) +def test_extract_config(model): + assert isinstance(model, modules.ScaleShiftMACE) + config = data.Configuration( + atomic_numbers=molecule("H2COH").numbers, + positions=molecule("H2COH").positions, + properties={ + "forces": molecule("H2COH").positions, + "energy": -1.5, + "charges": molecule("H2COH").numbers, + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, + ) + model_copy = modules.ScaleShiftMACE(**extract_config_mace_model(model)) + model_copy.load_state_dict(model.state_dict()) + z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) + atomic_data = data.AtomicData.from_config(config, z_table=z_table, cutoff=6.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model(batch.to_dict()) + output_copy = model_copy(batch.to_dict()) + # assert all items of the output dicts are equal + for key in output.keys(): + if isinstance(output[key], torch.Tensor): + assert torch.allclose(output[key], output_copy[key], atol=1e-5) + + +def test_remove_pt_head(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=float) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT"], + "atomic_inter_scale": [1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test molecule + mol = molecule("H2O") + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)}, + property_weights={"forces": 1.0, "energy": 1.0}, + head="DFT", + ) + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["pt_head", "DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + # Test original mode + output_orig = model(batch.to_dict()) + + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep="DFT") + + # Basic structure tests + assert len(new_model.heads) == 1 + assert new_model.heads[0] == "DFT" + assert new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + assert len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + assert len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + + # Test output consistency + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=["DFT"] + ) + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + batch = next(iter(dataloader)) + output_new = new_model(batch.to_dict()) + torch.testing.assert_close( + output_orig["energy"], output_new["energy"], rtol=1e-5, atol=1e-5 + ) + torch.testing.assert_close( + output_orig["forces"], output_new["forces"], rtol=1e-5, atol=1e-5 + ) + + +def test_remove_pt_head_multihead(): + # Set up test data + torch.manual_seed(42) + atomic_energies_pt_head = np.array( + [ + [1.0, 2.0], # H energies for each head + [3.0, 4.0], # O energies for each head + ] + * 2 + ) + z_table = AtomicNumberTable([1, 8]) # H and O + + # Create multihead model + model_config = { + "r_max": 5.0, + "num_bessel": 8, + "num_polynomial_cutoff": 5, + "max_ell": 2, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": len(z_table), + "hidden_irreps": o3.Irreps("32x0e + 32x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": torch.nn.functional.silu, + "atomic_energies": atomic_energies_pt_head, + "avg_num_neighbors": 8, + "atomic_numbers": z_table.zs, + "correlation": 3, + "heads": ["pt_head", "DFT", "MP2", "CCSD"], + "atomic_inter_scale": [1.0, 1.0, 1.0, 1.0], + "atomic_inter_shift": [0.0, 0.1, 0.2, 0.3], + } + + model = modules.ScaleShiftMACE(**model_config) + + # Create test configurations for each head + mol = molecule("H2O") + configs = {} + atomic_datas = {} + dataloaders = {} + original_outputs = {} + + # First get outputs from original model for each head + for head in model.heads: + config_pt_head = data.Configuration( + atomic_numbers=mol.numbers, + positions=mol.positions, + properties={"energy": 1.0, "forces": np.random.randn(len(mol), 3)}, + property_weights={"forces": 1.0, "energy": 1.0}, + head=head, + ) + configs[head] = config_pt_head + + atomic_data = data.AtomicData.from_config( + config_pt_head, z_table=z_table, cutoff=5.0, heads=model.heads + ) + atomic_datas[head] = atomic_data + + dataloader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], batch_size=1, shuffle=False + ) + dataloaders[head] = dataloader + + batch = next(iter(dataloader)) + output = model(batch.to_dict()) + original_outputs[head] = output + + # Now test each head separately + for i, head in enumerate(model.heads): + # Convert to single head model + new_model = remove_pt_head(model, head_to_keep=head) + + # Basic structure tests + assert len(new_model.heads) == 1, f"Failed for head {head}" + assert new_model.heads[0] == head, f"Failed for head {head}" + assert ( + new_model.atomic_energies_fn.atomic_energies.shape[0] == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.scale)) == 1 + ), f"Failed for head {head}" + assert ( + len(torch.atleast_1d(new_model.scale_shift.shift)) == 1 + ), f"Failed for head {head}" + + # Verify scale and shift values + assert torch.allclose( + new_model.scale_shift.scale, model.scale_shift.scale[i : i + 1] + ), f"Failed for head {head}" + assert torch.allclose( + new_model.scale_shift.shift, model.scale_shift.shift[i : i + 1] + ), f"Failed for head {head}" + + # Test output consistency + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + new_output = new_model(batch.to_dict()) + + # Compare outputs + print( + original_outputs[head]["energy"], + new_output["energy"], + ) + torch.testing.assert_close( + original_outputs[head]["energy"], + new_output["energy"], + rtol=1e-5, + atol=1e-5, + msg=f"Energy mismatch for head {head}", + ) + torch.testing.assert_close( + original_outputs[head]["forces"], + new_output["forces"], + rtol=1e-5, + atol=1e-5, + msg=f"Forces mismatch for head {head}", + ) + + # Test error cases + with pytest.raises(ValueError, match="Head non_existent not found in model"): + remove_pt_head(model, head_to_keep="non_existent") + + # Test default behavior (first non-PT head) + default_model = remove_pt_head(model) + assert default_model.heads[0] == "DFT" + + # Additional test: check if each model's computation graph is independent + models = {head: remove_pt_head(model, head_to_keep=head) for head in model.heads} + results = {} + + for head, head_model in models.items(): + single_head_data = data.AtomicData.from_config( + configs[head], z_table=z_table, cutoff=5.0, heads=[head] + ) + single_head_loader = torch_geometric.dataloader.DataLoader( + dataset=[single_head_data], batch_size=1, shuffle=False + ) + batch = next(iter(single_head_loader)) + results[head] = head_model(batch.to_dict()) + + # Verify each model produces different outputs + energies = torch.stack([results[head]["energy"] for head in model.heads]) + assert not torch.allclose( + energies[0], energies[1], rtol=1e-3 + ), "Different heads should produce different outputs" diff --git a/mace-bench/3rdparty/mace/tests/test_hessian.py b/mace-bench/3rdparty/mace/tests/test_hessian.py index 5e23e82..5345733 100644 --- a/mace-bench/3rdparty/mace/tests/test_hessian.py +++ b/mace-bench/3rdparty/mace/tests/test_hessian.py @@ -1,54 +1,54 @@ -import numpy as np -import pytest -from ase.build import fcc111 - -from mace.calculators import mace_mp - - -@pytest.fixture(name="setup_calculator_") -def setup_calculator(): - calc = mace_mp( - model="medium", dispersion=False, default_dtype="float64", device="cpu" - ) - return calc - - -@pytest.fixture(name="setup_structure_") -def setup_structure(setup_calculator_): - initial = fcc111("Pt", size=(4, 4, 1), vacuum=10.0, orthogonal=True) - initial.calc = setup_calculator_ - return initial - - -def test_potential_energy_and_hessian(setup_structure_): - initial = setup_structure_ - h_autograd = initial.calc.get_hessian(atoms=initial) - assert h_autograd.shape == (len(initial) * 3, len(initial), 3) - - -def test_finite_difference_hessian(setup_structure_): - initial = setup_structure_ - indicies = list(range(len(initial))) - delta, ndim = 1e-4, 3 - hessian = np.zeros((len(indicies) * ndim, len(indicies) * ndim)) - atoms_h = initial.copy() - for i, index in enumerate(indicies): - for j in range(ndim): - atoms_i = atoms_h.copy() - atoms_i.positions[index, j] += delta - atoms_i.calc = initial.calc - forces_i = atoms_i.get_forces() - - atoms_j = atoms_h.copy() - atoms_j.positions[index, j] -= delta - atoms_j.calc = initial.calc - forces_j = atoms_j.get_forces() - - hessian[:, i * ndim + j] = -(forces_i - forces_j)[indicies].flatten() / ( - 2 * delta - ) - - hessian = hessian.reshape((-1, len(initial), 3)) - h_autograd = initial.calc.get_hessian(atoms=initial) - is_close = np.allclose(h_autograd, hessian, atol=1e-6) - assert is_close +import numpy as np +import pytest +from ase.build import fcc111 + +from mace.calculators import mace_mp + + +@pytest.fixture(name="setup_calculator_") +def setup_calculator(): + calc = mace_mp( + model="medium", dispersion=False, default_dtype="float64", device="cpu" + ) + return calc + + +@pytest.fixture(name="setup_structure_") +def setup_structure(setup_calculator_): + initial = fcc111("Pt", size=(4, 4, 1), vacuum=10.0, orthogonal=True) + initial.calc = setup_calculator_ + return initial + + +def test_potential_energy_and_hessian(setup_structure_): + initial = setup_structure_ + h_autograd = initial.calc.get_hessian(atoms=initial) + assert h_autograd.shape == (len(initial) * 3, len(initial), 3) + + +def test_finite_difference_hessian(setup_structure_): + initial = setup_structure_ + indicies = list(range(len(initial))) + delta, ndim = 1e-4, 3 + hessian = np.zeros((len(indicies) * ndim, len(indicies) * ndim)) + atoms_h = initial.copy() + for i, index in enumerate(indicies): + for j in range(ndim): + atoms_i = atoms_h.copy() + atoms_i.positions[index, j] += delta + atoms_i.calc = initial.calc + forces_i = atoms_i.get_forces() + + atoms_j = atoms_h.copy() + atoms_j.positions[index, j] -= delta + atoms_j.calc = initial.calc + forces_j = atoms_j.get_forces() + + hessian[:, i * ndim + j] = -(forces_i - forces_j)[indicies].flatten() / ( + 2 * delta + ) + + hessian = hessian.reshape((-1, len(initial), 3)) + h_autograd = initial.calc.get_hessian(atoms=initial) + is_close = np.allclose(h_autograd, hessian, atol=1e-6) + assert is_close diff --git a/mace-bench/3rdparty/mace/tests/test_lmdb_database.py b/mace-bench/3rdparty/mace/tests/test_lmdb_database.py index 197661a..0c7043a 100644 --- a/mace-bench/3rdparty/mace/tests/test_lmdb_database.py +++ b/mace-bench/3rdparty/mace/tests/test_lmdb_database.py @@ -1,134 +1,134 @@ -import os -import tempfile - -import numpy as np -import torch -from ase.build import molecule -from ase.calculators.singlepoint import SinglePointCalculator - -from mace.data.lmdb_dataset import LMDBDataset -from mace.tools import AtomicNumberTable, torch_geometric -from mace.tools.fairchem_dataset.lmdb_dataset_tools import LMDBDatabase - - -def test_lmdb_dataset(): - """Test the LMDBDataset by creating a fake database and verifying batch creation.""" - # Set default dtype to match typical MACE usage - torch.set_default_dtype(torch.float64) - - # Set random seed for reproducibility - np.random.seed(42) - - # Create temporary directories for the databases - with tempfile.TemporaryDirectory() as tmpdir: - # Create 3 folders for databases - db_paths = [] - for i in range(3): - folder_path = os.path.join(tmpdir, f"folder_{i}") - os.makedirs(folder_path, exist_ok=True) - - # Create LMDB database files in each folder (2 per folder) - for j in range(2): - db_path = os.path.join(folder_path, f"data_{j}.aselmdb") - db = LMDBDatabase(db_path, readonly=False) - - # Add 2 configurations to each database - for _ in range(2): - # Create a water molecule using ASE's build functionality - atoms = molecule("H2O") - - # Apply small random displacements to the positions - displacement = np.random.rand(*atoms.positions.shape) * 0.1 - atoms.positions += displacement - - # Set cell and PBC - atoms.set_cell(np.eye(3) * 5.0) - atoms.set_pbc(True) - - # Add random energy, forces, and stress - energy = np.random.uniform( - -15.0, -5.0 - ) # Random energy between -15 and -5 eV - forces = ( - np.random.randn(*atoms.positions.shape) * 0.5 - ) # Random forces - stress = np.random.randn(6) * 0.2 # Random stress in Voigt notation - - # Add calculator to atoms with results - calc = SinglePointCalculator( - atoms, energy=energy, forces=forces, stress=stress - ) - atoms.calc = calc - - # Store in database - db.write(atoms) - - db.close() - - # Add folder path to our list - db_paths.append(folder_path) - - # Create the dataset using paths joined with colons - paths_str = ":".join(db_paths) - z_table = AtomicNumberTable([1, 8]) # H and O - dataset = LMDBDataset(file_path=paths_str, r_max=5.0, z_table=z_table) - - # Check dataset size (3 folders * 2 files * 2 configs = 12 entries) - assert len(dataset) == 12 - - # Test retrieving a single item - item = dataset[0] - print(item) - assert item.positions.shape == (3, 3) # 3 atoms, 3 coordinates - assert hasattr(item, "energy") - assert hasattr(item, "forces") - assert hasattr(item, "stress") - - # Create a dataloader - dataloader = torch_geometric.dataloader.DataLoader( - dataset=dataset, batch_size=4, shuffle=False, drop_last=False - ) - - # Get a batch and validate it - batch = next(iter(dataloader)) - - # Verify batch properties - should have 12 atoms (4 configs * 3 atoms per water) - assert batch.positions.shape == (12, 3) # 12 atoms, 3 coordinates - assert batch.energy.shape[0] == 4 # 4 energies (one per config) - assert batch.forces.shape == (12, 3) # Forces for each atom - print(batch.stress.shape) - assert batch.stress.shape == (4, 3, 3) # Stress for each config - - # Check batch has required attributes for MACE model processing - assert hasattr(batch, "batch") # Batch indices - assert batch.batch.shape[0] == 12 # One index per atom - assert hasattr(batch, "ptr") # Pointer for batch processing - assert batch.ptr.shape[0] == 5 # One pointer per config + 1 - - # Check that batch indices are correctly assigned - # First 3 atoms should be from config 0, next 3 from config 1, etc. - expected_batch = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) - assert torch.all(batch.batch == expected_batch) - - # Check ptr correctly points to start of each configuration - assert batch.ptr.tolist() == [0, 3, 6, 9, 12] - - # Create a batch dictionary that can be passed to a MACE model - batch_dict = batch.to_dict() - assert "positions" in batch_dict - assert "energy" in batch_dict - assert "forces" in batch_dict - assert "stress" in batch_dict - assert "batch" in batch_dict - assert "ptr" in batch_dict - - # Verify additional properties required by MACE - assert hasattr(batch, "edge_index") # Connectivity information - assert hasattr(batch, "shifts") # For periodic boundary conditions - assert hasattr(batch, "cell") # Unit cell information - - # Test that a full batch can be processed (without errors) - all_batches = list(dataloader) - assert ( - len(all_batches) == 3 - ) # Should have 3 batches (12 configs with batch size 4) +import os +import tempfile + +import numpy as np +import torch +from ase.build import molecule +from ase.calculators.singlepoint import SinglePointCalculator + +from mace.data.lmdb_dataset import LMDBDataset +from mace.tools import AtomicNumberTable, torch_geometric +from mace.tools.fairchem_dataset.lmdb_dataset_tools import LMDBDatabase + + +def test_lmdb_dataset(): + """Test the LMDBDataset by creating a fake database and verifying batch creation.""" + # Set default dtype to match typical MACE usage + torch.set_default_dtype(torch.float64) + + # Set random seed for reproducibility + np.random.seed(42) + + # Create temporary directories for the databases + with tempfile.TemporaryDirectory() as tmpdir: + # Create 3 folders for databases + db_paths = [] + for i in range(3): + folder_path = os.path.join(tmpdir, f"folder_{i}") + os.makedirs(folder_path, exist_ok=True) + + # Create LMDB database files in each folder (2 per folder) + for j in range(2): + db_path = os.path.join(folder_path, f"data_{j}.aselmdb") + db = LMDBDatabase(db_path, readonly=False) + + # Add 2 configurations to each database + for _ in range(2): + # Create a water molecule using ASE's build functionality + atoms = molecule("H2O") + + # Apply small random displacements to the positions + displacement = np.random.rand(*atoms.positions.shape) * 0.1 + atoms.positions += displacement + + # Set cell and PBC + atoms.set_cell(np.eye(3) * 5.0) + atoms.set_pbc(True) + + # Add random energy, forces, and stress + energy = np.random.uniform( + -15.0, -5.0 + ) # Random energy between -15 and -5 eV + forces = ( + np.random.randn(*atoms.positions.shape) * 0.5 + ) # Random forces + stress = np.random.randn(6) * 0.2 # Random stress in Voigt notation + + # Add calculator to atoms with results + calc = SinglePointCalculator( + atoms, energy=energy, forces=forces, stress=stress + ) + atoms.calc = calc + + # Store in database + db.write(atoms) + + db.close() + + # Add folder path to our list + db_paths.append(folder_path) + + # Create the dataset using paths joined with colons + paths_str = ":".join(db_paths) + z_table = AtomicNumberTable([1, 8]) # H and O + dataset = LMDBDataset(file_path=paths_str, r_max=5.0, z_table=z_table) + + # Check dataset size (3 folders * 2 files * 2 configs = 12 entries) + assert len(dataset) == 12 + + # Test retrieving a single item + item = dataset[0] + print(item) + assert item.positions.shape == (3, 3) # 3 atoms, 3 coordinates + assert hasattr(item, "energy") + assert hasattr(item, "forces") + assert hasattr(item, "stress") + + # Create a dataloader + dataloader = torch_geometric.dataloader.DataLoader( + dataset=dataset, batch_size=4, shuffle=False, drop_last=False + ) + + # Get a batch and validate it + batch = next(iter(dataloader)) + + # Verify batch properties - should have 12 atoms (4 configs * 3 atoms per water) + assert batch.positions.shape == (12, 3) # 12 atoms, 3 coordinates + assert batch.energy.shape[0] == 4 # 4 energies (one per config) + assert batch.forces.shape == (12, 3) # Forces for each atom + print(batch.stress.shape) + assert batch.stress.shape == (4, 3, 3) # Stress for each config + + # Check batch has required attributes for MACE model processing + assert hasattr(batch, "batch") # Batch indices + assert batch.batch.shape[0] == 12 # One index per atom + assert hasattr(batch, "ptr") # Pointer for batch processing + assert batch.ptr.shape[0] == 5 # One pointer per config + 1 + + # Check that batch indices are correctly assigned + # First 3 atoms should be from config 0, next 3 from config 1, etc. + expected_batch = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) + assert torch.all(batch.batch == expected_batch) + + # Check ptr correctly points to start of each configuration + assert batch.ptr.tolist() == [0, 3, 6, 9, 12] + + # Create a batch dictionary that can be passed to a MACE model + batch_dict = batch.to_dict() + assert "positions" in batch_dict + assert "energy" in batch_dict + assert "forces" in batch_dict + assert "stress" in batch_dict + assert "batch" in batch_dict + assert "ptr" in batch_dict + + # Verify additional properties required by MACE + assert hasattr(batch, "edge_index") # Connectivity information + assert hasattr(batch, "shifts") # For periodic boundary conditions + assert hasattr(batch, "cell") # Unit cell information + + # Test that a full batch can be processed (without errors) + all_batches = list(dataloader) + assert ( + len(all_batches) == 3 + ) # Should have 3 batches (12 configs with batch size 4) diff --git a/mace-bench/3rdparty/mace/tests/test_models.py b/mace-bench/3rdparty/mace/tests/test_models.py index 9c1d2a0..40ff48c 100644 --- a/mace-bench/3rdparty/mace/tests/test_models.py +++ b/mace-bench/3rdparty/mace/tests/test_models.py @@ -1,374 +1,374 @@ -import numpy as np -import torch -import torch.nn.functional -from ase import build -from e3nn import o3 -from e3nn.util import jit -from scipy.spatial.transform import Rotation as R - -from mace import data, modules, tools -from mace.tools import torch_geometric - -torch.set_default_dtype(torch.float64) -config = data.Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - "charges": np.array([-2.0, 1.0, 1.0]), - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, -) -# Created the rotated environment -rot = R.from_euler("z", 60, degrees=True).as_matrix() -positions_rotated = np.array(rot @ config.positions.T).T -config_rotated = data.Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=positions_rotated, - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - "charges": np.array([-2.0, 1.0, 1.0]), - "dipole": np.array([-1.5, 1.5, 2.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "charges": 1.0, - "dipole": 1.0, - }, -) -table = tools.AtomicNumberTable([1, 8]) -atomic_energies = np.array([1.0, 3.0], dtype=float) - - -def test_mace(): - # Create MACE model - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=6, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=5, - num_elements=2, - hidden_irreps=o3.Irreps("32x0e + 32x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, - avg_num_neighbors=8, - atomic_numbers=table.zs, - correlation=3, - radial_type="bessel", - ) - model = modules.MACE(**model_config) - model_compiled = jit.compile(model) - - atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0 - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - output1 = model(batch.to_dict(), training=True) - output2 = model_compiled(batch.to_dict(), training=True) - assert torch.allclose(output1["energy"][0], output2["energy"][0]) - assert torch.allclose(output2["energy"][0], output2["energy"][1]) - - -def test_dipole_mace(): - # create dipole MACE model - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=5, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=2, - hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=None, - avg_num_neighbors=3, - atomic_numbers=table.zs, - correlation=3, - radial_type="gaussian", - ) - model = modules.AtomicDipolesMACE(**model_config) - - atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0 - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - output = model( - batch, - training=True, - ) - # sanity check of dipoles being the right shape - assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape - # test equivariance of output dipoles - assert np.allclose( - np.array(rot @ output["dipole"][0].detach().numpy()), - output["dipole"][1].detach().numpy(), - ) - - -def test_energy_dipole_mace(): - # create dipole MACE model - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=5, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=2, - hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies, - avg_num_neighbors=3, - atomic_numbers=table.zs, - correlation=3, - ) - model = modules.EnergyDipolesMACE(**model_config) - - atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0 - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - output = model( - batch, - training=True, - ) - # sanity check of dipoles being the right shape - assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape - # test energy is invariant - assert torch.allclose(output["energy"][0], output["energy"][1]) - # test equivariance of output dipoles - assert np.allclose( - np.array(rot @ output["dipole"][0].detach().numpy()), - output["dipole"][1].detach().numpy(), - ) - - -def test_mace_multi_reference(): - atomic_energies_multi = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) - model_config = dict( - r_max=5, - num_bessel=8, - num_polynomial_cutoff=6, - max_ell=3, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=2, - num_elements=2, - hidden_irreps=o3.Irreps("96x0e + 96x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=atomic_energies_multi, - avg_num_neighbors=8, - atomic_numbers=table.zs, - distance_transform=True, - pair_repulsion=True, - correlation=3, - heads=["Default", "dft"], - # radial_type="chebyshev", - atomic_inter_scale=[1.0, 1.0], - atomic_inter_shift=[0.0, 0.1], - ) - model = modules.ScaleShiftMACE(**model_config) - model_compiled = jit.compile(model) - config.head = "Default" - config_rotated.head = "dft" - atomic_data = data.AtomicData.from_config( - config, z_table=table, cutoff=3.0, heads=["Default", "dft"] - ) - atomic_data2 = data.AtomicData.from_config( - config_rotated, z_table=table, cutoff=3.0, heads=["Default", "dft"] - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data, atomic_data2], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - output1 = model(batch.to_dict(), training=True) - output2 = model_compiled(batch.to_dict(), training=True) - assert torch.allclose(output1["energy"][0], output2["energy"][0]) - assert output2["energy"].shape[0] == 2 - - -def test_atomic_virials_stresses(): - """ - Test that atomic virials and stresses sum to the total virials and stress. - """ - # Set default dtype for reproducibility - torch.set_default_dtype(torch.float64) - - # Create a periodic cell with ASE - atoms = build.bulk("Si", "diamond", a=5.43) - # Apply strain to ensure non-zero stress - strain_tensor = np.eye(3) * 1.02 # 2% strain - atoms.set_cell(np.dot(atoms.get_cell(), strain_tensor), scale_atoms=True) - - # Add forces and energy for completeness - atoms.arrays["REF_forces"] = np.random.normal(0, 0.1, size=atoms.positions.shape) - atoms.info["REF_energy"] = np.random.normal(0, 1) - atoms.info["REF_stress"] = np.random.normal(0, 0.1, size=6) - - # Setup MACE model configuration - stress_z_table = tools.AtomicNumberTable([14]) # Silicon - stress_atomic_energies = np.array([0.0]) - - model_config = dict( - r_max=5.0, - num_bessel=8, - num_polynomial_cutoff=6, - max_ell=2, - interaction_cls=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - interaction_cls_first=modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - num_interactions=3, - num_elements=1, - hidden_irreps=o3.Irreps("32x0e + 32x1o"), - MLP_irreps=o3.Irreps("16x0e"), - gate=torch.nn.functional.silu, - atomic_energies=stress_atomic_energies, - avg_num_neighbors=4.0, - atomic_numbers=table.zs, - correlation=3, - atomic_inter_scale=1.0, - atomic_inter_shift=0.0, - ) - - # Create the model - model = modules.ScaleShiftMACE(**model_config) - - # Create atomic data - atomic_data = data.AtomicData.from_config( - data.config_from_atoms( - atoms, key_specification=data.KeySpecification.from_defaults() - ), - z_table=stress_z_table, - cutoff=5.0, - ) - - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[atomic_data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - batch_dict = batch.to_dict() - - # Run the model with compute_atomic_stresses=True - output = model( - batch_dict, - compute_force=True, - compute_virials=True, - compute_stress=True, - compute_atomic_stresses=True, - ) - - # Get total virials/stress and atomic virials/stresses - total_virials = output["virials"] - atomic_virials = output["atomic_virials"] - total_stress = output["stress"] - atomic_stresses = output["atomic_stresses"] - - # Test that atomic values are not None - assert atomic_virials is not None, "Atomic virials were not computed" - assert atomic_stresses is not None, "Atomic stresses were not computed" - - # Test shape of atomic values - assert atomic_virials.shape[0] == len(atoms), "Wrong shape for atomic virials" - assert atomic_virials.shape[1:] == (3, 3), "Atomic virials should be 3x3 matrices" - assert atomic_stresses.shape[0] == len(atoms), "Wrong shape for atomic stresses" - assert atomic_stresses.shape[1:] == (3, 3), "Atomic stresses should be 3x3 matrices" - - # Compute sum of atomic values - summed_atomic_virials = torch.sum(atomic_virials, dim=0) - summed_atomic_stresses = torch.sum(atomic_stresses, dim=0) - - # Test that sums match total values - assert torch.allclose( - summed_atomic_virials, total_virials.squeeze(0), atol=1e-6 - ), f"Sum of atomic virials {summed_atomic_virials} does not match total virials {total_virials.squeeze(0)}" - - assert torch.allclose( - summed_atomic_stresses, total_stress.squeeze(0), atol=1e-6 - ), f"Sum of atomic stresses (normalized by volume) {summed_atomic_stresses} does not match total stress {total_stress.squeeze(0)}" +import numpy as np +import torch +import torch.nn.functional +from ase import build +from e3nn import o3 +from e3nn.util import jit +from scipy.spatial.transform import Rotation as R + +from mace import data, modules, tools +from mace.tools import torch_geometric + +torch.set_default_dtype(torch.float64) +config = data.Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "charges": np.array([-2.0, 1.0, 1.0]), + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, +) +# Created the rotated environment +rot = R.from_euler("z", 60, degrees=True).as_matrix() +positions_rotated = np.array(rot @ config.positions.T).T +config_rotated = data.Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=positions_rotated, + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "charges": np.array([-2.0, 1.0, 1.0]), + "dipole": np.array([-1.5, 1.5, 2.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "charges": 1.0, + "dipole": 1.0, + }, +) +table = tools.AtomicNumberTable([1, 8]) +atomic_energies = np.array([1.0, 3.0], dtype=float) + + +def test_mace(): + # Create MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=5, + num_elements=2, + hidden_irreps=o3.Irreps("32x0e + 32x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=8, + atomic_numbers=table.zs, + correlation=3, + radial_type="bessel", + ) + model = modules.MACE(**model_config) + model_compiled = jit.compile(model) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output1 = model(batch.to_dict(), training=True) + output2 = model_compiled(batch.to_dict(), training=True) + assert torch.allclose(output1["energy"][0], output2["energy"][0]) + assert torch.allclose(output2["energy"][0], output2["energy"][1]) + + +def test_dipole_mace(): + # create dipole MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=5, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=None, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + radial_type="gaussian", + ) + model = modules.AtomicDipolesMACE(**model_config) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model( + batch, + training=True, + ) + # sanity check of dipoles being the right shape + assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape + # test equivariance of output dipoles + assert np.allclose( + np.array(rot @ output["dipole"][0].detach().numpy()), + output["dipole"][1].detach().numpy(), + ) + + +def test_energy_dipole_mace(): + # create dipole MACE model + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=5, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("16x0e + 16x1o + 16x2e"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies, + avg_num_neighbors=3, + atomic_numbers=table.zs, + correlation=3, + ) + model = modules.EnergyDipolesMACE(**model_config) + + atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0 + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + output = model( + batch, + training=True, + ) + # sanity check of dipoles being the right shape + assert output["dipole"][0].unsqueeze(0).shape == atomic_data.dipole.shape + # test energy is invariant + assert torch.allclose(output["energy"][0], output["energy"][1]) + # test equivariance of output dipoles + assert np.allclose( + np.array(rot @ output["dipole"][0].detach().numpy()), + output["dipole"][1].detach().numpy(), + ) + + +def test_mace_multi_reference(): + atomic_energies_multi = np.array([[1.0, 3.0], [0.0, 0.0]], dtype=float) + model_config = dict( + r_max=5, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=3, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("96x0e + 96x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=atomic_energies_multi, + avg_num_neighbors=8, + atomic_numbers=table.zs, + distance_transform=True, + pair_repulsion=True, + correlation=3, + heads=["Default", "dft"], + # radial_type="chebyshev", + atomic_inter_scale=[1.0, 1.0], + atomic_inter_shift=[0.0, 0.1], + ) + model = modules.ScaleShiftMACE(**model_config) + model_compiled = jit.compile(model) + config.head = "Default" + config_rotated.head = "dft" + atomic_data = data.AtomicData.from_config( + config, z_table=table, cutoff=3.0, heads=["Default", "dft"] + ) + atomic_data2 = data.AtomicData.from_config( + config_rotated, z_table=table, cutoff=3.0, heads=["Default", "dft"] + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data, atomic_data2], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + output1 = model(batch.to_dict(), training=True) + output2 = model_compiled(batch.to_dict(), training=True) + assert torch.allclose(output1["energy"][0], output2["energy"][0]) + assert output2["energy"].shape[0] == 2 + + +def test_atomic_virials_stresses(): + """ + Test that atomic virials and stresses sum to the total virials and stress. + """ + # Set default dtype for reproducibility + torch.set_default_dtype(torch.float64) + + # Create a periodic cell with ASE + atoms = build.bulk("Si", "diamond", a=5.43) + # Apply strain to ensure non-zero stress + strain_tensor = np.eye(3) * 1.02 # 2% strain + atoms.set_cell(np.dot(atoms.get_cell(), strain_tensor), scale_atoms=True) + + # Add forces and energy for completeness + atoms.arrays["REF_forces"] = np.random.normal(0, 0.1, size=atoms.positions.shape) + atoms.info["REF_energy"] = np.random.normal(0, 1) + atoms.info["REF_stress"] = np.random.normal(0, 0.1, size=6) + + # Setup MACE model configuration + stress_z_table = tools.AtomicNumberTable([14]) # Silicon + stress_atomic_energies = np.array([0.0]) + + model_config = dict( + r_max=5.0, + num_bessel=8, + num_polynomial_cutoff=6, + max_ell=2, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=3, + num_elements=1, + hidden_irreps=o3.Irreps("32x0e + 32x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=stress_atomic_energies, + avg_num_neighbors=4.0, + atomic_numbers=table.zs, + correlation=3, + atomic_inter_scale=1.0, + atomic_inter_shift=0.0, + ) + + # Create the model + model = modules.ScaleShiftMACE(**model_config) + + # Create atomic data + atomic_data = data.AtomicData.from_config( + data.config_from_atoms( + atoms, key_specification=data.KeySpecification.from_defaults() + ), + z_table=stress_z_table, + cutoff=5.0, + ) + + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch_dict = batch.to_dict() + + # Run the model with compute_atomic_stresses=True + output = model( + batch_dict, + compute_force=True, + compute_virials=True, + compute_stress=True, + compute_atomic_stresses=True, + ) + + # Get total virials/stress and atomic virials/stresses + total_virials = output["virials"] + atomic_virials = output["atomic_virials"] + total_stress = output["stress"] + atomic_stresses = output["atomic_stresses"] + + # Test that atomic values are not None + assert atomic_virials is not None, "Atomic virials were not computed" + assert atomic_stresses is not None, "Atomic stresses were not computed" + + # Test shape of atomic values + assert atomic_virials.shape[0] == len(atoms), "Wrong shape for atomic virials" + assert atomic_virials.shape[1:] == (3, 3), "Atomic virials should be 3x3 matrices" + assert atomic_stresses.shape[0] == len(atoms), "Wrong shape for atomic stresses" + assert atomic_stresses.shape[1:] == (3, 3), "Atomic stresses should be 3x3 matrices" + + # Compute sum of atomic values + summed_atomic_virials = torch.sum(atomic_virials, dim=0) + summed_atomic_stresses = torch.sum(atomic_stresses, dim=0) + + # Test that sums match total values + assert torch.allclose( + summed_atomic_virials, total_virials.squeeze(0), atol=1e-6 + ), f"Sum of atomic virials {summed_atomic_virials} does not match total virials {total_virials.squeeze(0)}" + + assert torch.allclose( + summed_atomic_stresses, total_stress.squeeze(0), atol=1e-6 + ), f"Sum of atomic stresses (normalized by volume) {summed_atomic_stresses} does not match total stress {total_stress.squeeze(0)}" diff --git a/mace-bench/3rdparty/mace/tests/test_modules.py b/mace-bench/3rdparty/mace/tests/test_modules.py index 57ddc32..6afcccf 100644 --- a/mace-bench/3rdparty/mace/tests/test_modules.py +++ b/mace-bench/3rdparty/mace/tests/test_modules.py @@ -1,268 +1,268 @@ -import numpy as np -import pytest -import torch -import torch.nn.functional -from e3nn import o3 - -from mace.data import AtomicData, Configuration -from mace.modules import ( - AtomicEnergiesBlock, - BesselBasis, - PolynomialCutoff, - SymmetricContraction, - WeightedEnergyForcesLoss, - WeightedHuberEnergyForcesStressLoss, - compute_mean_rms_energy_forces, - compute_statistics, -) -from mace.tools import AtomicNumberTable, scatter, to_numpy, torch_geometric -from mace.tools.scripts_utils import dict_to_array - - -@pytest.fixture(name="config") -def _config(): - return Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - "stress": np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - "stress": 1.0, - }, - ) - - -@pytest.fixture(name="table") -def _table(): - return AtomicNumberTable([1, 8]) - - -@pytest.fixture(name="config1") -def _config1(): - return Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.0, -2.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - ] - ), - properties={ - "forces": np.array( - [ - [0.0, -1.3, 0.0], - [1.0, 0.2, 0.0], - [0.0, 1.1, 0.3], - ] - ), - "energy": -1.5, - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - }, - head="DFT", - ) - - -@pytest.fixture(name="config2") -def _config2(): - return Configuration( - atomic_numbers=np.array([8, 1, 1]), - positions=np.array( - [ - [0.1, -1.9, 0.1], - [1.1, 0.1, 0.1], - [0.1, 1.1, 0.1], - ] - ), - properties={ - "forces": np.array( - [ - [0.1, -1.2, 0.1], - [1.1, 0.3, 0.1], - [0.1, 1.2, 0.4], - ] - ), - "energy": -1.4, - }, - property_weights={ - "forces": 1.0, - "energy": 1.0, - }, - head="MP2", - ) - - -@pytest.fixture(name="atomic_data") -def _atomic_data(config1, config2, table): - atomic_data1 = AtomicData.from_config( - config1, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] - ) - atomic_data2 = AtomicData.from_config( - config2, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] - ) - return [atomic_data1, atomic_data2] - - -@pytest.fixture(name="data_loader") -def _data_loader(atomic_data): - return torch_geometric.dataloader.DataLoader( - dataset=atomic_data, - batch_size=2, - shuffle=False, - drop_last=False, - ) - - -@pytest.fixture(name="atomic_energies") -def _atomic_energies(): - atomic_energies_dict = { - "DFT": np.array([0.0, 0.0]), - "MP2": np.array([0.1, 0.1]), - } - return dict_to_array(atomic_energies_dict, ["DFT", "MP2"]) - - -@pytest.fixture(autouse=True) -def _set_torch_default_dtype(): - torch.set_default_dtype(torch.float64) - - -def test_weighted_loss(config, table): - loss1 = WeightedEnergyForcesLoss(energy_weight=1, forces_weight=10) - loss2 = WeightedHuberEnergyForcesStressLoss(energy_weight=1, forces_weight=10) - data = AtomicData.from_config(config, z_table=table, cutoff=3.0) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data, data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - pred = { - "energy": batch.energy, - "forces": batch.forces, - "stress": batch.stress, - } - out1 = loss1(batch, pred) - assert out1 == 0.0 - out2 = loss2(batch, pred) - assert out2 == 0.0 - - -def test_symmetric_contraction(): - operation = SymmetricContraction( - irreps_in=o3.Irreps("16x0e + 16x1o + 16x2e"), - irreps_out=o3.Irreps("16x0e + 16x1o"), - correlation=3, - num_elements=2, - ) - torch.manual_seed(123) - features = torch.randn(30, 16, 9) - one_hots = torch.nn.functional.one_hot(torch.arange(0, 30) % 2).to( - torch.get_default_dtype() - ) - out = operation(features, one_hots) - assert out.shape == (30, 64) - assert operation.contractions[0].weights_max.shape == (2, 11, 16) - - -def test_bessel_basis(): - d = torch.linspace(start=0.5, end=5.5, steps=10) - bessel_basis = BesselBasis(r_max=6.0, num_basis=5) - output = bessel_basis(d.unsqueeze(-1)) - assert output.shape == (10, 5) - - -def test_polynomial_cutoff(): - d = torch.linspace(start=0.5, end=5.5, steps=10) - cutoff_fn = PolynomialCutoff(r_max=5.0) - output = cutoff_fn(d) - assert output.shape == (10,) - - -def test_atomic_energies(config, table): - energies_block = AtomicEnergiesBlock(atomic_energies=np.array([1.0, 3.0])) - data = AtomicData.from_config(config, z_table=table, cutoff=3.0) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data, data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - energies = energies_block(batch.node_attrs).squeeze(-1) - out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") - out = to_numpy(out) - assert np.allclose(out, np.array([5.0, 5.0])) - - -def test_atomic_energies_multireference(config, table): - energies_block = AtomicEnergiesBlock( - atomic_energies=np.array([[1.0, 3.0], [2.0, 4.0]]) - ) - config.head = "MP2" - data = AtomicData.from_config( - config, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] - ) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[data, data], - batch_size=2, - shuffle=True, - drop_last=False, - ) - batch = next(iter(data_loader)) - num_atoms_arange = torch.arange(batch["positions"].shape[0]) - node_heads = ( - batch["head"][batch["batch"]] - if "head" in batch - else torch.zeros_like(batch["batch"]) - ) - energies = energies_block(batch.node_attrs).squeeze(-1) - energies = energies[num_atoms_arange, node_heads] - out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") - out = to_numpy(out) - assert np.allclose(out, np.array([8.0, 8.0])) - - -def test_compute_mean_rms_energy_forces_multi_head(data_loader, atomic_energies): - mean, rms = compute_mean_rms_energy_forces(data_loader, atomic_energies) - assert isinstance(mean, np.ndarray) - assert isinstance(rms, np.ndarray) - assert mean.shape == (2,) - assert rms.shape == (2,) - assert np.all(rms >= 0) - assert rms[0] != rms[1] - - -def test_compute_statistics(data_loader, atomic_energies): - avg_num_neighbors, mean, std = compute_statistics(data_loader, atomic_energies) - assert isinstance(avg_num_neighbors, float) - assert isinstance(mean, np.ndarray) - assert isinstance(std, np.ndarray) - assert mean.shape == (2,) - assert std.shape == (2,) - assert avg_num_neighbors > 0 - assert np.all(mean != 0) - assert np.all(std > 0) - assert mean[0] != mean[1] - assert std[0] != std[1] +import numpy as np +import pytest +import torch +import torch.nn.functional +from e3nn import o3 + +from mace.data import AtomicData, Configuration +from mace.modules import ( + AtomicEnergiesBlock, + BesselBasis, + PolynomialCutoff, + SymmetricContraction, + WeightedEnergyForcesLoss, + WeightedHuberEnergyForcesStressLoss, + compute_mean_rms_energy_forces, + compute_statistics, +) +from mace.tools import AtomicNumberTable, scatter, to_numpy, torch_geometric +from mace.tools.scripts_utils import dict_to_array + + +@pytest.fixture(name="config") +def _config(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + "stress": np.array([1.0, 0.0, 0.5, 0.0, -1.0, 0.0]), + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + "stress": 1.0, + }, + ) + + +@pytest.fixture(name="table") +def _table(): + return AtomicNumberTable([1, 8]) + + +@pytest.fixture(name="config1") +def _config1(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.0, -2.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ), + properties={ + "forces": np.array( + [ + [0.0, -1.3, 0.0], + [1.0, 0.2, 0.0], + [0.0, 1.1, 0.3], + ] + ), + "energy": -1.5, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, + head="DFT", + ) + + +@pytest.fixture(name="config2") +def _config2(): + return Configuration( + atomic_numbers=np.array([8, 1, 1]), + positions=np.array( + [ + [0.1, -1.9, 0.1], + [1.1, 0.1, 0.1], + [0.1, 1.1, 0.1], + ] + ), + properties={ + "forces": np.array( + [ + [0.1, -1.2, 0.1], + [1.1, 0.3, 0.1], + [0.1, 1.2, 0.4], + ] + ), + "energy": -1.4, + }, + property_weights={ + "forces": 1.0, + "energy": 1.0, + }, + head="MP2", + ) + + +@pytest.fixture(name="atomic_data") +def _atomic_data(config1, config2, table): + atomic_data1 = AtomicData.from_config( + config1, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + atomic_data2 = AtomicData.from_config( + config2, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + return [atomic_data1, atomic_data2] + + +@pytest.fixture(name="data_loader") +def _data_loader(atomic_data): + return torch_geometric.dataloader.DataLoader( + dataset=atomic_data, + batch_size=2, + shuffle=False, + drop_last=False, + ) + + +@pytest.fixture(name="atomic_energies") +def _atomic_energies(): + atomic_energies_dict = { + "DFT": np.array([0.0, 0.0]), + "MP2": np.array([0.1, 0.1]), + } + return dict_to_array(atomic_energies_dict, ["DFT", "MP2"]) + + +@pytest.fixture(autouse=True) +def _set_torch_default_dtype(): + torch.set_default_dtype(torch.float64) + + +def test_weighted_loss(config, table): + loss1 = WeightedEnergyForcesLoss(energy_weight=1, forces_weight=10) + loss2 = WeightedHuberEnergyForcesStressLoss(energy_weight=1, forces_weight=10) + data = AtomicData.from_config(config, z_table=table, cutoff=3.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + pred = { + "energy": batch.energy, + "forces": batch.forces, + "stress": batch.stress, + } + out1 = loss1(batch, pred) + assert out1 == 0.0 + out2 = loss2(batch, pred) + assert out2 == 0.0 + + +def test_symmetric_contraction(): + operation = SymmetricContraction( + irreps_in=o3.Irreps("16x0e + 16x1o + 16x2e"), + irreps_out=o3.Irreps("16x0e + 16x1o"), + correlation=3, + num_elements=2, + ) + torch.manual_seed(123) + features = torch.randn(30, 16, 9) + one_hots = torch.nn.functional.one_hot(torch.arange(0, 30) % 2).to( + torch.get_default_dtype() + ) + out = operation(features, one_hots) + assert out.shape == (30, 64) + assert operation.contractions[0].weights_max.shape == (2, 11, 16) + + +def test_bessel_basis(): + d = torch.linspace(start=0.5, end=5.5, steps=10) + bessel_basis = BesselBasis(r_max=6.0, num_basis=5) + output = bessel_basis(d.unsqueeze(-1)) + assert output.shape == (10, 5) + + +def test_polynomial_cutoff(): + d = torch.linspace(start=0.5, end=5.5, steps=10) + cutoff_fn = PolynomialCutoff(r_max=5.0) + output = cutoff_fn(d) + assert output.shape == (10,) + + +def test_atomic_energies(config, table): + energies_block = AtomicEnergiesBlock(atomic_energies=np.array([1.0, 3.0])) + data = AtomicData.from_config(config, z_table=table, cutoff=3.0) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + energies = energies_block(batch.node_attrs).squeeze(-1) + out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") + out = to_numpy(out) + assert np.allclose(out, np.array([5.0, 5.0])) + + +def test_atomic_energies_multireference(config, table): + energies_block = AtomicEnergiesBlock( + atomic_energies=np.array([[1.0, 3.0], [2.0, 4.0]]) + ) + config.head = "MP2" + data = AtomicData.from_config( + config, z_table=table, cutoff=3.0, heads=["DFT", "MP2"] + ) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[data, data], + batch_size=2, + shuffle=True, + drop_last=False, + ) + batch = next(iter(data_loader)) + num_atoms_arange = torch.arange(batch["positions"].shape[0]) + node_heads = ( + batch["head"][batch["batch"]] + if "head" in batch + else torch.zeros_like(batch["batch"]) + ) + energies = energies_block(batch.node_attrs).squeeze(-1) + energies = energies[num_atoms_arange, node_heads] + out = scatter.scatter_sum(src=energies, index=batch.batch, dim=-1, reduce="sum") + out = to_numpy(out) + assert np.allclose(out, np.array([8.0, 8.0])) + + +def test_compute_mean_rms_energy_forces_multi_head(data_loader, atomic_energies): + mean, rms = compute_mean_rms_energy_forces(data_loader, atomic_energies) + assert isinstance(mean, np.ndarray) + assert isinstance(rms, np.ndarray) + assert mean.shape == (2,) + assert rms.shape == (2,) + assert np.all(rms >= 0) + assert rms[0] != rms[1] + + +def test_compute_statistics(data_loader, atomic_energies): + avg_num_neighbors, mean, std = compute_statistics(data_loader, atomic_energies) + assert isinstance(avg_num_neighbors, float) + assert isinstance(mean, np.ndarray) + assert isinstance(std, np.ndarray) + assert mean.shape == (2,) + assert std.shape == (2,) + assert avg_num_neighbors > 0 + assert np.all(mean != 0) + assert np.all(std > 0) + assert mean[0] != mean[1] + assert std[0] != std[1] diff --git a/mace-bench/3rdparty/mace/tests/test_multifiles.py b/mace-bench/3rdparty/mace/tests/test_multifiles.py index fb99518..16eacc2 100644 --- a/mace-bench/3rdparty/mace/tests/test_multifiles.py +++ b/mace-bench/3rdparty/mace/tests/test_multifiles.py @@ -1,1029 +1,1029 @@ -import json -import os -import shutil -import subprocess -import sys -import tempfile -import zlib -from pathlib import Path - -import lmdb -import numpy as np -import orjson -import pytest -import torch -import yaml -from ase.atoms import Atoms -from ase.calculators.singlepoint import SinglePointCalculator - -from mace.calculators import MACECalculator - - -def create_test_atoms(num_atoms=5, seed=42): - """Create random atoms for testing purposes with energy, forces, and stress.""" - # Set random seed for reproducibility - rng = np.random.RandomState(seed) - - # Create random positions - positions = rng.rand(num_atoms, 3) * 5.0 - - # Create random atomic numbers (H, C, N, O) - atomic_numbers = rng.choice([1, 6, 7, 8], size=num_atoms) - - # Create atoms object - atoms = Atoms( - numbers=atomic_numbers, - positions=positions, - cell=np.eye(3) * 10.0, # 10 Å periodic box - pbc=True, - ) - - # Add random energy, forces and stress - energy = float(rng.uniform(-15.0, -5.0)) - forces = rng.rand(num_atoms, 3) * 0.5 - 0.25 # Forces between -0.25 and 0.25 eV/Å - stress = rng.rand(6) * 0.2 - 0.1 # Stress tensor in Voigt notation - - # Add calculator to atoms with results - calc = SinglePointCalculator(atoms, energy=energy, forces=forces, stress=stress) - atoms.calc = calc - - # Mark isolated atoms with config_type - if num_atoms == 1: - atoms.info["config_type"] = "IsolatedAtom" - - return atoms - - -def create_xyz_file(atoms_list, filename): - """Write a list of atoms to an xyz file.""" - from ase.io import write - - write(filename, atoms_list, format="extxyz") - return filename - - -def create_e0s_file(e0s_dict, filename): - """Create an E0s JSON file with isolated atom energies.""" - # Convert keys to integers since MACE expects atomic numbers as integers - e0s_dict_int_keys = {int(k): v for k, v in e0s_dict.items()} - - with open(filename, "w", encoding="utf-8") as f: - json.dump(e0s_dict_int_keys, f) - return filename - - -def create_h5_dataset(xyz_file, output_dir, e0s_file=None, r_max=5.0, seed=42): - """ - Run MACE's preprocess_data.py script to convert an xyz file to h5 format. - - Args: - xyz_file: Path to the input xyz file - output_dir: Directory to store the preprocessed h5 files - e0s_file: Path to the E0s file with isolated atom energies - r_max: Cutoff radius - seed: Random seed - - Returns: - The output directory containing the h5 files - """ - # Make sure output directory exists - os.makedirs(output_dir, exist_ok=True) - - # Find the path to the preprocess_data.py script - preprocess_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" - ) - - # Set up command to run preprocess_data.py - cmd = [ - sys.executable, - str(preprocess_script), - f"--train_file={xyz_file}", - f"--r_max={r_max}", - f"--h5_prefix={output_dir}/", - f"--seed={seed}", - "--compute_statistics", # Generate statistics file - "--num_process=2", # Create 2 files for testing sharded loading - ] - - # Add E0s file if provided - if e0s_file: - cmd.append(f"--E0s={e0s_file}") - - # Set up environment - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - # Run the script - print(f"Running preprocess command: {' '.join(cmd)}") - try: - process = subprocess.run( - cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True - ) - # Print output for debugging - print("Preprocess stdout:", process.stdout.decode()) - print("Preprocess stderr:", process.stderr.decode()) - except subprocess.CalledProcessError as e: - print("Preprocess failed with error:", e) - print("Stdout:", e.stdout.decode() if e.stdout else "") - print("Stderr:", e.stderr.decode() if e.stderr else "") - raise - - return output_dir - - -def create_lmdb_dataset(atoms_list, folder_path, head_name="Default"): - """Create an LMDB dataset from a list of atoms objects that MACE can read.""" - # Create the folder if it doesn't exist - os.makedirs(folder_path, exist_ok=True) - - # Create the LMDB database file - db_path = os.path.join(folder_path, "data.aselmdb") - - # Initialize LMDB environment - env = lmdb.open( - db_path, - map_size=1099511627776, # 1TB - subdir=False, - meminit=False, - map_async=True, - ) - - # Open a transaction - with env.begin(write=True) as txn: - # Store metadata - metadata = {"format_version": 1} - txn.put( - "metadata".encode("ascii"), - zlib.compress(orjson.dumps(metadata, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - # Store nextid - nextid = len(atoms_list) + 1 - txn.put( - "nextid".encode("ascii"), - zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - # Store deleted_ids (empty) - txn.put( - "deleted_ids".encode("ascii"), - zlib.compress(orjson.dumps([], option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - # Store each atom - for i, atoms in enumerate(atoms_list): - id_num = i + 1 # Start from 1 - - # Convert atoms to dictionary - positions = atoms.get_positions() - cell = atoms.get_cell() - - # Create a dictionary with all necessary fields - dct = { - "numbers": atoms.get_atomic_numbers().tolist(), - "positions": positions.tolist(), - "cell": cell.tolist(), - "pbc": atoms.get_pbc().tolist(), - "ctime": 0.0, # Creation time - "mtime": 0.0, # Modification time - "user": "test", - "energy": atoms.calc.results["energy"], - "forces": atoms.calc.results["forces"].tolist(), - "stress": atoms.calc.results["stress"].tolist(), - "key_value_pairs": { - "config_type": atoms.info.get("config_type", "Default"), - "head": head_name, - }, - } - - # Store the atom in LMDB - txn.put( - f"{id_num}".encode("ascii"), - zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), - ) - - # Close the environment - env.close() - - return folder_path - - -@pytest.mark.slow -def test_multifile_training(): - """Test training with multiple file formats per head""" - # Create temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Set up file paths - xyz_file1 = os.path.join(temp_dir, "data1.xyz") - xyz_file2 = os.path.join(temp_dir, "data2.xyz") - iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") - h5_folder = os.path.join(temp_dir, "h5_data") - lmdb_folder1 = os.path.join( - temp_dir, "lmdb_data1_lmdb" - ) # Add _lmdb suffix for LMDB recognition - lmdb_folder2 = os.path.join( - temp_dir, "lmdb_data2_lmdb" - ) # Add _lmdb suffix for LMDB recognition - - config_path = os.path.join(temp_dir, "config.yaml") - results_dir = os.path.join(temp_dir, "results") - checkpoints_dir = os.path.join(temp_dir, "checkpoints") - model_dir = os.path.join(temp_dir, "models") - e0s_file = os.path.join(temp_dir, "e0s.json") - - # Create directories - os.makedirs(results_dir, exist_ok=True) - os.makedirs(checkpoints_dir, exist_ok=True) - os.makedirs(model_dir, exist_ok=True) - - # Set atomic numbers for z_table - z_table_elements = [1, 6, 7, 8] # H, C, N, O - - # Create test data for each format - rng = np.random.RandomState(42) - seeds = rng.randint(0, 10000, size=5) - - # Create isolated atoms for E0s (one of each element) - isolated_atoms = [] - e0s_dict = {} - for z in z_table_elements: - # Create isolated atom - atom = Atoms( - numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True - ) - energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy - forces = np.zeros((1, 3)) - stress = np.zeros(6) - calc = SinglePointCalculator( - atom, energy=energy, forces=forces, stress=stress - ) - atom.calc = calc - atom.info["config_type"] = "IsolatedAtom" - atom.info["REF_energy"] = energy # Make sure energy is in the right place - isolated_atoms.append(atom) - e0s_dict[str(z)] = energy # Store energy for E0s file - - # Create E0s file - create_e0s_file(e0s_dict, e0s_file) - - # Create isolated atoms xyz file - create_xyz_file(isolated_atoms, iso_atoms_file) - - # Create 10 atoms for each dataset - xyz_atoms1 = [ - create_test_atoms(num_atoms=5, seed=seeds[0] + i) for i in range(10) - ] - xyz_atoms2 = [ - create_test_atoms(num_atoms=5, seed=seeds[1] + i) for i in range(10) - ] - - # Create h5 data directly - first convert the xyz file to a format with REF_ keys - for atom in xyz_atoms1: - atom.info["REF_energy"] = atom.calc.results["energy"] - atom.arrays["REF_forces"] = atom.calc.results["forces"] - atom.info["REF_stress"] = atom.calc.results["stress"] - - for atom in xyz_atoms2: - atom.info["REF_energy"] = atom.calc.results["energy"] - atom.arrays["REF_forces"] = atom.calc.results["forces"] - atom.info["REF_stress"] = atom.calc.results["stress"] - - # Save isolated atoms to xyz files first, then create the h5 datasets - create_xyz_file(xyz_atoms1, xyz_file1) - create_xyz_file(xyz_atoms2, xyz_file2) - - # Create h5 data from xyz file, using both isolated atoms and real data - all_atoms_for_h5 = isolated_atoms + xyz_atoms2 - all_atoms_xyz = os.path.join(temp_dir, "all_atoms_for_h5.xyz") - create_xyz_file(all_atoms_for_h5, all_atoms_xyz) - create_h5_dataset(all_atoms_xyz, h5_folder) - - # Create LMDB datasets - lmdb_atoms1 = [ - create_test_atoms(num_atoms=5, seed=seeds[3] + i) for i in range(10) - ] - lmdb_atoms2 = [ - create_test_atoms(num_atoms=5, seed=seeds[4] + i) for i in range(10) - ] - create_lmdb_dataset(lmdb_atoms1, lmdb_folder1, head_name="head1") - create_lmdb_dataset(lmdb_atoms2, lmdb_folder2, head_name="head2") - - # Create config.yaml for training with proper format specification - config = { - "name": "multifile_test", - "seed": 42, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 5, - "max_num_epochs": 2, - "patience": 5, - "device": "cpu", - "energy_weight": 1.0, - "forces_weight": 10.0, - "loss": "weighted", - "optimizer": "adam", - "default_dtype": "float64", - "lr": 0.01, - "swa": False, - "work_dir": temp_dir, - "results_dir": results_dir, - "checkpoints_dir": checkpoints_dir, - "model_dir": model_dir, - "E0s": e0s_file, - "atomic_numbers": str(z_table_elements), - "heads": { - "head1": { - "train_file": [lmdb_folder1, xyz_file1], - "valid_file": xyz_file1, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - }, - "head2": { - "train_file": [h5_folder + "/train", xyz_file2], - "valid_file": xyz_file2, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - }, - }, - } - - # Write config file - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config, f) - - # Import the modified run_train from our local module - run_train_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - ) - - # Run training with subprocess - cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] - - # Set environment to add the current path to PYTHONPATH - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - # Run the process - process = subprocess.run( - cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, # Don't raise exception on non-zero exit, we'll check manually - ) - - # Print output for debugging - print("\n" + "=" * 40 + " STDOUT " + "=" * 40) - print(process.stdout.decode()) - print("\n" + "=" * 40 + " STDERR " + "=" * 40) - print(process.stderr.decode()) - - # Check that process completed successfully - assert ( - process.returncode == 0 - ), f"Training failed with error: {process.stderr.decode()}" - - # Check that model was created - model_path = os.path.join(model_dir, "multifile_test.model") - assert os.path.exists(model_path), f"Model was not created at {model_path}" - - # Try to load and run the model - model = torch.load(model_path, map_location="cpu") - assert model is not None, "Failed to load model" - - # Create a calculator - calc = MACECalculator(model_paths=model_path, device="cpu", head="head1") - - # Run prediction on a test atom - test_atom = create_test_atoms(num_atoms=5, seed=99999) - test_atom.calc = calc - energy = test_atom.get_potential_energy() - forces = test_atom.get_forces() - - # Assert we got sensible outputs - assert np.isfinite(energy), "Model produced non-finite energy" - assert np.all(np.isfinite(forces)), "Model produced non-finite forces" - - finally: - # Clean up - shutil.rmtree(temp_dir) - - -@pytest.mark.slow -def test_multiple_xyz_per_head(): - """Test training with multiple XYZ files per head for train, valid and test sets""" - # Create temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Set up file paths - create multiple xyz files for each dataset - train_xyz_files = [ - os.path.join(temp_dir, f"train_data{i}.xyz") for i in range(1, 4) - ] # 3 train files - valid_xyz_files = [ - os.path.join(temp_dir, f"valid_data{i}.xyz") for i in range(1, 3) - ] # 2 valid files - test_xyz_files = [ - os.path.join(temp_dir, f"test_data{i}.xyz") for i in range(1, 3) - ] # 2 test files - - iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") - - config_path = os.path.join(temp_dir, "config.yaml") - results_dir = os.path.join(temp_dir, "results") - checkpoints_dir = os.path.join(temp_dir, "checkpoints") - model_dir = os.path.join(temp_dir, "models") - e0s_file = os.path.join(temp_dir, "e0s.json") - - # Create directories - os.makedirs(results_dir, exist_ok=True) - os.makedirs(checkpoints_dir, exist_ok=True) - os.makedirs(model_dir, exist_ok=True) - - # Set atomic numbers for z_table - z_table_elements = [1, 6, 7, 8] # H, C, N, O - - # Create test data for each format - rng = np.random.RandomState(42) - seeds = rng.randint(0, 10000, size=10) # More seeds for multiple files - - # Create isolated atoms for E0s (one of each element) - isolated_atoms = [] - e0s_dict = {} - for z in z_table_elements: - # Create isolated atom - atom = Atoms( - numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True - ) - energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy - forces = np.zeros((1, 3)) - stress = np.zeros(6) - calc = SinglePointCalculator( - atom, energy=energy, forces=forces, stress=stress - ) - atom.calc = calc - atom.info["config_type"] = "IsolatedAtom" - isolated_atoms.append(atom) - e0s_dict[str(z)] = energy # Store energy for E0s file - - # Create E0s file - create_e0s_file(e0s_dict, e0s_file) - - # Create isolated atoms xyz file - create_xyz_file(isolated_atoms, iso_atoms_file) - - # Create atoms for each train dataset - use different seeds for variety - train_datasets = [] - for i, file in enumerate(train_xyz_files): - # Create atoms with different seeds - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i] + j) for j in range(5) - ] - create_xyz_file(atoms, file) - train_datasets.append(atoms) - - # Create atoms for validation datasets - valid_datasets = [] - for i, file in enumerate(valid_xyz_files): - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i + 3] + j) for j in range(3) - ] - create_xyz_file(atoms, file) - valid_datasets.append(atoms) - - # Create atoms for test datasets - test_datasets = [] - for i, file in enumerate(test_xyz_files): - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i + 5] + j) for j in range(3) - ] - create_xyz_file(atoms, file) - test_datasets.append(atoms) - - # Create config.yaml for training with multiple xyz files per dataset - config = { - "name": "multi_xyz_test", - "seed": 42, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 5, - "max_num_epochs": 2, - "patience": 5, - "device": "cpu", - "energy_weight": 1.0, - "forces_weight": 10.0, - "loss": "weighted", - "optimizer": "adam", - "default_dtype": "float64", - "lr": 0.01, - "swa": False, - "work_dir": temp_dir, - "results_dir": results_dir, - "checkpoints_dir": checkpoints_dir, - "model_dir": model_dir, - "E0s": e0s_file, - "atomic_numbers": str(z_table_elements), - "heads": { - "multi_xyz_head": { - # Using lists of multiple xyz files for each dataset - "train_file": train_xyz_files, - "valid_file": valid_xyz_files, - "test_file": test_xyz_files, - "energy_key": "energy", - "forces_key": "forces", - "stress_key": "stress", - }, - }, - } - - # Write config file - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config, f) - - # Import the modified run_train from our local module - run_train_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - ) - - # Run training with subprocess - cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] - - # Set environment to add the current path to PYTHONPATH - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - # Run the process - process = subprocess.run( - cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - ) - - # Print output for debugging - print("\n" + "=" * 40 + " STDOUT " + "=" * 40) - print(process.stdout.decode()) - print("\n" + "=" * 40 + " STDERR " + "=" * 40) - print(process.stderr.decode()) - - # Check that process completed successfully - assert ( - process.returncode == 0 - ), f"Training failed with error: {process.stderr.decode()}" - - # Check that model was created - model_path = os.path.join(model_dir, "multi_xyz_test.model") - assert os.path.exists(model_path), f"Model was not created at {model_path}" - - # Try to load and run the model - model = torch.load(model_path, map_location="cpu") - assert model is not None, "Failed to load model" - - # Create a calculator - calc = MACECalculator( - model_paths=model_path, device="cpu", head="multi_xyz_head" - ) - - # Run prediction on a test atom - test_atom = create_test_atoms(num_atoms=5, seed=99999) - test_atom.calc = calc - energy = test_atom.get_potential_energy() - forces = test_atom.get_forces() - - # Assert we got sensible outputs - assert np.isfinite(energy), "Model produced non-finite energy" - assert np.all(np.isfinite(forces)), "Model produced non-finite forces" - - finally: - # Clean up - shutil.rmtree(temp_dir) - - -@pytest.mark.slow -def test_single_xyz_per_head(): - """Test training with multiple XYZ files per head for train, valid and test sets""" - # Create temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Set up file paths - create multiple xyz files for each dataset - train_xyz_files = [ - os.path.join(temp_dir, f"train_data{i}.xyz") for i in range(1, 2) - ] # 3 train files - valid_xyz_files = [ - os.path.join(temp_dir, f"valid_data{i}.xyz") for i in range(1, 2) - ] # 2 valid files - test_xyz_files = [ - os.path.join(temp_dir, f"test_data{i}.xyz") for i in range(1, 2) - ] # 2 test files - - iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") - - config_path = os.path.join(temp_dir, "config.yaml") - results_dir = os.path.join(temp_dir, "results") - checkpoints_dir = os.path.join(temp_dir, "checkpoints") - model_dir = os.path.join(temp_dir, "models") - e0s_file = os.path.join(temp_dir, "e0s.json") - - # Create directories - os.makedirs(results_dir, exist_ok=True) - os.makedirs(checkpoints_dir, exist_ok=True) - os.makedirs(model_dir, exist_ok=True) - - # Set atomic numbers for z_table - z_table_elements = [1, 6, 7, 8] # H, C, N, O - - # Create test data for each format - rng = np.random.RandomState(42) - seeds = rng.randint(0, 10000, size=10) # More seeds for multiple files - - # Create isolated atoms for E0s (one of each element) - isolated_atoms = [] - e0s_dict = {} - for z in z_table_elements: - # Create isolated atom - atom = Atoms( - numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True - ) - energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy - forces = np.zeros((1, 3)) - stress = np.zeros(6) - calc = SinglePointCalculator( - atom, energy=energy, forces=forces, stress=stress - ) - atom.calc = calc - atom.info["config_type"] = "IsolatedAtom" - isolated_atoms.append(atom) - e0s_dict[str(z)] = energy # Store energy for E0s file - - # Create E0s file - create_e0s_file(e0s_dict, e0s_file) - - # Create isolated atoms xyz file - create_xyz_file(isolated_atoms, iso_atoms_file) - - # Create atoms for each train dataset - use different seeds for variety - train_datasets = [] - for i, file in enumerate(train_xyz_files): - # Create atoms with different seeds - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i] + j) for j in range(5) - ] - create_xyz_file(atoms, file) - train_datasets.append(atoms) - - # Create atoms for validation datasets - valid_datasets = [] - for i, file in enumerate(valid_xyz_files): - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i + 3] + j) for j in range(3) - ] - create_xyz_file(atoms, file) - valid_datasets.append(atoms) - - # Create atoms for test datasets - test_datasets = [] - for i, file in enumerate(test_xyz_files): - atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[i + 5] + j) for j in range(3) - ] - create_xyz_file(atoms, file) - test_datasets.append(atoms) - - # Create config.yaml for training with multiple xyz files per dataset - config = { - "name": "multi_xyz_test", - "seed": 42, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 5, - "max_num_epochs": 2, - "patience": 5, - "device": "cpu", - "energy_weight": 1.0, - "forces_weight": 10.0, - "loss": "weighted", - "optimizer": "adam", - "default_dtype": "float64", - "lr": 0.01, - "swa": False, - "work_dir": temp_dir, - "results_dir": results_dir, - "checkpoints_dir": checkpoints_dir, - "model_dir": model_dir, - "E0s": e0s_file, - "atomic_numbers": str(z_table_elements), - "heads": { - "multi_xyz_head": { - # Using lists of multiple xyz files for each dataset - "train_file": train_xyz_files, - "valid_file": valid_xyz_files, - "test_file": test_xyz_files, - "energy_key": "energy", - "forces_key": "forces", - "stress_key": "stress", - }, - }, - } - - # Write config file - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config, f) - - # Import the modified run_train from our local module - run_train_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - ) - - # Run training with subprocess - cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] - - # Set environment to add the current path to PYTHONPATH - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - # Run the process - process = subprocess.run( - cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - ) - - # Print output for debugging - print("\n" + "=" * 40 + " STDOUT " + "=" * 40) - print(process.stdout.decode()) - print("\n" + "=" * 40 + " STDERR " + "=" * 40) - print(process.stderr.decode()) - - # Check that process completed successfully - assert ( - process.returncode == 0 - ), f"Training failed with error: {process.stderr.decode()}" - - # Check that model was created - model_path = os.path.join(model_dir, "multi_xyz_test.model") - assert os.path.exists(model_path), f"Model was not created at {model_path}" - - # Try to load and run the model - model = torch.load(model_path, map_location="cpu") - assert model is not None, "Failed to load model" - - # Create a calculator - calc = MACECalculator( - model_paths=model_path, device="cpu", head="multi_xyz_head" - ) - - # Run prediction on a test atom - test_atom = create_test_atoms(num_atoms=5, seed=99999) - test_atom.calc = calc - energy = test_atom.get_potential_energy() - forces = test_atom.get_forces() - - # Assert we got sensible outputs - assert np.isfinite(energy), "Model produced non-finite energy" - assert np.all(np.isfinite(forces)), "Model produced non-finite forces" - - finally: - # Clean up - shutil.rmtree(temp_dir) - - -@pytest.mark.slow -def test_multihead_finetuning_different_formats(): - """Test multihead finetuning with different file formats for each head.""" - # Create temporary directory - temp_dir = tempfile.mkdtemp() - try: - # Set up file paths - xyz_file = os.path.join(temp_dir, "finetuning_xyz.xyz") - h5_folder = os.path.join(temp_dir, "h5_data") - iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") - - config_path = os.path.join(temp_dir, "config.yaml") - results_dir = os.path.join(temp_dir, "results") - checkpoints_dir = os.path.join(temp_dir, "checkpoints") - model_dir = os.path.join(temp_dir, "models") - e0s_file = os.path.join(temp_dir, "e0s.json") - - # Create directories - os.makedirs(results_dir, exist_ok=True) - os.makedirs(checkpoints_dir, exist_ok=True) - os.makedirs(model_dir, exist_ok=True) - - # Set atomic numbers for z_table - z_table_elements = [1, 6, 7, 8] # H, C, N, O - - # Create test data with different seeds - rng = np.random.RandomState(42) - seeds = rng.randint(0, 10000, size=3) - - # Create isolated atoms for E0s (one of each element) - isolated_atoms = [] - e0s_dict = {} - for z in z_table_elements: - atom = Atoms( - numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True - ) - energy = float(rng.uniform(-5.0, -1.0)) - forces = np.zeros((1, 3)) - stress = np.zeros(6) - calc = SinglePointCalculator( - atom, energy=energy, forces=forces, stress=stress - ) - atom.calc = calc - atom.info["config_type"] = "IsolatedAtom" - atom.info["REF_energy"] = energy # Make sure energy is in the right place - atom.arrays["REF_forces"] = forces - atom.info["REF_stress"] = stress - isolated_atoms.append(atom) - e0s_dict[str(z)] = energy - - # Create E0s file - create_e0s_file(e0s_dict, e0s_file) - - # Create isolated atoms xyz file - create_xyz_file(isolated_atoms, iso_atoms_file) - - # Create XYZ data for xyz_head - xyz_atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[0] + i) for i in range(30) - ] - # Add REF_ properties - for atom in xyz_atoms: - atom.info["REF_energy"] = atom.calc.results["energy"] - atom.arrays["REF_forces"] = atom.calc.results["forces"] - atom.info["REF_stress"] = atom.calc.results["stress"] - atom.info["head"] = "xyz_head" # Assign head - create_xyz_file(xyz_atoms, xyz_file) - - # Create H5 data for h5_head - h5_atoms = [ - create_test_atoms(num_atoms=5, seed=seeds[1] + i) for i in range(30) - ] - # Add REF_ properties - for atom in h5_atoms: - atom.info["REF_energy"] = atom.calc.results["energy"] - atom.arrays["REF_forces"] = atom.calc.results["forces"] - atom.info["REF_stress"] = atom.calc.results["stress"] - atom.info["head"] = "h5_head" # Assign head - - h5_atoms_xyz = os.path.join(temp_dir, "h5_atoms.xyz") - create_xyz_file(h5_atoms, h5_atoms_xyz) - # Include isolated atoms for E0s in the h5 dataset - all_atoms_for_h5 = h5_atoms + isolated_atoms - all_atoms_h5_xyz = os.path.join(temp_dir, "all_atoms_for_h5.xyz") - create_xyz_file(all_atoms_for_h5, all_atoms_h5_xyz) - create_h5_dataset(all_atoms_h5_xyz, h5_folder) - - # Create config.yaml for multihead finetuning - heads = { - "xyz_head": { - "train_file": xyz_file, - "valid_fraction": 0.2, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "E0s": e0s_file, - }, - "h5_head": { - "train_file": os.path.join(h5_folder, "train"), - "valid_file": os.path.join(h5_folder, "val"), - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "E0s": e0s_file, - }, - } - - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - - with open(config_path, "w", encoding="utf-8") as f: - f.write(yaml_str) - - # Now perform multihead finetuning - finetuning_params = { - "name": "multihead_finetuned", - "config": config_path, - "foundation_model": "small", # Use the small foundation model - "energy_weight": 1.0, - "forces_weight": 10.0, - "model": "MACE", - "hidden_irreps": "128x0e", # Match foundation model - "r_max": 5.0, - "batch_size": 2, - "max_num_epochs": 2, # Just do a quick finetuning for test - "device": "cpu", - "seed": 42, - "loss": "weighted", - "default_dtype": "float64", - "checkpoints_dir": checkpoints_dir, - "model_dir": model_dir, - "results_dir": results_dir, - "atomic_numbers": "[" + ",".join(map(str, z_table_elements)) + "]", - "multiheads_finetuning": True, - "filter_type_pt": "combinations", - "subselect_pt": "random", - "num_samples_pt": 10, # Small number for testing - "force_mh_ft_lr": True, # Force using specified learning rate - } - - # Run finetuning - run_train_script = ( - Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - ) - env = os.environ.copy() - env["PYTHONPATH"] = ( - str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") - ) - - cmd = [sys.executable, str(run_train_script)] - for k, v in finetuning_params.items(): - if v is None: - cmd.append(f"--{k}") - else: - cmd.append(f"--{k}={v}") - - # Run the process - process = subprocess.run( - cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=False, - ) - - # Print output for debugging - print("\n" + "=" * 40 + " STDOUT " + "=" * 40) - print(process.stdout.decode()) - print("\n" + "=" * 40 + " STDERR " + "=" * 40) - print(process.stderr.decode()) - - # Check that process completed successfully - assert ( - process.returncode == 0 - ), f"Finetuning failed with error: {process.stderr.decode()}" - - # Check that model was created - model_path = os.path.join(model_dir, "multihead_finetuned.model") - assert os.path.exists(model_path), f"Model was not created at {model_path}" - - # Load model and verify it has the expected heads - model = torch.load(model_path, map_location="cpu") - assert hasattr(model, "heads"), "Model does not have heads attribute" - assert set(["xyz_head", "h5_head", "pt_head"]).issubset( - set(model.heads) - ), "Expected heads not found in model" - - # Try to run the model with both heads - # For xyz_head - calc_xyz = MACECalculator( - model_paths=model_path, - device="cpu", - head="xyz_head", - default_dtype="float64", - ) - test_atom = create_test_atoms(num_atoms=5, seed=99999) - test_atom.calc = calc_xyz - energy_xyz = test_atom.get_potential_energy() - forces_xyz = test_atom.get_forces() - - # For h5_head - calc_h5 = MACECalculator( - model_paths=model_path, - device="cpu", - head="h5_head", - default_dtype="float64", - ) - test_atom.calc = calc_h5 - energy_h5 = test_atom.get_potential_energy() - forces_h5 = test_atom.get_forces() - - # Verify results - assert np.isfinite(energy_xyz), "xyz_head produced non-finite energy" - assert np.all(np.isfinite(forces_xyz)), "xyz_head produced non-finite forces" - assert np.isfinite(energy_h5), "h5_head produced non-finite energy" - assert np.all(np.isfinite(forces_h5)), "h5_head produced non-finite forces" - - finally: - # Clean up - shutil.rmtree(temp_dir) +import json +import os +import shutil +import subprocess +import sys +import tempfile +import zlib +from pathlib import Path + +import lmdb +import numpy as np +import orjson +import pytest +import torch +import yaml +from ase.atoms import Atoms +from ase.calculators.singlepoint import SinglePointCalculator + +from mace.calculators import MACECalculator + + +def create_test_atoms(num_atoms=5, seed=42): + """Create random atoms for testing purposes with energy, forces, and stress.""" + # Set random seed for reproducibility + rng = np.random.RandomState(seed) + + # Create random positions + positions = rng.rand(num_atoms, 3) * 5.0 + + # Create random atomic numbers (H, C, N, O) + atomic_numbers = rng.choice([1, 6, 7, 8], size=num_atoms) + + # Create atoms object + atoms = Atoms( + numbers=atomic_numbers, + positions=positions, + cell=np.eye(3) * 10.0, # 10 Å periodic box + pbc=True, + ) + + # Add random energy, forces and stress + energy = float(rng.uniform(-15.0, -5.0)) + forces = rng.rand(num_atoms, 3) * 0.5 - 0.25 # Forces between -0.25 and 0.25 eV/Å + stress = rng.rand(6) * 0.2 - 0.1 # Stress tensor in Voigt notation + + # Add calculator to atoms with results + calc = SinglePointCalculator(atoms, energy=energy, forces=forces, stress=stress) + atoms.calc = calc + + # Mark isolated atoms with config_type + if num_atoms == 1: + atoms.info["config_type"] = "IsolatedAtom" + + return atoms + + +def create_xyz_file(atoms_list, filename): + """Write a list of atoms to an xyz file.""" + from ase.io import write + + write(filename, atoms_list, format="extxyz") + return filename + + +def create_e0s_file(e0s_dict, filename): + """Create an E0s JSON file with isolated atom energies.""" + # Convert keys to integers since MACE expects atomic numbers as integers + e0s_dict_int_keys = {int(k): v for k, v in e0s_dict.items()} + + with open(filename, "w", encoding="utf-8") as f: + json.dump(e0s_dict_int_keys, f) + return filename + + +def create_h5_dataset(xyz_file, output_dir, e0s_file=None, r_max=5.0, seed=42): + """ + Run MACE's preprocess_data.py script to convert an xyz file to h5 format. + + Args: + xyz_file: Path to the input xyz file + output_dir: Directory to store the preprocessed h5 files + e0s_file: Path to the E0s file with isolated atom energies + r_max: Cutoff radius + seed: Random seed + + Returns: + The output directory containing the h5 files + """ + # Make sure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Find the path to the preprocess_data.py script + preprocess_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" + ) + + # Set up command to run preprocess_data.py + cmd = [ + sys.executable, + str(preprocess_script), + f"--train_file={xyz_file}", + f"--r_max={r_max}", + f"--h5_prefix={output_dir}/", + f"--seed={seed}", + "--compute_statistics", # Generate statistics file + "--num_process=2", # Create 2 files for testing sharded loading + ] + + # Add E0s file if provided + if e0s_file: + cmd.append(f"--E0s={e0s_file}") + + # Set up environment + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + # Run the script + print(f"Running preprocess command: {' '.join(cmd)}") + try: + process = subprocess.run( + cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + # Print output for debugging + print("Preprocess stdout:", process.stdout.decode()) + print("Preprocess stderr:", process.stderr.decode()) + except subprocess.CalledProcessError as e: + print("Preprocess failed with error:", e) + print("Stdout:", e.stdout.decode() if e.stdout else "") + print("Stderr:", e.stderr.decode() if e.stderr else "") + raise + + return output_dir + + +def create_lmdb_dataset(atoms_list, folder_path, head_name="Default"): + """Create an LMDB dataset from a list of atoms objects that MACE can read.""" + # Create the folder if it doesn't exist + os.makedirs(folder_path, exist_ok=True) + + # Create the LMDB database file + db_path = os.path.join(folder_path, "data.aselmdb") + + # Initialize LMDB environment + env = lmdb.open( + db_path, + map_size=1099511627776, # 1TB + subdir=False, + meminit=False, + map_async=True, + ) + + # Open a transaction + with env.begin(write=True) as txn: + # Store metadata + metadata = {"format_version": 1} + txn.put( + "metadata".encode("ascii"), + zlib.compress(orjson.dumps(metadata, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + # Store nextid + nextid = len(atoms_list) + 1 + txn.put( + "nextid".encode("ascii"), + zlib.compress(orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + # Store deleted_ids (empty) + txn.put( + "deleted_ids".encode("ascii"), + zlib.compress(orjson.dumps([], option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + # Store each atom + for i, atoms in enumerate(atoms_list): + id_num = i + 1 # Start from 1 + + # Convert atoms to dictionary + positions = atoms.get_positions() + cell = atoms.get_cell() + + # Create a dictionary with all necessary fields + dct = { + "numbers": atoms.get_atomic_numbers().tolist(), + "positions": positions.tolist(), + "cell": cell.tolist(), + "pbc": atoms.get_pbc().tolist(), + "ctime": 0.0, # Creation time + "mtime": 0.0, # Modification time + "user": "test", + "energy": atoms.calc.results["energy"], + "forces": atoms.calc.results["forces"].tolist(), + "stress": atoms.calc.results["stress"].tolist(), + "key_value_pairs": { + "config_type": atoms.info.get("config_type", "Default"), + "head": head_name, + }, + } + + # Store the atom in LMDB + txn.put( + f"{id_num}".encode("ascii"), + zlib.compress(orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY)), + ) + + # Close the environment + env.close() + + return folder_path + + +@pytest.mark.slow +def test_multifile_training(): + """Test training with multiple file formats per head""" + # Create temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Set up file paths + xyz_file1 = os.path.join(temp_dir, "data1.xyz") + xyz_file2 = os.path.join(temp_dir, "data2.xyz") + iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") + h5_folder = os.path.join(temp_dir, "h5_data") + lmdb_folder1 = os.path.join( + temp_dir, "lmdb_data1_lmdb" + ) # Add _lmdb suffix for LMDB recognition + lmdb_folder2 = os.path.join( + temp_dir, "lmdb_data2_lmdb" + ) # Add _lmdb suffix for LMDB recognition + + config_path = os.path.join(temp_dir, "config.yaml") + results_dir = os.path.join(temp_dir, "results") + checkpoints_dir = os.path.join(temp_dir, "checkpoints") + model_dir = os.path.join(temp_dir, "models") + e0s_file = os.path.join(temp_dir, "e0s.json") + + # Create directories + os.makedirs(results_dir, exist_ok=True) + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # Set atomic numbers for z_table + z_table_elements = [1, 6, 7, 8] # H, C, N, O + + # Create test data for each format + rng = np.random.RandomState(42) + seeds = rng.randint(0, 10000, size=5) + + # Create isolated atoms for E0s (one of each element) + isolated_atoms = [] + e0s_dict = {} + for z in z_table_elements: + # Create isolated atom + atom = Atoms( + numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True + ) + energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy + forces = np.zeros((1, 3)) + stress = np.zeros(6) + calc = SinglePointCalculator( + atom, energy=energy, forces=forces, stress=stress + ) + atom.calc = calc + atom.info["config_type"] = "IsolatedAtom" + atom.info["REF_energy"] = energy # Make sure energy is in the right place + isolated_atoms.append(atom) + e0s_dict[str(z)] = energy # Store energy for E0s file + + # Create E0s file + create_e0s_file(e0s_dict, e0s_file) + + # Create isolated atoms xyz file + create_xyz_file(isolated_atoms, iso_atoms_file) + + # Create 10 atoms for each dataset + xyz_atoms1 = [ + create_test_atoms(num_atoms=5, seed=seeds[0] + i) for i in range(10) + ] + xyz_atoms2 = [ + create_test_atoms(num_atoms=5, seed=seeds[1] + i) for i in range(10) + ] + + # Create h5 data directly - first convert the xyz file to a format with REF_ keys + for atom in xyz_atoms1: + atom.info["REF_energy"] = atom.calc.results["energy"] + atom.arrays["REF_forces"] = atom.calc.results["forces"] + atom.info["REF_stress"] = atom.calc.results["stress"] + + for atom in xyz_atoms2: + atom.info["REF_energy"] = atom.calc.results["energy"] + atom.arrays["REF_forces"] = atom.calc.results["forces"] + atom.info["REF_stress"] = atom.calc.results["stress"] + + # Save isolated atoms to xyz files first, then create the h5 datasets + create_xyz_file(xyz_atoms1, xyz_file1) + create_xyz_file(xyz_atoms2, xyz_file2) + + # Create h5 data from xyz file, using both isolated atoms and real data + all_atoms_for_h5 = isolated_atoms + xyz_atoms2 + all_atoms_xyz = os.path.join(temp_dir, "all_atoms_for_h5.xyz") + create_xyz_file(all_atoms_for_h5, all_atoms_xyz) + create_h5_dataset(all_atoms_xyz, h5_folder) + + # Create LMDB datasets + lmdb_atoms1 = [ + create_test_atoms(num_atoms=5, seed=seeds[3] + i) for i in range(10) + ] + lmdb_atoms2 = [ + create_test_atoms(num_atoms=5, seed=seeds[4] + i) for i in range(10) + ] + create_lmdb_dataset(lmdb_atoms1, lmdb_folder1, head_name="head1") + create_lmdb_dataset(lmdb_atoms2, lmdb_folder2, head_name="head2") + + # Create config.yaml for training with proper format specification + config = { + "name": "multifile_test", + "seed": 42, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 5, + "max_num_epochs": 2, + "patience": 5, + "device": "cpu", + "energy_weight": 1.0, + "forces_weight": 10.0, + "loss": "weighted", + "optimizer": "adam", + "default_dtype": "float64", + "lr": 0.01, + "swa": False, + "work_dir": temp_dir, + "results_dir": results_dir, + "checkpoints_dir": checkpoints_dir, + "model_dir": model_dir, + "E0s": e0s_file, + "atomic_numbers": str(z_table_elements), + "heads": { + "head1": { + "train_file": [lmdb_folder1, xyz_file1], + "valid_file": xyz_file1, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + }, + "head2": { + "train_file": [h5_folder + "/train", xyz_file2], + "valid_file": xyz_file2, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + }, + }, + } + + # Write config file + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config, f) + + # Import the modified run_train from our local module + run_train_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + ) + + # Run training with subprocess + cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] + + # Set environment to add the current path to PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + # Run the process + process = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, # Don't raise exception on non-zero exit, we'll check manually + ) + + # Print output for debugging + print("\n" + "=" * 40 + " STDOUT " + "=" * 40) + print(process.stdout.decode()) + print("\n" + "=" * 40 + " STDERR " + "=" * 40) + print(process.stderr.decode()) + + # Check that process completed successfully + assert ( + process.returncode == 0 + ), f"Training failed with error: {process.stderr.decode()}" + + # Check that model was created + model_path = os.path.join(model_dir, "multifile_test.model") + assert os.path.exists(model_path), f"Model was not created at {model_path}" + + # Try to load and run the model + model = torch.load(model_path, map_location="cpu") + assert model is not None, "Failed to load model" + + # Create a calculator + calc = MACECalculator(model_paths=model_path, device="cpu", head="head1") + + # Run prediction on a test atom + test_atom = create_test_atoms(num_atoms=5, seed=99999) + test_atom.calc = calc + energy = test_atom.get_potential_energy() + forces = test_atom.get_forces() + + # Assert we got sensible outputs + assert np.isfinite(energy), "Model produced non-finite energy" + assert np.all(np.isfinite(forces)), "Model produced non-finite forces" + + finally: + # Clean up + shutil.rmtree(temp_dir) + + +@pytest.mark.slow +def test_multiple_xyz_per_head(): + """Test training with multiple XYZ files per head for train, valid and test sets""" + # Create temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Set up file paths - create multiple xyz files for each dataset + train_xyz_files = [ + os.path.join(temp_dir, f"train_data{i}.xyz") for i in range(1, 4) + ] # 3 train files + valid_xyz_files = [ + os.path.join(temp_dir, f"valid_data{i}.xyz") for i in range(1, 3) + ] # 2 valid files + test_xyz_files = [ + os.path.join(temp_dir, f"test_data{i}.xyz") for i in range(1, 3) + ] # 2 test files + + iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") + + config_path = os.path.join(temp_dir, "config.yaml") + results_dir = os.path.join(temp_dir, "results") + checkpoints_dir = os.path.join(temp_dir, "checkpoints") + model_dir = os.path.join(temp_dir, "models") + e0s_file = os.path.join(temp_dir, "e0s.json") + + # Create directories + os.makedirs(results_dir, exist_ok=True) + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # Set atomic numbers for z_table + z_table_elements = [1, 6, 7, 8] # H, C, N, O + + # Create test data for each format + rng = np.random.RandomState(42) + seeds = rng.randint(0, 10000, size=10) # More seeds for multiple files + + # Create isolated atoms for E0s (one of each element) + isolated_atoms = [] + e0s_dict = {} + for z in z_table_elements: + # Create isolated atom + atom = Atoms( + numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True + ) + energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy + forces = np.zeros((1, 3)) + stress = np.zeros(6) + calc = SinglePointCalculator( + atom, energy=energy, forces=forces, stress=stress + ) + atom.calc = calc + atom.info["config_type"] = "IsolatedAtom" + isolated_atoms.append(atom) + e0s_dict[str(z)] = energy # Store energy for E0s file + + # Create E0s file + create_e0s_file(e0s_dict, e0s_file) + + # Create isolated atoms xyz file + create_xyz_file(isolated_atoms, iso_atoms_file) + + # Create atoms for each train dataset - use different seeds for variety + train_datasets = [] + for i, file in enumerate(train_xyz_files): + # Create atoms with different seeds + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i] + j) for j in range(5) + ] + create_xyz_file(atoms, file) + train_datasets.append(atoms) + + # Create atoms for validation datasets + valid_datasets = [] + for i, file in enumerate(valid_xyz_files): + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i + 3] + j) for j in range(3) + ] + create_xyz_file(atoms, file) + valid_datasets.append(atoms) + + # Create atoms for test datasets + test_datasets = [] + for i, file in enumerate(test_xyz_files): + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i + 5] + j) for j in range(3) + ] + create_xyz_file(atoms, file) + test_datasets.append(atoms) + + # Create config.yaml for training with multiple xyz files per dataset + config = { + "name": "multi_xyz_test", + "seed": 42, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 5, + "max_num_epochs": 2, + "patience": 5, + "device": "cpu", + "energy_weight": 1.0, + "forces_weight": 10.0, + "loss": "weighted", + "optimizer": "adam", + "default_dtype": "float64", + "lr": 0.01, + "swa": False, + "work_dir": temp_dir, + "results_dir": results_dir, + "checkpoints_dir": checkpoints_dir, + "model_dir": model_dir, + "E0s": e0s_file, + "atomic_numbers": str(z_table_elements), + "heads": { + "multi_xyz_head": { + # Using lists of multiple xyz files for each dataset + "train_file": train_xyz_files, + "valid_file": valid_xyz_files, + "test_file": test_xyz_files, + "energy_key": "energy", + "forces_key": "forces", + "stress_key": "stress", + }, + }, + } + + # Write config file + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config, f) + + # Import the modified run_train from our local module + run_train_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + ) + + # Run training with subprocess + cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] + + # Set environment to add the current path to PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + # Run the process + process = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + # Print output for debugging + print("\n" + "=" * 40 + " STDOUT " + "=" * 40) + print(process.stdout.decode()) + print("\n" + "=" * 40 + " STDERR " + "=" * 40) + print(process.stderr.decode()) + + # Check that process completed successfully + assert ( + process.returncode == 0 + ), f"Training failed with error: {process.stderr.decode()}" + + # Check that model was created + model_path = os.path.join(model_dir, "multi_xyz_test.model") + assert os.path.exists(model_path), f"Model was not created at {model_path}" + + # Try to load and run the model + model = torch.load(model_path, map_location="cpu") + assert model is not None, "Failed to load model" + + # Create a calculator + calc = MACECalculator( + model_paths=model_path, device="cpu", head="multi_xyz_head" + ) + + # Run prediction on a test atom + test_atom = create_test_atoms(num_atoms=5, seed=99999) + test_atom.calc = calc + energy = test_atom.get_potential_energy() + forces = test_atom.get_forces() + + # Assert we got sensible outputs + assert np.isfinite(energy), "Model produced non-finite energy" + assert np.all(np.isfinite(forces)), "Model produced non-finite forces" + + finally: + # Clean up + shutil.rmtree(temp_dir) + + +@pytest.mark.slow +def test_single_xyz_per_head(): + """Test training with multiple XYZ files per head for train, valid and test sets""" + # Create temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Set up file paths - create multiple xyz files for each dataset + train_xyz_files = [ + os.path.join(temp_dir, f"train_data{i}.xyz") for i in range(1, 2) + ] # 3 train files + valid_xyz_files = [ + os.path.join(temp_dir, f"valid_data{i}.xyz") for i in range(1, 2) + ] # 2 valid files + test_xyz_files = [ + os.path.join(temp_dir, f"test_data{i}.xyz") for i in range(1, 2) + ] # 2 test files + + iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") + + config_path = os.path.join(temp_dir, "config.yaml") + results_dir = os.path.join(temp_dir, "results") + checkpoints_dir = os.path.join(temp_dir, "checkpoints") + model_dir = os.path.join(temp_dir, "models") + e0s_file = os.path.join(temp_dir, "e0s.json") + + # Create directories + os.makedirs(results_dir, exist_ok=True) + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # Set atomic numbers for z_table + z_table_elements = [1, 6, 7, 8] # H, C, N, O + + # Create test data for each format + rng = np.random.RandomState(42) + seeds = rng.randint(0, 10000, size=10) # More seeds for multiple files + + # Create isolated atoms for E0s (one of each element) + isolated_atoms = [] + e0s_dict = {} + for z in z_table_elements: + # Create isolated atom + atom = Atoms( + numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True + ) + energy = float(rng.uniform(-5.0, -1.0)) # Random reference energy + forces = np.zeros((1, 3)) + stress = np.zeros(6) + calc = SinglePointCalculator( + atom, energy=energy, forces=forces, stress=stress + ) + atom.calc = calc + atom.info["config_type"] = "IsolatedAtom" + isolated_atoms.append(atom) + e0s_dict[str(z)] = energy # Store energy for E0s file + + # Create E0s file + create_e0s_file(e0s_dict, e0s_file) + + # Create isolated atoms xyz file + create_xyz_file(isolated_atoms, iso_atoms_file) + + # Create atoms for each train dataset - use different seeds for variety + train_datasets = [] + for i, file in enumerate(train_xyz_files): + # Create atoms with different seeds + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i] + j) for j in range(5) + ] + create_xyz_file(atoms, file) + train_datasets.append(atoms) + + # Create atoms for validation datasets + valid_datasets = [] + for i, file in enumerate(valid_xyz_files): + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i + 3] + j) for j in range(3) + ] + create_xyz_file(atoms, file) + valid_datasets.append(atoms) + + # Create atoms for test datasets + test_datasets = [] + for i, file in enumerate(test_xyz_files): + atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[i + 5] + j) for j in range(3) + ] + create_xyz_file(atoms, file) + test_datasets.append(atoms) + + # Create config.yaml for training with multiple xyz files per dataset + config = { + "name": "multi_xyz_test", + "seed": 42, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 5, + "max_num_epochs": 2, + "patience": 5, + "device": "cpu", + "energy_weight": 1.0, + "forces_weight": 10.0, + "loss": "weighted", + "optimizer": "adam", + "default_dtype": "float64", + "lr": 0.01, + "swa": False, + "work_dir": temp_dir, + "results_dir": results_dir, + "checkpoints_dir": checkpoints_dir, + "model_dir": model_dir, + "E0s": e0s_file, + "atomic_numbers": str(z_table_elements), + "heads": { + "multi_xyz_head": { + # Using lists of multiple xyz files for each dataset + "train_file": train_xyz_files, + "valid_file": valid_xyz_files, + "test_file": test_xyz_files, + "energy_key": "energy", + "forces_key": "forces", + "stress_key": "stress", + }, + }, + } + + # Write config file + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config, f) + + # Import the modified run_train from our local module + run_train_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + ) + + # Run training with subprocess + cmd = [sys.executable, str(run_train_script), f"--config={config_path}"] + + # Set environment to add the current path to PYTHONPATH + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + # Run the process + process = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + # Print output for debugging + print("\n" + "=" * 40 + " STDOUT " + "=" * 40) + print(process.stdout.decode()) + print("\n" + "=" * 40 + " STDERR " + "=" * 40) + print(process.stderr.decode()) + + # Check that process completed successfully + assert ( + process.returncode == 0 + ), f"Training failed with error: {process.stderr.decode()}" + + # Check that model was created + model_path = os.path.join(model_dir, "multi_xyz_test.model") + assert os.path.exists(model_path), f"Model was not created at {model_path}" + + # Try to load and run the model + model = torch.load(model_path, map_location="cpu") + assert model is not None, "Failed to load model" + + # Create a calculator + calc = MACECalculator( + model_paths=model_path, device="cpu", head="multi_xyz_head" + ) + + # Run prediction on a test atom + test_atom = create_test_atoms(num_atoms=5, seed=99999) + test_atom.calc = calc + energy = test_atom.get_potential_energy() + forces = test_atom.get_forces() + + # Assert we got sensible outputs + assert np.isfinite(energy), "Model produced non-finite energy" + assert np.all(np.isfinite(forces)), "Model produced non-finite forces" + + finally: + # Clean up + shutil.rmtree(temp_dir) + + +@pytest.mark.slow +def test_multihead_finetuning_different_formats(): + """Test multihead finetuning with different file formats for each head.""" + # Create temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Set up file paths + xyz_file = os.path.join(temp_dir, "finetuning_xyz.xyz") + h5_folder = os.path.join(temp_dir, "h5_data") + iso_atoms_file = os.path.join(temp_dir, "isolated_atoms.xyz") + + config_path = os.path.join(temp_dir, "config.yaml") + results_dir = os.path.join(temp_dir, "results") + checkpoints_dir = os.path.join(temp_dir, "checkpoints") + model_dir = os.path.join(temp_dir, "models") + e0s_file = os.path.join(temp_dir, "e0s.json") + + # Create directories + os.makedirs(results_dir, exist_ok=True) + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # Set atomic numbers for z_table + z_table_elements = [1, 6, 7, 8] # H, C, N, O + + # Create test data with different seeds + rng = np.random.RandomState(42) + seeds = rng.randint(0, 10000, size=3) + + # Create isolated atoms for E0s (one of each element) + isolated_atoms = [] + e0s_dict = {} + for z in z_table_elements: + atom = Atoms( + numbers=[z], positions=[[0, 0, 0]], cell=np.eye(3) * 10.0, pbc=True + ) + energy = float(rng.uniform(-5.0, -1.0)) + forces = np.zeros((1, 3)) + stress = np.zeros(6) + calc = SinglePointCalculator( + atom, energy=energy, forces=forces, stress=stress + ) + atom.calc = calc + atom.info["config_type"] = "IsolatedAtom" + atom.info["REF_energy"] = energy # Make sure energy is in the right place + atom.arrays["REF_forces"] = forces + atom.info["REF_stress"] = stress + isolated_atoms.append(atom) + e0s_dict[str(z)] = energy + + # Create E0s file + create_e0s_file(e0s_dict, e0s_file) + + # Create isolated atoms xyz file + create_xyz_file(isolated_atoms, iso_atoms_file) + + # Create XYZ data for xyz_head + xyz_atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[0] + i) for i in range(30) + ] + # Add REF_ properties + for atom in xyz_atoms: + atom.info["REF_energy"] = atom.calc.results["energy"] + atom.arrays["REF_forces"] = atom.calc.results["forces"] + atom.info["REF_stress"] = atom.calc.results["stress"] + atom.info["head"] = "xyz_head" # Assign head + create_xyz_file(xyz_atoms, xyz_file) + + # Create H5 data for h5_head + h5_atoms = [ + create_test_atoms(num_atoms=5, seed=seeds[1] + i) for i in range(30) + ] + # Add REF_ properties + for atom in h5_atoms: + atom.info["REF_energy"] = atom.calc.results["energy"] + atom.arrays["REF_forces"] = atom.calc.results["forces"] + atom.info["REF_stress"] = atom.calc.results["stress"] + atom.info["head"] = "h5_head" # Assign head + + h5_atoms_xyz = os.path.join(temp_dir, "h5_atoms.xyz") + create_xyz_file(h5_atoms, h5_atoms_xyz) + # Include isolated atoms for E0s in the h5 dataset + all_atoms_for_h5 = h5_atoms + isolated_atoms + all_atoms_h5_xyz = os.path.join(temp_dir, "all_atoms_for_h5.xyz") + create_xyz_file(all_atoms_for_h5, all_atoms_h5_xyz) + create_h5_dataset(all_atoms_h5_xyz, h5_folder) + + # Create config.yaml for multihead finetuning + heads = { + "xyz_head": { + "train_file": xyz_file, + "valid_fraction": 0.2, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "E0s": e0s_file, + }, + "h5_head": { + "train_file": os.path.join(h5_folder, "train"), + "valid_file": os.path.join(h5_folder, "val"), + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "E0s": e0s_file, + }, + } + + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + + with open(config_path, "w", encoding="utf-8") as f: + f.write(yaml_str) + + # Now perform multihead finetuning + finetuning_params = { + "name": "multihead_finetuned", + "config": config_path, + "foundation_model": "small", # Use the small foundation model + "energy_weight": 1.0, + "forces_weight": 10.0, + "model": "MACE", + "hidden_irreps": "128x0e", # Match foundation model + "r_max": 5.0, + "batch_size": 2, + "max_num_epochs": 2, # Just do a quick finetuning for test + "device": "cpu", + "seed": 42, + "loss": "weighted", + "default_dtype": "float64", + "checkpoints_dir": checkpoints_dir, + "model_dir": model_dir, + "results_dir": results_dir, + "atomic_numbers": "[" + ",".join(map(str, z_table_elements)) + "]", + "multiheads_finetuning": True, + "filter_type_pt": "combinations", + "subselect_pt": "random", + "num_samples_pt": 10, # Small number for testing + "force_mh_ft_lr": True, # Force using specified learning rate + } + + # Run finetuning + run_train_script = ( + Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + ) + env = os.environ.copy() + env["PYTHONPATH"] = ( + str(Path(__file__).parent.parent) + ":" + env.get("PYTHONPATH", "") + ) + + cmd = [sys.executable, str(run_train_script)] + for k, v in finetuning_params.items(): + if v is None: + cmd.append(f"--{k}") + else: + cmd.append(f"--{k}={v}") + + # Run the process + process = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + # Print output for debugging + print("\n" + "=" * 40 + " STDOUT " + "=" * 40) + print(process.stdout.decode()) + print("\n" + "=" * 40 + " STDERR " + "=" * 40) + print(process.stderr.decode()) + + # Check that process completed successfully + assert ( + process.returncode == 0 + ), f"Finetuning failed with error: {process.stderr.decode()}" + + # Check that model was created + model_path = os.path.join(model_dir, "multihead_finetuned.model") + assert os.path.exists(model_path), f"Model was not created at {model_path}" + + # Load model and verify it has the expected heads + model = torch.load(model_path, map_location="cpu") + assert hasattr(model, "heads"), "Model does not have heads attribute" + assert set(["xyz_head", "h5_head", "pt_head"]).issubset( + set(model.heads) + ), "Expected heads not found in model" + + # Try to run the model with both heads + # For xyz_head + calc_xyz = MACECalculator( + model_paths=model_path, + device="cpu", + head="xyz_head", + default_dtype="float64", + ) + test_atom = create_test_atoms(num_atoms=5, seed=99999) + test_atom.calc = calc_xyz + energy_xyz = test_atom.get_potential_energy() + forces_xyz = test_atom.get_forces() + + # For h5_head + calc_h5 = MACECalculator( + model_paths=model_path, + device="cpu", + head="h5_head", + default_dtype="float64", + ) + test_atom.calc = calc_h5 + energy_h5 = test_atom.get_potential_energy() + forces_h5 = test_atom.get_forces() + + # Verify results + assert np.isfinite(energy_xyz), "xyz_head produced non-finite energy" + assert np.all(np.isfinite(forces_xyz)), "xyz_head produced non-finite forces" + assert np.isfinite(energy_h5), "h5_head produced non-finite energy" + assert np.all(np.isfinite(forces_h5)), "h5_head produced non-finite forces" + + finally: + # Clean up + shutil.rmtree(temp_dir) diff --git a/mace-bench/3rdparty/mace/tests/test_preprocess.py b/mace-bench/3rdparty/mace/tests/test_preprocess.py index f976ff1..5b070e5 100644 --- a/mace-bench/3rdparty/mace/tests/test_preprocess.py +++ b/mace-bench/3rdparty/mace/tests/test_preprocess.py @@ -1,206 +1,206 @@ -import os -import subprocess -import sys -from pathlib import Path - -import ase.io -import numpy as np -import pytest -import yaml -from ase.atoms import Atoms - -pytest_mace_dir = Path(__file__).parent.parent -preprocess_data = Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" - - -@pytest.fixture(name="sample_configs") -def fixture_sample_configs(): - water = Atoms( - numbers=[8, 1, 1], - positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], - cell=[4] * 3, - pbc=[True] * 3, - ) - configs = [ - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), - ] - configs[0].info["REF_energy"] = 0.0 - configs[0].info["config_type"] = "IsolatedAtom" - configs[1].info["REF_energy"] = 0.0 - configs[1].info["config_type"] = "IsolatedAtom" - - np.random.seed(5) - for _ in range(10): - c = water.copy() - c.positions += np.random.normal(0.1, size=c.positions.shape) - c.info["REF_energy"] = np.random.normal(0.1) - c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) - c.info["REF_stress"] = np.random.normal(0.1, size=6) - configs.append(c) - - return configs - - -def test_preprocess_data(tmp_path, sample_configs): - ase.io.write(tmp_path / "sample.xyz", sample_configs) - - preprocess_params = { - "train_file": tmp_path / "sample.xyz", - "r_max": 5.0, - "config_type_weights": "{'Default':1.0}", - "num_process": 2, - "valid_fraction": 0.1, - "h5_prefix": tmp_path / "preprocessed_", - "compute_statistics": None, - "seed": 42, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - } - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(preprocess_data) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in preprocess_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - # Check if the output files are created - assert (tmp_path / "preprocessed_train").is_dir() - assert (tmp_path / "preprocessed_val").is_dir() - assert (tmp_path / "preprocessed_statistics.json").is_file() - - # Check if the correct number of files are created - train_files = list((tmp_path / "preprocessed_train").glob("*.h5")) - val_files = list((tmp_path / "preprocessed_val").glob("*.h5")) - assert len(train_files) == preprocess_params["num_process"] - assert len(val_files) == preprocess_params["num_process"] - - # Example of checking statistics file content: - import json - - with open(tmp_path / "preprocessed_statistics.json", "r", encoding="utf-8") as f: - statistics = json.load(f) - assert "atomic_energies" in statistics - assert "avg_num_neighbors" in statistics - assert "mean" in statistics - assert "std" in statistics - assert "atomic_numbers" in statistics - assert "r_max" in statistics - - # Example of checking H5 file content: - import h5py - - with h5py.File(train_files[0], "r") as f: - assert "config_batch_0" in f - config = f["config_batch_0"]["config_0"] - assert "atomic_numbers" in config - assert "positions" in config - assert "energy" in config["properties"] - assert "forces" in config["properties"] - - original_energies = [ - config.info["REF_energy"] - for config in sample_configs[2:] - if "REF_energy" in config.info - ] - original_forces = [ - config.arrays["REF_forces"] - for config in sample_configs[2:] - if "REF_forces" in config.arrays - ] - - h5_energies = [] - h5_forces = [] - - for train_file in train_files: - with h5py.File(train_file, "r") as f: - for _, batch in f.items(): - for config_key in batch.keys(): - config = batch[config_key] - assert "atomic_numbers" in config - assert "positions" in config - assert "energy" in config["properties"] - assert "forces" in config["properties"] - - h5_energies.append(config["properties"]["energy"][()]) - h5_forces.append(config["properties"]["forces"][()]) - - for val_file in val_files: - with h5py.File(val_file, "r") as f: - for _, batch in f.items(): - for config_key in batch.keys(): - config = batch[config_key] - h5_energies.append(config["properties"]["energy"][()]) - h5_forces.append(config["properties"]["forces"][()]) - - print("Original energies", original_energies) - print("H5 energies", h5_energies) - print("Original forces", original_forces) - print("H5 forces", h5_forces) - original_energies.sort() - h5_energies.sort() - original_forces = np.concatenate(original_forces).flatten() - h5_forces = np.concatenate(h5_forces).flatten() - original_forces.sort() - h5_forces.sort() - - # Compare energies and forces - np.testing.assert_allclose(original_energies, h5_energies, rtol=1e-5, atol=1e-8) - np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) - - print("All checks passed successfully!") - - -def test_preprocess_config(tmp_path, sample_configs): - ase.io.write(tmp_path / "sample.xyz", sample_configs) - - preprocess_params = { - "train_file": str(tmp_path / "sample.xyz"), - "r_max": 5.0, - "config_type_weights": "{'Default':1.0}", - "num_process": 2, - "valid_fraction": 0.1, - "h5_prefix": str(tmp_path / "preprocessed_"), - "compute_statistics": None, - "seed": 42, - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - } - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - yaml.dump(preprocess_params, file) - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(preprocess_data) - + " " - + "--config" - + " " - + str(filename) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import yaml +from ase.atoms import Atoms + +pytest_mace_dir = Path(__file__).parent.parent +preprocess_data = Path(__file__).parent.parent / "mace" / "cli" / "preprocess_data.py" + + +@pytest.fixture(name="sample_configs") +def fixture_sample_configs(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + configs[0].info["REF_energy"] = 0.0 + configs[0].info["config_type"] = "IsolatedAtom" + configs[1].info["REF_energy"] = 0.0 + configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(10): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + configs.append(c) + + return configs + + +def test_preprocess_data(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": tmp_path / "sample.xyz", + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": tmp_path / "preprocessed_", + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in preprocess_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Check if the output files are created + assert (tmp_path / "preprocessed_train").is_dir() + assert (tmp_path / "preprocessed_val").is_dir() + assert (tmp_path / "preprocessed_statistics.json").is_file() + + # Check if the correct number of files are created + train_files = list((tmp_path / "preprocessed_train").glob("*.h5")) + val_files = list((tmp_path / "preprocessed_val").glob("*.h5")) + assert len(train_files) == preprocess_params["num_process"] + assert len(val_files) == preprocess_params["num_process"] + + # Example of checking statistics file content: + import json + + with open(tmp_path / "preprocessed_statistics.json", "r", encoding="utf-8") as f: + statistics = json.load(f) + assert "atomic_energies" in statistics + assert "avg_num_neighbors" in statistics + assert "mean" in statistics + assert "std" in statistics + assert "atomic_numbers" in statistics + assert "r_max" in statistics + + # Example of checking H5 file content: + import h5py + + with h5py.File(train_files[0], "r") as f: + assert "config_batch_0" in f + config = f["config_batch_0"]["config_0"] + assert "atomic_numbers" in config + assert "positions" in config + assert "energy" in config["properties"] + assert "forces" in config["properties"] + + original_energies = [ + config.info["REF_energy"] + for config in sample_configs[2:] + if "REF_energy" in config.info + ] + original_forces = [ + config.arrays["REF_forces"] + for config in sample_configs[2:] + if "REF_forces" in config.arrays + ] + + h5_energies = [] + h5_forces = [] + + for train_file in train_files: + with h5py.File(train_file, "r") as f: + for _, batch in f.items(): + for config_key in batch.keys(): + config = batch[config_key] + assert "atomic_numbers" in config + assert "positions" in config + assert "energy" in config["properties"] + assert "forces" in config["properties"] + + h5_energies.append(config["properties"]["energy"][()]) + h5_forces.append(config["properties"]["forces"][()]) + + for val_file in val_files: + with h5py.File(val_file, "r") as f: + for _, batch in f.items(): + for config_key in batch.keys(): + config = batch[config_key] + h5_energies.append(config["properties"]["energy"][()]) + h5_forces.append(config["properties"]["forces"][()]) + + print("Original energies", original_energies) + print("H5 energies", h5_energies) + print("Original forces", original_forces) + print("H5 forces", h5_forces) + original_energies.sort() + h5_energies.sort() + original_forces = np.concatenate(original_forces).flatten() + h5_forces = np.concatenate(h5_forces).flatten() + original_forces.sort() + h5_forces.sort() + + # Compare energies and forces + np.testing.assert_allclose(original_energies, h5_energies, rtol=1e-5, atol=1e-8) + np.testing.assert_allclose(original_forces, h5_forces, rtol=1e-5, atol=1e-8) + + print("All checks passed successfully!") + + +def test_preprocess_config(tmp_path, sample_configs): + ase.io.write(tmp_path / "sample.xyz", sample_configs) + + preprocess_params = { + "train_file": str(tmp_path / "sample.xyz"), + "r_max": 5.0, + "config_type_weights": "{'Default':1.0}", + "num_process": 2, + "valid_fraction": 0.1, + "h5_prefix": str(tmp_path / "preprocessed_"), + "compute_statistics": None, + "seed": 42, + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + } + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + yaml.dump(preprocess_params, file) + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(preprocess_data) + + " " + + "--config" + + " " + + str(filename) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 diff --git a/mace-bench/3rdparty/mace/tests/test_run_train.py b/mace-bench/3rdparty/mace/tests/test_run_train.py index 921d12e..ddb8496 100644 --- a/mace-bench/3rdparty/mace/tests/test_run_train.py +++ b/mace-bench/3rdparty/mace/tests/test_run_train.py @@ -1,1458 +1,1458 @@ -import json -import os -import subprocess -import sys -from pathlib import Path - -import ase.io -import numpy as np -import pytest -import torch -from ase.atoms import Atoms - -from mace.calculators import MACECalculator, mace_mp - -try: - import cuequivariance as cue # pylint: disable=unused-import - - CUET_AVAILABLE = True -except ImportError: - CUET_AVAILABLE = False - -run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - - -@pytest.fixture(name="fitting_configs") -def fixture_fitting_configs(): - water = Atoms( - numbers=[8, 1, 1], - positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], - cell=[4] * 3, - pbc=[True] * 3, - ) - fit_configs = [ - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), - ] - fit_configs[0].info["REF_energy"] = 0.0 - fit_configs[0].info["config_type"] = "IsolatedAtom" - fit_configs[1].info["REF_energy"] = 0.0 - fit_configs[1].info["config_type"] = "IsolatedAtom" - - np.random.seed(5) - for _ in range(20): - c = water.copy() - c.positions += np.random.normal(0.1, size=c.positions.shape) - c.info["REF_energy"] = np.random.normal(0.1) - print(c.info["REF_energy"]) - c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) - c.info["REF_stress"] = np.random.normal(0.1, size=6) - fit_configs.append(c) - - return fit_configs - - -@pytest.fixture(name="pretraining_configs") -def fixture_pretraining_configs(): - configs = [] - for _ in range(10): - atoms = Atoms( - numbers=[8, 1, 1], - positions=np.random.rand(3, 3) * 3, - cell=[5, 5, 5], - pbc=[True] * 3, - ) - atoms.info["REF_energy"] = np.random.normal(0, 1) - atoms.arrays["REF_forces"] = np.random.normal(0, 1, size=(3, 3)) - atoms.info["REF_stress"] = np.random.normal(0, 1, size=6) - configs.append(atoms) - configs.append( - Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3), - ) - configs.append( - Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3) - ) - configs[-2].info["REF_energy"] = -2.0 - configs[-2].info["config_type"] = "IsolatedAtom" - configs[-1].info["REF_energy"] = -4.0 - configs[-1].info["config_type"] = "IsolatedAtom" - return configs - - -_mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "128x0e", - "r_max": 3.5, - "batch_size": 5, - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "restart_latest": None, - "device": "cpu", - "seed": 5, - "loss": "stress", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "eval_interval": 2, -} - - -def test_run_train(tmp_path, fitting_configs): - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 - ref_Es = [ - 0.0, - 0.0, - -0.039181344585828524, - -0.0915223395136733, - -0.14953484236456582, - -0.06662480820063998, - -0.09983737353050133, - 0.12477442296789745, - -0.06486086271762856, - -0.1460607988519944, - 0.12886334908465508, - -0.14000990081920373, - -0.05319886578958313, - 0.07780520158391, - -0.08895480281886901, - -0.15474719614734422, - 0.007756765146527644, - -0.044879267197498685, - -0.036065736712447574, - -0.24413743841886623, - -0.0838104612106429, - -0.14751978636626545, - ] - - assert np.allclose(Es, ref_Es) - - -def test_run_train_missing_data(tmp_path, fitting_configs): - del fitting_configs[5].info["REF_energy"] - del fitting_configs[6].arrays["REF_forces"] - del fitting_configs[7].info["REF_stress"] - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 - ref_Es = [ - 0.0, - 0.0, - -0.05464025113696155, - -0.11272131295940478, - 0.039200919331076826, - -0.07517990972827505, - -0.13504202474582666, - 0.0292022872055344, - -0.06541099574579018, - -0.1497824717832886, - 0.19397709360828813, - -0.13587609467143014, - -0.05242956276828463, - -0.0504862057364953, - -0.07095795959430119, - -0.2463753796753703, - -0.002031543147676121, - -0.03864918790300681, - -0.13680153117705554, - -0.23418951968636786, - -0.11790833839379238, - -0.14930562311066484, - ] - assert np.allclose(Es, ref_Es) - - -def test_run_train_no_stress(tmp_path, fitting_configs): - del fitting_configs[5].info["REF_energy"] - del fitting_configs[6].arrays["REF_forces"] - del fitting_configs[7].info["REF_stress"] - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - mace_params["loss"] = "weighted" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 - ref_Es = [ - 0.0, - 0.0, - -0.05450093218377135, - -0.11235475232750518, - 0.03914558031854152, - -0.07500839914816063, - -0.13469160624431492, - 0.029384214243251838, - -0.06521819204166135, - -0.14944896282001804, - 0.19413948083049481, - -0.13543541860473626, - -0.05235495076237124, - -0.049556206595684105, - -0.07080758913030646, - -0.24571898386301153, - -0.002070636306950905, - -0.03863113401320783, - -0.13620291339913712, - -0.23383074855679695, - -0.11776449630199368, - -0.1489441490225184, - ] - assert np.allclose(Es, ref_Es) - - -def test_run_train_multihead(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - fitting_configs_ccd = [] - for _, c in enumerate(fitting_configs): - c_dft = c.copy() - c_dft.info["head"] = "DFT" - fitting_configs_dft.append(c_dft) - - c_mp2 = c.copy() - c_mp2.info["head"] = "MP2" - fitting_configs_mp2.append(c_mp2) - - c_ccd = c.copy() - c_ccd.info["head"] = "CCD" - fitting_configs_ccd.append(c_ccd) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - ase.io.write(tmp_path / "fit_multihead_ccd.xyz", fitting_configs_ccd) - - heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, - "CCD": {"train_file": f"{str(tmp_path)}/fit_multihead_ccd.xyz"}, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["loss"] = "weighted" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["config"] = tmp_path / "config.yaml" - mace_params["batch_size"] = 2 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cpu", - default_dtype="float64", - head="CCD", - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 02/09/2024 on develop branch - ref_Es = [ - 0.0, - 0.0, - 0.10637113905361611, - -0.012499594026624754, - 0.08983077108171753, - 0.21071322543112597, - -0.028921849222784398, - -0.02423359575741567, - 0.022923252188079057, - -0.02048334610058991, - 0.4349711162741364, - -0.04455577015569887, - -0.09765806785570091, - 0.16013134616829822, - 0.0758442928017698, - -0.05931856557011721, - 0.33964473532953265, - 0.134338442158641, - 0.18024119757783053, - -0.18914740992058765, - -0.06503477155294624, - 0.03436649147415213, - ] - assert np.allclose(Es, ref_Es) - - -def test_run_train_foundation(tmp_path, fitting_configs): - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - mace_params["loss"] = "weighted" - mace_params["foundation_model"] = "small" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["multiheads_finetuning"] = False - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 28/03/2023 on repulsion a63434aaab70c84ee016e13e4aca8d57297a0f26 - ref_Es = [ - 1.6780993938446045, - 0.8916864395141602, - 0.7290308475494385, - 0.6194742918014526, - 0.6697757840156555, - 0.7025266289710999, - 0.5818213224411011, - 0.7897703647613525, - 0.6558921337127686, - 0.5071806907653809, - 3.581131935119629, - 0.691562294960022, - 0.6257331967353821, - 0.9560437202453613, - 0.7716934680938721, - 0.6730310916900635, - 0.8297463655471802, - 0.8053972721099854, - 0.8337507247924805, - 0.4107491970062256, - 0.6019601821899414, - 0.7301387786865234, - ] - assert np.allclose(Es, ref_Es) - - -def test_run_train_foundation_multihead(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - atomic_numbers = np.unique( - np.concatenate([at.numbers for at in fitting_configs]) - ).tolist() - for i, c in enumerate(fitting_configs): - if i in (0, 1): - c_dft = c.copy() - c_dft.info["head"] = "DFT" - fitting_configs_dft.append(c_dft) - fitting_configs_dft.append(c) - c_mp2 = c.copy() - c_mp2.info["head"] = "MP2" - fitting_configs_mp2.append(c_mp2) - elif i % 2 == 0: - c.info["head"] = "DFT" - fitting_configs_dft.append(c) - else: - c.info["head"] = "MP2" - fitting_configs_mp2.append(c) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["config"] = tmp_path / "config.yaml" - mace_params["loss"] = "weighted" - mace_params["foundation_model"] = "small" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["batch_size"] = 2 - mace_params["valid_batch_size"] = 1 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" - mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" - mace_params["filter_type_pt"] = "combinations" - mace_params["force_mh_ft_lr"] = True - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - Es = [] - for at in fitting_configs: - config_head = at.info.get("head", "MP2") - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cpu", - default_dtype="float64", - head=config_head, - ) - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 20/08/2024 on commit - ref_Es = [ - 1.654685616493225, - 0.44693732261657715, - 0.8741313815116882, - 0.569085955619812, - 0.7161882519721985, - 0.8654778599739075, - 0.8722733855247498, - 0.49582308530807495, - 0.814422607421875, - 0.7027317881584167, - 0.7196993827819824, - 0.517953097820282, - 0.8631765246391296, - 0.4679797887802124, - 0.8163984417915344, - 0.4252359867095947, - 1.0861445665359497, - 0.6829671263694763, - 0.7136879563331604, - 0.5160345435142517, - 0.7002358436584473, - 0.5574042201042175, - ] - assert np.allclose(Es, ref_Es, atol=1e-1) - - -def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - atomic_numbers = np.unique( - np.concatenate([at.numbers for at in fitting_configs]) - ).tolist() - for i, c in enumerate(fitting_configs): - - if i in (0, 1): - continue # skip isolated atoms, as energies specified by json files below - if i % 2 == 0: - c.info["head"] = "DFT" - fitting_configs_dft.append(c) - else: - c.info["head"] = "MP2" - fitting_configs_mp2.append(c) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - - # write E0s to json files - E0s = {1: 0.0, 8: 0.0} - with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: - json.dump(E0s, f) - with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: - json.dump(E0s, f) - - heads = { - "DFT": { - "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", - "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", - }, - "MP2": { - "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", - "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", - }, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["config"] = tmp_path / "config.yaml" - mace_params["loss"] = "weighted" - mace_params["foundation_model"] = "small" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["batch_size"] = 2 - mace_params["valid_batch_size"] = 1 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" - mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" - mace_params["filter_type_pt"] = "combinations" - mace_params["force_mh_ft_lr"] = True - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - Es = [] - for at in fitting_configs: - config_head = at.info.get("head", "MP2") - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cpu", - default_dtype="float64", - head=config_head, - ) - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 20/08/2024 on commit - ref_Es = [ - 1.654685616493225, - 0.44693732261657715, - 0.8741313815116882, - 0.569085955619812, - 0.7161882519721985, - 0.8654778599739075, - 0.8722733855247498, - 0.49582308530807495, - 0.814422607421875, - 0.7027317881584167, - 0.7196993827819824, - 0.517953097820282, - 0.8631765246391296, - 0.4679797887802124, - 0.8163984417915344, - 0.4252359867095947, - 1.0861445665359497, - 0.6829671263694763, - 0.7136879563331604, - 0.5160345435142517, - 0.7002358436584473, - 0.5574042201042175, - ] - assert np.allclose(Es, ref_Es, atol=1e-1) - - -def test_run_train_multihead_replay_custum_finetuning( - tmp_path, fitting_configs, pretraining_configs -): - ase.io.write(tmp_path / "pretrain.xyz", pretraining_configs) - - foundation_params = { - "name": "foundation", - "train_file": os.path.join(tmp_path, "pretrain.xyz"), - "valid_fraction": 0.2, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 2, - "max_num_epochs": 5, - "swa": None, - "start_swa": 3, - "device": "cpu", - "seed": 42, - "loss": "weighted", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "default_dtype": "float64", - "checkpoints_dir": str(tmp_path), - "model_dir": str(tmp_path), - } - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - - cmd = [sys.executable, str(run_train)] - for k, v in foundation_params.items(): - if v is None: - cmd.append(f"--{k}") - else: - cmd.append(f"--{k}={v}") - - p = subprocess.run(cmd, env=run_env, check=True) - assert p.returncode == 0 - - # Step 3: Create finetuning set - fitting_configs_dft = [] - fitting_configs_mp2 = [] - for i, c in enumerate(fitting_configs): - if i in (0, 1): - c_dft = c.copy() - c_dft.info["head"] = "DFT" - fitting_configs_dft.append(c_dft) - fitting_configs_dft.append(c) - c_mp2 = c.copy() - c_mp2.info["head"] = "MP2" - fitting_configs_mp2.append(c_mp2) - elif i % 2 == 0: - c.info["head"] = "DFT" - fitting_configs_dft.append(c) - else: - c.info["head"] = "MP2" - fitting_configs_mp2.append(c) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - - # Step 4: Finetune the pretrained model with multihead replay - heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - - finetuning_params = { - "name": "finetuned", - "valid_fraction": 0.1, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "32x0e", - "r_max": 5.0, - "batch_size": 2, - "max_num_epochs": 5, - "device": "cpu", - "seed": 42, - "loss": "weighted", - "default_dtype": "float64", - "checkpoints_dir": str(tmp_path), - "model_dir": str(tmp_path), - "foundation_model": os.path.join(tmp_path, "foundation.model"), - "config": os.path.join(tmp_path, "config.yaml"), - "pt_train_file": os.path.join(tmp_path, "pretrain.xyz"), - "num_samples_pt": 3, - "subselect_pt": "random", - "force_mh_ft_lr": True, - } - - cmd = [sys.executable, str(run_train)] - for k, v in finetuning_params.items(): - if v is None: - cmd.append(f"--{k}") - else: - cmd.append(f"--{k}={v}") - - p = subprocess.run(cmd, env=run_env, check=True) - assert p.returncode == 0 - - # Load and test the finetuned model - calc = MACECalculator( - model_paths=tmp_path / "finetuned.model", - device="cpu", - default_dtype="float64", - head="pt_head", - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Energies:", Es) - - # Add some basic checks - assert len(Es) == len(fitting_configs) - assert all(isinstance(E, float) for E in Es) - assert len(set(Es)) > 1 # Ens - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_run_train_cueq(tmp_path, fitting_configs): - torch.set_default_dtype(torch.float64) - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - mace_params["enable_cueq"] = True - mace_params["default_dtype"] = "float64" - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cuda") - Es = [] - for at in fitting_configs[2:]: - at.calc = calc - Es.append(at.get_potential_energy()) - - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True - ) - Es_cueq = [] - for at in fitting_configs[2:]: - at.calc = calc - Es_cueq.append(at.get_potential_energy()) - - # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 - ref_Es = [ - -0.039181344585828524, - -0.0915223395136733, - -0.14953484236456582, - -0.06662480820063998, - -0.09983737353050133, - 0.12477442296789745, - -0.06486086271762856, - -0.1460607988519944, - 0.12886334908465508, - -0.14000990081920373, - -0.05319886578958313, - 0.07780520158391, - -0.08895480281886901, - -0.15474719614734422, - 0.007756765146527644, - -0.044879267197498685, - -0.036065736712447574, - -0.24413743841886623, - -0.0838104612106429, - -0.14751978636626545, - ] - - assert np.allclose(Es, ref_Es) - assert np.allclose(ref_Es, Es_cueq) - - -@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") -def test_run_train_foundation_multihead_json_cueq(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - atomic_numbers = np.unique( - np.concatenate([at.numbers for at in fitting_configs]) - ).tolist() - for i, c in enumerate(fitting_configs): - - if i in (0, 1): - continue # skip isolated atoms, as energies specified by json files below - if i % 2 == 0: - c.info["head"] = "DFT" - fitting_configs_dft.append(c) - else: - c.info["head"] = "MP2" - fitting_configs_mp2.append(c) - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) - - # write E0s to json files - E0s = {1: 0.0, 8: 0.0} - with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: - json.dump(E0s, f) - with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: - json.dump(E0s, f) - - heads = { - "DFT": { - "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", - "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", - }, - "MP2": { - "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", - "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", - }, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(yaml_str) - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["config"] = tmp_path / "config.yaml" - mace_params["loss"] = "weighted" - mace_params["foundation_model"] = "small" - mace_params["hidden_irreps"] = "128x0e" - mace_params["r_max"] = 6.0 - mace_params["default_dtype"] = "float64" - mace_params["num_radial_basis"] = 10 - mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" - mace_params["batch_size"] = 2 - mace_params["valid_batch_size"] = 1 - mace_params["num_samples_pt"] = 50 - mace_params["subselect_pt"] = "random" - mace_params["enable_cueq"] = True - mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" - mace_params["filter_type_pt"] = "combinations" - mace_params["device"] = "cuda" - mace_params["force_mh_ft_lr"] = True - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - calc = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cuda", - default_dtype="float64", - head="DFT", - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 20/08/2024 on commit - ref_Es = [ - 1.654685616493225, - 0.44693732261657715, - 0.8741313815116882, - 0.569085955619812, - 0.7161882519721985, - 0.8654778599739075, - 0.8722733855247498, - 0.49582308530807495, - 0.814422607421875, - 0.7027317881584167, - 0.7196993827819824, - 0.517953097820282, - 0.8631765246391296, - 0.4679797887802124, - 0.8163984417915344, - 0.4252359867095947, - 1.0861445665359497, - 0.6829671263694763, - 0.7136879563331604, - 0.5160345435142517, - 0.7002358436584473, - 0.5574042201042175, - ] - assert np.allclose(Es, ref_Es, atol=1e-1) - - -def test_run_train_lbfgs(tmp_path, fitting_configs): - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = tmp_path / "fit.xyz" - mace_params["lbfgs"] = None - mace_params["max_num_epochs"] = 2 - - # make sure run_train.py is using the mace that is currently being tested - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print("Es", Es) - # from a run on 14/03/2025 - ref_Es = [ - 0.0, - 0.0, - -0.1874197850340979, - -0.25991775038059006, - 0.18263492399322268, - -0.15026829765490662, - -0.2403061362015996, - 0.1689257170630718, - -0.2095568077455055, - -0.2957758160829075, - -0.0035370913684985364, - -0.2195416610745775, - -0.25405549447739517, - -0.06201390990366806, - -0.13332219494388334, - -0.19633181702040337, - 0.013014932630445699, - -0.08808335967147174, - -0.06664444189210728, - -0.4230467426992034, - -0.2348250569553676, - -0.17593904833220647, - ] - assert np.allclose(Es, ref_Es, atol=1e-2) - - -def test_run_train_foundation_elements(tmp_path, fitting_configs): - - ase.io.write(tmp_path / "fit.xyz", fitting_configs) - - base_params = { - "name": "MACE", - "checkpoints_dir": str(tmp_path), - "model_dir": str(tmp_path), - "train_file": tmp_path / "fit.xyz", - "loss": "weighted", - "foundation_model": "small", - "hidden_irreps": "128x0e", - "r_max": 6.0, - "default_dtype": "float64", - "max_num_epochs": 5, - "num_radial_basis": 10, - "interaction_first": "RealAgnosticResidualInteractionBlock", - "multiheads_finetuning": False, - } - - # Run environment setup - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - - # First run: without foundation_model_elements (default behavior) - mace_params = base_params.copy() - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - # Load model and check elements - model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") - filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) - assert filtered_elements == {1, 8} # Only H and O should be present - - # Second run: with foundation_model_elements - mace_params = base_params.copy() - mace_params["name"] = "MACE_all_elements" - mace_params["foundation_model_elements"] = True # Flag-only argument - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - # Load model and check elements - model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") - all_elements = set(int(z) for z in model_all.atomic_numbers) - - # Get elements from foundation model for comparison - calc = mace_mp(model="small", device="cpu") - foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) - - # Check that all foundation model elements are preserved - assert all_elements == foundation_elements - assert len(all_elements) > len(filtered_elements) - - # Check that both models can make predictions - at = fitting_configs[2].copy() - - # Test filtered model - calc_filtered = MACECalculator( - model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" - ) - at.calc = calc_filtered - e1 = at.get_potential_energy() - - # Test all-elements model - calc_all = MACECalculator( - model_paths=tmp_path / "MACE_all_elements.model", - device="cpu", - default_dtype="float64", - ) - at.calc = calc_all - e2 = at.get_potential_energy() - - # Energies should be different since the models are trained differently, - # but both should give reasonable results - assert np.isfinite(e1) - assert np.isfinite(e2) - - -def test_run_train_foundation_elements_multihead(tmp_path, fitting_configs): - fitting_configs_dft = [] - fitting_configs_mp2 = [] - atomic_numbers = np.unique( - np.concatenate([at.numbers for at in fitting_configs]) - ).tolist() - for i, c in enumerate(fitting_configs): - if i in (0, 1): - c_dft = c.copy() - c_dft.info["head"] = "DFT" - fitting_configs_dft.append(c_dft) - c_mp2 = c.copy() - c_mp2.info["head"] = "MP2" - fitting_configs_mp2.append(c_mp2) - if i % 2 == 0: - c_copy = c.copy() - c_copy.info["head"] = "DFT" - fitting_configs_dft.append(c_copy) - else: - c_copy = c.copy() - c_copy.info["head"] = "MP2" - fitting_configs_mp2.append(c_copy) - - ase.io.write(tmp_path / "fit_dft.xyz", fitting_configs_dft) - ase.io.write(tmp_path / "fit_mp2.xyz", fitting_configs_mp2) - - # Create multihead configuration - heads = { - "DFT": {"train_file": f"{str(tmp_path)}/fit_dft.xyz"}, - "MP2": {"train_file": f"{str(tmp_path)}/fit_mp2.xyz"}, - } - yaml_str = "heads:\n" - for key, value in heads.items(): - yaml_str += f" {key}:\n" - for sub_key, sub_value in value.items(): - yaml_str += f" {sub_key}: {sub_value}\n" - config_file = tmp_path / "config.yaml" - with open(config_file, "w", encoding="utf-8") as file: - file.write(yaml_str) - - base_params = { - "name": "MACE", - "checkpoints_dir": str(tmp_path), - "model_dir": str(tmp_path), - "config": str(config_file), - "loss": "weighted", - "foundation_model": "small", - "hidden_irreps": "128x0e", - "r_max": 6.0, - "default_dtype": "float64", - "max_num_epochs": 5, - "num_radial_basis": 10, - "interaction_first": "RealAgnosticResidualInteractionBlock", - "force_mh_ft_lr": True, - "batch_size": 1, - "num_samples_pt": 50, - "subselect_pt": "random", - "atomic_numbers": "[" + ",".join(map(str, atomic_numbers)) + "]", - "filter_type_pt": "combinations", - "valid_fraction": 0.1, - "valid_batch_size": 1, - } - - # Run environment setup - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - - # First run: without foundation_model_elements (default behavior) - mace_params = base_params.copy() - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - try: - completed_process = subprocess.run( - cmd.split(), env=run_env, capture_output=True, text=True, check=True - ) - # Process executed successfully - print(completed_process.stdout) - except subprocess.CalledProcessError as e: - # Process failed with non-zero exit code - print(f"Command failed with exit code {e.returncode}") - print(f"STDOUT: {e.stdout}") - print(f"STDERR: {e.stderr}") - raise e - assert completed_process.returncode == 0 - - # Load model and check elements - model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") - filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) - assert filtered_elements == {1, 8} # Only H and O should be present - assert len(model_filtered.heads) == 3 # pt_head + DFT + MP2 - - # Second run: with foundation_model_elements - mace_params = base_params.copy() - mace_params["name"] = "MACE_all_elements" - mace_params["foundation_model_elements"] = True - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - p = subprocess.run(cmd.split(), env=run_env, check=True) - assert p.returncode == 0 - - # Load model and check elements - model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") - all_elements = set(int(z) for z in model_all.atomic_numbers) - - # Get elements from foundation model for comparison - calc = mace_mp(model="small", device="cpu") - foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) - - # Check that all foundation model elements are preserved - assert all_elements == foundation_elements - assert len(all_elements) > len(filtered_elements) - assert len(model_all.heads) == 3 # pt_head + DFT + MP2 - - # Check that both models can make predictions - at = fitting_configs_dft[2].copy() - - # Test filtered model - calc_filtered = MACECalculator( - model_paths=tmp_path / "MACE.model", - device="cpu", - default_dtype="float64", - head="DFT", - ) - at.calc = calc_filtered - e1 = at.get_potential_energy() - - # Test all-elements model - calc_all = MACECalculator( - model_paths=tmp_path / "MACE_all_elements.model", - device="cpu", - default_dtype="float64", - head="DFT", - ) - at.calc = calc_all - e2 = at.get_potential_energy() - - assert np.isfinite(e1) - assert np.isfinite(e2) +import json +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import torch +from ase.atoms import Atoms + +from mace.calculators import MACECalculator, mace_mp + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +@pytest.fixture(name="fitting_configs") +def fixture_fitting_configs(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + fit_configs[0].info["REF_energy"] = 0.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = 0.0 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + print(c.info["REF_energy"]) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + + +@pytest.fixture(name="pretraining_configs") +def fixture_pretraining_configs(): + configs = [] + for _ in range(10): + atoms = Atoms( + numbers=[8, 1, 1], + positions=np.random.rand(3, 3) * 3, + cell=[5, 5, 5], + pbc=[True] * 3, + ) + atoms.info["REF_energy"] = np.random.normal(0, 1) + atoms.arrays["REF_forces"] = np.random.normal(0, 1, size=(3, 3)) + atoms.info["REF_stress"] = np.random.normal(0, 1, size=6) + configs.append(atoms) + configs.append( + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3), + ) + configs.append( + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3) + ) + configs[-2].info["REF_energy"] = -2.0 + configs[-2].info["config_type"] = "IsolatedAtom" + configs[-1].info["REF_energy"] = -4.0 + configs[-1].info["config_type"] = "IsolatedAtom" + return configs + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": "cpu", + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, +} + + +def test_run_train(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + 0.0, + 0.0, + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545, + ] + + assert np.allclose(Es, ref_Es) + + +def test_run_train_missing_data(tmp_path, fitting_configs): + del fitting_configs[5].info["REF_energy"] + del fitting_configs[6].arrays["REF_forces"] + del fitting_configs[7].info["REF_stress"] + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + 0.0, + 0.0, + -0.05464025113696155, + -0.11272131295940478, + 0.039200919331076826, + -0.07517990972827505, + -0.13504202474582666, + 0.0292022872055344, + -0.06541099574579018, + -0.1497824717832886, + 0.19397709360828813, + -0.13587609467143014, + -0.05242956276828463, + -0.0504862057364953, + -0.07095795959430119, + -0.2463753796753703, + -0.002031543147676121, + -0.03864918790300681, + -0.13680153117705554, + -0.23418951968636786, + -0.11790833839379238, + -0.14930562311066484, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_no_stress(tmp_path, fitting_configs): + del fitting_configs[5].info["REF_energy"] + del fitting_configs[6].arrays["REF_forces"] + del fitting_configs[7].info["REF_stress"] + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 28/03/2023 on main 88d49f9ed6925dec07d1777043a36e1fe4872ff3 + ref_Es = [ + 0.0, + 0.0, + -0.05450093218377135, + -0.11235475232750518, + 0.03914558031854152, + -0.07500839914816063, + -0.13469160624431492, + 0.029384214243251838, + -0.06521819204166135, + -0.14944896282001804, + 0.19413948083049481, + -0.13543541860473626, + -0.05235495076237124, + -0.049556206595684105, + -0.07080758913030646, + -0.24571898386301153, + -0.002070636306950905, + -0.03863113401320783, + -0.13620291339913712, + -0.23383074855679695, + -0.11776449630199368, + -0.1489441490225184, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + fitting_configs_ccd = [] + for _, c in enumerate(fitting_configs): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + + c_ccd = c.copy() + c_ccd.info["head"] = "CCD" + fitting_configs_ccd.append(c_ccd) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + ase.io.write(tmp_path / "fit_multihead_ccd.xyz", fitting_configs_ccd) + + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + "CCD": {"train_file": f"{str(tmp_path)}/fit_multihead_ccd.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["loss"] = "weighted" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["config"] = tmp_path / "config.yaml" + mace_params["batch_size"] = 2 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head="CCD", + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 02/09/2024 on develop branch + ref_Es = [ + 0.0, + 0.0, + 0.10637113905361611, + -0.012499594026624754, + 0.08983077108171753, + 0.21071322543112597, + -0.028921849222784398, + -0.02423359575741567, + 0.022923252188079057, + -0.02048334610058991, + 0.4349711162741364, + -0.04455577015569887, + -0.09765806785570091, + 0.16013134616829822, + 0.0758442928017698, + -0.05931856557011721, + 0.33964473532953265, + 0.134338442158641, + 0.18024119757783053, + -0.18914740992058765, + -0.06503477155294624, + 0.03436649147415213, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_foundation(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 28/03/2023 on repulsion a63434aaab70c84ee016e13e4aca8d57297a0f26 + ref_Es = [ + 1.6780993938446045, + 0.8916864395141602, + 0.7290308475494385, + 0.6194742918014526, + 0.6697757840156555, + 0.7025266289710999, + 0.5818213224411011, + 0.7897703647613525, + 0.6558921337127686, + 0.5071806907653809, + 3.581131935119629, + 0.691562294960022, + 0.6257331967353821, + 0.9560437202453613, + 0.7716934680938721, + 0.6730310916900635, + 0.8297463655471802, + 0.8053972721099854, + 0.8337507247924805, + 0.4107491970062256, + 0.6019601821899414, + 0.7301387786865234, + ] + assert np.allclose(Es, ref_Es) + + +def test_run_train_foundation_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + atomic_numbers = np.unique( + np.concatenate([at.numbers for at in fitting_configs]) + ).tolist() + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + fitting_configs_dft.append(c) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + elif i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" + mace_params["filter_type_pt"] = "combinations" + mace_params["force_mh_ft_lr"] = True + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + Es = [] + for at in fitting_configs: + config_head = at.info.get("head", "MP2") + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, + ) + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_foundation_multihead_json(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + atomic_numbers = np.unique( + np.concatenate([at.numbers for at in fitting_configs]) + ).tolist() + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + if i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + + heads = { + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" + mace_params["filter_type_pt"] = "combinations" + mace_params["force_mh_ft_lr"] = True + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + Es = [] + for at in fitting_configs: + config_head = at.info.get("head", "MP2") + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head=config_head, + ) + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_multihead_replay_custum_finetuning( + tmp_path, fitting_configs, pretraining_configs +): + ase.io.write(tmp_path / "pretrain.xyz", pretraining_configs) + + foundation_params = { + "name": "foundation", + "train_file": os.path.join(tmp_path, "pretrain.xyz"), + "valid_fraction": 0.2, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 2, + "max_num_epochs": 5, + "swa": None, + "start_swa": 3, + "device": "cpu", + "seed": 42, + "loss": "weighted", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "default_dtype": "float64", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + } + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = [sys.executable, str(run_train)] + for k, v in foundation_params.items(): + if v is None: + cmd.append(f"--{k}") + else: + cmd.append(f"--{k}={v}") + + p = subprocess.run(cmd, env=run_env, check=True) + assert p.returncode == 0 + + # Step 3: Create finetuning set + fitting_configs_dft = [] + fitting_configs_mp2 = [] + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + fitting_configs_dft.append(c) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + elif i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # Step 4: Finetune the pretrained model with multihead replay + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + + finetuning_params = { + "name": "finetuned", + "valid_fraction": 0.1, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "32x0e", + "r_max": 5.0, + "batch_size": 2, + "max_num_epochs": 5, + "device": "cpu", + "seed": 42, + "loss": "weighted", + "default_dtype": "float64", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "foundation_model": os.path.join(tmp_path, "foundation.model"), + "config": os.path.join(tmp_path, "config.yaml"), + "pt_train_file": os.path.join(tmp_path, "pretrain.xyz"), + "num_samples_pt": 3, + "subselect_pt": "random", + "force_mh_ft_lr": True, + } + + cmd = [sys.executable, str(run_train)] + for k, v in finetuning_params.items(): + if v is None: + cmd.append(f"--{k}") + else: + cmd.append(f"--{k}={v}") + + p = subprocess.run(cmd, env=run_env, check=True) + assert p.returncode == 0 + + # Load and test the finetuned model + calc = MACECalculator( + model_paths=tmp_path / "finetuned.model", + device="cpu", + default_dtype="float64", + head="pt_head", + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Energies:", Es) + + # Add some basic checks + assert len(Es) == len(fitting_configs) + assert all(isinstance(E, float) for E in Es) + assert len(set(Es)) > 1 # Ens + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_run_train_cueq(tmp_path, fitting_configs): + torch.set_default_dtype(torch.float64) + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["enable_cueq"] = True + mace_params["default_dtype"] = "float64" + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cuda") + Es = [] + for at in fitting_configs[2:]: + at.calc = calc + Es.append(at.get_potential_energy()) + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", enable_cueq=True + ) + Es_cueq = [] + for at in fitting_configs[2:]: + at.calc = calc + Es_cueq.append(at.get_potential_energy()) + + # from a run on 04/06/2024 on stress_bugfix 967f0bfb6490086599da247874b24595d149caa7 + ref_Es = [ + -0.039181344585828524, + -0.0915223395136733, + -0.14953484236456582, + -0.06662480820063998, + -0.09983737353050133, + 0.12477442296789745, + -0.06486086271762856, + -0.1460607988519944, + 0.12886334908465508, + -0.14000990081920373, + -0.05319886578958313, + 0.07780520158391, + -0.08895480281886901, + -0.15474719614734422, + 0.007756765146527644, + -0.044879267197498685, + -0.036065736712447574, + -0.24413743841886623, + -0.0838104612106429, + -0.14751978636626545, + ] + + assert np.allclose(Es, ref_Es) + assert np.allclose(ref_Es, Es_cueq) + + +@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed") +def test_run_train_foundation_multihead_json_cueq(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + atomic_numbers = np.unique( + np.concatenate([at.numbers for at in fitting_configs]) + ).tolist() + for i, c in enumerate(fitting_configs): + + if i in (0, 1): + continue # skip isolated atoms, as energies specified by json files below + if i % 2 == 0: + c.info["head"] = "DFT" + fitting_configs_dft.append(c) + else: + c.info["head"] = "MP2" + fitting_configs_mp2.append(c) + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs_mp2) + + # write E0s to json files + E0s = {1: 0.0, 8: 0.0} + with open(tmp_path / "fit_multihead_dft.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + with open(tmp_path / "fit_multihead_mp2.json", "w", encoding="utf-8") as f: + json.dump(E0s, f) + + heads = { + "DFT": { + "train_file": f"{str(tmp_path)}/fit_multihead_dft.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_dft.json", + }, + "MP2": { + "train_file": f"{str(tmp_path)}/fit_multihead_mp2.xyz", + "E0s": f"{str(tmp_path)}/fit_multihead_mp2.json", + }, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(yaml_str) + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["config"] = tmp_path / "config.yaml" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["batch_size"] = 2 + mace_params["valid_batch_size"] = 1 + mace_params["num_samples_pt"] = 50 + mace_params["subselect_pt"] = "random" + mace_params["enable_cueq"] = True + mace_params["atomic_numbers"] = "[" + ",".join(map(str, atomic_numbers)) + "]" + mace_params["filter_type_pt"] = "combinations" + mace_params["device"] = "cuda" + mace_params["force_mh_ft_lr"] = True + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cuda", + default_dtype="float64", + head="DFT", + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 20/08/2024 on commit + ref_Es = [ + 1.654685616493225, + 0.44693732261657715, + 0.8741313815116882, + 0.569085955619812, + 0.7161882519721985, + 0.8654778599739075, + 0.8722733855247498, + 0.49582308530807495, + 0.814422607421875, + 0.7027317881584167, + 0.7196993827819824, + 0.517953097820282, + 0.8631765246391296, + 0.4679797887802124, + 0.8163984417915344, + 0.4252359867095947, + 1.0861445665359497, + 0.6829671263694763, + 0.7136879563331604, + 0.5160345435142517, + 0.7002358436584473, + 0.5574042201042175, + ] + assert np.allclose(Es, ref_Es, atol=1e-1) + + +def test_run_train_lbfgs(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["lbfgs"] = None + mace_params["max_num_epochs"] = 2 + + # make sure run_train.py is using the mace that is currently being tested + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator(model_paths=tmp_path / "MACE.model", device="cpu") + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + # from a run on 14/03/2025 + ref_Es = [ + 0.0, + 0.0, + -0.1874197850340979, + -0.25991775038059006, + 0.18263492399322268, + -0.15026829765490662, + -0.2403061362015996, + 0.1689257170630718, + -0.2095568077455055, + -0.2957758160829075, + -0.0035370913684985364, + -0.2195416610745775, + -0.25405549447739517, + -0.06201390990366806, + -0.13332219494388334, + -0.19633181702040337, + 0.013014932630445699, + -0.08808335967147174, + -0.06664444189210728, + -0.4230467426992034, + -0.2348250569553676, + -0.17593904833220647, + ] + assert np.allclose(Es, ref_Es, atol=1e-2) + + +def test_run_train_foundation_elements(tmp_path, fitting_configs): + + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + base_params = { + "name": "MACE", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "train_file": tmp_path / "fit.xyz", + "loss": "weighted", + "foundation_model": "small", + "hidden_irreps": "128x0e", + "r_max": 6.0, + "default_dtype": "float64", + "max_num_epochs": 5, + "num_radial_basis": 10, + "interaction_first": "RealAgnosticResidualInteractionBlock", + "multiheads_finetuning": False, + } + + # Run environment setup + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + # First run: without foundation_model_elements (default behavior) + mace_params = base_params.copy() + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") + filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) + assert filtered_elements == {1, 8} # Only H and O should be present + + # Second run: with foundation_model_elements + mace_params = base_params.copy() + mace_params["name"] = "MACE_all_elements" + mace_params["foundation_model_elements"] = True # Flag-only argument + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") + all_elements = set(int(z) for z in model_all.atomic_numbers) + + # Get elements from foundation model for comparison + calc = mace_mp(model="small", device="cpu") + foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) + + # Check that all foundation model elements are preserved + assert all_elements == foundation_elements + assert len(all_elements) > len(filtered_elements) + + # Check that both models can make predictions + at = fitting_configs[2].copy() + + # Test filtered model + calc_filtered = MACECalculator( + model_paths=tmp_path / "MACE.model", device="cpu", default_dtype="float64" + ) + at.calc = calc_filtered + e1 = at.get_potential_energy() + + # Test all-elements model + calc_all = MACECalculator( + model_paths=tmp_path / "MACE_all_elements.model", + device="cpu", + default_dtype="float64", + ) + at.calc = calc_all + e2 = at.get_potential_energy() + + # Energies should be different since the models are trained differently, + # but both should give reasonable results + assert np.isfinite(e1) + assert np.isfinite(e2) + + +def test_run_train_foundation_elements_multihead(tmp_path, fitting_configs): + fitting_configs_dft = [] + fitting_configs_mp2 = [] + atomic_numbers = np.unique( + np.concatenate([at.numbers for at in fitting_configs]) + ).tolist() + for i, c in enumerate(fitting_configs): + if i in (0, 1): + c_dft = c.copy() + c_dft.info["head"] = "DFT" + fitting_configs_dft.append(c_dft) + c_mp2 = c.copy() + c_mp2.info["head"] = "MP2" + fitting_configs_mp2.append(c_mp2) + if i % 2 == 0: + c_copy = c.copy() + c_copy.info["head"] = "DFT" + fitting_configs_dft.append(c_copy) + else: + c_copy = c.copy() + c_copy.info["head"] = "MP2" + fitting_configs_mp2.append(c_copy) + + ase.io.write(tmp_path / "fit_dft.xyz", fitting_configs_dft) + ase.io.write(tmp_path / "fit_mp2.xyz", fitting_configs_mp2) + + # Create multihead configuration + heads = { + "DFT": {"train_file": f"{str(tmp_path)}/fit_dft.xyz"}, + "MP2": {"train_file": f"{str(tmp_path)}/fit_mp2.xyz"}, + } + yaml_str = "heads:\n" + for key, value in heads.items(): + yaml_str += f" {key}:\n" + for sub_key, sub_value in value.items(): + yaml_str += f" {sub_key}: {sub_value}\n" + config_file = tmp_path / "config.yaml" + with open(config_file, "w", encoding="utf-8") as file: + file.write(yaml_str) + + base_params = { + "name": "MACE", + "checkpoints_dir": str(tmp_path), + "model_dir": str(tmp_path), + "config": str(config_file), + "loss": "weighted", + "foundation_model": "small", + "hidden_irreps": "128x0e", + "r_max": 6.0, + "default_dtype": "float64", + "max_num_epochs": 5, + "num_radial_basis": 10, + "interaction_first": "RealAgnosticResidualInteractionBlock", + "force_mh_ft_lr": True, + "batch_size": 1, + "num_samples_pt": 50, + "subselect_pt": "random", + "atomic_numbers": "[" + ",".join(map(str, atomic_numbers)) + "]", + "filter_type_pt": "combinations", + "valid_fraction": 0.1, + "valid_batch_size": 1, + } + + # Run environment setup + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + # First run: without foundation_model_elements (default behavior) + mace_params = base_params.copy() + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + try: + completed_process = subprocess.run( + cmd.split(), env=run_env, capture_output=True, text=True, check=True + ) + # Process executed successfully + print(completed_process.stdout) + except subprocess.CalledProcessError as e: + # Process failed with non-zero exit code + print(f"Command failed with exit code {e.returncode}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + raise e + assert completed_process.returncode == 0 + + # Load model and check elements + model_filtered = torch.load(tmp_path / "MACE.model", map_location="cpu") + filtered_elements = set(int(z) for z in model_filtered.atomic_numbers) + assert filtered_elements == {1, 8} # Only H and O should be present + assert len(model_filtered.heads) == 3 # pt_head + DFT + MP2 + + # Second run: with foundation_model_elements + mace_params = base_params.copy() + mace_params["name"] = "MACE_all_elements" + mace_params["foundation_model_elements"] = True + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + # Load model and check elements + model_all = torch.load(tmp_path / "MACE_all_elements.model", map_location="cpu") + all_elements = set(int(z) for z in model_all.atomic_numbers) + + # Get elements from foundation model for comparison + calc = mace_mp(model="small", device="cpu") + foundation_elements = set(int(z) for z in calc.models[0].atomic_numbers) + + # Check that all foundation model elements are preserved + assert all_elements == foundation_elements + assert len(all_elements) > len(filtered_elements) + assert len(model_all.heads) == 3 # pt_head + DFT + MP2 + + # Check that both models can make predictions + at = fitting_configs_dft[2].copy() + + # Test filtered model + calc_filtered = MACECalculator( + model_paths=tmp_path / "MACE.model", + device="cpu", + default_dtype="float64", + head="DFT", + ) + at.calc = calc_filtered + e1 = at.get_potential_energy() + + # Test all-elements model + calc_all = MACECalculator( + model_paths=tmp_path / "MACE_all_elements.model", + device="cpu", + default_dtype="float64", + head="DFT", + ) + at.calc = calc_all + e2 = at.get_potential_energy() + + assert np.isfinite(e1) + assert np.isfinite(e2) diff --git a/mace-bench/3rdparty/mace/tests/test_run_train_allkeys.py b/mace-bench/3rdparty/mace/tests/test_run_train_allkeys.py index 1c10217..1d59190 100644 --- a/mace-bench/3rdparty/mace/tests/test_run_train_allkeys.py +++ b/mace-bench/3rdparty/mace/tests/test_run_train_allkeys.py @@ -1,468 +1,468 @@ -import os -import subprocess -import sys -from copy import deepcopy -from pathlib import Path - -import ase.io -import numpy as np -import pytest -from ase.atoms import Atoms - -from mace.calculators.mace import MACECalculator -from mace.cli.run_train import run as run_mace_train -from mace.data.utils import KeySpecification -from mace.tools import build_default_arg_parser - -run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" - - -_mace_params = { - "name": "MACE", - "valid_fraction": 0.05, - "energy_weight": 1.0, - "forces_weight": 10.0, - "stress_weight": 1.0, - "model": "MACE", - "hidden_irreps": "128x0e", - "max_num_epochs": 10, - "swa": None, - "start_swa": 5, - "ema": None, - "ema_decay": 0.99, - "amsgrad": None, - "device": "cpu", - "seed": 5, - "loss": "weighted", - "energy_key": "REF_energy", - "forces_key": "REF_forces", - "stress_key": "REF_stress", - "interaction_first": "RealAgnosticResidualInteractionBlock", - "batch_size": 1, - "valid_batch_size": 1, - "num_samples_pt": 50, - "subselect_pt": "random", - "eval_interval": 2, - "num_radial_basis": 10, - "r_max": 6.0, - "default_dtype": "float64", -} - - -def configs_numbered_keys(): - np.random.seed(0) - water = Atoms( - numbers=[8, 1, 1], - positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], - cell=[4] * 3, - pbc=[True] * 3, - ) - - energies = list(np.random.normal(0.1, size=15)) - forces = list(np.random.normal(0.1, size=(15, 3, 3))) - - trial_configs_lists = [] - # some keys present, some not - keys_to_use = ( - ["REF_energy"] - + ["2_energy"] * 2 - + ["3_energy"] * 3 - + ["4_energy"] * 4 - + ["5_energy"] * 5 - ) - - force_keys_to_use = ( - ["REF_forces"] - + ["2_forces"] * 2 - + ["3_forces"] * 3 - + ["4_forces"] * 4 - + ["5_forces"] * 5 - ) - - for ind in range(15): - c = deepcopy(water) - c.info[keys_to_use[ind]] = energies[ind] - c.arrays[force_keys_to_use[ind]] = forces[ind] - c.positions += np.random.normal(0.1, size=(3, 3)) - trial_configs_lists.append(c) - - return trial_configs_lists - - -def trial_yamls_and_and_expected(): - yamls = {} - command_line_kwargs = {"energy_key": "2_energy", "forces_key": "2_forces"} - - yamls["no_heads"] = {} - - yamls["one_head_no_dicts"] = { - "heads": { - "Default": { - "energy_key": "3_energy", - } - } - } - - yamls["one_head_with_dicts"] = { - "heads": { - "Default": { - "info_keys": { - "energy": "3_energy", - }, - "arrays_keys": { - "forces": "3_forces", - }, - } - } - } - - yamls["two_heads_no_dicts"] = { - "heads": { - "dft": { - "train_file": "fit_multihead_dft.xyz", - "energy_key": "3_energy", - }, - "mp2": { - "train_file": "fit_multihead_mp2.xyz", - "energy_key": "4_energy", - }, - } - } - - yamls["two_heads_mixed"] = { - "heads": { - "dft": { - "train_file": "fit_multihead_dft.xyz", - "info_keys": { - "energy": "3_energy", - }, - "arrays_keys": { - "forces": "3_forces", - }, - "forces_key": "4_forces", - }, - "mp2": { - "train_file": "fit_multihead_mp2.xyz", - "energy_key": "4_energy", - }, - } - } - all_arg_sets = { - "with_command_line": { - key: {**command_line_kwargs, **value} for key, value in yamls.items() - }, - "without_command_line": yamls, - } - - all_expected_outputs = { - "with_command_line": { - "no_heads": [ - 1.0037831178668188, - 1.0183291323603265, - 1.0120784084221528, - 0.9935695881012243, - 1.0021641561865526, - 0.9999135609205868, - 0.9809440616323108, - 1.0025784765050076, - 1.0017901145495376, - 1.0136913185404515, - 1.006798563238269, - 1.0187758397828384, - 1.0180201540775071, - 1.0132368725061702, - 0.9998734173248169, - ], - "one_head_no_dicts": [ - 1.0028437510688613, - 1.0514693378041775, - 1.059933403321331, - 1.034719940573569, - 1.0438040675561824, - 1.019719477728329, - 0.9841759692947915, - 1.0435266573857496, - 1.0339501989779065, - 1.0501795448530264, - 1.0402594216704781, - 1.0604998765679152, - 1.0633411200246015, - 1.0539071190201297, - 1.0393496428177804, - ], - "one_head_with_dicts": [ - 0.8638341551096959, - 1.0078341354784144, - 1.0149701178418595, - 0.9945723048460148, - 1.0184158011731292, - 0.9992135295205004, - 0.8943420783639198, - 1.0327920054084088, - 0.9905731198078909, - 0.9838325204450648, - 1.0018725575620482, - 1.007263052421034, - 1.0335213929231966, - 1.0033503312511205, - 1.0174433894759563, - ], - "two_heads_no_dicts": [ - 0.9836377578288774, - 1.0196844186291318, - 1.0151628222871238, - 0.957307281711648, - 0.985574141310865, - 0.9629670134047853, - 0.9242583185138095, - 0.9807770070311039, - 0.9973679440479541, - 1.0221127246963275, - 1.0031807967874216, - 1.0358701219543687, - 1.0434208761164758, - 1.0235606028124515, - 0.9797494630655053, - ], - "two_heads_mixed": [ - 0.8664108574741868, - 0.9907166576278023, - 1.0051969372365164, - 0.978702477000018, - 1.025500166764692, - 0.9940095566375018, - 0.9034029726954119, - 1.0391739502744488, - 0.9717327061183668, - 0.972292103670355, - 1.0012510461663253, - 0.9978051155885286, - 1.0378611651753475, - 1.0003207628186224, - 1.0209509292189651, - ], - }, - "without_command_line": { - "no_heads": [ - 0.9352605307451007, - 0.991084559389268, - 0.9940350095024881, - 0.9953849198103668, - 0.9954705498032904, - 0.9964815693808411, - 0.9663142667436776, - 0.9947223808739147, - 0.9897776682803257, - 0.989027769690667, - 0.9910280920241263, - 0.992067980667518, - 0.9917276132506404, - 0.9902848752169671, - 0.9928585982942544, - ], - "one_head_no_dicts": [ - 0.9425342207393083, - 1.0149788456087416, - 1.0249228965652788, - 1.0247924743285792, - 1.02732103964481, - 1.0168852937950326, - 0.9771283495170653, - 1.0261776335561517, - 1.0130461033368028, - 1.0162619153561783, - 1.019995179866916, - 1.0209512298344965, - 1.0219971755636952, - 1.0195791901659124, - 1.0234662527729408, - ], - "one_head_with_dicts": [ - 0.8638341551096959, - 1.0078341354784144, - 1.0149701178418595, - 0.9945723048460148, - 1.0184158011731292, - 0.9992135295205004, - 0.8943420783639198, - 1.0327920054084088, - 0.9905731198078909, - 0.9838325204450648, - 1.0018725575620482, - 1.007263052421034, - 1.0335213929231966, - 1.0033503312511205, - 1.0174433894759563, - ], - "two_heads_no_dicts": [ - 0.9933763730233168, - 0.9986480398559268, - 1.0042486164355315, - 1.0025568793877726, - 1.0032598081704625, - 0.9926714183717912, - 0.9920385249670881, - 1.0020278841030676, - 1.0012474150830537, - 1.0039289677261019, - 1.0022718878661814, - 1.003586385624809, - 1.003436450009097, - 1.003805673887942, - 1.001450261102316, - ], - "two_heads_mixed": [ - 0.8781767864616707, - 0.9843563603794138, - 1.0145197579049248, - 0.9835060778675391, - 1.0419060462994596, - 0.9917393978520056, - 0.9091521032773944, - 1.0605463095070453, - 0.9685381713826684, - 0.9866493058823766, - 1.00305061187164, - 1.0051273128414386, - 1.037964258398104, - 1.0106663924241408, - 1.0274351814133602, - ], - }, - } - - list_of_all = [] - for key, value in all_arg_sets.items(): - for key2, value2 in value.items(): - list_of_all.append( - (value2, (key, key2), np.asarray(all_expected_outputs[key][key2])) - ) - - return list_of_all - - -def dict_to_yaml_str(data, indent=0): - yaml_str = "" - for key, value in data.items(): - yaml_str += " " * indent + str(key) + ":" - if isinstance(value, dict): - yaml_str += "\n" + dict_to_yaml_str(value, indent + 2) - else: - yaml_str += " " + str(value) + "\n" - return yaml_str - - -_trial_yamls_and_and_expected = trial_yamls_and_and_expected() - - -@pytest.mark.parametrize( - "yaml_contents, name, expected_value", _trial_yamls_and_and_expected -) -def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value): - fitting_configs = configs_numbered_keys() - - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) - ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs) - ase.io.write(tmp_path / "duplicated_fit_multihead_dft.xyz", fitting_configs) - - mace_params = _mace_params.copy() - mace_params["valid_fraction"] = 0.1 - mace_params["checkpoints_dir"] = str(tmp_path) - mace_params["model_dir"] = str(tmp_path) - mace_params["train_file"] = "fit_multihead_dft.xyz" - mace_params["E0s"] = "{1:0.0,8:1.0}" - mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz" - del mace_params["valid_fraction"] - mace_params["max_num_epochs"] = 1 # many tests to do - del mace_params["energy_key"] - del mace_params["forces_key"] - del mace_params["stress_key"] - - mace_params["name"] = "MACE_" - - filename = tmp_path / "config.yaml" - with open(filename, "w", encoding="utf-8") as file: - file.write(dict_to_yaml_str(yaml_contents)) - if len(yaml_contents) > 0: - mace_params["config"] = str(tmp_path / "config.yaml") - - run_env = os.environ.copy() - sys.path.insert(0, str(Path(__file__).parent.parent)) - run_env["PYTHONPATH"] = ":".join(sys.path) - print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) - - cmd = ( - sys.executable - + " " - + str(run_train) - + " " - + " ".join( - [ - (f"--{k}={v}" if v is not None else f"--{k}") - for k, v in mace_params.items() - ] - ) - ) - - p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True) - assert p.returncode == 0 - - if "heads" in yaml_contents: - headname = list(yaml_contents["heads"].keys())[0] - else: - headname = "Default" - - calc = MACECalculator( - tmp_path / "MACE_.model", device="cpu", default_dtype="float64", head=headname - ) - - Es = [] - for at in fitting_configs: - at.calc = calc - Es.append(at.get_potential_energy()) - - print(name) - print("Es", Es) - - assert np.allclose( - np.asarray(Es), expected_value, rtol=1e-8, atol=1e-8 - ), f"Expected {expected_value} but got {Es} with error {np.max(np.abs(Es - expected_value))}" - - -def test_multihead_finetuning_does_not_modify_default_keyspec(tmp_path): - fitting_configs = configs_numbered_keys() - ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) - - args = build_default_arg_parser().parse_args( - [ - "--name", - "_MACE_", - "--train_file", - str(tmp_path / "fit_multihead_dft.xyz"), - "--foundation_model", - "small", - "--device", - "cpu", - "--E0s", - "{1:0.0,8:1.0}", - "--energy_key", - "2_energy", - "--dry_run", - ] - ) - default_key_spec = KeySpecification.from_defaults() - default_key_spec.info_keys["energy"] = "2_energy" - run_mace_train(args) - assert args.key_specification == default_key_spec - -# for creating values -def make_output(): - outputs = {} - for yaml_contents, name, expected_value in _trial_yamls_and_and_expected: - if name[0] not in outputs: - outputs[name[0]] = {} - expected = test_key_specification_methods( - Path("."), yaml_contents, name, expected_value, debug_test=False - ) - outputs[name[0]][name[1]] = expected - print(outputs) +import os +import subprocess +import sys +from copy import deepcopy +from pathlib import Path + +import ase.io +import numpy as np +import pytest +from ase.atoms import Atoms + +from mace.calculators.mace import MACECalculator +from mace.cli.run_train import run as run_mace_train +from mace.data.utils import KeySpecification +from mace.tools import build_default_arg_parser + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "max_num_epochs": 10, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "device": "cpu", + "seed": 5, + "loss": "weighted", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "interaction_first": "RealAgnosticResidualInteractionBlock", + "batch_size": 1, + "valid_batch_size": 1, + "num_samples_pt": 50, + "subselect_pt": "random", + "eval_interval": 2, + "num_radial_basis": 10, + "r_max": 6.0, + "default_dtype": "float64", +} + + +def configs_numbered_keys(): + np.random.seed(0) + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + + energies = list(np.random.normal(0.1, size=15)) + forces = list(np.random.normal(0.1, size=(15, 3, 3))) + + trial_configs_lists = [] + # some keys present, some not + keys_to_use = ( + ["REF_energy"] + + ["2_energy"] * 2 + + ["3_energy"] * 3 + + ["4_energy"] * 4 + + ["5_energy"] * 5 + ) + + force_keys_to_use = ( + ["REF_forces"] + + ["2_forces"] * 2 + + ["3_forces"] * 3 + + ["4_forces"] * 4 + + ["5_forces"] * 5 + ) + + for ind in range(15): + c = deepcopy(water) + c.info[keys_to_use[ind]] = energies[ind] + c.arrays[force_keys_to_use[ind]] = forces[ind] + c.positions += np.random.normal(0.1, size=(3, 3)) + trial_configs_lists.append(c) + + return trial_configs_lists + + +def trial_yamls_and_and_expected(): + yamls = {} + command_line_kwargs = {"energy_key": "2_energy", "forces_key": "2_forces"} + + yamls["no_heads"] = {} + + yamls["one_head_no_dicts"] = { + "heads": { + "Default": { + "energy_key": "3_energy", + } + } + } + + yamls["one_head_with_dicts"] = { + "heads": { + "Default": { + "info_keys": { + "energy": "3_energy", + }, + "arrays_keys": { + "forces": "3_forces", + }, + } + } + } + + yamls["two_heads_no_dicts"] = { + "heads": { + "dft": { + "train_file": "fit_multihead_dft.xyz", + "energy_key": "3_energy", + }, + "mp2": { + "train_file": "fit_multihead_mp2.xyz", + "energy_key": "4_energy", + }, + } + } + + yamls["two_heads_mixed"] = { + "heads": { + "dft": { + "train_file": "fit_multihead_dft.xyz", + "info_keys": { + "energy": "3_energy", + }, + "arrays_keys": { + "forces": "3_forces", + }, + "forces_key": "4_forces", + }, + "mp2": { + "train_file": "fit_multihead_mp2.xyz", + "energy_key": "4_energy", + }, + } + } + all_arg_sets = { + "with_command_line": { + key: {**command_line_kwargs, **value} for key, value in yamls.items() + }, + "without_command_line": yamls, + } + + all_expected_outputs = { + "with_command_line": { + "no_heads": [ + 1.0037831178668188, + 1.0183291323603265, + 1.0120784084221528, + 0.9935695881012243, + 1.0021641561865526, + 0.9999135609205868, + 0.9809440616323108, + 1.0025784765050076, + 1.0017901145495376, + 1.0136913185404515, + 1.006798563238269, + 1.0187758397828384, + 1.0180201540775071, + 1.0132368725061702, + 0.9998734173248169, + ], + "one_head_no_dicts": [ + 1.0028437510688613, + 1.0514693378041775, + 1.059933403321331, + 1.034719940573569, + 1.0438040675561824, + 1.019719477728329, + 0.9841759692947915, + 1.0435266573857496, + 1.0339501989779065, + 1.0501795448530264, + 1.0402594216704781, + 1.0604998765679152, + 1.0633411200246015, + 1.0539071190201297, + 1.0393496428177804, + ], + "one_head_with_dicts": [ + 0.8638341551096959, + 1.0078341354784144, + 1.0149701178418595, + 0.9945723048460148, + 1.0184158011731292, + 0.9992135295205004, + 0.8943420783639198, + 1.0327920054084088, + 0.9905731198078909, + 0.9838325204450648, + 1.0018725575620482, + 1.007263052421034, + 1.0335213929231966, + 1.0033503312511205, + 1.0174433894759563, + ], + "two_heads_no_dicts": [ + 0.9836377578288774, + 1.0196844186291318, + 1.0151628222871238, + 0.957307281711648, + 0.985574141310865, + 0.9629670134047853, + 0.9242583185138095, + 0.9807770070311039, + 0.9973679440479541, + 1.0221127246963275, + 1.0031807967874216, + 1.0358701219543687, + 1.0434208761164758, + 1.0235606028124515, + 0.9797494630655053, + ], + "two_heads_mixed": [ + 0.8664108574741868, + 0.9907166576278023, + 1.0051969372365164, + 0.978702477000018, + 1.025500166764692, + 0.9940095566375018, + 0.9034029726954119, + 1.0391739502744488, + 0.9717327061183668, + 0.972292103670355, + 1.0012510461663253, + 0.9978051155885286, + 1.0378611651753475, + 1.0003207628186224, + 1.0209509292189651, + ], + }, + "without_command_line": { + "no_heads": [ + 0.9352605307451007, + 0.991084559389268, + 0.9940350095024881, + 0.9953849198103668, + 0.9954705498032904, + 0.9964815693808411, + 0.9663142667436776, + 0.9947223808739147, + 0.9897776682803257, + 0.989027769690667, + 0.9910280920241263, + 0.992067980667518, + 0.9917276132506404, + 0.9902848752169671, + 0.9928585982942544, + ], + "one_head_no_dicts": [ + 0.9425342207393083, + 1.0149788456087416, + 1.0249228965652788, + 1.0247924743285792, + 1.02732103964481, + 1.0168852937950326, + 0.9771283495170653, + 1.0261776335561517, + 1.0130461033368028, + 1.0162619153561783, + 1.019995179866916, + 1.0209512298344965, + 1.0219971755636952, + 1.0195791901659124, + 1.0234662527729408, + ], + "one_head_with_dicts": [ + 0.8638341551096959, + 1.0078341354784144, + 1.0149701178418595, + 0.9945723048460148, + 1.0184158011731292, + 0.9992135295205004, + 0.8943420783639198, + 1.0327920054084088, + 0.9905731198078909, + 0.9838325204450648, + 1.0018725575620482, + 1.007263052421034, + 1.0335213929231966, + 1.0033503312511205, + 1.0174433894759563, + ], + "two_heads_no_dicts": [ + 0.9933763730233168, + 0.9986480398559268, + 1.0042486164355315, + 1.0025568793877726, + 1.0032598081704625, + 0.9926714183717912, + 0.9920385249670881, + 1.0020278841030676, + 1.0012474150830537, + 1.0039289677261019, + 1.0022718878661814, + 1.003586385624809, + 1.003436450009097, + 1.003805673887942, + 1.001450261102316, + ], + "two_heads_mixed": [ + 0.8781767864616707, + 0.9843563603794138, + 1.0145197579049248, + 0.9835060778675391, + 1.0419060462994596, + 0.9917393978520056, + 0.9091521032773944, + 1.0605463095070453, + 0.9685381713826684, + 0.9866493058823766, + 1.00305061187164, + 1.0051273128414386, + 1.037964258398104, + 1.0106663924241408, + 1.0274351814133602, + ], + }, + } + + list_of_all = [] + for key, value in all_arg_sets.items(): + for key2, value2 in value.items(): + list_of_all.append( + (value2, (key, key2), np.asarray(all_expected_outputs[key][key2])) + ) + + return list_of_all + + +def dict_to_yaml_str(data, indent=0): + yaml_str = "" + for key, value in data.items(): + yaml_str += " " * indent + str(key) + ":" + if isinstance(value, dict): + yaml_str += "\n" + dict_to_yaml_str(value, indent + 2) + else: + yaml_str += " " + str(value) + "\n" + return yaml_str + + +_trial_yamls_and_and_expected = trial_yamls_and_and_expected() + + +@pytest.mark.parametrize( + "yaml_contents, name, expected_value", _trial_yamls_and_and_expected +) +def test_key_specification_methods(tmp_path, yaml_contents, name, expected_value): + fitting_configs = configs_numbered_keys() + + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) + ase.io.write(tmp_path / "fit_multihead_mp2.xyz", fitting_configs) + ase.io.write(tmp_path / "duplicated_fit_multihead_dft.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["valid_fraction"] = 0.1 + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = "fit_multihead_dft.xyz" + mace_params["E0s"] = "{1:0.0,8:1.0}" + mace_params["valid_file"] = "duplicated_fit_multihead_dft.xyz" + del mace_params["valid_fraction"] + mace_params["max_num_epochs"] = 1 # many tests to do + del mace_params["energy_key"] + del mace_params["forces_key"] + del mace_params["stress_key"] + + mace_params["name"] = "MACE_" + + filename = tmp_path / "config.yaml" + with open(filename, "w", encoding="utf-8") as file: + file.write(dict_to_yaml_str(yaml_contents)) + if len(yaml_contents) > 0: + mace_params["config"] = str(tmp_path / "config.yaml") + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + print("DEBUG subprocess PYTHONPATH", run_env["PYTHONPATH"]) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + p = subprocess.run(cmd.split(), env=run_env, cwd=tmp_path, check=True) + assert p.returncode == 0 + + if "heads" in yaml_contents: + headname = list(yaml_contents["heads"].keys())[0] + else: + headname = "Default" + + calc = MACECalculator( + tmp_path / "MACE_.model", device="cpu", default_dtype="float64", head=headname + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print(name) + print("Es", Es) + + assert np.allclose( + np.asarray(Es), expected_value, rtol=1e-8, atol=1e-8 + ), f"Expected {expected_value} but got {Es} with error {np.max(np.abs(Es - expected_value))}" + + +def test_multihead_finetuning_does_not_modify_default_keyspec(tmp_path): + fitting_configs = configs_numbered_keys() + ase.io.write(tmp_path / "fit_multihead_dft.xyz", fitting_configs) + + args = build_default_arg_parser().parse_args( + [ + "--name", + "_MACE_", + "--train_file", + str(tmp_path / "fit_multihead_dft.xyz"), + "--foundation_model", + "small", + "--device", + "cpu", + "--E0s", + "{1:0.0,8:1.0}", + "--energy_key", + "2_energy", + "--dry_run", + ] + ) + default_key_spec = KeySpecification.from_defaults() + default_key_spec.info_keys["energy"] = "2_energy" + run_mace_train(args) + assert args.key_specification == default_key_spec + +# for creating values +def make_output(): + outputs = {} + for yaml_contents, name, expected_value in _trial_yamls_and_and_expected: + if name[0] not in outputs: + outputs[name[0]] = {} + expected = test_key_specification_methods( + Path("."), yaml_contents, name, expected_value, debug_test=False + ) + outputs[name[0]][name[1]] = expected + print(outputs) diff --git a/mace-bench/3rdparty/mace/tests/test_schedulefree.py b/mace-bench/3rdparty/mace/tests/test_schedulefree.py index d84163c..00b2075 100644 --- a/mace-bench/3rdparty/mace/tests/test_schedulefree.py +++ b/mace-bench/3rdparty/mace/tests/test_schedulefree.py @@ -1,127 +1,127 @@ -import tempfile -from unittest.mock import MagicMock - -import numpy as np -import pytest -import torch -import torch.nn.functional as F -from e3nn import o3 - -from mace import data, modules, tools -from mace.tools import scripts_utils, torch_geometric - -try: - import schedulefree -except ImportError: - pytest.skip( - "Skipping schedulefree tests due to ImportError", allow_module_level=True - ) - -torch.set_default_dtype(torch.float64) - -table = tools.AtomicNumberTable([6]) -atomic_energies = np.array([1.0], dtype=float) -cutoff = 5.0 - - -def create_mace(device: str, seed: int = 1702): - torch_geometric.seed_everything(seed) - - model_config = { - "r_max": cutoff, - "num_bessel": 8, - "num_polynomial_cutoff": 6, - "max_ell": 3, - "interaction_cls": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "interaction_cls_first": modules.interaction_classes[ - "RealAgnosticResidualInteractionBlock" - ], - "num_interactions": 2, - "num_elements": 1, - "hidden_irreps": o3.Irreps("8x0e + 8x1o"), - "MLP_irreps": o3.Irreps("16x0e"), - "gate": F.silu, - "atomic_energies": atomic_energies, - "avg_num_neighbors": 8, - "atomic_numbers": table.zs, - "correlation": 3, - "radial_type": "bessel", - } - model = modules.MACE(**model_config) - return model.to(device) - - -def create_batch(device: str): - from ase import build - - size = 2 - atoms = build.bulk("C", "diamond", a=3.567, cubic=True) - atoms_list = [atoms.repeat((size, size, size))] - print("Number of atoms", len(atoms_list[0])) - - configs = [data.config_from_atoms(atoms) for atoms in atoms_list] - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) - for config in configs - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)) - batch = batch.to(device) - batch = batch.to_dict() - return batch - - -def do_optimization_step( - model, - optimizer, - device, -): - batch = create_batch(device) - model.train() - optimizer.train() - optimizer.zero_grad() - output = model(batch, training=True, compute_force=False) - loss = output["energy"].mean() - loss.backward() - optimizer.step() - model.eval() - optimizer.eval() - - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_can_load_checkpoint(device): - model = create_mace(device) - optimizer = schedulefree.adamw_schedulefree.AdamWScheduleFree(model.parameters()) - args = MagicMock() - args.optimizer = "schedulefree" - args.scheduler = "ExponentialLR" - args.lr_scheduler_gamma = 0.9 - lr_scheduler = scripts_utils.LRScheduler(optimizer, args) - with tempfile.TemporaryDirectory() as d: - checkpoint_handler = tools.CheckpointHandler( - directory=d, keep=False, tag="schedulefree" - ) - for _ in range(10): - do_optimization_step(model, optimizer, device) - batch = create_batch(device) - output = model(batch) - energy = output["energy"].detach().cpu().numpy() - - state = tools.CheckpointState( - model=model, optimizer=optimizer, lr_scheduler=lr_scheduler - ) - checkpoint_handler.save(state, epochs=0, keep_last=False) - checkpoint_handler.load_latest( - state=tools.CheckpointState(model, optimizer, lr_scheduler), - swa=False, - ) - batch = create_batch(device) - output = model(batch) - new_energy = output["energy"].detach().cpu().numpy() - assert np.allclose(energy, new_energy, atol=1e-9) +import tempfile +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +import torch.nn.functional as F +from e3nn import o3 + +from mace import data, modules, tools +from mace.tools import scripts_utils, torch_geometric + +try: + import schedulefree +except ImportError: + pytest.skip( + "Skipping schedulefree tests due to ImportError", allow_module_level=True + ) + +torch.set_default_dtype(torch.float64) + +table = tools.AtomicNumberTable([6]) +atomic_energies = np.array([1.0], dtype=float) +cutoff = 5.0 + + +def create_mace(device: str, seed: int = 1702): + torch_geometric.seed_everything(seed) + + model_config = { + "r_max": cutoff, + "num_bessel": 8, + "num_polynomial_cutoff": 6, + "max_ell": 3, + "interaction_cls": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "interaction_cls_first": modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + "num_interactions": 2, + "num_elements": 1, + "hidden_irreps": o3.Irreps("8x0e + 8x1o"), + "MLP_irreps": o3.Irreps("16x0e"), + "gate": F.silu, + "atomic_energies": atomic_energies, + "avg_num_neighbors": 8, + "atomic_numbers": table.zs, + "correlation": 3, + "radial_type": "bessel", + } + model = modules.MACE(**model_config) + return model.to(device) + + +def create_batch(device: str): + from ase import build + + size = 2 + atoms = build.bulk("C", "diamond", a=3.567, cubic=True) + atoms_list = [atoms.repeat((size, size, size))] + print("Number of atoms", len(atoms_list[0])) + + configs = [data.config_from_atoms(atoms) for atoms in atoms_list] + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) + for config in configs + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)) + batch = batch.to(device) + batch = batch.to_dict() + return batch + + +def do_optimization_step( + model, + optimizer, + device, +): + batch = create_batch(device) + model.train() + optimizer.train() + optimizer.zero_grad() + output = model(batch, training=True, compute_force=False) + loss = output["energy"].mean() + loss.backward() + optimizer.step() + model.eval() + optimizer.eval() + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_can_load_checkpoint(device): + model = create_mace(device) + optimizer = schedulefree.adamw_schedulefree.AdamWScheduleFree(model.parameters()) + args = MagicMock() + args.optimizer = "schedulefree" + args.scheduler = "ExponentialLR" + args.lr_scheduler_gamma = 0.9 + lr_scheduler = scripts_utils.LRScheduler(optimizer, args) + with tempfile.TemporaryDirectory() as d: + checkpoint_handler = tools.CheckpointHandler( + directory=d, keep=False, tag="schedulefree" + ) + for _ in range(10): + do_optimization_step(model, optimizer, device) + batch = create_batch(device) + output = model(batch) + energy = output["energy"].detach().cpu().numpy() + + state = tools.CheckpointState( + model=model, optimizer=optimizer, lr_scheduler=lr_scheduler + ) + checkpoint_handler.save(state, epochs=0, keep_last=False) + checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + ) + batch = create_batch(device) + output = model(batch) + new_energy = output["energy"].detach().cpu().numpy() + assert np.allclose(energy, new_energy, atol=1e-9) diff --git a/mace-bench/3rdparty/mace/tests/test_tools.py b/mace-bench/3rdparty/mace/tests/test_tools.py index 50c1ee8..227a1bf 100644 --- a/mace-bench/3rdparty/mace/tests/test_tools.py +++ b/mace-bench/3rdparty/mace/tests/test_tools.py @@ -1,48 +1,48 @@ -import tempfile - -import numpy as np -import torch -import torch.nn.functional -from torch import nn, optim - -from mace.tools import ( - AtomicNumberTable, - CheckpointHandler, - CheckpointState, - atomic_numbers_to_indices, -) - - -def test_atomic_number_table(): - table = AtomicNumberTable(zs=[1, 8]) - array = np.array([8, 8, 1]) - indices = atomic_numbers_to_indices(array, z_table=table) - expected = np.array([1, 1, 0], dtype=int) - assert np.allclose(expected, indices) - - -class MyModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 4) - - def forward(self, x): - return torch.nn.functional.relu(self.linear(x)) - - -def test_save_load(): - model = MyModel() - initial_lr = 0.001 - optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9) - scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99) - - with tempfile.TemporaryDirectory() as directory: - handler = CheckpointHandler(directory=directory, tag="test", keep=True) - handler.save(state=CheckpointState(model, optimizer, scheduler), epochs=50) - - optimizer.step() - scheduler.step() - assert not np.isclose(optimizer.param_groups[0]["lr"], initial_lr) - - handler.load_latest(state=CheckpointState(model, optimizer, scheduler)) - assert np.isclose(optimizer.param_groups[0]["lr"], initial_lr) +import tempfile + +import numpy as np +import torch +import torch.nn.functional +from torch import nn, optim + +from mace.tools import ( + AtomicNumberTable, + CheckpointHandler, + CheckpointState, + atomic_numbers_to_indices, +) + + +def test_atomic_number_table(): + table = AtomicNumberTable(zs=[1, 8]) + array = np.array([8, 8, 1]) + indices = atomic_numbers_to_indices(array, z_table=table) + expected = np.array([1, 1, 0], dtype=int) + assert np.allclose(expected, indices) + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 4) + + def forward(self, x): + return torch.nn.functional.relu(self.linear(x)) + + +def test_save_load(): + model = MyModel() + initial_lr = 0.001 + optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99) + + with tempfile.TemporaryDirectory() as directory: + handler = CheckpointHandler(directory=directory, tag="test", keep=True) + handler.save(state=CheckpointState(model, optimizer, scheduler), epochs=50) + + optimizer.step() + scheduler.step() + assert not np.isclose(optimizer.param_groups[0]["lr"], initial_lr) + + handler.load_latest(state=CheckpointState(model, optimizer, scheduler)) + assert np.isclose(optimizer.param_groups[0]["lr"], initial_lr) diff --git a/mace-bench/reproduce/init_7net.sh b/mace-bench/reproduce/init_7net.sh index 70aa632..22f19f1 100644 --- a/mace-bench/reproduce/init_7net.sh +++ b/mace-bench/reproduce/init_7net.sh @@ -1,11 +1,11 @@ -#!/bin/bash - -pip install torch_scatter==2.1.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install torch_sparse==0.6.18+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install torch_spline_conv==1.2.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install -r requirements.txt -pip install -e 3rdparty/SevenNet -pip install -e . -pip install ase==3.23.0 -pip install ninja + + +pip install torch_scatter==2.1.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install torch_sparse==0.6.18+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install torch_spline_conv==1.2.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install -r requirements.txt +pip install -e 3rdparty/SevenNet +pip install -e . +pip install ase==3.23.0 +pip install ninja pip install rdkit==2024.3.5 \ No newline at end of file diff --git a/mace-bench/reproduce/init_mace.sh b/mace-bench/reproduce/init_mace.sh index d7e9cba..b491a65 100644 --- a/mace-bench/reproduce/init_mace.sh +++ b/mace-bench/reproduce/init_mace.sh @@ -1,14 +1,12 @@ -#!/bin/bash - -pip install torch_scatter==2.1.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install torch_sparse==0.6.18+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install torch_spline_conv==1.2.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html -pip install -r requirements.txt -pip install -e 3rdparty/mace -pip install -e . -pip install e3nn==0.4.4 -pip install ase==3.23.0 -pip install ninja - -# for python_CSP -pip install rdkit-pypi +pip install torch_scatter==2.1.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install torch_sparse==0.6.18+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install torch_spline_conv==1.2.2+pt24cu121 -f https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html +pip install -r requirements.txt +pip install -e 3rdparty/mace +pip install -e . +pip install e3nn==0.4.4 +pip install ase==3.23.0 +pip install ninja + +# for python_CSP +pip install rdkit-pypi diff --git a/mace-bench/reproduce/mace_opt_new.py b/mace-bench/reproduce/mace_opt_new.py index 5b17949..4a6d9d7 100644 --- a/mace-bench/reproduce/mace_opt_new.py +++ b/mace-bench/reproduce/mace_opt_new.py @@ -1,300 +1,300 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -import os -os.environ['OMP_NUM_THREADS'] = '1' -os.environ['MKL_NUM_THREADS'] = '1' -os.environ['OPENBLAS_NUM_THREADS'] = '1' -import sys -# sys.path.append('/home/jiangj1group/zcxzcx1/volatile/mace') -from mace.calculators import mace_off, mace_mp -from ase.io import read, write -from ase.optimize import BFGS,LBFGS,FIRE,GPMin,MDMin, QuasiNewton -from ase.filters import UnitCellFilter, ExpCellFilter, FrechetCellFilter -import re -import io -from contextlib import redirect_stdout -import os -import pandas as pd -from joblib import Parallel, delayed -import json -import torch -import numpy as np -import random -import argparse -import time -import pathlib -import logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) -##################################################################### -os.environ['PYTHONHASHSEED'] = '1' -torch.manual_seed(1) -np.random.seed(1) -random.seed(1) -torch.cuda.manual_seed(1) -torch.cuda.manual_seed_all(1) -##################################################################### -# n_jobs=32 -# # n_jobs=2 -# path = './' -# molecule_single = 64 -# target_folder = "/data_raw/" -##################################################################### - -def calculate_density(crystal): - # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 - total_mass = sum(crystal.get_masses()) # 转换为克 - - # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 - # 1 Å^3 = 1e-24 cm^3 - volume = crystal.get_volume() # 转换为立方厘米 - - # 计算密度,质量除以体积 - density = total_mass / (volume*10**-24)/(6.022140857*10**23) # 单位是 g/cm^3 - return density - -def run_calculation_one(path,file,target_folder,molecule_single,idx): - # os.environ['OMP_NUM_THREADS'] = '1' - # os.environ['MKL_NUM_THREADS'] = '1' - # os.environ['OPENBLAS_NUM_THREADS'] = '1' - if reproduce: - print("Reproducing deterministic results.") - torch.use_deterministic_algorithms(True) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - np.set_printoptions(precision=17, suppress=False) - torch.set_printoptions(precision=17, sci_mode=False, linewidth=200) - if multithread and (not reproduce): - print("Using OMP and MKL multithreads will introduce non-deterministic results.") - else: - os.environ['OMP_NUM_THREADS'] = '1' - os.environ['MKL_NUM_THREADS'] = '1' - os.environ['OPENBLAS_NUM_THREADS'] = '1' - os.environ["CUDA_VISIBLE_DEVICES"]=str((idx%n_gpus)+gpu_offset) - - with io.StringIO() as buf, redirect_stdout(buf): - crystal = read(path+target_folder+file) - if molecule_single < 0: - molecule_single = int(file.split('_')[-1].split('.')[0]) - molecule_count = len(crystal.get_atomic_numbers())/molecule_single - calc = mace_off(model=model_path,dispersion=True, device='cuda', enable_cueq=cueq) - crystal.calc = calc - if filter1 == "UnitCellFilter": - sf = UnitCellFilter(crystal,scalar_pressure=0.0006) - elif filter1 == "FrechetCellFilter": - sf = FrechetCellFilter(crystal,scalar_pressure=0.0006) - else: - raise ValueError(f"Unrecognized filter type '{filter1}'. " - "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") - if optimizer_type1 == "BFGS": - if use_cuda_eigh: - optimizer = BFGS(sf, use_cuda_eigh=True) - else: - optimizer = BFGS(sf) - elif optimizer_type1 == "LBFGS": - optimizer = LBFGS(sf) - elif optimizer_type1 == "QuasiNewton": - optimizer = QuasiNewton(sf) - else: - raise ValueError(f"Unrecognized optimizer type '{optimizer_type1}'. " - "Supported types are 'BFGS' and 'LBFGS'.") - - if use_nsys or use_torch_profiler : # warmup for profiling - optimizer.run(fmax=0.01,steps=100) - if use_torch_profiler: - profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA - ], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), - with_stack=True - ) - profiler.start() - - start_time1 = time.time() - optimizer.run(fmax=0.01,steps=max_steps) - end_time1 = time.time() - - if use_torch_profiler: - profiler.stop() - - crystal.write(path+'cif_result_press/'+file[:-4]+"_press.cif") - output_1 = buf.getvalue() - # step_used_1 = float(re.split("\\s+", output_1.split('\n')[-2])[1][:]) - step_used_1 = optimizer.nsteps - if use_nsys or use_torch_profiler : - step_used_1 = step_used_1 - 100 - total_time1 = end_time1 - start_time1 - avg_time1 = total_time1 / step_used_1 if step_used_1 != 0 else 0 - - crystal = read(path+'cif_result_press/'+file[:-4]+"_press.cif") - crystal.calc = calc - if filter2 == "UnitCellFilter": - sf = UnitCellFilter(crystal) - elif filter2 == "FrechetCellFilter": - sf = FrechetCellFilter(crystal) - else: - raise ValueError(f"Unrecognized filter type '{filter2}'. " - "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") - if optimizer_type2 == "BFGS": - if use_cuda_eigh: - optimizer = BFGS(sf, use_cuda_eigh=True) - else: - optimizer = BFGS(sf) - elif optimizer_type2 == "LBFGS": - optimizer = LBFGS(sf) - elif optimizer_type2 == "QuasiNewton": - optimizer = QuasiNewton(sf) - else: - raise ValueError(f"Unrecognized optimizer type '{optimizer_type2}'. " - "Supported types are 'BFGS' and 'LBFGS'.") - if use_torch_profiler: - profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA - ], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), - with_stack=True - ) - profiler.start() - - start_time2 = time.time() - optimizer.run(fmax=0.01,steps=max_steps) - end_time2 = time.time() - - if use_torch_profiler: - profiler.stop() - - density = calculate_density(crystal) - crystal.write(path+'cif_result_final/'+file[:-4]+"_opt.cif") - output_2 = buf.getvalue() - energy = float(re.split("\\s+", output_2.split('\n')[-2])[3][:]) - # step_used_2 = float(re.split("\\s+", output_2.split('\n')[-2])[1][:]) - step_used_2 = optimizer.nsteps - energy_per_mol = energy / molecule_count * 96.485 - total_time2 = end_time2 - start_time2 - avg_time2 = total_time2 / step_used_2 if step_used_2 != 0 else 0 - - new_row = { - 'name': file[:-4], 'density': density, 'energy_kj': energy_per_mol, - 'step_used_1': step_used_1, 'step_used_2': step_used_2, - 'total_time1_s': total_time1, 'avg_time1_s': avg_time1, - 'total_time2_s': total_time2, 'avg_time2_s': avg_time2 - } - - print(f'output_2: {output_2}') - with open(path+'json_result/'+file[:-4]+".json", 'w') as json_file: - json.dump(new_row, json_file, indent=4) - return new_row - - -def already_have_calculation_one(path, file, target_folder, molecule_single, idx): - logging.info(f"reading on structure {file}") - print(f"reading on structure {file}") - with open(path + 'json_result/' + file[:-4] + ".json", 'r') as file: - old_row = json.load(file) - return old_row - -def run(): - df = pd.DataFrame(columns=['name', 'density', 'energy_kj', 'step_used_1', 'step_used_2', 'total_time1_s', 'avg_time1_s', 'total_time2_s', 'avg_time2_s']) - for root, dirs, files in os.walk(path + target_folder): - old_row = Parallel(n_jobs=n_jobs)( - delayed(already_have_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in - enumerate(files) if os.path.exists(path + 'json_result/' + file[:-4] + ".json")) - - filtered_files = [file for file in files if not os.path.exists(path + 'json_result/' + file[:-4] + ".json")] - new_row = Parallel(n_jobs=n_jobs)( - delayed(run_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in - enumerate(filtered_files)) - # show the length of new_row - print(f'new_row length: {len(new_row)}') - print(f'root: {root}\ndirs: {dirs}\nfiles: {files}') - for row in new_row: - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) - for row in old_row: - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) - df.to_csv(path + '/result.csv') - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Run parallel calculations on molecular crystals.") - parser.add_argument("--n_jobs", type=int, default=32, help="Number of parallel jobs to run (default: 32)") - parser.add_argument("--target_folder", type=str, required=True, help="Path to the target folder containing input files") - parser.add_argument("--path", type=str, default='./', help="Base path for the project (default: './')") - parser.add_argument("--molecule_single", type=int, default=-1, help="Number of atoms per molecule (default: 64)") - parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use (default: 2)") - parser.add_argument("--cueq", action='store_true', help="Whether to use cuEquivariance Library (default: False)") - parser.add_argument("--max_steps", type=int, default=3000, help="Number of max steps to run the optimization (default: 3000)") - parser.add_argument("--use_torch_profiler", action='store_true', help="Whether to use torch profiler (default: False)") - parser.add_argument("--use_nsys", action='store_true', help="Whether to use nsys profiler (default: False)") - parser.add_argument("--model", type=str, default="small", help="Model to use for the calculation (default: 'small')") - parser.add_argument("--optimizer", type=str, default="BFGS", help="Optimizer to use for the calculation (default: 'BFGS')") - parser.add_argument("--use_cuda_eigh", action='store_true', help="Whether to use CUDA for eigh (default: False)") - parser.add_argument("--gpu_offset", type=int, default=0, help="GPU offset to use for the calculation (default: 0)") - parser.add_argument("--multithread", action='store_true', help="Whether to use multithread (default: False)") - parser.add_argument("--reproduce", action='store_true', help="Whether to reproduce deterministic results (default: False)") - parser.add_argument("--filter1", type=str, default="UnitCellFilter", help="1st filter to use for the calculation (default: 'UnitCellFilter')") - parser.add_argument("--filter2", type=str, default="UnitCellFilter", help="2nd filter to use for the calculation (default: 'UnitCellFilter')") - parser.add_argument("--optimizer1", type=str, default="BFGS", help="1st optimizer to use for the calculation (default: 'BFGS')") - parser.add_argument("--optimizer2", type=str, default="BFGS", help="2nd optimizer to use for the calculation (default: 'BFGS')") - - args = parser.parse_args() - - n_jobs = args.n_jobs - target_folder = args.target_folder - path = args.path - molecule_single = args.molecule_single - n_gpus = args.n_gpus - cueq = args.cueq - max_steps = args.max_steps - use_torch_profiler = args.use_torch_profiler - use_nsys = args.use_nsys - model_path = args.model - optimizer_type = args.optimizer - use_cuda_eigh = args.use_cuda_eigh - gpu_offset = args.gpu_offset - multithread = args.multithread - reproduce = args.reproduce - filter1 = args.filter1 - filter2 = args.filter2 - optimizer_type1 = args.optimizer1 - optimizer_type2 = args.optimizer2 - - - try: - os.mkdir("./cif_result_press") - os.mkdir("./cif_result_final") - except: - pass - try: - os.mkdir("./json_result") - except: - pass - - start_time_all = time.time() - - - iter = 0 - while iter < 100: - iter += 1 - try: - run() - break - except Exception as e: - print(f"Error occurred: {e}") - print("Retrying...") - time.sleep(10) - - end_time_all = time.time() - total_time_all = end_time_all - start_time_all - print('dataset,total_time_all_s,attempts') - print(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}") - with open(path + 'timing.csv', 'w') as f: - f.write('dataset,total_time_all_s,attempts\n') +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +os.environ['OMP_NUM_THREADS'] = '1' +os.environ['MKL_NUM_THREADS'] = '1' +os.environ['OPENBLAS_NUM_THREADS'] = '1' +import sys +# sys.path.append('/home/jiangj1group/zcxzcx1/volatile/mace') +from mace.calculators import mace_off, mace_mp +from ase.io import read, write +from ase.optimize import BFGS,LBFGS,FIRE,GPMin,MDMin, QuasiNewton +from ase.filters import UnitCellFilter, ExpCellFilter, FrechetCellFilter +import re +import io +from contextlib import redirect_stdout +import os +import pandas as pd +from joblib import Parallel, delayed +import json +import torch +import numpy as np +import random +import argparse +import time +import pathlib +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) +##################################################################### +os.environ['PYTHONHASHSEED'] = '1' +torch.manual_seed(1) +np.random.seed(1) +random.seed(1) +torch.cuda.manual_seed(1) +torch.cuda.manual_seed_all(1) +##################################################################### +# n_jobs=32 +# # n_jobs=2 +# path = './' +# molecule_single = 64 +# target_folder = "/data_raw/" +##################################################################### + +def calculate_density(crystal): + # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 + total_mass = sum(crystal.get_masses()) # 转换为克 + + # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 + # 1 Å^3 = 1e-24 cm^3 + volume = crystal.get_volume() # 转换为立方厘米 + + # 计算密度,质量除以体积 + density = total_mass / (volume*10**-24)/(6.022140857*10**23) # 单位是 g/cm^3 + return density + +def run_calculation_one(path,file,target_folder,molecule_single,idx): + # os.environ['OMP_NUM_THREADS'] = '1' + # os.environ['MKL_NUM_THREADS'] = '1' + # os.environ['OPENBLAS_NUM_THREADS'] = '1' + if reproduce: + print("Reproducing deterministic results.") + torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + np.set_printoptions(precision=17, suppress=False) + torch.set_printoptions(precision=17, sci_mode=False, linewidth=200) + if multithread and (not reproduce): + print("Using OMP and MKL multithreads will introduce non-deterministic results.") + else: + os.environ['OMP_NUM_THREADS'] = '1' + os.environ['MKL_NUM_THREADS'] = '1' + os.environ['OPENBLAS_NUM_THREADS'] = '1' + os.environ["CUDA_VISIBLE_DEVICES"]=str((idx%n_gpus)+gpu_offset) + + with io.StringIO() as buf, redirect_stdout(buf): + crystal = read(path+target_folder+file) + if molecule_single < 0: + molecule_single = int(file.split('_')[-1].split('.')[0]) + molecule_count = len(crystal.get_atomic_numbers())/molecule_single + calc = mace_off(model=model_path,dispersion=True, device='cuda', enable_cueq=cueq) + crystal.calc = calc + if filter1 == "UnitCellFilter": + sf = UnitCellFilter(crystal,scalar_pressure=0.0006) + elif filter1 == "FrechetCellFilter": + sf = FrechetCellFilter(crystal,scalar_pressure=0.0006) + else: + raise ValueError(f"Unrecognized filter type '{filter1}'. " + "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") + if optimizer_type1 == "BFGS": + if use_cuda_eigh: + optimizer = BFGS(sf, use_cuda_eigh=True) + else: + optimizer = BFGS(sf) + elif optimizer_type1 == "LBFGS": + optimizer = LBFGS(sf) + elif optimizer_type1 == "QuasiNewton": + optimizer = QuasiNewton(sf) + else: + raise ValueError(f"Unrecognized optimizer type '{optimizer_type1}'. " + "Supported types are 'BFGS' and 'LBFGS'.") + + if use_nsys or use_torch_profiler : # warmup for profiling + optimizer.run(fmax=0.01,steps=100) + if use_torch_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), + with_stack=True + ) + profiler.start() + + start_time1 = time.time() + optimizer.run(fmax=0.01,steps=max_steps) + end_time1 = time.time() + + if use_torch_profiler: + profiler.stop() + + crystal.write(path+'cif_result_press/'+file[:-4]+"_press.cif") + output_1 = buf.getvalue() + # step_used_1 = float(re.split("\\s+", output_1.split('\n')[-2])[1][:]) + step_used_1 = optimizer.nsteps + if use_nsys or use_torch_profiler : + step_used_1 = step_used_1 - 100 + total_time1 = end_time1 - start_time1 + avg_time1 = total_time1 / step_used_1 if step_used_1 != 0 else 0 + + crystal = read(path+'cif_result_press/'+file[:-4]+"_press.cif") + crystal.calc = calc + if filter2 == "UnitCellFilter": + sf = UnitCellFilter(crystal) + elif filter2 == "FrechetCellFilter": + sf = FrechetCellFilter(crystal) + else: + raise ValueError(f"Unrecognized filter type '{filter2}'. " + "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") + if optimizer_type2 == "BFGS": + if use_cuda_eigh: + optimizer = BFGS(sf, use_cuda_eigh=True) + else: + optimizer = BFGS(sf) + elif optimizer_type2 == "LBFGS": + optimizer = LBFGS(sf) + elif optimizer_type2 == "QuasiNewton": + optimizer = QuasiNewton(sf) + else: + raise ValueError(f"Unrecognized optimizer type '{optimizer_type2}'. " + "Supported types are 'BFGS' and 'LBFGS'.") + if use_torch_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), + with_stack=True + ) + profiler.start() + + start_time2 = time.time() + optimizer.run(fmax=0.01,steps=max_steps) + end_time2 = time.time() + + if use_torch_profiler: + profiler.stop() + + density = calculate_density(crystal) + crystal.write(path+'cif_result_final/'+file[:-4]+"_opt.cif") + output_2 = buf.getvalue() + energy = float(re.split("\\s+", output_2.split('\n')[-2])[3][:]) + # step_used_2 = float(re.split("\\s+", output_2.split('\n')[-2])[1][:]) + step_used_2 = optimizer.nsteps + energy_per_mol = energy / molecule_count * 96.485 + total_time2 = end_time2 - start_time2 + avg_time2 = total_time2 / step_used_2 if step_used_2 != 0 else 0 + + new_row = { + 'name': file[:-4], 'density': density, 'energy_kj': energy_per_mol, + 'step_used_1': step_used_1, 'step_used_2': step_used_2, + 'total_time1_s': total_time1, 'avg_time1_s': avg_time1, + 'total_time2_s': total_time2, 'avg_time2_s': avg_time2 + } + + print(f'output_2: {output_2}') + with open(path+'json_result/'+file[:-4]+".json", 'w') as json_file: + json.dump(new_row, json_file, indent=4) + return new_row + + +def already_have_calculation_one(path, file, target_folder, molecule_single, idx): + logging.info(f"reading on structure {file}") + print(f"reading on structure {file}") + with open(path + 'json_result/' + file[:-4] + ".json", 'r') as file: + old_row = json.load(file) + return old_row + +def run(): + df = pd.DataFrame(columns=['name', 'density', 'energy_kj', 'step_used_1', 'step_used_2', 'total_time1_s', 'avg_time1_s', 'total_time2_s', 'avg_time2_s']) + for root, dirs, files in os.walk(path + target_folder): + old_row = Parallel(n_jobs=n_jobs)( + delayed(already_have_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in + enumerate(files) if os.path.exists(path + 'json_result/' + file[:-4] + ".json")) + + filtered_files = [file for file in files if not os.path.exists(path + 'json_result/' + file[:-4] + ".json")] + new_row = Parallel(n_jobs=n_jobs)( + delayed(run_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in + enumerate(filtered_files)) + # show the length of new_row + print(f'new_row length: {len(new_row)}') + print(f'root: {root}\ndirs: {dirs}\nfiles: {files}') + for row in new_row: + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) + for row in old_row: + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) + df.to_csv(path + '/result.csv') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Run parallel calculations on molecular crystals.") + parser.add_argument("--n_jobs", type=int, default=32, help="Number of parallel jobs to run (default: 32)") + parser.add_argument("--target_folder", type=str, required=True, help="Path to the target folder containing input files") + parser.add_argument("--path", type=str, default='./', help="Base path for the project (default: './')") + parser.add_argument("--molecule_single", type=int, default=-1, help="Number of atoms per molecule (default: 64)") + parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use (default: 2)") + parser.add_argument("--cueq", action='store_true', help="Whether to use cuEquivariance Library (default: False)") + parser.add_argument("--max_steps", type=int, default=3000, help="Number of max steps to run the optimization (default: 3000)") + parser.add_argument("--use_torch_profiler", action='store_true', help="Whether to use torch profiler (default: False)") + parser.add_argument("--use_nsys", action='store_true', help="Whether to use nsys profiler (default: False)") + parser.add_argument("--model", type=str, default="small", help="Model to use for the calculation (default: 'small')") + parser.add_argument("--optimizer", type=str, default="BFGS", help="Optimizer to use for the calculation (default: 'BFGS')") + parser.add_argument("--use_cuda_eigh", action='store_true', help="Whether to use CUDA for eigh (default: False)") + parser.add_argument("--gpu_offset", type=int, default=0, help="GPU offset to use for the calculation (default: 0)") + parser.add_argument("--multithread", action='store_true', help="Whether to use multithread (default: False)") + parser.add_argument("--reproduce", action='store_true', help="Whether to reproduce deterministic results (default: False)") + parser.add_argument("--filter1", type=str, default="UnitCellFilter", help="1st filter to use for the calculation (default: 'UnitCellFilter')") + parser.add_argument("--filter2", type=str, default="UnitCellFilter", help="2nd filter to use for the calculation (default: 'UnitCellFilter')") + parser.add_argument("--optimizer1", type=str, default="BFGS", help="1st optimizer to use for the calculation (default: 'BFGS')") + parser.add_argument("--optimizer2", type=str, default="BFGS", help="2nd optimizer to use for the calculation (default: 'BFGS')") + + args = parser.parse_args() + + n_jobs = args.n_jobs + target_folder = args.target_folder + path = args.path + molecule_single = args.molecule_single + n_gpus = args.n_gpus + cueq = args.cueq + max_steps = args.max_steps + use_torch_profiler = args.use_torch_profiler + use_nsys = args.use_nsys + model_path = args.model + optimizer_type = args.optimizer + use_cuda_eigh = args.use_cuda_eigh + gpu_offset = args.gpu_offset + multithread = args.multithread + reproduce = args.reproduce + filter1 = args.filter1 + filter2 = args.filter2 + optimizer_type1 = args.optimizer1 + optimizer_type2 = args.optimizer2 + + + try: + os.mkdir("./cif_result_press") + os.mkdir("./cif_result_final") + except: + pass + try: + os.mkdir("./json_result") + except: + pass + + start_time_all = time.time() + + + iter = 0 + while iter < 100: + iter += 1 + try: + run() + break + except Exception as e: + print(f"Error occurred: {e}") + print("Retrying...") + time.sleep(10) + + end_time_all = time.time() + total_time_all = end_time_all - start_time_all + print('dataset,total_time_all_s,attempts') + print(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}") + with open(path + 'timing.csv', 'w') as f: + f.write('dataset,total_time_all_s,attempts\n') f.write(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}\n") \ No newline at end of file diff --git a/mace-bench/reproduce/mace_opt_origin.py b/mace-bench/reproduce/mace_opt_origin.py index f6796e8..c059fb1 100644 --- a/mace-bench/reproduce/mace_opt_origin.py +++ b/mace-bench/reproduce/mace_opt_origin.py @@ -1,297 +1,297 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -import os -import sys -# sys.path.append('/home/jiangj1group/zcxzcx1/volatile/mace') -from mace.calculators import mace_off, mace_mp -from ase.io import read, write -from ase.optimize import BFGS,LBFGS,FIRE,GPMin,MDMin, QuasiNewton -from ase.filters import UnitCellFilter, ExpCellFilter, FrechetCellFilter -import re -import io -from contextlib import redirect_stdout -import os -import pandas as pd -from joblib import Parallel, delayed -import json -import torch -import numpy as np -import random -import argparse -import time -import pathlib -import logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) -##################################################################### -os.environ['PYTHONHASHSEED'] = '1' -torch.manual_seed(1) -np.random.seed(1) -random.seed(1) -torch.cuda.manual_seed(1) -torch.cuda.manual_seed_all(1) -##################################################################### -# n_jobs=32 -# # n_jobs=2 -# path = './' -# molecule_single = 64 -# target_folder = "/data_raw/" -##################################################################### - -def calculate_density(crystal): - # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 - total_mass = sum(crystal.get_masses()) # 转换为克 - - # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 - # 1 Å^3 = 1e-24 cm^3 - volume = crystal.get_volume() # 转换为立方厘米 - - # 计算密度,质量除以体积 - density = total_mass / (volume*10**-24)/(6.022140857*10**23) # 单位是 g/cm^3 - return density - -def run_calculation_one(path,file,target_folder,molecule_single,idx): - # os.environ['OMP_NUM_THREADS'] = '1' - # os.environ['MKL_NUM_THREADS'] = '1' - # os.environ['OPENBLAS_NUM_THREADS'] = '1' - if reproduce: - print("Reproducing deterministic results.") - torch.use_deterministic_algorithms(True) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - np.set_printoptions(precision=17, suppress=False) - torch.set_printoptions(precision=17, sci_mode=False, linewidth=200) - if multithread and (not reproduce): - print("Using OMP and MKL multithreads will introduce non-deterministic results.") - else: - os.environ['OMP_NUM_THREADS'] = '1' - os.environ['MKL_NUM_THREADS'] = '1' - os.environ['OPENBLAS_NUM_THREADS'] = '1' - os.environ["CUDA_VISIBLE_DEVICES"]=str((idx%n_gpus)+gpu_offset) - - with io.StringIO() as buf, redirect_stdout(buf): - crystal = read(path+target_folder+file) - if molecule_single < 0: - molecule_single = int(file.split('_')[-1].split('.')[0]) - molecule_count = len(crystal.get_atomic_numbers())/molecule_single - calc = mace_off(model=model_path,dispersion=True, device='cuda', enable_cueq=cueq) - crystal.calc = calc - if filter1 == "UnitCellFilter": - sf = UnitCellFilter(crystal,scalar_pressure=0.0006) - elif filter1 == "FrechetCellFilter": - sf = FrechetCellFilter(crystal,scalar_pressure=0.0006) - else: - raise ValueError(f"Unrecognized filter type '{filter1}'. " - "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") - if optimizer_type1 == "BFGS": - if use_cuda_eigh: - optimizer = BFGS(sf, use_cuda_eigh=True) - else: - optimizer = BFGS(sf) - elif optimizer_type1 == "LBFGS": - optimizer = LBFGS(sf) - elif optimizer_type1 == "QuasiNewton": - optimizer = QuasiNewton(sf) - else: - raise ValueError(f"Unrecognized optimizer type '{optimizer_type1}'. " - "Supported types are 'BFGS' and 'LBFGS'.") - - if use_nsys or use_torch_profiler : # warmup for profiling - optimizer.run(fmax=0.01,steps=100) - if use_torch_profiler: - profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA - ], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), - with_stack=True - ) - profiler.start() - - start_time1 = time.time() - optimizer.run(fmax=0.01,steps=max_steps) - end_time1 = time.time() - - if use_torch_profiler: - profiler.stop() - - crystal.write(path+'cif_result_press/'+file[:-4]+"_press.cif") - output_1 = buf.getvalue() - # step_used_1 = float(re.split("\\s+", output_1.split('\n')[-2])[1][:]) - step_used_1 = optimizer.nsteps - if use_nsys or use_torch_profiler : - step_used_1 = step_used_1 - 100 - total_time1 = end_time1 - start_time1 - avg_time1 = total_time1 / step_used_1 if step_used_1 != 0 else 0 - - crystal = read(path+'cif_result_press/'+file[:-4]+"_press.cif") - crystal.calc = calc - if filter2 == "UnitCellFilter": - sf = UnitCellFilter(crystal) - elif filter2 == "FrechetCellFilter": - sf = FrechetCellFilter(crystal) - else: - raise ValueError(f"Unrecognized filter type '{filter2}'. " - "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") - if optimizer_type2 == "BFGS": - if use_cuda_eigh: - optimizer = BFGS(sf, use_cuda_eigh=True) - else: - optimizer = BFGS(sf) - elif optimizer_type2 == "LBFGS": - optimizer = LBFGS(sf) - elif optimizer_type2 == "QuasiNewton": - optimizer = QuasiNewton(sf) - else: - raise ValueError(f"Unrecognized optimizer type '{optimizer_type2}'. " - "Supported types are 'BFGS' and 'LBFGS'.") - if use_torch_profiler: - profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA - ], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), - with_stack=True - ) - profiler.start() - - start_time2 = time.time() - optimizer.run(fmax=0.01,steps=max_steps) - end_time2 = time.time() - - if use_torch_profiler: - profiler.stop() - - density = calculate_density(crystal) - crystal.write(path+'cif_result_final/'+file[:-4]+"_opt.cif") - output_2 = buf.getvalue() - energy = float(re.split("\\s+", output_2.split('\n')[-2])[3][:]) - # step_used_2 = float(re.split("\\s+", output_2.split('\n')[-2])[1][:]) - step_used_2 = optimizer.nsteps - energy_per_mol = energy / molecule_count * 96.485 - total_time2 = end_time2 - start_time2 - avg_time2 = total_time2 / step_used_2 if step_used_2 != 0 else 0 - - new_row = { - 'name': file[:-4], 'density': density, 'energy_kj': energy_per_mol, - 'step_used_1': step_used_1, 'step_used_2': step_used_2, - 'total_time1_s': total_time1, 'avg_time1_s': avg_time1, - 'total_time2_s': total_time2, 'avg_time2_s': avg_time2 - } - - print(f'output_2: {output_2}') - with open(path+'json_result/'+file[:-4]+".json", 'w') as json_file: - json.dump(new_row, json_file, indent=4) - return new_row - - -def already_have_calculation_one(path, file, target_folder, molecule_single, idx): - logging.info(f"reading on structure {file}") - print(f"reading on structure {file}") - with open(path + 'json_result/' + file[:-4] + ".json", 'r') as file: - old_row = json.load(file) - return old_row - -def run(): - df = pd.DataFrame(columns=['name', 'density', 'energy_kj', 'step_used_1', 'step_used_2', 'total_time1_s', 'avg_time1_s', 'total_time2_s', 'avg_time2_s']) - for root, dirs, files in os.walk(path + target_folder): - old_row = Parallel(n_jobs=n_jobs)( - delayed(already_have_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in - enumerate(files) if os.path.exists(path + 'json_result/' + file[:-4] + ".json")) - - filtered_files = [file for file in files if not os.path.exists(path + 'json_result/' + file[:-4] + ".json")] - new_row = Parallel(n_jobs=n_jobs)( - delayed(run_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in - enumerate(filtered_files)) - # show the length of new_row - print(f'new_row length: {len(new_row)}') - print(f'root: {root}\ndirs: {dirs}\nfiles: {files}') - for row in new_row: - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) - for row in old_row: - df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) - df.to_csv(path + '/result.csv') - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Run parallel calculations on molecular crystals.") - parser.add_argument("--n_jobs", type=int, default=32, help="Number of parallel jobs to run (default: 32)") - parser.add_argument("--target_folder", type=str, required=True, help="Path to the target folder containing input files") - parser.add_argument("--path", type=str, default='./', help="Base path for the project (default: './')") - parser.add_argument("--molecule_single", type=int, default=-1, help="Number of atoms per molecule (default: 64)") - parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use (default: 2)") - parser.add_argument("--cueq", action='store_true', help="Whether to use cuEquivariance Library (default: False)") - parser.add_argument("--max_steps", type=int, default=3000, help="Number of max steps to run the optimization (default: 3000)") - parser.add_argument("--use_torch_profiler", action='store_true', help="Whether to use torch profiler (default: False)") - parser.add_argument("--use_nsys", action='store_true', help="Whether to use nsys profiler (default: False)") - parser.add_argument("--model", type=str, default="small", help="Model to use for the calculation (default: 'small')") - parser.add_argument("--optimizer", type=str, default="BFGS", help="Optimizer to use for the calculation (default: 'BFGS')") - parser.add_argument("--use_cuda_eigh", action='store_true', help="Whether to use CUDA for eigh (default: False)") - parser.add_argument("--gpu_offset", type=int, default=0, help="GPU offset to use for the calculation (default: 0)") - parser.add_argument("--multithread", action='store_true', help="Whether to use multithread (default: False)") - parser.add_argument("--reproduce", action='store_true', help="Whether to reproduce deterministic results (default: False)") - parser.add_argument("--filter1", type=str, default="UnitCellFilter", help="1st filter to use for the calculation (default: 'UnitCellFilter')") - parser.add_argument("--filter2", type=str, default="UnitCellFilter", help="2nd filter to use for the calculation (default: 'UnitCellFilter')") - parser.add_argument("--optimizer1", type=str, default="BFGS", help="1st optimizer to use for the calculation (default: 'BFGS')") - parser.add_argument("--optimizer2", type=str, default="BFGS", help="2nd optimizer to use for the calculation (default: 'BFGS')") - - args = parser.parse_args() - - n_jobs = args.n_jobs - target_folder = args.target_folder - path = args.path - molecule_single = args.molecule_single - n_gpus = args.n_gpus - cueq = args.cueq - max_steps = args.max_steps - use_torch_profiler = args.use_torch_profiler - use_nsys = args.use_nsys - model_path = args.model - optimizer_type = args.optimizer - use_cuda_eigh = args.use_cuda_eigh - gpu_offset = args.gpu_offset - multithread = args.multithread - reproduce = args.reproduce - filter1 = args.filter1 - filter2 = args.filter2 - optimizer_type1 = args.optimizer1 - optimizer_type2 = args.optimizer2 - - - try: - os.mkdir("./cif_result_press") - os.mkdir("./cif_result_final") - except: - pass - try: - os.mkdir("./json_result") - except: - pass - - start_time_all = time.time() - - - iter = 0 - while iter < 100: - iter += 1 - try: - run() - break - except Exception as e: - print(f"Error occurred: {e}") - print("Retrying...") - time.sleep(10) - - end_time_all = time.time() - total_time_all = end_time_all - start_time_all - print('dataset,total_time_all_s,attempts') - print(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}") - with open(path + 'timing.csv', 'w') as f: - f.write('dataset,total_time_all_s,attempts\n') +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import sys +# sys.path.append('/home/jiangj1group/zcxzcx1/volatile/mace') +from mace.calculators import mace_off, mace_mp +from ase.io import read, write +from ase.optimize import BFGS,LBFGS,FIRE,GPMin,MDMin, QuasiNewton +from ase.filters import UnitCellFilter, ExpCellFilter, FrechetCellFilter +import re +import io +from contextlib import redirect_stdout +import os +import pandas as pd +from joblib import Parallel, delayed +import json +import torch +import numpy as np +import random +import argparse +import time +import pathlib +import logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) +##################################################################### +os.environ['PYTHONHASHSEED'] = '1' +torch.manual_seed(1) +np.random.seed(1) +random.seed(1) +torch.cuda.manual_seed(1) +torch.cuda.manual_seed_all(1) +##################################################################### +# n_jobs=32 +# # n_jobs=2 +# path = './' +# molecule_single = 64 +# target_folder = "/data_raw/" +##################################################################### + +def calculate_density(crystal): + # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 + total_mass = sum(crystal.get_masses()) # 转换为克 + + # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 + # 1 Å^3 = 1e-24 cm^3 + volume = crystal.get_volume() # 转换为立方厘米 + + # 计算密度,质量除以体积 + density = total_mass / (volume*10**-24)/(6.022140857*10**23) # 单位是 g/cm^3 + return density + +def run_calculation_one(path,file,target_folder,molecule_single,idx): + # os.environ['OMP_NUM_THREADS'] = '1' + # os.environ['MKL_NUM_THREADS'] = '1' + # os.environ['OPENBLAS_NUM_THREADS'] = '1' + if reproduce: + print("Reproducing deterministic results.") + torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + np.set_printoptions(precision=17, suppress=False) + torch.set_printoptions(precision=17, sci_mode=False, linewidth=200) + if multithread and (not reproduce): + print("Using OMP and MKL multithreads will introduce non-deterministic results.") + else: + os.environ['OMP_NUM_THREADS'] = '1' + os.environ['MKL_NUM_THREADS'] = '1' + os.environ['OPENBLAS_NUM_THREADS'] = '1' + os.environ["CUDA_VISIBLE_DEVICES"]=str((idx%n_gpus)+gpu_offset) + + with io.StringIO() as buf, redirect_stdout(buf): + crystal = read(path+target_folder+file) + if molecule_single < 0: + molecule_single = int(file.split('_')[-1].split('.')[0]) + molecule_count = len(crystal.get_atomic_numbers())/molecule_single + calc = mace_off(model=model_path,dispersion=True, device='cuda', enable_cueq=cueq) + crystal.calc = calc + if filter1 == "UnitCellFilter": + sf = UnitCellFilter(crystal,scalar_pressure=0.0006) + elif filter1 == "FrechetCellFilter": + sf = FrechetCellFilter(crystal,scalar_pressure=0.0006) + else: + raise ValueError(f"Unrecognized filter type '{filter1}'. " + "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") + if optimizer_type1 == "BFGS": + if use_cuda_eigh: + optimizer = BFGS(sf, use_cuda_eigh=True) + else: + optimizer = BFGS(sf) + elif optimizer_type1 == "LBFGS": + optimizer = LBFGS(sf) + elif optimizer_type1 == "QuasiNewton": + optimizer = QuasiNewton(sf) + else: + raise ValueError(f"Unrecognized optimizer type '{optimizer_type1}'. " + "Supported types are 'BFGS' and 'LBFGS'.") + + if use_nsys or use_torch_profiler : # warmup for profiling + optimizer.run(fmax=0.01,steps=100) + if use_torch_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), + with_stack=True + ) + profiler.start() + + start_time1 = time.time() + optimizer.run(fmax=0.01,steps=max_steps) + end_time1 = time.time() + + if use_torch_profiler: + profiler.stop() + + crystal.write(path+'cif_result_press/'+file[:-4]+"_press.cif") + output_1 = buf.getvalue() + # step_used_1 = float(re.split("\\s+", output_1.split('\n')[-2])[1][:]) + step_used_1 = optimizer.nsteps + if use_nsys or use_torch_profiler : + step_used_1 = step_used_1 - 100 + total_time1 = end_time1 - start_time1 + avg_time1 = total_time1 / step_used_1 if step_used_1 != 0 else 0 + + crystal = read(path+'cif_result_press/'+file[:-4]+"_press.cif") + crystal.calc = calc + if filter2 == "UnitCellFilter": + sf = UnitCellFilter(crystal) + elif filter2 == "FrechetCellFilter": + sf = FrechetCellFilter(crystal) + else: + raise ValueError(f"Unrecognized filter type '{filter2}'. " + "Supported types are 'UnitCellFilter' and 'FrechetCellFilter'.") + if optimizer_type2 == "BFGS": + if use_cuda_eigh: + optimizer = BFGS(sf, use_cuda_eigh=True) + else: + optimizer = BFGS(sf) + elif optimizer_type2 == "LBFGS": + optimizer = LBFGS(sf) + elif optimizer_type2 == "QuasiNewton": + optimizer = QuasiNewton(sf) + else: + raise ValueError(f"Unrecognized optimizer type '{optimizer_type2}'. " + "Supported types are 'BFGS' and 'LBFGS'.") + if use_torch_profiler: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./log'), + with_stack=True + ) + profiler.start() + + start_time2 = time.time() + optimizer.run(fmax=0.01,steps=max_steps) + end_time2 = time.time() + + if use_torch_profiler: + profiler.stop() + + density = calculate_density(crystal) + crystal.write(path+'cif_result_final/'+file[:-4]+"_opt.cif") + output_2 = buf.getvalue() + energy = float(re.split("\\s+", output_2.split('\n')[-2])[3][:]) + # step_used_2 = float(re.split("\\s+", output_2.split('\n')[-2])[1][:]) + step_used_2 = optimizer.nsteps + energy_per_mol = energy / molecule_count * 96.485 + total_time2 = end_time2 - start_time2 + avg_time2 = total_time2 / step_used_2 if step_used_2 != 0 else 0 + + new_row = { + 'name': file[:-4], 'density': density, 'energy_kj': energy_per_mol, + 'step_used_1': step_used_1, 'step_used_2': step_used_2, + 'total_time1_s': total_time1, 'avg_time1_s': avg_time1, + 'total_time2_s': total_time2, 'avg_time2_s': avg_time2 + } + + print(f'output_2: {output_2}') + with open(path+'json_result/'+file[:-4]+".json", 'w') as json_file: + json.dump(new_row, json_file, indent=4) + return new_row + + +def already_have_calculation_one(path, file, target_folder, molecule_single, idx): + logging.info(f"reading on structure {file}") + print(f"reading on structure {file}") + with open(path + 'json_result/' + file[:-4] + ".json", 'r') as file: + old_row = json.load(file) + return old_row + +def run(): + df = pd.DataFrame(columns=['name', 'density', 'energy_kj', 'step_used_1', 'step_used_2', 'total_time1_s', 'avg_time1_s', 'total_time2_s', 'avg_time2_s']) + for root, dirs, files in os.walk(path + target_folder): + old_row = Parallel(n_jobs=n_jobs)( + delayed(already_have_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in + enumerate(files) if os.path.exists(path + 'json_result/' + file[:-4] + ".json")) + + filtered_files = [file for file in files if not os.path.exists(path + 'json_result/' + file[:-4] + ".json")] + new_row = Parallel(n_jobs=n_jobs)( + delayed(run_calculation_one)(path, file, target_folder, molecule_single, idx) for idx, file in + enumerate(filtered_files)) + # show the length of new_row + print(f'new_row length: {len(new_row)}') + print(f'root: {root}\ndirs: {dirs}\nfiles: {files}') + for row in new_row: + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) + for row in old_row: + df = pd.concat([df, pd.DataFrame([row])], ignore_index=True, axis=0) + df.to_csv(path + '/result.csv') + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Run parallel calculations on molecular crystals.") + parser.add_argument("--n_jobs", type=int, default=32, help="Number of parallel jobs to run (default: 32)") + parser.add_argument("--target_folder", type=str, required=True, help="Path to the target folder containing input files") + parser.add_argument("--path", type=str, default='./', help="Base path for the project (default: './')") + parser.add_argument("--molecule_single", type=int, default=-1, help="Number of atoms per molecule (default: 64)") + parser.add_argument("--n_gpus", type=int, default=2, help="Number of GPUs to use (default: 2)") + parser.add_argument("--cueq", action='store_true', help="Whether to use cuEquivariance Library (default: False)") + parser.add_argument("--max_steps", type=int, default=3000, help="Number of max steps to run the optimization (default: 3000)") + parser.add_argument("--use_torch_profiler", action='store_true', help="Whether to use torch profiler (default: False)") + parser.add_argument("--use_nsys", action='store_true', help="Whether to use nsys profiler (default: False)") + parser.add_argument("--model", type=str, default="small", help="Model to use for the calculation (default: 'small')") + parser.add_argument("--optimizer", type=str, default="BFGS", help="Optimizer to use for the calculation (default: 'BFGS')") + parser.add_argument("--use_cuda_eigh", action='store_true', help="Whether to use CUDA for eigh (default: False)") + parser.add_argument("--gpu_offset", type=int, default=0, help="GPU offset to use for the calculation (default: 0)") + parser.add_argument("--multithread", action='store_true', help="Whether to use multithread (default: False)") + parser.add_argument("--reproduce", action='store_true', help="Whether to reproduce deterministic results (default: False)") + parser.add_argument("--filter1", type=str, default="UnitCellFilter", help="1st filter to use for the calculation (default: 'UnitCellFilter')") + parser.add_argument("--filter2", type=str, default="UnitCellFilter", help="2nd filter to use for the calculation (default: 'UnitCellFilter')") + parser.add_argument("--optimizer1", type=str, default="BFGS", help="1st optimizer to use for the calculation (default: 'BFGS')") + parser.add_argument("--optimizer2", type=str, default="BFGS", help="2nd optimizer to use for the calculation (default: 'BFGS')") + + args = parser.parse_args() + + n_jobs = args.n_jobs + target_folder = args.target_folder + path = args.path + molecule_single = args.molecule_single + n_gpus = args.n_gpus + cueq = args.cueq + max_steps = args.max_steps + use_torch_profiler = args.use_torch_profiler + use_nsys = args.use_nsys + model_path = args.model + optimizer_type = args.optimizer + use_cuda_eigh = args.use_cuda_eigh + gpu_offset = args.gpu_offset + multithread = args.multithread + reproduce = args.reproduce + filter1 = args.filter1 + filter2 = args.filter2 + optimizer_type1 = args.optimizer1 + optimizer_type2 = args.optimizer2 + + + try: + os.mkdir("./cif_result_press") + os.mkdir("./cif_result_final") + except: + pass + try: + os.mkdir("./json_result") + except: + pass + + start_time_all = time.time() + + + iter = 0 + while iter < 100: + iter += 1 + try: + run() + break + except Exception as e: + print(f"Error occurred: {e}") + print("Retrying...") + time.sleep(10) + + end_time_all = time.time() + total_time_all = end_time_all - start_time_all + print('dataset,total_time_all_s,attempts') + print(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}") + with open(path + 'timing.csv', 'w') as f: + f.write('dataset,total_time_all_s,attempts\n') f.write(f"{pathlib.Path(target_folder).name},{total_time_all},{iter}\n") \ No newline at end of file diff --git a/mace-bench/reproduce/perf_v2_base/run_mace.sh b/mace-bench/reproduce/perf_v2_base/run_mace.sh index b886fe1..c9e81c8 100644 --- a/mace-bench/reproduce/perf_v2_base/run_mace.sh +++ b/mace-bench/reproduce/perf_v2_base/run_mace.sh @@ -1,5 +1,4 @@ -#!/bin/bash - -python ../mace_opt_new.py --n_jobs 64 --molecule_single 46 \ - --target_folder ../../data/perf_v2/ --model small --n_gpus 4 --gpu_offset 0 \ + +python ../mace_opt_new.py --n_jobs 64 --molecule_single 46 \ + --target_folder ../../data/perf_v2/ --model small --n_gpus 4 --gpu_offset 0 \ --optimizer1 QuasiNewton --filter1 UnitCellFilter --filter2 UnitCellFilter \ No newline at end of file diff --git a/mace-bench/reproduce/perf_v2_batch/opt.sh b/mace-bench/reproduce/perf_v2_batch/opt.sh index 51e8f97..9c8c172 100644 --- a/mace-bench/reproduce/perf_v2_batch/opt.sh +++ b/mace-bench/reproduce/perf_v2_batch/opt.sh @@ -1,6 +1,5 @@ -#!/bin/bash - -rm -r *_result_* - -python ../../scripts/mace_opt_batch.py --target_folder "../../data/perf_v2" --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers 40 --batch_size 0 \ + +rm -r *_result_* + +python ../../scripts/mace_opt_batch.py --target_folder "../../data/perf_v2" --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers 40 --batch_size 0 \ --max_steps 6000 --filter1 UnitCellFilter --filter2 UnitCellFilter --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 --cueq true --use_ordered_files true \ No newline at end of file diff --git a/mace-bench/reproduce/subtest.sh b/mace-bench/reproduce/subtest.sh index f9c74c7..8aaa930 100644 --- a/mace-bench/reproduce/subtest.sh +++ b/mace-bench/reproduce/subtest.sh @@ -1,25 +1,25 @@ -#!/bin/bash - -top_dir=$(pwd) - -natoms_nw_bs=( - "92 48 25" - "184 40 12" - "368 40 5" -) - -for config in "${natoms_nw_bs[@]}"; do - read natoms nw bs <<< "$config" - - dir="$top_dir/subtest_BATCH_${natoms}_g4_j${nw}_bs${bs}_cueq_cupbc" - mkdir -p "$dir" - cd "$dir" || continue - - pwd - python ../../scripts/mace_opt_batch.py \ - --target_folder "../../data/perf_v2_sorted/perf_v2_${natoms}" \ - --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers ${nw} --batch_size ${bs} \ - --max_steps 6000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ - --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 \ - --use_ordered_files true --cueq true > opt.log 2>&1 + + +top_dir=$(pwd) + +natoms_nw_bs=( + "92 48 25" + "184 40 12" + "368 40 5" +) + +for config in "${natoms_nw_bs[@]}"; do + read natoms nw bs <<< "$config" + + dir="$top_dir/subtest_BATCH_${natoms}_g4_j${nw}_bs${bs}_cueq_cupbc" + mkdir -p "$dir" + cd "$dir" || continue + + pwd + python ../../scripts/mace_opt_batch.py \ + --target_folder "../../data/perf_v2_sorted/perf_v2_${natoms}" \ + --molecule_single 46 --gpu_offset 0 --n_gpus 4 --num_workers ${nw} --batch_size ${bs} \ + --max_steps 6000 --filter1 UnitCellFilter --filter2 UnitCellFilter \ + --optimizer1 BFGSFusedLS --optimizer2 BFGS --num_threads 2 \ + --use_ordered_files true --cueq true > opt.log 2>&1 done \ No newline at end of file diff --git a/mace-bench/reproduce/subtest_baseline.sh b/mace-bench/reproduce/subtest_baseline.sh index 64761d4..08e4c49 100644 --- a/mace-bench/reproduce/subtest_baseline.sh +++ b/mace-bench/reproduce/subtest_baseline.sh @@ -1,24 +1,24 @@ -#!/bin/bash - -top_dir=$(pwd) - -natoms_nw_bs=( - "92 64" - "184 64" - "368 64" -) - -for config in "${natoms_nw_bs[@]}"; do - read natoms nw <<< "$config" - - dir="$top_dir/subtest_BASE_${natoms}_g4_j${nw}" - mkdir -p "$dir" - cd "$dir" || continue - - pwd - - python ../mace_opt_new.py --n_jobs ${nw} --molecule_single 46 \ - --target_folder ../../data/perf_v2_sorted/perf_v2_${natoms}/ --model small --n_gpus 4 \ - --gpu_offset 0 --optimizer1 QuasiNewton --filter1 UnitCellFilter \ - --filter2 UnitCellFilter --max_steps 3000 > opt.log 2>&1 + + +top_dir=$(pwd) + +natoms_nw_bs=( + "92 64" + "184 64" + "368 64" +) + +for config in "${natoms_nw_bs[@]}"; do + read natoms nw <<< "$config" + + dir="$top_dir/subtest_BASE_${natoms}_g4_j${nw}" + mkdir -p "$dir" + cd "$dir" || continue + + pwd + + python ../mace_opt_new.py --n_jobs ${nw} --molecule_single 46 \ + --target_folder ../../data/perf_v2_sorted/perf_v2_${natoms}/ --model small --n_gpus 4 \ + --gpu_offset 0 --optimizer1 QuasiNewton --filter1 UnitCellFilter \ + --filter2 UnitCellFilter --max_steps 3000 > opt.log 2>&1 done \ No newline at end of file diff --git a/mace-bench/requirements.txt b/mace-bench/requirements.txt index aa57e95..135b56b 100644 --- a/mace-bench/requirements.txt +++ b/mace-bench/requirements.txt @@ -1,137 +1,137 @@ ---extra-index-url https://download.pytorch.org/whl/cu121 -absl-py==2.1.0 -aiohappyeyeballs==2.4.4 -aiohttp==3.11.11 -aiosignal==1.3.2 -annotated-types==0.7.0 -antlr4-python3-runtime==4.9.3 -# -e git+https://gitlab.com/ase/ase.git@72c50c76bac2396c7d58385b231c65bd07458279#egg=ase&subdirectory=../../../3rdparty/ase -async-timeout==5.0.1 -attrs==24.3.0 -certifi==2024.8.30 -cfgv==3.4.0 -charset-normalizer==3.4.0 -click==8.1.8 -cloudpickle==3.1.0 -ConfigArgParse==1.7 -contourpy==1.3.1 -coverage==7.6.9 -cuequivariance==0.4.0 -cuequivariance-ops-torch-cu12==0.4.0 -cuequivariance-ops-cu12==0.4.0 -cuequivariance-torch==0.4.0 -cycler==0.12.1 -distlib==0.3.9 -docker-pycreds==0.4.0 -e3nn==0.4.4 -exceptiongroup==1.2.2 -# -e git+https://github.com/mazhaojia123/fairchem.git@f50db9d5b29debdfb265d9c3fad394f18e16cab8#egg=fairchem_core&subdirectory=../../../3rdparty/fairchem/packages/fairchem-core -filelock==3.13.1 -fonttools==4.55.1 -frozenlist==1.5.0 -fsspec==2024.2.0 -gitdb==4.0.11 -GitPython==3.1.43 -grpcio==1.68.1 -h5py==3.12.1 -hydra-core==1.3.2 -identify==2.6.3 -idna==3.10 -iniconfig==2.0.0 -Jinja2==3.1.3 -joblib==1.4.2 -kiwisolver==1.4.7 -latexcodec==3.0.0 -lightning-utilities==0.11.9 -llvmlite==0.43.0 -lmdb==1.5.1 -# -e git+https://github.com/mazhaojia123/mace.git@edd6b479f4974d0b8162712872ad2eed1aa2fb75#egg=mace_torch&subdirectory=../../../3rdparty/mace -Markdown==3.7 -MarkupSafe==2.1.5 -matplotlib==3.9.3 -matscipy==1.1.1 -monty==2024.10.21 -mpmath==1.3.0 -multidict==6.1.0 -networkx==3.2.1 -nodeenv==1.9.1 -numba==0.60.0 -numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==9.1.0.70 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.20.5 -nvidia-nvjitlink-cu12==12.1.105 -nvidia-nvtx-cu12==12.1.105 -omegaconf==2.3.0 -opt-einsum-fx==0.1.4 -opt_einsum==3.4.0 -orjson==3.10.12 -packaging==24.2 -palettable==3.3.3 -pandas==2.2.3 -pillow==11.0.0 -platformdirs==4.3.6 -plotly==5.24.1 -pluggy==1.5.0 -pre_commit==4.0.1 -prettytable==3.12.0 -propcache==0.2.1 -protobuf==5.29.2 -psutil==6.1.1 -pybtex==0.24.0 -pydantic==2.10.4 -pydantic_core==2.27.2 -pymatgen==2024.11.13 -pyparsing==3.2.0 -pytest==8.3.4 -pytest-cov==6.0.0 -python-dateutil==2.9.0.post0 -python-hostlist==2.0.0 -pytz==2024.2 -PyYAML==6.0.2 -requests==2.32.3 -ruamel.yaml==0.18.6 -ruamel.yaml.clib==0.2.12 -ruff==0.5.1 -scipy==1.14.1 -sentry-sdk==2.19.2 -setproctitle==1.3.4 -six==1.16.0 -smmap==5.0.1 -spglib==2.5.0 -submitit==1.5.2 -sympy==1.13.1 -syrupy==4.8.0 -tabulate==0.9.0 -tenacity==9.0.0 -tensorboard==2.18.0 -tensorboard-data-server==0.7.2 -tomli==2.2.1 -torch==2.4.1+cu121 -# ./torch-2.4.1+cu121-cp310-cp310-linux_x86_64.whl -torch-dftd==0.5.1 -torch-ema==0.3 -torch-geometric==2.6.1 -# torch_scatter==2.1.2+pt24cu121 -# torch_sparse==0.6.18+pt24cu121 -# torch_spline_conv==1.2.2+pt24cu121 -torchmetrics==1.6.0 -tqdm==4.67.1 -triton==3.0.0 -typing_extensions==4.12.2 -tzdata==2024.2 -uncertainties==3.2.2 -urllib3==2.2.3 -virtualenv==20.28.0 -wandb==0.19.1 -wcwidth==0.2.13 -Werkzeug==3.1.3 -yarl==1.18.3 +--extra-index-url https://download.pytorch.org/whl/cu121 +absl-py==2.1.0 +aiohappyeyeballs==2.4.4 +aiohttp==3.11.11 +aiosignal==1.3.2 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +# -e git+https://gitlab.com/ase/ase.git@72c50c76bac2396c7d58385b231c65bd07458279#egg=ase&subdirectory=../../../3rdparty/ase +async-timeout==5.0.1 +attrs==24.3.0 +certifi==2024.8.30 +cfgv==3.4.0 +charset-normalizer==3.4.0 +click==8.1.8 +cloudpickle==3.1.0 +ConfigArgParse==1.7 +contourpy==1.3.1 +coverage==7.6.9 +cuequivariance==0.4.0 +cuequivariance-ops-torch-cu12==0.4.0 +cuequivariance-ops-cu12==0.4.0 +cuequivariance-torch==0.4.0 +cycler==0.12.1 +distlib==0.3.9 +docker-pycreds==0.4.0 +e3nn==0.4.4 +exceptiongroup==1.2.2 +# -e git+https://github.com/mazhaojia123/fairchem.git@f50db9d5b29debdfb265d9c3fad394f18e16cab8#egg=fairchem_core&subdirectory=../../../3rdparty/fairchem/packages/fairchem-core +filelock==3.13.1 +fonttools==4.55.1 +frozenlist==1.5.0 +fsspec==2024.2.0 +gitdb==4.0.11 +GitPython==3.1.43 +grpcio==1.68.1 +h5py==3.12.1 +hydra-core==1.3.2 +identify==2.6.3 +idna==3.10 +iniconfig==2.0.0 +Jinja2==3.1.3 +joblib==1.4.2 +kiwisolver==1.4.7 +latexcodec==3.0.0 +lightning-utilities==0.11.9 +llvmlite==0.43.0 +lmdb==1.5.1 +# -e git+https://github.com/mazhaojia123/mace.git@edd6b479f4974d0b8162712872ad2eed1aa2fb75#egg=mace_torch&subdirectory=../../../3rdparty/mace +Markdown==3.7 +MarkupSafe==2.1.5 +matplotlib==3.9.3 +matscipy==1.1.1 +monty==2024.10.21 +mpmath==1.3.0 +multidict==6.1.0 +networkx==3.2.1 +nodeenv==1.9.1 +numba==0.60.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.1.105 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.3.0 +opt-einsum-fx==0.1.4 +opt_einsum==3.4.0 +orjson==3.10.12 +packaging==24.2 +palettable==3.3.3 +pandas==2.2.3 +pillow==11.0.0 +platformdirs==4.3.6 +plotly==5.24.1 +pluggy==1.5.0 +pre_commit==4.0.1 +prettytable==3.12.0 +propcache==0.2.1 +protobuf==5.29.2 +psutil==6.1.1 +pybtex==0.24.0 +pydantic==2.10.4 +pydantic_core==2.27.2 +pymatgen==2024.11.13 +pyparsing==3.2.0 +pytest==8.3.4 +pytest-cov==6.0.0 +python-dateutil==2.9.0.post0 +python-hostlist==2.0.0 +pytz==2024.2 +PyYAML==6.0.2 +requests==2.32.3 +ruamel.yaml==0.18.6 +ruamel.yaml.clib==0.2.12 +ruff==0.5.1 +scipy==1.14.1 +sentry-sdk==2.19.2 +setproctitle==1.3.4 +six==1.16.0 +smmap==5.0.1 +spglib==2.5.0 +submitit==1.5.2 +sympy==1.13.1 +syrupy==4.8.0 +tabulate==0.9.0 +tenacity==9.0.0 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +tomli==2.2.1 +torch==2.4.1+cu121 +# ./torch-2.4.1+cu121-cp310-cp310-linux_x86_64.whl +torch-dftd==0.5.1 +torch-ema==0.3 +torch-geometric==2.6.1 +# torch_scatter==2.1.2+pt24cu121 +# torch_sparse==0.6.18+pt24cu121 +# torch_spline_conv==1.2.2+pt24cu121 +torchmetrics==1.6.0 +tqdm==4.67.1 +triton==3.0.0 +typing_extensions==4.12.2 +tzdata==2024.2 +uncertainties==3.2.2 +urllib3==2.2.3 +virtualenv==20.28.0 +wandb==0.19.1 +wcwidth==0.2.13 +Werkzeug==3.1.3 +yarl==1.18.3 torch-tb-profiler==0.4.3 \ No newline at end of file diff --git a/mace-bench/scripts/mace_opt_batch.py b/mace-bench/scripts/mace_opt_batch.py index e5c53fa..6d6ab10 100644 --- a/mace-bench/scripts/mace_opt_batch.py +++ b/mace-bench/scripts/mace_opt_batch.py @@ -1,112 +1,112 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -import os -import argparse - -parser = argparse.ArgumentParser(description="Run batch optimization on molecular crystals.") -parser.add_argument("--target_folder", type=str, required=True, help="Target folder containing crystal files") -parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to distribute the files to") -parser.add_argument("--n_gpus", type=int, default=1, help="Number of GPUs to use for the optimization") -parser.add_argument("--gpu_offset", type=int, default=0, help="Offset for GPU numbering") -parser.add_argument("--batch_size", type=int, default=4, help="Number of files to process in a single batch") -parser.add_argument("--run_baseline", type=bool, default=False, help="Run baseline optimization using LBFGS from ase.optimize") -parser.add_argument("--max_steps", type=int, default=100, help="Number of max steps to run the optimization (default: 100)") -parser.add_argument("--filter1", type=str, default="UnitCellFilter", - choices=[None, "UnitCellFilter"], - help="Type of cell filter to use in first optimization") -parser.add_argument("--filter2", type=str, default="UnitCellFilter", - choices=[None, "UnitCellFilter"], - help="Type of cell filter to use in second optimization") -parser.add_argument("--optimizer1", type=str, default="BFGS", - choices=["LBFGS", "QuasiNewton", "BFGS", "BFGSLineSearch", "BFGSFusedLS"], - help="First optimizer to use (default: BFGS)") -parser.add_argument("--optimizer2", type=str, default="BFGS", - choices=["LBFGS", "QuasiNewton", "BFGS", "BFGSLineSearch", "BFGSFusedLS"], - help="Second optimizer to use (default: LBFGS)") -parser.add_argument("--skip_second_stage", type=bool, default=False, help="Skip the second optimization stage") -parser.add_argument("--scalar_pressure", type=float, default=0.0006, - help="Scalar pressure for cell optimization (default: 0.0006)") -parser.add_argument("--compile_mode", type=str, default=None, - choices=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], - help="Compile mode for MACE calculator") -parser.add_argument("--profile", type=str, default="False", - help="Enable profiling. Set to 'True' for basic profiling or provide a JSON string with profiler config options for wait, warmup, active, and repeat") -parser.add_argument("--num_threads", type=int, default=16, help="Number of cpu threads per process to use while running the optimization") -parser.add_argument("--bind_cores", type=str, default=None, - help=("Specify a comma-separated list of core ranges (e.g., '0-15,16-31,...') for each worker. The number of ranges must equal --num_workers.")) -parser.add_argument("--cueq", type=bool, default=False, help="Whether to use cuEquivariance Library (default: False)") -parser.add_argument("--molecule_single", type=int, default=64, help="Number of atoms per molecule (default: 64)") -parser.add_argument("--output_path", type=str, default="./", help="Absolute path for output files") -parser.add_argument("--model", type=str, default="mace", choices=["mace", "chgnet", "sevennet"], help="Model to use for optimization") -parser.add_argument("--use_ordered_files", type=bool, default=False, - help="Whether to sort files by atomic number in descending order before optimization") -args = parser.parse_args() - -os.environ['OMP_NUM_THREADS'] = str(args.num_threads) -os.environ['MKL_NUM_THREADS'] = str(args.num_threads) - -import pathlib -import logging -from batchopt import Scheduler, ensure_directory, run_baseline, count_atoms_cif -logging.basicConfig( - level=logging.WARNING, - format='%(asctime)s - %(process)d - %(levelname)s - %(message)s', - datefmt='%H:%M:%S', - force=True -) - -if __name__ == '__main__': - target_folder = pathlib.Path(args.target_folder) - files = [str(file) for file in target_folder.glob("*.cif")] - devices = [f"cuda:{i}" for i in range(args.gpu_offset, args.gpu_offset + args.n_gpus)] - - logging.info("Starting batch optimization.") - logging.info(f"Use devices: {devices}") - logging.info(f"files: {files}") - - output_path = args.output_path - if not os.path.isabs(output_path): - output_path = os.path.abspath(output_path) - logging.info(f"Output path: {output_path}") - - for output_dir in ["cif_result_press", "cif_result_final", "json_result_press", "json_result_final", "worker_results", "log"]: - dir_path = os.path.join(output_path, output_dir) - ensure_directory(dir_path) - - import time - start_time = time.perf_counter() - - use_ordered_files = args.use_ordered_files - if use_ordered_files: - logging.info(f"Use ordered files.") - if files[0].endswith("cif"): - files = sorted(files, key=count_atoms_cif, reverse=True) - else: - logging.error(f"No support for the file type in {target_folder}.") - end_time = time.perf_counter() - logging.info(f"atomic sorting time: {end_time - start_time:.4f} seconds.") - - if args.run_baseline: - run_baseline(files, args.num_workers, devices, args.max_steps, - args.filter1, args.filter2, args.skip_second_stage, - args.scalar_pressure, args.optimizer1, args.optimizer2, - output_path=output_path) - else: - scheduler = Scheduler(files=files, num_workers=args.num_workers, devices=devices, - batch_size=args.batch_size, max_steps=args.max_steps, - filter1=args.filter1, filter2=args.filter2, - skip_second_stage=args.skip_second_stage, - scalar_pressure=args.scalar_pressure, optimizer1=args.optimizer1, optimizer2=args.optimizer2, - compile_mode=args.compile_mode, profile=args.profile, - num_threads=args.num_threads, bind_cores=args.bind_cores, - cueq=args.cueq, molecule_single=args.molecule_single, - output_path=output_path, model=args.model) - scheduler.run() - - logging.info("Batch optimization completed.") - +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import argparse + +parser = argparse.ArgumentParser(description="Run batch optimization on molecular crystals.") +parser.add_argument("--target_folder", type=str, required=True, help="Target folder containing crystal files") +parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to distribute the files to") +parser.add_argument("--n_gpus", type=int, default=1, help="Number of GPUs to use for the optimization") +parser.add_argument("--gpu_offset", type=int, default=0, help="Offset for GPU numbering") +parser.add_argument("--batch_size", type=int, default=4, help="Number of files to process in a single batch") +parser.add_argument("--run_baseline", type=bool, default=False, help="Run baseline optimization using LBFGS from ase.optimize") +parser.add_argument("--max_steps", type=int, default=100, help="Number of max steps to run the optimization (default: 100)") +parser.add_argument("--filter1", type=str, default="UnitCellFilter", + choices=[None, "UnitCellFilter"], + help="Type of cell filter to use in first optimization") +parser.add_argument("--filter2", type=str, default="UnitCellFilter", + choices=[None, "UnitCellFilter"], + help="Type of cell filter to use in second optimization") +parser.add_argument("--optimizer1", type=str, default="BFGS", + choices=["LBFGS", "QuasiNewton", "BFGS", "BFGSLineSearch", "BFGSFusedLS"], + help="First optimizer to use (default: BFGS)") +parser.add_argument("--optimizer2", type=str, default="BFGS", + choices=["LBFGS", "QuasiNewton", "BFGS", "BFGSLineSearch", "BFGSFusedLS"], + help="Second optimizer to use (default: LBFGS)") +parser.add_argument("--skip_second_stage", type=bool, default=False, help="Skip the second optimization stage") +parser.add_argument("--scalar_pressure", type=float, default=0.0006, + help="Scalar pressure for cell optimization (default: 0.0006)") +parser.add_argument("--compile_mode", type=str, default=None, + choices=[None, "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], + help="Compile mode for MACE calculator") +parser.add_argument("--profile", type=str, default="False", + help="Enable profiling. Set to 'True' for basic profiling or provide a JSON string with profiler config options for wait, warmup, active, and repeat") +parser.add_argument("--num_threads", type=int, default=16, help="Number of cpu threads per process to use while running the optimization") +parser.add_argument("--bind_cores", type=str, default=None, + help=("Specify a comma-separated list of core ranges (e.g., '0-15,16-31,...') for each worker. The number of ranges must equal --num_workers.")) +parser.add_argument("--cueq", type=bool, default=False, help="Whether to use cuEquivariance Library (default: False)") +parser.add_argument("--molecule_single", type=int, default=64, help="Number of atoms per molecule (default: 64)") +parser.add_argument("--output_path", type=str, default="./", help="Absolute path for output files") +parser.add_argument("--model", type=str, default="mace", choices=["mace", "chgnet", "sevennet"], help="Model to use for optimization") +parser.add_argument("--use_ordered_files", type=bool, default=False, + help="Whether to sort files by atomic number in descending order before optimization") +args = parser.parse_args() + +os.environ['OMP_NUM_THREADS'] = str(args.num_threads) +os.environ['MKL_NUM_THREADS'] = str(args.num_threads) + +import pathlib +import logging +from batchopt import Scheduler, ensure_directory, run_baseline, count_atoms_cif +logging.basicConfig( + level=logging.WARNING, + format='%(asctime)s - %(process)d - %(levelname)s - %(message)s', + datefmt='%H:%M:%S', + force=True +) + +if __name__ == '__main__': + target_folder = pathlib.Path(args.target_folder) + files = [str(file) for file in target_folder.glob("*.cif")] + devices = [f"cuda:{i}" for i in range(args.gpu_offset, args.gpu_offset + args.n_gpus)] + + logging.info("Starting batch optimization.") + logging.info(f"Use devices: {devices}") + logging.info(f"files: {files}") + + output_path = args.output_path + if not os.path.isabs(output_path): + output_path = os.path.abspath(output_path) + logging.info(f"Output path: {output_path}") + + for output_dir in ["cif_result_press", "cif_result_final", "json_result_press", "json_result_final", "worker_results", "log"]: + dir_path = os.path.join(output_path, output_dir) + ensure_directory(dir_path) + + import time + start_time = time.perf_counter() + + use_ordered_files = args.use_ordered_files + if use_ordered_files: + logging.info(f"Use ordered files.") + if files[0].endswith("cif"): + files = sorted(files, key=count_atoms_cif, reverse=True) + else: + logging.error(f"No support for the file type in {target_folder}.") + end_time = time.perf_counter() + logging.info(f"atomic sorting time: {end_time - start_time:.4f} seconds.") + + if args.run_baseline: + run_baseline(files, args.num_workers, devices, args.max_steps, + args.filter1, args.filter2, args.skip_second_stage, + args.scalar_pressure, args.optimizer1, args.optimizer2, + output_path=output_path) + else: + scheduler = Scheduler(files=files, num_workers=args.num_workers, devices=devices, + batch_size=args.batch_size, max_steps=args.max_steps, + filter1=args.filter1, filter2=args.filter2, + skip_second_stage=args.skip_second_stage, + scalar_pressure=args.scalar_pressure, optimizer1=args.optimizer1, optimizer2=args.optimizer2, + compile_mode=args.compile_mode, profile=args.profile, + num_threads=args.num_threads, bind_cores=args.bind_cores, + cueq=args.cueq, molecule_single=args.molecule_single, + output_path=output_path, model=args.model) + scheduler.run() + + logging.info("Batch optimization completed.") + diff --git a/mace-bench/setup.py b/mace-bench/setup.py index b2de1fc..f7e1680 100644 --- a/mace-bench/setup.py +++ b/mace-bench/setup.py @@ -1,23 +1,23 @@ -from setuptools import setup, find_packages - -setup( - name='BOMLIP-CSP', - version='0.1', - author='Chengxi Zhao, Zhaojia Ma, Dingrui Fan', - author_email='chengxi_zhao@ustc.edu.cn, zhaojia_ma@foxmail.com', - description='Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction', - url='https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP', - license='MIT', - classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.10', - 'Topic :: Scientific/Engineering :: Chemistry', - 'Topic :: Scientific/Engineering :: Physics', - ], - python_requires='>=3.10', - package_dir={'': 'src'}, - packages=find_packages('src'), +from setuptools import setup, find_packages + +setup( + name='BOMLIP-CSP', + version='0.1', + author='Chengxi Zhao, Zhaojia Ma, Dingrui Fan', + author_email='chengxi_zhao@ustc.edu.cn, zhaojia_ma@foxmail.com', + description='Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction', + url='https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP', + license='MIT', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.10', + 'Topic :: Scientific/Engineering :: Chemistry', + 'Topic :: Scientific/Engineering :: Physics', + ], + python_requires='>=3.10', + package_dir={'': 'src'}, + packages=find_packages('src'), ) \ No newline at end of file diff --git a/mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO b/mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO index 7be7ce7..d9cc8e1 100644 --- a/mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO +++ b/mace-bench/src/BOMLIP_CSP.egg-info/PKG-INFO @@ -1,23 +1,23 @@ -Metadata-Version: 2.4 -Name: BOMLIP-CSP -Version: 0.1 -Summary: Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction -Home-page: https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP -Author: Chengxi Zhao, Zhaojia Ma, Dingrui Fan -Author-email: chengxi_zhao@ustc.edu.cn, zhaojia_ma@foxmail.com -License: MIT -Classifier: Development Status :: 3 - Alpha -Classifier: Intended Audience :: Science/Research -Classifier: License :: OSI Approved :: MIT License -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Topic :: Scientific/Engineering :: Chemistry -Classifier: Topic :: Scientific/Engineering :: Physics -Requires-Python: >=3.10 -Dynamic: author -Dynamic: author-email -Dynamic: classifier -Dynamic: home-page -Dynamic: license -Dynamic: requires-python -Dynamic: summary +Metadata-Version: 2.4 +Name: BOMLIP-CSP +Version: 0.1 +Summary: Integrating machine learning interatomic potentials with batched optimization for crystal structure prediction +Home-page: https://github.com/pic-ai-robotic-chemistry/BOMLIP-CSP +Author: Chengxi Zhao, Zhaojia Ma, Dingrui Fan +Author-email: chengxi_zhao@ustc.edu.cn, zhaojia_ma@foxmail.com +License: MIT +Classifier: Development Status :: 3 - Alpha +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Topic :: Scientific/Engineering :: Chemistry +Classifier: Topic :: Scientific/Engineering :: Physics +Requires-Python: >=3.10 +Dynamic: author +Dynamic: author-email +Dynamic: classifier +Dynamic: home-page +Dynamic: license +Dynamic: requires-python +Dynamic: summary diff --git a/mace-bench/src/BOMLIP_CSP.egg-info/SOURCES.txt b/mace-bench/src/BOMLIP_CSP.egg-info/SOURCES.txt index 0335aed..11e6ddc 100644 --- a/mace-bench/src/BOMLIP_CSP.egg-info/SOURCES.txt +++ b/mace-bench/src/BOMLIP_CSP.egg-info/SOURCES.txt @@ -1,20 +1,20 @@ -setup.py -src/BOMLIP_CSP.egg-info/PKG-INFO -src/BOMLIP_CSP.egg-info/SOURCES.txt -src/BOMLIP_CSP.egg-info/dependency_links.txt -src/BOMLIP_CSP.egg-info/top_level.txt -src/batchopt/__init__.py -src/batchopt/atoms_to_graphs.py -src/batchopt/baseline.py -src/batchopt/pbc_graph.py -src/batchopt/pbc_graph_legacy.py -src/batchopt/relaxengine.py -src/batchopt/utils.py -src/batchopt/extensions/__init__.py -src/batchopt/extensions/cuda_ops/__init__.py -src/batchopt/relaxation/__init__.py -src/batchopt/relaxation/ase_utils.py -src/batchopt/relaxation/optimizable.py -src/batchopt/relaxation/optimizers/__init__.py -src/batchopt/relaxation/optimizers/bfgs_torch.py +setup.py +src/BOMLIP_CSP.egg-info/PKG-INFO +src/BOMLIP_CSP.egg-info/SOURCES.txt +src/BOMLIP_CSP.egg-info/dependency_links.txt +src/BOMLIP_CSP.egg-info/top_level.txt +src/batchopt/__init__.py +src/batchopt/atoms_to_graphs.py +src/batchopt/baseline.py +src/batchopt/pbc_graph.py +src/batchopt/pbc_graph_legacy.py +src/batchopt/relaxengine.py +src/batchopt/utils.py +src/batchopt/extensions/__init__.py +src/batchopt/extensions/cuda_ops/__init__.py +src/batchopt/relaxation/__init__.py +src/batchopt/relaxation/ase_utils.py +src/batchopt/relaxation/optimizable.py +src/batchopt/relaxation/optimizers/__init__.py +src/batchopt/relaxation/optimizers/bfgs_torch.py src/batchopt/relaxation/optimizers/bfgsfusedls.py \ No newline at end of file diff --git a/mace-bench/src/BOMLIP_CSP.egg-info/dependency_links.txt b/mace-bench/src/BOMLIP_CSP.egg-info/dependency_links.txt index d3f5a12..8b13789 100644 --- a/mace-bench/src/BOMLIP_CSP.egg-info/dependency_links.txt +++ b/mace-bench/src/BOMLIP_CSP.egg-info/dependency_links.txt @@ -1 +1 @@ - + diff --git a/mace-bench/src/BOMLIP_CSP.egg-info/top_level.txt b/mace-bench/src/BOMLIP_CSP.egg-info/top_level.txt index 0961265..fe9f299 100644 --- a/mace-bench/src/BOMLIP_CSP.egg-info/top_level.txt +++ b/mace-bench/src/BOMLIP_CSP.egg-info/top_level.txt @@ -1 +1 @@ -batchopt +batchopt diff --git a/mace-bench/src/batchopt/__init__.py b/mace-bench/src/batchopt/__init__.py index c51e983..5d823fa 100644 --- a/mace-bench/src/batchopt/__init__.py +++ b/mace-bench/src/batchopt/__init__.py @@ -1,30 +1,30 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from .relaxengine import Scheduler, Worker -from .baseline import ensure_directory, run_baseline -from .utils import count_atoms_cif -from .pbc_graph import radius_graph_pbc_cuda - -try: - from . import extensions - _extensions_available = True -except ImportError as e: - import warnings - warnings.warn(f"Extensions not available: {e}. Falling back to PyTorch implementations.") - extensions = None - _extensions_available = False - -__all__ = [ - "Scheduler", - "ensure_directory", - "run_baseline", - "count_atoms_cif", - "Worker", - "extensions", - "radius_graph_pbc_cuda", +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from .relaxengine import Scheduler, Worker +from .baseline import ensure_directory, run_baseline +from .utils import count_atoms_cif +from .pbc_graph import radius_graph_pbc_cuda + +try: + from . import extensions + _extensions_available = True +except ImportError as e: + import warnings + warnings.warn(f"Extensions not available: {e}. Falling back to PyTorch implementations.") + extensions = None + _extensions_available = False + +__all__ = [ + "Scheduler", + "ensure_directory", + "run_baseline", + "count_atoms_cif", + "Worker", + "extensions", + "radius_graph_pbc_cuda", ] \ No newline at end of file diff --git a/mace-bench/src/batchopt/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index dbd2e9aa6a41d35935da53ac0620a105f60c740c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 894 zcmYjP&uiN-6qcM{j^i|GM;YuS*ku_k8HHWO7=w~R=~@a2lxYW{NJ^|$wu~&VO-}nW z24k>|U3%DY|BSCY^)DEV9VR8Los2)f_r0h0-e*a7yAg2f`;Y77R{+3I@BF`f^p$&y zZw@%XfI}WAm_5=*JRv)5$oiB}CBo;kl1GM*a?-@F#2q#pZOBzMfL>8#frC|aw%>{C< zl1yZU;Hb>FFvunX9Uo5r(P*w6k7VWp(d)zg@yTS2=2E%6YYrK$E#lG$X0=%&J$IRR zU91se7~Mek69+-np0GslQVCNv&$PJ^#*uy3zBGc|)~I4rW@JVSp=2fou<9{4geBDK zv>;5*9kpvHmt{c~hUN*$XN<6tQx}NDm30NB&Wh=c>k3X*55~6?%G|K@9hHjClz5I- zkB8`hD%WR$W|UnZtI_Flstrq!Omihtky&bewc){_W!y~xrtYEPq3NOKA@tC4Ft}H> z2;rI$f;&bidL^=jld;+Ab8+}j+Fd2wD6Le*x*%?QIcs6-iq)n%-X(X^$gPH|bC~AZ z*s(F%R8UmmB{iATQ&e?t`DvU$Lau??_CtF3eKbmRDn=8!WXIbhx*c=a)BJwLBG|9v3r6bh9TT? P?+&cPJ`CRnf8yp}zm@*g diff --git a/mace-bench/src/batchopt/__pycache__/atoms_to_graphs.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/atoms_to_graphs.cpython-310.pyc deleted file mode 100644 index 45c13be10aa89b2fdb92573a28e27aab703058af..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10226 zcmeHNOLN>-cE-EWV6*uUMM-`M_V~fJIW4~`<2VzSIh5#8My53qr;@>FfNwm;F z>jEvQSzu}s*^8-)T&c$kHE5@th_FXGcGgCHiBUv+{E(mhaI#tZ22r>a5DrClZ@u zjn5_4c%V4vn08ZY&F}w}wV4b?Jlx&O>@~k-e;Q=ox_veF+qM@oJIqqs>vqE^^s*pr zTh`8AnA&MF}4KmYJL($SO&7@&EU>3h<(a0Y^1?Pmh{sN<_$0b z6lvKV*7s?|*+;|2$HeeGWCF>N!E}Ymj#^jhYF!h5dR-T9gT2n)V5{uh=enb@HTE2P z{&UUI*$eE&&n1Y!OY9=vHP%9%&0a<;vu;pN_9lCay~>NfuQcOBKO6EuG~E>hz(gC*4!CrT z@R+1`8?C{aXV5}zH{pHa)CPD-gQa%T>4AX|7cje;Pbx+N+xsAhxrzzcJQ!d?L0oZ> zNZG?wNW=BfPQv}YwwT?3hkXz7Fc@;MGnmLlxpW8!D6$|6ct4B-AzZYsIPgF{!y7W) zNqB1C3n5N1gti~p++*PobL2uM1JB>X%vg9gWJ50!3t6`_?{*mPmNWKzA1yQc`j0l( z#mdqkn+!GZLQdU-`=W}4Y34zu>`su~3xZgXv#AjV91M7Z1ws6xQQN-erCBrr#ga`_ z2lj>7c-lQWK;>42G3~REcCnRM8nz>==4MyhW$y-9TgJg~22rm#G1*sI|F~(z{0+XF zekdr?mi4{+uGlVn4Qy(CXn*S65BtNu9S{2*9B8n4x|`T^a5oKc(%TaX)wWW18{F(h ziAQuchgs6?jyD}V*ob!_k|Zr4lY&#&e`5cH{0vDqhy%Vmve!CE5`ox_0D@d9jH*T~ zEWvJZd%A;d`*P6rh7mR?wYR(|4c5VWv8R)6QDj4jfKQM`A{wySaWNquoH%9Ns^+&H zchc?I<6>t5oas}>t!jSTaVOoLJ?>a;PaS$%dHU!R5qRpZ&r*2yL;|+^?R|$x1-v>Z zWrNcvGRyWIJ}kLD?tpQh#_aGu4A+T0_v}s>lj^xEOp?$b5NdBX71G)*`@GY#@1g+a z$zFTARcgVh9L?; z%%tBy^x1qPrQ_ZXMi8L1va6&sAb#t1G9WYT%^cvnUNpqY{e(l@2Q+mkTQUyeUFb|f z2AV?Hls4(VQm0LsK3U9D&ir-~m&}hRU61dv?WME&qBMsa8T4mo2q)+NO)q|@7Ee{j z->Jp#)Z$dFa`!v6_Dq z@)UF2otXckhGrznlo3|2&;>WCeR;@5`>`>3yfO~beu3Iw+PO`d9fJy&t1%nr zW$}4X;tP~4QnF0RIZDn@a+Z=6N>(X3kEAtMXlW32Pwgo>iT0wW=p?97&#xu^5?+sQ zT-t*}f2r^7?|Dft^pNv|@5A0vgmq~!%J!1j-Mo3dEbkzY-Mf@>|5C@x{5@FQOTv1( zS>je9-(Xajt{cW-=DL4?5mK6gF>lDn|15cbb%TPraj@AQcPey2@^RfugSKem(^iik z+Ik)TV(SSH$$&fzC2nTGL>|b0CTG&2lFLW(fs!jdu%f3im1#_WsNT{J)C286KQMNs zgIcC%MlL_V-(~5SFCLg+n|WI~tQ|>Ay&xUbGxM+xX4Nl9M=}F@d)5~U%B{S1I0v>` zR4z+*UO~G?Uc-1bP^h7W3f9wGLvMN;U_ZTe)Ki(>B|OB~zjd7l5MD$mCj-fGyh0IL zDTAp^@?1uM>i{sy@Go7Tcqk$i4&Z$t3PRwRZtwr%XR+4-@&SJ*><h(>%3i!Y435K{kkF}Fp? zKv`4h_Yk=Ut-6p3r#6;@S|tX&O+9N9QJ}L{2}o1qC$ldaJ`UMfX!slGcTyCV+-lKo zH}C`?c7RqXY4jD<0=t*>MZs=M*}r&#-6wl~i_sI!_s1ZYE4LKVW-vbb1u{X#1?fC# zHmOHONa+DGydW`!sSj1sN{8wZ!~nVwDDaUA=Bre4B#ScrNPz%=(LW$Y*GNNcKcZtU zb{t5h%EJDmCwLyIC~X=19lVMqKx|khb_-Gv1>W7D&;`p2T|hI1b~g;}6>6v7ueQHX z0kQEPP-l%clm9MdRi2;~U{hytyxDFU5~uuk(BR}QpQV#=j5pN$OSDV@T1sYZQC?MF z#Xf7w{`ZdW@~J(2TGaRR4fm)XI)aqNSzG-jOhp-o- zm7eK6aEELPawgNz)5vt#6&%Gw>j<_&eBl5d&M}=C%m81g9K7i@7+DS2DM4di&}a%8 zX10K%sM2A~O=dpS59+zX>O-0T5e|LbU4*q-;(v;fmqE+Qt-~`kqAcao;Yx1d@Xw*e z*?jIx>6UuXIG8_Z<_+{)&6`Niq1Aa%Z(xk(k%p2Be*#CP_Z)wi%f08B#pWK$2MhTE zYFRelLPKU=dOK~~ zn@PW)01yvZS`tE!&o-|MpoQENP!Qe%?sI5%L>kiq!cS;-7VHL`Zi0v)JtZzX1DjBx zz@gWKHGmPW(CIdnwl5>0wPE^XOUl|4(6vB&CU-bvoZKz(=#(J(-5EEF1|%aj#+nFO zFR|Wr!s~#u&3b3TX+sJ1coyt6jX8<_XBCb$gB%H7R9nUD&Ik@-mJ$CD)s%q#==TZw zvtj3A0h$Z%idqoE#5KlrkR7^r@^Kf3x9KPSL4>OfPHG#25FMI)DRdk748*i!gs)K$Nmgu9vW(yAJ|R5Q22 z`x``}^?qr-b>RsV!bO~L1tz287bYogVKIc|F-q5ArZb6 zD>Sh|f@x+Uoa0hV3Zs0B;I6Q!Nfp4>guQ(4-NGmfMC4btp!1cc!jZ*WM1w_rY$j6? z7#2FY#|U`$ywn50SS*mDn6e?|H)z&BMN(KHR8bj06uR(;VXO&gE(6OYYDh6+p%EyG zZOM3Rb?Q*_TU1w{aEJelItU@TM!8QZ*`_ucokGVFG6InpjGSgVh!7c)BTe%!xAFt( zKwbcMDd|u`j;vD`v!d2w_JuWFk7E*Tf`U(a2Dp0_H4@{EC9y75bJ})aH|>V*ZzA?4K19J_&TI`4wkybfGYyUV%|WfjjWBy~h?drNSSM zIe+v6fr(4<5BNX`3$H$xGa07N%Hi3^7sWS%-k#ReQF;yzfB^@m&Go?pw ztE-^R8=|L@%^`4ZP&AOixIKXTk|S(>sKX4NgF7f4HbJ3=o(sK2^xDX4F@oL0C2}7r zZa7@dO%V>ipR4>(ne(q z%e!bqwjauMbJOx(KBrM#1){O+?a4-*1n&( zo2Kg@-2VfRf-&r4W&SYo_{^nuX`Lt6;?r%~Z&fFidQLs88L65G>oET_xMzrwE-1=!FwlzzL7wcPJ6~h%VHVQ_B!9 z!xHm1sfG}RGQ!h{@c<_YP;nMTlbMsnxhOs`899ZYnwG{O`)g^gQVTRS^jTif81z+#1OAOKmB6R4~(j{=$%If1f@+zL?< zIf1u|oB&-#P9U!$N2ZlupyWA9o<{;FiCDp3q2vWBP`9_XP)fVYQ!20sR*k)W;JSt7 zy8VO=VKdyK;ktM5gOaMI=DIBLU6V=gsHMrN)vHQ2F9FqF?Jg|22Bi#JWSmJTL<&c<>b-C5{hQ@!`xL)3kg< z4VI`3c;t+5>!rg+Xih1C4Qf-l_9ITejdE5_(-EZ?*BnFB?O_&1sgUGiVR~xQw%74F z*-~d=4AWJJXbD!%;nV=+n>jGFH8a~Kc2B^5<_}9gd<5AnaXw9I)EdeJ(Ss|{P)Gg(a*WL zzUi8BJQwEsmRpc*6WX?0l7 zcK4U6f-lu44N^EzRpm1!Qg@)gJqKl;*rfOXIVzEv`-;0j=17gy@2l=&p;FMvJeeg8 z=q(X_Q)w=YzOz^2!7dB7dfd6>H=XvC_M6U^D?J)*-3=VK=f$tfdOPr(54_i%)ganp z!@ybhqOom%+zS#XiHFRmjvo{1KsOA08YPrC!-!Dk@E&zOxbgAfu+xoU+zBH5VZVQ4 zWo=_~&FKaq%#~vfi(~GPfKi{vY}bjqKzY0iXOy<=@4*C0QxyiI9@%JW!urr-UKr9) zL_UnSw!pf= z2%>HbYAPHV;i1AB!0n@7UGBwwdb#h7z%Rffpyksy1}xqNV_hEX@?IQuRyIG(#&6Nc z?_ExqfB6>Zp%)K0G!hyH5p4~2MHRec$Y|$i7dUiEf~Tyc*=j>Ieq{5|S^u}!zIo-f zhb!-}9CZBxREN)t(7*{&pR1apr0N8Y;sG4YC55Zo8qu#RT;DdJ3#oEX9oyU{h3qhmi)k^nJ4IgNGrUYn-{_M{n&+hhZCpxq zGILvF8>CDsVYOcy!+}m`Qad$C^?^1nr=_$Et5u>UpsWF<50rJFtdFZeS$z(r0ds7e z!!wR*X=Qts%%&P?NdB|YI++{Kf{cyHT$+97(R(HfaKsl5G~lg~+LOh!fwHCbw6>^> z=U}C!@qC(NlT=CP(s{B-%JRGPO@6!(?SSkJl>IsKT2kJ{)R6N2?P+;o)>68dF2VOq zv;b5n>pz~B^%Ru>^eE?7UQdcryB3D&o|oWB5|_`;+2YzLl?hEXEn>(FZ* zaN(dIj65F1kuyx-us9N8oGy#|4)}I!N&ygsD?C!)`u+(#qOo#w^;+kX8=E(-zrWU5 zUHb%ZZ2u;d&5q$`~S2Q)L(BM^qBu3q95uFu*f7o1(T3Fx;aj zW&rU^K)m&396+5U+qrz%jPln!t75mYd?>RZV-9l$m)P`Bqz=5 zQy4mJP(SVuLdqcw9I3l`2ILxq0cSJ7ugKHrPQc=b)v#9tTLu0~9CZT#%^Mr*zh-qD zH!>jxlyy5kL_5?oV3WcSs=!iS=!?R&&9#p79Z|vNi2x-m9KaAI^exIo5&a631j{7S zP|^TGXTyk<<%-*J5D5!C2Sbv`Gdc(Xm_=@f1)K^ql1gd5$X-P*9Gp`ClTDSF&(>ZR z(?(m=lH0+clTbg72oQl`7$U7(=|Y^~s75>ml)2R@&C!Ukft>NoB6*+fLN{bT!^v|` zP44)imt+FDHIK*rq_Y$7UMJ(mU3{8604co4r7A$9ltZW4Y_r$VZW*s_8IH4R4u@wBVWY#nFo@s-;_6Ni$VV z*Hle49$SX?A1hx{ORB9}P|L4oX=Ugc`gt4$YD3HDaSp!L4_5K1HOvys$kveG(`)FK zdPtwGV5;U1b=y$c2C(MKWE&dz%tAph)S)D+fJu-GsgMC-bk;It=GsJu=xlQ%g;e{2 zx~J@^d)l79XYA$n%)NZq!uWpQ;yH}ot56`a6sNOF#2n9XGo2xMVv)iF{SOz$WuD)* z((;x<>_4lRGLI{1Wl}hR)H*G$Dxc@aRp{Gk730De*FJ@W+JNNko~|m9wX7tYJ1~y# zXHX=+GZ?oQ6;l4%AeCu6kt5?8sgCQgfBk6x^C$MNJ+nXVTLU@kr1p)BOV8~7$7gn* zP3xeoMz$|0K^&W7i=fZZPd5K|Y{66Kj0(<9%x(k3w3@1GMxE{WR&pwLMcWji{1k0; z6{_DUsq*>KZ?&;H);MHa_taF|hYXBEeNUZNmXv*cU)|UC^ZVw$v7ftF_yjcdg+>(M zP2GUlgsHG}2BC*EYU)C3T@h*^Vd|lUnLLKMgk3|r>sY-Dm9UTCX8A50SDx5UMi&4>JBLT7bv29uo0GR&M{`DvERH3|M1FEd zs5?TzxQDS;i)`qOv?glg@2*k<8=%wrbC(^*3XQa2Y?h6rsBPYeT0<+%F;=T{~m^upF>qQEzMC)&HB+Y46Ulx0oM>2wI}BuFPeagYBpv} zvCw{e?(uwq^k0A^yp2Rp4c|yj-$WWVaV4pwmTxKA zN~)>t+dx;bPS9?xXJDH+JssDu^XH0R$NEdHTR;4lvy~4IXxtw#?{3ufHh;ePOYd3e z`GYXujYG%T9>j%LWccMVy1knhG0eQpVu%PBEagAhv7m*8_3r6FJB?plufv49e z@Rp(7f~WW`gpr#VjWlZaHLU*?l_uAAwOv$NOb5z1LS>Z!4;x21ea0%s7I3R$#I>Nx zY-}E(N7^5LR@$&qlPeH@^cS7mJ$*$Rqhn{R6`$b>H@W?}_UZ;m)juvppJ3}$s)4IV z=ud#LtCdo#Q99hgc1JsIjCG`3)M9yo12&t2TW&9??`E^DQlFKLSI!0ryO;Tms` zjdJ01)`8g~>|Fq>7bm{JuJpy3x$+I31M7`)31nMm)@n+2X+eW^3%5XG8RAo0(dd6c za)~eTN%8EK(Q6d(+en*lU(q^6y>Tyxq1%e+7#uQLb-kpI(Yuc))S*oN5LV5>)3xyILyukr=H$ZzmEuX2xF4H+>fPb%>b=fysrM~*1I}oG)Ge-a z19#q5zR}A&LqzT zER?>T5%b>H?THQya*IX6E*f`Na8J#u(1IH;; zd<&C(=w_nQg9A`tG^y-fU|5hsx1+NEp1gyrtJ6BrqPj=OQ0~(I_oOshFo^NJ19G86y<5?ZbA-x5B9|9Pf*UI#*CYXO5VtL^Ll7VM3#ZKAYt< zm2{klT1H+6hXhuI0s4di2NyL4iV8Rlp|V^0^RnB{e4`q!3PlKdNjhW)q8ieEkyA!1 z@<>hDs0p57oBJ_&EiCe87-m>hAjhJ;?kB>;Aq#~y$$7d68);Pz(krwr2N@~hsKg11 z_>c$``?iV`nC5pf2#< zPyD$!+Yiz>Q=SgO2fhVymhYMPL!d7RZBB*W>wFMyUKDyb6?*7bWmYR{a7{0qg)J9p%CgStwK9q@XW$0(Jep^OAkp4RTbDPBKbUKNZa0u=b zxsAN}PNQEyuzz@Ze~_o-ei|N5{%nDY$UP`zyO6~9hX-tsXTjFX7Zd&t$)dsif=2gM zj^slSnf(@S3=hQO#d@kz*9D;BA28UijcnaPj^UyPT1Ad}bX`a0JF8+D4b(Oox{2U^ z=Tus{{?0Tc-8NhkWPWTIuI{2wjWy7+fLf@bJNl>SBK0eyWmu+ZSja`rkFM21ZPeD` zfgP^lAh1E*Iw0xO+xcI^Zs-j_UIjdf$*7*UwNL8FFt5o>~Rw|l&=Qe!+-9 z))DLI2Z@JrV?J}tTCf(aQ`TwJI%R2hr25i>PqQn2XPY^#jZi<^tn07Zp}C;n^qNb$ z=~=oH2D-Vv?zoN_+QCvbySm{7df<0i)7G25W$Q?~PSf@R+tRz9Wivh8u=Q7Ot`4*H zbszb<<55BO*3Fgcx9?on(Iwi7Je~P|s9O%Ro1xFPb$=Z-_oRf(wwJP>)0WodC}(<} zADW@#dqG`^3af9we%)Ak`TEK$H*dcbsaL)2+o)B#;k3HUUNOBbGq~DnAuAbYfzz&; zp^4`hv)lfbZ3MQvZn(bR3F>N;3YumZ+APWicB^fBq0#jHeTjgfW$IP4({ViNcFe$9 z`RxuywC38c*xNxg`TEFY(yUa8E_4U9GgGgKReK)EOaxEtaO|NMiTXwUVqLmnYig=QKuaJKKD;O?+ zN9m<{X)f{9v;>--mZqd$hHHI^E1U9;+|xp3Q{@WRpnu9EwWn~kukA?vgr!=VmDo}H zNv=@Njx6#LJ4!FhGd%k&^^@65p`P+c?&Z)=I?V9gLz(9`aSk3S9B1RMIxY3`T()r5 zK2&&qT|pW(cplFj2WO@JO6e7NlD&gok;@BK`ijI8%;woaI)iip`kB4WaA>8+a->>L z`}19VKBtV&)i^h4*;`OPC_>h@TDIXp0q;kd;d%v){E(%ki&BH-t!EY_0h$U?tRHV^mRPK6-3ohs*oNhmizqmTJv_G;*HrJ+PT?IM)40 z^}FF?i7lewi#z8x{I-3*Z9dpA{Y}RN*R&Tq%-_TbI^Wq2H+;`nx$}CQzh-;Qjq?F( zo?kP==7!%vm9?hPVrFN^y2mM{RE%%V`^J^S&{fxe50Gl@#W&+}l?_zB(p1X>J@ zl)4g2xA0c|Yv?JLHlO;Ag3Xx|i4;A)mn1Gf_o2ea{OGYY8TPVLKy>qUY zMoETB2KtvK{evcFcckF5rE!fz*X|uht=yAqv71(6R4Z>KAIk4T#wO`LMuC%%wNh27 zU)X^*hl$N1(rGIL-Ou8j<QSje{N@-h z@uZbSk7bmOA*aH}c+tvP`9~@*OiDeGp1=tKUq$OG&sb?**-?A7QL2j6INH`Q>hX9j ztOBjYzI2MxxIr>_3rwJHd0&}WH$G9L@`P1FjY;ee-X`pqs5!wW_~fi)rBGT%*%ahX z;%raVu-3h4kM?G~KV=mtg?5E$=@Xev9~zI4e^};VJTpSd%mFF011YmY%G?1dbC7~J z9;DnlBqg57p;Z`2IU=MSIUwceK*~`e<=6o!$HHkV#i#o-cwa)JYSKelQtvo^C-6Io-#nzuin$l86!z*=nC3dIZ%&Gz7Iycc0-gRF=romK4MiEAqq0Bm zofc)7xhTVPRQ5-`Gnm~RPxI6KOiSfSp0LKSS7-Uz{t>Hyw|D{X`lFoou!s@VF`{Gq z6tC|nec%f0+*4>@;!oj?O&YHD&hc@64m!7h{4%fcg(+x26%tARdy8m)oG(HPWUQRh zTSEE-U$iRV&x1e7m-u<`PlM0!r+4H}6o#Jq^SpwT(z_t|s^FIezbN>N81E@7jWc|j zU)VfDYlM+2y=VCG{#kxuN0udi0^G9TPJ(+zaP#0U68DLG?^FDKg5QtuE;;^Cw$gVs zNN++)ogc?cD}p-#j?Pn^p9FVFaP#13W>4XK$-T=M>t)uws2ekjO!uOMB&Fo{MN(SywQ_ zZ}4wmPS35YxP8ZOKUP;)S8>CwHQ5AS)i+kZu&+7m>ox;0((zXBHi25~U2n3r#MRF##7%;T z``Bl|0tH5h68Ot#1A($}s!;=ZtgtFoOx?BZ4)DTbC6YTL3nlSoZ-v_(`?JNBx&Qpd z<(U_e{oB--${r`8q-TbHJBSi7j3`iQbR9;s;!f+f_uXsZ1}4AZyH-@&fAI~kyZRD4 zHSC_zZOX2KM5zrk0LEnO2&J-S2)Gf1#wB#z+3dE1ZkrvYb}Dfv1!p;1QC>`fz%4*QlxQ+P2-s;9vlCPYSjBN8pp*9u z!WcoM0{KKL<``u{j35B2v7=-NVDB@gQjB>8ua+;=V z0O3Z3wi(>Dt)cWx*Be$(#yBfV1RG`t0sy3KG*E%Mqm;>7kcQFP?M}EIO%NnGh$N_A zv<=PX-TL#9W&z1ZNoHD3H(=K=F7`ALTF1Sejj}>7jljWhWhY9DeZ~Bus)y|qdv0`W zW;ksVux|1RnMQDr4OcYAkWn>W-#BH&OB2mJsqLVwI4m#Tt{k`b6~JJP61Bfn+3jy0U-{?u}#6qk+g@-SmG>6G=5MSJb?s$@uPOf$o)EO(njf789uRWnL|!h0Zt--S0##|e zG~x={S0GWr=#;@X`(ODq^`k+8wo7BnD(uIY<^g+x3=p!7VH(Q8FSTE1PyA3C*%QgH zVNaxRo9?$K1ZF*8Po$n;PwcBTuqWbL`|SzX=&&bXQIG72EKop^u)MG*GC=>ZFknyQ zcouhdY+Jw%2T!&|8JygMWpT&`frWuKB~g#84cHDJOTwNg#nuMwi85^N*q+D(&y7(U z_5|^SB&s|OEHVa67+V*V8dw)!+nylX_&$4L;($Fd&L{TS6O$MP*%OljgHIj6 z;8VYmJwdj|1fM=21vZ9|G9#qS9*{Eo8`%?cLdx6$DX=kwlp{jQ(F0PB!kU=sPm?79 zZg#{SjXBs2$3BuilGhb$%&L5l=EttV;^^s+GQ;&hg+&qL+1?3Y;aRJ|PXOPk{W;`T zfN~3j@Owa{{UiK1w2D&4Cj8}e5Nd@-B8R#P{TIdmaK45E*4HNfp z7US0W*^$*jR?KmJ)~aIMuwP(-)cI53&w)P)JNg`W*wo-Ff+u}H$xn(=5L!Mr8qXeD zj@L|JXRJ8YKP%dv4d(?)A47T%r7sA3A;%X4N^cG1F6@&_d@PrCEk@}JUqkLGjCo9; z^d;C*u(bv#on63dF9tm^N@o`-;~664E!)G;>=Gp}6B(gq_8cY28`Gr9!w(^Fa?b!! z%fSMe7F-AKRC;mX zy+A#F6Qq9nN$8v1pmIXq>?I-sD!xqIO(HZt_6iZYOWCVLZWAGh&0Zry5SzVDgU-8@eYwE!Q4$MxJ$$((k9{&@riVZ+#|w>1Vln2T_RgV-Xn6K$TpD&M7~Re6X_B8 zLn7ZJ@_iydAo4zuhaimtfxSQ=Fj!vf5%OSN0N{xHMrDMA_CQ)`FHZnF_5tkTrPll<@td)SjBM%O~Rm2h~B9KRvA@!kh zMGD~+<@tx?`Ok9ot{mJzU)o)ny@mRS0JP+=2yY1dOoLGrd6C0oc}GHdiKlqBCZSFl zc!s>ADmR#Yo3cv%*XLad?9(T-j0K zKPvSnNCMhXuamq)alJ_ctX52V(fAa$4r|1WlC zas;mK{seHIjQe)R%JJHcf{|3w|12+||0!^DxLc<%BIF{PGgzq*K3}Xbb^yL3 z4f9NkRT}I^J#Rh(70T-PyyU>wW9nopA!uiY6xTxlCD0MB&?zKC!NR!2Hk%tdnPPAd zA*dMYZ9fP_C-43*o|Sj?kx8^9DqjuY=RmhY+PlI+ewTufWM9E^loge(i!cN{M5tzC zboa?wB|;E-)ASILr0^vbnQYC0y8}^2L`Gnm>1%fQo^5-gm*D{m?RY8!zZLZWmyM-| zzKD1Rh+JA|2)~zu2+nU&_oBUh-)whWdqH38x_2qJE&N|Po!XXzIlQ;w(3&t?IKp7N zn+!Kah>k^jdOzy){u~T$FuAMxaA3`CR2`grF|iHX?dTn5Zwdcg*RpI6Z7EtMd~~!S zgOTCBL|n+$*BLe>Xv0X~Ur}rmY-oq#tP3=BeK;A>e&~SHZSqD13wq$1YdW)C^S+qk zf)2kGTq-8H2N!fZY{r`>PywQ*Y1|RQA5+Kqr#BWMj~EP8z};SDFH#1%{e)+)epLAA z#04Cs$!X<#FrzcEYsD}FmZhrV%1ty~EPL!+O9aSz+zEVxE}T6q1~ zw@BU!5pjY4ggAOn$Bs#HRnn>>^ed)TIOrT0`^ph%OdqIWjC_%{)yNA^7I_;wa5IYQ zk}l4uF!WUzG=;`^JhkBg8~R=5o|rvwaW#r@sc>hs%}$ivOE=2HQ954J?OGT~8Ey&q zAMutjn%xji8>~%$fenhsX%55G9x})Y_sn2L7ip0Mp*voHqJN@mozPKoDvc|)2+sk$ z9O?x82Wecr72)SdtMEm{{)yeB29Lrn{2yN`yG6t-E2^f};Lk|H-+&yIzFDMe@KwNP z0$)gSH;b#gmO$T%kog(JN%qfwewBSn!*~-U%82<=fS2536u|!?FT%-7tHN*ERH@P6 zlf;vL&;dCh?V}iZR}PZjmfn{U5mH0sh$!X^V)VF?Llw8EJPnJFqMXEW?Jh;^5Xt`# z&t8He_dG!{j(d5uKzc_e@O?WzXGtc2tpK|tg=T9I&6*(fWDKIje13+EMj7rekLHAk zvq6wrU`qJIk#yoc7;QvU6#-+`_2B1?gYVY;v$z+6Ji5xO8kX-gwk0dGe;{GXjT?0> zj)Bvu6M=AX{$MLSO&OF6{x9*Mn&ND5;6>ivB8Or008a39o05AOsoDv6ZQvJrBp zJ0v=Ww|K|(L)TebT7w@N?(1O47-p-*2xvw5{r^CTaC}E?L?FawqzX~(Q0GhJ6=f+Z z8b4gr4IMXNX(}tkWrM$}unh9z7a+aaq7wdCKu&-9i_DUV&<$-!V^XwP^1LMkO}Z}r z%o9IZflABQuHU%&=B-uZ`t7%jx30c%CywW*#9tsTKO?+tjPNI%&9Kh!=N%TM@HV0U zvcz@@ZSp=9+0&Esl~~XV^o|cmBQ+_XDl2J)PA&c7Z%x+7PpD`~?Z1_RBIkdlEWj5j Y=f9A?r`#BF$#Y6lk&9m__sYBf17^8TM*si- diff --git a/mace-bench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc b/mace-bench/src/batchopt/__pycache__/relaxengine.cpython-310.pyc deleted file mode 100644 index 778b24d1991e2b241a7d3482846fa64926fd60d5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 27428 zcmdUY36LDudEQJ<&$+X+v-gE#0OA@P3t)K)0#A^j7>O&I07a`MdOXAPEv^#JD02EDmzLX+i~KnW5Mk?;Fo z&$R#{DdnUpJ2h{*-+lb=fB*ac@4amd4rU_w`-gw|^2)Qn6N!9*FTH;$yxfPc`xDVf z#Ew`IJ8B!%Xw9&UD4%21m=(jXSv6~MD_%=j3Gt6tleLtUlIKJtMb~43y%@I?3GK^itCjY8`Wyr za}@G0A2|K+nP7nK>2);VU1#8#fA^YKsa2Zgb5-YovbVSrO!fFYQLlKXoN84R>&U;MnmbRI(MY@CH42s5T0?mS9i>of*y~ltMOhMNN45ww125rgcN(C(UBTC6OwXKd%fy0ge}>-{tFav#3# zcffHXRut^Ru%loXh8=r0VwtvS$MGAt6Lu262|Hz{@td?W_5glUcGe!mZ`#R#xbyb# zv(d{@Yk(61zFB+B9>?#XJpp2$3esm5R~+`XYM#s#@Yi_WJ!xCuW6H|eM0ru+=tv_Akh<7eRx_=9k>eh%)SpNE_Chv5!+`PJdg5r5=rRNdu| z)Z_9#%6Ej0`J-^hn+Xrx?=pCo=4#&L#(Ojm@uoJXZKKOW{2@EC9#xBeWOK#`Cqn93 zJ8}kmy!}*+Em(1Pt$eZGyl3m|Ia_aIma1Raw|YE)lF9HW7twX(Mo#4B$39BKwf-Uk}%-rBlX(t*sm zs^Xx9Ap1CGLZ#+Bq*Ox%$!cSH8F>ZCi)B^k!!VL4*D0N=pvH?0R5C~{uC149AiUyO zV7g9qNlhXrHAQEd&J3MdI@{=Mr?Z33PC5lTyXfqua|@k4aL~J*GSpuB?xS;v&S5w& zN7QZjdF{TVD~+0S6dX_;U9#2DwGFV*dg;`e$BvQ>969II7gvtDYVoLy7DnQzOc<;f zU^R1V8$qg6s(>+;N+*$@dk>t%WFmUAPbQYQmNCpX5^=-GMzc{fYD81ljBBZQKAJI_ zBb|9V*Uoy;K(_6_G#>bsDfm5j@S~d$Jg!Dth8OjX%OL7i<9tlzNZ8AfR?IgZj68L$ zWj2R>6EQIz-fy(xt%Qe3dO6xk;x~cc6edsJjCkh8?cwB!Zzk-Ro;+A;k=wLV@2V9qh*h1s5I_*C)V(0{o^o|vla!i6#31enc_qqgYe*i% zFinHFD%Y2tVp`23hPoAwl@Kr2N``d^qRL9^#S+}lN+K<0jY6qhK!X_Gg(kFfJM%dP z=V^Rh5?sc}VW`Y#3d5Cy%c0Bg?adj}j5cR`hHWdAFLJ49|B-zT;49;oLr8PA5@|)f z$Z8ZKj&X(KXmA`cYE~=Oe1{jaV}9%t5!-wzb~*ZdzGZqQMlyaivYEIV!N|p*jM#BA zVn!~g!iV8HTw z5rSti)qLP##}iSih9M^PoO8U3j#Dpyp;W-&Abo*#6v_~-AXimXVQpo@?F1AadC;9Z z`*PGG|DkVl_Ncqz9u4^}pFO8`6nux@=I~LkvDQG;jib!ARIwMD+q-`4tiib!&x6Gx zyt%C_7c8TuOjZc1dZJJq3=(VZx>u7TL z7aHDa3X8%z6?sXJeB|Alk)|KIwCG?BXw+o_sJjtbObMw6$*|yHFl=qeSEH36x)dbC zCIy*p^@FIhh0cTBCA6c|X=-my;LS1y&!J5B4mgHkl4}`8)G#tp^O|9dM7OWoH#(1-NxMzBN!@~OJ%ODEn>LT;yPw>z)Q}nEX{KQ$~L!MfXxMTP= z&qq|zGd81FBk(uu_zY%P=)c=GG23Ex!jF;~DKTFV2&7m_A=%<>VdFgYT>kc5RE=7O!5^j1kJ(| zT%rq(i?tTwyxmX#?umjTzbOu=hw!6L!@=slUiB(i{<&}9QXF4XV6N&6W8BLa@d{>z ztGIv#iKoeO-5^%BZ6V#_pr*E5hel)ysz(vqf{cNA6Viz_(icURc57H-m(Fc;Xw1sA ziL7E$7nOKHtfLh)Nkda|FS!z#MxKZShV zyWnI&#d#xTq`<;WGW1v?2KH`%UOS&$)OaQfU##ZW%v9=)Ji?p1-inetWh^F;fBP?s zhY@^*e6ghmc?cPuZ65FPPNN;l8hK(A4t_v zH~x|lLyQ7)Is28(ZETUu`B=kBca&?f!Its>6!ftilQ5M5vXt0%ETvy8B zt~Tsc^*F=gb#jm(q4zRDQm1ye&b*Sy(xM4Muoz%bh^@wLY|O84f#}9g*jOMK+m62^!ZE^uoI3@1o;k_6^zqf%E# z>1#4|(bQ}{-Km=5&;`-8twTIk>g0mWrmQNC`v&5=+^=Q9RY$-@v%*`k!?~7j4 z;0^0J7+B1YVl{?lBEz-QFD1!Jz4Rv36MVX0fN>6N zJP@)}rpnNt>{2A$?8bM1sY1cq97djl_JEh;UaCz>A!AK^Ai?-xt#}d{H@k4Q7;9$g zPgm?p`G^aZ+nyswFbUPhk>$1ZyK7Dj+Q2#LUc-YVyZy)!y_33o(S7=G9sIE3 zx%gedMjm>e_`0`rK&+>dK;V|2AM&N!yOV0OcFaJEp!yvF&X>=X-;tCAkIcsLFUwStr{PWO@kQml=dNX^7nFdA_BjpxMi~ZDQPPk!oa|q&aYDuA4!*1|!5Fl=GmACzrz{AA?9^Uy%H zOCHuGzhqvHepmFu^H94OtCLjHVv~~d2hcA$G7%rj+&dxmLvkAho0)ncx&doWYY5B) z8*-7zhCVQK;Ui!h_!WsP>>c-tpGFVuL+?#XUuQOF_={f5`vX|x9%&6X_j$AaFjc}= zjOWK%Bi^>n?bMpKM*aNe4sWMFsuefb`eUmest(e}(erNK=Ko3fTsRok&kgH0hMq3i z6aFyzchcMC5B18S?V2-Px9(`A3Um2ln2Qrq_n1HC$EG5!amjVcAGfDJVcIkO`P#GI zZp1gK@Ia~CUW&CQ(839O`(@+#<1J{2H*c|bglTt%D>MBH?R@NAJ-)k1XUM15?-tT? z*Y5^B^P7EYk;vvAUzA_pMb;jU(UVZ1Q4VNzT4nF?_Fje>=>$xG;^eJmSqQ87~+cALyfs+x?lFQU%iX_SP2aGu-d#?c<)Fx&uA64L!B3(^G|> zp1Sk8o??Bu7lCDBb4u&}yY)G8I|_EA}fCtW}N9{cvEuG-=Xl4$E$}T@rgqT|{_u zm-cTg${x`!l$RiX6|mQ&iHC~*!fUa)`K4y2piNWQ!drQ_ZHp|4_*iP1b5H;U(TmOT z8)kig;Z z0G%JC^Fwrgn9e8Yyh!IIIzK|^AHiwPK2nFMjJ+d40d1=cun4MgvCtfM*K4)1+UW0} zVl2qQM!W_QS{k6W7To4FC%f~8cE+khLPT!e(VHtC?b&~dzJm%e<--ZGm@i9OK@O@o zPXGJpOwtkK-w8f(AFs~SVWH{)I#g4t2|6F9(|}{8NYk>#;Eq-{yXpv@f*2;1-Z-o= z-~ye8B>=;3Wge_7dXLKnDo8k&Jld2@Dj7x3u0F`v)Y%8|iwZMMt@AliZNReNike7q zqadcmK=OK0q6GsIUGL0b39dHEUXapRG%f}yY;AIUVOsTWzGBEAxUVl)m)L*xgZ9f1Kd!={Un?**`}pnKuzE&1%+)k zlQMEJPB3JKh6<>i^mfkRhVJ#mo2d*Gx}t=lf;Sa25rbAeNngtlPCQq#nbj zlHVO)&m^G|d&c-P$&n0wMTFi_a!(I}!THhBk($xDORMf5Q!}-YO#=gi_H4$@U>3 zTA%i&kiZoaX_B;`UdT(?DT=zHo=z_`?>f_7Ogl{=o3ovB>&t`#IfVl*5kJtp!aC_( z0*->#y5ig-8`S=5>mzCHb_;Mk?8T{5N6}GyjmvS-OG1Ga79CxHMs5)?d-TGe#VdS; zN6Iu8rdWaR%TdilWSVe>2-Eyy&X5$|NQvgPY^-^p(+G|O79O->?xbdD_snU=BD&k= z(Bt8%9?B5vyP4}p=sX9fdHelPV}`45VWs>uadGWsKwD-tq^Pg2=3u()=zBE7mvvza z*Jrbeaor+e{?tUOe?~`|Ev@`ld=+Mp?lk2~Z15NYFeQx~rf8c1+HF>!jh9pZRyRld*sXJJOV|2up_I^G|ePx}X9N<+J zwF+3TRUAO|({#H0tUF{})9%Dq^r)!_SecnRks3+Ori!WK#t<0HROEx%2W9OeW%T`# z!tTS@t;6|d0A@}SKr({M2qf#k%j}#z1Rz;PAZ8J30I^1#?2=)R+GEc`t6>c~xz!;| zOX4k~R^CoS-#sOS+I**A#-7wtDv^v33ITTvSvyQIK!ISWV7)I0qK5KYh;Qlvp@qn^ z5j3Oy#}OqNpTgIDKgJj{=WiW(=Am7ic_PEqygvlA8Q+IPq}Yhy(nbZB26QcvXJkzB zF<&I=Z9;Hlw0sSyPt&#I&{cgJ`ZL6pyl4>{v6EMcmz#tStz(mb`#~Em^dw!UU1%8MQ>3taEl9EdgX(~yOnvq0`NmLaE75Cht@aXv85s&%JXI3=hb+T{cp zLu%lH4sKCjsw{J_bhxlCHf`HkD#IY<7WZ!DaNiOHY%gl95nucHBm!EIPY3B7qH~ze z3Y;KI5h<*a3Z~jm9Vr+RDytdfJ0WOh37}n(B#9S2@8At;=2f>*ht9b{R9&Wm24VKs z7oEVgA!3RUX$}3IGw)t_P&DCM{1a9vCYZFgX=B0!@rQto1@s9Rs2fDoOUUPFF+Evm8-jArL^|L@Fc}AcqZ<(XToO2stS=U%3wL{iBDavWLgD-%)M!N5 zQ8~`CF{uHDT^+Kaz*Y@=(9Tu!wc%D8EP$vI;8zejVh>$0TbX*&8`XYx{z?ogbC}uf zVWfN38}~C*D~poXYz@?pO4tO#KITmzj3JXqGf=0xbIhJX$d7oF1mB|l(@-o9)TtsH zvu6=}C-LT0amX@$N}i~Ct9adS{gZ$l5oz3ou{ z$1bBkpV|&3ezG;>4{h$Sw_lCgJ78$YdplRbjF8K2dnZ)VP}JJHt{PWktzkb2#i{`R zFZ{BfS>20~+YQzEh(EBp53Db;x*y8#3FII1MyK?7 zxEjIRo)^=kgA3O#?C*?R;hV0}I7WCJJ=Oq*)i_i~ZY$rNnJ6Wem=pe@yVOKD^%t?>#WL@Wq(M8xN zFuuKE?eRUhWqhan$s3F>RM7qu#&^b_xNdxRe1q}b7mn}#t>cSc{8@jZGr~}VV}!Tq zk@aUUbAD}qe+*;0y|tq|y8ev64Wm0V8Nmo|gS(xh?C+4V4o4a}|LKMOz5&7Hz(R96 zlt*Mw4HZRK@g@>&GpQAFbEnU~Y*i6IT@pti@$QSX5zc1gbAD;QP`@Z(73Gb^3lfpP=)@bf_j1g_)IrLS@l89{$e1<>;#a z$|SN}{x?4TI-TF@_{r+|-|_m>7%cVMc<8G?Kua*}r%DilrQU~nHb>GgCzz8g`w2c# zCdOVw$WGKEu2zdtqozJZhs&Z?kR=({fHi*`A4Vk&%jur^5E9S;dQk)KVA?;R!(mge z(BZ&{mP&n=PZT+V98_xIkrxtV5NiTlNTYs_MojcT?hJk-I^weI9~=R;JT3 zP>wY0IJAGHBr7AdDuZSvbyeoCtIK?EUFN%WneVU5d{061VL|hOv2+~Yc}9 zw*^q23#1CRlmIb64*+?~m@oju0;rMNDnNFjr$UHHT~=Fe!v^&0MiP?TwM-mR9N=1+ z7$rFUC4iDxczj|b&@WK;2F+O1d?Opr#u5k}fv%~YmzD&fdyAT1OI$MsO*7S(CKU(Z z61lT>S;ED5EW%{jSyR7d5vuiU02ZUxH!eFTWl!{#+?G|%uTjD@hnkr#%e+oY6Gwk7 z;_Fg4{}9AWJBp(f5a9sJ*FqmaMG!x5v|=*yVPiV-9_+0yq-E!1ENJd(sQZvukbohq zN>5@@*DNq-i-7$V`295pi8If5W=xcW&C#%$eG~pHVl4F4nyi5n_!Pb_l?%;#2$SA4 z#A+;{DEB0NAm##Z0&xlgF5Up#mp9W(Y$k1>JMGvk4jRDz9-`aCAL_E}{8Ioj3rI>s zu2hD$)ibaT)5Wecv++LI5aR-kN^Ao1>}7e}q>We!5f1SW0FfflC_HKS|AP7>)_>-G z;9H9;PZgkJg07vG1&v@jQK&>B`H1>n#xe^8Els!M!tBLUSly`V5@G?859|*zvvAbO zE?b#WwcOm0V>>|tFgbunt-J>Q05TVrh9kHF1>!m*dG_#S;k7N~+=U~FouHUd&(K%+ z_OIZ{N_X>1butqDdH?7K<0Yx#7FGS%Ne{b`5wP|`?mU9R31Z4L$;8IKmdotc2&8Ro zO1V)+eTx?ftV1k+1eVa%w%P`9@@mF2R%4zC0Hi>VAW@p`Iu7nxZ+p zRrEWm{shSZcn#NoIg~hpxkg+wUnmnURW8A#iuv){K7oZ_o`Iw<;_g1wG|6|jT9bt6s`7MlCc;OwTHZPI4$w`%UbPOscN zUT?REM0r&G1$*^)w;4Chpr;*&`r2``Q~-SL=;4aLgaU)aSvggztrjMDOP3jn7CSoHz=-%E$kP|cMKd?Mab zPKpL&B^L&YHWmQfwgkIkc#6uJB0=1UZ(6_*5D?n7od!8gSpe6QK95^Nw2o@Ylj5JN zf%5&a=7VTy>uPVKVmV0a?mxgo#-!7-ITY&B2?cH<2EqpwS!;SCgL6!i?^dIdZ7^ zAig4V{5%E)s+AZHPC>%J#(g&wWN{nNYv4?Aw#9~&1Q)ndxTd&3jq;#7K7HYK>P4Io z@e|NsC*%$W#|ivD&lR*HVbspdV`om}AD!Bqjm z>v)pk)S36_dn|Ye6E`ocD}A2@8Sb36?c;Q#oM#8w2<{~j0$y+cUNABuCJUER-v=t{ zf>|I%$xdAMyeKBAk{Sk~ z0)2$ZALQD8x?N}0&#|!Q>3ob1r>5G3vw*|tiS;#LW*qftdVV{dIXWdevPBp6^%P&< zPDfPH#lAJh49I{yStF{f4iKmtb;PmJL>PT1f!`?3bw zY={v04>9GNaIAFLA{^mKv(vGwt*rrvA0jBE*niH5NqCW8sjE!;FX{9Pq>3k#zmp8= zSp=G3A992P8hriHdpXTwvahCwWDf{jf)33H<02LVd#%60UN6@hznw?+)$9-tlj3-Y zNz_FGtjD}I5tE(WAamp~(#)65(fps9qs`fnCEbXT_nj8u_6;@HjEO@4Lwl4*?J!Mo z){+yJr{b~MiH=3E-7#j#v`mnGY+?D+-}{*#`rMOWIkCL-&R2i+!@vK#4??#U0~E!| zQ00M}3S6iL<@xDGb-m_!!ogx3LkKZjjAl*w^) z#>|=MwODuS*Ap*D;w&2$s*OH@p1VJOPvPr+7aVYgC^$oSJf>xkGy6QsdVpfL-pks+ zTnm(9J9q?69D*~NCi?AsR2|^E9h=07034*fwGQ5MVFKrzApA2Pc?k}hM=wlV$RaO@ z?>KO4sJrbrkg{ShL02XgFptsf$Sd=7xZb_;AK+A8fM=zE_7)SN{_wZ%UDk@kSMFW@ z&%emOKe>1LR{8tFy~{|0zu&u8?PMoXovJwpU$L{kj~+4s@4*w!lvL^{7&jb2BP0g~ z%c_E9nKRZJT!P^eNF+UZFp9h*M{eK{Ak~&B5Delz3Y>W<=~TbKA>!_x69Ll!`|~}7 zRNw*teM&TCG{-^`?Mt{w1zP)$@(JIOS;YMuMjS(v#y5j~6$TBrGr%f?_pB_U)z+XF zUB&ht9Y2UOOaL)+@kE?l=#Pt2y8g%kRxW7SjnTsB*-aw-t4?1 zEPPMml9Uuc^?PBX;U+&1Yt^vNGf`rbfqf^_KEQw{*(u5C{W&p2W+Mhn}s5r3R^I2;?u+e3slh&6Cn>~tdtPw8;P zn5=)?6T2GJBvby>X2u?84ebd4@grDFr_duatR23y7^`u7xAo>KWkX#g2U!k#(h5T9 zTig9ve>;ScX@6#n=Kk@S$axdC8MA%>W`PM7qy@D$I;K?&szA~@U}yq zVBKTuSVJFnHH5LvQm*?K-I~I0o9;7|bpQ3e<_$9(#}If&!<6o|wwlJ9vZtl*wuLs# zof2!t-`S1=`_0_+!X3Oay)DAme9}OXh#@|-<*95(}QvCkycINpLW+=K64JGmUyeSo7>`@DUd`>*0? zRmW$)e+!=V*v8dUc-||0YvK^V=cN{yRXF7Fy+Vr!!w8LTq2dG5hx`2lojxSp9=IIs z&@G-hvZzlH--Gxb!uK$G^8m~6N2hQGZ0?aYiE5jlc3%puNEw5r%T<{dLY5a*-h-kE6__gnPrY|Ue)qs4DA z%MJClI*!nUbNXBCAO9wE{RPAQgPz$A8o}5XL6tY4kUqM=v9ORXV(Rn@u?u(mHyTy? zZyA{mrvQbuX;}E^LE)pvh0ksS?>KaG-XZzmm`*#7HYe)kDWGsY-(za?Ci6X}e)1d5 zckiG-C-WhLb{`S?1qbj)PDar7pXqV~YzL2kgT&NtdWZd8_MCmxK8E$;_?7raqxD6E z&P(Xu^v(S*OIvRBkG9+5?|Lp;I#Pe9H-|iQ4!2z~k;4PX;dbO;szDF4Rc0``+A(~O z<2!HPaV3FxbBK4R#JjDhBvb8ui;^b1TW^wE(z^`@n-Z)Yu^j$Bl>dNtyT9*pw2t+( zb*p##bB2E__{!vEqjg)mZ1;(->ctU5K|;0PS;>ZH0VYMS^xXL&;oAcdt+53BFJ0oL%kk`8wY!dZ6Q;DlTO zz*t5lyR|8l5fB=$!y{JE_W?hG>!+XvsS#CxZwLZ$a9PG;1vkcR^t9r!vchH`iUYS( z-jU8<$4hGvsXZy2!0n-~*9j@$?_Kbq@|-(%Kq$QxDSc~;G>C=uI)?-P;hlkYCnLS- z=~aoNC^WSbfh9dAx^mr#a>-wBnYy1kIrsL{Mcf|2VyP$U7Dx6$3b`(CEhzD0tn^d!nqCL?FAfKI5v0ecqbDSK}8da!B~8QLI7_G zkBs%vp(v{4@Sc!qUs<;vJ9ey7R@ZExyY4NDJ1v%la3GMz?j=e=IrkN@yC0mhyHi}d zr`1hMdw2VG$Zx(tf%A0B*>Y**HyE)qfMS$!=NjmVS4@Vb5^GbzbroC2sXLn=7bW3! zlc3Ys1x%{W3rX=T?$T}dz;$y{=!`| zn&j4DVrVH2YH0q|ligW`+wQsePRPqC@>Doj;@V zWjM|F9FfvmjdQvrVjGU8aIITz}7x~ose zesv?v-RZOHzaZgtwWFSnN-R+&N$t`a6R5RveYilt=6(%PHKOYls=sakE4!;*$C-%w z`-okP1_>}lT%uX|60}T}{C6N7OKgsXw>pK(l|ChVT?MI`V)uPz1ddcNPE5dRq^I)$ zOibV&Biwi#_FxbfUD_56nqnWm5A_qA;s~!Eqgo!PYnT)#bWL8es(ux*LWAC|^HAw^ zfR@34Q|DdOCEieri;Xq^FNuY=80*GUN6DlV&4t#xssriOAtK+W7-}` zUyYml9Tcwq2c6%C)6CPJr{Q5-XmQ%O%CJ9zb6rnS&&-bfW&wb&q9+6ZcwJ8IK36=? zs5@~VpAK|hg*ZDRrt6Y;oKhnE(^;v4IxQ^#1nSD0m7cdK~(JoBZhSN#GV zYCF{*!zsqBF`#X*ZV1*+%4CtjP}?gkB^Yj0?XXdy7jA7MPExZ{=yS=J!R&fi!imCl zy^2oKQ*8ndZTPLvqjoR|_E2*%s{V%c3z@+PT7^k8q^Y321Svf%ogsJ?~VEQ$df8$eYcV~8f`+agc4`cH2b{1-sBQ-`xOc6iF02;cQyy>lnet=^b z?)r#e1g}AKHE&kZB+#4%Z&?OWU<}&5V{02Y=;M_7Yi7Qm4s9>m{6%OePgj$x2T2$- z$b7T~j0*$~T!YbT&}zr5#^Xh`239aC-L;CJ{jHL8#_8;mR`F?t&acs7->Yc`{1u-Z zI>Jz}Pr?%&GB$lFG|ks)hz_C9irtATw($=FN?d9I0bi^_EYY^SPa)8Xqt*@A8q~BA zc8Rt@>}S@TZoq?vyBx5t9vtnB=EX^thUXZBb_-TKSy>h%zdt(EIvfvnN3{f=;Jq?--9p+vvYuAgMR)l;$;&s>+p z<$ZwE&VP9TB&0gd5uaxZhqja)q*yiSb}|XMAd;ONHUz_uy)c%Y&|^^KUXZWu-!V>Q z7*J)q*6tkjFVLU3=i9y{EXvAhSy`?)z=ia}t*o(jmxf$MsvztTyfWZ`+{!u&`3isL z^=-caM&N!QWH<9TU^Te3IAu)1!W>@Sh6~t@H}gajcTUoeMYp|fJd*;v26k(K%K%CP zB(eZvV2RGr8V#TgkuP>c7ei>Svk@dG-x8i2~h> zoCwDuD2MoJ#70SY*1uuCi8pu~K>L&WdM=4;<0XuF!uEbGXR>7G#=7Q_FRbZdYF=s! z2+^Ck;Q)4c;+WckZ?{ypGlOup zHTRbpSo7IG1gEkzsg2}}n+|{kA3-g6y{^=_+eYBZ+n+YpJd0jL&Cny=PwMqdl0YZi zf52J*ht4n!YhvwFuVqs0@P7YXdPMpb*aDV?E345fJ$a{I&!)0*TtFIQo*ZA`7~ae$ zQ|M3J)4*}X(Q%Y?4W{0>@j6Z@re=Uk*L6#vUI|=zZ32{qUI6OdAj1C5teK1MxMt$k z2euk(?}s8E%6&ER!N>{GrLC_dsy!&=6X)1!t5o)aO?(8?NH7G#EX%3BOe@j$NvU} C-2-bv}Cb`%6LsB(NEb(Q5~>{R!A{rbJvAMdL-Z#HWd{Qma$ zZ-+B2%lZcnPX8J(xCc%C*|scZQH$BkNo;Bx=upQ%m%0WLN(}UKffZUY zuu7{2)@aSZI;|VnpbZ0?v}s_Awg9?GJ8jdpjcb#nlg`i?gL}zrdWN1ccqu8Rb9Bz& z<)o6%(|LnelC$YKdd}e0LM)LjiD!pp(X3|YRq#qi*m3)+5qt^`HPOhgt+Dn({QhI~l zFtkqcae9;9H24h2-J-VupJivhw&-m($L0b4pn8_Q!_M&^@|z15JJ0X11!#->CjSxt zgkReFwC^lf53GKLU0@f#_ULEKd2V&z-T7;6B_C~zcraA{)u`(~RV+ zO8Q~HA186Bxa`$x8^c)oG9QbG`%%ug55pvmcqTdX#~I_oS3~YUUE4S!`~4i~ew^V% z?aA88>iY9lA5?;@q4^@umCs_qBbAG7KkozQ$tp^4-mASqM_cdOy0jTfU<9lc<(cAN zsWi;O0T)22vQULl5=sdJ@3ja=h?GYnXJaEmo`q`Ib#!y%#k195AVO_aIc!FoJ+N{&|2^pshfYe}RtKQwXG~qpXR=++F97 z_onW|-Lv;CVoiy-KOuY0zP0BvvTut;0fJPTHyKP%{)Jd1`)QbbFe zc>DQCEsa&2$lhpMwtZ6Abt1893s;XMw*>waxkz}6XA*#+wyY_GGVg}YM;d(E9)hN!u${;tvy$8*S z1$(Scy@|8c#NBPR46iUJyghZzF3lYj`xbvvWUF1AX5yF~p3H)q{=q43B&&jCr{LSS z3mQmviah}~!>Ygwr-s}64a}21I%sX(_~&jqC6elT|;%Kd(PUq5ame%>9e<-> zHwW+&XMu5N@#mQwk4CvrkVzqlhc}A30DRIq!5?IDFBb*baz%0 z*pAfIATZf42t)}tR0>BBXH0vLSEHe>#xhWlsdJ%;jCZ)M1AnLj&=mx_7Rxw;0L&t; z-8fU)>nC}rbnO>mGUls7fLM7LO5B?G0M%EI_oQ9Wqf1B1lRO)A%SAd4l2|IyMSgvV3@FCRA>ML=7UZ;?yAeW}R8+uiCTlRh*99a!L?=9l&QH z7SG#_gITf&BXJkR9*IwIxoWQZ!2s&PP?uE>2aaI3f~&E}A{Q8rewcQTBlPz$P}UyA z=DrOvX3OQB_Xm)0G5&TjM(0iRl_sSv7b-oZ%-=a!Wv69Gvt&}<@)(Sa$G21oyqSQn zx1R+ru@;67+vX9Hb+9f}MQV(j3d3&?K8SW5f^Ro_F^p2^L$o(0Eiha7Z7znoAta*&X9M9e|j_D+kU2@ftR9#Ls}K>pXhYUDo9wV0i># z;f7M_I_iQ-0#^w>8^Mc7S70Ho;uMM(7b44bwLi|_0h1>XhjEH=oWz@e!fT+Qk}a8M z0yBelatXz5ya@#gqCjd_Li7^UCNOE}+Tqg%Gzj!hUB_dY*yd7Tz7qIKD5AVPf=3N< zVts)i+#bQ@3Up6^!tz<&pN?yV(ycrMYzsrk33 zOi32-N5j(1V(##9CE7p}9qeYHqYH0KqS8Y1EHa^+n`&%Q0zMqTq+}CJB|e8KxP2vb zX~Sej;{&I%k#R2{xhl&kJR=i+7vC`Q*U%BGQm#Pm`Trlj6^_1+$zB!g;X9(v!7BnN VZe`xNLJmr{bKo4dj??<@{{S6znvVbg diff --git a/mace-bench/src/batchopt/atoms_to_graphs.py b/mace-bench/src/batchopt/atoms_to_graphs.py index ace2a73..d50fc76 100644 --- a/mace-bench/src/batchopt/atoms_to_graphs.py +++ b/mace-bench/src/batchopt/atoms_to_graphs.py @@ -1,309 +1,309 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import ase.db.sqlite -import ase.io.trajectory -import numpy as np -import torch -from ase.geometry import wrap_positions -from torch_geometric.data import Data - -from batchopt.utils import collate - -if TYPE_CHECKING: - from collections.abc import Sequence - -try: - from pymatgen.io.ase import AseAtomsAdaptor -except ImportError: - AseAtomsAdaptor = None - -from tqdm import tqdm - - -class AtomsToGraphs: - """A class to help convert periodic atomic structures to graphs. - - The AtomsToGraphs class takes in periodic atomic structures in form of ASE atoms objects and converts - them into graph representations for use in PyTorch. The primary purpose of this class is to determine the - nearest neighbors within some radius around each individual atom, taking into account PBC, and set the - pair index and distance between atom pairs appropriately. Lastly, atomic properties and the graph information - are put into a PyTorch geometric data object for use with PyTorch. - - Args: - max_neigh (int): Maximum number of neighbors to consider. - radius (int or float): Cutoff radius in Angstroms to search for neighbors. - r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. - r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. - r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. - r_distances (bool): Return the distances with other properties. - Default is False, so the distances will not be returned. - r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. - r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms. - Default is True, so the fixed indices will be returned. - r_pbc (bool): Return the periodic boundary conditions with other properties. - Default is False, so the periodic boundary conditions will not be returned. - r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other - properties. Default is None, so no data will be returned as properties. - - Attributes: - max_neigh (int): Maximum number of neighbors to consider. - radius (int or float): Cutoff radius in Angstoms to search for neighbors. - r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. - r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. - r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. - r_distances (bool): Return the distances with other properties. - Default is False, so the distances will not be returned. - r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. - r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms. - Default is True, so the fixed indices will be returned. - r_pbc (bool): Return the periodic boundary conditions with other properties. - Default is False, so the periodic boundary conditions will not be returned. - r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other - properties. Default is None, so no data will be returned as properties. - """ - - def __init__( - self, - max_neigh: int = 200, - radius: int = 6, - r_energy: bool = False, - r_forces: bool = False, - r_distances: bool = False, - r_edges: bool = True, - r_fixed: bool = True, - r_pbc: bool = False, - r_stress: bool = False, - r_data_keys: Sequence[str] | None = None, - ) -> None: - self.max_neigh = max_neigh - self.radius = radius - self.r_energy = r_energy - self.r_forces = r_forces - self.r_stress = r_stress - self.r_distances = r_distances - self.r_fixed = r_fixed - self.r_edges = r_edges - self.r_pbc = r_pbc - self.r_data_keys = r_data_keys - - def _get_neighbors_pymatgen(self, atoms: ase.Atoms): - """Preforms nearest neighbor search and returns edge index, distances, - and cell offsets""" - if AseAtomsAdaptor is None: - raise RuntimeError( - "Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed." - ) - - struct = AseAtomsAdaptor.get_structure(atoms) - _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( - r=self.radius, numerical_tol=0, exclude_self=True - ) - _nonmax_idx = [] - for i in range(len(atoms)): - idx_i = (_c_index == i).nonzero()[0] - # sort neighbors by distance, remove edges larger than max_neighbors - idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh] - _nonmax_idx.append(idx_i[idx_sorted]) - _nonmax_idx = np.concatenate(_nonmax_idx) - - _c_index = _c_index[_nonmax_idx] - _n_index = _n_index[_nonmax_idx] - n_distance = n_distance[_nonmax_idx] - _offsets = _offsets[_nonmax_idx] - - return _c_index, _n_index, n_distance, _offsets - - def _reshape_features(self, c_index, n_index, n_distance, offsets): - """Stack center and neighbor index and reshapes distances, - takes in np.arrays and returns torch tensors""" - edge_index = torch.LongTensor(np.vstack((n_index, c_index))) - edge_distances = torch.FloatTensor(n_distance) - cell_offsets = torch.LongTensor(offsets) - - # remove distances smaller than a tolerance ~ 0. The small tolerance is - # needed to correct for pymatgen's neighbor_list returning self atoms - # in a few edge cases. - nonzero = torch.where(edge_distances >= 1e-8)[0] - edge_index = edge_index[:, nonzero] - edge_distances = edge_distances[nonzero] - cell_offsets = cell_offsets[nonzero] - - return edge_index, edge_distances, cell_offsets - - def get_edge_distance_vec( - self, - pos, - edge_index, - cell, - cell_offsets, - ): - row, col = edge_index - distance_vectors = pos[row] - pos[col] - - # correct for pbc - cell = torch.repeat_interleave(cell, edge_index.shape[1], dim=0) - offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) - distance_vectors += offsets - - return distance_vectors - - def convert(self, atoms: ase.Atoms, sid=None): - """Convert a single atomic structure to a graph. - - Args: - atoms (ase.atoms.Atoms): An ASE atoms object. - - sid (uniquely identifying object): An identifier that can be used to track the structure in downstream - tasks. Common sids used in OCP datasets include unique strings or integers. - - Returns: - data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags, - and optionally, energy, forces, distances, edges, and periodic boundary conditions. - Optional properties can included by setting r_property=True when constructing the class. - """ - - # set the atomic numbers, positions, and cell - positions = np.array(atoms.get_positions(), copy=True) - pbc = np.array(atoms.pbc, copy=True) - cell = np.array(atoms.get_cell(complete=True), copy=True) - # TODO: change this back &&& ^^^ - # positions = wrap_positions(positions, cell, pbc=pbc, eps=0) - - atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8) - positions = torch.from_numpy(positions).float() - cell = torch.from_numpy(cell).view(1, 3, 3).float() - natoms = positions.shape[0] - - # initialized to torch.zeros(natoms) if tags missing. - # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags - tags = torch.tensor(atoms.get_tags(), dtype=torch.int) - - # put the minimum data in torch geometric data object - data = Data( - cell=cell, - pos=positions, - atomic_numbers=atomic_numbers, - natoms=natoms, - tags=tags, - ) - - # Optionally add a systemid (sid) to the object - if sid is not None: - data.sid = sid - - # optionally include other properties - if self.r_edges: - # run internal functions to get padded indices and distances - atoms_copy = atoms.copy() - atoms_copy.set_positions(positions) - split_idx_dist = self._get_neighbors_pymatgen(atoms_copy) - edge_index, edge_distances, cell_offsets = self._reshape_features( - *split_idx_dist - ) - - data.edge_index = edge_index - data.cell_offsets = cell_offsets - data.edge_distance_vec = self.get_edge_distance_vec( - positions, edge_index, cell, cell_offsets - ) - - del atoms_copy - if self.r_energy: - energy = atoms.get_potential_energy(apply_constraint=False) - data.energy = energy - if self.r_forces: - forces = torch.tensor( - atoms.get_forces(apply_constraint=False), dtype=torch.float32 - ) - data.forces = forces - if self.r_stress: - stress = torch.tensor( - atoms.get_stress(apply_constraint=False, voigt=False), - dtype=torch.float32, - ) - data.stress = stress - if self.r_distances and self.r_edges: - data.distances = edge_distances - if self.r_fixed: - fixed_idx = torch.zeros(natoms, dtype=torch.int) - if hasattr(atoms, "constraints"): - from ase.constraints import FixAtoms - - for constraint in atoms.constraints: - if isinstance(constraint, FixAtoms): - fixed_idx[constraint.index] = 1 - data.fixed = fixed_idx - if self.r_pbc: - data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool) - if self.r_data_keys is not None: - for data_key in self.r_data_keys: - data[data_key] = ( - atoms.info[data_key] - if isinstance(atoms.info[data_key], (int, float, str)) - else torch.tensor(atoms.info[data_key]) - ) - - return data - - def convert_all( - self, - atoms_collection, - processed_file_path: str | None = None, - collate_and_save=False, - disable_tqdm=False, - ): - """Convert all atoms objects in a list or in an ase.db to graphs. - - Args: - atoms_collection (list of ase.atoms.Atoms or ase.db.sqlite.SQLite3Database): - Either a list of ASE atoms objects or an ASE database. - processed_file_path (str): - A string of the path to where the processed file will be written. Default is None. - collate_and_save (bool): A boolean to collate and save or not. Default is False, so will not write a file. - - Returns: - data_list (list of torch_geometric.data.Data): - A list of torch geometric data objects containing molecular graph info and properties. - """ - - # list for all data - data_list = [] - if isinstance(atoms_collection, list): - atoms_iter = atoms_collection - elif isinstance(atoms_collection, ase.db.sqlite.SQLite3Database): - atoms_iter = atoms_collection.select() - elif isinstance( - atoms_collection, - (ase.io.trajectory.SlicedTrajectory, ase.io.trajectory.TrajectoryReader), - ): - atoms_iter = atoms_collection - else: - raise NotImplementedError - - for atoms in tqdm( - atoms_iter, - desc="converting ASE atoms collection to graphs", - total=len(atoms_collection), - unit=" systems", - disable=disable_tqdm, - ): - # check if atoms is an ASE Atoms object this for the ase.db case - data = self.convert( - atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms() - ) - data_list.append(data) - - if collate_and_save: - data, slices = collate(data_list) - torch.save((data, slices), processed_file_path) - - return data_list +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import ase.db.sqlite +import ase.io.trajectory +import numpy as np +import torch +from ase.geometry import wrap_positions +from torch_geometric.data import Data + +from batchopt.utils import collate + +if TYPE_CHECKING: + from collections.abc import Sequence + +try: + from pymatgen.io.ase import AseAtomsAdaptor +except ImportError: + AseAtomsAdaptor = None + +from tqdm import tqdm + + +class AtomsToGraphs: + """A class to help convert periodic atomic structures to graphs. + + The AtomsToGraphs class takes in periodic atomic structures in form of ASE atoms objects and converts + them into graph representations for use in PyTorch. The primary purpose of this class is to determine the + nearest neighbors within some radius around each individual atom, taking into account PBC, and set the + pair index and distance between atom pairs appropriately. Lastly, atomic properties and the graph information + are put into a PyTorch geometric data object for use with PyTorch. + + Args: + max_neigh (int): Maximum number of neighbors to consider. + radius (int or float): Cutoff radius in Angstroms to search for neighbors. + r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. + r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. + r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. + r_distances (bool): Return the distances with other properties. + Default is False, so the distances will not be returned. + r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. + r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms. + Default is True, so the fixed indices will be returned. + r_pbc (bool): Return the periodic boundary conditions with other properties. + Default is False, so the periodic boundary conditions will not be returned. + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other + properties. Default is None, so no data will be returned as properties. + + Attributes: + max_neigh (int): Maximum number of neighbors to consider. + radius (int or float): Cutoff radius in Angstoms to search for neighbors. + r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. + r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. + r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. + r_distances (bool): Return the distances with other properties. + Default is False, so the distances will not be returned. + r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. + r_fixed (bool): Return a binary vector with flags for fixed (1) vs free (0) atoms. + Default is True, so the fixed indices will be returned. + r_pbc (bool): Return the periodic boundary conditions with other properties. + Default is False, so the periodic boundary conditions will not be returned. + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other + properties. Default is None, so no data will be returned as properties. + """ + + def __init__( + self, + max_neigh: int = 200, + radius: int = 6, + r_energy: bool = False, + r_forces: bool = False, + r_distances: bool = False, + r_edges: bool = True, + r_fixed: bool = True, + r_pbc: bool = False, + r_stress: bool = False, + r_data_keys: Sequence[str] | None = None, + ) -> None: + self.max_neigh = max_neigh + self.radius = radius + self.r_energy = r_energy + self.r_forces = r_forces + self.r_stress = r_stress + self.r_distances = r_distances + self.r_fixed = r_fixed + self.r_edges = r_edges + self.r_pbc = r_pbc + self.r_data_keys = r_data_keys + + def _get_neighbors_pymatgen(self, atoms: ase.Atoms): + """Preforms nearest neighbor search and returns edge index, distances, + and cell offsets""" + if AseAtomsAdaptor is None: + raise RuntimeError( + "Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed." + ) + + struct = AseAtomsAdaptor.get_structure(atoms) + _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( + r=self.radius, numerical_tol=0, exclude_self=True + ) + _nonmax_idx = [] + for i in range(len(atoms)): + idx_i = (_c_index == i).nonzero()[0] + # sort neighbors by distance, remove edges larger than max_neighbors + idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh] + _nonmax_idx.append(idx_i[idx_sorted]) + _nonmax_idx = np.concatenate(_nonmax_idx) + + _c_index = _c_index[_nonmax_idx] + _n_index = _n_index[_nonmax_idx] + n_distance = n_distance[_nonmax_idx] + _offsets = _offsets[_nonmax_idx] + + return _c_index, _n_index, n_distance, _offsets + + def _reshape_features(self, c_index, n_index, n_distance, offsets): + """Stack center and neighbor index and reshapes distances, + takes in np.arrays and returns torch tensors""" + edge_index = torch.LongTensor(np.vstack((n_index, c_index))) + edge_distances = torch.FloatTensor(n_distance) + cell_offsets = torch.LongTensor(offsets) + + # remove distances smaller than a tolerance ~ 0. The small tolerance is + # needed to correct for pymatgen's neighbor_list returning self atoms + # in a few edge cases. + nonzero = torch.where(edge_distances >= 1e-8)[0] + edge_index = edge_index[:, nonzero] + edge_distances = edge_distances[nonzero] + cell_offsets = cell_offsets[nonzero] + + return edge_index, edge_distances, cell_offsets + + def get_edge_distance_vec( + self, + pos, + edge_index, + cell, + cell_offsets, + ): + row, col = edge_index + distance_vectors = pos[row] - pos[col] + + # correct for pbc + cell = torch.repeat_interleave(cell, edge_index.shape[1], dim=0) + offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) + distance_vectors += offsets + + return distance_vectors + + def convert(self, atoms: ase.Atoms, sid=None): + """Convert a single atomic structure to a graph. + + Args: + atoms (ase.atoms.Atoms): An ASE atoms object. + + sid (uniquely identifying object): An identifier that can be used to track the structure in downstream + tasks. Common sids used in OCP datasets include unique strings or integers. + + Returns: + data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags, + and optionally, energy, forces, distances, edges, and periodic boundary conditions. + Optional properties can included by setting r_property=True when constructing the class. + """ + + # set the atomic numbers, positions, and cell + positions = np.array(atoms.get_positions(), copy=True) + pbc = np.array(atoms.pbc, copy=True) + cell = np.array(atoms.get_cell(complete=True), copy=True) + # TODO: change this back &&& ^^^ + # positions = wrap_positions(positions, cell, pbc=pbc, eps=0) + + atomic_numbers = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.uint8) + positions = torch.from_numpy(positions).float() + cell = torch.from_numpy(cell).view(1, 3, 3).float() + natoms = positions.shape[0] + + # initialized to torch.zeros(natoms) if tags missing. + # https://wiki.fysik.dtu.dk/ase/_modules/ase/atoms.html#Atoms.get_tags + tags = torch.tensor(atoms.get_tags(), dtype=torch.int) + + # put the minimum data in torch geometric data object + data = Data( + cell=cell, + pos=positions, + atomic_numbers=atomic_numbers, + natoms=natoms, + tags=tags, + ) + + # Optionally add a systemid (sid) to the object + if sid is not None: + data.sid = sid + + # optionally include other properties + if self.r_edges: + # run internal functions to get padded indices and distances + atoms_copy = atoms.copy() + atoms_copy.set_positions(positions) + split_idx_dist = self._get_neighbors_pymatgen(atoms_copy) + edge_index, edge_distances, cell_offsets = self._reshape_features( + *split_idx_dist + ) + + data.edge_index = edge_index + data.cell_offsets = cell_offsets + data.edge_distance_vec = self.get_edge_distance_vec( + positions, edge_index, cell, cell_offsets + ) + + del atoms_copy + if self.r_energy: + energy = atoms.get_potential_energy(apply_constraint=False) + data.energy = energy + if self.r_forces: + forces = torch.tensor( + atoms.get_forces(apply_constraint=False), dtype=torch.float32 + ) + data.forces = forces + if self.r_stress: + stress = torch.tensor( + atoms.get_stress(apply_constraint=False, voigt=False), + dtype=torch.float32, + ) + data.stress = stress + if self.r_distances and self.r_edges: + data.distances = edge_distances + if self.r_fixed: + fixed_idx = torch.zeros(natoms, dtype=torch.int) + if hasattr(atoms, "constraints"): + from ase.constraints import FixAtoms + + for constraint in atoms.constraints: + if isinstance(constraint, FixAtoms): + fixed_idx[constraint.index] = 1 + data.fixed = fixed_idx + if self.r_pbc: + data.pbc = torch.tensor(atoms.pbc, dtype=torch.bool) + if self.r_data_keys is not None: + for data_key in self.r_data_keys: + data[data_key] = ( + atoms.info[data_key] + if isinstance(atoms.info[data_key], (int, float, str)) + else torch.tensor(atoms.info[data_key]) + ) + + return data + + def convert_all( + self, + atoms_collection, + processed_file_path: str | None = None, + collate_and_save=False, + disable_tqdm=False, + ): + """Convert all atoms objects in a list or in an ase.db to graphs. + + Args: + atoms_collection (list of ase.atoms.Atoms or ase.db.sqlite.SQLite3Database): + Either a list of ASE atoms objects or an ASE database. + processed_file_path (str): + A string of the path to where the processed file will be written. Default is None. + collate_and_save (bool): A boolean to collate and save or not. Default is False, so will not write a file. + + Returns: + data_list (list of torch_geometric.data.Data): + A list of torch geometric data objects containing molecular graph info and properties. + """ + + # list for all data + data_list = [] + if isinstance(atoms_collection, list): + atoms_iter = atoms_collection + elif isinstance(atoms_collection, ase.db.sqlite.SQLite3Database): + atoms_iter = atoms_collection.select() + elif isinstance( + atoms_collection, + (ase.io.trajectory.SlicedTrajectory, ase.io.trajectory.TrajectoryReader), + ): + atoms_iter = atoms_collection + else: + raise NotImplementedError + + for atoms in tqdm( + atoms_iter, + desc="converting ASE atoms collection to graphs", + total=len(atoms_collection), + unit=" systems", + disable=disable_tqdm, + ): + # check if atoms is an ASE Atoms object this for the ase.db case + data = self.convert( + atoms if isinstance(atoms, ase.atoms.Atoms) else atoms.toatoms() + ) + data_list.append(data) + + if collate_and_save: + data, slices = collate(data_list) + torch.save((data, slices), processed_file_path) + + return data_list diff --git a/mace-bench/src/batchopt/baseline.py b/mace-bench/src/batchopt/baseline.py index 552660b..d23595d 100644 --- a/mace-bench/src/batchopt/baseline.py +++ b/mace-bench/src/batchopt/baseline.py @@ -1,171 +1,171 @@ -""" -Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan} - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from ase.io import read -import logging -from joblib import Parallel, delayed -from ase.optimize import LBFGS as ASE_LBFGS -from ase.optimize import QuasiNewton as ASE_QuasiNewton -from ase.optimize import BFGS as ASE_BFGS -import time -import csv -import os -try: - from mace.calculators import mace_off -except ImportError: - logging.warning("Failed to import MACE modules") - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - - -def ensure_directory(directory): - """Create directory if it doesn't exist.""" - if not os.path.exists(directory): - os.makedirs(directory) - logging.info(f"Created directory: {directory}") - -def baseline_task(file, device, max_steps, filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, first_optimizer="LBFGS", second_optimizer="LBFGS"): - """ - Runs the baseline optimization using LBFGS from ase.optimize. - """ - os.environ["CUDA_VISIBLE_DEVICES"] = device.split(":")[-1] - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - logging.info(f"Starting baseline optimization for file {file} on device {device}.") - - - start_time = time.perf_counter() - - crystal = read(file) - # calc = mace_off(model="small", device=device) - calc = mace_off(model="small", device="cuda") - crystal.calc = calc - - first_optimizer_class ={ - "LBFGS": ASE_LBFGS, - "QuasiNewton": ASE_QuasiNewton, - "BFGS": ASE_BFGS - }.get(first_optimizer, ASE_LBFGS) - - # First optimization stage - if filter1 == "UnitCellFilter": - from ase.filters import UnitCellFilter - atoms_with_filter = UnitCellFilter(crystal, scalar_pressure=scalar_pressure) - first_optimizer_instance = first_optimizer_class(atoms_with_filter) - elif filter1 == "FrechetCellFilter": - from ase.filters import FrechetCellFilter - atoms_with_filter = FrechetCellFilter(crystal, scalar_pressure=scalar_pressure) - first_optimizer_instance = first_optimizer_class(atoms_with_filter) - else: - first_optimizer_instance = first_optimizer_class(crystal) - - start_time1 = time.perf_counter() - first_optimizer_instance.run(fmax=0.01, steps=max_steps) - end_time1 = time.perf_counter() - - # Save intermediate result - output_dir_press = "./cif_result_press" - output_file_press = os.path.join(output_dir_press, os.path.basename(file).replace(".cif", "_press.cif")) - crystal.write(output_file_press) - - elapsed_time1 = end_time1 - start_time1 - steps1 = first_optimizer_instance.nsteps - - if skip_second_stage: - - ret_result = { - "file": file, - "stage1_time": elapsed_time1, - "stage1_steps": steps1, - "stage2_time": 0.0, - "stage2_steps": 0, - "total_time": elapsed_time1, - "total_steps": steps1 - } - else: - # Second optimization stage - crystal = read(output_file_press) - crystal.calc = calc - - second_optimizer_class = { - "LBFGS": ASE_LBFGS, - "QuasiNewton": ASE_QuasiNewton, - "BFGS": ASE_BFGS - }.get(second_optimizer, ASE_LBFGS) - - if filter2 == "UnitCellFilter": - from ase.filters import UnitCellFilter - atoms_with_filter2 = UnitCellFilter(crystal) - second_optimizer_instance = second_optimizer_class(atoms_with_filter2) - elif filter2 == "FrechetCellFilter": - from ase.filters import FrechetCellFilter - atoms_with_filter2 = FrechetCellFilter(crystal) - second_optimizer_instance = second_optimizer_class(atoms_with_filter2) - else: - second_optimizer_instance = second_optimizer_class(crystal) - - start_time2 = time.perf_counter() - second_optimizer_instance.run(fmax=0.01, steps=max_steps) - end_time2 = time.perf_counter() - - # Save final result - output_dir_final = "./cif_result_final" - output_file_final = os.path.join(output_dir_final, os.path.basename(file).replace(".cif", "_opt.cif")) - crystal.write(output_file_final) - - # Collect metrics - elapsed_time2 = end_time2 - start_time2 - total_time = elapsed_time1 + elapsed_time2 - steps2 = second_optimizer_instance.nsteps - - ret_result = { - "file": file, - "stage1_time": elapsed_time1, - "stage1_steps": steps1, - "stage2_time": elapsed_time2, - "stage2_steps": steps2, - "total_time": total_time, - "total_steps": steps1 + steps2 - } - - logging.info(f"Baseline optimization completed for file {file}.") - return ret_result - -def run_baseline(files, num_workers, devices, max_steps, - filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, - optimizer1=None, optimizer2=None): - """ - Runs the baseline optimization using LBFGS from ase.optimize. - """ - logging.info(f"Starting baseline optimization with {num_workers} workers.") - - start_time = time.perf_counter() - results = Parallel(n_jobs=num_workers)( - delayed(baseline_task)(file, devices[i % len(devices)], max_steps, filter1, filter2, skip_second_stage, scalar_pressure, optimizer1, optimizer2) - for i, file in enumerate(files) - ) - end_time = time.perf_counter() - - csv_file = "results_baseline.csv" - with open(csv_file, mode='w', newline='') as file: - writer = csv.DictWriter(file, fieldnames=["file", "stage1_time", "stage1_steps", "stage2_time", "stage2_steps", "total_time", "total_steps"]) - writer.writeheader() - for result in results: - writer.writerow(result) - - logging.info(f"Baseline optimization completed in {end_time - start_time:.2f} seconds.") - final_elapsed_time = end_time - start_time - summary_csv_file = "summary_baseline.csv" - with open(summary_csv_file, mode='w', newline='') as file: - writer = csv.DictWriter(file, fieldnames=["elapsed_time", "num_workers", "batch_size"]) - writer.writeheader() - writer.writerow({ - "elapsed_time": final_elapsed_time, - "num_workers": num_workers, - "batch_size": 1 - }) - +""" +Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan} + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from ase.io import read +import logging +from joblib import Parallel, delayed +from ase.optimize import LBFGS as ASE_LBFGS +from ase.optimize import QuasiNewton as ASE_QuasiNewton +from ase.optimize import BFGS as ASE_BFGS +import time +import csv +import os +try: + from mace.calculators import mace_off +except ImportError: + logging.warning("Failed to import MACE modules") + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + +def ensure_directory(directory): + """Create directory if it doesn't exist.""" + if not os.path.exists(directory): + os.makedirs(directory) + logging.info(f"Created directory: {directory}") + +def baseline_task(file, device, max_steps, filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, first_optimizer="LBFGS", second_optimizer="LBFGS"): + """ + Runs the baseline optimization using LBFGS from ase.optimize. + """ + os.environ["CUDA_VISIBLE_DEVICES"] = device.split(":")[-1] + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + logging.info(f"Starting baseline optimization for file {file} on device {device}.") + + + start_time = time.perf_counter() + + crystal = read(file) + # calc = mace_off(model="small", device=device) + calc = mace_off(model="small", device="cuda") + crystal.calc = calc + + first_optimizer_class ={ + "LBFGS": ASE_LBFGS, + "QuasiNewton": ASE_QuasiNewton, + "BFGS": ASE_BFGS + }.get(first_optimizer, ASE_LBFGS) + + # First optimization stage + if filter1 == "UnitCellFilter": + from ase.filters import UnitCellFilter + atoms_with_filter = UnitCellFilter(crystal, scalar_pressure=scalar_pressure) + first_optimizer_instance = first_optimizer_class(atoms_with_filter) + elif filter1 == "FrechetCellFilter": + from ase.filters import FrechetCellFilter + atoms_with_filter = FrechetCellFilter(crystal, scalar_pressure=scalar_pressure) + first_optimizer_instance = first_optimizer_class(atoms_with_filter) + else: + first_optimizer_instance = first_optimizer_class(crystal) + + start_time1 = time.perf_counter() + first_optimizer_instance.run(fmax=0.01, steps=max_steps) + end_time1 = time.perf_counter() + + # Save intermediate result + output_dir_press = "./cif_result_press" + output_file_press = os.path.join(output_dir_press, os.path.basename(file).replace(".cif", "_press.cif")) + crystal.write(output_file_press) + + elapsed_time1 = end_time1 - start_time1 + steps1 = first_optimizer_instance.nsteps + + if skip_second_stage: + + ret_result = { + "file": file, + "stage1_time": elapsed_time1, + "stage1_steps": steps1, + "stage2_time": 0.0, + "stage2_steps": 0, + "total_time": elapsed_time1, + "total_steps": steps1 + } + else: + # Second optimization stage + crystal = read(output_file_press) + crystal.calc = calc + + second_optimizer_class = { + "LBFGS": ASE_LBFGS, + "QuasiNewton": ASE_QuasiNewton, + "BFGS": ASE_BFGS + }.get(second_optimizer, ASE_LBFGS) + + if filter2 == "UnitCellFilter": + from ase.filters import UnitCellFilter + atoms_with_filter2 = UnitCellFilter(crystal) + second_optimizer_instance = second_optimizer_class(atoms_with_filter2) + elif filter2 == "FrechetCellFilter": + from ase.filters import FrechetCellFilter + atoms_with_filter2 = FrechetCellFilter(crystal) + second_optimizer_instance = second_optimizer_class(atoms_with_filter2) + else: + second_optimizer_instance = second_optimizer_class(crystal) + + start_time2 = time.perf_counter() + second_optimizer_instance.run(fmax=0.01, steps=max_steps) + end_time2 = time.perf_counter() + + # Save final result + output_dir_final = "./cif_result_final" + output_file_final = os.path.join(output_dir_final, os.path.basename(file).replace(".cif", "_opt.cif")) + crystal.write(output_file_final) + + # Collect metrics + elapsed_time2 = end_time2 - start_time2 + total_time = elapsed_time1 + elapsed_time2 + steps2 = second_optimizer_instance.nsteps + + ret_result = { + "file": file, + "stage1_time": elapsed_time1, + "stage1_steps": steps1, + "stage2_time": elapsed_time2, + "stage2_steps": steps2, + "total_time": total_time, + "total_steps": steps1 + steps2 + } + + logging.info(f"Baseline optimization completed for file {file}.") + return ret_result + +def run_baseline(files, num_workers, devices, max_steps, + filter1=None, filter2=None, skip_second_stage=False, scalar_pressure=0.0006, + optimizer1=None, optimizer2=None): + """ + Runs the baseline optimization using LBFGS from ase.optimize. + """ + logging.info(f"Starting baseline optimization with {num_workers} workers.") + + start_time = time.perf_counter() + results = Parallel(n_jobs=num_workers)( + delayed(baseline_task)(file, devices[i % len(devices)], max_steps, filter1, filter2, skip_second_stage, scalar_pressure, optimizer1, optimizer2) + for i, file in enumerate(files) + ) + end_time = time.perf_counter() + + csv_file = "results_baseline.csv" + with open(csv_file, mode='w', newline='') as file: + writer = csv.DictWriter(file, fieldnames=["file", "stage1_time", "stage1_steps", "stage2_time", "stage2_steps", "total_time", "total_steps"]) + writer.writeheader() + for result in results: + writer.writerow(result) + + logging.info(f"Baseline optimization completed in {end_time - start_time:.2f} seconds.") + final_elapsed_time = end_time - start_time + summary_csv_file = "summary_baseline.csv" + with open(summary_csv_file, mode='w', newline='') as file: + writer = csv.DictWriter(file, fieldnames=["elapsed_time", "num_workers", "batch_size"]) + writer.writeheader() + writer.writerow({ + "elapsed_time": final_elapsed_time, + "num_workers": num_workers, + "batch_size": 1 + }) + logging.info(f"Summary results written to {summary_csv_file}.") \ No newline at end of file diff --git a/mace-bench/src/batchopt/extensions/__init__.py b/mace-bench/src/batchopt/extensions/__init__.py index 7006541..2e0a0f2 100644 --- a/mace-bench/src/batchopt/extensions/__init__.py +++ b/mace-bench/src/batchopt/extensions/__init__.py @@ -1,12 +1,12 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -BatchOpt Extensions - C++ and CUDA implementations for performance-critical operations. - -This module provides optimized implementations of common operations using -torch.utils.cpp_extension for JIT compilation. -""" - +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +BatchOpt Extensions - C++ and CUDA implementations for performance-critical operations. + +This module provides optimized implementations of common operations using +torch.utils.cpp_extension for JIT compilation. +""" + diff --git a/mace-bench/src/batchopt/extensions/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/extensions/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 05baaef36837c42cefd96f5e646b6c6ffacbd2f4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 558 zcmZ8fu}&N@5cSdIWQl*6GAG0?LZXYJB%3Qh0U{ltps|eY!-Liy%l1Ybe?&o}(|#JG$q9>?mzA z-uMb1%>oV&4v*ke!K&XOU%bMwCdL8d z7?3}XyXDzx32W^6?u0=ms$elYQz>tN)|~l|rDo?EDZlAz{CTRdysONH6btM_x3>qH zt9QS?e+3MwcfrM~>l9v7hU7ATf{w=Rn~W8W?tvKH=5XPuqd<0GK`GO3w61VQ#enzR z)Bij*V6kHqh{nmEgv3wSH%NDh1Q$0o1t$uFtYZr}gOE&HG=n+eA--t`A6h^E{ UHEHuS|FrG%^Y{mpk#8>k0xXlX)c^nh diff --git a/mace-bench/src/batchopt/extensions/cuda_ops/__init__.py b/mace-bench/src/batchopt/extensions/cuda_ops/__init__.py index 1e6bf7a..5ce1bd1 100644 --- a/mace-bench/src/batchopt/extensions/cuda_ops/__init__.py +++ b/mace-bench/src/batchopt/extensions/cuda_ops/__init__.py @@ -1,91 +1,91 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -CUDA Extension wrapper for vector addition and PBC graph operations. -""" -import torch -from torch.utils.cpp_extension import load -import os - -def load_cuda_extension(): - """Load the CUDA extension for vector addition.""" - # Check if CUDA is available - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available. Cannot load CUDA extension.") - - # Get the directory of this file - current_dir = os.path.dirname(os.path.abspath(__file__)) - - # Path to the CUDA source file - cuda_file = os.path.join(current_dir, "vector_add.cu") - - # Load the extension - return load( - name="vector_add_cuda", - sources=[cuda_file], - verbose=True, - extra_cflags=['-O3'], - extra_cuda_cflags=['-O3', '--use_fast_math'], - ) - -def load_pbc_graph_cuda_extension(): - """Load the CUDA extension for PBC graph operations.""" - # Check if CUDA is available - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available. Cannot load CUDA extension.") - - # Get the directory of this file - current_dir = os.path.dirname(os.path.abspath(__file__)) - - # Path to the CUDA source file - cuda_file = os.path.join(current_dir, "pbc_graph.cu") - - # Load the extension - return load( - name="pbc_graph_cuda", - sources=[cuda_file], - verbose=True, - extra_cflags=['-O3'], - extra_cuda_cflags=['-O3', '--use_fast_math'], - ) - -# Global variable to store loaded extension -_cuda_extension = None -_pbc_graph_cuda_extension = None - -def get_cuda_extension(): - """Get or load the CUDA extension.""" - global _cuda_extension - if _cuda_extension is None: - _cuda_extension = load_cuda_extension() - return _cuda_extension - -def get_pbc_graph_cuda_extension(): - """Get or load the PBC graph CUDA extension.""" - global _pbc_graph_cuda_extension - if _pbc_graph_cuda_extension is None: - _pbc_graph_cuda_extension = load_pbc_graph_cuda_extension() - return _pbc_graph_cuda_extension - -def vector_add_cuda(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """ - Perform vector addition using CUDA implementation. - - Args: - a: First input tensor (must be on CUDA device) - b: Second input tensor (must be on CUDA device) - - Returns: - Result tensor of element-wise addition - """ - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available.") - - if not (a.is_cuda and b.is_cuda): - raise ValueError("CUDA implementation requires CUDA tensors. Use .cuda() to move tensors to GPU.") - - extension = get_cuda_extension() - return extension.vector_add(a.float(), b.float()) +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +CUDA Extension wrapper for vector addition and PBC graph operations. +""" +import torch +from torch.utils.cpp_extension import load +import os + +def load_cuda_extension(): + """Load the CUDA extension for vector addition.""" + # Check if CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Cannot load CUDA extension.") + + # Get the directory of this file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Path to the CUDA source file + cuda_file = os.path.join(current_dir, "vector_add.cu") + + # Load the extension + return load( + name="vector_add_cuda", + sources=[cuda_file], + verbose=True, + extra_cflags=['-O3'], + extra_cuda_cflags=['-O3', '--use_fast_math'], + ) + +def load_pbc_graph_cuda_extension(): + """Load the CUDA extension for PBC graph operations.""" + # Check if CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. Cannot load CUDA extension.") + + # Get the directory of this file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Path to the CUDA source file + cuda_file = os.path.join(current_dir, "pbc_graph.cu") + + # Load the extension + return load( + name="pbc_graph_cuda", + sources=[cuda_file], + verbose=True, + extra_cflags=['-O3'], + extra_cuda_cflags=['-O3', '--use_fast_math'], + ) + +# Global variable to store loaded extension +_cuda_extension = None +_pbc_graph_cuda_extension = None + +def get_cuda_extension(): + """Get or load the CUDA extension.""" + global _cuda_extension + if _cuda_extension is None: + _cuda_extension = load_cuda_extension() + return _cuda_extension + +def get_pbc_graph_cuda_extension(): + """Get or load the PBC graph CUDA extension.""" + global _pbc_graph_cuda_extension + if _pbc_graph_cuda_extension is None: + _pbc_graph_cuda_extension = load_pbc_graph_cuda_extension() + return _pbc_graph_cuda_extension + +def vector_add_cuda(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Perform vector addition using CUDA implementation. + + Args: + a: First input tensor (must be on CUDA device) + b: Second input tensor (must be on CUDA device) + + Returns: + Result tensor of element-wise addition + """ + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available.") + + if not (a.is_cuda and b.is_cuda): + raise ValueError("CUDA implementation requires CUDA tensors. Use .cuda() to move tensors to GPU.") + + extension = get_cuda_extension() + return extension.vector_add(a.float(), b.float()) diff --git a/mace-bench/src/batchopt/extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/extensions/cuda_ops/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 358ed12c90314b7bbb2b36f457f9a66592fece28..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2301 zcmb7F&2QT_6c;7gk`*WE*Sd9ESMCd_fyLfB?2-*bTQ4gLv{}%kJLn?NB5lc#MUAA= z#c(fculpa8V-7p+pV48Lop$amJ8X|sYv*&o1n`OPk&o|_Kl0_}z=H3$-w%`ZCCmC# zGsnLsFn3|d1PHZ4iy~@gC`9mgGJE2L4yt8t=u($9URq&;HmL{yCbfwN!>2xN0pFql zU4nl=+YF7D=(3?Jnhu+Em0o)3hHYwZTb;|(KY~6l_Cz{L6uupI@SV@^e1#tq943U1 zQxXI_Nh-1ArHC1hIb|5wER9(%8O3EznZPPx`0>M?H)1^GAjfI07lKC*`x~3v8+e#z zplf7UaIP>-1&b9IdzcSl=fa8-jCF&c|Ma_ic;jaUmQtSMXMz+37#nhdca0(@l%`6r z5iqxPzmG?-C&3)H5iOV9;4jFgwb?9Pa6qiDJC17_7t6W2-YMT8&e&F^V)6Ud2JfVXsQ%q6)369phtf=z{9SOhp|bbEU1 z0aF<6rRh%bg;hDguDw3(+m9+AHXVj z)j93eBG2decO$weJtiyXdiuWh^dhzU;Dz{5laGN^h*W4$H3d_p$YY(;H=hHO1`M5J zc^Al8AevcYSFFy^*p*jRgW3k1nlG&9_RQY*X4bxcV9k)G2O8F$=|2I6Z;J_ln#mcw zEM=OH>LR6+B4ZP{at1NGM$M3WVkFniOrJ#7@%O2a3J|;~71p2(^}anRK{Q|(FxaG& zvR!x@I&-zbI^Je6H?P6}(K81;F{wJ{#1kgVY;GK04OY9p_AHfb{u&tF>2>4J88n@4 z*76fkW5IqZ;Wd(V5Or{}i=V=&aP26$-NB0E3EyR}BbS5D{DM7O&bOah^`t#xG<{Yh)=ZQr&_T!hn-^$r*DPO9r6%{)R#P1@=@FN zQPaML%t~ju@}h|HIEsuz;SGb*-BP8Q?8Zg0xa|gO#bte_hOWB1Z|RxA0jy`HW1Bms o%I5rI-YhPv|ILKcGGpH&(2*J)tN;K2 diff --git a/mace-bench/src/batchopt/pbc_graph.py b/mace-bench/src/batchopt/pbc_graph.py index d9cb744..998b047 100644 --- a/mace-bench/src/batchopt/pbc_graph.py +++ b/mace-bench/src/batchopt/pbc_graph.py @@ -1,158 +1,158 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -CUDA-accelerated PBC graph operations for atomic systems. -""" -import torch -from typing import Optional, List -from .pbc_graph_legacy import get_max_neighbors_mask -from .extensions.cuda_ops import get_pbc_graph_cuda_extension - -def radius_graph_pbc_cuda( - data, - radius, - max_num_neighbors_threshold, - enforce_max_neighbors_strictly: bool = False, - pbc=None, - dtype=torch.float64, -): - """ - Memory-efficient CUDA-accelerated version of radius_graph_pbc. - - This implementation follows the memory-efficient approach with triple loops - but accelerates the distance computation using CUDA kernels. - """ - if pbc is None: - pbc = [True, True, True] - - device = data.pos.device - batch_size = len(data.natoms) - - # Handle PBC settings - if hasattr(data, "pbc"): - data.pbc = torch.atleast_2d(data.pbc) - for i in range(3): - if not torch.any(data.pbc[:, i]).item(): - pbc[i] = False - elif torch.all(data.pbc[:, i]).item(): - pbc[i] = True - else: - raise RuntimeError( - "Different structures in the batch have different PBC configurations." - ) - - # position of the atoms - atom_pos = data.pos - - # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch - num_atoms_per_image = data.natoms - num_atoms_per_image_sqr = (num_atoms_per_image**2).long() - - # index offset between images - index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image - - index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) - num_atoms_per_image_expand = torch.repeat_interleave( - num_atoms_per_image, num_atoms_per_image_sqr - ) - - # Compute atom pair indices - num_atom_pairs = torch.sum(num_atoms_per_image_sqr) - index_sqr_offset = ( - torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr - ) - index_sqr_offset = torch.repeat_interleave( - index_sqr_offset, num_atoms_per_image_sqr - ) - atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset - - # Compute the indices for the pairs of atoms (using division and mod) - index1 = ( - torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") - ) + index_offset_expand - index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand - # Get the positions for each atom - pos1 = torch.index_select(atom_pos, 0, index1) - pos2 = torch.index_select(atom_pos, 0, index2) - - # Calculate required number of unit cells in each direction for PBC - cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) - cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) - - if pbc[0]: - inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) - rep_a1 = torch.ceil(radius * inv_min_dist_a1) - else: - rep_a1 = data.cell.new_zeros(1) - - if pbc[1]: - cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) - inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) - rep_a2 = torch.ceil(radius * inv_min_dist_a2) - else: - rep_a2 = data.cell.new_zeros(1) - - if pbc[2]: - cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) - inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) - rep_a3 = torch.ceil(radius * inv_min_dist_a3) - else: - rep_a3 = data.cell.new_zeros(1) - - # Take the max over all images for uniformity - max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())] - - # Pre-transpose data_cell for efficiency - data_cell = torch.transpose(data.cell, 1, 2) - - # Use CUDA kernel for the triple loop computation - # try: - pbc_graph_cuda = get_pbc_graph_cuda_extension() - - # Call the CUDA implementation - valid_pair_indices, unit_cell, atom_distance_sqr = pbc_graph_cuda.pbc_distance_cuda( - pos1, pos2, data_cell, - num_atoms_per_image_sqr, batch_size, max_rep, float(radius), device - ) - - # Map back to original index1 and index2 - if len(valid_pair_indices) > 0: - index1 = index1.index_select(0, valid_pair_indices.long()) - index2 = index2.index_select(0, valid_pair_indices.long()) - else: - index1 = torch.empty(0, dtype=torch.long, device=device) - index2 = torch.empty(0, dtype=torch.long, device=device) - unit_cell = torch.empty(0, 3, dtype=dtype, device=device) - atom_distance_sqr = torch.empty(0, dtype=dtype, device=device) - - # Sort index1 in ascending order and rearrange other arrays correspondingly - if len(index1) > 0: - sort_indices = torch.argsort(index1) - index1 = index1[sort_indices] - index2 = index2[sort_indices] - unit_cell = unit_cell[sort_indices] - atom_distance_sqr = atom_distance_sqr[sort_indices] - - mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( - natoms=data.natoms, - index=index1, - atom_distance=atom_distance_sqr, - max_num_neighbors_threshold=max_num_neighbors_threshold, - enforce_max_strictly=enforce_max_neighbors_strictly, - ) - - if not torch.all(mask_num_neighbors): - # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors - index1 = torch.masked_select(index1, mask_num_neighbors) - index2 = torch.masked_select(index2, mask_num_neighbors) - unit_cell = torch.masked_select( - unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - - edge_index = torch.stack((index2, index1)) - +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +CUDA-accelerated PBC graph operations for atomic systems. +""" +import torch +from typing import Optional, List +from .pbc_graph_legacy import get_max_neighbors_mask +from .extensions.cuda_ops import get_pbc_graph_cuda_extension + +def radius_graph_pbc_cuda( + data, + radius, + max_num_neighbors_threshold, + enforce_max_neighbors_strictly: bool = False, + pbc=None, + dtype=torch.float64, +): + """ + Memory-efficient CUDA-accelerated version of radius_graph_pbc. + + This implementation follows the memory-efficient approach with triple loops + but accelerates the distance computation using CUDA kernels. + """ + if pbc is None: + pbc = [True, True, True] + + device = data.pos.device + batch_size = len(data.natoms) + + # Handle PBC settings + if hasattr(data, "pbc"): + data.pbc = torch.atleast_2d(data.pbc) + for i in range(3): + if not torch.any(data.pbc[:, i]).item(): + pbc[i] = False + elif torch.all(data.pbc[:, i]).item(): + pbc[i] = True + else: + raise RuntimeError( + "Different structures in the batch have different PBC configurations." + ) + + # position of the atoms + atom_pos = data.pos + + # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch + num_atoms_per_image = data.natoms + num_atoms_per_image_sqr = (num_atoms_per_image**2).long() + + # index offset between images + index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image + + index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) + num_atoms_per_image_expand = torch.repeat_interleave( + num_atoms_per_image, num_atoms_per_image_sqr + ) + + # Compute atom pair indices + num_atom_pairs = torch.sum(num_atoms_per_image_sqr) + index_sqr_offset = ( + torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr + ) + index_sqr_offset = torch.repeat_interleave( + index_sqr_offset, num_atoms_per_image_sqr + ) + atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset + + # Compute the indices for the pairs of atoms (using division and mod) + index1 = ( + torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") + ) + index_offset_expand + index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand + # Get the positions for each atom + pos1 = torch.index_select(atom_pos, 0, index1) + pos2 = torch.index_select(atom_pos, 0, index2) + + # Calculate required number of unit cells in each direction for PBC + cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) + cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) + + if pbc[0]: + inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) + rep_a1 = torch.ceil(radius * inv_min_dist_a1) + else: + rep_a1 = data.cell.new_zeros(1) + + if pbc[1]: + cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) + inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) + rep_a2 = torch.ceil(radius * inv_min_dist_a2) + else: + rep_a2 = data.cell.new_zeros(1) + + if pbc[2]: + cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) + inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) + rep_a3 = torch.ceil(radius * inv_min_dist_a3) + else: + rep_a3 = data.cell.new_zeros(1) + + # Take the max over all images for uniformity + max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())] + + # Pre-transpose data_cell for efficiency + data_cell = torch.transpose(data.cell, 1, 2) + + # Use CUDA kernel for the triple loop computation + # try: + pbc_graph_cuda = get_pbc_graph_cuda_extension() + + # Call the CUDA implementation + valid_pair_indices, unit_cell, atom_distance_sqr = pbc_graph_cuda.pbc_distance_cuda( + pos1, pos2, data_cell, + num_atoms_per_image_sqr, batch_size, max_rep, float(radius), device + ) + + # Map back to original index1 and index2 + if len(valid_pair_indices) > 0: + index1 = index1.index_select(0, valid_pair_indices.long()) + index2 = index2.index_select(0, valid_pair_indices.long()) + else: + index1 = torch.empty(0, dtype=torch.long, device=device) + index2 = torch.empty(0, dtype=torch.long, device=device) + unit_cell = torch.empty(0, 3, dtype=dtype, device=device) + atom_distance_sqr = torch.empty(0, dtype=dtype, device=device) + + # Sort index1 in ascending order and rearrange other arrays correspondingly + if len(index1) > 0: + sort_indices = torch.argsort(index1) + index1 = index1[sort_indices] + index2 = index2[sort_indices] + unit_cell = unit_cell[sort_indices] + atom_distance_sqr = atom_distance_sqr[sort_indices] + + mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( + natoms=data.natoms, + index=index1, + atom_distance=atom_distance_sqr, + max_num_neighbors_threshold=max_num_neighbors_threshold, + enforce_max_strictly=enforce_max_neighbors_strictly, + ) + + if not torch.all(mask_num_neighbors): + # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors + index1 = torch.masked_select(index1, mask_num_neighbors) + index2 = torch.masked_select(index2, mask_num_neighbors) + unit_cell = torch.masked_select( + unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + + edge_index = torch.stack((index2, index1)) + return edge_index, unit_cell, num_neighbors_image \ No newline at end of file diff --git a/mace-bench/src/batchopt/pbc_graph_legacy.py b/mace-bench/src/batchopt/pbc_graph_legacy.py index 484ed0e..ca75c53 100644 --- a/mace-bench/src/batchopt/pbc_graph_legacy.py +++ b/mace-bench/src/batchopt/pbc_graph_legacy.py @@ -1,563 +1,563 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import numpy as np -import torch -import torch.nn as nn -import torch_geometric -from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas -from matplotlib.figure import Figure -from torch_geometric.data import Data -from torch_geometric.utils import remove_self_loops -from torch_scatter import scatter, segment_coo, segment_csr - - -if TYPE_CHECKING: - from collections.abc import Mapping - - from torch.nn.modules.module import _IncompatibleKeys - - -DEFAULT_ENV_VARS = { - # Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes). - # see https://pytorch.org/docs/stable/notes/cuda.html. - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", -} - - -def get_pbc_distances( - pos, - edge_index, - cell, - cell_offsets, - neighbors, - return_offsets: bool = False, - return_distance_vec: bool = False, -): - row, col = edge_index - - distance_vectors = pos[row] - pos[col] - - # correct for pbc - neighbors = neighbors.to(cell.device) - cell = torch.repeat_interleave(cell, neighbors, dim=0) - offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) - distance_vectors += offsets - - # compute distances - distances = distance_vectors.norm(dim=-1) - - # redundancy: remove zero distances - nonzero_idx = torch.arange(len(distances), device=distances.device)[distances != 0] - edge_index = edge_index[:, nonzero_idx] - distances = distances[nonzero_idx] - - out = { - "edge_index": edge_index, - "distances": distances, - } - - if return_distance_vec: - out["distance_vec"] = distance_vectors[nonzero_idx] - - if return_offsets: - out["offsets"] = offsets[nonzero_idx] - - return out - -def radius_graph_pbc_mem_effi( - data, - radius, - max_num_neighbors_threshold, - enforce_max_neighbors_strictly: bool = False, - pbc=None, - dtype=torch.float64, -): - if pbc is None: - pbc = [True, True, True] - device = data.pos.device - batch_size = len(data.natoms) - - if hasattr(data, "pbc"): - data.pbc = torch.atleast_2d(data.pbc) - for i in range(3): - if not torch.any(data.pbc[:, i]).item(): - pbc[i] = False - elif torch.all(data.pbc[:, i]).item(): - pbc[i] = True - else: - raise RuntimeError( - "Different structures in the batch have different PBC configurations. This is not currently supported." - ) - - # position of the atoms - atom_pos = data.pos - - # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch - num_atoms_per_image = data.natoms - num_atoms_per_image_sqr = (num_atoms_per_image**2).long() - - # index offset between images - index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image - - index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) - num_atoms_per_image_expand = torch.repeat_interleave( - num_atoms_per_image, num_atoms_per_image_sqr - ) - - # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image - # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement - # the following (but 10x faster since it removes the for loop) - # for batch_idx in range(batch_size): - # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) - num_atom_pairs = torch.sum(num_atoms_per_image_sqr) - index_sqr_offset = ( - torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr - ) - index_sqr_offset = torch.repeat_interleave( - index_sqr_offset, num_atoms_per_image_sqr - ) - atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset - - # Compute the indices for the pairs of atoms (using division and mod) - # If the systems get too large this apporach could run into numerical precision issues - index1 = ( - torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") - ) + index_offset_expand - index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand - # Get the positions for each atom - pos1 = torch.index_select(atom_pos, 0, index1) - pos2 = torch.index_select(atom_pos, 0, index2) - - # Calculate required number of unit cells in each direction. - # Smallest distance between planes separated by a1 is - # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. - # Note that the unit cell volume V = a1 * (a2 x a3) and that - # (a2 x a3) / V is also the reciprocal primitive vector - # (crystallographer's definition). - - cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) - cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) - - if pbc[0]: - inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) - rep_a1 = torch.ceil(radius * inv_min_dist_a1) - else: - rep_a1 = data.cell.new_zeros(1) - - if pbc[1]: - cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) - inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) - rep_a2 = torch.ceil(radius * inv_min_dist_a2) - else: - rep_a2 = data.cell.new_zeros(1) - - if pbc[2]: - cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) - inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) - rep_a3 = torch.ceil(radius * inv_min_dist_a3) - else: - rep_a3 = data.cell.new_zeros(1) - - # Take the max over all images for uniformity. This is essentially padding. - # Note that this can significantly increase the number of computed distances - # if the required repetitions are very different between images - # (which they usually are). Changing this to sparse (scatter) operations - # might be worth the effort if this function becomes a bottleneck. - max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())] - - # Memory-efficient implementation: iterate over unit cell offsets instead of expanding all at once - # This reduces memory usage by avoiding the creation of large tensor products - all_index1 = [] - all_index2 = [] - all_unit_cell = [] - all_atom_distance_sqr = [] - - # Pre-transpose data_cell for efficiency - data_cell = torch.transpose(data.cell, 1, 2) - - # Iterate over each unit cell offset combination - for i in range(-max_rep[0], max_rep[0] + 1): - for j in range(-max_rep[1], max_rep[1] + 1): - for k in range(-max_rep[2], max_rep[2] + 1): - # Create unit cell offset - unit_cell_offset = torch.tensor([i, j, k], device=device, dtype=dtype) - - # Compute the x, y, z positional offsets for this specific cell in each image - # unit_cell_offset_batch = unit_cell_offset.view(3, 1).expand(3, batch_size) - unit_cell_offset_batch = unit_cell_offset.view(1,3,1).expand(batch_size, -1, -1) - pbc_offsets = torch.bmm(data_cell, unit_cell_offset_batch).squeeze(-1) - pbc_offsets_per_atom = torch.repeat_interleave( - pbc_offsets, num_atoms_per_image_sqr, dim=0 - ) - - # Apply PBC offsets to the second atom positions - pos2_offset = pos2 + pbc_offsets_per_atom - - # Compute the squared distance between atoms - atom_distance_sqr = torch.sum((pos1 - pos2_offset) ** 2, dim=1) - - # Remove pairs that are too far apart - mask_within_radius = torch.le(atom_distance_sqr, radius * radius) - # Remove pairs with the same atoms (distance = 0.0) - mask_not_same = torch.gt(atom_distance_sqr, 0.0001) - mask = torch.logical_and(mask_within_radius, mask_not_same) - - # Only keep valid pairs for this unit cell offset - if torch.any(mask): - valid_index1 = torch.masked_select(index1, mask) - valid_index2 = torch.masked_select(index2, mask) - valid_distances = torch.masked_select(atom_distance_sqr, mask) - valid_unit_cell = unit_cell_offset.unsqueeze(0).repeat(valid_index1.shape[0], 1) - - all_index1.append(valid_index1) - all_index2.append(valid_index2) - all_unit_cell.append(valid_unit_cell) - all_atom_distance_sqr.append(valid_distances) - - # Concatenate all results - if len(all_index1) > 0: - index1 = torch.cat(all_index1) - index2 = torch.cat(all_index2) - unit_cell = torch.cat(all_unit_cell) - atom_distance_sqr = torch.cat(all_atom_distance_sqr) - - # Sort index1 in ascending order and rearrange other arrays correspondingly - sort_indices = torch.argsort(index1) - index1 = index1[sort_indices] - index2 = index2[sort_indices] - unit_cell = unit_cell[sort_indices] - atom_distance_sqr = atom_distance_sqr[sort_indices] - - else: - # No valid pairs found - index1 = torch.empty(0, dtype=torch.long, device=device) - index2 = torch.empty(0, dtype=torch.long, device=device) - unit_cell = torch.empty(0, 3, dtype=dtype, device=device) - atom_distance_sqr = torch.empty(0, dtype=dtype, device=device) - - mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( - natoms=data.natoms, - index=index1, - atom_distance=atom_distance_sqr, - max_num_neighbors_threshold=max_num_neighbors_threshold, - enforce_max_strictly=enforce_max_neighbors_strictly, - ) - - if not torch.all(mask_num_neighbors): - # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors - index1 = torch.masked_select(index1, mask_num_neighbors) - index2 = torch.masked_select(index2, mask_num_neighbors) - unit_cell = torch.masked_select( - unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - - edge_index = torch.stack((index2, index1)) - - return edge_index, unit_cell, num_neighbors_image - - -def radius_graph_pbc( - data, - radius, - max_num_neighbors_threshold, - enforce_max_neighbors_strictly: bool = False, - pbc=None, - dtype=torch.float64, -): - if pbc is None: - pbc = [True, True, True] - device = data.pos.device - batch_size = len(data.natoms) - - if hasattr(data, "pbc"): - data.pbc = torch.atleast_2d(data.pbc) - for i in range(3): - if not torch.any(data.pbc[:, i]).item(): - pbc[i] = False - elif torch.all(data.pbc[:, i]).item(): - pbc[i] = True - else: - raise RuntimeError( - "Different structures in the batch have different PBC configurations. This is not currently supported." - ) - - # position of the atoms - atom_pos = data.pos - - # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch - num_atoms_per_image = data.natoms - num_atoms_per_image_sqr = (num_atoms_per_image**2).long() - - # index offset between images - index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image - - index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) - num_atoms_per_image_expand = torch.repeat_interleave( - num_atoms_per_image, num_atoms_per_image_sqr - ) - - # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image - # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement - # the following (but 10x faster since it removes the for loop) - # for batch_idx in range(batch_size): - # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) - num_atom_pairs = torch.sum(num_atoms_per_image_sqr) - index_sqr_offset = ( - torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr - ) - index_sqr_offset = torch.repeat_interleave( - index_sqr_offset, num_atoms_per_image_sqr - ) - atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset - - # Compute the indices for the pairs of atoms (using division and mod) - # If the systems get too large this apporach could run into numerical precision issues - index1 = ( - torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") - ) + index_offset_expand - index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand - # Get the positions for each atom - pos1 = torch.index_select(atom_pos, 0, index1) - pos2 = torch.index_select(atom_pos, 0, index2) - - # Calculate required number of unit cells in each direction. - # Smallest distance between planes separated by a1 is - # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. - # Note that the unit cell volume V = a1 * (a2 x a3) and that - # (a2 x a3) / V is also the reciprocal primitive vector - # (crystallographer's definition). - - cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) - cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) - - if pbc[0]: - inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) - rep_a1 = torch.ceil(radius * inv_min_dist_a1) - else: - rep_a1 = data.cell.new_zeros(1) - - if pbc[1]: - cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) - inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) - rep_a2 = torch.ceil(radius * inv_min_dist_a2) - else: - rep_a2 = data.cell.new_zeros(1) - - if pbc[2]: - cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) - inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) - rep_a3 = torch.ceil(radius * inv_min_dist_a3) - else: - rep_a3 = data.cell.new_zeros(1) - - # Take the max over all images for uniformity. This is essentially padding. - # Note that this can significantly increase the number of computed distances - # if the required repetitions are very different between images - # (which they usually are). Changing this to sparse (scatter) operations - # might be worth the effort if this function becomes a bottleneck. - max_rep = [2*rep_a1.max(), 2*rep_a2.max(), 2*rep_a3.max()] - # max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()] - # max_rep = [torch.tensor(1, device=device)] * 3 - # logging.info(f"&&& max_rep: {max_rep}") - - # Tensor of unit cells - cells_per_dim = [ - torch.arange(-rep.item(), rep.item() + 1, device=device, dtype=dtype) - for rep in max_rep - ] - unit_cell = torch.cartesian_prod(*cells_per_dim) - num_cells = len(unit_cell) - unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(len(index2), 1, 1) - unit_cell = torch.transpose(unit_cell, 0, 1) - unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1) - - # Compute the x, y, z positional offsets for each cell in each image - # data_cell = torch.transpose(data.cell, 1, 2) - data_cell = torch.transpose(data.cell, 1, 2) - pbc_offsets = torch.bmm(data_cell, unit_cell_batch) - pbc_offsets_per_atom = torch.repeat_interleave( - pbc_offsets, num_atoms_per_image_sqr, dim=0 - ) - - # Expand the positions and indices for the 9 cells - pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells) - pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells) - index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1) - index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1) - # Add the PBC offsets for the second atom - pos2 = pos2 + pbc_offsets_per_atom - - # Compute the squared distance between atoms - atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1) - atom_distance_sqr = atom_distance_sqr.view(-1) - - # Remove pairs that are too far apart - mask_within_radius = torch.le(atom_distance_sqr, radius * radius) - # Remove pairs with the same atoms (distance = 0.0) - mask_not_same = torch.gt(atom_distance_sqr, 0.0001) - mask = torch.logical_and(mask_within_radius, mask_not_same) - index1 = torch.masked_select(index1, mask) - index2 = torch.masked_select(index2, mask) - unit_cell = torch.masked_select( - unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) - - mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( - natoms=data.natoms, - index=index1, - atom_distance=atom_distance_sqr, - max_num_neighbors_threshold=max_num_neighbors_threshold, - enforce_max_strictly=enforce_max_neighbors_strictly, - ) - - if not torch.all(mask_num_neighbors): - # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors - index1 = torch.masked_select(index1, mask_num_neighbors) - index2 = torch.masked_select(index2, mask_num_neighbors) - unit_cell = torch.masked_select( - unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) - ) - unit_cell = unit_cell.view(-1, 3) - - edge_index = torch.stack((index2, index1)) - - return edge_index, unit_cell, num_neighbors_image - - -@torch.compiler.disable -def get_max_neighbors_mask( - natoms, - index, - atom_distance, - max_num_neighbors_threshold, - degeneracy_tolerance: float = 0.01, - enforce_max_strictly: bool = False, -): - """ - Give a mask that filters out edges so that each atom has at most - `max_num_neighbors_threshold` neighbors. - Assumes that `index` is sorted. - - Enforcing the max strictly can force the arbitrary choice between - degenerate edges. This can lead to undesired behaviors; for - example, bulk formation energies which are not invariant to - unit cell choice. - - A degeneracy tolerance can help prevent sudden changes in edge - existence from small changes in atom position, for example, - rounding errors, slab relaxation, temperature, etc. - """ - - device = natoms.device - num_atoms = natoms.sum() - - # Get number of neighbors - # segment_coo assumes sorted index - ones = index.new_ones(1).expand_as(index) - num_neighbors = segment_coo(ones, index, dim_size=num_atoms) - max_num_neighbors = num_neighbors.max() - num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) - - # Get number of (thresholded) neighbors per image - image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) - image_indptr[1:] = torch.cumsum(natoms, dim=0) - num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) - - # If max_num_neighbors is below the threshold, return early - if ( - max_num_neighbors <= max_num_neighbors_threshold - or max_num_neighbors_threshold <= 0 - ): - mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as( - index - ) - return mask_num_neighbors, num_neighbors_image - - # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. - # Fill with infinity so we can easily remove unused distances later. - distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) - - # Create an index map to map distances from atom_distance to distance_sort - # index_sort_map assumes index to be sorted - index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors - index_neighbor_offset_expand = torch.repeat_interleave( - index_neighbor_offset, num_neighbors - ) - index_sort_map = ( - index * max_num_neighbors - + torch.arange(len(index), device=device) - - index_neighbor_offset_expand - ) - distance_sort.index_copy_(0, index_sort_map, atom_distance) - distance_sort = distance_sort.view(num_atoms, max_num_neighbors) - - # Sort neighboring atoms based on distance - distance_sort, index_sort = torch.sort(distance_sort, dim=1) - - # Select the max_num_neighbors_threshold neighbors that are closest - if enforce_max_strictly: - distance_sort = distance_sort[:, :max_num_neighbors_threshold] - index_sort = index_sort[:, :max_num_neighbors_threshold] - max_num_included = max_num_neighbors_threshold - - else: - effective_cutoff = ( - distance_sort[:, max_num_neighbors_threshold] + degeneracy_tolerance - ) - is_included = torch.le(distance_sort.T, effective_cutoff) - - # Set all undesired edges to infinite length to be removed later - distance_sort[~is_included.T] = np.inf - - # Subselect tensors for efficiency - num_included_per_atom = torch.sum(is_included, dim=0) - max_num_included = torch.max(num_included_per_atom) - distance_sort = distance_sort[:, :max_num_included] - index_sort = index_sort[:, :max_num_included] - - # Recompute the number of neighbors - num_neighbors_thresholded = num_neighbors.clamp(max=num_included_per_atom) - - num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) - - # Offset index_sort so that it indexes into index - index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( - -1, max_num_included - ) - # Remove "unused pairs" with infinite distances - mask_finite = torch.isfinite(distance_sort) - index_sort = torch.masked_select(index_sort, mask_finite) - - # At this point index_sort contains the index into index of the - # closest max_num_neighbors_threshold neighbors per atom - # Create a mask to remove all pairs not in index_sort - mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) - mask_num_neighbors.index_fill_(0, index_sort, True) - - return mask_num_neighbors, num_neighbors_image - - -def get_pruned_edge_idx( - edge_index, num_atoms: int, max_neigh: float = 1e9 -) -> torch.Tensor: - assert num_atoms is not None # TODO: Shouldn't be necessary - - # removes neighbors > max_neigh - # assumes neighbors are sorted in increasing distance - _nonmax_idx_list = [] - for i in range(num_atoms): - idx_i = torch.arange(len(edge_index[1]))[(edge_index[1] == i)][:max_neigh] - _nonmax_idx_list.append(idx_i) - return torch.cat(_nonmax_idx_list) +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import torch.nn as nn +import torch_geometric +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +from torch_geometric.data import Data +from torch_geometric.utils import remove_self_loops +from torch_scatter import scatter, segment_coo, segment_csr + + +if TYPE_CHECKING: + from collections.abc import Mapping + + from torch.nn.modules.module import _IncompatibleKeys + + +DEFAULT_ENV_VARS = { + # Expandable segments is a new cuda feature that helps with memory fragmentation during frequent allocations (ie: in the case of variable batch sizes). + # see https://pytorch.org/docs/stable/notes/cuda.html. + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", +} + + +def get_pbc_distances( + pos, + edge_index, + cell, + cell_offsets, + neighbors, + return_offsets: bool = False, + return_distance_vec: bool = False, +): + row, col = edge_index + + distance_vectors = pos[row] - pos[col] + + # correct for pbc + neighbors = neighbors.to(cell.device) + cell = torch.repeat_interleave(cell, neighbors, dim=0) + offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) + distance_vectors += offsets + + # compute distances + distances = distance_vectors.norm(dim=-1) + + # redundancy: remove zero distances + nonzero_idx = torch.arange(len(distances), device=distances.device)[distances != 0] + edge_index = edge_index[:, nonzero_idx] + distances = distances[nonzero_idx] + + out = { + "edge_index": edge_index, + "distances": distances, + } + + if return_distance_vec: + out["distance_vec"] = distance_vectors[nonzero_idx] + + if return_offsets: + out["offsets"] = offsets[nonzero_idx] + + return out + +def radius_graph_pbc_mem_effi( + data, + radius, + max_num_neighbors_threshold, + enforce_max_neighbors_strictly: bool = False, + pbc=None, + dtype=torch.float64, +): + if pbc is None: + pbc = [True, True, True] + device = data.pos.device + batch_size = len(data.natoms) + + if hasattr(data, "pbc"): + data.pbc = torch.atleast_2d(data.pbc) + for i in range(3): + if not torch.any(data.pbc[:, i]).item(): + pbc[i] = False + elif torch.all(data.pbc[:, i]).item(): + pbc[i] = True + else: + raise RuntimeError( + "Different structures in the batch have different PBC configurations. This is not currently supported." + ) + + # position of the atoms + atom_pos = data.pos + + # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch + num_atoms_per_image = data.natoms + num_atoms_per_image_sqr = (num_atoms_per_image**2).long() + + # index offset between images + index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image + + index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) + num_atoms_per_image_expand = torch.repeat_interleave( + num_atoms_per_image, num_atoms_per_image_sqr + ) + + # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image + # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement + # the following (but 10x faster since it removes the for loop) + # for batch_idx in range(batch_size): + # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) + num_atom_pairs = torch.sum(num_atoms_per_image_sqr) + index_sqr_offset = ( + torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr + ) + index_sqr_offset = torch.repeat_interleave( + index_sqr_offset, num_atoms_per_image_sqr + ) + atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset + + # Compute the indices for the pairs of atoms (using division and mod) + # If the systems get too large this apporach could run into numerical precision issues + index1 = ( + torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") + ) + index_offset_expand + index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand + # Get the positions for each atom + pos1 = torch.index_select(atom_pos, 0, index1) + pos2 = torch.index_select(atom_pos, 0, index2) + + # Calculate required number of unit cells in each direction. + # Smallest distance between planes separated by a1 is + # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. + # Note that the unit cell volume V = a1 * (a2 x a3) and that + # (a2 x a3) / V is also the reciprocal primitive vector + # (crystallographer's definition). + + cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) + cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) + + if pbc[0]: + inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) + rep_a1 = torch.ceil(radius * inv_min_dist_a1) + else: + rep_a1 = data.cell.new_zeros(1) + + if pbc[1]: + cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) + inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) + rep_a2 = torch.ceil(radius * inv_min_dist_a2) + else: + rep_a2 = data.cell.new_zeros(1) + + if pbc[2]: + cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) + inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) + rep_a3 = torch.ceil(radius * inv_min_dist_a3) + else: + rep_a3 = data.cell.new_zeros(1) + + # Take the max over all images for uniformity. This is essentially padding. + # Note that this can significantly increase the number of computed distances + # if the required repetitions are very different between images + # (which they usually are). Changing this to sparse (scatter) operations + # might be worth the effort if this function becomes a bottleneck. + max_rep = [int(2*rep_a1.max().item()), int(2*rep_a2.max().item()), int(2*rep_a3.max().item())] + + # Memory-efficient implementation: iterate over unit cell offsets instead of expanding all at once + # This reduces memory usage by avoiding the creation of large tensor products + all_index1 = [] + all_index2 = [] + all_unit_cell = [] + all_atom_distance_sqr = [] + + # Pre-transpose data_cell for efficiency + data_cell = torch.transpose(data.cell, 1, 2) + + # Iterate over each unit cell offset combination + for i in range(-max_rep[0], max_rep[0] + 1): + for j in range(-max_rep[1], max_rep[1] + 1): + for k in range(-max_rep[2], max_rep[2] + 1): + # Create unit cell offset + unit_cell_offset = torch.tensor([i, j, k], device=device, dtype=dtype) + + # Compute the x, y, z positional offsets for this specific cell in each image + # unit_cell_offset_batch = unit_cell_offset.view(3, 1).expand(3, batch_size) + unit_cell_offset_batch = unit_cell_offset.view(1,3,1).expand(batch_size, -1, -1) + pbc_offsets = torch.bmm(data_cell, unit_cell_offset_batch).squeeze(-1) + pbc_offsets_per_atom = torch.repeat_interleave( + pbc_offsets, num_atoms_per_image_sqr, dim=0 + ) + + # Apply PBC offsets to the second atom positions + pos2_offset = pos2 + pbc_offsets_per_atom + + # Compute the squared distance between atoms + atom_distance_sqr = torch.sum((pos1 - pos2_offset) ** 2, dim=1) + + # Remove pairs that are too far apart + mask_within_radius = torch.le(atom_distance_sqr, radius * radius) + # Remove pairs with the same atoms (distance = 0.0) + mask_not_same = torch.gt(atom_distance_sqr, 0.0001) + mask = torch.logical_and(mask_within_radius, mask_not_same) + + # Only keep valid pairs for this unit cell offset + if torch.any(mask): + valid_index1 = torch.masked_select(index1, mask) + valid_index2 = torch.masked_select(index2, mask) + valid_distances = torch.masked_select(atom_distance_sqr, mask) + valid_unit_cell = unit_cell_offset.unsqueeze(0).repeat(valid_index1.shape[0], 1) + + all_index1.append(valid_index1) + all_index2.append(valid_index2) + all_unit_cell.append(valid_unit_cell) + all_atom_distance_sqr.append(valid_distances) + + # Concatenate all results + if len(all_index1) > 0: + index1 = torch.cat(all_index1) + index2 = torch.cat(all_index2) + unit_cell = torch.cat(all_unit_cell) + atom_distance_sqr = torch.cat(all_atom_distance_sqr) + + # Sort index1 in ascending order and rearrange other arrays correspondingly + sort_indices = torch.argsort(index1) + index1 = index1[sort_indices] + index2 = index2[sort_indices] + unit_cell = unit_cell[sort_indices] + atom_distance_sqr = atom_distance_sqr[sort_indices] + + else: + # No valid pairs found + index1 = torch.empty(0, dtype=torch.long, device=device) + index2 = torch.empty(0, dtype=torch.long, device=device) + unit_cell = torch.empty(0, 3, dtype=dtype, device=device) + atom_distance_sqr = torch.empty(0, dtype=dtype, device=device) + + mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( + natoms=data.natoms, + index=index1, + atom_distance=atom_distance_sqr, + max_num_neighbors_threshold=max_num_neighbors_threshold, + enforce_max_strictly=enforce_max_neighbors_strictly, + ) + + if not torch.all(mask_num_neighbors): + # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors + index1 = torch.masked_select(index1, mask_num_neighbors) + index2 = torch.masked_select(index2, mask_num_neighbors) + unit_cell = torch.masked_select( + unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + + edge_index = torch.stack((index2, index1)) + + return edge_index, unit_cell, num_neighbors_image + + +def radius_graph_pbc( + data, + radius, + max_num_neighbors_threshold, + enforce_max_neighbors_strictly: bool = False, + pbc=None, + dtype=torch.float64, +): + if pbc is None: + pbc = [True, True, True] + device = data.pos.device + batch_size = len(data.natoms) + + if hasattr(data, "pbc"): + data.pbc = torch.atleast_2d(data.pbc) + for i in range(3): + if not torch.any(data.pbc[:, i]).item(): + pbc[i] = False + elif torch.all(data.pbc[:, i]).item(): + pbc[i] = True + else: + raise RuntimeError( + "Different structures in the batch have different PBC configurations. This is not currently supported." + ) + + # position of the atoms + atom_pos = data.pos + + # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch + num_atoms_per_image = data.natoms + num_atoms_per_image_sqr = (num_atoms_per_image**2).long() + + # index offset between images + index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image + + index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) + num_atoms_per_image_expand = torch.repeat_interleave( + num_atoms_per_image, num_atoms_per_image_sqr + ) + + # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image + # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement + # the following (but 10x faster since it removes the for loop) + # for batch_idx in range(batch_size): + # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) + num_atom_pairs = torch.sum(num_atoms_per_image_sqr) + index_sqr_offset = ( + torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr + ) + index_sqr_offset = torch.repeat_interleave( + index_sqr_offset, num_atoms_per_image_sqr + ) + atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset + + # Compute the indices for the pairs of atoms (using division and mod) + # If the systems get too large this apporach could run into numerical precision issues + index1 = ( + torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") + ) + index_offset_expand + index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand + # Get the positions for each atom + pos1 = torch.index_select(atom_pos, 0, index1) + pos2 = torch.index_select(atom_pos, 0, index2) + + # Calculate required number of unit cells in each direction. + # Smallest distance between planes separated by a1 is + # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. + # Note that the unit cell volume V = a1 * (a2 x a3) and that + # (a2 x a3) / V is also the reciprocal primitive vector + # (crystallographer's definition). + + cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) + cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) + + if pbc[0]: + inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) + rep_a1 = torch.ceil(radius * inv_min_dist_a1) + else: + rep_a1 = data.cell.new_zeros(1) + + if pbc[1]: + cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) + inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) + rep_a2 = torch.ceil(radius * inv_min_dist_a2) + else: + rep_a2 = data.cell.new_zeros(1) + + if pbc[2]: + cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) + inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) + rep_a3 = torch.ceil(radius * inv_min_dist_a3) + else: + rep_a3 = data.cell.new_zeros(1) + + # Take the max over all images for uniformity. This is essentially padding. + # Note that this can significantly increase the number of computed distances + # if the required repetitions are very different between images + # (which they usually are). Changing this to sparse (scatter) operations + # might be worth the effort if this function becomes a bottleneck. + max_rep = [2*rep_a1.max(), 2*rep_a2.max(), 2*rep_a3.max()] + # max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()] + # max_rep = [torch.tensor(1, device=device)] * 3 + # logging.info(f"&&& max_rep: {max_rep}") + + # Tensor of unit cells + cells_per_dim = [ + torch.arange(-rep.item(), rep.item() + 1, device=device, dtype=dtype) + for rep in max_rep + ] + unit_cell = torch.cartesian_prod(*cells_per_dim) + num_cells = len(unit_cell) + unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(len(index2), 1, 1) + unit_cell = torch.transpose(unit_cell, 0, 1) + unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1) + + # Compute the x, y, z positional offsets for each cell in each image + # data_cell = torch.transpose(data.cell, 1, 2) + data_cell = torch.transpose(data.cell, 1, 2) + pbc_offsets = torch.bmm(data_cell, unit_cell_batch) + pbc_offsets_per_atom = torch.repeat_interleave( + pbc_offsets, num_atoms_per_image_sqr, dim=0 + ) + + # Expand the positions and indices for the 9 cells + pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells) + pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells) + index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1) + index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1) + # Add the PBC offsets for the second atom + pos2 = pos2 + pbc_offsets_per_atom + + # Compute the squared distance between atoms + atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1) + atom_distance_sqr = atom_distance_sqr.view(-1) + + # Remove pairs that are too far apart + mask_within_radius = torch.le(atom_distance_sqr, radius * radius) + # Remove pairs with the same atoms (distance = 0.0) + mask_not_same = torch.gt(atom_distance_sqr, 0.0001) + mask = torch.logical_and(mask_within_radius, mask_not_same) + index1 = torch.masked_select(index1, mask) + index2 = torch.masked_select(index2, mask) + unit_cell = torch.masked_select( + unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) + + mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( + natoms=data.natoms, + index=index1, + atom_distance=atom_distance_sqr, + max_num_neighbors_threshold=max_num_neighbors_threshold, + enforce_max_strictly=enforce_max_neighbors_strictly, + ) + + if not torch.all(mask_num_neighbors): + # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors + index1 = torch.masked_select(index1, mask_num_neighbors) + index2 = torch.masked_select(index2, mask_num_neighbors) + unit_cell = torch.masked_select( + unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + + edge_index = torch.stack((index2, index1)) + + return edge_index, unit_cell, num_neighbors_image + + +@torch.compiler.disable +def get_max_neighbors_mask( + natoms, + index, + atom_distance, + max_num_neighbors_threshold, + degeneracy_tolerance: float = 0.01, + enforce_max_strictly: bool = False, +): + """ + Give a mask that filters out edges so that each atom has at most + `max_num_neighbors_threshold` neighbors. + Assumes that `index` is sorted. + + Enforcing the max strictly can force the arbitrary choice between + degenerate edges. This can lead to undesired behaviors; for + example, bulk formation energies which are not invariant to + unit cell choice. + + A degeneracy tolerance can help prevent sudden changes in edge + existence from small changes in atom position, for example, + rounding errors, slab relaxation, temperature, etc. + """ + + device = natoms.device + num_atoms = natoms.sum() + + # Get number of neighbors + # segment_coo assumes sorted index + ones = index.new_ones(1).expand_as(index) + num_neighbors = segment_coo(ones, index, dim_size=num_atoms) + max_num_neighbors = num_neighbors.max() + num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) + + # Get number of (thresholded) neighbors per image + image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) + image_indptr[1:] = torch.cumsum(natoms, dim=0) + num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + + # If max_num_neighbors is below the threshold, return early + if ( + max_num_neighbors <= max_num_neighbors_threshold + or max_num_neighbors_threshold <= 0 + ): + mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as( + index + ) + return mask_num_neighbors, num_neighbors_image + + # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. + # Fill with infinity so we can easily remove unused distances later. + distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) + + # Create an index map to map distances from atom_distance to distance_sort + # index_sort_map assumes index to be sorted + index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors + index_neighbor_offset_expand = torch.repeat_interleave( + index_neighbor_offset, num_neighbors + ) + index_sort_map = ( + index * max_num_neighbors + + torch.arange(len(index), device=device) + - index_neighbor_offset_expand + ) + distance_sort.index_copy_(0, index_sort_map, atom_distance) + distance_sort = distance_sort.view(num_atoms, max_num_neighbors) + + # Sort neighboring atoms based on distance + distance_sort, index_sort = torch.sort(distance_sort, dim=1) + + # Select the max_num_neighbors_threshold neighbors that are closest + if enforce_max_strictly: + distance_sort = distance_sort[:, :max_num_neighbors_threshold] + index_sort = index_sort[:, :max_num_neighbors_threshold] + max_num_included = max_num_neighbors_threshold + + else: + effective_cutoff = ( + distance_sort[:, max_num_neighbors_threshold] + degeneracy_tolerance + ) + is_included = torch.le(distance_sort.T, effective_cutoff) + + # Set all undesired edges to infinite length to be removed later + distance_sort[~is_included.T] = np.inf + + # Subselect tensors for efficiency + num_included_per_atom = torch.sum(is_included, dim=0) + max_num_included = torch.max(num_included_per_atom) + distance_sort = distance_sort[:, :max_num_included] + index_sort = index_sort[:, :max_num_included] + + # Recompute the number of neighbors + num_neighbors_thresholded = num_neighbors.clamp(max=num_included_per_atom) + + num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + + # Offset index_sort so that it indexes into index + index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( + -1, max_num_included + ) + # Remove "unused pairs" with infinite distances + mask_finite = torch.isfinite(distance_sort) + index_sort = torch.masked_select(index_sort, mask_finite) + + # At this point index_sort contains the index into index of the + # closest max_num_neighbors_threshold neighbors per atom + # Create a mask to remove all pairs not in index_sort + mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) + mask_num_neighbors.index_fill_(0, index_sort, True) + + return mask_num_neighbors, num_neighbors_image + + +def get_pruned_edge_idx( + edge_index, num_atoms: int, max_neigh: float = 1e9 +) -> torch.Tensor: + assert num_atoms is not None # TODO: Shouldn't be necessary + + # removes neighbors > max_neigh + # assumes neighbors are sorted in increasing distance + _nonmax_idx_list = [] + for i in range(num_atoms): + idx_i = torch.arange(len(edge_index[1]))[(edge_index[1] == i)][:max_neigh] + _nonmax_idx_list.append(idx_i) + return torch.cat(_nonmax_idx_list) diff --git a/mace-bench/src/batchopt/relaxation/__init__.py b/mace-bench/src/batchopt/relaxation/__init__.py index 05c842e..b2facb4 100644 --- a/mace-bench/src/batchopt/relaxation/__init__.py +++ b/mace-bench/src/batchopt/relaxation/__init__.py @@ -1,11 +1,11 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations -from .optimizable import OptimizableBatch, OptimizableUnitCellBatch - -__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"] +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations +from .optimizable import OptimizableBatch, OptimizableUnitCellBatch + +__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"] diff --git a/mace-bench/src/batchopt/relaxation/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 97873291234c625f16591881a2a0d01c6b6f37cc..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 542 zcmYjOv2NQi5G7?fHk1U&*zv{)NUa|zilRma15H}MX$K7@Xz~;?ph$tFfZCsvrO42+ zKclr%ej!sS6@_!d;qD#Z-MynbKQBnKU%#K!cS^{gF?nrSN zj1oDMxtXkBs^8vcsu6wm))4 z0yl=SC*sHNncT~_O1T!RmLC*%51OY);M>l-ha|4*_E9;@*4r=t@EuyAs^CSn8`n;A ls%f)HUd5PBt24Ih_QUdB>tt{6{U|vslgr83;)0%Ke*t9jp}_zE diff --git a/mace-bench/src/batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/__pycache__/ase_utils.cpython-310.pyc deleted file mode 100644 index 00c49d5df97e8cdf23fbb1800cb92170912d122d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2889 zcmai0&5s*N74PaV+dZ~tlAVOj2dLNugi+#IEfRtdg$a8kI#!>r}NSQq>0&zne>=ldiCm6y|3SU z)x6aTEO`F%x96k(0QFCty!yB>dC#(lJOm=K`c^^`J0rQ>w+Yft=HzbQHME;~x!?B< z?Paxo4c7gvo;Ugp!>eUMzSLhbbUkb4t$vGG%op#1TqA4e%l&1TyGa1=(x2=nq`%Ve z8qG#4xtv@{-b^k%_4?-$Z^!CfjX=8akE!Sy-7MlRjf#Y(O44XB zNV7CjOm>5yH%cXyd@N!{W1cVy(=3fyAz4DlMZyGCBS!CS_D(<1A7sFM^F?KNY;lSpg2;-r48`FECXdqL-qg;L1oz zn6BK5$}%m6+kziVdXtg?uhskY_C~mVcVqqi&8?qq0mnlPD1KoFWHYwS!Ls!zi%}mJ zuu%U=dT=7-{WwzbsN=jig6D+;dIchuKl|gq z|MB>i_I7c*sKXXr1MzXDnX4?Z=(JfvuU4JB^GAy~&?k?w< zc33I3H^_LTv@e(%i{j-{mP+-ZvGPHTXgztBfoQ-7rV{xMz~#`|v-U_rE?H9->6tyZ zKDL!J! zwPveTdrpW|1+xoP3#pIpQ~^}kQuKES+d41ynNT#MNu(lb@`x~wB%>zmK|tZTEr#-j zp>SrTZoIPNbmchrum-sYR7X%Y2J@w0$bwp2h|mRwpkzjo0+B4EyNpIj0urbmD&v2k zwo5k$wBW~Q$WygwR3sQ3ru(cwf%p{YUp^b#0F*#3Dc{`#L__z${rH?9@`j-(U=!+^ z7ZN;e$;$%A!^Mezm}#TF2$Rap<- z{*65)Q@bL{s-W#ED3vGVp)D>~_KYZd<^Z6qxjiQ{cWxB{UI9-n<<0!cJ|uG+i{KCT zM>asl5qSD*Ds5n3UGAXDt3V7zk2N5eL z&}QuJJDTi^ufPdrc=^k<5zpCL9z7mKd@qfFirMv2a7@`+IZ-2CgzG!o3;r%E;?bHE z@tVOVFV&i0S#)5U<{E%G97B_q-EyLXn;6TO=jE**f;NdBT(urwHycC6LxYX(&5Q$x z1gYEGC#?IZ_+Z);$GH7Co;|A@S_NP5r z!Zq>S5m!K?b{?{ccJ`yJV~aL2e1HUiauI^BA}2Ug{e75A6!R_LwZB*)7mrqL|BJ_O z{eP>W)4O>jA9kExuh;Q)Z2?l~x(H+4I0d`E2OB^b}8Si@b{4Oc|Pkcl^uaR~>sTc8)B zhJz1-ZcC=Zv*w|FxN2FEbe2ty1l~I091h^Q)o&D`oJdqAt|9Y#ICOD9wc5h~%YMtm zFFbBG;s>~n7KwLJ$VF`?98g}EVewrY(0yT0(0)C|Uxx+x9t_R%0cn$qq)wW|cm2=X zwKnmIYd6U%yml=hHoO5@0j~eD-tcjLv`YM=z-ij9Eq(-tbeyf8Xdw|Sl;ah?aB~GI z<0}T$I4o9?x`4w)7&_;5EesPLhoKI_a4?1|n}wmkiUQEXpNSNR6exrV{;&va6ef?_ zfs!<_*H?h*o<&@CrD^;n{;n)?vTP(yK`BkT^r|VUaLAxb;7>^x?=SNQA}~b8yH~Jh r8jFc%{U-dBjWc!&JEwdL2G?pfn-I4;eE4DvZMRO`BfEw%T>1RJ77ZD> diff --git a/mace-bench/src/batchopt/relaxation/__pycache__/optimizable.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/__pycache__/optimizable.cpython-310.pyc deleted file mode 100644 index 824d09ee31525297386353934b8828c591fcb4e6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 22681 zcmdsfYj9lGec#^4E_N4-2SJbo!H2FWQsPSFQX(bAVrZEZ9};5;rbSAQR#rBPJr}@& z`vA|qAc@VgwnL@%#HO3n<8jj@9l(#qOd7Xo;-=}ik7=iM+vJ0rq#t^lNn5+Ms-|x2 zXqrq}OYHCWKYRB9fSP2+U(yBk?78Rt_@Dp#+~&wgHi6Hte)aak)YU}dTYTvMr15Yv zk+6cjb|UdoLM4g`WhuL6)$O8fG49lyx?6N5?$(m^R52y-WNoOPE~e|5Vy2!gX6wVn zVM$Nba>WtUOVvh;qw+ge9K-KWZM?pvxJAmOYZLXY#jO(0)F$iOirXZft!=OGDDJQl z&nDEc%DtXYxtn%zs+C6lk=oAsbaC2Bd??WvcN6}u6=fy(=Z~%KR-_ljQq{^I@heZ_rJb5i>HTYWVrIyF7? zcR#-SIxS?;!q(aY_5H>D^#jENef|DcU+u|G?E`%y-Bx?B{!sCuz8b%4YG56a4%9l`Sv{}`ShSE-AM z{L$8*WKTDjR&;fKA@mMb^4<%6SU%*PZ&V6ixuLvj7q3woXuWZs0LopT-Fuet2CAG;jvb&_>I6<-f~0v+6xza?}hW1;$&~GiFB{p zU`F=&^QX_wTs-TcL)4WtPdA&Pr>feogiXETHRn*W*Gs5q|Ai?l;16SZy{6**AGw@>VolZZw)EF~H+krCeF?RcT2#mwX+rFzKSy zohjG-V5wa3^KLY9>6Hs-OQ)YZd;0z7XP%AHr)%XPc)6@2=Tu`Q$~@~gd|k#sXHb3D zg-%m5&G39Y2w$Dmx~Y*peYIJg4@*y$!e(i1xmGJ3yMC+`U_t?hGZa+HVd%5q=xjNx zEU*|_t_Ky4zyD+5;UrG{TRBAhMA5>Tx0QwcwWIOUndN$Y<#f4LS+2>zZaY!BRBB*O zrBakFmFi8kTw^>}DqUGF*G$gsgx-qQB)Id+k%ea6KTo+P3M*>|rBI9l@g-5hsD_@u8KZ3_ z8BNA%XO8vSi*o%>t^FHXk1O@66tD?nOWF^0x!M4D)oazIXyV+d^DmxW@aw&Dn>4nu z5GOhOf!+$<6BP6amNRebz4)2YV+ecOX=0O1WA? zbLN!RWuA)UGX%${xnYr8@oTkclyF(Clp4$RSziZ{v(#iTTWO^OKYSH1{aQ49>iOq! zZcfcSd-mcRx4!e$)j#`JfA8riB_Q(Hot(7kwY`~U!|$XX)%)0m6EdF*b&pYKSIJKvn^`)}*y$;I3WV~5FyypABeAk+u ztc@RJvmQRD^n7q!Vtg{ocH9$$V4eo^j&7cZwLFd^R0ER227AlI%Q#V$&Oos7K=q+u zSy}OFW8Pb8RvTg87`q*9vKhyL*MpEqx^GVCpi|VKXVY6XXg18$UX<k;&4U_-p^j4Jr6s?iP%DpG`Vj_)7!=sZs2OLe zGtOvhZ|+3epemKik;}xsL&)<7nqwBe)Y4h;-iQ^p!Q3A4Dkc% zL5Yte{*XE-@hynw)x#2>P_wF{4yyus+Nu;87W6a;*2PzI>T&f1(zmI3^`!bCq;FRX z>M3;`u^lK;Rg3C`dM`>$p~TbbeMsMl*hzH?v1xT#omOWM+a;~l)H(Gm%I-$`bLu?O z_o%vhzj_|Ay(rO8P4xkF0VVFsxS0>CC3Qu;tUiQiPiggvDx%!|>Q(jO*AvBkDo{VB zN=SJ?@$`n#cpnmehUl>MFHU(9@vZ^`g3sU?CozddQGC ztqF!v22YJ(uBq#Hw#&IHG#Ul4Y2YhoMPRP zc4BcPcwC`{Mf<9yA8#jCN7o?F@Og1;k-s_R-a>nec4;Jw$<$lV_Ka2`NS{UBqYM$rOF4KMB;p-8dl4-YHJYM1RM>tkz@=K5T?l99rsR| zhyX09y*IDROADSVhh>2Z7Uc(wq7UkEssZ>4LZKgoae#f`UlQ&{d*JY^S3y$*+y-iQ z!QptP{cgfQu|beH?p49=>t)}&R;|@MW_kwM%CV3M^vFQDrv0*70X6GD)IiGuR>!?- z3qB+b$RVJ503IBYV7Ue<2hUOh1O;r4IS*#un0Xfzk4v0XF&;*j2z7g97Y!U|mO2Gt zCO&97=yrw!s;XD3#1kO-D_+aj&4IzPXAxzX)1_|}5Is~i!vS3mua7ERXMj zUM}H~U-P}PQrMfI4r0cMfiL3}aa6$Rh9W{026`MkU4!dn4jvk(Y~rb9Eo^}~br`@v zJL^y7aUwP#Ff4WFoDwFFI6FiUcYVW9{~br9qU8Sa$c~$@nOR7 zh>a^<(MiT}+{^EX?64VuWL<@5t52dGeTu;|49+n)gCH6yRg6R_U<&bSc?oCAFLC># z5g8_Mw?JPs6h~Y^cpO`lX@+wpIg3%QQ&pby0e0~KgXb8WXW{9R@*yqOA!DH1SlBCZ z!-q5nE5Bo_tLHQFZjr1%yg5lQLGx7T4a6wEZ$O1j~R-VVN43S>}3oQ8n1ZX{gBRZH= zn8ctExfHGScFZSARTkzOkQ;a4p_mGIa2tAxIgqOO0M<1fa|U`3GlrBuR|Y!ECnJ|5 zFmO4JEWL&K$y)OoR0@uM6;H(ppkTm@E(x|4l|XScPBUtOb-S!TgnD1YH`t9J=j5y$ zSX0-|*$gKwH)&$7-Fjr>hTaoG2R4?Q%r7_gEh;mymDt#eiQCo;Bw9IXA3!pWFlVuT zgv^TYeal9Wwp%+kjKiGUd-fT0`%GYCt=0#h^@%SbT9D5JegLllJRB_fmFnCIG2yj^ zDl%epd<2L})L$BBda2n^LR(B3X`+CUpTywwJcF*5MTHITheHbjE1uSP1TaHAhipK^2RWE05Jba34pqQy3EHtu#|TcW zLtq7$eG8QW3e;IU2O*_(--dw>mftYaE>E@}jdVtu&|@Kva!6jvy=jF>W!sOE6xek5!+m&|y1>rZCxMV!yElxzukeO}cbAywHB77u zV>gDSRo2eLsQAw}O8TApS&R6<>-=ps;oud!ww&uKrI* zbA+xwhnTA$N0?kpu_V4j%DH8KB2h~7c`>t^#qaPMbF5NE(sBc7xyEP05vIpFid*gt z)JKb>ZT@!KYkUMxj!J_Lm+oCdWp1IbQo6AqrL!pgrFSU}6V)yIh7*o0j;j$h`g)@6 zoPquWb+_Opmj2Y$t!hGTRa@5_HTeq;HH$KqeKMA9 z7|Yk*Wi0W?hfx}{M`%H71LM&xhpmJV3X!(`rZq*q3ZiY3wZQcf>R;V3IwNkLe;#A1#=7XZY60&(yqV1EbTjQEV_A1zt|PuD!q z{A=kO7Pzbg@MvN!@w4_7Ab%+DEGnhlDzt*P?Z_@XO0}A=))Ud}2_j2MG*1@*w}LUJ z60O4icg;!L4!)zkt&qevfN{=5fm}lXKO9-8QEvq zPQc4aREUyL#Ws9AeH2gnhZzj6N|Zf8?T_sCg?P7j|2cQNThsKpQ(tEXCm9H}`BBCw zUG34W#5|WHatjG}MuKJ7mkqud>Tg>MC$G6~kAn{x#7?_prU@ zc@PQ!7Du?;reA|l|v?ZXgFZ=g{Kk{Gf| zfBm3Mbyh|)XyNePa zKx2s$)D#by-66^%DQfK)lnZ0=FUY42{t%q-Gn}QY%NfQlF(B`+e*poInw)JZxo$m1 zWIbXH-(@Zc5lYVHf^3IP2qN}IiEyA&l&`rLiz34PcB|S|3H2Q8vU%#f3YKGlX>V0)1vR}3f~3LWI?bblQa3TDBIJZM0R5-8fqv} z-wL|bpFyc!d4r#28${NXZ8*syvi>hI_A-K))BZf-RL2CD*Td=vj3Z=vjP`YzdqcJ* z2WFePtOg!W%d z%V1U8mfjbOOJFvf%die87n?w1v=gQlhp?@CVSr8p(qZ<4eZvJjJF6KQ0N1$q>DDg7 zat9v+xfzDg(EEbEg+8IOkW>3N5f{er4MgOle;PmflL+!zy~;QxG`+;&Bm$W3tJlGz zeUS-YV(_;aTw(B81o@;n1h<(WeXzeMBehCZb$$KnX9?tHY#-;+kD~af>&l@3(ZSK! zM9k~e*%T|hpTtD4g0+gZRrZ>#9FT)th_xhn=!$ts;Or@Yud7m&nd8`MQiT1X){}!P zX?Q2t+eCvDEWIRFTxKLN00tD8_S~h`W9!rUWkxXT(DxHHcm12N z%CbM7;*`u<&T9r@hVMfmB>UvIIEVcaCxqdFp#0*;z=ei04DA9?R7r zN3xOyEt9qk1ojw2!Ky-Fwr?^cto$nHMC_QgTMzwt=0T5>oUt(&K4rtaI5!?v?IQw0 zA-sdx?}~;;1Gc{@iu`KDqx(!i@&6>s8oqI}!Dl#lDf>mVt`|_aHNEL*r4X2AwmX1E zhERvBfS?ihdkPwXL=JXj5jyc8_aj(;0d=61&|ORdZR{rc5X=p8Br!dy0m>88$0?)* zw!Wqu<<^Glxf>2VVQ2^)y=lQy<_tY+7RTB^NWt+dN73&1>K2+M%9@U)|yT(ed=stkUo=-aS{BOI4Gk+^bH zW-%r-i9KUX@PdF=AINWs8p$p!n5a;0|$bA%Tui+NeR{ zFPqXsfiwA?y`7Ab(!zSoC+Kk#(iCl>Res)wr(LL53fw{cTj(xI!jL}i>#OXaL@P?w z=-fDu_M0`B&dEy98U<-$L3oWQg%Q`6Ylg6WoY}N5aXRo|5tid0GI1C#h7i_7S~p!; z)yR6GI0o9*Z>(C695N(d0%2dqdIS%mttm(clc0I@RvD#K04`-3DA|1{2h^OkrmWF- zrrb?nw%gReG2(%A@D*U?@g;g7r$#nq6+j}0!fF+WD9T;4;GB20SzD%SP!~^dm}FXI zf1z6YO*CdG)IUZnaAehmSPnOC&YTgSW{SqjsnEM0H?M!9KduES7#gKNl5NCRpq_1ex9 z!cFwqgZ17x`Khm7-}863ow>7T{i+%fHge_(C}q?3;RHc~fojuPMcQhn4L%XS*|xi8 zbr9^ZK8uZm*E}})xOl+i8fhtyRG?T2&}86qP&YcOBZ72C*Ah93$WE;IhCUJQLR~(m zN5MOLPOfFvhaJ-6Qp|{2DnE8_BuQ4^F!A3&tY@5pcY0()=VZf|@eIq5sgE3JFhx?1 z^(M0OzsLOFXYlVC2~^Uf1yG zYQ_de+8|+E?utp!*Iq=Ez;>U*kEYycsFLWaHsdek7ziB@t!)Blu7l~uH?!vv_O<6Kh#6z-%kJwaHiP#P?i%O42nfe;z zmcGWHR9|<8m?_OjF*Y4nS&=H|0M{ZXO0QDr=lTjUe%sZb$B$;*^v0^U?Q`dDrwkd8 zx|_6eyED3ffb!@^Sc1V;;=%lDmau22{1P1VZ`qent=_n#6&Qg39m{4-x!+|uK}a%5 z!AeqBP}AlU|Bup_naO`=b4Ww*S zEzK{VQ(1VXWHa^)xx3k^Y}Ve2-&}Si>)J`g?#m|mYmeJj?*C@@r<0=@>92zR20r8| z@UeowjxctpraLv69~dZw2||p{)UFSqX>M3OAx5~}YR~H<#gsp!_Nx0{hiFqwD^J}I zXX1?7ryjs>Ry`>t532m@x#FX>>QPRL^) zheD`04&rb;>OaFT!T;(su$FL~t(6yeCu&9TWUvelLs(&6>LQ^a`X1D>^?3{9LKPhl z>DZ|}cKf(6s>V{nqeJuyiSe+Uu( zB^181Exv&ADjkAfi+QJE!N_xE-UGX{OW+t;>-jC-ADLKKQMyS}I4-3aof=*3#n&ok z=c6%7KV?nxND#nx(jXndVuH517^^}w9)B1oR5*l+{lLz+CPIM{LNGOJ;F^Y^o6EIF zAS`ESl!&$ZG(ONX7}O3f_~5V(P`Fm4g=v)pBaI6p5V2FxFC9ltHXLrJR&x-HM?^60 zltTT{#)y<0LdnzOgFM=B!!eZ>;T`_N;TDMQui}^OXTk{@>L=l*HfQq&<}G_Nabu`G zBx6`a?^_oqSGOq$&WqbsR*l@UkW2659pc$M)y|@woeEygW}K6lQ)ybfly{l=Kmezr zV2^lF@8y#kRb#Ki4F)Ceqf1?Wy!^}1J)DMnF0N>ldA(0yi1#%WC zB0=HJm>WUZ!zTC+c+VUX=`MgEQG*w<-&_u27u(JSKLJ7l;yGxGc(V>}$sv49;k4!Z zy*o54akk(+Q}8a8RlSLJpDlP7@ee+n0vob?4Vdw(jYD#g1cl!Z0Y;aDL)|Xe$c2TK zpy0jeUqub?iKh;EM~^+>9enJmr=H6HWplq4Q}m$o?_b18y64RQgY?{{4j$K;JYoYM zmnHbT(f`hzbGga{9MIs!Rfrl?wp{}j1HKXjH?VPi2xSl~xU1&`(a^qUeD~m53Asi< zj2`T&D(0}jBN$efYP@`kPH;nEpbNGHNCuiD)(v3(@HesvkT*-4Qvg+Nd#4)Z-Tk9M*UZ$lY2=u9W4Vsb; z$>2~8vSq_4)VZeu$L(d@7T(aJ+;_n)K_dmXCs9gMe~W=AF_iea;GHqR zkfBV~*A6g9clQT7;|6oxuU0iQ<*_OX6NHy<4ZggvrockHR8ZCO{J~cdfr+Q$A2h}D zuRVM$KiJ+nGPQB!7tFO=ueoe=Dr~sP%g%JM6S&8k5?BUD1lW{`ZCL0&y3dOQasM4v z+93}*$%0_qI<|q!!j8*>PXAm4qoF@+lx1FJSzpCPqZt?haN&^PW7c+HC)pPin;o`+ zjl>@rzv%MJ6+Tdo%R#4$u7hnO5K~D7;Ks`wxN^nqn;pDIz!o76cAWW}=I$)6Xkk#8 zts+{$rAJ~4TwQEX-`jv9CXjQ!VXY3zVw^kw;@Ne(wvLzBxNE0i;vszCk(AmhNVR{$A6u-V60{B&V**a5Azr|tHDCtcJ9zBS zu{^J>9DMRn-c+Eb?;MQ`9mou*n@~t_RCUm+3hZak&4f}@8|)~&1Vx7g1{ACwRdX|@ z-M2U>$B2QBzv6cW8(AAntmzZAa(z~npKjY7ZNcQ=Rhd{A5J{tZQ=`v)q*Fid!AL>b z!-$(A6#fi`5oP&`24h3epGJaF$=4YZh2KA7>@y6q4E`wt+Bq}@ged!Rxwh=fdlZb~ zP{ig2^G`BA)exEzZ80}3;6jaK1=ay~d^5|90G`7>?p6IU%YT=_pRy9f;Eo{h9X`5O zf$wl7nAjS3cV?-@FPW#|&J!_je$*ILPO_cNpoIh0Y_YGu@9}X^(ztC!nNsH>M#%`o zA_h}T-^pFLfJo|zSp4n%rv7X+M;96BnH&ho)L6cgb)m1cv+xzcT^qQ{jM5$h%_TiY z=s&_z^z~}p-=_iEd`}0IQ5c7K1fQnx^Bz3V;}HBC;Nma#6M>iv)iV7esmeo6 z+nR^knK0S<4Qe%s=4EAn5Z5Yty!MBH&6np%3=Qc4uI*gc*rZ3$^|&D|m17H2ajY{cpktYJpW;I{o2r+OaI>k? zaZ@f7%}_O4?%Q_Ky)%&>aQPqHE(4yOTO$K|MjSDOuk0Cps^UF^84Y^^0Tr$baCe~7 z14T8bbzi(Q!WH!~KEPezDX$wRgWpC|(0z$g$=vnD()pTJHT_Q*++iRmex9+Oj)zl? zhE2PI+Ws1+Oa$uURv2`o`|k;reU%0#DONc&Ad~zGB1WAqZc4Bi-T;7Ns+c5S(6E}K zL_L|<0gdt;Rqb$Y>YB41T1XbNtHWY2wt!3><@^-2JqYI-O$&zc%nr-5+ zj-X82{wW*YHE?o*$C#8^0}_U204;QKM)gauyNp4yAEnPon~R=m9pa)70^f$^kEM=+ z!RKGVJZFsUnIt1WuAy+~V89I6Yuw*aeqeSe$$O-MQ54Ib96*g8l=f}gXPF`RN>npM zx=@7DXbcs~vc@$-GQzTN`ui*^Ee-Cjs0BaET_wP!c{vQ2hhggN)X)^LlKoEF$-X_E z+-|)KP($4Bxmj(d(H(GWBLI@VcjwwR#<8*X%CFX2Z;mjvkf7{|V5 zcO_PbMfc~f!J!{^FmY+a?`T_Isz8XKyTa-i?PGi`2ue6oC#g`!ltfEu=u@O!6B4$1 zEsF~STXMqjn+dv1m^`>OpiOa^8)|)uai3cN(T zhrMHmynZo+Sb46haYv<7R5z~~X_`5`Hal{L&{}e;or^7v_-gdSNW<%yyn*Wd_yq1n zxX#;UZ;6S;xrgIghjcrDPSo3ne}o+02c!O#a|FwN=xqH8V#j#;Dc=O zLI|LnSP$5c5KVJEug4M}0u#BdbKU)jRdZcGu#Nx*V$HYkP!{ZpE*CnYOSdvCDVV?U zn3<7ffT3Pi_t%+xm%-N={2GGTPXG52k8%P~9iKD3ZWUSicesi~`4jH6HE^qAkSH4b zG3dkAb1WF}b--fvyraI2fn$L{9&j|_-?zn&`+|U-Xh1xLIqc2hR5`_#6rm&yI5&aYBZ4}?J_7KKAvJ2{3$-%9r z%g*YSwS@i@mAvrR>h|)KjB#;7-$wf~ig2q+;&P^hb50V@(s7HZMRaGg<@W$53sH0z zCvRjhnr-cjN`~7}c1M>KCxsN_?MP5ICzYJEMjc$xwWr}_ zmby#t9B_)lPfQL?Te~6az*oacX*~4I>tGfEJ$$1;2Q|Ww z`$?40|AhhfbN!sYiL^1GU$L(PGxfjwXAm=W?ChEb;(L#0A98UV_ zZO15E{{{#1DF#2yU>vgo&$i0AuyZ|RyVnwA{(kC&2*Z;{hfy4jBn1Hn83Jgn31m?rceU^d5MEN4Q0?6-Y z#o=J71cNsje3$`=rnyZlM{SLdB-mPr6@k@Qs<~Epv4;`V z$)>=rE+k$8p;bcc{|M!_61*Qpgs(NUGd+<#j=9$OpV)J7CmKtqMz`kn+v(ih^oa>a zvo_R62k|olzkxh{^}`GvVIV7fh%s4*0%J!QurK`>g8UOQNxt3DycY(e#9-)EzubWN zv{?&e#!zqIB@uz#Hp6Z76{>f2lC7i|3^5>37~JqB#$?@IK`ff4yJ~z52#zs*aq&9G zHBb~V<_OW^#P2+Tu9Ppf6o z list[Atoms]: - """Convert a data batch to ase Atoms - - Args: - batch: data batch - results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results - are given no calculator will be added to the atoms objects. - wrap_pos: wrap positions back into the cell. - eps: Small number to prevent slightly negative coordinates from being wrapped. - - Returns: - list of Atoms - """ - n_systems = batch.natoms.shape[0] - natoms = batch.natoms.tolist() - numbers = torch.split(batch.atomic_numbers, natoms) - fixed = torch.split(batch.fixed.to(torch.bool), natoms) - if results is not None: - results = { - key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist() - if len(val) == len(batch) - else [v.cpu().detach().numpy() for v in torch.split(val, natoms)] - for key, val in results.items() - } - - positions = torch.split(batch.pos, natoms) - tags = torch.split(batch.tags, natoms) - cells = batch.cell - - atoms_objects = [] - for idx in range(n_systems): - pos = positions[idx].cpu().detach().numpy() - cell = cells[idx].cpu().detach().numpy() - - # TODO take pbc from data - # TODO: &&& ^^^ change this back !!! - # if wrap_pos: - # pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps) - - atoms = Atoms( - numbers=numbers[idx].tolist(), - cell=cell, - positions=pos, - tags=tags[idx].tolist(), - constraint=FixAtoms(mask=fixed[idx].tolist()), - pbc=[True, True, True], - ) - - if results is not None: - calc = SinglePointCalculator( - atoms=atoms, **{key: val[idx] for key, val in results.items()} - ) - atoms.set_calculator(calc) - - atoms_objects.append(atoms) - - return atoms_objects - +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + + + +Utilities to interface OCP models/trainers with the Atomic Simulation +Environment (ASE) +""" + +from __future__ import annotations + +from types import MappingProxyType +from typing import TYPE_CHECKING + +import torch +from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator +from ase.constraints import FixAtoms + +if TYPE_CHECKING: + from torch_geometric.data import Batch + + +# system level model predictions have different shapes than expected by ASE +ASE_PROP_RESHAPE = MappingProxyType( + {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} +) + + +def batch_to_atoms( + batch: Batch, + results: dict[str, torch.Tensor] | None = None, + wrap_pos: bool = True, + eps: float = 1e-7, +) -> list[Atoms]: + """Convert a data batch to ase Atoms + + Args: + batch: data batch + results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results + are given no calculator will be added to the atoms objects. + wrap_pos: wrap positions back into the cell. + eps: Small number to prevent slightly negative coordinates from being wrapped. + + Returns: + list of Atoms + """ + n_systems = batch.natoms.shape[0] + natoms = batch.natoms.tolist() + numbers = torch.split(batch.atomic_numbers, natoms) + fixed = torch.split(batch.fixed.to(torch.bool), natoms) + if results is not None: + results = { + key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist() + if len(val) == len(batch) + else [v.cpu().detach().numpy() for v in torch.split(val, natoms)] + for key, val in results.items() + } + + positions = torch.split(batch.pos, natoms) + tags = torch.split(batch.tags, natoms) + cells = batch.cell + + atoms_objects = [] + for idx in range(n_systems): + pos = positions[idx].cpu().detach().numpy() + cell = cells[idx].cpu().detach().numpy() + + # TODO take pbc from data + # TODO: &&& ^^^ change this back !!! + # if wrap_pos: + # pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps) + + atoms = Atoms( + numbers=numbers[idx].tolist(), + cell=cell, + positions=pos, + tags=tags[idx].tolist(), + constraint=FixAtoms(mask=fixed[idx].tolist()), + pbc=[True, True, True], + ) + + if results is not None: + calc = SinglePointCalculator( + atoms=atoms, **{key: val[idx] for key, val in results.items()} + ) + atoms.set_calculator(calc) + + atoms_objects.append(atoms) + + return atoms_objects + diff --git a/mace-bench/src/batchopt/relaxation/optimizable.py b/mace-bench/src/batchopt/relaxation/optimizable.py index 31c8469..3dc9623 100644 --- a/mace-bench/src/batchopt/relaxation/optimizable.py +++ b/mace-bench/src/batchopt/relaxation/optimizable.py @@ -1,791 +1,791 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -Modified from original Meta implementation. -""" - -from __future__ import annotations - -from functools import cached_property -from types import SimpleNamespace -from typing import TYPE_CHECKING, ClassVar, Any, Generator -import numpy as np -import torch -import logging -from ase.calculators.calculator import PropertyNotImplementedError -from ase.stress import voigt_6_to_full_3x3_stress -from torch_scatter import scatter - -from batchopt.relaxation.ase_utils import batch_to_atoms - - -# Define dummy classes for when imports fail -class _DummyCalculator: - pass - - -try: - from mace.calculators import MACECalculator -except ImportError: - logging.warning("Unable to import MACECalculator.") - MACECalculator = _DummyCalculator - -try: - from chgnet.model.dynamics import CHGNetCalculator -except ImportError: - logging.warning("Unable to import CHGNetCalculator.") - CHGNetCalculator = _DummyCalculator - -try: - from sevenn.calculator import ( - SevenNetCalculator, - SevenNetD3Calculator, - D3Calculator, - ) -except ImportError: - logging.warning("Unable to import SevenNetCalculator.") - SevenNetCalculator = _DummyCalculator - SevenNetD3Calculator = _DummyCalculator - D3Calculator = _DummyCalculator - -try: - from fairchem.core import pretrained_mlip, FAIRChemCalculator -except ImportError: - logging.warning("Unable to import FAIRChemCalculator.") - FAIRChemCalculator = _DummyCalculator - - -# this can be removed after pinning ASE dependency >= 3.23 -try: - from ase.optimize.optimize import Optimizable -except ImportError: - - class Optimizable: - pass - - -if TYPE_CHECKING: - from collections.abc import Sequence - - from ase import Atoms - from numpy.typing import NDArray - from torch_geometric.data import Batch - - -ALL_CHANGES: set[str] = { - "pos", - "atomic_numbers", - "cell", - "pbc", -} - - -# @torch.compile -def compare_batches( - batch1: Batch | None, - batch2: Batch, - tol: float = 1e-6, - excluded_properties: set[str] | None = None, -) -> list[str]: - """Compare properties between two batches - - Args: - batch1: atoms batch - batch2: atoms batch - tol: tolerance used to compare equility of floating point properties - excluded_properties: list of properties to exclude from comparison - - Returns: - list of system changes, property names that are differente between batch1 and batch2 - """ - system_changes = [] - - if batch1 is None: - system_changes = ALL_CHANGES - else: - properties_to_check = set(ALL_CHANGES) - if excluded_properties: - properties_to_check -= set(excluded_properties) - - # Check properties that aren't - for prop in ALL_CHANGES: - if prop in properties_to_check: - properties_to_check.remove(prop) - if not torch.allclose( - getattr(batch1, prop), getattr(batch2, prop), atol=tol - ): - system_changes.append(prop) - - return system_changes - - -class OptimizableBatch(Optimizable): - """A Batch version of ase Optimizable Atoms - - This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation - or in ase relaxations classes, i.e. ase.optimize.lbfgs - """ - - ignored_changes: ClassVar[set[str]] = set() - - def __init__( - self, - batch: Batch, - trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetCalculator | FAIRChemCalculator) - transform: torch.nn.Module | None = None, - mask_converged: bool = True, - numpy: bool = False, - masked_eps: float = 1e-8, - compute_stress: bool = False, - use_fast_predict: bool = True, - dtype: torch.dtype = torch.float64, - ): - """Initialize Optimizable Batch - - Args: - batch: A batch of atoms graph data - model: An instance of a BaseTrainer derived class - transform: graph transform - mask_converged: if true will mask systems in batch that are already converged - numpy: whether to cast results to numpy arrays - masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero - from zero differences in masked positions at future steps, we add a small number to prevent this. - compute_stress: whether to compute stress during prediction - use_fast_predict: use fast prediction method when available - dtype: data type for tensor operations (torch.float32 or torch.float64) - """ - self.batch = batch.to(trainer.device) - self.trainer = trainer - self.transform = transform - self.numpy = numpy - self.mask_converged = mask_converged - self._cached_batch = None - self._update_mask = None - self.torch_results = {} - self.results = {} - self._eps = masked_eps - self.dtype = dtype - - self.otf_graph = True # trainer._unwrapped_model.otf_graph - if not self.otf_graph and "edge_index" not in self.batch: - self.update_graph() - - self.batch.pos = self.batch.pos.to(dtype=self.dtype) - self.batch.cell = self.batch.cell.to(dtype=self.dtype) - self.compute_stress = compute_stress - self.use_fast_predict = use_fast_predict - - # Determine calculator type once during initialization for efficiency - self._calculator_type = self._determine_calculator_type() - logging.info( - f"OptimizableBatch initialized with calculator type: {self._calculator_type}" - ) - - def _determine_calculator_type(self) -> str: - """Determine the type of calculator to avoid repeated isinstance checks.""" - # Check against actual imported classes, not dummy classes - trainer_class_name = type(self.trainer).__name__ - trainer_module = type(self.trainer).__module__ - - if ( - "mace" in trainer_module.lower() - or trainer_class_name == "MACECalculator" - ): - return "mace" - elif ( - "chgnet" in trainer_module.lower() - or trainer_class_name == "CHGNetCalculator" - ): - return "chgnet" - elif "sevenn" in trainer_module.lower() or trainer_class_name in [ - "SevenNetCalculator", - "SevenNetD3Calculator", - "D3Calculator", - ]: - return "sevennet" - elif ( - "fairchem" in trainer_module.lower() - or trainer_class_name == "FAIRChemCalculator" - ): - return "fairchem" - else: - return "default" - - @property - def device(self): - return self.trainer.device - - @property - def batch_indices(self): - """Get the batch indices specifying which position/force corresponds to which batch.""" - return self.batch.batch - - @property - def converged_mask(self): - if self._update_mask is not None: - return torch.logical_not(self._update_mask) - return None - - @property - def update_mask(self): - if self._update_mask is None: - return torch.ones(len(self.batch), dtype=bool) - return self._update_mask - - @property - def converge_indices_list(self): - return torch.where(~self.update_mask)[0].tolist() - - @property - def elem_per_group(self): - # This return value actually represents the number of elements - # in a group within a batch. Each group corresponds to batch_indices. - # It will count the number of CELL elements in each group. - return torch.bincount(self.batch_indices) - - @property - def batch_size(self): - return len(torch.unique(self.batch_indices)) - - def check_state(self, batch: Batch, tol: float = 1e-12) -> bool: - """Check for any system changes since last calculation.""" - return compare_batches( - self._cached_batch, - batch, - tol=tol, - excluded_properties=set(self.ignored_changes), - ) - - def _predict(self) -> None: - """Run prediction if batch has any changes.""" - # TODO: Currently, the batch inference interfaces of various models are not unified and are poorly implemented. - system_changes = self.check_state(self.batch) - if len(system_changes) > 0: - if self._calculator_type == "mace": - # FIXME: &&& - # for key, val in self.batch.to_dict().items(): - # print(f'&&& key: {key}, val: {val}') - # self.torch_results = self.trainer.predict_debug(atoms_list, self.batch, compute_stress=self.compute_stress) - # self.torch_results = self.trainer.predict(self.config_batch) - if self.use_fast_predict: - self.torch_results = self.trainer.fast_predict( - self.batch, compute_stress=self.compute_stress - ) - self.batch.pos = self.batch.pos.to(self.dtype) - self.batch.cell = self.batch.cell.to(self.dtype) - else: - atoms_list = batch_to_atoms( - self.batch, results=None, wrap_pos=False, eps=1e-17 - ) - self.torch_results = self.trainer.predict( - atoms_list, compute_stress=self.compute_stress - ) - elif self._calculator_type == "fairchem": - # TODO: FAIRChemCalculator does not support batch prediction yet - atoms_list = batch_to_atoms( - self.batch, results=None, wrap_pos=False, eps=1e-17 - ) - self.torch_results = self.trainer.predict(atoms_list=atoms_list) - elif self._calculator_type == "chgnet": - atoms_list = batch_to_atoms( - self.batch, results=None, wrap_pos=False, eps=1e-17 - ) - model_prediction = self.trainer.predict( - atoms_list=atoms_list, task="efs" - ) - results = { - "energy": torch.tensor( - [pred["e"].item() for pred in model_prediction], - device=self.device, - dtype=self.dtype, - ), - "forces": torch.vstack( - [ - torch.from_numpy(pred["f"]).to( - device=self.device, dtype=self.dtype - ) - for pred in model_prediction - ] - ), - "stress": torch.vstack( - [ - torch.from_numpy(pred["s"]).to( - device=self.device, dtype=self.dtype - ) - for pred in model_prediction - ] - ).view(-1, 3, 3), - } - self.torch_results = results - elif self._calculator_type == "sevennet": - atoms_list = batch_to_atoms( - self.batch, results=None, wrap_pos=False, eps=1e-17 - ) - self.torch_results = self.trainer.predict(atoms_list=atoms_list) - else: # default case - self.torch_results = self.trainer.predict( - self.batch, per_image=False, disable_tqdm=True - ) - # save only subset of props in simple namespace instead of cloning the whole batch to save memory - changes = ALL_CHANGES - set(self.ignored_changes) - self._cached_batch = SimpleNamespace( - **{prop: self.batch[prop].clone() for prop in changes} - ) - - def get_property( - self, name, no_numpy: bool = False - ) -> torch.Tensor | NDArray: - """Get a predicted property by name.""" - self._predict() - if self.numpy: - self.results = { - key: pred.item() if pred.numel() == 1 else pred.cpu().numpy() - for key, pred in self.torch_results.items() - } - else: - self.results = self.torch_results - - if name not in self.results: - raise PropertyNotImplementedError( - f"{name} not present in this calculation" - ) - - return ( - self.results[name] - if no_numpy is False - else self.torch_results[name] - ) - - def get_positions(self) -> torch.Tensor | NDArray: - """Get the batch positions""" - pos = self.batch.pos.clone() - if self.numpy: - if self.mask_converged: - pos[~self.update_mask[self.batch.batch]] = self._eps - pos = pos.cpu().numpy() - - return pos - - def set_positions(self, positions: torch.Tensor | NDArray) -> None: - """Set the atom positions in the batch.""" - if isinstance(positions, np.ndarray): - positions = torch.tensor( - positions, dtype=self.dtype, device=self.device - ) - else: - positions = positions.to(dtype=self.dtype, device=self.device) - - if self.mask_converged and self._update_mask is not None: - mask = self.update_mask[self.batch.batch] - self.batch.pos[mask] = positions[mask] - else: - self.batch.pos = positions - - if not self.otf_graph: - self.update_graph() - - def get_forces( - self, apply_constraint: bool = False, no_numpy: bool = False - ) -> torch.Tensor | NDArray: - """Get predicted batch forces.""" - forces = self.get_property("forces", no_numpy=no_numpy) - if apply_constraint: - fixed_idx = torch.where(self.batch.fixed == 1)[0] - if isinstance(forces, np.ndarray): - fixed_idx = fixed_idx.tolist() - forces[fixed_idx] = 0.0 - return forces.view(-1, 3) - - def get_potential_energy(self, **kwargs) -> torch.Tensor | NDArray: - """Get predicted energy as the sum of all batch energies.""" - # ASE 3.22.1 expects a check for force_consistent calculations - if kwargs.get("force_consistent", False) is True: - raise PropertyNotImplementedError( - "force_consistent calculations are not implemented" - ) - if ( - len(self.batch) == 1 - ): # unfortunately batch size 1 returns a float, not a tensor - return self.get_property("energy") - return self.get_property("energy").sum() - - def get_potential_energies(self) -> torch.Tensor | NDArray: - """Get the predicted energy for each system in batch.""" - return self.get_property("energy") - - def get_cells(self) -> torch.Tensor: - """Get batch crystallographic cells.""" - return self.batch.cell - - def set_cells( - self, cells: torch.Tensor | NDArray, scale_atoms=False - ) -> None: - """Set batch cells.""" - assert self.batch.cell.shape == cells.shape, "Cell shape mismatch" - if isinstance(cells, np.ndarray): - cells = torch.tensor(cells, dtype=self.dtype, device=self.device) - cells = cells.to(dtype=self.dtype, device=self.device) - if scale_atoms: - from ase.geometry.cell import complete_cell - - # M = torch.linalg.solve( - # self.batch.cell.view(-1, 3, 3), - # cells.view(-1, 3, 3), - # ) - # TODO: need to implement a sparse version. - # tmp_pos = torch.matmul(self.batch.pos, M.reshape(-1,3)) - for i in range(self.batch_size): - if not self.update_mask[i]: - continue - M = np.linalg.solve( - complete_cell(self.batch.cell[i].cpu().detach().numpy()), - complete_cell(cells[i].cpu().detach().numpy()), - ) - pos_update_mask = self.batch.batch == i - self.batch.pos[pos_update_mask] = torch.matmul( - self.batch.pos[pos_update_mask], - torch.from_numpy(M).to(self.device).reshape(-1, 3), - ) - self.batch.cell[self.update_mask] = cells[self.update_mask] - - def get_volumes(self) -> torch.Tensor: - """Get a tensor of volumes for each cell in batch""" - cells = self.get_cells() - return torch.linalg.det(cells) - - def iterimages(self) -> Generator[Batch, None, None]: - # XXX document purpose of iterimages - this is just needed to work with ASE optimizers - yield self.batch - - def get_max_forces( - self, forces: torch.Tensor | None = None, apply_constraint: bool = False - ) -> torch.Tensor: - """Get the maximum forces per structure in batch""" - if forces is None: - forces = self.get_forces( - apply_constraint=apply_constraint, no_numpy=True - ) - return scatter( - (forces**2).sum(axis=1).sqrt(), self.batch_indices, reduce="max" - ) - - def converged( - self, - forces: torch.Tensor | NDArray | None, - fmax: float, - max_forces: torch.Tensor | None = None, - f_upper_limit: float = 1e20, - ) -> bool: - """Check if norm of all predicted forces are below fmax""" - if forces is not None: - if isinstance(forces, np.ndarray): - forces = torch.tensor( - forces, device=self.device, dtype=self.dtype - ) - max_forces = self.get_max_forces(forces) - elif max_forces is None: - max_forces = self.get_max_forces() - - # Update mask is True for forces that are greater than fmax AND less than f_upper_limit - update_mask = torch.logical_and( - max_forces.ge(fmax), max_forces.le(f_upper_limit) - ) - # update cached mask - if self.mask_converged: - if self._update_mask is None: - self._update_mask = update_mask - else: - # some models can have random noise in their predictions, so the mask is updated by - # keeping all previously converged structures masked even if new force predictions - # push it slightly above threshold - self._update_mask = torch.logical_and( - self._update_mask, update_mask - ) - update_mask = self._update_mask - - return not torch.any(update_mask).item() - - def get_atoms_list(self) -> list[Atoms]: - """Get ase Atoms objects corresponding to the batch""" - self._predict() # in case no predictions have been run - return batch_to_atoms(self.batch, results=self.torch_results) - - def update_graph(self): - """Update the graph if model does not use otf_graph.""" - graph = self.trainer._unwrapped_model.generate_graph(self.batch) - self.batch.edge_index = graph.edge_index - self.batch.cell_offsets = graph.cell_offsets - self.batch.neighbors = graph.neighbors - if self.transform is not None: - self.batch = self.transform(self.batch) - - def __len__(self) -> int: - # TODO: this might be changed in ASE to be 3 * len(self.atoms) - return len(self.batch.pos) - - -class OptimizableUnitCellBatch(OptimizableBatch): - """Modify the supercell and the atom positions in relaxations. - - Based on ase UnitCellFilter to work on data batches - """ - - def __init__( - self, - batch: Batch, - trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetD3Calculator | FAIRChemCalculator) - transform: torch.nn.Module | None = None, - numpy: bool = False, - mask_converged: bool = True, - mask: Sequence[bool] | None = None, - cell_factor: float | torch.Tensor | None = None, - hydrostatic_strain: bool = False, - constant_volume: bool = False, - scalar_pressure: float = 0.0, - masked_eps: float = 1e-8, - use_fast_predict: bool = True, - dtype: torch.dtype = torch.float64, - ): - """Create a filter that returns the forces and unit cell stresses together, for simultaneous optimization. - - For full details see: - E. B. Tadmor, G. S. Smith, N. Bernstein, and E. Kaxiras, - Phys. Rev. B 59, 235 (1999) - - Args: - batch: A batch of atoms graph data - model: An instance of a BaseTrainer derived class - transform: graph transform - numpy: whether to cast results to numpy arrays - mask_converged: if true will mask systems in batch that are already converged - mask: a boolean mask specifying which strain components are allowed to relax - cell_factor: - Factor by which deformation gradient is multiplied to put - it on the same scale as the positions when assembling - the combined position/cell vector. The stress contribution to - the forces is scaled down by the same factor. This can be thought - of as a very simple preconditioner. Default is number of atoms - which gives approximately the correct scaling. - hydrostatic_strain: - Constrain the cell by only allowing hydrostatic deformation. - The virial tensor is replaced by np.diag([np.trace(virial)]*3). - constant_volume: - Project out the diagonal elements of the virial tensor to allow - relaxations at constant volume, e.g. for mapping out an - energy-volume curve. Note: this only approximately conserves - the volume and breaks energy/force consistency so can only be - used with optimizers that do require a line minimisation - (e.g. FIRE). - scalar_pressure: - Applied pressure to use for enthalpy pV term. As above, this - breaks energy/force consistency. - masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero - from zero differences in masked positions at future steps, we add a small number to prevent this. - dtype: data type for tensor operations (torch.float32 or torch.float64) - """ - super().__init__( - batch=batch, - trainer=trainer, - transform=transform, - numpy=numpy, - mask_converged=mask_converged, - masked_eps=masked_eps, - compute_stress=True, - use_fast_predict=use_fast_predict, - dtype=dtype, - ) - - self.orig_cells = self.get_cells().clone() - self.stress = None - - if mask is None: - # mask = torch.eye(3, device=self.device) - mask = torch.ones(6, device=self.device) - - # TODO make sure mask is on GPU - if mask.shape == (6,): - self.mask = torch.tensor( - voigt_6_to_full_3x3_stress(mask.detach().cpu()), - device=self.device, - ) - elif mask.shape == (3, 3): - self.mask = mask - else: - raise ValueError("shape of mask should be (3,3) or (6,)") - - if isinstance(cell_factor, float): - cell_factor = cell_factor * torch.ones( - (3 * len(batch), 1), requires_grad=False - ) - if cell_factor is None: - cell_factor = self.batch.natoms.repeat_interleave(3).unsqueeze( - dim=1 - ) - - self.hydrostatic_strain = hydrostatic_strain - self.constant_volume = constant_volume - self.pressure = scalar_pressure * torch.eye(3, device=self.device) - self.cell_factor = cell_factor - self.stress = None - self._batch_trace = torch.vmap(torch.trace) - self._batch_diag = torch.vmap( - lambda x: x * torch.eye(3, device=x.device) - ) - - @cached_property - def batch_indices(self): - """Get the batch indices specifying which position/force corresponds to which batch. - - We augment this to specify the batch indices for augmented positions and forces. - """ - augmented_batch = torch.repeat_interleave( - torch.arange( - len(self.batch), - dtype=self.batch.batch.dtype, - device=self.device, - ), - 3, - ) - return torch.cat([self.batch.batch, augmented_batch]) - - def deform_grad(self): - """Get the cell deformation matrix""" - return torch.transpose( - torch.linalg.solve(self.orig_cells, self.get_cells()), 1, 2 - ) - - def get_positions(self): - """Get positions and cell deformation gradient.""" - cur_deform_grad = self.deform_grad() - natoms = self.batch.num_nodes - pos = torch.zeros( - (natoms + 3 * len(self.get_cells()), 3), - dtype=self.batch.pos.dtype, - device=self.device, - ) - - # Augmented positions are the self.atoms.positions but without the applied deformation gradient - pos[:natoms] = torch.linalg.solve( - cur_deform_grad[self.batch.batch, :, :], - self.batch.pos.view(-1, 3, 1), - ).view(-1, 3) - # cell DOFs are the deformation gradient times a scaling factor - pos[natoms:] = self.cell_factor * cur_deform_grad.view(-1, 3) - return pos.cpu().numpy() if self.numpy else pos - - def set_positions(self, positions: torch.Tensor | NDArray) -> None: - """Set positions and cell. - - positions has shape (natoms + ncells * 3, 3). - the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor - for each cell. - """ - if isinstance(positions, np.ndarray): - positions = torch.tensor( - positions, dtype=self.dtype, device=self.device - ) - else: - positions = positions.to(dtype=self.dtype, device=self.device) - - natoms = self.batch.num_nodes - new_atom_positions = positions[:natoms] - new_deform_grad = (positions[natoms:] / self.cell_factor).view(-1, 3, 3) - - # TODO check that in fact symmetry is preserved setting cells and positions - # Set the new cell from the original cell and the new deformation gradient. Both current and final structures - # should preserve symmetry. - new_cells = torch.bmm( - self.orig_cells, torch.transpose(new_deform_grad, 1, 2) - ) - self.set_cells(new_cells) - - # Set the positions from the ones passed in (which are without the deformation gradient applied) and the new - # deformation gradient. This should also preserve symmetry - new_atom_positions = torch.bmm( - new_atom_positions.view(-1, 1, 3), - torch.transpose( - new_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), 1, 2 - ), - ) - super().set_positions(new_atom_positions.view(-1, 3)) - - def get_potential_energy(self, **kwargs): - """ - returns potential energy including enthalpy PV term. - """ - atoms_energy = super().get_potential_energy(**kwargs) - return atoms_energy + self.pressure[0, 0] * self.get_volumes().sum() - - def get_forces( - self, apply_constraint: bool = False, no_numpy: bool = False - ) -> torch.Tensor | NDArray: - """Get forces and unit cell stress.""" - stress = self.get_property("stress", no_numpy=True).view(-1, 3, 3) - atom_forces = self.get_property("forces", no_numpy=True) - - if apply_constraint: - fixed_idx = torch.where(self.batch.fixed == 1)[0] - atom_forces[fixed_idx] = 0.0 - - volumes = self.get_volumes().view(-1, 1, 1) - # virial = -volumes * stress + self.pressure.view(-1, 3, 3) - virial = -volumes * (stress + self.pressure.view(-1, 3, 3)) - # print(f'&&& virial0: {virial}') - cur_deform_grad = self.deform_grad() - atom_forces = torch.bmm( - atom_forces.view(-1, 1, 3), - cur_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), - ) - virial = torch.linalg.solve( - cur_deform_grad, torch.transpose(virial, dim0=1, dim1=2) - ) - virial = torch.transpose(virial, dim0=1, dim1=2) - - # print(f'&&& virial1: {virial}') - - # TODO this does not work yet! maybe _batch_trace gives an issue - if self.hydrostatic_strain: - virial = self._batch_diag(self._batch_trace(virial) / 3.0) - - # Zero out components corresponding to fixed lattice elements - if (self.mask != 1.0).any(): - virial *= self.mask.view(-1, 3, 3) - - if self.constant_volume: - virial[:, range(3), range(3)] -= ( - self._batch_trace(virial).view(3, -1) / 3.0 - ) - - natoms = self.batch.num_nodes - augmented_forces = torch.zeros( - (natoms + 3 * len(self.get_cells()), 3), - device=self.device, - dtype=atom_forces.dtype, - ) - # print(f'&&& atom_forces: {atom_forces}') - # print(f'&&& virial2: {virial}') - augmented_forces[:natoms] = atom_forces.view(-1, 3) - augmented_forces[natoms:] = virial.view(-1, 3) / self.cell_factor - - self.stress = -virial.view(-1, 9) / volumes.view(-1, 1) - - if self.numpy and not no_numpy: - augmented_forces = augmented_forces.cpu().numpy() - - # print(f'&&& augmented_forces: {augmented_forces}') - - return augmented_forces - - def __len__(self): - return len(self.batch.pos) + 3 * len(self.batch) - - def get_potential_energies(self) -> torch.Tensor: - """Get the predicted energy for each system in batch.""" - return ( - self.get_property("energy").view(-1) - + self.pressure[0, 0] * self.get_volumes() - ) +""" +Copyright (c) Meta, Inc. and its affiliates. +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Modified from original Meta implementation. +""" + +from __future__ import annotations + +from functools import cached_property +from types import SimpleNamespace +from typing import TYPE_CHECKING, ClassVar, Any, Generator +import numpy as np +import torch +import logging +from ase.calculators.calculator import PropertyNotImplementedError +from ase.stress import voigt_6_to_full_3x3_stress +from torch_scatter import scatter + +from batchopt.relaxation.ase_utils import batch_to_atoms + + +# Define dummy classes for when imports fail +class _DummyCalculator: + pass + + +try: + from mace.calculators import MACECalculator +except ImportError: + logging.warning("Unable to import MACECalculator.") + MACECalculator = _DummyCalculator + +try: + from chgnet.model.dynamics import CHGNetCalculator +except ImportError: + logging.warning("Unable to import CHGNetCalculator.") + CHGNetCalculator = _DummyCalculator + +try: + from sevenn.calculator import ( + SevenNetCalculator, + SevenNetD3Calculator, + D3Calculator, + ) +except ImportError: + logging.warning("Unable to import SevenNetCalculator.") + SevenNetCalculator = _DummyCalculator + SevenNetD3Calculator = _DummyCalculator + D3Calculator = _DummyCalculator + +try: + from fairchem.core import pretrained_mlip, FAIRChemCalculator +except ImportError: + logging.warning("Unable to import FAIRChemCalculator.") + FAIRChemCalculator = _DummyCalculator + + +# this can be removed after pinning ASE dependency >= 3.23 +try: + from ase.optimize.optimize import Optimizable +except ImportError: + + class Optimizable: + pass + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ase import Atoms + from numpy.typing import NDArray + from torch_geometric.data import Batch + + +ALL_CHANGES: set[str] = { + "pos", + "atomic_numbers", + "cell", + "pbc", +} + + +# @torch.compile +def compare_batches( + batch1: Batch | None, + batch2: Batch, + tol: float = 1e-6, + excluded_properties: set[str] | None = None, +) -> list[str]: + """Compare properties between two batches + + Args: + batch1: atoms batch + batch2: atoms batch + tol: tolerance used to compare equility of floating point properties + excluded_properties: list of properties to exclude from comparison + + Returns: + list of system changes, property names that are differente between batch1 and batch2 + """ + system_changes = [] + + if batch1 is None: + system_changes = ALL_CHANGES + else: + properties_to_check = set(ALL_CHANGES) + if excluded_properties: + properties_to_check -= set(excluded_properties) + + # Check properties that aren't + for prop in ALL_CHANGES: + if prop in properties_to_check: + properties_to_check.remove(prop) + if not torch.allclose( + getattr(batch1, prop), getattr(batch2, prop), atol=tol + ): + system_changes.append(prop) + + return system_changes + + +class OptimizableBatch(Optimizable): + """A Batch version of ase Optimizable Atoms + + This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation + or in ase relaxations classes, i.e. ase.optimize.lbfgs + """ + + ignored_changes: ClassVar[set[str]] = set() + + def __init__( + self, + batch: Batch, + trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetCalculator | FAIRChemCalculator) + transform: torch.nn.Module | None = None, + mask_converged: bool = True, + numpy: bool = False, + masked_eps: float = 1e-8, + compute_stress: bool = False, + use_fast_predict: bool = True, + dtype: torch.dtype = torch.float64, + ): + """Initialize Optimizable Batch + + Args: + batch: A batch of atoms graph data + model: An instance of a BaseTrainer derived class + transform: graph transform + mask_converged: if true will mask systems in batch that are already converged + numpy: whether to cast results to numpy arrays + masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero + from zero differences in masked positions at future steps, we add a small number to prevent this. + compute_stress: whether to compute stress during prediction + use_fast_predict: use fast prediction method when available + dtype: data type for tensor operations (torch.float32 or torch.float64) + """ + self.batch = batch.to(trainer.device) + self.trainer = trainer + self.transform = transform + self.numpy = numpy + self.mask_converged = mask_converged + self._cached_batch = None + self._update_mask = None + self.torch_results = {} + self.results = {} + self._eps = masked_eps + self.dtype = dtype + + self.otf_graph = True # trainer._unwrapped_model.otf_graph + if not self.otf_graph and "edge_index" not in self.batch: + self.update_graph() + + self.batch.pos = self.batch.pos.to(dtype=self.dtype) + self.batch.cell = self.batch.cell.to(dtype=self.dtype) + self.compute_stress = compute_stress + self.use_fast_predict = use_fast_predict + + # Determine calculator type once during initialization for efficiency + self._calculator_type = self._determine_calculator_type() + logging.info( + f"OptimizableBatch initialized with calculator type: {self._calculator_type}" + ) + + def _determine_calculator_type(self) -> str: + """Determine the type of calculator to avoid repeated isinstance checks.""" + # Check against actual imported classes, not dummy classes + trainer_class_name = type(self.trainer).__name__ + trainer_module = type(self.trainer).__module__ + + if ( + "mace" in trainer_module.lower() + or trainer_class_name == "MACECalculator" + ): + return "mace" + elif ( + "chgnet" in trainer_module.lower() + or trainer_class_name == "CHGNetCalculator" + ): + return "chgnet" + elif "sevenn" in trainer_module.lower() or trainer_class_name in [ + "SevenNetCalculator", + "SevenNetD3Calculator", + "D3Calculator", + ]: + return "sevennet" + elif ( + "fairchem" in trainer_module.lower() + or trainer_class_name == "FAIRChemCalculator" + ): + return "fairchem" + else: + return "default" + + @property + def device(self): + return self.trainer.device + + @property + def batch_indices(self): + """Get the batch indices specifying which position/force corresponds to which batch.""" + return self.batch.batch + + @property + def converged_mask(self): + if self._update_mask is not None: + return torch.logical_not(self._update_mask) + return None + + @property + def update_mask(self): + if self._update_mask is None: + return torch.ones(len(self.batch), dtype=bool) + return self._update_mask + + @property + def converge_indices_list(self): + return torch.where(~self.update_mask)[0].tolist() + + @property + def elem_per_group(self): + # This return value actually represents the number of elements + # in a group within a batch. Each group corresponds to batch_indices. + # It will count the number of CELL elements in each group. + return torch.bincount(self.batch_indices) + + @property + def batch_size(self): + return len(torch.unique(self.batch_indices)) + + def check_state(self, batch: Batch, tol: float = 1e-12) -> bool: + """Check for any system changes since last calculation.""" + return compare_batches( + self._cached_batch, + batch, + tol=tol, + excluded_properties=set(self.ignored_changes), + ) + + def _predict(self) -> None: + """Run prediction if batch has any changes.""" + # TODO: Currently, the batch inference interfaces of various models are not unified and are poorly implemented. + system_changes = self.check_state(self.batch) + if len(system_changes) > 0: + if self._calculator_type == "mace": + # FIXME: &&& + # for key, val in self.batch.to_dict().items(): + # print(f'&&& key: {key}, val: {val}') + # self.torch_results = self.trainer.predict_debug(atoms_list, self.batch, compute_stress=self.compute_stress) + # self.torch_results = self.trainer.predict(self.config_batch) + if self.use_fast_predict: + self.torch_results = self.trainer.fast_predict( + self.batch, compute_stress=self.compute_stress + ) + self.batch.pos = self.batch.pos.to(self.dtype) + self.batch.cell = self.batch.cell.to(self.dtype) + else: + atoms_list = batch_to_atoms( + self.batch, results=None, wrap_pos=False, eps=1e-17 + ) + self.torch_results = self.trainer.predict( + atoms_list, compute_stress=self.compute_stress + ) + elif self._calculator_type == "fairchem": + # TODO: FAIRChemCalculator does not support batch prediction yet + atoms_list = batch_to_atoms( + self.batch, results=None, wrap_pos=False, eps=1e-17 + ) + self.torch_results = self.trainer.predict(atoms_list=atoms_list) + elif self._calculator_type == "chgnet": + atoms_list = batch_to_atoms( + self.batch, results=None, wrap_pos=False, eps=1e-17 + ) + model_prediction = self.trainer.predict( + atoms_list=atoms_list, task="efs" + ) + results = { + "energy": torch.tensor( + [pred["e"].item() for pred in model_prediction], + device=self.device, + dtype=self.dtype, + ), + "forces": torch.vstack( + [ + torch.from_numpy(pred["f"]).to( + device=self.device, dtype=self.dtype + ) + for pred in model_prediction + ] + ), + "stress": torch.vstack( + [ + torch.from_numpy(pred["s"]).to( + device=self.device, dtype=self.dtype + ) + for pred in model_prediction + ] + ).view(-1, 3, 3), + } + self.torch_results = results + elif self._calculator_type == "sevennet": + atoms_list = batch_to_atoms( + self.batch, results=None, wrap_pos=False, eps=1e-17 + ) + self.torch_results = self.trainer.predict(atoms_list=atoms_list) + else: # default case + self.torch_results = self.trainer.predict( + self.batch, per_image=False, disable_tqdm=True + ) + # save only subset of props in simple namespace instead of cloning the whole batch to save memory + changes = ALL_CHANGES - set(self.ignored_changes) + self._cached_batch = SimpleNamespace( + **{prop: self.batch[prop].clone() for prop in changes} + ) + + def get_property( + self, name, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get a predicted property by name.""" + self._predict() + if self.numpy: + self.results = { + key: pred.item() if pred.numel() == 1 else pred.cpu().numpy() + for key, pred in self.torch_results.items() + } + else: + self.results = self.torch_results + + if name not in self.results: + raise PropertyNotImplementedError( + f"{name} not present in this calculation" + ) + + return ( + self.results[name] + if no_numpy is False + else self.torch_results[name] + ) + + def get_positions(self) -> torch.Tensor | NDArray: + """Get the batch positions""" + pos = self.batch.pos.clone() + if self.numpy: + if self.mask_converged: + pos[~self.update_mask[self.batch.batch]] = self._eps + pos = pos.cpu().numpy() + + return pos + + def set_positions(self, positions: torch.Tensor | NDArray) -> None: + """Set the atom positions in the batch.""" + if isinstance(positions, np.ndarray): + positions = torch.tensor( + positions, dtype=self.dtype, device=self.device + ) + else: + positions = positions.to(dtype=self.dtype, device=self.device) + + if self.mask_converged and self._update_mask is not None: + mask = self.update_mask[self.batch.batch] + self.batch.pos[mask] = positions[mask] + else: + self.batch.pos = positions + + if not self.otf_graph: + self.update_graph() + + def get_forces( + self, apply_constraint: bool = False, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get predicted batch forces.""" + forces = self.get_property("forces", no_numpy=no_numpy) + if apply_constraint: + fixed_idx = torch.where(self.batch.fixed == 1)[0] + if isinstance(forces, np.ndarray): + fixed_idx = fixed_idx.tolist() + forces[fixed_idx] = 0.0 + return forces.view(-1, 3) + + def get_potential_energy(self, **kwargs) -> torch.Tensor | NDArray: + """Get predicted energy as the sum of all batch energies.""" + # ASE 3.22.1 expects a check for force_consistent calculations + if kwargs.get("force_consistent", False) is True: + raise PropertyNotImplementedError( + "force_consistent calculations are not implemented" + ) + if ( + len(self.batch) == 1 + ): # unfortunately batch size 1 returns a float, not a tensor + return self.get_property("energy") + return self.get_property("energy").sum() + + def get_potential_energies(self) -> torch.Tensor | NDArray: + """Get the predicted energy for each system in batch.""" + return self.get_property("energy") + + def get_cells(self) -> torch.Tensor: + """Get batch crystallographic cells.""" + return self.batch.cell + + def set_cells( + self, cells: torch.Tensor | NDArray, scale_atoms=False + ) -> None: + """Set batch cells.""" + assert self.batch.cell.shape == cells.shape, "Cell shape mismatch" + if isinstance(cells, np.ndarray): + cells = torch.tensor(cells, dtype=self.dtype, device=self.device) + cells = cells.to(dtype=self.dtype, device=self.device) + if scale_atoms: + from ase.geometry.cell import complete_cell + + # M = torch.linalg.solve( + # self.batch.cell.view(-1, 3, 3), + # cells.view(-1, 3, 3), + # ) + # TODO: need to implement a sparse version. + # tmp_pos = torch.matmul(self.batch.pos, M.reshape(-1,3)) + for i in range(self.batch_size): + if not self.update_mask[i]: + continue + M = np.linalg.solve( + complete_cell(self.batch.cell[i].cpu().detach().numpy()), + complete_cell(cells[i].cpu().detach().numpy()), + ) + pos_update_mask = self.batch.batch == i + self.batch.pos[pos_update_mask] = torch.matmul( + self.batch.pos[pos_update_mask], + torch.from_numpy(M).to(self.device).reshape(-1, 3), + ) + self.batch.cell[self.update_mask] = cells[self.update_mask] + + def get_volumes(self) -> torch.Tensor: + """Get a tensor of volumes for each cell in batch""" + cells = self.get_cells() + return torch.linalg.det(cells) + + def iterimages(self) -> Generator[Batch, None, None]: + # XXX document purpose of iterimages - this is just needed to work with ASE optimizers + yield self.batch + + def get_max_forces( + self, forces: torch.Tensor | None = None, apply_constraint: bool = False + ) -> torch.Tensor: + """Get the maximum forces per structure in batch""" + if forces is None: + forces = self.get_forces( + apply_constraint=apply_constraint, no_numpy=True + ) + return scatter( + (forces**2).sum(axis=1).sqrt(), self.batch_indices, reduce="max" + ) + + def converged( + self, + forces: torch.Tensor | NDArray | None, + fmax: float, + max_forces: torch.Tensor | None = None, + f_upper_limit: float = 1e20, + ) -> bool: + """Check if norm of all predicted forces are below fmax""" + if forces is not None: + if isinstance(forces, np.ndarray): + forces = torch.tensor( + forces, device=self.device, dtype=self.dtype + ) + max_forces = self.get_max_forces(forces) + elif max_forces is None: + max_forces = self.get_max_forces() + + # Update mask is True for forces that are greater than fmax AND less than f_upper_limit + update_mask = torch.logical_and( + max_forces.ge(fmax), max_forces.le(f_upper_limit) + ) + # update cached mask + if self.mask_converged: + if self._update_mask is None: + self._update_mask = update_mask + else: + # some models can have random noise in their predictions, so the mask is updated by + # keeping all previously converged structures masked even if new force predictions + # push it slightly above threshold + self._update_mask = torch.logical_and( + self._update_mask, update_mask + ) + update_mask = self._update_mask + + return not torch.any(update_mask).item() + + def get_atoms_list(self) -> list[Atoms]: + """Get ase Atoms objects corresponding to the batch""" + self._predict() # in case no predictions have been run + return batch_to_atoms(self.batch, results=self.torch_results) + + def update_graph(self): + """Update the graph if model does not use otf_graph.""" + graph = self.trainer._unwrapped_model.generate_graph(self.batch) + self.batch.edge_index = graph.edge_index + self.batch.cell_offsets = graph.cell_offsets + self.batch.neighbors = graph.neighbors + if self.transform is not None: + self.batch = self.transform(self.batch) + + def __len__(self) -> int: + # TODO: this might be changed in ASE to be 3 * len(self.atoms) + return len(self.batch.pos) + + +class OptimizableUnitCellBatch(OptimizableBatch): + """Modify the supercell and the atom positions in relaxations. + + Based on ase UnitCellFilter to work on data batches + """ + + def __init__( + self, + batch: Batch, + trainer: Any, # Any calculator type (MACECalculator | CHGNetCalculator | SevenNetD3Calculator | FAIRChemCalculator) + transform: torch.nn.Module | None = None, + numpy: bool = False, + mask_converged: bool = True, + mask: Sequence[bool] | None = None, + cell_factor: float | torch.Tensor | None = None, + hydrostatic_strain: bool = False, + constant_volume: bool = False, + scalar_pressure: float = 0.0, + masked_eps: float = 1e-8, + use_fast_predict: bool = True, + dtype: torch.dtype = torch.float64, + ): + """Create a filter that returns the forces and unit cell stresses together, for simultaneous optimization. + + For full details see: + E. B. Tadmor, G. S. Smith, N. Bernstein, and E. Kaxiras, + Phys. Rev. B 59, 235 (1999) + + Args: + batch: A batch of atoms graph data + model: An instance of a BaseTrainer derived class + transform: graph transform + numpy: whether to cast results to numpy arrays + mask_converged: if true will mask systems in batch that are already converged + mask: a boolean mask specifying which strain components are allowed to relax + cell_factor: + Factor by which deformation gradient is multiplied to put + it on the same scale as the positions when assembling + the combined position/cell vector. The stress contribution to + the forces is scaled down by the same factor. This can be thought + of as a very simple preconditioner. Default is number of atoms + which gives approximately the correct scaling. + hydrostatic_strain: + Constrain the cell by only allowing hydrostatic deformation. + The virial tensor is replaced by np.diag([np.trace(virial)]*3). + constant_volume: + Project out the diagonal elements of the virial tensor to allow + relaxations at constant volume, e.g. for mapping out an + energy-volume curve. Note: this only approximately conserves + the volume and breaks energy/force consistency so can only be + used with optimizers that do require a line minimisation + (e.g. FIRE). + scalar_pressure: + Applied pressure to use for enthalpy pV term. As above, this + breaks energy/force consistency. + masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero + from zero differences in masked positions at future steps, we add a small number to prevent this. + dtype: data type for tensor operations (torch.float32 or torch.float64) + """ + super().__init__( + batch=batch, + trainer=trainer, + transform=transform, + numpy=numpy, + mask_converged=mask_converged, + masked_eps=masked_eps, + compute_stress=True, + use_fast_predict=use_fast_predict, + dtype=dtype, + ) + + self.orig_cells = self.get_cells().clone() + self.stress = None + + if mask is None: + # mask = torch.eye(3, device=self.device) + mask = torch.ones(6, device=self.device) + + # TODO make sure mask is on GPU + if mask.shape == (6,): + self.mask = torch.tensor( + voigt_6_to_full_3x3_stress(mask.detach().cpu()), + device=self.device, + ) + elif mask.shape == (3, 3): + self.mask = mask + else: + raise ValueError("shape of mask should be (3,3) or (6,)") + + if isinstance(cell_factor, float): + cell_factor = cell_factor * torch.ones( + (3 * len(batch), 1), requires_grad=False + ) + if cell_factor is None: + cell_factor = self.batch.natoms.repeat_interleave(3).unsqueeze( + dim=1 + ) + + self.hydrostatic_strain = hydrostatic_strain + self.constant_volume = constant_volume + self.pressure = scalar_pressure * torch.eye(3, device=self.device) + self.cell_factor = cell_factor + self.stress = None + self._batch_trace = torch.vmap(torch.trace) + self._batch_diag = torch.vmap( + lambda x: x * torch.eye(3, device=x.device) + ) + + @cached_property + def batch_indices(self): + """Get the batch indices specifying which position/force corresponds to which batch. + + We augment this to specify the batch indices for augmented positions and forces. + """ + augmented_batch = torch.repeat_interleave( + torch.arange( + len(self.batch), + dtype=self.batch.batch.dtype, + device=self.device, + ), + 3, + ) + return torch.cat([self.batch.batch, augmented_batch]) + + def deform_grad(self): + """Get the cell deformation matrix""" + return torch.transpose( + torch.linalg.solve(self.orig_cells, self.get_cells()), 1, 2 + ) + + def get_positions(self): + """Get positions and cell deformation gradient.""" + cur_deform_grad = self.deform_grad() + natoms = self.batch.num_nodes + pos = torch.zeros( + (natoms + 3 * len(self.get_cells()), 3), + dtype=self.batch.pos.dtype, + device=self.device, + ) + + # Augmented positions are the self.atoms.positions but without the applied deformation gradient + pos[:natoms] = torch.linalg.solve( + cur_deform_grad[self.batch.batch, :, :], + self.batch.pos.view(-1, 3, 1), + ).view(-1, 3) + # cell DOFs are the deformation gradient times a scaling factor + pos[natoms:] = self.cell_factor * cur_deform_grad.view(-1, 3) + return pos.cpu().numpy() if self.numpy else pos + + def set_positions(self, positions: torch.Tensor | NDArray) -> None: + """Set positions and cell. + + positions has shape (natoms + ncells * 3, 3). + the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor + for each cell. + """ + if isinstance(positions, np.ndarray): + positions = torch.tensor( + positions, dtype=self.dtype, device=self.device + ) + else: + positions = positions.to(dtype=self.dtype, device=self.device) + + natoms = self.batch.num_nodes + new_atom_positions = positions[:natoms] + new_deform_grad = (positions[natoms:] / self.cell_factor).view(-1, 3, 3) + + # TODO check that in fact symmetry is preserved setting cells and positions + # Set the new cell from the original cell and the new deformation gradient. Both current and final structures + # should preserve symmetry. + new_cells = torch.bmm( + self.orig_cells, torch.transpose(new_deform_grad, 1, 2) + ) + self.set_cells(new_cells) + + # Set the positions from the ones passed in (which are without the deformation gradient applied) and the new + # deformation gradient. This should also preserve symmetry + new_atom_positions = torch.bmm( + new_atom_positions.view(-1, 1, 3), + torch.transpose( + new_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), 1, 2 + ), + ) + super().set_positions(new_atom_positions.view(-1, 3)) + + def get_potential_energy(self, **kwargs): + """ + returns potential energy including enthalpy PV term. + """ + atoms_energy = super().get_potential_energy(**kwargs) + return atoms_energy + self.pressure[0, 0] * self.get_volumes().sum() + + def get_forces( + self, apply_constraint: bool = False, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get forces and unit cell stress.""" + stress = self.get_property("stress", no_numpy=True).view(-1, 3, 3) + atom_forces = self.get_property("forces", no_numpy=True) + + if apply_constraint: + fixed_idx = torch.where(self.batch.fixed == 1)[0] + atom_forces[fixed_idx] = 0.0 + + volumes = self.get_volumes().view(-1, 1, 1) + # virial = -volumes * stress + self.pressure.view(-1, 3, 3) + virial = -volumes * (stress + self.pressure.view(-1, 3, 3)) + # print(f'&&& virial0: {virial}') + cur_deform_grad = self.deform_grad() + atom_forces = torch.bmm( + atom_forces.view(-1, 1, 3), + cur_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), + ) + virial = torch.linalg.solve( + cur_deform_grad, torch.transpose(virial, dim0=1, dim1=2) + ) + virial = torch.transpose(virial, dim0=1, dim1=2) + + # print(f'&&& virial1: {virial}') + + # TODO this does not work yet! maybe _batch_trace gives an issue + if self.hydrostatic_strain: + virial = self._batch_diag(self._batch_trace(virial) / 3.0) + + # Zero out components corresponding to fixed lattice elements + if (self.mask != 1.0).any(): + virial *= self.mask.view(-1, 3, 3) + + if self.constant_volume: + virial[:, range(3), range(3)] -= ( + self._batch_trace(virial).view(3, -1) / 3.0 + ) + + natoms = self.batch.num_nodes + augmented_forces = torch.zeros( + (natoms + 3 * len(self.get_cells()), 3), + device=self.device, + dtype=atom_forces.dtype, + ) + # print(f'&&& atom_forces: {atom_forces}') + # print(f'&&& virial2: {virial}') + augmented_forces[:natoms] = atom_forces.view(-1, 3) + augmented_forces[natoms:] = virial.view(-1, 3) / self.cell_factor + + self.stress = -virial.view(-1, 9) / volumes.view(-1, 1) + + if self.numpy and not no_numpy: + augmented_forces = augmented_forces.cpu().numpy() + + # print(f'&&& augmented_forces: {augmented_forces}') + + return augmented_forces + + def __len__(self): + return len(self.batch.pos) + 3 * len(self.batch) + + def get_potential_energies(self) -> torch.Tensor: + """Get the predicted energy for each system in batch.""" + return ( + self.get_property("energy").view(-1) + + self.pressure[0, 0] * self.get_volumes() + ) diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__init__.py b/mace-bench/src/batchopt/relaxation/optimizers/__init__.py index 47cba81..71d1fff 100644 --- a/mace-bench/src/batchopt/relaxation/optimizers/__init__.py +++ b/mace-bench/src/batchopt/relaxation/optimizers/__init__.py @@ -1,13 +1,13 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -from .bfgs_torch import BFGS -from .bfgsfusedls import BFGSFusedLS - +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from .bfgs_torch import BFGS +from .bfgsfusedls import BFGSFusedLS + __all__ = ["BFGS", "BFGSFusedLS"] \ No newline at end of file diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/__init__.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 55f3f1d40668d6ad2146482b50fb484eccfadd16..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 540 zcmYjPy>8nu5T+FQCyMh9H`YL7Eg3o#MNuOu3?xATWzs^>BxN$7cmhcQwO^?4kTEag zwNqZ9Q%My?a>Mby?|66o4#dmLl09aKCx{PeY552W?ZLyxXjA9_ky2 zEo1G-fkO-+w39rb+XJ@o=G&Cuq+FG+N%%55T`+4gu%OZU*SM6#72j^Z*D06I+abQ> zu69Mtxk+>n%lT>`XHUJDAiU6N>}~

V87>xXqWcu_K3bL3j@$;B5RbICPTea%6|Lk@{@iIRBp{XKXr zmZX6~Prs)7@4x^5`@e2<3k6rh=jXq?zneX$Y5zurqaOo>%bLcbKR||SHH|Z_w^>K8 z>5TG5+o&1pY1T|U&9>FCYc|>}ZnrZXr{*y2O^s)`b64ZefnLjUV_U0c2S0VMbbI@v z+1QOM=jzqU3+G>Wx$-T)Qrq>rt)}m~*LRyyCF=G?J*d>XJgA`9Zq|b^3V5X-@<3GL z-JtTVuV0^5S9ZFnuQWqyaKHKWD_6I+uU2-NZM>`MDx%wsE4(R!dfXNJmF^CD&S!~5 z5Nx;)X#G_tbAA|hV?SYKt7i@VG=GNYc>b;h^%Z#Wt_Jlj@e-aEf0nQC6L(q7=4E~oH5tCj zPvPnC(|irjEI-4~;_32r{sla9{2c!hf8wrL%Om$Be+sz*Uxu2i(tP9Eo7~HDw z&%ghv67S9*UcM%myL0iq8zjH5&?++>|8^7yJ!$Qllo2JiYcqUx7eAXhcn%7*p&Qelb%tKk59Mh$;Iu5-LFRkf4kL%X>wHMMXJ69} z*oU+qS_4ni7-bD*Gqzebm1!qN4mZZsnnkV4&3guVv@^kfmH<2 zdGQ1_sAZ@etwofPJcnOImhbAe?);`?J;B!5VCmRG>ZHu+hfJI7LwV}IAk)x>vp{n%;lzJ$+~oGf5`VQ@gpT51j-)jqT0O zYO;wja#5E{GHcqH!Pm7`K5^E_=d%~7R8S?0N_&;K7%wHpfs(3KI?ypsW@ytGT)n+_ zU7<=&Ax1v2fExv6hvt@e3ae6kB+ejLT~-?)s#L42hO{Prmy)@{vq9MJ1j3I4xh#TS z;Ky)70DU{~Z(+mI4Q+-TcmiPYd%Ym!0>M%fWJu!6NTenFun|b39faapR8(DMe{u=? z8~cJrOy6Jbws{J-)2f>FmKO#e)EwG(mCH@`o(hsRrP1WKYfi_HZhClInv_dTVs756 z(}{LOx8o^rBW!yavPA1hOjcoKRx*lgg_T$lzf~izPz9KBR$0xXN7$;{czyaIgCoyT zCUggY6Kf+tjM`KJbzQ-}kr?6xfZfCvSz}@fFcWMuV1^rj6gJGL_bl}uwHP-w@u#Xs z1_%N3Niy8#nR{kLwT}2k0&~1+;z^@qhfeG!&HeN^;yN|U`bQ&>yk!q;Sv!CcZ zJIZmF=k97rW|SXm1cGCBYGR0^f(kwhvSF_LS2aldTO|`09-5`_=vm@oYiU%P&9gjO zAt~O^!+EkpH-@ILrgD-^+yjGYd!Nv}-buCYf1Ip0m*lodFIH==pW;1YQcGKdGgrx6 zD-=wgjVc`l0OATE4tz6%<#2ePmY_#fC8UfHJGP1Yx(2 z+kN(8BM5`rJ#ooLhtG&|>33GOsY&<4BTzo=`GA!2Si!_)w2Id$QSSA|EI0Eu@io*n zss1XGsv#~?o>V1%k1AGWrrm8cpbu#_LnuyZ5gvToBjt!!sVCu)EK)c)qTMVu=vE*a z0ax@xo}8AaXaTij17D+ai7DY;jla~OGj;mnTL*9 zWDYC+2ep~YRO8xZV}IuE`*qmBjA_~ zp^h{Nkjh3VQ4na6I?S=j6w5#!;6KW60QAr#{uw){%N~FX4h;$lcQrsd^~w>mIAB8y zM@ZBrmMYmXMS5GYf!Cn=3_?JmRU9oP&Y1n^9Me!Ljh1Q5lx3vgnz)d~Ofon1HH~05 zaZBhqI+0|@?5wIOY1qEZMeUDpyfyJ92o9YKTD128LJ58Um|*Fmnyn0ubdrEPUY*LX zc+d%OASFA^9sLwDoSM%IMx|!h`)*?!;@#MguX1;Hv4)hMXNqZXydA~D@ z8Q!CW`)l<4jHX!C1kHjFY4B#JdRDPRY2Iprd)br4&3Il)BG040*rbGTIHiIc$jMyV zAAla+l*XW$;;~4j0&GJnX}6o9-$q~vaP=WsEDirgB#ZUF5J4Ej-&L<%w2vdC=!Q5) zh!G8`Oso0Jj5heibHU4x>&3M=g&7O?_HI08kS} zYAeMa4Q}Ct+wdiOSJY{+S193>EaJ8BQ1^ZOBAU(xMwb;319h{m|loc>-f_mZj6|VEG}dSU?;ip0o1`DKS>J3N3)N0&4@1ODKc#8c+l$ zgcNgtFvweDEe7JC2f!I9W5GMD^i1-qMwA1-VQxtS(HEPK`wiZ&2jUv00Y;N=%r{%Q zVB%hktEa9|7X;G96%w|W4vA^ZHGs}C+YcjbPcXoVHQg4aNm631(C*@d5YY}J(59nU zxKe6AL2}J&wr$pf1~PGAPuBXQ%Hi5 zM=K;P>QsulrF1oxhPq;sr%}&p$x>2CbEmXda0x@RPtT_F`SqmUX{Dc3{Tle?c~CN*V}2+t!F58-l#c#yREJ^Qq&-XkPL`%1DA3%NlIv> zlNNa6IC6w|;n(<0kb{_-qcFiGHV3|^;|aWAr}aD-Ak~~xb$O&zU)Ia0G5>3jQ(ln5 z?Sl0abDgmi7XVK?^?2)#aqj`!V&Yyw}wErzQ&%d-fpxnVxmTlrBT$%C_;RQDi=?XV#WHz+ycDDaCc z`NczQBv@c0SJXH~u#qKotaKHx$S3)b9Jm`wL))3Ki~ViBWw987?BQEWlz^J&tc4xYAejzW;wJG zV}lNqCzVqFf+sD*moUp&ZR$x6>EquQ^pVQ)czyh@@X>`RddX!TLKoL+y zYH^ElA5fw|;x^^T*9C>;>dJ5Tbnye~`6rb8ITHN4&`1xJ2~`oZODFju4{)dhEkH|33)T_j z^t2UDzW#qiAgG6VI4D0##o`KP`aGXrpI=wOS+e~F20>wb#vlkG!60xtDI*T1$30jd zq8@p#x=ewEJcVB3CMEln+(A+`Q$C<7mZ`oF6vmP$BVy?H>cu5Y>=K2{&^eYI6z@fkr`##56e9hIgKuV^{}jh?CtDazVc(_ zyTEY*$k6-5Fs@ad^gk0)ar(zej&gZQRIWg|B9f|cZEFj9q%9L4QKB?PxoVOA_@GMy z&y%j_b-KLYrhMM>_WFK1Z4qnqQkjxhDS3~Qw<#eyD!xq#g%|Ocl>CGeXp z@|vSNMor%^@`lc|Su^vd{91Vnzg8`@5zdG4EaZf1kqtXV_6mqkUxvVXu+dmgO-DqA_)yk@yJ5*Z8-F@utdvcEybNSU` zW35_@#V)Q^Te((aQ@c-a_*6{7kkCIl?KXl)jA)< z&YeBIxU{^OTdCI2R?2d!(QtE4Rh3I_LtV`^R#0;=imS@yBe7Sw_Y1loFV^b~x9C5j@X4+H|WM)pqext^8ooEv+tu{76$ZIG&%z5Dn!NRyOM;_TXpE z2T#-~z9$%lE~;@=u$e?<-Ac?O{FnSh380h(S2(DK57A%}IsUIZ{AY1^}rj(SnR zm=^<#d-3LBK`cS7PqL4MmjFyU8emFbT3|-tn4?wD@0zxjonVb|Cp?m$a3Ul5Nq0(W zPdgZUE$KwBbG9?m*Q~%@js|p#Eh!!AdgT~_q->;~w|wjAV%7D-PZ!n3X44NBdC|&# zSe2XQqI>3|8e=uX4?E?jKt@y!*@ZbjK6K1#9A#>fV2WUxU>r&>TM*bkQBcw^3V4{IE>2<=f>tCQ&W7 z7LvZzD%Vz2mQ%dlN@az&UDWfMx(}^?=p0>bY?O~~6x#u{Mygc43#hvW%6hbU)m?4W z3#XUQ2j!Q__0sCmmMR^UsbRy9s&cJ(MQ-!a?*5lm>nQK?3du;Vh0E;Q1#Df_Efl!N zEuMEc9oBF4H^EXPZ1Gt}{-%(R>FtRj%5-dMeckB9d7ajb zHM?Ju@XRfiq28>vt?QZ>mVCUQk4Qez&)bqu_Ve0W%8SoxxYCIexMXW-$%UA^ywA1{lE!F?G2vP^E#x9Dgd!m%BDnnH zay=(7{^Z&e2@Y53gs)RR;gyXny>bJe=XJrZSGfIWE-o14k}bEW+};iG$40QZP%l5d zFzttg*7>n^Sv7cx*30TJdhnB#vRi02T2*1uegvJZ7QrOsjz-J%%?&b)vOgA301N`| zT`L!HDc{WY;bOB{t~-i|rQSer55XG=juYGs;D=PPUMc%#tz1{HXX(8JKSuBtKA-BX zfvl|{M4vQ&A*oKV4(PB3R#J0Bc*K+3YCq(I6t!{mS^ zzFzgss&mD+H;S$G0!H#JrsSdq;C@bI3E&LAEzaN4=W$Cc`OBCoJ=?uC89jy6lo`X_ ziRlU6sge6Oqqp}B->6%!y+o?q`=pTIR~XVKAv*y(v<`6u^2(;B#yow?z%2u@Ujqh$ zu{zxuy4Yzy#MaUT9WR&$?q(oo$+;r5M^1_Ze zu5C?k;}&%yQo_13Z0T4NEVXe*>~&praAT03(x95WAT*L2)U zZEbdK7hn?KIkIciZWs#p&FzF{lKx;jr{fu$hSPmGa+&>UHPJ3*m<{QdrG z=%&Svvs(Mk;Y&uz9X~6)kCH+(n6TFo`gdmbfKOD#<1L7!T(6^}MsB^3<1X2j-qBPhj zx*MCdfWngIs<*M&YF$JPhX03Mtgh7q7`jk#NXG$n@1{C_Bo;9a*8+1Bm7|&!(KFjL6Ud@DK`V9VzXJSUi zjD-wPbz9FGyL3AW>Bvlg?k7-&9u1L_GCR?&6vm*v&jb<_v4QYicph${xHZs3tkV18!&y zloeDGB@@9Y5!A39>pI(^CaA@Jq@bA0YOO^{$E1A-?H_aFPWT!nm;D}l#_iuqc4aTf zVU(Fbg+i6l1nyPE=~9y!5sK`jvTyR*&f#|D0GSgW1J7yLW1+HRffub@dCww?CJ`5kq31v$TJ5@nwikBzgg6iC;fWTitR?~jbjE`FH0EgD|#FU zMKrw3=s4(kM-S6Hm+naj#Y`;SL>VJMmT@A>ok$2}Om%dM&0uA^EX_bCFj-wU16sLY z1O)fBsJ~z-p&e%VY9Pv~97_##5kJ1sYD0^eve6oGyA=q!6@?Dp%-bINt2jvE4 zCy171=iZ!JZLH^tEs!=?ENVq)e+%OQsV$(pfDjHSB^RnT2!!V5m=fwqLDUa{6CW1X#VGk98z5uIMk4mN9xz`ZCwDc zL6~B?Vc4%)mTlO2#z^UR>IMi{K#~j)B^#tE_{$mxQEnI{au({u0Bs1WhwwL}@6!`E zte4x_n@_bwrrG;2P2wfuCW*EP6%;8A5~yvcbSQzKDeRvD2SiPjgy*QJm_zJ;Arc&1 z8Jz^2@SRM_NjXfBF#<1%jJ|$l2LV?WIorS0|zl_k#{nK;FX{>0B9^P?wj zQU5gGq!}>ZNX;-LW-JlXn&ClGMHyr$EXf*LlG=}c7wp@w>j<8!pCFLMnE_ULED*Iv z?%d$w74cv%*T{=EucwR*_{!)Jtq9*9;V0AvqWH=Y&0yWGwgv7X2SWKD^9M9NQ;kD% zw?L^w+teftNYGd{c?38DC2zJ9U6rd;B;ZEf0#fEx1G?*4)sNFMybiC?$e8RTXi=;WW! zK4mz`>$J9-PU?E-!yDi)oTIs}sTc7ZT7$_7zlIla(oW{O)ne(K`Y1FA1H9#1>kEFc$u8{ZU*0s<%&o@~+}g(- zcK2`H+282{)LUR*Ny6@vfadZV*^xsqkS4i*&_K%EufjriN1m1EvJXhjLsA;Wmf7p{ zlI&V)FD#O^h{-&#jP)hs*R*mQMff#fb*rf2#u2fKUWCUW1G@$+&rkzkP(}^6J+=&C zGMCGpfnoYk`SGKt7TU=J?>w+^z$qMf>w!lOEUSaKkLoo9hX@u34glmMq_^jm3v`G) z-B7T-6zau|GJG8|F|fk;Qf?_;XFmm-U*Nq^<)U-dk3C&=S787zme&1L*Q{OGC~v^? zQu!I<(Bap9Ah!r^Pr?EnLTu&$JxwIJ0y8KTgBqMHZhL-_=5*Ni8_g!|{zW%e#m0yU zG513~S0cOaeaHkzPg@W|zDuR(A7X!>j8Hvuv`<)%^myOAQV$u!oZ=_~z;JZw~}ItF-A9{Y14@C_V)(lx{Qih(If_CxpM$^Srp+*Lae94cL8nqg1*kVoGthp^nS#Eiw1@B)OzCbu= zLSmv+ccG7-ZYVnjfmib0oKo#x_Sl0i^h9O|- z;lDQG)E~n)Qf5pneg@j+bnAv~Wm((Ie49^=XXd|aJd@oY)9o8s)TCv7v}X1@tY?j} zZ7gi-6DEW=!)D+7#j|#LSYX4v<%i2-*_KTTE2ff8$5js5#lziw|lchjfdyuo? z8q~i8=uMxZ?2sz(X=Z=ql#e?=e27I)m%EZvCk&W!%Se)3)WF8hFe4(@d;7Ks&7A!%7cqmd7B?I(tRr(_vPcbf&~bKkdwj zi+Fmax$oa8>*{|}6*KPl!PdpdCPkf*yvaa5T>Hb}NH!^|0SiHcv zb0UsIh%A^c0tg%fkfM~JD@r3$N@o}a&0eYP&@>G<5@1YVTp&!@m|1wq4=rCjb>X7e zSHW>}7**7&SLllqF^0a%E1=fC<+T!hKs6XH9XQ72$5<(; zPqC6!Dz?gggq|08o7<4y406BXi&&POchyXp$#}qHT{K20%KG@X0J5kBB zsnw`7KsV;#I}dbi2bq&9!{U2CccybJ%DQj!cvrw9{RC6J+)gv`@1TMD&jdoU{uA3<&DH8LKY}nKNrc6Se0+!` z5F`8;9ndfj!$be`M2vkX4BX5_i4;t!WV=}_W@hm#XaeTfo*Y7xK@TNfwB9FzME~ST zAw9hl5-kLttZnML46!g^uMDrio$$QA1ObicFI%bp15&V15UC+Na|f&nVYIp4Gdlkf zD}R~bD+HssF7gzlyUp@%b6u3xXcUI9V2&U;WD%(~B4?rnMCdNoauK}P%10s?Ma)4v z$169h($nWT9wl^Z^g8vPOBW?xjNWG+2}*f#NtV#m1NtJS#ezxtQV;*iJT;({?8^Gu^OBQ5E~A8L!$*>Xw8U;vqr z>zT!9hC6=EP#^RR$AZ~<(lI+zn7z@N^wOPaC)AnpGMyPO+nJTCJQ2l(-kS6#VBj4G zS)rR$j2-x$*|s}$JR-)%3BMQSWSF!K^?h$wX95Idmp3+{O>3Pw@P)Z+dVdu#g1`i| z%dy1-wac3vc{=M&sh6b2T(`zt9Vf#USTRBp9D7>pww!@!oI5fb>bwZfO z_v&U{uj`jT-ko7oW*A#HU|t{*n?-nAT*?wGQ;0pnYIb=uTf4Wl!Fl7!KAbJe*z<}7 zg`0d#QVYo?^>2V)0fh_HXyHx1apqw$nZ3e`>xULETzK@tE5vu-ddH~?OAC7xNr~!{ zJE2SgjTw-fan_H~Rtp)wj=&?~Vd1Xm8}53Da_opiazGx1Fsk2U+3y1&4q_8gK?5{v$XMTmAo2;CKiFzh5v z3Gx4R-Ire=xmZDqFya(?JUL+`;Q2}zht0!)F&Kv}Jd26OV;K_H00`{|8WD}i)~i_y z(-BxoVBZy;+BZ9U%t5?~uwlWbB}2$iV{RFa}TaLXtbeNI2xe zlDo&DIKYu$^hDkJ90~@^gK^79XB<+$(YfGOsK-2*IvmnRm?;rzU?4EbgsBqL5@88K zrPu=;NJ%g0q~W?sIhjt{gVV~3cgB#;B0Y}u1k#hTGK-!Z(1j^yEEp5wN0d%7xS%FL zUr1-tpmLMmln4=Gf&~pifE3KHFvN(-E8;L@B855ryp!ep>Si!c%qQZcK-sY7HT#KV zCl0gPSZAVxId#T6=}xvY*_rAj;kb)oJ_(u2Gy)^o*B2qTPGg6@h&|3=kH6@QOFyRi zD(X(I>#aY*?-ZVV)giUvj+yG)nCavU(@}3~LPPMzx1C9ODqIHrzAF9d_%*!*p10=7c9Xnh;sfnQRkixa?{<^R6Q%ZngQ&!@8ey+WD50V@6lFAFLF*yqZfqiU=A+@W!*GRlQ)p%f^EQD zng2IXYf|?7dq@E0TZKKsEAtiZ{rAArZ||4(?2eO<^csXqhBQ2!dBJuSTC-rVxV z^SLAEE*?QdvqYGVK6sc9f@ou5OYO;Sv$IQQFP=ShF86Tm{`+#r+L>-~>3GoDP|4lB zl5DrmPh5QTT(7LxU3c(a@m#lRcj%zFTSt8j(?26zv8&&_+*5S0wwLiR(CgX@U=2ayH9rFHsU+g3 zA6 z>DgE9&=iV3G+Wc0$= zzf8ssEgBX(5NCJe$z2=QPF=g1{n$u>cv#u?^-528=$z#umLaA`eL#5liDG(Kn(rB%SV# zg=i9Q3~ZZ3GaJlANbWfDW-!Cy`Eg8)CNq|L6POw5l7J%xlFt-Ie%2fBOxA5-1LKHw zK`V?tDFO_LWc`=C1Vs^$#F&UNuw(vz*pKj!+z(5AzP}&Qq5WW*`+@yOyi9LDBG?bz zi*#AgIM&RU8nBO8|DZfrKWcT?FDYRZrtE;>O$@CXeo3AqAh1GK|B{Gpf6vvs2E@Gp zBW|$+*P!P1ae^%D{4ICE^so=u1?`XmIe2_JFAn1+P zVRr(9zs-JuIg-WyJHeZs`$l~DCxdW8}xq8iy+e3M(l006RwYSZ=Rvv z0lN|Gf+#m(Gop{vMg&NdvC&%*ZxX^b&Xr>cB?x<=l?a59_>!>jNsOJ;7ER+8L?Z|nGrT*f#qcJhRJ+r zUO#LF%#M9bfZ;I*_NCj#F4SO6#9tb_z_geKQ;aZ>iDM7iWzZ&QMdHX-cFK>|k&=^vV$;UeORGA!5(t^r(3I_ynk1_&;641}*EI0U8E zxQ16+BD6vQg=n?1p%=Xl^(0JtPJ0XbLK>mcKw#0^;BSLN`y@sSFa%su6Kqod6RCx0 zV0jQz?A`t_!k;GdrSw+d^p>7)|v^tz0y)>c-7=y_t zI%3TM8DYtiAnKK2Wc$%a;L0$ST#V1L;5PxVcN_rk^A+*#%czx9eGgBLaWS-aQcQ}$ zf|1ong%${%OR)Wpx*gPSqlJ2bfE%IyBLHse-&{$i09W>}u+_gJ_zZ!Z->)$xF10^q z>c0|vmEhY1H3Hr;x$rXgq4mhhXhc_^1^7AK_EruF8xo|EL=51ocFwj9OHyaG_kbL{ z$4(_;QoDqTBOh#uc9L$iyO9{U(Zq>nIB-1CiN<$c;Q$LqW6_i|#)vDFMw}$*T~@sb z4R7uKBL1R(z4bmMJcOn~rANH6yobU-RY+sVF#;KviQd|+L21YhyOAxMMh%I2j!Gmn z=Yxncw1Aw@jfajqRt5d6MHv8FfGcb$hqtJPa){u?jN#ON5c3(ayAw-!<#BQJ);tmt zTpMF>EqZ3}h!4JQbc1Sj^;%Sc*MR#!aC*Vxb`b@>UQh!&y`0nIC_YK#<2dtL`|#lI zJjj=D`yQ{M5SRYj2x2{5NDsyd%%X4R*l*)_lcVo;t;K`cLQpMIPUn}P&Jw5WWE1E-ts`T1nn9ezGptt^xneJ>Z9MM4}oAQ~A^ z`IJP(w`Xp*@Da~G^#q>AX`aUSk$`tjsDd~oNG-bF{h?CJ1;&6cLz)w_zsl05W~ZayS=#Ig6^AY}biFMs(UBi7j`N^N-u za2F&3^y+4Wf+3uu*Do0YA_vi9|NX#j>hdzA2fhSM(F&orw~K;;~MTMq8O zgFEr37ZWu4j^N*dI4`}D#?YlPfZk_65-h~L^wvKE@nVSnrlzukmoNfy&?h}ceXnbJ zc;$^mV%=!Xy-j--BqJ!3S25BMP1D{OUaC0eTBwE4;_E_AyfM6TaR-EK(~S^))O|o5 zzFYx`g0IP<4@eHu2kf&z5A@M{&Kz2wM(a?w^%Fp@UTa(gc;kXzSNY;=x7RSrBi(Wr z0j|ArzF?chcO2gde4&O4(WF{51&`PCVAW6lSZ_V*BbVQgehtnG{;)Bm;XCTg!ao!T zO3#96%bbDIOov$yqOuJo5ih4PeA7#9C!R1nyE=27d2bim!@>=wKZjIOQu9cqWTo$q zBDe>+U_Tzg9H!Wx`qJeWQQDis4E%c;m=L}Nl^M@ZXl@2U$q{c1!M(PZLCCMgYk2uZ z%#p8WV7yuOiqYg3V_}Y*-wa+>oLQT|cM?9l-T0k?EqKU7rY5}IK@T{G7klGfei1Bd z8ct@tvu9|&)0ppIe`IfZqX&CLQ-Qa~o0S*p_CwO#b4~A`&R$52!};7$&yRK3Z*TpB zF>apkC%wI}JC8Z>X%VP8cX&CTOJ|=q1K(lViHpxLjh&hC?(p^n^XkoV+XBhtDc$01 zus5?%96#RM0|exrA5Oy>kG_FJnXO$2?ViS2B(Y0#h{Q3TrvE#zea;!jD7<$@XLo&<%a^O<507qesvw~$>{NUABj%xLXDNlD1Aqb)i*Txq96G0v7*U-zQkahl?qf&TMe?joO0N5#t zcY$!egP&2rd!e^UcWQw@;%lt^b%JjYGy&T4JJ&!PbBq6l1p(*ndj_G`ts@94zr|7q*+s^h2z=C5)Qx z{dbWV!;|si0A6SF4g7Bt;H0Mhh~Ucv;@STeM~aKzKDbBu$Vx9!2@-H4P#n*PD*OqA zD!HyAe7_qVT;h>^8VTE+hs~a56ljIdz^xB49FF~=KZ{m;{3T4wOd+%x`XX)rCWZCb zt1-Oln$V5j69{59+8Z#nz$Gi_8kMv?&?IffsP^?4usUdS?OL%Oy7dI`L3sC(DMbW9(rF;KXoX)O8%5Uz37)&f>z zjBvfJ*L3wYY6iYW+YGFlg*KC0Z9B-+GQe$~X*)r-mQ`b%b}q=*@c-8W+4q*)zG|piGSd~|O0(_DG2z!k z!QGW!r%p4=@{P1bTt)eIGhUZYRA2RZ51o(1ekTgWa_C`1EWEn!u6iBb_64@%a4+_; zkYA_$?tV-Zt`JAtKzyynxCSn=nvvDBM%K()Sv#A_I{0O?Iew1M^P~LOJ*#H&S^msD zt!DA#`~uke%jX80+78efN>;dA(A`2w%}{;9h%l%uuQZ}II2kEU&Ue?Jt9ab&#y+paVdcj9n;{siW&!Dy zW#TlN!mcmUqj|0pG=&I7v(u=wX)-sa5tyhF#iCb_d%~Yr<3ISLKcweJmp{1mh1`$7 z-~GE0@{`LSr2IQ`Wr)0fMH=;P@A^Y)exV&U7Ug*O_)uN736a4VnF+kRQS5i6wbBl~ zSnBl)c%PS+*Y2)*(vIRTddM7zYp=Ue9CoF-9ENRK=taIelnNosBh1BWE^i8HVDoZj zr+Oq4ZawU*G#k?7&3arfq}x`Ia0~A%bn@>3Nm!z}TA~B8!~oP26VOO3Kr^ucEv^CD z3T6~^6wE4^Q!vkIo`Ql!PP3Jg4B(i;%ZUSgJjnu1BssvzB;UQp4g5^f(m$p(r?>{` zuotfj(j8$FASu0XO7o7_jHP|Y6G5*lZI86Tm$vY`z86;y$^0(Hi^DYV2!W#njuDsz zkos$p)WvzJH}LxBrQLPg%_x=`bqwyB%yvSz*CmxvbGN;9X?N)0B58B~cC+rw{0ij8 zck5oe9f@aX>JtPi1Wo}|3(}1I_KKLMnx}0=e41LHF#NtIUO>~s<%QKS@D~EFzv_jp zrU$InJKSN8RwX^|C*cRK?wjd=f zW19_hlueW^&bHZ3t&!+yO`b;C$Fu=tq?H*!F1H*=DwQBltz2Sl!`7oVpU`Y}gQPVR z7tp4)iW}^jcJrk{CN6D_p{ATH8xrW0oyKIzuyAyIdxW&CmBAjMTdkvbkKvug`wZUW+}X^(L)vfxZw2p3yr=e1b96wvKS6sMaVsrf zBu-`a)uVh-yY)R#c3SbLuxGAQkLTiAPZR$TpW)dJ_6@e5LxbXI&CwcDT2e@gnEP*W zesiFza9EuoI0~NPM97+C>6~2H2xxY0@xdcJ56G7*?QZa}7Usm3~6cVc$u~WRWVnVboMET4sbF zt$Oea$}G%N4?9(P@hl4BIjS+q@k&d0orW)SO1<2u+4rTTyoR&`FAjQbK{_LgX?wE+ zW9vt9JT*#CEkprr--Depqgx_|t?mV~la*?z2ubdz~E0=$QcG;Ds!*=$~Pkep_2%oN@i0^G(utSd9(# zO?X3;5K}@kPN*KnfgWgGA}?ypXagI5)WkOvn`#zIS)adxbx3Emw=&!$djq?oC)p!p zORU7fHxE6bashfr-y*C7eM^`JqmO9aH@L;EHB&U=Y@CCw%O&}3eS-~4g|svS3$><2 z0Y(*~fFjr>4S11e;5MlEdR*d}4QOZD*3WTA%~niG+d6FQ*w?g9j#|;49gLxOj(WeX zq1RY4w#~%*yX!HxPID(E#NFC&AqroCElrA4qB50WC&s2UaHYJ{Tea`>R$~wJ9>fdLnDv@6oO4xH3DeS0+LH80zWkiB#Yq~!D$bQtX?i1bAIL%*dhln|%OF%|rpH35Q zP$^AgDNP8`*c35EZyBJ>T?BjZ#-SQEl9`NHR)tBBQb zt>Q&6oe1aKo@n^h@pMzKV@avM9w{R%uF{@fCh#hO>jb_GAPsLh8lDjkV{w5RzCsNa z3~MKrRx@gLR>U<_RVUJedV~5b5!mqv;!PT&bAh)Z-a?_~)O*4uQzk9K8a10f-Sg?U zG5qpjJw}>Di`1_o^uPLi%0IDM$J-C<>Y0w!+so6yfN zlj)Cay@WQK>G~9)3DcJYbl5z8Gwc*AJTiaUKlDV2s*@Is9>U2yLYP@ZE)Y&$+oC`X zS3+{=TNJdwbScM9ZPqsRL0AUOFc!rxU`(nXne=#WP{kzz7YQg=Je|5?_<^|TwcU~X zP_7v2Mdn|!yZc;&xPt2M;Em`mFblhXz`B#S-b0Roo;&1_EPOmUrq;5!@uH^VuA_R| z6rEBuir5r!A?h-UP;(cd>S|JTo=UCw4AvE&CqSV_)kuY_OBKINKxxopHTn|jejl$= zqgiruaCQgIE)56jMBwP506U9Uo!#H#Bx9T$Z5|;QUCKZLARN~~ogtCR%}q8iRX=NZ zqPF-^oM|}?Hn8Gs3l3%*;-6Ss1=JR)4+YR*OPSjgWXA}dx5?3ETa&G+gvLxFOx9Y{ zoC0k6l3SoJ+?e|2;DYj-rV5Q`I2_&k+IMwc+|<=Iu(WA?s}4s=dopPg4 zesAOL)J9aj$HZy1&K|nnV-eLJ6(2rcI#zzF_voYwK@J4j5%h@r<4cJDDwWE0WLZAxzqN3+ z+Aq$%Jr~SzckZ>hugoopIVg}gOW+d(K1qNggIY!z&n1^qeRo2!hRD6+1wOKMIhBRL zwZe}?9$jTA>a5W#GJGy~)}!NQvvx?yweee*2{-!0g&CGEc6 zw=cXjf8n!@MP+yH>*CkYT_|g`*AU6<#Re?%-LAOwH|X*aSzZ0zX{|rE&y9#)XIE4* z8l%kOQ#UF4wCdLhs8y34?6DynpzbGB|LXw%@0or&nyGqLW{@^(fT_}i%MZmJG)aee zPaIZ;T$ad1JwA>!xkIUGX}ogdt7#x%)_dHOg=XY>w{Zcb46gV^>SDg~)oWJ;*<)#R z!aK;(#$ttTSW*v>LBiDCCYR1AT_-5+Q=+3XM}l&5qC%syklow3Tlrk&h~%MB8MjJx z%S=!B&fRZ_S$iTNGJA#2nlK*>1w16tu#UT zJVsQDsdFNZs9UDY?g*W@OO=o?Y>Q}tJO^EA5G5v2BFpg? zKz9BF_0b;yP?Uq%MK2+Cu?+{1E_o`dK}<7>c+8(o*gDgh{a<>%Os$WoZvo+ejFr9clIOw-HFqvc}mZT|Db;6IE~HOQ@CCt>cY$>ebLgNI4JN+VICIJ z-0_<#?=IQ;9(PRoNVj?Aj*kP0$qnmMxLQyswF#F@VYqU_2F#n4#uTtG25$!9HH^5= zDE-KLbszT&norfHkp_jg_jktW8fWvr z?D_()=ENAyS|%_~fG+uBk^rTZu3s02@L8-*rMcM*H5?{DmNQKlE9~e_9R+S_D-oVz zDrrmSn%V}1Rj7E9z$pT!0YJAht|aC{oe^Qk4kPPI>fyRF=ej{if8hWxxbCf<*B-WH zTpV28b;V0Gnsi!S5X2S2C{PsNCGZ0RKP2#&lL~v0>L~e2ekmen=4iMsWJrPb{tEx; z%sKPUdB=1zxpzvI8daU7y{+`whIN?fUCmC#Bh zmiSyNr7E-6^ik9*$hf9PP}x7Sp5k{1(Egu9m~<|tY`cWCjTX_-7Bn&w1@)T+q+HPd E04(_1mH+?% diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/lbfgs_torch.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/lbfgs_torch.cpython-310.pyc deleted file mode 100644 index 57128a45333701e99b6ae1b2f563b728722f0ad6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5951 zcmZ`-&2Jn>cJJ!1>FN0pNz)SbVQ<^kdNpgCl(UX^&1hw>Wl45y(@sR}!V@cv=2Q)* zM?KTS>K;j)G=WVhAVv@Y36fJ528O3y<*+&Al0)t}HoN+T@1G8;1 z8e56gw(xBy_P}X78s{YLpwh1Bv72~$iQqr`E3FyOPugjdp_>qAK6u9bKJURwCad2p%I3O*Kj^lI?Iu(KRQLmtF= z7DU}{oWxNsvSr_2@5Nb=rDNFGkm1``531_`%vuASziW%5x#VW+jKYPeyq>h$qoTBHoSi zPOn3(kAG}T-XZO@F%U+ZK|>~IZL4Zlt*Tvhs&2J{KaYQxzrxS)^ADW1%`foh9vE$h zKhM8~Zm;6MyW z0)!0G@|qB6HGmoBZ|LUvu2_tMBR0H*bH12|k~_ zv%Ypi%^&cg=Z2b#^suNZgXnⅈ;4=NgCzKEzed~Jj_*vM}tv3+*D4KjCzr3X3<>{ zcE?E)<}&Ildm~K~<whd?dxcl(Kj~?8k3V;BQ^X2gMt-CXcQ(OD=a{>umP>Y0kk<*`}TPQW2bNd z-J$|mDLg=r)2^z9zJ|5D4tkSY4~c8z#H@K@c1qLFX*iEP75Pa~kssy@g}2MHPkz_n z_AdJqR#d^~8E!(UPV|Y1`v(V~AIZ&Zbv7g96Z%+%rya*0==xqShQ{MzCmC}AFAN57 zw?unobU*@6;ZYu)V1J<2)gX!FrpWSulnjT8a%|lWqC6dt&qhg--h;*EDfr$MCs}AK zTn*O7gAJG;P9Yu{(QXKgQ;Ci6-O^%CY(21>)nFY#pc5sXae~0X%S+(c!8pm|Q4#}n zFNVozM<15wYzP>-2qMnQdr=%2rCF?fefh*O+A3EOrmz?#!RAApp}5^J|Lqm1JsGFhc2)hjK~EJ zx6G(xs4gBpxkbANlmz$C0dK}#N5(AB>$7bK9{rzhQ^<3n;zSH5Z>=9~(&RzbOUDV< zN*aky+$AwT?MAIdRXb*e@)R6Q&J%csz#@P$-@mM?Gr`C+v~ZRd>c`)v>@IGj9Bm+y zJXWzxv2A6^UaCxaS()9-%EEQ!c`UYkWoIJk$|kXU%CYiUT9_TCH$5`sYrsDKY^9eD z#L6I=^rEyMM;LX)Ya^MGUdqXa7cm{}3-<(=|1~V|SXCtQ^al-~abn#!|H1a@&)K0KZi65oI zLzE4FYIIRJaQgwR*WNRB=>(5EI~L~kegkywL$+h*&FLv#*<;K=95wG4@^AI}9P#cO zh0VQ(Rz|pa+0`oxJ5K&g;p{T*+_7a>@2?it&x|ik;xc@XFf`}>mnQe8i@e6`55Q^q z?4CjQ*kfl5q69aO8)GK_9HY~mjg9P!!U9*5Fw--o1g1-cbwGXW)}U*A-`KH!^PT0V zPu!#V^n2(tZTZi$`(RX_o}cZXo1Q<=M0wZ$A*%O`+wP8AxYG-?w)L;TLBh9B?^N>V zFs>FAy#Cae=6Qq9@%e}5j#qfo=kfn7zOaY)W6!^XuHFx(FOW=b*FQIRDm&HuMU8z4 z*{Zs0foi{iw!(|R&i@6IXZ0%NUpTi&diDMZEg{=Hv$ozcl_%~aj)dtQRT)K64D+l* z#m?ae=4QC~hjX*+5(b+{l(*-0SZ?%L-vnc;V0Zu&iU(+O`b%?*l9~0mx~wByJ{T!> z?-9x@%0ew$Sy3jG8K=rWibC3NaUFB{va%^6J!0}z9H-q^E@PzpD-`iM>0osACzxmC z6Ru&J=l0$G$#bOx$?zPypg;s8%NPw zOX9~XKWeosr<0&uliwi(8IEib_%4B$2vDF_&fpGeK%E?9K#1n1$ipLv zt4fk?Zc;UA$HQ)_s!-isBxAvKcvsIJ&Jlw-T8S(qoy%8p$7qh8d<-ILH4 z@`x^8bqH5>XUrp26N6C>%ZNHXL03;RAxXScXOcOp8r2+OHbz5pUsa^YP>p2D0lPtl zJ^`&yo=fHE+@W097pk6?U)u0RFw&Ay)tR}-cL=#gN7&>hauvgNMSDy)QO?~c8H-XS zzeAMT1uWS2hnQs}07@~w=`)`p%^q6+uP-pq z^pS`>*8CMw*e?=(!Sk4Jm)v}|#1`11wE(W3=`owWK{=_;tgJ&};qXJ@!NEUlh-DlD z3JtJv9SnBN+@#98Z}n};Fb^#qhxPodj>%J0;eB@x*`=_il|Ag9dWh8ap1Eisda=Tp zRypF}TLyU!X#GWlYUMR;*Klf!#{*@-W-x}f_!uQxP2R#q5J^%wkaWS)x*A$Nhqj{p zp|}_3=s#1b3kv0Z^ul`CDheL*H2-^0WuyR;ohBPvDBpCPME)U`PgoRHcof~W zMg0=78`qs(y>I5)mQat7T^Unn509DQkdC;GTH}Y*>GjGusG0@aHIc(j)GFu@ahJLNe1C!4XN^7N;=%+s z4|R^qaUQSm%C3bxeQ^h!UB6aT&*J{zid*QhN0v=l`-*`cv&0e4Mf~7mbSLY*ai^9iCaMS7O zq7IL29PQPyT8X>`qHr%b--udT};LVqp#jV0e6 z37vl>8KQPtDKT2!eE;wM@o)a?UnW;u=SuJNDfX9{R+%vHTcA=E{S(8RZ>R>X&U(5s zo9IO2BB_fZFH_4=y2 zTQF^N4ru^^!!MnTV$q4bPa%Pln)WNoH(HZy50Mp7nt}rP4EhLdzxB_p9dpNm|Jd>z zVzX1|ntiKDHqc^g2!;W8lqp*aQzu_#O%Os~qLtcsgx?qh@_BYg{vI$br;MMCl04m@ z_A7&!^MMR$hpkEfpJ-_3f8alx1Mm?i>#VtNLp0`3?f;$B%E#)=-tr`MOR`3wPM`s> zzP5IwwOICz6jsXSk(iZ@D}~>(d(}8Tv@76tEUSPsaG%1^f>@T zOdS@t2%$}pM!;%A(dQh|n(}DuQyWQjMi?qT3{fbL@snGq>S4Gwj*^3wN{Eo!2}8L; z{38M*0@~EcfV7a0S(X+gRSvQWK##}*0MD?Q>7WHY_E-OJ?{)8UzwSHDH#K!@Uj7wm z={nsp{Stz6DeX%tKvb=hCJB^BzonMFJKBf5m&6+fWfR;UWt$|ElW!8x-rUf6D?IpF q^(Zsz-xc+-PrJSHu9OnFN_Bijda#Xp#e)~n9~CKe6gYLZVErE@m)?8; diff --git a/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/linesearch_torch.cpython-310.pyc b/mace-bench/src/batchopt/relaxation/optimizers/__pycache__/linesearch_torch.cpython-310.pyc deleted file mode 100644 index 595a73fe601e540bfd2e4aee3e9349c7ef6b4ad7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10077 zcmb_iYiu0Xb)GvjJ3IT}a=GL#mk&v6$w|yw*29VIM%BxaB~`6T$s!Xcgq`hb?~q(_ zxm?~ECK7R{Nld3TT-$MqHf?OQArlm?fg))8qiFvWMgA2<3b;QKVA3{5i&jtr1nI99 z{UP-`XZA%&Zd|0Zm^0_T&pr3tbIv_?*N2AE2A)?Q{iEfJ*9_xdC^7h0NSwv%6Oe`! zO=BVciUq-GbHVhi1=~w3I2R3RO6w&uW1QXg=!R%kNkn3=+4VWCrbtURlQw-A#GYqrRAl}SI ziMjcwug%{!G0;0^l#s#Awih|l`&^^$-5x@&5&$zpk$odrZ$@@d^H)_0Sy3Wb$4sKc zmZ#SJDA8yIPksy^zxFvVN-j3nYn+yxHW%xWxm1trdd)}adQ(>Ypgm(nw(m6;Rffi) zPTw(98K3|D?9t`*HSg$JZELxZ|(E>$RZ1yxtCu zDz90)Q41RDt)oab)*2wqKiUM>edcvg1C|JIvdr{{0Z+}b#m{{FYx-|Zs*c&1bIA7^f%J3RTb9puv4TPI#U_Rp{1Jb$`Q z9PK`NAU0lJ>8}CNS*(?zGk`*;08O0+w6p_gYjcOz+u0>aNNUkdg+g0B+|qV96u3Jn z4VhR;?wUL4T_el|8EtM;yKFGDlS4@sy*OhAa)zY{hb75XUXmP*EKzPj59=YFTeKi0 z_H1V&KRI(@er7asuykvUR{X8qh@2aNs4H3 z)SB((+Hsb_7)(&DG|5V%)`UEjkoFG(ICd;2j>uVAkrlRCu+vrnU(PonKg9w0>Gx2l z&E0*HKzbsu9|C z^s@NUBuJ&G(?EEDcLdXQ--uqj!6Q;dpVW@Mcc*pBxnZ65NcWirLY9LEoJCU&NydJ$ z!7k`Q-WL|6mi7wmp#<+IN4vn2*dv`s_HmFSNjj-YO$Q};7mmr{du6TN;5L27467xW00c{IproA>81a`W8M zSOMN|i`lunUq_@Fj`2<{U^Fur-8R*$+LZQ$5f-HtmO#B37WHsAE)!u%=fkotgcH20 z$5PnaJ4HRF^Li8-2HIt%gm)ZodDjgm@tuH$b-rv)7&5tPs{hiH;TUwuq#hYF#*J_a z)@W*5^kz}kkl&R_={{wcMs-pb_mocPlKKJHnCjG+YGGxl1!in*3+avz31?+a}9=9L3HGA~@bO_F|_wr-R-f9=|pYqv?_ zBm23hug%WvXJw)uK@RIX)=aUM=_?)9Md^m$X!${{Rrg|r=*;srjhxPQ9icvWZyzS? zF#<;k93yZX0Csn?-SpH)3Aafp#^#fDiXuyJmE@xm_VZ*^Z;du7LXjiZEtUCZKe3 zpz#l7f#_>lF`ZnNgL=c9SKQ}^!c3SAN5ZiXG=-z#a99Y7VJRF6a~MnLEYmWsQ{30P zu%zSQ=DXl=9z1?mj&eVi`XTBTSB3u%_?A%eLrE%!D6!O!L20o}cuJSXj8&n2EQ?&q zMg%n8;eG|amd>L5j)Wpb>v*L}1&weGn-vt*Nl>1IszoiLNg#bF0QaHz6FQSIXdKFu zP^!I^n_wxLWJUKB=po69^QKtQT?QXN2)CKXI~lljd8|qS>VFKh!popl+08tkWIhV; zH4n~?5*I+rmp~s}!Lk8WJ*{(~kY;gH4@+84K}t{PZ_ ze*%V8W-O-ey?fRewbXkso8tc@#dO5nH>)l}0VpyLTjOjVPE_XRuU3v+oh}~qiG+V_^vrDs=<}W>cxpJ{``c&oER=$&6KOT2Bkn?yqr_ibM>G>;{yLsL2 zI)it!r#e+TL&w>jI_eie{snab@D^S??{$wdnA>@zV{^DQLq_k??8Ot6Ygaxw`wXxO zcRN?PdhPtD&d<){d*$M_^K)}s)4i%p-2C~u`O4hW^Oxo>Tsr^ER;gD?%sSO&$o=^- z2cyinE3=M=*>mUBB@m>32_Uj5`__Rk|LEDj*#EC*a6(`Mc!s#ah1 z0(SmkGU6}3TWyD8QP*qQR<>9UOcqkE|WeD#EeYKB7EJmeGw$Fr@i;Q`xAOH`*-?LOqwkPT*=`#rc~us2o!JGw2V zpp#(7;2OieekBf@;CO3lm$a3}!MSWqq2@T+z@Z%qO#BQzv@PQH+muRe3uv31&IQ(v zO*U%^0j5#3A`D{zju~ZR7j_S}5q-kKoY%t{N)2lZmSYHg=R5dtXRHy@y|AdH0ShrD ztUpJQwqgng>7y|*8dJ>m#y}bBk~Rhg{Y@A~eo>Ex#g@xDeYDK2FnW=F3uwI7pXw}m z3D9q8_Atc%3qPFyl^?eHZjT?S0e%op`~d&MK0lJ+htSE6JsHKEQ?#PIQvLjR{;1WN zKVz)*E#`o!#|CEoWW0`mkP4pt57?vn8D}4u@Wf(?cmM~saSZLr1lgmj59ETS!3S^w zoEYFjLAp%mChTismF({(II)K_tu)C4&pCt-Hrhgn2-f<2_ysFS_WG|_Ztmq9$-_7L zeCzf5XMv-W@GtG|dgvlq@Yd^r`M~@aAJ{w1{qOp{qt8s-4|+*=di2<9CEK2O!&FZ~ zDsg<~UCLua_MHdrQipbDyt?=8Qd$$|eRio0zL{#`fh*FLON=l^fbStmnVyCHnC&}ofycAlx8Lu3abJmg!5>LlUGxS@$U2)XSfbrF6r z)|JL$P2lLnViKN3W|r;oQmpsfvmL}ZJI`JNA5U-N$YH3@5O|&d8O9k`{VL%_0(AmM z2=Gbr1;VZqr~=HCVoTPfj5Pu+0_z0Y1U3k;oii5^(pzfBhh4U;w%=HSQ3{s5pcV-P zftPA)Yrxu(*ocIp<7pc~U%(B9i}#wW8aj7zwD$iTNZxcEm;?Mi%myn9li}i>hGB8; zmSHOLX4*_YuuhpxkPUNFKwxeob;=#wc;#DmBfSYvE*b{uK%;*d9Lgm#dmC=o=6;=nGBERUfcFrAn(K92r?<7fc^ z<|7rN`-s7jm7Fne-^HUgfwRc4P^nK-Q7VoPaCA1QJb$^?_SCOX#Tf$B9G=MLS}a8$ zp(f6N?_)|8i5pRdEHnbTOTn@q4RyTBIDPMulhFMREjf8eX)Id_3zWuLoIHiCi2-Kr zb1+-wq_tyVWSYt6VsO;XRN z2z-LTGX$O^KyIU|5ul|}G^K%Qa=@6P;EH%o)AWg27fWYmO_9FiR$SM#8N1~En(+lU zmrZlsI;!n?hyiEu`lLI0H*UCl!2KQFxVgA-%lB{Gp#Dap0Wa_qbj`yIyzjn@xJteI zB+}Y^nGfPS;$;ZME?^ZoE9Q;X z7Ky4lM_D~iN9mKzh9A_|*V?DQO~OSj%Es2={@!3-^e&~2oNTVwYfb+&kZ!}7OdN`& zpky?}S7f;IlDO6LqY>q`y;^`Y&Ol9%&O^mLHL8ovT7Zb8I!uJIZg4hRjg3t&O2-v_ z+=WC5Rcqm7Gu-Uou+jLqqM$n(^;-l;fy7heNwOr7D8AF*H94q&yx1?c%I`P953m(( zOLV&Dx`4u!!+2?2e;Oh}o*>k%iGX<)*8!|{VRtPO15&KC+%II|wc9#BVRXwWb<)@< zpmn0t`ca+emeOqvt`Fi~pF(K4+pB~8WGDZePWJMXOYkr77VwVZ9g|{-v{n*cV-e{R z-f?O6^2!M5PXwg!lT69l>B1w7*T8*iPi6GpedhjV+c++W) z$iZ{n%|361)8W2wzn(_)2eZ>9M27ZpRBjyDe!@1cBkoe(MnuYAK~R2}J=2u|N=Nwc=NbyIDqd6qFFXaxG2c~J$)diXX9<0 zQ_wf~{&q}-dav~g%w++$Vj`>zP*DUGeJ;cth(~8W(78!h^b~sfcNpHvw&<4O8kp4`(b&k}sNIxx$JFlt_kakeDkA$^ee9X^b87}baL!<|)NN_Oq%P3HDK;t06C zpLl}rdP0h)46Jx@r-akW#9IbV8^yPb--i<*jC)ta|0b^t`S{xd z)aSY@1Dd}?C%#x-su*n}7^t_7dVN`nS7Y#E(bxxw?|ot? zi9=pnrF)c81-IAgL}e>=qQbYuCn|C9kD_P&pcG#svBuw;zPAGk1o_LJtTbActtYrY z{?megr1HS0-@Y;z)Pha_MCCpkMAmPP6dn03Y=dm_0%P0LH|H=vUqJnFQcRQmijVbWP)Nl@VkVuP5ByOzenJ=36NhB zIkk4%Ye@v(n$6fN*&9jYxL%az|C``8X4Q*58H&e6N!%$TvbKaj7qk}FV+Zz86wG`m zwnbNIlxGS2EVo@QJliU*ecBKhA3oph8U!{0=BQsR#X7^_-e-^BkP%%)HOl)%4l$lo2f(x2Qtu0iCHnNmL6JF*M)+#0^A{QJ;ZMjoe-mXCqr?3e9NYfSrg$BG zSw0fspAS~fN|X0)O_CRX$8{*eH-=nSz%69sZr?57z=b??oQs z2Xm>vPBjziH}O%g5_pXOd(UqZ#%@|4utW1Y3f}HIG;obgOjP9VD9XnPi_hNVMX5g^ ya2+5@v~MDy%cm;F5nU!XkDEMbjL1D1JNBnBj^WeIF0%_KS%hXS;|ugFRQ?Bn7uOj8 diff --git a/mace-bench/src/batchopt/relaxation/optimizers/bfgs_torch.py b/mace-bench/src/batchopt/relaxation/optimizers/bfgs_torch.py index 2f8fb89..62409eb 100644 --- a/mace-bench/src/batchopt/relaxation/optimizers/bfgs_torch.py +++ b/mace-bench/src/batchopt/relaxation/optimizers/bfgs_torch.py @@ -1,286 +1,286 @@ - -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations - -import logging -import torch -from torch_scatter import scatter - -from ..optimizable import OptimizableBatch - -class BFGS: - def __init__( - self, - optimizable_batch: OptimizableBatch, - maxstep: float = 0.2, - alpha: float = 70.0, - early_stop = False, - ) -> None: - """ - Args: - """ - self.optimizable = optimizable_batch - self.maxstep = maxstep - self.alpha = alpha - # self.H0 = 1.0 / self.alpha - self.trajectories = None - self.device=self.optimizable.device - - self.fmax = None - self.steps = None - - self.initialize() - self.early_stop = early_stop - - - def initialize(self): - # initial hessian - self.H0 = [ - torch.eye(3 * size, device=self.optimizable.device, dtype=torch.float64) * self.alpha - for size in self.optimizable.elem_per_group - ] - - self.H = [None] * self.optimizable.batch_size - self.pos0 = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64) - self.forces0 = torch.zeros_like(self.pos0, device=self.device, dtype=torch.float64) - - def restart_from_earlystop(self, restart_indices, old_batch_indices): - H_new = [] - pos0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64) - forces0_new = torch.zeros_like(pos0_new, device=self.device, dtype=torch.float64) - - # collect the preserved historical data by old_batch_indices - for i, idx in enumerate(restart_indices): - mask_old = (idx==old_batch_indices.repeat_interleave(3)) - mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) - H_new.append(self.H[idx]) - pos0_new[mask] = self.pos0[mask_old] - forces0_new[mask] = self.forces0[mask_old] - - # append new info for the new batch - for i in range(len(H_new), self.optimizable.batch_size): - H_new.append(None) - - self.H = H_new - self.pos0 = pos0_new - self.forces0 = forces0_new - - - def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None): - logging.info("Enter bfgs's main program.") - self.fmax = fmax - self.max_iter = maxstep - - if is_restart_earlystop: - self.restart_from_earlystop(restart_indices, old_batch_indices) - - iteration = 0 - max_forces = self.optimizable.get_max_forces(apply_constraint=True) - logging.info("Step Fmax(eV/A)") - - while iteration < self.max_iter and not self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, - ): - if self.early_stop and iteration > 0: - converge_indices = self.optimizable.converge_indices_list - if len(converge_indices) > 0: - logging.info(f"Early stopping at iteration {iteration}") - break - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - self.step() - max_forces = self.optimizable.get_max_forces(apply_constraint=True) - iteration += 1 - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - # GPU memory usage as per nvidia-smi seems to gradually build up as - # batches are processed. This releases unoccupied cached memory. - torch.cuda.empty_cache() - - # set predicted values to batch - for name, value in self.optimizable.results.items(): - setattr(self.optimizable.batch, name, value) - - self.nsteps = iteration - - if self.early_stop: - converge_indices_list = self.optimizable.converge_indices_list - return converge_indices_list - else: - return self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces - ) - - - def step(self): - forces = self.optimizable.get_forces(apply_constraint=True).to( - dtype=torch.float64 - ) - pos = self.optimizable.get_positions().to(dtype=torch.float64) - dpos, steplengths = self.prepare_step(pos, forces) - dpos = self.determine_step(dpos, steplengths) - self.optimizable.set_positions(pos+dpos) - - - def prepare_step(self, pos, forces): - forces = forces.reshape(-1) - pos = pos.view(-1) - self.update(pos, forces, self.pos0, self.forces0) - - dpos_list = [] - cur_indices = self.optimizable.batch_indices.repeat_interleave(3) - # 预初始化结果列表 - dpos_list = [None] * len(self.H) - - # 分离计算任务:仅对需要计算的H矩阵创建流 - calc_indices = [i for i, need_update in enumerate(self.optimizable.update_mask) if need_update] - streams = [torch.cuda.Stream() for _ in calc_indices] - - # 并行执行实际计算 - for i, stream in zip(calc_indices, streams): - with torch.cuda.stream(stream): - omega, V = torch.linalg.eigh(self.H[i]) - dpos_list[i] = (V @ (forces[cur_indices==i].t() @ V / torch.abs(omega)).t()) - - # 同步所有计算流 - torch.cuda.current_stream().synchronize() - - # 在主线程处理零张量 - for i in range(len(self.H)): - if not self.optimizable.update_mask[i]: - dpos_list[i] = torch.zeros_like(forces[cur_indices==i]) - - # 同步所有流 - for stream in streams: - stream.synchronize() - - # dpos = torch.vstack(dpos_list) - dpos = torch.zeros_like(forces) - for i in torch.unique(cur_indices): - mask = (cur_indices == i) - dpos[mask] = dpos_list[i] - dpos = dpos.reshape(-1, 3) - - steplengths = (dpos ** 2).sum(dim=-1).sqrt() - self.pos0 = pos - self.forces0 = forces - - return dpos, steplengths - - - def determine_step(self, dpos, steplengths): - longest_steps = scatter( - steplengths, self.optimizable.batch_indices, reduce="max" - ) - longest_steps = longest_steps[self.optimizable.batch_indices] - maxstep = longest_steps.new_tensor(self.maxstep) - scale = (longest_steps).reciprocal() * torch.min(longest_steps, maxstep) - dpos *= scale.unsqueeze(1) - return dpos - - def update(self, pos, forces, pos0, forces0): - if self.H is None: - self.H = self.H0 - return - dpos = pos - pos0 - dforces = forces - forces0 - batch_indices_flatten = self.optimizable.batch_indices.repeat_interleave(3) - dg = torch.zeros_like(dforces) - all_size = self.optimizable.elem_per_group - - for i in range(self.optimizable.batch_size): - if self.H[i] is None: - continue - mask = (i==batch_indices_flatten) - if torch.abs(dpos[mask]).max() < 1e-7: - continue - - dg[mask] = self.H[i] @ dpos[mask] - - a = self._batched_dot_1d(dforces, dpos) - b = self._batched_dot_1d(dpos, dg) - - for i in range(self.optimizable.batch_size): - if self.H[i] is None: - self.H[i] = torch.eye(3*all_size[i], device=self.device, dtype=torch.float64) * self.alpha - continue - mask = (i==batch_indices_flatten) - if not self.optimizable.update_mask[i]: - continue - if torch.abs(dpos[mask]).max() < 1e-7: - continue - - outer_force = torch.outer(dforces[mask], dforces[mask]) - outer_dg = torch.outer(dg[mask], dg[mask]) - self.H[i] -= outer_force / a[i] + outer_dg / b[i] - - - - def update_parallel(self, pos, forces, pos0, forces0): - if self.H is None: - self.H = self.H0 - return - - dpos = pos - pos0 - - if torch.abs(dpos).max() < 1e-7: - return - - dforces = forces - forces0 - cur_indices = self.optimizable.batch_indices.repeat_interleave(3) - a = self._batched_dot_1d(dforces, dpos) - # DONE: There is a bug using hstack. - # dg = torch.hstack([self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))]) - # DONE: parallel this part - # dg_list = [self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))] - dg_list = [None] * len(self.H) - streams = [torch.cuda.Stream() for _ in dg_list] - for i, stream in zip(range(len(dg_list)), streams): - with torch.cuda.stream(stream): - dg_list[i] = self.H[i] @ dpos[cur_indices == i] - - torch.cuda.current_stream().synchronize() - for stream in streams: - stream.synchronize() - - dg = torch.zeros_like(dforces) - for i in torch.unique(cur_indices): - mask = (cur_indices == i) - dg[mask] = dg_list[i] - b = self._batched_dot_1d(dpos, dg) - - # DONE: parallel this part - for i, stream in zip(range(len(self.H)), streams): - if not self.optimizable.update_mask[i]: - continue - with torch.cuda.stream(stream): - outer_force = torch.outer(dforces[cur_indices==i], dforces[cur_indices==i]) - outer_dg = torch.outer(dg[cur_indices==i], dg[cur_indices==i]) - self.H[i] -= outer_force / a[i] + outer_dg / b[i] - - torch.cuda.current_stream().synchronize() - for stream in streams: - stream.synchronize() - - - def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor): - return scatter( - (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" - ) - - def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor): - return scatter( - (x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum" + +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +import torch +from torch_scatter import scatter + +from ..optimizable import OptimizableBatch + +class BFGS: + def __init__( + self, + optimizable_batch: OptimizableBatch, + maxstep: float = 0.2, + alpha: float = 70.0, + early_stop = False, + ) -> None: + """ + Args: + """ + self.optimizable = optimizable_batch + self.maxstep = maxstep + self.alpha = alpha + # self.H0 = 1.0 / self.alpha + self.trajectories = None + self.device=self.optimizable.device + + self.fmax = None + self.steps = None + + self.initialize() + self.early_stop = early_stop + + + def initialize(self): + # initial hessian + self.H0 = [ + torch.eye(3 * size, device=self.optimizable.device, dtype=torch.float64) * self.alpha + for size in self.optimizable.elem_per_group + ] + + self.H = [None] * self.optimizable.batch_size + self.pos0 = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64) + self.forces0 = torch.zeros_like(self.pos0, device=self.device, dtype=torch.float64) + + def restart_from_earlystop(self, restart_indices, old_batch_indices): + H_new = [] + pos0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device, dtype=torch.float64) + forces0_new = torch.zeros_like(pos0_new, device=self.device, dtype=torch.float64) + + # collect the preserved historical data by old_batch_indices + for i, idx in enumerate(restart_indices): + mask_old = (idx==old_batch_indices.repeat_interleave(3)) + mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) + H_new.append(self.H[idx]) + pos0_new[mask] = self.pos0[mask_old] + forces0_new[mask] = self.forces0[mask_old] + + # append new info for the new batch + for i in range(len(H_new), self.optimizable.batch_size): + H_new.append(None) + + self.H = H_new + self.pos0 = pos0_new + self.forces0 = forces0_new + + + def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None): + logging.info("Enter bfgs's main program.") + self.fmax = fmax + self.max_iter = maxstep + + if is_restart_earlystop: + self.restart_from_earlystop(restart_indices, old_batch_indices) + + iteration = 0 + max_forces = self.optimizable.get_max_forces(apply_constraint=True) + logging.info("Step Fmax(eV/A)") + + while iteration < self.max_iter and not self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, + ): + if self.early_stop and iteration > 0: + converge_indices = self.optimizable.converge_indices_list + if len(converge_indices) > 0: + logging.info(f"Early stopping at iteration {iteration}") + break + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + self.step() + max_forces = self.optimizable.get_max_forces(apply_constraint=True) + iteration += 1 + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + # GPU memory usage as per nvidia-smi seems to gradually build up as + # batches are processed. This releases unoccupied cached memory. + torch.cuda.empty_cache() + + # set predicted values to batch + for name, value in self.optimizable.results.items(): + setattr(self.optimizable.batch, name, value) + + self.nsteps = iteration + + if self.early_stop: + converge_indices_list = self.optimizable.converge_indices_list + return converge_indices_list + else: + return self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces + ) + + + def step(self): + forces = self.optimizable.get_forces(apply_constraint=True).to( + dtype=torch.float64 + ) + pos = self.optimizable.get_positions().to(dtype=torch.float64) + dpos, steplengths = self.prepare_step(pos, forces) + dpos = self.determine_step(dpos, steplengths) + self.optimizable.set_positions(pos+dpos) + + + def prepare_step(self, pos, forces): + forces = forces.reshape(-1) + pos = pos.view(-1) + self.update(pos, forces, self.pos0, self.forces0) + + dpos_list = [] + cur_indices = self.optimizable.batch_indices.repeat_interleave(3) + # 预初始化结果列表 + dpos_list = [None] * len(self.H) + + # 分离计算任务:仅对需要计算的H矩阵创建流 + calc_indices = [i for i, need_update in enumerate(self.optimizable.update_mask) if need_update] + streams = [torch.cuda.Stream() for _ in calc_indices] + + # 并行执行实际计算 + for i, stream in zip(calc_indices, streams): + with torch.cuda.stream(stream): + omega, V = torch.linalg.eigh(self.H[i]) + dpos_list[i] = (V @ (forces[cur_indices==i].t() @ V / torch.abs(omega)).t()) + + # 同步所有计算流 + torch.cuda.current_stream().synchronize() + + # 在主线程处理零张量 + for i in range(len(self.H)): + if not self.optimizable.update_mask[i]: + dpos_list[i] = torch.zeros_like(forces[cur_indices==i]) + + # 同步所有流 + for stream in streams: + stream.synchronize() + + # dpos = torch.vstack(dpos_list) + dpos = torch.zeros_like(forces) + for i in torch.unique(cur_indices): + mask = (cur_indices == i) + dpos[mask] = dpos_list[i] + dpos = dpos.reshape(-1, 3) + + steplengths = (dpos ** 2).sum(dim=-1).sqrt() + self.pos0 = pos + self.forces0 = forces + + return dpos, steplengths + + + def determine_step(self, dpos, steplengths): + longest_steps = scatter( + steplengths, self.optimizable.batch_indices, reduce="max" + ) + longest_steps = longest_steps[self.optimizable.batch_indices] + maxstep = longest_steps.new_tensor(self.maxstep) + scale = (longest_steps).reciprocal() * torch.min(longest_steps, maxstep) + dpos *= scale.unsqueeze(1) + return dpos + + def update(self, pos, forces, pos0, forces0): + if self.H is None: + self.H = self.H0 + return + dpos = pos - pos0 + dforces = forces - forces0 + batch_indices_flatten = self.optimizable.batch_indices.repeat_interleave(3) + dg = torch.zeros_like(dforces) + all_size = self.optimizable.elem_per_group + + for i in range(self.optimizable.batch_size): + if self.H[i] is None: + continue + mask = (i==batch_indices_flatten) + if torch.abs(dpos[mask]).max() < 1e-7: + continue + + dg[mask] = self.H[i] @ dpos[mask] + + a = self._batched_dot_1d(dforces, dpos) + b = self._batched_dot_1d(dpos, dg) + + for i in range(self.optimizable.batch_size): + if self.H[i] is None: + self.H[i] = torch.eye(3*all_size[i], device=self.device, dtype=torch.float64) * self.alpha + continue + mask = (i==batch_indices_flatten) + if not self.optimizable.update_mask[i]: + continue + if torch.abs(dpos[mask]).max() < 1e-7: + continue + + outer_force = torch.outer(dforces[mask], dforces[mask]) + outer_dg = torch.outer(dg[mask], dg[mask]) + self.H[i] -= outer_force / a[i] + outer_dg / b[i] + + + + def update_parallel(self, pos, forces, pos0, forces0): + if self.H is None: + self.H = self.H0 + return + + dpos = pos - pos0 + + if torch.abs(dpos).max() < 1e-7: + return + + dforces = forces - forces0 + cur_indices = self.optimizable.batch_indices.repeat_interleave(3) + a = self._batched_dot_1d(dforces, dpos) + # DONE: There is a bug using hstack. + # dg = torch.hstack([self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))]) + # DONE: parallel this part + # dg_list = [self.H[i] @ dpos[cur_indices == i] for i in range(len(self.H))] + dg_list = [None] * len(self.H) + streams = [torch.cuda.Stream() for _ in dg_list] + for i, stream in zip(range(len(dg_list)), streams): + with torch.cuda.stream(stream): + dg_list[i] = self.H[i] @ dpos[cur_indices == i] + + torch.cuda.current_stream().synchronize() + for stream in streams: + stream.synchronize() + + dg = torch.zeros_like(dforces) + for i in torch.unique(cur_indices): + mask = (cur_indices == i) + dg[mask] = dg_list[i] + b = self._batched_dot_1d(dpos, dg) + + # DONE: parallel this part + for i, stream in zip(range(len(self.H)), streams): + if not self.optimizable.update_mask[i]: + continue + with torch.cuda.stream(stream): + outer_force = torch.outer(dforces[cur_indices==i], dforces[cur_indices==i]) + outer_dg = torch.outer(dg[cur_indices==i], dg[cur_indices==i]) + self.H[i] -= outer_force / a[i] + outer_dg / b[i] + + torch.cuda.current_stream().synchronize() + for stream in streams: + stream.synchronize() + + + def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" + ) + + def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum" ) \ No newline at end of file diff --git a/mace-bench/src/batchopt/relaxation/optimizers/bfgsfusedls.py b/mace-bench/src/batchopt/relaxation/optimizers/bfgsfusedls.py index 522de6d..9dc221b 100644 --- a/mace-bench/src/batchopt/relaxation/optimizers/bfgsfusedls.py +++ b/mace-bench/src/batchopt/relaxation/optimizers/bfgsfusedls.py @@ -1,993 +1,993 @@ -""" -Copyright (c) 2025 Ma Zhaojia - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from __future__ import annotations -import logging -import torch -from torch_scatter import scatter -# from .linesearch_torch import LineSearchBatch -from ..optimizable import OptimizableBatch -from torch.profiler import profile, record_function, ProfilerActivity, schedule, tensorboard_trace_handler -from datetime import datetime -import os -import math -import gc - -class BFGSFusedLS: - """ - Port of BFGSLineSearch from bfgslinesearch.py, adapted to PyTorch - and batched operations, mirroring lbfgs_torch.py structure. - """ - def __init__( - self, - optimizable_batch: OptimizableBatch, - maxstep: float = 0.2, - c1: float = 0.23, - c2: float = 0.46, - alpha: float = 10.0, - stpmax: float = 50.0, - device = 'cpu', - early_stop: bool = False, - use_profiler: bool = False, - profiler_log_dir: str = './log', - profiler_schedule_config: dict = None, - dtype: torch.dtype = torch.float64, - ): - self.optimizable = optimizable_batch - self.maxstep = maxstep - self.c1 = c1 - self.c2 = c2 - self.alpha = alpha - self.stpmax = stpmax - self.nsteps = 0 - self.device = device - self.force_calls = 0 - self.early_stop = early_stop - self.use_profiler = use_profiler - self.profiler_log_dir = profiler_log_dir - self.profiler_schedule_config = profiler_schedule_config or {"wait": 48, "warmup": 1, "active": 1, "repeat": 8} - self.dtype = dtype - - self.converge_indices_list = None - - # The information from the previous round is useful for the current round's calculations. - ## These variables need to be update accroding to new input when eary stop is triggered. - self.Hs = None - self.r0 = None - self.g0 = None - self.p_list = [None] * self.optimizable.batch_size - self.no_update_list = [False] * self.optimizable.batch_size - self.ls_completed = [True] * self.optimizable.batch_size - self.ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu", dtype=self.dtype) - ## need to be recalculate when early stop is triggered - self.forces = None - self.energies = None - - def restart_from_earlystop(self, restart_indices, old_batch_indices): - Hs_new = [] - r0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device) - g0_new = torch.zeros_like(r0_new, device=self.device) - p_list_new = [] - no_update_list_new = [] - ls_completed_new = [] - - # collect the preserved historical info by old_indices - for i, idx in enumerate(restart_indices): - mask_old = (idx==old_batch_indices.repeat_interleave(3)) - mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) - Hs_new.append(self.Hs[idx]) - p_list_new.append(self.p_list[idx]) - no_update_list_new.append(self.no_update_list[idx]) - ls_completed_new.append(self.ls_completed[idx]) - r0_new[mask] = self.r0[mask_old] - g0_new[mask] = self.g0[mask_old] - - # append new info for new element in batch - for i in range(len(Hs_new), self.optimizable.batch_size): - # Hs_new.append(torch.eye(3 * self.optimizable.elem_per_group[i], device=self.device, dtype=torch.float64)) - Hs_new.append(None) - p_list_new.append(None) - no_update_list_new.append(False) - ls_completed_new.append(True) - - self.Hs = Hs_new - self.r0 = r0_new - self.g0 = g0_new - self.p_list = p_list_new - self.no_update_list = no_update_list_new - self.ls_completed = ls_completed_new - self.forces = None - self.energies = None - self.ls_batch.restart_from_earlystop(restart_indices=restart_indices, batch_indices_new=self.optimizable.batch_indices) - - def step(self): - optimizable = self.optimizable - if self.forces is None: - self.forces = optimizable.get_forces().to(self.device) - r = optimizable.get_positions().reshape(-1).to(self.device) - g = -self.forces.reshape(-1) / self.alpha - p0_list = self.p_list - self.update(r, g, self.r0, self.g0, p0_list) - if self.energies is None: - self.energies = self.func(r) - - for i in range(self.optimizable.batch_size): - if self.ls_completed[i]: - p = -torch.matmul(self.Hs[i], g[i==self.optimizable.batch_indices.repeat_interleave(3)]) - - # Implement scaling for numerical stability with simpler calculation - p_size = torch.sqrt((p**2).sum()) - min_size = torch.sqrt(self.optimizable.elem_per_group[i] * 1e-10) - if p_size <= min_size: - p = p * (min_size / p_size) - - self.p_list[i] = p - - # ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu") - continue_search = [not elem for elem in self.ls_completed] - self.alpha_k_list, self.e_list, self.e0_list, self.no_update_list, self.ls_completed = self.ls_batch._linesearch_batch( - self.func, self.fprime, r, self.p_list, g, self.energies, None, - maxstep=self.maxstep, c1=self.c1, c2=self.c2, stpmax=self.stpmax, continue_search=continue_search - ) - - # reset device for linesearch result - for i in range(self.optimizable.batch_size): - if self.ls_completed[i]: - self.alpha_k_list[i] = self.alpha_k_list[i].to(self.device) - self.p_list[i] = self.p_list[i].to(self.device) - - dr_tensor = torch.zeros_like(r) - - - for i in range(self.optimizable.batch_size): - # if check_cache: - # mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) - # dr_tensor_all[mask] = self.alpha_k_list[i].to(self.device) * self.p_list[i].to(self.device) - - if not self.ls_completed[i]: - continue - if self.alpha_k_list[i] is None: - raise RuntimeError("LineSearch failed!") - - mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) - dr_tensor[mask] = self.alpha_k_list[i] * self.p_list[i] - - # if check_cache: - # cached_pos = optimizable.get_positions().reshape(-1).to(self.device) - # update_pos = r + dr_tensor_all - # assert torch.allclose(update_pos, cached_pos), "dr_tensor_cached should be equal to dr_tensor" - - - # TODO: get_forces/get_potential_energies will trigger compare_batch which is time-consuming - forces_cache = optimizable.get_forces() - energies_cache = self.optimizable.get_potential_energies() / self.alpha - - # update self.forces - for i in range(self.optimizable.batch_size): - if not self.ls_completed[i]: - continue - mask = (i == self.optimizable.batch_indices) - self.forces[mask] = forces_cache[mask] - self.energies[i] = energies_cache[i] - - optimizable.set_positions((r + dr_tensor).reshape(-1, 3)) - - self.r0 = r - self.g0 = g - - # @torch.compile - def update(self, r, g, r0, g0, p0_list): - all_sizes = self.optimizable.elem_per_group - - if self.Hs is None: - self.Hs = [ - torch.eye(3 * sz, device=self.device, dtype=self.dtype) - for sz in all_sizes - ] - return - - dr = r - r0 - dg = g - g0 - - for i in range(self.optimizable.batch_size): - if self.Hs[i] is None: - self.Hs[i] = torch.eye(3 * all_sizes[i], device=self.optimizable.device, dtype=self.dtype) - continue - if not self.ls_completed[i]: - continue - if self.no_update_list[i] is True: - print('skip update') - continue - - cur_mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) - cur_g = g[cur_mask] - cur_p0 = p0_list[i] - cur_g0 = g0[cur_mask] - cur_dg = dg[cur_mask] - cur_dr = dr[cur_mask] - - if not (((self.alpha_k_list[i] or 0) > 0 and - abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False): - continue - - try: - rhok = 1.0 / (torch.dot(cur_dg, cur_dr)) - except: - rhok = 1000.0 - print("Divide-by-zero encountered: rhok assumed large") - if torch.isinf(rhok): - rhok = 1000.0 - print("Divide-by-zero encountered: rhok assumed large") - I = torch.eye(all_sizes[i]*3, device=self.device, dtype=self.dtype) - A1 = I - cur_dr[:, None] * cur_dg[None, :] * rhok - A2 = I - cur_dg[:, None] * cur_dr[None, :] * rhok - self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) + - rhok * cur_dr[:, None] * cur_dr[None, :]) - - - # def update(self, r, g, r0, g0, p0_list): - # self.Is = [ - # torch.eye(sz * 3, dtype=torch.float64, device=self.device) - # for sz in self.optimizable.elem_per_group - # ] - - # # TODO: BFGS for loop 是不是在被打断之后需要重建这个 self.Hs? - # # TODO: 并且我们保存的上一次的r,g,r0,g0也被丢弃了 - # if self.Hs is None: - # self.Hs = [ - # torch.eye(3 * sz, device=self.optimizable.device, dtype=torch.float64) - # for sz in self.optimizable.elem_per_group - # ] - # return - # else: - # dr = r - r0 - # dg = g - g0 - - # for i in range(self.optimizable.batch_size): - # if not self.ls_completed[i]: - # continue - # cur_mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) - # cur_g = g[cur_mask] - # cur_p0 = p0_list[i] - # cur_g0 = g0[cur_mask] - # cur_dg = dg[cur_mask] - # cur_dr = dr[cur_mask] - - # if not (((self.alpha_k_list[i] or 0) > 0 and - # abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False): - # break - - # if self.no_update_list[i] is True: - # print('skip update') - # break - - # try: - # rhok = 1.0 / (torch.dot(cur_dg, cur_dr)) - # except: - # rhok = 1000.0 - # print("Divide-by-zero encountered: rhok assumed large") - # if torch.isinf(rhok): - # rhok = 1000.0 - # print("Divide-by-zero encountered: rhok assumed large") - # A1 = self.Is[i] - cur_dr[:, None] * cur_dg[None, :] * rhok - # A2 = self.Is[i] - cur_dg[:, None] * cur_dr[None, :] * rhok - # self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) + - # rhok * cur_dr[:, None] * cur_dr[None, :]) - - - - def func(self, x): - self.optimizable.set_positions(x.reshape(-1, 3).to(self.device)) - return self.optimizable.get_potential_energies() / self.alpha - - def fprime(self, x): - self.optimizable.set_positions(x.reshape(-1, 3).to(self.device)) - - self.force_calls += 1 - forces = self.optimizable.get_forces().reshape(-1) - return - forces / self.alpha - - def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None): - logging.info("Enter bfgsfusedlinesearch's main program.") - self.fmax = fmax - self.max_iter = maxstep - - if is_restart_earlystop: - self.restart_from_earlystop(restart_indices, old_batch_indices) - - iteration = 0 - max_forces = self.optimizable.get_max_forces(apply_constraint=True) - logging.info("Step Fmax(eV/A)") - - # Run with profiler if enabled - if self.use_profiler: - activities = [ProfilerActivity.CPU] - if torch.cuda.is_available(): - activities.append(ProfilerActivity.CUDA) - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - pid = os.getpid() - with torch.profiler.profile( - activities=activities, - schedule=torch.profiler.schedule( - wait=self.profiler_schedule_config["wait"], - warmup=self.profiler_schedule_config["warmup"], - active=self.profiler_schedule_config["active"], - repeat=self.profiler_schedule_config["repeat"] - ), - on_trace_ready=tensorboard_trace_handler(self.profiler_log_dir, worker_name=f"BFGSLS_{pid}"), - with_stack=True, - profile_memory=True, - ) as prof: - # Main optimization loop with profiling - while iteration < self.max_iter and not self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, - ): - if self.early_stop and iteration > 0: - self.converge_indices_list = self.optimizable.converge_indices_list - if len(self.converge_indices_list) > 0: - logging.info(f"Early stopping at iteration {iteration}") - break - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - self.step() - max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces) - iteration += 1 - - # Step the profiler in each iteration - prof.step() - - else: - # Original optimization loop without profiling - while iteration < self.max_iter and not self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, - ): - if self.early_stop and iteration > 0: - self.converge_indices_list = self.optimizable.converge_indices_list - if len(self.converge_indices_list) > 0: - logging.info(f"Early stopping at iteration {iteration}") - break - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - self.step() - max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces) - iteration += 1 - - logging.info( - f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) - ) - - # GPU memory usage as per nvidia-smi seems to gradually build up as - # batches are processed. This releases unoccupied cached memory. - torch.cuda.empty_cache() - gc.collect() - - # set predicted values to batch - for name, value in self.optimizable.results.items(): - setattr(self.optimizable.batch, name, value) - - self.nsteps = iteration - - if self.early_stop: - self.converge_indices_list = self.optimizable.converge_indices_list - return self.converge_indices_list - else: - return self.optimizable.converged( - forces=None, fmax=self.fmax, max_forces=max_forces - ) - - def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor): - return scatter( - (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" - ) - - def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor): - return scatter( - (x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum" - ) - -# flake8: noqa -import math -import torch -import logging - -pymin = min -pymax = max - - -class LineSearch: - def __init__(self, xtol=1e-14, device='cpu', dtype=torch.float64): - self.xtol = xtol - self.task = 'START' - self.device = device - self.dtype = dtype - self.isave = torch.zeros(2, dtype=torch.int64, device=self.device) - self.dsave = torch.zeros(13, dtype=self.dtype, device=self.device) - self.fc = 0 - self.gc = 0 - self.case = 0 - self.old_stp = 0 - - def initialize(self, xk, pk, gfk, old_fval, old_old_fval, - maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., - stpmax=50., stpmin=1e-8): - # Scalar parameters can stay as Python scalars - self.stpmin = stpmin - self.stpmax = stpmax - self.xtrapl = xtrapl - self.xtrapu = xtrapu - self.maxstep = maxstep - - # Move tensors to the device - self.pk = pk.to(self.device) - xk = xk.to(self.device) - gfk = gfk.to(self.device) - - phi0 = old_fval - - - # This dot product needs tensors - derphi0 = torch.dot(gfk, self.pk).item() - - # Use Python math for scalar calculations - self.dim = len(pk) - self.gms = math.sqrt(self.dim) * maxstep - - alpha1 = 1.0 - self.no_update = False - self.gradient = True - - self.steps = [] - return alpha1, phi0, derphi0 - - def prologue(self, fval, gval, pk_tensor, alpha1): - phi0 = fval - derphi0 = torch.dot(gval, pk_tensor) - self.old_stp = alpha1 - # TODO: self.no_update == True: break is needed to reimplemented. - - return phi0, derphi0 - - def epilogue(self): - pass - - def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, - maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., - stpmax=50., stpmin=1e-8, args=()): - self.stpmin = stpmin - self.pk = pk.to(self.device) - self.stpmax = stpmax - self.xtrapl = xtrapl - self.xtrapu = xtrapu - self.maxstep = maxstep - - xk = xk.to(self.device) - - # Convert inputs to torch tensors if they're not already - if not isinstance(old_fval, torch.Tensor): - phi0 = torch.tensor(old_fval, dtype=self.dtype, device=self.device) - else: - phi0 = old_fval.to(self.device) - - # Ensure pk and gfk are torch tensors - pk_tensor = torch.tensor(pk, dtype=self.dtype, device=self.device) if not isinstance(pk, torch.Tensor) else pk.to(self.device) - gfk_tensor = torch.tensor(gfk, dtype=self.dtype, device=self.device) if not isinstance(gfk, torch.Tensor) else gfk.to(self.device) - - derphi0 = torch.dot(gfk_tensor, pk_tensor) - self.dim = len(pk) - self.gms = torch.sqrt(torch.tensor(self.dim, dtype=self.dtype, device=self.device)) * maxstep - alpha1 = 1. - self.no_update = False - - if isinstance(myfprime, tuple): - fprime = myfprime[0] - gradient = False - else: - fprime = myfprime - newargs = args - gradient = True - - fval = phi0 - gval = gfk_tensor - self.steps = [] - - while True: - stp = self.step(alpha1, phi0, derphi0, c1, c2, - self.xtol, - self.isave, self.dsave) - - if self.task[:2] == 'FG': - alpha1 = stp - - # Get function value and gradient - x_new = xk + stp * pk_tensor - fval = func(x_new).to(self.device) - self.fc += 1 - - gval = fprime(x_new).to(self.device) - if gradient: - self.gc += 1 - else: - self.fc += len(xk) + 1 - - phi0 = fval - derphi0 = torch.dot(gval, pk_tensor) - self.old_stp = alpha1 - - if self.no_update == True: - break - else: - break - - if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN': - stp = None # failed - - return stp, fval.item(), old_fval.item() if isinstance(old_fval, torch.Tensor) else old_fval, self.no_update - - def step(self, stp, f, g, c1, c2, xtol, isave, dsave): - if self.task[:5] == 'START': - # Check the input arguments for errors. - if stp < self.stpmin: - self.task = 'ERROR: STP .LT. minstep' - if stp > self.stpmax: - self.task = 'ERROR: STP .GT. maxstep' - if g >= 0: - self.task = 'ERROR: INITIAL G >= 0' - if c1 < 0: - self.task = 'ERROR: c1 .LT. 0' - if c2 < 0: - self.task = 'ERROR: c2 .LT. 0' - if xtol < 0: - self.task = 'ERROR: XTOL .LT. 0' - if self.stpmin < 0: - self.task = 'ERROR: minstep .LT. 0' - if self.stpmax < self.stpmin: - self.task = 'ERROR: maxstep .LT. minstep' - if self.task[:5] == 'ERROR': - return stp - - # Initialize local variables. - self.bracket = False - stage = 1 - finit = f - ginit = g - gtest = c1 * ginit - width = self.stpmax - self.stpmin - width1 = width / .5 - - # The variables stx, fx, gx contain the values of the step, - # function, and derivative at the best step. - # The variables sty, fy, gy contain the values of the step, - # function, and derivative at sty. - # The variables stp, f, g contain the values of the step, - # function, and derivative at stp. - stx = 0.0 - fx = finit - gx = ginit - sty = 0.0 - fy = finit - gy = ginit - stmin = 0.0 - stmax = stp + self.xtrapu * stp - self.task = 'FG' - self.save((stage, ginit, gtest, gx, - gy, finit, fx, fy, stx, sty, - stmin, stmax, width, width1)) - stp = self.determine_step(stp) - return stp - else: - if self.isave[0] == 1: - self.bracket = True - else: - self.bracket = False - stage = self.isave[1] - (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, - width, width1) = self.dsave - - # If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the - # algorithm enters the second stage. - ftest = finit + stp * gtest - if stage == 1 and f < ftest and g >= 0.: - stage = 2 - - # Test for warnings. - if self.bracket and (stp <= stmin or stp >= stmax): - self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS' - if self.bracket and stmax - stmin <= self.xtol * stmax: - self.task = 'WARNING: XTOL TEST SATISFIED' - if stp == self.stpmax and f <= ftest and g <= gtest: - self.task = 'WARNING: STP = maxstep' - if stp == self.stpmin and (f > ftest or g >= gtest): - self.task = 'WARNING: STP = minstep' - - # Test for convergence. - # if f <= ftest and abs(g) <= c2 * (- ginit): - # self.task = 'CONVERGENCE' - if (f < ftest or math.isclose(f, ftest, rel_tol=1e-6, abs_tol=1e-5)) and (abs(g) < c2 * (- ginit) or math.isclose(abs(g), c2 * (- ginit), rel_tol=1e-6, abs_tol=1e-5)): - self.task = 'CONVERGENCE' - - # Test for termination. - if self.task[:4] == 'WARN' or self.task[:4] == 'CONV': - self.save((stage, ginit, gtest, gx, - gy, finit, fx, fy, stx, sty, - stmin, stmax, width, width1)) - return stp - - stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty, - fy, gy, stp, f, g, - stmin, stmax) - - # Decide if a bisection step is needed. - if self.bracket: - if abs(sty - stx) >= .66 * width1: - stp = stx + .5 * (sty - stx) - width1 = width - width = abs(sty - stx) - - # Set the minimum and maximum steps allowed for stp. - if self.bracket: - stmin = min(stx, sty) - stmax = max(stx, sty) - else: - stmin = stp + self.xtrapl * (stp - stx) - stmax = stp + self.xtrapu * (stp - stx) - - # Force the step to be within the bounds maxstep and minstep. - stp = max(stp, self.stpmin) - stp = min(stp, self.stpmax) - - if (stx == stp and stp == self.stpmax and stmin > self.stpmax): - self.no_update = True - - # If further progress is not possible, let stp be the best - # point obtained during the search. - if (self.bracket and stp < stmin or stp >= stmax) \ - or (self.bracket and stmax - stmin < self.xtol * stmax): - stp = stx - - # Obtain another function and derivative. - self.task = 'FG' - self.save((stage, ginit, gtest, gx, - gy, finit, fx, fy, stx, sty, - stmin, stmax, width, width1)) - return stp - - def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp, - stpmin, stpmax): - sign = gp * (gx / abs(gx)) - - # First case: A higher function value. The minimum is bracketed. - # If the cubic step is closer to stx than the quadratic step, the - # cubic step is taken, otherwise the average of the cubic and - # quadratic steps is taken. - if fp > fx: # case1 - self.case = 1 - theta = 3. * (fx - fp) / (stp - stx) + gx + gp - s = max(max(abs(theta), abs(gx)), abs(gp)) - gamma = s * math.sqrt((theta / s) ** 2. - (gx / s) * (gp / s)) - if stp < stx: - gamma = -gamma - p = (gamma - gx) + theta - q = ((gamma - gx) + gamma) + gp - r = p / q - stpc = stx + r * (stp - stx) - stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \ - * (stp - stx) - if (abs(stpc - stx) < abs(stpq - stx)): - stpf = stpc - else: - stpf = stpc + (stpq - stpc) / 2. - - self.bracket = True - - # Second case: A lower function value and derivatives of opposite - # sign. The minimum is bracketed. If the cubic step is farther from - # stp than the secant step, the cubic step is taken, otherwise the - # secant step is taken. - elif sign < 0: # case2 - self.case = 2 - theta = 3. * (fx - fp) / (stp - stx) + gx + gp - s = max(max(abs(theta), abs(gx)), abs(gp)) - gamma = s * math.sqrt((theta / s) ** 2 - (gx / s) * (gp / s)) - if stp > stx: - gamma = -gamma - p = (gamma - gp) + theta - q = ((gamma - gp) + gamma) + gx - r = p / q - stpc = stp + r * (stx - stp) - stpq = stp + (gp / (gp - gx)) * (stx - stp) - if (abs(stpc - stp) > abs(stpq - stp)): - stpf = stpc - else: - stpf = stpq - self.bracket = True - - # Third case: A lower function value, derivatives of the same sign, - # and the magnitude of the derivative decreases. - elif abs(gp) < abs(gx): # case3 - self.case = 3 - # The cubic step is computed only if the cubic tends to infinity - # in the direction of the step or if the minimum of the cubic - # is beyond stp. Otherwise the cubic step is defined to be the - # secant step. - theta = 3. * (fx - fp) / (stp - stx) + gx + gp - s = max(max(abs(theta), abs(gx)), abs(gp)) - - # The case gamma = 0 only arises if the cubic does not tend - # to infinity in the direction of the step. - gamma = s * math.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s))) - if stp > stx: - gamma = -gamma - p = (gamma - gp) + theta - q = (gamma + (gx - gp)) + gamma - r = p / q - if r < 0. and gamma != 0: - stpc = stp + r * (stx - stp) - elif stp > stx: - stpc = stpmax - else: - stpc = stpmin - stpq = stp + (gp / (gp - gx)) * (stx - stp) - - if self.bracket: - # A minimizer has been bracketed. If the cubic step is - # closer to stp than the secant step, the cubic step is - # taken, otherwise the secant step is taken. - if abs(stpc - stp) < abs(stpq - stp): - stpf = stpc - else: - stpf = stpq - if stp > stx: - stpf = min(stp + .66 * (sty - stp), stpf) - else: - stpf = max(stp + .66 * (sty - stp), stpf) - else: - # A minimizer has not been bracketed. If the cubic step is - # farther from stp than the secant step, the cubic step is - # taken, otherwise the secant step is taken. - if abs(stpc - stp) > abs(stpq - stp): - stpf = stpc - else: - stpf = stpq - stpf = min(stpmax, stpf) - stpf = max(stpmin, stpf) - - # Fourth case: A lower function value, derivatives of the same sign, - # and the magnitude of the derivative does not decrease. If the - # minimum is not bracketed, the step is either minstep or maxstep, - # otherwise the cubic step is taken. - else: # case4 - self.case = 4 - if self.bracket: - theta = 3. * (fp - fy) / (sty - stp) + gy + gp - s = max(max(abs(theta), abs(gy)), abs(gp)) - gamma = s * math.sqrt((theta / s) ** 2 - (gy / s) * (gp / s)) - if stp > sty: - gamma = -gamma - p = (gamma - gp) + theta - q = ((gamma - gp) + gamma) + gy - r = p / q - stpc = stp + r * (sty - stp) - stpf = stpc - elif stp > stx: - stpf = stpmax - else: - stpf = stpmin - - # Update the interval which contains a minimizer. - if fp > fx: - sty = stp - fy = fp - gy = gp - else: - if sign < 0: - sty = stx - fy = fx - gy = gx - stx = stp - fx = fp - gx = gp - - # Compute the new step. - stp = self.determine_step(stpf) - - return stx, sty, stp, gx, fx, gy, fy - - def determine_step(self, stp): - dr = stp - self.old_stp - x = torch.reshape(self.pk.to(self.device), (-1, 3)) - steplengths = ((dr * x)**2).sum(1)**0.5 - maxsteplength = max(steplengths) - if maxsteplength >= self.maxstep: - dr *= self.maxstep / maxsteplength - stp = self.old_stp + dr - return stp - - def save(self, data): - if self.bracket: - self.isave[0] = 1 - else: - self.isave[0] = 0 - self.isave[1] = data[0] - self.dsave = data[1:] - -class LineSearchBatch: - - def __init__(self, batch_indices, device='cpu', dtype=torch.float64): - self.device = device - self.dtype = dtype - self.batch_indices = batch_indices.to(self.device) - self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device) - self.batch_size = len(torch.unique(batch_indices)) - self.linesearch_list = [LineSearch(device=self.device, dtype=self.dtype) for _ in range(self.batch_size)] - self.steps = [1.] * self.batch_size - self.phi0_values = [None] * self.batch_size - self.derphi0_values = [None] * self.batch_size - - def restart_from_earlystop(self, restart_indices, batch_indices_new): - self.batch_indices = batch_indices_new.to(self.device) - self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device) - self.batch_size = len(torch.unique(batch_indices_new)) - - linesearch_list_new = [] - steps_new = [] - phi0_values_new = [] - derphi0_values_new = [] - - for i, idx in enumerate(restart_indices): - linesearch_list_new.append(self.linesearch_list[idx]) - steps_new.append(self.steps[idx]) - phi0_values_new.append(self.phi0_values[idx]) - derphi0_values_new.append(self.derphi0_values[idx]) - - for i in range(len(restart_indices), self.batch_size): - linesearch_list_new.append(LineSearch(device=self.device)) - steps_new.append(1.) - phi0_values_new.append(None) - derphi0_values_new.append(None) - - self.linesearch_list = linesearch_list_new - self.steps = steps_new - self.phi0_values = phi0_values_new - self.derphi0_values = derphi0_values_new - - - - def _linesearch_batch(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, - maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., - stpmax=50., stpmin=1e-8, continue_search=None, max_iter=15): - if continue_search is None: - self.linesearch_list = [LineSearch(device=self.device) for _ in range(self.batch_size)] - else: - assert len(continue_search) == self.batch_size - for i in range(len(continue_search)): - if not continue_search[i]: - self.linesearch_list[i] = LineSearch(device=self.device) - - if isinstance(xk, torch.Tensor): - xk = xk.to(self.device) - for i in range(len(pk)): - pk[i] = pk[i].to(self.device) - if isinstance(gfk, torch.Tensor): - gfk = gfk.to(self.device) - if isinstance(old_fval, torch.Tensor): - old_fval = old_fval.to(self.device) - if isinstance(old_old_fval, torch.Tensor): - old_old_fval = old_old_fval.to(self.device) - - - # results for each batch element - alpha_results = [] - e_result = [] - e0_result = [] - no_update_result = [] - - # Initialize step sizes and line search state for each batch element - completed = [False] * self.batch_size - - # Initialize iteration counter - iter_count = 0 - - # Initialize all line searches using the initialize method - for i in range(self.batch_size): - if continue_search[i]: - continue - - ls = self.linesearch_list[i] - mask = (i == self.batch_indices_flatten) - - # Use the initialize method to set up line search parameters - alpha1, phi0, derphi0 = ls.initialize( - xk[mask], pk[i], gfk[mask], old_fval[i], old_old_fval, - maxstep, c1, c2, xtrapl, xtrapu, stpmax, stpmin - ) - - # Store the initialization values - self.steps[i] = alpha1 - self.phi0_values[i] = phi0 - self.derphi0_values[i] = derphi0 - - # Main optimization loop - while True: - # 1. step forward - # logging.info(f"step's input: alpha1: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}") - for i in range(self.batch_size): - if completed[i]: - continue - ls = self.linesearch_list[i] - if ls.fc > max_iter: - completed[i] = True - logging.warning(f"LineSearchBatch[{i}] reached max_iter: {max_iter}") - continue - stp = ls.step(self.steps[i], self.phi0_values[i], self.derphi0_values[i], - c1, c2, ls.xtol, ls.isave, ls.dsave) - if ls.task[:2] == 'FG': - self.steps[i] = stp - else: - completed[i] = True - - # 2. calculate new function value and gradient - x_new_batch = torch.zeros_like(xk) - for i in range(self.batch_size): - mask = (i == self.batch_indices_flatten) - x_new_batch[mask] = xk[mask] + self.steps[i] * pk[i] - f_batch = func(x_new_batch).to(self.device) - g_batch = myfprime(x_new_batch).to(self.device) - - # 3. update function value and gradient - for i in range(self.batch_size): - ls = self.linesearch_list[i] - mask = (i == self.batch_indices_flatten) - if ls.task[:2] == 'FG': - # Update function value and gradient - f_val = f_batch[i:i+1] - g_val = g_batch[mask] - ls.fc += 1 - phi0, derphi0 = ls.prologue(f_val, g_val, pk[i], self.steps[i]) - # logging.info(f"phi0, derphi0: {phi0}, {derphi0}") - self.phi0_values[i] = phi0 - self.derphi0_values[i] = derphi0 # TODO: why we put the derphi0 here instead of set it inside the LineSearch class? - if ls.no_update: - completed[i] = True - else: - completed[i] = True - - iter_count += 1 - logging.info(f"LineSearchBatch iter: {iter_count}: alpha: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}") - if any(completed): - break - - # 4. set a linesearch upper limit - # if iter_count > max_iter: - # for i in range(self.batch_size): - # completed[i] = True - # logging.warning(f"LineSearchBatch reached max_iter: {max_iter}") - # break - - # Collect results - for i in range(self.batch_size): - ls = self.linesearch_list[i] - if ls.task[:5] == 'ERROR' or ls.task[1:4] == 'WARN': - stp = torch.tensor(1., device=self.device) - else: - stp = self.steps[i] if isinstance(self.steps[i], torch.Tensor) else torch.tensor(self.steps[i], device=self.device) - - alpha_results.append(stp) - e_result.append(self.phi0_values[i].item() if self.phi0_values[i] is not None else None) - e0_result.append(old_fval[i].item() if isinstance(old_fval[i], torch.Tensor) else old_fval[i]) - no_update_result.append(ls.no_update) - - logging.info(f"LineSearchBatch finished in {iter_count} iterations. \ - LineSearch Status: {[stat for stat in completed]}") - - return alpha_results, e_result, e0_result, no_update_result, completed +""" +Copyright (c) 2025 Ma Zhaojia + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations +import logging +import torch +from torch_scatter import scatter +# from .linesearch_torch import LineSearchBatch +from ..optimizable import OptimizableBatch +from torch.profiler import profile, record_function, ProfilerActivity, schedule, tensorboard_trace_handler +from datetime import datetime +import os +import math +import gc + +class BFGSFusedLS: + """ + Port of BFGSLineSearch from bfgslinesearch.py, adapted to PyTorch + and batched operations, mirroring lbfgs_torch.py structure. + """ + def __init__( + self, + optimizable_batch: OptimizableBatch, + maxstep: float = 0.2, + c1: float = 0.23, + c2: float = 0.46, + alpha: float = 10.0, + stpmax: float = 50.0, + device = 'cpu', + early_stop: bool = False, + use_profiler: bool = False, + profiler_log_dir: str = './log', + profiler_schedule_config: dict = None, + dtype: torch.dtype = torch.float64, + ): + self.optimizable = optimizable_batch + self.maxstep = maxstep + self.c1 = c1 + self.c2 = c2 + self.alpha = alpha + self.stpmax = stpmax + self.nsteps = 0 + self.device = device + self.force_calls = 0 + self.early_stop = early_stop + self.use_profiler = use_profiler + self.profiler_log_dir = profiler_log_dir + self.profiler_schedule_config = profiler_schedule_config or {"wait": 48, "warmup": 1, "active": 1, "repeat": 8} + self.dtype = dtype + + self.converge_indices_list = None + + # The information from the previous round is useful for the current round's calculations. + ## These variables need to be update accroding to new input when eary stop is triggered. + self.Hs = None + self.r0 = None + self.g0 = None + self.p_list = [None] * self.optimizable.batch_size + self.no_update_list = [False] * self.optimizable.batch_size + self.ls_completed = [True] * self.optimizable.batch_size + self.ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu", dtype=self.dtype) + ## need to be recalculate when early stop is triggered + self.forces = None + self.energies = None + + def restart_from_earlystop(self, restart_indices, old_batch_indices): + Hs_new = [] + r0_new = torch.zeros_like(self.optimizable.get_positions().reshape(-1), device=self.device) + g0_new = torch.zeros_like(r0_new, device=self.device) + p_list_new = [] + no_update_list_new = [] + ls_completed_new = [] + + # collect the preserved historical info by old_indices + for i, idx in enumerate(restart_indices): + mask_old = (idx==old_batch_indices.repeat_interleave(3)) + mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) + Hs_new.append(self.Hs[idx]) + p_list_new.append(self.p_list[idx]) + no_update_list_new.append(self.no_update_list[idx]) + ls_completed_new.append(self.ls_completed[idx]) + r0_new[mask] = self.r0[mask_old] + g0_new[mask] = self.g0[mask_old] + + # append new info for new element in batch + for i in range(len(Hs_new), self.optimizable.batch_size): + # Hs_new.append(torch.eye(3 * self.optimizable.elem_per_group[i], device=self.device, dtype=torch.float64)) + Hs_new.append(None) + p_list_new.append(None) + no_update_list_new.append(False) + ls_completed_new.append(True) + + self.Hs = Hs_new + self.r0 = r0_new + self.g0 = g0_new + self.p_list = p_list_new + self.no_update_list = no_update_list_new + self.ls_completed = ls_completed_new + self.forces = None + self.energies = None + self.ls_batch.restart_from_earlystop(restart_indices=restart_indices, batch_indices_new=self.optimizable.batch_indices) + + def step(self): + optimizable = self.optimizable + if self.forces is None: + self.forces = optimizable.get_forces().to(self.device) + r = optimizable.get_positions().reshape(-1).to(self.device) + g = -self.forces.reshape(-1) / self.alpha + p0_list = self.p_list + self.update(r, g, self.r0, self.g0, p0_list) + if self.energies is None: + self.energies = self.func(r) + + for i in range(self.optimizable.batch_size): + if self.ls_completed[i]: + p = -torch.matmul(self.Hs[i], g[i==self.optimizable.batch_indices.repeat_interleave(3)]) + + # Implement scaling for numerical stability with simpler calculation + p_size = torch.sqrt((p**2).sum()) + min_size = torch.sqrt(self.optimizable.elem_per_group[i] * 1e-10) + if p_size <= min_size: + p = p * (min_size / p_size) + + self.p_list[i] = p + + # ls_batch = LineSearchBatch(self.optimizable.batch_indices, device="cpu") + continue_search = [not elem for elem in self.ls_completed] + self.alpha_k_list, self.e_list, self.e0_list, self.no_update_list, self.ls_completed = self.ls_batch._linesearch_batch( + self.func, self.fprime, r, self.p_list, g, self.energies, None, + maxstep=self.maxstep, c1=self.c1, c2=self.c2, stpmax=self.stpmax, continue_search=continue_search + ) + + # reset device for linesearch result + for i in range(self.optimizable.batch_size): + if self.ls_completed[i]: + self.alpha_k_list[i] = self.alpha_k_list[i].to(self.device) + self.p_list[i] = self.p_list[i].to(self.device) + + dr_tensor = torch.zeros_like(r) + + + for i in range(self.optimizable.batch_size): + # if check_cache: + # mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) + # dr_tensor_all[mask] = self.alpha_k_list[i].to(self.device) * self.p_list[i].to(self.device) + + if not self.ls_completed[i]: + continue + if self.alpha_k_list[i] is None: + raise RuntimeError("LineSearch failed!") + + mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) + dr_tensor[mask] = self.alpha_k_list[i] * self.p_list[i] + + # if check_cache: + # cached_pos = optimizable.get_positions().reshape(-1).to(self.device) + # update_pos = r + dr_tensor_all + # assert torch.allclose(update_pos, cached_pos), "dr_tensor_cached should be equal to dr_tensor" + + + # TODO: get_forces/get_potential_energies will trigger compare_batch which is time-consuming + forces_cache = optimizable.get_forces() + energies_cache = self.optimizable.get_potential_energies() / self.alpha + + # update self.forces + for i in range(self.optimizable.batch_size): + if not self.ls_completed[i]: + continue + mask = (i == self.optimizable.batch_indices) + self.forces[mask] = forces_cache[mask] + self.energies[i] = energies_cache[i] + + optimizable.set_positions((r + dr_tensor).reshape(-1, 3)) + + self.r0 = r + self.g0 = g + + # @torch.compile + def update(self, r, g, r0, g0, p0_list): + all_sizes = self.optimizable.elem_per_group + + if self.Hs is None: + self.Hs = [ + torch.eye(3 * sz, device=self.device, dtype=self.dtype) + for sz in all_sizes + ] + return + + dr = r - r0 + dg = g - g0 + + for i in range(self.optimizable.batch_size): + if self.Hs[i] is None: + self.Hs[i] = torch.eye(3 * all_sizes[i], device=self.optimizable.device, dtype=self.dtype) + continue + if not self.ls_completed[i]: + continue + if self.no_update_list[i] is True: + print('skip update') + continue + + cur_mask = (i == self.optimizable.batch_indices.repeat_interleave(3)) + cur_g = g[cur_mask] + cur_p0 = p0_list[i] + cur_g0 = g0[cur_mask] + cur_dg = dg[cur_mask] + cur_dr = dr[cur_mask] + + if not (((self.alpha_k_list[i] or 0) > 0 and + abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False): + continue + + try: + rhok = 1.0 / (torch.dot(cur_dg, cur_dr)) + except: + rhok = 1000.0 + print("Divide-by-zero encountered: rhok assumed large") + if torch.isinf(rhok): + rhok = 1000.0 + print("Divide-by-zero encountered: rhok assumed large") + I = torch.eye(all_sizes[i]*3, device=self.device, dtype=self.dtype) + A1 = I - cur_dr[:, None] * cur_dg[None, :] * rhok + A2 = I - cur_dg[:, None] * cur_dr[None, :] * rhok + self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) + + rhok * cur_dr[:, None] * cur_dr[None, :]) + + + # def update(self, r, g, r0, g0, p0_list): + # self.Is = [ + # torch.eye(sz * 3, dtype=torch.float64, device=self.device) + # for sz in self.optimizable.elem_per_group + # ] + + # # TODO: BFGS for loop 是不是在被打断之后需要重建这个 self.Hs? + # # TODO: 并且我们保存的上一次的r,g,r0,g0也被丢弃了 + # if self.Hs is None: + # self.Hs = [ + # torch.eye(3 * sz, device=self.optimizable.device, dtype=torch.float64) + # for sz in self.optimizable.elem_per_group + # ] + # return + # else: + # dr = r - r0 + # dg = g - g0 + + # for i in range(self.optimizable.batch_size): + # if not self.ls_completed[i]: + # continue + # cur_mask = (i==self.optimizable.batch_indices.repeat_interleave(3)) + # cur_g = g[cur_mask] + # cur_p0 = p0_list[i] + # cur_g0 = g0[cur_mask] + # cur_dg = dg[cur_mask] + # cur_dr = dr[cur_mask] + + # if not (((self.alpha_k_list[i] or 0) > 0 and + # abs(torch.dot(cur_g, cur_p0)) - abs(torch.dot(cur_g0, cur_p0)) < 0) or False): + # break + + # if self.no_update_list[i] is True: + # print('skip update') + # break + + # try: + # rhok = 1.0 / (torch.dot(cur_dg, cur_dr)) + # except: + # rhok = 1000.0 + # print("Divide-by-zero encountered: rhok assumed large") + # if torch.isinf(rhok): + # rhok = 1000.0 + # print("Divide-by-zero encountered: rhok assumed large") + # A1 = self.Is[i] - cur_dr[:, None] * cur_dg[None, :] * rhok + # A2 = self.Is[i] - cur_dg[:, None] * cur_dr[None, :] * rhok + # self.Hs[i] = (torch.matmul(A1, torch.matmul(self.Hs[i], A2)) + + # rhok * cur_dr[:, None] * cur_dr[None, :]) + + + + def func(self, x): + self.optimizable.set_positions(x.reshape(-1, 3).to(self.device)) + return self.optimizable.get_potential_energies() / self.alpha + + def fprime(self, x): + self.optimizable.set_positions(x.reshape(-1, 3).to(self.device)) + + self.force_calls += 1 + forces = self.optimizable.get_forces().reshape(-1) + return - forces / self.alpha + + def run(self, fmax, maxstep, is_restart_earlystop=False, restart_indices=None, old_batch_indices=None): + logging.info("Enter bfgsfusedlinesearch's main program.") + self.fmax = fmax + self.max_iter = maxstep + + if is_restart_earlystop: + self.restart_from_earlystop(restart_indices, old_batch_indices) + + iteration = 0 + max_forces = self.optimizable.get_max_forces(apply_constraint=True) + logging.info("Step Fmax(eV/A)") + + # Run with profiler if enabled + if self.use_profiler: + activities = [ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + pid = os.getpid() + with torch.profiler.profile( + activities=activities, + schedule=torch.profiler.schedule( + wait=self.profiler_schedule_config["wait"], + warmup=self.profiler_schedule_config["warmup"], + active=self.profiler_schedule_config["active"], + repeat=self.profiler_schedule_config["repeat"] + ), + on_trace_ready=tensorboard_trace_handler(self.profiler_log_dir, worker_name=f"BFGSLS_{pid}"), + with_stack=True, + profile_memory=True, + ) as prof: + # Main optimization loop with profiling + while iteration < self.max_iter and not self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, + ): + if self.early_stop and iteration > 0: + self.converge_indices_list = self.optimizable.converge_indices_list + if len(self.converge_indices_list) > 0: + logging.info(f"Early stopping at iteration {iteration}") + break + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + self.step() + max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces) + iteration += 1 + + # Step the profiler in each iteration + prof.step() + + else: + # Original optimization loop without profiling + while iteration < self.max_iter and not self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces, f_upper_limit=1e25, + ): + if self.early_stop and iteration > 0: + self.converge_indices_list = self.optimizable.converge_indices_list + if len(self.converge_indices_list) > 0: + logging.info(f"Early stopping at iteration {iteration}") + break + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + self.step() + max_forces = self.optimizable.get_max_forces(apply_constraint=True, forces=self.forces) + iteration += 1 + + logging.info( + f"{iteration} " + " ".join(f"{x:18.15g}" for x in max_forces.tolist()) + ) + + # GPU memory usage as per nvidia-smi seems to gradually build up as + # batches are processed. This releases unoccupied cached memory. + torch.cuda.empty_cache() + gc.collect() + + # set predicted values to batch + for name, value in self.optimizable.results.items(): + setattr(self.optimizable.batch, name, value) + + self.nsteps = iteration + + if self.early_stop: + self.converge_indices_list = self.optimizable.converge_indices_list + return self.converge_indices_list + else: + return self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces + ) + + def _batched_dot_2d(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" + ) + + def _batched_dot_1d(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y), self.optimizable.batch_indices.repeat_interleave(3), reduce="sum" + ) + +# flake8: noqa +import math +import torch +import logging + +pymin = min +pymax = max + + +class LineSearch: + def __init__(self, xtol=1e-14, device='cpu', dtype=torch.float64): + self.xtol = xtol + self.task = 'START' + self.device = device + self.dtype = dtype + self.isave = torch.zeros(2, dtype=torch.int64, device=self.device) + self.dsave = torch.zeros(13, dtype=self.dtype, device=self.device) + self.fc = 0 + self.gc = 0 + self.case = 0 + self.old_stp = 0 + + def initialize(self, xk, pk, gfk, old_fval, old_old_fval, + maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., + stpmax=50., stpmin=1e-8): + # Scalar parameters can stay as Python scalars + self.stpmin = stpmin + self.stpmax = stpmax + self.xtrapl = xtrapl + self.xtrapu = xtrapu + self.maxstep = maxstep + + # Move tensors to the device + self.pk = pk.to(self.device) + xk = xk.to(self.device) + gfk = gfk.to(self.device) + + phi0 = old_fval + + + # This dot product needs tensors + derphi0 = torch.dot(gfk, self.pk).item() + + # Use Python math for scalar calculations + self.dim = len(pk) + self.gms = math.sqrt(self.dim) * maxstep + + alpha1 = 1.0 + self.no_update = False + self.gradient = True + + self.steps = [] + return alpha1, phi0, derphi0 + + def prologue(self, fval, gval, pk_tensor, alpha1): + phi0 = fval + derphi0 = torch.dot(gval, pk_tensor) + self.old_stp = alpha1 + # TODO: self.no_update == True: break is needed to reimplemented. + + return phi0, derphi0 + + def epilogue(self): + pass + + def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, + maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., + stpmax=50., stpmin=1e-8, args=()): + self.stpmin = stpmin + self.pk = pk.to(self.device) + self.stpmax = stpmax + self.xtrapl = xtrapl + self.xtrapu = xtrapu + self.maxstep = maxstep + + xk = xk.to(self.device) + + # Convert inputs to torch tensors if they're not already + if not isinstance(old_fval, torch.Tensor): + phi0 = torch.tensor(old_fval, dtype=self.dtype, device=self.device) + else: + phi0 = old_fval.to(self.device) + + # Ensure pk and gfk are torch tensors + pk_tensor = torch.tensor(pk, dtype=self.dtype, device=self.device) if not isinstance(pk, torch.Tensor) else pk.to(self.device) + gfk_tensor = torch.tensor(gfk, dtype=self.dtype, device=self.device) if not isinstance(gfk, torch.Tensor) else gfk.to(self.device) + + derphi0 = torch.dot(gfk_tensor, pk_tensor) + self.dim = len(pk) + self.gms = torch.sqrt(torch.tensor(self.dim, dtype=self.dtype, device=self.device)) * maxstep + alpha1 = 1. + self.no_update = False + + if isinstance(myfprime, tuple): + fprime = myfprime[0] + gradient = False + else: + fprime = myfprime + newargs = args + gradient = True + + fval = phi0 + gval = gfk_tensor + self.steps = [] + + while True: + stp = self.step(alpha1, phi0, derphi0, c1, c2, + self.xtol, + self.isave, self.dsave) + + if self.task[:2] == 'FG': + alpha1 = stp + + # Get function value and gradient + x_new = xk + stp * pk_tensor + fval = func(x_new).to(self.device) + self.fc += 1 + + gval = fprime(x_new).to(self.device) + if gradient: + self.gc += 1 + else: + self.fc += len(xk) + 1 + + phi0 = fval + derphi0 = torch.dot(gval, pk_tensor) + self.old_stp = alpha1 + + if self.no_update == True: + break + else: + break + + if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN': + stp = None # failed + + return stp, fval.item(), old_fval.item() if isinstance(old_fval, torch.Tensor) else old_fval, self.no_update + + def step(self, stp, f, g, c1, c2, xtol, isave, dsave): + if self.task[:5] == 'START': + # Check the input arguments for errors. + if stp < self.stpmin: + self.task = 'ERROR: STP .LT. minstep' + if stp > self.stpmax: + self.task = 'ERROR: STP .GT. maxstep' + if g >= 0: + self.task = 'ERROR: INITIAL G >= 0' + if c1 < 0: + self.task = 'ERROR: c1 .LT. 0' + if c2 < 0: + self.task = 'ERROR: c2 .LT. 0' + if xtol < 0: + self.task = 'ERROR: XTOL .LT. 0' + if self.stpmin < 0: + self.task = 'ERROR: minstep .LT. 0' + if self.stpmax < self.stpmin: + self.task = 'ERROR: maxstep .LT. minstep' + if self.task[:5] == 'ERROR': + return stp + + # Initialize local variables. + self.bracket = False + stage = 1 + finit = f + ginit = g + gtest = c1 * ginit + width = self.stpmax - self.stpmin + width1 = width / .5 + + # The variables stx, fx, gx contain the values of the step, + # function, and derivative at the best step. + # The variables sty, fy, gy contain the values of the step, + # function, and derivative at sty. + # The variables stp, f, g contain the values of the step, + # function, and derivative at stp. + stx = 0.0 + fx = finit + gx = ginit + sty = 0.0 + fy = finit + gy = ginit + stmin = 0.0 + stmax = stp + self.xtrapu * stp + self.task = 'FG' + self.save((stage, ginit, gtest, gx, + gy, finit, fx, fy, stx, sty, + stmin, stmax, width, width1)) + stp = self.determine_step(stp) + return stp + else: + if self.isave[0] == 1: + self.bracket = True + else: + self.bracket = False + stage = self.isave[1] + (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, + width, width1) = self.dsave + + # If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the + # algorithm enters the second stage. + ftest = finit + stp * gtest + if stage == 1 and f < ftest and g >= 0.: + stage = 2 + + # Test for warnings. + if self.bracket and (stp <= stmin or stp >= stmax): + self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS' + if self.bracket and stmax - stmin <= self.xtol * stmax: + self.task = 'WARNING: XTOL TEST SATISFIED' + if stp == self.stpmax and f <= ftest and g <= gtest: + self.task = 'WARNING: STP = maxstep' + if stp == self.stpmin and (f > ftest or g >= gtest): + self.task = 'WARNING: STP = minstep' + + # Test for convergence. + # if f <= ftest and abs(g) <= c2 * (- ginit): + # self.task = 'CONVERGENCE' + if (f < ftest or math.isclose(f, ftest, rel_tol=1e-6, abs_tol=1e-5)) and (abs(g) < c2 * (- ginit) or math.isclose(abs(g), c2 * (- ginit), rel_tol=1e-6, abs_tol=1e-5)): + self.task = 'CONVERGENCE' + + # Test for termination. + if self.task[:4] == 'WARN' or self.task[:4] == 'CONV': + self.save((stage, ginit, gtest, gx, + gy, finit, fx, fy, stx, sty, + stmin, stmax, width, width1)) + return stp + + stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty, + fy, gy, stp, f, g, + stmin, stmax) + + # Decide if a bisection step is needed. + if self.bracket: + if abs(sty - stx) >= .66 * width1: + stp = stx + .5 * (sty - stx) + width1 = width + width = abs(sty - stx) + + # Set the minimum and maximum steps allowed for stp. + if self.bracket: + stmin = min(stx, sty) + stmax = max(stx, sty) + else: + stmin = stp + self.xtrapl * (stp - stx) + stmax = stp + self.xtrapu * (stp - stx) + + # Force the step to be within the bounds maxstep and minstep. + stp = max(stp, self.stpmin) + stp = min(stp, self.stpmax) + + if (stx == stp and stp == self.stpmax and stmin > self.stpmax): + self.no_update = True + + # If further progress is not possible, let stp be the best + # point obtained during the search. + if (self.bracket and stp < stmin or stp >= stmax) \ + or (self.bracket and stmax - stmin < self.xtol * stmax): + stp = stx + + # Obtain another function and derivative. + self.task = 'FG' + self.save((stage, ginit, gtest, gx, + gy, finit, fx, fy, stx, sty, + stmin, stmax, width, width1)) + return stp + + def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp, + stpmin, stpmax): + sign = gp * (gx / abs(gx)) + + # First case: A higher function value. The minimum is bracketed. + # If the cubic step is closer to stx than the quadratic step, the + # cubic step is taken, otherwise the average of the cubic and + # quadratic steps is taken. + if fp > fx: # case1 + self.case = 1 + theta = 3. * (fx - fp) / (stp - stx) + gx + gp + s = max(max(abs(theta), abs(gx)), abs(gp)) + gamma = s * math.sqrt((theta / s) ** 2. - (gx / s) * (gp / s)) + if stp < stx: + gamma = -gamma + p = (gamma - gx) + theta + q = ((gamma - gx) + gamma) + gp + r = p / q + stpc = stx + r * (stp - stx) + stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \ + * (stp - stx) + if (abs(stpc - stx) < abs(stpq - stx)): + stpf = stpc + else: + stpf = stpc + (stpq - stpc) / 2. + + self.bracket = True + + # Second case: A lower function value and derivatives of opposite + # sign. The minimum is bracketed. If the cubic step is farther from + # stp than the secant step, the cubic step is taken, otherwise the + # secant step is taken. + elif sign < 0: # case2 + self.case = 2 + theta = 3. * (fx - fp) / (stp - stx) + gx + gp + s = max(max(abs(theta), abs(gx)), abs(gp)) + gamma = s * math.sqrt((theta / s) ** 2 - (gx / s) * (gp / s)) + if stp > stx: + gamma = -gamma + p = (gamma - gp) + theta + q = ((gamma - gp) + gamma) + gx + r = p / q + stpc = stp + r * (stx - stp) + stpq = stp + (gp / (gp - gx)) * (stx - stp) + if (abs(stpc - stp) > abs(stpq - stp)): + stpf = stpc + else: + stpf = stpq + self.bracket = True + + # Third case: A lower function value, derivatives of the same sign, + # and the magnitude of the derivative decreases. + elif abs(gp) < abs(gx): # case3 + self.case = 3 + # The cubic step is computed only if the cubic tends to infinity + # in the direction of the step or if the minimum of the cubic + # is beyond stp. Otherwise the cubic step is defined to be the + # secant step. + theta = 3. * (fx - fp) / (stp - stx) + gx + gp + s = max(max(abs(theta), abs(gx)), abs(gp)) + + # The case gamma = 0 only arises if the cubic does not tend + # to infinity in the direction of the step. + gamma = s * math.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s))) + if stp > stx: + gamma = -gamma + p = (gamma - gp) + theta + q = (gamma + (gx - gp)) + gamma + r = p / q + if r < 0. and gamma != 0: + stpc = stp + r * (stx - stp) + elif stp > stx: + stpc = stpmax + else: + stpc = stpmin + stpq = stp + (gp / (gp - gx)) * (stx - stp) + + if self.bracket: + # A minimizer has been bracketed. If the cubic step is + # closer to stp than the secant step, the cubic step is + # taken, otherwise the secant step is taken. + if abs(stpc - stp) < abs(stpq - stp): + stpf = stpc + else: + stpf = stpq + if stp > stx: + stpf = min(stp + .66 * (sty - stp), stpf) + else: + stpf = max(stp + .66 * (sty - stp), stpf) + else: + # A minimizer has not been bracketed. If the cubic step is + # farther from stp than the secant step, the cubic step is + # taken, otherwise the secant step is taken. + if abs(stpc - stp) > abs(stpq - stp): + stpf = stpc + else: + stpf = stpq + stpf = min(stpmax, stpf) + stpf = max(stpmin, stpf) + + # Fourth case: A lower function value, derivatives of the same sign, + # and the magnitude of the derivative does not decrease. If the + # minimum is not bracketed, the step is either minstep or maxstep, + # otherwise the cubic step is taken. + else: # case4 + self.case = 4 + if self.bracket: + theta = 3. * (fp - fy) / (sty - stp) + gy + gp + s = max(max(abs(theta), abs(gy)), abs(gp)) + gamma = s * math.sqrt((theta / s) ** 2 - (gy / s) * (gp / s)) + if stp > sty: + gamma = -gamma + p = (gamma - gp) + theta + q = ((gamma - gp) + gamma) + gy + r = p / q + stpc = stp + r * (sty - stp) + stpf = stpc + elif stp > stx: + stpf = stpmax + else: + stpf = stpmin + + # Update the interval which contains a minimizer. + if fp > fx: + sty = stp + fy = fp + gy = gp + else: + if sign < 0: + sty = stx + fy = fx + gy = gx + stx = stp + fx = fp + gx = gp + + # Compute the new step. + stp = self.determine_step(stpf) + + return stx, sty, stp, gx, fx, gy, fy + + def determine_step(self, stp): + dr = stp - self.old_stp + x = torch.reshape(self.pk.to(self.device), (-1, 3)) + steplengths = ((dr * x)**2).sum(1)**0.5 + maxsteplength = max(steplengths) + if maxsteplength >= self.maxstep: + dr *= self.maxstep / maxsteplength + stp = self.old_stp + dr + return stp + + def save(self, data): + if self.bracket: + self.isave[0] = 1 + else: + self.isave[0] = 0 + self.isave[1] = data[0] + self.dsave = data[1:] + +class LineSearchBatch: + + def __init__(self, batch_indices, device='cpu', dtype=torch.float64): + self.device = device + self.dtype = dtype + self.batch_indices = batch_indices.to(self.device) + self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device) + self.batch_size = len(torch.unique(batch_indices)) + self.linesearch_list = [LineSearch(device=self.device, dtype=self.dtype) for _ in range(self.batch_size)] + self.steps = [1.] * self.batch_size + self.phi0_values = [None] * self.batch_size + self.derphi0_values = [None] * self.batch_size + + def restart_from_earlystop(self, restart_indices, batch_indices_new): + self.batch_indices = batch_indices_new.to(self.device) + self.batch_indices_flatten = self.batch_indices.repeat_interleave(3).to(self.device) + self.batch_size = len(torch.unique(batch_indices_new)) + + linesearch_list_new = [] + steps_new = [] + phi0_values_new = [] + derphi0_values_new = [] + + for i, idx in enumerate(restart_indices): + linesearch_list_new.append(self.linesearch_list[idx]) + steps_new.append(self.steps[idx]) + phi0_values_new.append(self.phi0_values[idx]) + derphi0_values_new.append(self.derphi0_values[idx]) + + for i in range(len(restart_indices), self.batch_size): + linesearch_list_new.append(LineSearch(device=self.device)) + steps_new.append(1.) + phi0_values_new.append(None) + derphi0_values_new.append(None) + + self.linesearch_list = linesearch_list_new + self.steps = steps_new + self.phi0_values = phi0_values_new + self.derphi0_values = derphi0_values_new + + + + def _linesearch_batch(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, + maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., + stpmax=50., stpmin=1e-8, continue_search=None, max_iter=15): + if continue_search is None: + self.linesearch_list = [LineSearch(device=self.device) for _ in range(self.batch_size)] + else: + assert len(continue_search) == self.batch_size + for i in range(len(continue_search)): + if not continue_search[i]: + self.linesearch_list[i] = LineSearch(device=self.device) + + if isinstance(xk, torch.Tensor): + xk = xk.to(self.device) + for i in range(len(pk)): + pk[i] = pk[i].to(self.device) + if isinstance(gfk, torch.Tensor): + gfk = gfk.to(self.device) + if isinstance(old_fval, torch.Tensor): + old_fval = old_fval.to(self.device) + if isinstance(old_old_fval, torch.Tensor): + old_old_fval = old_old_fval.to(self.device) + + + # results for each batch element + alpha_results = [] + e_result = [] + e0_result = [] + no_update_result = [] + + # Initialize step sizes and line search state for each batch element + completed = [False] * self.batch_size + + # Initialize iteration counter + iter_count = 0 + + # Initialize all line searches using the initialize method + for i in range(self.batch_size): + if continue_search[i]: + continue + + ls = self.linesearch_list[i] + mask = (i == self.batch_indices_flatten) + + # Use the initialize method to set up line search parameters + alpha1, phi0, derphi0 = ls.initialize( + xk[mask], pk[i], gfk[mask], old_fval[i], old_old_fval, + maxstep, c1, c2, xtrapl, xtrapu, stpmax, stpmin + ) + + # Store the initialization values + self.steps[i] = alpha1 + self.phi0_values[i] = phi0 + self.derphi0_values[i] = derphi0 + + # Main optimization loop + while True: + # 1. step forward + # logging.info(f"step's input: alpha1: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}") + for i in range(self.batch_size): + if completed[i]: + continue + ls = self.linesearch_list[i] + if ls.fc > max_iter: + completed[i] = True + logging.warning(f"LineSearchBatch[{i}] reached max_iter: {max_iter}") + continue + stp = ls.step(self.steps[i], self.phi0_values[i], self.derphi0_values[i], + c1, c2, ls.xtol, ls.isave, ls.dsave) + if ls.task[:2] == 'FG': + self.steps[i] = stp + else: + completed[i] = True + + # 2. calculate new function value and gradient + x_new_batch = torch.zeros_like(xk) + for i in range(self.batch_size): + mask = (i == self.batch_indices_flatten) + x_new_batch[mask] = xk[mask] + self.steps[i] * pk[i] + f_batch = func(x_new_batch).to(self.device) + g_batch = myfprime(x_new_batch).to(self.device) + + # 3. update function value and gradient + for i in range(self.batch_size): + ls = self.linesearch_list[i] + mask = (i == self.batch_indices_flatten) + if ls.task[:2] == 'FG': + # Update function value and gradient + f_val = f_batch[i:i+1] + g_val = g_batch[mask] + ls.fc += 1 + phi0, derphi0 = ls.prologue(f_val, g_val, pk[i], self.steps[i]) + # logging.info(f"phi0, derphi0: {phi0}, {derphi0}") + self.phi0_values[i] = phi0 + self.derphi0_values[i] = derphi0 # TODO: why we put the derphi0 here instead of set it inside the LineSearch class? + if ls.no_update: + completed[i] = True + else: + completed[i] = True + + iter_count += 1 + logging.info(f"LineSearchBatch iter: {iter_count}: alpha: {torch.tensor([step.item() if isinstance(step, torch.Tensor) else step for step in self.steps])}") + if any(completed): + break + + # 4. set a linesearch upper limit + # if iter_count > max_iter: + # for i in range(self.batch_size): + # completed[i] = True + # logging.warning(f"LineSearchBatch reached max_iter: {max_iter}") + # break + + # Collect results + for i in range(self.batch_size): + ls = self.linesearch_list[i] + if ls.task[:5] == 'ERROR' or ls.task[1:4] == 'WARN': + stp = torch.tensor(1., device=self.device) + else: + stp = self.steps[i] if isinstance(self.steps[i], torch.Tensor) else torch.tensor(self.steps[i], device=self.device) + + alpha_results.append(stp) + e_result.append(self.phi0_values[i].item() if self.phi0_values[i] is not None else None) + e0_result.append(old_fval[i].item() if isinstance(old_fval[i], torch.Tensor) else old_fval[i]) + no_update_result.append(ls.no_update) + + logging.info(f"LineSearchBatch finished in {iter_count} iterations. \ + LineSearch Status: {[stat for stat in completed]}") + + return alpha_results, e_result, e0_result, no_update_result, completed diff --git a/mace-bench/src/batchopt/relaxengine.py b/mace-bench/src/batchopt/relaxengine.py index 09c3dae..617145f 100644 --- a/mace-bench/src/batchopt/relaxengine.py +++ b/mace-bench/src/batchopt/relaxengine.py @@ -1,1433 +1,1433 @@ -""" -Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan} - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from ase.io import read - -# from ase.optimize import ASE_LBFGS -import torch -from torch.multiprocessing import Process, set_start_method -from batchopt.atoms_to_graphs import AtomsToGraphs -from batchopt.utils import data_list_collater -from batchopt.relaxation.optimizers import ( - BFGS, - BFGSFusedLS, -) -from batchopt.relaxation import OptimizableBatch, OptimizableUnitCellBatch -import logging -import time -import csv -from multiprocessing import Queue -import os -import psutil -import multiprocessing -import json -import subprocess - -try: - from chgnet.model.dynamics import CHGNetCalculator -except ImportError: - logging.warning("Failed to import CHGNet modules") - -try: - from sevenn.calculator import SevenNetCalculator, SevenNetD3Calculator -except ImportError: - logging.warning("Failed to import SevenNet modules") - -try: - from fairchem.core import pretrained_mlip, FAIRChemCalculator -except ImportError: - logging.warning("Failed to import FAIRChem modules") - -try: - from mace.calculators import mace_off -except ImportError: - logging.warning("Failed to import MACE modules") - -import threading -from .utils import count_atoms_cif -from collections import deque - - -class Scheduler: - """ - Scheduler distributes relaxation tasks to workers. - """ - - def __init__( - self, - files, - num_workers, - devices, - batch_size, - max_steps, - filter1, - filter2, - optimizer1, - optimizer2, - skip_second_stage, - scalar_pressure, - compile_mode, - profile, - num_threads, - bind_cores, - cueq, - molecule_single, - output_path, - model, - ): - - self.files = files - self.num_workers = num_workers - self.devices = devices - self.batch_size = batch_size - self.max_steps = max_steps - self.filter1 = filter1 - self.filter2 = filter2 - self.optimizer1 = optimizer1 - self.optimizer2 = optimizer2 - self.skip_second_stage = skip_second_stage - self.scalar_pressure = scalar_pressure - self.compile_mode = compile_mode - self.profile = profile - self.num_threads = num_threads - self.cueq = cueq - self.molecule_single = molecule_single - self.output_path = ( - output_path - if os.path.isabs(output_path) - else os.path.abspath(output_path) - ) - self.model = model - - try: - set_start_method("spawn") - except RuntimeError: - logging.warning( - "set_start_method('spawn') failed, trying 'forkserver' instead." - ) - - if bind_cores is not None: - self.cpu_mask = self._parse_bind_cores(bind_cores) - else: - self.cpu_mask = None - - def _parse_bind_cores(self, bind_cores): - # Expect custom_bind_str to be like "0-15,16-31,..." - ranges = bind_cores.split(",") - if len(ranges) != self.num_workers: - return None - binding = [] - for r in ranges: - try: - start_str, end_str = r.split("-") - start = int(start_str) - end = int(end_str) - except ValueError: - logging.error("Custom binding format should be 'start-end'.") - return None - - binding.append(set(range(start, end + 1))) - return binding - - def _get_physical_logical_core_mapping(self): - """Get the mapping between logical cores and their physical core IDs.""" - try: - # This information is available in Linux systems - mapping = {} - logical_cores = psutil.cpu_count(logical=True) - - for i in range(logical_cores): - try: - # Read core_id from /sys/devices/system/cpu/cpu{i}/topology/core_id - with open( - f"/sys/devices/system/cpu/cpu{i}/topology/core_id" - ) as f: - core_id = int(f.read().strip()) - # Read physical_package_id (socket) for more complete information - with open( - f"/sys/devices/system/cpu/cpu{i}/topology/physical_package_id" - ) as f: - package_id = int(f.read().strip()) - mapping[i] = (package_id, core_id) - except (FileNotFoundError, ValueError, IOError): - mapping[i] = None - return mapping - except Exception as e: - logging.error(f"Failed to get core mapping: {e}") - return {} - - def _get_physical_core_mask(self): - # Get the number of physical and logical cores - physical_cores = psutil.cpu_count(logical=False) - logical_cores = psutil.cpu_count(logical=True) - - if physical_cores is None or physical_cores < 1: - # Fallback to multiprocessing if psutil fails - logical_cores = multiprocessing.cpu_count() - physical_cores = logical_cores // 2 - if physical_cores < 1: - physical_cores = 1 - print(f"Using estimated physical cores: {physical_cores}") - - # Get the mapping between logical and physical cores - core_mapping = self._get_physical_logical_core_mapping() - - # Create a CPU mask that includes all physical cores (first core of each physical core) - physical_core_mask = set() - if core_mapping: - # Group by physical core ID - cores_by_physical = {} - for logical_id, physical_info in core_mapping.items(): - if physical_info is not None: - package_id, core_id = physical_info - key = (package_id, core_id) - if key not in cores_by_physical: - cores_by_physical[key] = [] - cores_by_physical[key].append(logical_id) - - # Select one logical core from each physical core - for physical_cores_list in cores_by_physical.values(): - physical_core_mask.add( - physical_cores_list[0] - ) # First logical core of each physical core - else: - # If mapping fails, use a simple assumption (may not be accurate on all systems) - threads_per_core = logical_cores // physical_cores - physical_core_mask = set(range(0, logical_cores, threads_per_core)) - - return physical_core_mask - - def worker_task( - self, files, device, batch_size, result_queue, physical_cores - ): - if physical_cores is not None: - try: - # Bind the current process to physical cores - pid = os.getpid() - os.sched_setaffinity(pid, physical_cores) - logging.info(f"bind to physical_core_ids: {physical_cores}") - - # Verify the affinity was set correctly - current_affinity = os.sched_getaffinity(pid) - logging.info( - f"Process bound to {len(current_affinity)} cores: {sorted(current_affinity)}" - ) - - except AttributeError: - logging.error( - "sched_setaffinity not supported on this platform" - ) - except Exception as e: - logging.error(f"Failed to bind to physical cores: {e}") - - # pass the number of processes on each worker - nproc = self.num_workers // len(self.devices) - - worker = Worker( - files, - device, - batch_size, - self.max_steps, - self.filter1, - self.filter2, - self.optimizer1, - self.optimizer2, - self.skip_second_stage, - self.scalar_pressure, - self.compile_mode, - self.profile, - self.cueq, - self.molecule_single, - self.output_path, - self.model, - nproc, - ) - # results = worker.run() - results = worker.continuous_run() - result_queue.put(results) - - def _terminate_processes(self, processes): - """Helper method to terminate all processes.""" - for i, p in processes: - if p.is_alive(): - logging.info(f"Terminating process {p.pid}") - p.terminate() - p.join(timeout=3) # Wait for up to 3 seconds - if p.is_alive(): - logging.warning( - f"Process {p.pid} did not terminate, killing it" - ) - p.kill() - p.join() - - # create a thread to conduct "nvidia-smi" - @staticmethod - def _monitor_memory(interval=2, gpu_index=1): - try: - while True: - result = subprocess.check_output( - [ - "nvidia-smi", - "--query-gpu=memory.used,memory.total", - "--format=csv,nounits,noheader", - ] - ).decode("utf-8") - - lines = result.strip().split("\n") - used, total = map(int, lines[gpu_index].split(",")) - logging.info( - f"[nvidia-smi] Memory-Usage on GPU {gpu_index}: {used}MiB / {total}MiB" - ) - - time.sleep(interval) - except KeyboardInterrupt: - logging.info("Monitor interrupted.") - - except Exception as e: - logging.error(f"Unexpected error when monitor memory: {str(e)}") - - def run(self): - logging.info(f"Starting Scheduler with {self.num_workers} workers.") - processes = [] - result_queue = Queue() - start_time = time.perf_counter() - - if self.cpu_mask is not None: - physical_cores_per_worker = self.cpu_mask - logging.info( - f"Use customed cores binding. Physical cores per worker: {physical_cores_per_worker}" - ) - else: - # all_physical_cores = self._get_physical_core_mask() - # num_per_worker = len(all_physical_cores) // self.num_workers - # physical_cores_per_worker = [ - # list(all_physical_cores)[i:i + num_per_worker] for i in range(0, len(all_physical_cores), num_per_worker) - # ] - # logging.info(f"Physical cores per worker: {physical_cores_per_worker}") - physical_cores_per_worker = [None] * self.num_workers - - try: - # Start all worker processes - for i in range(self.num_workers): - files_for_worker = self.files[i :: self.num_workers] - device = self.devices[i % len(self.devices)] - logging.info( - f"Starting worker {i} with {len(files_for_worker)} files on device {device}." - ) - p = Process( - target=self.worker_task, - args=( - files_for_worker, - device, - self.batch_size, - result_queue, - physical_cores_per_worker[i], - ), - ) - p.start() - processes.append((i, p)) - - # monitor gpu memory usage to figure out what makes the differences of footprint among batches - # in each iteration. - use_memory_monitor = False - if use_memory_monitor: - monitor_proc = Process( - target=Scheduler._monitor_memory, args=() - ) - monitor_proc.start() - - # Monitor processes and collect results - csv_paths = [] - completed_processes = 0 - while completed_processes < self.num_workers: - for i, p in processes: - if not p.is_alive() and p.exitcode != 0: - if p.exitcode == -11 or p.exitcode == 1: - # Restart the process if exit code is -11 or -1 - logging.warning( - f"Worker process {p.pid} exited with code {p.exitcode}. Restarting worker {i}." - ) - files_for_worker = self.files[i :: self.num_workers] - device = self.devices[i % len(self.devices)] - new_process = Process( - target=self.worker_task, - args=( - files_for_worker, - device, - self.batch_size, - result_queue, - physical_cores_per_worker[i], - ), - ) - new_process.start() - processes[i] = ( - i, - new_process, - ) # Replace the old process with the new one - else: - # Raise an error for other exit codes - raise RuntimeError( - f"Worker process {p.pid} failed with exit code {p.exitcode}" - ) - - # Try to get result from queue with timeout - try: - result = result_queue.get(timeout=10) - csv_paths.append(result) - completed_processes += 1 - except Exception as e: - continue - - # terminate monitor - if use_memory_monitor: - monitor_proc.terminate() - monitor_proc.join() - - # Process results and create final CSV - merged_results = [] - for csv_path in csv_paths: - try: - with open(csv_path, mode="r") as f: - reader = csv.DictReader(f) - merged_results.extend(list(reader)) - except Exception as e: - logging.error(f"Error processing {csv_path}: {str(e)}") - - except Exception as e: - # Log the error and elapsed time - end_time = time.perf_counter() - elapsed_time = end_time - start_time - logging.error( - f"Error occurred after running for {elapsed_time:.2f} seconds: {str(e)}" - ) - - # Create error log file - error_log = f"scheduler_error_{int(time.time())}.log" - with open(error_log, "w") as f: - f.write(f"Error occurred after {elapsed_time:.2f} seconds\n") - f.write(f"Error message: {str(e)}\n") - f.write(f"Number of workers: {self.num_workers}\n") - f.write(f"Batch size: {self.batch_size}\n") - - # Terminate all processes - self._terminate_processes(processes) - raise # Re-raise the exception after cleanup - - finally: - end_time = time.perf_counter() - elapsed_time = end_time - start_time - - # Write final results if we have any - if "merged_results" in locals() and merged_results: - csv_file = os.path.join( - self.output_path, "results_scheduler.csv" - ) - with open(csv_file, mode="w", newline="") as file: - writer = csv.DictWriter( - file, - fieldnames=[ - "file", - "stage1_steps", - "stage1_time", - "stage1_energy", - "stage1_density", - "stage2_steps", - "stage2_time", - "stage2_energy", - "stage2_density", - "total_steps", - "total_time", - ], - ) - writer.writeheader() - for row in merged_results: - try: - processed_row = { - "file": row["file"], - "stage1_steps": int(row["stage1_steps"]), - "stage1_time": float(row["stage1_time"]), - "stage1_energy": float(row["stage1_energy"]), - "stage1_density": float(row["stage1_density"]), - "stage2_steps": int(row["stage2_steps"]), - "stage2_time": float(row["stage2_time"]), - "stage2_energy": float(row["stage2_energy"]), - "stage2_density": float(row["stage2_density"]), - "total_steps": int(row["total_steps"]), - "total_time": float(row["total_time"]), - } - writer.writerow(processed_row) - except (KeyError, ValueError) as e: - logging.error( - f"Invalid data format in row {row}: {str(e)}" - ) - - # Write summary - summary_csv_file = os.path.join( - self.output_path, "summary_scheduler.csv" - ) - with open(summary_csv_file, mode="w", newline="") as file: - writer = csv.DictWriter( - file, - fieldnames=["elapsed_time", "num_workers", "batch_size"], - ) - writer.writeheader() - writer.writerow( - { - "elapsed_time": elapsed_time, - "num_workers": self.num_workers, - "batch_size": self.batch_size, - } - ) - - logging.info(f"Scheduler completed in {elapsed_time:.2f} seconds.") - - def run_debug(self): - logging.info("Starting Scheduler in debug mode (sequential execution).") - - def worker_task(files, device, batch_size): - worker = Worker( - files, device, batch_size, self.max_steps, self.filter1 - ) - worker.run() - - for i in range(self.num_workers): - files_for_worker = self.files[i :: self.num_workers] - device = self.devices[i % len(self.devices)] - logging.info( - f"Running worker {i} with {len(files_for_worker)} files on device {device}." - ) - worker_task(files_for_worker, device, self.batch_size) - - logging.info("All workers have completed their tasks in debug mode.") - - -class Worker: - """ - Worker is single process that runs a batch of optimization tasks. - """ - - def __init__( - self, - files, - device, - batch_size, - max_steps, - filter1, - filter2, - optimizer1, - optimizer2, - skip_second_stage, - scalar_pressure, - compile_mode, - profile, - cueq, - molecule_single, - output_path, - model, - nproc, - ): - self.files = files - self.device = device - self.batch_size = batch_size - self.max_steps = max_steps - self.filter1 = filter1 - self.filter2 = filter2 - self.optimizer1 = optimizer1 - self.optimizer2 = optimizer2 - self.skip_second_stage = skip_second_stage # Store skip_second_stage - self.scalar_pressure = scalar_pressure - self.compile_mode = compile_mode - self.profile = profile - self.cueq = cueq - self.molecule_single = molecule_single - self.output_path = ( - output_path - if os.path.isabs(output_path) - else os.path.abspath(output_path) - ) - self.model = model - self.nproc = nproc - - # Parse profiler options if provided - self.use_profiler = False - self.profiler_schedule_config = { - "wait": 48, - "warmup": 1, - "active": 1, - "repeat": 1, - } - self.profiler_log_dir = None - - if self.profile and self.profile != "False": - self.use_profiler = True - # Create directory for profiler output - self.profiler_log_dir = os.path.join(self.output_path, "log") - os.makedirs(self.profiler_log_dir, exist_ok=True) - if self.profile != "True": - try: - # Try to parse profile as a JSON string with schedule config - profile_config = json.loads(self.profile) - if isinstance(profile_config, dict): - for key in ["wait", "warmup", "active", "repeat"]: - if key in profile_config and isinstance( - profile_config[key], int - ): - self.profiler_schedule_config[key] = ( - profile_config[key] - ) - except json.JSONDecodeError: - logging.warning( - f"Could not parse profile config: {self.profile}, using defaults" - ) - - # For monitor thread - self.stop_event = threading.Event() - - def run(self): - logging.info( - f"Worker started on device {self.device} with {len(self.files)} files." - ) - a2g = AtomsToGraphs(r_edges=False, r_pbc=True) - # model = torch.load("/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", map_location=self.device) - # z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) - calculator = mace_off(model="small", device=self.device) - - results = [] - - for batch_files in self._batch_files(self.files, self.batch_size): - logging.info(f"Processing batch with {len(batch_files)} files.") - start_time = time.perf_counter() - - atoms_list = [] - for file in batch_files: - atoms = read(file) - atoms_list.append(atoms) - gbatch = data_list_collater( - [a2g.convert(atoms) for atoms in atoms_list] - ) - - gbatch = gbatch.to(self.device) - if self.filter1 == "UnitCellFilter": - from batchopt.relaxation import OptimizableUnitCellBatch - - obatch = OptimizableUnitCellBatch( - gbatch, - trainer=calculator, - numpy=False, - scalar_pressure=self.scalar_pressure, - ) - else: - obatch = OptimizableBatch( - gbatch, trainer=calculator, numpy=False - ) - - # First optimization stage - if self.optimizer1 == "LBFGS": - batch_optimizer1 = LBFGS( - obatch, damping=1.0, alpha=70.0, maxstep=0.2 - ) - elif self.optimizer1 == "BFGS": - batch_optimizer1 = BFGS(obatch, alpha=70.0, maxstep=0.2) - elif self.optimizer1 == "BFGSLineSearch": - batch_optimizer1 = BFGSLineSearch(obatch, device=self.device) - elif self.optimizer1 == "BFGSFusedLS": - batch_optimizer1 = BFGSFusedLS(obatch, device=self.device) - else: - raise ValueError(f"Unknown optimizer: {self.optimizer1}") - - start_time1 = time.perf_counter() - batch_optimizer1.run(0.01, self.max_steps) - end_time1 = time.perf_counter() - elapsed_time1 = end_time1 - start_time1 - - # Save intermediate results - atoms_list = obatch.get_atoms_list() - for atoms, file_path in zip(atoms_list, batch_files): - file_name = file_path.split("/")[-1] - output_file = os.path.join( - self.output_path, - "cif_result_press", - file_name.replace(".cif", "_press.cif"), - ) - atoms.write(output_file) - - # Capture maximum force after first optimization stage - max_force1 = obatch.get_max_forces(apply_constraint=True) - - steps1 = batch_optimizer1.nsteps - - if self.skip_second_stage: - # If skipping second stage, set metrics to zero - for file, force in zip(batch_files, max_force1): - results.append( - { - "file": file, - "stage1_time": elapsed_time1, - "stage1_steps": steps1, - "stage2_time": 0.0, - "stage2_steps": 0, - "total_time": elapsed_time1, - "total_steps": steps1, - "force1": force.item(), - "force2": 0.0, - } - ) - continue - - # Only proceed with second stage if not skipping - # Reload intermediate structures for second stage - atoms_list = [] - for file_path in batch_files: - file_name = file_path.split("/")[-1] - press_file = os.path.join( - self.output_path, - "cif_result_press", - file_name.replace(".cif", "_press.cif"), - ) - atoms = read(press_file) - atoms_list.append(atoms) - - # Rebuild batch from optimized structures - gbatch = data_list_collater( - [a2g.convert(atoms) for atoms in atoms_list] - ) - gbatch = gbatch.to(self.device) - - # Second optimization stage - if self.filter2 == "UnitCellFilter": - obatch2 = OptimizableUnitCellBatch( - gbatch, trainer=calculator, numpy=False, scalar_pressure=0.0 - ) - else: - obatch2 = OptimizableBatch( - gbatch, trainer=calculator, numpy=False - ) - - if self.optimizer2 == "LBFGS": - batch_optimizer2 = LBFGS( - obatch2, damping=1.0, alpha=70.0, maxstep=0.2 - ) - elif self.optimizer2 == "BFGS": - batch_optimizer2 = BFGS(obatch2, alpha=70.0, maxstep=0.2) - elif self.optimizer2 == "BFGSLineSearch": - batch_optimizer2 = BFGSLineSearch(obatch2, device=self.device) - elif self.optimizer2 == "BFGSFusedLS": - batch_optimizer2 = BFGSFusedLS(obatch2, device=self.device) - else: - raise ValueError(f"Unknown optimizer: {self.optimizer2}") - start_time2 = time.perf_counter() - batch_optimizer2.run(0.01, self.max_steps) - end_time2 = time.perf_counter() - elapsed_time2 = end_time2 - start_time2 - - # Save final results - atoms_list = obatch2.get_atoms_list() - for atoms, file_path in zip(atoms_list, batch_files): - file_name = file_path.split("/")[-1] - output_file = os.path.join( - self.output_path, - "cif_result_final", - file_name.replace(".cif", "_opt.cif"), - ) - atoms.write(output_file) - - # Capture maximum force after second optimization stage - max_force2 = obatch2.get_max_forces(apply_constraint=True) - - steps2 = batch_optimizer2.nsteps - - for file, f1, f2 in zip(batch_files, max_force1, max_force2): - results.append( - { - "file": file, - "stage1_time": elapsed_time1, - "stage1_steps": steps1, - "stage2_time": elapsed_time2, - "stage2_steps": steps2, - "total_time": elapsed_time1 + elapsed_time2, - "total_steps": steps1 + steps2, - "force1": f1.item(), - "force2": f2.item(), - } - ) - - return results - - def _batch_files(self, files, batch_size): - for i in range(0, len(files), batch_size): - yield files[i : i + batch_size] - - @staticmethod - def _torch_memory_monitor(interval=2, device=None, stop_event=None): - try: - # explicitly CUDA initialization - torch.cuda._lazy_init() - while not stop_event.is_set(): - allocated = torch.cuda.memory_allocated(device=device) - reserved = torch.cuda.memory_reserved(device=device) - logging.info( - f"[torch] Allocated Memory: {allocated / 1024**2:.2f} MiB" - ) - logging.info( - f"[torch] Reserved Memory: {reserved / 1024**2:.2f} MiB" - ) - time.sleep(interval) - except Exception as e: - logging.error(f"Unexpected error when monitor memory: {str(e)}") - - def continuous_run(self): - """ - Execute a continuous run of the batching optimization process. - """ - logging.info("Starting continuous_run with two rounds of optimization.") - - # torch memory monitor api - use_torch_memory_monitor = False - if use_torch_memory_monitor: - memory_monitor = threading.Thread( - target=Worker._torch_memory_monitor, - args=(2, self.device, self.stop_event), - ) - memory_monitor.start() - - # First round of optimization - try: - logging.info("Starting first round of optimization.") - results_round1, new_atoms_files = self.continuous_batching( - atoms_path=self.files, - result_path_prefix=os.path.join( - self.output_path, "cif_result_press/" - ), - fmax=0.01, - maxstep=self.max_steps, - use_filter=self.filter1, - optimizer=self.optimizer1, - scalar_pressure=self.scalar_pressure, - dtype=torch.float64, - ) - logging.info( - f"Completed first round of optimization. Results: {len(results_round1)}" - ) - except KeyboardInterrupt as e: - if use_torch_memory_monitor: - self.stop_event.set() - memory_monitor.join() - logging.error(f"Error during first round of optimization: {e}") - raise - except Exception as e: - logging.error(f"Error during first round of optimization: {e}") - raise - - if self.skip_second_stage: - logging.info("Skipping second round of optimization.") - return results_round1 - - # Second round of optimization without pressure - try: - logging.info("Starting second round of optimization.") - results_round2, _ = self.continuous_batching( - atoms_path=new_atoms_files, - result_path_prefix=os.path.join( - self.output_path, "cif_result_final/" - ), - fmax=0.01, - maxstep=self.max_steps, - # maxstep=3000, - use_filter=self.filter2, - optimizer=self.optimizer2, - scalar_pressure=0.0, - dtype=torch.float64, - ) - logging.info( - f"Completed second round of optimization. Results: {len(results_round2)}" - ) - except KeyboardInterrupt as e: - if use_torch_memory_monitor: - self.stop_event.set() - memory_monitor.join() - logging.error(f"Error during second round of optimization: {e}") - raise - except Exception as e: - logging.error(f"Error during second round of optimization: {e}") - raise - - if use_torch_memory_monitor: - self.stop_event.set() - memory_monitor.join() - - return self._save_results_to_csv(results_round1, results_round2) - - def _save_results_to_csv(self, results_round1, results_round2): - """Helper method to save results to CSV file and return the path.""" - combined_results = [] - results_map = {} - - # Process first round results - for result in results_round1: - file_name = result["file"] - results_map[file_name] = { - "file": file_name, - "stage1_steps": result["steps"], - "stage1_time": result["runtime"], - "stage1_energy": result["energy"], - "stage1_density": result["density"], - "stage2_steps": 0, - "stage2_time": 0.0, - "stage2_energy": 0.0, - "stage2_density": 0, - "total_steps": result["steps"], - "total_time": result["runtime"], - } - - # Process second round results - for result in results_round2: - file_name = result["file"] - if file_name in results_map: - results_map[file_name].update( - { - "stage2_steps": result["steps"], - "stage2_time": result["runtime"], - "stage2_energy": result["energy"], - "stage2_density": result["density"], - "total_steps": results_map[file_name]["stage1_steps"] - + result["steps"], - "total_time": results_map[file_name]["stage1_time"] - + result["runtime"], - } - ) - else: - results_map[file_name] = { - "file": file_name, - "stage1_steps": 0, - "stage1_time": 0.0, - "stage1_energy": 0.0, - "stage1_density": 0, - "stage2_steps": result["steps"], - "stage2_time": result["runtime"], - "stage2_energy": result["energy"], - "stage2_density": result["density"], - "total_steps": result["steps"], - "total_time": result["runtime"], - } - - # Convert map to list - combined_results = list(results_map.values()) - - logging.info( - f"Combined results from both rounds. Total results: {len(combined_results)}" - ) - - worker_id = os.getpid() - timestamp = int(time.time()) - csv_filename = f"worker_{worker_id}_{timestamp}.csv" - csv_path = os.path.join( - self.output_path, "worker_results", csv_filename - ) - os.makedirs(os.path.dirname(csv_path), exist_ok=True) - - with open(csv_path, mode="w", newline="") as csvfile: - fieldnames = [ - "file", - "stage1_steps", - "stage1_time", - "stage1_energy", - "stage1_density", - "stage2_steps", - "stage2_time", - "stage2_energy", - "stage2_density", - "total_steps", - "total_time", - ] - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writeheader() - for result in combined_results: - writer.writerow(result) - - return csv_path - - def _get_density(self, crystal): - # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 - total_mass = sum(crystal.get_masses()) # 转换为克 - - # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 - # 1 Å^3 = 1e-24 cm^3 - volume = crystal.get_volume() # 转换为立方厘米 - - # 计算密度,质量除以体积 - density = ( - total_mass / (volume * 10**-24) / (6.022140857 * 10**23) - ) # 单位是 g/cm^3 - return density - - @staticmethod - def select_factor(history: deque): - # TODO: when history is mix of different size, the smaller `values` should be selected. - boundaries = [0, 50, 100, 200, 400, 800] - values = [0.4, 0.8, 0.9, 0.6, 0.5, 0.4] - factor_result = [] - for graph_size in history: - for i in range(len(boundaries) - 1): - if boundaries[i] <= graph_size < boundaries[i + 1]: - factor_result.append(values[i]) - break - if len(factor_result) == 0: - return 0.4 - else: - return min(factor_result) - - def continuous_batching( - self, - atoms_path, - result_path_prefix, - fmax, - maxstep, - use_filter, - optimizer, - scalar_pressure, - dtype=torch.float64, - ): - """ - Performs continuous batched optimization of atomic structures. - - This method implements a continuous batching strategy for optimizing multiple atomic structures, - where converged structures are replaced with new ones to maintain batch efficiency. - - Parameters - ---------- - atoms_path : list - List of file paths to atomic structure files to be optimized - result_path_prefix : str - Prefix for output file paths where optimized structures will be saved - fmax : float, optional - Maximum force criterion for convergence, by default 0.01 - maxstep : int, optional - Maximum number of optimization steps per batch, by default 3000 - use_filter : str, optional - Filter to be used for optimization, by default "UnitCellFilter" - optimizer : str, optional - Optimizer to be used for optimization, by default "LBFGS" - scalar_pressure : float, optional - Scalar pressure to be applied, by default 0.0 - - Returns - ------- - None - The optimized structures are saved to disk - - Notes - ----- - The method: - - Processes structures in batches of predefined size - - Uses MACE neural network potential for energy/force calculations - - Employs LBFGS optimization with unit cell relaxation - - Dynamically replaces converged structures with new ones in the batch - - Tracks convergence and optimization steps for each structure - """ - # Load saved structures - result = [] - optimized_atoms_paths = [] - - json_dir = result_path_prefix.replace("cif", "json") - - remove_list = [] - # TODO: Why we read all CIF here? - for pre_cif in atoms_path: - cif_path = os.path.join(result_path_prefix, pre_cif.split("/")[-1]) - json_path = os.path.join( - json_dir, pre_cif.split("/")[-1].replace(".cif", ".json") - ) - if ( - os.path.exists(cif_path) - and os.path.exists(json_path) - and os.path.getsize(cif_path) > 0 - and os.path.getsize(json_path) > 0 - ): - with open(json_path, "r") as f: - result_data = json.load(f) - result.append(result_data) - optimized_atoms_paths.append(cif_path) - remove_list.append(pre_cif) - logging.info(f"File {cif_path} already exists, loaded.") - # else: - # try: - # read(pre_cif) - # except Exception as e: - # logging.info(f"Failed to read {pre_cif}: {e}") - # remove_list.append(pre_cif) - for i in remove_list: - atoms_path.remove(i) - - if self.batch_size > 0: - # Initialize variables - room_in_batch = self.batch_size - indices_to_process = 0 - cur_batch_path = atoms_path[ - indices_to_process : indices_to_process + room_in_batch - ] - if len(cur_batch_path) == 0: - logging.info("No structures to process.") - return result, optimized_atoms_paths - room_in_batch -= len(cur_batch_path) - indices_to_process += len(cur_batch_path) - cur_atoms_list = [read(path) for path in cur_batch_path] - a2g = AtomsToGraphs(r_edges=False, r_pbc=True) - gbatch = data_list_collater( - [a2g.convert(read(path)) for path in cur_batch_path] - ) - else: - # Set Maximum Number of atoms per batch - history = deque(maxlen=10) - history.append(1000) - max_bnatoms = 24080 - safe_factor = self.select_factor(history) - - indices_to_process = 0 - bnatoms = 0 - cur_batch_path = [] - graphs_list = [] - a2g = AtomsToGraphs(r_edges=False, r_pbc=True) - - while indices_to_process < len(atoms_path): - graph_natoms = count_atoms_cif(atoms_path[indices_to_process]) - if ( - bnatoms + graph_natoms - > max_bnatoms * safe_factor // self.nproc - ): - break - graph = a2g.convert(read(atoms_path[indices_to_process])) - bnatoms += graph_natoms - cur_batch_path.append(atoms_path[indices_to_process]) - graphs_list.append(graph) - indices_to_process += 1 - history.append(graph_natoms) - safe_factor = self.select_factor(history) - if len(graphs_list) == 0: - logging.info("No structures to process.") - return result, optimized_atoms_paths - gbatch = data_list_collater(graphs_list) - logging.info(f"current batch size: {len(cur_batch_path)}") - - total_natoms = sum([graph.natoms for graph in graphs_list]) - logging.info(f"total_natoms: {total_natoms}") - - gbatch = gbatch.to(self.device) - batch_optimizer = None - - # Initial calculator - if self.model == "mace": - if dtype == torch.float32: - calculator = mace_off( - model="small", - device=self.device, - enable_cueq=self.cueq, - default_dtype="float32", - ) - else: - calculator = mace_off( - model="small", device=self.device, enable_cueq=self.cueq - ) - elif self.model == "chgnet": - calculator = CHGNetCalculator( - use_device=self.device, enable_cueq=self.cueq - ) - elif self.model == "sevennet": - # calculator = SevenNetCalculator(device=self.device, enable_cueq=self.cueq) - calculator = SevenNetD3Calculator( - device=self.device, - enable_cueq=self.cueq, - batch_size=self.batch_size, - ) - # calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device) - # calculator = MACECalculator(model_paths="/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", device=self.device, compile_mode=self.compile_mode) - if use_filter == "UnitCellFilter": - obatch = OptimizableUnitCellBatch( - gbatch, - trainer=calculator, - numpy=False, - scalar_pressure=scalar_pressure, - ) - else: - obatch = OptimizableBatch(gbatch, trainer=calculator, numpy=False) - - orig_cells = obatch.orig_cells.clone() - - converged_atoms_count = 0 - converge_indices = [] - all_indices = [] - cur_batch_steps = [0] * len(cur_batch_path) - cur_batch_times = [time.perf_counter()] * len( - cur_batch_path - ) # Track start times - - while converged_atoms_count < len(atoms_path): - # Update batch - if len(all_indices) > 0: - if self.batch_size > 0: - room_in_batch += len(all_indices) - new_batch_path = atoms_path[ - indices_to_process : indices_to_process + room_in_batch - ] - logging.info(f"new_batch_path: {new_batch_path}") - room_in_batch -= len(new_batch_path) - indices_to_process += len(new_batch_path) - - optimized_atoms_new = [] - cur_batch_path_new = [] - cur_batch_steps_new = [] - cur_batch_times_new = [] - orig_cells_new = torch.zeros( - [self.batch_size - room_in_batch, 3, 3], - device=self.device, - ) - cell_offset = 0 - - restart_indices = [] - old_batch_indices = obatch.batch_indices - for i in range(len(optimized_atoms)): - if i in all_indices: - continue - else: - restart_indices.append(i) - optimized_atoms_new.append(optimized_atoms[i]) - cur_batch_path_new.append(cur_batch_path[i]) - cur_batch_steps_new.append(cur_batch_steps[i]) - cur_batch_times_new.append(cur_batch_times[i]) - - orig_cells_new[cell_offset] = orig_cells[i] - cell_offset += 1 - - for new_path in new_batch_path: - optimized_atoms_new.append(read(new_path)) - cur_batch_path_new.append(new_path) - cur_batch_steps_new.append(0) - cur_batch_times_new.append(time.perf_counter()) - - # Update the batch with new structures - optimized_atoms = optimized_atoms_new - cur_batch_path = cur_batch_path_new - cur_batch_steps = cur_batch_steps_new - cur_batch_times = cur_batch_times_new - else: - bnatoms = 0 - optimized_atoms_new = [] - cur_batch_path_new = [] - cur_batch_steps_new = [] - cur_batch_times_new = [] - - restart_indices = [] - old_batch_indices = obatch.batch_indices - for i in range(len(optimized_atoms)): - if i in all_indices: - continue - restart_indices.append(i) - optimized_atoms_new.append(optimized_atoms[i]) - cur_batch_path_new.append(cur_batch_path[i]) - cur_batch_steps_new.append(cur_batch_steps[i]) - cur_batch_times_new.append(cur_batch_times[i]) - bnatoms += a2g.convert(read(cur_batch_path[i])).natoms - - while indices_to_process < len(atoms_path): - new_path = atoms_path[indices_to_process] - graph_natoms = count_atoms_cif(new_path) - if ( - bnatoms + graph_natoms - > max_bnatoms * safe_factor // self.nproc - ): - break - bnatoms += graph_natoms - optimized_atoms_new.append(read(new_path)) - cur_batch_path_new.append(new_path) - cur_batch_steps_new.append(0) - cur_batch_times_new.append(time.perf_counter()) - indices_to_process += 1 - history.append(graph_natoms) - safe_factor = self.select_factor(history) - - orig_cells_new = torch.zeros( - [len(optimized_atoms_new), 3, 3], device=self.device - ) - cell_offset = 0 - for i in range(len(optimized_atoms)): - if i in all_indices: - continue - orig_cells_new[cell_offset] = orig_cells[i] - cell_offset += 1 - - # Update the batch with new structures - optimized_atoms = optimized_atoms_new - cur_batch_path = cur_batch_path_new - cur_batch_steps = cur_batch_steps_new - cur_batch_times = cur_batch_times_new - - logging.info(f"current batch size: {len(optimized_atoms)}") - - graphs_list = [a2g.convert(atoms) for atoms in optimized_atoms] - total_natoms = sum([graph.natoms for graph in graphs_list]) - logging.info(f"total_natoms: {total_natoms}") - logging.info(f"cur_batch_path to processing: {cur_batch_path}") - gbatch = data_list_collater(graphs_list) - gbatch = gbatch.to(self.device) - if self.model == "sevennet": - # calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device) - calculator = SevenNetD3Calculator( - device=self.device, - enable_cueq=self.cueq, - batch_size=self.batch_size, - ) - if use_filter == "UnitCellFilter": - obatch = OptimizableUnitCellBatch( - gbatch, - trainer=calculator, - numpy=False, - scalar_pressure=scalar_pressure, - ) - else: - obatch = OptimizableBatch( - gbatch, trainer=calculator, numpy=False - ) - for i in range(cell_offset): - obatch.orig_cells[i] = orig_cells_new[i] - orig_cells = obatch.orig_cells.clone() - - # Optimize the current batch - if optimizer == "LBFGS": - batch_optimizer = LBFGS( - obatch, - damping=1.0, - alpha=70.0, - maxstep=0.2, - early_stop=True, - ) - elif optimizer == "BFGS": - if len(all_indices) > 0: - logging.info(f"Restarting with indices: {restart_indices}") - batch_optimizer.optimizable = obatch - else: - batch_optimizer = BFGS( - obatch, alpha=70.0, maxstep=0.2, early_stop=True - ) - elif optimizer == "BFGSLineSearch": - batch_optimizer = BFGSLineSearch( - obatch, - device=self.device, - early_stop=True, - use_profiler=self.use_profiler, - profiler_log_dir=self.profiler_log_dir, - profiler_schedule_config=self.profiler_schedule_config, - ) - elif optimizer == "BFGSFusedLS": - if len(all_indices) > 0: - logging.info(f"Restarting with indices: {restart_indices}") - batch_optimizer.optimizable = obatch - else: - batch_optimizer = BFGSFusedLS( - obatch, - device=self.device, - early_stop=True, - use_profiler=self.use_profiler, - profiler_log_dir=self.profiler_log_dir, - profiler_schedule_config=self.profiler_schedule_config, - ) - else: - raise ValueError(f"Unknown optimizer: {optimizer}") - - # 动态计算剩余可用步数(基于当前批次最大已执行步数) - current_max_steps = max(cur_batch_steps) if cur_batch_steps else 0 - remaining_steps = max( - maxstep - current_max_steps, 1 - ) # 保证至少运行1步 - - # 执行优化并获取收敛的索引 - if (optimizer == "BFGSFusedLS" or optimizer == "BFGS") and len( - all_indices - ) > 0: - converge_indices = batch_optimizer.run( - fmax, - remaining_steps, - is_restart_earlystop=True, - restart_indices=restart_indices, - old_batch_indices=old_batch_indices, - ) - else: - converge_indices = batch_optimizer.run(fmax, remaining_steps) - - # Print energies of all structures - # logging.info(f"Final energies of all structures: {batch_optimizer.energies}") - energies_list = ( - batch_optimizer.optimizable.get_potential_energies().tolist() - ) - logging.info(f"Final energies of all structures: {energies_list}") - - # 更新所有结构的累计步数 - cur_batch_steps = [ - steps + batch_optimizer.nsteps for steps in cur_batch_steps - ] - - # 找出超过最大步数的结构索引 - over_maxstep_indices = [ - i - for i, steps in enumerate(cur_batch_steps) - if steps >= maxstep - 1 - ] - - # 合并收敛和超限的索引(去重) - all_indices = list(set(converge_indices + over_maxstep_indices)) - - # Get optimized atoms - optimized_atoms = obatch.get_atoms_list() - converged_atoms_count += len(all_indices) - - end_time = time.perf_counter() - # 处理所有需要退出的结构(包括收敛和超限) - for idx in all_indices: - runtime = end_time - cur_batch_times[idx] - - energy_per_mol = ( - energies_list[idx] - / ( - len(optimized_atoms[idx].get_atomic_numbers()) - / self.molecule_single - ) - * 96.485 - ) - density = self._get_density(optimized_atoms[idx]) - - # Save results - result_data = { - "file": cur_batch_path[idx].split("/")[-1].split(".")[0], - "steps": cur_batch_steps[idx], - "runtime": runtime, - "energy": energy_per_mol, - "density": density, - } - result.append(result_data) - - # Save optimized structure - # converged_atoms_path = os.path.join(result_path_prefix, cur_batch_path[idx].split('/')[-1].replace('.cif', '.traj')) - converged_atoms_path = os.path.join( - result_path_prefix, cur_batch_path[idx].split("/")[-1] - ) - optimized_atoms[idx].write(converged_atoms_path) - optimized_atoms_paths.append(converged_atoms_path) - - # write a json file to store reslt_data - os.makedirs(json_dir, exist_ok=True) - # json_path = os.path.join(json_dir, cur_batch_path[idx].split('/')[-1]+'.json') - json_path = os.path.join( - json_dir, - cur_batch_path[idx].split("/")[-1].replace(".cif", ".json"), - ) - with open(json_path, "w") as f: - json.dump(result_data, f) - - logging.info(f"cur_batch_path: {cur_batch_path}") - logging.info(f"cur_batch_steps: {cur_batch_steps}") - logging.info(f"all_indices: {all_indices}") - logging.info(f"length of optimized_atoms: {len(optimized_atoms)}") - - return result, optimized_atoms_paths +""" +Copyright (c) 2025 {Chengxi Zhao, Zhaojia Ma, Dingrui Fan} + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from ase.io import read + +# from ase.optimize import ASE_LBFGS +import torch +from torch.multiprocessing import Process, set_start_method +from batchopt.atoms_to_graphs import AtomsToGraphs +from batchopt.utils import data_list_collater +from batchopt.relaxation.optimizers import ( + BFGS, + BFGSFusedLS, +) +from batchopt.relaxation import OptimizableBatch, OptimizableUnitCellBatch +import logging +import time +import csv +from multiprocessing import Queue +import os +import psutil +import multiprocessing +import json +import subprocess + +try: + from chgnet.model.dynamics import CHGNetCalculator +except ImportError: + logging.warning("Failed to import CHGNet modules") + +try: + from sevenn.calculator import SevenNetCalculator, SevenNetD3Calculator +except ImportError: + logging.warning("Failed to import SevenNet modules") + +try: + from fairchem.core import pretrained_mlip, FAIRChemCalculator +except ImportError: + logging.warning("Failed to import FAIRChem modules") + +try: + from mace.calculators import mace_off +except ImportError: + logging.warning("Failed to import MACE modules") + +import threading +from .utils import count_atoms_cif +from collections import deque + + +class Scheduler: + """ + Scheduler distributes relaxation tasks to workers. + """ + + def __init__( + self, + files, + num_workers, + devices, + batch_size, + max_steps, + filter1, + filter2, + optimizer1, + optimizer2, + skip_second_stage, + scalar_pressure, + compile_mode, + profile, + num_threads, + bind_cores, + cueq, + molecule_single, + output_path, + model, + ): + + self.files = files + self.num_workers = num_workers + self.devices = devices + self.batch_size = batch_size + self.max_steps = max_steps + self.filter1 = filter1 + self.filter2 = filter2 + self.optimizer1 = optimizer1 + self.optimizer2 = optimizer2 + self.skip_second_stage = skip_second_stage + self.scalar_pressure = scalar_pressure + self.compile_mode = compile_mode + self.profile = profile + self.num_threads = num_threads + self.cueq = cueq + self.molecule_single = molecule_single + self.output_path = ( + output_path + if os.path.isabs(output_path) + else os.path.abspath(output_path) + ) + self.model = model + + try: + set_start_method("spawn") + except RuntimeError: + logging.warning( + "set_start_method('spawn') failed, trying 'forkserver' instead." + ) + + if bind_cores is not None: + self.cpu_mask = self._parse_bind_cores(bind_cores) + else: + self.cpu_mask = None + + def _parse_bind_cores(self, bind_cores): + # Expect custom_bind_str to be like "0-15,16-31,..." + ranges = bind_cores.split(",") + if len(ranges) != self.num_workers: + return None + binding = [] + for r in ranges: + try: + start_str, end_str = r.split("-") + start = int(start_str) + end = int(end_str) + except ValueError: + logging.error("Custom binding format should be 'start-end'.") + return None + + binding.append(set(range(start, end + 1))) + return binding + + def _get_physical_logical_core_mapping(self): + """Get the mapping between logical cores and their physical core IDs.""" + try: + # This information is available in Linux systems + mapping = {} + logical_cores = psutil.cpu_count(logical=True) + + for i in range(logical_cores): + try: + # Read core_id from /sys/devices/system/cpu/cpu{i}/topology/core_id + with open( + f"/sys/devices/system/cpu/cpu{i}/topology/core_id" + ) as f: + core_id = int(f.read().strip()) + # Read physical_package_id (socket) for more complete information + with open( + f"/sys/devices/system/cpu/cpu{i}/topology/physical_package_id" + ) as f: + package_id = int(f.read().strip()) + mapping[i] = (package_id, core_id) + except (FileNotFoundError, ValueError, IOError): + mapping[i] = None + return mapping + except Exception as e: + logging.error(f"Failed to get core mapping: {e}") + return {} + + def _get_physical_core_mask(self): + # Get the number of physical and logical cores + physical_cores = psutil.cpu_count(logical=False) + logical_cores = psutil.cpu_count(logical=True) + + if physical_cores is None or physical_cores < 1: + # Fallback to multiprocessing if psutil fails + logical_cores = multiprocessing.cpu_count() + physical_cores = logical_cores // 2 + if physical_cores < 1: + physical_cores = 1 + print(f"Using estimated physical cores: {physical_cores}") + + # Get the mapping between logical and physical cores + core_mapping = self._get_physical_logical_core_mapping() + + # Create a CPU mask that includes all physical cores (first core of each physical core) + physical_core_mask = set() + if core_mapping: + # Group by physical core ID + cores_by_physical = {} + for logical_id, physical_info in core_mapping.items(): + if physical_info is not None: + package_id, core_id = physical_info + key = (package_id, core_id) + if key not in cores_by_physical: + cores_by_physical[key] = [] + cores_by_physical[key].append(logical_id) + + # Select one logical core from each physical core + for physical_cores_list in cores_by_physical.values(): + physical_core_mask.add( + physical_cores_list[0] + ) # First logical core of each physical core + else: + # If mapping fails, use a simple assumption (may not be accurate on all systems) + threads_per_core = logical_cores // physical_cores + physical_core_mask = set(range(0, logical_cores, threads_per_core)) + + return physical_core_mask + + def worker_task( + self, files, device, batch_size, result_queue, physical_cores + ): + if physical_cores is not None: + try: + # Bind the current process to physical cores + pid = os.getpid() + os.sched_setaffinity(pid, physical_cores) + logging.info(f"bind to physical_core_ids: {physical_cores}") + + # Verify the affinity was set correctly + current_affinity = os.sched_getaffinity(pid) + logging.info( + f"Process bound to {len(current_affinity)} cores: {sorted(current_affinity)}" + ) + + except AttributeError: + logging.error( + "sched_setaffinity not supported on this platform" + ) + except Exception as e: + logging.error(f"Failed to bind to physical cores: {e}") + + # pass the number of processes on each worker + nproc = self.num_workers // len(self.devices) + + worker = Worker( + files, + device, + batch_size, + self.max_steps, + self.filter1, + self.filter2, + self.optimizer1, + self.optimizer2, + self.skip_second_stage, + self.scalar_pressure, + self.compile_mode, + self.profile, + self.cueq, + self.molecule_single, + self.output_path, + self.model, + nproc, + ) + # results = worker.run() + results = worker.continuous_run() + result_queue.put(results) + + def _terminate_processes(self, processes): + """Helper method to terminate all processes.""" + for i, p in processes: + if p.is_alive(): + logging.info(f"Terminating process {p.pid}") + p.terminate() + p.join(timeout=3) # Wait for up to 3 seconds + if p.is_alive(): + logging.warning( + f"Process {p.pid} did not terminate, killing it" + ) + p.kill() + p.join() + + # create a thread to conduct "nvidia-smi" + @staticmethod + def _monitor_memory(interval=2, gpu_index=1): + try: + while True: + result = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=memory.used,memory.total", + "--format=csv,nounits,noheader", + ] + ).decode("utf-8") + + lines = result.strip().split("\n") + used, total = map(int, lines[gpu_index].split(",")) + logging.info( + f"[nvidia-smi] Memory-Usage on GPU {gpu_index}: {used}MiB / {total}MiB" + ) + + time.sleep(interval) + except KeyboardInterrupt: + logging.info("Monitor interrupted.") + + except Exception as e: + logging.error(f"Unexpected error when monitor memory: {str(e)}") + + def run(self): + logging.info(f"Starting Scheduler with {self.num_workers} workers.") + processes = [] + result_queue = Queue() + start_time = time.perf_counter() + + if self.cpu_mask is not None: + physical_cores_per_worker = self.cpu_mask + logging.info( + f"Use customed cores binding. Physical cores per worker: {physical_cores_per_worker}" + ) + else: + # all_physical_cores = self._get_physical_core_mask() + # num_per_worker = len(all_physical_cores) // self.num_workers + # physical_cores_per_worker = [ + # list(all_physical_cores)[i:i + num_per_worker] for i in range(0, len(all_physical_cores), num_per_worker) + # ] + # logging.info(f"Physical cores per worker: {physical_cores_per_worker}") + physical_cores_per_worker = [None] * self.num_workers + + try: + # Start all worker processes + for i in range(self.num_workers): + files_for_worker = self.files[i :: self.num_workers] + device = self.devices[i % len(self.devices)] + logging.info( + f"Starting worker {i} with {len(files_for_worker)} files on device {device}." + ) + p = Process( + target=self.worker_task, + args=( + files_for_worker, + device, + self.batch_size, + result_queue, + physical_cores_per_worker[i], + ), + ) + p.start() + processes.append((i, p)) + + # monitor gpu memory usage to figure out what makes the differences of footprint among batches + # in each iteration. + use_memory_monitor = False + if use_memory_monitor: + monitor_proc = Process( + target=Scheduler._monitor_memory, args=() + ) + monitor_proc.start() + + # Monitor processes and collect results + csv_paths = [] + completed_processes = 0 + while completed_processes < self.num_workers: + for i, p in processes: + if not p.is_alive() and p.exitcode != 0: + if p.exitcode == -11 or p.exitcode == 1: + # Restart the process if exit code is -11 or -1 + logging.warning( + f"Worker process {p.pid} exited with code {p.exitcode}. Restarting worker {i}." + ) + files_for_worker = self.files[i :: self.num_workers] + device = self.devices[i % len(self.devices)] + new_process = Process( + target=self.worker_task, + args=( + files_for_worker, + device, + self.batch_size, + result_queue, + physical_cores_per_worker[i], + ), + ) + new_process.start() + processes[i] = ( + i, + new_process, + ) # Replace the old process with the new one + else: + # Raise an error for other exit codes + raise RuntimeError( + f"Worker process {p.pid} failed with exit code {p.exitcode}" + ) + + # Try to get result from queue with timeout + try: + result = result_queue.get(timeout=10) + csv_paths.append(result) + completed_processes += 1 + except Exception as e: + continue + + # terminate monitor + if use_memory_monitor: + monitor_proc.terminate() + monitor_proc.join() + + # Process results and create final CSV + merged_results = [] + for csv_path in csv_paths: + try: + with open(csv_path, mode="r") as f: + reader = csv.DictReader(f) + merged_results.extend(list(reader)) + except Exception as e: + logging.error(f"Error processing {csv_path}: {str(e)}") + + except Exception as e: + # Log the error and elapsed time + end_time = time.perf_counter() + elapsed_time = end_time - start_time + logging.error( + f"Error occurred after running for {elapsed_time:.2f} seconds: {str(e)}" + ) + + # Create error log file + error_log = f"scheduler_error_{int(time.time())}.log" + with open(error_log, "w") as f: + f.write(f"Error occurred after {elapsed_time:.2f} seconds\n") + f.write(f"Error message: {str(e)}\n") + f.write(f"Number of workers: {self.num_workers}\n") + f.write(f"Batch size: {self.batch_size}\n") + + # Terminate all processes + self._terminate_processes(processes) + raise # Re-raise the exception after cleanup + + finally: + end_time = time.perf_counter() + elapsed_time = end_time - start_time + + # Write final results if we have any + if "merged_results" in locals() and merged_results: + csv_file = os.path.join( + self.output_path, "results_scheduler.csv" + ) + with open(csv_file, mode="w", newline="") as file: + writer = csv.DictWriter( + file, + fieldnames=[ + "file", + "stage1_steps", + "stage1_time", + "stage1_energy", + "stage1_density", + "stage2_steps", + "stage2_time", + "stage2_energy", + "stage2_density", + "total_steps", + "total_time", + ], + ) + writer.writeheader() + for row in merged_results: + try: + processed_row = { + "file": row["file"], + "stage1_steps": int(row["stage1_steps"]), + "stage1_time": float(row["stage1_time"]), + "stage1_energy": float(row["stage1_energy"]), + "stage1_density": float(row["stage1_density"]), + "stage2_steps": int(row["stage2_steps"]), + "stage2_time": float(row["stage2_time"]), + "stage2_energy": float(row["stage2_energy"]), + "stage2_density": float(row["stage2_density"]), + "total_steps": int(row["total_steps"]), + "total_time": float(row["total_time"]), + } + writer.writerow(processed_row) + except (KeyError, ValueError) as e: + logging.error( + f"Invalid data format in row {row}: {str(e)}" + ) + + # Write summary + summary_csv_file = os.path.join( + self.output_path, "summary_scheduler.csv" + ) + with open(summary_csv_file, mode="w", newline="") as file: + writer = csv.DictWriter( + file, + fieldnames=["elapsed_time", "num_workers", "batch_size"], + ) + writer.writeheader() + writer.writerow( + { + "elapsed_time": elapsed_time, + "num_workers": self.num_workers, + "batch_size": self.batch_size, + } + ) + + logging.info(f"Scheduler completed in {elapsed_time:.2f} seconds.") + + def run_debug(self): + logging.info("Starting Scheduler in debug mode (sequential execution).") + + def worker_task(files, device, batch_size): + worker = Worker( + files, device, batch_size, self.max_steps, self.filter1 + ) + worker.run() + + for i in range(self.num_workers): + files_for_worker = self.files[i :: self.num_workers] + device = self.devices[i % len(self.devices)] + logging.info( + f"Running worker {i} with {len(files_for_worker)} files on device {device}." + ) + worker_task(files_for_worker, device, self.batch_size) + + logging.info("All workers have completed their tasks in debug mode.") + + +class Worker: + """ + Worker is single process that runs a batch of optimization tasks. + """ + + def __init__( + self, + files, + device, + batch_size, + max_steps, + filter1, + filter2, + optimizer1, + optimizer2, + skip_second_stage, + scalar_pressure, + compile_mode, + profile, + cueq, + molecule_single, + output_path, + model, + nproc, + ): + self.files = files + self.device = device + self.batch_size = batch_size + self.max_steps = max_steps + self.filter1 = filter1 + self.filter2 = filter2 + self.optimizer1 = optimizer1 + self.optimizer2 = optimizer2 + self.skip_second_stage = skip_second_stage # Store skip_second_stage + self.scalar_pressure = scalar_pressure + self.compile_mode = compile_mode + self.profile = profile + self.cueq = cueq + self.molecule_single = molecule_single + self.output_path = ( + output_path + if os.path.isabs(output_path) + else os.path.abspath(output_path) + ) + self.model = model + self.nproc = nproc + + # Parse profiler options if provided + self.use_profiler = False + self.profiler_schedule_config = { + "wait": 48, + "warmup": 1, + "active": 1, + "repeat": 1, + } + self.profiler_log_dir = None + + if self.profile and self.profile != "False": + self.use_profiler = True + # Create directory for profiler output + self.profiler_log_dir = os.path.join(self.output_path, "log") + os.makedirs(self.profiler_log_dir, exist_ok=True) + if self.profile != "True": + try: + # Try to parse profile as a JSON string with schedule config + profile_config = json.loads(self.profile) + if isinstance(profile_config, dict): + for key in ["wait", "warmup", "active", "repeat"]: + if key in profile_config and isinstance( + profile_config[key], int + ): + self.profiler_schedule_config[key] = ( + profile_config[key] + ) + except json.JSONDecodeError: + logging.warning( + f"Could not parse profile config: {self.profile}, using defaults" + ) + + # For monitor thread + self.stop_event = threading.Event() + + def run(self): + logging.info( + f"Worker started on device {self.device} with {len(self.files)} files." + ) + a2g = AtomsToGraphs(r_edges=False, r_pbc=True) + # model = torch.load("/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", map_location=self.device) + # z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers]) + calculator = mace_off(model="small", device=self.device) + + results = [] + + for batch_files in self._batch_files(self.files, self.batch_size): + logging.info(f"Processing batch with {len(batch_files)} files.") + start_time = time.perf_counter() + + atoms_list = [] + for file in batch_files: + atoms = read(file) + atoms_list.append(atoms) + gbatch = data_list_collater( + [a2g.convert(atoms) for atoms in atoms_list] + ) + + gbatch = gbatch.to(self.device) + if self.filter1 == "UnitCellFilter": + from batchopt.relaxation import OptimizableUnitCellBatch + + obatch = OptimizableUnitCellBatch( + gbatch, + trainer=calculator, + numpy=False, + scalar_pressure=self.scalar_pressure, + ) + else: + obatch = OptimizableBatch( + gbatch, trainer=calculator, numpy=False + ) + + # First optimization stage + if self.optimizer1 == "LBFGS": + batch_optimizer1 = LBFGS( + obatch, damping=1.0, alpha=70.0, maxstep=0.2 + ) + elif self.optimizer1 == "BFGS": + batch_optimizer1 = BFGS(obatch, alpha=70.0, maxstep=0.2) + elif self.optimizer1 == "BFGSLineSearch": + batch_optimizer1 = BFGSLineSearch(obatch, device=self.device) + elif self.optimizer1 == "BFGSFusedLS": + batch_optimizer1 = BFGSFusedLS(obatch, device=self.device) + else: + raise ValueError(f"Unknown optimizer: {self.optimizer1}") + + start_time1 = time.perf_counter() + batch_optimizer1.run(0.01, self.max_steps) + end_time1 = time.perf_counter() + elapsed_time1 = end_time1 - start_time1 + + # Save intermediate results + atoms_list = obatch.get_atoms_list() + for atoms, file_path in zip(atoms_list, batch_files): + file_name = file_path.split("/")[-1] + output_file = os.path.join( + self.output_path, + "cif_result_press", + file_name.replace(".cif", "_press.cif"), + ) + atoms.write(output_file) + + # Capture maximum force after first optimization stage + max_force1 = obatch.get_max_forces(apply_constraint=True) + + steps1 = batch_optimizer1.nsteps + + if self.skip_second_stage: + # If skipping second stage, set metrics to zero + for file, force in zip(batch_files, max_force1): + results.append( + { + "file": file, + "stage1_time": elapsed_time1, + "stage1_steps": steps1, + "stage2_time": 0.0, + "stage2_steps": 0, + "total_time": elapsed_time1, + "total_steps": steps1, + "force1": force.item(), + "force2": 0.0, + } + ) + continue + + # Only proceed with second stage if not skipping + # Reload intermediate structures for second stage + atoms_list = [] + for file_path in batch_files: + file_name = file_path.split("/")[-1] + press_file = os.path.join( + self.output_path, + "cif_result_press", + file_name.replace(".cif", "_press.cif"), + ) + atoms = read(press_file) + atoms_list.append(atoms) + + # Rebuild batch from optimized structures + gbatch = data_list_collater( + [a2g.convert(atoms) for atoms in atoms_list] + ) + gbatch = gbatch.to(self.device) + + # Second optimization stage + if self.filter2 == "UnitCellFilter": + obatch2 = OptimizableUnitCellBatch( + gbatch, trainer=calculator, numpy=False, scalar_pressure=0.0 + ) + else: + obatch2 = OptimizableBatch( + gbatch, trainer=calculator, numpy=False + ) + + if self.optimizer2 == "LBFGS": + batch_optimizer2 = LBFGS( + obatch2, damping=1.0, alpha=70.0, maxstep=0.2 + ) + elif self.optimizer2 == "BFGS": + batch_optimizer2 = BFGS(obatch2, alpha=70.0, maxstep=0.2) + elif self.optimizer2 == "BFGSLineSearch": + batch_optimizer2 = BFGSLineSearch(obatch2, device=self.device) + elif self.optimizer2 == "BFGSFusedLS": + batch_optimizer2 = BFGSFusedLS(obatch2, device=self.device) + else: + raise ValueError(f"Unknown optimizer: {self.optimizer2}") + start_time2 = time.perf_counter() + batch_optimizer2.run(0.01, self.max_steps) + end_time2 = time.perf_counter() + elapsed_time2 = end_time2 - start_time2 + + # Save final results + atoms_list = obatch2.get_atoms_list() + for atoms, file_path in zip(atoms_list, batch_files): + file_name = file_path.split("/")[-1] + output_file = os.path.join( + self.output_path, + "cif_result_final", + file_name.replace(".cif", "_opt.cif"), + ) + atoms.write(output_file) + + # Capture maximum force after second optimization stage + max_force2 = obatch2.get_max_forces(apply_constraint=True) + + steps2 = batch_optimizer2.nsteps + + for file, f1, f2 in zip(batch_files, max_force1, max_force2): + results.append( + { + "file": file, + "stage1_time": elapsed_time1, + "stage1_steps": steps1, + "stage2_time": elapsed_time2, + "stage2_steps": steps2, + "total_time": elapsed_time1 + elapsed_time2, + "total_steps": steps1 + steps2, + "force1": f1.item(), + "force2": f2.item(), + } + ) + + return results + + def _batch_files(self, files, batch_size): + for i in range(0, len(files), batch_size): + yield files[i : i + batch_size] + + @staticmethod + def _torch_memory_monitor(interval=2, device=None, stop_event=None): + try: + # explicitly CUDA initialization + torch.cuda._lazy_init() + while not stop_event.is_set(): + allocated = torch.cuda.memory_allocated(device=device) + reserved = torch.cuda.memory_reserved(device=device) + logging.info( + f"[torch] Allocated Memory: {allocated / 1024**2:.2f} MiB" + ) + logging.info( + f"[torch] Reserved Memory: {reserved / 1024**2:.2f} MiB" + ) + time.sleep(interval) + except Exception as e: + logging.error(f"Unexpected error when monitor memory: {str(e)}") + + def continuous_run(self): + """ + Execute a continuous run of the batching optimization process. + """ + logging.info("Starting continuous_run with two rounds of optimization.") + + # torch memory monitor api + use_torch_memory_monitor = False + if use_torch_memory_monitor: + memory_monitor = threading.Thread( + target=Worker._torch_memory_monitor, + args=(2, self.device, self.stop_event), + ) + memory_monitor.start() + + # First round of optimization + try: + logging.info("Starting first round of optimization.") + results_round1, new_atoms_files = self.continuous_batching( + atoms_path=self.files, + result_path_prefix=os.path.join( + self.output_path, "cif_result_press/" + ), + fmax=0.01, + maxstep=self.max_steps, + use_filter=self.filter1, + optimizer=self.optimizer1, + scalar_pressure=self.scalar_pressure, + dtype=torch.float64, + ) + logging.info( + f"Completed first round of optimization. Results: {len(results_round1)}" + ) + except KeyboardInterrupt as e: + if use_torch_memory_monitor: + self.stop_event.set() + memory_monitor.join() + logging.error(f"Error during first round of optimization: {e}") + raise + except Exception as e: + logging.error(f"Error during first round of optimization: {e}") + raise + + if self.skip_second_stage: + logging.info("Skipping second round of optimization.") + return results_round1 + + # Second round of optimization without pressure + try: + logging.info("Starting second round of optimization.") + results_round2, _ = self.continuous_batching( + atoms_path=new_atoms_files, + result_path_prefix=os.path.join( + self.output_path, "cif_result_final/" + ), + fmax=0.01, + maxstep=self.max_steps, + # maxstep=3000, + use_filter=self.filter2, + optimizer=self.optimizer2, + scalar_pressure=0.0, + dtype=torch.float64, + ) + logging.info( + f"Completed second round of optimization. Results: {len(results_round2)}" + ) + except KeyboardInterrupt as e: + if use_torch_memory_monitor: + self.stop_event.set() + memory_monitor.join() + logging.error(f"Error during second round of optimization: {e}") + raise + except Exception as e: + logging.error(f"Error during second round of optimization: {e}") + raise + + if use_torch_memory_monitor: + self.stop_event.set() + memory_monitor.join() + + return self._save_results_to_csv(results_round1, results_round2) + + def _save_results_to_csv(self, results_round1, results_round2): + """Helper method to save results to CSV file and return the path.""" + combined_results = [] + results_map = {} + + # Process first round results + for result in results_round1: + file_name = result["file"] + results_map[file_name] = { + "file": file_name, + "stage1_steps": result["steps"], + "stage1_time": result["runtime"], + "stage1_energy": result["energy"], + "stage1_density": result["density"], + "stage2_steps": 0, + "stage2_time": 0.0, + "stage2_energy": 0.0, + "stage2_density": 0, + "total_steps": result["steps"], + "total_time": result["runtime"], + } + + # Process second round results + for result in results_round2: + file_name = result["file"] + if file_name in results_map: + results_map[file_name].update( + { + "stage2_steps": result["steps"], + "stage2_time": result["runtime"], + "stage2_energy": result["energy"], + "stage2_density": result["density"], + "total_steps": results_map[file_name]["stage1_steps"] + + result["steps"], + "total_time": results_map[file_name]["stage1_time"] + + result["runtime"], + } + ) + else: + results_map[file_name] = { + "file": file_name, + "stage1_steps": 0, + "stage1_time": 0.0, + "stage1_energy": 0.0, + "stage1_density": 0, + "stage2_steps": result["steps"], + "stage2_time": result["runtime"], + "stage2_energy": result["energy"], + "stage2_density": result["density"], + "total_steps": result["steps"], + "total_time": result["runtime"], + } + + # Convert map to list + combined_results = list(results_map.values()) + + logging.info( + f"Combined results from both rounds. Total results: {len(combined_results)}" + ) + + worker_id = os.getpid() + timestamp = int(time.time()) + csv_filename = f"worker_{worker_id}_{timestamp}.csv" + csv_path = os.path.join( + self.output_path, "worker_results", csv_filename + ) + os.makedirs(os.path.dirname(csv_path), exist_ok=True) + + with open(csv_path, mode="w", newline="") as csvfile: + fieldnames = [ + "file", + "stage1_steps", + "stage1_time", + "stage1_energy", + "stage1_density", + "stage2_steps", + "stage2_time", + "stage2_energy", + "stage2_density", + "total_steps", + "total_time", + ] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + for result in combined_results: + writer.writerow(result) + + return csv_path + + def _get_density(self, crystal): + # 计算总质量,ASE 中的 get_masses 方法返回一个数组,包含了所有原子的质量 + total_mass = sum(crystal.get_masses()) # 转换为克 + + # 获取体积,ASE 的 get_volume 方法返回晶胞的体积,单位是 Å^3 + # 1 Å^3 = 1e-24 cm^3 + volume = crystal.get_volume() # 转换为立方厘米 + + # 计算密度,质量除以体积 + density = ( + total_mass / (volume * 10**-24) / (6.022140857 * 10**23) + ) # 单位是 g/cm^3 + return density + + @staticmethod + def select_factor(history: deque): + # TODO: when history is mix of different size, the smaller `values` should be selected. + boundaries = [0, 50, 100, 200, 400, 800] + values = [0.4, 0.8, 0.9, 0.6, 0.5, 0.4] + factor_result = [] + for graph_size in history: + for i in range(len(boundaries) - 1): + if boundaries[i] <= graph_size < boundaries[i + 1]: + factor_result.append(values[i]) + break + if len(factor_result) == 0: + return 0.4 + else: + return min(factor_result) + + def continuous_batching( + self, + atoms_path, + result_path_prefix, + fmax, + maxstep, + use_filter, + optimizer, + scalar_pressure, + dtype=torch.float64, + ): + """ + Performs continuous batched optimization of atomic structures. + + This method implements a continuous batching strategy for optimizing multiple atomic structures, + where converged structures are replaced with new ones to maintain batch efficiency. + + Parameters + ---------- + atoms_path : list + List of file paths to atomic structure files to be optimized + result_path_prefix : str + Prefix for output file paths where optimized structures will be saved + fmax : float, optional + Maximum force criterion for convergence, by default 0.01 + maxstep : int, optional + Maximum number of optimization steps per batch, by default 3000 + use_filter : str, optional + Filter to be used for optimization, by default "UnitCellFilter" + optimizer : str, optional + Optimizer to be used for optimization, by default "LBFGS" + scalar_pressure : float, optional + Scalar pressure to be applied, by default 0.0 + + Returns + ------- + None + The optimized structures are saved to disk + + Notes + ----- + The method: + - Processes structures in batches of predefined size + - Uses MACE neural network potential for energy/force calculations + - Employs LBFGS optimization with unit cell relaxation + - Dynamically replaces converged structures with new ones in the batch + - Tracks convergence and optimization steps for each structure + """ + # Load saved structures + result = [] + optimized_atoms_paths = [] + + json_dir = result_path_prefix.replace("cif", "json") + + remove_list = [] + # TODO: Why we read all CIF here? + for pre_cif in atoms_path: + cif_path = os.path.join(result_path_prefix, pre_cif.split("/")[-1]) + json_path = os.path.join( + json_dir, pre_cif.split("/")[-1].replace(".cif", ".json") + ) + if ( + os.path.exists(cif_path) + and os.path.exists(json_path) + and os.path.getsize(cif_path) > 0 + and os.path.getsize(json_path) > 0 + ): + with open(json_path, "r") as f: + result_data = json.load(f) + result.append(result_data) + optimized_atoms_paths.append(cif_path) + remove_list.append(pre_cif) + logging.info(f"File {cif_path} already exists, loaded.") + # else: + # try: + # read(pre_cif) + # except Exception as e: + # logging.info(f"Failed to read {pre_cif}: {e}") + # remove_list.append(pre_cif) + for i in remove_list: + atoms_path.remove(i) + + if self.batch_size > 0: + # Initialize variables + room_in_batch = self.batch_size + indices_to_process = 0 + cur_batch_path = atoms_path[ + indices_to_process : indices_to_process + room_in_batch + ] + if len(cur_batch_path) == 0: + logging.info("No structures to process.") + return result, optimized_atoms_paths + room_in_batch -= len(cur_batch_path) + indices_to_process += len(cur_batch_path) + cur_atoms_list = [read(path) for path in cur_batch_path] + a2g = AtomsToGraphs(r_edges=False, r_pbc=True) + gbatch = data_list_collater( + [a2g.convert(read(path)) for path in cur_batch_path] + ) + else: + # Set Maximum Number of atoms per batch + history = deque(maxlen=10) + history.append(1000) + max_bnatoms = 24080 + safe_factor = self.select_factor(history) + + indices_to_process = 0 + bnatoms = 0 + cur_batch_path = [] + graphs_list = [] + a2g = AtomsToGraphs(r_edges=False, r_pbc=True) + + while indices_to_process < len(atoms_path): + graph_natoms = count_atoms_cif(atoms_path[indices_to_process]) + if ( + bnatoms + graph_natoms + > max_bnatoms * safe_factor // self.nproc + ): + break + graph = a2g.convert(read(atoms_path[indices_to_process])) + bnatoms += graph_natoms + cur_batch_path.append(atoms_path[indices_to_process]) + graphs_list.append(graph) + indices_to_process += 1 + history.append(graph_natoms) + safe_factor = self.select_factor(history) + if len(graphs_list) == 0: + logging.info("No structures to process.") + return result, optimized_atoms_paths + gbatch = data_list_collater(graphs_list) + logging.info(f"current batch size: {len(cur_batch_path)}") + + total_natoms = sum([graph.natoms for graph in graphs_list]) + logging.info(f"total_natoms: {total_natoms}") + + gbatch = gbatch.to(self.device) + batch_optimizer = None + + # Initial calculator + if self.model == "mace": + if dtype == torch.float32: + calculator = mace_off( + model="small", + device=self.device, + enable_cueq=self.cueq, + default_dtype="float32", + ) + else: + calculator = mace_off( + model="small", device=self.device, enable_cueq=self.cueq + ) + elif self.model == "chgnet": + calculator = CHGNetCalculator( + use_device=self.device, enable_cueq=self.cueq + ) + elif self.model == "sevennet": + # calculator = SevenNetCalculator(device=self.device, enable_cueq=self.cueq) + calculator = SevenNetD3Calculator( + device=self.device, + enable_cueq=self.cueq, + batch_size=self.batch_size, + ) + # calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device) + # calculator = MACECalculator(model_paths="/home/mazhaojia/.cache/mace/MACE-OFF23_small.model", device=self.device, compile_mode=self.compile_mode) + if use_filter == "UnitCellFilter": + obatch = OptimizableUnitCellBatch( + gbatch, + trainer=calculator, + numpy=False, + scalar_pressure=scalar_pressure, + ) + else: + obatch = OptimizableBatch(gbatch, trainer=calculator, numpy=False) + + orig_cells = obatch.orig_cells.clone() + + converged_atoms_count = 0 + converge_indices = [] + all_indices = [] + cur_batch_steps = [0] * len(cur_batch_path) + cur_batch_times = [time.perf_counter()] * len( + cur_batch_path + ) # Track start times + + while converged_atoms_count < len(atoms_path): + # Update batch + if len(all_indices) > 0: + if self.batch_size > 0: + room_in_batch += len(all_indices) + new_batch_path = atoms_path[ + indices_to_process : indices_to_process + room_in_batch + ] + logging.info(f"new_batch_path: {new_batch_path}") + room_in_batch -= len(new_batch_path) + indices_to_process += len(new_batch_path) + + optimized_atoms_new = [] + cur_batch_path_new = [] + cur_batch_steps_new = [] + cur_batch_times_new = [] + orig_cells_new = torch.zeros( + [self.batch_size - room_in_batch, 3, 3], + device=self.device, + ) + cell_offset = 0 + + restart_indices = [] + old_batch_indices = obatch.batch_indices + for i in range(len(optimized_atoms)): + if i in all_indices: + continue + else: + restart_indices.append(i) + optimized_atoms_new.append(optimized_atoms[i]) + cur_batch_path_new.append(cur_batch_path[i]) + cur_batch_steps_new.append(cur_batch_steps[i]) + cur_batch_times_new.append(cur_batch_times[i]) + + orig_cells_new[cell_offset] = orig_cells[i] + cell_offset += 1 + + for new_path in new_batch_path: + optimized_atoms_new.append(read(new_path)) + cur_batch_path_new.append(new_path) + cur_batch_steps_new.append(0) + cur_batch_times_new.append(time.perf_counter()) + + # Update the batch with new structures + optimized_atoms = optimized_atoms_new + cur_batch_path = cur_batch_path_new + cur_batch_steps = cur_batch_steps_new + cur_batch_times = cur_batch_times_new + else: + bnatoms = 0 + optimized_atoms_new = [] + cur_batch_path_new = [] + cur_batch_steps_new = [] + cur_batch_times_new = [] + + restart_indices = [] + old_batch_indices = obatch.batch_indices + for i in range(len(optimized_atoms)): + if i in all_indices: + continue + restart_indices.append(i) + optimized_atoms_new.append(optimized_atoms[i]) + cur_batch_path_new.append(cur_batch_path[i]) + cur_batch_steps_new.append(cur_batch_steps[i]) + cur_batch_times_new.append(cur_batch_times[i]) + bnatoms += a2g.convert(read(cur_batch_path[i])).natoms + + while indices_to_process < len(atoms_path): + new_path = atoms_path[indices_to_process] + graph_natoms = count_atoms_cif(new_path) + if ( + bnatoms + graph_natoms + > max_bnatoms * safe_factor // self.nproc + ): + break + bnatoms += graph_natoms + optimized_atoms_new.append(read(new_path)) + cur_batch_path_new.append(new_path) + cur_batch_steps_new.append(0) + cur_batch_times_new.append(time.perf_counter()) + indices_to_process += 1 + history.append(graph_natoms) + safe_factor = self.select_factor(history) + + orig_cells_new = torch.zeros( + [len(optimized_atoms_new), 3, 3], device=self.device + ) + cell_offset = 0 + for i in range(len(optimized_atoms)): + if i in all_indices: + continue + orig_cells_new[cell_offset] = orig_cells[i] + cell_offset += 1 + + # Update the batch with new structures + optimized_atoms = optimized_atoms_new + cur_batch_path = cur_batch_path_new + cur_batch_steps = cur_batch_steps_new + cur_batch_times = cur_batch_times_new + + logging.info(f"current batch size: {len(optimized_atoms)}") + + graphs_list = [a2g.convert(atoms) for atoms in optimized_atoms] + total_natoms = sum([graph.natoms for graph in graphs_list]) + logging.info(f"total_natoms: {total_natoms}") + logging.info(f"cur_batch_path to processing: {cur_batch_path}") + gbatch = data_list_collater(graphs_list) + gbatch = gbatch.to(self.device) + if self.model == "sevennet": + # calculator = SevenNetCalculator('7net-mf-ompa', modal='mpa', device=self.device) + calculator = SevenNetD3Calculator( + device=self.device, + enable_cueq=self.cueq, + batch_size=self.batch_size, + ) + if use_filter == "UnitCellFilter": + obatch = OptimizableUnitCellBatch( + gbatch, + trainer=calculator, + numpy=False, + scalar_pressure=scalar_pressure, + ) + else: + obatch = OptimizableBatch( + gbatch, trainer=calculator, numpy=False + ) + for i in range(cell_offset): + obatch.orig_cells[i] = orig_cells_new[i] + orig_cells = obatch.orig_cells.clone() + + # Optimize the current batch + if optimizer == "LBFGS": + batch_optimizer = LBFGS( + obatch, + damping=1.0, + alpha=70.0, + maxstep=0.2, + early_stop=True, + ) + elif optimizer == "BFGS": + if len(all_indices) > 0: + logging.info(f"Restarting with indices: {restart_indices}") + batch_optimizer.optimizable = obatch + else: + batch_optimizer = BFGS( + obatch, alpha=70.0, maxstep=0.2, early_stop=True + ) + elif optimizer == "BFGSLineSearch": + batch_optimizer = BFGSLineSearch( + obatch, + device=self.device, + early_stop=True, + use_profiler=self.use_profiler, + profiler_log_dir=self.profiler_log_dir, + profiler_schedule_config=self.profiler_schedule_config, + ) + elif optimizer == "BFGSFusedLS": + if len(all_indices) > 0: + logging.info(f"Restarting with indices: {restart_indices}") + batch_optimizer.optimizable = obatch + else: + batch_optimizer = BFGSFusedLS( + obatch, + device=self.device, + early_stop=True, + use_profiler=self.use_profiler, + profiler_log_dir=self.profiler_log_dir, + profiler_schedule_config=self.profiler_schedule_config, + ) + else: + raise ValueError(f"Unknown optimizer: {optimizer}") + + # 动态计算剩余可用步数(基于当前批次最大已执行步数) + current_max_steps = max(cur_batch_steps) if cur_batch_steps else 0 + remaining_steps = max( + maxstep - current_max_steps, 1 + ) # 保证至少运行1步 + + # 执行优化并获取收敛的索引 + if (optimizer == "BFGSFusedLS" or optimizer == "BFGS") and len( + all_indices + ) > 0: + converge_indices = batch_optimizer.run( + fmax, + remaining_steps, + is_restart_earlystop=True, + restart_indices=restart_indices, + old_batch_indices=old_batch_indices, + ) + else: + converge_indices = batch_optimizer.run(fmax, remaining_steps) + + # Print energies of all structures + # logging.info(f"Final energies of all structures: {batch_optimizer.energies}") + energies_list = ( + batch_optimizer.optimizable.get_potential_energies().tolist() + ) + logging.info(f"Final energies of all structures: {energies_list}") + + # 更新所有结构的累计步数 + cur_batch_steps = [ + steps + batch_optimizer.nsteps for steps in cur_batch_steps + ] + + # 找出超过最大步数的结构索引 + over_maxstep_indices = [ + i + for i, steps in enumerate(cur_batch_steps) + if steps >= maxstep - 1 + ] + + # 合并收敛和超限的索引(去重) + all_indices = list(set(converge_indices + over_maxstep_indices)) + + # Get optimized atoms + optimized_atoms = obatch.get_atoms_list() + converged_atoms_count += len(all_indices) + + end_time = time.perf_counter() + # 处理所有需要退出的结构(包括收敛和超限) + for idx in all_indices: + runtime = end_time - cur_batch_times[idx] + + energy_per_mol = ( + energies_list[idx] + / ( + len(optimized_atoms[idx].get_atomic_numbers()) + / self.molecule_single + ) + * 96.485 + ) + density = self._get_density(optimized_atoms[idx]) + + # Save results + result_data = { + "file": cur_batch_path[idx].split("/")[-1].split(".")[0], + "steps": cur_batch_steps[idx], + "runtime": runtime, + "energy": energy_per_mol, + "density": density, + } + result.append(result_data) + + # Save optimized structure + # converged_atoms_path = os.path.join(result_path_prefix, cur_batch_path[idx].split('/')[-1].replace('.cif', '.traj')) + converged_atoms_path = os.path.join( + result_path_prefix, cur_batch_path[idx].split("/")[-1] + ) + optimized_atoms[idx].write(converged_atoms_path) + optimized_atoms_paths.append(converged_atoms_path) + + # write a json file to store reslt_data + os.makedirs(json_dir, exist_ok=True) + # json_path = os.path.join(json_dir, cur_batch_path[idx].split('/')[-1]+'.json') + json_path = os.path.join( + json_dir, + cur_batch_path[idx].split("/")[-1].replace(".cif", ".json"), + ) + with open(json_path, "w") as f: + json.dump(result_data, f) + + logging.info(f"cur_batch_path: {cur_batch_path}") + logging.info(f"cur_batch_steps: {cur_batch_steps}") + logging.info(f"all_indices: {all_indices}") + logging.info(f"length of optimized_atoms: {len(optimized_atoms)}") + + return result, optimized_atoms_paths diff --git a/mace-bench/src/batchopt/utils.py b/mace-bench/src/batchopt/utils.py index 7b899f4..ff70707 100644 --- a/mace-bench/src/batchopt/utils.py +++ b/mace-bench/src/batchopt/utils.py @@ -1,121 +1,121 @@ -""" -Copyright (c) Meta, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -import ast -import collections -import copy -import datetime -import errno -import functools -import importlib -import itertools -import json -import logging -import os -import subprocess -import sys -import time -from bisect import bisect -from contextlib import contextmanager -from dataclasses import dataclass -from functools import wraps -from itertools import product -from pathlib import Path -from typing import TYPE_CHECKING, Any -from uuid import uuid4 - -import numpy as np -import torch -import torch.nn as nn -import torch_geometric -import yaml -from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas -from matplotlib.figure import Figure -from torch_geometric.data import Data -from torch_geometric.utils import remove_self_loops -from torch_scatter import scatter, segment_coo, segment_csr - -from torch_geometric.data.data import BaseData -from torch_geometric.data import Batch - -# sort files by atomic number in descending order -def count_atoms_cif(file): - in_atom_site = False - natoms = 0 - with open(file, 'r') as f: - while line := f.readline(): - if line.lower().startswith("loop_"): - in_atom_site = False - continue - # if line.lower().startswith("_atom_site_"): - if "_atom_site_" in line.lower(): - in_atom_site = True - continue - if in_atom_site: - if line.startswith("_"): - in_atom_site = False - continue - elif line: - natoms += 1 - return natoms - -# Override the collation method in `pytorch_geometric.data.InMemoryDataset` -def collate(data_list): - keys = data_list[0].keys - data = data_list[0].__class__() - - for key in keys: - data[key] = [] - slices = {key: [0] for key in keys} - - for item, key in product(data_list, keys): - data[key].append(item[key]) - if torch.is_tensor(item[key]): - s = slices[key][-1] + item[key].size(item.__cat_dim__(key, item[key])) - elif isinstance(item[key], (int, float)): - s = slices[key][-1] + 1 - else: - raise ValueError("Unsupported attribute type") - slices[key].append(s) - - if hasattr(data_list[0], "__num_nodes__"): - data.__num_nodes__ = [] - for item in data_list: - data.__num_nodes__.append(item.num_nodes) - - for key in keys: - if torch.is_tensor(data_list[0][key]): - data[key] = torch.cat( - data[key], dim=data.__cat_dim__(key, data_list[0][key]) - ) - else: - data[key] = torch.tensor(data[key]) - slices[key] = torch.tensor(slices[key], dtype=torch.long) - - return data, slices - -def data_list_collater( - data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False -) -> BaseData | dict[str, torch.Tensor]: - batch = Batch.from_data_list(data_list) - - if not otf_graph: - try: - n_neighbors = [] - for _, data in enumerate(data_list): - n_index = data.edge_index[1, :] - n_neighbors.append(n_index.shape[0]) - batch.neighbors = torch.tensor(n_neighbors) - except (NotImplementedError, TypeError): - logging.warning( - "LMDB does not contain edge index information, set otf_graph=True" - ) - - if to_dict: - batch = dict(batch.items()) - +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +import ast +import collections +import copy +import datetime +import errno +import functools +import importlib +import itertools +import json +import logging +import os +import subprocess +import sys +import time +from bisect import bisect +from contextlib import contextmanager +from dataclasses import dataclass +from functools import wraps +from itertools import product +from pathlib import Path +from typing import TYPE_CHECKING, Any +from uuid import uuid4 + +import numpy as np +import torch +import torch.nn as nn +import torch_geometric +import yaml +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure +from torch_geometric.data import Data +from torch_geometric.utils import remove_self_loops +from torch_scatter import scatter, segment_coo, segment_csr + +from torch_geometric.data.data import BaseData +from torch_geometric.data import Batch + +# sort files by atomic number in descending order +def count_atoms_cif(file): + in_atom_site = False + natoms = 0 + with open(file, 'r') as f: + while line := f.readline(): + if line.lower().startswith("loop_"): + in_atom_site = False + continue + # if line.lower().startswith("_atom_site_"): + if "_atom_site_" in line.lower(): + in_atom_site = True + continue + if in_atom_site: + if line.startswith("_"): + in_atom_site = False + continue + elif line: + natoms += 1 + return natoms + +# Override the collation method in `pytorch_geometric.data.InMemoryDataset` +def collate(data_list): + keys = data_list[0].keys + data = data_list[0].__class__() + + for key in keys: + data[key] = [] + slices = {key: [0] for key in keys} + + for item, key in product(data_list, keys): + data[key].append(item[key]) + if torch.is_tensor(item[key]): + s = slices[key][-1] + item[key].size(item.__cat_dim__(key, item[key])) + elif isinstance(item[key], (int, float)): + s = slices[key][-1] + 1 + else: + raise ValueError("Unsupported attribute type") + slices[key].append(s) + + if hasattr(data_list[0], "__num_nodes__"): + data.__num_nodes__ = [] + for item in data_list: + data.__num_nodes__.append(item.num_nodes) + + for key in keys: + if torch.is_tensor(data_list[0][key]): + data[key] = torch.cat( + data[key], dim=data.__cat_dim__(key, data_list[0][key]) + ) + else: + data[key] = torch.tensor(data[key]) + slices[key] = torch.tensor(slices[key], dtype=torch.long) + + return data, slices + +def data_list_collater( + data_list: list[BaseData], otf_graph: bool = False, to_dict: bool = False +) -> BaseData | dict[str, torch.Tensor]: + batch = Batch.from_data_list(data_list) + + if not otf_graph: + try: + n_neighbors = [] + for _, data in enumerate(data_list): + n_index = data.edge_index[1, :] + n_neighbors.append(n_index.shape[0]) + batch.neighbors = torch.tensor(n_neighbors) + except (NotImplementedError, TypeError): + logging.warning( + "LMDB does not contain edge index information, set otf_graph=True" + ) + + if to_dict: + batch = dict(batch.items()) + return batch \ No newline at end of file diff --git a/mace-bench/util/env.sh b/mace-bench/util/env.sh index 00b7806..9a47c6a 100644 --- a/mace-bench/util/env.sh +++ b/mace-bench/util/env.sh @@ -1,7 +1,6 @@ -#!/bin/bash - -export CUDA_HOME=/usr/local/cuda -export PATH=$CUDA_HOME/bin:$PATH -export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH -export LIBRARY_PATH=$CUDA_HOME/lib64:$LIBRARY_PATH + +export CUDA_HOME=/usr/local/cuda +export PATH=$CUDA_HOME/bin:$PATH +export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH +export LIBRARY_PATH=$CUDA_HOME/lib64:$LIBRARY_PATH export CPATH=$CUDA_HOME/include:$CPATH \ No newline at end of file diff --git a/mace-bench/util/mps_clean.sh b/mace-bench/util/mps_clean.sh index 2a1354b..15ef558 100644 --- a/mace-bench/util/mps_clean.sh +++ b/mace-bench/util/mps_clean.sh @@ -1,11 +1,10 @@ -#!/bin/bash - -echo quit | nvidia-cuda-mps-control -nvidia-smi -i 0 -c DEFAULT -nvidia-smi -i 1 -c DEFAULT -nvidia-smi -i 2 -c DEFAULT -nvidia-smi -i 3 -c DEFAULT -nvidia-smi -i 4 -c DEFAULT -nvidia-smi -i 5 -c DEFAULT -nvidia-smi -i 6 -c DEFAULT + +echo quit | nvidia-cuda-mps-control +nvidia-smi -i 0 -c DEFAULT +nvidia-smi -i 1 -c DEFAULT +nvidia-smi -i 2 -c DEFAULT +nvidia-smi -i 3 -c DEFAULT +nvidia-smi -i 4 -c DEFAULT +nvidia-smi -i 5 -c DEFAULT +nvidia-smi -i 6 -c DEFAULT nvidia-smi -i 7 -c DEFAULT \ No newline at end of file diff --git a/mace-bench/util/mps_start.sh b/mace-bench/util/mps_start.sh index 46b0966..9304a85 100644 --- a/mace-bench/util/mps_start.sh +++ b/mace-bench/util/mps_start.sh @@ -1,10 +1,10 @@ -#!/bin/bash -nvidia-smi -i 0 -c EXCLUSIVE_PROCESS # Set GPU 0 to exclusive mode. -nvidia-smi -i 1 -c EXCLUSIVE_PROCESS # Set GPU 1 to exclusive mode. -nvidia-smi -i 2 -c EXCLUSIVE_PROCESS # Set GPU 2 to exclusive mode. -nvidia-smi -i 3 -c EXCLUSIVE_PROCESS # Set GPU 3 to exclusive mode. -nvidia-smi -i 4 -c EXCLUSIVE_PROCESS # Set GPU 4 to exclusive mode. -nvidia-smi -i 5 -c EXCLUSIVE_PROCESS # Set GPU 5 to exclusive mode. -nvidia-smi -i 6 -c EXCLUSIVE_PROCESS # Set GPU 6 to exclusive mode. -nvidia-smi -i 7 -c EXCLUSIVE_PROCESS # Set GPU 7 to exclusive mode. + +nvidia-smi -i 0 -c EXCLUSIVE_PROCESS # Set GPU 0 to exclusive mode. +nvidia-smi -i 1 -c EXCLUSIVE_PROCESS # Set GPU 1 to exclusive mode. +nvidia-smi -i 2 -c EXCLUSIVE_PROCESS # Set GPU 2 to exclusive mode. +nvidia-smi -i 3 -c EXCLUSIVE_PROCESS # Set GPU 3 to exclusive mode. +nvidia-smi -i 4 -c EXCLUSIVE_PROCESS # Set GPU 4 to exclusive mode. +nvidia-smi -i 5 -c EXCLUSIVE_PROCESS # Set GPU 5 to exclusive mode. +nvidia-smi -i 6 -c EXCLUSIVE_PROCESS # Set GPU 6 to exclusive mode. +nvidia-smi -i 7 -c EXCLUSIVE_PROCESS # Set GPU 7 to exclusive mode. nvidia-cuda-mps-control -d # Start the daemon. \ No newline at end of file diff --git a/main.py b/main.py index 093d67c..187c4df 100644 --- a/main.py +++ b/main.py @@ -1,95 +1,104 @@ -from basic_function import format_parser -from basic_function import packaged_function -from basic_function import conformer_search -import time -import argparse -import os -import itertools - - -if __name__ == '__main__': - - time_start = time.time() - - # initiate configuration - ############################################################################################## - parser = argparse.ArgumentParser() - parser.add_argument('--path', type=str, default="./", help='Path to process') - parser.add_argument('--smiles', type=str, required=True, help='SMILES string of the molecules, split by . if multiple molecules are used') - parser.add_argument('--generate_conformers', type=int, default=20, help='Number of conformers to generate. When it is <=0, only load existing conformers to generate structures') - parser.add_argument('--use_conformers', type=int, default=4, help='Number of conformers used to generate structure. When it is <=0, no structure generation would be done') - parser.add_argument('--molecule_num_in_cell', type=str, nargs='+', default=['1'], help='number of molecules in a unit cell, split by comma for multiple molecules, and split by space for multiple packings') - parser.add_argument('--num_generation', type=int, nargs='+', default=[100], help='number of structures to generate, split by space for multiple packings') - parser.add_argument('--space_group_list', type=str, nargs='+', default=["2,14"], help='Space group list for structure generation, spilt by comma to add mutiple groups, split by space for multiple packings') - parser.add_argument('--add_name', type=str, nargs='+', default=["CRYSTAL"], help='Add name for the generated structures, split by space for multiple packings') - parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of workers for parallel processing') - args = parser.parse_args() - - target_folder = args.path - smiles_list = args.smiles.split('.') - generate_conformers = args.generate_conformers - use_conformers = args.use_conformers - molecule_num_in_cell = [list(map(int, num.split(','))) for num in args.molecule_num_in_cell] - num_generation = args.num_generation - space_group_list = [list(map(int, group.split(','))) for group in args.space_group_list] - add_name = args.add_name - max_workers = args.max_workers - - num_molecules = len(smiles_list) - num_packings = max(len(molecule_num_in_cell), len(space_group_list)) - - for i in range(len(molecule_num_in_cell)): - if len(molecule_num_in_cell[i]) < num_molecules: - molecule_num_in_cell[i].extend([1] * (num_molecules - len(molecule_num_in_cell[i]))) - elif len(molecule_num_in_cell[i]) > num_molecules: - molecule_num_in_cell[i] = molecule_num_in_cell[i][:num_molecules] - - while len(molecule_num_in_cell) < num_packings: - molecule_num_in_cell.append(molecule_num_in_cell[-1]) - - while len(space_group_list) < num_packings: - space_group_list.append(space_group_list[-1]) - - while len(add_name) < num_packings: - add_name.append(add_name[-1]) - - while len(num_generation) < num_packings: - num_generation.append(num_generation[-1]) - - - # step1: conformer search - ############################################################################################## - molecule_data = [] - for i in range(num_molecules): - molecule_folder = os.path.join(target_folder, f"molecule_{i+1}") - molecule_data.append([]) - if generate_conformers > 0: - conformer_search.conformer_search(smiles_list[i], molecule_folder, num_conformers=generate_conformers, max_attempts=10000, rms_thresh=0.1) - with open(os.path.join(molecule_folder, "info.txt"), "w") as smiles_file: - smiles_file.write(f"SMILES: {smiles_list[i]}") - for j in range(use_conformers): - temp_path = os.path.join(molecule_folder, "conformers", f"conformer_{j}.xyz") - if not os.path.exists(temp_path): - break - molecule_data[i].append(format_parser.read_xyz_file(temp_path)) - if len(molecule_data[i]) <= 0: - print(f"No conformer loaded for molecule_{i+1}. Check configurations!") - break - - idx_data = [list(range(len(item))) for item in molecule_data] - combinations = list(itertools.product(*idx_data)) - - - # step2: structure generation - ############################################################################################## - for i in range(num_packings): - for combination in combinations: - molecule_list = [] - for j in range(num_molecules): - for cnt in range(molecule_num_in_cell[i][j]): - molecule_list.append(molecule_data[j][combination[j]]) - c_name = "".join(map(str, combination)) - packaged_function.CSP_generater_parallel(molecule_list, target_folder, need_structure=num_generation[i], space_group_list=space_group_list[i],add_name=f"{add_name[i]}_C{c_name}", max_workers=max_workers,start_seed=1) - - time_end=time.time() - print('time cost',time_end-time_start,'s') +from basic_function import format_parser +from basic_function import packaged_function +from basic_function import conformer_search +import time +import argparse +import os +import itertools + + +if __name__ == '__main__': + + time_start = time.time() + + # initiate configuration + ############################################################################################## + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default="./", help='Path to process') + parser.add_argument('--smiles', type=str, required=True, help='SMILES string of the molecules, split by . if multiple molecules are used') + parser.add_argument('--generate_conformers', type=int, default=20, help='Number of conformers to generate. When it is <=0, only load existing conformers to generate structures') + parser.add_argument('--use_conformers', type=int, default=4, help='Number of conformers used to generate structure. When it is <=0, no structure generation would be done') + parser.add_argument('--molecule_num_in_cell', type=str, nargs='+', default=['1'], help='number of molecules in a unit cell, split by comma for multiple molecules, and split by space for multiple packings') + parser.add_argument('--num_generation', type=int, nargs='+', default=[100], help='number of structures to generate, split by space for multiple packings') + parser.add_argument('--space_group_list', type=str, nargs='+', default=["2,14"], help='Space group list for structure generation, spilt by comma to add mutiple groups, split by space for multiple packings') + parser.add_argument('--add_name', type=str, nargs='+', default=["CRYSTAL"], help='Add name for the generated structures, split by space for multiple packings') + parser.add_argument('--max_workers', type=int, default=8, help='Maximum number of workers for parallel processing') + parser.add_argument('--mode', type=str, default=8, choices=['all', 'conformer_only', 'structure_only'], help='choose the jobs to do') + args = parser.parse_args() + + target_folder = args.path + smiles_list = args.smiles.split('.') + generate_conformers = args.generate_conformers + use_conformers = args.use_conformers + molecule_num_in_cell = [list(map(int, num.split(','))) for num in args.molecule_num_in_cell] + num_generation = args.num_generation + space_group_list = [list(map(int, group.split(','))) for group in args.space_group_list] + add_name = args.add_name + max_workers = args.max_workers + mode = args.mode + + num_molecules = len(smiles_list) + num_packings = max(len(molecule_num_in_cell), len(space_group_list)) + + for i in range(len(molecule_num_in_cell)): + if len(molecule_num_in_cell[i]) < num_molecules: + molecule_num_in_cell[i].extend([1] * (num_molecules - len(molecule_num_in_cell[i]))) + elif len(molecule_num_in_cell[i]) > num_molecules: + molecule_num_in_cell[i] = molecule_num_in_cell[i][:num_molecules] + + while len(molecule_num_in_cell) < num_packings: + molecule_num_in_cell.append(molecule_num_in_cell[-1]) + + while len(space_group_list) < num_packings: + space_group_list.append(space_group_list[-1]) + + while len(add_name) < num_packings: + add_name.append(add_name[-1]) + + while len(num_generation) < num_packings: + num_generation.append(num_generation[-1]) + + + # step1: conformer search + ############################################################################################## + molecule_data = [] + for i in range(num_molecules): + molecule_folder = os.path.join(target_folder, f"molecule_{i+1}") + molecule_data.append([]) + if generate_conformers > 0 and mode != "structure_only": + conformer_search.conformer_search(smiles_list[i], molecule_folder, num_conformers=generate_conformers, max_attempts=10000, rms_thresh=0.1) + with open(os.path.join(molecule_folder, "info.txt"), "w") as smiles_file: + smiles_file.write(f"SMILES: {smiles_list[i]}") + file_num = len(os.listdir(os.path.join(molecule_folder, "conformers"))) + cnt = 0 + for j in range(file_num): + if cnt >= use_conformers: + break + temp_path = os.path.join(molecule_folder, "conformers", f"conformer_{j}.xyz") + if not os.path.exists(temp_path): + break + molecule_data[i].append(format_parser.read_xyz_file(temp_path)) + cnt += 1 + + if len(molecule_data[i]) <= 0: + print(f"No conformer loaded for molecule_{i+1}. Check configurations!") + break + + idx_data = [list(range(len(item))) for item in molecule_data] + combinations = list(itertools.product(*idx_data)) + + + # step2: structure generation + ############################################################################################## + if mode != "conformer_only": + for i in range(num_packings): + for combination in combinations: + molecule_list = [] + for j in range(num_molecules): + for cnt in range(molecule_num_in_cell[i][j]): + molecule_list.append(molecule_data[j][combination[j]]) + c_name = "".join(map(str, combination)) + packaged_function.CSP_generater_parallel(molecule_list, target_folder, need_structure=num_generation[i], space_group_list=space_group_list[i],add_name=f"{add_name[i]}_C{c_name}", max_workers=max_workers,start_seed=1) + + time_end=time.time() + print('time cost',time_end-time_start,'s') diff --git a/post_process/check_match.py b/post_process/check_match.py index 5f27d30..a4b8953 100644 --- a/post_process/check_match.py +++ b/post_process/check_match.py @@ -1,154 +1,154 @@ -import warnings -warnings.filterwarnings("ignore", category=DeprecationWarning) -warnings.filterwarnings("ignore", category=UserWarning) - -from ccdc.crystal import PackingSimilarity -from ccdc.io import CrystalReader -import glob -import os -import sys -import random -import pandas as pd -from multiprocessing import Pool, cpu_count, TimeoutError as mpTimeoutError # import TimeoutError -import argparse - -# --- Global Configuration --- -REPORT_TARGET = 15 -LARGE_CONFORMER_DIFF = True - -# --- Worker Process Initializer --- -def init_worker(ref_path, engine_settings): - """ - Initializes a worker process. - This function loads the reference structure and creates the similarity engine - once per process, storing them in global variables for that process. - """ - # print(f"Worker process {os.getpid()} initializing...") - # sys.stdout.flush() - global worker_ref_crystal - global worker_similarity_engine - worker_ref_crystal = CrystalReader(ref_path)[0] - worker_similarity_engine = PackingSimilarity() - worker_similarity_engine.settings.allow_molecular_differences = engine_settings['allow_molecular_differences'] - worker_similarity_engine.settings.distance_tolerance = engine_settings['distance_tolerance'] - worker_similarity_engine.settings.angle_tolerance = engine_settings['angle_tolerance'] - worker_similarity_engine.settings.packing_shell_size = engine_settings['packing_shell_size'] - worker_similarity_engine.settings.ignore_hydrogen_positions = engine_settings['ignore_hydrogen_positions'] - worker_similarity_engine.settings.ignore_bond_counts = engine_settings['ignore_bond_counts'] - worker_similarity_engine.settings.ignore_hydrogen_counts = engine_settings['ignore_hydrogen_counts'] - -# --- Single Task Processing Function --- -def process_single_cif(csp_file_path): - """ - Compares a single candidate structure against the reference structure - loaded in the worker's global scope. - Returns a tuple indicating the result type ('matched' or 'failed') and the file path. - """ - global worker_ref_crystal - global worker_similarity_engine - try: - try_structure = CrystalReader(csp_file_path)[0] - h = worker_similarity_engine.compare(try_structure, worker_ref_crystal) - if h.nmatched_molecules >= REPORT_TARGET: - print(f"MATCH: {os.path.basename(csp_file_path)} | Matched Molecules: {h.nmatched_molecules}, RMSD: {h.rmsd:.3f}") - sys.stdout.flush() - return ("matched", csp_file_path) - except Exception as e: - if not LARGE_CONFORMER_DIFF: - print(f"FAIL: {os.path.basename(csp_file_path)} | Reason: {e}") - sys.stdout.flush() - return ("failed", csp_file_path) - return None - -# --- Main Execution Block --- -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--path', type=str, default="./", help='Path to process') - parser.add_argument('--ref_path', type=str, default="../refs", help='Path to find reference structrues') - parser.add_argument('--workers', type=int, default=80, help='Max worker number limit') - parser.add_argument('--timeout', type=int, default=20, help='Timeout for each task in seconds') - - args = parser.parse_args() - base_path = args.path - refs_dir = args.ref_path - PROCESS_NUM = min(args.workers, cpu_count()) # Use the specified number of workers or the max available - TIMEOUT_SECONDS = args.timeout # Set the timeout for each task - all_results = [] - - print(f"Starting checking match using up to {PROCESS_NUM} processes with a {TIMEOUT_SECONDS}s timeout per task...") - - folders_to_process = [] - csp_dir = os.path.join(base_path, "cif_result_final") - if os.path.exists(csp_dir) and os.path.exists(refs_dir): - folders_to_process.append((csp_dir, refs_dir)) - - for csp_dir, refs_dir in folders_to_process: - for ref_filename in os.listdir(refs_dir): - if not ref_filename.endswith(".cif"): - continue - - ref_full_path = os.path.join(refs_dir, ref_filename) - print(f"\n--- Processing Reference File: {ref_full_path} ---") - - csp_files = glob.glob(os.path.join(csp_dir, '*.cif')) - random.shuffle(csp_files) - - if not csp_files: - print("No candidate .cif files found, skipping.") - continue - - engine_settings = { - 'allow_molecular_differences': False, - 'distance_tolerance': 0.2, - 'angle_tolerance': 20, - 'packing_shell_size': 15, - 'ignore_hydrogen_positions': True, - 'ignore_bond_counts': True, - 'ignore_hydrogen_counts': True - } - - with Pool(processes=PROCESS_NUM, initializer=init_worker, initargs=(ref_full_path, engine_settings)) as pool: - - async_results = [] - for f in csp_files: - res = pool.apply_async(process_single_cif, args=(f,)) - async_results.append(res) - - results_list = [] - for i, res_obj in enumerate(async_results): - try: - result = res_obj.get(timeout=TIMEOUT_SECONDS) - results_list.append(result) - except mpTimeoutError: - timed_out_file = csp_files[i] - print(f"TIMEOUT: {timed_out_file} | Task exceeded {TIMEOUT_SECONDS}s limit.") - sys.stdout.flush() - results_list.append(("failed", timed_out_file)) - - matched_structures = [] - failed_structures = [] - for res in results_list: - if res: - status, path = res - if status == "matched": - matched_structures.append(os.path.basename(path)) - elif status == "failed": - failed_structures.append(os.path.basename(path)) - - all_results.append({ - "ref_name": ref_filename, - "matched_count": len(matched_structures), - "matched_structures": ";".join(matched_structures), - "failed_count": len(failed_structures), - "failed_structures": ";".join(failed_structures) - }) - print(f"--- Finished {ref_filename}. Matched: {len(matched_structures)}, Failed: {len(failed_structures)} ---") - - if all_results: - df = pd.DataFrame(all_results) - output_filename = "match_results.csv" - df.to_csv(output_filename, index=False) - print(f"\nAll processing finished. Results saved to {output_filename}") - else: +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +from ccdc.crystal import PackingSimilarity +from ccdc.io import CrystalReader +import glob +import os +import sys +import random +import pandas as pd +from multiprocessing import Pool, cpu_count, TimeoutError as mpTimeoutError # import TimeoutError +import argparse + +# --- Global Configuration --- +REPORT_TARGET = 15 +LARGE_CONFORMER_DIFF = True + +# --- Worker Process Initializer --- +def init_worker(ref_path, engine_settings): + """ + Initializes a worker process. + This function loads the reference structure and creates the similarity engine + once per process, storing them in global variables for that process. + """ + # print(f"Worker process {os.getpid()} initializing...") + # sys.stdout.flush() + global worker_ref_crystal + global worker_similarity_engine + worker_ref_crystal = CrystalReader(ref_path)[0] + worker_similarity_engine = PackingSimilarity() + worker_similarity_engine.settings.allow_molecular_differences = engine_settings['allow_molecular_differences'] + worker_similarity_engine.settings.distance_tolerance = engine_settings['distance_tolerance'] + worker_similarity_engine.settings.angle_tolerance = engine_settings['angle_tolerance'] + worker_similarity_engine.settings.packing_shell_size = engine_settings['packing_shell_size'] + worker_similarity_engine.settings.ignore_hydrogen_positions = engine_settings['ignore_hydrogen_positions'] + worker_similarity_engine.settings.ignore_bond_counts = engine_settings['ignore_bond_counts'] + worker_similarity_engine.settings.ignore_hydrogen_counts = engine_settings['ignore_hydrogen_counts'] + +# --- Single Task Processing Function --- +def process_single_cif(csp_file_path): + """ + Compares a single candidate structure against the reference structure + loaded in the worker's global scope. + Returns a tuple indicating the result type ('matched' or 'failed') and the file path. + """ + global worker_ref_crystal + global worker_similarity_engine + try: + try_structure = CrystalReader(csp_file_path)[0] + h = worker_similarity_engine.compare(try_structure, worker_ref_crystal) + if h.nmatched_molecules >= REPORT_TARGET: + print(f"MATCH: {os.path.basename(csp_file_path)} | Matched Molecules: {h.nmatched_molecules}, RMSD: {h.rmsd:.3f}") + sys.stdout.flush() + return ("matched", csp_file_path) + except Exception as e: + if not LARGE_CONFORMER_DIFF: + print(f"FAIL: {os.path.basename(csp_file_path)} | Reason: {e}") + sys.stdout.flush() + return ("failed", csp_file_path) + return None + +# --- Main Execution Block --- +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default="./", help='Path to process') + parser.add_argument('--ref_path', type=str, default="../refs", help='Path to find reference structrues') + parser.add_argument('--workers', type=int, default=80, help='Max worker number limit') + parser.add_argument('--timeout', type=int, default=20, help='Timeout for each task in seconds') + + args = parser.parse_args() + base_path = args.path + refs_dir = args.ref_path + PROCESS_NUM = min(args.workers, cpu_count()) # Use the specified number of workers or the max available + TIMEOUT_SECONDS = args.timeout # Set the timeout for each task + all_results = [] + + print(f"Starting checking match using up to {PROCESS_NUM} processes with a {TIMEOUT_SECONDS}s timeout per task...") + + folders_to_process = [] + csp_dir = os.path.join(base_path, "cif_result_final") + if os.path.exists(csp_dir) and os.path.exists(refs_dir): + folders_to_process.append((csp_dir, refs_dir)) + + for csp_dir, refs_dir in folders_to_process: + for ref_filename in os.listdir(refs_dir): + if not ref_filename.endswith(".cif"): + continue + + ref_full_path = os.path.join(refs_dir, ref_filename) + print(f"\n--- Processing Reference File: {ref_full_path} ---") + + csp_files = glob.glob(os.path.join(csp_dir, '*.cif')) + random.shuffle(csp_files) + + if not csp_files: + print("No candidate .cif files found, skipping.") + continue + + engine_settings = { + 'allow_molecular_differences': False, + 'distance_tolerance': 0.2, + 'angle_tolerance': 20, + 'packing_shell_size': 15, + 'ignore_hydrogen_positions': True, + 'ignore_bond_counts': True, + 'ignore_hydrogen_counts': True + } + + with Pool(processes=PROCESS_NUM, initializer=init_worker, initargs=(ref_full_path, engine_settings)) as pool: + + async_results = [] + for f in csp_files: + res = pool.apply_async(process_single_cif, args=(f,)) + async_results.append(res) + + results_list = [] + for i, res_obj in enumerate(async_results): + try: + result = res_obj.get(timeout=TIMEOUT_SECONDS) + results_list.append(result) + except mpTimeoutError: + timed_out_file = csp_files[i] + print(f"TIMEOUT: {timed_out_file} | Task exceeded {TIMEOUT_SECONDS}s limit.") + sys.stdout.flush() + results_list.append(("failed", timed_out_file)) + + matched_structures = [] + failed_structures = [] + for res in results_list: + if res: + status, path = res + if status == "matched": + matched_structures.append(os.path.basename(path)) + elif status == "failed": + failed_structures.append(os.path.basename(path)) + + all_results.append({ + "ref_name": ref_filename, + "matched_count": len(matched_structures), + "matched_structures": ";".join(matched_structures), + "failed_count": len(failed_structures), + "failed_structures": ";".join(failed_structures) + }) + print(f"--- Finished {ref_filename}. Matched: {len(matched_structures)}, Failed: {len(failed_structures)} ---") + + if all_results: + df = pd.DataFrame(all_results) + output_filename = "match_results.csv" + df.to_csv(output_filename, index=False) + print(f"\nAll processing finished. Results saved to {output_filename}") + else: print("\nNo valid data processed. No output file generated.") \ No newline at end of file diff --git a/post_process/clean_table.py b/post_process/clean_table.py index d3cfa3e..2099ecc 100644 --- a/post_process/clean_table.py +++ b/post_process/clean_table.py @@ -1,49 +1,49 @@ -import os -import sys -import pandas as pd - - -target_folder = os.getcwd() - - -if __name__ == '__main__': - print(f"Cleaning table in folder: {target_folder}") - error_folder = os.path.join(target_folder, "error_result") - if not os.path.exists(error_folder): - os.makedirs(error_folder) - cif_folder = os.path.join(target_folder, "cif_result_final") - csv_file = os.path.join(target_folder, "results_scheduler.csv") - - if not os.path.exists(csv_file): - print(f"CSV file not found in {target_folder}, skipping...") - sys.exit(0) - if not os.path.exists(cif_folder): - print(f"CIF folder not found in {target_folder}, skipping...") - sys.exit(0) - - all_crystals = pd.read_csv(csv_file) - cifs = os.listdir(cif_folder) - invalid_list = [] - for i, item in all_crystals.iterrows(): - if int(item['stage2_steps']) > 2990 or (item['stage1_steps'] >= 2999 and item["stage2_steps"] == 0) or abs(float(item['stage2_energy'])) > 1e15 or item['file']+".cif" not in cifs: - invalid_list.append(i) - - if invalid_list: - invalid_df = all_crystals.iloc[invalid_list] - all_crystals.drop(index=invalid_list, inplace=True) - - all_crystals.sort_values(by=["stage2_energy"], ascending=True, inplace=True) - all_crystals.reset_index(drop=True, inplace=True) - - min_energy = all_crystals['stage2_energy'][0] - all_crystals['relative_energy'] = all_crystals['stage2_energy'] - min_energy - - for cif in cifs: - if cif[:-4] not in all_crystals['file'].values: - # move cif to error folder - src_path = os.path.join(cif_folder, cif) - dest_path = os.path.join(error_folder, cif) - os.rename(src_path, dest_path) - - # Save the cleaned DataFrame back to CSV +import os +import sys +import pandas as pd + + +target_folder = os.getcwd() + + +if __name__ == '__main__': + print(f"Cleaning table in folder: {target_folder}") + error_folder = os.path.join(target_folder, "error_result") + if not os.path.exists(error_folder): + os.makedirs(error_folder) + cif_folder = os.path.join(target_folder, "cif_result_final") + csv_file = os.path.join(target_folder, "results_scheduler.csv") + + if not os.path.exists(csv_file): + print(f"CSV file not found in {target_folder}, skipping...") + sys.exit(0) + if not os.path.exists(cif_folder): + print(f"CIF folder not found in {target_folder}, skipping...") + sys.exit(0) + + all_crystals = pd.read_csv(csv_file) + cifs = os.listdir(cif_folder) + invalid_list = [] + for i, item in all_crystals.iterrows(): + if int(item['stage2_steps']) > 2990 or (item['stage1_steps'] >= 2999 and item["stage2_steps"] == 0) or abs(float(item['stage2_energy'])) > 1e15 or item['file']+".cif" not in cifs: + invalid_list.append(i) + + if invalid_list: + invalid_df = all_crystals.iloc[invalid_list] + all_crystals.drop(index=invalid_list, inplace=True) + + all_crystals.sort_values(by=["stage2_energy"], ascending=True, inplace=True) + all_crystals.reset_index(drop=True, inplace=True) + + min_energy = all_crystals['stage2_energy'][0] + all_crystals['relative_energy'] = all_crystals['stage2_energy'] - min_energy + + for cif in cifs: + if cif[:-4] not in all_crystals['file'].values: + # move cif to error folder + src_path = os.path.join(cif_folder, cif) + dest_path = os.path.join(error_folder, cif) + os.rename(src_path, dest_path) + + # Save the cleaned DataFrame back to CSV all_crystals.to_csv(csv_file, index=False) \ No newline at end of file diff --git a/post_process/duplicate_remove.py b/post_process/duplicate_remove.py index a8da373..a594adb 100644 --- a/post_process/duplicate_remove.py +++ b/post_process/duplicate_remove.py @@ -1,247 +1,247 @@ -import sys -import os -import glob -import pandas as pd -import warnings -from pathlib import Path -from tqdm import tqdm -import multiprocessing as mp -from ccdc.crystal import PackingSimilarity -from ccdc.io import CrystalReader -import shutil -import argparse - -warnings.filterwarnings("ignore", category=DeprecationWarning) - -######################### -# Global Settings -######################### - -# The number of matching molecules required in a packing shell to consider two crystals similar. -molecule_shell_size = 15 - -# Initialize the packing similarity engine from the CCDC API. -similarity_engine = PackingSimilarity() - -# Configure the similarity engine settings. -similarity_engine.settings.allow_molecular_differences = False -similarity_engine.settings.distance_tolerance = 0.2 -similarity_engine.settings.angle_tolerance = 20 -similarity_engine.settings.packing_shell_size = molecule_shell_size - -# Settings to make the comparison less strict regarding hydrogen atoms and bond counts. -similarity_engine.settings.ignore_hydrogen_positions = True -similarity_engine.settings.ignore_bond_counts = True -similarity_engine.settings.ignore_hydrogen_counts = True - -# Pre-filtering thresholds to avoid expensive comparisons for vastly different structures. -ENERGY_DIFF = 5.0 -DENSITY_DIFF = 0.2 - - -def compare_sim_pack_pair(args): - """ - Compares two crystal structures to determine if they are identical based on packing similarity. - Accepts a tuple (candidate_fp, ufp) containing file paths to the two structures. - Returns True if the structures are considered the same, False otherwise. - """ - cand_fp, ufp = args - try: - c1 = CrystalReader(cand_fp)[0] - c2 = CrystalReader(ufp)[0] - h = similarity_engine.compare(c1, c2) - # Structures are deemed identical if the number of matched molecules meets the global shell size. - return h.nmatched_molecules >= molecule_shell_size - except Exception: - # Return False if any error occurs during file reading or comparison. - return False - -def find_top_unique(struct_list, target_count, n_workers): - """ - Identifies a target number of unique structures from a list, sorted by energy. - It iterates through structures from lowest to highest energy and uses parallel processing - to compare each candidate against the list of already confirmed unique structures. - - Args: - struct_list (list): A list of tuples, where each tuple is (density, energy, filepath). - target_count (int): The desired number of unique structures to find. - n_workers (int): The number of worker processes for parallel comparison. - - Returns: - tuple: A tuple containing: - - A list of file paths for the unique structures found. - - A dictionary mapping each unique structure's path to a list of its duplicate paths. - """ - unique_list = [] - - # A dictionary to store each unique structure and its corresponding duplicates. - # Format: {unique_fp: [dup_fp1, dup_fp2, ...]} - duplicate_map = {} - - # Initialize the multiprocessing pool. - pool = mp.Pool(processes=n_workers) - try: - # Iterate through the main structure list, which is pre-sorted by energy. - for dens, ene, cand_fp in struct_list: - # Stop if the target count of unique structures has been reached. - if len(unique_list) >= target_count: - break - - # Create comparison tasks only for unique structures within the density and energy thresholds. - tasks = [(cand_fp, ufp) for dens2, ene2, ufp in unique_list if abs(dens2 - dens) < DENSITY_DIFF and abs(ene2 - ene) < ENERGY_DIFF] - is_dup = False - - if tasks: - # Use tqdm for a progress bar during parallel comparison. - results_iterator = tqdm(pool.imap(compare_sim_pack_pair, tasks), - total=len(tasks), - desc=f"Comparing {os.path.basename(cand_fp)}", - leave=False) - # Convert iterator to a list to process results. - results_list = list(results_iterator) - - # Pair the boolean results with the original tasks to identify which comparison was successful. - for result, task in zip(results_list, tasks): - if result: - is_dup = True - # The second element of the task tuple is the path of the matched unique structure. - matched_unique_fp = task[1] - - # Record the current candidate as a duplicate of the matched unique structure. - duplicate_map[matched_unique_fp].append(cand_fp) - - # Once a duplicate is found, no more comparisons are needed for this candidate. - break - - if not is_dup: - # If the candidate is not a duplicate of any existing unique structure, add it to the list. - unique_list.append((dens, ene, cand_fp)) - - # Create a new entry in the duplicate map for this new unique structure. - duplicate_map[cand_fp] = [] - - # Print real-time progress. - print(f">>> Unique count: {len(unique_list)} (added {os.path.basename(cand_fp)})") - - finally: - # Ensure the multiprocessing pool is properly closed. - pool.close() - pool.join() - - # Return the list of unique file paths and the map of duplicates. - return [fp for _, _, fp in unique_list], duplicate_map - - -def process_folder(folder_path): - """ - Main workflow function to process a given folder. It reads structural data, - finds unique structures, and organizes the results into folders and a CSV file. - """ - base_path = folder_path - - # Define output directories for unique structures and their duplicates. - t_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}") - d_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}_duplicates") - - # Load the summary data from the CSV file. - df = pd.read_csv(os.path.join(base_path, "results_scheduler.csv")) - cif_folder = os.path.join(base_path, "cif_result_final") - ######################### - - # Create a map from a simplified filename to its full file path for quick lookup. - file_paths = glob.glob(os.path.join(cif_folder, "*.cif")) - file_map = {} - for fp in file_paths: - base = Path(fp).stem - # Standardize the name by removing "_opt" suffix if it exists. - if base.endswith("_opt"): - key = base[:-4] - else: - key = base - file_map[key] = fp - - # Build the list of candidate structures, filtering by the energy threshold. - struct_list = [] - for _, row in df.iterrows(): - name = row["file"] - dens = row["stage2_density"] - ene = row["relative_energy"] - if ene > ENERGY_THRESHOLD: - continue - fp = file_map.get(name) - if fp: - struct_list.append((dens, ene, fp)) - else: - print(f"Warning: no CIF for {name}") - - # Sort the list of structures by energy in ascending order. - struct_list.sort(key=lambda x: x[1]) - - # Call the main function to find unique structures and the duplicate map. - unique_paths, duplicate_map = find_top_unique( - struct_list, - target_count=TARGET_UNIQUE_COUNT, - n_workers=WORKER_COUNT - ) - - # Prepare output directories, removing them first if they already exist. - if os.path.exists(t_folder): - shutil.rmtree(t_folder) - os.makedirs(t_folder) - - if os.path.exists(d_folder): - shutil.rmtree(d_folder) - os.makedirs(d_folder) - - print(f"\nFound {len(unique_paths)} unique structures " - f"(energy <= {ENERGY_THRESHOLD}, target {TARGET_UNIQUE_COUNT}).") - - # Copy the unique structure files to the target folder. - for p in unique_paths: - shutil.copy(p, t_folder) - - # Helper function to get a clean filename without the "_opt" suffix. - get_clean_name = lambda p: Path(p).stem.replace('_opt', '') - - # Create a new DataFrame containing only the data for the unique structures. - unique_names = [get_clean_name(p) for p in unique_paths] - unique_df = df[df['file'].isin(unique_names)].copy() - unique_df['duplicates'] = '' # Add a new column to store duplicate names. - - # Populate the 'duplicates' column and copy duplicate files to their folder. - for unique_fp, duplicates_list in duplicate_map.items(): - unique_name = get_clean_name(unique_fp) - if unique_name in unique_df['file'].values: - du_name = [get_clean_name(p) for p in duplicates_list] - # Add the list of duplicate names as a comma-separated string. - unique_df.loc[unique_df['file'] == unique_name, 'duplicates'] = ', '.join(du_name) - # Copy duplicate files to the duplicates folder. - for p in duplicates_list: - shutil.copy(p, d_folder) - - # Save the final DataFrame with unique structures and their duplicates to a new CSV file. - unique_csv_path = os.path.join(base_path, 'unique_structures.csv') - unique_df.to_csv(unique_csv_path, index=False) - -if __name__ == '__main__': - # Required for freezing the application when creating executables with multiprocessing. - mp.freeze_support() - - # Set up command-line argument parsing. - parser = argparse.ArgumentParser() - parser.add_argument('--path', type=str, default="./", help='Path to process') - parser.add_argument('--energy', type=float, default=30, help='Energy threshold for filtering structures') - parser.add_argument('--count', type=int, default=200, help='Target number of unique structures to find') - parser.add_argument('--workers', type=int, default=80, help='Max worker number limit') - args = parser.parse_args() - - # Set global variables from command-line arguments. - target_folder = args.path - ENERGY_THRESHOLD = args.energy - TARGET_UNIQUE_COUNT = args.count - WORKER_COUNT = args.workers - - # Start the processing if the specified folder exists. - if os.path.exists(target_folder): - print(f"Processing folder: {target_folder}") +import sys +import os +import glob +import pandas as pd +import warnings +from pathlib import Path +from tqdm import tqdm +import multiprocessing as mp +from ccdc.crystal import PackingSimilarity +from ccdc.io import CrystalReader +import shutil +import argparse + +warnings.filterwarnings("ignore", category=DeprecationWarning) + +######################### +# Global Settings +######################### + +# The number of matching molecules required in a packing shell to consider two crystals similar. +molecule_shell_size = 15 + +# Initialize the packing similarity engine from the CCDC API. +similarity_engine = PackingSimilarity() + +# Configure the similarity engine settings. +similarity_engine.settings.allow_molecular_differences = False +similarity_engine.settings.distance_tolerance = 0.2 +similarity_engine.settings.angle_tolerance = 20 +similarity_engine.settings.packing_shell_size = molecule_shell_size + +# Settings to make the comparison less strict regarding hydrogen atoms and bond counts. +similarity_engine.settings.ignore_hydrogen_positions = True +similarity_engine.settings.ignore_bond_counts = True +similarity_engine.settings.ignore_hydrogen_counts = True + +# Pre-filtering thresholds to avoid expensive comparisons for vastly different structures. +ENERGY_DIFF = 5.0 +DENSITY_DIFF = 0.2 + + +def compare_sim_pack_pair(args): + """ + Compares two crystal structures to determine if they are identical based on packing similarity. + Accepts a tuple (candidate_fp, ufp) containing file paths to the two structures. + Returns True if the structures are considered the same, False otherwise. + """ + cand_fp, ufp = args + try: + c1 = CrystalReader(cand_fp)[0] + c2 = CrystalReader(ufp)[0] + h = similarity_engine.compare(c1, c2) + # Structures are deemed identical if the number of matched molecules meets the global shell size. + return h.nmatched_molecules >= molecule_shell_size + except Exception: + # Return False if any error occurs during file reading or comparison. + return False + +def find_top_unique(struct_list, target_count, n_workers): + """ + Identifies a target number of unique structures from a list, sorted by energy. + It iterates through structures from lowest to highest energy and uses parallel processing + to compare each candidate against the list of already confirmed unique structures. + + Args: + struct_list (list): A list of tuples, where each tuple is (density, energy, filepath). + target_count (int): The desired number of unique structures to find. + n_workers (int): The number of worker processes for parallel comparison. + + Returns: + tuple: A tuple containing: + - A list of file paths for the unique structures found. + - A dictionary mapping each unique structure's path to a list of its duplicate paths. + """ + unique_list = [] + + # A dictionary to store each unique structure and its corresponding duplicates. + # Format: {unique_fp: [dup_fp1, dup_fp2, ...]} + duplicate_map = {} + + # Initialize the multiprocessing pool. + pool = mp.Pool(processes=n_workers) + try: + # Iterate through the main structure list, which is pre-sorted by energy. + for dens, ene, cand_fp in struct_list: + # Stop if the target count of unique structures has been reached. + if len(unique_list) >= target_count: + break + + # Create comparison tasks only for unique structures within the density and energy thresholds. + tasks = [(cand_fp, ufp) for dens2, ene2, ufp in unique_list if abs(dens2 - dens) < DENSITY_DIFF and abs(ene2 - ene) < ENERGY_DIFF] + is_dup = False + + if tasks: + # Use tqdm for a progress bar during parallel comparison. + results_iterator = tqdm(pool.imap(compare_sim_pack_pair, tasks), + total=len(tasks), + desc=f"Comparing {os.path.basename(cand_fp)}", + leave=False) + # Convert iterator to a list to process results. + results_list = list(results_iterator) + + # Pair the boolean results with the original tasks to identify which comparison was successful. + for result, task in zip(results_list, tasks): + if result: + is_dup = True + # The second element of the task tuple is the path of the matched unique structure. + matched_unique_fp = task[1] + + # Record the current candidate as a duplicate of the matched unique structure. + duplicate_map[matched_unique_fp].append(cand_fp) + + # Once a duplicate is found, no more comparisons are needed for this candidate. + break + + if not is_dup: + # If the candidate is not a duplicate of any existing unique structure, add it to the list. + unique_list.append((dens, ene, cand_fp)) + + # Create a new entry in the duplicate map for this new unique structure. + duplicate_map[cand_fp] = [] + + # Print real-time progress. + print(f">>> Unique count: {len(unique_list)} (added {os.path.basename(cand_fp)})") + + finally: + # Ensure the multiprocessing pool is properly closed. + pool.close() + pool.join() + + # Return the list of unique file paths and the map of duplicates. + return [fp for _, _, fp in unique_list], duplicate_map + + +def process_folder(folder_path): + """ + Main workflow function to process a given folder. It reads structural data, + finds unique structures, and organizes the results into folders and a CSV file. + """ + base_path = folder_path + + # Define output directories for unique structures and their duplicates. + t_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}") + d_folder = os.path.join(base_path, f"unique_{TARGET_UNIQUE_COUNT}_duplicates") + + # Load the summary data from the CSV file. + df = pd.read_csv(os.path.join(base_path, "results_scheduler.csv")) + cif_folder = os.path.join(base_path, "cif_result_final") + ######################### + + # Create a map from a simplified filename to its full file path for quick lookup. + file_paths = glob.glob(os.path.join(cif_folder, "*.cif")) + file_map = {} + for fp in file_paths: + base = Path(fp).stem + # Standardize the name by removing "_opt" suffix if it exists. + if base.endswith("_opt"): + key = base[:-4] + else: + key = base + file_map[key] = fp + + # Build the list of candidate structures, filtering by the energy threshold. + struct_list = [] + for _, row in df.iterrows(): + name = row["file"] + dens = row["stage2_density"] + ene = row["relative_energy"] + if ene > ENERGY_THRESHOLD: + continue + fp = file_map.get(name) + if fp: + struct_list.append((dens, ene, fp)) + else: + print(f"Warning: no CIF for {name}") + + # Sort the list of structures by energy in ascending order. + struct_list.sort(key=lambda x: x[1]) + + # Call the main function to find unique structures and the duplicate map. + unique_paths, duplicate_map = find_top_unique( + struct_list, + target_count=TARGET_UNIQUE_COUNT, + n_workers=WORKER_COUNT + ) + + # Prepare output directories, removing them first if they already exist. + if os.path.exists(t_folder): + shutil.rmtree(t_folder) + os.makedirs(t_folder) + + if os.path.exists(d_folder): + shutil.rmtree(d_folder) + os.makedirs(d_folder) + + print(f"\nFound {len(unique_paths)} unique structures " + f"(energy <= {ENERGY_THRESHOLD}, target {TARGET_UNIQUE_COUNT}).") + + # Copy the unique structure files to the target folder. + for p in unique_paths: + shutil.copy(p, t_folder) + + # Helper function to get a clean filename without the "_opt" suffix. + get_clean_name = lambda p: Path(p).stem.replace('_opt', '') + + # Create a new DataFrame containing only the data for the unique structures. + unique_names = [get_clean_name(p) for p in unique_paths] + unique_df = df[df['file'].isin(unique_names)].copy() + unique_df['duplicates'] = '' # Add a new column to store duplicate names. + + # Populate the 'duplicates' column and copy duplicate files to their folder. + for unique_fp, duplicates_list in duplicate_map.items(): + unique_name = get_clean_name(unique_fp) + if unique_name in unique_df['file'].values: + du_name = [get_clean_name(p) for p in duplicates_list] + # Add the list of duplicate names as a comma-separated string. + unique_df.loc[unique_df['file'] == unique_name, 'duplicates'] = ', '.join(du_name) + # Copy duplicate files to the duplicates folder. + for p in duplicates_list: + shutil.copy(p, d_folder) + + # Save the final DataFrame with unique structures and their duplicates to a new CSV file. + unique_csv_path = os.path.join(base_path, 'unique_structures.csv') + unique_df.to_csv(unique_csv_path, index=False) + +if __name__ == '__main__': + # Required for freezing the application when creating executables with multiprocessing. + mp.freeze_support() + + # Set up command-line argument parsing. + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default="./", help='Path to process') + parser.add_argument('--energy', type=float, default=30, help='Energy threshold for filtering structures') + parser.add_argument('--count', type=int, default=200, help='Target number of unique structures to find') + parser.add_argument('--workers', type=int, default=80, help='Max worker number limit') + args = parser.parse_args() + + # Set global variables from command-line arguments. + target_folder = args.path + ENERGY_THRESHOLD = args.energy + TARGET_UNIQUE_COUNT = args.count + WORKER_COUNT = args.workers + + # Start the processing if the specified folder exists. + if os.path.exists(target_folder): + print(f"Processing folder: {target_folder}") process_folder(target_folder) \ No newline at end of file diff --git a/post_process/run_remove.sh b/post_process/run_remove.sh index 8fc6777..43f555c 100644 --- a/post_process/run_remove.sh +++ b/post_process/run_remove.sh @@ -1,41 +1,40 @@ -#!/bin/sh -l -#An example for GPU job. -#SBATCH -D ./ -#SBATCH --export=ALL -#SBATCH -J csp_test -#SBATCH -o job-%j.log -#SBATCH -e job-%j.err -#SBATCH -p GPU-8A100 -#SBATCH -N 1 -n 8 -#SBATCH --gres=gpu:1 -#SBATCH --qos=gpu_8a100 -#SBATCH -t 3-00:00:00 - - -echo ========================================================= -echo SLURM job: submitted date = `date` -date_start=`date +%s` - -echo ========================================================= -echo Job output begins -echo ----------------- -echo - - -python duplicate_remove.py - - -echo -echo --------------- -echo Job output ends -date_end=`date +%s` -seconds=$((date_end-date_start)) -minutes=$((seconds/60)) -seconds=$((seconds-60*minutes)) -hours=$((minutes/60)) -minutes=$((minutes-60*hours)) -echo ========================================================= -echo SLURM job: finished date = `date` -echo Total run time : $hours Hours $minutes Minutes $seconds Seconds -echo ========================================================= - +#An example for GPU job. +#SBATCH -D ./ +#SBATCH --export=ALL +#SBATCH -J csp_test +#SBATCH -o job-%j.log +#SBATCH -e job-%j.err +#SBATCH -p GPU-8A100 +#SBATCH -N 1 -n 8 +#SBATCH --gres=gpu:1 +#SBATCH --qos=gpu_8a100 +#SBATCH -t 3-00:00:00 + + +echo ========================================================= +echo SLURM job: submitted date = `date` +date_start=`date +%s` + +echo ========================================================= +echo Job output begins +echo ----------------- +echo + + +python duplicate_remove.py + + +echo +echo --------------- +echo Job output ends +date_end=`date +%s` +seconds=$((date_end-date_start)) +minutes=$((seconds/60)) +seconds=$((seconds-60*minutes)) +hours=$((minutes/60)) +minutes=$((minutes-60*hours)) +echo ========================================================= +echo SLURM job: finished date = `date` +echo Total run time : $hours Hours $minutes Minutes $seconds Seconds +echo ========================================================= + -- GitLab